diff --git a/.codex/skills/generate-vpto-release-doc/SKILL.md b/.codex/skills/generate-vpto-release-doc/SKILL.md new file mode 100644 index 000000000..9a88c9847 --- /dev/null +++ b/.codex/skills/generate-vpto-release-doc/SKILL.md @@ -0,0 +1,141 @@ +--- +name: generate-vpto-release-doc +description: Generate or refresh bundled PTO ISA reference docs. Packages `docs/vpto-spec.md` together with `docs/isa/micro-isa/*.md` into the main repo doc `docs/PTO-micro-Instruction-SPEC.md`, packages `docs/isa/tile-op/*.md` into the companion repo doc `docs/PTO-tile-Instruction-SPEC.md`, and still supports emitting a legacy versioned snapshot under `docs/release/`. +--- + +# Generate VPTO Release Doc + +Use this skill when the task is specifically about: +- creating or refreshing the main bundled micro-spec `docs/PTO-micro-Instruction-SPEC.md` +- creating or refreshing the standalone Tile Instruction bundle `docs/PTO-tile-Instruction-SPEC.md` +- emitting a legacy versioned snapshot under `docs/release/` +- regenerating downstream bundled docs from PTOAS sources through the same script + +The main bundled micro-spec starts from `docs/vpto-spec.md`, strips draft-only metadata and appendix content, then inlines every chapter from `docs/isa/micro-isa/` under a dedicated detailed-reference section. `docs/PTO-tile-Instruction-SPEC.md` stays separate and all cross-doc links are rewritten to point at the bundled filenames. + +Do not hand-edit bundled outputs; regenerate them through the script so link rewriting stays reproducible. + +## Canonical Workflow + +1. Pick the target version. The repo-doc bundle filenames do not carry a version suffix; the version is recorded in the version-history bullets inside each file. + +Default output paths: + +```bash +docs/PTO-micro-Instruction-SPEC.md # main bundled micro-spec +docs/PTO-tile-Instruction-SPEC.md # standalone Tile Instruction bundle +docs/release/vpto-spec-v.md # optional legacy versioned snapshot +``` + +2. Run the bundled generator script. + +Generate both repo-doc bundles (default): + +```bash +python3 .codex/skills/generate-vpto-release-doc/scripts/generate_release_vpto_spec.py --version 0.4 +``` + +Generate just one repo-doc bundle: + +```bash +python3 .codex/skills/generate-vpto-release-doc/scripts/generate_release_vpto_spec.py --version 0.4 --target micro +python3 .codex/skills/generate-vpto-release-doc/scripts/generate_release_vpto_spec.py --version 0.4 --target tileop +``` + +Generate only the legacy versioned snapshot: + +```bash +python3 .codex/skills/generate-vpto-release-doc/scripts/generate_release_vpto_spec.py --version 0.4 --target merged +``` + +Generate downstream docs into another directory while preserving old PTO-Gym filenames: + +```bash +python3 .codex/skills/generate-vpto-release-doc/scripts/generate_release_vpto_spec.py \ + --version 0.4 \ + --output-dir 3rdparty/PTO-Gym/docs \ + --micro-output-name PTO-micro-Instruction-SPEC.md \ + --tileop-output-name PTO-tile-Instruction-SPEC.md + +python3 .codex/skills/generate-vpto-release-doc/scripts/generate_release_vpto_spec.py \ + --version 0.4 \ + --target merged \ + --output-dir 3rdparty/PTO-Gym/docs \ + --micro-output-name PTO-micro-Instruction-SPEC.md \ + --tileop-output-name PTO-tile-Instruction-SPEC.md \ + --merged-output-name vpto-spec.md +``` + +Custom version-bullet text: + +```bash +python3 .codex/skills/generate-vpto-release-doc/scripts/generate_release_vpto_spec.py \ + --version 0.4 \ + --version-note 'Custom release note for this run' +``` + +3. Review each generated file. + +Invariants for the main bundle (`docs/PTO-micro-Instruction-SPEC.md`): +- exactly one `#` level title at the top +- `[toc]` is present near the top +- the requested-version bullet is at the top of the version-history list +- beginning-of-file draft metadata (`Status`, `Base`, `Updated`) is removed +- appendix content is removed +- `docs/vpto-spec.md` contributes the overview / notation / summary sections +- `## Detailed ISA Group Reference` exists and inlines every chapter from `docs/isa/micro-isa/` in sorted order +- chapter headings are demoted by two levels so each `# N. ...` becomes `### N. ...` +- source-tree links into `isa/micro-isa/...` are rewritten to in-document `#micro-...` anchors +- source-tree links into `isa/tile-op/...` are rewritten to `PTO-tile-Instruction-SPEC.md#tile-...` + +Invariants for the Tile Instruction bundle (`docs/PTO-tile-Instruction-SPEC.md`): +- exactly one `#` level title at the top +- `[toc]` is present near the top +- every chapter file from `docs/isa/tile-op/` is inlined in sorted order +- intra-bundle links resolve to `` anchors +- cross-bundle links are rewritten to `PTO-micro-Instruction-SPEC.md#micro-XX-name` + +Invariants for the legacy versioned snapshot: +- it carries the same bundled micro-spec content as the main repo doc +- it is emitted under `docs/release/` by default unless `--output-dir` overrides the destination +- Tile Instruction links are rewritten relative to the snapshot location + +4. If the user wants extra release-note wording, patch only the version bullets or other small wording around the generated content. Prefer rerunning the script over hand-merging large sections. + +## Source Mapping + +| Source | Target | +|--------|--------| +| `docs/vpto-spec.md` + `docs/isa/micro-isa/*.md` | `docs/PTO-micro-Instruction-SPEC.md` | +| `docs/isa/tile-op/*.md` | `docs/PTO-tile-Instruction-SPEC.md` | +| `docs/vpto-spec.md` + `docs/isa/micro-isa/*.md` | `docs/release/vpto-spec-v.md` (legacy versioned snapshot) | + +## Merge Rules + +For the main bundled micro-spec the script: +- emits a single top-level title +- prepends a target-specific version-bullet list +- inserts a `[toc]` marker +- starts from `docs/vpto-spec.md` +- strips draft metadata, appendix content, and now-misleading "see individual files" prose +- preserves the high-level overview and summary sections +- rewrites `isa/micro-isa/*.md` links to in-document anchors +- rewrites `isa/tile-op/*.md` links to the companion bundle filename + anchor +- appends a `## Detailed ISA Group Reference` section that inlines all `docs/isa/micro-isa/*.md` chapters in sorted order + +For the standalone Tile Instruction bundle the script: +- emits a single top-level title +- prepends a target-specific version-bullet list +- inserts a `[toc]` marker +- inlines all `docs/isa/tile-op/*.md` files in sorted order +- demotes chapter headings by one level +- emits stable HTML anchors like `` +- rewrites intra-bundle relative links to those anchors +- rewrites cross-bundle relative links to the bundled micro-spec filename + anchor + +## Notes + +- Repo-doc bundle filenames intentionally drop the version suffix; the version is recorded only in the per-file version-history bullets. +- Default version notes for known versions live inside the script; pass `--version-note` to add or override the note for the requested target. +- When chapter filenames or numbering change in `docs/isa/micro-isa/` or `docs/isa/tile-op/`, regenerate both repo-doc bundles and any legacy snapshots so links stay synchronized. +- If downstream consumers still require older filenames such as `pto-micro-instruction.md`, use `--micro-output-name` / `--tileop-output-name` instead of hand-renaming the generated files. diff --git a/.codex/skills/generate-vpto-release-doc/scripts/generate_release_vpto_spec.py b/.codex/skills/generate-vpto-release-doc/scripts/generate_release_vpto_spec.py new file mode 100644 index 000000000..3cb6e52d2 --- /dev/null +++ b/.codex/skills/generate-vpto-release-doc/scripts/generate_release_vpto_spec.py @@ -0,0 +1,600 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""Generate bundled PTO ISA reference docs. + +Default repo docs: +- ``docs/vpto-spec.md`` + ``docs/isa/micro-isa/*.md`` + -> ``docs/PTO-micro-Instruction-SPEC.md`` +- ``docs/isa/tile-op/*.md`` -> ``docs/PTO-tile-Instruction-SPEC.md`` + +Legacy versioned snapshot: +- ``docs/vpto-spec.md`` + ``docs/isa/micro-isa/*.md`` + -> ``docs/release/vpto-spec-v.md`` +""" + +from __future__ import annotations + +import argparse +import re +from pathlib import Path + + +ROOT = Path(__file__).resolve().parents[4] +DOCS_DIR = ROOT / "docs" +ISA_DIR = DOCS_DIR / "isa" +RELEASE_DIR = DOCS_DIR / "release" +SOURCE_SPEC = DOCS_DIR / "vpto-spec.md" + +MICRO_TARGET = "micro" +TILEOP_TARGET = "tileop" +MERGED_TARGET = "merged" + +SPLIT_BUNDLES: dict[str, dict[str, str | Path]] = { + MICRO_TARGET: { + "source_dir": ISA_DIR / "micro-isa", + "output_name": "PTO-micro-Instruction-SPEC.md", + "title": "# PTO micro Instruction Spec — Draft (A5)", + "anchor_prefix": "micro", + "source_dir_name": "micro-isa", + "peer_target": TILEOP_TARGET, + "peer_dir": "tile-op", + }, + TILEOP_TARGET: { + "source_dir": ISA_DIR / "tile-op", + "output_name": "PTO-tile-Instruction-SPEC.md", + "title": "# PTO Tile Instruction SPEC (A5)", + "anchor_prefix": "tile", + "source_dir_name": "tile-op", + "peer_target": MICRO_TARGET, + "peer_dir": "micro-isa", + }, +} + +MICRO_DOC_TITLE = "# PTO micro Instruction Spec — Draft (A5)" +MERGED_OUTPUT_NAME = "vpto-spec-v{version}.md" +PART_IV_HEADING = "## Part IV: PTO Tile Instruction" + +MICRO_REQUIRED_SECTIONS = [ + "## Part I: Architecture Overview", + "## Part II: Notation Convention", + "## Part III: ISA Instruction Reference", + "## Instruction Groups", + "## Supported Data Types", + "## Common Patterns", + "## Quick Reference by Category", +] + +VERSION_NOTES = { + MICRO_TARGET: { + "0.1": "Doc Init", + "0.2": "Update micro Instruction latency and throughput", + "0.3": "Add runtime block query and vector-interval legality notes; Normalize load/store distribution families; Update get_buf/rls_buf details", + "0.4": "Update DMA instruction docs and add PTO Tile Instruction SPEC", + }, + TILEOP_TARGET: { + "0.4": "Initial PTO Tile Instruction SPEC covering core TileOps", + }, + MERGED_TARGET: { + "0.1": "Doc Init", + "0.2": "Update micro Instruction latency and throughput", + "0.3": "Add runtime block query and vector-interval legality notes; Normalize load/store distribution families; Update get_buf/rls_buf details", + "0.4": "Update DMA instruction docs and add PTO Tile Instruction SPEC", + }, +} + +# In-directory chapter link inside one ISA folder, for example +# ``[Foo](02-types-and-attributes.md)`` or ``[Foo](02-types-and-attributes.md#anchor)``. +INTRA_LINK_RE = re.compile( + r"\]\((?P[0-9]{2}-[A-Za-z0-9-]+)\.md(?P#[A-Za-z0-9_-]+)?\)" +) +# Cross-bundle link from one ISA folder to its sibling, for example +# ``[Foo](../micro-isa/01-pipeline-sync.md)``. +PEER_LINK_RE_TEMPLATE = ( + r"\]\(\.\./{peer_dir}/(?P[0-9]{{2}}-[A-Za-z0-9-]+)\.md" + r"(?P#[A-Za-z0-9_-]+)?\)" +) +# Link from the top-level vpto-spec.md to a chapter inside an ISA folder, for example +# ``[Foo](isa/tile-op/01-tile-overview.md)``. +SPEC_LINK_RE_TEMPLATE = ( + r"\]\((?:\.\./)?(?:docs/)?isa/{source_dir_name}/(?P[0-9]{{2}}-[A-Za-z0-9-]+)\.md" + r"(?P#[A-Za-z0-9_-]+)?\)" +) + + +def render_version_bullets(target: str, version: str, version_note: str | None) -> str: + notes = dict(VERSION_NOTES[target]) + if version_note: + notes[version] = version_note + elif version not in notes: + notes[version] = "Release refresh" + + def key_fn(item: str) -> tuple[int, ...]: + return tuple(int(part) for part in item.split(".")) + + return "\n".join( + f"- v{ver}: {notes[ver]}" for ver in sorted(notes, key=key_fn, reverse=True) + ) + + +def extract_sections(markdown: str) -> dict[str, str]: + headings = list(re.finditer(r"^## .*$", markdown, flags=re.MULTILINE)) + sections: dict[str, str] = {} + for index, match in enumerate(headings): + heading = match.group(0).strip() + start = match.start() + end = headings[index + 1].start() if index + 1 < len(headings) else len(markdown) + sections[heading] = markdown[start:end].strip() + "\n" + return sections + + +def trim_trailing_rule(text: str) -> str: + return re.sub(r"\n---\s*\Z", "\n", text.strip() + "\n").rstrip() + + +def demote_headings(text: str, levels: int = 1) -> str: + """Increase ATX heading level by ``levels``, capped at H6.""" + + def replace(match: re.Match[str]) -> str: + hashes = match.group(1) + heading = match.group(2) + new_level = min(6, len(hashes) + levels) + return f"{'#' * new_level} {heading}" + + return re.sub(r"^(#{1,6})\s+(.*)$", replace, text, flags=re.MULTILINE) + + +def strip_spec_unwanted_lines(markdown: str) -> str: + lines = markdown.splitlines() + kept: list[str] = [] + skip_correspondence = False + + for line in lines: + if re.match(r"^## Correspondence Categories\b", line): + skip_correspondence = True + continue + if skip_correspondence: + if re.match(r"^## ", line): + skip_correspondence = False + else: + continue + if line.startswith("> **Status:**"): + continue + if line.startswith("> **Base:**"): + continue + if line.startswith("> **Additions from:**"): + continue + if line.startswith("> **Updated:**"): + continue + if "For detailed semantics, C-style pseudocode, and CCE mappings" in line: + continue + if "CCE correspondence" in line or "builtin mapping" in line.lower(): + continue + kept.append(line) + + text = "\n".join(kept).strip() + "\n" + text = re.sub(r"\n## Appendix\b.*\Z", "\n", text, flags=re.DOTALL) + return re.sub(r"\n{3,}", "\n\n", text).strip() + "\n" + + +def normalize_part_three_heading(text: str) -> str: + lines = text.splitlines() + if len(lines) >= 2 and lines[0].startswith("## Part III: ISA Instruction Reference"): + if lines[1].startswith("# Part III: ISA Instruction Reference"): + lines = ["## Part III: ISA Instruction Reference — Summary"] + lines[2:] + return "\n".join(lines).strip() + "\n" + + +def resolve_bundle_ref( + from_output_target: str, + to_target: str, + split_output_names: dict[str, str], + output_dir_overridden: bool, +) -> str: + bundle_name = split_output_names[to_target] + if from_output_target == MERGED_TARGET and not output_dir_overridden: + return f"../{bundle_name}" + return bundle_name + + +def rewrite_intra_links(text: str, anchor_prefix: str) -> str: + """Rewrite same-bundle chapter links to in-document anchors.""" + + def repl(match: re.Match[str]) -> str: + chapter = match.group("chapter").lower() + anchor_suffix = match.group("anchor") or "" + if anchor_suffix: + return f"]({anchor_suffix})" + return f"](#{anchor_prefix}-{chapter})" + + return INTRA_LINK_RE.sub(repl, text) + + +def rewrite_cross_bundle_links( + text: str, + source_target: str, + from_output_target: str, + split_output_names: dict[str, str], + output_dir_overridden: bool, +) -> str: + """Rewrite cross-bundle chapter links to the companion bundle.""" + + cfg = SPLIT_BUNDLES[source_target] + peer_target: str = cfg["peer_target"] # type: ignore[assignment] + peer_dir: str = cfg["peer_dir"] # type: ignore[assignment] + peer_cfg = SPLIT_BUNDLES[peer_target] + peer_anchor_prefix: str = peer_cfg["anchor_prefix"] # type: ignore[assignment] + peer_bundle_ref = resolve_bundle_ref( + from_output_target, + peer_target, + split_output_names, + output_dir_overridden, + ) + peer_link_re = re.compile(PEER_LINK_RE_TEMPLATE.format(peer_dir=re.escape(peer_dir))) + + def repl(match: re.Match[str]) -> str: + chapter = match.group("chapter").lower() + anchor_suffix = match.group("anchor") or "" + if anchor_suffix: + return f"]({peer_bundle_ref}{anchor_suffix})" + return f"]({peer_bundle_ref}#{peer_anchor_prefix}-{chapter})" + + return peer_link_re.sub(repl, text) + + +def rewrite_spec_chapter_links( + text: str, + same_doc_targets: set[str], + from_output_target: str, + split_output_names: dict[str, str], + output_dir_overridden: bool, +) -> str: + """Rewrite ``docs/vpto-spec.md`` chapter links to bundled outputs.""" + + for target, cfg in SPLIT_BUNDLES.items(): + source_dir_name: str = cfg["source_dir_name"] # type: ignore[assignment] + anchor_prefix: str = cfg["anchor_prefix"] # type: ignore[assignment] + link_re = re.compile( + SPEC_LINK_RE_TEMPLATE.format(source_dir_name=re.escape(source_dir_name)) + ) + + if target in same_doc_targets: + + def repl(match: re.Match[str]) -> str: + chapter = match.group("chapter").lower() + anchor_suffix = match.group("anchor") or "" + if anchor_suffix: + return f"]({anchor_suffix})" + return f"](#{anchor_prefix}-{chapter})" + + else: + bundle_ref = resolve_bundle_ref( + from_output_target, + target, + split_output_names, + output_dir_overridden, + ) + + def repl(match: re.Match[str]) -> str: + chapter = match.group("chapter").lower() + anchor_suffix = match.group("anchor") or "" + if anchor_suffix: + return f"]({bundle_ref}{anchor_suffix})" + return f"]({bundle_ref}#{anchor_prefix}-{chapter})" + + text = link_re.sub(repl, text) + return text + + +def build_chapter_blocks( + source_target: str, + heading_levels: int, + from_output_target: str, + split_output_names: dict[str, str], + output_dir_overridden: bool, +) -> tuple[list[Path], list[str]]: + cfg = SPLIT_BUNDLES[source_target] + source_dir: Path = cfg["source_dir"] # type: ignore[assignment] + if not source_dir.is_dir(): + raise SystemExit(f"source directory not found: {source_dir}") + + chapter_files = sorted(source_dir.glob("*.md")) + if not chapter_files: + raise SystemExit(f"no .md files found in {source_dir}") + + anchor_prefix: str = cfg["anchor_prefix"] # type: ignore[assignment] + blocks: list[str] = [] + + for path in chapter_files: + chapter_id = f"{anchor_prefix}-{path.stem.lower()}" + text = path.read_text().strip() + "\n" + text = rewrite_intra_links(text, anchor_prefix) + text = rewrite_cross_bundle_links( + text, + source_target, + from_output_target, + split_output_names, + output_dir_overridden, + ) + text = demote_headings(text, levels=heading_levels) + text = f'\n\n{text}' + blocks.append(trim_trailing_rule(text)) + + return chapter_files, blocks + + +def build_micro_bundle( + version_target: str, + version: str, + version_note: str | None, + from_output_target: str, + split_output_names: dict[str, str], + output_dir_overridden: bool, +) -> str: + source_text = strip_spec_unwanted_lines(SOURCE_SPEC.read_text()) + sections = extract_sections(source_text) + + missing = [name for name in MICRO_REQUIRED_SECTIONS if name not in sections] + if missing: + raise SystemExit(f"missing expected headings in {SOURCE_SPEC}: {missing}") + + rendered_sections: list[str] = [] + for name in MICRO_REQUIRED_SECTIONS[:4]: + text = sections[name] + if name == "## Part III: ISA Instruction Reference": + text = normalize_part_three_heading(text) + text = rewrite_spec_chapter_links( + text, + same_doc_targets={MICRO_TARGET}, + from_output_target=from_output_target, + split_output_names=split_output_names, + output_dir_overridden=output_dir_overridden, + ) + rendered_sections.append(trim_trailing_rule(text)) + + chapter_files, chapter_blocks = build_chapter_blocks( + MICRO_TARGET, + heading_levels=2, + from_output_target=from_output_target, + split_output_names=split_output_names, + output_dir_overridden=output_dir_overridden, + ) + detailed_section = "\n".join( + [ + "## Detailed ISA Group Reference", + "", + ( + f"This section inlines the {len(chapter_files)} ISA group documents so the " + "architectural overview, notation, summary table, and per-group semantics can " + "be read in a single file." + ), + "", + "\n\n".join(chapter_blocks), + ] + ).strip() + + for name in MICRO_REQUIRED_SECTIONS[4:]: + text = rewrite_spec_chapter_links( + sections[name], + same_doc_targets={MICRO_TARGET}, + from_output_target=from_output_target, + split_output_names=split_output_names, + output_dir_overridden=output_dir_overridden, + ) + rendered_sections.append(trim_trailing_rule(text)) + + if PART_IV_HEADING in sections: + part_four = rewrite_spec_chapter_links( + sections[PART_IV_HEADING], + same_doc_targets={MICRO_TARGET}, + from_output_target=from_output_target, + split_output_names=split_output_names, + output_dir_overridden=output_dir_overridden, + ) + rendered_sections.append(trim_trailing_rule(part_four)) + + body = rendered_sections[:4] + [detailed_section] + rendered_sections[4:] + parts = [ + MICRO_DOC_TITLE, + "", + render_version_bullets(version_target, version, version_note), + "", + "[toc]", + "", + "---", + "", + "\n\n".join(body), + "", + ] + return "\n".join(parts) + + +def build_tileop_bundle( + version: str, + version_note: str | None, + split_output_names: dict[str, str], + output_dir_overridden: bool, +) -> str: + cfg = SPLIT_BUNDLES[TILEOP_TARGET] + _, blocks = build_chapter_blocks( + TILEOP_TARGET, + heading_levels=1, + from_output_target=TILEOP_TARGET, + split_output_names=split_output_names, + output_dir_overridden=output_dir_overridden, + ) + + parts = [ + cfg["title"], + "", + render_version_bullets(TILEOP_TARGET, version, version_note), + "", + "[toc]", + "", + "---", + "", + "\n\n".join(blocks), + "", + ] + return "\n".join(parts) + + +def validate_micro_bundle(text: str) -> None: + if text.count(MICRO_DOC_TITLE) != 1: + raise SystemExit(f"expected exactly one top-level title: {MICRO_DOC_TITLE!r}") + if len(re.findall(r"^# ", text, flags=re.MULTILINE)) != 1: + raise SystemExit("expected exactly one top-level heading in micro bundle") + if "\n[toc]\n" not in text: + raise SystemExit("missing [toc] near top") + if "## Detailed ISA Group Reference" not in text: + raise SystemExit("missing Detailed ISA Group Reference section") + if re.search(r"^> \*\*(Status|Base|Updated):", text, flags=re.MULTILINE): + raise SystemExit("beginning metadata must not remain in micro bundle") + if re.search(r"^## Appendix\b", text, flags=re.MULTILINE): + raise SystemExit("appendix content must not remain in micro bundle") + if "../micro-isa/" in text or "../tile-op/" in text: + raise SystemExit("stale relative ISA directory links remain in micro bundle") + if "isa/micro-isa/" in text or "isa/tile-op/" in text or "docs/isa/" in text: + raise SystemExit("stale source-tree ISA links remain in micro bundle") + + +def validate_tileop_bundle(text: str) -> None: + title: str = SPLIT_BUNDLES[TILEOP_TARGET]["title"] # type: ignore[assignment] + if text.count(title) != 1: + raise SystemExit(f"expected exactly one top-level title: {title!r}") + if len(re.findall(r"^# ", text, flags=re.MULTILINE)) != 1: + raise SystemExit("expected exactly one top-level heading in tileop bundle") + if "\n[toc]\n" not in text: + raise SystemExit("missing [toc] near top") + if "../micro-isa/" in text or "../tile-op/" in text: + raise SystemExit("stale relative ISA directory links remain in tileop bundle") + + +def resolve_split_output_names(args: argparse.Namespace) -> dict[str, str]: + return { + MICRO_TARGET: args.micro_output_name + or SPLIT_BUNDLES[MICRO_TARGET]["output_name"], # type: ignore[index] + TILEOP_TARGET: args.tileop_output_name + or SPLIT_BUNDLES[TILEOP_TARGET]["output_name"], # type: ignore[index] + } + + +def resolve_output_name( + target: str, + version: str, + split_output_names: dict[str, str], + merged_output_name: str | None, +) -> str: + if target == MERGED_TARGET: + if merged_output_name: + return merged_output_name + return MERGED_OUTPUT_NAME.format(version=version) + return split_output_names[target] + + +def resolve_output_dir(target: str, output_dir: Path | None) -> Path: + if output_dir is not None: + return output_dir + if target == MERGED_TARGET: + return RELEASE_DIR + return DOCS_DIR + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--version", required=True, help="Release version, for example 0.4") + parser.add_argument( + "--version-note", + help="Version bullet text for the requested target and version", + ) + parser.add_argument( + "--target", + choices=sorted(SPLIT_BUNDLES.keys()) + [MERGED_TARGET, "all"], + default="all", + help="Which target to generate. 'all' keeps the repo-doc default.", + ) + parser.add_argument( + "--output-dir", + type=Path, + help=( + "Override the output directory. By default, repo-doc bundles go to docs/ " + "and the legacy versioned snapshot goes to docs/release/." + ), + ) + parser.add_argument( + "--micro-output-name", + help="Override the main bundled micro-spec filename.", + ) + parser.add_argument( + "--tileop-output-name", + help="Override the PTO Tile Instruction bundle filename.", + ) + parser.add_argument( + "--merged-output-name", + help="Override the legacy merged-doc output filename, for example vpto-spec.md", + ) + args = parser.parse_args() + + if args.target == "all": + targets = [MICRO_TARGET, TILEOP_TARGET] + else: + targets = [args.target] + + output_dir_overridden = args.output_dir is not None + split_output_names = resolve_split_output_names(args) + + for target in targets: + if target == MICRO_TARGET: + text = build_micro_bundle( + version_target=MICRO_TARGET, + version=args.version, + version_note=args.version_note, + from_output_target=MICRO_TARGET, + split_output_names=split_output_names, + output_dir_overridden=output_dir_overridden, + ) + validate_micro_bundle(text) + elif target == TILEOP_TARGET: + text = build_tileop_bundle( + version=args.version, + version_note=args.version_note, + split_output_names=split_output_names, + output_dir_overridden=output_dir_overridden, + ) + validate_tileop_bundle(text) + else: + text = build_micro_bundle( + version_target=MERGED_TARGET, + version=args.version, + version_note=args.version_note, + from_output_target=MERGED_TARGET, + split_output_names=split_output_names, + output_dir_overridden=output_dir_overridden, + ) + validate_micro_bundle(text) + + output_dir = resolve_output_dir(target, args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + output_name = resolve_output_name( + target, + args.version, + split_output_names, + args.merged_output_name, + ) + output = output_dir / output_name + output.write_text(text) + try: + shown = output.relative_to(ROOT) + except ValueError: + shown = output + print(f"wrote {shown}") + + +if __name__ == "__main__": + main() diff --git a/.codex/skills/llvm-test-tool-fallback/SKILL.md b/.codex/skills/llvm-test-tool-fallback/SKILL.md new file mode 100644 index 000000000..e71bf178d --- /dev/null +++ b/.codex/skills/llvm-test-tool-fallback/SKILL.md @@ -0,0 +1,16 @@ +--- +name: llvm-test-tool-fallback +description: When `lit` or `FileCheck` is missing from the current shell, look for the corresponding LLVM test tools in the environment or existing LLVM workspace before treating it as a repo issue. +--- + +# LLVM Test Tool Fallback + +Use this skill when: +- `python3 -m lit` fails because `lit` is missing +- `FileCheck` is not in `PATH` +- a test command fails only because LLVM test tools are not available in the current shell + +Rule: +- do not stop at `command not found` +- first try to find `lit` / `FileCheck` from the environment's LLVM toolchain or an existing LLVM workspace +- treat missing `lit` / `FileCheck` as an environment-tool issue, not as a PTOAS regression diff --git a/.codex/skills/pto-a5-installed-impl-trace/SKILL.md b/.codex/skills/pto-a5-installed-impl-trace/SKILL.md new file mode 100644 index 000000000..43190c015 --- /dev/null +++ b/.codex/skills/pto-a5-installed-impl-trace/SKILL.md @@ -0,0 +1,210 @@ +--- +name: pto-a5-installed-impl-trace +description: Guide LLVM IR discovery for A5 VPTO lowering from the installed CANN/PTO implementation under ASCEND_HOME_PATH. Use when the user does not yet know which `llvm.hivm.*` intrinsic, builtin wrapper, or operand contract a VPTO/A5 op should lower to. +--- + +# PTO A5 Installed Implementation Trace + +Use this skill when the task is specifically about: +- checking what an A5 PTO op really does on the installed machine +- mapping PTO/A5 behavior to builtins or LLVM/HIVM intrinsics +- tracing PTO wrappers down to CCE builtin wrappers such as `__builtin_cce_*` +- deciding whether repo-local lowering is correct or only a guess +- resolving conflicts between generated repo IR and installed PTO headers +- tracing `Cmp`, `Cmps`, predicate, pack, store, or typed vector behavior + +This skill answers: +- what LLVM IR a VPTO op should lower to +- what the authoritative intrinsic name is +- what operand list or mask form the installed toolchain expects +- whether repo-local lowering or emission diverges from installed behavior + +This skill does not answer: +- how to build or link a finished LLVM-path artifact end to end +- how to package `.o`, `fatobj`, or `.so` +- how to run board validation + +## Strong Rule + +If you are about to change repo code for an A5 op, stop and inspect the +installed PTO implementation first. Treat the installed PTO library under +`ASCEND_HOME_PATH` as the semantic source of truth. + +Only make a repo-local substitution after you have confirmed one of: +- the installed PTO headers already express that replacement relationship +- the frontend/compiler intrinsic contract proves two forms are equivalent at + the intrinsic layer + +Do not guess behavior from repo-local lowering, emitter code, or from what +"seems plausible" for an intrinsic sequence. + +Do not start from repo-local lowering when the question is about real A5 +behavior. The installed PTO implementation under `ASCEND_HOME_PATH` is the +first source of truth. + +## Required Search Order + +Always follow this order: + +1. `source /usr/local/Ascend/cann/set_env.sh` +2. confirm `ASCEND_HOME_PATH` +3. inspect installed PTO dispatch headers: + - `$ASCEND_HOME_PATH/aarch64-linux/include/pto/common/pto_instr_impl.hpp` +4. inspect the matching A5 implementation: + - `$ASCEND_HOME_PATH/aarch64-linux/include/pto/npu/a5/T*.hpp` +5. inspect typed helpers: + - `$ASCEND_HOME_PATH/aarch64-linux/include/pto/npu/a5/utils.hpp` +6. inspect builtin wrapper headers when the question is about the real compiler-facing builtin: + - `$ASCEND_HOME_PATH/tools/bisheng_compiler/lib/clang/*/include/__clang_cce_vector_intrinsics.h` + - `$ASCEND_HOME_PATH/tools/bisheng_compiler/lib/clang/*/include/npu_arch_*/__clang_cce_vector_intrinsics.h` +7. inspect intrinsic name availability directly from the installed compiler binary before guessing LLVM/HIVM spellings: + - `strings $ASCEND_HOME_PATH/bin/bisheng | rg 'llvm\\.hivm\\.'` + - narrow to the op under investigation, for example: + - `strings $ASCEND_HOME_PATH/bin/bisheng | rg 'llvm\\.hivm\\.(vneg|vrsqrt|vnot|vmov)'` +8. only then compare against repo-local code under `lib/PTO/Transforms/` + +## Practical Fast Path + +For VPTO LLVM emission work, prefer this concrete order instead of jumping +straight to ad hoc compiler probes: + +1. confirm the op exists in installed PTO/A5 headers +2. confirm the builtin wrapper shape in installed Clang headers +3. confirm the intrinsic name family with: + - `strings $ASCEND_HOME_PATH/bin/bisheng | rg 'llvm\\.hivm\\.'` +4. patch repo-local emitter/lowering as little as possible +5. generate real repo-driven LLVM IR through the existing VPTO validation path: + - `source scripts/ptoas_env.sh` + - `WORK_SPACE=/tmp/ CASE_NAME= DEVICE=SIM COMPILE_ONLY=1 test/vpto/scripts/run_host_vpto_validation.sh` +6. inspect: + - `//*.ll` + - `//validation.log` +7. if you only have an AICore `.bc` from `-save-temps`, convert it back to + textual LLVM IR with: + - `source scripts/ptoas_env.sh` + - `bisheng --target=hiipu64-hisilicon-cce -Xclang -cce-bitcode-is-aicore -S -emit-llvm -c .bc -o .ll` + - this is useful for installed PTO / `pto-isa` traces where `*.tmp.bc` + exists but no `.ll` was saved + - do not use bare `bisheng -S -emit-llvm .bc`; on this machine that + falls back to the host target and can crash in the backend +8. only after seeing the real generated `.ll` and Bisheng failure should you + refine the call shape + +This route is preferred because it preserves the real PTOAS lowering context, +the real case structure, and the exact driver invocation used by the repo. + +## Probe Strategy + +Use probes in this order: + +1. installed headers +2. `strings bisheng` +3. repo-generated VPTO LLVM IR from `run_host_vpto_validation.sh` +4. if needed, recover textual LLVM IR from saved AICore bitcode with: + - `bisheng --target=hiipu64-hisilicon-cce -Xclang -cce-bitcode-is-aicore -S -emit-llvm -c .bc -o .ll` +5. only then minimal handwritten `.ll` probes +6. handwritten `.cce` frontend probes are last resort + +Handwritten `.ll` probes are acceptable for quick ABI sanity checks such as: +- whether Bisheng recognizes a specific `llvm.hivm.*` name +- whether a guessed argument count immediately crashes or verifies + +But they are not the primary source of truth for semantic or frontend wrapper +behavior. + +## Avoid These Traps + +Do not default to handwritten `.cce` probes when repo-driven IR is available. +On this machine, bare `.cce` probes often fail before reaching the real +question because they are missing the exact frontend driver mode, target +features, wrapper setup, or host/device compilation context used by the repo. + +In particular, treat these as warning signs that you have started too low in +the stack: +- errors around `[aicore]` +- errors around `__cce_half` +- builtin alias attribute failures +- missing target feature or wrapper environment failures + +When these happen, step back to the repo-driven compile-only flow instead of +trying to repair the ad hoc frontend invocation from scratch. + +## Trace By The Real Type Split + +Do not infer the active implementation from the final storage type alone. +Follow the source element type and the installed dispatch branch. + +Example: +- for `Cmp` with `f32 -> ui8`, inspect the `sizeof(src) == 4` branch, not the + `ui8` destination branch +- for scalar or packed outputs, treat pack/store ops separately from compare + predicate generation + +Typical A5 compare split: +- 32-bit source elements -> `TCmp_32B` / `TCmps_32B` +- 16-bit source elements -> 16-bit branch +- 8-bit source elements -> 8-bit branch + +## What To Extract + +When tracing an op, capture: +- the installed PTO entrypoint that handles it +- the exact typed branch that matches the user case +- the builtins used in order +- any typed helper that explains `pset/plt` or store packing selection +- the compiler builtin wrapper if it is visible in installed Clang headers + +For compare-family questions, separate: +- predicate generation +- compare builtin +- predicate pack/interleave +- predicate store + +Stop at the builtin wrapper layer if the lower compiler implementation is not +available. That is still enough to answer questions such as: +- `pset_b32 -> __builtin_cce_pset_b32` +- `plt_b32 -> __builtin_cce_plt_b32_v300` + +## When The Builtin Name Is Still Not Enough + +If the installed PTO headers tell you the wrapper builtin but that still does +not answer the LLVM/HIVM operand contract, do not guess from repo-local +lowering. Extend the trace using the generated repo testcase first, and only +after that the real compiler frontend: + +1. run an existing repo case with: + - `WORK_SPACE=/tmp/ CASE_NAME= DEVICE=SIM COMPILE_ONLY=1 test/vpto/scripts/run_host_vpto_validation.sh` +2. inspect the generated `.ll` and `validation.log` +3. if the repo-generated LLVM IR still leaves the contract ambiguous, inspect + the testcase build flags from: + - `/build/CMakeFiles/.dir/flags.make` + - `/build/CMakeFiles/.dir/build.make` +4. rerun the same `bisheng` compile with `-v` and `-save-temps` +5. inspect: + - `*.ccei` for the exact installed PTO wrapper call sequence + - `strings *.bc | rg 'llvm.hivm\\.'` to see which HIVM intrinsics survived +6. if needed, recover textual IR from the saved AICore bitcode: + - `bisheng --target=hiipu64-hisilicon-cce -Xclang -cce-bitcode-is-aicore -S -emit-llvm -c .bc -o .ll` +7. if needed, rerun the same frontend compile with `-S`, `-emit-llvm`, or the + equivalent `cc1` invocation from `-v` to inspect the real LLVM IR emitted by + the compiler frontend before instruction selection + +This is the required fallback when the question is really: +- what exact `llvm.hivm.*` intrinsic shape the compiler expects +- whether a hand-written LLVM IR call shape is valid +- whether a selector failure is caused by a guessed mask/value form + +Prefer this real-frontend route over inventing mask constants or argument +shapes from memory. + +## Reporting Back + +When you use this skill, report: +- the exact installed header paths inspected +- whether `strings $ASCEND_HOME_PATH/bin/bisheng` confirmed the intrinsic name +- which typed branch was the authoritative one +- the builtin sequence observed there +- the builtin wrapper name if you found one in the installed Clang headers +- whether repo-generated `.ll` matched the guessed call shape +- whether repo-local lowering matches or diverges +- the first concrete mismatch, if any diff --git a/.codex/skills/pto-gym-vpto-validation/SKILL.md b/.codex/skills/pto-gym-vpto-validation/SKILL.md new file mode 100644 index 000000000..0e1451a61 --- /dev/null +++ b/.codex/skills/pto-gym-vpto-validation/SKILL.md @@ -0,0 +1,85 @@ +--- +name: pto-gym-vpto-validation +description: Run PTO-Gym validation from this PTOAS repo. Use when the user asks to run PTO-Gym SIM or board validation from the current source tree. Always force PTOAS onto the VPTO LLVM path instead of relying on the repo default backend. +--- + +# PTO-Gym VPTO Validation + +Use this skill when the task is specifically about: +- running `3rdparty/PTO-Gym/examples/pto/scripts/run_host_vpto_validation.sh` +- running `3rdparty/PTO-Gym/examples/pto/scripts/run_host_vpto_validation_parallel.sh` +- validating PTO-Gym cases from this PTOAS source tree + +## Required Rule + +When PTO-Gym is run from this repo, do not rely on the default PTOAS backend. + +Always pass PTOAS flags that force the VPTO LLVM path. +The current `ptoas` CLI spellings in this repo are `--pto-backend=vpto` and +`--vpto-emit-hivm-llvm`; do not shorten `--pto-backend` to `--backend`. + +Use: + +```bash +PTOAS_FLAGS='--pto-backend=vpto --vpto-emit-hivm-llvm --pto-arch a5' +``` + +If the caller already provides `PTOAS_FLAGS`, make sure these options are still +present. Do not silently fall back to the repo default backend. + +## Canonical Environment + +Use `.work/` under the repo for all scratch output and temp files: + +```bash +mkdir -p .work/tmp .work/runs +export TMPDIR=$PWD/.work/tmp +export TMP=$TMPDIR +export TEMP=$TMPDIR +``` + +Typical simulator environment: + +```bash +source /home/mouliangyu/.local/ascend/beta.2/cann-9.0.0-beta.2/set_env.sh +export ASCEND_HOME_PATH=/home/mouliangyu/.local/ascend/beta.2/cann-9.0.0-beta.2 +export PTOAS_BIN=$PWD/build/tools/ptoas/ptoas +export PTOAS_FLAGS='--pto-backend=vpto --vpto-emit-hivm-llvm --pto-arch a5' +``` + +## Canonical Commands + +Single case: + +```bash +WORK_SPACE=$PWD/.work/runs/pto-gym-single \ +ASCEND_HOME_PATH=$ASCEND_HOME_PATH \ +PTOAS_BIN=$PTOAS_BIN \ +PTOAS_FLAGS="$PTOAS_FLAGS" \ +CASE_NAME=micro-op/binary-vector/vadd \ +DEVICE=SIM \ +bash 3rdparty/PTO-Gym/examples/pto/scripts/run_host_vpto_validation.sh +``` + +Parallel micro-op sweep: + +```bash +WORK_SPACE=$PWD/.work/runs/pto-gym-microop \ +ASCEND_HOME_PATH=$ASCEND_HOME_PATH \ +PTOAS_BIN=$PTOAS_BIN \ +PTOAS_FLAGS="$PTOAS_FLAGS" \ +CASE_PREFIX=micro-op \ +DEVICE=SIM \ +JOBS=64 \ +bash 3rdparty/PTO-Gym/examples/pto/scripts/run_host_vpto_validation_parallel.sh +``` + +## Reporting Back + +Report: +- the exact `PTOAS_FLAGS` used +- the final `PASS/FAIL` counts +- the summary file path under `.work/runs/...` + +If a run fails, identify the first failing case from `parallel-summary.tsv` and +then inspect that case directory under `WORK_SPACE`. diff --git a/.codex/skills/ptoas-bisheng-asm-from-object-cmd/SKILL.md b/.codex/skills/ptoas-bisheng-asm-from-object-cmd/SKILL.md new file mode 100644 index 000000000..b50bb36fc --- /dev/null +++ b/.codex/skills/ptoas-bisheng-asm-from-object-cmd/SKILL.md @@ -0,0 +1,37 @@ +--- +name: ptoas-bisheng-asm-from-object-cmd +description: Use when you need assembly for a PTOAS VPTO case that already compiles to a device object. First find the exact command that produced the `.o`, then derive the `.s` command by replacing `-c` with `-S`. Do not guess a fresh Bisheng command line. +metadata: + short-description: Derive `.s` from real `.o` command +--- + +# PTOAS Bisheng ASM From Object Command + +Use this skill when the task is to inspect generated assembly for a VPTO case and the case already has a known `.o` build path. + +## Rule + +- Do not invent a new `bisheng` command. +- First find the exact command that built the `.o`. +- Then derive the `.s` command from that exact command by changing `-c` to `-S`. +- Keep the rest of the arguments unchanged unless the original command already wrote to a conflicting output path. + +## Preferred Sources + +- Validation script logs +- Build scripts such as `test/vpto/scripts/run_host_vpto_validation.sh` +- Saved shell history or generated compile traces in the case workspace + +## Procedure + +1. Locate the real `.o` compile command for the target case. +2. Copy that command exactly. +3. Replace `-c` with `-S`. +4. Point `-o` to a `.s` path. +5. Run the derived command. +6. Inspect the generated assembly instead of guessing from LLVM IR. + +## Anti-Pattern + +- Do not hand-write a new `bisheng -S ...` command from memory. +- Do not drop flags such as `--target`, `-march`, `--cce-aicore-arch`, `--cce-aicore-only`, `-O2`, include paths, or wrapper options that were present in the real `.o` command. diff --git a/.codex/skills/ptoas-build-and-abs/SKILL.md b/.codex/skills/ptoas-build-and-abs/SKILL.md new file mode 100644 index 000000000..bbfd993e3 --- /dev/null +++ b/.codex/skills/ptoas-build-and-abs/SKILL.md @@ -0,0 +1,101 @@ +--- +name: ptoas-build-and-abs +description: Rebuild PTOAS in the repo build directory and compile the Abs sample to inspect generated VPTO output. Use when the user asks to build ptoas, rebuild the current build tree, or run/check the Abs sample output. +--- + +# PTOAS Build And Abs + +Use this skill when the task is specifically about: +- rebuilding `ptoas` in this repo +- doing a full repo build in the repo-local `build/` directory +- compiling `test/samples/Abs/abs.py` +- inspecting the generated VPTO text for `Abs` + +## Canonical Commands + +### 1. Configure the repo-local build directory + +`do_cmake.sh` is the canonical entrypoint. It always targets `./build`. + +```bash +./do_cmake.sh --llvm /data/mouliangyu/projects/github.com/llvm/llvm-project/install +``` + +If `do_cmake.sh` fails because `build/` has a generator mismatch between old Makefiles/Ninja metadata, do not guess. State that `build/` is inconsistent and ask before cleaning the generated build metadata in `build/`. + +### 2. Build + +For just the CLI: + +```bash +CCACHE_DISABLE=1 ninja -C build ptoas +``` + +For a full repo build: + +```bash +CCACHE_DISABLE=1 ninja -C build +``` + +If the user asked for "full build", prefer the full command above. If they only want to run `Abs`, building `ptoas` is usually enough. + +### 3. Prepare runtime environment + +Before running `runop.sh`, always: + +```bash +source env.sh +``` + +This sets `PYTHONPATH`, `LD_LIBRARY_PATH`, and the MLIR/PTO python roots needed by the samples. + +### 4. Compile `Abs` to VPTO text + +Use `runop.sh` with explicit `PTOAS_BIN`, explicit output directory, and A5 backend flags: + +```bash +source env.sh +PTOAS_BIN="$PWD/build/tools/ptoas/ptoas" \ +PTOAS_OUT_DIR=/tmp/ptoas-abs-vpto \ +PTOAS_FLAGS='--pto-arch a5 --pto-backend=vpto --vpto-print-ir' \ +./test/samples/runop.sh -t Abs +``` + +Expected outputs: +- `/tmp/ptoas-abs-vpto/Abs/abs-pto-ir.pto` +- `/tmp/ptoas-abs-vpto/Abs/abs-pto.cpp` + +Despite the `.cpp` suffix, on the VPTO backend this file contains the emitted VPTO textual IR. + +## Inspection + +The main file to show the user is: + +```bash +sed -n '1,260p' /tmp/ptoas-abs-vpto/Abs/abs-pto.cpp +``` + +For quick sanity checks, look for: +- `vpto.copy_gm_to_ubuf` +- `src_strides = [32, 1]` +- `trace_offsets = [0, 0]` +- `trace_sizes = [32, 32]` +- `cce_aiv_loop_hint` +- `llvm.loop.aivector_scope` +- `vpto.vlds` +- `vpto.vabs` +- `vpto.vsts` +- `vpto.copy_ubuf_to_gm` + +## Reporting Back + +When you ran `Abs`, report: +- whether `ptoas` had to be rebuilt +- the exact generated file path for the VPTO text +- whether the output contains the expected copy-family metadata and vec-scope carrier attrs + +If the build fails, include the first concrete blocker: +- generator mismatch in `build/` +- link failure in `ptoas` +- missing runtime env because `env.sh` was not sourced +- missing sample output file diff --git a/.codex/skills/ptoas-manual-authoring/SKILL.md b/.codex/skills/ptoas-manual-authoring/SKILL.md new file mode 100644 index 000000000..7e80c3155 --- /dev/null +++ b/.codex/skills/ptoas-manual-authoring/SKILL.md @@ -0,0 +1,109 @@ +--- +name: ptoas-manual-authoring +description: Write or revise PTOAS user-facing manuals and ISA/spec docs. Use when updating docs under `docs/`, especially op manuals, syntax references, semantic descriptions, constraints, examples, or bundled release docs, and when the user asks for clear semantic documentation without exposing lowering, raw ops, registers, intrinsics, or other implementation details. +--- + +# PTOAS Manual Authoring + +Use this skill when writing user-facing PTOAS manuals, ISA chapters, op +references, release specs, and examples. The goal is an accurate current manual +for users of the IR, not an implementation note. + +## Boundary + +User-facing manuals describe the operation contract. Do not expose lowering +details unless the user explicitly asks for an implementation design document. + +Do not mention: + +- raw op layers, bridge ops, or wrapper-to-raw expansion +- legacy or low-level operation names in high-level manuals or in + `docs/vpto-spec.md` high-level summaries, inventories, and navigation tables +- hardware registers, control-bit numbers, packed instruction fields, or + intrinsic names +- pass names, lowering helper names, emitter internals, or source file paths +- historical alternatives, removed syntax, migration notes, or stale design + states that are no longer valid + +Implementation details may live in `docs/designs/` or implementation plans, but +not in the stable user manual. + +Legacy and low-level operation references may remain only in explicitly legacy +release snapshots or dedicated low-level reference material. Do not promote them +into current high-level op manuals, examples, or `docs/vpto-spec.md` indexes. + +## Required Content + +For each op or op family, document these items when applicable: + +- Purpose: what logical operation the op represents. +- Syntax: complete assembly form, including optional clauses and their order. +- Operands and attributes: names, types, address spaces, units, defaults, and + whether values are element counts, bytes, strides, flags, modes, or pointers. +- Legal expressions: all accepted keywords, enum values, flags, and mutually + exclusive forms. +- Constraints: type combinations, address-space requirements, target-profile + availability, alignment, shape/layout restrictions, and cross-operand rules. +- Detailed semantics: exact logical meaning of each operand and clause. +- Semantic pseudocode: reference-style pseudocode for the observable result. +- Hardware execution logic: describe the user-visible execution behavior, data + movement, pipeline ordering, synchronization, layout transformation, numeric + mode, saturation, rounding, accumulation, or broadcasting behavior without + naming underlying instructions or register fields. +- Examples: minimal but meaningful examples using non-trivial values or + realistic shapes. Avoid examples that only prove parsing. + +## Writing Rules + +- Prefer semantic names over hardware-field names. +- State units explicitly. For strides and lengths, say whether they are bytes, + elements, tiles, blocks, or C0 units. +- State defaults explicitly. If omitting a clause inherits surrounding state, + say that; if it means disabled, say disabled. +- State invalid combinations as constraints instead of implying them through + examples. +- Keep the manual canonical. Remove obsolete plans and superseded forms instead + of preserving them for history. +- When updating `docs/vpto-spec.md`, keep high-level summaries, inventories, + and chapter op lists aligned with the current semantic surface. Do not list + legacy implementation ops just because they exist in ODS or old manuals. +- Avoid vague phrases such as "sets parameters", "configures the pipeline", or + "does the conversion" unless followed by the concrete values, organization, + and observable effect. +- If behavior is inferred from simulator or hardware validation, write the + semantic result and note uncertainty only when it still affects users. + +## Pseudocode Guidance + +Pseudocode should model the observable operation, not the lowering sequence. + +Use logical buffers and indices: + +```text +for m in 0 .. M: + for n in 0 .. N: + dst[m, n] = ... +``` + +When layout transforms are involved, describe source indexing, destination +indexing, shape interpretation, stride units, and padding or invalid-lane +behavior. When numeric modes are involved, show when rounding, saturation, +conversion, or exceptional-value handling occurs relative to the main +calculation. + +## Review Checklist + +Before finishing a manual edit: + +- The doc answers "what values can I write" and "what do they mean". +- Every optional clause has legal values, default behavior, and constraints. +- The semantic pseudocode matches the prose. +- User-visible hardware behavior is described without leaking instruction or + register implementation. +- No removed syntax, old方案, TODO design fragments, or dead alternatives remain + in the stable manual. +- No legacy or low-level op names are introduced into high-level manuals or + `docs/vpto-spec.md` high-level summaries/inventories. +- Examples are meaningful and consistent with verifier/lowering behavior. +- Related generated or bundled docs are refreshed when this repo expects them + to be kept in sync. diff --git a/.codex/skills/ptoas-npu-validation-a5/SKILL.md b/.codex/skills/ptoas-npu-validation-a5/SKILL.md new file mode 100644 index 000000000..735cde327 --- /dev/null +++ b/.codex/skills/ptoas-npu-validation-a5/SKILL.md @@ -0,0 +1,335 @@ +--- +name: ptoas-npu-validation-a5 +description: Generate and run PTOAS-based A5 test/npu_validation or test/vpto validations, build the testcase binaries, and validate runtime output on NPU or simulator. Use when the user wants NPU run validation, golden/compare checks, or runtime troubleshooting for A5. +--- + +# PTOAS NPU Validation A5 + +Use this skill when the task is specifically about: +- generating `test/npu_validation` projects from PTOAS output +- running `test/vpto/scripts/run_host_vpto_validation.sh` +- running `test/vpto` board validation or simulator validation +- building testcase binaries for A5 +- running NPU or simulator validation +- generating golden inputs and checking results with `compare.py` +- diagnosing runtime blockers such as missing device access or `aclrtSetDevice` + +This skill is the main entry point for runtime validation. + +Do not use this skill as the primary entry point when the task is only about: +- exporting LLVM IR or LLVM bitcode +- validating the `bisheng` handoff +- assembling a fat object or replacement kernel library from the LLVM path + +When this validation flow needs a custom LLVM IR or LLVM BC artifact, use +`ptoas-vpto-llvm-artifacts` first to build that artifact, then return here to +run the testcase. + +## Important Constraint + +The `npu_validation` flow still depends on an EmitC-generated sample export to +materialize the host-side testcase skeleton. + +For the existing automation, this EmitC export step is not something the user +must run manually first. The provided host-validation scripts already do it for +you. + +Specifically: +- `run_host_npu_validation.sh` automatically invokes `test/samples/runop.sh` + first +- that export is written under `WORK_SPACE/emitc/...` +- `run_host_npu_validation_case.sh` then uses that generated EmitC `*-pto.cpp` + as the input to `generate_testcase.py` + +Even when the final kernel under validation comes from the VPTO/LLVM path, the +current scripts do not generate a standalone host runner from VPTO MLIR or +LLVM IR directly. The canonical automated flow is: + +1. `run_host_npu_validation.sh` automatically exports the sample through the + default EmitC path to get `*-pto.cpp` +2. `run_host_npu_validation_case.sh` runs `generate_testcase.py` on that + generated EmitC kernel to create the testcase directory, host `main.cpp`, + kernel wrapper source, `launch.cpp`, and build system +3. if LLVM/VPTO validation is desired, `run_host_npu_validation_case.sh` + optionally calls `build_llvm_ir_kernel_so.sh` to rebuild and replace only + the final `lib_kernel.so` +4. the generated testcase binary is then run against that replacement kernel + library + +In other words: +- the scripts automatically do the EmitC export step before testcase + generation +- EmitC is still required to produce the host/testcase scaffolding +- LLVM/VPTO replaces the device kernel library, not the host testcase +- feeding raw VPTO textual MLIR directly into `generate_testcase.py` is not a + supported path + +## Automation Entry Points + +Use these scripts as the default automation entry points instead of rebuilding +the flow by hand: + +- `test/vpto/scripts/run_host_vpto_validation.sh` + - top-level driver for curated VPTO `kernel.pto` board/simulator validation + - consumes hand-authored VPTO cases under `test/vpto/cases/...` + - handles lowering, LLVM-path device object build, host build, golden, and compare + - is the default entry point when the user asks to run VPTO board validation directly + - when it fails at runtime, follow this skill's troubleshooting guidance instead of treating the first `aclrtSetDevice` failure as a final product regression + +- `test/npu_validation/scripts/run_host_npu_validation.sh` + - top-level driver for host/NPU validation + - automatically runs `test/samples/runop.sh` first + - automatically writes the EmitC export under `WORK_SPACE/emitc/...` + - discovers testcase names from `test/samples//npu_validation/...` + - dispatches each testcase to `run_host_npu_validation_case.sh` + +- `test/npu_validation/scripts/run_host_npu_validation_case.sh` + - per-testcase execution driver + - consumes the already-generated EmitC kernel from `WORK_SPACE/emitc/...` + - runs `generate_testcase.py` + - configures and builds the testcase + - when `KERNEL_MODE=llvm`, calls `build_llvm_ir_kernel_so.sh` to replace the + device kernel shared library + - runs the testcase binary and then `compare.py` + +- `test/npu_validation/scripts/build_llvm_ir_kernel_so.sh` + - helper used by the case runner for LLVM/VPTO validation + - assumes the EmitC-derived testcase and host wrapper already exist + - rebuilds only the replacement `lib_kernel.so` + - its internal `runop.sh` export may return non-zero because another sample + in the same family failed, but the script intentionally continues if the + requested testcase's LLVM IR artifact was still produced + +## Preconditions + +Before running `npu_validation` or `test/vpto`, make sure: +- `ptoas` is already built in `./build` +- `bisheng` is in `PATH` or available through CANN `set_env.sh` +- `PTO_ISA_ROOT` points to a `pto-isa` checkout with: + - `include/` + - `tests/common/` +- the shell can read `/dev/davinci*` if you intend to execute on real hardware + +Example: + +```bash +export PTO_ISA_ROOT=/path/to/pto-isa +``` + +Useful runtime check: + +```bash +source /usr/local/Ascend/cann/set_env.sh +python3 - <<'PY' +import ctypes +lib = ctypes.cdll.LoadLibrary('libascendcl.so') +aclInit = lib.aclInit; aclInit.argtypes=[ctypes.c_char_p]; aclInit.restype=ctypes.c_int +aclrtGetDeviceCount = lib.aclrtGetDeviceCount; aclrtGetDeviceCount.argtypes=[ctypes.c_void_p]; aclrtGetDeviceCount.restype=ctypes.c_int +aclrtSetDevice = lib.aclrtSetDevice; aclrtSetDevice.argtypes=[ctypes.c_int]; aclrtSetDevice.restype=ctypes.c_int +cnt = ctypes.c_uint(0) +print('aclInit', aclInit(None)) +print('aclrtGetDeviceCount', aclrtGetDeviceCount(ctypes.byref(cnt)), cnt.value) +print('aclrtSetDevice', aclrtSetDevice(0)) +PY +``` + +Interpretation: +- `aclInit` succeeds +- `aclrtGetDeviceCount` should report at least one device if the runtime can enumerate hardware +- if `aclrtSetDevice(0)` fails with `507033` (`ACL_ERROR_RT_DEV_SETUP_ERROR`), the user context can see a device but cannot open a usable runtime context + +This interpretation applies equally to: + +- `test/npu_validation` +- `test/vpto` + +When `test/vpto/scripts/run_host_vpto_validation.sh` hits `aclrtSetDevice`, do not immediately report a testcase regression. First treat it as a runtime-environment blocker and follow the checks in this skill. + +## Canonical Flow + +### 1. Generate the PTOAS kernel + +Use the default EmitC-style output, because `npu_validation` consumes `*-pto.cpp`. + +```bash +source env.sh +PTOAS_BIN="$PWD/build/tools/ptoas/ptoas" \ +PTOAS_OUT_DIR=/tmp/ptoas-abs-emitc \ +./test/samples/runop.sh -t Abs +``` + +Expected output: +- `/tmp/ptoas-abs-emitc/Abs/abs-pto.cpp` +- this EmitC kernel is also the required host/testcase input for the later + LLVM/VPTO replacement flow + +### 2. Generate the `npu_validation` testcase + +```bash +python3 test/npu_validation/scripts/generate_testcase.py \ + --input /tmp/ptoas-abs-emitc/Abs/abs-pto.cpp \ + --testcase abs \ + --output-root /tmp/ptoas-npu-validation-run \ + --run-mode sim \ + --soc-version dav_3102 \ + --aicore-arch dav-c310-vec +``` + +Expected output directory: +- `/tmp/ptoas-npu-validation-run/Abs/abs` + +### 3. Configure and build + +```bash +export PTO_ISA_ROOT=/path/to/pto-isa +source /usr/local/Ascend/cann/set_env.sh +cmake -S /tmp/ptoas-npu-validation-run/Abs/abs \ + -B /tmp/ptoas-npu-validation-run/Abs/abs/build \ + -DSOC_VERSION=dav_3102 \ + -DENABLE_SIM_GOLDEN=ON +cmake --build /tmp/ptoas-npu-validation-run/Abs/abs/build --parallel +``` + +Typical build expectations: +- `libabs_kernel.so` builds +- `abs` builds +- `abs_sim` may also build if the simulator runtime is available + +If you need to replace the default `libabs_kernel.so` with one assembled from +an LLVM IR or LLVM BC path, build that artifact with +`ptoas-vpto-llvm-artifacts` and place it first in `LD_LIBRARY_PATH` when +running `./build/abs`. + +Important: +- the LLVM/VPTO path does not bypass EmitC testcase generation +- `build_llvm_ir_kernel_so.sh` assumes the testcase was already generated from + the EmitC export and reuses its host wrapper/build artifacts + +### 4. Generate golden inputs + +```bash +cd /tmp/ptoas-npu-validation-run/Abs/abs +python3 ./golden.py +``` + +Expected files: +- `v1.bin` +- `v2.bin` + +For the generated `Abs` testcase, `golden.py` does not emit `golden_v2.bin`, +but `compare.py` expects it. Build the oracle explicitly from the input: + +```bash +cd /tmp/ptoas-npu-validation-run/Abs/abs +python3 - <<'PY' +import numpy as np +v1 = np.fromfile('v1.bin', dtype=np.float32) +np.abs(v1).astype(np.float32).tofile('golden_v2.bin') +PY +``` + +Expected additional file: +- `golden_v2.bin` + +## Running + +### NPU run + +Only attempt this on a shell that can actually see `/dev/davinci*`. + +```bash +export PTO_ISA_ROOT=/path/to/pto-isa +source /usr/local/Ascend/cann/set_env.sh +cd /tmp/ptoas-npu-validation-run/Abs/abs +LD_LIBRARY_PATH="${ASCEND_HOME_PATH}/lib64:${LD_LIBRARY_PATH:-}" \ + ./build/abs +``` + +For the repo's automated host-validation flow, prefer the script's default +remote runner: + +```bash +HOST_RUNNER='ssh root@localhost' +``` + +This is already the default in `run_host_npu_validation.sh` / +`run_host_npu_validation_case.sh`, and it is the preferred way to reach a root +context on the local machine when passwordless root SSH is already configured. + +Use that path first instead of assuming `sudo` is available or passwordless. + +If you are not using the repo scripts and your environment explicitly supports +`sudo`, you may still retry manually with: + +```bash +sudo bash -lc ' + cd /tmp/ptoas-npu-validation-run/Abs/abs + source /usr/local/Ascend/cann/set_env.sh >/dev/null 2>&1 + LD_LIBRARY_PATH="${ASCEND_HOME_PATH}/lib64:${LD_LIBRARY_PATH:-}" \ + ./build/abs +' +``` + +Observed runtime result on this machine for the `Abs` testcase: +- normal user run failed at `aclrtSetDevice(0)` with `507033` +- root-context execution is expected to go through the script default + `ssh root@localhost` path when available +- `python3 ./compare.py` then reported `[INFO] compare passed` + +Observed runtime result on this machine for the VPTO LLVM-path host validation +of `PyPTOIRParser/paged_attention_example_kernel_online_update`: +- `test/npu_validation/scripts/run_host_npu_validation.sh` passed end-to-end +- the replacement kernel library from `build_llvm_ir_kernel_so.sh` was loaded + successfully +- `compare.py` reported `[INFO] compare passed` +- during the LLVM artifact export step, `runop.sh` returned non-zero because + `paged_attention_example_kernel_softmax_prepare` failed in the same sample + batch, but the requested `online_update` LLVM IR was still generated and the + validation flow remained valid + +### Simulator run + +If `abs_sim` links successfully, run it with simulator libraries in `LD_LIBRARY_PATH`. + +```bash +export PTO_ISA_ROOT=/path/to/pto-isa +source /usr/local/Ascend/cann/set_env.sh +cd /tmp/ptoas-npu-validation-run/Abs/abs +LD_LIBRARY_PATH="${ASCEND_HOME_PATH}/aarch64-linux/simulator/dav_3510/lib:${ASCEND_HOME_PATH}/lib64:${LD_LIBRARY_PATH:-}" \ + ./build/abs_sim +``` + +Treat simulator execution as optional. Depending on the local CANN install, the +simulator binary may link successfully but still fail at runtime due to missing +simulator services or runtime symbols. + +## Compare + +After generating `golden_v2.bin` and running the NPU binary, compare with: + +```bash +cd /tmp/ptoas-npu-validation-run/Abs/abs +python3 ./compare.py +``` + +Expected success output: +- `[INFO] compare passed` + +## Known Failure Modes + +- `generate_testcase.py` fails because the input is not a PTOAS EmitC `*-pto.cpp` kernel +- configure fails because `PTO_ISA_ROOT` is unset or points to the wrong checkout +- `abs_sim` fails to link because simulator runtime symbols are missing +- `./build/abs` fails at `aclInit(nullptr)` because the shell does not have usable Ascend runtime access +- non-`sudo` `./build/abs` fails at `aclrtSetDevice(0)` with `507033`, meaning the user context sees the device but cannot open a usable runtime context +- `compare.py` reports `golden_v2.bin` missing because the testcase generation did not create it automatically + +## Reporting Back + +When you use this skill, report: +- the generated testcase directory +- whether `libabs_kernel.so`, `abs`, and `abs_sim` built +- whether `golden.py` generated input bins and whether `golden_v2.bin` had to be created explicitly +- whether NPU execution worked directly or required elevated privileges +- whether `compare.py` passed +- the first concrete blocker for NPU or simulator execution diff --git a/.codex/skills/ptoas-test-framework-guidance/SKILL.md b/.codex/skills/ptoas-test-framework-guidance/SKILL.md new file mode 100644 index 000000000..130f0b880 --- /dev/null +++ b/.codex/skills/ptoas-test-framework-guidance/SKILL.md @@ -0,0 +1,100 @@ +--- +name: ptoas-test-framework-guidance +description: Guidance for adding moving reviewing or validating PTOAS tests across lit VPTO runtime and TileLang ST frameworks. +--- + +# PTOAS Test Framework Guidance + +Use this skill before adding or relocating PTOAS tests. The goal is to put each test in the framework that actually runs it and validates the intended behavior. + +## Framework Selection + +- Use `test/lit` for compiler regression tests: parser/printing, verifier diagnostics, pass output, IR rewrites, CLI behavior, and `FileCheck`-style checks. +- Use `test/lit/pto` for generic PTO/EmitC/DSL-lowering lit tests that do not run the VPTO backend. +- Use `test/lit/vpto` for lit tests whose `RUN:` line uses `--pto-backend=vpto`, `--emit-vpto`, VPTO pass dumps, or VPTO-specific diagnostics. +- Use `test/lit/vpto/cube` for focused cube VPTO lit tests when grouping with existing cube tests is clearer. +- Use `test/vpto` for runtime validation of VPTO-generated fatobj/kernel behavior on SIM or NPU. These tests must compile and run, generate data, and compare results. +- Use `test/tilelang_st` for TileLang DSL system tests that validate DSL rendering plus build/run/compare behavior through the TileLang ST harness. + +## Do Not + +- Do not add new `.pto` tests under `test/basic`; lit does not discover that directory. +- Do not place VPTO backend lit tests under `test/lit/pto`; put them under `test/lit/vpto`. +- Do not add runtime/simulator expectations to lit tests. If correctness depends on executing a kernel and comparing data, use `test/vpto` or `test/tilelang_st`. +- Do not add test path or coverage notes to ISA/spec manuals. `docs/isa` and `docs/vpto-spec.md` should describe op semantics and interfaces, not test locations. +- Do not fix unrelated compiler behavior while only moving tests. If relocation exposes stale tests, report the stale category separately before broad rewrites. + +## `test/lit` Rules + +- Every lit `.pto` file needs at least one `// RUN:` line. +- Prefer small single-purpose checks with `FileCheck`. +- Use `not ptoas ... 2>&1 | FileCheck %s` for negative verifier/parser tests. +- Use `--emit-pto-ir` when checking PTO IR output. +- Use `--pto-backend=vpto --emit-vpto` or `--mlir-print-ir-*` when checking VPTO IR/pass output. +- Keep output checks meaningful; do not reduce a test to only a function-name check if the old test verified richer behavior. + +Validation commands: + +```bash +lit --show-tests build/test/lit +lit -v build/test/lit --filter '' +cmake --build build -j64 --target check-pto +``` + +If `lit` or `FileCheck` is missing, use the `llvm-test-tool-fallback` skill before treating it as a test failure. + +## `test/vpto` Runtime Rules + +Use `test/vpto/cases` when the test must prove generated VPTO code executes correctly. + +A case directory is discovered only when it contains: + +- `kernel.pto` +- `launch.cpp` +- `main.cpp` +- `golden.py` +- `compare.py` + +Current VPTO runtime cases should use the unified fatobj flow emitted by `ptoas`; do not add per-case `stub.cpp` or split `cube.pto`/`kernel.pto` unless the framework explicitly changes. For mixed cube/vector kernels, keep the code in `kernel.pto` using the current module/section form expected by PTOAS. + +Validation commands: + +```bash +WORK_SPACE=/tmp/pto-vpto CASE_NAME='' \ + test/vpto/scripts/run_host_vpto_validation.sh + +WORK_SPACE=/tmp/pto-vpto JOBS=64 \ + test/vpto/scripts/run_host_vpto_validation_parallel.sh +``` + +Required environment normally includes `ASCEND_HOME_PATH`; SIM runs may need `SIM_LIB_DIR` if auto-detection fails. + +## `test/tilelang_st` Rules + +Use `test/tilelang_st` when the behavior starts from TileLang DSL and must be verified through DSL-generated `.pto`, build, run, data generation, and comparison. + +Testcases live under: + +```text +test/tilelang_st/npu//src/st/testcase// +``` + +The batch runner discovers a testcase when `.pto` exists in that testcase directory. Keep testcase data generation and case definitions with the existing ST structure, usually including `cases.py` when multiple parameterized cases are needed. + +Validation commands: + +```bash +python3 test/tilelang_st/script/run_all_st.py --list +python3 test/tilelang_st/script/run_all_st.py -r sim -v a5 -t '' --smoke -j 1 +python3 test/tilelang_st/script/run_all_st.py -r sim -v a5 --smoke -j 64 +``` + +Use the ST harness rather than ad-hoc scripts, so CI and local validation exercise the same path. + +## Finishing Checklist + +- The test is in the framework that matches its assertion: compile/IR, VPTO runtime, or TileLang ST runtime. +- New lit tests are visible in `lit --show-tests build/test/lit`. +- VPTO backend lit tests live under `test/lit/vpto`, not `test/lit/pto`. +- Runtime tests are discoverable by their framework without special-case script logic. +- The smallest relevant validation command was run, and failures are reported by framework and category. diff --git a/.codex/skills/ptoas-vpto-llvm-artifacts/SKILL.md b/.codex/skills/ptoas-vpto-llvm-artifacts/SKILL.md new file mode 100644 index 000000000..d24d19c85 --- /dev/null +++ b/.codex/skills/ptoas-vpto-llvm-artifacts/SKILL.md @@ -0,0 +1,319 @@ +--- +name: "ptoas-vpto-llvm-artifacts" +description: "Guide the PTOAS VPTO compile-and-link workflow: inspect VPTO MLIR, export LLVM IR or LLVM bitcode, validate the Bisheng handoff, and assemble device objects, fat objects, or shared kernel libraries. Use when the user asks how to build, export, compile, or link VPTO LLVM-path artifacts for A5." +--- + +# PTOAS VPTO LLVM Artifacts + +Use this skill when the task is specifically about: +- printing or inspecting VPTO intermediate MLIR +- exporting PTOAS A5 kernels as LLVM IR or LLVM bitcode through the VPTO backend +- checking whether the export is textual LLVM IR or real LLVM bitcode +- compiling the exported artifact with `bisheng` +- assembling a device object, fat relocatable object, or shared kernel library from the LLVM path +- helping with an "LLVM IR path build", "LLVM IR path compile", or "VPTO MLIR" request + +This skill answers: +- how to build or export the artifact +- how to hand the artifact to Bisheng +- how to continue from `.ll` / `.bc` to `.o` / `fatobj` / `.so` +- where each stage output is written + +This skill does not answer: +- which `llvm.hivm.*` intrinsic a VPTO op should lower to +- what the authoritative intrinsic name or operand contract is +- whether the repo-local emitter guessed the wrong LLVM IR form + +Those questions belong to `pto-a5-installed-impl-trace`. + +## Strong Rule + +Treat this skill as a compile-and-link workflow guide, not as the authority for +discovering intrinsic mappings. If the task turns into "what should this VPTO +op lower to" or "is this `llvm.hivm.*` form correct", switch to +`pto-a5-installed-impl-trace`. + +This is not the primary entry point for: +- generating `test/npu_validation` testcases +- running on hardware, handling `aclrtSetDevice`, or deciding whether `sudo` is needed +- `golden.py` / `compare.py` result checks +- discovering the authoritative LLVM IR shape for a VPTO op + +If the end goal is runtime validation, use `ptoas-npu-validation-a5` as the main +skill and call this skill only when that flow needs a custom LLVM IR or LLVM BC +kernel artifact. + +## Preconditions + +Before using this path, make sure: +- `ptoas` is already built in `./build` +- `bisheng` is available through CANN `set_env.sh` +- `env.sh` can be sourced from the repo root +- for the fatobj path, you already have a generated testcase directory that + contains a wrapper source such as `abs_kernel.cpp` and a built `launch.cpp.o` + +Load the repo environment before running examples: + +```bash +set +u +source env.sh +set -u +``` + +Use the `set +u` form when the caller shell has `set -u`, because `env.sh` +appends to variables such as `PYTHONPATH` and `LD_LIBRARY_PATH`. + +## Inspect VPTO MLIR + +Use this when you need to look at the VPTO-stage IR before deciding whether to +continue to textual LLVM IR, LLVM bitcode, or the full artifact assembly flow. + +Canonical flag: + +```bash +--vpto-print-ir +``` + +Example: + +```bash +source env.sh +PTOAS_BIN="$PWD/build/tools/ptoas/ptoas" \ +PTOAS_OUT_DIR=/tmp/ptoas-vpto-ir \ +PTOAS_FLAGS='--pto-arch a5 --pto-backend=vpto --vpto-print-ir' \ +./test/samples/runop.sh -t Abs +``` + +Use this output to: +- confirm the lowering has reached the VPTO dialect you expect +- inspect whether a transformation issue appears before LLVM export +- compare the VPTO MLIR path against the later LLVM IR or bitcode output + +## Export Paths + +### LLVM bitcode export + +Use: + +```bash +--pto-backend=vpto --vpto-emit-hivm-bc +``` + +Example: + +```bash +source env.sh +PTOAS_BIN="$PWD/build/tools/ptoas/ptoas" \ +PTOAS_OUT_DIR=/tmp/ptoas-vpto-hivm-bc \ +PTOAS_FLAGS='--pto-arch a5 --pto-backend=vpto --vpto-emit-hivm-bc' \ +./test/samples/runop.sh -t Abs +``` + +Typical outputs: +- `/tmp/ptoas-vpto-hivm-bc/Abs/abs-pto-ir.pto` +- `/tmp/ptoas-vpto-hivm-bc/Abs/abs-pto.cpp` + +Important: +- the payload is written to `*-pto.cpp` even in bitcode mode +- that file is LLVM bitcode, not C++ source + +Bitcode checks: + +```bash +file /tmp/ptoas-vpto-hivm-bc/Abs/abs-pto.cpp +xxd -l 16 /tmp/ptoas-vpto-hivm-bc/Abs/abs-pto.cpp +"$LLVM_ROOT/bin/llvm-dis" /tmp/ptoas-vpto-hivm-bc/Abs/abs-pto.cpp -o - | sed -n '1,80p' +``` + +Expected signs: +- `file` reports `LLVM IR bitcode` +- the header starts with `42 43 c0 de` +- `llvm-dis` shows HiVM/LLVM content + +### Textual LLVM IR export + +Use: + +```bash +--pto-backend=vpto --vpto-emit-hivm-llvm +``` + +Example: + +```bash +source env.sh +PTOAS_BIN="$PWD/build/tools/ptoas/ptoas" \ +PTOAS_OUT_DIR=/tmp/ptoas-vpto-hivm-llvm \ +PTOAS_FLAGS='--pto-arch a5 --pto-backend=vpto --vpto-emit-hivm-llvm' \ +./test/samples/runop.sh -t Abs +``` + +Typical output: +- `/tmp/ptoas-vpto-hivm-llvm/Abs/abs-pto.cpp` + +Important: +- despite the `.cpp` suffix, this file is textual LLVM IR +- compile it with `-x ir` + +Suggested progression: +- start with `--vpto-print-ir` when the user wants the intermediate VPTO form +- use `--vpto-emit-hivm-llvm` when the user wants textual LLVM IR +- use `--vpto-emit-hivm-bc` when the user wants real LLVM bitcode + +## Compile The Export With Bisheng + +Load the CANN environment first: + +```bash +source /usr/local/Ascend/cann/set_env.sh +``` + +### Compile bitcode to a device object + +Preferred: + +```bash +bisheng \ + --target=hiipu64-hisilicon-cce \ + -march=dav-c310-vec \ + --cce-aicore-arch=dav-c310-vec \ + --cce-aicore-only \ + -O2 \ + -c -x ir /tmp/ptoas-vpto-hivm-bc/Abs/abs-pto.cpp \ + -o /tmp/ptoas-vpto-hivm-bc/Abs/abs-pto.o +``` + +Alternative: +- copy or rename the payload to `.bc` +- compile without relying on the misleading `.cpp` suffix + +### Compile textual LLVM IR to a device object + +```bash +bisheng \ + --target=hiipu64-hisilicon-cce \ + -march=dav-c310-vec \ + --cce-aicore-arch=dav-c310-vec \ + --cce-aicore-only \ + -O2 \ + -c -x ir /tmp/ptoas-vpto-hivm-llvm/Abs/abs-pto.cpp \ + -o /tmp/abs_ir_path_artifacts/kernel_from_llvm_ir.o +``` + +Checks: +- keep `-march` and `--cce-aicore-arch` aligned with the intended testcase arch +- for the LLVM IR path, the resulting object should not retain unresolved + `llvm.hivm.*` symbols + +## If You Need The Real Compiler-Expected Intrinsic Shape + +This is outside the main purpose of this skill. + +When a hand-written LLVM IR path fails in instruction selection or appears to +miscompile, use this trace order: + +1. confirm the installed PTO wrapper path first with `pto-a5-installed-impl-trace` +2. generate the normal testcase kernel source through the working emitc path +3. inspect testcase compile flags from: + - `/build/CMakeFiles/.dir/flags.make` + - `/build/CMakeFiles/.dir/build.make` +4. rerun that same `bisheng` compile with `-v` and `-save-temps` +5. inspect: + - `*.ccei` to confirm the wrapper builtin sequence + - `strings *.bc | rg 'llvm.hivm\\.'` to see which HIVM intrinsics survive +6. if builtin names still are not enough, extract the exact frontend-produced + LLVM IR by replaying the `cc1` invocation from `-v` with `-emit-llvm -S` + +Use this when you need to answer questions such as: +- is the intrinsic name correct but the mask form wrong +- did the compiler expect a `plt/pset` result instead of a literal mask +- is the LLVM IR path missing hidden frontend-generated structure or attrs + +This is the preferred way to align repo-local LLVM emission with the real +compiler contract. + +## Assemble Fat Objects And Shared Libraries + +Use this only when the validation flow needs a replacement kernel library built +from the LLVM path. The canonical example below uses the generated `Abs` +testcase, but the pattern is the same for other testcases: take the testcase +wrapper source, embed the device object, pack it with `cce-ld`, then link the +shared kernel library. + +Required testcase artifacts: +- a wrapper source such as `/tmp/ptoas-npu-validation-run/Abs/abs/abs_kernel.cpp` +- a built launch object such as + `/tmp/ptoas-npu-validation-run/Abs/abs/build/CMakeFiles/abs_kernel.dir/launch.cpp.o` + +### 1. Build the host stub object + +```bash +/usr/local/Ascend/cann-9.0.0/tools/bisheng_compiler/bin/bisheng -cc1 \ + -triple aarch64-unknown-linux-gnu \ + -fcce-is-host \ + -fcce-fatobj-compile \ + -fcce-include-aibinary /tmp/abs_ir_path_artifacts/kernel_from_llvm_ir.o \ + -fcce-device-module-id a55ab1efc0defeed \ + -fcce-aicore-arch dav-c310-vec \ + -x cce /tmp/ptoas-npu-validation-run/Abs/abs/abs_kernel.cpp \ + -o /tmp/abs_ir_path_artifacts/kernel_host_stub.o +``` + +### 2. Pack the fat relocatable object + +```bash +/usr/local/Ascend/cann-9.0.0/bin/cce-ld \ + /usr/local/Ascend/cann-9.0.0/bin/ld.lld \ + -x \ + -cce-lite-bin-module-id a55ab1efc0defeed \ + -cce-aicore-arch=dav-c310-vec \ + -r \ + -o /tmp/abs_ir_path_artifacts/kernel_fat.o \ + -cce-stub-dir /usr/local/Ascend/cann-9.0.0/tools/bisheng_compiler/lib/clang/15.0.5/include/cce_stub \ + -cce-install-dir /usr/local/Ascend/cann-9.0.0/tools/bisheng_compiler/bin \ + -cce-inputs-number 1 \ + /tmp/abs_ir_path_artifacts/kernel_host_stub.o +``` + +The module id must match between: +- `-fcce-device-module-id` +- `-cce-lite-bin-module-id` + +### 3. Link the shared kernel library + +```bash +mkdir -p /tmp/abs_ir_path_artifacts/link_try +cd /tmp/abs_ir_path_artifacts/link_try +/usr/local/Ascend/cann-9.0.0/bin/bisheng \ + -fPIC -s -Wl,-z,relro -Wl,-z,now --cce-fatobj-link \ + -shared -Wl,-soname,libabs_kernel.so \ + -o libabs_kernel.so \ + /tmp/abs_ir_path_artifacts/kernel_fat.o \ + /tmp/ptoas-npu-validation-run/Abs/abs/build/CMakeFiles/abs_kernel.dir/launch.cpp.o +``` + +This skill stops at producing the replacement artifact. To run the testcase +with that library and validate outputs, switch back to `ptoas-npu-validation-a5`. + +## Failure Modes + +Report the first concrete blocker: +- `--vpto-print-ir`, `--vpto-emit-hivm-bc`, or `--vpto-emit-hivm-llvm` used without `--pto-backend=vpto` +- `--vpto-emit-hivm-bc` or `--vpto-emit-hivm-llvm` used without `--pto-backend=vpto` +- `env.sh` was not sourced, or failed under `set -u` +- `bisheng` was not found or CANN environment was not loaded +- a bitcode payload was treated as source because it kept a misleading suffix +- the testcase wrapper or `launch.cpp.o` is missing for the fatobj path +- the module ids used for stub creation and `cce-ld` packing do not match + +## Reporting Back + +When you use this skill, report: +- whether the user-facing artifact of interest was VPTO MLIR, textual LLVM IR, or LLVM bitcode +- the exact `ptoas` flags used +- whether the export was VPTO MLIR, LLVM bitcode, or textual LLVM IR +- the exact output path that contains the exported payload +- whether `llvm-dis`, `file`, or direct inspection confirmed the payload type +- whether `bisheng` produced a device object +- whether the flow also produced a fat relocatable object or shared kernel library +- which step was the first blocker, if the full artifact chain did not complete diff --git a/.codex/skills/resolve-dsl-issue/SKILL.md b/.codex/skills/resolve-dsl-issue/SKILL.md new file mode 100644 index 000000000..52752c0db --- /dev/null +++ b/.codex/skills/resolve-dsl-issue/SKILL.md @@ -0,0 +1,263 @@ +--- +name: resolve-dsl-issue +description: 根据用户提供的 issue 链接,提取 DSL 与 PTO IR 复现最小用例,运行 PTOAS 复现并分析日志,在用户指导下完成修复、提交并自动创建关联 issue 的 PR。 +--- + +# Resolve DSL Issue + +当任务满足以下任一条件时使用本 skill: +- 用户明确提供了要处理的 issue 链接 +- 用户希望“按 issue 内容复现 DSL 问题并定位根因” + +不建议作为主入口的场景: +- 仅做编译/构建,不涉及 issue 复现 +- 仅做 NPU 运行验证,不涉及 DSL/PTO IR 复现 + +## 目标 + +从 issue 中抽取可执行复现输入(DSL + PTO IR),在仓库内构造最小复现并定位根因;在用户确认修复方向后完成代码修复、验证、提交,并自动创建关联原始 issue 的 PR。 + +## 前置条件 + +- 当前目录是 PTOAS 仓库根目录 +- `build/` 目录可写 +- `ptoas` 可执行(已在 PATH 或有明确绝对路径) +- 能访问 issue 内容(网页、API、或用户粘贴) +- 若需要自动创建 PR:`gh` CLI 已安装并登录(`gh auth status` 成功) + +## 标准流程 + +### 1. 解析 issue,提取两个代码片段 + +必须提取到两类片段: +- DSL 代码片段(`.py`) +- PTO IR 代码片段(`.pto`)(如果是纯DSL前端问题,可以不需要 PTO IR) + +推荐提取顺序: +1. issue 正文 +2. issue 评论 +3. issue 附件/粘贴内容 + +如果任一片段缺失,停止后续复现,直接在 issue 请求补充(模板见“评论模板”)。 + +### 2. 在仓库中落盘复现文件 + +文件位置固定为: +- DSL: `lib/TileOps/` +- PTO IR: `test/dsl/` + +命名建议使用 issue 编号,避免冲突,例如: +- `lib/TileOps/issue__repro.py` +- `test/dsl/issue__repro.pto` + +要求: +- 原样写入,避免“自动修复”代码导致偏离用户输入 +- 保留 issue 中的关键注释和输入形状信息 + +### 3. 执行编译并保存日志 + +标准命令: + +```bash +ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand --vpto-emit-hivm-llvm &> +``` + +推荐日志路径: +- `build/issue__repro.log` + +示例: + +```bash +ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand --vpto-emit-hivm-llvm \ + test/dsl/issue_1234_repro.pto \ + &> build/issue_1234_repro.log +``` + +### 4. 分析日志,判断是否复现 + +先定位关键错误信息(error/fatal/assert/traceback),再判断是否与 issue 描述一致。 + +推荐快速检索: + +```bash +rg -n "error|fatal|assert|traceback|failed" build/issue__repro.log +``` + +分支处理: +- 未复现:在 issue 中请求更完整的复现信息(环境、命令、输入、预期/实际) +- 已复现:进入根因定位 + +### 5. 根因定位与方案建议 + +定位输出应至少包含: +- 触发错误的阶段(前端解析/TileOp 展开/Lowering/LLVM 发射等) +- 直接触发点(具体报错行、pass、或输入约束不满足) +- 根因判断(1-2 条最可能原因,标注置信度) +- 修复建议(最小改动优先) + +如果无法在当前上下文完成修复实现,也需要给出: +- 建议修改文件范围 +- 建议新增/补充的测试用例 + +### 6. 与用户确认修复方向(必须) + +在进入代码修改前,先向用户同步: +- 复现文件路径 +- 复现命令 +- 关键报错摘要 +- 根因与建议 +- 待确认项(如环境差异) + +只有在用户明确同意修复方向后,才进入第 7 步。 + +### 7. 实施修复并本地验证 + +修复要求: +- 仅改动与该 issue 直接相关的最小文件集合 +- 优先补充或更新回归测试(如 `test/dsl` 相关用例) +- 保留复现输入,避免把“复现文件”误删 + +验证要求: +- 至少重新执行一次复现命令,确认错误消失或行为符合预期 +- 将关键验证日志保存到 `build/issue__fix_verify.log` +- 跑一次完整的dsl测试集,确认无其他回归 + +### 8. 提交代码(在用户确认后执行) + +分支命名建议: +- `fix/issue--dsl` + +提交信息建议(至少包含 issue 编号): +- `fix(dsl): <简要修复描述> (#)` + +示例命令: + +```bash +git checkout -b fix/issue_1234_dsl +git add +git commit -m "fix(dsl): handle in (#1234)" +git push -u origin fix/issue_1234_dsl +``` + +### 9. 自动创建 PR 并关联原始 issue + +目标仓库:https://github.com/mouliangyu/PTOAS/ +目标分支:feature-vpto-backend + +关联规则(GitHub): +- 在 PR 描述中包含 `Closes #` 或 `Fixes #` +- 若是跨仓库 issue,使用 `Closes /#` +- 合并后删除分支 + +推荐使用 `gh pr create`: + +```bash +gh pr create \ + --base main \ + --head fix/issue_1234_dsl \ + --title "fix(dsl): <简要修复标题>" \ + --body "$(cat <<'EOF' +## Summary +- <修改点1> +- <修改点2> + +## Repro +- issue: #1234 +- repro cmd: `ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand --vpto-emit-hivm-llvm test/dsl/issue_1234_repro.pto` + +## Validation +- <验证命令/结果> + +Closes #1234 +EOF +)" +``` + +创建 PR 后需要回填: +- PR 链接 +- 关联 issue 语句是否生效(是否显示 “linked issues”) + +### 10. 结果同步并等待 review + +向用户同步: +- 修复文件列表 +- 提交 hash +- PR 链接 +- 关联 issue 状态 +- 后续待办(例如 reviewer 关注点) + +## 评论模板 + +### 模板 A:缺少 DSL 或 PTO IR 片段 + +```text +为了准确复现该问题,还需要完整的最小复现输入。请补充以下两段代码: +1) DSL Python 片段(可直接运行到生成该 PTO 的部分) +2) 对应的 PTO IR 片段(完整函数/入口,不要省略关键上下文) + +建议同时提供:执行命令、实际报错、期望行为。 +``` + +### 模板 B:当前未复现 + +```text +我已按当前 issue 信息完成复现尝试,但暂未在本地复现相同报错。 +请补充以下信息以便继续定位: +1) 完整执行命令(含所有 flags) +2) 运行环境(分支/commit、CANN 版本、是否自定义环境变量) +3) 实际报错全文(建议粘贴日志片段) +4) 期望结果与当前结果差异 +``` + +### 模板 C:已复现并给出建议 + +```text +已使用 issue 中输入复现成功,关键报错位于:<阶段/文件/日志行号>。 +初步根因:<根因描述>。 +建议修复:<最小修复方案>。 + +如果你同意该方向,我会继续补充对应测试并提交修复实现供 review。 +``` + +### 模板 D:修复完成,准备提交与开 PR + +```text +修复已完成并通过本地验证。 +计划执行: +1) 提交分支:fix/issue--dsl +2) 创建 PR 并在描述中添加 `Closes #` 自动关联 issue + +请确认是否按该方案提交并创建 PR。 +``` + +### 模板 E:PR 已创建并关联 issue + +```text +PR 已创建: +已在 PR 描述中添加 `Closes #`,原始 issue 已自动关联。 + +本次提交: +- Commit: +- 关键修改: +- 验证结果: +``` + +## 执行注意事项 + +- 不要在未确认复现之前改动用户原始输入语义 +- 优先保留最小复现,不做无关重构 +- 若 issue 信息不完整,先补信息再继续,不要猜测输入 +- 日志分析时优先使用首次错误点,不要只看最后一行报错 +- 未经用户确认,不要直接执行 `git commit`、`git push`、`gh pr create` +- PR 关联语句建议统一放在 PR body 末尾,避免被模板覆盖 +- 若 `gh` 未登录或无权限,输出完整 PR 标题/body 草稿供用户手动创建 + +## 最终输出格式(给用户) + +建议按以下顺序输出: +1. 是否复现成功 +2. 复现文件路径与命令 +3. 日志关键错误(1-3 条) +4. 根因判断 +5. 修复建议与下一步计划 +6.(若完成修复)提交信息、PR 链接、issue 关联状态 diff --git a/.github/workflows/build_wheel.yml b/.github/workflows/build_wheel.yml index 3087c69bd..36c7079e4 100644 --- a/.github/workflows/build_wheel.yml +++ b/.github/workflows/build_wheel.yml @@ -24,17 +24,20 @@ permissions: contents: write env: - LLVM_TAG: llvmorg-19.1.7 + LLVM_REPO: https://github.com/vpto-dev/llvm-project.git + LLVM_REF: feature-vpto LLVM_CACHE_FLAVOR: release-hardening-v1 jobs: build_wheel: - name: Build wheel (Python ${{ matrix.python }}, ${{ matrix.arch }}) + # Wheel publication is paused, but this job still builds and uploads the + # packaged ptoas binary distribution. + name: Build ptoas-bin (${{ matrix.arch }}) runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-latest' || 'ubuntu-24.04-arm' }} strategy: fail-fast: false matrix: - python: ["3.10", "3.11", "3.12"] + python: ["3.11"] arch: ["x86_64", "aarch64"] container: @@ -92,12 +95,22 @@ jobs: echo "PTO_BUILD_DIR=$GITHUB_WORKSPACE/build-release" >> $GITHUB_ENV echo "PTO_INSTALL_DIR=$GITHUB_WORKSPACE/install-release" >> $GITHUB_ENV + - name: Resolve LLVM source SHA + id: llvm-source + run: | + LLVM_SOURCE_SHA="$(git ls-remote "${LLVM_REPO}" "refs/heads/${LLVM_REF}" | awk '{print $1}')" + if [ -z "${LLVM_SOURCE_SHA}" ]; then + echo "ERROR: failed to resolve ${LLVM_REPO} ${LLVM_REF}" >&2 + exit 1 + fi + echo "sha=${LLVM_SOURCE_SHA}" >> "$GITHUB_OUTPUT" + - name: Restore LLVM build cache (exact key) id: cache-llvm uses: actions/cache/restore@v4 with: path: /llvm-workspace/llvm-project/build-release - key: llvm-${{ env.LLVM_TAG }}-manylinux_2_34-${{ matrix.arch }}-${{ matrix.python }}-${{ env.LLVM_CACHE_FLAVOR }} + key: llvm-${{ steps.llvm-source.outputs.sha }}-manylinux_2_34-${{ matrix.arch }}-${{ matrix.python }}-${{ env.LLVM_CACHE_FLAVOR }} - name: Prepare LLVM source run: | @@ -105,10 +118,11 @@ jobs: cd $LLVM_SOURCE_DIR if [ ! -d .git ]; then git init - git remote add origin https://github.com/llvm/llvm-project.git + git remote add origin "${LLVM_REPO}" fi - git fetch --depth 1 origin tag ${{ env.LLVM_TAG }} - git checkout ${{ env.LLVM_TAG }} + git remote set-url origin "${LLVM_REPO}" + git fetch --depth 1 origin "${LLVM_REF}" + git checkout --force FETCH_HEAD - name: Warn when default-branch cache is missing for PR/release runs if: steps.cache-llvm.outputs.cache-hit != 'true' && (github.event_name == 'pull_request' || github.event_name == 'release') @@ -116,7 +130,7 @@ jobs: echo "LLVM cache miss while running on PR/release." echo "PR/release runs are cache-consumers only and should reuse cache generated on the default branch." echo "Please run this workflow on the default branch first to populate cache key:" - echo "llvm-${{ env.LLVM_TAG }}-manylinux_2_34-${{ matrix.arch }}-${{ matrix.python }}-${{ env.LLVM_CACHE_FLAVOR }}" + echo "llvm-${{ steps.llvm-source.outputs.sha }}-manylinux_2_34-${{ matrix.arch }}-${{ matrix.python }}-${{ env.LLVM_CACHE_FLAVOR }}" - name: Build LLVM/MLIR if: steps.cache-llvm.outputs.cache-hit != 'true' @@ -140,7 +154,7 @@ jobs: uses: actions/cache/save@v4 with: path: /llvm-workspace/llvm-project/build-release - key: llvm-${{ env.LLVM_TAG }}-manylinux_2_34-${{ matrix.arch }}-${{ matrix.python }}-${{ env.LLVM_CACHE_FLAVOR }} + key: llvm-${{ steps.llvm-source.outputs.sha }}-manylinux_2_34-${{ matrix.arch }}-${{ matrix.python }}-${{ env.LLVM_CACHE_FLAVOR }} - name: Build PTOAS run: | @@ -163,12 +177,14 @@ jobs: ninja -C $PTO_BUILD_DIR install - name: Create Python wheel + if: false run: | export PATH="${PY_PATH}/bin:$PATH" export PTOAS_PYTHON_PACKAGE_VERSION="${PTOAS_VERSION}" bash $PTO_SOURCE_DIR/docker/create_wheel.sh - name: Repair wheel with auditwheel + if: false run: | export PATH="${PY_PATH}/bin:$PATH" export PY_PACKAGE_DIR=$LLVM_BUILD_DIR/tools/mlir/python_packages/mlir_core @@ -177,6 +193,7 @@ jobs: auditwheel repair --plat manylinux_2_34_${{ matrix.arch }} dist/ptoas*.whl -w wheelhouse - name: Test wheel installation + if: false run: | export PATH="${PY_PATH}/bin:$PATH" export PY_PACKAGE_DIR=$LLVM_BUILD_DIR/tools/mlir/python_packages/mlir_core @@ -189,32 +206,33 @@ jobs: bash $PTO_SOURCE_DIR/docker/test_ptoas_cli.sh - name: Copy wheel to workspace + if: false run: | export PY_PACKAGE_DIR=$LLVM_BUILD_DIR/tools/mlir/python_packages/mlir_core mkdir -p $GITHUB_WORKSPACE/wheelhouse cp $PY_PACKAGE_DIR/wheelhouse/ptoas*.whl $GITHUB_WORKSPACE/wheelhouse/ - name: Upload wheel artifact + if: false uses: actions/upload-artifact@v4 with: name: ptoas-wheel-py${{ matrix.python }}-${{ matrix.arch }} path: wheelhouse/*.whl - name: Collect ptoas binary and dependencies - if: matrix.python == '3.11' run: | bash $PTO_SOURCE_DIR/docker/collect_ptoas_dist.sh $GITHUB_WORKSPACE/ptoas-dist - name: Upload ptoas binary artifact - if: matrix.python == '3.11' uses: actions/upload-artifact@v4 with: name: ptoas-bin-${{ matrix.arch }} path: ptoas-dist/ upload_release_assets: + # Disabled together with build_wheel because wheel publication is paused. + if: false name: Upload release assets - if: github.event_name == 'release' || github.event_name == 'schedule' needs: build_wheel runs-on: ubuntu-latest diff --git a/.github/workflows/build_wheel_mac.yml b/.github/workflows/build_wheel_mac.yml index 7477aeee9..31e3bcb20 100644 --- a/.github/workflows/build_wheel_mac.yml +++ b/.github/workflows/build_wheel_mac.yml @@ -23,22 +23,21 @@ permissions: contents: write env: - LLVM_TAG: llvmorg-19.1.7 + LLVM_REPO: https://github.com/vpto-dev/llvm-project.git + LLVM_REF: feature-vpto LLVM_CACHE_FLAVOR: release-v2 jobs: build_wheel: - name: Build wheel (Python ${{ matrix.python }}, ${{ matrix.arch }}) + # Wheel publication is paused, but this job still builds and uploads the + # packaged ptoas binary distribution. + name: Build ptoas-bin (macOS ${{ matrix.arch }}) runs-on: ${{ matrix.arch == 'x86_64' && 'macos-15-intel' || 'macos-26' }} strategy: fail-fast: false matrix: - python: ["3.10", "3.11", "3.12"] + python: ["3.11"] arch: ["x86_64", "aarch64"] - # Keep macOS matrix jobs at 5 to stay within GitHub Actions macOS job limits. - exclude: - - python: "3.10" - arch: "x86_64" steps: - name: Checkout repository @@ -96,12 +95,22 @@ jobs: echo "PTO_BUILD_DIR=$GITHUB_WORKSPACE/build-release" >> $GITHUB_ENV echo "PTO_INSTALL_DIR=$GITHUB_WORKSPACE/install-release" >> $GITHUB_ENV + - name: Resolve LLVM source SHA + id: llvm-source + run: | + LLVM_SOURCE_SHA="$(git ls-remote "${LLVM_REPO}" "refs/heads/${LLVM_REF}" | awk '{print $1}')" + if [ -z "${LLVM_SOURCE_SHA}" ]; then + echo "ERROR: failed to resolve ${LLVM_REPO} ${LLVM_REF}" >&2 + exit 1 + fi + echo "sha=${LLVM_SOURCE_SHA}" >> "$GITHUB_OUTPUT" + - name: Restore LLVM build cache (exact key) id: cache-llvm uses: actions/cache/restore@v4 with: path: ${{ runner.temp }}/llvm-workspace/llvm-project/build-release - key: llvm-${{ env.LLVM_TAG }}-${{ steps.runner-meta.outputs.label }}-${{ matrix.arch }}-${{ matrix.python }}-${{ env.LLVM_CACHE_FLAVOR }} + key: llvm-${{ steps.llvm-source.outputs.sha }}-${{ steps.runner-meta.outputs.label }}-${{ matrix.arch }}-${{ matrix.python }}-${{ env.LLVM_CACHE_FLAVOR }} - name: Prepare LLVM source run: | @@ -109,10 +118,11 @@ jobs: cd $LLVM_SOURCE_DIR if [ ! -d .git ]; then git init - git remote add origin https://github.com/llvm/llvm-project.git + git remote add origin "${LLVM_REPO}" fi - git fetch --depth 1 origin tag ${{ env.LLVM_TAG }} - git checkout ${{ env.LLVM_TAG }} + git remote set-url origin "${LLVM_REPO}" + git fetch --depth 1 origin "${LLVM_REF}" + git checkout --force FETCH_HEAD - name: Warn when default-branch cache is missing for PR/release runs if: steps.cache-llvm.outputs.cache-hit != 'true' && (github.event_name == 'pull_request' || github.event_name == 'release') @@ -120,7 +130,7 @@ jobs: echo "LLVM cache miss while running on PR/release." echo "PR/release runs are cache-consumers only and should reuse cache generated on the default branch." echo "Please run this workflow on the default branch first to populate cache key:" - echo "llvm-${{ env.LLVM_TAG }}-${{ steps.runner-meta.outputs.label }}-${{ matrix.arch }}-${{ matrix.python }}-${{ env.LLVM_CACHE_FLAVOR }}" + echo "llvm-${{ steps.llvm-source.outputs.sha }}-${{ steps.runner-meta.outputs.label }}-${{ matrix.arch }}-${{ matrix.python }}-${{ env.LLVM_CACHE_FLAVOR }}" - name: Build LLVM/MLIR if: steps.cache-llvm.outputs.cache-hit != 'true' @@ -143,7 +153,7 @@ jobs: uses: actions/cache/save@v4 with: path: ${{ runner.temp }}/llvm-workspace/llvm-project/build-release - key: llvm-${{ env.LLVM_TAG }}-${{ steps.runner-meta.outputs.label }}-${{ matrix.arch }}-${{ matrix.python }}-${{ env.LLVM_CACHE_FLAVOR }} + key: llvm-${{ steps.llvm-source.outputs.sha }}-${{ steps.runner-meta.outputs.label }}-${{ matrix.arch }}-${{ matrix.python }}-${{ env.LLVM_CACHE_FLAVOR }} - name: Build PTOAS run: | @@ -165,6 +175,7 @@ jobs: ninja -C $PTO_BUILD_DIR install - name: Create Python wheel + if: false run: | export PY_PACKAGE_DIR=$LLVM_BUILD_DIR/tools/mlir/python_packages/mlir_core export PTOAS_PYTHON_PACKAGE_VERSION="${PTOAS_VERSION}" @@ -184,6 +195,7 @@ jobs: fi - name: Repair wheel with delocate + if: false run: | set -euo pipefail export PY_PACKAGE_DIR=$LLVM_BUILD_DIR/tools/mlir/python_packages/mlir_core @@ -258,6 +270,7 @@ jobs: ls -lh wheelhouse - name: Test wheel installation + if: false run: | export PY_PACKAGE_DIR=$LLVM_BUILD_DIR/tools/mlir/python_packages/mlir_core pip install $PY_PACKAGE_DIR/wheelhouse/ptoas*.whl @@ -270,31 +283,30 @@ jobs: bash $PTO_SOURCE_DIR/docker/test_ptoas_cli.sh - name: Copy wheel to workspace + if: false run: | export PY_PACKAGE_DIR=$LLVM_BUILD_DIR/tools/mlir/python_packages/mlir_core mkdir -p $GITHUB_WORKSPACE/wheelhouse cp $PY_PACKAGE_DIR/wheelhouse/ptoas*.whl $GITHUB_WORKSPACE/wheelhouse/ - name: Upload wheel artifact + if: false uses: actions/upload-artifact@v4 with: name: ptoas-wheel-macos-py${{ matrix.python }}-${{ matrix.arch }} path: wheelhouse/*.whl - name: Collect ptoas binary and dependencies - if: matrix.python == '3.11' run: | bash "$PTO_SOURCE_DIR/docker/collect_ptoas_dist_mac.sh" "$GITHUB_WORKSPACE/ptoas-dist" - name: Archive ptoas binary artifact - if: matrix.python == '3.11' run: | chmod +x "$GITHUB_WORKSPACE/ptoas-dist/ptoas" "$GITHUB_WORKSPACE/ptoas-dist/bin/ptoas" tar -czf "$GITHUB_WORKSPACE/ptoas-bin-macos-${{ matrix.arch }}.tar.gz" \ -C "$GITHUB_WORKSPACE/ptoas-dist" . - name: Smoke test archived ptoas binary artifact - if: matrix.python == '3.11' run: | TEST_DIR="$RUNNER_TEMP/ptoas-dist-smoke-${{ matrix.arch }}" rm -rf "$TEST_DIR" @@ -312,7 +324,7 @@ jobs: >/dev/null - name: Smoke test wheel imports after collecting artifacts - if: matrix.python == '3.11' + if: false run: | # Test the copied wheel artifact from an isolated env with no build-tree # paths, so post-collect/release regressions (e.g. ARM cs_invalid_page) @@ -348,15 +360,15 @@ jobs: fi - name: Upload ptoas binary artifact - if: matrix.python == '3.11' uses: actions/upload-artifact@v4 with: name: ptoas-bin-macos-${{ matrix.arch }} path: ptoas-bin-macos-${{ matrix.arch }}.tar.gz upload_release_assets: + # Disabled together with build_wheel because wheel publication is paused. + if: false name: Upload release assets - if: github.event_name == 'release' || github.event_name == 'schedule' needs: build_wheel runs-on: ubuntu-latest diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c0ef0625b..c45a80929 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,7 +1,10 @@ name: CI +concurrency: + group: ci-${{ github.event_name }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: ${{ github.event_name == 'pull_request' }} + on: - push: pull_request: # Nightly remote-board validation (GitHub cron is UTC). # 22:00 CST (UTC+8) == 14:00 UTC. @@ -67,7 +70,7 @@ permissions: jobs: license-header-check: - if: ${{ github.event_name == 'pull_request' || github.event_name == 'push' }} + if: ${{ github.event_name == 'pull_request' }} runs-on: ubuntu-22.04 steps: - name: Checkout @@ -93,7 +96,8 @@ jobs: build-and-test: runs-on: ubuntu-22.04 env: - LLVM_COMMIT: cd708029e0b2869e80abe31ddb175f7c35361f90 + LLVM_REPO: https://github.com/vpto-dev/llvm-project.git + LLVM_REF: feature-vpto LLVM_BUILD_DIR: ${{ github.workspace }}/llvm-project/llvm/build-assert LLVM_DIR: ${{ github.workspace }}/llvm-project/llvm/build-assert PTO_BUILD_DIR: ${{ github.workspace }}/build-assert @@ -104,8 +108,6 @@ jobs: - name: Checkout uses: actions/checkout@v4 with: - repository: ${{ github.event.pull_request.head.repo.full_name || github.repository }} - ref: ${{ github.event.pull_request.head.sha || github.sha }} fetch-depth: 1 persist-credentials: false @@ -140,6 +142,17 @@ jobs: mkdir -p "${PAYLOAD_DIR}/test/npu_validation/scripts" mkdir -p "${PAYLOAD_DIR}/test/npu_validation/templates" + - name: Resolve LLVM source SHA + shell: bash + run: | + set -euo pipefail + LLVM_SOURCE_SHA="$(git ls-remote "${LLVM_REPO}" "refs/heads/${LLVM_REF}" | awk '{print $1}')" + [[ -n "${LLVM_SOURCE_SHA}" ]] || { + echo "ERROR: failed to resolve ${LLVM_REPO} ${LLVM_REF}" >&2 + exit 1 + } + echo "LLVM_SOURCE_SHA=${LLVM_SOURCE_SHA}" >> "${GITHUB_ENV}" + # 先恢复 LLVM build 缓存 - name: Restore LLVM build cache id: cache-llvm @@ -147,7 +160,7 @@ jobs: with: path: | llvm-project/llvm/build-assert - key: llvm-${{ runner.os }}-${{ env.LLVM_COMMIT }}-${{ env.LLVM_CACHE_FLAVOR }} + key: llvm-${{ runner.os }}-${{ env.LLVM_SOURCE_SHA }}-${{ env.LLVM_CACHE_FLAVOR }} - name: Prepare LLVM source (no rebuild) run: | @@ -157,11 +170,12 @@ jobs: # cache 只保存 build 目录,这里补一个最小 git repo 供 cmake/ninja 使用 if [ ! -d .git ]; then git init - git remote add origin https://github.com/llvm/llvm-project.git + git remote add origin "${LLVM_REPO}" fi - git fetch --depth 1 origin tag llvmorg-19.1.7 - git checkout "${LLVM_COMMIT}" + git remote set-url origin "${LLVM_REPO}" + git fetch --depth 1 origin "${LLVM_REF}" + git checkout --force FETCH_HEAD - name: Build LLVM/MLIR (only if cache miss) if: steps.cache-llvm.outputs.cache-hit != 'true' @@ -186,7 +200,7 @@ jobs: with: path: | llvm-project/llvm/build-assert - key: llvm-${{ runner.os }}-${{ env.LLVM_COMMIT }}-${{ env.LLVM_CACHE_FLAVOR }} + key: llvm-${{ runner.os }}-${{ env.LLVM_SOURCE_SHA }}-${{ env.LLVM_CACHE_FLAVOR }} - name: Build PTOAS run: | @@ -271,6 +285,233 @@ jobs: path: ${{ env.PAYLOAD_TGZ }} if-no-files-found: error + vpto-sim-validation: + runs-on: [self-hosted, Linux, X64, label-1] + timeout-minutes: 120 + concurrency: + group: vpto-sim-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + if: >- + ${{ + github.event_name == 'workflow_dispatch' || + github.event_name == 'schedule' || + github.event_name == 'pull_request' + }} + env: + LLVM_REPO: https://github.com/vpto-dev/llvm-project.git + LLVM_REF: feature-vpto + VPTO_SIM_WORKSPACE: ${{ github.workspace }}/.work/vpto-sim-ci + TILELANG_DSL_WORKSPACE: ${{ github.workspace }}/.work/tilelang-dsl-ci + TILELANG_DSL_UT_WORKSPACE: ${{ github.workspace }}/.work/tilelang-dsl-ut-ci + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 1 + persist-credentials: false + + - name: Resolve LLVM directories + shell: bash + env: + TOOL_CACHE: ${{ runner.tool_cache }} + run: | + set -euo pipefail + echo "LLVM_ROOT=${TOOL_CACHE}/llvm-project" >> "${GITHUB_ENV}" + echo "LLVM_DIR=${TOOL_CACHE}/llvm-project/llvm/build-shared" >> "${GITHUB_ENV}" + echo "MLIR_PYTHONPATH=${TOOL_CACHE}/llvm-project/llvm/build-shared/tools/mlir/python_packages/mlir_core" >> "${GITHUB_ENV}" + + - name: Ensure runner dependencies + shell: bash + run: | + set -euo pipefail + missing_tools=() + for tool in python3 git cmake ninja make; do + if ! command -v "${tool}" >/dev/null 2>&1; then + missing_tools+=("${tool}") + fi + done + + if [[ "${#missing_tools[@]}" -gt 0 ]]; then + if command -v sudo >/dev/null 2>&1 && command -v apt-get >/dev/null 2>&1; then + sudo apt-get update + sudo apt-get install -y python3 python3-pip git cmake ninja-build make + else + echo "ERROR: missing required tools on self-hosted runner: ${missing_tools[*]}" >&2 + echo "ERROR: automatic installation requires sudo + apt-get" >&2 + exit 1 + fi + fi + + python3 -m pip --version >/dev/null 2>&1 || { + if command -v sudo >/dev/null 2>&1 && command -v apt-get >/dev/null 2>&1; then + sudo apt-get update + sudo apt-get install -y python3-pip + else + echo "ERROR: python3-pip is required on self-hosted runner" >&2 + exit 1 + fi + } + + need_pip_install=0 + python3 -c "import numpy" >/dev/null 2>&1 || need_pip_install=1 + python3 -m pybind11 --cmakedir >/dev/null 2>&1 || need_pip_install=1 + python3 -c "import ml_dtypes" >/dev/null 2>&1 || need_pip_install=1 + + if [[ "${need_pip_install}" -eq 1 ]]; then + python3 -m pip install --upgrade pip + python3 -m pip install 'pybind11<3' numpy ml-dtypes + fi + + - name: Clean CI work dirs + shell: bash + run: | + set -euo pipefail + rm -rf "${GITHUB_WORKSPACE}/build" + rm -rf "${VPTO_SIM_WORKSPACE}" + rm -rf "${TILELANG_DSL_WORKSPACE}" + + - name: Prepare LLVM source (no rebuild) + shell: bash + run: | + set -euo pipefail + mkdir -p "${LLVM_ROOT}" + cd "${LLVM_ROOT}" + + if [ ! -d .git ]; then + git init + git remote add origin "${LLVM_REPO}" + fi + + git remote set-url origin "${LLVM_REPO}" + git fetch --depth 1 origin "${LLVM_REF}" + git checkout --force FETCH_HEAD + + - name: Build LLVM/MLIR + shell: bash + run: | + set -euo pipefail + cd "${LLVM_ROOT}" + cmake -G Ninja -S llvm -B llvm/build-shared \ + -DLLVM_ENABLE_PROJECTS="mlir;clang" \ + -DBUILD_SHARED_LIBS=ON \ + -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ + -DPython3_EXECUTABLE=python3 \ + -DCMAKE_BUILD_TYPE=Release \ + -DLLVM_TARGETS_TO_BUILD="host" + + ninja -C llvm/build-shared + + - name: Build PTOAS + shell: bash + run: | + set -euo pipefail + export PYBIND11_CMAKE_DIR="$(python3 -m pybind11 --cmakedir)" + cmake -G Ninja -S . -B build \ + -DLLVM_DIR="${LLVM_DIR}/lib/cmake/llvm" \ + -DMLIR_DIR="${LLVM_DIR}/lib/cmake/mlir" \ + -DPython3_EXECUTABLE=python3 \ + -DPython3_FIND_STRATEGY=LOCATION \ + -Dpybind11_DIR="${PYBIND11_CMAKE_DIR}" \ + -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ + -DMLIR_PYTHON_PACKAGE_DIR="${LLVM_DIR}/tools/mlir/python_packages/mlir_core" \ + -DCMAKE_BUILD_TYPE=Release + ninja -C build ptoas + + - name: Resolve simulator environment + shell: bash + run: | + set -euo pipefail + + detect_ascend_home() { + for d in \ + "${ASCEND_HOME_PATH:-}" \ + /usr/local/Ascend/cann \ + /usr/local/Ascend/cann-* \ + /usr/local/Ascend/ascend-toolkit/latest + do + [[ -n "${d}" && -d "${d}" ]] || continue + printf '%s\n' "${d}" + return 0 + done + return 1 + } + + ASCEND_HOME_PATH_DETECTED="$(detect_ascend_home || true)" + if [[ -z "${ASCEND_HOME_PATH_DETECTED}" ]]; then + echo "ERROR: failed to detect ASCEND_HOME_PATH on self-hosted runner" >&2 + exit 1 + fi + + echo "ASCEND_HOME_PATH=${ASCEND_HOME_PATH_DETECTED}" >> "${GITHUB_ENV}" + echo "PTOAS_BIN=${GITHUB_WORKSPACE}/build/tools/ptoas/ptoas" >> "${GITHUB_ENV}" + + - name: Run VPTO SIM validation + if: ${{ true }} + shell: bash + run: | + set -euo pipefail + mkdir -p "${VPTO_SIM_WORKSPACE}" + WORK_SPACE="${VPTO_SIM_WORKSPACE}" \ + ASCEND_HOME_PATH="${ASCEND_HOME_PATH}" \ + PTOAS_BIN="${PTOAS_BIN}" \ + DEVICE=SIM \ + JOBS="${JOBS:-32}" \ + bash test/vpto/scripts/run_host_vpto_validation_parallel.sh + + - name: Run TileLang DSL CI + shell: bash + run: | + set -euo pipefail + mkdir -p "${TILELANG_DSL_WORKSPACE}" + if [[ "${{ github.event_name }}" == "pull_request" ]]; then + ASCEND_HOME_PATH="${ASCEND_HOME_PATH}" \ + PTOAS_BIN="${PTOAS_BIN}" \ + bash test/tilelang_st/script/run_ci.sh -r sim -v a5 --jobs 64 --smoke \ + 2>&1 | tee "${TILELANG_DSL_WORKSPACE}/run_ci.log" + else + ASCEND_HOME_PATH="${ASCEND_HOME_PATH}" \ + PTOAS_BIN="${PTOAS_BIN}" \ + bash test/tilelang_st/script/run_ci.sh -r sim -v a5 --jobs 64 \ + 2>&1 | tee "${TILELANG_DSL_WORKSPACE}/run_ci.log" + fi + + - name: Upload TileLang DSL logs + if: always() + uses: actions/upload-artifact@v4 + with: + name: tilelang-dsl-ci-${{ github.run_id }} + path: | + ${{ env.TILELANG_DSL_WORKSPACE }}/run_ci.log + if-no-files-found: warn + + - name: Run TileLang DSL unit tests + shell: bash + run: | + set -euo pipefail + mkdir -p "${TILELANG_DSL_UT_WORKSPACE}" + cd tilelang-dsl + PYTHONPATH=python python3 -m unittest discover -s tests -p 'test_*.py' \ + 2>&1 | tee "${TILELANG_DSL_UT_WORKSPACE}/unittest.log" + + - name: Upload TileLang DSL unit test logs + if: always() + uses: actions/upload-artifact@v4 + with: + name: tilelang-dsl-ut-ci-${{ github.run_id }} + path: | + ${{ env.TILELANG_DSL_UT_WORKSPACE }}/unittest.log + if-no-files-found: warn + + - name: Upload VPTO SIM logs + if: always() + uses: actions/upload-artifact@v4 + with: + name: vpto-sim-validation-${{ github.run_id }} + path: | + ${{ env.VPTO_SIM_WORKSPACE }}/parallel-runner.log + ${{ env.VPTO_SIM_WORKSPACE }}/parallel-summary.tsv + if-no-files-found: warn + remote-npu-validation: needs: build-and-test runs-on: ubuntu-22.04 diff --git a/.gitignore b/.gitignore index 44c61b02a..61f15f6b0 100644 --- a/.gitignore +++ b/.gitignore @@ -64,6 +64,11 @@ dist/ /remote_npu_validation_results*.tsv /npu_validation/ test/samples/**/npu_validation/ +!test/samples/**/npu_validation/ +test/samples/**/npu_validation/* +!test/samples/**/npu_validation/golden.py +!test/samples/**/npu_validation/*/ +!test/samples/**/npu_validation/*/golden.py /tmp_gen* # IDE/editor @@ -73,3 +78,7 @@ test/samples/**/npu_validation/ .DS_Store .ipynb_checkpoints/ *.orig + +# Local workspace +.work/ +.local/ diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 000000000..9ae183956 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "3rdparty/PTO-Gym"] + path = 3rdparty/PTO-Gym + url = git@github.com:PTO-ISA/PTO-Gym.git diff --git a/3rdparty/PTO-Gym b/3rdparty/PTO-Gym new file mode 160000 index 000000000..a68542fdb --- /dev/null +++ b/3rdparty/PTO-Gym @@ -0,0 +1 @@ +Subproject commit a68542fdb84149d3c8be6b1be507ace625e04a90 diff --git a/CMakeLists.txt b/CMakeLists.txt index e9dbcf96e..d82cc2c75 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -101,6 +101,11 @@ include_directories(${PROJECT_BINARY_DIR}/include) link_directories(${LLVM_BUILD_LIBRARY_DIR}) add_definitions(${LLVM_DEFINITIONS}) +# ========================================================= +# 3.1 Testing option setup +# ========================================================= +include(CTest) + # 开启 Python 绑定选项 option(PTO_ENABLE_PYTHON_BINDING "Enable Python bindings" ON) @@ -112,6 +117,7 @@ if(PTO_ENABLE_PYTHON_BINDING) set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/python/pto) add_subdirectory(python) + add_subdirectory(tilelang-dsl) endif() # ========================================================= @@ -124,9 +130,27 @@ add_subdirectory(tools) # ========================================================= # 4.1 Tests (ctest) # ========================================================= -include(CTest) if(BUILD_TESTING) enable_testing() + if(PTO_ENABLE_PYTHON_BINDING) + add_test( + NAME tilelang_dsl_import + COMMAND "${Python3_EXECUTABLE}" + "${CMAKE_CURRENT_SOURCE_DIR}/tilelang-dsl/tests/import_tilelang_dsl.py" + ) + set_tests_properties(tilelang_dsl_import PROPERTIES + ENVIRONMENT "PYTHONPATH=${CMAKE_BINARY_DIR}/python:$ENV{PYTHONPATH}" + ) + add_test( + NAME tilelang_dsl_unittest + COMMAND "${Python3_EXECUTABLE}" -m unittest discover + -s "${CMAKE_CURRENT_SOURCE_DIR}/tilelang-dsl/tests" + -p "test_*.py" + ) + set_tests_properties(tilelang_dsl_unittest PROPERTIES + ENVIRONMENT "PYTHONPATH=${CMAKE_BINARY_DIR}/python:$ENV{PYTHONPATH}" + ) + endif() add_subdirectory(tools/ptobc/tests) endif() diff --git a/docker/Dockerfile b/docker/Dockerfile index df8636d7e..358794e62 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -74,6 +74,7 @@ ENV PY_PACKAGE_DIR=$LLVM_BUILD_DIR/tools/mlir/python_packages/mlir_core/ # copy pto.py, _pto_ops_gen.py RUN cp $PTO_INSTALL_DIR/mlir/dialects/*.py $PY_PACKAGE_DIR/mlir/dialects/ +RUN rm -rf $PY_PACKAGE_DIR/tilelang_dsl $PY_PACKAGE_DIR/TileOps && cp -R $PTO_INSTALL_DIR/tilelang_dsl $PY_PACKAGE_DIR/tilelang_dsl && cp -R $PTO_INSTALL_DIR/share/ptoas/TileOps $PY_PACKAGE_DIR/TileOps COPY ./setup.py $PY_PACKAGE_DIR/ diff --git a/docker/collect_ptoas_dist.sh b/docker/collect_ptoas_dist.sh index 4cb3c39de..ffd8fb679 100755 --- a/docker/collect_ptoas_dist.sh +++ b/docker/collect_ptoas_dist.sh @@ -22,6 +22,8 @@ # ptoas - Wrapper script that sets up LD_LIBRARY_PATH # bin/ptoas - The actual ptoas binary # lib/*.so* - Required shared library dependencies +# share/ptoas/TileOps - TileLang template library +# tilelang_dsl/ - TileLang DSL Python package set -euo pipefail @@ -45,6 +47,10 @@ export LD_LIBRARY_PATH="${LLVM_BUILD_DIR}/lib:${PTO_INSTALL_DIR}/lib:${LD_LIBRAR PTO_BUILD_DIR="${PTO_BUILD_DIR:-${PTO_SOURCE_DIR}/build}" PTOAS_BIN="${PTO_BUILD_DIR}/tools/ptoas/ptoas" PTOAS_DEPS_DIR="${PTOAS_DIST_DIR}/lib" +PTOAS_TILEOPS_SRC_DIR="${PTO_INSTALL_DIR}/share/ptoas/TileOps" +PTOAS_TILEOPS_DIST_DIR="${PTOAS_DIST_DIR}/share/ptoas/TileOps" +PTOAS_TILELANG_DSL_SRC_DIR="${PTO_INSTALL_DIR}/tilelang_dsl" +PTOAS_TILELANG_DSL_DIST_DIR="${PTOAS_DIST_DIR}/tilelang_dsl" if [ ! -f "$PTOAS_BIN" ]; then echo "Error: ptoas binary not found at $PTOAS_BIN" >&2 @@ -121,7 +127,10 @@ harden_elf() { } # Create output directories -mkdir -p "${PTOAS_DIST_DIR}/bin" "${PTOAS_DEPS_DIR}" +mkdir -p \ + "${PTOAS_DIST_DIR}/bin" \ + "${PTOAS_DEPS_DIR}" \ + "$(dirname "${PTOAS_TILEOPS_DIST_DIR}")" # Copy ptoas binary echo "Copying ptoas binary..." @@ -151,6 +160,19 @@ while read -r packaged; do harden_elf "$packaged" done < <(find "${PTOAS_DIST_DIR}/bin" "${PTOAS_DEPS_DIR}" -type f | sort) +echo "Copying TileLang runtime resources..." +if [ ! -d "${PTOAS_TILEOPS_SRC_DIR}" ]; then + echo "Error: TileOps resource directory not found at ${PTOAS_TILEOPS_SRC_DIR}" >&2 + exit 1 +fi +if [ ! -d "${PTOAS_TILELANG_DSL_SRC_DIR}" ]; then + echo "Error: tilelang_dsl package directory not found at ${PTOAS_TILELANG_DSL_SRC_DIR}" >&2 + exit 1 +fi +rm -rf "${PTOAS_TILEOPS_DIST_DIR}" "${PTOAS_TILELANG_DSL_DIST_DIR}" +cp -R "${PTOAS_TILEOPS_SRC_DIR}" "${PTOAS_TILEOPS_DIST_DIR}" +cp -R "${PTOAS_TILELANG_DSL_SRC_DIR}" "${PTOAS_TILELANG_DSL_DIST_DIR}" + # Create wrapper script echo "Creating wrapper script..." cat > "${PTOAS_DIST_DIR}/ptoas" << 'WRAPPER_EOF' @@ -175,11 +197,16 @@ else echo "$VERSION_OUTPUT" | grep -Eq '^ptoas [0-9]+\.[0-9]+$' fi +test -d "${PTOAS_TILEOPS_DIST_DIR}" +test -f "${PTOAS_TILELANG_DSL_DIST_DIR}/__init__.py" + # Show collected files echo "" echo "=== ptoas distribution contents ===" ls -la "${PTOAS_DIST_DIR}/" ls -la "${PTOAS_DIST_DIR}/bin/" +ls -la "${PTOAS_DIST_DIR}/share/ptoas/" +ls -la "${PTOAS_TILELANG_DSL_DIST_DIR}" SO_COUNT=$(find "${PTOAS_DEPS_DIR}" -name "*.so*" 2>/dev/null | wc -l) echo "=== Collected .so dependencies (${SO_COUNT} files) ===" du -sh "${PTOAS_DEPS_DIR}/" diff --git a/docker/collect_ptoas_dist_mac.sh b/docker/collect_ptoas_dist_mac.sh index b5d5338d6..3eb95f4e3 100644 --- a/docker/collect_ptoas_dist_mac.sh +++ b/docker/collect_ptoas_dist_mac.sh @@ -22,6 +22,8 @@ # ptoas - Wrapper script that sets up DYLD_LIBRARY_PATH # bin/ptoas - The actual ptoas binary # lib/*.dylib - Required shared library dependencies +# share/ptoas/TileOps - TileLang template library +# tilelang_dsl/ - TileLang DSL Python package set -euo pipefail @@ -43,6 +45,10 @@ done PTO_BUILD_DIR="${PTO_BUILD_DIR:-${PTO_SOURCE_DIR}/build}" PTOAS_BIN="${PTO_BUILD_DIR}/tools/ptoas/ptoas" PTOAS_DEPS_DIR="${PTOAS_DIST_DIR}/lib" +PTOAS_TILEOPS_SRC_DIR="${PTO_INSTALL_DIR}/share/ptoas/TileOps" +PTOAS_TILEOPS_DIST_DIR="${PTOAS_DIST_DIR}/share/ptoas/TileOps" +PTOAS_TILELANG_DSL_SRC_DIR="${PTO_INSTALL_DIR}/tilelang_dsl" +PTOAS_TILELANG_DSL_DIST_DIR="${PTOAS_DIST_DIR}/tilelang_dsl" UNRESOLVED_NON_SYSTEM_COUNT=0 if [ ! -f "$PTOAS_BIN" ]; then @@ -50,7 +56,10 @@ if [ ! -f "$PTOAS_BIN" ]; then exit 1 fi -mkdir -p "${PTOAS_DIST_DIR}/bin" "${PTOAS_DEPS_DIR}" +mkdir -p \ + "${PTOAS_DIST_DIR}/bin" \ + "${PTOAS_DEPS_DIR}" \ + "$(dirname "${PTOAS_TILEOPS_DIST_DIR}")" cp -fL "$PTOAS_BIN" "${PTOAS_DIST_DIR}/bin/" chmod +x "${PTOAS_DIST_DIR}/bin/ptoas" @@ -238,6 +247,19 @@ PY echo "Collecting dylib dependencies..." collect_dylibs "${PTOAS_DIST_DIR}/bin/ptoas" +echo "Copying TileLang runtime resources..." +if [[ ! -d "${PTOAS_TILEOPS_SRC_DIR}" ]]; then + echo "Error: TileOps resource directory not found at ${PTOAS_TILEOPS_SRC_DIR}" >&2 + exit 1 +fi +if [[ ! -d "${PTOAS_TILELANG_DSL_SRC_DIR}" ]]; then + echo "Error: tilelang_dsl package directory not found at ${PTOAS_TILELANG_DSL_SRC_DIR}" >&2 + exit 1 +fi +rm -rf "${PTOAS_TILEOPS_DIST_DIR}" "${PTOAS_TILELANG_DSL_DIST_DIR}" +cp -R "${PTOAS_TILEOPS_SRC_DIR}" "${PTOAS_TILEOPS_DIST_DIR}" +cp -R "${PTOAS_TILELANG_DSL_SRC_DIR}" "${PTOAS_TILELANG_DSL_DIST_DIR}" + echo "Rewriting packaged install names..." rewrite_packaged_install_names @@ -330,6 +352,8 @@ if [ -n "${PTOAS_VERSION:-}" ]; then else echo "$VERSION_OUTPUT" | grep -Eq '^ptoas [0-9]+\.[0-9]+$' fi +test -d "${PTOAS_TILEOPS_DIST_DIR}" +test -f "${PTOAS_TILELANG_DSL_DIST_DIR}/__init__.py" env -u DYLD_LIBRARY_PATH -u LD_LIBRARY_PATH \ "${PTOAS_DIST_DIR}/ptoas" \ "${PTO_SOURCE_DIR}/test/lit/pto/kernel_kind_vector_scf_while_emitc.pto" \ @@ -339,6 +363,8 @@ echo "" echo "=== ptoas distribution contents ===" ls -la "${PTOAS_DIST_DIR}/" ls -la "${PTOAS_DIST_DIR}/bin/" +ls -la "${PTOAS_DIST_DIR}/share/ptoas/" +ls -la "${PTOAS_TILELANG_DSL_DIST_DIR}" DYLIB_COUNT=$(find "${PTOAS_DEPS_DIR}" -name "*.dylib" 2>/dev/null | wc -l) echo "=== Collected .dylib dependencies (${DYLIB_COUNT} files) ===" du -sh "${PTOAS_DEPS_DIR}/" diff --git a/docker/create_wheel.sh b/docker/create_wheel.sh index 4762e1abc..2145fb9e7 100755 --- a/docker/create_wheel.sh +++ b/docker/create_wheel.sh @@ -44,6 +44,13 @@ echo "Wheel package version: ${PTOAS_PYTHON_PACKAGE_VERSION}" echo "Copying PTO dialect files..." cp "${PTO_INSTALL_DIR}/mlir/dialects/"*.py "${PY_PACKAGE_DIR}/mlir/dialects/" +# Copy TileLang resources into the wheel staging tree so wheel installs keep +# the template library and Python DSL available. +echo "Copying TileLang resources..." +rm -rf "${PY_PACKAGE_DIR}/tilelang_dsl" "${PY_PACKAGE_DIR}/TileOps" +cp -R "${PTO_INSTALL_DIR}/tilelang_dsl" "${PY_PACKAGE_DIR}/tilelang_dsl" +cp -R "${PTO_INSTALL_DIR}/share/ptoas/TileOps" "${PY_PACKAGE_DIR}/TileOps" + # Copy platform-specific setup.py to package directory. # On macOS, use setup_mac.py and rename it to setup.py in the build dir. SETUP_TEMPLATE="${PTO_SOURCE_DIR}/docker/setup.py" diff --git a/docs/PTO_IR_manual.md b/docs/PTO_IR_manual.md index e5bb7e615..767cece24 100644 --- a/docs/PTO_IR_manual.md +++ b/docs/PTO_IR_manual.md @@ -99,15 +99,20 @@ number of scalar FP4 elements. Operation support is still opt-in. Defining the type in PTO IR does not by itself imply that any particular operation accepts it. -### 2.2 `!pto.ptr` +### 2.2 `!pto.ptr` -A pointer to global memory. +A typed pointer. `memorySpace` is optional and defaults to `gm`. | Parameter | Type | Description | |-----------|------|-------------| | `elementType` | `element-type(i1/i8/i16/i32/f16/f32/bf16...)` | Element type pointed to | +| `memorySpace` | `gm` or `ub` | Pointer address space alias (`gm` -> global memory, `ub` -> vector/UB memory) | -**Syntax:** `!pto.ptr` +**Syntax:** `!pto.ptr` or `!pto.ptr` + +Pointer conversions are modeled explicitly with [`pto.castptr`](#ptocastptr). +Between two `!pto.ptr` types, casts are only legal when both pointers stay in +the same PTO memory space. --- @@ -448,6 +453,39 @@ result = ptr + offset // offset is in elements, not bytes %ptr_off = pto.addptr %base, %offset : !pto.ptr -> !pto.ptr ``` +##### `pto.castptr` - Explicit Pointer Cast + +**Summary:** Performs an explicit cast between integer addresses and `!pto.ptr`, +or between two `!pto.ptr` types. + +**Semantics:** + +```mlir +%p0 = pto.castptr %addr : i64 -> !pto.ptr +%p1 = pto.castptr %p0 : !pto.ptr -> !pto.ptr +%addr2 = pto.castptr %p1 : !pto.ptr -> i64 +``` + +**Arguments:** + +| Name | Type | Description | +|------|------|-------------| +| `input` | `integer` or `!pto.ptr<...>` | Source value to cast | + +**Results:** `integer` or `!pto.ptr<...>` + +**Constraints & Verification:** + +- Integer-to-integer casts are rejected; use normal integer cast ops instead +- Descriptor values such as `!pto.tensor_view<...>` and `!pto.partition_tensor_view<...>` are not legal direct inputs; extract a memref address first +- Pointer-to-pointer casts are only legal when source and destination stay in + the same PTO memory space (`gm` or `ub`) +- The operation is pure (no side effects) + +**Hardware Mapping:** + +- No hardware pipeline (representation conversion only) + ##### `pto.make_tensor_view` - Create Tensor View **Summary:** Constructs a global tensor view from a pointer, declaring the physical base and strides (no allocation, no data movement). @@ -534,6 +572,43 @@ This op is primarily defined on `!pto.tensor_view`. --- +##### `pto.get_tensor_view_stride` - Get Tensor View Dimension Stride + +**Summary:** Returns the logical stride of a given dimension of a tensor view. + +**Semantics:** + +```mlir +stride = get_tensor_view_stride(tv_or_mr, dim_index) +``` + +This op is defined on `!pto.tensor_view`. During internal lowering, the same +query may temporarily appear on the memref form lowered from the tensor view. + +**Arguments:** + +| Name | Type | Description | +|------|------|-------------| +| `tensor_view` | `!pto.tensor_view<...>` or `memref<...>` | Logical tensor view or its lowered memref form | +| `dim_index` | `index` | Dimension index (0-based) | + +**Results:** `index` — the logical stride of the requested dimension, measured +in elements rather than bytes. + +**Notes:** + +- This op is the IR counterpart of the DSL-side `TensorView.strides` metadata access. +- After lowering to memref, static strides may be folded into constants, while dynamic strides are derived from memref metadata. + +**Basic Example:** + +```mlir +%s0 = pto.get_tensor_view_stride %tv, %c0 : !pto.tensor_view -> index +%s1 = pto.get_tensor_view_stride %tv, %c1 : !pto.tensor_view -> index +``` + +--- + ##### `pto.partition_view` - Partition Tensor View **Summary:** Creates a logical window on a tensor_view using offsets and sizes, producing a `partition_tensor_view`. diff --git a/docs/build_with_installed_llvm.md b/docs/build_with_installed_llvm.md new file mode 100644 index 000000000..1167e8df2 --- /dev/null +++ b/docs/build_with_installed_llvm.md @@ -0,0 +1,153 @@ +# 基于已安装 LLVM 的 PTOAS 构建说明 + +本文档按 [README.md](../README.md) 第 3 章的逻辑整理,适用于: + +- LLVM/MLIR `19.1.7` 已经构建并安装完成。 +- LLVM 安装路径固定为 `/opt/llvm`。 +- `/opt/llvm` 是共享目录,不希望 `ptoas` 的安装步骤写入其中。 + +## 3.0 环境变量配置 + +先按 README 第 3.0 节的思路把变量定好。区别是这里不再使用 LLVM 源码目录和 LLVM build tree,而是直接使用 LLVM install tree。 + +```bash +# ================= 配置区域 (请按实际环境调整) ================= +export WORKSPACE_DIR=$HOME/llvm-workspace + +# LLVM 已安装完成,直接指向 install 根目录 +export LLVM_INSTALL_DIR=/opt/llvm + +# 为兼容仓库内部分脚本 / lit 变量命名,这里额外保留 LLVM_BUILD_DIR +export LLVM_BUILD_DIR=$LLVM_INSTALL_DIR + +# ptoas 源码与安装路径 +export PTO_SOURCE_DIR=$WORKSPACE_DIR/PTOAS +export PTO_INSTALL_DIR=$PTO_SOURCE_DIR/install-optllvm +# ============================================================ + +mkdir -p "$WORKSPACE_DIR" +``` + +说明: + +- 这里的 `LLVM_BUILD_DIR` 只是为了兼容仓库内已有变量名,实际指向的是 LLVM install 根目录 `/opt/llvm`。 +- `PTO_INSTALL_DIR` 建议单独放到 PTOAS 自己目录下,避免与共享 LLVM 安装混用。 + +## 3.1 环境准备 + +沿用 README 第 3.1 节即可,重点确认这些依赖已经满足: + +- Linux +- GCC >= 9 或 Clang +- CMake >= 3.20 +- Ninja +- Python 3.8+ +- `pybind11` +- `numpy` + +```bash +pip3 install pybind11 numpy +``` + +## 跳过 3.2 + +README 第 3.2 节是 LLVM/MLIR 的下载和编译步骤。当前场景下 LLVM 已经安装在 `/opt/llvm`,这一节可以直接跳过。 + +已验证: + +```bash +/opt/llvm/bin/llvm-config --version +``` + +输出为: + +```text +19.1.7 +``` + +## 3.3 第二步:构建 ptoas + +这里沿用 README 第 3.3 节的流程,但有两处需要改动: + +1. `LLVM_DIR` 和 `MLIR_DIR` 改为 `/opt/llvm/lib/cmake/...` +2. `MLIR_PYTHON_PACKAGE_DIR` 不再指向共享的 `/opt/llvm/python_packages/mlir_core`,而是指向 `PTO_INSTALL_DIR` + +如果继续沿用 README 里的 `MLIR_PYTHON_PACKAGE_DIR=$LLVM_BUILD_DIR/tools/mlir/python_packages/mlir_core`,在 `/opt/llvm` 场景下会把 `_pto.cpython-*.so` 安装到共享 LLVM 目录,不适合多人共用。 + +```bash +cd "$PTO_SOURCE_DIR" + +# 1. 获取 pybind11 的 CMake 路径 +export PYBIND11_CMAKE_DIR=$(python3 -m pybind11 --cmakedir) + +# 2. 配置 CMake +cmake -G Ninja \ + -S . \ + -B build \ + -DLLVM_DIR=$LLVM_INSTALL_DIR/lib/cmake/llvm \ + -DMLIR_DIR=$LLVM_INSTALL_DIR/lib/cmake/mlir \ + -DPython3_EXECUTABLE=$(which python3) \ + -DPython3_FIND_STRATEGY=LOCATION \ + -Dpybind11_DIR="${PYBIND11_CMAKE_DIR}" \ + -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ + -DMLIR_PYTHON_PACKAGE_DIR="$PTO_INSTALL_DIR" \ + -DCMAKE_INSTALL_PREFIX="$PTO_INSTALL_DIR" + +# 3. 编译并安装 +ninja -C build +cmake --install build +``` + +## 构建后关键产物 + +按上面的配置,关键产物位置如下: + +- build 目录: + - `$PTO_SOURCE_DIR/build/tools/ptoas/ptoas` + - `$PTO_SOURCE_DIR/build/tools/ptobc/ptobc` + - `$PTO_SOURCE_DIR/build/python/mlir/_mlir_libs/_pto.cpython-*.so` + - `$PTO_SOURCE_DIR/build/python/mlir/dialects/pto.py` +- install 目录: + - `$PTO_INSTALL_DIR/bin/ptoas` + - `$PTO_INSTALL_DIR/mlir/_mlir_libs/_pto.cpython-*.so` + - `$PTO_INSTALL_DIR/mlir/dialects/pto.py` + - `$PTO_INSTALL_DIR/share/ptoas/oplib/level3` + +## 补充:运行环境 + +### 使用 build 目录中的 `ptoas` + +```bash +export PATH=$PTO_SOURCE_DIR/build/tools/ptoas:$PATH +export PYTHONPATH=$LLVM_INSTALL_DIR/python_packages/mlir_core:$PTO_SOURCE_DIR/build/python:$PYTHONPATH +export LD_LIBRARY_PATH=$LLVM_INSTALL_DIR/lib:$PTO_SOURCE_DIR/build/lib:$LD_LIBRARY_PATH +``` + +### 使用 install 目录中的 `ptoas` + +```bash +export PATH=$PTO_INSTALL_DIR/bin:$PATH +export PYTHONPATH=$LLVM_INSTALL_DIR/python_packages/mlir_core:$PTO_INSTALL_DIR:$PYTHONPATH +export LD_LIBRARY_PATH=$LLVM_INSTALL_DIR/lib:$PTO_INSTALL_DIR/lib:$LD_LIBRARY_PATH +``` + +注意: + +- install 版 `ptoas` 仍然需要从 `/opt/llvm/lib` 加载 LLVM/MLIR 共享库。 +- 如果直接运行 `$PTO_INSTALL_DIR/bin/ptoas` 而没有设置 `LD_LIBRARY_PATH=$LLVM_INSTALL_DIR/lib:...`,会报缺少 `libMLIR*.so`。 + +## 本地验证结果 + +当前仓库已验证通过以下组合: + +- `LLVM_DIR=/opt/llvm/lib/cmake/llvm` +- `MLIR_DIR=/opt/llvm/lib/cmake/mlir` +- `MLIR_PYTHON_PACKAGE_DIR=$PTO_INSTALL_DIR` +- `CMAKE_INSTALL_PREFIX=$PTO_INSTALL_DIR` + +最小验证结果: + +- build 版 `ptoas --version` 输出 `ptoas 0.22` +- build 版 `ptoas` 可成功处理 `test/lit/pto/empty_func.pto` +- install 版 Python 绑定可在 `PYTHONPATH=/opt/llvm/python_packages/mlir_core:$PTO_INSTALL_DIR` 下正常导入 +- 若 install 版 `ptoas` 配合 `LD_LIBRARY_PATH=/opt/llvm/lib:$PTO_INSTALL_DIR/lib`,可正常执行 diff --git a/docs/designs/2026-04-15-tcvt-tilelib-sample-and-work-items.md b/docs/designs/2026-04-15-tcvt-tilelib-sample-and-work-items.md new file mode 100644 index 000000000..da1034887 --- /dev/null +++ b/docs/designs/2026-04-15-tcvt-tilelib-sample-and-work-items.md @@ -0,0 +1,366 @@ +# `pto.tcvt` TileLib 模板库设计与工作项 + +## 1. 目标 + +参考 `pto-isa` 在 A5 上现有的 `TCVT_IMPL` 实现,用 TileLang DSL 在 TileLib 中补齐 `pto.tcvt` 模板库。 + +当前 PTOAS 的 `pto.tcvt` 只显式携带 `rmode`,没有单独暴露 `sat_mode`。因此模板库不能只做一层简单透传,而是要在内部按 `(src_dtype, dst_dtype)` 复现 A5 的默认 `sat_mode` 选择,再为不同类型对走到正确的 VPTO 路径。 + + +## 2. 当前语义 + +### 2.1 PTOAS 侧 + +当前 PTOAS 的 `pto.tcvt` 只有 `rmode` attribute,没有 `sat_mode` attribute。 + +这意味着从 PTOAS 传到 TileLib 的静态信息只有 round mode。默认饱和策略需要模板库自己补齐。 + +### 2.2 A5 `pto-isa` 侧 + +A5 侧已有多组 `TCVT_IMPL` 重载,包括: + +- `TCVT_IMPL(dst, src, mode)` +- `TCVT_IMPL(dst, src, mode, satMode)` +- `TCVT_IMPL(dst, src, tmp, mode)` +- `TCVT_IMPL(dst, src, tmp, mode, satMode)` + +其中 `TCVT_IMPL(dst, src, tmp, mode)` 在 A5 上只是转调无 `tmp` 的版本,`tmp` 本身不参与实现。这里保留 `tmp`,主要是为了和 A2/A3 的接口形态保持兼容。 + +如果只聚焦当前 `pto.tcvt` 真正需要对齐的那条入口,也就是: + +```cpp +TCVT_IMPL(dst, src, mode) +``` + +那么 A5 `pto-isa` 里的主要过程可以概括成下面这条链路: + +1. 先按 `(src_dtype, dst_dtype)` 选默认 `satMode` + 也就是这条入口本身先做一层类型分派,把当前 type pair 映射成默认 + `satMode=ON` 或 `OFF`。 + +2. 再转调到显式 `satMode` 的主实现入口 + +```cpp +TCVT_IMPL(dst, src, mode, satMode) +``` + +3. 在显式 `satMode` 入口里,先根据 `(src_dtype, dst_dtype, satMode)` 计算当前需要设置哪些 CTRL 位 + 这里会调用 `determineSaturationCtrlBits(...)`,然后再调用 + `applySaturationCtrlBits(...)` 把这些 CTRL 位写进去。 + +4. CTRL 位设置完成后,再按 `round_mode` 做一层 switch 分派 + 例如分到 `RoundRType` / `RoundAType` / `RoundFType` / `RoundCType` / + `RoundZType` / `RoundOType`,最后统一调用: + +```cpp +implTCVT(...) +``` + +5. `implTCVT(...)` 内部再按 type pair 落到具体 helper + 例如: + - `cast32to32` + - `cast32to16` + - `cast16to32` + - `cast16to16` + - `cast16to8` + - 以及 `NonSatTorch` 那几条专门 helper + +6. 最后恢复之前改过的 CTRL 位 + 也就是在主实现入口的尾部调用 `restoreSaturationCtrlBits(...)`。 + +把这段代码实现压成一条线来看,就是: + +```text +TCVT_IMPL(dst, src, mode) + -> 按类型对选默认 satMode + -> TCVT_IMPL(dst, src, mode, satMode) + -> determineSaturationCtrlBits(...) + -> applySaturationCtrlBits(...) + -> switch(round_mode) + -> implTCVT(...) + -> cast helper / NonSatTorch helper + -> restoreSaturationCtrlBits(...) +``` + +对 TileLib 来说,真正需要复现的就是这条框架,而不是只把 `rmode` 直接透传给某一个 +`vcvt` 就结束。 + +因此,对当前 A5 来说,`pto.tcvt` 需要对齐的真实语义是: + +1. 外部只显式给 `rmode` +2. 库内部按类型对选择默认 `sat_mode` +3. 再按类型对和 `sat_mode` 进入具体实现路径 + +## 3. A5 实现要点 + +### 3.1 默认 `sat_mode` + +A5 的 round-only `TCVT_IMPL(dst, src, mode)` 对下面这些类型对默认使用 `sat_mode=OFF`: + +| 源类型 | 目标类型 | 默认 `sat_mode` | 说明 | +|---|---|---|---| +| `f16` | `u8` | `OFF` | A5 现有默认行为 | +| `f16` | `i8` | `OFF` | A5 现有默认行为 | +| `f32` | `i16` | `OFF` | A5 现有默认行为 | +| `f16` | `i16` | `OFF` | A5 现有默认行为 | +| `i64` | `i32` | `OFF` | A5 现有默认行为 | +| `i32` | `i16` | `OFF` | A5 现有默认行为 | + +除上表外,其余类型对默认使用 `sat_mode=ON`。 + +这部分规则应直接在 TileLib 模板内部复现,不应依赖 PTOAS 额外传参。 + +### 3.2 A5 `TCVT` 整体支持表 + +按三个实现维度分类: + +- 是否受 `round_mode` 影响 +- 是否受 `sat_mode` 影响 +- 是否需要 `NonSatTorch` 对齐 + +这里根据 `pto-isa/include/pto/npu/a5/TCvt.hpp` 整理,不等于当前 +PTOAS + TileLib 已经全部打通。 + +下面各表最后一列 `TileLib是否支持` 以当前 +`PTOAS/lib/TileOps/tcvt_template.py` 实际实现为准。当前已打通的先标 `已支持`, +其余暂时留空。 + +#### 3.2.1 不受 `round_mode` / `sat_mode` 影响,也不需要 `NonSatTorch` + +这组最适合优先实现,基本都是 expand / unpack 路径。 + +| 源类型 | 目标类型 | A5 helper 覆盖 | 备注 | TileLib是否支持 | +|---|---|---|---|---| +| `f16` | `f32` | 1D+2D,`vcvt + part` | type expand | `已支持` | +| `bf16` | `f32` | 1D+2D,`vcvt + part` | type expand | `已支持` | +| `i16` | `f32` / `i32` / `u32` | 1D+2D,expand helper | widening path | `已支持` | +| `i32` | `i64` | 1D+2D,expand helper | | `已支持` | +| `u8` | `f16` / `u16` | 1D only,expand helper | 当前只看到 1D helper | `已支持` | +| `i8` | `f16` / `i16` / `i32` | 1D only,expand helper | 当前只看到 1D helper | `已支持` | +| `fp8_e4m3` / `fp8_e5m2` / `h8` | `f32` | 1D+2D,expand helper | source 8-bit float | | +| `fp4_e1m2x2` / `fp4_e2m1x2` | `bf16` | 1D+2D,专用 unpack helper | 4-bit packed source | | + +#### 3.2.2 受 `round_mode` 影响,不受 `sat_mode` 影响,也不需要 `NonSatTorch` + +这组属于 round-only 路径。 + +| 源类型 | 目标类型 | A5 helper 覆盖 | 备注 | TileLib是否支持 | +|---|---|---|---|---| +| `f32` | `f32` | 1D+2D,`vtrc` | 保持 `f32`,做 integer-valued float rounding | `已支持` | +| `f16` | `i32` | 1D+2D,`vcvt + part` | | `已支持` | +| `i16` | `f16` | 1D+2D,`vcvt` | | `已支持` | +| `i32` | `f32` | 1D+2D,`vcvt` | | `已支持` | +| `i64` | `f32` | 1D+2D,`vcvt + part` | | `已支持` | +| `bf16` | `fp4_e1m2x2` / `fp4_e2m1x2` | 1D+2D,专用 packed helper | 不是普通 `vcvt` 套餐,但不吃 `sat_mode` | | + +#### 3.2.3 不受 `round_mode` 影响,受 `sat_mode` 影响,不需要 `NonSatTorch` + +这组主要是整数窄化。 + +| 源类型 | 目标类型 | A5 helper 覆盖 | 默认 `effective_sat_mode` | 备注 | TileLib是否支持 | +|---|---|---|---|---|---| +| `i16` | `u8` | 1D+2D,`vcvt + part` | `ON` | | `已支持` | +| `i32` | `i16` | 1D+2D,`vcvt + part` | `OFF` | | `已支持` | +| `i32` | `u16` / `u8` | 1D+2D,`vcvt + part` | `ON` | | `已支持` | +| `u32` | `i16` / `u16` / `u8` | 1D+2D,`vcvt + part` | `ON` | | `已支持` | +| `i64` | `i32` | 1D+2D,`vcvt + part` | `OFF` | | `已支持` | + +#### 3.2.4 同时受 `round_mode` 和 `sat_mode` 影响,但不需要 `NonSatTorch` + +这组是常规 `tcvt` 主干路径。当前先打通的 `f32 -> i32` 就属于这一类。 + +| 源类型 | 目标类型 | A5 helper 覆盖 | 默认 `effective_sat_mode` | 备注 | TileLib是否支持 | +|---|---|---|---|---|---| +| `f32` | `f16` / `bf16` | 1D+2D,`vcvt + part` | `ON` | 窄化 float | `已支持` | +| `f32` | `i32` | 1D+2D,`vcvt` | `ON` | 当前已先打通这一类普通路径 | `已支持` | +| `f32` | `i64` | 1D+2D,`vcvt + part` | `ON` | | `已支持` | +| `f32` | `fp8_e4m3` / `fp8_e5m2` | 1D+2D,`vcvt + part` | `ON` | | | +| `f16` | `u8` | 1D+2D,`vcvt + part` | `OFF` | | `已支持` | +| `bf16` | `i32` | 1D+2D,`vcvt + part` | `ON` | | `已支持` | +| `bf16` | `f16` | 1D+2D,`vcvt` | `ON` | helper 内部是 `SAT_ROUND` 顺序 | `已支持` | + +#### 3.2.5 同时受 `round_mode` 和 `sat_mode` 影响,且需要 `NonSatTorch` + +这组后面要单独收口。不能把它们直接等价成普通 `vcvt(..., sat=NOSAT)`。 + +| 源类型 | 目标类型 | A5 helper 覆盖 | 默认 `effective_sat_mode` | `NonSatTorch` | 备注 | TileLib是否支持 | +|---|---|---|---|---|---|---| +| `f32` | `i16` | 1D+2D,`vcvt + part` | `OFF` | 是 | `OFF` 时走 `NonSatTorch` | `已支持` | +| `f16` | `i16` | 1D+2D,`vcvt` | `OFF` | 是 | | `已支持` | +| `f16` | `i8` | 1D+2D,`vcvt + part` | `OFF` | 是 | | `已支持` | + +#### 3.2.6 专用 helper,`round_mode` 受限 + +这组不建议和普通路径一起排第一批。A5 helper 虽然形式上带模板参数,但当前实现实际固定在特定 round 行为上。 + +| 源类型 | 目标类型 | A5 helper 覆盖 | 默认 `effective_sat_mode` | 备注 | TileLib是否支持 | +|---|---|---|---|---|---| +| `f32` | `h8` | 1D+2D,专用 helper | `ON` | helper 实际固定 `ROUND_A` | | +| `f16` | `h8` | 1D+2D,专用 helper | `ON` | helper 实际固定 `ROUND_A` | | + +这里再记三点: + +- `f16 -> fp8_e4m3/e5m2` 当前 A5 `pto-isa` 明确未实现;`f16` 这边只提供了 `h8` 专用 helper。 +- `h8`、`fp4` 这类路径不是普通 `vcvt` 套餐,后面做 TileLib 时不建议和常规 `f32/f16/bf16/int` 主干混在第一批一起做。 +- 这里说“受 / 不受 `round_mode` 影响”指的是该 pair 的 A5 helper 是否真的消费 round 语义,不是说 PTOAS 这层拿不到 `rmode`。 + +### 3.3 `round_mode` 映射表 + +当前 `pto.tcvt` 这条链路里,round mode 至少会经过四层名字: + +1. PTOAS op attr:`#pto` +2. `ExpandTileOp` 传给 TileLang 的上下文字符串:`round_mode` +3. TileLang DSL 前端:`pto.VcvtRoundMode.*` +4. VPTO / A5 lowering:`rnd = "R"` 这一类 token,或 `RoundMode::CAST_*` + +建议文档和实现都按下面这张表统一,不要在不同层写不同别名。 + +| PTOAS `rmode` | `ExpandTileOp` 传值 | DSL 前端 | VPTO token | A5 / EmitC | 语义 | +|---|---|---|---|---|---| +| `NONE` | `RINT` | `pto.VcvtRoundMode.R` | `R` / `ROUND_R` | `RoundMode::CAST_RINT` | round to nearest, ties to even | +| `RINT` | `RINT` | `pto.VcvtRoundMode.R` | `R` / `ROUND_R` | `RoundMode::CAST_RINT` | round to nearest, ties to even | +| `CAST_RINT` | `RINT` | `pto.VcvtRoundMode.R` | `R` / `ROUND_R` | `RoundMode::CAST_RINT` | round to nearest, ties to even | +| `ROUND` | `ROUND` | `pto.VcvtRoundMode.A` | `A` / `ROUND_A` | `RoundMode::CAST_ROUND` | round away from zero | +| `FLOOR` | `FLOOR` | `pto.VcvtRoundMode.F` | `F` / `ROUND_F` | `RoundMode::CAST_FLOOR` | round toward negative infinity | +| `CEIL` | `CEIL` | `pto.VcvtRoundMode.C` | `C` / `ROUND_C` | `RoundMode::CAST_CEIL` | round toward positive infinity | +| `TRUNC` | `TRUNC` | `pto.VcvtRoundMode.Z` | `Z` / `ROUND_Z` | `RoundMode::CAST_TRUNC` | round toward zero | +| `ODD` | `ODD` | `pto.VcvtRoundMode.O` | `O` / `ROUND_O` | `RoundMode::CAST_ODD` | round to odd | + +这里再补三条实现上要注意的点: + +- `ExpandTileOp` 当前应把 `NONE` / `RINT` / `CAST_RINT` 统一归一成 `RINT`,这样模板内部只需要处理一套默认 round-to-nearest 语义。 +- `PTO_IR_manual` 里对 `ROUND` 的描述偏旧,当前实现和 VPTO 规格应按 “away from zero” 理解。 +- `f32 -> f32` 这条 `vtrc` 路径不能直接照抄上表全部 token。当前 VPTO `vtrc` 规格只明确列了 `R/A/F/C/Z`,`ODD` 需要单独看目标语义,不应默认跟 `vcvt` 完全等价。 + +### 3.4 不同类型对的处理路径 + +从模板实现角度看,更重要的不是 A5 内部怎么切 CTRL 位,而是不同类型对最终该走哪条路径。建议按下面这张表组织 TileLib 逻辑: + +| 类型对 | 默认路径 | 备注 | +|---|---|---| +| `f32 -> f32` | `vtrc` | 这是 round-to-int-valued-float,不应走 `vcvt` | +| `f32 -> i16` 且 `sat_mode=OFF` | `NonSatTorch` helper | 需要对齐 A5 现有边界值行为 | +| `f16 -> i16` 且 `sat_mode=OFF` | `NonSatTorch` helper | 需要对齐 A5 现有边界值行为 | +| `f16 -> i8` 且 `sat_mode=OFF` | `NonSatTorch` helper | 需要对齐 A5 现有边界值行为 | +| 其余合法类型对 | `vcvt` | 具体带哪些 attr 取决于 VPTO contract | + +`NonSatTorch` 这三条路径不能简单等价成普通 `vcvt(..., sat=NOSAT)`。A5 这里保留了专门实现,是为了在 `inf`、`nan`、`overflow` 这些边界值上对齐当前行为。 + +### 3.5 `vcvt` 的 attr 约束 + +TileLib 侧即使已经推导出了 `sat_mode`,也不能无条件给 `vcvt` 传 `rnd/sat/part`。这些 attr 是否应该出现,仍然要服从 VPTO `vcvt` 的 verifier 约束。 + +下面列几个模板里一定会碰到的典型路径: + +| 类型对 | `rnd` | `sat` | `part` | 建议路径 | +|---|---|---|---|---| +| `f32 -> i32` | 需要 | 需要 | 不需要 | `vcvt` | +| `i32 -> f32` | 需要 | 不需要 | 不需要 | `vcvt` | +| `f32 -> f16/bf16` | 需要 | 需要 | 需要 | `vcvt` | +| `f16/bf16 -> f32` | 不需要 | 不需要 | 需要 | `vcvt` | +| `f32 -> f32` | 不适用 | 不适用 | 不适用 | `vtrc` | + +因此,模板里最好把“默认 `sat_mode` 推导”和“`vcvt` attr 组织”拆成两层,不要混在一起写。 + +## 4. TileLib 设计建议 + +### 4.1 模板主流程 + +TileLib 中的 `pto.tcvt` 模板建议保持下面这个结构: + +```python +@pto.vkernel(target="a5", op="pto.tcvt") +def template_tcvt(src: pto.Tile, dst: pto.Tile): + src_dtype = src.element_type + dst_dtype = dst.element_type + + round_mode = pto.get_op_attr("round_mode", "RINT") + sat_mode = _a5_default_tcvt_sat_mode(src_dtype, dst_dtype) + + if _needs_nonsat_torch(src_dtype, dst_dtype, sat_mode): + return _emit_nonsat_torch_tcvt(src, dst, round_mode) + + return _emit_regular_tcvt(src, dst, round_mode, sat_mode) +``` + +这里建议把逻辑拆成三个内部 helper: + +- `_a5_default_tcvt_sat_mode(src_dtype, dst_dtype)` +- `_needs_nonsat_torch(src_dtype, dst_dtype, sat_mode)` +- `_emit_regular_tcvt(...)` + +这样写更容易和 A5 `pto-isa` 的现有规则对齐,也方便后面做单测。 + +### 4.2 普通路径的分派原则 + +`_emit_regular_tcvt(...)` 里建议只做两件事: + +1. 判断当前类型对应该走 `vtrc` 还是 `vcvt` +2. 如果走 `vcvt`,按 VPTO contract 决定是否附带 `rnd`、`sat`、`part` + +不要直接按 A5 C++ helper 名称去分派 TileLang DSL。TileLib 需要对齐的是最终语义,而不是逐个复刻底层 helper 名。 + +### 4.3 `NonSatTorch` 的定位 + +`NonSatTorch` 在这里应视为模板内部实现细节,不是新的对外接口。 + +可以先完成普通路径,再补 `NonSatTorch`。如果目标是和当前 A5 行为严格对齐,这三条特殊路径需要在第一版就一起补上。 + +## 5. 工作项 + +### 5.1 TileLib 模板库 + +需要补一份 `pto.tcvt` TileLib 模板,实现以下逻辑: + +| 工作项 | 说明 | +|---|---| +| 读取 `round_mode` | 通过 `pto.get_op_attr("round_mode", "RINT")` 获取 | +| 推导默认 `sat_mode` | 严格按 A5 类型对规则实现 | +| 支持 `vtrc` 路径 | 至少覆盖 `f32 -> f32` | +| 支持普通 `vcvt` 路径 | 并满足 VPTO verifier 对 attr 的要求 | +| 支持 `NonSatTorch` 路径 | 至少覆盖 `f32 -> i16`、`f16 -> i16`、`f16 -> i8` 且默认 `OFF` 的场景 | + +### 5.2 DSL / ExpandHelper / `ExpandTileOp` + +除了模板本身,还需要把下面几处配套能力接上: + +| 模块 | 工作项 | +|---|---| +| TileLang DSL | 支持 `pto.get_op_attr("round_mode", ...)` | +| TileLang DSL | 为 `pto.vtrc` 补 round-mode surface,避免 `f32 -> f32` 卡住 | +| ExpandHelper | 传递 `round_mode` 到模板上下文 | +| `ExpandTileOp` | `SpecKey` 纳入 `round_mode`,避免不同 `rmode` 错误复用实例 | + +当前没有必要把 `sat_mode` 加进 `SpecKey`,因为在现有语义下,它完全由 `(src_dtype, dst_dtype)` 决定,而这部分已经包含在操作数 specialization 里。 + +### 5.3 测试 + +建议测试按三类准备: + +| 测试类型 | 关注点 | +|---|---| +| 模板选择与缓存 | 相同类型对、不同 `rmode` 不应复用同一实例 | +| 模板展开 | `round_mode` 能正确进入 `vtrc` / `vcvt` | +| 数值行为 | 默认 `OFF` 类型对、`NonSatTorch` 特殊路径、`f32 -> f32` 路径 | + +最少应覆盖下面这些代表性 case: + +- `f32 -> f32` +- `f32 -> i16` +- `f16 -> i16` +- `f16 -> i8` +- `f32 -> i32` +- `f32 -> f16` +- `i32 -> f32` + +## 6. 结论 + +这项工作的关键不是“把 `rmode` 传给一个 `vcvt`”这么简单,而是把当前 A5 `pto-isa` 在 round-only `TCVT_IMPL` 里隐含的默认 `sat_mode` 规则和类型分派规则一起带到 TileLib。 + +对当前 PTOAS `pto.tcvt` 而言,模板库应复现下面这条主线: + +1. 从 PTOAS 读取 `round_mode` +2. 在模板内部按 `(src_dtype, dst_dtype)` 推导默认 `sat_mode` +3. 按类型对分派到 `vtrc`、普通 `vcvt` 或 `NonSatTorch` helper + +这样实现出来的 TileLib 模板库,才能和 A5 `pto-isa` 现有行为保持一致。 diff --git a/docs/designs/acc-store-structured-interface.md b/docs/designs/acc-store-structured-interface.md new file mode 100644 index 000000000..2b8eccacc --- /dev/null +++ b/docs/designs/acc-store-structured-interface.md @@ -0,0 +1,202 @@ +# `acc_store` 统一接口设计 + +## 1. 目标方案 + +`acc_store` 的目标接口只保留 `target_profile` 可用的 `L0C -> OUT` 结构化语义: + +```mlir +pto.mte_l0c_l1 %src, %dst, %m, %n, %src_stride, %dst_stride, + unit_flag(check_only | check_and_clear)?, + pre_quant(%scalar_or_fb_addr, mode = ...)?, + pre_relu(%alpha_or_fb_addr, mode = ..., clip = %clip_value ?)?, + nz2nd? | nz2dn(%loop0_src_stride)? | nz2nz(%split)?, + loop3(%count, %src_stride, %dst_stride)?, + (sat | nosat)? +``` + +其中: + +- `%src` 必须是 `!pto.ptr<..., l0c>` +- `%dst` 允许是 `!pto.ptr<..., gm>` `!pto.ptr<..., vec>`或 `!pto.ptr<..., l1>` +- 这些扩展字段全部是可选的,未出现就表示不启用 +- `nz2nd / nz2dn / nz2nz` 和 `loop3` 是并列的 layout 相关参数 + +## 2. 字段形态 + +这些是这版结构化接口里各字段的可选/必填关系: + +- `unit_flag(check_only | check_and_clear)?` + - 不写表示 `off` + - `check_only` 对应先检查后不清零 + - `check_and_clear` 对应先检查后清零 + - `2'b01` 是 ISA 保留值,不进入合法 keyword +- `pre_quant(..., mode = ...)` + - `mode` 必填 + - `mode` 可选值:`no_convert`、`f32_f16`、`qf322hif8_pre_vec`、`qf322hif8_pre_scalar`、`qf322hif8_pre_hybrid_vec`、`qf322hif8_pre_hybrid_scalar`、`deqs32_int_vec`、`deqs32_int_scalar`、`req8_vec`、`req8_scalar`、`deqf16_vec`、`deqf16_scalar`、`qf322fp8_pre_vec`、`qf322fp8_pre_scalar`、`qf322f32_pre_vec`、`qf322f32_pre_scalar`、`f32_bf16`、`qf162b8_pre_vec`、`qf162b8_pre_scalar`、`qf162s4_pre_vec`、`qf162s4_pre_scalar`、`req4_vec`、`req4_scalar`、`qf322b8_pre_vec`、`qf322b8_pre_scalar`、`qf322s4_pre_vec`、`qf322s4_pre_scalar`、`deqs16_vec`、`deqs16_scalar`、`qf162s16_pre_vec`、`qf162s16_pre_scalar`、`qf322f16_pre_vec`、`qf322f16_pre_scalar`、`qf322bf16_pre_vec`、`qf322bf16_pre_scalar`、`qs322bf16_pre_vec`、`qs322bf16_pre_scalar` + - `%scalar_or_fb_addr` 由 `mode` 决定解释方式 + - scalar 类模式下,`%scalar_or_fb_addr` 是量化参数值,允许直接传 `f16`、`bf16`、`f32` + - `f16`/`bf16` scalar payload 会先扩成 `f32`,再按 `SPR.QUANT_PRE` 需要的 32-bit 浮点 bit pattern 编码 + - `f32` scalar payload 直接按 32-bit 浮点 bit pattern 编码到 `SPR.QUANT_PRE` + - vector 类模式下,`%scalar_or_fb_addr` 是 FB1 地址,映射到 `SPR.FPC[15:8] / Quant_PRE_ADDR` + - `mode` 还必须与 `acc_store*` 的源/目的元素类型匹配;例如 `f32 -> f16` 应选 `qf322f16_pre_vec/scalar`,`req8_vec/scalar` 只适用于 `i32 -> i8/u8` + - 无额外可选子参数 +- `pre_relu(%alpha_or_fb_addr, mode = ..., clip = %clip_value)?` + - `mode` 必填 + - `mode` 可选值:`no_relu`、`normal_relu`、`scalar_relu`、`vector_relu`、`pwl` + - payload 不是对所有 mode 都必填: + - `mode = no_relu` 或 `mode = normal_relu` 时,不带 payload + - `mode = scalar_relu` 时,必须带 `%alpha_or_fb_addr`,允许直接传 `f16`、`bf16`、`f32` + - `f16`/`bf16` scalar alpha 会先扩成 `f32`,再按 `SPR.RELU_ALPHA` 需要的 32-bit 浮点 bit pattern 编码 + - `f32` scalar alpha 直接按 32-bit 浮点 bit pattern 编码到 `SPR.RELU_ALPHA` + - `mode = vector_relu` 时,必须带 `%alpha_or_fb_addr`,其值作为 FB1 地址,映射到 `SPR.FPC[7:0] / RELU_PRE_ADDR` + - `clip = %clip_value` 为可选子句:表示启用 pre-stage clip,并把 `%clip_value` 映射到 `SPR.FIX_CLIP_RELU` + - `clip` 只允许用于手册明确覆盖的目标类型:`f16`、`ui8`、`s4/s8/s16` +- `nz2nd?` + - 无额外参数 +- `nz2dn(%loop0_src_stride)?` + - `loop0_src_stride` 必填 +- `nz2nz(%split)?` + - `split` 可选 + - 不写 `split` 表示不做 F32 channel split +- `loop3(%count, %src_stride, %dst_stride)?` + - 这三个参数都必填 + - 无额外可选子参数 +- `(sat | nosat)?` + - 可选 flag + - 不写表示不显式配置饱和控制,沿用进入 op 前的状态 + - 写 `sat` 表示本 op 内选择饱和行为 + - 写 `nosat` 表示本 op 内选择非饱和行为 + - `sat` 和 `nosat` 互斥 +- `atomic(type = ..., op = ...)?` + - 仅 `acc_store_gm` 支持 + - `type` 必填,可选值:`f32`、`f16`、`s16`、`s32`、`s8`、`bf16` + - `op` 必填,可选值:`add`、`max`、`min` + - 不写 `atomic(...)` 表示普通覆盖写回,不启用 OUT atomic read-modify-write + +## 3. 约束 + +`target_profile` 下,这版接口保留的有效项是: + +- `pre_quant(...)` +- `pre_relu(..., clip = %clip_value)?` +- `unit_flag(...)` +- `nz2nd` +- `nz2dn(%loop0_src_stride)` +- `nz2nz(%split)?` +- `loop3(...)` +- `sat` / `nosat` +- `atomic(...)`(仅 `acc_store_gm`) + +其中: + +- `pre_quant` 的 scalar 类模式走 `SPR.QUANT_PRE` +- `pre_quant` 的 vector 类模式走 `SPR.FPC[15:8] / Quant_PRE_ADDR`,对应 FB1 mem_block0 +- `pre_relu(%alpha_or_fb_addr, mode = scalar_relu)` 走 `SPR.RELU_ALPHA[31:13]`,不走 FB1 地址 +- `pre_relu(%alpha_or_fb_addr, mode = vector_relu)` 走 FB1 mem_block1,并通过 `SPR.FPC[7:0]` 选择 `RELU_PRE_ADDR` +- `pre_relu(..., clip = %clip_value)` 的 `clip` 子句走 `SPR.FIX_CLIP_RELU` +- `unit_flag` 不走 FB1 +- `split` 不走 FB1 +- `sat` / `nosat` 走 `SPR.CTRL` +- `atomic` 仅在 `acc_store_gm` 上走 `SPR.CTRL` +- `post-stage`、`element-wise`、`LoopEnhance` 相关扩展不纳入本版接口 + +注意:`NZ2DN` 和 `unit_flag` 不是无条件兼容的。`loop0_src_stride != 1` 时,`unit_flag` 必须关闭。 + +`target_profile` 下不是禁止 `NZ2ND / NZ2DN` 的参数。相反,`FIX_L0C_TO_OUT.f32/s32` 明确标了 `NZ2ND Mode` 和 `NZ2DN Mode` valid;其中 `nz2dn(%loop0_src_stride)` 仍然需要把 `loop0_src_stride` 写入 `CHANNEL_PARA[63:48]`,单位是 `C0_SIZE`。 + +`nz2nz(%split)` 只允许用于 `f32` 输出。`SPLIT_EN = 1` 且输出类型不是 `f32` 时是非法配置。 + +`loop3(...)` 不是 `nz2dn` 或 `nz2nd` 的别名,它是单独的参数组,只在 `nz2nd` 或 `nz2dn` 场景下使用。 + +## 4. 映射 + +- `pre_quant(%scalar_or_fb_addr, mode = ...)` 映射到 `SPR.QUANT_PRE` 或 `SPR.FPC[15:8] / Quant_PRE_ADDR` +- `pre_relu(%alpha_or_fb_addr, mode = ...)` 映射到 `X_t[41:39] / ReLU_PRE`,并按模式进一步映射到 `SPR.RELU_ALPHA[31:13]` 或 `SPR.FPC[7:0] / RELU_PRE_ADDR` +- `pre_relu(..., clip = %clip_value)?` 映射到 `X_t[31:30] / Clip_ReLU_PRE`(使能)以及 `SPR.FIX_CLIP_RELU[15:0]` +- `unit_flag(check_only | check_and_clear)?` 映射到 `X_t[33:32] / unit_flag` +- `nz2nz(%split)?` 映射到 `X_t[42] / SPLIT_EN` +- `nz2dn(%loop0_src_stride)` 映射到 `CHANNEL_PARA[63:48]` +- `loop3(...)` 映射到 `SPR.LOOP3_PARA` +- `sat` / `nosat` 映射到 `SPR.CTRL[48] / ctrl_sat_ctrl` +- `atomic(type = ..., op = ...)?` 仅对 `acc_store_gm` 有效,映射到 `SPR.CTRL[8:6] / ctrl_atomic_en` 和 `SPR.CTRL[10:9] / ctrl_atomic_op` + +## 5. Keyword + +当前结构化接口使用语义 keyword,不直接暴露 bit 编码: + +- `pre_relu.mode` + - `no_relu` -> `3'b000` + - `normal_relu` -> `3'b001` + - `scalar_relu` -> `3'b010` + - `vector_relu` -> `3'b011` + - `pwl` -> `3'b100` +- `pre_quant.mode` + - `no_convert` -> `6'b000000` + - `f32_f16` -> `6'b000001` + - `qf322hif8_pre_vec` -> `6'b000010` + - `qf322hif8_pre_scalar` -> `6'b000011` + - `qf322hif8_pre_hybrid_vec` -> `6'b000100` + - `qf322hif8_pre_hybrid_scalar` -> `6'b000101` + - `deqs32_int_vec` -> `6'b000110` + - `deqs32_int_scalar` -> `6'b000111` + - `req8_vec` -> `6'b001000` + - `req8_scalar` -> `6'b001001` + - `deqf16_vec` -> `6'b001010` + - `deqf16_scalar` -> `6'b001011` + - `qf322fp8_pre_vec` -> `6'b001100` + - `qf322fp8_pre_scalar` -> `6'b001101` + - `qf322f32_pre_vec` -> `6'b001110` + - `qf322f32_pre_scalar` -> `6'b001111` + - `f32_bf16` -> `6'b010000` + - `qf162b8_pre_vec` -> `6'b010001` + - `qf162b8_pre_scalar` -> `6'b010010` + - `qf162s4_pre_vec` -> `6'b010011` + - `qf162s4_pre_scalar` -> `6'b010100` + - `req4_vec` -> `6'b010101` + - `req4_scalar` -> `6'b010110` + - `qf322b8_pre_vec` -> `6'b010111` + - `qf322b8_pre_scalar` -> `6'b011000` + - `qf322s4_pre_vec` -> `6'b011001` + - `qf322s4_pre_scalar` -> `6'b011010` + - `deqs16_vec` -> `6'b011011` + - `deqs16_scalar` -> `6'b011100` + - `qf162s16_pre_vec` -> `6'b011101` + - `qf162s16_pre_scalar` -> `6'b011110` + - `qf322f16_pre_vec` -> `6'b011111` + - `qf322f16_pre_scalar` -> `6'b100000` + - `qf322bf16_pre_vec` -> `6'b100001` + - `qf322bf16_pre_scalar` -> `6'b100010` + - `qs322bf16_pre_vec` -> `6'b100011` + - `qs322bf16_pre_scalar` -> `6'b100100` +- `pre_quant.scalar` + - the specific scalar payload is mode-dependent and lives in `SPR.QUANT_PRE` +- `pre_quant.fb_addr` + - the specific parameter array address is mode-dependent and lives in `SPR.FPC[15:8]` +- `pre_relu.clip`(是否出现 `clip = %clip_value` 子句) + - 未出现 -> `2'b00` + - 出现 -> `2'b01` +- `unit_flag` + - absent -> `2'b00` + - `check_only` -> `2'b10` + - `check_and_clear` -> `2'b11` +- `atomic.type` + - `f32` -> `3'b001` + - `f16` -> `3'b010` + - `s16` -> `3'b011` + - `s32` -> `3'b100` + - `s8` -> `3'b101` + - `bf16` -> `3'b110` +- `atomic.op` + - `add` -> `2'b00` + - `max` -> `2'b01` + - `min` -> `2'b10` +- `sat` / `nosat` + - absent -> no explicit `CTRL[48]` override + - `sat` -> `CTRL[48] = 1'b0` + - `nosat` -> `CTRL[48] = 1'b1` + +## 6. 说明 + +这份文档只描述目标方案,不保留旧扁平接口的过渡写法,也不展开 `profile1` 的后处理、element-wise 和 LoopEnhance 字段列表。 + +这里不再引入 `fixpipe(...)` 大包;这些项直接作为 `acc_store` 的结构化语义字段出现,避免把 source access、layout transform 和 writeback control 都误解成同一个固定 pipeline stage。 diff --git a/docs/designs/cube-load-nd2nz-dn2nz-interface.md b/docs/designs/cube-load-nd2nz-dn2nz-interface.md new file mode 100644 index 000000000..488b1b3ec --- /dev/null +++ b/docs/designs/cube-load-nd2nz-dn2nz-interface.md @@ -0,0 +1,572 @@ +# `cube_load_nd2nz` / `cube_load_dn2nz` 接口统一性整理 + +## 1. 目标 + +本文只做一件事:基于 `pto-isa` 的真实 A5 用法,整理 `cube_load_nd2nz` 和 `cube_load_dn2nz` 对应底层接口的参数语义与典型使用场景,评估两者是否可以收敛到同一套上层接口模型。 + +本文不讨论 release 文档写法,也不讨论 LLVM emitter 细节,只关注: + +- `pto-isa` 里底层 intrinsic 是怎么被调用的 +- 每个场景下每个参数实际表达什么 +- 哪些参数天然共通 +- 哪些差异需要保留为 mode 区分 + +## 2. 底层接口长什么样 + +在 A5 `pto-isa` 中,这两条路径最终都走 `TLoadCubeInstr`,再分发到底层 intrinsic: + +- `ND` 路径: `copy_gm_to_cbuf_multi_nd2nz` +- `DN` 路径: `copy_gm_to_cbuf_multi_dn2nz` + +参考: + +- [`include/pto/npu/a5/TLoad.hpp:235`](../../../../gitlab.com/cann/pto-isa/include/pto/npu/a5/TLoad.hpp#L235) + +两者在 A5 上的调用形态基本一致: + +```cpp +copy_gm_to_cbuf_multi_*d2nz(dst, src, + 0 /*sid*/, + loop1SrcStride, + 0 /*l2_cache_ctrl*/, + nValue, + dValue, + loop4SrcStride, + false /*smallc0_en*/); +``` + +这里有两个重要事实: + +1. `sid` 在 `pto-isa` 的这些场景里固定为 `0` +2. 目标侧的 NZ 落点结构,并不是通过 intrinsic 参数直接完整表达,而是预先通过 `set_mte2_nz_para(...)` 编程 + +也就是说,真实语义来自两部分: + +- intrinsic 实参: 源侧遍历方式 + 一部分搬运形状 +- `MTE2_NZ_PARA`: 目标侧 NZ 存放结构 + +## 3. 统一参数视角 + +虽然底层名字分成 `nd2nz` 和 `dn2nz`,但从 `pto-isa` 的真实使用看,它们可以先抽象成同一组语义参数,并收敛到统一的上层接口 `cube_load_frac`: + +### 3.1 intrinsic 侧参数 + +| 统一名称 | A5 底层字段 | 语义 | +|---|---|---| +| `src` | `src` | GM 源指针 | +| `dst` | `dst` | CBUF/L1 目标指针 | +| `l2_cache_ctrl` | `l2_cache_ctrl` | L2 cache control 配置位 | +| `src_inner_stride` | `loop1SrcStride` | 源侧最内层重复单元之间的跨度,单位 byte | +| `n_value` | `nValue` | 一次连续搬运的内层长度 | +| `d_value` | `dValue` | 被打包进 NZ/C0 结构的那一维大小,常见是 `C`、`K` 或 `validRow/validCol` 中的一维 | +| `src_outer_stride` | `loop4SrcStride` | 源侧更外一层重复单元之间的跨度,单位 byte;无外层时通常为 `0` | +| `smallc0_en` | `smallc0_mode` | small C0 mode 开关;仅在 `D <= 4` 时可开启 | + +### 3.2 `MTE2_NZ_PARA` 侧参数 + +`pto-isa` 中目标侧结构通过 `set_mte2_nz_para(...)` 传入: + +```text +MTE2_NZ_PARA[63:48] = loop4DstStride +MTE2_NZ_PARA[47:32] = loop3DstStride +MTE2_NZ_PARA[31:16] = loop2DstStride +MTE2_NZ_PARA[15:0] = groupCount +``` + +这里的 `loop2/3/4` 目标 stride 单位都不是 byte,而是 `C0_size`。 + +在 A5 上,`C0_size` 是硬件固定的 32B 地址单位。 +因此: + +- `dst_loop*_stride = 1` 表示目标地址前进 `32B` +- `dst_loop*_stride = 4` 表示目标地址前进 `128B` + +需要注意,`C0_size` 固定为 32B,但一个 `C0` 中包含多少个元素,取决于元素类型大小: + +| 元素类型 | 每个 `C0` 可容纳的元素数 | +|---|---| +| `i8` / `u8` | `32` | +| `f16` / `i16` | `16` | +| `f32` / `i32` | `8` | + +这里最后 16bit 在不同 mode 下叫法不同: + +- `nd2nz` 场景里通常叫 `ndNum` +- `dn2nz` 场景里通常叫 `dnNum` + +但从统一建模的角度,它们本质上都可以看成: + +- `group_count`: 目标侧 NZ 排布中,由硬件一次处理的外层组数 + +因此目标侧可以统一抽象成: + +| 统一名称 | A5 底层字段 | 语义 | +|---|---|---| +| `group_count` | `MTE2_NZ_PARA[15:0]` | 最内层之上的目标分组数;在不同场景里具体映射为 `ndNum` 或 `dnNum` | +| `dst_loop2_stride` | `MTE2_NZ_PARA[31:16]` | 目标 NZ 结构的 loop2 步长 | +| `dst_loop3_stride` | `MTE2_NZ_PARA[47:32]` | 目标 NZ 结构的 loop3 步长 | +| `dst_loop4_stride` | `MTE2_NZ_PARA[63:48]` | 目标 NZ 结构的 loop4 步长 | + +## 4. `nd2nz` 的真实使用场景 + +### 4.1 场景 A: `MX_A_ND -> ZZ` + +对应 `pto-isa`: + +- `TLoadMxCubeADN2ZZ` +- [`include/pto/npu/a5/TLoad.hpp:723`](../../../../gitlab.com/cann/pto-isa/include/pto/npu/a5/TLoad.hpp#L723) + +参数映射: + +| 参数 | 取值方式 | 含义 | +|---|---|---| +| `n_value` | `validCol >> 1` | 每次搬运的列向长度 | +| `d_value` | `validRow` | 每次搬运的行向长度 | +| `src_inner_stride` | `GetByteSize(dtype, gStride4) * sizeof(uint16_t)` | 源内层相邻片段跨度 | +| `src_outer_stride` | `0` | 该场景无更外一层源重复 | +| `group_count` | `1` | 单组 | +| `dst_loop2_stride` | `1` | 固定 | +| `dst_loop3_stride` | `TileData::Cols >> 1` | 目标列方向 NZ 布局步长 | +| `dst_loop4_stride` | `0` | 无更外层目标重复 | + +这个场景本质上是: + +- 源是 ND 风格遍历 +- 目标写入左矩阵使用的 ZZ 型 NZ 布局 + +### 4.2 场景 B: `MX_B_ND -> NN` + +对应 `pto-isa`: + +- `TLoadMxCubeBND2NN` +- [`include/pto/npu/a5/TLoad.hpp:744`](../../../../gitlab.com/cann/pto-isa/include/pto/npu/a5/TLoad.hpp#L744) + +参数映射: + +| 参数 | 取值方式 | 含义 | +|---|---|---| +| `n_value` | `validRow >> 1` | 每次搬运的行向长度 | +| `d_value` | `validCol` | 每次搬运的列向长度 | +| `src_inner_stride` | `GetByteSize(dtype, gStride3) * sizeof(uint16_t)` | 源内层相邻片段跨度 | +| `src_outer_stride` | `0` | 无更外层源重复 | +| `group_count` | `1` | 单组 | +| `dst_loop2_stride` | `1` | 固定 | +| `dst_loop3_stride` | `TileData::Rows >> 1` | 目标行方向 NZ 布局步长 | +| `dst_loop4_stride` | `0` | 无更外层目标重复 | + +这个场景和上一条基本同构,只是 `A/B` 左右矩阵语义不同,导致 `n_value` / `d_value` 与目标 stride 的映射不同。 + +### 4.3 场景 C: 通用 `ND -> [N,C1,H,W,C0]` + +对应 `pto-isa`: + +- 通用 ND 到卷积 tile 的路径 +- [`include/pto/npu/a5/TLoad.hpp:1000`](../../../../gitlab.com/cann/pto-isa/include/pto/npu/a5/TLoad.hpp#L1000) + +参数映射: + +| 参数 | 取值方式 | 含义 | +|---|---|---| +| `group_count` | `srcShape2` | 这里就是 `ndNum = H` | +| `n_value` | `srcShape3` | 这里是 `W` | +| `d_value` | `srcShape4` | 这里是 `C` | +| `src_inner_stride` | `bytes(gStride3)` | W 维相邻行组跨度 | +| `src_outer_stride` | `bytes(gStride2)` | H 维相邻组跨度 | +| `dst_loop2_stride` | `1` | 固定 | +| `dst_loop3_stride` | `dstShape2 * dstShape3` | 目标 `H*W` 组跨度 | +| `dst_loop4_stride` | `dstShape3` | 目标 `W` 步长 | + +这个场景最能体现 `nd2nz` 的共性: + +- `group_count` 真正承担的是一个外层 ND 组数 +- `src_outer_stride` 在这里是真实有意义的,不是所有场景都能省掉 + +## 5. `dn2nz` 的真实使用场景 + +### 5.1 场景 A: `MX_A_DN -> ZZ` + +对应 `pto-isa`: + +- `TLoadMxCubeAND2ZZ` +- [`include/pto/npu/a5/TLoad.hpp:664`](../../../../gitlab.com/cann/pto-isa/include/pto/npu/a5/TLoad.hpp#L664) + +参数映射: + +| 参数 | 取值方式 | 含义 | +|---|---|---| +| `n_value` | `validCol >> 1` | 每次搬运的列向长度 | +| `d_value` | `validRow` | 每次搬运的行向长度 | +| `src_inner_stride` | `bytes(gStride3)` | 源内层相邻片段跨度 | +| `src_outer_stride` | `0` | 无更外层源重复 | +| `group_count` | `1` | 这里就是 `dnNum = 1` | +| `dst_loop2_stride` | `1` | 固定 | +| `dst_loop3_stride` | `TileData::Cols >> 1` | 目标列方向 NZ 布局步长 | +| `dst_loop4_stride` | `0` | 无更外层目标重复 | + +### 5.2 场景 B: `MX_B_DN -> NN` + +对应 `pto-isa`: + +- `TLoadMxCubeBDN2NN` +- [`include/pto/npu/a5/TLoad.hpp:765`](../../../../gitlab.com/cann/pto-isa/include/pto/npu/a5/TLoad.hpp#L765) + +参数映射: + +| 参数 | 取值方式 | 含义 | +|---|---|---| +| `n_value` | `validRow >> 1` | 每次搬运的行向长度 | +| `d_value` | `validCol` | 每次搬运的列向长度 | +| `src_inner_stride` | `bytes(gStride4)` | 源内层相邻片段跨度 | +| `src_outer_stride` | `0` | 无更外层源重复 | +| `group_count` | `1` | 单组 | +| `dst_loop2_stride` | `1` | 固定 | +| `dst_loop3_stride` | `TileData::Rows >> 1` | 目标行方向 NZ 布局步长 | +| `dst_loop4_stride` | `0` | 无更外层目标重复 | + +### 5.3 场景 C: `NCHW -> [N,C1,H,W,C0]` + +对应 `pto-isa`: + +- `TLoadNCHW` +- [`include/pto/npu/a5/TLoad.hpp:1027`](../../../../gitlab.com/cann/pto-isa/include/pto/npu/a5/TLoad.hpp#L1027) + +参数映射: + +| 参数 | 取值方式 | 含义 | +|---|---|---| +| `group_count` | `1` | 这里固定 `dnNum = 1` | +| `n_value` | `srcW` 或 `srcH * srcW` | 内层搬运单元;W 连续时可并成 `H*W` | +| `d_value` | `srcC` | 被 pack 进 `C0` 的通道数 | +| `src_inner_stride` | `bytes(gStride2)` | 相邻 `C` 分片对应的源跨度 | +| `src_outer_stride` | `0` | 该路径外层循环通常软件展开在外面 | +| `dst_loop2_stride` | `1` | 固定 | +| `dst_loop3_stride` | `dstH * dstW` | 目标 HW 组跨度 | +| `dst_loop4_stride` | `dstW` | 目标 W 步长 | + +这里 `dn2nz` 的特点是: + +- `group_count` 往往不是 `H` / `D` 这种大维度 +- 外层 `H` 或 `N` 的重复,很多时候不是塞进 intrinsic,而是由外层 for 循环包住 + +### 5.4 场景 D: `NCDHW -> [N,D,C1,H,W,C0]` + +对应 `pto-isa`: + +- `TLoadNCDHW2NDC1HWC0` +- [`include/pto/npu/a5/TLoad.hpp:1128`](../../../../gitlab.com/cann/pto-isa/include/pto/npu/a5/TLoad.hpp#L1128) + +参数映射: + +| 参数 | 取值方式 | 含义 | +|---|---|---| +| `group_count` | `1` | 这里固定 `dnNum = 1` | +| `n_value` | `srcH * srcW` 或退化为 `srcW` | H/W 是否连续决定内层搬运长度 | +| `d_value` | `srcC` | 被 pack 进 `C0` 的通道数 | +| `src_inner_stride` | `bytes(gStride1)` | 相邻 `C` 分片的源跨度 | +| `src_outer_stride` | `0` | `D` / `H` 外层重复通常由外部循环承担 | +| `dst_loop2_stride` | `1` | 固定 | +| `dst_loop3_stride` | `dstH * dstW` | 目标 HW 组跨度 | +| `dst_loop4_stride` | `dstW` | 目标 W 步长 | + +### 5.5 场景 E: `NCHW -> FractalZ` + +对应 `pto-isa`: + +- `TLoadNCHW2FractalZ` +- [`include/pto/npu/a5/TLoad.hpp:1085`](../../../../gitlab.com/cann/pto-isa/include/pto/npu/a5/TLoad.hpp#L1085) + +参数映射: + +| 参数 | 取值方式 | 含义 | +|---|---|---| +| `group_count` | `srcShape1` | 这里 `dnNum = N` | +| `n_value` | `gStride2` | 一次搬完整个 `H*W` | +| `d_value` | `srcShape2` | 这里是 `C` | +| `src_inner_stride` | `bytes(gStride2)` | 一个 `N` 组对应的源跨度 | +| `src_outer_stride` | `bytes(gStride1)` | 相邻更外层组跨度 | +| `dst_loop2_stride` | `dstShape1 * dstShape2` | 目标 loop2 步长 | +| `dst_loop3_stride` | `loop2DstStride * dstHW` | 目标 loop3 步长 | +| `dst_loop4_stride` | `1` | 连续存放 | + +这个场景说明: + +- `dn2nz` 也不是只能处理 `group_count = 1` +- `src_outer_stride` 也不是 `dn2nz` 专属的无效参数 + +## 6. 对比结论 + +从 `pto-isa` 的真实使用看,`nd2nz` 和 `dn2nz` 的差异并不在于“参数种类不同”,而在于“同一组参数对应的源布局遍历语义不同”。 + +两者共通点: + +- 都需要 `src_inner_stride` +- 都需要 `n_value` +- 都需要 `d_value` +- 都可能需要 `src_outer_stride` +- 都需要目标侧 `group_count / dst_loop2_stride / dst_loop3_stride / dst_loop4_stride` +- 都经常配合外层软件循环使用 + +两者核心差异: + +- `group_count` 在 `nd2nz` 中更像 ND 分组数,在 `dn2nz` 中更像 DN 分组数 +- 源张量哪一维映射到 `n_value` / `d_value` / `src_inner_stride`,取决于源布局模式 +- 某些 `dn2nz` 场景会把 `H` / `D` 外层维度拆到软件循环,而不是塞进 `group_count` + +因此,如果只从“参数列表”看,这两条接口是可以统一的;真正需要保留差异的是: + +- 一个显式 `nd2nz | dn2nz` mode keyword +- 每个 mode 自己的 shape-to-parameter 映射规则 + +## 7. 一个最小搬移示意 + +这里用一个最小例子,把这些参数如何驱动“多组 2D 矩阵 -> 多组 NZ 分形”的搬移过程画出来。 + +设: + +- `group_count = 2` +- `n_value = 3` +- `d_value = 5` +- `src_inner_stride = 32B` +- `src_outer_stride = 256B` +- `dst_loop2_stride = 1` +- `dst_loop3_stride = 4` +- `dst_loop4_stride = 20` + +这里可以把源理解成两组逻辑 2D 矩阵,每组都是 `N x D = 3 x 5`。 + +### 7.1 源侧视角 + +先不区分 `nd2nz` / `dn2nz` 的地址解释差异,只看统一抽象下的“分组 + 内层步长”: + +```text +group 0 base = src + 0 * src_outer_stride +group 1 base = src + 1 * src_outer_stride + +group g 内部有 3 个 N 单元: + +N0 base = group_base + 0 * src_inner_stride +N1 base = group_base + 1 * src_inner_stride +N2 base = group_base + 2 * src_inner_stride + +每个 N 单元里有 D=5 个元素: + +N0: [d0 d1 d2 d3 d4] +N1: [d0 d1 d2 d3 d4] +N2: [d0 d1 d2 d3 d4] +``` + +如果画成两组源矩阵,可以看成: + +```text +group 0: + N0 -> [00 01 02 03 04] + N1 -> [10 11 12 13 14] + N2 -> [20 21 22 23 24] + +group 1: + N0 -> [30 31 32 33 34] + N1 -> [40 41 42 43 44] + N2 -> [50 51 52 53 54] +``` + +这里: + +- `src_inner_stride` 决定 `N0 -> N1 -> N2` 怎么跳 +- `src_outer_stride` 决定 `group 0 -> group 1` 怎么跳 +- `n_value = 3` 决定每组取 3 条 N +- `d_value = 5` 决定每条 N 上取 5 个 D 元素 + +### 7.2 目标 NZ 视角 + +目标不是平铺成普通二维矩阵,而是按 NZ 分形排布到 L1。 + +可以先把它抽象成: + +```text +group g 的目标基址 = dst + g * dst_loop4_stride * C0_size + +group 内部: + 第 i 个 D-block 的目标基址 = group_dst_base + i * dst_loop3_stride * C0_size + 第 j 个 N 单元的目标基址 = d_block_dst_base + j * dst_loop2_stride * C0_size +``` + +在这个例子里: + +- `dst_loop2_stride = 1` 表示相邻 N 单元在目标上紧邻排布 +- `dst_loop3_stride = 4` 表示相邻 D-block 之间隔 4 个 `C0_size` +- `dst_loop4_stride = 20` 表示相邻 group 的整块矩阵在目标上隔 20 个 `C0_size` + +如果只画逻辑落点关系,不展开完整 `C0`,可以看成: + +```text +group 0 NZ: + D-block 0: + N0 <- [00 01 02 03 ...] + N1 <- [10 11 12 13 ...] + N2 <- [20 21 22 23 ...] + D-block 1: + N0 <- [04 pad pad pad ...] + N1 <- [14 pad pad pad ...] + N2 <- [24 pad pad pad ...] + +group 1 NZ: + D-block 0: + N0 <- [30 31 32 33 ...] + N1 <- [40 41 42 43 ...] + N2 <- [50 51 52 53 ...] + D-block 1: + N0 <- [34 pad pad pad ...] + N1 <- [44 pad pad pad ...] + N2 <- [54 pad pad pad ...] +``` + +这里故意选 `d_value = 5`,就是为了看出尾块不满时的行为: + +- 第一块装下前 4 个 D 元素 +- 第二块只剩第 5 个元素 +- 尾部由硬件补 pad + +### 7.3 参数到底在控制什么 + +把这个例子压缩成一句话: + +- `n_value` 决定每组有多少条 N 线要搬 +- `d_value` 决定每条 N 线上有多少个 D 元素要 pack 进分形 +- `src_inner_stride` 决定源上相邻两条 N 线怎么跳 +- `src_outer_stride` 决定源上相邻两组矩阵怎么跳 +- `dst_loop2_stride` 决定目标上相邻 N 线怎么摆 +- `dst_loop3_stride` 决定目标上相邻 D-block 怎么摆 +- `dst_loop4_stride` 决定目标上相邻 group 怎么摆 + +### 7.4 `nd2nz` 和 `dn2nz` 真正差在哪 + +上面的图故意只画了统一抽象,因为两条指令的参数框架本身是一样的。 + +真正的差异在于:源侧地址解释顺序不同。 + +- `nd2nz`: 更像把源看成 ND 矩阵,再按 `N x D` 逻辑去取数 +- `dn2nz`: 更像把源看成 DN 矩阵,再按另一套源地址递推顺序去取数 + +但无论哪一种: + +- `n_value` / `d_value` 仍然定义“这一组搬多大” +- `src_inner_stride` / `src_outer_stride` 仍然定义“源怎么走” +- `dst_loop2/3/4_stride` 仍然定义“NZ 分形怎么落” + +因此从上层接口看,它们完全可以共享同一组参数模型,只在 `mode` 上区分源布局解释规则。 + +## 8. 建议的统一抽象 + +如果上层想统一接口,建议先统一成“参数语义层”,而不是强行复用现有底层名字。 + +可以考虑的统一抽象如下: + +```text +cube_load_frac( + src, + dst, + nd2nz | dn2nz, + shape(n_value, d_value), + src_layout(src_inner_stride, src_outer_stride?), + dst_group(group_count, dst_loop2_stride, dst_loop3_stride, dst_loop4_stride), + ctrl(l2_cache_ctrl, smallc0_en) +) +``` + +其中: + +- `shape(...)` 只描述一次分形搬移的逻辑 `N x D` 大小 +- `src_layout(...)` 只描述源侧地址递推 +- `dst_group(...)` 只描述目标 NZ 分形排布 +- `ctrl(...)` 只描述底层控制位 + +如果 `src_outer_stride` 不提供,则默认按 `0` 处理。 + +这套抽象的好处: + +- `nd2nz` / `dn2nz` 共享同一组结构化参数 +- 底层是否走 `copy_gm_to_cbuf_multi_nd2nz` 还是 `copy_gm_to_cbuf_multi_dn2nz`,由 `mode` 决定 +- `MTE2_NZ_PARA` 的 4 个字段可以原样保留,不需要再隐式推导 +这类接口不单独暴露 `padding` 参数。 + +- 当 `d_value` 不能完整填满目标分形时,尾部补齐由硬件按 zero padding 完成 +- 当 `smallc0_en = true` 时,small C0 mode 会改变补齐与对齐方式,但仍然不是用户可配置的 pad value + +因此,这里的 padding 语义属于指令内建行为,而不是像 `dma_load` 那样的显式接口参数。 + +如果直接写成接近 VPTO 的 syntax 草案,可以是: + +```text +pto.mte_gm_l1_frac %src, %dst, + nd2nz | dn2nz, + shape(%n_value, %d_value), + src_layout(%src_inner_stride[, %src_outer_stride]), + dst_group(%group_count, %dst_loop2_stride, %dst_loop3_stride, %dst_loop4_stride), + ctrl(%l2_cache_ctrl, %smallc0_en) + : !pto.ptr<..., gm>, !pto.ptr<..., l1>, + nd2nz | dn2nz, + shape i64, i64, + src_layout(i64[, i64]), + dst_group i64, i64, i64, i64, + ctrl i64, i1 +``` + +推荐的 builder 视角也和语法保持一致: + +```text +cube_load_frac( + src, dst, + mode, + shape(n_value, d_value), + src_layout(src_inner_stride, src_outer_stride = 0), + dst_group(group_count, dst_loop2_stride, dst_loop3_stride, dst_loop4_stride), + ctrl(l2_cache_ctrl, smallc0_en) +) +``` + +## 9. 哪些参数可以默认,哪些最好显式暴露 + +从 `pto-isa` 的现状看: + +### 9.1 可以默认的 + +- `sid = 0` + +`sid` 在当前调研到的 A5 `pto-isa` 使用点里都是固定值。 + +### 9.2 建议显式暴露的 + +- `mode` +- `shape(n_value, d_value)` +- `src_layout(src_inner_stride, src_outer_stride?)` +- `dst_group(group_count, dst_loop2_stride, dst_loop3_stride, dst_loop4_stride)` +- `ctrl(l2_cache_ctrl, smallc0_en)` + +这些都是底层接口真实存在、并且会影响行为或未来扩展空间的参数。其中: + +- `l2_cache_ctrl` 当前 `pto-isa` A5 用法里固定传 `0` +- `smallc0_en` 当前 `pto-isa` A5 用法里固定传 `false` +- 但从 `disa-cube.json` 看,这两个字段都属于原始接口语义的一部分,不应在统一接口里直接消失 + +### 9.3 可选暴露的 + +- `src_outer_stride` + +这个参数不是每个场景都需要,但一旦做通用接口,最好保留。 +在结构化接口里,`src_outer_stride` 仍属于 `src_layout(...)` 的一部分,只是允许省略。 + +## 10. 初步判断 + +结论很直接: + +- `cube_load_nd2nz` 和 `cube_load_dn2nz` 在参数语义层是可以统一的 +- 不能统一掉的不是参数列表,而是 `mode` 对源布局遍历规则的解释 +- 如果后续要做 VPTO 新接口,建议抽象成一个统一的 `cube_load_frac` 接口,再保留 `nd2nz | dn2nz` 这两个 mode keyword +- 这套统一接口更适合使用结构化分组: + - `shape(...)` + - `src_layout(...)` + - `dst_group(...)` + - `ctrl(...)` +- 在这套统一接口里,`l2_cache_ctrl` 和 `smallc0_en` 也应保留为显式参数;只有 `sid` 可以继续固定隐藏 + +如果下一步需要,我可以继续把这份设计文档再往前推进一层,直接写成一版面向 VPTO op 设计的 syntax 草案和 verifier 约束。 diff --git a/docs/designs/mad-lowering-contract-design.md b/docs/designs/mad-lowering-contract-design.md new file mode 100644 index 000000000..48a337056 --- /dev/null +++ b/docs/designs/mad-lowering-contract-design.md @@ -0,0 +1,600 @@ +# `mad` 族泛化 lowering 规则设计 + +## 问题 + +`mad` 族当前已经拆成 semantic op 和 raw op,但 lowering 仍然不够泛化: + +- ordinary MAD 与 MX MAD 的 callee 选择依赖局部 if / fallback,导致 FP8 场景容易串线。 +- `X_t`、`CTRL`、bias packing、callee dispatch 分散在多个 helper 中,新增一种类型或模式时不清楚应该改哪一层。 +- 部分类型识别依赖字符串匹配,并散落在 emitter 逻辑里。 + +要解决的问题不是再加一个“更大的 descriptor”,而是定义一组泛化 lowering 规则: + +- 不复制 operand。 +- 不把可由类型推导的信息再枚举存一份。 +- 不让 raw-to-LLVM 重新解释 semantic clause。 +- 不允许 ordinary/MX family 互相 fallback。 + +## 核心原则 + +### 1. IR 本身是事实源,op interface 是访问入口 + +lowering 不引入承载 operand 的 descriptor。operand、type、attribute 仍然只存在于原 op 上。 +但 lowering 也不应该到处写 `isa` 这种 class 判断。 + +需要给 semantic MAD 和 raw MAD 各定义一个 op interface,让不同 op class 暴露同一组 +accessor 和派生语义: + +```c++ +enum class MadFamily { Ordinary, Mx }; +enum class MadAccumulation { ZeroInit, Accumulate, BiasInit }; +enum class MadRawKind { Ordinary, OrdinaryBias, Mx, MxBias }; + +class MadSemanticOpInterface { + Value getLhs(); + Value getRhs(); + Value getDst(); + Value getM(); + Value getN(); + Value getK(); + + bool hasBiasOperand(); + Value getBiasOrNull(); + bool supportsTf32Mode(); + bool readsAccumulator(); + bool initializesAccumulatorWithZero(); + bool initializesAccumulatorWithBias(); + + std::optional getUnitFlagMode(); + bool getDisableGemv(); + std::optional getSatMode(); + std::optional getTf32Mode(); + bool getNDir(); + + MadFamily getMadFamily(); + MadAccumulation getMadAccumulation(); +}; + +class MadRawOpInterface { + Value getLhs(); + Value getRhs(); + Value getDst(); + Value getXt(); + + bool hasBiasOperand(); + Value getBiasOrNull(); + + MadRawKind getMadRawKind(); + MadFamily getMadFamily(); + bool readsAccumulator(); + bool initializesAccumulatorWithZero(); + bool initializesAccumulatorWithBias(); +}; +``` + +interface method 可以由 ODS 的 extra class declaration 或 C++ method 实现。关键是: +lowering pattern 只匹配 interface,不直接按 6 个 semantic op class 和 4 个 raw op +class 分别写逻辑。 + +允许在 interface method 的实现内部有一次 class 分发,因为那是 op 定义层的局部事实; +不允许在 lowering 主流程里散落 class 分发。 + +MLIR 不会因为原 op 已经有同名 getter 就自动认为它实现了 interface。需要在 ODS +里显式把 interface 加到 op traits 上。interface method 的实现有两种方式: + +```tablegen +def MadSemanticOpInterface : OpInterface<"MadSemanticOpInterface"> { + let cppNamespace = "::mlir::pto"; + let methods = [ + InterfaceMethod<"lhs", "::mlir::Value", "getLhs">, + ... + ]; +} + +def PTO_MadOp : PTO_Op<"mad", [ + MadSemanticOpInterface, + DeclareOpInterfaceMethods +]> { ... } +``` + +如果 interface method 没有 default implementation,ODS 只会为实现该 interface 的 +op 生成声明,具体定义需要在 C++ 里补齐。若所有实现 op 都有相同名字和相同语义的 +generated accessor,可以在 interface method 上写 default implementation: + +```tablegen +InterfaceMethod< + "lhs", + "::mlir::Value", + "getMadLhs", + (ins), + [{}], + [{ return $_op.getLhs(); }] +> +``` + +对 `lhs/rhs/dst/m/n/k` 这类所有 semantic MAD 都同名的字段,可以用 default +implementation 直接转发已有 accessor。对 `bias`、`tf32_mode` 这类并非所有 op +都有的字段,interface 必须提供 capability 方法: + +- `getMadFamily()`:ordinary 或 MX,决定 raw family 和 callee lookup family。 +- `getMadAccumulation()`:`ZeroInit / Accumulate / BiasInit`,决定 accumulator 初值语义。 +- `readsAccumulator()`:只有 acc 模式返回 true;它表示 C 初值来自现有 `%dst`。 +- `initializesAccumulatorWithZero()`:zero-init 模式返回 true;它决定 `X_t.c_init = 1`。 +- `initializesAccumulatorWithBias()`:bias-init 模式返回 true;它决定 `X_t.c_src = 1`。 +- `hasBiasOperand()`:只有 bias-init op 返回 true。 +- `getBiasOrNull()`:`hasBiasOperand() == false` 时返回空值。 +- `supportsTf32Mode()`:ordinary MAD op 返回 true,MX MAD op 返回 false。 +- `getTf32Mode()`:`supportsTf32Mode() == false` 时必须返回 `std::nullopt`。 + +lowering 只能先看 capability,再使用 optional accessor;不能直接假设所有实现 op +都有 `getBias()` 或 `getTf32ModeAttr()`。 + +这两个 capability 是正交的,不能用一个 op class 分支同时处理: + +| op | family | accumulation | reads acc | zero init | bias init | bias operand | TF32 | +|---|---|---|---:|---:|---:|---:|---:| +| `pto.mad` | Ordinary | ZeroInit | false | true | false | false | true | +| `pto.mad_acc` | Ordinary | Accumulate | true | false | false | false | true | +| `pto.mad_bias` | Ordinary | BiasInit | false | false | true | true | true | +| `pto.mad_mx` | MX | ZeroInit | false | true | false | false | false | +| `pto.mad_mx_acc` | MX | Accumulate | true | false | false | false | false | +| `pto.mad_mx_bias` | MX | BiasInit | false | false | true | true | false | + +因此 `mad_bias` 同时是 bias op 和 TF32-capable op;`mad_mx_bias` 是 bias op +但不是 TF32-capable op。`mad_acc` 没有额外 operand,但它是唯一会读取现有 accumulator +作为 C 初值的模式。lowering 不能把 “没有 bias operand” 直接等价成 “zero-init”, +也不能把 “bias op” 和 “不支持 TF32” 绑定在一起。 + +实现上也不要让 interface default implementation 调用某个并非所有 op 都存在的 +generated getter。建议: + +- `getLhs/getRhs/getDst` 可以用固定 operand index `0/1/2`。 +- `getBiasOrNull` 根据 `hasBiasOperand()` 决定是否返回 operand `3`。 +- `getM/getN/getK` 根据 `hasBiasOperand()` 决定从 operand `3/4/5` 或 `4/5/6` 读取。 +- `readsAccumulator/initializesAccumulatorWithZero/initializesAccumulatorWithBias` + 从 `getMadAccumulation()` 派生,三者必须互斥且恰好一个为 true。 +- `getTf32Mode` 通过通用 attribute lookup 读取 `"tf32_mode"`,而不是调用 generated + `getTf32Mode()`;MX op 没有这个 attr 时自然返回空。 +- verifier 保证 `supportsTf32Mode() == false` 的 op 不能携带 `"tf32_mode"`。 + +也就是说,interface 的统一性来自“固定 operand organization + capability”,不是来自 +假设所有 op 都有完全相同的 generated C++ getter。 + +因此这里不是“自己写 C++ 继承类”。正确方式是: + +1. 在 `PTOInterfaces.td` 定义 op interface。 +2. 在每个 MAD ODS op 的 traits 中声明实现该 interface。 +3. 能统一转发的 getter 用 interface default implementation。 +4. 形态相关的字段,例如 family / accumulation / bias/tf32 capability,用每个 op + 的小实现显式给出。 + +### 2. family 由 op kind 决定,不由类型猜 + +ordinary / MX 是 op 语义,不是类型语义: + +- `pto.mad*` semantic op 只能 lower 到 ordinary raw family。 +- `pto.mad_mx*` semantic op 只能 lower 到 MX raw family。 +- `pto.mad_raw` / `pto.mad_bias_raw` 只能发 ordinary MAD。 +- `pto.mad_mx_raw` / `pto.mad_mx_bias_raw` 只能发 MX MAD。 + +类型只用于选择同一 family 内的具体 typed intrinsic。类型不能改变 family。 + +这是防止普通 FP8 和 MX FP8 串线的关键规则。 + +### 3. lowering 使用规则函数,不使用大 descriptor + +semantic-to-raw 必须有一个统一入口。这个入口负责从 semantic op 生成 raw op +需要的两个运行时值: + +- `xt`:raw MAD 的 packed shape/config operand,由 semantic op 的 + `m/n/k` 和 clause 生成。 +- `ctrl_for_mad`:本次 MAD 临时使用的控制状态,由 semantic op 的 + numeric/layout clause 和指针类型生成。 + +规则 helper 只服务这个统一入口,并且接收 interface,而不是裸 `Operation *`: + +```c++ +MadRawKind deriveRawKind(MadSemanticOpInterface op); +Value buildMadXtFromSemanticOp(MadSemanticOpInterface op, + PatternRewriter &rewriter); +Value emitCtrlForMad(MadSemanticOpInterface op, Value ctrlSaved, + PatternRewriter &rewriter); +StringRef lookupMadIntrinsic(MadRawOpInterface op); +``` + +这些函数不返回“重新包装过的 op”。它们只返回当前阶段真正需要的产物。 + +## semantic-to-raw 规则 + +semantic-to-raw 的统一入口是: + +```c++ +LogicalResult lowerMadSemanticOp(MadSemanticOpInterface op, + PatternRewriter &rewriter); +``` + +这个函数是唯一创建 `xt` 的地方。`xt` 不是外部传入的,也不是 raw-to-LLVM +再生成的;它在 semantic-to-raw 期间由原 semantic op 的 operands/attributes +构造出来,然后作为 operand 传给 raw op。 + +输入是 semantic op,输出是: + +```text +get_ctrl +set_ctrl(ctrl_for_this_mad) +raw op(..., xt) +set_ctrl(ctrl_saved) +``` + +### raw op 选择 + +raw op 只由 semantic op 名字决定: + +| semantic op | raw op | +|---|---| +| `pto.mad` | `pto.mad_raw` | +| `pto.mad_acc` | `pto.mad_raw` | +| `pto.mad_bias` | `pto.mad_bias_raw` | +| `pto.mad_mx` | `pto.mad_mx_raw` | +| `pto.mad_mx_acc` | `pto.mad_mx_raw` | +| `pto.mad_mx_bias` | `pto.mad_mx_bias_raw` | + +这里不需要 descriptor。pattern 使用一个通用 +`lowerMadSemanticOp(MadSemanticOpInterface op)`,通过 interface 的 +`getMadFamily()` 和 `getMadAccumulation()` 选择 raw op。 + +### `X_t` 生成 + +`X_t` 是 raw op 的 packed `xt` operand。它只在 +`buildMadXtFromSemanticOp(op)` 中生成,来源是 semantic op 本身: + +```text +X_t.M = op.m +X_t.K = op.k +X_t.N = op.n +X_t.unit_flag = op.unit_flag or 0 +X_t.disable_gemv = op.has(disable_gemv) +X_t.c_src = op.initializesAccumulatorWithBias() +X_t.c_init = op.initializesAccumulatorWithZero() +``` + +其中 accumulation 由 semantic op 自己通过 interface 暴露: + +```text +mad / mad_mx -> ZeroInit +mad_acc / mad_mx_acc -> Accumulate +mad_bias / mad_mx_bias -> BiasInit +``` + +这条规则避免把 `c_src/c_init` 存进另一个结构。它们是 op kind 的派生语义。 + +### `CTRL` 生成 + +`CTRL` 只由 semantic clause 和指针类型生成: + +```text +CTRL[HiF8] = isHiF8(lhs.type, rhs.type) +CTRL[TF32 enable/round] = op.supportsTf32Mode ? op.tf32_mode/default : disabled +CTRL[sat] = op.sat_mode only if explicitly present +CTRL[n_dir] = op.has(n_dir) +``` + +规则: + +- HiF8 必须从 lhs/rhs 指针元素类型推导,不能作为独立 operand 或 enum 保存。 +- TF32 只允许 `supportsTf32Mode() == true` 的 ordinary `f32 x f32 -> f32` + semantic op 使用;MX op 的 `supportsTf32Mode()` 必须为 false,`getTf32Mode()` + 必须返回空值。 +- `sat|nosat` 不写时不覆盖对应状态;写了才覆盖。 +- `n_dir` 不写时显式设置为默认方向,避免污染后续 MAD。 +- semantic-to-raw 必须保存并恢复进入 op 前的 `CTRL`。 + +### semantic-to-raw 伪码 + +```c++ +LogicalResult lowerMadSemanticOp(MadSemanticOpInterface op, + PatternRewriter &rewriter) { + // One entry for the entire semantic-to-raw conversion. + // Every value consumed by the raw op is produced here. + MadRawKind rawKind = deriveRawKind(op); // only from interface family/accumulation + Value xt = buildMadXtFromSemanticOp(op, rewriter); // op.m/n/k + clauses + + Value ctrlSaved = emitGetCtrl(); + Value ctrlForOp = emitCtrlForMad(op, ctrlSaved, rewriter); + emitSetCtrl(ctrlForOp); + + emitRawOp(rawKind, op, xt, rewriter); // forwards existing operands + + emitSetCtrl(ctrlSaved); + erase op; +} +``` + +注意这里的 `emitRawOp(rawKind, op, xt, rewriter)` 只是把原 op 的现有 +data operands 加上刚生成的 `xt` 转发给 raw op,不创建一份新的 operand model。 + +更具体地说,几个规则函数应当长这样: + +```c++ +MadRawKind deriveRawKind(MadSemanticOpInterface op) { + switch (op.getMadFamily()) { + case MadFamily::Ordinary: + return op.getMadAccumulation() == MadAccumulation::BiasInit + ? MadRawKind::OrdinaryBias + : MadRawKind::Ordinary; + case MadFamily::Mx: + return op.getMadAccumulation() == MadAccumulation::BiasInit + ? MadRawKind::MxBias + : MadRawKind::Mx; + } +} + +Value buildMadXtFromSemanticOp(MadSemanticOpInterface op, + PatternRewriter &rewriter) { + Value m = op.getM(); + Value n = op.getN(); + Value k = op.getK(); + + Value xt = zextOrCastI64(m); + xt = bitOr(xt, shl(zextOrCastI64(k), 12)); + xt = bitOr(xt, shl(zextOrCastI64(n), 24)); + + if (auto mode = op.getUnitFlagMode()) { + uint64_t bits = *mode == pto::MadUnitFlagMode::CheckOnly ? 2 : 3; + xt = bitOr(xt, shl(i64(bits), 55)); + } + + if (op.getDisableGemv()) + xt = bitOr(xt, shl(i64(1), 61)); + + if (op.initializesAccumulatorWithBias()) + xt = bitOr(xt, shl(i64(1), 62)); + + if (op.initializesAccumulatorWithZero()) + xt = bitOr(xt, shl(i64(1), 63)); + + return xt; +} + +Value emitCtrlForMad(MadSemanticOpInterface op, Value ctrlSaved, + PatternRewriter &rewriter) { + Value ctrl = ctrlSaved; + + // HiF8 is inferred from the existing pointer element types. + bool hif8 = isHiF8Type(getPtrElementType(op.getLhs())); + ctrl = setCtrlBit(ctrl, kCtrlHiF8, hif8); + + if (op.supportsTf32Mode()) { + auto tf32 = op.getTf32Mode(); + ctrl = setCtrlBit(ctrl, kCtrlTf32Enable, true); + ctrl = setCtrlBit(ctrl, kCtrlTf32RoundAway, + tf32 && *tf32 == pto::Tf32Mode::RoundAway); + } else { + ctrl = setCtrlBit(ctrl, kCtrlTf32Enable, false); + ctrl = setCtrlBit(ctrl, kCtrlTf32RoundAway, false); + } + + // sat/nosat is only an override when the semantic op spells it explicitly. + if (auto sat = op.getSatMode()) + ctrl = setCtrlBit(ctrl, kCtrlNoSat, + *sat == pto::MadSatMode::NoSat); + + ctrl = setCtrlBit(ctrl, kCtrlNDir, op.getNDir()); + return ctrl; +} + +void emitRawOp(MadRawKind rawKind, MadSemanticOpInterface op, Value xt, + PatternRewriter &rewriter) { + Value lhs = op.getLhs(); + Value rhs = op.getRhs(); + Value dst = op.getDst(); + + switch (rawKind) { + case MadRawKind::Ordinary: + rewriter.create(op.getLoc(), lhs, rhs, dst, xt); + return; + case MadRawKind::OrdinaryBias: + assert(op.hasBiasOperand()); + rewriter.create(op.getLoc(), lhs, rhs, dst, + op.getBiasOrNull(), xt); + return; + case MadRawKind::Mx: + rewriter.create(op.getLoc(), lhs, rhs, dst, xt); + return; + case MadRawKind::MxBias: + assert(op.hasBiasOperand()); + rewriter.create(op.getLoc(), lhs, rhs, dst, + op.getBiasOrNull(), xt); + return; + } +} +``` + +这里的 `op.getLhs()/op.getM()/op.getUnitFlagMode()` 都来自 interface。它们只从原 op +读取 operand 或 attribute,不缓存、不重组、不创建新的语义对象。 + +## raw-to-LLVM 规则 + +raw-to-LLVM 的输入是 raw op。它只做四件事: + +1. 从 raw op kind 得到 family。 +2. 从 raw op operand type 生成 intrinsic type suffix。 +3. 查 family-local intrinsic 表。 +4. 发 call。 + +### family-local dispatch + +callee lookup 必须拆成两个互不 fallback 的入口: + +```c++ +FailureOr lookupOrdinaryMadIntrinsic(Type lhs, Type rhs, Type dst); +FailureOr lookupMxMadIntrinsic(Type lhs, Type rhs, Type dst); +``` + +调用规则: + +```c++ +if (op.getMadFamily() == MadFamily::Ordinary) + callee = lookupOrdinaryMadIntrinsic(lhsElem, rhsElem, dstElem); +else if (op.getMadFamily() == MadFamily::Mx) + callee = lookupMxMadIntrinsic(lhsElem, rhsElem, dstElem); +``` + +禁止: + +```c++ +ordinary lookup failed -> try MX lookup +MX lookup failed -> try ordinary lookup +``` + +这比 `MadElementFamily` 更直接:类型 suffix 是从 operand type 现场推导的,不需要先存成 enum。 + +### typed suffix 推导 + +suffix 推导只回答“这个 type 在当前 family 下叫什么”: + +```c++ +FailureOr getOrdinaryMadTypeSuffix(Type lhsElem, Type rhsElem, + Type dstElem); + +FailureOr getMxMadTypeSuffix(Type lhsElem, Type rhsElem, + Type dstElem); +``` + +示例: + +```text +ordinary: + f16, f16, f32 -> "f162f32.c310" + bf16, bf16, f32 -> "bf162f32.c310" + f32, f32, f32 -> "f322f32.c310" + e4m3, e4m3, f32 -> "e4m3e4m3.c310" + +MX: + e4m3, e4m3, f32 -> "e4m3e4m3" + e4m3, e5m2, f32 -> "e4m3e5m2" +``` + +同样的 FP8 类型组合在 ordinary 和 MX 下可以映射到不同 intrinsic stem,但这个差异由 family-local lookup 决定,不由类型自己决定。 + +### HiF8 处理 + +HiF8 不参与 raw-to-LLVM callee 区分: + +- HiF8 ordinary MAD 使用 ordinary FP8 typed suffix。 +- HiF8 的执行解释由 semantic-to-raw 的 `CTRL` 修改表达。 +- raw-to-LLVM 不读取 HiF8 semantic mode,也不设置 `CTRL`。 + +这保证 HiF8 不会因为 callee 名称选择污染 ordinary FP8。 + +### bias packing + +bias packing 是 raw kind 的机械规则: + +```text +mad_raw / mad_mx_raw: + call dst = dst + +mad_bias_raw / mad_mx_bias_raw: + call dst = pack(dst, bias) +``` + +它不参与 callee 选择,也不影响 ordinary/MX family。 + +### raw-to-LLVM 伪码 + +```c++ +LogicalResult emitMadRaw(MadRawOpInterface op, + ConversionPatternRewriter &rewriter) { + Type lhsElem = getPtrElementType(op.getLhs()); + Type rhsElem = getPtrElementType(op.getRhs()); + Type dstElem = getPtrElementType(op.getDst()); + + FailureOr callee = + op.getMadFamily() == MadFamily::Mx + ? lookupMxMadIntrinsic(lhsElem, rhsElem, dstElem) + : lookupOrdinaryMadIntrinsic(lhsElem, rhsElem, dstElem); + if (failed(callee)) + return failure(); + + Value lhs = castToLeft(op.getLhs()); + Value rhs = castToRight(op.getRhs()); + Value dst = castToAcc(op.getDst()); + Value callDst = op.hasBiasOperand() + ? packDstAndBias(dst, castToBias(op.getBiasOrNull())) + : dst; + + emitCall(*callee, callDst, lhs, rhs, op.getXt()); +} +``` + +## 类型识别规则 + +类型识别不应该在 emitter 中到处 `contains("e4m3")`。需要收敛成 family-local type suffix helper: + +```c++ +FailureOr getOrdinaryMadElemToken(Type elem); +FailureOr getMxMadElemToken(Type elem); +bool isHiF8Type(Type elem); +``` + +约束: + +- 优先使用 PTO type API。 +- 如果某些 FP8/HiF8 类型暂时没有稳定 API,允许在这个 helper 内部有兼容字符串匹配。 +- 字符串匹配不得出现在 callee lookup、pattern rewrite、raw lowering 主流程里。 +- unsupported target-profile type 在 helper 中失败,不进入 fallback。 + +这样新增类型只改 type token helper 和对应 family-local suffix 规则。 + +## 实现组织 + +建议新增轻量 helper 和 op interface,而不是新增大 descriptor: + +```text +include/PTO/IR/PTOInterfaces.td +include/PTO/Transforms/MadLoweringRules.h +lib/PTO/Transforms/MadLoweringRules.cpp +``` + +放入: + +- `MadSemanticOpInterface` / `MadRawOpInterface` +- semantic-to-raw 规则:`deriveRawKind`、`buildMadXtFromSemanticOp`、`emitCtrlForMad` +- raw-to-LLVM 规则:`lookupOrdinaryMadIntrinsic`、`lookupMxMadIntrinsic` +- type token helper:`getOrdinaryMadElemToken`、`getMxMadElemToken`、`isHiF8Type` + +不放入: + +- operand 副本 +- type-family enum 副本 +- 与具体 rewriter 强绑定的大型状态对象 +- lowering 主流程里的 repeated op class 判断 + +`VPTOExpandWrapperOps.cpp` 保留 IR 构造和 pattern 注册。 +`VPTOLLVMEmitter.cpp` 保留 LLVM address-space cast、bias packing、call emission。 + +## 验收标准 + +结构验收: + +- ordinary raw lowering 只调用 `lookupOrdinaryMadIntrinsic`。 +- MX raw lowering 只调用 `lookupMxMadIntrinsic`。 +- 两个 lookup 之间没有 fallback。 +- semantic-to-raw 主流程匹配 `MadSemanticOpInterface`,不是逐个 op class 模板实例。 +- raw-to-LLVM 主流程匹配 `MadRawOpInterface`,不是逐个 raw op class 分支。 +- semantic-to-raw 不构造保存 operand 的 descriptor。 +- raw-to-LLVM 不读取 semantic clause。 +- FP8/HiF8 字符串识别如果存在,只存在于 type token helper。 + +行为验收: + +- MAD SIM 全量通过。 +- ordinary FP8 `mad_raw` 静态导向 ordinary `MAD.e4m3e4m3`。 +- `mad_mx_raw` / `mad_mx_bias_raw` 静态导向 `MMAD.MX.*`。 +- HiF8 + 后续 ordinary FP8 的同 kernel SIM 通过,证明 `CTRL` 不泄漏。 +- `sat|nosat`、`tf32_mode`、`n_dir` 的现有 SIM 覆盖仍通过。 + +## 非目标 + +本设计不改用户可见 MAD op 语法,不新增 MX scale operand,不改 acc_store 族接口,也不重新定义 `sat|nosat` 数值语义。 diff --git a/docs/designs/mad-semantic-op-design.md b/docs/designs/mad-semantic-op-design.md new file mode 100644 index 000000000..e5c9a3a7b --- /dev/null +++ b/docs/designs/mad-semantic-op-design.md @@ -0,0 +1,539 @@ +# `mad` 族语义化 op 设计 + +## 目标 + +把 `pto.mad*` / `pto.mad_mx*` 从“按 ISA 位域拼装”收敛成“语义自描述” op。 + +设计原则: + +- op 直接表达计算语义 +- 影响结果的因素必须可见 +- 能从类型推导的,不再单独暴露 +- 不能从类型推导的,必须显式成 clause +- target profile 先做闭包,不把 profile1 / reserved 字段混进来 + +这里讨论的是 `disa-cube.json` 对应的 target profile 语义。 + +## 1. 语义来源 + +### 1.1 从指针类型推导 + +`mad` / `mad_mx` 的矩阵类型应由指针元素类型推导,而不是再单独放一个 `type` 参数。 + +### 1.2 必须显式表达 + +- `unit_flag` +- `disable_gemv` +- `sat` / `nosat` +- `tf32_mode` +- `n_dir` +- `bias` + +`C` 的初值语义不单独做成 clause,而是由 op 本身区分: + +- `pto.mad`:zero-init +- `pto.mad_acc`:accumulate-init +- `pto.mad_bias`:bias-init + +### 1.3 通过规则约束,不作为独立 operand + +- `mad_mx` 的 scale 地址 +- 对齐 / fractal / layout 约束 +- GEMV 条件 + +## 2. `mad` 族完整 op 集 + +### 2.1 `pto.mad` + +```mlir +pto.mad %lhs, %rhs, %dst, %m, %n, %k + unit_flag(check_only | check_and_set)? + disable_gemv? + (sat | nosat)? + tf32_mode(round_even | round_away)? + n_dir? + : !pto.ptr<..., l0a>, !pto.ptr<..., l0b>, !pto.ptr<..., l0c>, i64, i64, i64 +``` + +语义: + +```text +dst = lhs * rhs +``` + +### 2.2 `pto.mad_acc` + +```mlir +pto.mad_acc %lhs, %rhs, %dst, %m, %n, %k + unit_flag(check_only | check_and_set)? + disable_gemv? + (sat | nosat)? + tf32_mode(round_even | round_away)? + n_dir? + : !pto.ptr<..., l0a>, !pto.ptr<..., l0b>, !pto.ptr<..., l0c>, i64, i64, i64 +``` + +语义: + +```text +dst = dst + lhs * rhs +``` + +### 2.3 `pto.mad_bias` + +```mlir +pto.mad_bias %lhs, %rhs, %dst, %bias, %m, %n, %k + unit_flag(check_only | check_and_set)? + disable_gemv? + (sat | nosat)? + tf32_mode(round_even | round_away)? + n_dir? + : !pto.ptr<..., l0a>, !pto.ptr<..., l0b>, !pto.ptr<..., l0c>, + !pto.ptr<..., bt>, i64, i64, i64 +``` + +语义: + +```text +dst = bias + lhs * rhs +``` + +### 2.4 `pto.mad_mx` + +```mlir +pto.mad_mx %lhs, %rhs, %dst, %m, %n, %k + unit_flag(check_only | check_and_set)? + disable_gemv? + (sat | nosat)? + n_dir? + : !pto.ptr<..., l0a>, !pto.ptr<..., l0b>, !pto.ptr<..., l0c>, i64, i64, i64 +``` + +语义: + +```text +dst = (ScaleA * lhs) * (ScaleB * rhs) +``` + +说明: + +- `ScaleA` / `ScaleB` 不作为显式 operand +- 它们通过 `lhs` / `rhs` 的地址派生到 `L0A_MX / L0B_MX` +- `lhs` 与 `rhs` 的 MX scale 存储必须已被外部加载并与 data tile 对齐 + +### 2.5 `pto.mad_mx_acc` + +```mlir +pto.mad_mx_acc %lhs, %rhs, %dst, %m, %n, %k + unit_flag(check_only | check_and_set)? + disable_gemv? + (sat | nosat)? + n_dir? + : !pto.ptr<..., l0a>, !pto.ptr<..., l0b>, !pto.ptr<..., l0c>, i64, i64, i64 +``` + +语义: + +```text +dst = dst + (ScaleA * lhs) * (ScaleB * rhs) +``` + +### 2.6 `pto.mad_mx_bias` + +```mlir +pto.mad_mx_bias %lhs, %rhs, %dst, %bias, %m, %n, %k + unit_flag(check_only | check_and_set)? + disable_gemv? + (sat | nosat)? + n_dir? + : !pto.ptr<..., l0a>, !pto.ptr<..., l0b>, !pto.ptr<..., l0c>, + !pto.ptr<..., bt>, i64, i64, i64 +``` + +语义: + +```text +dst = bias + (ScaleA * lhs) * (ScaleB * rhs) +``` + +## 3. raw op 接口 + +semantic `mad` 族会展开成: + +```text +CTRL update + raw MAD/MMAD op +``` + +raw op 只承载底层 MAD/MMAD 指令本身,不承载 `CTRL` 语义。 + +### 3.1 raw op 集合 + +为了保留 typed pointer 和 memory effect 信息,raw 层不直接做成全寄存器 +`i64, i64, i64, i64` 形式,而是使用 typed pointer 加 packed `X_t`: + +```mlir +pto.mad_raw %lhs, %rhs, %dst, %xt + : !pto.ptr<..., l0a>, !pto.ptr<..., l0b>, !pto.ptr<..., l0c>, i64 + +pto.mad_bias_raw %lhs, %rhs, %dst, %bias, %xt + : !pto.ptr<..., l0a>, !pto.ptr<..., l0b>, !pto.ptr<..., l0c>, + !pto.ptr<..., bt>, i64 + +pto.mad_mx_raw %lhs, %rhs, %dst, %xt + : !pto.ptr<..., l0a>, !pto.ptr<..., l0b>, !pto.ptr<..., l0c>, i64 + +pto.mad_mx_bias_raw %lhs, %rhs, %dst, %bias, %xt + : !pto.ptr<..., l0a>, !pto.ptr<..., l0b>, !pto.ptr<..., l0c>, + !pto.ptr<..., bt>, i64 +``` + +`mad_acc` 和 `mad_mx_acc` 不需要单独 raw op;它们使用 +`pto.mad_raw` / `pto.mad_mx_raw`,区别只在 `%xt` 里的 `c_init` 位。 + +### 3.2 raw operand 语义 + +- `%lhs`:底层 `[X_n]`,Matrix A,必须是 `left` +- `%rhs`:底层 `[X_m]`,Matrix B,必须是 `right` +- `%dst`:底层 `[X_d[31:0]]`,Matrix C in L0C,必须是 `acc` +- `%bias`:底层 `[X_d[63:32]]`,bias table buffer,必须是 `bias` +- `%xt`:已经 packed 的 `X_t`,类型必须是 `i64` + +raw op 不再接收: + +- `unit_flag(...)` +- `disable_gemv` +- `sat` +- `tf32_mode(...)` +- `n_dir` +- `m / n / k` + +这些都必须在 semantic-to-raw 展开之前完成编码。 + +### 3.3 `%xt` packed bit 约定 + +`%xt` 是底层 `X_t`: + +- `[11:0]`:M +- `[23:12]`:K +- `[35:24]`:N +- `[56:55]`:unit-flag control bits,合法值 `0 / 2 / 3` +- `[61]`:GEMV disable +- `[62]`:C source,`0` 表示 L0C,`1` 表示 bias table +- `[63]`:C init,`1` 表示 C 初值为 0,`0` 表示读取 C + +semantic op 到 raw op 的 `%xt` 生成规则: + +| semantic op | raw op | `X_t[62] / c_src` | `X_t[63] / c_init` | +|---|---|---:|---:| +| `pto.mad` | `pto.mad_raw` | `0` | `1` | +| `pto.mad_acc` | `pto.mad_raw` | `0` | `0` | +| `pto.mad_bias` | `pto.mad_bias_raw` | `1` | `0` | +| `pto.mad_mx` | `pto.mad_mx_raw` | `0` | `1` | +| `pto.mad_mx_acc` | `pto.mad_mx_raw` | `0` | `0` | +| `pto.mad_mx_bias` | `pto.mad_mx_bias_raw` | `1` | `0` | + +### 3.4 raw op 不负责的内容 + +raw op 不配置 `SPR.CTRL`。下面这些语义必须由 semantic-to-raw 展开显式插入 +`get_ctrl / sbitset0 / sbitset1 / set_ctrl`: + +- `hif8` ptr element type -> `CTRL[45]` +- `tf32_mode(...)` -> `CTRL[46] / CTRL[47]` +- `sat` / `nosat` -> `CTRL[48]` +- `n_dir` -> `CTRL[51]` + +其中 `hif8 / tf32_mode / n_dir` 有明确的 off/on 语义,所以 semantic-to-raw +不能只在开启时置 1,也要在关闭时置 0: + +- 普通 `fp8_e4m3` -> `CTRL[45] = 0` +- `hif8` -> `CTRL[45] = 1` +- 普通 `f322f32` -> `CTRL[46] = 0` +- `tf32_mode(round_even)` -> `CTRL[46] = 1, CTRL[47] = 0` +- `tf32_mode(round_away)` -> `CTRL[46] = 1, CTRL[47] = 1` +- 不写 `n_dir` -> `CTRL[51] = 0` +- 写 `n_dir` -> `CTRL[51] = 1` + +`sat` / `nosat` 当前仍是显式 flag:写 `sat` 生成饱和语义配置,写 `nosat` +生成非饱和语义配置;不写不覆盖 `CTRL[48]`。 + +raw op 也不负责 MX scale 地址组织;`mad_mx_raw` 仍然按 `lhs / rhs` 地址派生 +`L0A_MX / L0B_MX`,并通过 verifier 约束 scale 布局。 + +### 3.5 raw verifier 规则 + +- raw op 不允许出现任何 semantic clause +- `%lhs / %rhs / %dst` 必须是 typed `!pto.ptr` +- `%lhs` 地址空间必须是 `left` +- `%rhs` 地址空间必须是 `right` +- `%dst` 地址空间必须是 `acc` +- bias raw op 的 `%bias` 地址空间必须是 `bias` +- bias raw op 的 `%bias` 元素类型必须和 `%dst` 元素类型一致 +- `%xt` 必须是 `i64` +- 如果 `%xt` 是常量: + - raw non-bias op 要求 `X_t[62] = 0` + - raw bias op 要求 `X_t[62] = 1` + - `X_t[56:55]` 只能是 `0 / 2 / 3` + +## 4. Type 语义 + +### 4.1 `mad` 家族 target profile 可用类型 + +| Family | lhs/rhs | dst | 备注 | +|---|---|---|---| +| `s8` | `s8` | `s32` | 可由 ptr 元素类型推导 | +| `f162f32` | `f16` | `f32` | 可由 ptr 元素类型推导 | +| `bf162f32` | `bf16` | `f32` | 可由 ptr 元素类型推导 | +| `f322f32` | `f32` | `f32` | 普通 FP32 可由 ptr 元素类型推导;TF32 需要显式 `tf32_mode(...)` | +| `e4m3e4m3` | `fp8_e4m3` / `hif8` | `f32` | 普通 FP8 和 HiF8 由 ptr 元素类型区分 | +| `e4m3e5m2` | `fp8_e4m3` / `fp8_e5m2` | `f32` | 可由 ptr 元素类型推导 | +| `e5m2e4m3` | `fp8_e5m2` / `fp8_e4m3` | `f32` | 可由 ptr 元素类型推导 | +| `e5m2e5m2` | `fp8_e5m2` | `f32` | 可由 ptr 元素类型推导 | + +`u8`、`s4`、`s16s8`、`f162f16`、`f16u2`、`u8s8`、`b8u2`、`MMAD_SP` 不纳入 target-profile 设计。 + +### 4.2 `mad_mx` 家族 target profile 可用类型 + +| Family | lhs/rhs | dst | 备注 | +|---|---|---|---| +| `e1m2e1m2` | `fp4_e1m2` | `f32` | 可由 ptr 元素类型推导 | +| `e1m2e2m1` | `fp4_e1m2` / `fp4_e2m1` | `f32` | 可由 ptr 元素类型推导 | +| `e2m1e1m2` | `fp4_e2m1` / `fp4_e1m2` | `f32` | 可由 ptr 元素类型推导 | +| `e2m1e2m1` | `fp4_e2m1` | `f32` | 可由 ptr 元素类型推导 | +| `e4m3e4m3` | `fp8_e4m3` | `f32` | 可由 ptr 元素类型推导 | +| `e4m3e5m2` | `fp8_e4m3` / `fp8_e5m2` | `f32` | 可由 ptr 元素类型推导 | +| `e5m2e4m3` | `fp8_e5m2` / `fp8_e4m3` | `f32` | 可由 ptr 元素类型推导 | +| `e5m2e5m2` | `fp8_e5m2` | `f32` | 可由 ptr 元素类型推导 | + +## 5. Clause 语义 + +### 5.1 `unit_flag(...)` + +这是 producer 侧的 L0C block 语义。 + +- 不写 `unit_flag(...)`:关闭 +- `unit_flag(check_only)`:检查,不设置 +- `unit_flag(check_and_set)`:检查并设置 + +`check_and_set` 是 `mad` 侧对应语义;consumer 侧 `acc_store` 才使用 `check_and_clear`。 + +### 5.2 `disable_gemv?` + +- 不写:允许 GEMV +- 写:禁止 GEMV + +### 5.3 `sat?` / `nosat?` + +表示 CUBE 的饱和/传播语义。 + +- 不写:保留 target-profile 下的默认 numeric policy +- 写:显式请求 saturate 语义 +- 写 `nosat`:显式请求 non-saturate 语义 + +### 5.4 `tf32_mode(...)` + +只对不能从指针元素类型推导的执行模式出现: + +- `tf32_mode(round_even | round_away)`:只对 `f322f32` 有意义 + +`hif8` 不放进 `tf32_mode(...)`。后续引入独立 HiF8 元素类型后,`hif8` 语义由 `lhs / rhs` 的 ptr 元素类型推导;普通 `fp8_e4m3` 仍表示普通 E4M3 解释。 + +其他 family 不应携带 `tf32_mode(...)`。 + +### 5.5 `n_dir?` + +这是 `CTRL[51]` 的语义化表达,用来约束 CUBE 输出 L0C 的方向顺序。 + +- 不写:`CTRL[51] = 1'b0`,先 M 后 N +- 写 `n_dir`:`CTRL[51] = 1'b1`,先 N 后 M + +这个 clause 主要和后续 `acc_store*` 的 layout transform / unit-flag 语义配合,不改变数学结果。 + +## 6. `mad_mx` 的 scale 规则 + +`mad_mx` 不提供 scale pointer operand。 + +scale 通过输入地址派生: + +- `lhs` 对应 `L0A_MX` +- `rhs` 对应 `L0B_MX` +- scale 基址由 data tile 地址派生,形如 `addr / 16` + +也就是说,`mad_mx` 只负责声明“我要做 MX 语义”,不负责再把 scale 地址作为独立数据流传进来。 + +### 6.1 约束 + +- scale dtype 固定为 `e8m0` +- MX-fp4 家族的 data tile 为 `K0 = 64`,对应 scale tile 为 `16 x 2` +- MX-fp8 家族的 data tile 为 `K0 = 32`,对应 scale tile 为 `16 x 2` +- 每 32 个 K 元素共享同一个 scale +- `L0A_MX / L0B_MX` 必须和 `L0A / L0B` 地址对齐 +- MX-fp4 / MX-fp8 的 K0 和 fractal 布局必须满足 target-profile 约束 + +## 7. 设计约束 + +### 7.1 `mad_bias` + +- `bias` 必须是 `BIAS` 地址空间 +- `bias` 元素类型与 `dst` 一致 + +### 7.2 `mad_mx` + +- 不能把 scale 当作独立 operand +- scale 必须通过派生规则和 verifier 约束表达 + +### 7.3 `tf32_mode` + +- `f322f32` 不能只靠 ptr 类型表达 +- 必须显式带 `tf32_mode(...)` + +### 7.4 `hif8` + +- `hif8` 由指针元素类型表达,不作为独立 clause +- `hif8` 只允许用于 `e4m3e4m3` family +- `lhs / rhs` 必须同时是普通 `fp8_e4m3` 或同时是 `hif8`;不允许一边普通 E4M3、一边 HiF8 + +### 7.5 `CTRL` 派生枚举 + +这部分是 verifier / lowering 需要固定住的关键词,不直接暴露 bit 编码: + +- `unit_flag` + - `check_only` -> `2'b10` + - `check_and_set` -> `2'b11` +- `disable_gemv` + - present -> `X_t[61] = 1'b1` +- `hif8` ptr element type + - present on both `lhs / rhs` -> `CTRL[45] = 1'b1` + - absent -> `CTRL[45] = 1'b0` +- `tf32_mode(round_even | round_away)` + - `round_even` -> `CTRL[46]=1'b1, CTRL[47]=1'b0` + - `round_away` -> `CTRL[46]=1'b1, CTRL[47]=1'b1` +- `sat` / `nosat` + - `sat` -> `CTRL[48] = 1'b0` + - `nosat` -> `CTRL[48] = 1'b1` +- `n_dir` + - absent -> `CTRL[51] = 1'b0` + - present -> `CTRL[51] = 1'b1` + +### 7.6 `sat` / `nosat` + +- `sat` 和 `nosat` 是互斥的显式语义开关 +- 不写时保留 target-profile 默认行为 +- 写时表示希望显式控制饱和语义,不要依赖隐式约定 + +### 7.7 `n_dir` + +- `n_dir` 只表达输出方向 +- 不改变数值含义 +- 需要和 `acc_store*` 的 layout 设计一致 + +## 8. 推荐 verifier 规则 + +### 8.1 通用 + +- `lhs / rhs / dst` 必须是 typed `!pto.ptr` +- `m / n / k` 必须是可转成 i64 的整型值 +- `unit_flag(...)` 只能是 `check_only` 或 `check_and_set` +- `disable_gemv` 只能作为 flag 出现 +- `n_dir` 只能作为 flag 出现 + +### 8.2 `mad_bias` + +- `bias` 必须是 `BIAS` 地址空间 +- `bias` 元素类型和 `dst` 一致 + +### 8.3 `mad_mx` + +- `lhs` / `rhs` 需满足 MX family 类型表 +- `dst` 必须是 `f32` +- scale 派生地址必须与 data tile 地址匹配 +- scale 布局和 K0 规则必须满足 MX family 约束 + +## 9. target profile 排除项 + +以下不纳入本版设计: + +- `Feature Map Offset` / `fm_offset` +- `Weight Matrix Offset` / `wt_offset` +- `smask_addr` +- `sub_dtype` +- `right_shift_en` +- `MMAD_SP` +- 其他 reserved / profile1-only 字段 + +## 10. 最终接口形状 + +semantic op: + +```mlir +pto.mad %lhs, %rhs, %dst, %m, %n, %k + unit_flag(check_only | check_and_set)? + disable_gemv? + (sat | nosat)? + tf32_mode(round_even | round_away)? + n_dir? + : !pto.ptr<..., l0a>, !pto.ptr<..., l0b>, !pto.ptr<..., l0c>, i64, i64, i64 + +pto.mad_acc %lhs, %rhs, %dst, %m, %n, %k + unit_flag(check_only | check_and_set)? + disable_gemv? + (sat | nosat)? + tf32_mode(round_even | round_away)? + n_dir? + : !pto.ptr<..., l0a>, !pto.ptr<..., l0b>, !pto.ptr<..., l0c>, i64, i64, i64 + +pto.mad_bias %lhs, %rhs, %dst, %bias, %m, %n, %k + unit_flag(check_only | check_and_set)? + disable_gemv? + (sat | nosat)? + tf32_mode(round_even | round_away)? + n_dir? + : !pto.ptr<..., l0a>, !pto.ptr<..., l0b>, !pto.ptr<..., l0c>, !pto.ptr<..., bt>, i64, i64, i64 + +pto.mad_mx %lhs, %rhs, %dst, %m, %n, %k + unit_flag(check_only | check_and_set)? + disable_gemv? + (sat | nosat)? + n_dir? + : !pto.ptr<..., l0a>, !pto.ptr<..., l0b>, !pto.ptr<..., l0c>, i64, i64, i64 + +pto.mad_mx_acc %lhs, %rhs, %dst, %m, %n, %k + unit_flag(check_only | check_and_set)? + disable_gemv? + (sat | nosat)? + n_dir? + : !pto.ptr<..., l0a>, !pto.ptr<..., l0b>, !pto.ptr<..., l0c>, i64, i64, i64 + +pto.mad_mx_bias %lhs, %rhs, %dst, %bias, %m, %n, %k + unit_flag(check_only | check_and_set)? + disable_gemv? + (sat | nosat)? + n_dir? + : !pto.ptr<..., l0a>, !pto.ptr<..., l0b>, !pto.ptr<..., l0c>, !pto.ptr<..., bt>, i64, i64, i64 +``` + +raw op: + +```mlir +pto.mad_raw %lhs, %rhs, %dst, %xt + : !pto.ptr<..., l0a>, !pto.ptr<..., l0b>, !pto.ptr<..., l0c>, i64 + +pto.mad_bias_raw %lhs, %rhs, %dst, %bias, %xt + : !pto.ptr<..., l0a>, !pto.ptr<..., l0b>, !pto.ptr<..., l0c>, !pto.ptr<..., bt>, i64 + +pto.mad_mx_raw %lhs, %rhs, %dst, %xt + : !pto.ptr<..., l0a>, !pto.ptr<..., l0b>, !pto.ptr<..., l0c>, i64 + +pto.mad_mx_bias_raw %lhs, %rhs, %dst, %bias, %xt + : !pto.ptr<..., l0a>, !pto.ptr<..., l0b>, !pto.ptr<..., l0c>, !pto.ptr<..., bt>, i64 +``` + +这版设计的核心变化是: + +- type 由指针推导 +- `unit_flag` 改成 producer 语义 `check_only` / `check_and_set`,不再混入 `check_and_clear` +- `disable_gemv` 改成 flag +- 新增 raw op 层,semantic op 不再直接 lowered 到 HIVM intrinsic +- raw op 只消费 typed pointer 和 packed `%xt` +- `mad_mx` 不再把 scale 当成独立 operand +- `sat`、`tf32_mode(...)`、`n_dir` 作为显式语义 clause +- `hif8` 从指针元素类型推导,不作为独立 clause diff --git a/docs/designs/ptoas-emit-fatobj.md b/docs/designs/ptoas-emit-fatobj.md new file mode 100644 index 000000000..50348b2f0 --- /dev/null +++ b/docs/designs/ptoas-emit-fatobj.md @@ -0,0 +1,129 @@ +# `ptoas emit fatobj` + +## 输入输出形式 + +### 输入 + +```mlir +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @helper(...) { + } + func.func @foo(...) attributes {pto.kernel} { + ... + } +} + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @helper(...) {} + func.func @foo(...) attributes {pto.kernel} { + ... + } +} +``` + +允许只有一个 module,但是最多每个 kernel_kind 一个 module,只有vector|cube两类 module + +带 `pto.kernel` 的函数在输入中保留逻辑函数名,不要求输入侧手工拼接 `_mix_aic` / `_mix_aiv` 后缀。旧属性名 `pto.aicore` 仍被兼容识别,但新输入应使用 `pto.kernel`。 + +### 输出 + +fatobj 对象文件 + +## 每个模块的工作 + +### `ptoas` + +`--pto-backend=vpto` 直接输出 fatobj。这里移除的是对外的 llvm ir/bc 输出模式。所有修改发生在 vpto 路径内。emitc 路径不要动,不是我们关心的内容。 + +0. 对于两个并列 module,mlir 的 parser 会自动做一个嵌套。为了保证 pass pipeline 不产生分歧,单个 module 的场景也主动包裹为相同的嵌套结构。 + +1. 当前 `ptoas.cpp` 中进入 VPTO 路径有两处。这两处入口都需要改造并支持 fatobj 输出能力,但这里强调的是“两处都要改”,不是把它们在控制流上强行归并成一个入口: + +- 第一处是 `effectiveBackend == PTOBackend::VPTO && inputIsVPTOIR && !hasTileOpsToExpand` 的直入 VPTO 分支。这条路径在 [`tools/ptoas/ptoas.cpp:1605`](/home/mouliangyu/projects/github.com/mouliangyu/PTOAS-3/tools/ptoas/ptoas.cpp#L1605) 附近,当前会在必要时先做 `inlineTilelangHelpersOnVPTOInput`,然后直接 `return emitVPTOBackendResult(...)`。 +- 第二处是通用 PTO 前端 pipeline 跑完之后的 `effectiveBackend == PTOBackend::VPTO` 分支。这条路径在 [`tools/ptoas/ptoas.cpp:1670`](/home/mouliangyu/projects/github.com/mouliangyu/PTOAS-3/tools/ptoas/ptoas.cpp#L1670) 附近,当前会先打印 seam IR、再 `lowerPTOToVPTOBackend`,最后 `return emitVPTOBackendResult(...)`。 + +2. 进入 VPTO 路径后,首先保证 module 自动嵌套为如下形式,如果只有单个 module,就手动加一层。 + +```mlir +module { + module { + } + module { ; 如果只有一个 module 就没有这第二个子 module + } +} +``` + +3. `pto.kernel_kind` 需要位于最底层的 module。 + +4. 所有 pass 都通过 nest pm 驱动,不允许手动切分 module 分别跑 pass pipeline。这是嵌套 module 的统一驱动方式。 + +5. `ptoas` 负责统一调度,但不负责具体链接细节。它负责: + +- 调用 `VPTOHostStubEmission` 生成 stub 源码字符串 +- 调用 `VPTOLLVMEmitter` 生成 cube|vector 两个 llvm module 结构 +- 将 stub、cube、vector 组件输入给 fatobj emission 组件,并直接把最终结果写入 `outputFile` + +6. 不要修改 emitc 路径的代码。 + +### `VPTOHostStubEmission` + +1. 负责 stub 源码字符串的生成,依据输入中 `pto.kernel` 函数的签名和符号约定生成对应 stub + +2. cube 和 vector module 中的同名 `pto.kernel` 函数共享同一个 stub 函数 + +### `VPTOLLVMEmitter` + +1. 负责 llvm module 生成工作,prepare 和 translate 合并,通过同一个 nest pm 驱动。不允许手动切分 module 分别跑 pass pipeline + +2. 对外职责是接收嵌套 module 输入,并输出按 `kernel_kind` 拆分好的 vector / cube llvm module + +3. 对带 `pto.kernel` 的函数,按所属 `kernel_kind` 自动补真实 device 符号后缀: + +- vector 补 `_mix_aiv` +- cube 补 `_mix_aic` + +输入侧只保留逻辑函数名,不在输入 IR 中手工编码这个后缀。 + +4. 当前文件中有很多函数本身不是 module pass,无法直接注册到 nest pm 中,需要用 pass 封装后再进入统一 pipeline + +5. `runPipeline` 是这个模块内部的统一驱动入口,pass 注册集中发生在这里 + +### `VPTOFatobjEmission` + +1. 负责和工具链、临时文件、最终 fatobj 输出打交道,将 vector、cube、stub 组件组织并产出 fatobj + +2. 负责临时文件管理。这里禁止使用“临时目录托管 + 目录递归删除”的模型,而是只管理单个临时文件。原因不是实现 bug,而是目录模型本身具有更高的删除风险:一旦路径判断错误,目录删除天然带有批量删除和隐式路径解释的风险;文件级清理不存在这种大范围破坏面。 + +3. 将 vector/cube LLVM module 和 stub 字符串按工具链需要写入临时文件。文件落盘是统一主模型。 + +4. 参考 test/vpto 下的脚本搭建编译流程,并参考 LLVM/Clang driver 的工具链调用模式实现本地封装。这里需要有一组统一接口负责: + +- 创建并注册临时文件,便于统一清理 +- 调用 `llvm::sys::ExecuteAndWait(...)` 执行外部工具 +- 在底层工具支持时,将已经落盘的临时文件内容通过标准输入重定向给子进程 +- 在底层工具不支持时,回退为显式临时文件输入 + +5. 上述封装的目标不是消灭临时文件,而是统一管理“临时文件创建 / 注册 / 重定向 / 清理”,让 toolchain 交互方式稳定收敛。 + +6. 链接过程整体参考 test/vpto 下的脚本,最终输出为 `-o` 的参数。 + + +## 测试约束 + +### `test/vpto` 脚本 + +1. `test/vpto` 中的测试脚本需要统一为使用 `ptoas` 直接吐出的 fatobj。 + +2. 脚本不再自己分别编译 device llvm、device obj、host stub 再手动打包,而是直接消费 `ptoas` 的 fatobj 输出结果。 + +3. 脚本中的 mixed / non-mixed、cube / vector、单独 `cube.pto` 等特判路径都需要移除,统一走同一种编译与链接模型。 + +### `test/vpto` case 组织 + +1. 每个 case 只保留一个 `kernel.pto`。 + +2. 原来 `cube.pto` 中的代码需要挪到 `kernel.pto` 里的 cube module 中。 + +3. `kernel.pto` 中允许同时包含 vector module 和 cube module,并通过 `pto.kernel_kind` 区分。 + +4. 测试数据生成、host stub、launch、compare 等配套文件继续按现有 case 目录组织保留,不在这次改造中改变其职责。 diff --git a/docs/designs/ptoas-tileop-expand-design.md b/docs/designs/ptoas-tileop-expand-design.md index 1ad4ad63d..9c5c49c57 100644 --- a/docs/designs/ptoas-tileop-expand-design.md +++ b/docs/designs/ptoas-tileop-expand-design.md @@ -669,8 +669,8 @@ Fold pass 处理两族 intrinsic,通过严格的模式匹配将它们解析回 ##### tile_buf 系列折叠 -每一个被折叠的 tile_buf intrinsic,其 `tile_buf` 操作数必须由如下固定链定义 -(由 `MemrefToTileBuf` pass 保证),否则 pass 直接报错并失败: +每一个被折叠的 tile_buf intrinsic,其 `tile_buf` 操作数必须能解析到调用点 +的 materialized tile handle,否则 pass 直接报错并失败: ```mlir %0 = pto.pointer_cast(%addr) {config = ...} @@ -882,7 +882,7 @@ def template_xxx(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): - Python DSL 模板编写和实例化的单元测试 以当前 `lib/TileOps/tadd_template.py` 为例,新增/维护 - `test/basic/expand_tile_op_tilelang.pto` + `test/lit/vpto/expand_tile_op_tilelang.pto` 作为 `pto.tadd` TileLang 模板实例化的基础回归。该用例覆盖: 1. `ExpandTileOp` 是否能匹配 `pto.tadd` 并调用 Python DSL helper; 2. 模板实例化后的 `func.call` 是否能被 inline; @@ -922,7 +922,7 @@ def template_xxx(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): // expands pto.tadd via the default TileLang Python DSL template // lib/TileOps/tadd_template.py. // - // Pipeline: MemrefToTileBuf -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics + // Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics // // RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s @@ -958,38 +958,23 @@ def template_xxx(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): } } ``` -- Expand TileOp pass 的端到端测试(`pto.tadd` → Vector IR) - 使用以下命令生成最终 LLVM IR,并继续交给 Bisheng 做设备侧编译校验: +- Expand TileOp pass 的端到端测试(`pto.tadd` → VPTO fatobj) + 使用以下命令生成最终 fatobj,并由 `ptoas` 内部完成 VPTO lowering、device 编译、stub 生成和打包: ```bash - ./build/tools/ptoas/ptoas test/basic/expand_tile_op_tilelang.pto \ + ./build/tools/ptoas/ptoas test/lit/vpto/expand_tile_op_tilelang.pto \ --pto-arch=a5 \ --pto-backend=vpto \ --enable-tile-op-expand \ - --vpto-emit-hivm-llvm \ - -o - \ - > add.ll - ``` - - 说明: - - `stdout` 中的最终产物是 textual LLVM IR,因此这里使用 `-o - > add.ll` 显式落盘。 - - 随后将生成的 `add.ll` 交给 Bisheng: - - ```bash - bisheng \ - --target=hiipu64-hisilicon-cce \ - -march=dav-c310-vec \ - --cce-aicore-arch=dav-c310-vec \ - --cce-aicore-only \ - -c -x ir add.ll \ -o add.o ``` - 若上述命令成功生成 `add.o`,则说明当前 `pto.tadd` 的向量库模板已经完成: + 说明: + - `add.o` 是 host 可链接的 fatobj 对象。 + - 若上述命令成功生成 `add.o`,则说明当前 `pto.tadd` 的向量库模板已经完成: - TileLang 模板实例化; - - `pto.tadd -> Vector IR -> LLVM IR` 的端到端 lowering; - - Bisheng 设备侧编译校验。 + - `pto.tadd -> VPTO -> LLVM` 的端到端 lowering; + - device 编译、stub 生成和 fatobj 打包。 - 融合场景测试(多个 Tile op 连续使用后的 VF Fusion) - 更新 `PTO_IR_manual.md` 和 TileLang DSL Guide @@ -1016,17 +1001,13 @@ run_st.py │ └─ 配置 simulator / NPU 运行环境 ├─ build_project() │ ├─ cmake -DRUN_MODE=... -DSOC_VERSION=... -DTEST_CASE=... -DPTOAS_BIN=... - │ ├─ ptoas: .pto -> _kernel.ll + │ ├─ ptoas: .pto -> _kernel.o │ │ flags: │ │ --pto-arch=a5 │ │ --pto-backend=vpto │ │ --enable-insert-sync │ │ --enable-tile-op-expand - │ │ --vpto-emit-hivm-llvm - │ ├─ bisheng -x ir: _kernel.ll -> _kernel_device.o - │ ├─ repack_tilelang_kernel.sh: - │ │ _kernel_device.o -> _kernel_repack.o - │ ├─ bisheng -xcce: launch.cpp + _kernel_repack.o -> lib_kernel.so + │ ├─ bisheng -xcce: launch.cpp + _kernel.o -> lib_kernel.so │ └─ bisheng -xc++: main.cpp -> ├─ run_gen_data() │ └─ 在 build/testcase// 下生成每个 case 的 input/golden @@ -1040,19 +1021,15 @@ run_st.py ```text .pto - ──ptoas──> _kernel.ll (LLVM IR) - ──bisheng -x ir──> _kernel_device.o (device-only 对象) - ──repack_tilelang_kernel.sh──> _kernel_repack.o - (host-linkable fatobj) - ──bisheng -xcce launch.cpp + repack.o──> lib_kernel.so + ──ptoas──> _kernel.o (host-linkable fatobj) + ──bisheng -xcce launch.cpp + _kernel.o──> lib_kernel.so (共享库) ──bisheng -xc++ main.cpp + .so──> (host 可执行文件) ``` -其中 repack 步骤是 TileLang ST 与 pto-isa ST 的核心区别:`ptoas + bisheng -x ir` 产出的 -`*_kernel_device.o` 是 device-only 对象,不能直接作为 host 侧链接输入。repack 脚本从 -`launch.cpp` 中抽取 kernel 声明生成 stub,通过 `-fcce-include-aibinary` 嵌入 device binary, -产出 host 可链接的 fatobj。 +`ptoas` 直接产出 host 可链接的 fatobj。TileLang ST 不再维护 `kernel.ll -> device.o -> repack` +这条旧中间链路,`launch.cpp` 只负责 host 侧 kernel 声明和 wrapper,最终由 `bisheng -xcce` +把 `launch.cpp` 与 fatobj 链接成共享库。 运行阶段同样是 ST 框架的一部分,而不是“编译完以后开发者手工处理”的额外步骤: @@ -1146,7 +1123,7 @@ void LaunchTSUB_f32_16x64(float *a, float *b, float *c, void *stream) { ``` 关键约束: -- `extern "C" __global__ AICORE void ...` 这一声明形态不可改变,repack 脚本用 sed 从中抽取 stub +- `extern "C" __global__ AICORE void ...` 声明需要和 `.pto` 中的 `pto.kernel` 函数签名对应 - kernel 参数类型和顺序必须与 `.pto` 中函数签名一致 - `<<<1, nullptr, stream>>>` 表示单核启动 @@ -1333,8 +1310,7 @@ build/testcase/tsub/ | 阶段 | 排查方向 | |---|---| | `ptoas` 编译失败 | 检查 `.pto` 语法、TileLang 模板是否匹配、是否缺少 `--enable-tile-op-expand` | -| `bisheng -x ir` 失败 | 检查 `build/testcase//_kernel.ll` 中的 LLVM IR | -| repack 失败 | 检查 `launch.cpp` 中的 kernel 声明是否符合 `extern "C" __global__ AICORE void` 格式 | +| fatobj 生成失败 | 检查 `ptoas` stderr、`.pto` 语义和 `pto.kernel` 函数签名 | | 链接失败 | 检查共享库符号名一致性、ACL 运行时依赖 | | kernel 执行失败 | 确认 `build/testcase///input*.bin` 是否已生成 | | compare fail | 先检查 `output.bin` vs `golden.bin` 差异,再检查 `.pto` 语义和参数顺序 | diff --git a/docs/designs/tilelang-cube-dsl-design.md b/docs/designs/tilelang-cube-dsl-design.md new file mode 100644 index 000000000..d01afa17d --- /dev/null +++ b/docs/designs/tilelang-cube-dsl-design.md @@ -0,0 +1,700 @@ +# TileLang Cube DSL Design + +> **状态:** 需求对齐完成,尚未实现 +> **范围:** Python 前端语法设计,不涉及后端 lowering 实现细节 + +--- + +## 1. 背景与动机 + +### 1.1 硬件背景 + +PTOAS 目标硬件包含两种独立的计算单元: + +| 单元 | 硬件核心 | IR kernel_kind | 编译宏 | 典型操作 | +|------|---------|----------------|--------|---------| +| **Vector** | AIV | `#pto.kernel_kind` | `__DAV_VEC__` | 向量加载/存储/ALU/谓词 | +| **Cube** | AIC | `#pto.kernel_kind` | `__DAV_CUBE__` | 矩阵乘法 (MAD)、分形数据搬运 | + +**关键约束:两种指令不能出现在同一个函数中。** 这是硬件限制,编译器验证器已在 IR 层强制执行(`verifyFrontendKernelKind` 检查),DSL 设计必须在 Python 语法层面体现这一分离。 + +### 1.2 当前状态 + +- **Vector DSL**:已有完整的 `@vkernel` 装饰器 + `pto.vecscope` / `pto.strict_vecscope` 作用域机制,支持 basic/advanced 两层 API 面 +- **Cube IR**:VPTO bridge 层指令(`pto.mte_gm_l1`、`pto.mad`、`pto.mte_l0c_l1` 等)已在 IR 层完整定义,有 lowering 和 LLVM 发射支持 +- **缺失环节**:没有对应的 Python DSL 前端,程序员无法用 Python 写出 Cube 指令 + +### 1.3 设计目标 + +1. 提供 `@ckernel` 装饰器,与 `@vkernel` 并列,从入口层面区分硬件单元 +2. 暴露完整的 VPTO bridge 层 Cube 操作(数据搬运 + 矩阵计算) +3. 支持模板槽位 `pto.tpl()` 机制,复用 Vector DSL 的设计模式 +4. 在 DSL 语义分析阶段就阻止 Cube/Vector 指令混用 + +### 1.4 设计原则 + +- **GM 数据用 TensorView / PartitionTensorView 表示**:Cube tileop 的 GM 输入数据通过 `TensorView`(逻辑张量视图)或 `PartitionTensorView`(分块视图)表达,不使用 `Tile` 表示 GM 数据 +- **Tile 用于特定地址空间的缓冲区**:`Tile` 类型表示在特定硬件地址空间(LEFT/RIGHT/ACC/MAT/BIAS)中分配的 tile buffer +- **VPTO bridge 层使用 ptr 表示**:Cube bridge 操作数使用 `pto.ptr` 原始指针,通过 `.as_ptr()` 从 Tile/TensorView 获取 +- **通过 `pto.Tile` 构造器分配带地址空间和布局配置的 tile buffer**:通过 `pto.Tile` 构造器分配带地址空间和布局配置的 tile buffer +- **本次不涉及同步操作**:只关注 Cube 指令本身的 DSL 暴露,同步由 `--enable-insert-sync` 自动插入 +- **参数顺序与 IR 保持一致**:避免心智负担 + +--- + +## 2. @ckernel 装饰器 + +### 2.1 基本语法 + +```python +from tilelang_dsl import ckernel, Tile, MemorySpace, select_kernel + +@ckernel( + op="pto.mad", # 单 op 名称 + dtypes=[(pto.f16, pto.f16, pto.f32)], # 支持的 dtype 组合 + name="my_matmul", # 模板名称 + # 以下为可选参数 + ops=["mad", "mad_acc", "mad_bias"], # 多 op 模板槽位 + templates={ # 槽位 → 具体 op 映射 + "compute": { + "mad": "mad", + "mad_acc": "mad_acc", + "mad_bias": "mad_bias", + } + }, +) +def kernel( + a_tv: PartitionTensorView, # GM 输入,通过 PartitionTensorView 表达 + b_tv: PartitionTensorView, + c_tv: PartitionTensorView, # GM 输出 + M: int, K: int, N: int, +): + ... +``` + +### 2.2 参数说明 + +| 参数 | 类型 | 必填 | 说明 | +|------|------|------|------| +| `op` | str | 与 `ops` 二选一 | 单 op 名称,如 `"pto.mad"` | +| `ops` | list[str] | 与 `op` 二选一 | 多 op 名称列表,启用模板槽位机制 | +| `dtypes` | list[tuple] | 是 | 支持的 dtype 组合,如 `[(f16, f16, f32)]` | +| `name` | str | 是 | 模板名称,用于注册和选择 | +| `templates` | dict | 否 | 模板槽位映射,将 `pto.tpl("slot", ...)` 映射到具体 op | +| `target` | str | 否 | 目标架构,默认 `"a5"` | + +### 2.3 函数参数类型约定 + +Cube 内核的参数类型反映它们在数据流中的角色: + +| 参数类型 | 用途 | 说明 | +|----------|------|------| +| `PartitionTensorView` | GM 上的分块输入/输出 | 由调用方从完整 `TensorView` 通过 `PartitionViewOp` 切出子块传入 | +| `TensorView` | GM 上的完整逻辑张量 | 用于无需分块的场景 | +| `Tile`(特定 addr space) | 已分配的硬件 tile buffer | 当调用方已经分配好 LEFT/RIGHT/ACC 等 tile 时传入 | +| `int` | 维度参数 | M, K, N 等矩阵维度 | +| `pto.f16` / `pto.f32` 等 | 标量参数 | 如 threshold、alpha 等 | +| `pto.ptr` | 原始指针 | 需要直接操作指针时(如 GM pointer) | + +### 2.4 与 @vkernel 的关键差异 + +| 特性 | @vkernel | @ckernel | +|------|----------|----------| +| 硬件单元 | Vector (AIV) | Cube (AIC) | +| 执行作用域 | `pto.vecscope` / `pto.strict_vecscope` | **无需作用域**,函数体直接是 Cube 线性代码 | +| GM 数据表示 | `TensorView` / `Tile` | `TensorView` / `PartitionTensorView` | +| 缓冲区 | Tile (UB/VEC) | Tile (MAT/LEFT/RIGHT/ACC/BIAS) | +| 操作数抽象 | Tile + VecScope 内的向量寄存器和 mask | `pto.ptr` 原始指针 | +| 核心操作 | 向量 ALU、加载/存储 | 数据搬运 + 矩阵乘法 (mad) | +| 生成 IR 属性 | `#pto.kernel_kind` | `#pto.kernel_kind` | + +--- + +## 3. Cube 编程模型 + +### 3.1 数据流 + +``` +PartitionTensorView (GM) + │ + ├──(cube_load)──> L1/cbuf (MAT) ──(left_load)──> L0A (LEFT) + │ │ + ├──(cube_load)──> L1/cbuf (MAT) ──(right_load)──> L0B (RIGHT) + │ │ + │ ┌────┘ + │ ▼ + │ ┌──────────┐ + │ │ pto.mad │ + │ └──────────┘ + │ │ + │ ▼ + │ L1/cbuf (MAT) <──(acc_store)── L0C (ACC) + │ │ │ + │ ├──(cube_store)──> UB (VEC) │ + │ ├──(acc_store_gm)──> GM <───────────┘ + │ └──(acc_store_ub)──> UB + │ + ▼ +PartitionTensorView (GM, 写回) +``` + +### 3.2 地址空间 + +| 地址空间 | 枚举值 | 说明 | 对应 IR 类型 | +|----------|--------|------|-------------| +| `GM` | `MemorySpace.GM` | 全局内存 | `!pto.ptr` | +| `MAT` | `MemorySpace.MAT` | L1 缓冲区 (cbuf) | `!pto.ptr` | +| `LEFT` | `MemorySpace.LEFT` | L0A 矩阵左乘数缓冲区 | `!pto.ptr` | +| `RIGHT` | `MemorySpace.RIGHT` | L0B 矩阵右乘数缓冲区 | `!pto.ptr` | +| `ACC` | `MemorySpace.ACC` | L0C 累加器缓冲区 | `!pto.ptr` | +| `BIAS` | `MemorySpace.BIAS` | Bias 表 | `!pto.ptr` | +| `UB` | `MemorySpace.UB` | 统一缓冲区 (Vector 侧) | `!pto.ptr` | + +### 3.3 缓冲区分配接口 + +#### `pto.Tile` 构造器 + +```python +pto.Tile( + shape: tuple[int, ...], # 缓冲区形状 (必填) + dtype: pto dtype, # 元素类型 (必填) + memory_space: MemorySpace, # 地址空间 (必填) + valid_shape: tuple[int, ...] | None = None, # 有效区域,默认等于 shape + blayout: BLayout | None = None, # B 布局,默认按地址空间自动选择 + slayout: SLayout | None = None, # S 布局,默认按地址空间自动选择 + fractal_size: int | None = None, # 分形大小,默认按地址空间自动选择 + pad_value: PadValue = PadValue.Null, # 填充策略 + compact_mode: CompactMode = CompactMode.Null, # 压缩模式 + addr: int | None = None, # 预分配地址(level3 使用) +) -> Tile +``` + +**布局配置默认值(按地址空间):** + +| 地址空间 | blayout | slayout | fractal_size | +|----------|---------|---------|-------------| +| `MAT` | `ColMajor` | `RowMajor` | `TileConfig.fractalABSize` (512) | +| `LEFT` | `ColMajor` | `RowMajor` | `TileConfig.fractalABSize` (512) | +| `RIGHT` | `RowMajor` | `ColMajor` | `TileConfig.fractalABSize` (512) | +| `ACC` | `ColMajor` | `RowMajor` | `TileConfig.fractalCSize` (1024) | +| `BIAS` | `RowMajor` | `NoneBox` | `TileConfig.fractalABSize` (512) | + +**枚举值定义:** + +| 枚举类型 | 可选值 | +|----------|--------| +| `BLayout` | `ColMajor` (0), `RowMajor` (1) | +| `SLayout` | `NoneBox` (0), `RowMajor` (1), `ColMajor` (2) | +| `PadValue` | `Null` (0), `Zero` (1), `Max` (2), `Min` (3) | +| `CompactMode` | `Null` (0), `Normal` (1), `RowPlusOne` (2) | + +#### `.as_ptr()` + +从 Tile 或 TensorView/PartitionTensorView 获取原始指针(方法调用): + +```python +# 从 Tile 获取指针(地址空间由 Tile 的类型决定) +l0a_ptr = l0a_tile.as_ptr() # Tile[LEFT] → pto.ptr + +# 从 TensorView / PartitionTensorView 获取 GM 指针 +gm_ptr = tensor_view.as_ptr() # TensorView → pto.ptr +a_ptr = a_tv.as_ptr() # PartitionTensorView → pto.ptr +``` + +### 3.4 指针偏移 + +子矩阵寻址通过 `pto.addptr` 实现,偏移量以元素为单位: + +```python +a_k = pto.addptr(a_ptr, k_off) # 偏移 k_off 个元素 +``` + +不引入 tile slice 语法糖,保持与 VPTO 层的 ptr 抽象一致。 + +### 3.5 典型编程模式 + +```python +@ckernel(op="pto.mad", dtypes=[(pto.f16, pto.f16, pto.f32)], name="gemm") +def gemm(a_tv: PartitionTensorView, # GM 输入 A [M, K] + b_tv: PartitionTensorView, # GM 输入 B [K, N] + c_tv: PartitionTensorView, # GM 输出 C [M, N] + M: int, K: int, N: int): + # 1. 从 PartitionTensorView 获取 GM 指针 + a_ptr = a_tv.as_ptr() # -> pto.ptr + b_ptr = b_tv.as_ptr() # -> pto.ptr + c_ptr = c_tv.as_ptr() # -> pto.ptr + + # 2. 分配 L1 (MAT) tile buffer 并获取指针 + l1_a = pto.Tile([M, K], pto.f16, MemorySpace.MAT) + l1_b = pto.Tile([K, N], pto.f16, MemorySpace.MAT) + + # 3. 分配 L0 tile buffer 并获取指针 + l0a = pto.Tile([M, K], pto.f16, MemorySpace.LEFT) + l0b = pto.Tile([K, N], pto.f16, MemorySpace.RIGHT) + l0c = pto.Tile([M, N], pto.f32, MemorySpace.ACC) + + # 4. GM -> L1 数据搬运 + pto.mte_gm_l1(a_ptr, l1_a.as_ptr(), K, nburst=(1, 0, 0)) + pto.mte_gm_l1(b_ptr, l1_b.as_ptr(), N, nburst=(1, 0, 0)) + + # 5. L1 -> L0 数据搬运 + pto.mte_l1_l0a(l1_a.as_ptr(), l0a.as_ptr(), M, K) + pto.mte_l1_l0b(l1_b.as_ptr(), l0b.as_ptr(), K, N) + + # 6. 矩阵乘法 + pto.mad(l0a.as_ptr(), l0b.as_ptr(), l0c.as_ptr(), M, N, K) + + # 7. L0C -> GM 结果写回 + pto.mte_l0c_gm(l0c.as_ptr(), c_ptr, M, N, + src_stride=N, dst_stride=N, + mode="nz2nd") +``` + +--- + +## 4. Cube 操作 API 面 + +以下为 `@ckernel` 函数体内支持的 `pto.*` 调用。所有操作数使用 `pto.ptr` 指针类型。 + +### 4.1 矩阵计算操作 + +#### `pto.mad` — 零初始化矩阵乘法 + +```python +pto.mad(lhs: pto.ptr, rhs: pto.ptr, dst: pto.ptr, + m: int, n: int, k: int, + unit_flag_ctrl: int = 0, disable_gemv: bool = False) +``` + +语义:`dst = lhs * rhs`(零初始化累加器后计算) + +#### `pto.mad_acc` — 累加矩阵乘法 + +```python +pto.mad_acc(lhs: pto.ptr, rhs: pto.ptr, dst: pto.ptr, + m: int, n: int, k: int, + unit_flag_ctrl: int = 0, disable_gemv: bool = False) +``` + +语义:`dst += lhs * rhs` + +#### `pto.mad_bias` — 带 Bias 矩阵乘法 + +```python +pto.mad_bias(lhs: pto.ptr, rhs: pto.ptr, dst: pto.ptr, + bias: pto.ptr, + m: int, n: int, k: int, + unit_flag_ctrl: int = 0, disable_gemv: bool = False) +``` + +语义:`dst = lhs * rhs + bias` + +#### `pto.mad_mx` / `pto.mad_mx_acc` / `pto.mad_mx_bias` + +MX micro-scaling 变体,参数与对应非 MX 版本相同,用于 `f8` 等 MX 数据类型。 + +### 4.2 数据搬运操作 + +#### `pto.mte_gm_l1` — GM → L1 (cbuf) + +```python +pto.mte_gm_l1(src: pto.ptr, dst: pto.ptr, + len_burst: int, + nburst: tuple[int, int, int] = (1, 0, 0), + loops: list[tuple[int, int, int]] | None = None) +``` + +#### `pto.mte_l1_ub` — L1 (cbuf) → UB + +```python +pto.mte_l1_ub(src: pto.ptr, dst: pto.ptr, + len_burst: int, + nburst: tuple[int, int, int] = (1, 0, 0), + loops: list[tuple[int, int, int]] | None = None) +``` + +#### `pto.mte_gm_l1_frac` — 分形加载 (nd2nz / dn2nz) + +```python +pto.mte_gm_l1_frac(src: pto.ptr, dst: pto.ptr, + mode: str, # "nd2nz" | "dn2nz" + shape: tuple[int, int], # (n_value, d_value) + src_layout: tuple[int, int], # (inner_stride, outer_stride) + dst_group: tuple[int, int, int, int], # (count, l2s, l3s, l4s) + ctrl: tuple[int, bool]) # (l2_cache_ctrl, smallc0_en) +``` + +#### `pto.mte_l1_bt` — L1 (cbuf) → Bias 表 + +```python +pto.mte_l1_bt(src: pto.ptr, dst: pto.ptr, + len_burst: int, + nburst: tuple[int, int, int] = (1, 0, 0)) +``` + +#### `pto.mte_l1_l0a` — L1 (cbuf) → L0A + +```python +pto.mte_l1_l0a(src: pto.ptr, dst: pto.ptr, + m: int, k: int) +``` + +#### `pto.mte_l1_l0b` — L1 (cbuf) → L0B + +```python +pto.mte_l1_l0b(src: pto.ptr, dst: pto.ptr, + k: int, n: int) +``` + +#### `pto.mte_l1_l0a_mx` / `pto.mte_l1_l0b_mx` + +MX 模式 L1→L0A/L0B 搬运,参数同非 MX 版本。 + +### 4.3 结果写回操作 + +#### `pto.mte_l0c_l1` — L0C (acc) → L1 (cbuf) + +```python +pto.mte_l0c_l1(src: pto.ptr, dst: pto.ptr, + m: int, n: int, + src_stride: int, dst_stride: int, + mode: str = "nz2nd", # "nz2nd" | "nz2dn" | "nz2nz" + loop0_src_stride: int | None = None, # mode="nz2dn" 时需要 + split: int | None = None, # mode="nz2nz" 时需要 + loop3: tuple[int, int, int] | None = None) +``` + +#### `pto.mte_l0c_gm` — L0C (acc) → GM + +```python +pto.mte_l0c_gm(src: pto.ptr, dst: pto.ptr, + m: int, n: int, + src_stride: int, dst_stride: int, + sid: int = 0, l2_cache_ctrl: int = 0, + mode: str = "nz2nd", + loop0_src_stride: int | None = None, + split: int | None = None, + loop3: tuple[int, int, int] | None = None) +``` + +#### `pto.mte_l0c_ub` — L0C (acc) → UB + +```python +pto.mte_l0c_ub(src: pto.ptr, dst: pto.ptr, + m: int, n: int, + src_stride: int, dst_stride: int, + dual_dst_mode: int = 0, sub_blockid: int = 0, + mode: str = "nz2nd", + loop0_src_stride: int | None = None, + channel_split_en: int | None = None, # mode="nz2nz" 时需要 + loop3: tuple[int, int, int] | None = None) +``` + +--- + +## 5. 模板槽位机制 + +### 5.1 设计 + +复用 Vector DSL 的 `pto.tpl()` 机制,允许一个 Cube kernel 模板适配多种 mad 操作变体。 + +### 5.2 语法 + +```python +@ckernel( + ops=["mad", "mad_acc"], + dtypes=[(pto.f16, pto.f16, pto.f32)], + name="gemm_template", + templates={ + "compute": {"mad": "mad", "mad_acc": "mad_acc"}, + }, +) +def gemm_template(a_tv: PartitionTensorView, b_tv: PartitionTensorView, + c_tv: PartitionTensorView, M: int, K: int, N: int): + a_ptr = a_tv.as_ptr() + b_ptr = b_tv.as_ptr() + c_ptr = c_tv.as_ptr() + + l1_a = pto.Tile([M, K], pto.f16, MemorySpace.MAT) + l1_b = pto.Tile([K, N], pto.f16, MemorySpace.MAT) + l0a = pto.Tile([M, K], pto.f16, MemorySpace.LEFT) + l0b = pto.Tile([K, N], pto.f16, MemorySpace.RIGHT) + l0c = pto.Tile([M, N], pto.f32, MemorySpace.ACC) + + pto.mte_gm_l1(a_ptr, l1_a.as_ptr(), K, nburst=(1, 0, 0)) + pto.mte_gm_l1(b_ptr, l1_b.as_ptr(), N, nburst=(1, 0, 0)) + pto.mte_l1_l0a(l1_a.as_ptr(), l0a.as_ptr(), M, K) + pto.mte_l1_l0b(l1_b.as_ptr(), l0b.as_ptr(), K, N) + + # 模板槽位:根据 selected_op 自动替换为 mad 或 mad_acc + pto.tpl("compute", l0a.as_ptr(), l0b.as_ptr(), l0c.as_ptr(), M, N, K) + + pto.mte_l0c_gm(l0c.as_ptr(), c_ptr, M, N, + src_stride=N, dst_stride=N, mode="nz2nd") +``` + +使用方式: + +```python +k_mad = select_kernel("a5", "gemm_template", selected_op="mad") +k_acc = select_kernel("a5", "gemm_template", selected_op="mad_acc") +``` + +### 5.3 约束 + +模板槽位中的变体必须参数签名一致: + +| 槽位组 | 成员 | 参数 | +|--------|------|------| +| `compute` | `mad`, `mad_acc` | `(lhs, rhs, dst, m, n, k)` | +| `compute_bias` | `mad_bias` | `(lhs, rhs, dst, bias, m, n, k)` | +| `compute_mx` | `mad_mx`, `mad_mx_acc` | `(lhs, rhs, dst, m, n, k)` | + +参数不一致的变体(如 mad vs mad_bias)不能放在同一个槽位中。 + +--- + +## 6. 硬件分离规则 + +### 6.1 函数级别隔离 + +- `@ckernel` 生成的函数带有 `#pto.kernel_kind` 属性 +- `@vkernel` 生成的函数带有 `#pto.kernel_kind` 属性 +- 验证器在 IR 层阻止两种指令出现在同一函数中 + +### 6.2 DSL 层面强制 + +在语义分析阶段: + +1. `@ckernel` 函数体内不允许出现 Vector 专有操作(`vlds`、`vadd` 等) +2. `@ckernel` 函数体内不允许出现 `pto.vecscope` / `pto.strict_vecscope` +3. CKernel 不能调用 VKernel 的 inline_proc,反之亦然 + +### 6.3 模块级别 + +- 同一个 `.py` 文件中可以同时定义 `@ckernel` 和 `@vkernel` +- 每个函数独立编译,由 EmitC 后端通过 `__DAV_CUBE__` / `__DAV_VEC__` 宏守卫条件编译 + +--- + +## 7. 与 Vector DSL 的共享基础设施 + +| 设施 | 说明 | +|------|------| +| `TensorView` / `PartitionTensorView` | GM 数据的高层视图,两者通用 | +| `Tile` 类型 | 缓冲区类型标注,通过 `MemorySpace` 区分地址空间 | +| `select_kernel()` / `KernelRegistry` | Kernel 注册和选择 | +| `MaterializedMLIRModule` | 具体化后的 MLIR 模块 | +| `pto.ptr` / `pto.castptr` / `pto.addptr` | 指针操作 | +| `MemorySpace` | 地址空间枚举(已含 MAT/LEFT/RIGHT/ACC/BIAS) | +| `Tile` 构造器 | 缓冲区分配(通过 `pto.Tile()` 构造) | +| `TileConfig` | 分形大小等常量 | + +--- + +## 8. 完整示例 + +### 8.1 基础 GEMM + +```python +from tilelang_dsl import ckernel, Tile, MemorySpace + +@ckernel( + op="pto.mad", + dtypes=[(pto.f16, pto.f16, pto.f32)], + name="gemm", +) +def gemm(a_tv: PartitionTensorView, # [M, K] in GM + b_tv: PartitionTensorView, # [K, N] in GM + c_tv: PartitionTensorView, # [M, N] in GM, output + M: int, K: int, N: int): + # Get GM pointers from PartitionTensorViews + a_ptr = a_tv.as_ptr() + b_ptr = b_tv.as_ptr() + c_ptr = c_tv.as_ptr() + + # Allocate tiles in respective address spaces + l1_a = pto.Tile([M, K], pto.f16, MemorySpace.MAT) + l1_b = pto.Tile([K, N], pto.f16, MemorySpace.MAT) + l0a = pto.Tile([M, K], pto.f16, MemorySpace.LEFT) + l0b = pto.Tile([K, N], pto.f16, MemorySpace.RIGHT) + l0c = pto.Tile([M, N], pto.f32, MemorySpace.ACC) + + # Data movement + pto.mte_gm_l1(a_ptr, l1_a.as_ptr(), K, nburst=(1, 0, 0)) + pto.mte_gm_l1(b_ptr, l1_b.as_ptr(), N, nburst=(1, 0, 0)) + pto.mte_l1_l0a(l1_a.as_ptr(), l0a.as_ptr(), M, K) + pto.mte_l1_l0b(l1_b.as_ptr(), l0b.as_ptr(), K, N) + + # Compute + pto.mad(l0a.as_ptr(), l0b.as_ptr(), l0c.as_ptr(), M, N, K) + + # Writeback + pto.mte_l0c_gm(l0c.as_ptr(), c_ptr, M, N, + src_stride=N, dst_stride=N, mode="nz2nd") +``` + +### 8.2 Split-K GEMM + +```python +@ckernel( + op="pto.mad", + dtypes=[(pto.f16, pto.f16, pto.f32)], + name="gemm_splitk", +) +def gemm_splitk(a_tv: PartitionTensorView, # [M, K] + b_tv: PartitionTensorView, # [K, N] + c_tv: PartitionTensorView, # [M, N] + M: int, K: int, N: int, BASEK: int): + iters = K // BASEK + + a_ptr = a_tv.as_ptr() + b_ptr = b_tv.as_ptr() + c_ptr = c_tv.as_ptr() + + l1_a = pto.Tile([M, BASEK], pto.f16, MemorySpace.MAT) + l1_b = pto.Tile([BASEK, N], pto.f16, MemorySpace.MAT) + l0a = pto.Tile([M, BASEK], pto.f16, MemorySpace.LEFT) + l0b = pto.Tile([BASEK, N], pto.f16, MemorySpace.RIGHT) + l0c = pto.Tile([M, N], pto.f32, MemorySpace.ACC) + + for k_step in range(iters): + k_off = k_step * BASEK + a_k = pto.addptr(a_ptr, k_off) + b_k = pto.addptr(b_ptr, k_off) + + pto.mte_gm_l1(a_k, l1_a.as_ptr(), BASEK, nburst=(1, 0, 0)) + pto.mte_gm_l1(b_k, l1_b.as_ptr(), N, nburst=(1, 0, 0)) + pto.mte_l1_l0a(l1_a.as_ptr(), l0a.as_ptr(), M, BASEK) + pto.mte_l1_l0b(l1_b.as_ptr(), l0b.as_ptr(), BASEK, N) + + if k_step == 0: + pto.mad(l0a.as_ptr(), l0b.as_ptr(), l0c.as_ptr(), M, N, BASEK) + else: + pto.mad_acc(l0a.as_ptr(), l0b.as_ptr(), l0c.as_ptr(), M, N, BASEK) + + pto.mte_l0c_gm(l0c.as_ptr(), c_ptr, M, N, + src_stride=N, dst_stride=N, mode="nz2nd") +``` + +### 8.3 带 Bias 的矩阵乘法 + +```python +@ckernel( + op="pto.mad_bias", + dtypes=[(pto.f16, pto.f16, pto.f32)], + name="gemm_bias", +) +def gemm_bias(a_tv: PartitionTensorView, b_tv: PartitionTensorView, + c_tv: PartitionTensorView, bias_tv: PartitionTensorView, + M: int, K: int, N: int): + a_ptr = a_tv.as_ptr() + b_ptr = b_tv.as_ptr() + c_ptr = c_tv.as_ptr() + bias_ptr = bias_tv.as_ptr() + + l1_a = pto.Tile([M, K], pto.f16, MemorySpace.MAT) + l1_b = pto.Tile([K, N], pto.f16, MemorySpace.MAT) + l1_bias = pto.Tile([1, N], pto.f32, MemorySpace.MAT) + l0a = pto.Tile([M, K], pto.f16, MemorySpace.LEFT) + l0b = pto.Tile([K, N], pto.f16, MemorySpace.RIGHT) + l0c = pto.Tile([M, N], pto.f32, MemorySpace.ACC) + bt = pto.Tile([1, N], pto.f32, MemorySpace.BIAS) + + pto.mte_gm_l1(a_ptr, l1_a.as_ptr(), K, nburst=(1, 0, 0)) + pto.mte_gm_l1(b_ptr, l1_b.as_ptr(), N, nburst=(1, 0, 0)) + pto.mte_gm_l1(bias_ptr, l1_bias.as_ptr(), N, nburst=(1, 0, 0)) + pto.mte_l1_bt(l1_bias.as_ptr(), bt.as_ptr(), N, nburst=(1, 0, 0)) + + pto.mte_l1_l0a(l1_a.as_ptr(), l0a.as_ptr(), M, K) + pto.mte_l1_l0b(l1_b.as_ptr(), l0b.as_ptr(), K, N) + pto.mad_bias(l0a.as_ptr(), l0b.as_ptr(), l0c.as_ptr(), bt.as_ptr(), M, N, K) + + pto.mte_l0c_gm(l0c.as_ptr(), c_ptr, M, N, + src_stride=N, dst_stride=N, mode="nz2nd") +``` + +### 8.4 分形加载 (nd2nz) 示例 + +```python +@ckernel( + op="pto.mad", + dtypes=[(pto.f16, pto.f16, pto.f32)], + name="gemm_frac", +) +def gemm_frac(a_tv: PartitionTensorView, b_tv: PartitionTensorView, + c_tv: PartitionTensorView, M: int, K: int, N: int): + a_ptr = a_tv.as_ptr() + b_ptr = b_tv.as_ptr() + c_ptr = c_tv.as_ptr() + + l1_a = pto.Tile([M, K], pto.f16, MemorySpace.MAT) + l1_b = pto.Tile([K, N], pto.f16, MemorySpace.MAT) + l0a = pto.Tile([M, K], pto.f16, MemorySpace.LEFT) + l0b = pto.Tile([K, N], pto.f16, MemorySpace.RIGHT) + l0c = pto.Tile([M, N], pto.f32, MemorySpace.ACC) + + pto.mte_gm_l1_frac(a_ptr, l1_a.as_ptr(), "nd2nz", + shape=(M, K), + src_layout=(K,), + dst_group=(1, 0, 0, 0), + ctrl=(0, False)) + pto.mte_gm_l1(b_ptr, l1_b.as_ptr(), N, nburst=(1, 0, 0)) + + pto.mte_l1_l0a(l1_a.as_ptr(), l0a.as_ptr(), M, K) + pto.mte_l1_l0b(l1_b.as_ptr(), l0b.as_ptr(), K, N) + pto.mad(l0a.as_ptr(), l0b.as_ptr(), l0c.as_ptr(), M, N, K) + + pto.mte_l0c_gm(l0c.as_ptr(), c_ptr, M, N, + src_stride=N, dst_stride=N, mode="nz2nd") +``` + +--- + +## 9. Lowering 流程 + +### 9.1 与 Vector DSL 的对比 + +| 阶段 | Vector DSL | Cube DSL | +|------|-----------|----------| +| AST 解析 | `frontend_ast.py` → `FrontendKernelNode` | 增加 `FrontendCKernelNode` | +| 语义分析 | `semantic.py` → `SemanticKernel`(含 vecscope 分析) | 增加 Cube 语义分析(无 vecscope,线性 IR) | +| MLIR 发射 | `lowering.py` → MLIR 文本(含 `vecscope` 块) | 增加 Cube lowering(直接发射线性 VPTO IR) | +| IR 属性 | `#pto.kernel_kind` | `#pto.kernel_kind` | +| 目标 march | `dav-c310-vec` | `dav-c310-cube` | + +### 9.2 Cube 特有问题 + +1. **无 vecscope 作用域**:Cube 函数体直接是线性 IR 序列 +2. **地址空间验证**:每个 Cube op 对操作数的地址空间有严格要求 +3. **ptr 管理**:`.as_ptr()` 从 Tile/TensorView 取地址、`pto.addptr` 指针偏移需要在语义阶段正确处理 +4. **Tile 构造器配置**:`pto.Tile()` 按地址空间自动推导布局默认值 + +--- + +## 10. 分阶段实施建议 + +### Phase 1:最小可用面 (MVP) + +- `@ckernel` 装饰器 +- `pto.Tile` 构造器 + `.as_ptr()` 缓冲区分配和指针获取 +- `pto.mad` / `pto.mad_acc` / `pto.mad_bias` +- `pto.mte_gm_l1` / `pto.mte_l1_ub` +- `pto.mte_l1_l0a` / `pto.mte_l1_l0b` +- `pto.mte_l0c_gm` +- 模板槽位 `pto.tpl()` 基本支持 + +### Phase 2:完整 bridge 面 + +- `pto.mad_mx` / `pto.mad_mx_acc` / `pto.mad_mx_bias` +- `pto.mte_gm_l1_frac` +- `pto.mte_l1_bt` +- `pto.mte_l1_l0a_mx` / `pto.mte_l1_l0b_mx` +- `pto.mte_l0c_l1` / `pto.mte_l0c_ub` +- `pto.addptr` 指针偏移 + +### Phase 3:高级特性 + +- Split-K 循环语法糖 +- 分形参数自动推导 +- Tile 构造器布局全自动推断 diff --git a/docs/designs/tilelang-st-framework.md b/docs/designs/tilelang-st-framework.md new file mode 100644 index 000000000..60a48bb05 --- /dev/null +++ b/docs/designs/tilelang-st-framework.md @@ -0,0 +1,585 @@ +# TileLang ST 精度验证框架 + +## 1. 文档目标 + +本文从 TileLang 库开发者的视角介绍当前 `test/tilelang_st` 框架的使用方式。 + +这份框架的目标不是做单纯的 IR 回归,而是回答下面两个更贴近开发的问题: + +1. 我新写的 TileLang 模板库实现,展开到 PTO / VPTO / LLVM IR 之后,最终在 simulator 或 NPU 上跑出来的数值是否正确。 +2. 如果我要为一个新 op 增加 ST 用例,最少需要准备哪些文件,运行链路会经过哪些阶段。 + +当前框架已经具备下面这些能力: + +- 从 `.pto` 直接驱动 `ptoas`,不需要手写 `kernel.cpp` 或中间 `.ll` +- 支持在一个 testcase 下放多个 case +- 支持 `sim` / `npu` 两种运行模式 +- 支持单 case 过滤 +- 支持 `src` / `dst` 逻辑 shape 不一致的 testcase(例如 `trowsum` 这类 reduction) +- 支持把输入、golden、output 隔离到 `build/testcase//` 下,避免不同 testcase 之间互相覆盖 + +## 2. 框架定位 + +TileLang ST 参考了 `pto-isa` 的 ST 目录组织方式,但编译链路不同。 + +| 维度 | pto-isa ST | TileLang ST | +|---|---|---| +| kernel 来源 | 手写 `kernel.cpp` | 手写 `.pto`,由 `ptoas` 展开 TileLang DSL 模板 | +| 编译入口 | `bisheng -xcce kernel.cpp` | `ptoas .pto -> fatobj` | +| device 对象接入 host | 编译器一步直接生成 fatobj | `ptoas` 直接生成 host-linkable fatobj | +| 精度比较 | GTest / C++ 比较逻辑 | `compare.py` + `numpy.allclose` | +| 多 case 组织 | 多个 GTest case | 一个 testcase 下多个 kernel 函数 + host case table | + +换句话说,TileLang ST 更适合验证“库模板展开后的端到端运行正确性”,而不是验证某一段单独的 CCE kernel.cpp。 + +## 3. 当前执行流程 + +统一入口是: + +```bash +python3 test/tilelang_st/script/run_st.py -r sim -v a5 -t tadd +``` + +cube 类 kernel 也可以直接走同一入口,例如: + +```bash +python3 test/tilelang_st/script/run_st.py -r sim -v a5 -t tmatmul +``` + +完整链路如下: + +```text +run_st.py + ├─ set_env_variables() + │ └─ 配置 simulator / NPU 运行环境 + ├─ build_project() + │ ├─ cmake -DRUN_MODE=... -DSOC_VERSION=... -DTEST_CASE=... -DPTOAS_BIN=... + │ ├─ ptoas: .pto -> _kernel.o + │ │ flags: + │ │ --pto-arch=a5 + │ │ --pto-backend=vpto + │ │ --enable-insert-sync + │ │ --enable-tile-op-expand + │ ├─ bisheng -xcce: launch.cpp + _kernel.o -> lib_kernel.so + │ └─ bisheng -xc++: main.cpp -> + ├─ run_gen_data() + │ └─ 在 build/testcase// 下生成每个 case 的 input/golden + ├─ run_binary() + │ └─ 在 build/testcase// 下执行 ../../bin/ [case] + └─ run_compare() + └─ 在 build/testcase// 下逐 case 比较 golden/output +``` + +### 3.1 关于 fatobj 直连 + +TileLang ST 现在不再经过 `kernel.ll -> device.o -> repack` 的中间路径。 + +`ptoas` 直接输出 host 可链接的 fatobj 对象,`launch.cpp` 只负责提供 host 侧的 kernel 声明和 wrapper,然后由 `bisheng -xcce` 直接把 `launch.cpp` 和 fatobj 链接成 `lib_kernel.so`。 + +如果 fatobj 输出没有生成成功,后续 host 链接自然也不会成功,因此排查时优先看 `ptoas` 的输出是否完整。 + +### 3.2 关于 case 的执行和比较顺序 + +默认情况下: + +1. `gen_data.py` 会先为 testcase 下的所有 case 生成输入和 golden +2. `./bin/` 会依次跑完所有 case +3. `compare.py` 再依次比较所有 case 的 `golden.bin` 和 `output.bin` + +如果使用 `-c `,则运行和比较都会只针对这个 case。 + +## 4. 目录结构与职责 + +当前目录结构如下: + +```text +test/tilelang_st/ + ├── script/ +│ ├── run_st.py +│ ├── run_all_st.py +│ └── run_ci.sh +└── npu/ + └── a5/ + └── src/st/ + ├── CMakeLists.txt + └── testcase/ + ├── CMakeLists.txt + ├── run_ptoas_to_file.cmake + ├── st_common.py + └── tadd/ + ├── CMakeLists.txt + ├── cases.py + ├── tadd.pto + ├── launch.cpp + ├── main.cpp + ├── gen_data.py + └── compare.py +``` + +各文件职责如下: + +| 文件 | 职责 | +|---|---| +| `script/run_st.py` | 统一入口,负责编译、生成数据、执行二进制、比较结果 | +| `script/run_all_st.py` | 汇总执行所有 testcase 的入口 | +| `script/run_ci.sh` | CI 入口包装 | +| `src/st/CMakeLists.txt` | 顶层 CMake,设置编译器、环境和依赖 | +| `testcase/CMakeLists.txt` | 定义 `pto_tilelang_vec_st()` 宏,并注册所有 testcase | +| `testcase/run_ptoas_to_file.cmake` | 封装 `ptoas` 调用,把 `.pto` 编译成 fatobj | +| `testcase/st_common.py` | 所有 testcase 共享的 Python 公共模块(case 校验、数据生成辅助、`result_cmp`、终端着色) | +| `testcase//cases.py` | **case 定义的单一来源**,`gen_data.py` 和 `compare.py` 均从此导入;默认使用 `shape`/`valid_shape`,像 `trowsum` 这类输出 shape 不同的 op 再额外补 `dst_shape`/`dst_valid_shape` | +| `testcase//.pto` | testcase 的 kernel 描述,通常一个文件中放多个 case 对应的函数 | +| `testcase//launch.cpp` | kernel 声明和 launch wrapper | +| `testcase//main.cpp` | host driver,负责分配内存、launch kernel、回写 output(`ACL_CHECK` 宏由公共头 `test_common.h` 提供) | +| `testcase//gen_data.py` | 生成 input 与 golden,从 `cases.py` 读取 case 列表 | +| `testcase//compare.py` | 每个 testcase 自己的比较脚本,决定读取哪些 bin、reshape 成什么形状、裁哪一块数据,再调用公共 `result_cmp()` | + +## 5. 日常使用方式 + +### 5.0 前置条件 + +运行 TileLang ST 之前,建议先确认下面几件事: + +- 仓库里的 `ptoas` 已经编出来,默认路径是 `build/tools/ptoas/ptoas` +- `ASCEND_HOME_PATH` 已经设置正确 +- 如果需要手工跑 `ptoas`、`bisheng` 或 lit,优先先执行: + +```bash +source scripts/ptoas_env.sh +``` + +`run_st.py` 会在运行时补充 simulator / NPU 相关环境,但它不会替你构建 `ptoas`。 + +### 5.1 运行已有 testcase + +```bash +# simulator 上跑 tadd 全部 case +python3 test/tilelang_st/script/run_st.py -r sim -v a5 -t tadd + +# NPU 上跑 tadd 全部 case +python3 test/tilelang_st/script/run_st.py -r npu -v a5 -t tadd + +# 只跑一个 case +python3 test/tilelang_st/script/run_st.py -r sim -v a5 -t tadd -c f32_16x64 + +# 复用已有 build 目录,不重新编译 +python3 test/tilelang_st/script/run_st.py -r sim -v a5 -t tadd -w +``` + +### 5.2 常用参数 + +| 参数 | 含义 | +|---|---| +| `-r, --run-mode` | 运行模式,`sim` 或 `npu` | +| `-v, --soc-version` | SoC 版本,目前只支持 `a5` | +| `-t, --testcase` | testcase 名称,对应 `testcase//` | +| `-c, --case` | 只运行一个 case | +| `-p, --ptoas-bin` | 指定 `ptoas` 路径 | +| `-w, --without-build` | 跳过构建,直接复用已有 `build/` | + +### 5.3 产物在哪 + +testcase 的运行时数据不再写到 `build/` 根目录,而是写到: + +```text +test/tilelang_st/npu/a5/src/st/build/testcase// +``` + +以 `tadd` 为例: + +```text +build/testcase/tadd/ +├── gen_data.py +├── compare.py +├── f32_16x64/ +│ ├── input1.bin +│ ├── input2.bin +│ ├── golden.bin +│ └── output.bin +└── f32_32x32/ + ├── input1.bin + ├── input2.bin + ├── golden.bin + └── output.bin +``` + +这个布局的好处是: + +- 不同 testcase 之间不会因为 case 同名而互相覆盖 +- 方便开发者直接进入 `build/testcase//` 复查输入、输出和 golden +- 使用 `-w` 时,不容易把旧 testcase 的残留数据误认为当前结果 + +### 5.4 比较输出 + +`compare.py` 会对 pass/fail 做明显提示: + +- pass:粗体绿色 +- fail:粗体红色 + +比较逻辑目前使用 `numpy.allclose`。建议阈值: + +| dtype | 建议 eps | +|---|---| +| `float32` | `1e-6` | +| `float16` | `1e-3` | +| `bfloat16` | `1e-2` | +| `int8/int16/int32` | `0` | + +## 6. 作为库开发者,如何增加一个新 op testcase + +这一节回答“我开发了一个新的 TileLang 库实现,怎么用 ST 框架验证它”。 + +以新增 `pto.tsub` 为例,最少需要准备下面这些文件: + +| 文件 | 是否新增/修改 | 说明 | +|---|---|---| +| `testcase/tsub/CMakeLists.txt` | 新增 | 一般只有一行 `pto_tilelang_vec_st(tsub)` | +| `testcase/tsub/cases.py` | 新增 | **case 定义的单一来源**:每个 case 必须指定 `name`/`dtype`/`shape`/`valid_shape`/`eps`;如果输出 shape 不同,再额外补 `dst_shape`/`dst_valid_shape` | +| `testcase/tsub/tsub.pto` | 新增 | 定义一个或多个 case 的 kernel 函数 | +| `testcase/tsub/launch.cpp` | 新增 | 为每个 kernel 函数声明 entry 并提供 launch wrapper | +| `testcase/tsub/main.cpp` | 新增 | host driver,负责 case table、内存拷贝、launch 和 output 落盘 | +| `testcase/tsub/gen_data.py` | 新增 | 生成每个 case 的输入和 golden,从 `cases.py` 导入 `CASES` | +| `testcase/tsub/compare.py` | 新增 | testcase 自己决定比较哪些输出数据,再调用公共 `result_cmp()` | +| `testcase/CMakeLists.txt` | 修改 | 把 `tsub` 加入 `ALL_TESTCASES` | + +通常不需要修改: + +- `script/run_st.py` +- `src/st/CMakeLists.txt` +- `testcase/st_common.py` +- `testcase/run_ptoas_to_file.cmake` +- `testcase` 目录下的旧 `.ll` / `device.o` / `repack` 产物 + +除非你在改框架本身,而不是新增一个 testcase。 + +## 7. 以 `pto.tadd` 为例,需要改哪些文件 + +当前仓库里 `tadd` 已经是一个完整样例。把它当成模板即可。 + +### 7.1 `testcase/tadd/CMakeLists.txt` + +这个文件通常最简单: + +```cmake +pto_tilelang_vec_st(tadd) +``` + +含义是让公共宏接管 `tadd.pto -> tadd_kernel.o -> libtadd_kernel.so -> tadd` 这一整条流水线。 + +### 7.2 `testcase/tadd/tadd.pto` + +这是最核心的文件。你需要在这里写出要验证的 kernel 形态。 + +当前 `tadd.pto` 的特点是: + +- 一个文件中包含多个 case +- 每个 case 对应一个 `func.func @TADD__x(...)` +- 函数体里显式写出 `make_tensor_view`、`partition_view`、`alloc_tile`、`tload`、`pto.tadd`、`tstore` + +如果你在开发 `pto.tadd` 库实现,最关键的是先把你要覆盖的 case 设计好。例如: + +- `f32` / `f16` / `bf16` +- 不同 tile 形状 +- 边界 valid 行列不是整 tile 的情况 + +这里的函数命名建议统一成: + +```text +TADD__x +``` + +例如: + +```text +TADD_f32_16x64 +TADD_f32_32x32 +``` + +### 7.3 `testcase/tadd/launch.cpp` + +这个文件的职责只有两个: + +1. 声明 kernel entry +2. 为 host driver 提供 `Launch*` wrapper + +当前推荐写法和 `tadd` 一致: + +```cpp +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +extern "C" __global__ AICORE void TADD_f32_16x64(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTADD_f32_16x64(float *a, float *b, float *c, void *stream) { + TADD_f32_16x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} +``` + +注意点: + +- `launch.cpp` 不需要包含 PTO 头文件 +- `AICORE` 直接本地定义为 `[aicore]` +- 这里的 kernel 声明必须和 `tadd.pto` 中的 `pto.kernel` 函数签名对应,供 `bisheng -xcce` 直接链接 fatobj +- kernel 参数顺序必须和 `.pto` 中函数签名保持一致 + +### 7.4 `testcase/tadd/main.cpp` + +这个文件负责 host 侧调度。 + +你需要做的事主要有三类: + +1. 声明所有 `LaunchTADD_*` wrapper +2. 在 `kCases[]` 中列出每个 case 的名字、launch 函数、输入/输出 shape、valid shape、元素大小 +3. 在 `RunCase()` 中完成: + - 从 `.//input*.bin` 读取输入 + - `aclrtMemcpy` 把输入拷到 device + - 调用 `tc.launch(...)` + - `aclrtSynchronizeStream` + - 把输出拷回 host + - 写 `.//output.bin` + +当前 `tadd/main.cpp` 的 case table 形式如下: + +```cpp +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; +}; + +static const TestCase kCases[] = { + {"f32_16x64", LaunchTADD_f32_16x64, 16, 64, 16, 64, sizeof(float)}, + {"f32_32x32", LaunchTADD_f32_32x32, 32, 32, 32, 32, sizeof(float)}, +}; +``` + +注意:`ACL_CHECK` 宏已移至公共头文件 `test_common.h`(需在 `acl/acl.h` 之后包含),不需要在每个 testcase 的 `main.cpp` 中重复定义。 + +你在新增 case 时,必须同步更新这个表。 + +- 对 `tadd` 这类同 shape op,字段需与 `cases.py` 的 `shape` / `valid_shape` 保持一致。 +- 对 `trowsum` 这类输出 shape 不同的 op,host 侧需要把输入大小和输出大小分开计算。 + +### 7.5 `testcase/tadd/cases.py` + +这是 case 定义的**单一来源**,`gen_data.py` 和 `compare.py` 均从此导入 `CASES`。 + +每个 case 必须包含以下字段: + +```python +"name" +"dtype" +"shape" +"valid_shape" +"eps" +``` + +```python +CASES = [ + { + "name": "f32_16x64", # case 标识,对应运行时子目录和 main.cpp kCases[] 中的 name + "dtype": np.float32, # numpy dtype + "shape": (16, 64), # 分配的 tile 维度 (rows, cols) + "valid_shape": (16, 64), # 有效计算区域 (valid_rows, valid_cols) + "eps": 1e-6, # numpy.allclose 容差 + }, +] +``` + +`valid_shape` 为必填字段。当 valid shape 等于 tile shape 时也必须显式写出。 + +如果输出 shape 不同,可以额外补下面两个字段: + +```python +CASES = [ + { + "name": "f32_16x64", + "dtype": np.float32, + "shape": (16, 64), # 输入 tensor shape + "valid_shape": (16, 64), # 输入有效区域 + "dst_shape": (16, 1), # 输出 tensor shape(GM 可见形状) + "dst_valid_shape": (16, 1), # 输出有效区域 + "eps": 1e-5, + }, +] +``` + +这也是 `trowsum` 推荐使用的写法。注意 `dst_shape` 描述的是写回 GM 后的实际结果形状,而不是片上 tile 的物理展开形状。 + +### 7.6 `testcase/tadd/gen_data.py` + +这个文件负责为每个 case 生成输入和 golden。从 `cases.py` 导入 `CASES`, +从 `st_common.py` 导入辅助函数(`setup_case_rng`、`save_case_data`)。 + +以 `pto.tadd` 为例,每个 case 的核心逻辑: + +```python +golden = np.zeros(shape, dtype=dtype) +vr, vc = case["valid_shape"] +golden[:vr, :vc] = (input1[:vr, :vc] + input2[:vr, :vc]).astype(dtype, copy=False) +``` + +golden 只在 `valid_shape` 区域内计算,区域外保持零值。 + +如果是 `trowsum` 这类输出 shape 不同的 op,则 `gen_data.py` 应该按 `dst_shape` 生成 `golden`,按 `valid_shape` 完成规约计算。例如: + +```python +shape = case["shape"] +valid_shape = case["valid_shape"] +dst_shape = case["dst_shape"] +dst_valid_shape = case["dst_valid_shape"] +input1 = np.random.randint(1, 10, size=shape).astype(dtype) +golden = np.zeros(dst_shape, dtype=dtype) +golden[:dst_valid_shape[0], 0] = np.sum( + input1[:valid_shape[0], :valid_shape[1]], axis=1 +).astype(dtype, copy=False)[:dst_valid_shape[0]] +``` + +比较阶段也会按 `dst_shape` / `dst_valid_shape` 读取和 reshape `golden.bin`、`output.bin`。 + +每个 case 使用独立的随机 seed(`setup_case_rng` 基于 `hash(case["name"])`), +新增或调整 case 顺序不会影响已有 case 的测试数据。 + +### 7.7 `testcase//compare.py` + +比较脚本不再放在公共目录,而是每个 testcase 自己维护一份。 + +这样做的目的很直接: + +- 公共层只提供 `result_cmp(golden, output, eps)` 这种“比已经准备好的数据”的接口 +- 具体读取哪些 bin、reshape 成什么形状、裁哪一块 valid 区域,由 testcase 自己决定 + +以 `tadd` 为例,`compare.py` 的核心逻辑就是: + +```python +golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) +output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) +ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) +``` + +如果是 `trowsum`,则可以自己改成按 `dst_shape` reshape,并只比较 `rows x 1` 的有效区域。 + +这种拆法更接近 `pto-isa` 的 `ResultCmp` 思路:公共层只负责“怎么比”,不负责“该比哪块数据”。 + +## 8. 如果只是在已有 `tadd` 下新增一个 case + +如果 `tadd` testcase 已经存在,而你只是想加一个新 case,例如 `f32_8x128`,则通常只需要同步修改 4 个文件: + +| 文件 | 必须修改的内容 | +|---|---| +| `testcase/tadd/cases.py` | 在 `CASES` 中加入新条目(含 `name`/`dtype`/`shape`/`valid_shape`/`eps`) | +| `testcase/tadd/tadd.pto` | 新增一个 `func.func @TADD_f32_8x128(...)` | +| `testcase/tadd/launch.cpp` | 新增 `extern "C"` kernel 声明和 `LaunchTADD_f32_8x128` | +| `testcase/tadd/main.cpp` | 在 `kCases[]` 中加入 `{"f32_8x128", LaunchTADD_f32_8x128, 8, 128, 8, 128, sizeof(float)}` | + +不需要改: + +- `testcase/tadd/gen_data.py`(自动从 `cases.py` 读取) +- `testcase/tadd/compare.py`(自动从 `cases.py` 读取) +- `testcase/tadd/CMakeLists.txt` +- `testcase/CMakeLists.txt` +- `run_st.py` + +## 9. 文件之间必须保持一致的约束 + +这是新增 testcase 时最容易出错的地方。 + +### 9.1 命名一致 + +下面这几处名字必须严格一致: + +| 位置 | 示例 | +|---|---| +| `.pto` 中的 kernel 函数名 | `@TADD_f32_16x64` | +| `launch.cpp` 中的 kernel 声明 | `TADD_f32_16x64` | +| `launch.cpp` / `main.cpp` 中的 wrapper 名 | `LaunchTADD_f32_16x64` | +| `main.cpp` 的 case 名 | `f32_16x64` | +| `gen_data.py` / `compare.py` 的 case 名 | `f32_16x64` | +| 运行时目录名 | `build/testcase/tadd/f32_16x64/` | + +### 9.2 参数顺序一致 + +`.pto` 里 kernel 的参数顺序、`launch.cpp` 声明顺序、`main.cpp` 里 launch wrapper 的参数顺序必须一致。 +如果 `tadd` 的语义是 `(a, b) -> c`,那 host 侧和 compare 也都要按这个顺序组织。 + +### 9.3 shape、valid_shape、dst_shape 和 dtype 一致 + +`cases.py` 中的 shape 信息和 `dtype` 是 Python 侧的单一来源,`gen_data.py` 和 `compare.py` 自动从中读取。 + +- 对大多数 op,`shape`/`valid_shape` 就够了。 +- 对 `trowsum` 这类输出 shape 不同的 op,再额外维护 `dst_shape`/`dst_valid_shape`。 + +但 C++ 侧的 `main.cpp` `kCases[]` 和 `.pto` 中的 tensor/tile shape 仍需手动与 `cases.py` 保持一致。 +否则运行能成功,结果也可能是错误的,且定位会很耗时。 + +## 10. 建议的开发验证节奏 + +作为库开发者,建议用下面的节奏迭代: + +1. 先写一个最小 case,例如 `f32_16x64` +2. 在 simulator 上跑单 case: + +```bash +python3 test/tilelang_st/script/run_st.py -r sim -v a5 -t tadd -c f32_16x64 +``` + +3. 改 `.pto` 或 host 代码后,如果确认只是小修改,可以用: + +```bash +python3 test/tilelang_st/script/run_st.py -r sim -v a5 -t tadd -c f32_16x64 -w +``` + +4. 单 case 稳定后,再补更多 shape / dtype case +5. 再跑全量 `tadd` +6. 最后如果需要,再切到 `-r npu` + +## 11. 调试建议 + +### 11.1 编译失败看哪里 + +- `ptoas` 失败:优先看 `.pto` 本身、TileLang 模板实例化、是否缺少 `--enable-insert-sync` +- fatobj 生成失败:优先看 `ptoas` 的 stderr 和 `.pto` 语义是否完整 +- `launch.cpp` / `main.cpp` 链接失败:优先看共享库、ACL 运行时依赖和符号名一致性 + +### 11.2 运行失败看哪里 + +- `main.cpp` 报读文件失败:先确认 `build/testcase///input*.bin` 是否存在 +- kernel 能跑但 compare fail:先看 `output.bin` 与 `golden.bin` 的差异,再看 `.pto` 语义和 host 参数顺序 +- 某个 case 单独跑通过、全量跑失败:优先怀疑 case 目录隔离、host 资源释放、或者多 case 共用状态 + +### 11.3 典型排查文件 + +| 文件 | 作用 | +|---|---| +| `build/testcase//_kernel.o` | 看 `ptoas` 最终生成的 fatobj | +| `build/testcase///golden.bin` | 确认 Python 侧 oracle 是否正确 | +| `build/testcase///output.bin` | 确认运行时实际输出 | +| `testcase//main.cpp` | 确认 host 侧参数顺序、shape 和文件路径 | +| `testcase//compare.py` | 确认比较阈值是否合理 | + +## 12. 一句话总结 + +对于库开发者来说,TileLang ST 框架就是一条固定好的端到端验证流水线: + +```text +写 .pto -> 接入 testcase 六件套 -> run_st.py 编译运行 -> 查看 build/testcase// 下的 input/golden/output -> 判断库实现是否正确 +``` + +如果你想验证的是 `pto.tadd`,最重要的是把下面几处保持同步: + +- `cases.py` 中的 case 定义(name/dtype/shape/valid_shape/eps)—— Python 侧的单一来源 +- `tadd.pto` 中的 kernel 函数名和 tile shape +- `launch.cpp` 中的 kernel 声明与 wrapper +- `main.cpp` 中的 `kCases[]`(rows/cols/validRows/validCols 需与 `cases.py` 一致) +- `gen_data.py` 中的 golden 计算逻辑(op 语义相关,如加法/减法) + +`compare.py` 和 `gen_data.py` 的 case 列表、比较阈值均自动从 `cases.py` 读取,不需要单独维护。 + +这几处一致,框架就能帮助你把 TileLang 库实现的”端到端正确性”稳定地跑起来。 diff --git a/docs/designs/vpto-section-sugar.md b/docs/designs/vpto-section-sugar.md new file mode 100644 index 000000000..b65effaaa --- /dev/null +++ b/docs/designs/vpto-section-sugar.md @@ -0,0 +1,199 @@ +# VPTO section sugar + +## 背景 + +VPTO fatobj 工作流当前使用显式双 module 编程模型: + +```mlir +module attributes {pto.target_arch = "a5"} { + module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @kernel(...) attributes {pto.kernel} { + ... + } + } + + module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @kernel(...) attributes {pto.kernel} { + ... + } + } +} +``` + +这个模型适合作为后端规范输入,但对手写 mixed kernel 不够紧凑。我们希望增加一个语法糖:用户可以在一个 `pto.kernel` 函数内用 `pto.section.vector` 和 `pto.section.cube` 分别编写 vector / cube 代码,VPTO 路径入口处立刻把它解包为现有双 module 形式。 + +## 调研结论 + +1. `pto.section.cube` / `pto.section.vector` 已经存在于 PTO dialect。 + +2. 这两个 op 当前定义在 `include/PTO/IR/PTOOps.td`,是 `SingleBlock, NoTerminator` 的 region container。 + +3. `PTOWrapFunctionsInSectionsPass` 已经能把带 `pto.kernel_kind` 的 frontend function body 包进对应 section,但它服务的是旧 frontend section/EmitC 模型,不是 VPTO mixed module 解包。 + +4. 多个 verifier 已经认识 section 上下文,例如 tpush/tpop/tfree 允许出现在 `pto.section.cube/vector` 内。 + +5. VPTO fatobj 后端当前只认规范双 module: + - `vpto-normalize-container` 把单个带 `pto.kernel_kind` 的 module 包成外层 container,并要求外层只包含带 `pto.kernel_kind` 的子 module。 + - `VPTOLLVMEmitter` 按子 module 的 `pto.kernel_kind` 选择 cube/vector LLVM 目标,并给 `pto.kernel` 函数补 `_mix_aic` / `_mix_aiv` 后缀。 + - `VPTOHostStubEmission` 根据同名 `pto.kernel` 函数生成一个 host stub,并校验 mixed variants 的签名一致。 + +结论:新语法糖应复用现有 `pto.section.cube/vector` op,只新增 VPTO 入口解包 pass,不改变 LLVM emitter、host stub emission 和 fatobj emission 的核心模型。 + +## 输入形式 + +语法糖输入是一个普通 kernel module,module 上不需要 `pto.kernel_kind`。同一个 `pto.kernel` 函数内可以包含一个 vector section、一个 cube section,或者只包含其中一个。旧属性名 `pto.aicore` 仍被兼容识别,但新输入应使用 `pto.kernel`。 + +```mlir +module attributes {pto.target_arch = "a5"} { + func.func @kernel(%src: !pto.ptr, %dst: !pto.ptr) + attributes {pto.kernel} { + %c0 = arith.constant 0 : i64 + + pto.section.vector { + // vector code + } + + pto.section.cube { + // cube code + } + + return + } +} +``` + +section 内的代码允许使用函数参数、函数内 section 外定义的 SSA 值,以及同一 section 内定义的值。解包 pass 不单独分析这些依赖,而是整体 clone 原函数,再按目标 core 删除另一类 section。 + +## 输出形式 + +解包后的 IR 必须是现有 VPTO fatobj 规范输入: + +```mlir +module attributes {pto.target_arch = "a5"} { + module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @kernel(%src: !pto.ptr, %dst: !pto.ptr) + attributes {pto.kernel} { + // original function body with cube sections removed + return + } + } + + module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @kernel(%src: !pto.ptr, %dst: !pto.ptr) + attributes {pto.kernel} { + // original function body with vector sections removed + return + } + } +} +``` + +后续 VPTO pipeline 不再感知 `pto.section.cube/vector`,只处理带 `pto.kernel_kind` 的子 module。 + +## 解包步骤 + +1. 识别 sugar module。 + + 如果顶层 module 已经是 container,且子 module 带 `pto.kernel_kind`,则认为输入已经是规范双 module,不做 section 解包。 + + 如果顶层 module 自身带 `pto.kernel_kind`,则由 `vpto-normalize-container` 包一层外层 container。 + + 如果顶层 module 不带 `pto.kernel_kind`,且含有 `pto.kernel` 函数内的 `pto.section.cube/vector`,则进入 section sugar 解包。 + +2. 为每个实际出现的 kernel kind 创建一个子 module。 + + vector section 生成 `module attributes {pto.kernel_kind = #pto.kernel_kind}`。 + + cube section 生成 `module attributes {pto.kernel_kind = #pto.kernel_kind}`。 + + 外层 module 保留 `pto.target_arch` 等 module 级公共属性。 + +3. 为每个带 `pto.kernel` 的函数生成同名函数 variant。 + + 输出函数保留原函数名、参数列表、结果类型和 `pto.kernel` 属性。后续 `VPTOLLVMEmitter` 仍负责补 `_mix_aiv` / `_mix_aic` 后缀。 + + vector module 中放原函数的一个 clone,然后删除其中所有 `pto.section.cube`。 + + cube module 中放原函数的一个 clone,然后删除其中所有 `pto.section.vector`。 + +4. 展开目标 section。 + + 在 vector module 中,把 `pto.section.vector` 替换为其 body 内的操作。 + + 在 cube module 中,把 `pto.section.cube` 替换为其 body 内的操作。 + + 因为每个目标函数是从原函数整体 clone 出来的,section 外公共前置代码会自然保留,不需要单独分析和克隆 section 依赖。 + +5. 依赖校验。 + + 解包 pass 不做复杂的跨 section 依赖分析。删除非目标 section 后,如果目标代码仍引用了被删除 section 产生的 SSA 值,后续 MLIR verifier 应直接报错。 + + 这也是期望行为:cube/vector section 之间不能通过普通 SSA 值直接传递数据,跨 core 通信必须用显式同步和搬移 op 表达。 + +## 约束 + +1. 一个 `pto.kernel` 函数内每种 section 最多出现一次。 + +2. `pto.section.cube` 和 `pto.section.vector` 不能嵌套。 + +3. section sugar 输入中,`pto.kernel` 函数 body 的顶层可包含公共前置定义、section op、同步/搬移等普通操作和 `return`。这些 section 外操作会被完整保留到每个目标函数中。 + +4. 同一个输入 module 中如果有多个 `pto.kernel` 函数,则每个函数只进入它实际包含的 section kind 对应子 module。后续 host stub 继续要求同名 mixed variants 的签名一致。 + +5. helper 函数随目标 module 一起复制。无用 helper 可以交给后续 DCE 或保持存在,不作为 section sugar 的语义问题。 + +6. 解包后不保留 section op。section op 只是源级 sugar,不进入 VPTO LLVM emission。 + +## 放置位置 + +新增 pass 命名为 `vpto-split-cv-module`。 + +它应当在 VPTO 路径最前面执行,位置早于: + +1. `vpto-normalize-container` + +2. `prepareVPTOForEmission` + +推荐把入口职责调整为: + +```text +VPTO input + -> expand section sugar to kernel_kind modules + -> normalize single kernel_kind module to outer container + -> verify normalized container + -> existing nested VPTO pipeline + -> LLVM modules + -> fatobj +``` + +这样后续 fatobj workflow 仍只有一种规范 IR 形态,不需要在 emitter 或 stub 生成阶段处理 section。 + +## 与现有 pass 的关系 + +`PTOWrapFunctionsInSectionsPass` 和本设计方向相反: + +```text +kernel_kind function -> section.cube/vector +``` + +新 pass 的方向是: + +```text +section.cube/vector -> kernel_kind module +``` + +因此不复用 `PTOWrapFunctionsInSectionsPass`,但可以复用它对 section op 的遍历经验和 verifier 约束。 + +## 测试计划 + +1. 添加一个 lit 测试,输入单 module + `pto.section.vector` / `pto.section.cube`,检查解包后出现两个 `pto.kernel_kind` 子 module。 + +2. 添加只含 vector section 的测试,检查它等价于单 vector module 输入。 + +3. 添加错误测试: + - 同一个函数里重复 vector section。 + - section 嵌套。 + - section 捕获不可克隆的外部 SSA 值。 + - `pto.kernel` 函数有返回值。 + +4. 把现有一个 mixed VPTO host validation case 改写成 section sugar 输入,确认 `ptoas --pto-backend=vpto` 仍能直接生成 fatobj 并通过 SIM。 diff --git a/docs/isa-legacy/02-dma-copy.md b/docs/isa-legacy/02-dma-copy.md new file mode 100644 index 000000000..d9e9b3b1b --- /dev/null +++ b/docs/isa-legacy/02-dma-copy.md @@ -0,0 +1,659 @@ +# 2. DMA Copy Programming + +> **Category:** DMA transfer configuration and execution +> **Pipelines:** MTE2 (GM→UB), MTE3 (UB→GM) + +DMA transfers move data between Global Memory (GM) and Unified Buffer (UB). The MTE engines operate asynchronously from the Vector core, requiring explicit sync (see [Pipeline Sync](01-pipeline-sync.md)). + +The MTE2/MTE3 DMA engine executes a **multi-level nested loop** transfer. Before issuing the copy instruction, stride and loop-size registers must be configured. + +--- + +## Loop Stride Configuration (GM→UB) + +These ops configure the MTE2 DMA engine's hardware loops for GM→UB transfers. They must be set **before** calling `pto.copy_gm_to_ubuf`. + +### `pto.set_loop_size_outtoub` + +- **syntax:** `pto.set_loop_size_outtoub %loop1_count, %loop2_count : i64, i64` +- **semantics:** Configure HW loop iteration counts for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%loop1_count` | 21 bits | Inner HW loop iteration count | +| `%loop2_count` | 21 bits | Outer HW loop iteration count | + +When not using multi-level looping, set both to 1. + +--- + +### `pto.set_loop2_stride_outtoub` + +- **syntax:** `pto.set_loop2_stride_outtoub %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure outer loop (loop2) pointer advance for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 40 bits | GM source pointer advance per loop2 iteration (bytes) | +| `%dst_stride` | 21 bits | UB destination pointer advance per loop2 iteration (bytes) | + +After each loop2 iteration, the DMA engine advances the GM read pointer by `%src_stride` and UB write pointer by `%dst_stride`. + +--- + +### `pto.set_loop1_stride_outtoub` + +- **syntax:** `pto.set_loop1_stride_outtoub %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure inner loop (loop1) pointer advance for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 40 bits | GM source pointer advance per loop1 iteration (bytes) | +| `%dst_stride` | 21 bits | UB destination pointer advance per loop1 iteration (bytes) | + +--- + +## Loop Stride Configuration (UB→GM) + +These ops configure the MTE3 DMA engine's hardware loops for UB→GM transfers. They must be set **before** calling `pto.copy_ubuf_to_gm`. + +Note: UB stride fields are 21 bits (sufficient for 256KB UB address space), GM stride fields are 40 bits (full GM address range). + +### `pto.set_loop_size_ubtoout` + +- **syntax:** `pto.set_loop_size_ubtoout %loop1_count, %loop2_count : i64, i64` +- **semantics:** Configure HW loop iteration counts for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%loop1_count` | 21 bits | Inner HW loop iteration count | +| `%loop2_count` | 21 bits | Outer HW loop iteration count | + +--- + +### `pto.set_loop2_stride_ubtoout` + +- **syntax:** `pto.set_loop2_stride_ubtoout %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure outer loop (loop2) pointer advance for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 21 bits | UB source pointer advance per loop2 iteration (bytes) | +| `%dst_stride` | 40 bits | GM destination pointer advance per loop2 iteration (bytes) | + +--- + +### `pto.set_loop1_stride_ubtoout` + +- **syntax:** `pto.set_loop1_stride_ubtoout %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure inner loop (loop1) pointer advance for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 21 bits | UB source pointer advance per loop1 iteration (bytes) | +| `%dst_stride` | 40 bits | GM destination pointer advance per loop1 iteration (bytes) | + +--- + +## Pad Value Configuration + +### `pto.set_mov_pad_val` + +- **syntax:** `pto.set_mov_pad_val %value : T` +- **supported `T`:** `i8`, `i16`, `i32`, `f16`, `bf16`, `f32` +- **semantics:** Configure the pad fill value used by GM→UB DMA when `data_select_bit = true`. + +This op programs the hardware pad register consumed by `pto.copy_gm_to_ubuf`. The operand is a typed scalar. Its raw bit pattern is encoded into the underlying hardware configuration payload: + +- integer inputs use their zero-extended bit pattern +- floating-point inputs use their bitcast-to-integer bit pattern, then zero-extend to `i64` + +This configuration affects only the GM→UB padding path. UB→GM DMA ignores the pad value. + +**Parameter Table:** + +| Parameter | Description | +|-----------|-------------| +| `%value` | Pad fill scalar. Must be one of `i8/i16/i32/f16/bf16/f32`. | + +**Example:** + +```mlir +%pad = arith.constant 0 : i16 +pto.set_mov_pad_val %pad : i16 +``` + +--- + +## DMA Transfer Execution + +### `pto.copy_gm_to_ubuf` + +- **syntax:** +```mlir +pto.copy_gm_to_ubuf %gm_src, %ub_dst, + %sid, %n_burst, %len_burst, %left_padding, %right_padding, + %data_select_bit, %l2_cache_ctl, %src_stride, %dst_stride + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` +- **semantics:** DMA transfer from Global Memory (`!pto.ptr`) to Unified Buffer (`!pto.ptr`). + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%gm_src` | GM source pointer (`!pto.ptr`) | +| `%ub_dst` | UB destination pointer (`!pto.ptr`, 32B-aligned) | +| `%sid` | Stream ID (usually 0) | +| `%n_burst` | Number of burst rows (innermost loop count) | +| `%len_burst` | Contiguous bytes transferred per burst row | +| `%left_padding` | Left padding count (bytes) | +| `%right_padding` | Right padding count (bytes) | +| `%data_select_bit` | Padding / data-select control bit (`i1`) | +| `%l2_cache_ctl` | L2 cache allocate control (TBD — controls whether DMA allocates in L2 cache) | +| `%src_stride` | GM source stride: start-to-start distance between consecutive burst rows (bytes) | +| `%dst_stride` | UB destination stride: start-to-start distance between consecutive burst rows (bytes, 32B-aligned) | + +--- + +### `pto.copy_ubuf_to_gm` + +- **syntax:** +```mlir +pto.copy_ubuf_to_gm %ub_src, %gm_dst, + %sid, %n_burst, %len_burst, %reserved, %dst_stride, %src_stride + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` +- **semantics:** DMA transfer from Unified Buffer (`!pto.ptr`) to Global Memory (`!pto.ptr`). MTE3 reads only `len_burst` bytes from each UB row (de-padding). + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%ub_src` | UB source pointer (`!pto.ptr`, 32B-aligned) | +| `%gm_dst` | GM destination pointer (`!pto.ptr`) | +| `%sid` | Stream ID (usually 0) | +| `%n_burst` | Number of burst rows | +| `%len_burst` | Contiguous bytes transferred per burst row | +| `%reserved` | Reserved field (set to 0) | +| `%dst_stride` | GM destination stride: start-to-start distance between consecutive burst rows (bytes) | +| `%src_stride` | UB source stride: start-to-start distance between consecutive burst rows (bytes, 32B-aligned) | + +--- + +### `pto.copy_ubuf_to_ubuf` + +- **syntax:** +```mlir +pto.copy_ubuf_to_ubuf %ub_src, %ub_dst, + %sid, %n_burst, %len_burst, %src_gap, %dst_gap + : !pto.ptr, !pto.ptr, + i64, i64, i64, i64, i64 +``` +- **semantics:** Raw UB→UB copy within Unified Buffer. `pto.mte_ub_ub` uses the same operand contract. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%ub_src` | ptr | UB source pointer (`!pto.ptr`, 32B-aligned) | +| `%ub_dst` | ptr | UB destination pointer (`!pto.ptr`, 32B-aligned) | +| `%sid` | 16 bits | Stream ID | +| `%n_burst` | 16 bits | Number of bursts | +| `%len_burst` | 16 bits | Burst length in units of 32 bytes | +| `%src_gap` | 16 bits | Source gap between consecutive bursts, in units of 32 bytes | +| `%dst_gap` | 16 bits | Destination gap between consecutive bursts, in units of 32 bytes | + +--- + +## Burst / Stride / Gap / Pad Model + +The legacy DMA copy family uses two different innermost-burst contracts: + +- `pto.copy_gm_to_ubuf` / `pto.copy_ubuf_to_gm` are **stride-based**. Their + source and destination stride operands are start-to-start distances in bytes. +- `pto.copy_ubuf_to_ubuf` is **gap-based**. Its + `%len_burst`, `%src_gap`, and `%dst_gap` operands are encoded in units of + 32 bytes. + +### Key Terms + +``` +GM↔UB burst = lenBurst contiguous bytes transferred per row +GM↔UB stride = distance (bytes) from start of row[r] to start of row[r+1] +pad = ub_stride - lenBurst, padded to the 32B alignment boundary +UB→UB burst = len_burst * 32 bytes +UB→UB next source start = previous source start + (len_burst + src_gap) * 32 bytes +UB→UB next destination start = previous destination start + (len_burst + dst_gap) * 32 bytes +``` + +### Alignment Constraints + +- **UB addresses** (both source and destination) must be **32-byte aligned**. +- **GM→UB padding**: When `data_select_bit = true`, each UB row is padded from `lenBurst` up to the **32B-aligned boundary** of `ub_stride` with `pad_val` (set via `set_mov_pad_val`). This ensures every UB row starts at a 32B-aligned offset. +- **UB→GM de-padding**: MTE3 reads `lenBurst` bytes from each 32B-aligned UB row (skipping any padding that was added during load), writing only valid data to GM. This effectively strips padding on store. + +### UB→UB Raw Copy (`pto.copy_ubuf_to_ubuf`) + +For UB→UB raw copy, each burst copies `len_burst * 32` bytes. + +After burst `r`, the next burst starts at: + +```text +src_next = src_curr + (len_burst + src_gap) * 32 bytes +dst_next = dst_curr + (len_burst + dst_gap) * 32 bytes +``` + +So `src_gap` and `dst_gap` are not start-to-start strides. They are additional +gaps inserted after the copied 32B blocks. + +### 2D Diagram: GM→UB (pto.copy_gm_to_ubuf) + +``` +GM (source, `!pto.ptr`): + + |<--- src_stride (start-to-start) --->| + |<- len_burst ->| | +Row 0: [##DATA########]......................| +Row 1: [##DATA########]......................| +Row 2: [##DATA########]......................| + ... +Row N-1: [##DATA########] + +UB (destination, `!pto.ptr`, 32B-aligned): + + |<---------- dst_stride (32B-aligned) ---------->| + |<- len_burst ->|<- pad (to 32B boundary) ->| | +Row 0: [##DATA########][000000 PAD 000000000000000] +Row 1: [##DATA########][000000 PAD 000000000000000] +Row 2: [##DATA########][000000 PAD 000000000000000] + ... +Row N-1: [##DATA########][000000 PAD 000000000000000] + +N = n_burst +stride = start of row[r] to start of row[r+1] +pad = filled with pad_val to 32B boundary (data_select_bit=true) +[DATA] = valid data transferred by DMA +[PAD] = pad_val fill (set via set_mov_pad_val) +``` + +### 2D Diagram: UB→GM (pto.copy_ubuf_to_gm) + +``` +UB (source, `!pto.ptr`, 32B-aligned start addr): + + |<---------- src_stride (32B-aligned) --------->| + |<- len_burst ->|<-- pad (ignored on read) -->| | +Row 0: [##DATA########][000 pad 000000000000000000] +Row 1: [##DATA########][000 pad 000000000000000000] +Row 2: [##DATA########][000 pad 000000000000000000] + ... +Row N-1: [##DATA########][000 pad 000000000000000000] + +GM (destination, `!pto.ptr`): + + |<--- dst_stride (start-to-start) --->| + |<- len_burst ->| | +Row 0: [##DATA########]......................| +Row 1: [##DATA########]......................| +Row 2: [##DATA########]......................| + ... +Row N-1: [##DATA########] + +N = n_burst +MTE3 reads only len_burst bytes from each UB row (de-padding). +Only len_burst bytes are written to each GM row. +``` + +--- + +## Multi-Level Loop Semantics (C Code) + +The full DMA transfer is a nested loop. The HW loop registers (set before the copy) control the outer levels, and the copy instruction parameters control the innermost burst level. + +### GM→UB Full Loop + +```c +// C equivalent of what the HW executes: +for (int j = 0; j < loop2_count; j++) { // HW outer loop + uint8_t *gm1 = gm_src + j * loop2_src_stride; + uint8_t *ub1 = ub_dst + j * loop2_dst_stride; + + for (int k = 0; k < loop1_count; k++) { // HW inner loop + uint8_t *gm2 = gm1 + k * loop1_src_stride; + uint8_t *ub2 = ub1 + k * loop1_dst_stride; + + for (int r = 0; r < n_burst; r++) { // burst engine + memcpy(ub2 + r * dst_stride, // UB dest row + gm2 + r * src_stride, // GM src row + len_burst); // contiguous bytes + if (data_select_bit) + memset(ub2 + r * dst_stride + len_burst, + pad_val, dst_stride - len_burst); + } + } +} +``` + +### UB→GM Full Loop + +```c +// C equivalent: +for (int j = 0; j < loop2_count; j++) { + uint8_t *ub1 = ub_src + j * loop2_src_stride; + uint8_t *gm1 = gm_dst + j * loop2_dst_stride; + + for (int k = 0; k < loop1_count; k++) { + uint8_t *ub2 = ub1 + k * loop1_src_stride; + uint8_t *gm2 = gm1 + k * loop1_dst_stride; + + for (int r = 0; r < n_burst; r++) { + memcpy(gm2 + r * dst_stride, // GM dest row + ub2 + r * src_stride, // UB src row + len_burst); // contiguous bytes + } + } +} +``` + +--- + +## Example 1: GM→UB — Load a 32×32 f32 Tile (Simple Case) + +Load a 32×32 f32 tile from GM into UB. This matches the `abs_kernel_2d` test case. + +``` +GM layout (32 × 32 f32, contiguous): + + |<- len_burst = 128B (32 × 4) ->| + |<- src_stride = 128B --------->| + +--[#######TILE#######]--+ row 0 + +--[#######TILE#######]--+ row 1 + ... + +--[#######TILE#######]--+ row 31 + +UB layout (32 × 32 f32, 32B-aligned, contiguous): + + |<- dst_stride = 128B (32B-aligned) ->| + +--[#######TILE#######]--+ row 0 + +--[#######TILE#######]--+ row 1 + ... + +--[#######TILE#######]--+ row 31 + + len_burst = 32 × 4 = 128 bytes + src_stride = 128 bytes (contiguous rows) + dst_stride = 128 bytes (already 32B-aligned, no padding) +``` + +```mlir +// Simple 2D load — no multi-level loops needed +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + +pto.copy_gm_to_ubuf %arg0, %ub_in, + %c0_i64, // sid = 0 + %c32_i64, // n_burst = 32 (32 rows) + %c128_i64, // len_burst = 128 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c128_i64, // src_stride = 128 bytes + %c128_i64 // dst_stride = 128 bytes + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +## Example 2: GM→UB — Load a 2D Tile from a Larger Matrix + +Load a 64×128 tile (f16) from a 1024×512 matrix in GM into UB. + +``` +GM layout (1024 × 512 f16): + + col 0 col 128 col 512 + | | | + +--[###TILE###]+.....................+ row R + +--[###TILE###]+.....................+ row R+1 + ... + +--[###TILE###]+.....................+ row R+63 + + |<--------- src_stride = 1024B ----------->| + |<-len_burst=256B->| + + len_burst = 128 × 2 = 256 bytes (128 f16 elements) + src_stride = 512 × 2 = 1024 bytes (start-to-start, full GM row) + +UB layout (64 × 128 f16, 32B-aligned, contiguous): + + +--[###TILE###]--+ row 0 (256 bytes, 32B-aligned, no pad) + +--[###TILE###]--+ row 1 + ... + +--[###TILE###]--+ row 63 + + dst_stride = 256 bytes (= len_burst, already 32B-aligned, no padding) +``` + +```mlir +// Simple 2D load — no multi-level loops needed +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 (64 rows) + %c256_i64, // len_burst = 256 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c1024_i64, // src_stride = 1024 bytes (full matrix row) + %c256_i64 // dst_stride = 256 bytes (tile row) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +## Example 3: GM→UB — Load with Padding + +Load 100 valid columns from GM into a 128-wide UB tile (f16). The remaining 28 columns are zero-padded. + +``` +GM (100 cols valid, contiguous): + + |<-len_burst=200B->| + |<- src_stride=200B (start-to-start) ->| + +--[####DATA####]-+ row 0 + +--[####DATA####]-+ row 1 + ... + +--[####DATA####]-+ row 63 + +UB (128 cols wide, 32B-aligned, padded): + + |<--------- dst_stride = 256B (32B-aligned) --------->| + |<-len_burst=200B->|<---- pad = 56B to 32B boundary ->| + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 0 + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 1 + ... + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 63 + + len_burst = 100 × 2 = 200 bytes + src_stride = 200 bytes (start-to-start, contiguous in GM) + dst_stride = 128 × 2 = 256 bytes (32B-aligned tile width in UB) + pad = 256 - 200 = 56 bytes (padded to 32B boundary with pad_val) +``` + +```mlir +%pad = arith.constant 0 : i16 +pto.set_mov_pad_val %pad : i16 +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 + %c200_i64, // len_burst = 200 bytes + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %true, // data_select_bit = true (enable padding) + %c0_i64, // l2_cache_ctl = 0 + %c200_i64, // src_stride = 200 bytes + %c256_i64 // dst_stride = 256 bytes (32B-aligned) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +## Example 4: UB→GM — Store a 32×32 f32 Tile (Simple Case) + +Store a 32×32 f32 tile from UB back to GM. This matches the `abs_kernel_2d` test case. + +``` +UB (source, 32B-aligned, 32 × 32 f32): + + |<- src_stride = 128B (32B-aligned) ->| + |<- len_burst = 128B ->| + +--[#######TILE#######]---+ row 0 + +--[#######TILE#######]---+ row 1 + ... + +--[#######TILE#######]---+ row 31 + + (no padding here — len_burst == src_stride) + +GM (dest, 32 × 32 f32): + + |<- dst_stride = 128B ->| + |<- len_burst = 128B -->| + +--[#######TILE#######]---+ row 0 + +--[#######TILE#######]---+ row 1 + ... + +--[#######TILE#######]---+ row 31 +``` + +```mlir +// Configure MTE3 strides +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + +pto.copy_ubuf_to_gm %ub_out, %arg1, + %c0_i64, // sid = 0 + %c32_i64, // n_burst = 32 + %c128_i64, // len_burst = 128 bytes + %c0_i64, // reserved = 0 + %c128_i64, // dst_stride = 128 bytes + %c128_i64 // src_stride = 128 bytes + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` + +--- + +## Example 5: UB→GM — Store a 2D Tile Back to a Larger Matrix + +Store a 64×128 tile (f16) from UB back to a 1024×512 GM matrix at an offset. + +``` +UB (source, 32B-aligned, 64 × 128 f16): + + |<- src_stride = 256B (32B-aligned) ->| + |<- len_burst = 256B ->| + +--[#####TILE#####]---+ row 0 + +--[#####TILE#####]---+ row 1 + ... + +--[#####TILE#####]---+ row 63 + + (no padding here — len_burst == src_stride) + +GM (dest, into 1024 × 512 matrix): + + |<----------- dst_stride = 1024B (start-to-start) --------->| + |<- len_burst = 256B ->| | + col 0 col 128 col 512 + +--[#####TILE#####]---+.............................+ row R + +--[#####TILE#####]---+.............................+ row R+1 + ... + +--[#####TILE#####]---+.............................+ row R+63 + + MTE3 reads len_burst bytes from each 32B-aligned UB row, + writes only len_burst bytes per GM row (stride controls row spacing). +``` + +```mlir +// Configure MTE3 strides +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_ubtoout %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_ubtoout %c0_i64, %c0_i64 : i64, i64 + +pto.copy_ubuf_to_gm %ub_ptr, %gm_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 + %c256_i64, // len_burst = 256 bytes + %c0_i64, // reserved = 0 + %c1024_i64, // dst_stride = 1024 bytes (GM row) + %c256_i64 // src_stride = 256 bytes (UB row) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` + +--- + +## Example 6: GM→UB with Multi-Level Loop (Batch of Tiles) + +Load 4 batches of 8×128 tiles from a [4, 8, 128] f16 tensor using loop1. + +``` +GM [4, 8, 128] f16 (contiguous): UB (4 tiles laid out sequentially): + + batch 0: 8 rows × 256 bytes [batch 0: 8×128][batch 1: 8×128] + batch 1: 8 rows × 256 bytes [batch 2: 8×128][batch 3: 8×128] + batch 2: 8 rows × 256 bytes + batch 3: 8 rows × 256 bytes loop1 src_stride = 2048 bytes (8 × 256) + loop1 dst_stride = 2048 bytes (8 × 256) + Each batch = 8 × 256 = 2048 bytes loop1_count = 4 (iterate over batches) +``` + +```mlir +// loop1_count = 4 batches, loop2_count = 1 (not used) +pto.set_loop_size_outtoub %c4_i64, %c1_i64 : i64, i64 + +// loop1 stride: advance by one batch (2048 bytes) in both GM and UB +pto.set_loop1_stride_outtoub %c2048_i64, %c2048_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c8_i64, // n_burst = 8 rows per batch + %c256_i64, // len_burst = 256 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c256_i64, // src_stride = 256 (contiguous rows) + %c256_i64 // dst_stride = 256 (contiguous rows) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +Execution trace: + +``` +loop1 iter 0: gm_ptr + 0×2048 → ub_ptr + 0×2048, DMA 8 rows × 256B +loop1 iter 1: gm_ptr + 1×2048 → ub_ptr + 1×2048, DMA 8 rows × 256B +loop1 iter 2: gm_ptr + 2×2048 → ub_ptr + 2×2048, DMA 8 rows × 256B +loop1 iter 3: gm_ptr + 3×2048 → ub_ptr + 3×2048, DMA 8 rows × 256B +``` diff --git a/docs/isa-legacy/16-cube-matmul.md b/docs/isa-legacy/16-cube-matmul.md new file mode 100644 index 000000000..b19befa47 --- /dev/null +++ b/docs/isa-legacy/16-cube-matmul.md @@ -0,0 +1,152 @@ +# 16A. Cube Matmul Raw Ops (MAT) + +> **Category:** Cube unit raw ops — low-level data movement and matrix-side configuration +> **Audience:** Backend / lowering / bridge-op implementers + +--- + +## Scope + +This document lists the raw cube matmul ops used by wrapper interfaces in +`16-cube-matmul.md`. + +If you are writing user-facing VPTO kernels, prefer wrapper ops such as +`pto.mte_gm_l1`, `pto.mte_l1_l0a`, `pto.mte_l1_l0b`, `pto.mte_l0c_*`, and +`pto.mte_gm_l1_frac`. + +--- + +## Raw Staging / Load Ops + +### `pto.copy_gm_to_cbuf` + +- **syntax:** +```mlir +pto.copy_gm_to_cbuf %src, %dst, %n_burst, %len_burst, %src_stride, %dst_stride + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 +``` +- **semantics:** Copy matrix tile data from GM to L1 (`cbuf`). + +### `pto.load_cbuf_to_ca` + +- **syntax:** +```mlir +pto.load_cbuf_to_ca %src, %dst, %m_start, %k_start, %m_step, %k_step, %src_stride, %dst_stride + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` +- **semantics:** Load L1 (`cbuf`) tile to L0A. + +### `pto.load_cbuf_to_cb` + +- **syntax:** +```mlir +pto.load_cbuf_to_cb %src, %dst, %m_start, %k_start, %m_step, %k_step, %src_stride, %dst_stride + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` +- **semantics:** Load L1 (`cbuf`) tile to L0B. + +### `pto.load_cbuf_to_ca_mx` + +- **syntax:** +```mlir +pto.load_cbuf_to_ca_mx %src, %dst, %m, %k + : !pto.ptr, !pto.ptr, i64, i64 +``` +- **semantics:** Load L1 (`cbuf`) tile to L0A using MX path. + +### `pto.load_cbuf_to_cb_mx` + +- **syntax:** +```mlir +pto.load_cbuf_to_cb_mx %src, %dst, %x_start_position, %y_start_position, %x_step, %y_step, %src_stride, %dst_stride + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` +- **semantics:** Load L1 (`cbuf`) tile to L0B using MX path with explicit hardware control fields. + +--- + +## Raw L0C Writeback / Move Ops + +### `pto.copy_matrix_cc_to_gm` + +- **syntax:** +```mlir +pto.copy_matrix_cc_to_gm %src, %dst, %xm, %xt + : !pto.ptr, !pto.ptr, i64, i64 +``` +- **semantics:** Write L0C (`acc`) tile back to GM. + +### `pto.copy_matrix_cc_to_cbuf` + +- **syntax:** +```mlir +pto.copy_matrix_cc_to_cbuf %src, %dst, %config0, %config1 + : !pto.ptr, !pto.ptr, i64, i64 +``` +- **semantics:** Move L0C (`acc`) tile to L1 (`cbuf`). + +### `pto.copy_matrix_cc_to_ub` + +- **syntax:** +```mlir +pto.copy_matrix_cc_to_ub %src, %dst, %config0, %config1 + : !pto.ptr, !pto.ptr, i64, i64 +``` +- **semantics:** Move L0C (`acc`) tile to UB. + +--- + +## Raw CBUF Outbound Ops + +### `pto.copy_cbuf_to_bt` + +- **syntax:** +```mlir +pto.copy_cbuf_to_bt %src, %dst, %len_burst, %n_burst, %src_gap, %dst_gap + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 +``` +- **semantics:** Move L1 (`cbuf`) data to BT buffer. + +### `pto.copy_cbuf_to_fbuf` + +- **syntax:** +```mlir +pto.copy_cbuf_to_fbuf %src, %dst, %n_burst, %len_burst, %src_gap, %dst_gap + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 +``` +- **semantics:** Move L1 (`cbuf`) data to FB-related destination path. + +### `pto.copy_gm_to_cbuf_multi_nd2nz` + +- **syntax:** +```mlir +pto.copy_gm_to_cbuf_multi_nd2nz %src, %dst, %sid, %loop1_src_stride, %l2_cache_ctrl, %n_value, %d_value, %loop4_src_stride, %smallc0_en + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i1 +``` +- **semantics:** Multi-fractal `ND2NZ` staging from GM to L1 (`cbuf`). + +### `pto.copy_gm_to_cbuf_multi_dn2nz` + +- **syntax:** +```mlir +pto.copy_gm_to_cbuf_multi_dn2nz %src, %dst, %sid, %loop1_src_stride, %l2_cache_ctrl, %n_value, %d_value, %loop4_src_stride, %smallc0_en + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i1 +``` +- **semantics:** Multi-fractal `DN2NZ` staging from GM to L1 (`cbuf`). + +--- + +## Raw vs Wrapper Mapping + +| Wrapper op | Typical raw op(s) | +|---|---| +| `pto.mte_gm_l1` | `pto.copy_gm_to_cbuf` + loop setup | +| `pto.mte_l1_l0a` | `pto.load_cbuf_to_ca` | +| `pto.mte_l1_l0b` | `pto.load_cbuf_to_cb` | +| `pto.mte_l1_l0a_mx` | `pto.load_cbuf_to_ca_mx` | +| `pto.mte_l1_l0b_mx` | `pto.load_cbuf_to_cb_mx` | +| `pto.mte_gm_l1_frac` | `pto.copy_gm_to_cbuf_multi_nd2nz` / `pto.copy_gm_to_cbuf_multi_dn2nz` + config setup | +| `pto.mte_l1_bt` | `pto.copy_cbuf_to_bt` | +| `pto.mte_l0c_l1` | `pto.copy_matrix_cc_to_cbuf` (+ related config) | +| `pto.mte_l0c_gm` | `pto.copy_matrix_cc_to_gm` (+ related config) | +| `pto.mte_l0c_ub` | `pto.copy_matrix_cc_to_ub` (+ related config) | diff --git a/docs/isa/micro-isa/01-pipeline-sync.md b/docs/isa/micro-isa/01-pipeline-sync.md new file mode 100644 index 000000000..ea1e7784e --- /dev/null +++ b/docs/isa/micro-isa/01-pipeline-sync.md @@ -0,0 +1,600 @@ +# 1. Pipeline Synchronization + +> **Category:** Synchronization primitives for coordinating pipeline execution +> **Pipelines:** MTE2 (GM→UB), PIPE_V (Vector), MTE3 (UB→GM) + +The PTO micro Instruction model operates on the Ascend 950's **Decoupled Access-Execute** architecture. The MTE and Vector pipelines run asynchronously, requiring explicit synchronization to prevent data hazards. + +--- + +## Intra-Core Pipeline Sync + +These ops coordinate data flow between pipelines within a single vector core. + +### `pto.set_flag` + +- **syntax:** `pto.set_flag["SRC_PIPE", "DST_PIPE", "EVENT_ID"]` +- **semantics:** Signal event from source pipe to destination pipe. + +```c +set_flag(src_pipe, dst_pipe, event_id); +``` + +**Example:** After MTE2 completes GM→UB transfer, signal Vector pipe: +```mlir +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +--- + +### `pto.wait_flag` + +- **syntax:** `pto.wait_flag["SRC_PIPE", "DST_PIPE", "EVENT_ID"]` +- **semantics:** Block destination pipe until source pipe signals event. + +```c +wait_flag(src_pipe, dst_pipe, event_id); +``` + +**Example:** Vector pipe waits for MTE2 data to arrive: +```mlir +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +--- + +### `pto.pipe_barrier` + +- **syntax:** `pto.pipe_barrier "PIPE_*"` +- **semantics:** Drain all pending ops in the specified pipe. All previously issued operations on that pipe complete before any subsequent operation begins. + +```c +pipe_barrier(pipe); +``` + +**Pipe identifiers:** `PIPE_MTE2`, `PIPE_V`, `PIPE_MTE3` + +**Example:** Two back-to-back `pto.mte_ub_gm` calls writing to the same GM address. Without a barrier, MTE3 may reorder them and the final GM value is non-deterministic: + +```mlir +// Both stores target the same GM address — order matters! +pto.mte_ub_gm %ub_partial_0, %gm_result, ... +// Without pipe_barrier, MTE3 could execute the second copy before the first +// completes, producing a non-deterministic result at %gm_result. +pto.pipe_barrier "PIPE_MTE3" +// After barrier: first copy is guaranteed complete. Second copy overwrites deterministically. +pto.mte_ub_gm %ub_partial_1, %gm_result, ... +``` + +--- + +### `pto.get_buf` + +- **syntax:** `pto.get_buf "PIPE_*", %buf_id, %mode : i64, i64` +- **semantics:** Acquire buffer slot for inter-pipeline double-buffering coordination. + +```c +get_buf(pipe, buf_id, mode); +``` + +--- + +### `pto.rls_buf` + +- **syntax:** `pto.rls_buf "PIPE_*", %buf_id, %mode : i64, i64` +- **semantics:** Release buffer slot to allow other pipeline to proceed. + +```c +rls_buf(pipe, buf_id, mode); +``` + +--- + +### Mode Parameter for `get_buf` / `rls_buf` + +The `mode` parameter controls how `get_buf` and `rls_buf` interact with pipeline execution and dependency tracking: + +| Mode | `get_buf` Behavior | `rls_buf` Behavior | Use Case | +|------|-------------------|-------------------|----------| +| **0** (default) | **Blocking acquire**: waits for all previous `rls_buf` with same `buf_id` from all pipelines (in program order) before the specified pipe can proceed | **Immediate release**: signals completion for only the instructions related to the specified pipe | **Automatic ping/pong dependency** — recommended for double/multi-buffering | +| **1** | **Non-blocking acquire**: does not wait; pipe execution proceeds immediately | **Deferred release**: waits for all instructions across all pipelines with same `buf_id` to retire before signaling | **Backward compatibility** with `set_flag`/`wait_flag` semantics | + +**Mode 0 (Default — Recommended):** +- `get_buf`: The specified pipeline blocks until all previous `rls_buf` operations for the same buffer ID (from any pipeline) have completed, respecting program order. +- `rls_buf`: Immediately signals that the specified pipeline has finished using the buffer — only waits for that pipe's related instructions. +- This mode provides **automatic RAW/WAR/WAW dependency resolution** based on buffer ID and program order, making it ideal for ping/pong and N-buffer patterns. + +**Mode 1 (Legacy Compatibility):** +- `get_buf`: Does not block — the pipeline proceeds immediately without waiting. +- `rls_buf`: Waits for **all** previous instructions across **all** pipelines with the same buffer ID to retire before signaling release. +- This mode emulates `set_flag`/`wait_flag` behavior and is provided for backward compatibility with existing code patterns. + +> **Note:** A5 supports both `set_flag`/`wait_flag` and `get_buf`/`rls_buf` mechanisms. Mode 1 is rarely needed since mode 0 provides a more programmer-friendly approach for buffer-based synchronization. + +--- + +### `pto.mem_bar` + +- **syntax:** `pto.mem_bar "BARRIER_TYPE"` +- **semantics:** Shared-memory (UB address space) memory fence within `__VEC_SCOPE__`. Required when UB addresses alias between memory operations. The barrier type selects which classes of prior instructions must complete before which classes of subsequent instructions may proceed. + +```c +mem_bar(barrier_type); +``` + +**Barrier types** are organized into three families by the scope of prior vs. subsequent instructions: + +| Family | Barrier type | Prior instructions | Subsequent instructions | +|--------|-------------|-------------------|------------------------| +| **VV** (vector→vector) | `VV_ALL` | All vector load/store | All vector load/store | +| | `VST_VLD` | All vector store | All vector load | +| | `VLD_VST` | All vector load | All vector store | +| | `VST_VST` | All vector store | All vector store | +| **VS** (vector→scalar) | `VS_ALL` | All vector load/store | All scalar load/store | +| | `VST_LD` | All vector store | All scalar load | +| | `VLD_ST` | All vector load | All scalar store | +| | `VST_ST` | All vector store | All scalar store | +| **SV** (scalar→vector) | `SV_ALL` | All scalar load/store | All vector load/store | +| | `ST_VLD` | All scalar store | All vector load | +| | `LD_VST` | All scalar load | All vector store | +| | `ST_VST` | All scalar store | All vector store | + +**Example:** Ensure vector stores are visible before subsequent vector loads to the same UB region: +```mlir +pto.vsts %v0, %ub[%c0] : !pto.vreg<64xf32>, !pto.ptr +pto.mem_bar "VST_VLD" +%v1 = pto.vlds %ub[%c0] : !pto.ptr -> !pto.vreg<64xf32> +``` + +--- + +## Why `get_buf` / `rls_buf` is More Programmer-Friendly + +The buffer-based synchronization (`get_buf`/`rls_buf`) provides the **same functional capability** as `set_flag`/`wait_flag` for maintaining correct ordering of RAW/WAR/WAW dependencies across pipelines, but with significant usability advantages: + +### 1. No Manual Priming or Draining + +With `set_flag`/`wait_flag`, ping/pong loops require: +- **Pre-loop priming**: 4× `set_flag` to initialize reverse-dependency signals (otherwise first iteration deadlocks) +- **Post-loop draining**: 4× `wait_flag` to consume leftover signals from final iterations + +With `get_buf`/`rls_buf`: +- **First iteration**: Buffer is initially free, so `get_buf` proceeds immediately — no priming needed +- **Final iteration**: Last `rls_buf` simply completes — no draining required + +### 2. No Loop Peeling for Complex Dependencies + +For non-1:1 producer-consumer ratios (e.g., 1 MTE2 load : N Vector compute slices), `set_flag`/`wait_flag` requires **peeling the set_flag outside the loop**: + +```mlir +// set_flag/wait_flag: 1 MTE2 load, 8 Vector computes on slices +// MTE2 loads large tile once +pto.mte_gm_ub %gm_ptr, %ub_tile, ... +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVT_TILE_READY"] // ◀ MUST be outside loop + +// Vector consumes in 8 slices — but wait_flag can only fire ONCE +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVT_TILE_READY"] // ◀ MUST peel before loop +scf.for %slice = %c0 to %c8 step %c1 { + // compute on %ub_tile[%slice] + // Cannot put wait_flag here — would deadlock on iteration 1+ +} +``` + +With `get_buf`/`rls_buf`, acquire/release can be **inside the loop** — no peeling needed: + +```mlir +// get_buf/rls_buf: same 1:8 pattern, acquire/release inside loop works fine +// MTE2 loads large tile +pto.get_buf "PIPE_MTE2", %bufid_tile, %c0 : i64, i64 +pto.mte_gm_ub %gm_ptr, %ub_tile, ... +pto.rls_buf "PIPE_MTE2", %bufid_tile, %c0 : i64, i64 + +// Vector acquires/releases per slice — all 8 iterations work correctly +scf.for %slice = %c0 to %c8 step %c1 { + pto.get_buf "PIPE_V", %bufid_tile, %c0 : i64, i64 // iteration 0: blocks until MTE2 done + // iteration 1-7: proceeds immediately (already acquired) + // compute on %ub_tile[%slice] + pto.rls_buf "PIPE_V", %bufid_tile, %c0 : i64, i64 +} +// No peeling required — get_buf handles the MTE2→V dependency automatically +``` + +### 3. Simpler Mental Model + +| Aspect | `set_flag`/`wait_flag` | `get_buf`/`rls_buf` | +|--------|------------------------|---------------------| +| **Dependency tracking** | Manual: track event IDs, signal directions, pair every set with wait | Automatic: buffer ID + program order | +| **Event ID management** | **8 IDs per pipe-pair direction** (HW limit); each buffer occupies 1 ID per direction | **1 buffer ID per shared resource** (HW limit: 32 global); same ID used across all pipelines | +| **Error-prone areas** | Forgetting prime/drain, mismatched IDs, wrong direction | Forgetting release (but compile-time checkable) | + +### Quick Example Comparison + +**Problem:** MTE2 loads into `buf[i%2]`, Vector processes, MTE3 stores — standard ping/pong. + +**set_flag/wait_flag approach:** +```mlir +// BEFORE loop: prime 4 reverse-dep signals +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_0"] +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_1"] +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_0"] +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_1"] + +scf.for %i = ... { + // 4 set_flag + 4 wait_flag inside loop + // Must track 4 IDs: 2 pipe-pair directions × 2 ping/pong buffers +} + +// AFTER loop: drain 4 signals +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_0"] +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_1"] +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_0"] +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_1"] +``` + +**get_buf/rls_buf approach:** +```mlir +scf.for %i = ... { + pto.get_buf %bufid_in[%pp], "PIPE_MTE2" + // ... MTE2 work ... + pto.rls_buf %bufid_in[%pp], "PIPE_MTE2" + + pto.get_buf %bufid_in[%pp], "PIPE_V" + pto.get_buf %bufid_out[%pp], "PIPE_V" + // ... Vector work ... + pto.rls_buf %bufid_in[%pp], "PIPE_V" + pto.rls_buf %bufid_out[%pp], "PIPE_V" + + pto.get_buf %bufid_out[%pp], "PIPE_MTE3" + // ... MTE3 work ... + pto.rls_buf %bufid_out[%pp], "PIPE_MTE3" +} +// Done. No prime. No drain. Dependencies resolved by buffer ID + program order. +``` + +--- + +## Intra-Core Sync Patterns & Examples + +### Example 1: `set_flag` / `wait_flag` (Explicit Events) + +Each cross-pipeline data dependency requires an explicit signal/wait pair. The programmer must manually insert `set_flag` after the producer and `wait_flag` before the consumer. + +```mlir +// ─── Stage 1: MTE2 loads data from GM into UB ─── +pto.mte_gm_ub %gm_ptr, %ub_ptr, ... + +// MTE2 signals: "UB data is ready for Vector pipe" +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +// ─── Stage 2: Vector pipe consumes UB data ─── +// Vector waits until MTE2's signal arrives +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +pto.vecscope { + %v = pto.vlds %ub_ptr[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} + +// Vector signals: "UB output is ready for MTE3" +pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + +// ─── Stage 3: MTE3 stores result from UB back to GM ─── +// MTE3 waits until Vector's signal arrives +pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + +pto.mte_ub_gm %ub_out, %gm_out, ... +``` + +**Key property:** Every cross-pipeline edge is an explicit `(set_flag, wait_flag)` pair. Simple for straight-line code, but gets verbose in loops (see Example 3). + +--- + +### Example 2: `get_buf` / `rls_buf` (Resource-Based) + +Instead of naming events, each pipeline declares when it **acquires** (`get_buf`) and **releases** (`rls_buf`) a shared UB buffer. Cross-pipeline RAW/WAR dependencies are resolved implicitly by program order — if MTE2 releases `buf_A` and Vector later acquires `buf_A`, the hardware ensures the acquire cannot proceed until the release completes. + +```mlir +// ─── Stage 1: MTE2 loads data into UB ─── +// MTE2 acquires ub_ptr — blocks if Vector hasn't released it from a prior iteration +pto.get_buf "PIPE_MTE2", %bufid_ub_ptr, %c0 : i64, i64 // mode=0 (default) +pto.mte_gm_ub %gm_ptr, %ub_ptr, ... +// MTE2 done writing ub_ptr — release it so Vector can consume +pto.rls_buf "PIPE_MTE2", %bufid_ub_ptr, %c0 : i64, i64 + +// ─── Stage 2: Vector computation ─── +// Vector acquires ub_ptr (input) — blocks until MTE2 releases it (RAW: MTE2 write → V read) +pto.get_buf "PIPE_V", %bufid_ub_ptr, %c0 : i64, i64 +// Vector acquires ub_out (output) — blocks until MTE3 releases it from a prior iteration (WAR: MTE3 read → V write) +pto.get_buf "PIPE_V", %bufid_ub_out, %c0 : i64, i64 + +pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub_ptr[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} + +// Vector done reading ub_ptr — release so MTE2 can reuse it in next iteration +pto.rls_buf "PIPE_V", %bufid_ub_ptr, %c0 : i64, i64 +// Vector done writing ub_out — release so MTE3 can consume +pto.rls_buf "PIPE_V", %bufid_ub_out, %c0 : i64, i64 + +// ─── Stage 3: MTE3 stores result to GM ─── +// MTE3 acquires ub_out — blocks until Vector releases it (RAW: V write → MTE3 read) +pto.get_buf "PIPE_MTE3", %bufid_ub_out, %c0 : i64, i64 +pto.mte_ub_gm %ub_out, %gm_out, ... +// MTE3 done reading ub_out — release so Vector can reuse it in next iteration +pto.rls_buf "PIPE_MTE3", %bufid_ub_out, %c0 : i64, i64 +``` + +**Key property:** No event IDs needed. Dependencies are implicit from program order of `get_buf`/`rls_buf` on the same buffer ID. This becomes much more convenient in multi-iteration loops (see Example 3). + +--- + +### Example 3: Ping/Pong Double-Buffering Loop + +Double-buffering overlaps DMA and compute by using two UB buffers alternately. All three stages (MTE2, Vector, MTE3) appear in the **same iteration** — the hardware pipelines them across iterations because different iterations operate on different buffers (`buf[i%2]`). + +#### Event ID scheme (`set_flag` / `wait_flag`) + +With 2 ping/pong buffers and 2 pipeline pairs (MTE2↔V, V↔MTE3), `set_flag`/`wait_flag` needs **8 event IDs** = 2 pipe-pairs × 2 buffers × (forward + reverse): + +**MTE2 ↔ Vector (input buffers):** + +| Event ID | Direction | Purpose | +|----------|-----------|---------| +| `EVT_IN_FWD_0` | MTE2 → V | RAW: buf_in[0] data ready | +| `EVT_IN_FWD_1` | MTE2 → V | RAW: buf_in[1] data ready | +| `EVT_IN_REV_0` | V → MTE2 | WAR: Vector done reading buf_in[0] | +| `EVT_IN_REV_1` | V → MTE2 | WAR: Vector done reading buf_in[1] | + +**Vector ↔ MTE3 (output buffers):** + +| Event ID | Direction | Purpose | +|----------|-----------|---------| +| `EVT_OUT_FWD_0` | V → MTE3 | RAW: buf_out[0] result ready | +| `EVT_OUT_FWD_1` | V → MTE3 | RAW: buf_out[1] result ready | +| `EVT_OUT_REV_0` | MTE3 → V | WAR: MTE3 done reading buf_out[0] | +| `EVT_OUT_REV_1` | MTE3 → V | WAR: MTE3 done reading buf_out[1] | + +#### 3a. `set_flag` / `wait_flag` version + +```mlir +// ═══ Pre-loop: prime ALL reverse-dependency signals ═══ +// Both input and output buffers start unused. We must pre-send +// reverse-dep signals so the first iteration's wait_flags don't deadlock. +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_0"] // ◀ PRIME: buf_in[0] "free" +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_1"] // ◀ PRIME: buf_in[1] "free" +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_0"] // ◀ PRIME: buf_out[0] "free" +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_1"] // ◀ PRIME: buf_out[1] "free" + +scf.for %i = %c0 to %N step %c1 { + // ── All 3 stages in same iteration, indexed by i%2 ── + // %pp = i % 2 (ping/pong selector for buffer & event IDs) + + // ── MTE2: load tile[i] into buf_in[i%2] ── + // WAR: wait until Vector has released buf_in[i%2] from iteration i-2 + pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{pp}"] + pto.mte_gm_ub %gm_ptr[%i], %ub_in[%pp], ... + // RAW: signal Vector that buf_in[i%2] data is ready + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVT_IN_FWD_{pp}"] + + // ── Vector: compute buf_in[i%2] → buf_out[i%2] ── + // RAW: wait for MTE2 to finish loading buf_in[i%2] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVT_IN_FWD_{pp}"] + // WAR: wait for MTE3 to finish reading buf_out[i%2] from iteration i-2 + pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{pp}"] + scf.for %dummy = %c0 to %c1 step %c1 { + %v = pto.vlds %ub_in[%pp][%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%pp][%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } {llvm.loop.aivector_scope} + // WAR: tell MTE2 "done reading buf_in[i%2]" + pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{pp}"] + // RAW: tell MTE3 "buf_out[i%2] result ready" + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVT_OUT_FWD_{pp}"] + + // ── MTE3: store result from buf_out[i%2] to GM ── + // RAW: wait for Vector to finish writing buf_out[i%2] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVT_OUT_FWD_{pp}"] + pto.mte_ub_gm %ub_out[%pp], %gm_out[%i], ... + // WAR: tell Vector "done reading buf_out[i%2]" + pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{pp}"] +} + +// ═══ Post-loop: drain — match every pre-loop prime with a wait ═══ +// Each priming set_flag must be paired. The last loop iteration's +// set_flags are consumed by wait_flags that will never fire inside the +// loop (there is no iteration i+2). Drain them here. +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{(N-1)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{(N-2)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{(N-1)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{(N-2)%2}"] // ◀ DRAIN +``` + +**What `set_flag`/`wait_flag` requires outside the loop:** +- **Before the loop (4 × `set_flag`):** Prime every reverse-dependency event ID — one per buffer per pipe-pair. Without this, the first iteration's `wait_flag` for reverse deps would deadlock (no signal was ever sent). +- **After the loop (4 × `wait_flag`):** Drain the matching reverse-dep signals from the last iterations. Every `set_flag` must be paired with a `wait_flag` — the last loop iterations produce signals that no subsequent iteration consumes, so they must be drained explicitly. + +#### 3b. `get_buf` / `rls_buf` version + +Same ping/pong double-buffering, but **no pre-loop priming or post-loop draining needed.** Buffer acquire/release semantics handle everything. + +```mlir +scf.for %i = %c0 to %N step %c1 { + // %pp = i % 2 (ping/pong selector) + + // ── MTE2: load tile[i] into buf[i%2] ── + // Acquires buf[i%2] — on first iteration, buffer is free so proceeds immediately. + // On later iterations, blocks until Vector releases buf[i%2] (WAR: automatic). + pto.get_buf "PIPE_MTE2", %bufid_buf[%pp], %c0 : i64, i64 // mode=0 + pto.mte_gm_ub %gm_ptr[%i], %ub_buf[%pp], ... + pto.rls_buf "PIPE_MTE2", %bufid_buf[%pp], %c0 : i64, i64 + + // ── Vector: compute on buf[i%2] ── + // Acquires buf[i%2] — blocks until MTE2 releases it (RAW: automatic) + pto.get_buf "PIPE_V", %bufid_buf[%pp], %c0 : i64, i64 + pto.get_buf "PIPE_V", %bufid_out[%pp], %c0 : i64, i64 + scf.for %dummy = %c0 to %c1 step %c1 { + %v = pto.vlds %ub_buf[%pp][%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%pp][%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } {llvm.loop.aivector_scope} + // Release buf[i%2] — MTE2 can reuse in iteration i+2 (WAR resolved) + pto.rls_buf "PIPE_V", %bufid_buf[%pp], %c0 : i64, i64 + pto.rls_buf "PIPE_V", %bufid_out[%pp], %c0 : i64, i64 + + // ── MTE3: store result ── + // Acquires out[i%2] — blocks until Vector releases it (RAW: automatic) + pto.get_buf "PIPE_MTE3", %bufid_out[%pp], %c0 : i64, i64 + pto.mte_ub_gm %ub_out[%pp], %gm_out[%i], ... + pto.rls_buf "PIPE_MTE3", %bufid_out[%pp], %c0 : i64, i64 +} +// No post-loop drain needed — last rls_buf completes the pipeline. +``` + +**No priming, no draining, no event IDs.** The acquire/release protocol on buffer IDs indexed by `i%2` implicitly resolves all cross-pipeline dependencies: +- **RAW** (MTE2→V): Vector's `get_buf` blocks until MTE2's `rls_buf` on `buf[i%2]` +- **WAR** (V→MTE2): MTE2's `get_buf` in iteration `i+2` blocks until Vector's `rls_buf` in iteration `i` (same buffer) +- **First iteration:** Buffer is initially free, so `get_buf` proceeds without blocking — no priming needed + +--- + +## Comparison Summary + +| Aspect | `set_flag` / `wait_flag` | `get_buf` / `rls_buf` | +|--------|--------------------------|------------------------| +| Dependency model | Explicit event signals | Implicit via buffer acquire/release | +| IDs per pipe-pair | 2 IDs per buffer: 1 for forward (e.g., MTE2→V) + 1 for reverse (V→MTE2) | **1 ID per buffer** (handles both directions automatically) | +| Total HW IDs | **8 per pipe-pair** (hardware limit) | **32 global** across all pipes | +| Reverse (WAR) deps | Extra `set_flag`/`wait_flag` pair per buffer | Handled automatically | +| Pre-loop setup | `set_flag` to prime each reverse dep | **None** | +| Post-loop teardown | `wait_flag` to drain all primed signals | **None** | +| Loop peeling for complex deps | Required for non-1:1 or nested loops | **Not required** | +| Straight-line code | Simple, clear | Slightly more verbose (bracket each stage) | +| Ping/pong loops | 8 event IDs + 4 prime + 4 drain | Same pattern, **no overhead** | +| Best used for | Simple pipelines, fine-grained control | Double/multi-buffering, complex loops | + +--- + +## Inter-Core Sync + +> **Note:** Inter-core sync is only needed for **mixed Cube+Vector tasks** where Cube produces data that Vector consumes (or vice versa). **Vec-only tasks can ignore this section entirely.** + +These ops coordinate execution across the Cube block and Vector subblocks within a cluster. Each core cluster consists of **1 Cube block : 2 Vector subblocks**, each with its own **SU (Sequencer Unit)** running independent instruction streams. + +``` +Core Cluster (1:2 ratio) +┌─────────────────────────────────────────────┐ +│ ┌──────────────┐ ┌──────────────┐ │ +│ │ AIC (Cube) │ │ AIV0 (Vec) │ │ +│ │ ┌────────┐ │ │ ┌────────┐ │ │ +│ │ │ SU │──┼────┼──│ SU │ │ │ +│ │ └────────┘ │ │ └────────┘ │ │ +│ │ CUBE pipe │ │ MTE2/V/MTE3 │ │ +│ │ L0C buffer │ │ UB (256KB) │ │ +│ └──────────────┘ └──────────────┘ │ +│ ┌──────────────┐ │ +│ │ AIV1 (Vec) │ │ +│ │ ┌────────┐ │ │ +│ │ │ SU │ │ │ +│ │ └────────┘ │ │ +│ │ MTE2/V/MTE3 │ │ +│ │ UB (256KB) │ │ +│ └──────────────┘ │ +└─────────────────────────────────────────────┘ +``` + +### Platform Comparison + +| Aspect | A2A3 (Ascend 910) | A5 (Ascend 950) | +|--------|-------------------|-----------------| +| **Signal op** | `set_cross_core` (mode2) | `set_intra_block` | +| **Wait op** | `wait_flag_dev` | `wait_intra_core` | +| **Wait behavior** | SU-level blocking (entire core stalls) | Per-pipeline (only named pipe stalls) | +| **Semaphore pool** | 16 IDs per cluster, 4-bit counter | 16 IDs, but 32-ID address space (see below) | +| **C→V** | **Broadcast**: one `set` reaches both AIV0+AIV1 | **1:1**: separate `set` per subblock required | +| **V→C** | **Reduce**: Cube waits for both subblocks in one `wait` | **1:1**: Cube needs separate `wait` per subblock | + +### A2A3: `set_cross_core` / `wait_flag_dev` + +```c +// mode2 broadcast/reduce semantics for 1:2 cluster +set_cross_core(pipe, semaphore_id); // pipe: VEC/MTE2/CUBE/FIX +wait_flag_dev(semaphore_id); // SU-level blocking +``` + +``` +C→V Broadcast (one set reaches both): + AIC ──set_cross_core──┬──> AIV0 sema++ + └──> AIV1 sema++ + +V→C Reduce (one wait for both): + AIV0 ──set_cross_core──┐ + ├──> AIC wait_flag_dev (blocks until BOTH) + AIV1 ──set_cross_core──┘ +``` + +### `pto.set_cross_core` + +- **syntax:** `pto.set_cross_core %core_id, %event_id : i64, i64` +- **semantics:** Signal event to another core. Uses **mode2** for 1:2 cluster on A2A3. + +### `pto.wait_flag_dev` + +- **syntax:** `pto.wait_flag_dev %core_id, %event_id : i64, i64` +- **semantics:** Wait for event from another core. **SU-level blocking** — entire core stalls. + +### A5: `set_intra_block` / `wait_intra_core` + +```c +set_intra_block(trigger_pipe, semaphore_id); +wait_intra_core(wait_pipe, semaphore_id); // only named pipe stalls +``` + +**A5 semaphore address space:** The hardware has **16 physical semaphore IDs** but exposes a **32-ID address space** to support 1:1 signaling to each subblock: + +| ID Range | Target | +|----------|--------| +| 0–15 | AIV0 (subblock 0) | +| 16–31 (+15 offset) | AIV1 (subblock 1) | + +This means C→V requires **separate `set_intra_block` calls** per subblock (no broadcast), and V→C requires **separate `wait_intra_core` calls** per subblock (no hardware reduce). + +``` +C→V on A5 (1:1, no broadcast — need two sets): + AIC ──set_intra_block(pipe, sema_id)────> AIV0 + AIC ──set_intra_block(pipe, sema_id+15)──> AIV1 + +V→C on A5 (1:1, no reduce — need two waits): + AIV0 ──set_intra_block──> AIC wait_intra_core(pipe, sema_id) + AIV1 ──set_intra_block──> AIC wait_intra_core(pipe, sema_id+15) // extra wait +``` + +### `pto.set_intra_block` + +- **syntax:** `pto.set_intra_block %block_id, %event_id : i64, i64` +- **semantics:** Signal event within a block (A5). Specifies **trigger pipe**. 1:1 per subblock. + +### `pto.wait_intra_core` + +- **syntax:** `pto.wait_intra_core %block_id, %event_id : i64, i64` +- **semantics:** Wait for event within block (A5). Specifies **which pipeline should wait** — only that pipe stalls, SU and other pipes continue. + +### Wait Granularity Comparison + +``` +A2A3 wait_flag_dev (SU-level stall): + SU ──┬── PIPE_MTE2 ───╳ ALL STALLED + ├── PIPE_V ───╳ ALL STALLED + └── PIPE_MTE3 ───╳ ALL STALLED + +A5 wait_intra_core "PIPE_MTE2" (per-pipe stall): + SU ──┬── PIPE_MTE2 ───╳ STALLED (waiting for Cube) + ├── PIPE_V ─── ✓ RUNNING + └── PIPE_MTE3 ─── ✓ RUNNING +``` diff --git a/docs/isa/micro-isa/02-dma-copy.md b/docs/isa/micro-isa/02-dma-copy.md new file mode 100644 index 000000000..4747237c0 --- /dev/null +++ b/docs/isa/micro-isa/02-dma-copy.md @@ -0,0 +1,585 @@ +# 2. DMA Copy Programming + +> **Category:** DMA transfer configuration and execution +> **Pipelines:** MTE2 (GM→UB), MTE3 (UB→GM) + +DMA transfers move data between Global Memory (GM) and Unified Buffer (UB). The MTE engines operate asynchronously from the Vector core, requiring explicit sync (see [Pipeline Sync](01-pipeline-sync.md)). + +This document describes the public grouped DMA interfaces: + +- `pto.mte_gm_ub` +- `pto.mte_ub_gm` +- `pto.mte_ub_ub` +- `pto.mte_ub_l1` + +--- + +## DMA Transfer Execution + +### `pto.mte_gm_ub` + +- **syntax:** +```mlir +pto.mte_gm_ub %gm_src, %ub_dst, %l2_cache_ctl, %len_burst + nburst(%n_burst, %src_stride, %dst_stride) + [loop(%loop_count, %loop_src_stride, %loop_dst_stride)]* + [pad(%pad_value[, %left_padding_count, %right_padding_count])] + : !pto.ptr, !pto.ptr, i64, i64, i64, + [loop i64, i64, i64,]* + [pad T[, i64, i64]] +``` +- **semantics:** Grouped GM→UB DMA transfer. `nburst(...)` defines the innermost repeated burst transfer, optional `loop(...)` groups add outer repetition levels, and `pad(...)` controls UB row padding. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%gm_src` | ptr | GM source pointer (`!pto.ptr`) | +| `%ub_dst` | ptr | UB destination pointer (`!pto.ptr`, 32B-aligned) | +| `%l2_cache_ctl` | 2 bits | L2 cache allocate control | +| `%len_burst` | 16 bits | Contiguous bytes transferred per burst row | +| `nburst(%n_burst, %src_stride, %dst_stride)` | 16 bits / 40 bits / 21 bits | Required innermost burst group: count, GM source stride, UB destination stride | +| `loop(%loop_count, %loop_src_stride, %loop_dst_stride)` | 21 bits / 40 bits / 21 bits | Optional outer repetition group: count, GM source stride, UB destination stride | +| `pad(%pad_value[, %left_padding_count, %right_padding_count])` | scalar / 8 bits / 8 bits | Optional padding: fill value, optional left padding count, optional right padding count | + +**Constraints:** + +- `nburst(...)` is always required. +- Each `loop(...)` group must be provided as a complete triple when present. +- `nburst(...)` is the innermost group. +- `loop(...)` groups are ordered from inner to outer. +- The first `loop(...)` group wraps `nburst(...)`. +- Each additional `loop(...)` group wraps all earlier groups. +- `pad(...)` may contain only `%pad_value`; omitted left and right padding counts default to 0. +- If either left or right padding count is provided, both counts must be provided. +- `pad(...)` is independent of the optional `loop(...)` groups. +- A DMA load may use `nburst(...) pad(...)` without any `loop(...)` group. + +**Example:** + +```mlir +pto.mte_gm_ub %gm_in, %ub_out, %cache, %len_burst + nburst(%rows, %gm_row_stride, %ub_row_stride) + loop(%tiles, %gm_tile_stride, %ub_tile_stride) + pad(%pad) + : !pto.ptr, !pto.ptr, i64, i64, + loop i64, i64, i64, pad f16 +``` + +--- + +### `pto.mte_ub_gm` + +- **syntax:** +```mlir +pto.mte_ub_gm %ub_src, %gm_dst, %len_burst + nburst(%n_burst, %src_stride, %dst_stride) + [loop(%loop_count, %loop_src_stride, %loop_dst_stride)]* + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, + [loop i64, i64, i64,]* +``` +- **semantics:** Grouped UB→GM DMA transfer. `nburst(...)` defines the innermost repeated burst transfer, and optional `loop(...)` groups add outer repetition levels. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%ub_src` | ptr | UB source pointer (`!pto.ptr`, 32B-aligned) | +| `%gm_dst` | ptr | GM destination pointer (`!pto.ptr`) | +| `%len_burst` | 16 bits | Contiguous bytes transferred per burst row | +| `nburst(%n_burst, %src_stride, %dst_stride)` | 16 bits / 21 bits / 40 bits | Required innermost burst group: count, UB source stride, GM destination stride | +| `loop(%loop_count, %loop_src_stride, %loop_dst_stride)` | 21 bits / 21 bits / 40 bits | Optional outer repetition group: count, UB source stride, GM destination stride | + +**Constraints:** + +- `nburst(...)` is always required. +- Each `loop(...)` group must be provided as a complete triple when present. +- `nburst(...)` is the innermost group. +- `loop(...)` groups are ordered from inner to outer. +- The first `loop(...)` group wraps `nburst(...)`. +- Each additional `loop(...)` group wraps all earlier groups. + +**Example:** + +```mlir +pto.mte_ub_gm %ub_in, %gm_out, %len_burst + nburst(%rows, %ub_row_stride, %gm_row_stride) + loop(%tiles, %ub_tile_stride, %gm_tile_stride) + loop(%batches, %ub_batch_stride, %gm_batch_stride) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, + loop i64, i64, i64, loop i64, i64, i64 +``` + +--- + +### `pto.mte_ub_ub` + +- **syntax:** +```mlir +pto.mte_ub_ub %ub_src, %ub_dst, %len_burst + nburst(%n_burst, %src_gap, %dst_gap) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 +``` +- **semantics:** Grouped UB→UB copy. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%ub_src` | ptr | UB source pointer (`!pto.ptr`, 32B-aligned) | +| `%ub_dst` | ptr | UB destination pointer (`!pto.ptr`, 32B-aligned) | +| `%len_burst` | 16 bits | Burst length in units of 32 bytes | +| `nburst(%n_burst, %src_gap, %dst_gap)` | 16 bits / 16 bits / 16 bits | Required copy burst group: count, source gap, destination gap | + +**Constraints:** + +- UB source and destination addresses must be 32B-aligned. +- `%len_burst`, `%src_gap`, and `%dst_gap` are encoded in units of 32 bytes. + +**Example:** + +```mlir +pto.mte_ub_ub %ub_src, %ub_dst, %len32b + nburst(%rows, %src_gap, %dst_gap) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 +``` + +--- + +### `pto.mte_ub_l1` + +- **syntax:** +```mlir +pto.mte_ub_l1 %ub_src, %l1_dst, %len_burst + nburst(%n_burst, %src_gap, %dst_gap) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 +``` +- **semantics:** Grouped UB→L1/CBUF copy. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%ub_src` | ptr | UB source pointer (`!pto.ptr`, 32B-aligned) | +| `%l1_dst` | ptr | L1 destination pointer (`!pto.ptr`, 32B-aligned) | +| `%len_burst` | 16 bits | Burst length in units of 32 bytes | +| `nburst(%n_burst, %src_gap, %dst_gap)` | 16 bits / 16 bits / 16 bits | Required copy burst group: count, source gap, destination gap | + +**Constraints:** + +- UB source and L1 destination addresses must be 32B-aligned. +- `%len_burst`, `%src_gap`, and `%dst_gap` are encoded in units of 32 bytes. + +**Example:** + +```mlir +pto.mte_ub_l1 %ub_src, %l1_dst, %len32b + nburst(%rows, %src_gap, %dst_gap) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 +``` + +--- + +## Grouped DMA Burst / Stride / Pad Model + +This section describes the grouped DMA interfaces in this document: +`pto.mte_gm_ub` and `pto.mte_ub_gm`. + +For these grouped DMA ops, the innermost `nburst(...)` group is +**stride-based**: the source and destination stride operands are the +start-to-start byte distance from one burst row to the next row. + +### Key Terms + +``` +burst = lenBurst contiguous bytes transferred per row +stride = distance (bytes) from start of row[r] to start of row[r+1] +pad = ub_stride - lenBurst, padded to the 32B alignment boundary +``` + +### Alignment Constraints + +- **UB addresses** (both source and destination) must be **32-byte aligned**. +- **GM→UB padding**: When `pad(...)` is present on `pto.mte_gm_ub`, each UB row is padded from `lenBurst` up to the **32B-aligned boundary** of `ub_stride` with `pad_val`. This ensures every UB row starts at a 32B-aligned offset. +- **UB→GM de-padding**: MTE3 reads `lenBurst` bytes from each 32B-aligned UB row (skipping any padding that was added during load), writing only valid data to GM. This effectively strips padding on store. + +--- + +## UB Copy Burst / Step Model + +This section describes the grouped UB-copy interface in this document: +`pto.mte_ub_ub` and `pto.mte_ub_l1`. + +For `pto.mte_ub_ub` and `pto.mte_ub_l1`, each burst copies `len_burst * 32` bytes. + +The next burst starts at: + +```text +src_next = src_curr + (len_burst + src_gap) * 32 bytes +dst_next = dst_curr + (len_burst + dst_gap) * 32 bytes +``` + +So `src_gap` and `dst_gap` are gap fields that advance to the next burst +after the copied 32B blocks. + +### 2D Diagram: GM→UB (`pto.mte_gm_ub`) + +``` +GM (source, `!pto.ptr`): + + |<--- src_stride (start-to-start) --->| + |<- len_burst ->| | +Row 0: [##DATA########]......................| +Row 1: [##DATA########]......................| +Row 2: [##DATA########]......................| + ... +Row N-1: [##DATA########] + +UB (destination, `!pto.ptr`, 32B-aligned): + + |<---------- dst_stride (32B-aligned) ---------->| + |<- len_burst ->|<- pad (to 32B boundary) ->| | +Row 0: [##DATA########][000000 PAD 000000000000000] +Row 1: [##DATA########][000000 PAD 000000000000000] +Row 2: [##DATA########][000000 PAD 000000000000000] + ... +Row N-1: [##DATA########][000000 PAD 000000000000000] + +N = n_burst +stride = start of row[r] to start of row[r+1] +pad = filled with pad_val to 32B boundary (`pad(...)` present) +[DATA] = valid data transferred by DMA +[PAD] = pad_val fill (from `pad(...)`) +``` + +### 2D Diagram: UB→GM (`pto.mte_ub_gm` with GM destination) + +``` +UB (source, `!pto.ptr`, 32B-aligned start addr): + + |<---------- src_stride (32B-aligned) --------->| + |<- len_burst ->|<-- pad (ignored on read) -->| | +Row 0: [##DATA########][000 pad 000000000000000000] +Row 1: [##DATA########][000 pad 000000000000000000] +Row 2: [##DATA########][000 pad 000000000000000000] + ... +Row N-1: [##DATA########][000 pad 000000000000000000] + +GM (destination, `!pto.ptr`): + + |<--- dst_stride (start-to-start) --->| + |<- len_burst ->| | +Row 0: [##DATA########]......................| +Row 1: [##DATA########]......................| +Row 2: [##DATA########]......................| + ... +Row N-1: [##DATA########] + +N = n_burst +MTE3 reads only len_burst bytes from each UB row (de-padding). +Only len_burst bytes are written to each GM row. +``` + +--- + +## Multi-Level Loop Semantics + +The full DMA transfer is a nested loop. `nburst(...)` is the innermost group. +If one or more `loop(...)` groups are present, they wrap `nburst(...)` in the +same order they appear in the op: the first `loop(...)` is the innermost outer +group, the second `loop(...)` wraps the first one, and so on. + +### GM→UB Full Loop + +For a form + +```mlir +pto.mte_gm_ub %gm_src, %ub_dst, %l2_cache_ctl, %len_burst + nburst(%n_burst, %src_stride, %dst_stride) + loop(%c0, %s0, %d0) + loop(%c1, %s1, %d1) + ... + loop(%cN, %sN, %dN) + [pad(%pad_value[, %left_padding_count, %right_padding_count])] +``` + +the transfer is equivalent to: + +```c +for (int lN = 0; lN < cN; ++lN) { + ... + for (int l1 = 0; l1 < c1; ++l1) { + for (int l0 = 0; l0 < c0; ++l0) { + uint8_t *gm_base = gm_src + l0 * s0 + l1 * s1 + ... + lN * sN; + uint8_t *ub_base = ub_dst + l0 * d0 + l1 * d1 + ... + lN * dN; + for (int r = 0; r < n_burst; ++r) { + memcpy(ub_base + r * dst_stride, + gm_base + r * src_stride, + len_burst); + if (pad_enabled) + memset(ub_base + r * dst_stride + len_burst, + pad_val, + dst_stride - len_burst); + } + } + } +} +``` + +If no `loop(...)` group is present, only the innermost `nburst(...)` loop +remains. + +### UB→Destination Full Loop + +For a form + +```mlir +pto.mte_ub_gm %ub_src, %dst, %len_burst + nburst(%n_burst, %src_stride, %dst_stride) + loop(%c0, %s0, %d0) + loop(%c1, %s1, %d1) + ... + loop(%cN, %sN, %dN) +``` + +the transfer is equivalent to: + +```c +for (int lN = 0; lN < cN; ++lN) { + ... + for (int l1 = 0; l1 < c1; ++l1) { + for (int l0 = 0; l0 < c0; ++l0) { + uint8_t *ub_base = ub_src + l0 * s0 + l1 * s1 + ... + lN * sN; + uint8_t *dst_base = dst + l0 * d0 + l1 * d1 + ... + lN * dN; + for (int r = 0; r < n_burst; ++r) { + memcpy(dst_base + r * dst_stride, + ub_base + r * src_stride, + len_burst); + } + } + } +} +``` + +If no `loop(...)` group is present, only the innermost `nburst(...)` loop +remains. + +--- + +## Example 1: GM→UB — Load a 32×32 f32 Tile (Simple Case) + +Load a 32×32 f32 tile from GM into UB. This matches the `abs_kernel_2d` test case. + +``` +GM layout (32 × 32 f32, contiguous): + + |<- len_burst = 128B (32 × 4) ->| + |<- src_stride = 128B --------->| + +--[#######TILE#######]--+ row 0 + +--[#######TILE#######]--+ row 1 + ... + +--[#######TILE#######]--+ row 31 + +UB layout (32 × 32 f32, 32B-aligned, contiguous): + + |<- dst_stride = 128B (32B-aligned) ->| + +--[#######TILE#######]--+ row 0 + +--[#######TILE#######]--+ row 1 + ... + +--[#######TILE#######]--+ row 31 + + len_burst = 32 × 4 = 128 bytes + src_stride = 128 bytes (contiguous rows) + dst_stride = 128 bytes (already 32B-aligned, no padding) +``` + +```mlir +// Simple 2D load — only nburst(...) is needed +pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64 +``` + +--- + +## Example 2: GM→UB — Load a 2D Tile from a Larger Matrix + +Load a 64×128 tile (f16) from a 1024×512 matrix in GM into UB. + +``` +GM layout (1024 × 512 f16): + + col 0 col 128 col 512 + | | | + +--[###TILE###]+.....................+ row R + +--[###TILE###]+.....................+ row R+1 + ... + +--[###TILE###]+.....................+ row R+63 + + |<--------- src_stride = 1024B ----------->| + |<-len_burst=256B->| + + len_burst = 128 × 2 = 256 bytes (128 f16 elements) + src_stride = 512 × 2 = 1024 bytes (start-to-start, full GM row) + +UB layout (64 × 128 f16, 32B-aligned, contiguous): + + +--[###TILE###]--+ row 0 (256 bytes, 32B-aligned, no pad) + +--[###TILE###]--+ row 1 + ... + +--[###TILE###]--+ row 63 + + dst_stride = 256 bytes (= len_burst, already 32B-aligned, no padding) +``` + +```mlir +pto.mte_gm_ub %gm_ptr, %ub_ptr, %c0_i64, %c256_i64 + nburst(%c64_i64, %c1024_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64 +``` + +--- + +## Example 3: GM→UB — Load with Padding + +Load 100 valid columns from GM into a 128-wide UB tile (f16). The remaining 28 columns are zero-padded. + +``` +GM (100 cols valid, contiguous): + + |<-len_burst=200B->| + |<- src_stride=200B (start-to-start) ->| + +--[####DATA####]-+ row 0 + +--[####DATA####]-+ row 1 + ... + +--[####DATA####]-+ row 63 + +UB (128 cols wide, 32B-aligned, padded): + + |<--------- dst_stride = 256B (32B-aligned) --------->| + |<-len_burst=200B->|<---- pad = 56B to 32B boundary ->| + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 0 + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 1 + ... + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 63 + + len_burst = 100 × 2 = 200 bytes + src_stride = 200 bytes (start-to-start, contiguous in GM) + dst_stride = 128 × 2 = 256 bytes (32B-aligned tile width in UB) + pad = 256 - 200 = 56 bytes (padded to 32B boundary with pad_val) +``` + +```mlir +%pad = arith.constant 0 : i16 +pto.mte_gm_ub %gm_ptr, %ub_ptr, %c0_i64, %c200_i64 + nburst(%c64_i64, %c200_i64, %c256_i64) + pad(%pad, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, pad i16, i64, i64 +``` + +--- + +## Example 4: UB→GM — Store a 32×32 f32 Tile (Simple Case) + +Store a 32×32 f32 tile from UB back to GM. This matches the `abs_kernel_2d` test case. + +``` +UB (source, 32B-aligned, 32 × 32 f32): + + |<- src_stride = 128B (32B-aligned) ->| + |<- len_burst = 128B ->| + +--[#######TILE#######]---+ row 0 + +--[#######TILE#######]---+ row 1 + ... + +--[#######TILE#######]---+ row 31 + + (no padding here — len_burst == src_stride) + +GM (dest, 32 × 32 f32): + + |<- dst_stride = 128B ->| + |<- len_burst = 128B -->| + +--[#######TILE#######]---+ row 0 + +--[#######TILE#######]---+ row 1 + ... + +--[#######TILE#######]---+ row 31 +``` + +```mlir +pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 +``` + +--- + +## Example 5: UB→GM — Store a 2D Tile Back to a Larger Matrix + +Store a 64×128 tile (f16) from UB back to a 1024×512 GM matrix at an offset. + +``` +UB (source, 32B-aligned, 64 × 128 f16): + + |<- src_stride = 256B (32B-aligned) ->| + |<- len_burst = 256B ->| + +--[#####TILE#####]---+ row 0 + +--[#####TILE#####]---+ row 1 + ... + +--[#####TILE#####]---+ row 63 + + (no padding here — len_burst == src_stride) + +GM (dest, into 1024 × 512 matrix): + + |<----------- dst_stride = 1024B (start-to-start) --------->| + |<- len_burst = 256B ->| | + col 0 col 128 col 512 + +--[#####TILE#####]---+.............................+ row R + +--[#####TILE#####]---+.............................+ row R+1 + ... + +--[#####TILE#####]---+.............................+ row R+63 + + MTE3 reads len_burst bytes from each 32B-aligned UB row, + writes only len_burst bytes per GM row (stride controls row spacing). +``` + +```mlir +pto.mte_ub_gm %ub_ptr, %gm_ptr, %c256_i64 + nburst(%c64_i64, %c256_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 +``` + +--- + +## Example 6: GM→UB with Multi-Level Loop (Batch of Tiles) + +Load 4 batches of 8×128 tiles from a [4, 8, 128] f16 tensor using one outer +`loop(...)` group. + +``` +GM [4, 8, 128] f16 (contiguous): UB (4 tiles laid out sequentially): + + batch 0: 8 rows × 256 bytes [batch 0: 8×128][batch 1: 8×128] + batch 1: 8 rows × 256 bytes [batch 2: 8×128][batch 3: 8×128] + batch 2: 8 rows × 256 bytes + batch 3: 8 rows × 256 bytes outer loop src_stride = 2048 bytes (8 × 256) + outer loop dst_stride = 2048 bytes (8 × 256) + Each batch = 8 × 256 = 2048 bytes outer loop count = 4 (iterate over batches) +``` + +```mlir +// One outer loop group over 4 batches +pto.mte_gm_ub %gm_ptr, %ub_ptr, %c0_i64, %c256_i64 + nburst(%c8_i64, %c256_i64, %c256_i64) + loop(%c4_i64, %c2048_i64, %c2048_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, loop i64, i64, i64 +``` + +Execution trace: + +``` +loop iter 0: gm_ptr + 0×2048 → ub_ptr + 0×2048, DMA 8 rows × 256B +loop iter 1: gm_ptr + 1×2048 → ub_ptr + 1×2048, DMA 8 rows × 256B +loop iter 2: gm_ptr + 2×2048 → ub_ptr + 2×2048, DMA 8 rows × 256B +loop iter 3: gm_ptr + 3×2048 → ub_ptr + 3×2048, DMA 8 rows × 256B +``` diff --git a/docs/isa/micro-isa/03-vector-load-store.md b/docs/isa/micro-isa/03-vector-load-store.md new file mode 100644 index 000000000..902287147 --- /dev/null +++ b/docs/isa/micro-isa/03-vector-load-store.md @@ -0,0 +1,600 @@ +# 3. Vector Load/Store + +> **Category:** UB ↔ Vector Register data movement +> **Pipeline:** PIPE_V (Vector Core) + +Vector loads move data from Unified Buffer (UB) to vector registers (`vreg`). Vector stores move data from `vreg` back to UB. All vector compute operates only on `vreg` — UB is the staging area between DMA and compute. + +## Common Operand Model + +- `%source` / `%dest` is the base address operand in SSA form. The base pointer + MUST address the Vector tile buffer / UB space. +- `%offset` is the displacement operand in SSA form. The exact encoding is + instruction-specific, but the effective address and any post-update behavior + MUST match the selected instruction form. +- `%mask` is the predicate operand for predicated memory families. For memory + families, + inactive lanes or inactive blocks MUST NOT issue memory requests unless the + instruction explicitly documents a different behavior. +- `%result` is the destination vector register value in SSA form. +- `!pto.align` is the SSA carrier for alignment-buffer state used by unaligned + load/store families. The PTO micro Instruction representation makes that state explicit rather than implicit. + +--- + +## Latency and throughput (A5) + +**Cycle-accurate simulator (CA model)** issue→retire timings for vector-side instructions behind this chapter. Values are **simulator** results, **not** guaranteed for silicon. + +**SOC:** Tables below are from **Ascend910_9599** CA sim (the pto-isa ST default when **Ascend950PR_9599** is not selected). + +**Log `dist:` tokens:** PTO load/store modes lower to **`RV_VLD` / `RV_VLDI` / `RV_VST` / `RV_VSTI`** with a **`dist:`** field on the vector pipes (`RVECLD` / `RVECST`). Some simulator logs typo contiguous load as `dist:NORAML`; treat as **`NORMAL`**. + +### Reference op latencies (A5 mnemonics) + +| A5 mnemonic | Mode / note | Typical issue→retire (cycles) | +|-------------|-------------|------------------------------| +| `RV_VLD` | `dist:NORMAL` / `NORAML` | **9** | +| `RV_VLDI` | `dist:DINTLV` (dual vreg) | **9** | +| `RV_VST` / `RV_VSTI` | `dist:NORM_B8` / `NORM_B16` / `NORM_B32` | **9** | +| `RV_VGATHER2` | `Dtype: B32` | **27–28** | +| `RV_VGATHERB` | indexed byte gather | **~21** | +| `RV_VSCATTER` | `Dtype: B16` | **~17** | +| `RV_VADD` | F32 between UB-backed ops | **7** | + +### `dist:` tokens (issue→retire) + +Most **`dist:`** tokens are **9** issue→retire cycles. **`INTLV_B8` / `INTLV_B16` / `INTLV_B32`** on **`RV_VSTI`** are **12** cycles. + +| `dist:` (as in log) | RV op | issue→retire (cycles) | +|---------------------|-------|----------------------| +| `DINTLV_B8` / `DINTLV_B16` / `DINTLV_B32` | `RV_VLDI` | **9** | +| `BRC_B8` / `BRC_B16` / `BRC_B32` | `RV_VLD` | **9** | +| `BRC_BLK` | `RV_VLD` | **9** | +| `INTLV_B8` / `INTLV_B16` / `INTLV_B32` | `RV_VSTI` | **12** | +| `UNPK_B8` / `UNPK_B16` / `UNPK_B32` | `RV_VLD` | **9** | +| `NORM_B8` / `NORM_B16` / `NORM_B32` | `RV_VSTI` | **9** | +| `PK_B16` / `PK_B32` / `PK_B64` / `PK4_B32` | `RV_VSTI` | **9** | +| `NORMAL` / `NORAML` | `RV_VLD` | **9** | + +**Note:** PTO intrinsic **`BRC_BLK`** matches the **`BRC_BLK`** `dist:` string on **`RV_VLD`** in simulator logs (block-replicate path; not a plain contiguous copy in the usual tiling use). + +**Issue (vector load/store):** `pto.vlds` (**`RV_VLD`**) is **dual-issue capable**: two independent `pto.vlds` can issue **in the same cycle**. **Alternatively**, the hardware can issue **one** `pto.vlds` **and** **one** `pto.vsts` together (**1+1**) in the same cycle. Each cycle is **either** dual **`vlds`** **or** **`vlds` + `vsts` (1+1)**—those two issue modes are mutually exclusive. Sustained throughput still depends on RAW hazards and loop structure. + +**Throughput (simulator, pattern-dependent):** + +- **`RV_VLD` / `pto.vlds`:** Dual-issue **or** half of a **1+1** with `vsts`, per the rule above. +- **`RV_VST` / `pto.vsts`:** In a **1+1** cycle, pairs with one `vlds`; otherwise typically **one** store per cycle in tight loops. +- **`RV_VGATHER2`:** Much lower than contiguous `RV_VLD` (on the order of **~0.1** ops/cycle in steady-state alongside 27–28-cycle latency). + +### PTO `dist` summary (loads) + +| PTO `dist` (load) | Latency | +|-------------------|-------------------| +| `NORM` | **9** cycles | +| `UNPK_B8` / `UNPK_B16` / `UNPK_B32` | **9** cycles | +| `DINTLV_B8` / `DINTLV_B16` / `DINTLV_B32` | **9** cycles (`RV_VLDI`) | +| `BRC_B8` / `BRC_B16` / `BRC_B32` | **9** cycles (`RV_VLD`) | +| `BRC_BLK` | **9** cycles as **`dist:BRC_BLK`** on `RV_VLD` | +| `BDINTLV` | **9** cycles | +| `US_B8` / `US_B16`, `DS_B8` / `DS_B16`, `SPLT4CHN`, `SPLT2CHN_B8` / `SPLT2CHN_B16` | **9** cycles | + +### PTO `dist` summary (stores) + +| PTO `dist` (store) | Latency | +|--------------------|-------------------| +| `NORM_B8` / `NORM_B16` / `NORM_B32` | **9** cycles (`RV_VSTI`) | +| `PK_B16` / `PK_B32` / `PK_B64` / `PK4_B32` | **9** cycles | +| `INTLV_B8` / `INTLV_B16` / `INTLV_B32` (`pto.vstsx2`) | **12** cycles | +| `MRG4CHN_B8`, `MRG2CHN_B8`, `MRG2CHN_B16` | **9** cycles (surface retained; current A5 hardware still reports them unsupported at validation time) | + +### Gather, scatter, and special addressing + +| PTO op | A5-level | Latency | +|--------|----------|-------------------| +| `pto.vgather2` | `RV_VGATHER2` | **27–28** cycles (pattern-dependent) | +| `pto.vgatherb` | `RV_VGATHERB` | **~21** cycles issue→retire | +| `pto.vgather2_bc` | (broadcast gather) | **27–28** cycles (same as **`pto.vgather2`**) | +| `pto.vscatter` | `RV_VSCATTER` | **~17** cycles for **`Dtype: B16`** | + +### Strided loads/stores, unaligned ops, alignment state + +Ops such as **`pto.vldas`**, **`pto.vldus`**, **`pto.vsld`**, **`pto.vsldb`**, **`pto.vsst`**, **`pto.vsstb`**, **`pto.vsta`**, **`pto.vstas`**, **`pto.vstar`**, **`pto.vstu`**, **`pto.vstus`**, **`pto.vstur`**: **9** cycles (same vector load/store pipe family as contiguous `RV_VLD` / `RV_VST` unless listed otherwise above). + +### Dual-issue vs DMA + +DMA **`TLOAD` / `TSTORE`** (global memory ↔ UB) use **MTE** pipes, not `RV_VLD`/`RV_VST`. **MTE2** `MOV_*` latency is not the same as vector `RV_VLD` latency (see `02-dma-copy.md` for GM↔UB movement). + +--- + +## Contiguous Loads + +### `pto.vlds` + +- **syntax:** `%result = pto.vlds %source[%offset] {dist = "DIST"} : !pto.ptr -> !pto.vreg` +- **semantics:** Vector load with distribution mode. +- **inputs:** + `%source` is the UB base address, `%offset` is the load displacement, and + `DIST` selects the distribution mode. +- **outputs:** + `%result` is the loaded vector register value. +- **constraints and limitations:** + The effective address MUST satisfy the alignment rule of the selected + distribution mode. `NORM` reads one full vector footprint. Broadcast, + upsample, downsample, unpack, split-channel, and deinterleave modes change + how memory bytes are mapped into destination lanes, but they do not change the + fact that the source is UB memory. PTO surface exposes load `dist` as family + tokens, and each family only supports the element widths listed below. + +**Distribution families:** + +| Family | Allowed element widths | C semantics | Latency | +|------|-------------|-------------|-------------| +| `NORM` | width-agnostic | `dst[i] = UB[base + i * sizeof(T)]` | **9** cycles | +| `BRC_B8` / `BRC_B16` / `BRC_B32` | `b8`, `b16`, `b32` | `dst[i] = UB[base]` for all `i` | **9** cycles | +| `US_B8` / `US_B16` | `b8`, `b16` | `dst[2*i] = dst[2*i+1] = UB[base + i]` | **9** cycles | +| `DS_B8` / `DS_B16` | `b8`, `b16` | `dst[i] = UB[base + 2*i]` | **9** cycles | +| `UNPK_B8` / `UNPK_B16` / `UNPK_B32` | `b8`, `b16`, `b32` | Expand packed source data into wider lanes | **9** cycles | +| `BRC_BLK` | width-agnostic | Block-replicate load path; simulator logs may print `dist:BRC_BLK` | **9** cycles | +| `E2B_B16` / `E2B_B32` | `b16`, `b32` | Load element groups and expand them into byte-oriented lane layout | **9** cycles | +| `UNPK4` | `b8` | Unpack 4-way packed `b8` source groups into destination lanes | **9** cycles | +| `SPLT4CHN` | `b8` | Split 4-channel interleaved source into one channel plane | **9** cycles | +| `SPLT2CHN_B8` / `SPLT2CHN_B16` | `b8`, `b16` | Split 2-channel interleaved source into one channel plane | **9** cycles | + +`pto.vlds` currently covers only single-result load families. Dual-result +deinterleave forms are modeled separately in PTO surface as +[`pto.vldsx2`](#ptovldsx2): `BDINTLV` is the block-deinterleave family, while +`DINTLV_B8` / `DINTLV_B16` / `DINTLV_B32` are the element-width-sensitive +deinterleave forms. + +**Example — Contiguous load:** +```mlir +%v = pto.vlds %ub[%offset] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> +``` + +**Example — Broadcast scalar to all lanes:** +```mlir +%v = pto.vlds %ub[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> +``` + +--- + +### `pto.vldas` + +- **syntax:** `%result = pto.vldas %source : !pto.ptr -> !pto.align` +- **semantics:** Prime alignment buffer for subsequent unaligned load. +- **inputs:** + `%source` is the UB address whose surrounding aligned block seeds the load + alignment state. +- **outputs:** + `%result` is the initialized load-alignment state. +- **constraints and limitations:** + This op is the required leading operation for a `pto.vldus` stream using the + same alignment state. The source address itself need not be 32-byte aligned; + hardware truncates it to the aligned block boundary for the priming load. +- **Latency:** **9** cycles. + +--- + +### `pto.vldus` + +- **syntax:** `%result, %align_out = pto.vldus %source, %align : !pto.ptr, !pto.align -> !pto.vreg, !pto.align` +- **semantics:** Unaligned load using primed align state. +- **inputs:** + `%source` is the current UB address and `%align` is the incoming load + alignment state primed by `pto.vldas` or a prior `pto.vldus`. +- **outputs:** + `%result` is the assembled vector value and `%align_out` is the updated + alignment state. +- **constraints and limitations:** + A matching `pto.vldas` MUST appear before the first dependent `pto.vldus` + stream in the same vector loop. The installed no-post A5 interface keeps a + struct-shaped internal return for lowering convenience, but its no-post + `base` field is not meaningful user-visible state. VPTO therefore hides that + value and only exposes the updated align carrier. Reusing the original + `%source` starts a new explicit access point; if the caller wants another + no-post access, it should compute the next source pointer explicitly and pair + it with the required align setup. +- **Latency:** **9** cycles. + +**Unaligned load pattern:** +```mlir +%align = pto.vldas %ub : !pto.ptr -> !pto.align +%vec, %align2 = pto.vldus %ub, %align : !pto.ptr, !pto.align -> !pto.vreg<64xf32>, !pto.align +``` + +--- + +### `pto.init_align` + +- **syntax:** `%result = pto.init_align : !pto.align` +- **semantics:** Initialize store-side align carrier state. +- **outputs:** + `%result` is a fresh zero-initialized align carrier for store-side unaligned + streams such as `pto.vstus`, `pto.vstur`, `pto.vstar`, `pto.vstas`, and + `pto.pstu`. +- **constraints and limitations:** + This op is for store-family initialization only. Unaligned load streams still + start from `pto.vldas`. + +--- + +## Dual Loads (Deinterleave) + +### `pto.vldsx2` + +- **syntax:** `%low, %high = pto.vldsx2 %source[%offset], "DIST" : !pto.ptr, index -> !pto.vreg, !pto.vreg` +- **semantics:** Dual load with deinterleave (AoS → SoA conversion). +- **inputs:** + `%source` is the UB base pointer, `%offset` is the displacement, and `DIST` + selects a dual-load/deinterleave layout. +- **outputs:** + `%low` and `%high` are the two destination vectors. +- **constraints and limitations:** + This family is only legal for interleave/deinterleave style distributions. + The two outputs form an ordered pair, and that pairing MUST be preserved. + PTO surface accepts deinterleave families. `BDINTLV` is element-width + agnostic, while `DINTLV_B8` / `DINTLV_B16` / `DINTLV_B32` support only the + element widths listed in the + table. +- **latency:** `BDINTLV` / `DINTLV_B8` / `DINTLV_B16` / `DINTLV_B32` are all + **9** cycles. + +**Distribution families:** + +| Family | Allowed element widths | C semantics | Latency | +|------|-------------|-------------|-------------| +| `BDINTLV` | width-agnostic | Block deinterleave into two destination vectors | **9** cycles | +| `DINTLV_B8` / `DINTLV_B16` / `DINTLV_B32` | `b8`, `b16`, `b32` | Deinterleave alternating elements into `%low` / `%high` | **9** cycles | + +```c +// DINTLV_B32 family on 32-bit elements: deinterleave 32-bit elements +for (int i = 0; i < 64; i++) { + low[i] = UB[base + 8*i]; // even elements + high[i] = UB[base + 8*i + 4]; // odd elements +} +``` + +**Example — Load interleaved XY pairs into separate X/Y vectors:** +```mlir +%x, %y = pto.vldsx2 %ub[%offset], "DINTLV_B32" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +``` + +### `pto.vsldb` + +- **syntax:** `%result = pto.vsldb %source, %block_stride, %repeat_stride, %mask : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg` +- **semantics:** Block-strided load for 2D tile access. +- **inputs:** + `%source` is the UB base pointer. `%block_stride` and `%repeat_stride` are + the two 16-bit fields of the hardware control word, and `%mask` controls + which blocks participate. +- **outputs:** + `%result` is the loaded vector. +- **constraints and limitations:** + PTO surface does not expose the packed control word directly. If a block is + masked off, the corresponding destination block is zeroed and MUST NOT raise + an address overflow exception for that block. +- **Latency:** **9** cycles. + +```c +// Block-strided load on 32-bit elements: one 32B block = 8 lanes. +for (int blk = 0; blk < 8; ++blk) { + if (pg_b32[blk]) + dst_block[blk] = UB_block[base + repeat_stride + blk * block_stride]; + else + dst_block[blk] = 0; +} +``` + +--- + +## Gather (Indexed) Loads + +### `pto.vgather2` + +- **syntax:** `%result = pto.vgather2 %source, %offsets, %mask : !pto.ptr, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Indexed gather from UB. +- **inputs:** + `%source` is the UB base pointer, `%offsets` provides per-lane element + offsets, and `%mask` selects the active requests. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + Only masked-on indices participate. The index element width + and interpretation MUST match the selected gather form, and each effective + address must satisfy that form's alignment rules. +- **Latency:** **27–28** cycles per `RV_VGATHER2`; throughput much lower than contiguous `RV_VLD` (see **Latency and throughput (A5)** at the start of this chapter). + +```c +for (int i = 0; i < N; i++) + if (mask[i]) + dst[i] = UB[base + offsets[i] * sizeof(T)]; +``` + +--- + +### `pto.vgatherb` + +- **syntax:** `%result = pto.vgatherb %source, %offsets, %mask : !pto.ptr, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Block gather load from UB. +- **inputs:** + `%source` is the UB base pointer, `%offsets` is a `ui32` offset vector, and + `%mask` is a `b32` predicate over the block-index lanes. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + This is a 32-byte block gather, not an element gather. `%source` MUST be + 32-byte aligned. Each participating `offsets[i]` is interpreted as a byte + offset and MUST itself be 32-byte aligned. Only the low `VL/8` bytes of the + offset vector are semantically valid; the effective block address is + `block_addr[i] = offsets_u32[i] + base`. If a `b32` predicate position is + false, the corresponding block does not participate in address coalescing, + does not raise overflow on that block address, and the destination block is + zero-filled. +- **Latency:** **~21** cycles issue→retire. + +```c +for (int blk = 0; blk < VL / 32; ++blk) { + if (pg_b32[blk]) + dst_block[blk] = UB_block[base + offsets_u32[blk]]; + else + dst_block[blk] = 0; +} +``` + +--- + +### `pto.vgather2_bc` + +- **syntax:** `%result = pto.vgather2_bc %source, %offsets, %mask : !pto.ptr, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Gather with broadcast, conditioned by mask. +- **inputs:** + `%source` is the UB base pointer, `%offsets` contains gather indices, and + `%mask` gates which lanes participate. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + This is a backward-compatible family. Masked-off lanes do not participate in + address coalescing and do not trigger address overflow exceptions; their + destination lanes are zero-filled. On the current PTO surface, `%offsets` + uses 32-bit integer elements. +- **Latency:** **27–28** cycles (same as **`pto.vgather2`**). + +--- + +## Contiguous Stores + +### `pto.vsts` + +- **syntax:** `pto.vsts %value, %dest[%offset], %mask {dist = "DIST"} : !pto.vreg, !pto.ptr, !pto.mask` +- **semantics:** Vector store with distribution mode. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, `%offset` is + the displacement, `%mask` is the predicate operand, and + `DIST` selects the store distribution. +- **outputs:** + This op has no SSA result; it writes to UB memory. +- **constraints and limitations:** + The effective destination address MUST satisfy the alignment rule of the + selected store mode. The single-input `pto.vsts` family covers contiguous + store, first-element-only store, packed store, and channel-merge store. + Dual-input interleave store remains in `pto.vstsx2`. PTO surface exposes + store `dist` as family tokens, and each family only supports the element + widths listed below. + +**Distribution families:** + +| Family | Allowed element widths | C semantics | Latency | +|------|-------------|-------------|-------------| +| `NORM_B8` / `NORM_B16` / `NORM_B32` | `b8`, `b16`, `b32` | `UB[base + i] = src[i]` | **9** cycles | +| `1PT_B8` / `1PT_B16` / `1PT_B32` | `b8`, `b16`, `b32` | Only element 0 is written to the destination footprint; the predicate register is ignored. | **9** cycles | +| `PK_B16` | `b16` | Pack the source vector, extract the lower half bits of all elements, and only store the active elements. The predicate is interpreted for 16-bit data. | **9** cycles | +| `PK_B32` | `b32` | Pack the source vector, extract the lower half bits of all elements, and only store the active elements. The predicate is interpreted for 32-bit data. | **9** cycles | +| `PK_B64` | `b64` | Pack the source vector, extract the lower half bits of all elements, and only store the active elements. The predicate is interpreted for 64-bit data. | **9** cycles | +| `PK4_B32` | `b32` | Pack the source vector, extract the lower 8 bits of all elements, and only store the active elements. The predicate is interpreted for 32-bit data. | **9** cycles | +| `MRG4CHN_B8` | `b8` | Merge 4 interleaved 8-bit channels within each 32B block; the predicate is interpreted for 32-bit data and applies after channel merge. | **9** cycles | +| `MRG2CHN_B8` / `MRG2CHN_B16` | `b8`, `b16` | Merge 2 interleaved channels within each 32B block; for `MRG2CHN_B8` the predicate is interpreted for 16-bit data, and for `MRG2CHN_B16` it is interpreted for 32-bit data; in both cases it applies after channel merge. | **9** cycles | + +**Example — Contiguous store:** +```mlir +pto.vsts %v, %ub[%offset], %mask {dist = "NORM_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +--- + +## Dual Stores (Interleave) + +### `pto.vstsx2` + +- **syntax:** `pto.vstsx2 %low, %high, %dest[%offset], "DIST", %mask : !pto.vreg, !pto.vreg, !pto.ptr, index, !pto.mask` +- **semantics:** Dual interleaved store (SoA → AoS conversion). +- **inputs:** + `%low` and `%high` are the two source vectors, `%dest` is the UB base pointer, + `%offset` is the displacement, `DIST` selects the interleave layout, and + `%mask` is the predicate operand. +- **outputs:** + This op has no SSA result; it writes an interleaved stream to UB. +- **constraints and limitations:** + This family is only legal for interleave distributions. The two source + vectors form an ordered pair, and the interleave semantics of that pair MUST + be preserved. PTO surface accepts the `INTLV` family, which only supports the + element widths listed below. For all `INTLV_*` distributions, the predicate + register is ignored. +- **latency:** `INTLV` is **12** cycles。 + +**Distribution families:** + +| Family | Allowed element widths | C semantics | Latency | +|------|-------------|-------------|-------------| +| `INTLV` | `b8`, `b16`, `b32` | Interleave `%low` / `%high` into one destination stream | **12** cycles | + +```c +// INTLV family on 32-bit elements: +for (int i = 0; i < 64; i++) { + UB[base + 8*i] = low[i]; + UB[base + 8*i + 4] = high[i]; +} +``` + +### `pto.vsstb` + +- **syntax:** `pto.vsstb %value, %dest, %block_stride, %repeat_stride, %mask : !pto.vreg, !pto.ptr, i16, i16, !pto.mask` +- **semantics:** Block-strided store for 2D tile access. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, + `%block_stride` and `%repeat_stride` are the two 16-bit fields of the + hardware control word, and `%mask` controls block participation. +- **outputs:** + This op writes UB memory and returns no SSA value. +- **constraints and limitations:** + PTO surface does not expose the packed control word directly. Masked-off + blocks MUST NOT issue memory writes. +- **Latency:** **9** cycles. + +```c +// Block-strided store on 32-bit elements: one 32B block = 8 lanes. +for (int blk = 0; blk < 8; ++blk) { + if (pg_b32[blk]) + UB_block[base + repeat_stride + blk * block_stride] = src_block[blk]; +} +``` + +--- + +## Scatter (Indexed) Stores + +### `pto.vscatter` + +- **syntax:** `pto.vscatter %value, %dest, %offsets, %mask : !pto.vreg, !pto.ptr, !pto.vreg, !pto.mask` +- **semantics:** Indexed scatter to UB. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, `%offsets` + provides per-lane or per-block indices, and `%mask` selects the active + requests. +- **outputs:** + This op writes UB memory and returns no SSA value. +- **constraints and limitations:** + Only `b8`, `b16`, and `b32` element sizes are supported. The index vector + must use a supported integer element type and layout for this family. + Each computed address MUST be element-aligned. If two or more indices alias, + only one write is guaranteed and the winning lane is implementation-defined. +- **Latency:** **~17** cycles for **`Dtype: B16`**. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) + UB[base + offsets[i] * sizeof(T)] = src[i]; +``` + +--- + +## Alignment State Stores + +### `pto.vstas` +- **syntax:** `pto.vstas %value, %dest, %offset : !pto.align, !pto.ptr, i32` +- **semantics:** Scalar-register-offset form of alignment-state flush. +- **inputs:** + `%value` is the pending store-alignment state, `%dest` is the UB base + pointer, and `%offset` is the scalar-register style displacement. +- **outputs:** + This op writes buffered tail bytes to UB and returns no SSA value. +- **constraints and limitations:** + This family flushes pending store-alignment state using an explicit scalar + offset and keeps the scalar-offset form explicit. The incoming `%value` + should come from `pto.init_align` or from a prior state-producing unaligned + store op in the same stream. `%dest` and `%offset` together must identify the + same logical flush point produced by the immediately preceding stateful + unaligned-store step on that stream; using an unrelated base/offset pair is + invalid even if `%value` itself came from the same stream. +- **Latency:** **9** cycles. + +--- + +### `pto.vstar` +- **syntax:** `pto.vstar %value, %dest : !pto.align, !pto.ptr` +- **semantics:** Flush alignment state using the register-update form. +- **inputs:** + `%value` is the pending store-alignment state and `%dest` is the UB base + pointer. +- **outputs:** + This op writes buffered tail bytes to UB and returns no SSA value. +- **constraints and limitations:** + The implicit update state consumed by this flush MUST correspond to the same + store stream that produced `%value`. The first store-side state in a stream + should be created by `pto.init_align`. +- **Latency:** **9** cycles. + +--- + +### `pto.vstar` + +- **syntax:** `pto.vstar %value, %dest : !pto.align, !pto.ptr` +- **semantics:** Flush remaining alignment state. +- **inputs:** + `%value` is the pending alignment/buffer state that still needs to be emitted, + and `%dest` is the UB destination base pointer. +- **outputs:** + No SSA result. The effect is a memory-side flush that writes the remaining + buffered bytes to memory. +- **constraints and limitations:** + This op terminates an unaligned-store sequence. It MUST be paired with a + compatible prior state-producing store sequence so that the pending tail state + is well-defined. +- **Latency:** **9** cycles. + +--- + +## Stateful Store Ops + +These ops make reference-updated state explicit as SSA results. + +### `pto.vstus` + +- **syntax:** `%align_out = pto.vstus %align_in, %offset, %value, %base : !pto.align, i32, !pto.vreg, !pto.ptr -> !pto.align` +- **semantics:** No-post unaligned store with scalar offset. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%offset` is the scalar + displacement, `%value` is the vector being stored, and `%base` is the UB base + pointer. +- **outputs:** + `%align_out` is the updated buffered-tail state. +- **constraints and limitations:** + This is the scalar-offset stateful form of the unaligned store family. The + scalar offset width MUST match the selected form, and a later flush op is + still required. The first `%align_in` in the stream should come from + `pto.init_align`. This op does not mean "store a full vector starting at + `%base + %offset`". Instead, `%offset` describes how far the store stream + advances at this step, and `%align_out` carries any residual tail that could + not be committed yet. The no-post surface does not expose an updated base + pointer. A later flush op must therefore use an explicit destination/offset + pair that identifies the same logical flush point as this `pto.vstus`. +- **Latency:** **9** cycles. + +--- + +### `pto.vstur` + +- **syntax:** `%align_out = pto.vstur %align_in, %value, %base, "MODE" : !pto.align, !pto.vreg, !pto.ptr -> !pto.align` +- **semantics:** Unaligned store with residual flush and SPR-AR-driven state update. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%value` is the vector to + store, `%base` is the UB base pointer, and `MODE` selects whether the + hardware updates `SPR AR` after the store. +- **outputs:** + `%align_out` is the updated residual state after the current partial store. +- **constraints and limitations:** + The effective address is `base + AR`, where `AR` is the hardware SPR state + carried outside SSA. `POST_UPDATE` means hardware may advance `SPR AR` + according to the fixed `SPR SQZN` configuration; `NO_POST_UPDATE` preserves + the current `SPR AR` value. This form exposes only the evolving residual + align-state in SSA; it does not by itself guarantee that all buffered bytes + have reached memory. A compatible final flush is still required unless the + surrounding sequence is known to be complete. Independent sequences typically + begin from `AR = 0`; if the surrounding program does not already guarantee + that, the hardware sequence should clear `SPR AR` before the first dependent + `pto.vstur`. The first `%align_in` in the stream should come from + `pto.init_align`. `pto.vstur` also consumes the fixed `SPR SQZN` state, so a + preceding squeeze producer such as `pto.vsqz` / `pto.vusqz` MUST establish + the byte count before the store. `MODE` MUST be one of `POST_UPDATE` or + `NO_POST_UPDATE`. +- **Latency:** **9** cycles. diff --git a/docs/isa/micro-isa/04-predicate-load-store.md b/docs/isa/micro-isa/04-predicate-load-store.md new file mode 100644 index 000000000..9c3bed11d --- /dev/null +++ b/docs/isa/micro-isa/04-predicate-load-store.md @@ -0,0 +1,135 @@ +# 4. Predicate Load/Store + +> **Category:** UB ↔ Predicate Register data movement +> **Pipeline:** PIPE_V (Vector Core) + +Predicate registers (`!pto.mask`) are 256-bit registers that enable per-lane conditional execution. These ops move predicate values between UB and predicate registers. + +In concrete examples, `G` should be chosen to match the consumer family. The +examples below use `b32` when the loaded/stored mask is used with `f32` +vector compares or selects. + +The predicate load/store ops documented on this page always use explicit +`base[offset]` addressing. The immediate forms (`pldi`, `psti`) and dynamic +forms (`plds`, `psts`) differ only in how `%offset` is supplied. + +--- + +## Predicate Loads + +### `pto.plds` + +- **syntax:** `%result = pto.plds %source[%offset], "DIST" : !pto.ptr, index -> !pto.mask` +- **semantics:** Load predicate register with runtime offset. This is the + dynamic-offset form of `pto.pldi`: the predicate payload interpretation is + the same, but `%offset` is supplied as an SSA `index` instead of a constant + `index` immediate. +- **DIST:** mandatory string token, one of `NORM`, `US`, `DS`. + - `NORM`: load a normal packed predicate payload of size `VL/8`. + - `US`: load a packed predicate payload of size `VL/16`, then duplicate each + loaded bit once. + - `DS`: load a packed predicate payload of size `2 * VL/8`, then keep one + bit out of every two bits. + +The loaded payload is a packed predicate image in UB. Consumer ops interpret +the resulting `!pto.mask` according to the mask granularity `G`. +`pto.plds` only +models the explicit `base[offset]` form. + +**Example:** +```mlir +%mask = pto.plds %ub[%c0], "NORM" : !pto.ptr, index -> !pto.mask +``` + +--- + +### `pto.pldi` + +- **syntax:** `%result = pto.pldi %source[%offset], "DIST" : !pto.ptr, index -> !pto.mask` +- **offset:** must be a constant `index` immediate in PTO surface form. +- **semantics:** Load predicate register with immediate offset. +- **DIST:** mandatory string token, one of `NORM`, `US`, `DS`. + - `NORM`: load a normal packed predicate payload of size `VL/8`. + - `US`: load a packed predicate payload of size `VL/16`, then duplicate each + loaded bit once. + - `DS`: load a packed predicate payload of size `2 * VL/8`, then keep one + bit out of every two bits. + +Like `pto.plds`, this op reads a packed predicate payload from UB and +materializes it as `!pto.mask`. + +--- + +## Predicate Stores + +### `pto.psts` + +- **syntax:** `pto.psts %value, %dest[%offset], "DIST" : !pto.mask, !pto.ptr, index` +- **semantics:** Store predicate register with runtime offset. This is the + dynamic-offset form of `pto.psti`: the predicate payload interpretation is + the same, but `%offset` is supplied as an SSA `index` instead of a constant + `index` immediate. +- **DIST:** mandatory string token, one of `NORM`, `PK`. + - `NORM`: store the packed predicate payload into a normal destination space + of size `VL/8`. + - `PK`: store the packed predicate payload into a destination space of size + `VL/16`, keeping one bit out of every two bits. + +`pto.psts` stores the packed predicate payload represented by `!pto.mask`. +It only models the explicit `base[offset]` form. + +**Example:** +```mlir +pto.psts %mask, %ub[%c0], "NORM" : !pto.mask, !pto.ptr, index +``` + +--- + +### `pto.psti` + +- **syntax:** `pto.psti %value, %dest[%offset], "DIST" : !pto.mask, !pto.ptr, index` +- **offset:** must be a constant `index` immediate in PTO surface form. +- **semantics:** Store predicate register with immediate offset. +- **DIST:** mandatory string token, one of `NORM`, `PK`. + - `NORM`: store the packed predicate payload into a normal destination space + of size `VL/8`. + - `PK`: store the packed predicate payload into a destination space of size + `VL/16`, keeping one bit out of every two bits. + +`pto.psti` and `pto.psts` store the packed predicate payload represented by +`!pto.mask`. The surface distinction is only immediate-offset versus +dynamic-offset. + +--- + +### `pto.pstu` + +- **syntax:** `%align_out, %base_out = pto.pstu %align_in, %value, %base : !pto.align, !pto.mask, !pto.ptr -> !pto.align, !pto.ptr` +- **syntax:** `%align_out, %base_out = pto.pstu %align_in, %value, %base : !pto.align, !pto.mask, !pto.ptr -> !pto.align, !pto.ptr` +- **semantics:** Predicate unaligned store with align/base state update. The base type is fixed by mask granularity: `b16 <-> ui16`, `b32 <-> ui32`. +- **outputs:** + `%align_out` and `%base_out` are the updated unaligned-store state and are + intended to be used by a later `pto.pstu` call. +- **constraints and limitations:** + The first `%align_in` in a predicate unaligned-store stream should come from + `pto.init_align`. + +--- + +## Typical Usage Pattern + +```mlir +// Generate comparison mask +%mask = pto.vcmp %v0, %v1, %seed, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + +// Store mask to UB for later use +pto.psts %mask, %ub_mask[%c0], "NORM" : !pto.mask, !pto.ptr, index + +// ... later in another kernel ... + +// Load mask from UB +%saved_mask = pto.plds %ub_mask[%c0], "NORM" : !pto.ptr, index -> !pto.mask + +// Use for predicated select +%result = pto.vsel %v_true, %v_false, %saved_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` diff --git a/docs/isa/micro-isa/05-materialization-predicate.md b/docs/isa/micro-isa/05-materialization-predicate.md new file mode 100644 index 000000000..e6ee34975 --- /dev/null +++ b/docs/isa/micro-isa/05-materialization-predicate.md @@ -0,0 +1,322 @@ +# 5. Materialization & Predicate Ops + +> **Category:** Scalar broadcast, predicate generation and manipulation +> **Pipeline:** PIPE_V (Vector Core) + +These ops create vectors from scalar values and manipulate predicate registers. + +## Common Operand Model + +- `%value` is the scalar source value in SSA form. +- `%input` is either a source scalar or a source vector depending on the op. +- `%result` is the destination vector register value. +- For 32-bit scalar inputs, the scalar source MUST satisfy the backend's legal + scalar-source constraints for this family. + +--- + +## Scalar Materialization + +### `pto.vbr` + +- **syntax:** `%result = pto.vbr %value : T -> !pto.vreg` +- **semantics:** Broadcast scalar to all vector lanes. +- **inputs:** + `%value` is the scalar source. +- **outputs:** + `%result` is a vector whose active lanes all carry `%value`. +- **constraints and limitations:** + Supported forms are `b8`, `b16`, and `b32`. For `b8`, only the low 8 bits of + the scalar source are consumed. + +```c +for (int i = 0; i < N; i++) + dst[i] = value; +``` + +**Example:** +```mlir +%one = pto.vbr %c1_f32 : f32 -> !pto.vreg<64xf32> +``` + +--- + +### `pto.vdup` + +- **syntax:** `%result = pto.vdup %input, %mask {position = "LOWEST|HIGHEST"} : T|!pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Duplicate scalar or vector element to all lanes. +- **inputs:** + `%input` supplies the scalar or source-lane value selected by `position`, + and `%mask` controls the active lanes. +- **outputs:** + `%result` is the duplicated vector. +- **constraints and limitations:** + `position` selects which source vector element is duplicated and is only valid + for vector input. `position` defaults to `LOWEST`. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? input_scalar_or_element : 0; +``` + +--- + +## Predicate Generation + +### `pto.pset_b8` / `pto.pset_b16` / `pto.pset_b32` + +- **syntax:** `%result = pto.pset_b8 "PATTERN" : !pto.mask` +- **syntax:** `%result = pto.pset_b16 "PATTERN" : !pto.mask` +- **syntax:** `%result = pto.pset_b32 "PATTERN" : !pto.mask` +- **semantics:** Materialize a predicate register from a named pattern token. + +**Supported pattern tokens:** + +| Pattern | Description | +|---------|-------------| +| `PAT_ALL` | All lanes active | +| `PAT_ALLF` | All lanes inactive | +| `PAT_H` | High half active | +| `PAT_Q` | Upper quarter active | +| `PAT_VL1`...`PAT_VL128` | First N logical lanes active | +| `PAT_M3`, `PAT_M4` | Modular patterns | + +`PAT_ALL` is the PTO spelling of the VISA-style all-true predicate pattern. +The other tokens listed above are also concrete installed-toolchain pattern +objects, not PTO-only aliases. + +**Example — All 64 f32 lanes active:** +```mlir +%all_active = pto.pset_b32 "PAT_ALL" : !pto.mask +``` + +**Example — First 16 lanes active:** +```mlir +%first_16 = pto.pset_b32 "PAT_VL16" : !pto.mask +``` + +--- + +### `pto.pge_b8` / `pto.pge_b16` / `pto.pge_b32` + +- **syntax:** `%result = pto.pge_b8 "PATTERN" : !pto.mask` +- **syntax:** `%result = pto.pge_b16 "PATTERN" : !pto.mask` +- **syntax:** `%result = pto.pge_b32 "PATTERN" : !pto.mask` +- **semantics:** Generate a predicate from a lane-count pattern token. In the + common tail-mask form, `PAT_VL` marks the first `N` logical lanes active. +- **supported pattern tokens:** `PAT_ALL`, `PAT_ALLF`, `PAT_H`, `PAT_Q`, + `PAT_VL1`, `PAT_VL2`, `PAT_VL3`, `PAT_VL4`, `PAT_VL8`, `PAT_VL16`, + `PAT_VL32`, `PAT_VL64`, `PAT_VL128`, `PAT_M3`, `PAT_M4` + +```c +for (int i = 0; i < TOTAL_LANES; i++) + mask[i] = (i < len); +``` + +**Example — Tail mask for remainder loop:** +```mlir +%tail_mask = pto.pge_b32 "PAT_VL8" : !pto.mask +``` + +--- + +### `pto.plt_b8` / `pto.plt_b16` / `pto.plt_b32` + +- **syntax:** `%mask, %scalar_out = pto.plt_b8 %scalar : i32 -> !pto.mask, i32` +- **syntax:** `%mask, %scalar_out = pto.plt_b16 %scalar : i32 -> !pto.mask, i32` +- **syntax:** `%mask, %scalar_out = pto.plt_b32 %scalar : i32 -> !pto.mask, i32` +- **semantics:** Generate a tail-style predicate from an SSA lane-count value. + On A5/V300-style toolchains, this family is exposed as a post-update wrapper: + the predicate result becomes `%mask`, and the wrapper's carry-out scalar state + is surfaced as `%scalar_out`. +- **inputs:** + `%scalar` is the incoming lane-count / remaining-count state. +- **outputs:** + `%mask` is the generated predicate. + `%scalar_out` is the post-update scalar carry-out from the same `plt` call + and can be threaded into a subsequent `pto.plt_b*` call in the same chain. + +```c +for (int i = 0; i < VL_t; ++i) + mask[i] = (i < scalar_in); + +scalar_out = (scalar_in < VL_t) ? 0 : (scalar_in - VL_t); +``` + +Where `VL_t` is the logical lane count of the concrete op variant: + +- `pto.plt_b8`: `VL_t = 256` +- `pto.plt_b16`: `VL_t = 128` +- `pto.plt_b32`: `VL_t = 64` + +--- + +## Predicate Pack/Unpack + +### `pto.ppack` + +- **syntax:** `%result = pto.ppack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Narrowing pack of predicate register. +- **part tokens:** + - `LOWER`: pack into the lower half of `%result`; the upper half is zeroed. + - `HIGHER`: pack into the higher half of `%result`; the lower half is zeroed. + +Conceptually, `pto.ppack` keeps one bit out of each adjacent 2-bit group from +`%input`, packs those kept bits into the selected half of `%result`, and fills +the other half with zeros. + +```c +// Let VL be the logical lane count of the destination predicate. +// LOWER +for (int i = 0; i < VL / 2; ++i) + result[i] = input[2 * i]; +for (int i = VL / 2; i < VL; ++i) + result[i] = 0; + +// HIGHER +for (int i = 0; i < VL / 2; ++i) + result[VL / 2 + i] = input[2 * i]; +for (int i = 0; i < VL / 2; ++i) + result[i] = 0; +``` + +--- + +### `pto.punpack` + +- **syntax:** `%result = pto.punpack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Widening unpack of predicate register. +- **part tokens:** + - `LOWER`: unpack from the lower half of `%input`. + - `HIGHER`: unpack from the higher half of `%input`. + +Conceptually, `pto.punpack` reads the selected half of `%input`, zero-extends +each 1-bit predicate element into a 2-bit group in `%result`, and leaves the +expanded image in the full destination predicate register. + +```c +// Let VL be the logical lane count of the destination predicate. +// LOWER +for (int i = 0; i < VL / 2; ++i) { + result[2 * i] = input[i]; + result[2 * i + 1] = 0; +} + +// HIGHER +for (int i = 0; i < VL / 2; ++i) { + result[2 * i] = input[VL / 2 + i]; + result[2 * i + 1] = 0; +} +``` + +--- + +## Predicate Logical Ops + +### `pto.pand` + +- **syntax:** `%result = pto.pand %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise AND gated by a governing predicate. + +Inactive lanes selected out by `%mask` are zeroed. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? (src0[i] & src1[i]) : 0; +``` + +--- + +### `pto.por` + +- **syntax:** `%result = pto.por %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise OR gated by a governing predicate. + +Inactive lanes selected out by `%mask` are zeroed. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? (src0[i] | src1[i]) : 0; +``` + +--- + +### `pto.pxor` + +- **syntax:** `%result = pto.pxor %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise XOR gated by a governing predicate. + +Inactive lanes selected by `%mask` are zeroed. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? (src0[i] ^ src1[i]) : 0; +``` + +--- + +### `pto.pnot` + +- **syntax:** `%result = pto.pnot %input, %mask : !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise NOT gated by a governing predicate. + +Inactive lanes selected by `%mask` are zeroed. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? (~src[i]) : 0; +``` + +--- + +### `pto.psel` + +- **syntax:** `%result = pto.psel %src0, %src1, %sel : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate select (mux). `%sel` is the governing predicate that + chooses lanes from `%src0` or `%src1`. + +```c +for (int i = 0; i < N; i++) + dst[i] = sel[i] ? src0[i] : src1[i]; +``` + +--- + +### `pto.pdintlv_b8` / `pto.pdintlv_b16` / `pto.pdintlv_b32` + +- **syntax:** `%low, %high = pto.pdintlv_b8 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **syntax:** `%low, %high = pto.pdintlv_b16 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **syntax:** `%low, %high = pto.pdintlv_b32 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **semantics:** De-interleave two predicate sources and return the two + de-interleaved predicate images in the same predicate element family. + +--- + +### `pto.pintlv_b8` / `pto.pintlv_b16` / `pto.pintlv_b32` + +- **syntax:** `%low, %high = pto.pintlv_b8 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **syntax:** `%low, %high = pto.pintlv_b16 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **syntax:** `%low, %high = pto.pintlv_b32 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **semantics:** Interleave two predicate sources and return the two + resulting predicate images in the same predicate element family. + +--- + +## Typical Usage + +```mlir +// Generate all-active mask for f32 (64 lanes) +%all = pto.pset_b32 "PAT_ALL" : !pto.mask + +// Generate tail mask for remainder (last 12 elements) +%tail = pto.pge_b32 "PAT_VL12" : !pto.mask + +// Compare and generate mask +%cmp_mask = pto.vcmp %a, %b, %all, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + +// Combine masks: only process tail elements that passed comparison +%combined = pto.pand %cmp_mask, %tail, %all : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + +// Use for predicated operation +%result = pto.vsel %true_vals, %false_vals, %combined : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` diff --git a/docs/isa/micro-isa/06-unary-vector-ops.md b/docs/isa/micro-isa/06-unary-vector-ops.md new file mode 100644 index 000000000..8eff4fcec --- /dev/null +++ b/docs/isa/micro-isa/06-unary-vector-ops.md @@ -0,0 +1,173 @@ +# 6. Unary Vector Ops + +> **Category:** Single-input vector operations +> **Pipeline:** PIPE_V (Vector Core) + +Element-wise operations that take one vector input and produce one vector output. + +## Common Operand Model + +- `%input` is the source vector register value. +- `%mask` is the predicate operand. For this family, inactive lanes follow the + predication behavior of the selected instruction form: zeroing forms + zero-fill inactive lanes, while merging forms preserve the destination value. +- `%result` is the destination vector register value. Unless stated otherwise, + `%result` has the same lane count and element type as `%input`. + +## CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). **fp16** values use **aclFloat16** in traces where measured. **bf16:** no simple-tile ST coverage on this surface; treat as **—**. + +| PTO op | RV (CA) | fp32 | fp16 | bf16 | +|--------|---------|------|------|------| +| `pto.vabs` | `RV_VABS_FP` | **5** | **5** | — | +| `pto.vneg` | `RV_VMULS` | **8** | **8** | — | +| `pto.vexp` | `RV_VEXP` | **16** | **21** | — | +| `pto.vln` | `RV_VLN` | **18** | **23** | — | +| `pto.vsqrt` | `RV_VSQRT` | **17** | **22** | — | +| `pto.vrelu` | `RV_VRELU` | **5** | **5** | — | +| `pto.vnot` | `RV_VNOT` | — | int-only paths | — | +| `pto.vmov` | `RV_VLD` proxy | **9** | **9** | — | + +--- + +## Arithmetic + +### `pto.vabs` + +- **syntax:** `%result = pto.vabs %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] < 0) ? -src[i] : src[i]; +``` + +- **inputs:** `%input` supplies the source lanes and `%mask` selects which lanes + participate. +- **outputs:** `%result` receives the lane-wise absolute values. +- **constraints and limitations:** Source and result types MUST match. On A5, + integer overflow follows the ISA default truncation behavior for this family; + `pto.vabs` is not an explicit saturating op. + +--- + +### `pto.vneg` + +- **syntax:** `%result = pto.vneg %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = -src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise arithmetic negation. +- **constraints and limitations:** Source and result types MUST match. + +--- + +## Transcendental + +### `pto.vexp` + +- **syntax:** `%result = pto.vexp %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = expf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds `exp(input[i])` per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + +--- + +### `pto.vln` + +- **syntax:** `%result = pto.vln %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = logf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the natural logarithm per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + For real-number semantics, active inputs SHOULD be strictly positive; non- + positive inputs follow the target's exception/NaN rules. + +--- + +### `pto.vsqrt` + +- **syntax:** `%result = pto.vsqrt %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = sqrtf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the square root per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + Negative active inputs follow the target's exception/NaN rules. + +--- + +## Activation + +### `pto.vrelu` + +- **syntax:** `%result = pto.vrelu %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** si32, i32, f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] > 0) ? src[i] : 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds `max(input[i], 0)` per active lane. +- **constraints and limitations:** Signed or signless 32-bit integer and + floating-point element types are legal on the current A5 surface described + here. + +--- + +## Bitwise + +### `pto.vnot` + +- **syntax:** `%result = pto.vnot %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = ~src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the lane-wise bitwise inversion. +- **constraints and limitations:** Integer element types only. + +--- + +## Movement + +## Typical Usage + +```mlir +// Softmax numerator: exp(x - max) +%sub = pto.vsub %x, %max_broadcast, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%exp = pto.vexp %sub, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// ReLU activation +%activated = pto.vrelu %linear_out, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` diff --git a/docs/isa/micro-isa/07-binary-vector-ops.md b/docs/isa/micro-isa/07-binary-vector-ops.md new file mode 100644 index 000000000..0ab4ae554 --- /dev/null +++ b/docs/isa/micro-isa/07-binary-vector-ops.md @@ -0,0 +1,293 @@ +# 7. Binary Vector Ops + +> **Category:** Two-input vector operations +> **Pipeline:** PIPE_V (Vector Core) + +Element-wise operations that take two vector inputs and produce one vector output. + +## Common Operand Model + +- `%lhs` and `%rhs` are the two source vector register values. +- `%mask` is the predicate operand `Pg` that gates which lanes participate. +- `%result` is the destination vector register value. Unless explicitly noted, + it has the same lane count and element type as the inputs. +- Unless explicitly documented otherwise, `%lhs`, `%rhs`, and `%result` MUST + have matching vector shapes and element types. + +## CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). **fp16** uses **aclFloat16** in measured traces. **bf16:** — (no dedicated vec tile ST on this surface). + +| PTO op | RV (CA) | fp32 | fp16 | bf16 | +|--------|---------|------|------|------| +| `pto.vadd` | `RV_VADD` | **7** | **7** | — | +| `pto.vsub` | `RV_VSUB` | **7** | **7** | — | +| `pto.vmul` | `RV_VMUL` | **8** | **8** | — | +| `pto.vdiv` | `RV_VDIV` | **17** | **22** | — | + +--- + +## Arithmetic + +### `pto.vadd` + +- **syntax:** `%result = pto.vadd %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i64, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] + src1[i]; +``` + +- **inputs:** `%lhs` and `%rhs` are added lane-wise; `%mask` selects active + lanes. +- **outputs:** `%result` is the lane-wise sum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +### `pto.vsub` + +- **syntax:** `%result = pto.vsub %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i64, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] - src1[i]; +``` + +- **inputs:** `%lhs` is the minuend, `%rhs` is the subtrahend, and `%mask` + selects active lanes. +- **outputs:** `%result` is the lane-wise difference. +- **constraints and limitations:** Input and result types MUST match. + +--- + +### `pto.vmul` + +- **syntax:** `%result = pto.vmul %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, bf16, f32 (**NOT** i8/ui8) + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] * src1[i]; +``` + +- **inputs:** `%lhs` and `%rhs` are multiplied lane-wise; `%mask` selects + active lanes. +- **outputs:** `%result` is the lane-wise product. +- **constraints and limitations:** The current A5 profile excludes `i8/ui8` + forms from this surface. + +--- + +### `pto.vdiv` + +- **syntax:** `%result = pto.vdiv %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 only (no integer division) + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] / src1[i]; +``` + +- **inputs:** `%lhs` is the numerator, `%rhs` is the denominator, and `%mask` + selects active lanes. +- **outputs:** `%result` is the lane-wise quotient. +- **constraints and limitations:** Floating-point element types only. Active + denominators containing `+0` or `-0` follow the target's exceptional + behavior. + +--- + +### `pto.vmax` + +- **syntax:** `%result = pto.vmax %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src0[i] > src1[i]) ? src0[i] : src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` holds the lane-wise maximum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +### `pto.vmin` + +- **syntax:** `%result = pto.vmin %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src0[i] < src1[i]) ? src0[i] : src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` holds the lane-wise minimum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +## Bitwise + +### `pto.vand` + +- **syntax:** `%result = pto.vand %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] & src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise AND. +- **constraints and limitations:** Integer element types only. + +--- + +### `pto.vor` + +- **syntax:** `%result = pto.vor %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] | src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise OR. +- **constraints and limitations:** Integer element types only. + +--- + +### `pto.vxor` + +- **syntax:** `%result = pto.vxor %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] ^ src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise XOR. +- **constraints and limitations:** Integer element types only. + +--- + +## Shift + +### `pto.vshl` + +- **syntax:** `%result = pto.vshl %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] << src1[i]; +``` + +- **inputs:** `%lhs` supplies the shifted value, `%rhs` supplies the per-lane + shift amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. Shift counts + SHOULD stay within `[0, bitwidth(T) - 1]`; out-of-range behavior is target- + defined unless the verifier narrows it further. + +--- + +### `pto.vshr` + +- **syntax:** `%result = pto.vshr %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] >> src1[i]; // arithmetic for signed, logical for unsigned +``` + +- **inputs:** `%lhs` supplies the shifted value, `%rhs` supplies the per-lane + shift amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. Signedness of the + element type determines arithmetic vs logical behavior. + +--- + +## Carry Operations + +### `pto.vaddc` + +- **syntax:** `%result, %carry = pto.vaddc %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Add with carry output. + +```c +for (int i = 0; i < N; i++) { + uint64_t r = (uint64_t)src0[i] + src1[i]; + dst[i] = (T)r; + carry[i] = (r >> bitwidth); +} +``` + +- **inputs:** `%lhs` and `%rhs` are added lane-wise and `%mask` selects active + lanes. +- **outputs:** `%result` is the truncated arithmetic result and `%carry` is the + carry/overflow predicate per lane. +- **A5 types:** `i32`, `si32`, `ui32` +- **constraints and limitations:** This is a carry-chain integer add family. On + the current A5 surface, only 32-bit integer element types are supported. + `%mask` and `%carry` therefore use the same typed-mask granularity as the + data vector family, which on the current documented A5 surface means + `!pto.mask`. + +--- + +### `pto.vsubc` + +- **syntax:** `%result, %carry = pto.vsubc %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Subtract with per-lane carry output. + +```c +for (int i = 0; i < N; i++) { + dst[i] = src0[i] - src1[i]; + carry[i] = (src0[i] >= src1[i]); +} +``` + +- **inputs:** `%lhs` and `%rhs` are subtracted lane-wise and `%mask` selects + active lanes. +- **outputs:** `%result` is the arithmetic difference and `%carry` is the + per-lane carry predicate. For this subtraction family, active lanes set + `%carry[i] = 1` when the subtraction completes without borrow, and + `%carry[i] = 0` when a borrow occurs. +- **A5 types:** `i32`, `si32`, `ui32` +- **constraints and limitations:** This operation is currently restricted to + the 32-bit integer carry/borrow-chain family. `%mask` and `%carry` + therefore use the same typed-mask granularity as the data vector family, + which on the current documented A5 surface means `!pto.mask`. + +--- + +## Typical Usage + +```mlir +// Vector addition +%sum = pto.vadd %a, %b, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Element-wise multiply +%prod = pto.vmul %x, %y, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Clamp to range [min, max] +%clamped_low = pto.vmax %input, %min_vec, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%clamped = pto.vmin %clamped_low, %max_vec, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Bit manipulation +%masked = pto.vand %data, %bitmask, %mask : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +``` diff --git a/docs/isa/micro-isa/08-vec-scalar-ops.md b/docs/isa/micro-isa/08-vec-scalar-ops.md new file mode 100644 index 000000000..9ef60d3cb --- /dev/null +++ b/docs/isa/micro-isa/08-vec-scalar-ops.md @@ -0,0 +1,236 @@ +# 8. Vec-Scalar Ops + +> **Category:** Vector-scalar operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that combine a vector with a scalar value, applying the scalar to every lane. + +## Common Operand Model + +- `%input` is the source vector register value. +- `%scalar` is the scalar operand in SSA form. +- `%mask` is the predicate operand. +- `%result` is the destination vector register value. +- For 32-bit scalar forms, the scalar source MUST satisfy the backend's legal + scalar-source constraints for this family. +- For elementwise vec-scalar families whose scalar conceptually matches the + vector element type (`pto.vadds`, `pto.vmuls`, `pto.vmaxs`, + `pto.vmins`, `pto.vlrelu`): + - signed integer vectors accept signed integer scalars with the same width, + and also accept signless `i` + - unsigned integer vectors accept unsigned integer scalars with the same + width, and also accept signless `i` + - signless integer vectors accept signless `i` +- `pto.vshls` and `pto.vshrs` are not part of that rule; their scalar operand + is the shift amount and remains fixed to `i16`. + +## CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). **fp16** uses **aclFloat16** in measured traces. **bf16:** —. + +| PTO op | RV (CA) | fp32 | fp16 | bf16 | +|--------|---------|------|------|------| +| `pto.vadds` | `RV_VADDS` | **7** | **7** | — | +| `pto.vmuls` | `RV_VMULS` | **8** | **8** | — | + +--- + +## Arithmetic + +### `pto.vadds` + +- **syntax:** `%result = pto.vadds %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` +- **A5 types:** `si8`, `si16`, `si32`, `ui8`, `ui16`, `ui32`, `f16`, `bf16`, `f32` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] + scalar; +``` + +- **inputs:** `%input` is the source vector, `%scalar` is broadcast logically to + each lane, and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise sum. +- **constraints and limitations:** Input vector element type, scalar type, and + result vector element type MUST match. For integer vector forms, `%scalar` + may also use matching-signedness integer or signless `i` with the same + bit width as the vector element type, so it can be fed directly from `arith` + constants. + +--- + +### `pto.vmuls` + +- **syntax:** `%result = pto.vmuls %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] * scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise product. +- **constraints and limitations:** Supported element types are hardware-family + specific; the current PTO micro Instruction documentation covers the common + numeric cases. For integer vector forms, `%scalar` may use matching-signedness + integer or signless `i` with the same bit width as the vector element + type. + +--- + +### `pto.vmaxs` + +- **syntax:** `%result = pto.vmaxs %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] > scalar) ? src[i] : scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise maximum. +- **constraints and limitations:** Input and result types MUST match. For + integer vector forms, `%scalar` may use matching-signedness integer or + signless `i` with the same bit width as the vector element type. + +--- + +### `pto.vmins` + +- **syntax:** `%result = pto.vmins %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] < scalar) ? src[i] : scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise minimum. +- **constraints and limitations:** Input and result types MUST match. For + integer vector forms, `%scalar` may use matching-signedness integer or + signless `i` with the same bit width as the vector element type. + +--- + +## Shift + +### `pto.vshls` + +- **syntax:** `%result = pto.vshls %input, %scalar, %mask : !pto.vreg, i16, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] << scalar; +``` + +- **inputs:** `%input` is the value vector, `%scalar` is the uniform `i16` shift + amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. The shift amount + SHOULD stay within the source element width. + +--- + +### `pto.vshrs` + +- **syntax:** `%result = pto.vshrs %input, %scalar, %mask : !pto.vreg, i16, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] >> scalar; +``` + +- **inputs:** `%input` is the value vector, `%scalar` is the uniform `i16` shift + amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. + +--- + +### `pto.vlrelu` + +- **syntax:** `%result = pto.vlrelu %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : scalar * src[i]; +``` + +- **inputs:** `%input` is the activation vector, `%scalar` is the leaky slope, + and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise leaky-ReLU result. +- **constraints and limitations:** Only `f16` and `f32` forms are currently + documented for `pto.vlrelu`. + +--- + +## Carry Operations + +### `pto.vaddcs` + +- **syntax:** `%result, %carry = pto.vaddcs %lhs, %rhs, %carry_in, %mask : !pto.vreg, !pto.vreg, !pto.mask, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Add with carry-in and carry-out. + +```c +for (int i = 0; i < N; i++) { + uint64_t r = (uint64_t)src0[i] + src1[i] + carry_in[i]; + dst[i] = (T)r; + carry_out[i] = (r >> bitwidth); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the value vectors, `%carry_in` is the + incoming carry predicate, and `%mask` selects active lanes. +- **outputs:** `%result` is the arithmetic result and `%carry` is the carry-out + predicate. +- **A5 types:** `i32`, `si32`, `ui32` +- **constraints and limitations:** This is the scalar-extended carry-chain + family. On the current A5 surface, only 32-bit integer element types are + supported. `%carry_in`, `%mask`, and `%carry` therefore all use the same + typed-mask granularity as the data vector family, which on the current + documented A5 surface means `!pto.mask`. + +--- + +### `pto.vsubcs` + +- **syntax:** `%result, %carry = pto.vsubcs %lhs, %rhs, %carry_in, %mask : !pto.vreg, !pto.vreg, !pto.mask, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Subtract with carry input and output. + +```c +for (int i = 0; i < N; i++) { + dst[i] = src0[i] - src1[i] - (1 - carry_in[i]); + carry_out[i] = (src0[i] >= src1[i] + (1 - carry_in[i])); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the value vectors, `%carry_in` is the + incoming carry predicate, and `%mask` selects active lanes. +- **outputs:** `%result` is the arithmetic result and `%carry` is the + carry predicate after the lane-wise subtraction. For this subtraction family, + active lanes set `%carry[i] = 1` when the subtraction completes without + borrow, and `%carry[i] = 0` when a borrow occurs. +- **A5 types:** `i32`, `si32`, `ui32` +- **constraints and limitations:** This is the scalar-extended borrow-chain + family and is currently restricted to 32-bit integer element types. + `%carry_in`, `%mask`, and `%carry` therefore all use the same typed-mask + granularity as the data vector family, which on the current documented A5 + surface means `!pto.mask`. + +--- + +## Typical Usage + +```mlir +// Add bias to all elements +%biased = pto.vadds %activation, %bias_scalar, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Scale by constant +%scaled = pto.vmuls %input, %scale, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Clamp to [0, 255] for uint8 quantization +%clamped_low = pto.vmaxs %input, %c0, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> +%clamped = pto.vmins %clamped_low, %c255, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Shift right by fixed amount +%shifted = pto.vshrs %data, %c4, %mask : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> +``` diff --git a/docs/isa/micro-isa/09-conversion-ops.md b/docs/isa/micro-isa/09-conversion-ops.md new file mode 100644 index 000000000..090921a86 --- /dev/null +++ b/docs/isa/micro-isa/09-conversion-ops.md @@ -0,0 +1,349 @@ +# 9. Conversion Ops + +> **Category:** Type conversion operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that convert between data types (float/int, narrowing/widening). + +## Common Operand Model + +- `%input` is the source vector register value. +- `%mask` is the predicate mask that selects active conversion lanes. +- `%result` is the destination vector register value. +- `rnd`, `sat`, and `part` are optional attributes that refine + conversion behavior when the selected source/destination type pair needs + rounding, saturation, or lane placement control. +- The single `pto.vcvt` surface covers float-int, float-float, int-float, and + int-int conversion families. + +## CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). Only representative traces below; other `pto.vcvt` conversion pairs depend on the RV lowering in the trace. + +| PTO op | RV (CA) | Note | Latency | +|--------|---------|------|---------| +| `pto.vcvt` | `RV_VCVT_F2F` | f32→f16 | **7** | +| `pto.vci` | — | no vector `RV_*` in sampled `veccore0` trace | — | + +--- + +## `pto.vci` + +- **syntax:** `%result = pto.vci %index {order = "ASC|DESC"} : T -> !pto.vreg` +- **semantics:** Generate a lane-index vector from a scalar base value. +- **inputs:** + `%index` is the scalar base value. Supported scalar types are `i8/i16/i32`, + `f16`, and `f32`. +- **outputs:** + `%result` is the generated index vector. +- **constraints and limitations:** + This is an index-generation family, not a numeric conversion. `order` and + the result element type together determine whether lanes are generated as + `base + lane_id` or `base - lane_id`. Supported result types are + `!pto.vreg<256xsi8>`, `!pto.vreg<128xsi16>`, `!pto.vreg<64xsi32>`, + `!pto.vreg<128xf16>`, and `!pto.vreg<64xf32>`. `%index` must use the + matching scalar type for `f16`/`f32`; for integer results, `%index` must use + the same bit width and may be signless or signed. + +--- + +## `pto.vcvt` + +- **syntax:** `%result = pto.vcvt %input, %mask {rnd = "RND", sat = "SAT", part = "PART"} : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Type conversion between float/int types with rounding control. + +```c +for (int i = 0; i < min(N, M); i++) + if (mask[i]) + dst[i] = convert(src[i], T0, T1, rnd); +``` + +- **inputs:** + `%input` is the source vector, `%mask` selects active lanes, and attributes + select rounding, saturation, and output placement when the conversion changes + width or packs into sub-lane positions. +- **outputs:** + `%result` is the converted vector. +- **constraints and limitations:** + Only documented source/destination type pairs are legal. All three + attributes are optional at the surface level, but only the subset meaningful + to the selected conversion kind should be provided. The execution mask must + use the typed-mask granularity that matches the source vector family on the + current surface; there is no `!pto.mask` form in VPTO. + +--- + +### Rounding Modes + +| Mode | Description | +|------|-------------| +| `R` | Round to nearest, ties to even (default) | +| `A` | Round away from zero | +| `F` | Round toward negative infinity (floor) | +| `C` | Round toward positive infinity (ceil) | +| `Z` | Round toward zero (truncate) | +| `O` | Round to odd | + +--- + +### Saturation Modes + +| Mode | Description | +|------|-------------| +| `SAT` | Saturate on overflow | +| `NOSAT` | No saturation (wrap/undefined on overflow) | + +--- + +### Part Modes + +Use `part` when a width-changing conversion writes only one half of each wider +destination lane group. + +- `Part` (`PART_EVEN`, `PART_ODD`) + - Used by ordinary width-changing conversions. + - Typical cases include `32 -> 16`, `16 -> 32`, and other even/odd packing + or unpacking forms. +- `Part_T` (`PART_P0`, `PART_P1`, `PART_P2`, `PART_P3`) + - Used by lower-level packed placement forms. + - Typical cases include `32 -> 8`, packed fp8/fp4 conversion paths, and + other flows where the result is written into one of four sub-parts before a + later merge or compact step. + +| Mode | Description | +|------|-------------| +| `EVEN` | Output to even-indexed lanes | +| `ODD` | Output to odd-indexed lanes | +| `P0` | Output to sub-part 0 in 4-way packed placement forms | +| `P1` | Output to sub-part 1 in 4-way packed placement forms | +| `P2` | Output to sub-part 2 in 4-way packed placement forms | +| `P3` | Output to sub-part 3 in 4-way packed placement forms | + +--- + +### Attribute Guidance + +- `rnd` + - Use when the conversion needs an explicit rounding rule, especially for + float-to-int, float-to-float narrowing, or integer-to-float forms that do + not map exactly. +- `mask` + - Use to select which source lanes participate in the conversion. In + width-changing conversions, `mask` works together with `part` / `pp` to + determine which logical lane positions are produced. +- `sat` + - Use when the conversion may overflow the destination range and hardware + exposes a saturating form. +- `part` + - Use for width-changing conversions that select the even or odd half of the + destination packing layout. + +#### Float To Int + +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<32xsi64>` +- `%dst = pto.vcvt %src, %mask {rnd, sat} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xsi32>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {rnd, part} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xsi32>` +- `%dst = pto.vcvt %src, %mask {rnd, sat} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<256xsi8>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<64xsi32>` + +#### Float To Float + +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xbf16>` +- `%dst = pto.vcvt %src, %mask {rnd, sat} : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<128xf16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<64xf32>` + +#### Int To Float + +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<128xf16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xsi8>, !pto.mask -> !pto.vreg<128xf16>` +- `%dst = pto.vcvt %src, %mask {rnd} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<128xf16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<64xf32>` +- `%dst = pto.vcvt %src, %mask {rnd} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<64xf32>` + +#### Int To Int + +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<128xui16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<64xui32>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xsi8>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xsi8>, !pto.mask -> !pto.vreg<64xsi32>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<64xui32>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<64xui32>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<64xsi32>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<128xui16>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<128xui16>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<32xsi64>` + +### A5 Supported Type Matrix + +The table below is only a summary. For exact attribute combinations, use the +per-form entries above as the source of truth. + +| `src \ dst` | `ui8` | `si8` | `ui16` | `si16` | `ui32` | `si32` | `si64` | `f16` | `f32` | `bf16` | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| `ui8` | | | Y | | Y | | | Y | | | +| `si8` | | | | Y | | Y | | Y | | | +| `ui16` | Y | | | | Y | | | | | | +| `si16` | Y | | | | Y | Y | | Y | Y | | +| `ui32` | Y | | Y | Y | | | | | | | +| `si32` | Y | | Y | Y | | | Y | | Y | | +| `si64` | | | | | | | | | | | +| `f16` | Y | Y | | Y | | Y | | | Y | | +| `f32` | | | | Y | | Y | Y | Y | | Y | +| `bf16` | | | | | | Y | | Y | Y | | + +--- + +### Width-Changing Conversion Pattern + +For conversions that change width (e.g., f32→f16), use even/odd parts and combine: + +```mlir +// Convert two f32 vectors to one f16 vector +%even = pto.vcvt %in0, %mask {rnd = "R", sat = "SAT", part = "EVEN"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +%odd = pto.vcvt %in1, %mask {rnd = "R", sat = "SAT", part = "ODD"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +%result = pto.vor %even, %odd, %mask : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> +``` + +--- + +## `pto.vtrc` + +- **syntax:** `%result = pto.vtrc %input, %mask, "RND" : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Truncate/round float to integer-valued float (stays in float type). + +```c +for (int i = 0; i < N; i++) + dst[i] = round_to_int_valued_float(src[i], rnd); +``` + +- **inputs:** + `%input` is the floating-point source vector, `%mask` selects active lanes, + and `RND` selects the truncation/rounding rule. +- **outputs:** + `%result` is still a floating-point vector, but each active lane now carries + an integer-valued floating-point result. +- **constraints and limitations:** + This op does not change the element type. `T` must be `f16`, `f32`, or + `bf16`. `RND` must be one of `R`, `A`, `F`, `C`, or `Z`. `BW` must match the + element width: `b16` for `f16`/`bf16`, `b32` for `f32`. + +**Example:** +```mlir +// Round to nearest integer, keep as float +%rounded = pto.vtrc %input, %mask, "R" : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// input: [1.4, 2.6, -1.5, 3.0] +// output: [1.0, 3.0, -2.0, 3.0] +``` + +--- + +## Typical Usage + +```mlir +// Quantization: f32 → i8 with saturation +%scaled = pto.vmuls %input, %scale, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> +%quantized = pto.vcvt %scaled, %mask {rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xi32> +// Then narrow i32 → i8 via pack ops + +// Mixed precision: bf16 → f32 for accumulation +%f32_vec = pto.vcvt %bf16_input, %mask {part = "EVEN"} + : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<64xf32> + +// Floor for integer division +%floored = pto.vtrc %ratio, %mask, "F" : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%int_div = pto.vcvt %floored, %mask {rnd = "Z"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xi32> +``` + +--- + +## `pto.vbitcast` + +- **syntax:** `%result = pto.vbitcast %input : !pto.vreg -> !pto.vreg` +- **semantics:** Bitwise reinterpretation of a vreg vector without changing the underlying bit pattern. This operation performs a pure type cast that preserves the exact bits of each element, changing only their interpretation (e.g., from floating-point to integer). + +- **inputs:** + `%input` is the source vector register value. +- **outputs:** + `%result` is the reinterpreted vector register value. +- **constraints and limitations:** + 1. Both source and result must be `!pto.vreg<...>` types. + 2. Source and result vectors must have the same total bit width (currently 2048 bits). + 3. Only integer and floating-point element types are supported. + +**Element bit-width equality examples:** +- `f32<64>` → `i32<64>` (both 32-bit elements, total 2048 bits) +- `f16<128>` → `i16<128>` (both 16-bit elements, total 2048 bits) +- `bf16<128>` → `ui16<128>` (both 16-bit elements, total 2048 bits) +- `si32<64>` → `ui32<64>` (both 32-bit elements, total 2048 bits) +- `f32<64>` → `i16<128>` (32-bit/16-bit elements, total 2048 bits) + +**Verification:** The operation verifies that: +1. Both input and result are `!pto.vreg<...>` types. +2. Total bit width equals 2048 (the fixed vreg size). + +**Comparison with `pto.vcvt`:** +- `pto.vcvt` performs value conversion with rounding, saturation, and lane placement control. +- `pto.vbitcast` performs bitwise reinterpretation without changing the underlying bit pattern. + +**Example: Reinterpreting float as integer for bit manipulation** +```mlir +// Prepare a vector of float values +%fvec = pto.vlds %ub[%lane] : !pto.ptr -> !pto.vreg<64xf32> + +// Reinterpret as integer for bitwise operations +%ivec = pto.vbitcast %fvec : !pto.vreg<64xf32> -> !pto.vreg<64xi32> + +// Extract sign bit (bit 31) +%sign_bits = pto.vand %ivec, %sign_mask, %mask : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> + +// Reinterpret back to float +%fvec_without_sign = pto.vbitcast %sign_bits : !pto.vreg<64xi32> -> !pto.vreg<64xf32> +``` + +**Example: Type punning between signed and unsigned integer** +```mlir +// Convert signed to unsigned without changing bits +%signed = pto.vlds %ub[%lane] : !pto.ptr -> !pto.vreg<64xsi32> +%unsigned = pto.vbitcast %signed : !pto.vreg<64xsi32> -> !pto.vreg<64xui32> +// Bits are identical; interpretation changes from signed to unsigned +``` + +## `pto.pbitcast` + +- **syntax:** `%result = pto.pbitcast %input : !pto.mask -> !pto.mask` +- **semantics:** Bitwise reinterpretation of a predicate register without + changing the underlying predicate-register image. This op makes mask-family + reinterpretation explicit in VPTO IR when a producer and consumer expect + different `!pto.mask<...>` views of the same hardware predicate state. + +- **inputs:** + `%input` is the source predicate register value. +- **outputs:** + `%result` is the reinterpreted predicate register value. +- **constraints and limitations:** + 1. Both source and result must be `!pto.mask<...>` types. + 2. `pto.pbitcast` does not materialize or normalize predicate contents; it + only changes which mask granularity the surrounding VPTO IR uses to + interpret the same predicate bits. + +**Example: Reinterpret a b16 predicate as b32 before a consumer** +```mlir +%m16 = pto.pintlv_b16 %lhs, %rhs : !pto.mask, !pto.mask -> !pto.mask, !pto.mask +%m32 = pto.pbitcast %m16#0 : !pto.mask -> !pto.mask +%result = pto.vsel %a, %b, %m32 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` diff --git a/docs/isa/micro-isa/10-reduction-ops.md b/docs/isa/micro-isa/10-reduction-ops.md new file mode 100644 index 000000000..b264d6386 --- /dev/null +++ b/docs/isa/micro-isa/10-reduction-ops.md @@ -0,0 +1,246 @@ +# 10. Reduction Ops + +> **Category:** Vector reduction operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that reduce a vector to a scalar or per-group result. + +## Common Operand Model + +- `%input` is the source vector register value. +- `%mask` is the predicate operand `Pg`; inactive lanes do not participate. +- `%result` is the destination vector register value. +- Reduction results are written into the low-significance portion of the + destination vector and the remaining destination bits are zero-filled. + +--- + +## Full Vector Reductions + +### `pto.vcadd` + +- **syntax:** `%result = pto.vcadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, f32 +- **semantics:** Sum all elements. Result in lane 0, others zeroed. + +```c +T sum = 0; +for (int i = 0; i < N; i++) + sum += src[i]; +dst[0] = sum; +for (int i = 1; i < N; i++) + dst[i] = 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains the reduction result in its low element(s). +- **constraints and limitations:** On A5, `i8/u8` inputs produce widened + `i16/u16` results with half as many lanes (`M = N / 2`), and `i16/u16` inputs + produce widened `i32/u32` results with half as many lanes. For + `i32/u32/f16/f32` inputs, `U = T` and `M = N`. If all predicate bits are + zero, the result is zero. + +--- + +### `pto.vcmax` + +- **syntax:** `%result = pto.vcmax %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Find max element with argmax. The lowest destination element + stores the maximum value, the second-lowest destination element stores the + index of the first maximum, and all remaining elements are zero-filled. + +```c +T mx = -INF; int idx = 0; +for (int i = 0; i < N; i++) + if (src[i] > mx) { mx = src[i]; idx = i; } +dst[0] = mx; +dst[1] = idx; +for (int i = 2; i < N; i++) + dst[i] = 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result[0]` holds the extremum value and `%result[1]` holds the + index. Other destination elements are zero-filled. +- **constraints and limitations:** If there are multiple maxima, the minimum + index is written. For floating-point types, inactive lanes are treated as + `-INF`; if all lanes are inactive, `%result[0]` becomes `-INF`. For integer + types, inactive lanes are treated as the literal minimum value; if all lanes + are inactive, `%result[0]` becomes that literal minimum value. The index is + written into the second destination element slot of the same destination + vector register. + +--- + +### `pto.vcmin` + +- **syntax:** `%result = pto.vcmin %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Find min element with argmin. The lowest destination element + stores the minimum value, the second-lowest destination element stores the + index of the first minimum, and all remaining elements are zero-filled. + +```c +T mn = INF; int idx = 0; +for (int i = 0; i < N; i++) + if (src[i] < mn) { mn = src[i]; idx = i; } +dst[0] = mn; +dst[1] = idx; +for (int i = 2; i < N; i++) + dst[i] = 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result[0]` holds the extremum value and `%result[1]` holds the + index. Other destination elements are zero-filled. +- **constraints and limitations:** If there are multiple minima, the minimum + index is written. For floating-point types, inactive lanes are treated as + `+INF`; if all lanes are inactive, `%result[0]` becomes `+INF`. For integer + types, inactive lanes are treated as the literal maximum value; if all lanes + are inactive, `%result[0]` becomes that literal maximum value. The index is + written into the second destination element slot of the same destination + vector register. + +--- + +## Per-VLane (Group) Reductions + +The vector register is organized as **8 VLanes** of 32 bytes each. Group reductions operate within each VLane independently. + +``` +vreg layout (f32 example, 64 elements total): +VLane 0: [0..7] VLane 1: [8..15] VLane 2: [16..23] VLane 3: [24..31] +VLane 4: [32..39] VLane 5: [40..47] VLane 6: [48..55] VLane 7: [56..63] +``` + +### `pto.vcgadd` + +- **syntax:** `%result = pto.vcgadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Sum within each VLane. 8 results at indices 0, 8, 16, 24, 32, 40, 48, 56 (for f32). + +```c +int K = N / 8; // elements per VLane +for (int g = 0; g < 8; g++) { + T sum = 0; + for (int i = 0; i < K; i++) + sum += src[g*K + i]; + dst[g*K] = sum; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +// For f32: results at dst[0], dst[8], dst[16], dst[24], dst[32], dst[40], dst[48], dst[56] +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one sum per 32-byte VLane group, written + contiguously into the low slot of each group. +- **constraints and limitations:** This is a per-32-byte VLane-group reduction. + Inactive lanes are treated as zero. + +--- + +### `pto.vcgmax` + +- **syntax:** `%result = pto.vcgmax %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Max within each VLane. + +```c +int K = N / 8; +for (int g = 0; g < 8; g++) { + T mx = -INF; + for (int i = 0; i < K; i++) + if (src[g*K + i] > mx) mx = src[g*K + i]; + dst[g*K] = mx; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one maximum per 32-byte VLane group. +- **constraints and limitations:** Grouping is by hardware 32-byte VLane, not by + arbitrary software subvector. + +--- + +### `pto.vcgmin` + +- **syntax:** `%result = pto.vcgmin %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Min within each VLane. + +```c +int K = N / 8; +for (int g = 0; g < 8; g++) { + T mn = INF; + for (int i = 0; i < K; i++) + if (src[g*K + i] < mn) mn = src[g*K + i]; + dst[g*K] = mn; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one minimum per 32-byte VLane group. +- **constraints and limitations:** Grouping is by hardware 32-byte VLane, not by + arbitrary software subvector. + +--- + +## Prefix Operations + +### `pto.vcpadd` + +- **syntax:** `%result = pto.vcpadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Inclusive prefix sum (scan). + +```c +dst[0] = src[0]; +for (int i = 1; i < N; i++) + dst[i] = dst[i-1] + src[i]; +``` + +**Example:** +```c +// input: [1, 2, 3, 4, 5, ...] +// output: [1, 3, 6, 10, 15, ...] +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` is the inclusive prefix-sum vector. +- **constraints and limitations:** Only floating-point element types are + documented on the current A5 surface here. + +--- + +## Typical Usage + +```mlir +// Softmax: find max for numerical stability +%max_vec = pto.vcmax %logits, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// max is in lane 0, broadcast it +%max_broadcast = pto.vlds %ub_tmp[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> + +// Row-wise sum using vcgadd (for 8-row tile) +%row_sums = pto.vcgadd %tile, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// Results at indices 0, 8, 16, 24, 32, 40, 48, 56 + +// Full vector sum for normalization +%total = pto.vcadd %values, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// total[0] contains the sum + +// Prefix sum for cumulative distribution +%cdf = pto.vcpadd %pdf, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` diff --git a/docs/isa/micro-isa/11-compare-select.md b/docs/isa/micro-isa/11-compare-select.md new file mode 100644 index 000000000..bc28f2fd1 --- /dev/null +++ b/docs/isa/micro-isa/11-compare-select.md @@ -0,0 +1,182 @@ +# 11. Compare & Select + +> **Category:** Comparison and conditional selection operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that compare vectors and conditionally select elements. + +## Common Operand Model + +- `%src0` and `%src1` are source vector operands. +- `%scalar` is the scalar operand for scalar-comparison families. +- `%seed` is the incoming predicate that limits which lanes participate in the + compare. +- `%result` is either a predicate mask (`vcmp`, `vcmps`) or a vector register + (`vsel`, `vselr`, `vselrv2`). + +--- + +## Comparison Operations + +### `pto.vcmp` + +- **syntax:** `%result = pto.vcmp %src0, %src1, %seed, "CMP_MODE" : !pto.vreg, !pto.vreg, !pto.mask -> !pto.mask` +- **semantics:** Element-wise comparison, output predicate mask. + +```c +for (int i = 0; i < N; i++) + if (seed[i]) + dst[i] = (src0[i] CMP src1[i]) ? 1 : 0; +``` + +**Compare modes:** + +| Mode | Operation | +|------|-----------| +| `eq` | Equal (==) | +| `ne` | Not equal (!=) | +| `lt` | Less than (<) | +| `le` | Less than or equal (<=) | +| `gt` | Greater than (>) | +| `ge` | Greater than or equal (>=) | + +**Example:** +```mlir +%all_active = pto.pset_b32 "PAT_ALL" : !pto.mask +%lt_mask = pto.vcmp %a, %b, %all_active, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +// lt_mask[i] = 1 if a[i] < b[i] +``` + +- **inputs:** `%src0`, `%src1`, and `%seed`; `CMP_MODE` selects the comparison + predicate. +- **outputs:** `%result` is the generated predicate mask. +- **constraints and limitations:** Only lanes enabled by `%seed` participate. + Integer and floating-point comparisons follow their own element-type-specific + comparison rules. `%seed` and `%result` keep the typed-mask granularity that + matches `%src0` / `%src1`. + +--- + +### `pto.vcmps` + +- **syntax:** `%result = pto.vcmps %src, %scalar, %seed, "CMP_MODE" : !pto.vreg, T, !pto.mask -> !pto.mask` +- **semantics:** Compare vector against scalar. + +```c +for (int i = 0; i < N; i++) + if (seed[i]) + dst[i] = (src[i] CMP scalar) ? 1 : 0; +``` + +**Example:** +```mlir +%positive_mask = pto.vcmps %values, %c0_f32, %all_active, "gt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +// positive_mask[i] = 1 if values[i] > 0 +``` + +- **inputs:** `%src` is the vector source, `%scalar` is the scalar comparison + value, and `%seed` is the incoming predicate. +- **outputs:** `%result` is the generated predicate mask. +- **constraints and limitations:** For 32-bit scalar forms, the scalar source + MUST satisfy the backend's legal scalar-source constraints for this family. + `%seed` and `%result` keep the typed-mask granularity that matches `%src`. + +--- + +## Selection Operations + +### `pto.vsel` + +- **syntax:** `%result = pto.vsel %src0, %src1, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Per-lane select based on mask. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? src0[i] : src1[i]; +``` + +**Example — Conditional assignment:** +```mlir +// dst = mask ? true_vals : false_vals +%result = pto.vsel %true_vals, %false_vals, %condition + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +- **inputs:** `%src0` is the true-path vector, `%src1` is the false-path vector, + and `%mask` selects between them. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** Source vectors and result MUST have matching + vector shapes and element types. `%mask` keeps the typed-mask granularity + that matches the selected vector family. + +--- + +### `pto.vselr` + +- **syntax:** `%result = pto.vselr %src, %idx : !pto.vreg, !pto.vreg> -> !pto.vreg` +- **semantics:** Lane-select by index vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = src[idx[i]]; +``` + +- **inputs:** `%src` is the source vector. `%idx` is the lane-index vector. +- **outputs:** `%result` is the reordered vector. +- **constraints and limitations:** `%idx` must use integer elements. `%idx` + must have the same lane count as `%src`, and its integer element width must + match the bit width of `%src` element type. + +--- + +### `pto.vselrv2` + +- **syntax:** `%result = pto.vselrv2 %src0, %src1 : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Variant select form with the same current two-vector operand shape. +- **inputs:** `%src0` and `%src1` are the source vectors. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** This page records the surface shape only. + Lowering MUST preserve the exact A5 variant semantics selected for this form. + +--- + +## Typical Usage + +```mlir +// Clamp negative values to zero (manual ReLU) +%all = pto.pset_b32 "PAT_ALL" : !pto.mask +%zero = pto.vbr %c0_f32 : f32 -> !pto.vreg<64xf32> +%neg_mask = pto.vcmps %input, %c0_f32, %all, "lt" : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%clamped = pto.vsel %zero, %input, %neg_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Element-wise max via compare+select +%gt_mask = pto.vcmp %a, %b, %all, "gt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +%max_ab = pto.vsel %a, %b, %gt_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Threshold filter +%above_thresh = pto.vcmps %scores, %threshold, %all, "ge" : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%filtered = pto.vsel %scores, %zero, %above_thresh : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +--- + +## Compare + Select Pattern + +```mlir +// Softmax safe exp: exp(x - max) where x < max returns exp of negative +// but we want to clamp to avoid underflow + +%all = pto.pset_b32 "PAT_ALL" : !pto.mask + +// 1. Compare against threshold +%too_small = pto.vcmps %x_minus_max, %min_exp_arg, %all, "lt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask + +// 2. Clamp values below threshold +%clamped = pto.vsel %min_exp_arg_vec, %x_minus_max, %too_small + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// 3. Safe exp +%exp_result = pto.vexp %clamped, %all : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` diff --git a/docs/isa/micro-isa/12-data-rearrangement.md b/docs/isa/micro-isa/12-data-rearrangement.md new file mode 100644 index 000000000..359e7c306 --- /dev/null +++ b/docs/isa/micro-isa/12-data-rearrangement.md @@ -0,0 +1,230 @@ +# 12. Data Rearrangement + +> **Category:** In-register data movement and permutation +> **Pipeline:** PIPE_V (Vector Core) + +Operations that rearrange data within or between vector registers without memory access. + +## Common Operand Model + +- `%lhs` / `%rhs` are source vector register values. +- `%src` is a single source vector register value. +- `%result` is the destination vector register value unless an op explicitly + returns multiple vectors. +- These families do not access UB directly; they only rearrange register + contents. + +--- + +## Interleave / Deinterleave + +### `pto.vintlv` + +- **syntax:** `%low, %high = pto.vintlv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg, !pto.vreg` +- **semantics:** Interleave elements from two sources. + +```c +// Interleave: merge even/odd elements from two sources +// low = {src0[0], src1[0], src0[1], src1[1], ...} +// high = {src0[N/2], src1[N/2], src0[N/2+1], src1[N/2+1], ...} +``` + +- **inputs:** `%lhs` and `%rhs` are the two source vectors. +- **outputs:** `%low` and `%high` are the two destination vectors. +- **constraints and limitations:** The two outputs form a paired interleave + result. The PTO micro Instruction representation exposes that pair as two SSA results, and the pair ordering MUST + be preserved. + +--- + +### `pto.vdintlv` + +- **syntax:** `%low, %high = pto.vdintlv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg, !pto.vreg` +- **semantics:** Deinterleave elements into even/odd. + +```c +// Deinterleave: separate even/odd elements +// low = {src0[0], src0[2], src0[4], ...} // even +// high = {src0[1], src0[3], src0[5], ...} // odd +``` + +- **inputs:** `%lhs` and `%rhs` represent the interleaved source stream in the + current PTO micro Instruction representation. +- **outputs:** `%low` and `%high` are the separated destination vectors. +- **constraints and limitations:** The two outputs form the even/odd + deinterleave result pair, and their ordering MUST be preserved. + +--- + +## Compress / Expand + +### `pto.vsqz` + +- **syntax:** `%result = pto.vsqz %src, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Compress — pack active lanes to front. + +```c +int j = 0; +for (int i = 0; i < N; i++) + if (mask[i]) dst[j++] = src[i]; +while (j < N) dst[j++] = 0; +``` + +**Use case:** Sparse data compaction, filtering. + +- **inputs:** `%src` is the source vector and `%mask` selects which elements are + kept. +- **outputs:** `%result` is the compacted vector. +- **constraints and limitations:** This is a reduction-style compaction family. + Preserved element order MUST match source lane order. + +--- + +### `pto.vusqz` + +- **syntax:** `%result = pto.vusqz %src, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Generate per-lane prefix counts from the governing predicate. + +```c +dst[0] = 0; +for (int i = 1; i < N; i++) + dst[i] = mask[i - 1] ? (dst[i - 1] + 1) : dst[i - 1]; +``` + +- **inputs:** `%mask` is the governing predicate. The current PTO surface keeps + `%src` in the operand list for interface compatibility, but the observable + result semantics are determined by `%mask`. +- **outputs:** `%result[i]` equals the number of active lanes in `%mask[0:i)`, + with `%result[0] = 0`. +- **constraints and limitations:** `T` is currently limited to `si8`, `si16`, + or `si32`. This operation is a predicate-derived counting/rearrangement + primitive rather than a value-placement primitive. The final predicate lane + does not contribute to a later output lane because there is no `dst[N]`. + +--- + +--- + +### `pto.vselr` + +- **syntax:** `%result = pto.vselr %src, %idx : !pto.vreg, !pto.vreg> -> !pto.vreg` +- **semantics:** Register lane-select with an explicit index vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = src[idx[i]]; +``` + +- **inputs:** `%src` is the source vector. `%idx` is the lane-index vector. +- **outputs:** `%result` is the reordered vector. +- **constraints and limitations:** This page records the rearrangement use of + the family; the compare/select page documents the same name from the predicate + selection perspective. + +--- + +## Pack / Unpack + +### `pto.vpack` + +- **syntax:** `%result = pto.vpack %src, "PART" : !pto.vreg -> !pto.vreg<2NxT_narrow>` +- **semantics:** Narrow one wide vector and place the narrowed payload into the + selected half of the result. The other half is filled with zero. + +```c +// e.g., vreg<64xi32> → vreg<128xui16> +for (int i = 0; i < N; i++) + dst[i] = 0; + +if (part == LOWER) { + for (int i = 0; i < N; i++) + dst[i] = truncate(src[i]); +} else { // HIGHER + for (int i = 0; i < N; i++) + dst[N + i] = truncate(src[i]); +} +``` + +- **inputs:** `%src` is the wide source vector. `"LOWER"` and `"HIGHER"` + select whether the narrowed payload lands in the lower or upper half. +- **outputs:** `%result` is the packed narrow vector. +- **constraints and limitations:** Packing is a narrowing conversion with + truncation semantics. Current VPTO surface supports `i32/ui32 -> ui16` and + `i16/ui16 -> ui8`. + +--- + +### `pto.vsunpack` + +- **syntax:** `%result = pto.vsunpack %src, %part : !pto.vreg, index -> !pto.vreg` +- **semantics:** Sign-extending unpack — narrow to wide (half). + +```c +// e.g., vreg<128xi16> → vreg<64xi32> (one half) +for (int i = 0; i < N/2; i++) + dst[i] = sign_extend(src[part_offset + i]); +``` + +- **inputs:** `%src` is the packed narrow vector and `%part` selects which half + is unpacked. +- **outputs:** `%result` is the widened vector. +- **constraints and limitations:** This is the sign-extending unpack family. + +--- + +### `pto.vzunpack` + +- **syntax:** `%result = pto.vzunpack %src, %part : !pto.vreg, index -> !pto.vreg` +- **semantics:** Zero-extending unpack — narrow to wide (half). + +```c +for (int i = 0; i < N/2; i++) + dst[i] = zero_extend(src[part_offset + i]); +``` + +- **inputs:** `%src` is the packed narrow vector and `%part` selects which half + is unpacked. +- **outputs:** `%result` is the widened vector. +- **constraints and limitations:** This is the zero-extending unpack family. + +--- + +## Typical Usage + +```mlir +// AoS → SoA conversion using deinterleave +%even, %odd = pto.vdintlv %interleaved0, %interleaved1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +// Filter: keep only elements passing condition +%pass_mask = pto.vcmps %values, %threshold, %all, "gt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%compacted = pto.vsqz %values, %pass_mask + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Type narrowing via pack +%packed_i16 = pto.vpack %wide_i32, "LOWER" + : !pto.vreg<64xi32> -> !pto.vreg<128xui16> +``` + +--- + +## V2 Interleave Forms + +### `pto.vintlvv2` + +- **syntax:** `%result = pto.vintlvv2 %lhs, %rhs, "PART" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **inputs:** `%lhs` and `%rhs` are source vectors and `PART` selects the + returned half of the V2 interleave result. +- **outputs:** `%result` is the selected interleave half. +- **constraints and limitations:** This op exposes only one half of the V2 + result in SSA form. + +### `pto.vdintlvv2` + +- **syntax:** `%result = pto.vdintlvv2 %lhs, %rhs, "PART" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **inputs:** `%lhs` and `%rhs` are source vectors and `PART` selects the + returned half of the V2 deinterleave result. +- **outputs:** `%result` is the selected deinterleave half. +- **constraints and limitations:** This op exposes only one half of the V2 + result in SSA form. diff --git a/docs/isa/micro-isa/13-dsa-sfu-ops.md b/docs/isa/micro-isa/13-dsa-sfu-ops.md new file mode 100644 index 000000000..62ebfba40 --- /dev/null +++ b/docs/isa/micro-isa/13-dsa-sfu-ops.md @@ -0,0 +1,245 @@ +# 13. DSA/SFU Ops + +> **Category:** Domain-specific accelerator and special function unit operations +> **Pipeline:** PIPE_V (Vector Core) / SFU + +Fused operations, special functions, and UB-to-UB operations that leverage hardware acceleration. + +## Common Operand Model + +- `%input`, `%lhs`, `%rhs`, `%acc`, and `%alpha` are source SSA values whose + roles are called out per instruction. +- `%mask` is the predicate operand `Pg` when present. +- `%result` is the destination SSA value. +- This page mixes three different backend shapes: pure `vreg -> vreg` ops, + conversion/fusion ops, and UB-to-UB helpers. Each instruction section calls + out which storage model it uses. + +--- + +## Fused Activation Ops (vreg→vreg) + +### `pto.vlrelu` + +- **syntax:** `%result = pto.vlrelu %input, %alpha, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Leaky ReLU with scalar alpha. + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : alpha * src[i]; +``` + +- **inputs:** `%input` is the activation vector, `%alpha` is the scalar slope, + and `%mask` selects active lanes. +- **outputs:** `%result` is the leaky-ReLU vector. +- **constraints and limitations:** Only `f16` and `f32` forms are currently + documented for `pto.vlrelu`. + +--- + +### `pto.vprelu` + +- **syntax:** `%result = pto.vprelu %input, %alpha, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Parametric ReLU with per-element alpha vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : alpha[i] * src[i]; +``` + +- **inputs:** `%input` is the activation vector, `%alpha` is the per-element + slope vector, and `%mask` selects active lanes. +- **outputs:** `%result` is the parametric-ReLU vector. +- **constraints and limitations:** Floating-point element types only on the + current A5 surface. + +--- + +### `pto.vexpdif` + +- **syntax:** `%result = pto.vexpdif %input, %max, %mask, "EVEN|ODD" : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** input `f16` or `f32`, output `f32` +- **semantics:** Fused exp(x - max) for numerically stable softmax. + +```c +for (int i = 0; i < N; i++) + dst[i] = expf(src[i] - max[i]); +``` + +**Use case:** Softmax numerator computation with numerical stability. + +- **inputs:** `%input` is the source vector, `%max` is the broadcasted + subtraction term, `%mask` selects active source lanes, and `%part` selects + `EVEN` or `ODD` for the underlying hardware contract. +- **outputs:** `%result` is the fused `exp(input - max)` vector with `f32` + elements. +- **constraints and limitations:** Source vectors must be `f16` or `f32`, the + result vector must be `f32`, the mask granularity must match the input + vector element width, and source/result storage width must match. + +--- + +## Fused Compute+Convert Ops + +### `pto.vaxpy` + +- **syntax:** `%result = pto.vaxpy %src0, %src1, %alpha, %mask : !pto.vreg, !pto.vreg, T, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** AXPY — scalar-vector multiply-add. + +```c +for (int i = 0; i < N; i++) + dst[i] = alpha * src0[i] + src1[i]; +``` + +- **inputs:** `%src0` is the scaled vector, `%src1` is the addend vector, + `%alpha` is the scalar multiplier, and `%mask` selects active lanes. +- **outputs:** `%result` is the fused AXPY result. +- **constraints and limitations:** Floating-point element types only on the + current documented surface. + +--- + + +## Extended Arithmetic + +### `pto.vmull` + +- **syntax:** `%low, %high = pto.vmull %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.vreg` +- **A5 types:** i32/ui32 (native 32×32→64 widening multiply) +- **semantics:** Widening multiply with high/low results. + +```c +for (int i = 0; i < 64; i++) { + int64_t r = (int64_t)src0_i32[i] * (int64_t)src1_i32[i]; + dst_lo[i] = (int32_t)(r & 0xFFFFFFFF); + dst_hi[i] = (int32_t)(r >> 32); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the source vectors and `%mask` selects + active lanes. +- **outputs:** `%low` and `%high` expose the widened-product low/high parts. +- **constraints and limitations:** The current documented A5 form is the native + widening 32x32->64 integer multiply family. + +--- + +### `pto.vmula` + +- **syntax:** `%result = pto.vmula %acc, %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Multiply-accumulate. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) + dst[i] = acc[i] + lhs[i] * rhs[i]; +``` + +- **inputs:** `%acc` is the accumulator input, `%lhs` and `%rhs` are the + multiplicands, and `%mask` selects active lanes. +- **outputs:** `%result` is the multiply-accumulate result. +- **constraints and limitations:** `pto.vmula` is a fused multiply-accumulate + operation and is not always interchangeable with separate `vmul` plus `vadd`. + +--- + +## Index Generation + +### `pto.vci` + +- **syntax:** `%result = pto.vci %index {order = "ASC|DESC"} : T -> !pto.vreg` +- **semantics:** Generate a lane index vector from a scalar base value. + +```c +for (int i = 0; i < N; i++) + dst[i] = (order == ASC) ? (base_index + i) : (base_index - i); +``` + +**Use case:** Generate indices for gather/scatter, argsort, etc. + +- **inputs:** `%index` is the scalar base value. Supported scalar types are + `i8/i16/i32`, `f16`, and `f32`. +- **outputs:** `%result` is the generated index vector. +- **constraints and limitations:** `%result` element type determines both the + generated element type and the lane count. Supported result types are + `!pto.vreg<256xsi8>`, `!pto.vreg<128xsi16>`, `!pto.vreg<64xsi32>`, + `!pto.vreg<128xf16>`, and `!pto.vreg<64xf32>`. `%index` must use the + matching scalar type for `f16`/`f32`; for integer results, `%index` must use + the same bit width and may be signless or signed. + +--- + +## Sorting Operations + +### `pto.vbitsort` + +- **syntax:** `pto.vbitsort %dest, %src, %indices, %repeat_times : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, index` +- **semantics:** Sort 32 region proposals by score and materialize sorted + proposal records into `%dest`. +- **inputs:** `%dest` is the UB destination buffer. `%src` is the UB score + buffer. `%indices` is the UB index buffer. `%repeat_times` is the repeat + count; each repeat processes the next adjacent group of 32 scores and 32 + indices. +- **outputs:** This op writes UB memory and returns no SSA value. Each output + record occupies 8 bytes: the upper 4 bytes hold the index and the lower + 4 bytes hold the score. For `f16` score forms, the score uses the lower + 2 bytes of that 4-byte score field and the upper 2 bytes are reserved. +- **constraints and limitations:** `%dest`, `%src`, and `%indices` MUST be + UB-backed pointers and SHOULD satisfy the backend alignment contract expected + by the A5 `VBS32` instruction. Scores are sorted in descending order, so the + highest score is written to the lowest destination address. Equal-score ties + preserve the earlier input proposal first. This is a UB helper, not a pure + `vreg -> vreg` op. + +--- + +### `pto.vmrgsort4` + +- **syntax:** `pto.vmrgsort4 %dest, %src0, %src1, %src2, %src3, %count, %config : !pto.ptr, !pto.ptr, !pto.ptr, !pto.ptr, !pto.ptr, i64, i64` +- **semantics:** Merge-sort 4 pre-sorted input vectors. +- **inputs:** `%dest` is the UB destination, `%src0..%src3` are the four + pre-sorted UB inputs, `%count` is the number of valid elements, and `%config` + is the operation control word. +- **outputs:** This op writes UB memory and returns no SSA value. +- **constraints and limitations:** Inputs MUST already be sorted according to + the sort order encoded by `%config`. + +--- + +### `pto.get_vms4_sr` + +- **syntax:** `%list0, %list1, %list2, %list3 = pto.get_vms4_sr : i16, i16, i16, i16` +- **semantics:** Read `VMS4_SR` and return the finished counts for source + lists 0, 1, 2, and 3. After exhausted `pto.vmrgsort4`, the four results map + to `VMS4_SR[15:0]`, `VMS4_SR[31:16]`, `VMS4_SR[47:32]`, and + `VMS4_SR[63:48]`. + +--- + +## Current Implementation Surface Summary + +- `pto.vmull %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.vreg` +- `pto.vmula %acc, %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- `pto.vci %index {order = "ASC|DESC"} : T -> !pto.vreg` +- `pto.vbitsort %dest, %src, %indices, %repeat_times : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, index` +- `pto.vmrgsort4 %dest, %src0, %src1, %src2, %src3, %count, %config : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, i64, i64` +- `pto.get_vms4_sr : i16, i16, i16, i16` + +--- + +## Typical Usage + +```mlir +// Softmax with fused expdiff +%max_broadcast = pto.vlds %ub_max[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> +%exp_stable = pto.vexpdif %logits, %max_broadcast, %mask, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Leaky ReLU activation +%activated = pto.vlrelu %linear_out, %alpha_scalar, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Generate ascending si32 indices for argsort +%indices = pto.vci %c0 {order = "ASC"} : i32 -> !pto.vreg<64xsi32> +``` diff --git a/docs/isa/micro-isa/14-shared-arith.md b/docs/isa/micro-isa/14-shared-arith.md new file mode 100644 index 000000000..6c703dc55 --- /dev/null +++ b/docs/isa/micro-isa/14-shared-arith.md @@ -0,0 +1,99 @@ +# 14. Arith (Shared MLIR Dialect) + +> **Category:** Shared full scalar `arith` surface used around PTO ops +> **Dialect:** `arith` +> **Upstream Reference:** https://mlir.llvm.org/docs/Dialects/ArithOps/ + +The upstream MLIR `arith` dialect defines primitive arithmetic, comparison, select, and cast operations over signless integer, index, floating-point, and boolean-compatible scalar values. Within PTO micro Instruction code, the full scalar operation surface of `arith` is supported. These ops are used around PTO instructions to build constants, compute offsets and loop bounds, perform general scalar math, derive valid-shape metadata, and form predicates for `scf` control flow. + +These ops are part of the documented PTO micro Instruction surface, but they are not PTO ISA instructions. + +--- + +## Role in PTO micro Instruction Code + +- materialize scalar constants used by PTO scalar operands and loop bounds +- compute scalar/index offsets for tensor views, partitioning, and dynamic shapes +- perform general scalar integer and floating-point math outside PTO vector/tile payload operations +- derive scalar predicates that guard `scf.if` or `scf.while` +- apply scalar casts, width changes, bitwise ops, and selects without introducing PTO-specific control ops + +Prefer PTO ops for vector or tile payload math. Use `arith` for scalar computation and bookkeeping that surrounds PTO regions. + +--- + +## Supported Surface + +The documented PTO micro Instruction surface supports the full scalar operation surface of upstream `arith`. The upstream `arith` dialect reference remains authoritative for the exhaustive op-by-op syntax and semantics. The categories below summarize how that support is used in PTO micro Instruction code. + +| Category | Representative Ops | Typical Use in PTO micro Instruction Code | +|----------|--------------------|------------------| +| Constants | `arith.constant` | integer, floating-point, boolean, and `index` constants | +| Integer / Index Arithmetic | `arith.addi`, `arith.subi`, `arith.muli`, `arith.divsi`, `arith.divui`, `arith.ceildivsi`, `arith.ceildivui`, `arith.floordivsi`, `arith.remsi`, `arith.remui` | offsets, bounds, chunk sizes, scalar math | +| Floating-Point Arithmetic | `arith.addf`, `arith.subf`, `arith.mulf`, `arith.divf`, `arith.negf`, `arith.maximumf`, `arith.minimumf`, `arith.maxnumf`, `arith.minnumf` | scalar math around PTO regions | +| Bitwise / Shift Ops | `arith.andi`, `arith.ori`, `arith.xori`, `arith.shli`, `arith.shrsi`, `arith.shrui` | flags, masks, packed scalar fields | +| Comparisons / Select | `arith.cmpi`, `arith.cmpf`, `arith.select`, `arith.maxsi`, `arith.minui` | predicates, clamps, scalar muxes | +| Casts / Width Changes | `arith.index_cast`, `arith.index_castui`, `arith.extsi`, `arith.extui`, `arith.trunci`, `arith.sitofp`, `arith.uitofp`, `arith.fptosi`, `arith.fptoui`, `arith.extf`, `arith.truncf`, `arith.bitcast` | ABI glue, dynamic-shape plumbing, scalar type adaptation | + +--- + +## Current PTOAS Coverage + +- the current repository examples are still dominated by constants, casts, integer/index arithmetic, compares, and selects because those are the most common surrounding-scalar patterns in existing kernels +- backend-specific tests such as the PTO shared-dialect fixture visibly exercise only a representative subset of `arith` ops in a single path +- the documented PTO micro Instruction source-level contract is nevertheless the full scalar `arith` surface, not just the index-heavy subset that appears most often in current samples + +This section therefore uses representative categories and examples instead of pretending that the supported `arith` surface is limited to the currently most common sample patterns. + +--- + +## Typical Patterns + +### Scalar Setup + +```mlir +%c0 = arith.constant 0 : index +%c1 = arith.constant 1 : index +%scale = arith.constant 2.0 : f32 +``` + +### Dynamic Offset Computation + +```mlir +%vrow = arith.index_cast %valid_row : i32 to index +%chunk = arith.muli %row, %c32 : index +%tail = arith.subi %limit, %chunk : index +``` + +### General Scalar Arithmetic + +```mlir +%sum_i = arith.addi %lhs_i, %rhs_i : i32 +%sum_f = arith.addf %lhs_f, %rhs_f : f32 +%prod_f = arith.mulf %sum_f, %scale : f32 +``` + +### Scalar Predicate and Selection + +```mlir +%is_first = arith.cmpi eq, %i, %c0 : index +%active = arith.select %is_first, %first_count, %steady_count : index +``` + +### Bitwise / Width Adaptation + +```mlir +%flags = arith.andi %flags0, %flags1 : i32 +%wide = arith.extui %flags : i32 to i64 +%shrunk = arith.trunci %wide : i64 to i16 +``` + +--- + +## Authoring Guidance + +- treat upstream `arith` scalar semantics as the source of truth for supported scalar ops +- keep `arith` values scalar or `index` typed; do not use `arith` as a substitute for PTO vector/tile compute +- use `arith` for general scalar math, scalar comparisons, bitwise operations, and casts around PTO regions, not just for `index` arithmetic +- use `arith.cmpi` / `arith.cmpf` plus `scf.if` / `scf.while` for control flow, not ad hoc control intrinsics +- prefer `arith.index_cast` / `arith.index_castui` at ABI or shape boundaries where `index` is required, but do not read that as a restriction on the rest of scalar `arith` diff --git a/docs/isa/micro-isa/15-shared-scf.md b/docs/isa/micro-isa/15-shared-scf.md new file mode 100644 index 000000000..12a637fd7 --- /dev/null +++ b/docs/isa/micro-isa/15-shared-scf.md @@ -0,0 +1,97 @@ +# 15. SCF (Shared MLIR Dialect) + +> **Category:** Shared structured control flow around PTO regions +> **Dialect:** `scf` +> **Upstream Reference:** https://mlir.llvm.org/docs/Dialects/SCFDialect/ + +The upstream MLIR `scf` dialect defines structured control flow operations with regions, including counted loops, conditional regions, and while-style loops. In PTO micro Instruction code, `scf` is the control shell around PTO ops: it sequences DMA, vector, and tile operations; carries scalar or tile state across iterations; and preserves analyzable control flow for PTO-specific analyses and lowerings. + +These ops are part of the documented PTO micro Instruction surface, but they are shared MLIR control-flow constructs rather than PTO ISA instructions. + +--- + +## Supported Ops + +| Op | Role in PTO micro Instruction Code | Notes | +|----|------------------------|-------| +| `scf.for` | counted loops and loop-carried values | common structured counted loop form | +| `scf.if` | structured conditional execution | may yield values or act as side-effect-only branch | +| `scf.yield` | region terminator for `for` / `if` / `while` bodies | carries loop or branch results | +| `scf.while` | break-like or stateful loops | useful for source-level structured control | +| `scf.condition` | loop-continue / loop-exit decision for `scf.while` | placed in the "before" region | + +Ops such as `scf.execute_region`, `scf.forall`, or `scf.index_switch` are not part of the documented shared-dialect portion of the PTO micro Instruction surface here. + +--- + +## Current PTOAS Coverage + +- `scf.for`, `scf.if`, and `scf.yield` are directly exercised in the shared-dialect PTO fixture and appear widely across PTO samples +- PTO synchronization and memory analyses explicitly reason about `scf.for`, `scf.if`, `scf.yield`, and `scf.while` +- `scf.while` and `scf.condition` appear in control-flow samples and are handled in PTO-to-EmitC control-flow lowering, but they are less broadly exercised than `for` / `if` on all backend paths + +--- + +## Typical Patterns + +### Counted Loop + +```mlir +scf.for %i = %c0 to %c4 step %c1 { + %offset = arith.muli %i, %c32 : index + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} +``` + +### Counted Loop with Loop-Carried State + +```mlir +%final_alive = scf.for %i = %c0 to %c4 step %c1 + iter_args(%alive = %true) -> (i1) { + %break_now = arith.cmpi eq, %i, %c2 : index + %next_alive = scf.if %break_now -> (i1) { + scf.yield %false : i1 + } else { + scf.yield %alive : i1 + } + scf.yield %next_alive : i1 +} +``` + +### Structured Conditional Region + +```mlir +%is_mode_a = arith.cmpi eq, %mode, %c0_i32 : i32 +scf.if %is_mode_a { + pto.tmuls ins(%data, %scale_a : !pto.tile_buf<...>, f32) outs(%data : !pto.tile_buf<...>) +} else { + pto.tadds ins(%data, %bias_b : !pto.tile_buf<...>, f32) outs(%data : !pto.tile_buf<...>) +} +``` + +### While-Style Break Loop + +```mlir +%final:2 = scf.while (%i = %c0, %alive = %true) : (index, i1) -> (index, i1) { + %lt = arith.cmpi slt, %i, %c4 : index + %go = arith.andi %lt, %alive : i1 + scf.condition(%go) %i, %alive : index, i1 +} do { +^bb0(%i2: index, %alive2: i1): + %next_i = arith.addi %i2, %c1 : index + scf.yield %next_i, %alive2 : index, i1 +} +``` + +--- + +## Authoring Guidance + +- use `scf.for` for regular counted loops and loop-carried scalar/tile state +- use `scf.if` for structured branching around PTO regions instead of inventing PTO-specific branch ops +- keep region results explicit with `scf.yield`; this is important for PTO analyses that track carried buffers and aliasing +- use `scf.while` only when a counted loop cannot express the control cleanly; `scf.for` remains the more common and better-exercised form in the current repository +- build branch predicates and loop conditions with `arith` ops, not PTO vector masks, unless the control decision truly comes from a scalarized value diff --git a/docs/isa/micro-isa/16-cube-matmul.md b/docs/isa/micro-isa/16-cube-matmul.md new file mode 100644 index 000000000..fd3eece32 --- /dev/null +++ b/docs/isa/micro-isa/16-cube-matmul.md @@ -0,0 +1,1168 @@ +# 16. Cube Matrix Multiply + +> **Category:** Cube unit ops — staged load/store, matrix multiply, and +> FIXPIPE MTE writeback + +This chapter documents the high-level Cube VPTO surface. It describes logical +data objects, operand units, layout contracts, numeric behavior, and writeback +effects from the user's point of view. + +--- + +## Common Cube Operand Model + +Cube ops use typed PTO pointers to name logical storage domains. The canonical +`!pto.ptr` address-space names are the hardware-domain names below. The legacy +names are accepted only as parser aliases and are printed back as canonical +names. + +| Canonical address space | Legacy alias | Logical role | +|-------------------------|--------------|--------------| +| `gm` | - | Global memory | +| `l1` | `mat` | L1 matrix staging buffer | +| `l0a` | `left` | Left matrix operand tile for Cube compute | +| `l0b` | `right` | Right matrix operand tile for Cube compute | +| `l0c` | `acc` | Accumulator/result tile produced by Cube compute | +| `bt` | `bias` | Bias vector payload consumed by bias matmul forms | +| `fb` | `scaling` | FIXPIPE parameter payloads consumed by vector quant/ReLU clauses | +| `ub` | `vec` | Unified Buffer destination/source for vector-side use | + +Unless an op says otherwise: + +- Shape operands such as `%m`, `%n`, `%k`, `shape(%n, %d)` are logical element + counts, not byte counts. +- Length operands named `%len_burst` in byte-copy surfaces are byte counts + unless the op explicitly states a different unit. +- Strides named `src_stride` or `dst_stride` are start-to-start distances in + the unit stated by the op. Do not infer byte units from the name alone. +- Pointer operands select the base address of the logical object. Sub-tile + selection is expressed by computing a different base pointer before calling + the op, unless the op exposes an explicit start or group operand. +- Cache/session hint operands may affect the memory path but do not change the + mathematical value written or read. + +--- + +## Cube Compute Ops + +The `pto.mad*` family computes logical matrix multiplication over tiles already +prepared in `l0a` and `l0b`: + +```text +lhs: M x K +rhs: K x N +dst: M x N +``` + +The matrix element types are inferred from `%lhs`, `%rhs`, and `%dst` pointer +element types. There is no separate type selector. Unsupported type +combinations are invalid programs. + +The current VPTO surface enforces the Cube storage roles through pointer +address spaces: `%lhs` is `l0a`, `%rhs` is `l0b`, and `%dst` is `l0c`. +Bias forms additionally require `%bias` in the `bt` address space with the +same element type as `%dst`. MX forms require MX element types on both `%lhs` +and `%rhs`; the current target-profile MX data type is `f8E4M3FN`. + +### MAD Common Clauses + +| Clause | Values | Effect | +|--------|--------|--------| +| `unit_flag(...)` | `check_only`, `check_and_set` | Participates in producer-side tile synchronization. `check_only` checks that the producer slot can be used. `check_and_set` also publishes the produced `%dst` tile for later consumers. Omit the clause when the schedule does not use unit flags for this tile. | +| `disable_gemv` | flag | Applies only when `%m = 1`. Omitted means GEMV A-vector consumption: `%lhs` must contain the logical `1 x K` row in the target GEMV left-tile organization. Present means normal matmul left-tile organization. The mathematical result is still `lhs @ rhs`; only the required `%lhs` organization changes. For `%m != 1`, normal matmul organization is used. | +| `sat` / `nosat` | flags | Floating exceptional-value mode for floating and MX MAD forms. With `sat`, exceptional multiply inputs are normalized before arithmetic (`+/-inf` to finite type extrema, `nan` to 0) and finite overflow saturates to the finite type range. With `nosat`, exceptional inputs are preserved and overflow may produce exceptional outputs. Omit both to use the execution mode selected outside this op. Integer MAD forms do not accept these flags. | +| `tf32_mode(...)` | `round_even`, `round_away` | Valid only for non-MX `f32 x f32 -> f32`. FP32 inputs are rounded to TF32 precision before multiplication; accumulation and output remain FP32. | +| `n_dir` | flag | Requests N-direction result production order for schedules that combine compute with unit flags and later layout movement. It does not change `dst[m, n]`. | + +Reference semantics for non-MX forms: + +```text +product[m, n] = sum k in 0 .. K-1: + numeric_lhs(lhs[m, k]) * numeric_rhs(rhs[k, n]) + +pto.mad: dst[m, n] = product[m, n] +pto.mad_acc: dst[m, n] = dst[m, n] + product[m, n] +pto.mad_bias: dst[m, n] = product[m, n] + bias[n] +``` + +For integer forms, the op multiplies the typed values already present in +`l0a` and `l0b`. Per-input offset correction for quantized integer +algorithms is not an operand of `pto.mad*`; apply such correction before +loading the Cube operands when the algorithm needs it. + +### MX Matmul Model + +`pto.mad_mx*` additionally applies microscaling. The scale payloads are loaded +with `pto.mte_l1_l0a_mx` / `pto.mte_l1_l0b_mx` and are associated with the +selected `%lhs` / `%rhs` tiles; they are not direct operands of `pto.mad_mx*`. + +The K dimension is partitioned into 32-element groups: + +```text +k_group = floor(k / 32) + +mx_product[m, n] = + sum k in 0 .. K-1: + (lhs[m, k] * lhs_scale[m, k_group]) * + (rhs[k, n] * rhs_scale[k_group, n]) +``` + +Current target-profile MX data tiles use `f8E4M3FN`. `%k` must be compatible +with MX grouping. On the current target profile, MX matmul consumes K in +64-element multiples, which contain two 32-element scale groups. + +### `pto.mad` + +- **syntax:** +```mlir +pto.mad %lhs, %rhs, %dst, %m, %n, %k + unit_flag(check_only | check_and_set)? + disable_gemv? + (sat | nosat)? + tf32_mode(round_even | round_away)? + n_dir? + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 +``` +- **semantics:** Zero-init matrix multiply, `dst[m, n] = sum_k(lhs[m, k] * rhs[k, n])`. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%lhs` | ptr | Left operand tile in `l0a`, interpreted as logical `M x K` | +| `%rhs` | ptr | Right operand tile in `l0b`, interpreted as logical `K x N` | +| `%dst` | ptr | Accumulator destination tile in `l0c`, interpreted as logical `M x N` | +| `%m` | i64 | Logical M element count | +| `%n` | i64 | Logical N element count | +| `%k` | i64 | Logical K element count | +| optional clauses | - | See [MAD Common Clauses](#mad-common-clauses) | + +**Constraints:** + +- `%lhs`, `%rhs`, and `%dst` must be in `l0a`, `l0b`, and `l0c`. +- `%m`, `%n`, and `%k` must be positive and satisfy the target shape limits + for the selected element-type combination. +- `tf32_mode(...)` requires `f32` lhs, rhs, and dst element types. +- `sat` / `nosat` requires a floating element-type combination. +- Packed 4-bit integer data requires `%k` to select an even number of K + elements. + +**Example:** + +```mlir +pto.mad %l0a, %l0b, %l0c, %c16_i64, %c16_i64, %c32_i64 + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 +``` + +--- + +### `pto.mad_acc` + +- **syntax:** +```mlir +pto.mad_acc %lhs, %rhs, %dst, %m, %n, %k + unit_flag(check_only | check_and_set)? + disable_gemv? + (sat | nosat)? + tf32_mode(round_even | round_away)? + n_dir? + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 +``` +- **semantics:** Accumulating matrix multiply, + `dst[m, n] = dst[m, n] + sum_k(lhs[m, k] * rhs[k, n])`. + +**Parameter Table:** same as `pto.mad`. + +**Constraints:** same as `pto.mad`. + +**Example:** + +```mlir +pto.mad_acc %l0a, %l0b, %l0c, %c16_i64, %c16_i64, %c32_i64 unit_flag(check_only) + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 +``` + +--- + +### `pto.mad_bias` + +- **syntax:** +```mlir +pto.mad_bias %lhs, %rhs, %dst, %bias, %m, %n, %k + unit_flag(check_only | check_and_set)? + disable_gemv? + (sat | nosat)? + tf32_mode(round_even | round_away)? + n_dir? + : !pto.ptr, !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 +``` +- **semantics:** Bias-init matrix multiply, + `dst[m, n] = sum_k(lhs[m, k] * rhs[k, n]) + bias[n]`. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%lhs`, `%rhs`, `%dst`, `%m`, `%n`, `%k` | - | Same as `pto.mad` | +| `%bias` | ptr | Bias vector in `bt`, interpreted as `N` values and broadcast across M | +| optional clauses | - | See [MAD Common Clauses](#mad-common-clauses) | + +**Constraints:** + +- `%bias` must be in `bt` address space. +- `%bias` element type must match `%dst` element type. +- Only `N` bias values are consumed; `%bias` is not an `M x N` matrix. +- Other constraints match `pto.mad`. + +**Example:** + +```mlir +pto.mad_bias %l0a, %l0b, %l0c, %bt, %c16_i64, %c16_i64, %c32_i64 + : !pto.ptr, !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 +``` + +--- + +### `pto.mad_mx` + +- **syntax:** +```mlir +pto.mad_mx %lhs, %rhs, %dst, %m, %n, %k + unit_flag(check_only | check_and_set)? + disable_gemv? + (sat | nosat)? + n_dir? + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 +``` +- **semantics:** Zero-init MX matrix multiply, `dst[m, n] = mx_product[m, n]`. + +**Parameter Table:** same as `pto.mad`; `%lhs` and `%rhs` must have matching +MX scale payloads prepared by the MX load ops. + +**Constraints:** + +- Operands must use a target-supported MX dtype combination. +- Matching left and right MX scale payloads must be loaded before this op. +- `%k` must satisfy the MX grouping rule described in [MX Matmul Model](#mx-matmul-model). +- `tf32_mode(...)` is not a clause of MX MAD. + +**Example:** + +```mlir +pto.mad_mx %l0a, %l0b, %l0c, %c16_i64, %c16_i64, %c64_i64 + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 +``` + +--- + +### `pto.mad_mx_acc` + +- **syntax:** +```mlir +pto.mad_mx_acc %lhs, %rhs, %dst, %m, %n, %k + unit_flag(check_only | check_and_set)? + disable_gemv? + (sat | nosat)? + n_dir? + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 +``` +- **semantics:** Accumulating MX matrix multiply, + `dst[m, n] = dst[m, n] + mx_product[m, n]`. + +**Parameter Table:** same as `pto.mad_mx`. + +**Constraints:** same as `pto.mad_mx`. + +**Example:** + +```mlir +pto.mad_mx_acc %l0a, %l0b, %l0c, %c16_i64, %c16_i64, %c64_i64 + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 +``` + +--- + +### `pto.mad_mx_bias` + +- **syntax:** +```mlir +pto.mad_mx_bias %lhs, %rhs, %dst, %bias, %m, %n, %k + unit_flag(check_only | check_and_set)? + disable_gemv? + (sat | nosat)? + n_dir? + : !pto.ptr, !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 +``` +- **semantics:** Bias-init MX matrix multiply, + `dst[m, n] = mx_product[m, n] + bias[n]`. + +**Parameter Table:** same as `pto.mad_bias`, with MX `%lhs` / `%rhs` scale +payload requirements from `pto.mad_mx`. + +**Constraints:** same as `pto.mad_mx` plus `pto.mad_bias` bias constraints. + +**Example:** + +```mlir +pto.mad_mx_bias %l0a, %l0b, %l0c, %bt, %c16_i64, %c16_i64, %c64_i64 + : !pto.ptr, !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 +``` + +--- + +## Cube Data Movement Ops + +### Cube Burst / Loop Addressing Model + +`pto.mte_gm_l1` and `pto.mte_l1_ub` use the same grouped transfer model: + +```text +burst(row) = len_burst contiguous bytes +nburst = innermost repeated burst group +loop = optional outer repetition group +``` + +For each `nburst` row, the source and destination start addresses advance by +`src_stride` and `dst_stride` after a burst row. Optional `loop(...)` groups +wrap the full inner transfer pattern and advance by their own source and +destination strides between repetitions. All lengths and strides in this model +are bytes. + +### `pto.mte_gm_l1` + +- **syntax:** +```mlir +pto.mte_gm_l1 %src, %dst, %len_burst + nburst(%count, %src_stride, %dst_stride) + [loop(%count_i, %src_stride_i, %dst_stride_i)]* + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 +``` +- **semantics:** Structured GM-to-L1 copy. The op copies grouped byte ranges + from `%src` in `gm` to `%dst` in `l1`. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src` | ptr | GM source base pointer | +| `%dst` | ptr | L1 matrix destination base pointer in `l1` | +| `%len_burst` | i64 | Bytes copied per burst row | +| `nburst(%count, %src_stride, %dst_stride)` | i64 triple | Innermost burst count and byte strides between row starts | +| `loop(%count_i, %src_stride_i, %dst_stride_i)` | i64 triple | Optional outer repetition; strides are byte advances between enclosed patterns | + +**Constraints:** + +- `nburst(...)` is required. +- Each `loop(...)` group must provide all three operands. +- For a contiguous 16-element f16 vector, use `%len_burst = 32`. + +**Example:** + +```mlir +pto.mte_gm_l1 %bias_gm, %l1_bias, %c32_i64 + nburst(%c4_i64, %c64_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 +``` + +--- + +### `pto.mte_l1_ub` + +- **syntax:** +```mlir +pto.mte_l1_ub %src, %dst, %len_burst + nburst(%count, %src_stride, %dst_stride) + [loop(%count_i, %src_stride_i, %dst_stride_i)]* + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 +``` +- **semantics:** Structured L1-to-UB copy. The grouped byte ranges are read + from `%src` in `l1` and written to `%dst` in `ub`. + +**Parameter Table:** same grouped byte model as `pto.mte_gm_l1`, with source +and destination address spaces reversed to `l1 -> ub`. + +**Constraints:** + +- `%src` must be in `l1`, `%dst` must be in `ub`. +- `nburst(...)` is required. +- Each `loop(...)` group must provide all three operands. + +**Example:** + +```mlir +pto.mte_l1_ub %l1_src, %ub_dst, %c64_i64 + nburst(%c2_i64, %c128_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 +``` + +--- + +### `pto.mte_gm_l1_frac` + +- **syntax:** +```mlir +pto.mte_gm_l1_frac %src, %dst, nd2nz|dn2nz, + shape(%n_value, %d_value), + src_layout(%src_inner_stride[, %src_outer_stride]), + dst_group(%group_count, %dst_loop2_stride, %dst_loop3_stride, %dst_loop4_stride), + ctrl(%l2_cache_ctrl, %smallc0_en) + : !pto.ptr, !pto.ptr, ... +``` +- **semantics:** Load a logical 2-D GM region and write one or more L1 NZ + matrix groups. `nd2nz` reads a logical `src[n, d]` matrix. `dn2nz` reads a + logical `src[d, n]` matrix and writes the same logical `N x D` result into + NZ layout. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src` | ptr | GM source base pointer | +| `%dst` | ptr | L1 NZ destination base pointer in `l1` | +| `nd2nz` / `dn2nz` | keyword | Source logical layout mode | +| `shape(%n_value, %d_value)` | i64 pair | Logical output shape before NZ packing | +| `src_layout(%src_inner_stride[, %src_outer_stride])` | i64 / optional i64 | Source row/matrix byte strides | +| `dst_group(...)` | i64 tuple | Destination group count and placement strides in C0-size units | +| `ctrl(%l2_cache_ctrl, %smallc0_en)` | i64, i1 | Cache hint and small-C0 packing enable | + +`src_layout(%src_inner_stride)` describes one logical source matrix. For +`nd2nz`, `%src_inner_stride` is the byte distance from `src[n, 0]` to +`src[n + 1, 0]`. For `dn2nz`, it is the byte distance from `src[d, 0]` to +`src[d + 1, 0]`. When `%src_outer_stride` is present, it is the byte distance +between adjacent source matrices. When omitted, the outer source stride is 0. + +`dst_group(%group_count, %dst_loop2_stride, %dst_loop3_stride, +%dst_loop4_stride)` writes `%group_count` logical matrices. Destination strides +are measured in C0-size units; one C0-size unit is 32 bytes. These strides +place generated NZ blocks relative to `%dst`. They do not select a separate +memory block. + +Reference addressing: + +```text +for g in 0 .. group_count-1: + src_g = src + g * src_outer_stride + dst_g = dst + g * dst_loop4_stride * 32 + + for n in 0 .. n_value-1: + for d in 0 .. d_value-1: + if mode == nd2nz: + value = load(src_g + n * src_inner_stride + d * sizeof(T)) + else: + value = load(src_g + d * src_inner_stride + n * sizeof(T)) + store value into NZ position for logical [n, d] under dst_g + + invalid lanes in the final C0 group are written as zero +``` + +**Constraints:** + +- Source strides are bytes. For row-major `16 x 16` f16 input, + `src_layout(32)` describes consecutive rows. +- Destination strides are C0-size units, not bytes and not elements. +- `smallc0_en = true` is valid only for target-supported small-C0 cases. The + current contract rejects `d_value > 4` in small-C0 mode. +- In normal C0 mode, each destination C0 burst is padded to 32 bytes. In + small-C0 mode, each destination burst is padded to 4 logical channels, and + the generated inner-N and C0 destination placement is fixed by that + small-C0 packing rule. `%dst_loop4_stride` still places adjacent matrix + groups. +- In small-C0 mode, missing logical `N` rows and invalid `D` lanes are written + as zero, and the tail of a generated NZ matrix is padded to the 32-byte C0 + boundary. +- Destination regions selected by `%dst` and `dst_group(...)` must not overlap. + If two generated writes target the same bytes, the final value is not a + stable program result. + +**Example:** + +```mlir +pto.mte_gm_l1_frac %src, %dst, nd2nz, + shape(%c32_i64, %c16_i64), + src_layout(%c32_i64, %c1024_i64), + dst_group(%c2_i64, %c1_i64, %c16_i64, %c64_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, shape i64, i64, + src_layout(i64, i64), dst_group i64, i64, i64, i64, ctrl i64, i1 +``` + +--- + +### `pto.mte_l1_bt` + +- **syntax:** +```mlir +pto.mte_l1_bt %src, %dst, %len_burst + nburst(%count, %src_gap, %dst_gap) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 +``` +- **semantics:** Load an L1 bias payload into the `bt` address space for + later `pto.mad_bias` / `pto.mad_mx_bias` consumption. The consumer interprets + the result as an `N`-element bias vector `bias[n]`. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src` | ptr | L1 source pointer in `l1` | +| `%dst` | ptr | Bias destination pointer in `bt` | +| `%len_burst` | i64 | Number of bias-load units per burst | +| `%count` | i64 | Burst count | +| `%src_gap` | i64 | Source gap between bursts, in bias-load units | +| `%dst_gap` | i64 | Destination gap between bursts, in bias-load units | + +One burst loads `%len_burst` units from `%src` and writes the corresponding +bias values to `%dst`. After each burst except the last, source and destination +advance by the burst length plus the corresponding gap. + +**Constraints:** + +- Supported type pairs: `f32->f32`, `i32->i32`, `f16->f32`, `bf16->f32`. +- For `bf16->f32`, compact bf16 source values are always widened to f32 bias + values. For `f16->f32`, compact f16 source values are widened when the load + is used as an f32 bias payload; otherwise the f16 payload is stored in the + 32-bit bias slot with unused high bits. +- Load exactly the channel bias values needed by the consumer tile; the bias + payload is not result-shaped. + +**Example:** + +```mlir +pto.mte_l1_bt %l1_bias, %bt, %c1_i64 nburst(%c4_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 +``` + +--- + +### `pto.mte_l1_fb` + +- **syntax:** +```mlir +pto.mte_l1_fb %src, %dst, %len_burst + nburst(%count, %src_gap, %dst_gap) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 +``` +- **semantics:** Load FIXPIPE parameter payloads from L1 into `fb`. + Vector `pre_quant(...)` and `pre_relu(...)` clauses in `pto.mte_l0c_l1*` + later consume these payloads through `fb` pointers. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src` | ptr | L1 source pointer in `l1` | +| `%dst` | ptr | Scaling destination pointer in `fb` | +| `%len_burst` | i64 | Number of parameter-load units per burst | +| `%count` | i64 | Burst count | +| `%src_gap` | i64 | Source gap between bursts, in parameter-load units | +| `%dst_gap` | i64 | Destination gap between bursts, in parameter-load units | + +The copy unit of `pto.mte_l1_fb` is the parameter-load unit of this op. It is +separate from the row size consumed by `mte_l0c_*` vector payloads. +`%len_burst` and the `nburst(...)` gaps are counted in these load units, not +in bytes and not in destination elements. After `pto.mte_l1_fb` materializes the +payload in `fb`, vector pre-ReLU consumers read it as 64B parameter rows +and vector pre-quant consumers read it as 128B parameter rows. The payload +pointer passed to `mte_l0c_*` must point at the first row for the logical +output tile, and rows must follow the same channel/NZ order consumed by that +store. + +**Constraints:** + +- `%src` must be in `l1`, `%dst` must be in `fb`. +- Vector `pre_quant` and `pre_relu` consumers require parameter data prepared + in the row order documented by [FIXPIPE MTE Ops](#fixpipe-mte-ops). + +**Example:** + +```mlir +pto.mte_l1_fb %l1_fp, %fb_fp, %c2_i64 nburst(%c4_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 +``` + +--- + +### Left / Right Tile Load Model + +`pto.mte_l1_l0a` and `pto.mte_l1_l0b` move L1 cube-fractal tiles into the +compute operand domains. `%src` must already point to an L1 cube-fractal tile; +these ops do not convert arbitrary row-major matrices. Use +`pto.mte_gm_l1_frac` first when the original data is plain ND/DN layout. + +If `transpose = true`, the selected logical source tile is transposed before it +is placed in the destination operand domain. Omitting the attribute means +`transpose = false`. + +### `pto.mte_l1_l0a` + +- **syntax:** +```mlir +pto.mte_l1_l0a %src, %dst, %m, %k + : !pto.ptr, !pto.ptr, i64, i64 +``` +- **semantics:** Load a logical `%m x %k` left tile from L1 `l1` into `l0a`. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src` | ptr | L1 cube-fractal source tile in `l1` | +| `%dst` | ptr | Left operand destination in `l0a` | +| `%m` | i64 | Logical M extent | +| `%k` | i64 | Logical K extent | +| `transpose` | attr | Optional boolean source-tile transpose before destination placement | + +**Constraints:** + +- `%src` must be in `l1`, `%dst` must be in `l0a`. +- `%src` and `%dst` must satisfy the target alignment for Cube tile loads. +- `transpose = true` requires a tile shape supported by the element-type + transpose granularity. + +**Example:** + +```mlir +pto.mte_l1_l0a %l1_a, %l0a, %c16_i64, %c32_i64 + : !pto.ptr, !pto.ptr, i64, i64 +``` + +--- + +### `pto.mte_l1_l0b` + +- **syntax:** +```mlir +pto.mte_l1_l0b %src, %dst, %k, %n + : !pto.ptr, !pto.ptr, i64, i64 +``` +- **semantics:** Load a logical `%k x %n` right tile from L1 `l1` into + `l0b`. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src` | ptr | L1 cube-fractal source tile in `l1` | +| `%dst` | ptr | Right operand destination in `l0b` | +| `%k` | i64 | Logical K extent | +| `%n` | i64 | Logical N extent | +| `transpose` | attr | Optional boolean source-tile transpose before destination placement | + +**Constraints:** + +- `%src` must be in `l1`, `%dst` must be in `l0b`. +- `%src` and `%dst` must satisfy the target alignment for Cube tile loads. +- `transpose = true` requires a tile shape supported by the element-type + transpose granularity. + +**Example:** + +```mlir +pto.mte_l1_l0b %l1_b, %l0b, %c32_i64, %c16_i64 + : !pto.ptr, !pto.ptr, i64, i64 +``` + +--- + +### MX Scale Load Model + +MX scale loads prepare the scale payloads consumed by `pto.mad_mx*`. Each scale +entry applies to one 32-element K group. + +- Left scale logical shape: `[M, ceil(K / 32)]`. +- Right scale logical shape: `[ceil(K / 32), N]`. +- L1 source data is organized as 32B scale fragments in the same logical order + as the associated data tile. + +### `pto.mte_l1_l0a_mx` + +- **syntax:** +```mlir +pto.mte_l1_l0a_mx %src, %dst, %m, %k + : !pto.ptr, !pto.ptr, i64, i64 +``` +- **semantics:** Load left-side MX scale fragments for a logical `%m x %k` + left data tile. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src` | ptr | L1 MX scale source in `l1` | +| `%dst` | ptr | Left-side MX payload destination associated with `l0a` | +| `%m` | i64 | M extent of the associated left data tile | +| `%k` | i64 | K extent; scale grouping is by 32 K elements | + +**Constraints:** + +- `%src` must be in `l1`, `%dst` must be in `l0a`. +- `%src` and `%dst` must satisfy 32B MX scale-fragment alignment. + +**Example:** + +```mlir +pto.mte_l1_l0a_mx %l1_a_scale, %l0a_scale, %c16_i64, %c64_i64 + : !pto.ptr, !pto.ptr, i64, i64 +``` + +--- + +### `pto.mte_l1_l0b_mx` + +- **syntax:** +```mlir +pto.mte_l1_l0b_mx %src, %dst, %k, %n + : !pto.ptr, !pto.ptr, i64, i64 +``` +- **semantics:** Load right-side MX scale fragments for a logical `%k x %n` + right data tile. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src` | ptr | L1 MX scale source in `l1` | +| `%dst` | ptr | Right-side MX payload destination associated with `l0b` | +| `%k` | i64 | K extent; scale grouping is by 32 K elements | +| `%n` | i64 | N extent of the associated right data tile | + +**Constraints:** + +- `%src` must be in `l1`, `%dst` must be in `l0b`. +- `%src` and `%dst` must satisfy 32B MX scale-fragment alignment. + +**Example:** + +```mlir +pto.mte_l1_l0b_mx %l1_b_scale, %l0b_scale, %c64_i64, %c16_i64 + : !pto.ptr, !pto.ptr, i64, i64 +``` + +--- + +## FIXPIPE MTE Ops + +`pto.mte_l0c_l1*` writes logical accumulator results from `l0c` to `l1`, `gm`, +or `ub`. The family shares this pipeline order: + +```text +1. Read logical acc[m, n] from %src using the selected layout mode. +2. Optionally participate in consumer-side unit-flag synchronization. +3. Optionally apply pre_quant(payload, mode). +4. Optionally apply pre_relu(payload, mode), then optional clip. +5. Convert to the destination element type using sat/nosat behavior. +6. Write to the selected destination layout and address space. +7. Apply store-target effects such as GM atomic or UB dual destination. +``` + +Only the clauses documented here affect `pto.mte_l0c_l1*`. Other transforms +must be represented by separate PTO ops before producing `l0c` or after the +writeback destination is materialized. + +### FIXPIPE Common Clauses + +| Clause | Values | Effect | +|--------|--------|--------| +| `unit_flag(...)` | `check_only`, `check_and_clear` | Checks that the accumulator tile is ready for consumption. `check_and_clear` also clears the consumed tile state for later reuse. Omit when the schedule does not use unit flags. | +| `pre_quant(%payload, mode = ...)` | see below | Applies the selected pre-quantization or conversion before ReLU/clip and final store. | +| `pre_relu([%payload, ]mode = ...[, clip = %clip])` | `no_relu`, `normal_relu`, `scalar_relu`, `vector_relu` | Applies ReLU-family activation before final destination conversion. `clip` is part of this clause and applies after the selected ReLU mode. | +| `nz2nd` / `nz2dn(...)` / `nz2nz(...)` | layout modes | Selects how logical `acc[m, n]` is written to the destination layout. | +| `loop3(%count, %src_stride3, %dst_stride3)` | i64 triple | Repeats the whole selected `m x n` writeback pattern. | +| `sat` / `sat(preserve_nan)` / `nosat` | flags | Selects final conversion behavior for floating exceptional values and finite overflow where the destination type is affected. | + +`pre_quant` legal modes: + +```text +f32_f16, +qf322hif8_pre_vec, qf322hif8_pre_scalar, +qf322hif8_pre_hybrid_vec, qf322hif8_pre_hybrid_scalar, +deqs32_int_vec, deqs32_int_scalar, +req8_vec, req8_scalar, +deqf16_vec, deqf16_scalar, +qf322fp8_pre_vec, qf322fp8_pre_scalar, +qf322f32_pre_vec, qf322f32_pre_scalar, +f32_bf16, +qf162b8_pre_vec, qf162b8_pre_scalar, +qf162s4_pre_vec, qf162s4_pre_scalar, +req4_vec, req4_scalar, +qf322b8_pre_vec, qf322b8_pre_scalar, +qf322s4_pre_vec, qf322s4_pre_scalar, +deqs16_vec, deqs16_scalar, +qf162s16_pre_vec, qf162s16_pre_scalar, +qf322f16_pre_vec, qf322f16_pre_scalar, +qf322bf16_pre_vec, qf322bf16_pre_scalar, +qs322bf16_pre_vec, qs322bf16_pre_scalar +``` + +`_scalar` modes take one floating scalar payload (`f16`, `bf16`, or `f32`) +broadcast to the whole logical output tile. `f16` and `bf16` scalar payloads +are first interpreted as numeric values and widened to `f32`; `f32` payloads +are used directly. `_vec` modes take a `!pto.ptr` +payload pointer. The pointer element type is the logical parameter element +type, not a packed transport carrier. The pointer names the first parameter +row for this store; later rows +advance in the same channel/NZ order as the logical accumulator elements +consumed by the selected layout mode. Each vector pre-quant row is a 128B +parameter row prepared by `pto.mte_l1_fb`; each row supplies the per-channel +scale and any mode-specific offset/sign controls used by the selected +quantization family. Vector pre-ReLU rows are 64B parameter rows and supply +the per-channel alpha values consumed by `vector_relu`. + +`pre_quant` mode families: + +| Family | Acc source | Result meaning | Payload | +|--------|------------|----------------|---------| +| `f32_f16`, `f32_bf16` | `f32` | Convert f32 accumulator values to f16 or bf16; rounding is nearest, ties to even | Scalar payload is required by syntax but does not select per-channel scaling | +| `qf322hif8_pre_*`, `qf322fp8_pre_*` | `f32` | Scale and quantize f32 to hif8/fp8-style destination payloads | Scalar scale or vector scale rows; hybrid modes use the target hybrid rule | +| `qf322f32_pre_*` | `f32` | Apply quant scaling while keeping f32 destination values | Scalar scale or vector scale rows | +| `qf322f16_pre_*`, `qf322bf16_pre_*` | `f32` | Scale f32, then convert to f16 or bf16 destination values | Scalar scale or vector scale rows | +| `qf322b8_pre_*`, `qf322s4_pre_*` | `f32` | Scale, offset, round, and narrow f32 to 8-bit or signed 4-bit integer payloads | Scalar or vector scale/offset parameter set | +| `qf162b8_pre_*`, `qf162s4_pre_*` | `f32` | Convert through an f16-domain pre-stage, then scale/narrow to integer payloads | Scalar or vector scale/offset parameter set | +| `qf162s16_pre_*` | `i32` | Convert through an f16-domain pre-stage, then scale/narrow to signed 16-bit payloads | Scalar or vector scale/offset parameter set | +| `deqs32_int_*`, `deqs16_*` | `i32` | Rescale integer accumulator values in an integer destination family | Scalar or vector multiplier/offset parameter set | +| `req8_*`, `req4_*` | `i32` | Requantize i32 accumulator values to 8-bit or 4-bit integer payloads | Scalar or vector multiplier/offset/sign parameter set | +| `deqf16_*` | `i32` | Dequantize i32 accumulator values to f16 destination values | Scalar or vector multiplier/offset parameter set | +| `qs322bf16_pre_*` | `i32` | Scale i32 accumulator values and convert to bf16 destination values | Scalar or vector multiplier/offset parameter set | + +The mode name determines the accepted accumulator source family. `f32_f16`, +`f32_bf16`, `qf322hif8_pre_*`, `qf322fp8_pre_*`, `qf322f32_pre_*`, +`qf322f16_pre_*`, `qf322bf16_pre_*`, `qf322b8_pre_*`, +`qf322s4_pre_*`, `qf162b8_pre_*`, and `qf162s4_pre_*` consume `f32` +accumulator values. `deqs32_int_*`, `deqs16_*`, `req8_*`, `req4_*`, +`deqf16_*`, `qf162s16_pre_*`, and `qs322bf16_pre_*` consume `i32` +accumulator values. The final destination element type must match the result +family implied by the mode name; for example, `qf322f16_pre_*` writes an +f16-family result, while `req8_*` writes an 8-bit integer-family result. + +Integer quantization families with `b8` in the name can produce either signed +8-bit or unsigned 8-bit results according to the sign control carried by the +scalar or vector parameter set. Families with `s4` or `s16` produce signed +4-bit or signed 16-bit results. Offset fields are added after scaling and +before the final narrow/saturate step. When a family has no offset/sign in its +payload, the payload scale alone controls the conversion. + +`pre_relu` semantics: + +```text +no_relu: y = x +normal_relu: y = max(x, 0) +scalar_relu: y = x >= 0 ? x : alpha * x +vector_relu: y = x >= 0 ? x : alpha[channel] * x +``` + +`scalar_relu` takes a floating scalar payload (`f16`, `bf16`, or `f32`) and +broadcasts it to all negative values in the logical tile. `vector_relu` takes +a `!pto.ptr` pointer whose elements are per-channel +alpha values and whose 64B rows follow the same channel/NZ order as the store. +`no_relu` and `normal_relu` do not take a payload. If +`clip = %clip` is present: + +```text +y = min(y, clip) +``` + +`sat`, `sat(preserve_nan)`, and `nosat` control final conversion to destination +element types affected by FIXPIPE saturation: + +- `sat`: finite overflow clamps to the destination finite range; `+/-inf` + clamps to finite extrema; `nan` writes as 0. +- `sat(preserve_nan)`: same finite overflow and infinity handling as `sat`, + but NaN writes as NaN when the destination format can represent NaN. This is + intended for fp8 and hif8 destination families; for formats without a NaN + encoding it is equivalent to `sat`. +- `nosat`: finite overflow may produce destination exceptional values; + exceptional input values are preserved where the destination format supports + them. +- For fp8 and hif8 destination families, `nosat` preserves NaN; overflow + becomes the destination exceptional value when the destination encoding + supports it. +- For integer destination families, `sat`/`nosat` is not the integer overflow + policy; integer narrowing and clipping are determined by the selected + pre-quant mode, its payload, and any `clip` clause. +- For `f32` destinations, floating exceptional values are preserved; `sat` + does not force f32 `inf`/`nan` into finite values. + +### FIXPIPE Layout Model + +`%src` points to the base accumulator tile. `%m` and `%n` select the logical +result rectangle to write. If the physical accumulator tile contains dummy rows +or lanes outside that rectangle, they are not written to the destination. + +Layout modes: + +| Mode | Destination layout | Extra operand | +|------|--------------------|---------------| +| omitted | Normal target-profile writeback layout | none | +| `nz2nd` | Logical ND order | none | +| `nz2dn(%loop0_src_stride)` | Logical D/N-swapped order | `%loop0_src_stride` in C0-size units | +| `nz2nz(%split)` | NZ-style destination | `%split`, destination split point | + +`%src_stride` is measured in C0-size units and advances the accumulator source +between adjacent source groups selected by the layout mode. `%dst_stride` is +measured in destination elements and advances the destination row/group +selected by the layout mode. In `loop3`, `%src_stride3` is in C0-size units and +`%dst_stride3` is in destination elements. + +Reference semantics: + +```text +repeat_count = loop3.count if loop3 is present else 1 + +for r in 0 .. repeat_count-1: + src_r = src + r * loop3.src_stride * 32 + dst_r = dst + r * loop3.dst_stride * sizeof(dst_element) + + for m in 0 .. M-1: + for n in 0 .. N-1: + x = read_acc_logical(src_r, m, n, src_stride, layout_mode) + + if pre_quant: + x = apply_pre_quant(x, payload, mode) + + if pre_relu: + x = apply_pre_relu(x, payload, mode) + if clip: + x = min(x, clip) + + y = convert_to_destination_type(x, sat_or_nosat) + write_destination(dst_r, y, m, n, dst_stride, layout_mode) +``` + +When no layout clause is present, the store uses the target-profile normal +writeback layout for the destination address space. This mode performs no +explicit ND/DN/NZ layout transform; `%dst_stride` is still the destination +start-to-start stride in destination elements for the normal writeback rows or +groups. + +For `nz2nd`, `write_destination` stores logical `y[m, n]` in ND order. For +`nz2dn`, it stores the same logical result with the D/N dimensions swapped; the +extra `%loop0_src_stride` selects how the swapped source walk advances through +the accumulator tile. For `nz2nz`, it preserves NZ-style destination packing +and uses `%split` as the destination split point. + +### `pto.mte_l0c_l1` + +- **syntax:** +```mlir +pto.mte_l0c_l1 %src, %dst, %m, %n, %src_stride, %dst_stride + [, unit_flag(check_only | check_and_clear)]? + [, pre_quant(%payload, mode = )]? + [, pre_relu([%payload, ]mode = [, clip = %clip])]? + [, nz2nd | nz2dn(%loop0_src_stride) | nz2nz(%split)?] + [, loop3(%count, %src_stride3, %dst_stride3)]? + [, sat | sat(preserve_nan) | nosat]? + : ... +``` +- **semantics:** FIXPIPE writeback from `l0c` to L1 `l1`. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src` | buffer-like | Accumulator source in `l0c` | +| `%dst` | buffer-like | L1 destination in `l1` | +| `%m` | i64 | Logical M element count | +| `%n` | i64 | Logical N element count | +| `%src_stride` | i64 | Source stride in C0-size units | +| `%dst_stride` | i64 | Destination stride in destination elements | +| optional clauses | - | See [FIXPIPE Common Clauses](#fixpipe-common-clauses) and [FIXPIPE Layout Model](#fixpipe-layout-model) | + +**Constraints:** + +- Clauses must appear in canonical order: + `unit_flag` -> `pre_quant` -> `pre_relu` -> layout -> `loop3` -> `sat`/`nosat`. +- `pre_quant` requires payload and mode together. +- Vector `pre_quant` modes require a `fb` pointer with `f16`, `bf16`, or + `f32` element type. +- Scalar `pre_quant` modes require an `f16`, `bf16`, or `f32` scalar payload. +- `pre_quant` source element type must be `f32` or `i32`, and the selected + mode must be compatible with the source and destination element types. +- `no_relu` and `normal_relu` do not accept a payload. +- `scalar_relu` requires an `f16`, `bf16`, or `f32` scalar payload. +- `vector_relu` requires a `fb` pointer with `f16`, `bf16`, or `f32` + element type. +- `clip` can appear only inside `pre_relu(...)`. +- `clip` is supported for destination `f16`, `ui8`, and signed/signless + 4/8/16-bit integer destinations. The clip payload must match the destination + family: `f16` for f16, 16-bit unsigned/signless payload for `ui8`, and + signed/signless `i4/i8/i16` for signed integer destinations. +- `nz2dn` requires `%loop0_src_stride`; `nz2nd` and `nz2nz` do not accept it. +- `unit_flag` must be omitted when `nz2dn(%loop0_src_stride)` uses a value + other than 1. +- `nz2nz` requires `f32` destination element type and does not accept `loop3`. +- `sat`, `sat(preserve_nan)`, and `nosat` are mutually exclusive. + +**Example:** + +```mlir +pto.mte_l0c_l1 %l0c, %l1_out, %c16_i64, %c32_i64, %c16_i64, %c32_i64, + pre_quant(%c1_f32, mode = qf322f16_pre_scalar), + pre_relu(%c025_f32, mode = scalar_relu), + nz2nd, + sat + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, f32, f32 +``` + +--- + +### `pto.mte_l0c_gm` + +- **syntax:** +```mlir +pto.mte_l0c_gm %src, %dst, %m, %n, %src_stride, %dst_stride, %sid, %l2_cache_ctrl + [, unit_flag(check_only | check_and_clear)]? + [, pre_quant(%payload, mode = )]? + [, pre_relu([%payload, ]mode = [, clip = %clip])]? + [, nz2nd | nz2dn(%loop0_src_stride) | nz2nz(%split)?] + [, loop3(%count, %src_stride3, %dst_stride3)]? + [, sat | sat(preserve_nan) | nosat]? + [, atomic(type = , op = )]? + : ... +``` +- **semantics:** FIXPIPE writeback from `l0c` to GM. The data transform clauses + match `pto.mte_l0c_l1`; GM-specific operands select the GM write path and + optional atomic update behavior. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src`, `%m`, `%n`, `%src_stride` | - | Same as `pto.mte_l0c_l1` | +| `%dst` | buffer-like | GM destination | +| `%dst_stride` | i64 | GM destination stride in destination elements | +| `%sid` | i64 | GM stream/session hint for the OUT/GM path; does not change written values | +| `%l2_cache_ctrl` | i64 | GM store cache hint; does not change written values | +| `atomic(type = ..., op = ...)` | clause | Optional GM read-modify-write | +| other optional clauses | - | Same as `pto.mte_l0c_l1` | + +`%sid` and `%l2_cache_ctrl` affect the memory path only. They do not change +the logical result, destination layout, numeric conversion, or atomic +operation. For target-profile GM writeback, constant `%sid` values must be in +`[0, 3]`; use `0` unless the surrounding memory system deliberately assigns a +different stream/session hint. Constant `%l2_cache_ctrl` values must fit in the +target cache-control hint range `[0, 15]`. + +`atomic(type = T, op = add|max|min)` performs an atomic read-modify-write at +each GM destination element. `add` accumulates the converted value into the +existing GM value. `max` and `min` compare using `T` and write the selected +value. Supported atomic types are `f32`, `f16`, `bf16`, `s32`, `s16`, and `s8`. + +**Constraints:** + +- `atomic(...)` is valid only on `pto.mte_l0c_gm`. +- `atomic` requires both `type` and `op`. +- Atomic op values are `add`, `max`, and `min`. +- If `%sid` or `%l2_cache_ctrl` is a constant, it must be in the target range + described above. +- Other constraints match `pto.mte_l0c_l1`. + +**Example:** + +```mlir +pto.mte_l0c_gm %l0c, %out, %c16_i64, %c32_i64, %c16_i64, %c32_i64, + %c0_i64, %c0_i64, + pre_quant(%c1_f32, mode = qf322f16_pre_scalar), + nz2nd, + atomic(type = f16, op = add) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, f32 +``` + +--- + +### `pto.mte_l0c_ub` + +- **syntax:** +```mlir +pto.mte_l0c_ub %src, %dst, %m, %n, %src_stride, %dst_stride, + dst_mode(%sub_blockid | split_m | split_n) + [, unit_flag(check_only | check_and_clear)]? + [, pre_quant(%payload, mode = )]? + [, pre_relu([%payload, ]mode = [, clip = %clip])]? + [, nz2nd | nz2dn(%loop0_src_stride) | nz2nz(%split)?] + [, loop3(%count, %src_stride3, %dst_stride3)]? + [, sat | sat(preserve_nan) | nosat]? + : ... +``` +- **semantics:** FIXPIPE writeback from `l0c` to UB. The data transform clauses + match `pto.mte_l0c_l1`; UB-specific operands select single or dual destination + behavior. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src`, `%m`, `%n`, `%src_stride` | - | Same as `pto.mte_l0c_l1` | +| `%dst` | buffer-like | UB destination | +| `%dst_stride` | i64 | UB destination stride in destination elements | +| `dst_mode(%sub_blockid)` | i64 operand | Single-destination mode. `%sub_blockid` selects UB sub-block `0` or `1`; the value may be dynamic. | +| `dst_mode(split_m)` | keyword | Dual-destination mode that splits the logical tile along M. | +| `dst_mode(split_n)` | keyword | Dual-destination mode that splits the logical tile along N. | +| optional clauses | - | Same as `pto.mte_l0c_l1`; `atomic(...)` is not supported | + +In `dst_mode(%sub_blockid)`, the whole logical result tile is written to the +selected UB sub-block using the selected layout mode and `%dst` as that +sub-block's base destination pointer. + +In `dst_mode(split_m)`, the logical tile is split into two M ranges: +`[0, m/2)` and `[m/2, m)`. The first range is written to UB sub-block 0 and the +second range is written to UB sub-block 1. Each sub-block sees its own +destination origin at `%dst`; within each sub-block, the written logical tile +has shape `(m / 2) x n`. + +In `dst_mode(split_n)`, the logical tile is split into two N ranges: +`[0, n/2)` and `[n/2, n)`. The first range is written to UB sub-block 0 and the +second range is written to UB sub-block 1. Each sub-block sees its own +destination origin at `%dst`; within each sub-block, the written logical tile +has shape `m x (n / 2)`. + +**Constraints:** + +- `atomic(...)` is not supported. +- `dst_mode(%sub_blockid)` writes the whole logical tile to one UB sub-block. + Runtime `%sub_blockid` values must be `0` or `1`; constant values are checked + statically when available. +- `dst_mode(split_m)` splits the logical tile along M into two equal-height + sub-block regions. `%m` must be even; each sub-block receives an + `(m / 2) x n` tile. +- `dst_mode(split_n)` splits the logical tile along N into two equal-width + sub-block regions. `%n` must be a multiple of 32; each sub-block receives an + `m x (n / 2)` tile. +- Dual-destination split modes are valid only for target-supported normal or + `nz2nd` writeback cases with pre-quant, pre-ReLU/clip, and other transform + clauses omitted. +- Other constraints match `pto.mte_l0c_l1`. + +**Example:** + +```mlir +pto.mte_l0c_ub %l0c, %ub_out, %c16_i64, %c32_i64, %c16_i64, %c32_i64, + dst_mode(%c1_i64), + nz2nd + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 +``` + +--- + +## Typical Usage / Patterns + +A common Cube matmul flow is: + +```text +GM row/column-major data + -> pto.mte_gm_l1_frac or pto.mte_gm_l1 into L1 `l1` + -> pto.mte_l1_l0a / pto.mte_l1_l0b into `l0a`/`l0b` tiles + -> pto.mad* produces `l0c` tile + -> pto.mte_l0c_l1* writes L1, GM, or UB with optional FIXPIPE transforms +``` + +For MX matmul, load the data tiles and the matching MX scale payloads before +calling `pto.mad_mx*`: + +```text +left data tile + left scale payload +right data tile + right scale payload + -> pto.mad_mx* +``` + +For bias matmul, prepare the bias vector in `bt` with `pto.mte_l1_bt` before the +`pto.mad_bias` / `pto.mad_mx_bias` consumer. diff --git a/docs/isa/tile-op/01-tile-overview.md b/docs/isa/tile-op/01-tile-overview.md new file mode 100644 index 000000000..dcfc2e7f5 --- /dev/null +++ b/docs/isa/tile-op/01-tile-overview.md @@ -0,0 +1,343 @@ +# 1. Tile and PTO Tile Instruction Overview + +> **Category:** Foundational concepts + +This chapter introduces both the tile data model and the **Tile Instruction** surface that operates on it. Read this before any of the per-group Tile Instruction references. + +--- + +## 1.1 What is PTO Tile Instruction + +**PTO Tile Instruction** is a high-performance instruction library built on top of [PTO micro Instruction](../micro-isa/01-pipeline-sync.md). Each tile instruction encapsulates a tile-granular pattern — DMA between GM and on-chip buffers, vector arithmetic over a whole tile, reductions, broadcast / expansion, selection, padding — that internally expands to a sequence of micro-instruction primitives (`pto.vlds`, `pto.vsts`, `pto.vadd`, mask ops, sync flags, …). + +For the kernel author this means: + +- **Author at the tile level.** Use `pto.tload`, `pto.tadd`, `pto.trowsum`, etc., to express tile-granular DMA and compute without writing the underlying vector loop. +- **Drop down to micro instruction when needed.** Inside `pto.vecscope`, `pto.tile_buf_addr` lowers a tile handle to a UB pointer, so handwritten micro-instruction code can read and write the same on-chip data. The mixing pattern is documented in [§1.10](#110-mixing-pto-tile-instruction-and-pto-micro-instruction). +- **Predictable lowering.** Because every Tile Instruction is templated against micro instruction, a kernel that mixes Tile and micro can share scratch tiles, masks, and pipeline events with no representation gap. + +The remaining chapters in this document cover the tile data types, pointer / view ops, DMA, compute families, and op-by-op syntax. The semantics below define the storage contract those ops share. + +## 1.2 Tile Buffer Model + +A **tile** is a bounded, rectangular 2-D sub-region of data that lives in **local on-chip memory** (UB, L0A, L0B, L0C, bias, or scaling buffer) and is consumed or produced by tile-level instructions. A tile is a storage object with an explicit lifetime and an explicit on-chip placement. + +Tile Instruction models tiles as **tile buffers** of type `!pto.tile_buf<...>`. A tile buffer records: + +- the **memory domain** (`loc`) — where the tile lives on chip; +- the **element type** (`dtype`) — how bits are interpreted; +- the **physical shape** (`rows`, `cols`) — how much storage the tile occupies; +- the **valid region** (`v_row`, `v_col`) — the populated sub-rectangle within the physical tile (may be `?` for runtime-dynamic); +- **layout and fractal** metadata (`blayout`, `slayout`, `fractal`, `pad`) — how elements are arranged in storage. + +This differs from a global tensor: + +- A `!pto.tensor_view` is a logical descriptor over **global memory (GM)** — shape information, no on-chip residency. +- A `!pto.partition_tensor_view` is a logical sub-window of a tensor view, still in GM. +- A `!pto.tile_buf` is the **local, on-chip** materialization of a partition — data placed in UB / L0 / bias / scaling buffers. + +Data flow between these is explicit: + +``` +!pto.tensor_view --partition_view--> !pto.partition_tensor_view --tload--> !pto.tile_buf + (GM) (GM slice) (on-chip tile) +``` + +Placement, lifetime, and reuse affect both correctness and performance. `pto.alloc_tile` makes allocation explicit, and pipeline ordering is expressed through the synchronization primitives described in [`01-pipeline-sync.md`](../micro-isa/01-pipeline-sync.md). + +**Explicit buffer lifetime example:** + +```mlir +%a0 = pto.alloc_tile : !pto.tile_buf +%a1 = pto.alloc_tile : !pto.tile_buf + +pto.tload ins(%pv0 : !pto.partition_tensor_view<16x16xf16>) + outs(%a0 : !pto.tile_buf) +pto.tload ins(%pv1 : !pto.partition_tensor_view<16x16xf16>) + outs(%a1 : !pto.tile_buf) +``` + +## 1.3 Hardware Memory Hierarchy + +The Ascend NPU on-chip memory layout that tile buffers map onto: + +``` +GM (Global Memory) +|- MAT (L1 Cache) +| |- LEFT (L0A — left matrix buffer) +| |- RIGHT (L0B — right matrix buffer) +| |- ACC (L0C — accumulator) +| `- BIAS (bias buffer) +`- VEC (UB — unified buffer) +``` + +`loc` on a tile buffer selects one of these domains. The full enum (with mnemonics) is defined in [§2.6 AddressSpace](02-types-and-attributes.md#26-addressspace); each tile ISA chapter calls out which `loc` domains are legal for the ops it covers. + +## 1.4 Instruction Form + +Most Tile Instruction ops use an explicit source/destination form. The destination tile buffer is named in `outs(...)` and is updated in place: + +```mlir +pto. ins(, , ... : , , ...) + outs( : ) + [ {optional-attrs} ] +``` + +- Inputs appear inside `ins(...)` with their types. +- The output tile buffer appears inside `outs(...)`. +- Scalar operands (where applicable) are listed inside `ins(...)` alongside tile operands. +- Optional attributes follow as a trailing `{ ... }` block. + +Synchronization, sub-view, and allocation ops may diverge from this pattern (for example `pto.alloc_tile` yields a tile-buffer handle, and `pto.subset` returns a view). Each chapter states the assembly format for its ops. + +```mlir +pto.tadd ins(%a, %b : !pto.tile_buf, !pto.tile_buf) + outs(%c : !pto.tile_buf) +``` + +## 1.5 Physical Shape vs Valid Region + +Every tile buffer has two shape concepts: + +- **Physical shape** `(rows, cols)` — the extent of backing storage; static and known when the tile buffer type is declared. +- **Valid region** `(v_row, v_col)` — the populated sub-rectangle; either static or dynamic (`?`). + +The physical shape drives layout, fractal alignment, and buffer-size accounting. The valid region drives the iteration domain of compute and DMA ops. **Undefined behavior:** elements outside the valid region are padding — their contents must not be read. + +When the valid region is dynamic (`v_row = ?` or `v_col = ?`), it is provided at `pto.alloc_tile` time (or updated later with `pto.set_validshape`). Most Tile Instruction ops use the destination valid region as the iteration domain; a few ops require all operands to share the same valid region. + +## 1.6 Pipeline Association + +Every Tile Instruction op is associated with a hardware pipeline in the Decoupled Access-Execute architecture: + +| Pipeline | Symbol | Typical ops | +|----------|--------|------------| +| DMA inbound | `PIPE_MTE2` | `pto.tload` | +| DMA outbound | `PIPE_MTE3` | `pto.tstore` | +| Vector | `PIPE_V` | `pto.tadd`, `pto.tadds`, `pto.texp`, `pto.tcvt`, and the rest of the vector arithmetic set | +| Scalar | `PIPE_S` | scalar `arith`/`scf` ops interleaved with tile code | + +Cross-pipeline data dependencies are ordered explicitly, either via the **Flag/Event** mechanism (`pto.set_flag`/`pto.wait_flag`) or the **Buffer-ID** mechanism (`pto.get_buf`/`pto.rls_buf`). See [`01-pipeline-sync.md`](../micro-isa/01-pipeline-sync.md) for the full semantics. + +## 1.7 Scratch Operands and A2/A3 Compatibility + +Some Tile Instruction ops carry an extra `%tmp` tile operand whose only purpose is to keep the operand list aligned with the corresponding A2/A3 PTO instruction interface. Examples include `pto.txor` / `pto.txors` ([Chapter 8](08-bitwise-shift-ops.md)) and `pto.tsel` / `pto.tsels` ([Chapter 11](11-selection-ops.md)). + +`%tmp` exists for cross-arch interface compatibility — A5 templates may not materially use it, but it remains in the public op signature so the same Tile IR can be reused across A2/A3 and A5. Treat it as a required operand whose dtype/shape constraints are stated by the individual op page. + +## 1.8 Conventions for Chapters 5–12 + +Unless an op page states otherwise, the chapters that follow assume: + +- tile operands use `loc=vec`; +- tile layouts use `blayout=row_major` and `slayout=none_box`; +- valid bounds satisfy `v_row <= rows` and `v_col <= cols`; +- examples use the compact `!pto.tile_buf` form. Omitted attributes carry their default values: `valid` = physical shape, `blayout=row_major`, `slayout=none_box`, `fractal=512`, `pad=0`. + +The op pages call out any deviation from these conventions explicitly. + +## 1.9 Minimal End-to-End Example + +A minimal tile-level "load, add, store" kernel: + +```mlir +// Build the GM view and partition it +%tv = pto.make_tensor_view %gm_ptr, shape = [%m, %n], strides = [%s0, %s1] + : !pto.tensor_view +%pv = pto.partition_view %tv, offsets = [%c0, %c0], sizes = [%c16, %c16] + : !pto.tensor_view -> !pto.partition_tensor_view<16x16xf16> + +// Allocate on-chip tile buffers +%a = pto.alloc_tile : !pto.tile_buf +%b = pto.alloc_tile : !pto.tile_buf +%c = pto.alloc_tile : !pto.tile_buf + +// DMA-in, compute, DMA-out +pto.tload ins(%pv : !pto.partition_tensor_view<16x16xf16>) outs(%a : !pto.tile_buf) +pto.tload ins(%pv2 : !pto.partition_tensor_view<16x16xf16>) outs(%b : !pto.tile_buf) +pto.tadd ins(%a, %b : !pto.tile_buf, !pto.tile_buf) + outs(%c : !pto.tile_buf) +pto.tstore ins(%c : !pto.tile_buf) + outs(%pv_out : !pto.partition_tensor_view<16x16xf16>) +``` + +Synchronization is omitted for clarity; for the real ordering contracts (`pto.set_flag`/`pto.wait_flag`, `pto.get_buf`/`pto.rls_buf`, `pto.pipe_barrier`) see [`01-pipeline-sync.md`](../micro-isa/01-pipeline-sync.md). + + + +## 1.10 Mixing PTO Tile Instruction and PTO micro Instruction + +PTO Tile Instruction and PTO micro Instruction can be authored side-by-side in the same kernel. The Tile Instruction surface owns tile placement and GM ↔ on-chip DMA; the micro surface owns vector-register compute inside `pto.vecscope`. The two surfaces meet through `pto.tile_buf_addr`, which converts a tile handle into a UB pointer that vector ops can consume. + +This section presents a softmax kernel that uses both surfaces together, then walks through it. + +### Kernel Structure + +The kernel follows a fixed shape that all mixed Tile + micro programs share: + +1. Build `tensor_view` / `partition_view` descriptors for each GM operand. +2. Use `pto.alloc_tile` to allocate UB tiles with explicit static **size** and **address**. +3. Use `pto.tload` to move data from GM partitions into tiles. +4. Cross the **MTE2 → V** synchronization edge with `pto.set_flag` / `pto.wait_flag`. +5. Open a `pto.vecscope` region. Inside the scope: + - Use `pto.tile_buf_addr` to lower each tile handle into a `!pto.ptr<..., ub>`. + - Use `pto.vlds` / `pto.vsts` and the rest of the micro vector ops to read, compute, and write UB. +6. Cross the **V → MTE3** synchronization edge with `pto.set_flag` / `pto.wait_flag`. +7. Use `pto.tstore` to move tiles back to GM. + +Two boundary rules govern this layout: + +- Tile-domain ops (`pto.tload`, `pto.tstore`, `pto.tadd`, …) **must not appear inside** `pto.vecscope`. +- `pto.tile_buf_addr` is **only legal inside** `pto.vecscope` / `pto.strict_vecscope`. + +The kernel also manually drives address allocation (`alloc_tile addr = ...`) and pipeline synchronization. Lowering with `--enable-insert-sync` is therefore disabled, and `--pto-level=level3` is used so that `alloc_tile` accepts an explicit address operand. + +### Kernel Listing + +The listing below is an online softmax-update kernel reduced to the structurally interesting parts. Repeated descriptors and the deep online-softmax math are abbreviated with `// ...` so that the Tile / micro / sync boundaries stay visible. + +```mlir +module attributes {pto.target_arch = "a5"} { + func.func @online_softmax_update_kernel_2d( + %arg0: !pto.ptr, // oldmax (rows x 1) + %arg1: !pto.ptr, // oldsum (rows x 1) + %arg2: !pto.ptr, // qk (rows x 128) + %arg3: !pto.ptr, // newmax (rows x 1) + %arg4: !pto.ptr, // newsum (rows x 1) + %arg5: !pto.ptr, // expmax (rows x 1) + %arg6: !pto.ptr, // out (rows x 128) + %arg7: i32, %arg8: i32) { // %arg7 = seq_len, %arg8 = total_rows + // -------- (1) GM views and partitions -------- + // Eight rows of the qk and out tensors are processed per block. + %qk_view = pto.make_tensor_view %arg2, + shape = [%c1, %c1, %c1, %rows, %c128], + strides = [%rows_x_128, %rows_x_128, %rows_x_128, %c128, %c1] + : !pto.tensor_view + %qk_part = pto.partition_view %qk_view, + offsets = [%c0, %c0, %c0, %row_base, %c0], + sizes = [%c1, %c1, %c1, %row_count, %seq] + : !pto.tensor_view + -> !pto.partition_tensor_view + // ... oldmax/oldsum/newmax/newsum/expmax/out views/partitions analogous ... + + // -------- (2) Tile allocation with static size and explicit UB address -------- + %qk_tile = pto.alloc_tile addr = %c256_i64 + valid_row = %row_count valid_col = %seq + : !pto.tile_buf + %out_tile = pto.alloc_tile addr = %c8448_i64 + valid_row = %row_count valid_col = %seq + : !pto.tile_buf + %oldmax_tile = pto.alloc_tile addr = %c0_i64 valid_row = %row_count + : !pto.tile_buf + // ... oldsum/newmax/newsum/expmax tiles analogous (each at its own UB addr) ... + + // -------- (3) GM → tile DMA -------- + pto.tload ins(%qk_part : !pto.partition_tensor_view) + outs(%qk_tile : !pto.tile_buf) + pto.tload ins(%oldmax_part : !pto.partition_tensor_view) + outs(%oldmax_tile: !pto.tile_buf) + // ... oldsum tload analogous ... + + // -------- (4) MTE2 → V synchronization -------- + pto.set_flag ["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + // -------- (5) Vector region: tile_buf_addr + micro compute -------- + pto.vecscope { + // Lower tile handles to UB pointers. + %ub_qk = pto.tile_buf_addr %qk_tile + : !pto.tile_buf + -> !pto.ptr + %ub_out = pto.tile_buf_addr %out_tile + : !pto.tile_buf + -> !pto.ptr + %ub_newmax = pto.tile_buf_addr %newmax_tile + : !pto.tile_buf + -> !pto.ptr + // ... ub_oldmax / ub_oldsum / ub_newsum / ub_expmax analogous ... + + %active = pto.pset_b32 "PAT_ALL" : !pto.mask + %one_mask, %_ = pto.plt_b32 %c1_i32 : i32 -> !pto.mask, i32 + + scf.for %row = %c0 to %row_count step %c1 { + // Online-softmax max/sum reduction (one row at a time). + %row_qk = arith.muli %row, %c128 : index + %oldmax_bc = pto.vlds %ub_oldmax[%row] + {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> + %final_max, %final_sum = scf.for %chunk = %c0 to %c128 step %c64 + iter_args(%running_max = %oldmax_bc, %running_sum = %oldsum_bc) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %base = arith.addi %row_qk, %chunk : index + %vec = pto.vlds %ub_qk[%base] + : !pto.ptr -> !pto.vreg<64xf32> + // ... running_max / running_sum update via vcmax / vexpdif / vmul / vadd ... + scf.yield %merged_max, %merged_sum : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + // Persist the row-local results back to UB. + pto.vsts %final_max, %ub_newmax[%row], %one_mask + {dist = "1PT_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + + // Second pass: write softmax output back into the qk tile's UB region. + scf.for %chunk = %c0 to %c128 step %c64 { + %base = arith.addi %row_qk, %chunk : index + %vec = pto.vlds %ub_qk[%base] + : !pto.ptr -> !pto.vreg<64xf32> + %exp = pto.vexpdif %vec, %final_max, %chunk_mask, "ODD" + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask + -> !pto.vreg<64xf32> + %out = pto.vdiv %exp, %final_sum, %chunk_mask + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask + -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%base], %chunk_mask + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + } + } + + // -------- (6) V → MTE3 synchronization -------- + pto.set_flag ["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + // -------- (7) Tile → GM DMA -------- + pto.tstore ins(%out_tile : !pto.tile_buf) + outs(%out_part : !pto.partition_tensor_view) + pto.tstore ins(%newmax_tile : !pto.tile_buf) + outs(%newmax_part: !pto.partition_tensor_view) + // ... newsum/expmax tstore analogous ... + + pto.barrier #pto.pipe + return + } +} +``` + +### Code Walkthrough + +The seven numbered comments in the listing above mark the seven steps from §Kernel Structure. The notes below highlight what each step contributes to the Tile / micro split. + +**(1) GM views and partitions** — pure metadata. `pto.make_tensor_view` records the GM tensor's shape and strides; `pto.partition_view` carves out the per-block sub-window. Neither op moves data, and both stay outside `pto.vecscope`. The 5-D shape is a quirk of this kernel's layout convention; the boundary rules don't depend on rank. + +**(2) `pto.alloc_tile` with static size and address** — declares the UB tile handles. The result type fixes the static physical shape (e.g. `8x128xf32`); `addr = %c256_i64` pins the tile to a specific UB byte offset; `valid_row = ...` / `valid_col = ...` carry the runtime valid extents (the `?` markers in `valid=?x?`). Because addresses are hand-assigned, this kernel compiles with `--pto-level=level3` and disables `--enable-insert-sync`. + +**(3) `pto.tload`** — copies a GM partition into the UB tile. Runs on `PIPE_MTE2`. Stays in the Tile domain; it cannot appear inside `pto.vecscope`. + +**(4) MTE2 → V flag handshake** — DMA inbound and the vector pipeline run asynchronously. The producer/consumer edge between `tload` and the upcoming `vecscope` must be made explicit with `pto.set_flag` / `pto.wait_flag`. + +**(5) Vector region** — `pto.vecscope` opens a vector-execution region. The first thing inside is a series of `pto.tile_buf_addr` ops, each lowering a tile handle into a `!pto.ptr`. From that point on the body is pure micro: `pto.vlds` reads UB into vregs, vector arithmetic / SFU / mask ops compute on vregs, and `pto.vsts` writes vregs back to UB. Tile ops are forbidden inside this region; `pto.tile_buf_addr` is forbidden outside. + +**(6) V → MTE3 flag handshake** — mirror of step (4), this time gating the vector results visible to the outbound DMA. + +**(7) `pto.tstore`** — writes each UB tile back to its GM partition, completing the round trip. Same Tile-domain rules as `tload`. + +### Where the Tile and Micro Boundaries Sit + +| Op | Where it must live | Why | +|----|-------------------|-----| +| `pto.alloc_tile`, `pto.tload`, `pto.tstore`, `pto.tadd`, … (Tile domain) | **Outside** `pto.vecscope` | Tile ops describe tile residency and tile-granular DMA / compute; they have no meaning inside a vector-register region. | +| `pto.vlds`, `pto.vsts`, `pto.vmax`, `pto.vexpdif`, … (micro domain) | **Inside** `pto.vecscope` | These ops produce/consume `!pto.vreg` and `!pto.mask` values that only exist inside a vector region. | +| `pto.tile_buf_addr` | **Inside** `pto.vecscope` only | This is the single sanctioned bridge from a tile handle to a UB pointer; outside vecscope, tile handles must be consumed by Tile ops, not by address extraction. | +| `pto.set_flag` / `pto.wait_flag` (and other sync primitives) | Either side | Sync ops belong to whichever pipeline edge they coordinate; in this kernel they appear at the MTE2 → V and V → MTE3 boundaries. | + +In short: keep DMA and tile shape management in Tile-land, keep vreg/mask compute in vecscope, and use `pto.tile_buf_addr` exactly at the boundary. diff --git a/docs/isa/tile-op/02-types-and-attributes.md b/docs/isa/tile-op/02-types-and-attributes.md new file mode 100644 index 000000000..a7f849c48 --- /dev/null +++ b/docs/isa/tile-op/02-types-and-attributes.md @@ -0,0 +1,175 @@ +# 2. Types & Attributes + +> **Category:** Type system and attribute vocabulary + +This chapter defines the types and attributes used across the Tile Instruction chapters. + +--- + +## 2.1 Element Types + +Element types describe the primitive scalar values stored in tiles; by themselves they do not form a value. Common element categories: + +- **Integers:** signless — `i1`, `i8`, `i16`, `i32`, `i64`. Signedness is not encoded in the type; it is selected by operation semantics or attributes. +- **Floating-point:** `f16`, `bf16`, `f32`. +- **Index-like:** `index` values appear as scalar operands (offsets, sizes, scalar compares). + +Operation-specific constraints: + +- Elementwise ops typically require operand and result element types to match. +- Reductions, math ops, and division typically restrict to floating-point or a subset of integer types. +- Bitwise ops require integer element types. +- `pto.tcvt` defines explicit element-type changes under an explicit rounding mode. + +Memory layout and address space do not change element-type semantics; they only affect placement and access patterns. + +## 2.2 `!pto.ptr` + +A typed pointer. `memorySpace` is optional and defaults to `gm`. + +| Parameter | Type | Description | +|-----------|------|-------------| +| `elementType` | element type | Element type pointed to. | +| `memorySpace` | `gm` \| `vec` | Pointer address space (`gm` → global memory, `vec` → UB / vector memory). | + +**Syntax:** `!pto.ptr` or `!pto.ptr` + +Pointer conversions are modeled explicitly with `pto.castptr`. Between two `!pto.ptr` types, casts are only legal when both pointers stay in the same PTO memory space. + +## 2.3 `!pto.tensor_view` + +A descriptor for a global-memory tensor. Holds shape information; strides are supplied at `pto.make_tensor_view` construction time. Does not own data. + +| Parameter | Type | Description | +|-----------|------|-------------| +| `shape` | `ArrayRef` | Tensor shape `[d0, d1]` (each dim may be `?`). | +| `elementType` | element type | Element data type. | + +**Syntax:** `!pto.tensor_view<1024x512xf16>` + +## 2.4 `!pto.partition_tensor_view` + +A logical partition (slice) of a `tensor_view`. Holds shape information for a tile-sized region; strides are inherited from the parent `tensor_view`. Does not own data. + +| Parameter | Type | Description | +|-----------|------|-------------| +| `shape` | `ArrayRef` | Partition shape `[d0, d1]`. | +| `elementType` | element type | Element data type. | + +**Syntax:** `!pto.partition_tensor_view<16x16xf16>` + +## 2.5 `!pto.tile_buf` + +`pto.tile_buf` represents a local on-chip tile buffer with explicit placement, shape, valid region, and layout/fractal metadata. The textual form is **compact**: only the leading `` triple is mandatory; everything else is omitted when it equals its default. + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `loc` | keyword | — | Local memory domain (`vec` / `mat` / `left` / `right` / `acc` / `bias` / `scaling`). | +| `R` × `C` × `dtype` | shape × element type | — | Physical row/column count and element type. | +| `valid` | `v_row x v_col` (each `int64` or `?`) | `R x C` | Valid region. Omitted when equal to physical shape. | +| `blayout` | `BLayout` | `row_major` | Base layout. | +| `slayout` | `SLayout` | `none_box` | Secondary layout. | +| `fractal` | `int32` | `512` | Fractal size. | +| `pad` | `PadValue` enum int | `0` (`null`) | Padding policy/value selector. | + +**Examples:** + +```mlir +// Default config, valid == physical +!pto.tile_buf + +// Dynamic valid region +!pto.tile_buf + +// Non-default config +!pto.tile_buf +``` + +`?` denotes a dynamic symbol resolved at runtime (via `pto.alloc_tile` operands or `pto.set_validshape`). + +## 2.6 AddressSpace + +Defines the physical storage location of a buffer in the Ascend NPU memory hierarchy. + +| Value | Int | Mnemonic | Hardware Mapping | +|-------|-----|----------|------------------| +| `Zero` | 0 | `zero` | Default (unspecified). | +| `GM` | 1 | `gm` | Global Memory. | +| `MAT` | 2 | `mat` | L1 Cache. | +| `LEFT` | 3 | `left` | L0A (left matrix buffer). | +| `RIGHT` | 4 | `right` | L0B (right matrix buffer). | +| `ACC` | 5 | `acc` | L0C (accumulator). | +| `VEC` | 6 | `vec` | UB (unified buffer). | +| `BIAS` | 7 | `bias` | Bias buffer. | +| `SCALING` | 8 | `scaling` | Scaling buffer. | + +**Attribute syntax:** `loc=` (for example `loc=vec`). + +## 2.7 Tile Buf Config + +Composite attribute for tile-buffer layout/fractal/pad. + +| Parameter | Type | Description | +|-----------|------|-------------| +| `bLayout` | `BLayoutAttr` | Base layout (RowMajor / ColMajor). | +| `sLayout` | `SLayoutAttr` | Secondary layout (NoneBox / RowMajor / ColMajor). | +| `sFractalSize` | `IntegerAttr (i32)` | Secondary fractal size. | +| `pad` | `PadValueAttr` | Pad value policy. | + +**Syntax:** `#pto.tile_buf_config` + +**BLayout:** + +| Value | Int | Mnemonic | +|-------|-----|----------| +| `RowMajor` | 0 | `row_major` | +| `ColMajor` | 1 | `col_major` | + +**SLayout:** + +| Value | Int | Mnemonic | +|-------|-----|----------| +| `NoneBox` | 0 | `none_box` | +| `RowMajor` | 1 | `row_major` | +| `ColMajor` | 2 | `col_major` | + +**PadValue:** + +| Value | Int | Mnemonic | +|-------|-----|----------| +| `Null` | 0 | `null` | +| `Zero` | 1 | `zero` | +| `Max` | 2 | `max` | +| `Min` | 3 | `min` | + +## 2.8 Layout + +Global tensor layout attribute for `tensor_view` and `partition_tensor_view`. Tile buffers additionally use **Tile Buf Config** (§2.7) to describe physical/fractal layout. + +| Value | Int | Mnemonic | Description | +|-------|-----|----------|-------------| +| `ND` | 0 | `nd` | Row-major (Normal-Dimension). | +| `DN` | 1 | `dn` | Column-major (Dimension-Normal). | +| `NZ` | 2 | `nz` | Fractal / blocked layout. | + +**Attribute syntax:** `#pto.layout` + +## 2.9 PadMode (for loads) + +Padding mode for `pto.tload`. + +| Value | Int | Description | +|-------|-----|-------------| +| `PadNull` | 0 | No padding. | +| `PadFirstElem` | 1 | Pad using the first element. | +| `PadValue` | 2 | Pad using a specified value. | + +## 2.10 Shared Scalar and Control-Flow Ops + +Tile programs commonly interleave `pto` instructions with a small set of supporting ops: + +- **`func`** — `func.func`, `func.return`, `func.call`. +- **`arith`** — scalar constants/casts (`arith.constant`, `arith.index_cast`, `arith.bitcast`, `arith.extf`/`truncf`/…), integer/float arithmetic, bitwise/shift, compares/select, extended and min/max ops. +- **`scf`** — `scf.for`, `scf.if`, `scf.yield`; several other structured control-flow forms are lowered through `cf`. + +These supporting ops are included here only insofar as tile programs need function structure, scalar computation, and structured control flow; full coverage of those surfaces is out of scope for this reference. diff --git a/docs/isa/tile-op/03-pointer-and-view.md b/docs/isa/tile-op/03-pointer-and-view.md new file mode 100644 index 000000000..d00ab0a06 --- /dev/null +++ b/docs/isa/tile-op/03-pointer-and-view.md @@ -0,0 +1,349 @@ +# 3. Pointer & View Operations + +> **Category:** Address arithmetic, tensor-view construction, tile-buffer allocation +> **Pipeline:** None (all ops are metadata / view construction; no HW side effect) + +These instructions build the address, view, and tile-buffer metadata that later DMA and compute instructions consume. None of them moves data. + +--- + +## `pto.addptr` + +- **syntax:** +```mlir +%result = pto.addptr %base, %offset : !pto.ptr -> !pto.ptr +``` +- **semantics:** `result = ptr + offset`, with `offset` counted in **elements** (not bytes). + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `%base` | `!pto.ptr` | Base pointer. | +| `%offset` | `index` | Element offset. | + +**Constraints:** + +- Result type must match the input pointer type. +- The op is pure (no side effects). + +**Example:** + +```mlir +%ptr_off = pto.addptr %base, %offset : !pto.ptr -> !pto.ptr +``` + +--- + +## `pto.castptr` + +- **syntax:** +```mlir +%p_ptr = pto.castptr %addr : i64 -> !pto.ptr +%p_ptr2 = pto.castptr %p_ptr : !pto.ptr -> !pto.ptr +%addr2 = pto.castptr %p_ptr : !pto.ptr -> i64 +``` +- **semantics:** Explicit cast between integer addresses and `!pto.ptr`, or between two `!pto.ptr` types. + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `input` | integer \| `!pto.ptr<...>` | Source value. | + +**Constraints:** + +- Integer-to-integer casts are rejected; use normal integer cast ops. +- Descriptor types (`!pto.tensor_view<...>`, `!pto.partition_tensor_view<...>`) are not legal direct inputs; extract an address first. +- Pointer-to-pointer casts are only legal when source and destination stay in the same PTO memory space (`gm` or `vec`). +- The op is pure. + +**Example:** + +```mlir +%p0 = pto.castptr %addr : i64 -> !pto.ptr +%p1 = pto.castptr %p0 : !pto.ptr -> !pto.ptr +%a2 = pto.castptr %p1 : !pto.ptr -> i64 +``` + +--- + +## `pto.make_tensor_view` + +- **syntax:** +```mlir +%tv = pto.make_tensor_view %ptr, shape = [%m, %n], strides = [%s0, %s1] + : !pto.tensor_view +``` +- **semantics:** Construct a global tensor view from a pointer, declaring the physical base and strides. No allocation, no data movement. + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `%ptr` | `AnyType` | Source pointer (must be `!pto.ptr` with element type matching the result). | +| `shape` | `Variadic` | Dynamic shape dimensions. | +| `strides` | `Variadic` | Dynamic strides. | +| `layout` | `LayoutAttr` (optional) | `nd` / `dn` / `nz` hint. | + +**Constraints:** + +- `ptr` element type must match the result element type. +- `shape` and `strides` operand counts must match the tensor_view rank. +- If `layout` is provided with static shapes/strides, it must be consistent with the inferred layout. + +**Example:** + +```mlir +%tv = pto.make_tensor_view %ptr, shape = [%m, %n], strides = [%s0, %s1] + : !pto.tensor_view +``` + +--- + +## `pto.get_tensor_view_dim` + +- **syntax:** +```mlir +%dim = pto.get_tensor_view_dim %tv, %idx : !pto.tensor_view<...> -> index +``` +- **semantics:** Return the runtime size of dimension `%idx` from a `tensor_view`. + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `%tv` | `!pto.tensor_view<...>` | Logical tensor view. | +| `%idx` | `index` | Dimension index (0-based). | + +**Example:** + +```mlir +%h = pto.get_tensor_view_dim %tv, %c0 : !pto.tensor_view -> index +``` + +--- + +## `pto.get_tensor_view_stride` + +- **syntax:** +```mlir +%stride = pto.get_tensor_view_stride %tv, %idx : !pto.tensor_view<...> -> index +``` +- **semantics:** Return the logical stride of dimension `%idx`, measured in **elements** (not bytes). + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `%tv` | `!pto.tensor_view<...>` or memref form | Tensor view or its lowered memory-reference form. | +| `%idx` | `index` | Dimension index (0-based). | + +**Example:** + +```mlir +%s0 = pto.get_tensor_view_stride %tv, %c0 : !pto.tensor_view -> index +``` + +--- + +## `pto.tensor_view_addr` + +- **syntax:** +```mlir +%result = pto.tensor_view_addr %src : !pto.tensor_view<...> -> memref<...> +%result = pto.tensor_view_addr %src : !pto.tensor_view<...> -> !pto.ptr +``` +- **semantics:** Extract the underlying address view from a `tensor_view` or `partition_tensor_view`. + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `%src` | `!pto.tensor_view<...>` or `!pto.partition_tensor_view<...>` | Source view descriptor. | + +**Constraints:** + +- The result type must be either the lowered memref view or a GM pointer `!pto.ptr` to the same underlying storage. +- The op is pure and does not move data. + +**Example:** + +```mlir +%base = pto.tensor_view_addr %tv : !pto.tensor_view -> !pto.ptr +``` + +`pto.tensor_view_addr` exposes the underlying address represented by the view descriptor. When the result type is a memref, it exposes the lowered view directly. When the result type is `!pto.ptr<..., gm>`, it exposes the same address in pointer form. During compiler-internal lowering, the operand may already be rewritten to a memref form; in that case this op is folded away or rewritten to an equivalent memref-to-ptr cast. + +--- + +## `pto.partition_view` + +- **syntax:** +```mlir +%pv = pto.partition_view %tv, offsets = [%o0, %o1], sizes = [%s0, %s1] + : !pto.tensor_view<...> -> !pto.partition_tensor_view<...> +``` +- **semantics:** `result = source[offsets, sizes]` — a logical window on a `tensor_view`. Captures both static and dynamic shapes; does not move data. + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `%tv` | `TensorViewType` | Input tensor view. | +| `offsets` | `Variadic` | Dynamic offsets. | +| `sizes` | `Variadic` | Dynamic sizes. | + +**Constraints:** + +- `offsets`/`sizes` counts must match the rank of `source`. + +**Example:** + +```mlir +%pv = pto.partition_view %tv, offsets = [%off0, %off1], sizes = [%s0, %s1] + : !pto.tensor_view<1024x512xf16> -> !pto.partition_tensor_view<16x16xf16> +``` + +--- + +## `pto.alloc_tile` + +- **syntax:** +```mlir +%tb = pto.alloc_tile : !pto.tile_buf<...> +%tb2 = pto.alloc_tile valid_row = %vr valid_col = %vc : !pto.tile_buf +%tb3 = pto.alloc_tile addr = %ad : !pto.tile_buf<...> +``` +- **semantics:** Declare the lifetime of a tile buffer. Each call produces an **independent** tile-buffer instance. + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `addr` | `Optional` | Optional start address. If omitted, assigned by the implementation. | +| `valid_row` | `Optional` | Dynamic valid-row count (required when result `v_row = ?`). | +| `valid_col` | `Optional` | Dynamic valid-col count (required when result `v_col = ?`). | + +**Constraints:** + +- If result `v_row`/`v_col` are dynamic (`?`), the corresponding operands must be present. +- If result `v_row`/`v_col` are static, the corresponding operands must be absent. + +**Example:** + +```mlir +%tb = pto.alloc_tile : !pto.tile_buf +``` + +--- + +## `pto.subset` + +- **syntax:** +```mlir +%sub = pto.subset %src[%i, %j] sizes [rows, cols] : !pto.tile_buf<...> +``` +- **semantics:** `result = source[offsets]` with static `sizes`. Creates a strided view of a parent tile. + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `%src` | `pto.tile_buf` | Parent tile buffer. | +| `offsets` | `Variadic` | Runtime offsets `[i, j]`. | +| `sizes` | `I64ArrayAttr` | Static shape `[rows, cols]`. | + +**Constraints:** + +- Boxed-vs-non-boxed behavior is derived from the source's tile config (`blayout`, `slayout`, `fractal`) and element type. +- For non-boxed layouts (`slayout=none_box`), no additional subset-specific structural checks are enforced. +- For boxed layouts: + - `sizes` must have length 2 and both subset sizes must be positive. + - Subset sizes must be multiples of the inferred inner boxed shape. + - `offsets` must have length 2; constant offsets must be non-negative and multiples of the inferred inner boxed shape. + - Source tile shape must be statically known. + - For boxed row-major tiles: subset must keep the full source column extent, and the column offset must be the constant `0`. + - For boxed col-major tiles: subset must keep the full source row extent, and the row offset must be the constant `0`. +- The inferred result reuses the source's element type, address space, and tile config. `valid_shape` is derived from the parent valid shape and constant offsets, or dynamic when offsets are dynamic. + +**Example:** + +```mlir +%sub = pto.subset %src[%i, %j] sizes [32, 32] + : !pto.tile_buf +``` + +--- + +## `pto.set_validshape` + +- **syntax:** +```mlir +pto.set_validshape %src, %valid_row, %valid_col : !pto.tile_buf +``` +- **semantics:** Update the runtime `v_row`/`v_col` metadata on an existing **dynamic** tile buffer. + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `%src` | `pto.tile_buf` | Dynamic rank-2 tile buffer. | +| `%valid_row` | `index` | Runtime valid row count. | +| `%valid_col` | `index` | Runtime valid column count. | + +**Constraints:** + +- `%src` must be rank-2 and use `v_row = ?` and `v_col = ?` on both dimensions. +- Tile programs use `pto.tile_buf`; memref forms are a lowering artifact and are not part of this surface. +- Constant `valid_row`/`valid_col` must be non-negative and `<=` the tile's static shape bounds. + +**Example:** + +```mlir +%src = pto.alloc_tile : !pto.tile_buf +pto.set_validshape %src, %vr, %vc : !pto.tile_buf +``` + +--- + +## `pto.tile_buf_addr` + +- **syntax:** +```mlir +%ub_ptr = pto.tile_buf_addr %tile : !pto.tile_buf<...> -> !pto.ptr +%ub_ref = pto.tile_buf_addr %tile : !pto.tile_buf<...> -> memref<...> +``` +- **semantics:** Extract the address of a `pto.tile_buf`'s data region. Returns either a typed PTO pointer (`!pto.ptr`) or a memref view, depending on the requested result type. Pure op: no data movement, no pipeline activity. + +This op is the **boundary between tile-buffer instructions and pointer-based vector instructions**. Inside a `pto.vecscope` body, use `pto.tile_buf_addr` to materialize a vec-space pointer from a tile handle allocated outside the scope; vector load/store ops such as `pto.vlds` and `pto.vsts` then consume that pointer. + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `%tile` | `pto.tile_buf` or tile-bound memref | Tile handle whose data-region address is taken. | + +**Results:** `!pto.ptr` or `memref<...>`. Memref results use the tile's static shape and address space; pointer results use the tile's element type and memory space (e.g. `vec`). + +**Constraints:** + +- Result must be either a typed PTO pointer or a memref view; no other result types are accepted. +- When a memref result is requested, the lowered form uses the tile's static shape and address space. +- `pto.tile_buf_addr` is **only legal inside `pto.vecscope` / `pto.strict_vecscope`**. Outside a vector scope, tile handles must be consumed by tile-level ops (`pto.tload`, `pto.tstore`, `pto.tadd`, …) rather than by address extraction. Conversely, tile-level ops must **not** appear inside `pto.vecscope`. + +**Example (inside `pto.vecscope`):** + +```mlir +%tile = pto.alloc_tile addr = %c0_i64 valid_row = %r + : !pto.tile_buf + +pto.vecscope { + %ub = pto.tile_buf_addr %tile + : !pto.tile_buf -> !pto.ptr + // ... vector-scope loads/stores on %ub ... +} +``` + +See [`03-vector-load-store.md`](../micro-isa/03-vector-load-store.md) for the pointer-based +vector load/store side of this handoff. diff --git a/docs/isa/tile-op/04-dma-data-movement.md b/docs/isa/tile-op/04-dma-data-movement.md new file mode 100644 index 000000000..6646b7fdf --- /dev/null +++ b/docs/isa/tile-op/04-dma-data-movement.md @@ -0,0 +1,80 @@ +# 4. DMA Data Movement + +> **Category:** GM↔on-chip DMA for tile buffers +> **Pipelines:** PIPE_MTE2 (GM→UB), PIPE_MTE3 (UB→GM), PIPE_FIX (when source is `loc=acc`) + +This chapter documents the public tile DMA instructions `pto.tload` and `pto.tstore`. Other raw scalar load/store helpers are outside the current tile-instruction subset and are not covered here. + +--- + +## `pto.tload` + +- **syntax:** +```mlir +pto.tload ins(%src : !pto.partition_tensor_view<...>) + outs(%dst : !pto.tile_buf<...>) +``` +- **semantics:** Physical DMA transfer from a global partition view into a local tile buffer. For each element `(i, j)` in the destination valid region: `dst[i, j] = src[i, j]`. + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `PartitionTensorViewType` | Source partition view. | +| `dst` | `pto.tile_buf` | Destination tile buffer. | + +**Constraints:** + +- Tile element type ∈ `{i8, i16, i32, i64, f16, bf16, f32}`. +- Destination tile must use `loc=vec`. +- Destination tile element type and source partition element type must have the same bitwidth. +- Runtime: source partition extents and destination valid region must be positive. + +**Pipeline:** `PIPE_MTE2`. + +**Example:** + +```mlir +pto.tload ins(%pv : !pto.partition_tensor_view<16x16xf16>) + outs(%tb : !pto.tile_buf) +``` + +--- + +## `pto.tstore` + +- **syntax:** +```mlir +pto.tstore ins(%src : !pto.tile_buf<...>) + outs(%dst : !pto.partition_tensor_view<...>) +``` +- **semantics:** Store a 2-D tile buffer back to a 2-D partition view. For each element `(i, j)` in the source valid region: `dst[i, j] = src[i, j]`. + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `pto.tile_buf` | Source tile buffer. | +| `dst` | `PartitionTensorViewType` | Destination partition view. | + +**Constraints:** + +- `src` must be `!pto.tile_buf`, `dst` must be `!pto.partition_tensor_view`. +- Static dst shape dims and static src valid-shape dims must be positive. +- `src.loc ∈ {vec, mat, acc}`. +- For `loc=vec` / `loc=mat`: src element type ∈ `{i8, i16, i32, i64, f16, bf16, f32}`; src/dst element bitwidth must match. +- For `loc=acc`: + - src element type must be `i32` or `f32`. + - dst element type ∈ `{i32, f32, f16, bf16}`. + +**Pipeline:** + +- `src.loc=acc` uses **PIPE_FIX**. +- `src.loc=vec` / `src.loc=mat` uses **PIPE_MTE3**. + +**Example:** + +```mlir +pto.tstore ins(%tb : !pto.tile_buf) + outs(%pv : !pto.partition_tensor_view<16x16xf16>) +``` diff --git a/docs/isa/tile-op/05-vector-arithmetic.md b/docs/isa/tile-op/05-vector-arithmetic.md new file mode 100644 index 000000000..01624f580 --- /dev/null +++ b/docs/isa/tile-op/05-vector-arithmetic.md @@ -0,0 +1,207 @@ +# 5. Vector Arithmetic and Activation Operations + +> **Category:** Base tile-local VEC arithmetic +> **Pipeline:** PIPE_V + +This chapter documents the TileLib arithmetic families that keep the same output tile shape as their source tiles. These instructions operate on `!pto.tile_buf` values in `loc=vec` and cover tile-tile arithmetic, tile-scalar arithmetic, unary math, and activation ops. + +Reduction, partial, bitwise, conversion, broadcast / expansion, selection, and fill / padding families are documented in Chapters 6-12. + +--- + +## 5.1 Binary Tile-Tile Arithmetic + +Tile-tile arithmetic families: + +| Op | Semantics | +|----|-----------| +| `pto.tadd` | `dst[i, j] = src0[i, j] + src1[i, j]` | +| `pto.tsub` | `dst[i, j] = src0[i, j] - src1[i, j]` | +| `pto.tmul` | `dst[i, j] = src0[i, j] * src1[i, j]` | +| `pto.tdiv` | `dst[i, j] = src0[i, j] / src1[i, j]` | +| `pto.tmax` | `dst[i, j] = max(src0[i, j], src1[i, j])` | +| `pto.tmin` | `dst[i, j] = min(src0[i, j], src1[i, j])` | + +### Common Syntax + +```mlir +pto. ins(%src0, %src1 : !pto.tile_buf<...>, !pto.tile_buf<...>) + outs(%dst : !pto.tile_buf<...>) +``` + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src0` | `pto.tile_buf` | First source tile buffer. | +| `src1` | `pto.tile_buf` | Second source tile buffer. | +| `dst` | `pto.tile_buf` | Destination tile buffer. | + +**Constraints:** + +- `src0`, `src1`, and `dst` must be shape-compatible tile buffers on `loc=vec`. +- The valid region must match across all three tiles. +- Element type legality is target-defined; ops specialize over the tile dtype selected at expansion time. +- `pto.tdiv` uses element-wise division; **undefined behavior** on divide-by-zero. +- `pto.tdiv` additionally accepts `precision_mode = #pto`. + Omitted means `DEFAULT`. + `HIGH_PRECISION` is currently legal only when the tile element type is `f16` or `f32`. + +**Example:** + +```mlir +pto.tadd ins(%a, %b : !pto.tile_buf, !pto.tile_buf) + outs(%c : !pto.tile_buf) +``` + +--- + +## 5.2 Tile-Scalar Arithmetic + +Tile-scalar families: + +| Op | Supported operand form(s) | Semantics | +|----|---------------------------|-----------| +| `pto.tadds` | `tile, scalar` | `dst[i, j] = src[i, j] + scalar` | +| `pto.tsubs` | `tile, scalar` | `dst[i, j] = src[i, j] - scalar` | +| `pto.tmuls` | `tile, scalar` | `dst[i, j] = src[i, j] * scalar` | +| `pto.tdivs` | `tile, scalar` and `scalar, tile` | `dst = src / scalar` or `dst = scalar / src` | +| `pto.tmaxs` | `tile, scalar` | `dst[i, j] = max(src[i, j], scalar)` | +| `pto.tmins` | `tile, scalar` | `dst[i, j] = min(src[i, j], scalar)` | + +### Common Syntax + +For `pto.tadds`, `pto.tsubs`, `pto.tmuls`, `pto.tmaxs`, and `pto.tmins`: + +```mlir +pto. ins(%src, %scalar : !pto.tile_buf<...>, ) + outs(%dst : !pto.tile_buf<...>) +``` + +For `pto.tdivs`: + +```mlir +pto.tdivs ins(%src, %scalar : !pto.tile_buf<...>, ) + outs(%dst : !pto.tile_buf<...>) + {precision_mode = #pto} + +pto.tdivs ins(%scalar, %src : , !pto.tile_buf<...>) + outs(%dst : !pto.tile_buf<...>) + {precision_mode = #pto} +``` + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `pto.tile_buf` | Source tile buffer. | +| `scalar` | signless integer / floating-point scalar | Scalar broadcast across the tile. | +| `dst` | `pto.tile_buf` | Destination tile buffer. | + +**Constraints:** + +- `src` and `dst` must be shape-compatible `loc=vec` tile buffers. +- The scalar element type must be compatible with the tile element type. +- `pto.tdivs` is the only scalar family with two public operand orders. **Undefined behavior** on divide-by-zero (either `scalar==0` or any `src[i,j]==0` in the `scalar/src` form). +- `pto.tdivs` additionally accepts `precision_mode = #pto`. + Omitted means `DEFAULT`. + `HIGH_PRECISION` is currently legal only when the tile element type is `f16` or `f32`. + +**Example:** + +```mlir +pto.tadds ins(%a, %s : !pto.tile_buf, f32) + outs(%c : !pto.tile_buf) +``` + +```mlir +pto.tdivs ins(%s, %a : f32, !pto.tile_buf) + outs(%c : !pto.tile_buf) +``` + +--- + +## 5.3 Unary Math + +All ops below share the common form: + +```mlir +pto. ins(%src : !pto.tile_buf<...>) + outs(%dst : !pto.tile_buf<...>) +``` + +| Op | Semantics | +|----|-----------| +| `pto.tabs` | `dst = abs(src)` | +| `pto.tneg` | `dst = -src` | +| `pto.texp` | `dst = exp(src)` | +| `pto.tlog` | `dst = ln(src)` | +| `pto.tsqrt` | `dst = sqrt(src)` | +| `pto.trsqrt` | `dst = 1 / sqrt(src)` | +| `pto.trecip` | `dst = 1 / src` | + +**Constraints:** + +- `src` and `dst` must have the same valid region. +- These ops are numeric Tile Instruction ops on `loc=vec`. +- **Undefined behavior** on out-of-domain inputs: `tlog(<=0)`, `tsqrt(<0)`, `trsqrt(<=0)`, `trecip(0)`. +- `pto.texp`, `pto.tlog`, `pto.tsqrt`, `pto.trsqrt`, and `pto.trecip` additionally accept + `precision_mode = #pto`. + Omitted means `DEFAULT`. + For these five unary ops, `HIGH_PRECISION` is currently legal on their supported floating-point element types. + +**Precision-Mode Form:** + +```mlir +pto. ins(%src : !pto.tile_buf<...>) + outs(%dst : !pto.tile_buf<...>) + {precision_mode = #pto} +``` + +**Example:** + +```mlir +pto.tabs ins(%a : !pto.tile_buf) + outs(%c : !pto.tile_buf) +``` + +--- + +## 5.4 Activation Operations + +Activation family: + +| Op | Semantics | +|----|-----------| +| `pto.trelu` | `dst[i, j] = max(0, src[i, j])` | +| `pto.tlrelu` | `dst[i, j] = src[i, j] > 0 ? src[i, j] : slope * src[i, j]` | + +### Common Forms + +ReLU: + +```mlir +pto.trelu ins(%src : !pto.tile_buf<...>) + outs(%dst : !pto.tile_buf<...>) +``` + +Leaky ReLU: + +```mlir +pto.tlrelu ins(%src, %slope : !pto.tile_buf<...>, f32) + outs(%dst : !pto.tile_buf<...>) +``` + +**Constraints:** + +- `src` and `dst` must have the same valid region. +- `pto.trelu` supports `f16`, `f32`, and `i32`. +- `pto.tlrelu` supports `f16` and `f32`, with the slope passed as an `f32` scalar operand. +- Both ops execute on `loc=vec` tiles via the vector pipeline. + +**Example:** + +```mlir +pto.trelu ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) +``` diff --git a/docs/isa/tile-op/06-reduction-ops.md b/docs/isa/tile-op/06-reduction-ops.md new file mode 100644 index 000000000..770cd3490 --- /dev/null +++ b/docs/isa/tile-op/06-reduction-ops.md @@ -0,0 +1,83 @@ +# 6. Reduction Operations + +> **Category:** Tile-local VEC reductions +> **Pipeline:** PIPE_V + +This chapter documents the TileLib reduction families. These ops reduce one or more source dimensions into smaller destination tiles and are organized into row-reduction and column-reduction groups. + +--- + +## 6.1 Row Reductions + +Row reductions reduce each row of `%src` into one element stored at `%dst[row, 0]`. The op shape carries a scratch tile operand `%tmp` to keep the operand list aligned with the A2/A3 PTO instruction interface (see [§1.7](01-tile-overview.md#17-scratch-operands-and-a2a3-compatibility)). + +### Common Syntax + +```mlir +pto. ins(%src, %tmp : !pto.tile_buf<...>, !pto.tile_buf<...>) + outs(%dst : !pto.tile_buf<...>) +``` + +| Op | Semantics | +|----|-----------| +| `pto.trowsum` | `dst[i, 0] = sum_j src[i, j]` | +| `pto.trowprod` | `dst[i, 0] = prod_j src[i, j]` | +| `pto.trowmax` | `dst[i, 0] = max_j src[i, j]` | +| `pto.trowmin` | `dst[i, 0] = min_j src[i, j]` | +| `pto.trowargmax` | `dst[i, 0] = argmax_j src[i, j]` | +| `pto.trowargmin` | `dst[i, 0] = argmin_j src[i, j]` | + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `pto.tile_buf` | Source tile buffer. | +| `tmp` | `pto.tile_buf` | Scratch tile for A2/A3 interface compatibility. | +| `dst` | `pto.tile_buf` | Destination tile storing one result per source row. | + +**Constraints:** + +- `dst.v_row` should match `src.v_row`. +- `dst.v_col` should be `1`. +- `pto.trowargmax` and `pto.trowargmin` require an integer destination element type for the row-local index result. +- Numeric widening / narrowing inside the reduction is target-defined by the selected template (e.g. `trowsum` may widen `i16` accumulation internally before storing to `dst`). + +**Example:** + +```mlir +pto.trowsum ins(%src, %tmp : !pto.tile_buf, !pto.tile_buf) + outs(%dst : !pto.tile_buf) +``` + +--- + +## 6.2 Column Reductions + +Column reductions reduce each column of `%src` into one element stored at `%dst[0, col]`. + +### Common Syntax + +```mlir +pto. ins(%src : !pto.tile_buf<...>) + outs(%dst : !pto.tile_buf<...>) +``` + +| Op | Semantics | +|----|-----------| +| `pto.tcolsum` | `dst[0, j] = sum_i src[i, j]` | +| `pto.tcolprod` | `dst[0, j] = prod_i src[i, j]` | +| `pto.tcolmax` | `dst[0, j] = max_i src[i, j]` | +| `pto.tcolmin` | `dst[0, j] = min_i src[i, j]` | + +**Constraints:** + +- `dst.v_row` should be `1`. +- `dst.v_col` should match `src.v_col`. +- Templates assume prefix-aligned valid regions and row-major VEC tiles. + +**Example:** + +```mlir +pto.tcolsum ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) +``` diff --git a/docs/isa/tile-op/07-partial-elementwise.md b/docs/isa/tile-op/07-partial-elementwise.md new file mode 100644 index 000000000..22a5d4e9d --- /dev/null +++ b/docs/isa/tile-op/07-partial-elementwise.md @@ -0,0 +1,37 @@ +# 7. Partial Elementwise Operations + +> **Category:** Tile-local VEC partial-shape compute +> **Pipeline:** PIPE_V + +This chapter documents the TileLib partial elementwise families. These ops combine two tiles whose valid regions may differ, but whose overlap starts at `[0, 0]`. + +--- + +## Common Syntax + +```mlir +pto. ins(%src0, %src1 : !pto.tile_buf<...>, !pto.tile_buf<...>) + outs(%dst : !pto.tile_buf<...>) +``` + +| Op | Semantics on the overlap region | +|----|----------------------------------| +| `pto.tpartadd` | `dst = src0 + src1` | +| `pto.tpartmul` | `dst = src0 * src1` | +| `pto.tpartmax` | `dst = max(src0, src1)` | +| `pto.tpartmin` | `dst = min(src0, src1)` | + +**Constraints:** + +- Let `big` ∈ {`src0`, `src1`} be the operand whose valid shape equals `dst.valid_shape`, and `small` be the other operand. Exactly one operand plays each role. +- `small.valid_shape` must be a prefix-aligned sub-rectangle of `dst.valid_shape` (i.e. starting at `[0, 0]`). +- For `pto.tpartadd` and `pto.tpartmul`: outside the overlap (where only `big` covers `dst`), `dst` takes `big`'s value. +- For `pto.tpartmax` and `pto.tpartmin`: A5 templates initialize `dst` with the dtype extremum before merging the operands, so uncovered regions follow the template's pad-extremum behavior. + +**Example:** + +```mlir +pto.tpartadd ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) +``` diff --git a/docs/isa/tile-op/08-bitwise-shift-ops.md b/docs/isa/tile-op/08-bitwise-shift-ops.md new file mode 100644 index 000000000..4d32359e1 --- /dev/null +++ b/docs/isa/tile-op/08-bitwise-shift-ops.md @@ -0,0 +1,116 @@ +# 8. Bitwise and Shift Operations + +> **Category:** Tile-local integer VEC compute +> **Pipeline:** PIPE_V + +This chapter documents the integer-only TileLib bitwise and shift families. + +--- + +## 8.1 Unary Bitwise NOT: `pto.tnot` + +- **syntax:** +```mlir +pto.tnot ins(%src : !pto.tile_buf<...>) + outs(%dst : !pto.tile_buf<...>) +``` +- **semantics:** `dst = ~src`. + +**Constraints:** + +- Tile element types must be integer types. +- `src` and `dst` must have the same valid region. + +**Example:** + +```mlir +pto.tnot ins(%a : !pto.tile_buf) + outs(%c : !pto.tile_buf) +``` + +--- + +## 8.2 Binary Tile-Tile Bitwise and Shift Families + +Tile-tile bitwise and shift families: + +| Op | Semantics | +|----|-----------| +| `pto.tand` | `dst = src0 & src1` | +| `pto.tor` | `dst = src0 \| src1` | +| `pto.txor` | `dst = src0 ^ src1` | +| `pto.tshl` | `dst = src0 << src1` | +| `pto.tshr` | `dst = src0 >> src1` | + +### Common Forms + +For `pto.tand`, `pto.tor`, `pto.tshl`, and `pto.tshr`: + +```mlir +pto. ins(%src0, %src1 : !pto.tile_buf<...>, !pto.tile_buf<...>) + outs(%dst : !pto.tile_buf<...>) +``` + +`pto.txor` carries an extra scratch tile `%tmp` for A2/A3 interface compatibility (see [§1.7](01-tile-overview.md#17-scratch-operands-and-a2a3-compatibility)): + +```mlir +pto.txor ins(%src0, %src1, %tmp : !pto.tile_buf<...>, !pto.tile_buf<...>, + !pto.tile_buf<...>) + outs(%dst : !pto.tile_buf<...>) +``` + +**Constraints:** + +- Tile element types must be integer types. +- `src0`, `src1`, and `dst` must have the same valid region. + +**Example:** + +```mlir +pto.tand ins(%a, %b : !pto.tile_buf, !pto.tile_buf) + outs(%c : !pto.tile_buf) +``` + +--- + +## 8.3 Tile-Scalar Bitwise and Shift Families + +Tile-scalar bitwise and shift families: + +| Op | Semantics | +|----|-----------| +| `pto.tands` | `dst = src & scalar` | +| `pto.tors` | `dst = src \| scalar` | +| `pto.txors` | `dst = src ^ scalar` | +| `pto.tshls` | `dst = src << scalar` | +| `pto.tshrs` | `dst = src >> scalar` | + +### Common Forms + +For `pto.tands`, `pto.tors`, `pto.tshls`, and `pto.tshrs`: + +```mlir +pto. ins(%src, %scalar : !pto.tile_buf<...>, ) + outs(%dst : !pto.tile_buf<...>) +``` + +`pto.txors` carries an extra scratch tile `%tmp` for A2/A3 interface compatibility: + +```mlir +pto.txors ins(%src, %scalar, %tmp : !pto.tile_buf<...>, , + !pto.tile_buf<...>) + outs(%dst : !pto.tile_buf<...>) +``` + +**Constraints:** + +- Tile element types must be integer types. +- `src` and `dst` must have the same valid region. +- The scalar operand must be an integer-compatible shift / bitwise scalar. + +**Example:** + +```mlir +pto.tands ins(%a, %s : !pto.tile_buf, i32) + outs(%dst : !pto.tile_buf) +``` diff --git a/docs/isa/tile-op/09-type-conversion.md b/docs/isa/tile-op/09-type-conversion.md new file mode 100644 index 000000000..82e402ac8 --- /dev/null +++ b/docs/isa/tile-op/09-type-conversion.md @@ -0,0 +1,58 @@ +# 9. Type Conversion + +> **Category:** Element-wise type conversion +> **Pipeline:** PIPE_V + +This chapter documents the element-wise tile conversion instruction `pto.tcvt` and the rounding modes it uses. + +--- + +## `RoundMode` + +Rounding modes for `pto.tcvt`. + +| Value | Int | Description | +|-------|-----|-------------| +| `NONE` | 0 | No rounding. | +| `RINT` | 1 | Round to nearest integer. | +| `ROUND` | 2 | Round `f16` away from zero. | +| `FLOOR` | 3 | Round toward negative infinity. | +| `CEIL` | 4 | Round toward positive infinity. | +| `TRUNC` | 5 | Truncate toward zero. | +| `ODD` | 6 | Round to odd. | +| `CAST_RINT` | 7 | Cast with round-to-nearest (default). | + +**Attribute syntax:** `#pto` + +--- + +## `pto.tcvt` + +- **syntax:** +```mlir +pto.tcvt ins(%src : !pto.tile_buf<...>) + outs(%dst : !pto.tile_buf<...>) + {rmode = #pto} +``` +- **semantics:** `dst[i, j] = cast(src[i, j], rmode)` element-wise. + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `pto.tile_buf` | Source tile. | +| `dst` | `pto.tile_buf` | Destination tile (different element type). | +| `rmode` | `RoundModeAttr` | Default `CAST_RINT`. | + +**Constraints:** + +- `src`/`dst` must be shape/valid-region compatible. +- This reference does not define extra legality rules for the `(src, dst)` type pair. **Undefined behavior** on conversion pairs not supported by the target hardware; consult the A2/A3 and A5 hardware specs for legal pairs. + +**Example:** + +```mlir +pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + {rmode = #pto} +``` diff --git a/docs/isa/tile-op/10-broadcast-and-expansion-ops.md b/docs/isa/tile-op/10-broadcast-and-expansion-ops.md new file mode 100644 index 000000000..3e2133c93 --- /dev/null +++ b/docs/isa/tile-op/10-broadcast-and-expansion-ops.md @@ -0,0 +1,210 @@ +# 10. Broadcast and Expansion Operations + +> **Category:** Tile-local VEC broadcast and expansion compute +> **Pipeline:** PIPE_V + +This chapter documents the TileLib broadcast, row-expansion, and column-expansion families. These ops populate destination tiles by broadcasting one logical scalar across a larger region — either from a standalone scalar operand, one source value per destination row, or one source value per destination column. + +--- + +## 10.1 Scalar Broadcast: `pto.texpands` + +- **syntax:** +```mlir +pto.texpands ins(%scalar : ) + outs(%dst : !pto.tile_buf<...>) +``` +- **semantics:** `dst[i, j] = scalar` for every element inside `dst`'s valid region. + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `scalar` | signless integer / floating-point scalar | Scalar value broadcast into the destination tile. | +| `dst` | `pto.tile_buf` | Destination tile buffer. | + +**Constraints:** + +- The TileLib template is VEC-oriented and fills `dst.valid_shape`. +- The scalar type must be compatible with `dst.dtype`. + +**Example:** + +```mlir +pto.texpands ins(%scalar : f32) + outs(%dst : !pto.tile_buf) +``` + +--- + +## 10.2 Row-Wise Broadcast: `pto.trowexpand` + +- **syntax:** +```mlir +pto.trowexpand ins(%src : !pto.tile_buf<...>) + outs(%dst : !pto.tile_buf<...>) +``` +- **semantics:** `dst[row, col] = src[row, 0]`. + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `pto.tile_buf` | Source tile carrying one logical scalar per destination row. | +| `dst` | `pto.tile_buf` | Destination tile buffer. | + +**Constraints:** + +- `src` and `dst` must have the same number of valid rows. +- `src` must encode exactly one logical source value per destination row. +- Templates target row-major VEC layouts. + +**Example:** + +```mlir +pto.trowexpand ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) +``` + +--- + +## 10.3 Row-Wise Broadcast Arithmetic and Transform Families + +The row-expansion family combines a full tile `%src0` with a per-row scalar carrier `%src1`: + +| Op | Semantics | +|----|-----------| +| `pto.trowexpandadd` | `dst[row, col] = src0[row, col] + src1[row, 0]` | +| `pto.trowexpandsub` | `dst[row, col] = src0[row, col] - src1[row, 0]` | +| `pto.trowexpandmul` | `dst[row, col] = src0[row, col] * src1[row, 0]` | +| `pto.trowexpanddiv` | `dst[row, col] = src0[row, col] / src1[row, 0]` | +| `pto.trowexpandmax` | `dst[row, col] = max(src0[row, col], src1[row, 0])` | +| `pto.trowexpandmin` | `dst[row, col] = min(src0[row, col], src1[row, 0])` | +| `pto.trowexpandexpdif` | `dst[row, col] = exp(src0[row, col] - src1[row, 0])` | + +### Common Syntax + +```mlir +pto. ins(%src0, %src1 : !pto.tile_buf<...>, !pto.tile_buf<...>) + outs(%dst : !pto.tile_buf<...>) +``` + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src0` | `pto.tile_buf` | Main source tile. | +| `src1` | `pto.tile_buf` | Tile carrying one logical scalar per destination row. | +| `dst` | `pto.tile_buf` | Destination tile buffer. | + +**Constraints:** + +- `src0` and `dst` must be shape/valid-region compatible. +- `src1` must provide one logical scalar per destination row. +- Templates target row-major VEC layouts. +- `pto.trowexpanddiv` and `pto.trowexpandexpdif` are floating-point-only. +- `pto.trowexpanddiv` additionally accepts `precision_mode = #pto`. + Omitted means `DEFAULT`. + `HIGH_PRECISION` is currently legal only when the tile element type is `f16` or `f32`. + +**Example:** + +```mlir +pto.trowexpandadd ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) +``` + +```mlir +pto.trowexpanddiv ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + {precision_mode = #pto} +``` + +--- + +## 10.4 Column-Wise Broadcast: `pto.tcolexpand` + +- **syntax:** +```mlir +pto.tcolexpand ins(%src : !pto.tile_buf<...>) + outs(%dst : !pto.tile_buf<...>) +``` +- **semantics:** `dst[row, col] = src[0, col]`. + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `pto.tile_buf` | Source tile carrying one logical scalar per destination column. | +| `dst` | `pto.tile_buf` | Destination tile buffer. | + +**Constraints:** + +- `src` and `dst` must have the same number of valid columns. +- `src` must encode exactly one logical source value per destination column. +- Templates target row-major VEC layouts. + +**Example:** + +```mlir +pto.tcolexpand ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) +``` + +--- + +## 10.5 Column-Wise Broadcast Arithmetic and Transform Families + +The column-expansion family combines a full tile `%src0` with a per-column scalar carrier `%src1`: + +| Op | Semantics | +|----|-----------| +| `pto.tcolexpandadd` | `dst[row, col] = src0[row, col] + src1[0, col]` | +| `pto.tcolexpandsub` | `dst[row, col] = src0[row, col] - src1[0, col]` | +| `pto.tcolexpandmul` | `dst[row, col] = src0[row, col] * src1[0, col]` | +| `pto.tcolexpanddiv` | `dst[row, col] = src0[row, col] / src1[0, col]` | +| `pto.tcolexpandmax` | `dst[row, col] = max(src0[row, col], src1[0, col])` | +| `pto.tcolexpandmin` | `dst[row, col] = min(src0[row, col], src1[0, col])` | +| `pto.tcolexpandexpdif` | `dst[row, col] = exp(src0[row, col] - src1[0, col])` | + +### Common Syntax + +```mlir +pto. ins(%src0, %src1 : !pto.tile_buf<...>, !pto.tile_buf<...>) + outs(%dst : !pto.tile_buf<...>) +``` + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src0` | `pto.tile_buf` | Main source tile. | +| `src1` | `pto.tile_buf` | Tile carrying one logical scalar per destination column. | +| `dst` | `pto.tile_buf` | Destination tile buffer. | + +**Constraints:** + +- `src0` and `dst` must be shape/valid-region compatible. +- `src1` must provide one logical scalar per destination column. +- Templates target row-major VEC layouts. +- `pto.tcolexpanddiv` and `pto.tcolexpandexpdif` are floating-point-only. +- `pto.tcolexpanddiv` additionally accepts `precision_mode = #pto`. + Omitted means `DEFAULT`. + `HIGH_PRECISION` is currently legal only when the tile element type is `f16` or `f32`. + +**Example:** + +```mlir +pto.tcolexpandadd ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) +``` + +```mlir +pto.tcolexpanddiv ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + {precision_mode = #pto} +``` diff --git a/docs/isa/tile-op/11-selection-ops.md b/docs/isa/tile-op/11-selection-ops.md new file mode 100644 index 000000000..231c8169f --- /dev/null +++ b/docs/isa/tile-op/11-selection-ops.md @@ -0,0 +1,84 @@ +# 11. Selection Operations + +> **Category:** Tile-local VEC selection compute +> **Pipeline:** PIPE_V + +This chapter documents the TileLib selection families. These ops select between data sources under control of a packed predicate-mask tile. + +The mask tile carries packed predicate bytes in UB. Templates load predicate bits directly with predicate-load helpers such as `plds`, then use `vsel` to choose the data path. + +`pto.tsel` and `pto.tsels` carry an extra `%tmp` operand for A2/A3 interface compatibility (see [§1.7](01-tile-overview.md#17-scratch-operands-and-a2a3-compatibility)). + +--- + +## 11.1 `pto.tsel` + +- **syntax:** +```mlir +pto.tsel ins(%mask, %src0, %src1, %tmp : + !pto.tile_buf<...>, !pto.tile_buf<...>, + !pto.tile_buf<...>, !pto.tile_buf<...>) + outs(%dst : !pto.tile_buf<...>) +``` +- **semantics:** `dst[i, j] = mask[i, j] ? src0[i, j] : src1[i, j]`. + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `pto.tile_buf` | Packed predicate-mask carrier. | +| `src0` | `pto.tile_buf` | Value selected when the predicate bit is true. | +| `src1` | `pto.tile_buf` | Value selected when the predicate bit is false. | +| `tmp` | `pto.tile_buf` | Scratch tile for A2/A3 interface compatibility. | +| `dst` | `pto.tile_buf` | Destination tile buffer. | + +**Constraints:** + +- `src0`, `src1`, and `dst` must have the same shape and valid region. +- The `tsel` template specializes the mask carrier as an `i8` tile with packed predicate bytes. + +**Example:** + +```mlir +pto.tsel ins(%mask, %a, %b, %tmp : + !pto.tile_buf, !pto.tile_buf, + !pto.tile_buf, !pto.tile_buf) + outs(%dst : !pto.tile_buf) +``` + +--- + +## 11.2 `pto.tsels` + +- **syntax:** +```mlir +pto.tsels ins(%mask, %src, %tmp, %scalar : + !pto.tile_buf<...>, !pto.tile_buf<...>, + !pto.tile_buf<...>, ) + outs(%dst : !pto.tile_buf<...>) +``` +- **semantics:** `dst[i, j] = mask[i, j] ? src[i, j] : scalar`. + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `pto.tile_buf` | Packed predicate-mask carrier. | +| `src` | `pto.tile_buf` | Source tile selected when the predicate bit is true. | +| `tmp` | `pto.tile_buf` | Scratch tile for A2/A3 interface compatibility. | +| `scalar` | signless integer / floating-point scalar | Scalar selected when the predicate bit is false. | +| `dst` | `pto.tile_buf` | Destination tile buffer. | + +**Constraints:** + +- `src` and `dst` must have the same shape and valid region. +- `tsels` accepts packed-mask carrier tiles with `i8`, `i16`, or `i32` element types. + +**Example:** + +```mlir +pto.tsels ins(%mask, %src, %tmp, %scalar : + !pto.tile_buf, !pto.tile_buf, + !pto.tile_buf, f16) + outs(%dst : !pto.tile_buf) +``` diff --git a/docs/isa/tile-op/12-fill-and-padding-ops.md b/docs/isa/tile-op/12-fill-and-padding-ops.md new file mode 100644 index 000000000..0734b2f55 --- /dev/null +++ b/docs/isa/tile-op/12-fill-and-padding-ops.md @@ -0,0 +1,101 @@ +# 12. Fill and Padding Operations + +> **Category:** Tile-local fill, pad, and expansion materialization +> **Pipeline:** PIPE_V + +This chapter documents the TileLib fill / padding families. These ops preserve or materialize valid data and then synthesize the remaining destination region from the destination tile's padding policy. + +The destination tile's `pad` / `pad_value` configuration determines which value is written into the synthesized padding or expansion region. + +--- + +## 12.1 `pto.tfillpad` + +- **syntax:** +```mlir +pto.tfillpad ins(%src : !pto.tile_buf<...>) + outs(%dst : !pto.tile_buf<...>) +``` +- **semantics:** copy valid data from `src` into `dst`, then fill the remaining destination region according to `dst`'s pad policy. + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `pto.tile_buf` | Source tile. | +| `dst` | `pto.tile_buf` | Destination tile carrying the pad configuration. | + +**Constraints:** + +- Source and destination element types must be compatible. +- The destination tile must carry a meaningful pad configuration. +- This family is VEC-oriented. + +**Example:** + +```mlir +pto.tfillpad ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) +``` + +--- + +## 12.2 `pto.tfillpad_expand` + +- **syntax:** +```mlir +pto.tfillpad_expand ins(%src : !pto.tile_buf<...>) + outs(%dst : !pto.tile_buf<...>) +``` +- **semantics:** copy valid data from `src` into `dst`, then fill row/column expansion according to `dst`'s pad policy when the destination valid region or backing shape is larger than the source. + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `pto.tile_buf` | Source tile. | +| `dst` | `pto.tile_buf` | Larger destination tile carrying the pad configuration. | + +**Constraints:** + +- `dst` may be larger than `src` in valid region or physical shape. +- The fill value is derived from `dst.pad_value`. +- A unified VEC-oriented template handles the supported element families. + +**Example:** + +```mlir +pto.tfillpad_expand ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) +``` + +--- + +## 12.3 `pto.tfillpad_inplace` + +- **syntax:** +```mlir +pto.tfillpad_inplace ins(%src : !pto.tile_buf<...>) + outs(%dst : !pto.tile_buf<...>) +``` +- **semantics:** update the padding / expansion region of an already materialized tile without a separate copy-in phase. + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `pto.tile_buf` | Source tile buffer. | +| `dst` | `pto.tile_buf` | Destination tile buffer, typically aliasing the same physical tile. | + +**Constraints:** + +- PTOAS exposes `pto.tfillpad_inplace` as a dedicated Tile op. +- In typical use, `src` and `dst` refer to the same underlying tile buffer. +- The fill value is derived from `dst.pad_value`. + +**Example:** + +```mlir +pto.tfillpad_inplace ins(%tile : !pto.tile_buf) + outs(%tile : !pto.tile_buf) +``` diff --git a/docs/release/PTO-micro-Instruction-SPEC-v0.4.md b/docs/release/PTO-micro-Instruction-SPEC-v0.4.md new file mode 100644 index 000000000..d13c2df40 --- /dev/null +++ b/docs/release/PTO-micro-Instruction-SPEC-v0.4.md @@ -0,0 +1,5373 @@ +# PTO micro Instruction Spec — Draft (A5) + +- v0.4: Update DMA instruction docs and add PTO Tile Instruction SPEC +- v0.3: Add runtime block query and vector-interval legality notes; Normalize load/store distribution families; Update get_buf/rls_buf details +- v0.2: Update micro Instruction latency and throughput +- v0.1: Doc Init + +[toc] + +--- + +## Part I: Architecture Overview + +### Overview + +This document defines the PTO micro Instruction, a compiler-internal and externally facing specification designed to represent vector compute kernels within the PTO architecture. Much like NVVM provides a robust IR for GPU architectures, the PTO micro Instruction serves as the direct bridge between high-level programming models and the underlying hardware ISA, providing a precise, low-level representation of vector workloads explicitly designed for the Ascend 950 architecture. + +#### Position in the Stack and Layer Modeled + +The PTO micro Instruction operates as a very low-level intermediate representation within the PTO compiler stack. It is uniquely designed to accurately and comprehensively express all architectural information of the Ascend 950 hardware. It specifically models the bare-metal vector execution layer, making hardware-specific capabilities and constraints, such as exact vector lane configurations, memory space hierarchies, and hardware-specific fusion semantics, fully transparent and controllable. + +#### PTO Instruction Modes and Compilation Flows + +Within the end-to-end PTO software stack, PTO instructions may appear in three closely related authoring or lowering modes: + +- **PTO Tile Instruction**: tile-oriented PTO code that serves as a nano-kernel encapsulation of tile instructions, primarily expressing computation and data movement in terms of tile buffers, tile shapes, and tile-local layout. +- **PTO micro Instruction**: vector-execution-oriented PTO code that makes DMA setup, vector registers, masks, synchronization, and `__VEC_SCOPE__` boundaries explicit. This document is centered on this mode. +- **PTO Tile+micro Instruction**: a hybrid PTO form that keeps tile-level orchestration while embedding explicit micro-instruction regions where direct vector-pipeline control is required. + +From these PTO instruction forms, the stack can proceed along two main compilation flows: + +- **CCE generation flow**: PTO ISA is lowered into a CCE-oriented representation, which is then compiled by the BiSheng toolchain into Ascend device binaries. +- **Bytecode generation flow**: PTO ISA is emitted as bytecode, which is then compiled by the BiSheng toolchain into Ascend device binaries. + +```text +High-level frameworks / DSLs / library kernels + | + v + +----------------------------------+ + | PTO ISA layer | + | | + | (1) PTO Tile Instruction | + | (2) PTO micro Instruction | + | (3) PTO Tile+micro Instruction | + +----------------+-----------------+ + | + +------------+------------+ + | | + v v + +-------------------------+ +-------------------------+ + | Path A: generate CCE | | Path B: generate | + | (CCE-oriented form) | | bytecode | + +------------+------------+ +------------+------------+ + | | + v v + +-----------------------------------------------+ + | BiSheng compiler | + +---------------------------+-------------------+ + | + v + +-----------------------------+ + | Ascend device binaries | + +-----------------------------+ +``` + +#### Why External Developers Read or Author PTO micro Instruction + +While the majority of users will interact with the PTO architecture via higher-level frameworks, external developers may need to read or author PTO micro Instruction directly for several key reasons: + +- Custom Toolchain Development: build custom compiler frontends or domain-specific languages (DSLs) that target the Ascend 950 architecture with maximum hardware utilization. +- Performance Engineering: inspect the output of high-level compiler passes, verify fine-grained optimization behaviors, and pinpoint performance bottlenecks at the architectural level. +- Micro-Optimization: hand-author highly optimized, critical mathematical kernels using a stable, precise IR when higher-level abstractions cannot achieve the theoretical peak performance of the hardware. + +#### Relationship to CCE + +The PTO micro Instruction is designed to express the full semantic capabilities of the Compute Cube Engine (CCE), but with significant structural and pipeline advantages for compiler development. + +- Bypassing the C/Clang Pipeline: while CCE heavily relies on C/C++ extensions parsed by Clang, the PTO micro Instruction operates entirely independently of the C language frontend. By bypassing Clang AST generation and frontend processing, utilizing the PTO micro Instruction significantly reduces overall compilation time and memory overhead. +- Enhanced IR Verification: because the PTO micro Instruction is a strongly typed, SSA-based (Static Single Assignment) compiler IR rather than a C-wrapper API, it provides a much more rigorous and detailed IR verification process. Structural inconsistencies, invalid memory access patterns, and operand type mismatches are caught immediately with precise, explicit diagnostic feedback, providing developers with much higher visibility into kernel correctness than traditional CCE error reporting. + +#### Intended Audience + +This document is written for compiler engineers, library writers, and advanced performance architects. We expect the reader to have a working understanding of modern compiler infrastructure, specifically MLIR, the principles of Static Single Assignment (SSA) form, and a deep understanding of the vector-processing capabilities of the Ascend 950 architecture. + +### Getting Started + +The PTO micro Instruction is architected as a performance-critical layer within the compiler stack, specifically designed to exploit the **Decoupled Access-Execute** (DAE) nature of the Ascend 950 hardware. + +#### Hardware Pipeline Modeling + +The IR is structured to mirror the three primary hardware pipelines of the Ascend 950 architecture. Correct PTO micro Instruction authoring requires managing the interaction between these asynchronous units: + +**MTE2** (Memory Transfer Engine - Inbound): Responsible for moving data from Global Memory (GM) to the Unified Buffer (UB). + +**Vector Core** (Computation): The primary engine for executing SIMD operations on data stored in UB. + +**MTE3** (Memory Transfer Engine - Outbound): Responsible for moving processed data from UB back to GM. + +#### Architecture Detail: Vector Lane (VLane) + +The vector register is organized as **8 VLanes** of 32 bytes each. A VLane is the atomic unit for group reduction operations. + +``` +vreg (256 bytes total): +┌─────────┬─────────┬─────────┬─────┬─────────┬─────────┐ +│ VLane 0 │ VLane 1 │ VLane 2 │ ... │ VLane 6 │ VLane 7 │ +│ 32B │ 32B │ 32B │ │ 32B │ 32B │ +└─────────┴─────────┴─────────┴─────┴─────────┴─────────┘ +``` + +Elements per VLane by data type: + +| Data Type | Elements/VLane | Total Elements/vreg | +|-----------|---------------|-------------------| +| i8/si8/ui8 | 32 | 256 | +| i16/si16/ui16/f16/bf16 | 16 | 128 | +| i32/si32/ui32/f32 | 8 | 64 | +| i64/si64/ui64 | 4 | 32 | + +#### Memory and Synchronization Model + +The PTO micro Instruction enforces a strict memory hierarchy. The Unified Buffer (UB) is the only valid operand source for vector compute instructions. Consequently, the architecture of a PTO micro Instruction program is defined by the explicit management of data movement: + +**Address Space Isolation**: The IR uses `!pto.ptr` to distinguish between GM (`!pto.ptr`) and UB (`!pto.ptr`). The verifier ensures that vector compute operations do not access GM directly; data must first be moved into UB. + +**UB Capacity**: The Unified Buffer provides 256KB of on-chip SRAM (also referred to as "vecTile"). + +**Data Flow**: + +``` +┌─────────────────────────────────────────────┐ +│ Global Memory (GM) │ +│ (Off-chip HBM/DDR) │ +└─────────────────────┬───────────────────────┘ + │ DMA (MTE2 inbound / MTE3 outbound) +┌─────────────────────▼───────────────────────┐ +│ Unified Buffer (UB) │ +│ (On-chip SRAM, 256KB) │ +└─────────────────────┬───────────────────────┘ + │ Vector Load/Store (PIPE_V) +┌─────────────────────▼───────────────────────┐ +│ Vector Register File (VRF) │ +│ vreg (256B each) + mask (256-bit each) │ +└─────────────────────────────────────────────┘ +``` + +1. **GM → UB**: DMA transfer via MTE2 (`pto.dma_load`) +2. **UB → vreg**: Vector Load instructions (`pto.vlds`, `pto.vldsx2`, etc.) +3. **vreg → vreg**: Compute instructions (`pto.vadd`, `pto.vmul`, etc.) +4. **vreg → UB**: Vector Store instructions (`pto.vsts`, `pto.vstsx2`, etc.) +5. **UB → GM**: DMA transfer via MTE3 (`pto.dma_store`) + +The grouped DMA surface in this specification covers GM↔UB transfer only. +Low-level raw copy families such as UB→UB copy use separate operand contracts +and are outside this grouped DMA interface. + +**Load/Store Access Patterns**: + +For UB↔vreg data movement, besides contiguous load/store, the architecture provides rich access pattern support including strided access, pack/unpack, interleave/deinterleave, broadcast, upsample/downsample, channel split/merge, gather/scatter, and squeeze/expand operations. For detailed instruction syntax and distribution modes, refer to the [Vector Load/Store](#micro-03-vector-load-store) group in the ISA specification. + +#### Synchronization Model + +The Ascend 950 architecture employs a cluster-based design with a 1:2 ratio of Cube cores to Vector cores. The PTO micro Instruction provides multiple levels of synchronization to manage concurrent execution across pipelines and cores: + +**Inter-Core Synchronization (within a cluster):** + +Synchronization between cores within the same cluster is achieved via the core sync mechanism using `pto.set_intra_core` and `pto.wait_intra_core` operations. This enables coordination between Cube and Vector cores sharing the same cluster resources. + +**Vector Core Pipeline Synchronization:** + +Within a single core, multiple pipelines operate asynchronously: + +- **MTE2 (PIPE_MTE2)**: DMA copy-in from GM to UB +- **MTE3 (PIPE_MTE3)**: DMA copy-out from UB to GM +- **Vector Compute (PIPE_V)**: Vector ALU operations +- **Scalar (PIPE_S)**: Scalar unit running the kernel program + +Pipeline synchronization can be achieved through two mechanisms: + +1. **Flag/Event mechanism**: `pto.set_flag` and `pto.wait_flag` operations resolve Read-After-Write (RAW) and Write-After-Read (WAR) hazards between pipelines. + +2. **Buffer-ID mechanism**: `pto.get_buf` and `pto.rls_buf` provide finer-grained synchronization through buffer acquisition and release semantics for producer-consumer coordination. + +**Intra-Pipeline Memory Barriers (within `__VEC_SCOPE__`):** + +Within the vector execution scope, the hardware does not track UB address aliasing between reg↔UB accesses. When UB addresses overlap or alias between vector load/store operations, explicit memory barriers are required: + +```c +pto.mem_bar "VV_ALL" // All prior vector ops complete before subsequent +pto.mem_bar "VST_VLD" // All prior vector stores visible before subsequent loads +pto.mem_bar "VLD_VST" // All prior vector loads complete before subsequent stores +``` + +Without proper barriers, loads may see stale data or stores may be reordered incorrectly. + +#### Execution Scopes (__VEC_SCOPE__) + +`__VEC_SCOPE__` is the IR-level representation of a Vector Function (VF) launch. In the PTO architecture, it defines the hardware interface between the Scalar Unit and the Vector Thread. + +In PTO micro Instruction source IR, vector execution scopes are modeled as dedicated region ops. The default form is `pto.vecscope`; when the scope body must reject implicit capture and require explicit region arguments, use `pto.strict_vecscope`. + +**Scalar-Vector Interface:** + +The execution model follows non-blocking fork semantics: + +- Scalar invocation: the scalar processor invokes a vector thread by calling a VF. Once the launch command is issued, the scalar unit does not stall and continues executing subsequent instructions in the pipeline. +- Vector execution: after invocation, the vector thread independently fetches and executes the instructions defined within the VF scope. +- Parallelism: this decoupled execution allows the scalar and vector units to run in parallel, so the scalar unit can prepare addresses or manage control flow while the vector unit performs heavy SIMD computation. + +**Launch Mechanism And Constraints:** + +- Parameter buffering: all arguments required by the VF must be staged in hardware-specific buffers. +- Launch overhead: launching a VF incurs a latency of a few cycles. Very small VFs should account for this overhead because launch cost can rival useful computation time. + +**MLIR Representation:** + +```mlir +pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} +``` + +**Strict MLIR Representation:** + +```mlir +pto.strict_vecscope(%ub, %ub_out, %lane) { +^bb0(%in: !pto.ptr, %out: !pto.ptr, %iv: index): + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %in[%iv] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %out[%iv], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} : (!pto.ptr, !pto.ptr, index) -> () +``` + +`pto.strict_vecscope` is the strict form of `pto.vecscope`. + +- `pto.vecscope` allows the body to use surrounding SSA values directly. +- `pto.strict_vecscope` requires every external value used by the body to be passed through the op operand list and received as a body block argument. +- `pto.strict_vecscope` rejects implicit capture from the surrounding scope. +- both ops still represent one explicit VPTO vector interval. +- regardless of whether the source form uses `pto.vecscope`, + `pto.strict_vecscope`, or a lowered carrier loop with + `llvm.loop.aivector_scope`, every op that produces or consumes `!pto.vreg`, + `!pto.mask<...>`, or `!pto.align` must be enclosed by exactly one vector + interval +- nested vector intervals are not part of the legal VPTO surface; ordinary + nested `scf.for` structure is fine, but one vector interval may not contain + another vector interval + +### Example: VecScope + +```mlir +pto.dma_load %7, %2, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64 + +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +pto.vecscope { + scf.for %lane = %c0 to %9 step %c64 { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %2[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %8[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } +} + +pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] +pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] +pto.dma_store %8, %14, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 +``` + +### Example: Strict VecScope + +```mlir +pto.strict_vecscope(%ub_in, %ub_out, %lane, %remaining) { +^bb0(%in: !pto.ptr, %out: !pto.ptr, %iv: index, %rem: i32): + %mask, %next_remaining = pto.plt_b32 %rem : i32 -> !pto.mask, i32 + %v = pto.vlds %in[%iv] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %out[%iv], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} : (!pto.ptr, !pto.ptr, index, i32) -> () +``` + +Use `pto.strict_vecscope` when the source form should make all vector-scope inputs explicit in the region signature instead of relying on surrounding SSA visibility. The scope op itself only defines the vector-interval boundary and region argument contract. + +### Scope + +This document is the interface specification centered on the `mlir::pto` dialect and the shared MLIR surface used alongside it in PTO micro Instruction programs. + +It only describes: + +- operation names +- operand and result lists +- operand and result types +- important attributes +- C-style semantics for each operation + +It does not describe lowering strategy. + +PTO micro Instruction source programs are not restricted to `pto` operations alone. In practice they also use shared MLIR dialect ops, most notably the full scalar operation surface of `arith` together with structured control-flow ops from `scf`, to express scalar constants, scalar arithmetic, type conversion, comparisons, and structured control flow around PTO vector or tile regions. These shared-dialect ops are part of the supported PTO micro Instruction source surface and should be regarded as part of PTO-ISA alongside `pto` dialect operations. + +### Shared MLIR Dialects + +- `arith`: the full scalar `arith` surface is supported in PTO micro Instruction programs, covering scalar integer, floating-point, boolean, and `index` operations. In current samples the most common uses are still constants, offset/bounds arithmetic, casts, compares, and selects. +- `scf`: structured control flow used to model counted loops, conditional regions, loop-carried state, and break-like control around PTO compute and data-movement ops. +- Shared dialect ops remain in standard MLIR form so that PTO analyses and backend passes can reason about control flow and scalar state without re-encoding them as PTO-specific instructions. + +### BlockDim Query Operations + +These ops expose the current kernel instance's execution coordinates to scalar code. They are the PTO-level equivalent of runtime queries such as `GetBlockIdx()` and `GetBlockNum()` in kernel programming models. + +Use them when the same kernel body is launched across multiple blocks or subblocks and each execution instance must figure out which slice of the global workload it owns. + +A common pattern is: + +- split the full input/output tensor into `block_num` disjoint block-sized regions +- let each block compute its own starting offset from `block_idx` +- within one block, further tile the local region and drive the tile loop with ordinary scalar `arith` / `scf` ops + +For example, if a tensor is split evenly across 8 blocks and each block handles `block_length = 2048` elements, then block `b` owns the global range `[b * block_length, (b + 1) * block_length)`. The per-block GM base pointer can be formed by adding `block_idx * block_length` elements to the original base pointer. + +At the PTO micro Instruction level, these runtime-query ops are pure scalar producers. They do not perform data movement, do not allocate memory, and do not by themselves create tiling or double buffering. Instead, they provide the scalar values used by surrounding address computation and structured control flow. + +#### Example: block-level data partitioning + +```mlir +%block = pto.get_block_idx +%block_num = pto.get_block_num +%block_len = arith.constant 2048 : index +%base = arith.index_cast %block : i64 to index +%offset = arith.muli %base, %block_len : index +%block_in = pto.addptr %gm_in, %offset : !pto.ptr -> !pto.ptr +%block_out = pto.addptr %gm_out, %offset : !pto.ptr -> !pto.ptr +``` + +In this pattern, all blocks execute the same kernel body, but each block sees a different `%block` value and therefore computes a different GM window. + +#### `pto.get_block_idx` + +- **syntax:** `%block = pto.get_block_idx` +- **result:** `i64` +- **semantics:** Return the current block ID in the range `[0, pto.get_block_num())`. + +```c +block = block_idx(); +``` + +#### `pto.get_subblock_idx` + +- **syntax:** `%subblock = pto.get_subblock_idx` +- **result:** `i64` +- **semantics:** Return the current subblock ID in the range `[0, pto.get_subblock_num())`. + +```c +subblock = subblock_idx(); +``` + +#### `pto.get_block_num` + +- **syntax:** `%block_num = pto.get_block_num` +- **result:** `i64` +- **semantics:** Return the total number of launched blocks visible to the current kernel instance. + +```c +block_num = block_num(); +``` + +#### `pto.get_subblock_num` + +- **syntax:** `%subblock_num = pto.get_subblock_num` +- **result:** `i64` +- **semantics:** Return the total number of visible subblocks for the current execution instance. + +```c +subblock_num = subblock_num(); +``` + +Typical usage: + +```mlir +%block = pto.get_block_idx +%subblock = pto.get_subblock_idx +%block_num = pto.get_block_num +%subblock_num = pto.get_subblock_num +``` + +### Core Types + +### Element Types +`vreg`: `!pto.vreg` Fixed-width PTO micro Instruction vector type with total width exactly 256 bytes (2048 bits). `N` is the lane count, `T` is the element type, and `N * bitwidth(T) = 2048`. + +| Type | Bits | Description | +|------|------|-------------| +| `i8` / `si8` / `ui8` | 8 | Signless/signed/unsigned 8-bit integer | +| `i16` / `si16` / `ui16` | 16 | Signless/signed/unsigned 16-bit integer | +| `i32` / `si32` / `ui32` | 32 | Signless/signed/unsigned 32-bit integer | +| `i64` / `si64` / `ui64` | 64 | Signless/signed/unsigned 64-bit integer | +| `f16` | 16 | IEEE 754 half precision | +| `bf16` | 16 | Brain floating point | +| `f32` | 32 | IEEE 754 single precision | + +### Mask Types + +`mask`: `!pto.mask` Typed predicate-register view. `G` is one of `b8`, `b16`, `b32` and records the byte-granularity interpretation used by VPTO ops and verifiers. + +Typed masks are also the primary legality contract for predicated VPTO code: + +- vector ops over `f32`, `i32`, `si32`, and `ui32` consume `!pto.mask` +- vector ops over `f16`, `bf16`, `i16`, `si16`, and `ui16` consume + `!pto.mask` +- vector ops over 8-bit element families consume `!pto.mask` +- compare families keep seed-mask and result-mask granularity aligned with the + compared vector family +- carry families keep carry-in, carry-out, and execution-mask granularity + aligned with the data-vector family +- mask-only ops that do not explicitly change granularity preserve the same `G` + +### Address Space Conventions + +PTO micro Instruction memory operands use `!pto.ptr`. This specification models the following memory-space attributes: + +| Space | Interpretation | +|-------|----------------| +| `gm` | Global Memory (GM), off-chip HBM/DDR storage | +| `ub` | Unified Buffer (UB), on-chip vector buffer | + +Typical pointer construction and pointer arithmetic follow the same `!pto.ptr<..., space>` form: + +```mlir +%0 = pto.castptr %c0 : i64 -> !pto.ptr +%1 = pto.addptr %0, %c1024 : !pto.ptr -> !pto.ptr +``` + +### `!pto.ptr` + +`!pto.ptr` is the typed pointer form used for explicit memory operands in PTO micro Instruction. + +- `T` is the element type associated with the pointed-to storage. +- `space` is the memory domain, typically `gm` or `ub` in this specification. +- A `pto.ptr` value carries an address plus its element-type / memory-space interpretation, but it does not carry tensor shape or stride metadata by itself. +- Tensor semantics are introduced separately through view-building operations such as `pto.make_tensor_view`. +- Pointer arithmetic is element-based rather than byte-based. + +Typical examples: + +- `!pto.ptr` +- `!pto.ptr` +- `!pto.ptr` + +### Pointer Operations + +#### `pto.castptr` + +- **syntax:** `%result = pto.castptr %addr : i64 -> !pto.ptr` +- **semantics:** Reinterpret a scalar address value as a typed PTO pointer in the target memory space. + +```c +result = (ptr)addr; +``` + +`pto.castptr` is a pointer-construction operation. It does not perform data movement and does not by itself imply any load/store side effect. + +#### `pto.addptr` + +- **syntax:** `%result = pto.addptr %ptr, %offset : !pto.ptr -> !pto.ptr` +- **semantics:** Compute a new pointer by advancing the base pointer by an element offset. + +```c +result = ptr + offset; // offset counted in elements, not bytes +``` + +`pto.addptr` preserves both the element type `T` and the memory-space tag `space`. + +#### `pto.load_scalar` + +- **syntax:** `%value = pto.load_scalar %ptr[%offset] : !pto.ptr -> T` +- **semantics:** Load one scalar element from a pointer-like operand. + +```c +value = ptr[offset]; +``` + +- **inputs:** + `%ptr` is a typed PTO pointer `!pto.ptr`, and `%offset` is an + `index` displacement counted in elements. +- **outputs:** + `%value` is the loaded scalar element. +- **constraints and limitations:** + The result type MUST match the element type of `%ptr`. This op is a scalar + memory helper; unlike `pto.vlds`, it does not produce a `vreg` result and + does not participate in vector load `dist` families. + +#### `pto.store_scalar` + +- **syntax:** `pto.store_scalar %value, %ptr[%offset] : !pto.ptr, T` +- **semantics:** Store one scalar element to a pointer-like operand. + +```c +ptr[offset] = value; +``` + +- **inputs:** + `%value` is the scalar value to store. `%ptr` is a typed PTO pointer + `!pto.ptr`, and `%offset` is an `index` displacement counted in + elements. +- **constraints and limitations:** + The stored value type MUST match the element type of `%ptr`. This op is a + scalar memory helper; unlike `pto.vsts`, it does not consume a mask and does + not target vector-store `dist` families. + +#### Pointer-Based Vector Access Example + +The following lowered-style fragment shows how typed PTO pointers flow through +pointer construction, pointer arithmetic, structured control flow, and PTO +memory ops. Scalar memory access is expressed on `!pto.ptr` in +general, but the common VPTO pattern here is UB-local scalar access alongside +UB vector loads/stores: + +```mlir +%0 = pto.castptr %c0 : i64 -> !pto.ptr +%1 = pto.addptr %0, %c1024 : !pto.ptr -> !pto.ptr +pto.vecscope { + %16 = scf.for %arg3 = %c0 to %11 step %c64 iter_args(%arg4 = %12) -> (i32) { + %mask, %scalar_out = pto.plt_b32 %arg4 : i32 -> !pto.mask, i32 + %s = pto.load_scalar %1[%c4] : !pto.ptr -> f32 + pto.store_scalar %s, %1[%c8] : !pto.ptr, f32 + %17 = pto.vlds %1[%arg3] : !pto.ptr -> !pto.vreg<64xf32> + %18 = pto.vabs %17, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %18, %10[%arg3], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %scalar_out : i32 + } +} +``` + +In this pattern, `pto.castptr` materializes a typed UB pointer, `pto.addptr` shifts the base by 1024 `f32` elements, and the subsequent `[%arg3]` indexing on `pto.vlds` / `pto.vsts` applies an additional element offset relative to that base. + +### Special Types + +#### `!pto.mask` + +`!pto.mask` models an A5 predicate register (256-bit) under a typed granularity view, not an integer vector. + +`G` is part of the type and MUST be one of: + +- `b32` +- `b16` +- `b8` + +All three forms describe the same physical 256-bit predicate-register class. The type parameter does not encode how many lanes are currently active. Instead, it records how VPTO interprets the register when matching mask-producing ops, mask-consuming ops, and verifier legality rules. + +In the ISA chapters below, this document uses `!pto.mask` as shorthand when a +family is generic over granularity. For op families whose names already encode +the granularity, such as `pset_b32`, `pge_b16`, `plt_b8`, +`pdintlv_b8`, and `pintlv_b16`, examples use the corresponding concrete typed +mask. + +**Mask Granularity:** + +The predicate register is 256 bits in length, where each bit controls 1 byte of data. `G` therefore describes how many bytes form one logical element slot: + +| Mask Type | Bytes / Element Slot | Typical Element Family | Derived Logical Lanes | +|-----------|----------------------|------------------------|-----------------------| +| `!pto.mask` | 4 | `f32` / `i32` | 64 | +| `!pto.mask` | 2 | `f16` / `bf16` / `i16` | 128 | +| `!pto.mask` | 1 | 8-bit element family | 256 | + +This is intentionally different from a lane-vector model such as `mask<64xi1>`: + +- `!pto.mask` still denotes a 256-bit predicate register; +- `64` is only the derived logical lane count for the `b32` view; +- value-level patterns such as `PAT_VL32` describe which lanes are active, not a different type. +- `pto.vaddc`, `pto.vsubc`, `pto.vaddcs`, and `pto.vsubcs` use `!pto.mask` + to carry their per-lane carry results, interpreted with this same + granularity. + +**Predication Behavior (Zero-Merge):** + +The native hardware predication mode is **ZEROING** — inactive lanes produce zero: + +```c +dst[i] = mask[i] ? op(src0[i], src1[i]) : 0 // ZEROING mode +``` + +```mlir +// Predicated add: inactive lanes produce zero +%mask = pto.pset_b32 "PAT_VL32" : !pto.mask // first 32 logical b32 lanes active +%result = pto.vcmp %a, %b, %mask, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +``` + +```mlir +// Compare and select: generate mask from comparison, use for conditional select +%mask = pto.vcmp %lhs, %rhs, %seed, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +%out = pto.vsel %x, %y, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +#### `!pto.align` + +`!pto.align` models the A5 vector-align carrier state. It is not payload data. + +```mlir +%align = pto.vldas %ub : !pto.ptr -> !pto.align +%vec, %align_out = pto.vldus %ub, %align : !pto.ptr, !pto.align -> !pto.vreg<64xf32>, !pto.align + +%store_align = pto.init_align : !pto.align +%next_align = pto.vstus %store_align, %offset, %vec, %ub + : !pto.align, i32, !pto.vreg<64xf32>, !pto.ptr -> !pto.align +``` + +## Part II: Notation Convention + +This section defines the MLIR syntax patterns and C-style semantic notation used throughout the ISA reference (Part III). + +### MLIR Op Syntax Patterns + +All PTO micro Instruction operations follow standard MLIR syntax. The common patterns are: + +**Unary (one vector in, one vector out):** + +```mlir +%result = pto. %input : !pto.vreg -> !pto.vreg +``` + +**Binary (two vectors in, one vector out):** + +```mlir +%result = pto. %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg +``` + +**Vec-Scalar (one vector + one scalar in, one vector out):** + +```mlir +%result = pto. %input, %scalar : !pto.vreg, T -> !pto.vreg +``` + +**Load (memory to register):** + +```mlir +%result = pto.vlds %source[%offset] {dist = "DIST"} : !pto.ptr -> !pto.vreg +``` + +**Store (register to memory):** + +```mlir +pto.vsts %value, %destination[%offset] {dist = "DIST"} : !pto.vreg, !pto.ptr +``` + +**Dual Load (one load, two results — deinterleave):** + +```mlir +%low, %high = pto.vldsx2 %source[%offset], "DIST" : !pto.ptr, index -> !pto.vreg, !pto.vreg +``` + +**Dual Store (two inputs, one interleaved store):** + +```mlir +pto.vstsx2 %low, %high, %dest[%offset], "DIST", %mask : !pto.vreg, !pto.vreg, !pto.ptr, index, !pto.mask +``` + +**Compare (two vectors + seed mask in, mask out):** + +```mlir +%mask = pto.vcmp %src0, %src1, %seed, "CMP_MODE" : !pto.vreg, !pto.vreg, !pto.mask -> !pto.mask +``` + +**Conversion (one vector in, different-typed vector out):** + +```mlir +%result = pto.vcvt %input, %mask {rnd = "R", sat = "SAT", part = "EVEN"} : !pto.vreg, !pto.mask -> !pto.vreg +``` + +**Predicate construction:** + +```mlir +%mask = pto.pset_b32 "PAT_ALL" : !pto.mask +%tail = pto.pge_b32 "PAT_VL16" : !pto.mask +``` + +**Sync operations:** + +```mlir +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +**Pointer construction and arithmetic:** + +```mlir +%ptr = pto.castptr %addr : i64 -> !pto.ptr +%ptr2 = pto.addptr %ptr, %offset : !pto.ptr -> !pto.ptr +``` + +### Shared Dialect Syntax Patterns + +PTO micro Instruction programs may interleave PTO ops with standard MLIR `arith` and `scf` ops. +The examples below emphasize common index-heavy patterns, but `arith` support is not limited to index arithmetic. + +**Scalar / index constant:** + +```mlir +%c0 = arith.constant 0 : index +%zero = arith.constant 0.0 : f32 +``` + +**Scalar arithmetic (integer / float / boolean-style bitwise):** + +```mlir +%sum_i = arith.addi %lhs_i, %rhs_i : i32 +%sum_f = arith.addf %lhs_f, %rhs_f : f32 +%bits = arith.andi %flags0, %flags1 : i32 +``` + +**Scalar compare and select:** + +```mlir +%cond = arith.cmpi eq, %lhs, %rhs : index +%bound = arith.select %cond, %a, %b : index +``` + +**Counted loop with loop-carried values:** + +```mlir +%result = scf.for %iv = %lb to %ub step %step + iter_args(%acc = %init) -> (index) { + %next = arith.addi %acc, %iv : index + scf.yield %next : index +} +``` + +**Structured conditional region:** + +```mlir +%selected = scf.if %cond -> (index) { + scf.yield %then_value : index +} else { + scf.yield %else_value : index +} +``` + +**Structured while loop:** + +```mlir +%state:2 = scf.while (%iv = %c0, %alive = %true) : (index, i1) -> (index, i1) { + %keep_going = arith.cmpi slt, %iv, %limit : index + scf.condition(%keep_going) %iv, %alive : index, i1 +} do { +^bb0(%iv_in: index, %alive_in: i1): + %iv_next = arith.addi %iv_in, %c1 : index + scf.yield %iv_next, %alive_in : index, i1 +} +``` + +### C-Style Semantics Convention + +For each ISA operation in Part III, semantics are expressed as C code. The convention: + +```c +// Vector register contents as arrays: +T dst[N]; // destination +T src0[N]; // first source +T src1[N]; // second source (binary ops) +T scalar; // scalar operand (vec-scalar ops) +int mask[N]; // per-lane predicate (0 or 1) + +// N = lane count determined by type: +// N = 256 for i8/si8/ui8 +// N = 128 for i16/si16/ui16/f16/bf16 +// N = 64 for i32/si32/ui32/f32 +// N = 32 for i64/si64/ui64 +``` + +**Example — pto.vadd semantics:** + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] + src1[i]; +``` + +**Example — pto.vcgadd (group reduction per VLane) semantics:** + +```c +int K = N / 8; // elements per VLane +for (int g = 0; g < 8; g++) { + T sum = 0; + for (int i = 0; i < K; i++) + sum += src[g*K + i]; + dst[g*K] = sum; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +For A5 reduction result types: + +- `pto.vcadd` widens `i8 -> i16`, `u8 -> u16`, `i16 -> i32`, and `u16 -> u32`, + with the lane count halved in each widening case. +- `pto.vcadd` keeps the same result type for `f16`, `f32`, `i32`, and `u32`. + +### Template Placeholder Conventions + +| Placeholder | Meaning | +|-------------|---------| +| `"SRC_PIPE"`, `"DST_PIPE"` | Pipeline identifiers: `"PIPE_MTE2"`, `"PIPE_V"`, `"PIPE_MTE3"` | +| `"EVENT_ID"` | Event identifier: `"EVENT_ID0"` etc. | +| `"DIST"` | Distribution mode string (see the relevant load/store ISA group in Part III) | +| `"CMP_MODE"` | Compare predicate: `eq \| ne \| lt \| le \| gt \| ge` | +| `"RND"` | Rounding mode: `R \| A \| F \| C \| Z \| O` | +| `"SAT"` | Saturation: `SAT \| NOSAT` | +| `"PART"` | Half selector: `EVEN \| ODD` | +| `"PAT_*"` | Predicate pattern literal | +| `T` | Element type (f32, f16, bf16, i32, i16, i8, etc.) | +| `N` | Lane count (`N * bitwidth(T) = 2048`) | + +## Part III: ISA Instruction Reference — Summary + +This section provides a categorized overview of all PTO micro Instruction operations plus the shared MLIR `arith` and `scf` ops that may appear in PTO micro Instruction programs. Detailed documentation for each group is available in the linked files. + +## Instruction Groups + +| # | Group | Description | Count | Details | +|---|-------|-------------|-------|---------| +| 1 | [Pipeline Sync](#micro-01-pipeline-sync) | Intra-core pipeline synchronization | 5 | `pto.set_flag`, `pto.wait_flag`, `pto.pipe_barrier`, `pto.get_buf`, `pto.rls_buf` | +| 2 | [DMA Copy Programming](#micro-02-dma-copy) | Public DMA transfer interface between GM↔UB and UB↔UB | 3 | `pto.dma_load`, `pto.dma_store`, `pto.dma_copy` | +| 3 | [Vector Load/Store](#micro-03-vector-load-store) | UB↔vreg data movement with various access patterns | ~20 | `pto.vlds`, `pto.vldsx2`, `pto.vgather2`, `pto.vsts`, `pto.vstsx2`, `pto.vscatter`, etc. | +| 4 | [Predicate Load/Store](#micro-04-predicate-load-store) | UB↔mask register movement | 5 | `pto.plds`, `pto.pldi`, `pto.psts`, `pto.psti`, `pto.pstu` | +| 5 | [Materialization & Predicate Ops](#micro-05-materialization-predicate) | Scalar broadcast, predicate generation and manipulation | ~17 | `pto.vbr`, `pto.vdup`, `pto.pset_b*`, `pto.pge_b*`, `pto.plt_b*`, `pto.ppack`, `pto.punpack`, `pto.pnot`, `pto.psel`, etc. | +| 6 | [Unary Vector Ops](#micro-06-unary-vector-ops) | Single-input element-wise operations | 6 | `pto.vabs`, `pto.vexp`, `pto.vln`, `pto.vsqrt`, `pto.vrelu`, `pto.vnot` | +| 7 | [Binary Vector Ops](#micro-07-binary-vector-ops) | Two-input element-wise operations | 13 | `pto.vadd`, `pto.vsub`, `pto.vmul`, `pto.vdiv`, `pto.vmax`, `pto.vmin`, `pto.vand`, `pto.vor`, `pto.vxor`, `pto.vshl`, `pto.vshr`, `pto.vaddc`, `pto.vsubc` | +| 8 | [Vec-Scalar Ops](#micro-08-vec-scalar-ops) | Vector-scalar operations | 9 | `pto.vadds`, `pto.vmuls`, `pto.vmaxs`, `pto.vmins`, `pto.vlrelu`, `pto.vshls`, `pto.vshrs`, `pto.vaddcs`, `pto.vsubcs` | +| 9 | [Conversion Ops](#micro-09-conversion-ops) | Type conversion with rounding/saturation control | 4 | `pto.vcvt`, `pto.vtrc`, `pto.vbitcast`, `pto.pbitcast` | +| 10 | [Reduction Ops](#micro-10-reduction-ops) | Vector reductions | 7 | `pto.vcadd`, `pto.vcmax`, `pto.vcmin`, `pto.vcgadd`, `pto.vcgmax`, `pto.vcgmin`, `pto.vcpadd` | +| 11 | [Compare & Select](#micro-11-compare-select) | Comparison and conditional selection | 4 (+1 not A5) | `pto.vcmp`, `pto.vcmps`, `pto.vsel`, `pto.vselr` (`pto.vselrv2` removed: not A5) | +| 12 | [Data Rearrangement](#micro-12-data-rearrangement) | In-register data movement and permutation | 2 (+2 not A5) | `pto.vintlv`, `pto.vdintlv` (`pto.vintlvv2`, `pto.vdintlvv2` removed: not A5) | +| 13 | [DSA/SFU Ops](#micro-13-dsa-sfu-ops) | Specialized ops, index generation, and sorting helpers | 9 | `pto.vlrelu`, `pto.vprelu`, `pto.vexpdif`, `pto.vaxpy`, `pto.vmull`, `pto.vmula`, `pto.vci`, `pto.vbitsort`, `pto.vmrgsort4` | +| 14 | [Arith (Shared MLIR Dialect)](#micro-14-shared-arith) | Full scalar `arith` surface used around PTO ops; the companion page lists categories and representative examples | all scalar ops | `arith.constant`, `arith.addi`, `arith.addf`, `arith.cmpi`, `arith.cmpf`, `arith.select`, `arith.index_cast`, `arith.extsi`, `arith.trunci`, `arith.andi`, `arith.shli`, etc. | +| 15 | [SCF (Shared MLIR Dialect)](#micro-15-shared-scf) | Structured loops, branches, and loop-carried state around PTO regions | 5 | `scf.for`, `scf.if`, `scf.while`, `scf.condition`, `scf.yield` | + +## Detailed ISA Group Reference + +This section inlines the 15 ISA group documents so the architectural overview, notation, summary table, and per-group semantics can be read in a single file. + + + +### 1. Pipeline Synchronization + +> **Category:** Synchronization primitives for coordinating pipeline execution +> **Pipelines:** MTE2 (GM→UB), PIPE_V (Vector), MTE3 (UB→GM) + +The PTO micro Instruction model operates on the Ascend 950's **Decoupled Access-Execute** architecture. The MTE and Vector pipelines run asynchronously, requiring explicit synchronization to prevent data hazards. + +--- + +#### Intra-Core Pipeline Sync + +These ops coordinate data flow between pipelines within a single vector core. + +##### `pto.set_flag` + +- **syntax:** `pto.set_flag["SRC_PIPE", "DST_PIPE", "EVENT_ID"]` +- **semantics:** Signal event from source pipe to destination pipe. + +```c +set_flag(src_pipe, dst_pipe, event_id); +``` + +**Example:** After MTE2 completes GM→UB transfer, signal Vector pipe: +```mlir +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +--- + +##### `pto.wait_flag` + +- **syntax:** `pto.wait_flag["SRC_PIPE", "DST_PIPE", "EVENT_ID"]` +- **semantics:** Block destination pipe until source pipe signals event. + +```c +wait_flag(src_pipe, dst_pipe, event_id); +``` + +**Example:** Vector pipe waits for MTE2 data to arrive: +```mlir +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +--- + +##### `pto.pipe_barrier` + +- **syntax:** `pto.pipe_barrier "PIPE_*"` +- **semantics:** Drain all pending ops in the specified pipe. All previously issued operations on that pipe complete before any subsequent operation begins. + +```c +pipe_barrier(pipe); +``` + +**Pipe identifiers:** `PIPE_MTE2`, `PIPE_V`, `PIPE_MTE3` + +**Example:** Two back-to-back `dma_store` calls writing to the same GM address. Without a barrier, MTE3 may reorder them and the final GM value is non-deterministic: + +```mlir +// Both stores target the same GM address — order matters! +pto.dma_store %ub_partial_0, %gm_result, ... +// Without pipe_barrier, MTE3 could execute the second copy before the first +// completes, producing a non-deterministic result at %gm_result. +pto.pipe_barrier "PIPE_MTE3" +// After barrier: first copy is guaranteed complete. Second copy overwrites deterministically. +pto.dma_store %ub_partial_1, %gm_result, ... +``` + +--- + +##### `pto.get_buf` + +- **syntax:** `pto.get_buf "PIPE_*", %buf_id, %mode : i64, i64` +- **semantics:** Acquire buffer slot for inter-pipeline double-buffering coordination. + +```c +get_buf(pipe, buf_id, mode); +``` + +--- + +##### `pto.rls_buf` + +- **syntax:** `pto.rls_buf "PIPE_*", %buf_id, %mode : i64, i64` +- **semantics:** Release buffer slot to allow other pipeline to proceed. + +```c +rls_buf(pipe, buf_id, mode); +``` + +--- + +##### Mode Parameter for `get_buf` / `rls_buf` + +The `mode` parameter controls how `get_buf` and `rls_buf` interact with pipeline execution and dependency tracking: + +| Mode | `get_buf` Behavior | `rls_buf` Behavior | Use Case | +|------|-------------------|-------------------|----------| +| **0** (default) | **Blocking acquire**: waits for all previous `rls_buf` with same `buf_id` from all pipelines (in program order) before the specified pipe can proceed | **Immediate release**: signals completion for only the instructions related to the specified pipe | **Automatic ping/pong dependency** — recommended for double/multi-buffering | +| **1** | **Non-blocking acquire**: does not wait; pipe execution proceeds immediately | **Deferred release**: waits for all instructions across all pipelines with same `buf_id` to retire before signaling | **Backward compatibility** with `set_flag`/`wait_flag` semantics | + +**Mode 0 (Default — Recommended):** +- `get_buf`: The specified pipeline blocks until all previous `rls_buf` operations for the same buffer ID (from any pipeline) have completed, respecting program order. +- `rls_buf`: Immediately signals that the specified pipeline has finished using the buffer — only waits for that pipe's related instructions. +- This mode provides **automatic RAW/WAR/WAW dependency resolution** based on buffer ID and program order, making it ideal for ping/pong and N-buffer patterns. + +**Mode 1 (Legacy Compatibility):** +- `get_buf`: Does not block — the pipeline proceeds immediately without waiting. +- `rls_buf`: Waits for **all** previous instructions across **all** pipelines with the same buffer ID to retire before signaling release. +- This mode emulates `set_flag`/`wait_flag` behavior and is provided for backward compatibility with existing code patterns. + +> **Note:** A5 supports both `set_flag`/`wait_flag` and `get_buf`/`rls_buf` mechanisms. Mode 1 is rarely needed since mode 0 provides a more programmer-friendly approach for buffer-based synchronization. + +--- + +##### `pto.mem_bar` + +- **syntax:** `pto.mem_bar "BARRIER_TYPE"` +- **semantics:** Shared-memory (UB address space) memory fence within `__VEC_SCOPE__`. Required when UB addresses alias between memory operations. The barrier type selects which classes of prior instructions must complete before which classes of subsequent instructions may proceed. + +```c +mem_bar(barrier_type); +``` + +**Barrier types** are organized into three families by the scope of prior vs. subsequent instructions: + +| Family | Barrier type | Prior instructions | Subsequent instructions | +|--------|-------------|-------------------|------------------------| +| **VV** (vector→vector) | `VV_ALL` | All vector load/store | All vector load/store | +| | `VST_VLD` | All vector store | All vector load | +| | `VLD_VST` | All vector load | All vector store | +| | `VST_VST` | All vector store | All vector store | +| **VS** (vector→scalar) | `VS_ALL` | All vector load/store | All scalar load/store | +| | `VST_LD` | All vector store | All scalar load | +| | `VLD_ST` | All vector load | All scalar store | +| | `VST_ST` | All vector store | All scalar store | +| **SV** (scalar→vector) | `SV_ALL` | All scalar load/store | All vector load/store | +| | `ST_VLD` | All scalar store | All vector load | +| | `LD_VST` | All scalar load | All vector store | +| | `ST_VST` | All scalar store | All vector store | + +**Example:** Ensure vector stores are visible before subsequent vector loads to the same UB region: +```mlir +pto.vsts %v0, %ub[%c0] : !pto.vreg<64xf32>, !pto.ptr +pto.mem_bar "VST_VLD" +%v1 = pto.vlds %ub[%c0] : !pto.ptr -> !pto.vreg<64xf32> +``` + +--- + +#### Why `get_buf` / `rls_buf` is More Programmer-Friendly + +The buffer-based synchronization (`get_buf`/`rls_buf`) provides the **same functional capability** as `set_flag`/`wait_flag` for maintaining correct ordering of RAW/WAR/WAW dependencies across pipelines, but with significant usability advantages: + +##### 1. No Manual Priming or Draining + +With `set_flag`/`wait_flag`, ping/pong loops require: +- **Pre-loop priming**: 4× `set_flag` to initialize reverse-dependency signals (otherwise first iteration deadlocks) +- **Post-loop draining**: 4× `wait_flag` to consume leftover signals from final iterations + +With `get_buf`/`rls_buf`: +- **First iteration**: Buffer is initially free, so `get_buf` proceeds immediately — no priming needed +- **Final iteration**: Last `rls_buf` simply completes — no draining required + +##### 2. No Loop Peeling for Complex Dependencies + +For non-1:1 producer-consumer ratios (e.g., 1 MTE2 load : N Vector compute slices), `set_flag`/`wait_flag` requires **peeling the set_flag outside the loop**: + +```mlir +// set_flag/wait_flag: 1 MTE2 load, 8 Vector computes on slices +// MTE2 loads large tile once +pto.dma_load %gm_ptr, %ub_tile, ... +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVT_TILE_READY"] // ◀ MUST be outside loop + +// Vector consumes in 8 slices — but wait_flag can only fire ONCE +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVT_TILE_READY"] // ◀ MUST peel before loop +scf.for %slice = %c0 to %c8 step %c1 { + // compute on %ub_tile[%slice] + // Cannot put wait_flag here — would deadlock on iteration 1+ +} +``` + +With `get_buf`/`rls_buf`, acquire/release can be **inside the loop** — no peeling needed: + +```mlir +// get_buf/rls_buf: same 1:8 pattern, acquire/release inside loop works fine +// MTE2 loads large tile +pto.get_buf "PIPE_MTE2", %bufid_tile, %c0 : i64, i64 +pto.dma_load %gm_ptr, %ub_tile, ... +pto.rls_buf "PIPE_MTE2", %bufid_tile, %c0 : i64, i64 + +// Vector acquires/releases per slice — all 8 iterations work correctly +scf.for %slice = %c0 to %c8 step %c1 { + pto.get_buf "PIPE_V", %bufid_tile, %c0 : i64, i64 // iteration 0: blocks until MTE2 done + // iteration 1-7: proceeds immediately (already acquired) + // compute on %ub_tile[%slice] + pto.rls_buf "PIPE_V", %bufid_tile, %c0 : i64, i64 +} +// No peeling required — get_buf handles the MTE2→V dependency automatically +``` + +##### 3. Simpler Mental Model + +| Aspect | `set_flag`/`wait_flag` | `get_buf`/`rls_buf` | +|--------|------------------------|---------------------| +| **Dependency tracking** | Manual: track event IDs, signal directions, pair every set with wait | Automatic: buffer ID + program order | +| **Event ID management** | **8 IDs per pipe-pair direction** (HW limit); each buffer occupies 1 ID per direction | **1 buffer ID per shared resource** (HW limit: 32 global); same ID used across all pipelines | +| **Error-prone areas** | Forgetting prime/drain, mismatched IDs, wrong direction | Forgetting release (but compile-time checkable) | + +##### Quick Example Comparison + +**Problem:** MTE2 loads into `buf[i%2]`, Vector processes, MTE3 stores — standard ping/pong. + +**set_flag/wait_flag approach:** +```mlir +// BEFORE loop: prime 4 reverse-dep signals +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_0"] +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_1"] +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_0"] +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_1"] + +scf.for %i = ... { + // 4 set_flag + 4 wait_flag inside loop + // Must track 4 IDs: 2 pipe-pair directions × 2 ping/pong buffers +} + +// AFTER loop: drain 4 signals +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_0"] +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_1"] +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_0"] +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_1"] +``` + +**get_buf/rls_buf approach:** +```mlir +scf.for %i = ... { + pto.get_buf %bufid_in[%pp], "PIPE_MTE2" + // ... MTE2 work ... + pto.rls_buf %bufid_in[%pp], "PIPE_MTE2" + + pto.get_buf %bufid_in[%pp], "PIPE_V" + pto.get_buf %bufid_out[%pp], "PIPE_V" + // ... Vector work ... + pto.rls_buf %bufid_in[%pp], "PIPE_V" + pto.rls_buf %bufid_out[%pp], "PIPE_V" + + pto.get_buf %bufid_out[%pp], "PIPE_MTE3" + // ... MTE3 work ... + pto.rls_buf %bufid_out[%pp], "PIPE_MTE3" +} +// Done. No prime. No drain. Dependencies resolved by buffer ID + program order. +``` + +--- + +#### Intra-Core Sync Patterns & Examples + +##### Example 1: `set_flag` / `wait_flag` (Explicit Events) + +Each cross-pipeline data dependency requires an explicit signal/wait pair. The programmer must manually insert `set_flag` after the producer and `wait_flag` before the consumer. + +```mlir +// ─── Stage 1: MTE2 loads data from GM into UB ─── +pto.dma_load %gm_ptr, %ub_ptr, ... + +// MTE2 signals: "UB data is ready for Vector pipe" +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +// ─── Stage 2: Vector pipe consumes UB data ─── +// Vector waits until MTE2's signal arrives +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +pto.vecscope { + %v = pto.vlds %ub_ptr[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} + +// Vector signals: "UB output is ready for MTE3" +pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + +// ─── Stage 3: MTE3 stores result from UB back to GM ─── +// MTE3 waits until Vector's signal arrives +pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + +pto.dma_store %ub_out, %gm_out, ... +``` + +**Key property:** Every cross-pipeline edge is an explicit `(set_flag, wait_flag)` pair. Simple for straight-line code, but gets verbose in loops (see Example 3). + +--- + +##### Example 2: `get_buf` / `rls_buf` (Resource-Based) + +Instead of naming events, each pipeline declares when it **acquires** (`get_buf`) and **releases** (`rls_buf`) a shared UB buffer. Cross-pipeline RAW/WAR dependencies are resolved implicitly by program order — if MTE2 releases `buf_A` and Vector later acquires `buf_A`, the hardware ensures the acquire cannot proceed until the release completes. + +```mlir +// ─── Stage 1: MTE2 loads data into UB ─── +// MTE2 acquires ub_ptr — blocks if Vector hasn't released it from a prior iteration +pto.get_buf "PIPE_MTE2", %bufid_ub_ptr, %c0 : i64, i64 // mode=0 (default) +pto.dma_load %gm_ptr, %ub_ptr, ... +// MTE2 done writing ub_ptr — release it so Vector can consume +pto.rls_buf "PIPE_MTE2", %bufid_ub_ptr, %c0 : i64, i64 + +// ─── Stage 2: Vector computation ─── +// Vector acquires ub_ptr (input) — blocks until MTE2 releases it (RAW: MTE2 write → V read) +pto.get_buf "PIPE_V", %bufid_ub_ptr, %c0 : i64, i64 +// Vector acquires ub_out (output) — blocks until MTE3 releases it from a prior iteration (WAR: MTE3 read → V write) +pto.get_buf "PIPE_V", %bufid_ub_out, %c0 : i64, i64 + +pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub_ptr[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} + +// Vector done reading ub_ptr — release so MTE2 can reuse it in next iteration +pto.rls_buf "PIPE_V", %bufid_ub_ptr, %c0 : i64, i64 +// Vector done writing ub_out — release so MTE3 can consume +pto.rls_buf "PIPE_V", %bufid_ub_out, %c0 : i64, i64 + +// ─── Stage 3: MTE3 stores result to GM ─── +// MTE3 acquires ub_out — blocks until Vector releases it (RAW: V write → MTE3 read) +pto.get_buf "PIPE_MTE3", %bufid_ub_out, %c0 : i64, i64 +pto.dma_store %ub_out, %gm_out, ... +// MTE3 done reading ub_out — release so Vector can reuse it in next iteration +pto.rls_buf "PIPE_MTE3", %bufid_ub_out, %c0 : i64, i64 +``` + +**Key property:** No event IDs needed. Dependencies are implicit from program order of `get_buf`/`rls_buf` on the same buffer ID. This becomes much more convenient in multi-iteration loops (see Example 3). + +--- + +##### Example 3: Ping/Pong Double-Buffering Loop + +Double-buffering overlaps DMA and compute by using two UB buffers alternately. All three stages (MTE2, Vector, MTE3) appear in the **same iteration** — the hardware pipelines them across iterations because different iterations operate on different buffers (`buf[i%2]`). + +###### Event ID scheme (`set_flag` / `wait_flag`) + +With 2 ping/pong buffers and 2 pipeline pairs (MTE2↔V, V↔MTE3), `set_flag`/`wait_flag` needs **8 event IDs** = 2 pipe-pairs × 2 buffers × (forward + reverse): + +**MTE2 ↔ Vector (input buffers):** + +| Event ID | Direction | Purpose | +|----------|-----------|---------| +| `EVT_IN_FWD_0` | MTE2 → V | RAW: buf_in[0] data ready | +| `EVT_IN_FWD_1` | MTE2 → V | RAW: buf_in[1] data ready | +| `EVT_IN_REV_0` | V → MTE2 | WAR: Vector done reading buf_in[0] | +| `EVT_IN_REV_1` | V → MTE2 | WAR: Vector done reading buf_in[1] | + +**Vector ↔ MTE3 (output buffers):** + +| Event ID | Direction | Purpose | +|----------|-----------|---------| +| `EVT_OUT_FWD_0` | V → MTE3 | RAW: buf_out[0] result ready | +| `EVT_OUT_FWD_1` | V → MTE3 | RAW: buf_out[1] result ready | +| `EVT_OUT_REV_0` | MTE3 → V | WAR: MTE3 done reading buf_out[0] | +| `EVT_OUT_REV_1` | MTE3 → V | WAR: MTE3 done reading buf_out[1] | + +###### 3a. `set_flag` / `wait_flag` version + +```mlir +// ═══ Pre-loop: prime ALL reverse-dependency signals ═══ +// Both input and output buffers start unused. We must pre-send +// reverse-dep signals so the first iteration's wait_flags don't deadlock. +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_0"] // ◀ PRIME: buf_in[0] "free" +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_1"] // ◀ PRIME: buf_in[1] "free" +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_0"] // ◀ PRIME: buf_out[0] "free" +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_1"] // ◀ PRIME: buf_out[1] "free" + +scf.for %i = %c0 to %N step %c1 { + // ── All 3 stages in same iteration, indexed by i%2 ── + // %pp = i % 2 (ping/pong selector for buffer & event IDs) + + // ── MTE2: load tile[i] into buf_in[i%2] ── + // WAR: wait until Vector has released buf_in[i%2] from iteration i-2 + pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{pp}"] + pto.dma_load %gm_ptr[%i], %ub_in[%pp], ... + // RAW: signal Vector that buf_in[i%2] data is ready + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVT_IN_FWD_{pp}"] + + // ── Vector: compute buf_in[i%2] → buf_out[i%2] ── + // RAW: wait for MTE2 to finish loading buf_in[i%2] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVT_IN_FWD_{pp}"] + // WAR: wait for MTE3 to finish reading buf_out[i%2] from iteration i-2 + pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{pp}"] + scf.for %dummy = %c0 to %c1 step %c1 { + %v = pto.vlds %ub_in[%pp][%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%pp][%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } {llvm.loop.aivector_scope} + // WAR: tell MTE2 "done reading buf_in[i%2]" + pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{pp}"] + // RAW: tell MTE3 "buf_out[i%2] result ready" + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVT_OUT_FWD_{pp}"] + + // ── MTE3: store result from buf_out[i%2] to GM ── + // RAW: wait for Vector to finish writing buf_out[i%2] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVT_OUT_FWD_{pp}"] + pto.dma_store %ub_out[%pp], %gm_out[%i], ... + // WAR: tell Vector "done reading buf_out[i%2]" + pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{pp}"] +} + +// ═══ Post-loop: drain — match every pre-loop prime with a wait ═══ +// Each priming set_flag must be paired. The last loop iteration's +// set_flags are consumed by wait_flags that will never fire inside the +// loop (there is no iteration i+2). Drain them here. +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{(N-1)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{(N-2)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{(N-1)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{(N-2)%2}"] // ◀ DRAIN +``` + +**What `set_flag`/`wait_flag` requires outside the loop:** +- **Before the loop (4 × `set_flag`):** Prime every reverse-dependency event ID — one per buffer per pipe-pair. Without this, the first iteration's `wait_flag` for reverse deps would deadlock (no signal was ever sent). +- **After the loop (4 × `wait_flag`):** Drain the matching reverse-dep signals from the last iterations. Every `set_flag` must be paired with a `wait_flag` — the last loop iterations produce signals that no subsequent iteration consumes, so they must be drained explicitly. + +###### 3b. `get_buf` / `rls_buf` version + +Same ping/pong double-buffering, but **no pre-loop priming or post-loop draining needed.** Buffer acquire/release semantics handle everything. + +```mlir +scf.for %i = %c0 to %N step %c1 { + // %pp = i % 2 (ping/pong selector) + + // ── MTE2: load tile[i] into buf[i%2] ── + // Acquires buf[i%2] — on first iteration, buffer is free so proceeds immediately. + // On later iterations, blocks until Vector releases buf[i%2] (WAR: automatic). + pto.get_buf "PIPE_MTE2", %bufid_buf[%pp], %c0 : i64, i64 // mode=0 + pto.dma_load %gm_ptr[%i], %ub_buf[%pp], ... + pto.rls_buf "PIPE_MTE2", %bufid_buf[%pp], %c0 : i64, i64 + + // ── Vector: compute on buf[i%2] ── + // Acquires buf[i%2] — blocks until MTE2 releases it (RAW: automatic) + pto.get_buf "PIPE_V", %bufid_buf[%pp], %c0 : i64, i64 + pto.get_buf "PIPE_V", %bufid_out[%pp], %c0 : i64, i64 + scf.for %dummy = %c0 to %c1 step %c1 { + %v = pto.vlds %ub_buf[%pp][%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%pp][%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } {llvm.loop.aivector_scope} + // Release buf[i%2] — MTE2 can reuse in iteration i+2 (WAR resolved) + pto.rls_buf "PIPE_V", %bufid_buf[%pp], %c0 : i64, i64 + pto.rls_buf "PIPE_V", %bufid_out[%pp], %c0 : i64, i64 + + // ── MTE3: store result ── + // Acquires out[i%2] — blocks until Vector releases it (RAW: automatic) + pto.get_buf "PIPE_MTE3", %bufid_out[%pp], %c0 : i64, i64 + pto.dma_store %ub_out[%pp], %gm_out[%i], ... + pto.rls_buf "PIPE_MTE3", %bufid_out[%pp], %c0 : i64, i64 +} +// No post-loop drain needed — last rls_buf completes the pipeline. +``` + +**No priming, no draining, no event IDs.** The acquire/release protocol on buffer IDs indexed by `i%2` implicitly resolves all cross-pipeline dependencies: +- **RAW** (MTE2→V): Vector's `get_buf` blocks until MTE2's `rls_buf` on `buf[i%2]` +- **WAR** (V→MTE2): MTE2's `get_buf` in iteration `i+2` blocks until Vector's `rls_buf` in iteration `i` (same buffer) +- **First iteration:** Buffer is initially free, so `get_buf` proceeds without blocking — no priming needed + +--- + +#### Comparison Summary + +| Aspect | `set_flag` / `wait_flag` | `get_buf` / `rls_buf` | +|--------|--------------------------|------------------------| +| Dependency model | Explicit event signals | Implicit via buffer acquire/release | +| IDs per pipe-pair | 2 IDs per buffer: 1 for forward (e.g., MTE2→V) + 1 for reverse (V→MTE2) | **1 ID per buffer** (handles both directions automatically) | +| Total HW IDs | **8 per pipe-pair** (hardware limit) | **32 global** across all pipes | +| Reverse (WAR) deps | Extra `set_flag`/`wait_flag` pair per buffer | Handled automatically | +| Pre-loop setup | `set_flag` to prime each reverse dep | **None** | +| Post-loop teardown | `wait_flag` to drain all primed signals | **None** | +| Loop peeling for complex deps | Required for non-1:1 or nested loops | **Not required** | +| Straight-line code | Simple, clear | Slightly more verbose (bracket each stage) | +| Ping/pong loops | 8 event IDs + 4 prime + 4 drain | Same pattern, **no overhead** | +| Best used for | Simple pipelines, fine-grained control | Double/multi-buffering, complex loops | + +--- + +#### Inter-Core Sync + +> **Note:** Inter-core sync is only needed for **mixed Cube+Vector tasks** where Cube produces data that Vector consumes (or vice versa). **Vec-only tasks can ignore this section entirely.** + +These ops coordinate execution across the Cube block and Vector subblocks within a cluster. Each core cluster consists of **1 Cube block : 2 Vector subblocks**, each with its own **SU (Sequencer Unit)** running independent instruction streams. + +``` +Core Cluster (1:2 ratio) +┌─────────────────────────────────────────────┐ +│ ┌──────────────┐ ┌──────────────┐ │ +│ │ AIC (Cube) │ │ AIV0 (Vec) │ │ +│ │ ┌────────┐ │ │ ┌────────┐ │ │ +│ │ │ SU │──┼────┼──│ SU │ │ │ +│ │ └────────┘ │ │ └────────┘ │ │ +│ │ CUBE pipe │ │ MTE2/V/MTE3 │ │ +│ │ L0C buffer │ │ UB (256KB) │ │ +│ └──────────────┘ └──────────────┘ │ +│ ┌──────────────┐ │ +│ │ AIV1 (Vec) │ │ +│ │ ┌────────┐ │ │ +│ │ │ SU │ │ │ +│ │ └────────┘ │ │ +│ │ MTE2/V/MTE3 │ │ +│ │ UB (256KB) │ │ +│ └──────────────┘ │ +└─────────────────────────────────────────────┘ +``` + +##### Platform Comparison + +| Aspect | A2A3 (Ascend 910) | A5 (Ascend 950) | +|--------|-------------------|-----------------| +| **Signal op** | `set_cross_core` (mode2) | `set_intra_block` | +| **Wait op** | `wait_flag_dev` | `wait_intra_core` | +| **Wait behavior** | SU-level blocking (entire core stalls) | Per-pipeline (only named pipe stalls) | +| **Semaphore pool** | 16 IDs per cluster, 4-bit counter | 16 IDs, but 32-ID address space (see below) | +| **C→V** | **Broadcast**: one `set` reaches both AIV0+AIV1 | **1:1**: separate `set` per subblock required | +| **V→C** | **Reduce**: Cube waits for both subblocks in one `wait` | **1:1**: Cube needs separate `wait` per subblock | + +##### A2A3: `set_cross_core` / `wait_flag_dev` + +```c +// mode2 broadcast/reduce semantics for 1:2 cluster +set_cross_core(pipe, semaphore_id); // pipe: VEC/MTE2/CUBE/FIX +wait_flag_dev(semaphore_id); // SU-level blocking +``` + +``` +C→V Broadcast (one set reaches both): + AIC ──set_cross_core──┬──> AIV0 sema++ + └──> AIV1 sema++ + +V→C Reduce (one wait for both): + AIV0 ──set_cross_core──┐ + ├──> AIC wait_flag_dev (blocks until BOTH) + AIV1 ──set_cross_core──┘ +``` + +##### `pto.set_cross_core` + +- **syntax:** `pto.set_cross_core %core_id, %event_id : i64, i64` +- **semantics:** Signal event to another core. Uses **mode2** for 1:2 cluster on A2A3. + +##### `pto.wait_flag_dev` + +- **syntax:** `pto.wait_flag_dev %core_id, %event_id : i64, i64` +- **semantics:** Wait for event from another core. **SU-level blocking** — entire core stalls. + +##### A5: `set_intra_block` / `wait_intra_core` + +```c +set_intra_block(trigger_pipe, semaphore_id); +wait_intra_core(wait_pipe, semaphore_id); // only named pipe stalls +``` + +**A5 semaphore address space:** The hardware has **16 physical semaphore IDs** but exposes a **32-ID address space** to support 1:1 signaling to each subblock: + +| ID Range | Target | +|----------|--------| +| 0–15 | AIV0 (subblock 0) | +| 16–31 (+15 offset) | AIV1 (subblock 1) | + +This means C→V requires **separate `set_intra_block` calls** per subblock (no broadcast), and V→C requires **separate `wait_intra_core` calls** per subblock (no hardware reduce). + +``` +C→V on A5 (1:1, no broadcast — need two sets): + AIC ──set_intra_block(pipe, sema_id)────> AIV0 + AIC ──set_intra_block(pipe, sema_id+15)──> AIV1 + +V→C on A5 (1:1, no reduce — need two waits): + AIV0 ──set_intra_block──> AIC wait_intra_core(pipe, sema_id) + AIV1 ──set_intra_block──> AIC wait_intra_core(pipe, sema_id+15) // extra wait +``` + +##### `pto.set_intra_block` + +- **syntax:** `pto.set_intra_block %block_id, %event_id : i64, i64` +- **semantics:** Signal event within a block (A5). Specifies **trigger pipe**. 1:1 per subblock. + +##### `pto.wait_intra_core` + +- **syntax:** `pto.wait_intra_core %block_id, %event_id : i64, i64` +- **semantics:** Wait for event within block (A5). Specifies **which pipeline should wait** — only that pipe stalls, SU and other pipes continue. + +##### Wait Granularity Comparison + +``` +A2A3 wait_flag_dev (SU-level stall): + SU ──┬── PIPE_MTE2 ───╳ ALL STALLED + ├── PIPE_V ───╳ ALL STALLED + └── PIPE_MTE3 ───╳ ALL STALLED + +A5 wait_intra_core "PIPE_MTE2" (per-pipe stall): + SU ──┬── PIPE_MTE2 ───╳ STALLED (waiting for Cube) + ├── PIPE_V ─── ✓ RUNNING + └── PIPE_MTE3 ─── ✓ RUNNING +``` + + + +### 2. DMA Copy Programming + +> **Category:** DMA transfer configuration and execution +> **Pipelines:** MTE2 (GM→UB), MTE3 (UB→GM) + +DMA transfers move data between Global Memory (GM) and Unified Buffer (UB). The MTE engines operate asynchronously from the Vector core, requiring explicit sync (see [Pipeline Sync](#micro-01-pipeline-sync)). + +This document describes the public grouped DMA interfaces: + +- `pto.dma_load` +- `pto.dma_store` +- `pto.dma_copy` + +This chapter covers the public grouped DMA interfaces. The legacy raw copy +family remains documented separately; in particular, `pto.copy_ubuf_to_ubuf` +shares the same UB→UB copy contract as `pto.dma_copy` but remains a legacy +surface op. + +--- + +#### DMA Transfer Execution + +##### `pto.dma_load` + +- **syntax:** +```mlir +pto.dma_load %gm_src, %ub_dst, %l2_cache_ctl, %len_burst + nburst(%n_burst, %src_stride, %dst_stride) + [loop(%loop_count, %loop_src_stride, %loop_dst_stride)]* + [pad(%pad_value[, %left_padding_count, %right_padding_count])] + : !pto.ptr, !pto.ptr, i64, i64, i64, + [loop i64, i64, i64,]* + [pad T[, i64, i64]] +``` +- **semantics:** Grouped GM→UB DMA transfer. `nburst(...)` defines the innermost repeated burst transfer, optional `loop(...)` groups add outer repetition levels, and `pad(...)` controls UB row padding. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%gm_src` | ptr | GM source pointer (`!pto.ptr`) | +| `%ub_dst` | ptr | UB destination pointer (`!pto.ptr`, 32B-aligned) | +| `%l2_cache_ctl` | 2 bits | L2 cache allocate control | +| `%len_burst` | 16 bits | Contiguous bytes transferred per burst row | +| `nburst(%n_burst, %src_stride, %dst_stride)` | 16 bits / 40 bits / 21 bits | Required innermost burst group: count, GM source stride, UB destination stride | +| `loop(%loop_count, %loop_src_stride, %loop_dst_stride)` | 21 bits / 40 bits / 21 bits | Optional outer repetition group: count, GM source stride, UB destination stride | +| `pad(%pad_value[, %left_padding_count, %right_padding_count])` | scalar / 8 bits / 8 bits | Optional padding: fill value, optional left padding count, optional right padding count | + +**Constraints:** + +- `nburst(...)` is always required. +- Each `loop(...)` group must be provided as a complete triple when present. +- `nburst(...)` is the innermost group. +- `loop(...)` groups are ordered from inner to outer. +- The first `loop(...)` group wraps `nburst(...)`. +- Each additional `loop(...)` group wraps all earlier groups. +- `pad(...)` may contain only `%pad_value`; omitted left and right padding counts default to 0. +- If either left or right padding count is provided, both counts must be provided. +- `pad(...)` is independent of the optional `loop(...)` groups. +- A DMA load may use `nburst(...) pad(...)` without any `loop(...)` group. + +**Example:** + +```mlir +pto.dma_load %gm_in, %ub_out, %cache, %len_burst + nburst(%rows, %gm_row_stride, %ub_row_stride) + loop(%tiles, %gm_tile_stride, %ub_tile_stride) + pad(%pad) + : !pto.ptr, !pto.ptr, i64, i64, + loop i64, i64, i64, pad f16 +``` + +--- + +##### `pto.dma_store` + +- **syntax:** +```mlir +pto.dma_store %ub_src, %gm_dst, %len_burst + nburst(%n_burst, %src_stride, %dst_stride) + [loop(%loop_count, %loop_src_stride, %loop_dst_stride)]* + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, + [loop i64, i64, i64,]* +``` +- **semantics:** Grouped UB→GM DMA transfer. `nburst(...)` defines the innermost repeated burst transfer, and optional `loop(...)` groups add outer repetition levels. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%ub_src` | ptr | UB source pointer (`!pto.ptr`, 32B-aligned) | +| `%gm_dst` | ptr | GM destination pointer (`!pto.ptr`) | +| `%len_burst` | 16 bits | Contiguous bytes transferred per burst row | +| `nburst(%n_burst, %src_stride, %dst_stride)` | 16 bits / 21 bits / 40 bits | Required innermost burst group: count, UB source stride, GM destination stride | +| `loop(%loop_count, %loop_src_stride, %loop_dst_stride)` | 21 bits / 21 bits / 40 bits | Optional outer repetition group: count, UB source stride, GM destination stride | + +**Constraints:** + +- `nburst(...)` is always required. +- Each `loop(...)` group must be provided as a complete triple when present. +- `nburst(...)` is the innermost group. +- `loop(...)` groups are ordered from inner to outer. +- The first `loop(...)` group wraps `nburst(...)`. +- Each additional `loop(...)` group wraps all earlier groups. + +**Example:** + +```mlir +pto.dma_store %ub_in, %gm_out, %len_burst + nburst(%rows, %ub_row_stride, %gm_row_stride) + loop(%tiles, %ub_tile_stride, %gm_tile_stride) + loop(%batches, %ub_batch_stride, %gm_batch_stride) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, + loop i64, i64, i64, loop i64, i64, i64 +``` + +--- + +##### `pto.dma_copy` + +- **syntax:** +```mlir +pto.dma_copy %ub_src, %ub_dst, %len_burst + nburst(%n_burst, %src_gap, %dst_gap) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 +``` +- **semantics:** Grouped UB→UB raw copy.. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%ub_src` | ptr | UB source pointer (`!pto.ptr`, 32B-aligned) | +| `%ub_dst` | ptr | UB destination pointer (`!pto.ptr`, 32B-aligned) | +| `%len_burst` | 16 bits | Burst length in units of 32 bytes | +| `nburst(%n_burst, %src_gap, %dst_gap)` | 16 bits / 16 bits / 16 bits | Required UB→UB outer burst group: count, source gap, destination gap | + +**Constraints:** + +- UB source and destination addresses must be 32B-aligned. +- `%len_burst`, `%src_gap`, and `%dst_gap` are encoded in units of 32 bytes. + +**Example:** + +```mlir +pto.dma_copy %ub_src, %ub_dst, %len32b + nburst(%rows, %src_gap, %dst_gap) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 +``` + +--- + +#### GM↔UB Burst / Stride / Pad Model + +This section describes the grouped GM↔UB DMA interfaces in this document: +`pto.dma_load` and `pto.dma_store`. + +For these grouped GM↔UB DMA ops, the innermost `nburst(...)` group is +**stride-based**: the source and destination stride operands are the +start-to-start byte distance from one burst row to the next row. + +##### Key Terms + +``` +burst = lenBurst contiguous bytes transferred per row +stride = distance (bytes) from start of row[r] to start of row[r+1] +pad = ub_stride - lenBurst, padded to the 32B alignment boundary +``` + +##### Alignment Constraints + +- **UB addresses** (both source and destination) must be **32-byte aligned**. +- **GM→UB padding**: When `pad(...)` is present on `pto.dma_load`, each UB row is padded from `lenBurst` up to the **32B-aligned boundary** of `ub_stride` with `pad_val`. This ensures every UB row starts at a 32B-aligned offset. +- **UB→GM de-padding**: MTE3 reads `lenBurst` bytes from each 32B-aligned UB row (skipping any padding that was added during load), writing only valid data to GM. This effectively strips padding on store. + +--- + +#### UB→UB Burst / Gap Model + +This section describes the grouped UB→UB DMA interface in this document: +`pto.dma_copy`. + +For `pto.dma_copy`, each burst copies `len_burst * 32` bytes. + +The next burst starts at: + +```text +src_next = src_curr + (len_burst + src_gap) * 32 bytes +dst_next = dst_curr + (len_burst + dst_gap) * 32 bytes +``` + +So `src_gap` and `dst_gap` are additional gaps after the copied 32B blocks. +They are not start-to-start strides. + +##### 2D Diagram: GM→UB (`pto.dma_load`) + +``` +GM (source, `!pto.ptr`): + + |<--- src_stride (start-to-start) --->| + |<- len_burst ->| | +Row 0: [##DATA########]......................| +Row 1: [##DATA########]......................| +Row 2: [##DATA########]......................| + ... +Row N-1: [##DATA########] + +UB (destination, `!pto.ptr`, 32B-aligned): + + |<---------- dst_stride (32B-aligned) ---------->| + |<- len_burst ->|<- pad (to 32B boundary) ->| | +Row 0: [##DATA########][000000 PAD 000000000000000] +Row 1: [##DATA########][000000 PAD 000000000000000] +Row 2: [##DATA########][000000 PAD 000000000000000] + ... +Row N-1: [##DATA########][000000 PAD 000000000000000] + +N = n_burst +stride = start of row[r] to start of row[r+1] +pad = filled with pad_val to 32B boundary (`pad(...)` present) +[DATA] = valid data transferred by DMA +[PAD] = pad_val fill (from `pad(...)`) +``` + +##### 2D Diagram: UB→GM (`pto.dma_store`) + +``` +UB (source, `!pto.ptr`, 32B-aligned start addr): + + |<---------- src_stride (32B-aligned) --------->| + |<- len_burst ->|<-- pad (ignored on read) -->| | +Row 0: [##DATA########][000 pad 000000000000000000] +Row 1: [##DATA########][000 pad 000000000000000000] +Row 2: [##DATA########][000 pad 000000000000000000] + ... +Row N-1: [##DATA########][000 pad 000000000000000000] + +GM (destination, `!pto.ptr`): + + |<--- dst_stride (start-to-start) --->| + |<- len_burst ->| | +Row 0: [##DATA########]......................| +Row 1: [##DATA########]......................| +Row 2: [##DATA########]......................| + ... +Row N-1: [##DATA########] + +N = n_burst +MTE3 reads only len_burst bytes from each UB row (de-padding). +Only len_burst bytes are written to each GM row. +``` + +--- + +#### Multi-Level Loop Semantics + +The full DMA transfer is a nested loop. `nburst(...)` is the innermost group. +If one or more `loop(...)` groups are present, they wrap `nburst(...)` in the +same order they appear in the op: the first `loop(...)` is the innermost outer +group, the second `loop(...)` wraps the first one, and so on. + +##### GM→UB Full Loop + +For a form + +```mlir +pto.dma_load %gm_src, %ub_dst, %l2_cache_ctl, %len_burst + nburst(%n_burst, %src_stride, %dst_stride) + loop(%c0, %s0, %d0) + loop(%c1, %s1, %d1) + ... + loop(%cN, %sN, %dN) + [pad(%pad_value[, %left_padding_count, %right_padding_count])] +``` + +the transfer is equivalent to: + +```c +for (int lN = 0; lN < cN; ++lN) { + ... + for (int l1 = 0; l1 < c1; ++l1) { + for (int l0 = 0; l0 < c0; ++l0) { + uint8_t *gm_base = gm_src + l0 * s0 + l1 * s1 + ... + lN * sN; + uint8_t *ub_base = ub_dst + l0 * d0 + l1 * d1 + ... + lN * dN; + for (int r = 0; r < n_burst; ++r) { + memcpy(ub_base + r * dst_stride, + gm_base + r * src_stride, + len_burst); + if (pad_enabled) + memset(ub_base + r * dst_stride + len_burst, + pad_val, + dst_stride - len_burst); + } + } + } +} +``` + +If no `loop(...)` group is present, only the innermost `nburst(...)` loop +remains. + +##### UB→GM Full Loop + +For a form + +```mlir +pto.dma_store %ub_src, %gm_dst, %len_burst + nburst(%n_burst, %src_stride, %dst_stride) + loop(%c0, %s0, %d0) + loop(%c1, %s1, %d1) + ... + loop(%cN, %sN, %dN) +``` + +the transfer is equivalent to: + +```c +for (int lN = 0; lN < cN; ++lN) { + ... + for (int l1 = 0; l1 < c1; ++l1) { + for (int l0 = 0; l0 < c0; ++l0) { + uint8_t *ub_base = ub_src + l0 * s0 + l1 * s1 + ... + lN * sN; + uint8_t *gm_base = gm_dst + l0 * d0 + l1 * d1 + ... + lN * dN; + for (int r = 0; r < n_burst; ++r) { + memcpy(gm_base + r * dst_stride, + ub_base + r * src_stride, + len_burst); + } + } + } +} +``` + +If no `loop(...)` group is present, only the innermost `nburst(...)` loop +remains. + +--- + +#### Example 1: GM→UB — Load a 32×32 f32 Tile (Simple Case) + +Load a 32×32 f32 tile from GM into UB. This matches the `abs_kernel_2d` test case. + +``` +GM layout (32 × 32 f32, contiguous): + + |<- len_burst = 128B (32 × 4) ->| + |<- src_stride = 128B --------->| + +--[#######TILE#######]--+ row 0 + +--[#######TILE#######]--+ row 1 + ... + +--[#######TILE#######]--+ row 31 + +UB layout (32 × 32 f32, 32B-aligned, contiguous): + + |<- dst_stride = 128B (32B-aligned) ->| + +--[#######TILE#######]--+ row 0 + +--[#######TILE#######]--+ row 1 + ... + +--[#######TILE#######]--+ row 31 + + len_burst = 32 × 4 = 128 bytes + src_stride = 128 bytes (contiguous rows) + dst_stride = 128 bytes (already 32B-aligned, no padding) +``` + +```mlir +// Simple 2D load — only nburst(...) is needed +pto.dma_load %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64 +``` + +--- + +#### Example 2: GM→UB — Load a 2D Tile from a Larger Matrix + +Load a 64×128 tile (f16) from a 1024×512 matrix in GM into UB. + +``` +GM layout (1024 × 512 f16): + + col 0 col 128 col 512 + | | | + +--[###TILE###]+.....................+ row R + +--[###TILE###]+.....................+ row R+1 + ... + +--[###TILE###]+.....................+ row R+63 + + |<--------- src_stride = 1024B ----------->| + |<-len_burst=256B->| + + len_burst = 128 × 2 = 256 bytes (128 f16 elements) + src_stride = 512 × 2 = 1024 bytes (start-to-start, full GM row) + +UB layout (64 × 128 f16, 32B-aligned, contiguous): + + +--[###TILE###]--+ row 0 (256 bytes, 32B-aligned, no pad) + +--[###TILE###]--+ row 1 + ... + +--[###TILE###]--+ row 63 + + dst_stride = 256 bytes (= len_burst, already 32B-aligned, no padding) +``` + +```mlir +pto.dma_load %gm_ptr, %ub_ptr, %c0_i64, %c256_i64 + nburst(%c64_i64, %c1024_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64 +``` + +--- + +#### Example 3: GM→UB — Load with Padding + +Load 100 valid columns from GM into a 128-wide UB tile (f16). The remaining 28 columns are zero-padded. + +``` +GM (100 cols valid, contiguous): + + |<-len_burst=200B->| + |<- src_stride=200B (start-to-start) ->| + +--[####DATA####]-+ row 0 + +--[####DATA####]-+ row 1 + ... + +--[####DATA####]-+ row 63 + +UB (128 cols wide, 32B-aligned, padded): + + |<--------- dst_stride = 256B (32B-aligned) --------->| + |<-len_burst=200B->|<---- pad = 56B to 32B boundary ->| + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 0 + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 1 + ... + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 63 + + len_burst = 100 × 2 = 200 bytes + src_stride = 200 bytes (start-to-start, contiguous in GM) + dst_stride = 128 × 2 = 256 bytes (32B-aligned tile width in UB) + pad = 256 - 200 = 56 bytes (padded to 32B boundary with pad_val) +``` + +```mlir +%pad = arith.constant 0 : i16 +pto.dma_load %gm_ptr, %ub_ptr, %c0_i64, %c200_i64 + nburst(%c64_i64, %c200_i64, %c256_i64) + pad(%pad, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, pad i16, i64, i64 +``` + +--- + +#### Example 4: UB→GM — Store a 32×32 f32 Tile (Simple Case) + +Store a 32×32 f32 tile from UB back to GM. This matches the `abs_kernel_2d` test case. + +``` +UB (source, 32B-aligned, 32 × 32 f32): + + |<- src_stride = 128B (32B-aligned) ->| + |<- len_burst = 128B ->| + +--[#######TILE#######]---+ row 0 + +--[#######TILE#######]---+ row 1 + ... + +--[#######TILE#######]---+ row 31 + + (no padding here — len_burst == src_stride) + +GM (dest, 32 × 32 f32): + + |<- dst_stride = 128B ->| + |<- len_burst = 128B -->| + +--[#######TILE#######]---+ row 0 + +--[#######TILE#######]---+ row 1 + ... + +--[#######TILE#######]---+ row 31 +``` + +```mlir +pto.dma_store %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 +``` + +--- + +#### Example 5: UB→GM — Store a 2D Tile Back to a Larger Matrix + +Store a 64×128 tile (f16) from UB back to a 1024×512 GM matrix at an offset. + +``` +UB (source, 32B-aligned, 64 × 128 f16): + + |<- src_stride = 256B (32B-aligned) ->| + |<- len_burst = 256B ->| + +--[#####TILE#####]---+ row 0 + +--[#####TILE#####]---+ row 1 + ... + +--[#####TILE#####]---+ row 63 + + (no padding here — len_burst == src_stride) + +GM (dest, into 1024 × 512 matrix): + + |<----------- dst_stride = 1024B (start-to-start) --------->| + |<- len_burst = 256B ->| | + col 0 col 128 col 512 + +--[#####TILE#####]---+.............................+ row R + +--[#####TILE#####]---+.............................+ row R+1 + ... + +--[#####TILE#####]---+.............................+ row R+63 + + MTE3 reads len_burst bytes from each 32B-aligned UB row, + writes only len_burst bytes per GM row (stride controls row spacing). +``` + +```mlir +pto.dma_store %ub_ptr, %gm_ptr, %c256_i64 + nburst(%c64_i64, %c256_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 +``` + +--- + +#### Example 6: GM→UB with Multi-Level Loop (Batch of Tiles) + +Load 4 batches of 8×128 tiles from a [4, 8, 128] f16 tensor using one outer +`loop(...)` group. + +``` +GM [4, 8, 128] f16 (contiguous): UB (4 tiles laid out sequentially): + + batch 0: 8 rows × 256 bytes [batch 0: 8×128][batch 1: 8×128] + batch 1: 8 rows × 256 bytes [batch 2: 8×128][batch 3: 8×128] + batch 2: 8 rows × 256 bytes + batch 3: 8 rows × 256 bytes outer loop src_stride = 2048 bytes (8 × 256) + outer loop dst_stride = 2048 bytes (8 × 256) + Each batch = 8 × 256 = 2048 bytes outer loop count = 4 (iterate over batches) +``` + +```mlir +// One outer loop group over 4 batches +pto.dma_load %gm_ptr, %ub_ptr, %c0_i64, %c256_i64 + nburst(%c8_i64, %c256_i64, %c256_i64) + loop(%c4_i64, %c2048_i64, %c2048_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, loop i64, i64, i64 +``` + +Execution trace: + +``` +loop iter 0: gm_ptr + 0×2048 → ub_ptr + 0×2048, DMA 8 rows × 256B +loop iter 1: gm_ptr + 1×2048 → ub_ptr + 1×2048, DMA 8 rows × 256B +loop iter 2: gm_ptr + 2×2048 → ub_ptr + 2×2048, DMA 8 rows × 256B +loop iter 3: gm_ptr + 3×2048 → ub_ptr + 3×2048, DMA 8 rows × 256B +``` + + + +### 3. Vector Load/Store + +> **Category:** UB ↔ Vector Register data movement +> **Pipeline:** PIPE_V (Vector Core) + +Vector loads move data from Unified Buffer (UB) to vector registers (`vreg`). Vector stores move data from `vreg` back to UB. All vector compute operates only on `vreg` — UB is the staging area between DMA and compute. + +#### Common Operand Model + +- `%source` / `%dest` is the base address operand in SSA form. The base pointer + MUST address the Vector tile buffer / UB space. +- `%offset` is the displacement operand in SSA form. The exact encoding is + instruction-specific, but the effective address and any post-update behavior + MUST match the selected instruction form. +- `%mask` is the predicate operand for predicated memory families. For memory + families, + inactive lanes or inactive blocks MUST NOT issue memory requests unless the + instruction explicitly documents a different behavior. +- `%result` is the destination vector register value in SSA form. +- `!pto.align` is the SSA carrier for alignment-buffer state used by unaligned + load/store families. The PTO micro Instruction representation makes that state explicit rather than implicit. + +--- + +#### Latency and throughput (A5) + +**Cycle-accurate simulator (CA model)** issue→retire timings for vector-side instructions behind this chapter. Values are **simulator** results, **not** guaranteed for silicon. + +**SOC:** Tables below are from **Ascend910_9599** CA sim (the pto-isa ST default when **Ascend950PR_9599** is not selected). + +**Log `dist:` tokens:** PTO load/store modes lower to **`RV_VLD` / `RV_VLDI` / `RV_VST` / `RV_VSTI`** with a **`dist:`** field on the vector pipes (`RVECLD` / `RVECST`). Some simulator logs typo contiguous load as `dist:NORAML`; treat as **`NORMAL`**. + +##### Reference op latencies (A5 mnemonics) + +| A5 mnemonic | Mode / note | Typical issue→retire (cycles) | +|-------------|-------------|------------------------------| +| `RV_VLD` | `dist:NORMAL` / `NORAML` | **9** | +| `RV_VLDI` | `dist:DINTLV` (dual vreg) | **9** | +| `RV_VST` / `RV_VSTI` | `dist:NORM_B8` / `NORM_B16` / `NORM_B32` | **9** | +| `RV_VGATHER2` | `Dtype: B32` | **27–28** | +| `RV_VGATHERB` | indexed byte gather | **~21** | +| `RV_VSCATTER` | `Dtype: B16` | **~17** | +| `RV_VADD` | F32 between UB-backed ops | **7** | + +##### `dist:` tokens (issue→retire) + +Most **`dist:`** tokens are **9** issue→retire cycles. **`INTLV_B8` / `INTLV_B16` / `INTLV_B32`** on **`RV_VSTI`** are **12** cycles. + +| `dist:` (as in log) | RV op | issue→retire (cycles) | +|---------------------|-------|----------------------| +| `DINTLV_B8` / `DINTLV_B16` / `DINTLV_B32` | `RV_VLDI` | **9** | +| `BRC_B8` / `BRC_B16` / `BRC_B32` | `RV_VLD` | **9** | +| `BRC_BLK` | `RV_VLD` | **9** | +| `INTLV_B8` / `INTLV_B16` / `INTLV_B32` | `RV_VSTI` | **12** | +| `UNPK_B8` / `UNPK_B16` / `UNPK_B32` | `RV_VLD` | **9** | +| `NORM_B8` / `NORM_B16` / `NORM_B32` | `RV_VSTI` | **9** | +| `PK_B16` / `PK_B32` / `PK_B64` / `PK4_B32` | `RV_VSTI` | **9** | +| `NORMAL` / `NORAML` | `RV_VLD` | **9** | + +**Note:** PTO intrinsic **`BRC_BLK`** matches the **`BRC_BLK`** `dist:` string on **`RV_VLD`** in simulator logs (block-replicate path; not a plain contiguous copy in the usual tiling use). + +**Issue (vector load/store):** `pto.vlds` (**`RV_VLD`**) is **dual-issue capable**: two independent `pto.vlds` can issue **in the same cycle**. **Alternatively**, the hardware can issue **one** `pto.vlds` **and** **one** `pto.vsts` together (**1+1**) in the same cycle. Each cycle is **either** dual **`vlds`** **or** **`vlds` + `vsts` (1+1)**—those two issue modes are mutually exclusive. Sustained throughput still depends on RAW hazards and loop structure. + +**Throughput (simulator, pattern-dependent):** + +- **`RV_VLD` / `pto.vlds`:** Dual-issue **or** half of a **1+1** with `vsts`, per the rule above. +- **`RV_VST` / `pto.vsts`:** In a **1+1** cycle, pairs with one `vlds`; otherwise typically **one** store per cycle in tight loops. +- **`RV_VGATHER2`:** Much lower than contiguous `RV_VLD` (on the order of **~0.1** ops/cycle in steady-state alongside 27–28-cycle latency). + +##### PTO `dist` summary (loads) + +| PTO `dist` (load) | Latency | +|-------------------|-------------------| +| `NORM` | **9** cycles | +| `UNPK_B8` / `UNPK_B16` / `UNPK_B32` | **9** cycles | +| `DINTLV_B8` / `DINTLV_B16` / `DINTLV_B32` | **9** cycles (`RV_VLDI`) | +| `BRC_B8` / `BRC_B16` / `BRC_B32` | **9** cycles (`RV_VLD`) | +| `BRC_BLK` | **9** cycles as **`dist:BRC_BLK`** on `RV_VLD` | +| `BDINTLV` | **9** cycles | +| `US_B8` / `US_B16`, `DS_B8` / `DS_B16`, `SPLT4CHN`, `SPLT2CHN_B8` / `SPLT2CHN_B16` | **9** cycles | + +##### PTO `dist` summary (stores) + +| PTO `dist` (store) | Latency | +|--------------------|-------------------| +| `NORM_B8` / `NORM_B16` / `NORM_B32` | **9** cycles (`RV_VSTI`) | +| `PK_B16` / `PK_B32` / `PK_B64` / `PK4_B32` | **9** cycles | +| `INTLV_B8` / `INTLV_B16` / `INTLV_B32` (`pto.vstsx2`) | **12** cycles | +| `MRG4CHN_B8`, `MRG2CHN_B8`, `MRG2CHN_B16` | **9** cycles (surface retained; current A5 hardware still reports them unsupported at validation time) | + +##### Gather, scatter, and special addressing + +| PTO op | A5-level | Latency | +|--------|----------|-------------------| +| `pto.vgather2` | `RV_VGATHER2` | **27–28** cycles (pattern-dependent) | +| `pto.vgatherb` | `RV_VGATHERB` | **~21** cycles issue→retire | +| `pto.vgather2_bc` | (broadcast gather) | **27–28** cycles (same as **`pto.vgather2`**) | +| `pto.vscatter` | `RV_VSCATTER` | **~17** cycles for **`Dtype: B16`** | + +##### Strided loads/stores, unaligned ops, alignment state + +Ops such as **`pto.vldas`**, **`pto.vldus`**, **`pto.vsld`**, **`pto.vsldb`**, **`pto.vsst`**, **`pto.vsstb`**, **`pto.vsta`**, **`pto.vstas`**, **`pto.vstar`**, **`pto.vstu`**, **`pto.vstus`**, **`pto.vstur`**: **9** cycles (same vector load/store pipe family as contiguous `RV_VLD` / `RV_VST` unless listed otherwise above). + +##### Dual-issue vs DMA + +DMA **`TLOAD` / `TSTORE`** (global memory ↔ UB) use **MTE** pipes, not `RV_VLD`/`RV_VST`. **MTE2** `MOV_*` latency is not the same as vector `RV_VLD` latency (see `02-dma-copy.md` for GM↔UB movement). + +--- + +#### Contiguous Loads + +##### `pto.vlds` + +- **syntax:** `%result = pto.vlds %source[%offset] {dist = "DIST"} : !pto.ptr -> !pto.vreg` +- **semantics:** Vector load with distribution mode. +- **inputs:** + `%source` is the UB base address, `%offset` is the load displacement, and + `DIST` selects the distribution mode. +- **outputs:** + `%result` is the loaded vector register value. +- **constraints and limitations:** + The effective address MUST satisfy the alignment rule of the selected + distribution mode. `NORM` reads one full vector footprint. Broadcast, + upsample, downsample, unpack, split-channel, and deinterleave modes change + how memory bytes are mapped into destination lanes, but they do not change the + fact that the source is UB memory. PTO surface exposes load `dist` as family + tokens, and each family only supports the element widths listed below. + +**Distribution families:** + +| Family | Allowed element widths | C semantics | Latency | +|------|-------------|-------------|-------------| +| `NORM` | width-agnostic | `dst[i] = UB[base + i * sizeof(T)]` | **9** cycles | +| `BRC_B8` / `BRC_B16` / `BRC_B32` | `b8`, `b16`, `b32` | `dst[i] = UB[base]` for all `i` | **9** cycles | +| `US_B8` / `US_B16` | `b8`, `b16` | `dst[2*i] = dst[2*i+1] = UB[base + i]` | **9** cycles | +| `DS_B8` / `DS_B16` | `b8`, `b16` | `dst[i] = UB[base + 2*i]` | **9** cycles | +| `UNPK_B8` / `UNPK_B16` / `UNPK_B32` | `b8`, `b16`, `b32` | Expand packed source data into wider lanes | **9** cycles | +| `BRC_BLK` | width-agnostic | Block-replicate load path; simulator logs may print `dist:BRC_BLK` | **9** cycles | +| `E2B_B16` / `E2B_B32` | `b16`, `b32` | Load element groups and expand them into byte-oriented lane layout | **9** cycles | +| `UNPK4` | `b8` | Unpack 4-way packed `b8` source groups into destination lanes | **9** cycles | +| `SPLT4CHN` | `b8` | Split 4-channel interleaved source into one channel plane | **9** cycles | +| `SPLT2CHN_B8` / `SPLT2CHN_B16` | `b8`, `b16` | Split 2-channel interleaved source into one channel plane | **9** cycles | + +`pto.vlds` currently covers only single-result load families. Dual-result +deinterleave forms are modeled separately in PTO surface as +[`pto.vldsx2`](#ptovldsx2): `BDINTLV` is the block-deinterleave family, while +`DINTLV_B8` / `DINTLV_B16` / `DINTLV_B32` are the element-width-sensitive +deinterleave forms. + +**Example — Contiguous load:** +```mlir +%v = pto.vlds %ub[%offset] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> +``` + +**Example — Broadcast scalar to all lanes:** +```mlir +%v = pto.vlds %ub[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> +``` + +--- + +##### `pto.vldas` + +- **syntax:** `%result = pto.vldas %source : !pto.ptr -> !pto.align` +- **semantics:** Prime alignment buffer for subsequent unaligned load. +- **inputs:** + `%source` is the UB address whose surrounding aligned block seeds the load + alignment state. +- **outputs:** + `%result` is the initialized load-alignment state. +- **constraints and limitations:** + This op is the required leading operation for a `pto.vldus` stream using the + same alignment state. The source address itself need not be 32-byte aligned; + hardware truncates it to the aligned block boundary for the priming load. +- **Latency:** **9** cycles. + +--- + +##### `pto.vldus` + +- **syntax:** `%result, %align_out = pto.vldus %source, %align : !pto.ptr, !pto.align -> !pto.vreg, !pto.align` +- **semantics:** Unaligned load using primed align state. +- **inputs:** + `%source` is the current UB address and `%align` is the incoming load + alignment state primed by `pto.vldas` or a prior `pto.vldus`. +- **outputs:** + `%result` is the assembled vector value and `%align_out` is the updated + alignment state. +- **constraints and limitations:** + A matching `pto.vldas` MUST appear before the first dependent `pto.vldus` + stream in the same vector loop. The installed no-post A5 interface keeps a + struct-shaped internal return for lowering convenience, but its no-post + `base` field is not meaningful user-visible state. VPTO therefore hides that + value and only exposes the updated align carrier. Reusing the original + `%source` starts a new explicit access point; if the caller wants another + no-post access, it should compute the next source pointer explicitly and pair + it with the required align setup. +- **Latency:** **9** cycles. + +**Unaligned load pattern:** +```mlir +%align = pto.vldas %ub : !pto.ptr -> !pto.align +%vec, %align2 = pto.vldus %ub, %align : !pto.ptr, !pto.align -> !pto.vreg<64xf32>, !pto.align +``` + +--- + +##### `pto.init_align` + +- **syntax:** `%result = pto.init_align : !pto.align` +- **semantics:** Initialize store-side align carrier state. +- **outputs:** + `%result` is a fresh zero-initialized align carrier for store-side unaligned + streams such as `pto.vstus`, `pto.vstur`, `pto.vstar`, `pto.vstas`, and + `pto.pstu`. +- **constraints and limitations:** + This op is for store-family initialization only. Unaligned load streams still + start from `pto.vldas`. + +--- + +#### Dual Loads (Deinterleave) + +##### `pto.vldsx2` + +- **syntax:** `%low, %high = pto.vldsx2 %source[%offset], "DIST" : !pto.ptr, index -> !pto.vreg, !pto.vreg` +- **semantics:** Dual load with deinterleave (AoS → SoA conversion). +- **inputs:** + `%source` is the UB base pointer, `%offset` is the displacement, and `DIST` + selects a dual-load/deinterleave layout. +- **outputs:** + `%low` and `%high` are the two destination vectors. +- **constraints and limitations:** + This family is only legal for interleave/deinterleave style distributions. + The two outputs form an ordered pair, and that pairing MUST be preserved. + PTO surface accepts deinterleave families. `BDINTLV` is element-width + agnostic, while `DINTLV_B8` / `DINTLV_B16` / `DINTLV_B32` support only the + element widths listed in the + table. +- **latency:** `BDINTLV` / `DINTLV_B8` / `DINTLV_B16` / `DINTLV_B32` are all + **9** cycles. + +**Distribution families:** + +| Family | Allowed element widths | C semantics | Latency | +|------|-------------|-------------|-------------| +| `BDINTLV` | width-agnostic | Block deinterleave into two destination vectors | **9** cycles | +| `DINTLV_B8` / `DINTLV_B16` / `DINTLV_B32` | `b8`, `b16`, `b32` | Deinterleave alternating elements into `%low` / `%high` | **9** cycles | + +```c +// DINTLV_B32 family on 32-bit elements: deinterleave 32-bit elements +for (int i = 0; i < 64; i++) { + low[i] = UB[base + 8*i]; // even elements + high[i] = UB[base + 8*i + 4]; // odd elements +} +``` + +**Example — Load interleaved XY pairs into separate X/Y vectors:** +```mlir +%x, %y = pto.vldsx2 %ub[%offset], "DINTLV_B32" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +``` + +##### `pto.vsldb` + +- **syntax:** `%result = pto.vsldb %source, %block_stride, %repeat_stride, %mask : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg` +- **semantics:** Block-strided load for 2D tile access. +- **inputs:** + `%source` is the UB base pointer. `%block_stride` and `%repeat_stride` are + the two 16-bit fields of the hardware control word, and `%mask` controls + which blocks participate. +- **outputs:** + `%result` is the loaded vector. +- **constraints and limitations:** + PTO surface does not expose the packed control word directly. If a block is + masked off, the corresponding destination block is zeroed and MUST NOT raise + an address overflow exception for that block. +- **Latency:** **9** cycles. + +```c +// Block-strided load on 32-bit elements: one 32B block = 8 lanes. +for (int blk = 0; blk < 8; ++blk) { + if (pg_b32[blk]) + dst_block[blk] = UB_block[base + repeat_stride + blk * block_stride]; + else + dst_block[blk] = 0; +} +``` + +--- + +#### Gather (Indexed) Loads + +##### `pto.vgather2` + +- **syntax:** `%result = pto.vgather2 %source, %offsets, %mask : !pto.ptr, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Indexed gather from UB. +- **inputs:** + `%source` is the UB base pointer, `%offsets` provides per-lane element + offsets, and `%mask` selects the active requests. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + Only masked-on indices participate. The index element width + and interpretation MUST match the selected gather form, and each effective + address must satisfy that form's alignment rules. +- **Latency:** **27–28** cycles per `RV_VGATHER2`; throughput much lower than contiguous `RV_VLD` (see **Latency and throughput (A5)** at the start of this chapter). + +```c +for (int i = 0; i < N; i++) + if (mask[i]) + dst[i] = UB[base + offsets[i] * sizeof(T)]; +``` + +--- + +##### `pto.vgatherb` + +- **syntax:** `%result = pto.vgatherb %source, %offsets, %mask : !pto.ptr, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Block gather load from UB. +- **inputs:** + `%source` is the UB base pointer, `%offsets` is a `ui32` offset vector, and + `%mask` is a `b32` predicate over the block-index lanes. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + This is a 32-byte block gather, not an element gather. `%source` MUST be + 32-byte aligned. Each participating `offsets[i]` is interpreted as a byte + offset and MUST itself be 32-byte aligned. Only the low `VL/8` bytes of the + offset vector are semantically valid; the effective block address is + `block_addr[i] = offsets_u32[i] + base`. If a `b32` predicate position is + false, the corresponding block does not participate in address coalescing, + does not raise overflow on that block address, and the destination block is + zero-filled. +- **Latency:** **~21** cycles issue→retire. + +```c +for (int blk = 0; blk < VL / 32; ++blk) { + if (pg_b32[blk]) + dst_block[blk] = UB_block[base + offsets_u32[blk]]; + else + dst_block[blk] = 0; +} +``` + +--- + +##### `pto.vgather2_bc` + +- **syntax:** `%result = pto.vgather2_bc %source, %offsets, %mask : !pto.ptr, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Gather with broadcast, conditioned by mask. +- **inputs:** + `%source` is the UB base pointer, `%offsets` contains gather indices, and + `%mask` gates which lanes participate. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + This is a backward-compatible family. Masked-off lanes do not participate in + address coalescing and do not trigger address overflow exceptions; their + destination lanes are zero-filled. On the current PTO surface, `%offsets` + uses 32-bit integer elements. +- **Latency:** **27–28** cycles (same as **`pto.vgather2`**). + +--- + +#### Contiguous Stores + +##### `pto.vsts` + +- **syntax:** `pto.vsts %value, %dest[%offset], %mask {dist = "DIST"} : !pto.vreg, !pto.ptr, !pto.mask` +- **semantics:** Vector store with distribution mode. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, `%offset` is + the displacement, `%mask` is the predicate operand, and + `DIST` selects the store distribution. +- **outputs:** + This op has no SSA result; it writes to UB memory. +- **constraints and limitations:** + The effective destination address MUST satisfy the alignment rule of the + selected store mode. The single-input `pto.vsts` family covers contiguous + store, first-element-only store, packed store, and channel-merge store. + Dual-input interleave store remains in `pto.vstsx2`. PTO surface exposes + store `dist` as family tokens, and each family only supports the element + widths listed below. + +**Distribution families:** + +| Family | Allowed element widths | C semantics | Latency | +|------|-------------|-------------|-------------| +| `NORM_B8` / `NORM_B16` / `NORM_B32` | `b8`, `b16`, `b32` | `UB[base + i] = src[i]` | **9** cycles | +| `1PT_B8` / `1PT_B16` / `1PT_B32` | `b8`, `b16`, `b32` | Only element 0 is written to the destination footprint; the predicate register is ignored. | **9** cycles | +| `PK_B16` | `b16` | Pack the source vector, extract the lower half bits of all elements, and only store the active elements. The predicate is interpreted for 16-bit data. | **9** cycles | +| `PK_B32` | `b32` | Pack the source vector, extract the lower half bits of all elements, and only store the active elements. The predicate is interpreted for 32-bit data. | **9** cycles | +| `PK_B64` | `b64` | Pack the source vector, extract the lower half bits of all elements, and only store the active elements. The predicate is interpreted for 64-bit data. | **9** cycles | +| `PK4_B32` | `b32` | Pack the source vector, extract the lower 8 bits of all elements, and only store the active elements. The predicate is interpreted for 32-bit data. | **9** cycles | +| `MRG4CHN_B8` | `b8` | Merge 4 interleaved 8-bit channels within each 32B block; the predicate is interpreted for 32-bit data and applies after channel merge. | **9** cycles | +| `MRG2CHN_B8` / `MRG2CHN_B16` | `b8`, `b16` | Merge 2 interleaved channels within each 32B block; for `MRG2CHN_B8` the predicate is interpreted for 16-bit data, and for `MRG2CHN_B16` it is interpreted for 32-bit data; in both cases it applies after channel merge. | **9** cycles | + +**Example — Contiguous store:** +```mlir +pto.vsts %v, %ub[%offset], %mask {dist = "NORM_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +--- + +#### Dual Stores (Interleave) + +##### `pto.vstsx2` + +- **syntax:** `pto.vstsx2 %low, %high, %dest[%offset], "DIST", %mask : !pto.vreg, !pto.vreg, !pto.ptr, index, !pto.mask` +- **semantics:** Dual interleaved store (SoA → AoS conversion). +- **inputs:** + `%low` and `%high` are the two source vectors, `%dest` is the UB base pointer, + `%offset` is the displacement, `DIST` selects the interleave layout, and + `%mask` is the predicate operand. +- **outputs:** + This op has no SSA result; it writes an interleaved stream to UB. +- **constraints and limitations:** + This family is only legal for interleave distributions. The two source + vectors form an ordered pair, and the interleave semantics of that pair MUST + be preserved. PTO surface accepts the `INTLV` family, which only supports the + element widths listed below. For all `INTLV_*` distributions, the predicate + register is ignored. +- **latency:** `INTLV` is **12** cycles。 + +**Distribution families:** + +| Family | Allowed element widths | C semantics | Latency | +|------|-------------|-------------|-------------| +| `INTLV` | `b8`, `b16`, `b32` | Interleave `%low` / `%high` into one destination stream | **12** cycles | + +```c +// INTLV family on 32-bit elements: +for (int i = 0; i < 64; i++) { + UB[base + 8*i] = low[i]; + UB[base + 8*i + 4] = high[i]; +} +``` + +##### `pto.vsstb` + +- **syntax:** `pto.vsstb %value, %dest, %block_stride, %repeat_stride, %mask : !pto.vreg, !pto.ptr, i16, i16, !pto.mask` +- **semantics:** Block-strided store for 2D tile access. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, + `%block_stride` and `%repeat_stride` are the two 16-bit fields of the + hardware control word, and `%mask` controls block participation. +- **outputs:** + This op writes UB memory and returns no SSA value. +- **constraints and limitations:** + PTO surface does not expose the packed control word directly. Masked-off + blocks MUST NOT issue memory writes. +- **Latency:** **9** cycles. + +```c +// Block-strided store on 32-bit elements: one 32B block = 8 lanes. +for (int blk = 0; blk < 8; ++blk) { + if (pg_b32[blk]) + UB_block[base + repeat_stride + blk * block_stride] = src_block[blk]; +} +``` + +--- + +#### Scatter (Indexed) Stores + +##### `pto.vscatter` + +- **syntax:** `pto.vscatter %value, %dest, %offsets, %mask : !pto.vreg, !pto.ptr, !pto.vreg, !pto.mask` +- **semantics:** Indexed scatter to UB. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, `%offsets` + provides per-lane or per-block indices, and `%mask` selects the active + requests. +- **outputs:** + This op writes UB memory and returns no SSA value. +- **constraints and limitations:** + Only `b8`, `b16`, and `b32` element sizes are supported. The index vector + must use a supported integer element type and layout for this family. + Each computed address MUST be element-aligned. If two or more indices alias, + only one write is guaranteed and the winning lane is implementation-defined. +- **Latency:** **~17** cycles for **`Dtype: B16`**. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) + UB[base + offsets[i] * sizeof(T)] = src[i]; +``` + +--- + +#### Alignment State Stores + +##### `pto.vstas` +- **syntax:** `pto.vstas %value, %dest, %offset : !pto.align, !pto.ptr, i32` +- **semantics:** Scalar-register-offset form of alignment-state flush. +- **inputs:** + `%value` is the pending store-alignment state, `%dest` is the UB base + pointer, and `%offset` is the scalar-register style displacement. +- **outputs:** + This op writes buffered tail bytes to UB and returns no SSA value. +- **constraints and limitations:** + This family flushes pending store-alignment state using an explicit scalar + offset and keeps the scalar-offset form explicit. The incoming `%value` + should come from `pto.init_align` or from a prior state-producing unaligned + store op in the same stream. `%dest` and `%offset` together must identify the + same logical flush point produced by the immediately preceding stateful + unaligned-store step on that stream; using an unrelated base/offset pair is + invalid even if `%value` itself came from the same stream. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstar` +- **syntax:** `pto.vstar %value, %dest : !pto.align, !pto.ptr` +- **semantics:** Flush alignment state using the register-update form. +- **inputs:** + `%value` is the pending store-alignment state and `%dest` is the UB base + pointer. +- **outputs:** + This op writes buffered tail bytes to UB and returns no SSA value. +- **constraints and limitations:** + The implicit update state consumed by this flush MUST correspond to the same + store stream that produced `%value`. The first store-side state in a stream + should be created by `pto.init_align`. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstar` + +- **syntax:** `pto.vstar %value, %dest : !pto.align, !pto.ptr` +- **semantics:** Flush remaining alignment state. +- **inputs:** + `%value` is the pending alignment/buffer state that still needs to be emitted, + and `%dest` is the UB destination base pointer. +- **outputs:** + No SSA result. The effect is a memory-side flush that writes the remaining + buffered bytes to memory. +- **constraints and limitations:** + This op terminates an unaligned-store sequence. It MUST be paired with a + compatible prior state-producing store sequence so that the pending tail state + is well-defined. +- **Latency:** **9** cycles. + +--- + +#### Stateful Store Ops + +These ops make reference-updated state explicit as SSA results. + +##### `pto.vstus` + +- **syntax:** `%align_out = pto.vstus %align_in, %offset, %value, %base : !pto.align, i32, !pto.vreg, !pto.ptr -> !pto.align` +- **semantics:** No-post unaligned store with scalar offset. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%offset` is the scalar + displacement, `%value` is the vector being stored, and `%base` is the UB base + pointer. +- **outputs:** + `%align_out` is the updated buffered-tail state. +- **constraints and limitations:** + This is the scalar-offset stateful form of the unaligned store family. The + scalar offset width MUST match the selected form, and a later flush op is + still required. The first `%align_in` in the stream should come from + `pto.init_align`. This op does not mean "store a full vector starting at + `%base + %offset`". Instead, `%offset` describes how far the store stream + advances at this step, and `%align_out` carries any residual tail that could + not be committed yet. The no-post surface does not expose an updated base + pointer. A later flush op must therefore use an explicit destination/offset + pair that identifies the same logical flush point as this `pto.vstus`. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstur` + +- **syntax:** `%align_out = pto.vstur %align_in, %value, %base, "MODE" : !pto.align, !pto.vreg, !pto.ptr -> !pto.align` +- **semantics:** Unaligned store with residual flush and SPR-AR-driven state update. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%value` is the vector to + store, `%base` is the UB base pointer, and `MODE` selects whether the + hardware updates `SPR AR` after the store. +- **outputs:** + `%align_out` is the updated residual state after the current partial store. +- **constraints and limitations:** + The effective address is `base + AR`, where `AR` is the hardware SPR state + carried outside SSA. `POST_UPDATE` means hardware may advance `SPR AR` + according to the fixed `SPR SQZN` configuration; `NO_POST_UPDATE` preserves + the current `SPR AR` value. This form exposes only the evolving residual + align-state in SSA; it does not by itself guarantee that all buffered bytes + have reached memory. A compatible final flush is still required unless the + surrounding sequence is known to be complete. Independent sequences typically + begin from `AR = 0`; if the surrounding program does not already guarantee + that, the hardware sequence should clear `SPR AR` before the first dependent + `pto.vstur`. The first `%align_in` in the stream should come from + `pto.init_align`. `pto.vstur` also consumes the fixed `SPR SQZN` state, so a + preceding squeeze producer such as `pto.vsqz` / `pto.vusqz` MUST establish + the byte count before the store. `MODE` MUST be one of `POST_UPDATE` or + `NO_POST_UPDATE`. +- **Latency:** **9** cycles. + + + +### 4. Predicate Load/Store + +> **Category:** UB ↔ Predicate Register data movement +> **Pipeline:** PIPE_V (Vector Core) + +Predicate registers (`!pto.mask`) are 256-bit registers that enable per-lane conditional execution. These ops move predicate values between UB and predicate registers. + +In concrete examples, `G` should be chosen to match the consumer family. The +examples below use `b32` when the loaded/stored mask is used with `f32` +vector compares or selects. + +The predicate load/store ops documented on this page always use explicit +`base[offset]` addressing. The immediate forms (`pldi`, `psti`) and dynamic +forms (`plds`, `psts`) differ only in how `%offset` is supplied. + +--- + +#### Predicate Loads + +##### `pto.plds` + +- **syntax:** `%result = pto.plds %source[%offset], "DIST" : !pto.ptr, index -> !pto.mask` +- **semantics:** Load predicate register with runtime offset. This is the + dynamic-offset form of `pto.pldi`: the predicate payload interpretation is + the same, but `%offset` is supplied as an SSA `index` instead of a constant + `index` immediate. +- **DIST:** mandatory string token, one of `NORM`, `US`, `DS`. + - `NORM`: load a normal packed predicate payload of size `VL/8`. + - `US`: load a packed predicate payload of size `VL/16`, then duplicate each + loaded bit once. + - `DS`: load a packed predicate payload of size `2 * VL/8`, then keep one + bit out of every two bits. + +The loaded payload is a packed predicate image in UB. Consumer ops interpret +the resulting `!pto.mask` according to the mask granularity `G`. +`pto.plds` only +models the explicit `base[offset]` form. + +**Example:** +```mlir +%mask = pto.plds %ub[%c0], "NORM" : !pto.ptr, index -> !pto.mask +``` + +--- + +##### `pto.pldi` + +- **syntax:** `%result = pto.pldi %source[%offset], "DIST" : !pto.ptr, index -> !pto.mask` +- **offset:** must be a constant `index` immediate in PTO surface form. +- **semantics:** Load predicate register with immediate offset. +- **DIST:** mandatory string token, one of `NORM`, `US`, `DS`. + - `NORM`: load a normal packed predicate payload of size `VL/8`. + - `US`: load a packed predicate payload of size `VL/16`, then duplicate each + loaded bit once. + - `DS`: load a packed predicate payload of size `2 * VL/8`, then keep one + bit out of every two bits. + +Like `pto.plds`, this op reads a packed predicate payload from UB and +materializes it as `!pto.mask`. + +--- + +#### Predicate Stores + +##### `pto.psts` + +- **syntax:** `pto.psts %value, %dest[%offset], "DIST" : !pto.mask, !pto.ptr, index` +- **semantics:** Store predicate register with runtime offset. This is the + dynamic-offset form of `pto.psti`: the predicate payload interpretation is + the same, but `%offset` is supplied as an SSA `index` instead of a constant + `index` immediate. +- **DIST:** mandatory string token, one of `NORM`, `PK`. + - `NORM`: store the packed predicate payload into a normal destination space + of size `VL/8`. + - `PK`: store the packed predicate payload into a destination space of size + `VL/16`, keeping one bit out of every two bits. + +`pto.psts` stores the packed predicate payload represented by `!pto.mask`. +It only models the explicit `base[offset]` form. + +**Example:** +```mlir +pto.psts %mask, %ub[%c0], "NORM" : !pto.mask, !pto.ptr, index +``` + +--- + +##### `pto.psti` + +- **syntax:** `pto.psti %value, %dest[%offset], "DIST" : !pto.mask, !pto.ptr, index` +- **offset:** must be a constant `index` immediate in PTO surface form. +- **semantics:** Store predicate register with immediate offset. +- **DIST:** mandatory string token, one of `NORM`, `PK`. + - `NORM`: store the packed predicate payload into a normal destination space + of size `VL/8`. + - `PK`: store the packed predicate payload into a destination space of size + `VL/16`, keeping one bit out of every two bits. + +`pto.psti` and `pto.psts` store the packed predicate payload represented by +`!pto.mask`. The surface distinction is only immediate-offset versus +dynamic-offset. + +--- + +##### `pto.pstu` + +- **syntax:** `%align_out, %base_out = pto.pstu %align_in, %value, %base : !pto.align, !pto.mask, !pto.ptr -> !pto.align, !pto.ptr` +- **syntax:** `%align_out, %base_out = pto.pstu %align_in, %value, %base : !pto.align, !pto.mask, !pto.ptr -> !pto.align, !pto.ptr` +- **semantics:** Predicate unaligned store with align/base state update. The base type is fixed by mask granularity: `b16 <-> ui16`, `b32 <-> ui32`. +- **outputs:** + `%align_out` and `%base_out` are the updated unaligned-store state and are + intended to be used by a later `pto.pstu` call. +- **constraints and limitations:** + The first `%align_in` in a predicate unaligned-store stream should come from + `pto.init_align`. + +--- + +#### Typical Usage Pattern + +```mlir +// Generate comparison mask +%mask = pto.vcmp %v0, %v1, %seed, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + +// Store mask to UB for later use +pto.psts %mask, %ub_mask[%c0], "NORM" : !pto.mask, !pto.ptr, index + +// ... later in another kernel ... + +// Load mask from UB +%saved_mask = pto.plds %ub_mask[%c0], "NORM" : !pto.ptr, index -> !pto.mask + +// Use for predicated select +%result = pto.vsel %v_true, %v_false, %saved_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 5. Materialization & Predicate Ops + +> **Category:** Scalar broadcast, predicate generation and manipulation +> **Pipeline:** PIPE_V (Vector Core) + +These ops create vectors from scalar values and manipulate predicate registers. + +#### Common Operand Model + +- `%value` is the scalar source value in SSA form. +- `%input` is either a source scalar or a source vector depending on the op. +- `%result` is the destination vector register value. +- For 32-bit scalar inputs, the scalar source MUST satisfy the backend's legal + scalar-source constraints for this family. + +--- + +#### Scalar Materialization + +##### `pto.vbr` + +- **syntax:** `%result = pto.vbr %value : T -> !pto.vreg` +- **semantics:** Broadcast scalar to all vector lanes. +- **inputs:** + `%value` is the scalar source. +- **outputs:** + `%result` is a vector whose active lanes all carry `%value`. +- **constraints and limitations:** + Supported forms are `b8`, `b16`, and `b32`. For `b8`, only the low 8 bits of + the scalar source are consumed. + +```c +for (int i = 0; i < N; i++) + dst[i] = value; +``` + +**Example:** +```mlir +%one = pto.vbr %c1_f32 : f32 -> !pto.vreg<64xf32> +``` + +--- + +##### `pto.vdup` + +- **syntax:** `%result = pto.vdup %input, %mask {position = "LOWEST|HIGHEST"} : T|!pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Duplicate scalar or vector element to all lanes. +- **inputs:** + `%input` supplies the scalar or source-lane value selected by `position`, + and `%mask` controls the active lanes. +- **outputs:** + `%result` is the duplicated vector. +- **constraints and limitations:** + `position` selects which source vector element is duplicated and is only valid + for vector input. `position` defaults to `LOWEST`. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? input_scalar_or_element : 0; +``` + +--- + +#### Predicate Generation + +##### `pto.pset_b8` / `pto.pset_b16` / `pto.pset_b32` + +- **syntax:** `%result = pto.pset_b8 "PATTERN" : !pto.mask` +- **syntax:** `%result = pto.pset_b16 "PATTERN" : !pto.mask` +- **syntax:** `%result = pto.pset_b32 "PATTERN" : !pto.mask` +- **semantics:** Materialize a predicate register from a named pattern token. + +**Supported pattern tokens:** + +| Pattern | Description | +|---------|-------------| +| `PAT_ALL` | All lanes active | +| `PAT_ALLF` | All lanes inactive | +| `PAT_H` | High half active | +| `PAT_Q` | Upper quarter active | +| `PAT_VL1`...`PAT_VL128` | First N logical lanes active | +| `PAT_M3`, `PAT_M4` | Modular patterns | + +`PAT_ALL` is the PTO spelling of the VISA-style all-true predicate pattern. +The other tokens listed above are also concrete installed-toolchain pattern +objects, not PTO-only aliases. + +**Example — All 64 f32 lanes active:** +```mlir +%all_active = pto.pset_b32 "PAT_ALL" : !pto.mask +``` + +**Example — First 16 lanes active:** +```mlir +%first_16 = pto.pset_b32 "PAT_VL16" : !pto.mask +``` + +--- + +##### `pto.pge_b8` / `pto.pge_b16` / `pto.pge_b32` + +- **syntax:** `%result = pto.pge_b8 "PATTERN" : !pto.mask` +- **syntax:** `%result = pto.pge_b16 "PATTERN" : !pto.mask` +- **syntax:** `%result = pto.pge_b32 "PATTERN" : !pto.mask` +- **semantics:** Generate a predicate from a lane-count pattern token. In the + common tail-mask form, `PAT_VL` marks the first `N` logical lanes active. +- **supported pattern tokens:** `PAT_ALL`, `PAT_ALLF`, `PAT_H`, `PAT_Q`, + `PAT_VL1`, `PAT_VL2`, `PAT_VL3`, `PAT_VL4`, `PAT_VL8`, `PAT_VL16`, + `PAT_VL32`, `PAT_VL64`, `PAT_VL128`, `PAT_M3`, `PAT_M4` + +```c +for (int i = 0; i < TOTAL_LANES; i++) + mask[i] = (i < len); +``` + +**Example — Tail mask for remainder loop:** +```mlir +%tail_mask = pto.pge_b32 "PAT_VL8" : !pto.mask +``` + +--- + +##### `pto.plt_b8` / `pto.plt_b16` / `pto.plt_b32` + +- **syntax:** `%mask, %scalar_out = pto.plt_b8 %scalar : i32 -> !pto.mask, i32` +- **syntax:** `%mask, %scalar_out = pto.plt_b16 %scalar : i32 -> !pto.mask, i32` +- **syntax:** `%mask, %scalar_out = pto.plt_b32 %scalar : i32 -> !pto.mask, i32` +- **semantics:** Generate a tail-style predicate from an SSA lane-count value. + On A5/V300-style toolchains, this family is exposed as a post-update wrapper: + the predicate result becomes `%mask`, and the wrapper's carry-out scalar state + is surfaced as `%scalar_out`. +- **inputs:** + `%scalar` is the incoming lane-count / remaining-count state. +- **outputs:** + `%mask` is the generated predicate. + `%scalar_out` is the post-update scalar carry-out from the same `plt` call + and can be threaded into a subsequent `pto.plt_b*` call in the same chain. + +```c +for (int i = 0; i < VL_t; ++i) + mask[i] = (i < scalar_in); + +scalar_out = (scalar_in < VL_t) ? 0 : (scalar_in - VL_t); +``` + +Where `VL_t` is the logical lane count of the concrete op variant: + +- `pto.plt_b8`: `VL_t = 256` +- `pto.plt_b16`: `VL_t = 128` +- `pto.plt_b32`: `VL_t = 64` + +--- + +#### Predicate Pack/Unpack + +##### `pto.ppack` + +- **syntax:** `%result = pto.ppack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Narrowing pack of predicate register. +- **part tokens:** + - `LOWER`: pack into the lower half of `%result`; the upper half is zeroed. + - `HIGHER`: pack into the higher half of `%result`; the lower half is zeroed. + +Conceptually, `pto.ppack` keeps one bit out of each adjacent 2-bit group from +`%input`, packs those kept bits into the selected half of `%result`, and fills +the other half with zeros. + +```c +// Let VL be the logical lane count of the destination predicate. +// LOWER +for (int i = 0; i < VL / 2; ++i) + result[i] = input[2 * i]; +for (int i = VL / 2; i < VL; ++i) + result[i] = 0; + +// HIGHER +for (int i = 0; i < VL / 2; ++i) + result[VL / 2 + i] = input[2 * i]; +for (int i = 0; i < VL / 2; ++i) + result[i] = 0; +``` + +--- + +##### `pto.punpack` + +- **syntax:** `%result = pto.punpack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Widening unpack of predicate register. +- **part tokens:** + - `LOWER`: unpack from the lower half of `%input`. + - `HIGHER`: unpack from the higher half of `%input`. + +Conceptually, `pto.punpack` reads the selected half of `%input`, zero-extends +each 1-bit predicate element into a 2-bit group in `%result`, and leaves the +expanded image in the full destination predicate register. + +```c +// Let VL be the logical lane count of the destination predicate. +// LOWER +for (int i = 0; i < VL / 2; ++i) { + result[2 * i] = input[i]; + result[2 * i + 1] = 0; +} + +// HIGHER +for (int i = 0; i < VL / 2; ++i) { + result[2 * i] = input[VL / 2 + i]; + result[2 * i + 1] = 0; +} +``` + +--- + +#### Predicate Logical Ops + +##### `pto.pand` + +- **syntax:** `%result = pto.pand %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise AND gated by a governing predicate. + +Inactive lanes selected out by `%mask` are zeroed. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? (src0[i] & src1[i]) : 0; +``` + +--- + +##### `pto.por` + +- **syntax:** `%result = pto.por %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise OR gated by a governing predicate. + +Inactive lanes selected out by `%mask` are zeroed. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? (src0[i] | src1[i]) : 0; +``` + +--- + +##### `pto.pxor` + +- **syntax:** `%result = pto.pxor %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise XOR gated by a governing predicate. + +Inactive lanes selected by `%mask` are zeroed. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? (src0[i] ^ src1[i]) : 0; +``` + +--- + +##### `pto.pnot` + +- **syntax:** `%result = pto.pnot %input, %mask : !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise NOT gated by a governing predicate. + +Inactive lanes selected by `%mask` are zeroed. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? (~src[i]) : 0; +``` + +--- + +##### `pto.psel` + +- **syntax:** `%result = pto.psel %src0, %src1, %sel : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate select (mux). `%sel` is the governing predicate that + chooses lanes from `%src0` or `%src1`. + +```c +for (int i = 0; i < N; i++) + dst[i] = sel[i] ? src0[i] : src1[i]; +``` + +--- + +##### `pto.pdintlv_b8` / `pto.pdintlv_b16` / `pto.pdintlv_b32` + +- **syntax:** `%low, %high = pto.pdintlv_b8 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **syntax:** `%low, %high = pto.pdintlv_b16 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **syntax:** `%low, %high = pto.pdintlv_b32 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **semantics:** De-interleave two predicate sources and return the two + de-interleaved predicate images in the same predicate element family. + +--- + +##### `pto.pintlv_b8` / `pto.pintlv_b16` / `pto.pintlv_b32` + +- **syntax:** `%low, %high = pto.pintlv_b8 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **syntax:** `%low, %high = pto.pintlv_b16 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **syntax:** `%low, %high = pto.pintlv_b32 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **semantics:** Interleave two predicate sources and return the two + resulting predicate images in the same predicate element family. + +--- + +#### Typical Usage + +```mlir +// Generate all-active mask for f32 (64 lanes) +%all = pto.pset_b32 "PAT_ALL" : !pto.mask + +// Generate tail mask for remainder (last 12 elements) +%tail = pto.pge_b32 "PAT_VL12" : !pto.mask + +// Compare and generate mask +%cmp_mask = pto.vcmp %a, %b, %all, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + +// Combine masks: only process tail elements that passed comparison +%combined = pto.pand %cmp_mask, %tail, %all : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + +// Use for predicated operation +%result = pto.vsel %true_vals, %false_vals, %combined : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 6. Unary Vector Ops + +> **Category:** Single-input vector operations +> **Pipeline:** PIPE_V (Vector Core) + +Element-wise operations that take one vector input and produce one vector output. + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%mask` is the predicate operand. For this family, inactive lanes follow the + predication behavior of the selected instruction form: zeroing forms + zero-fill inactive lanes, while merging forms preserve the destination value. +- `%result` is the destination vector register value. Unless stated otherwise, + `%result` has the same lane count and element type as `%input`. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). **fp16** values use **aclFloat16** in traces where measured. **bf16:** no simple-tile ST coverage on this surface; treat as **—**. + +| PTO op | RV (CA) | fp32 | fp16 | bf16 | +|--------|---------|------|------|------| +| `pto.vabs` | `RV_VABS_FP` | **5** | **5** | — | +| `pto.vneg` | `RV_VMULS` | **8** | **8** | — | +| `pto.vexp` | `RV_VEXP` | **16** | **21** | — | +| `pto.vln` | `RV_VLN` | **18** | **23** | — | +| `pto.vsqrt` | `RV_VSQRT` | **17** | **22** | — | +| `pto.vrelu` | `RV_VRELU` | **5** | **5** | — | +| `pto.vnot` | `RV_VNOT` | — | int-only paths | — | +| `pto.vmov` | `RV_VLD` proxy | **9** | **9** | — | + +--- + +#### Arithmetic + +##### `pto.vabs` + +- **syntax:** `%result = pto.vabs %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] < 0) ? -src[i] : src[i]; +``` + +- **inputs:** `%input` supplies the source lanes and `%mask` selects which lanes + participate. +- **outputs:** `%result` receives the lane-wise absolute values. +- **constraints and limitations:** Source and result types MUST match. On A5, + integer overflow follows the ISA default truncation behavior for this family; + `pto.vabs` is not an explicit saturating op. + +--- + +##### `pto.vneg` + +- **syntax:** `%result = pto.vneg %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = -src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise arithmetic negation. +- **constraints and limitations:** Source and result types MUST match. + +--- + +#### Transcendental + +##### `pto.vexp` + +- **syntax:** `%result = pto.vexp %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = expf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds `exp(input[i])` per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + +--- + +##### `pto.vln` + +- **syntax:** `%result = pto.vln %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = logf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the natural logarithm per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + For real-number semantics, active inputs SHOULD be strictly positive; non- + positive inputs follow the target's exception/NaN rules. + +--- + +##### `pto.vsqrt` + +- **syntax:** `%result = pto.vsqrt %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = sqrtf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the square root per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + Negative active inputs follow the target's exception/NaN rules. + +--- + +#### Activation + +##### `pto.vrelu` + +- **syntax:** `%result = pto.vrelu %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** si32, i32, f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] > 0) ? src[i] : 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds `max(input[i], 0)` per active lane. +- **constraints and limitations:** Signed or signless 32-bit integer and + floating-point element types are legal on the current A5 surface described + here. + +--- + +#### Bitwise + +##### `pto.vnot` + +- **syntax:** `%result = pto.vnot %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = ~src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the lane-wise bitwise inversion. +- **constraints and limitations:** Integer element types only. + +--- + +#### Movement + +#### Typical Usage + +```mlir +// Softmax numerator: exp(x - max) +%sub = pto.vsub %x, %max_broadcast, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%exp = pto.vexp %sub, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// ReLU activation +%activated = pto.vrelu %linear_out, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 7. Binary Vector Ops + +> **Category:** Two-input vector operations +> **Pipeline:** PIPE_V (Vector Core) + +Element-wise operations that take two vector inputs and produce one vector output. + +#### Common Operand Model + +- `%lhs` and `%rhs` are the two source vector register values. +- `%mask` is the predicate operand `Pg` that gates which lanes participate. +- `%result` is the destination vector register value. Unless explicitly noted, + it has the same lane count and element type as the inputs. +- Unless explicitly documented otherwise, `%lhs`, `%rhs`, and `%result` MUST + have matching vector shapes and element types. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). **fp16** uses **aclFloat16** in measured traces. **bf16:** — (no dedicated vec tile ST on this surface). + +| PTO op | RV (CA) | fp32 | fp16 | bf16 | +|--------|---------|------|------|------| +| `pto.vadd` | `RV_VADD` | **7** | **7** | — | +| `pto.vsub` | `RV_VSUB` | **7** | **7** | — | +| `pto.vmul` | `RV_VMUL` | **8** | **8** | — | +| `pto.vdiv` | `RV_VDIV` | **17** | **22** | — | + +--- + +#### Arithmetic + +##### `pto.vadd` + +- **syntax:** `%result = pto.vadd %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i64, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] + src1[i]; +``` + +- **inputs:** `%lhs` and `%rhs` are added lane-wise; `%mask` selects active + lanes. +- **outputs:** `%result` is the lane-wise sum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vsub` + +- **syntax:** `%result = pto.vsub %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i64, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] - src1[i]; +``` + +- **inputs:** `%lhs` is the minuend, `%rhs` is the subtrahend, and `%mask` + selects active lanes. +- **outputs:** `%result` is the lane-wise difference. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vmul` + +- **syntax:** `%result = pto.vmul %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, bf16, f32 (**NOT** i8/ui8) + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] * src1[i]; +``` + +- **inputs:** `%lhs` and `%rhs` are multiplied lane-wise; `%mask` selects + active lanes. +- **outputs:** `%result` is the lane-wise product. +- **constraints and limitations:** The current A5 profile excludes `i8/ui8` + forms from this surface. + +--- + +##### `pto.vdiv` + +- **syntax:** `%result = pto.vdiv %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 only (no integer division) + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] / src1[i]; +``` + +- **inputs:** `%lhs` is the numerator, `%rhs` is the denominator, and `%mask` + selects active lanes. +- **outputs:** `%result` is the lane-wise quotient. +- **constraints and limitations:** Floating-point element types only. Active + denominators containing `+0` or `-0` follow the target's exceptional + behavior. + +--- + +##### `pto.vmax` + +- **syntax:** `%result = pto.vmax %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src0[i] > src1[i]) ? src0[i] : src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` holds the lane-wise maximum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vmin` + +- **syntax:** `%result = pto.vmin %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src0[i] < src1[i]) ? src0[i] : src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` holds the lane-wise minimum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +#### Bitwise + +##### `pto.vand` + +- **syntax:** `%result = pto.vand %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] & src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise AND. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vor` + +- **syntax:** `%result = pto.vor %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] | src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise OR. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vxor` + +- **syntax:** `%result = pto.vxor %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] ^ src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise XOR. +- **constraints and limitations:** Integer element types only. + +--- + +#### Shift + +##### `pto.vshl` + +- **syntax:** `%result = pto.vshl %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] << src1[i]; +``` + +- **inputs:** `%lhs` supplies the shifted value, `%rhs` supplies the per-lane + shift amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. Shift counts + SHOULD stay within `[0, bitwidth(T) - 1]`; out-of-range behavior is target- + defined unless the verifier narrows it further. + +--- + +##### `pto.vshr` + +- **syntax:** `%result = pto.vshr %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] >> src1[i]; // arithmetic for signed, logical for unsigned +``` + +- **inputs:** `%lhs` supplies the shifted value, `%rhs` supplies the per-lane + shift amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. Signedness of the + element type determines arithmetic vs logical behavior. + +--- + +#### Carry Operations + +##### `pto.vaddc` + +- **syntax:** `%result, %carry = pto.vaddc %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Add with carry output. + +```c +for (int i = 0; i < N; i++) { + uint64_t r = (uint64_t)src0[i] + src1[i]; + dst[i] = (T)r; + carry[i] = (r >> bitwidth); +} +``` + +- **inputs:** `%lhs` and `%rhs` are added lane-wise and `%mask` selects active + lanes. +- **outputs:** `%result` is the truncated arithmetic result and `%carry` is the + carry/overflow predicate per lane. +- **A5 types:** `i32`, `si32`, `ui32` +- **constraints and limitations:** This is a carry-chain integer add family. On + the current A5 surface, only 32-bit integer element types are supported. + `%mask` and `%carry` therefore use the same typed-mask granularity as the + data vector family, which on the current documented A5 surface means + `!pto.mask`. + +--- + +##### `pto.vsubc` + +- **syntax:** `%result, %carry = pto.vsubc %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Subtract with per-lane carry output. + +```c +for (int i = 0; i < N; i++) { + dst[i] = src0[i] - src1[i]; + carry[i] = (src0[i] >= src1[i]); +} +``` + +- **inputs:** `%lhs` and `%rhs` are subtracted lane-wise and `%mask` selects + active lanes. +- **outputs:** `%result` is the arithmetic difference and `%carry` is the + per-lane carry predicate. For this subtraction family, active lanes set + `%carry[i] = 1` when the subtraction completes without borrow, and + `%carry[i] = 0` when a borrow occurs. +- **A5 types:** `i32`, `si32`, `ui32` +- **constraints and limitations:** This operation is currently restricted to + the 32-bit integer carry/borrow-chain family. `%mask` and `%carry` + therefore use the same typed-mask granularity as the data vector family, + which on the current documented A5 surface means `!pto.mask`. + +--- + +#### Typical Usage + +```mlir +// Vector addition +%sum = pto.vadd %a, %b, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Element-wise multiply +%prod = pto.vmul %x, %y, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Clamp to range [min, max] +%clamped_low = pto.vmax %input, %min_vec, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%clamped = pto.vmin %clamped_low, %max_vec, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Bit manipulation +%masked = pto.vand %data, %bitmask, %mask : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +``` + + + +### 8. Vec-Scalar Ops + +> **Category:** Vector-scalar operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that combine a vector with a scalar value, applying the scalar to every lane. + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%scalar` is the scalar operand in SSA form. +- `%mask` is the predicate operand. +- `%result` is the destination vector register value. +- For 32-bit scalar forms, the scalar source MUST satisfy the backend's legal + scalar-source constraints for this family. +- For elementwise vec-scalar families whose scalar conceptually matches the + vector element type (`pto.vadds`, `pto.vmuls`, `pto.vmaxs`, + `pto.vmins`, `pto.vlrelu`): + - signed integer vectors accept signed integer scalars with the same width, + and also accept signless `i` + - unsigned integer vectors accept unsigned integer scalars with the same + width, and also accept signless `i` + - signless integer vectors accept signless `i` +- `pto.vshls` and `pto.vshrs` are not part of that rule; their scalar operand + is the shift amount and remains fixed to `i16`. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). **fp16** uses **aclFloat16** in measured traces. **bf16:** —. + +| PTO op | RV (CA) | fp32 | fp16 | bf16 | +|--------|---------|------|------|------| +| `pto.vadds` | `RV_VADDS` | **7** | **7** | — | +| `pto.vmuls` | `RV_VMULS` | **8** | **8** | — | + +--- + +#### Arithmetic + +##### `pto.vadds` + +- **syntax:** `%result = pto.vadds %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` +- **A5 types:** `si8`, `si16`, `si32`, `ui8`, `ui16`, `ui32`, `f16`, `bf16`, `f32` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] + scalar; +``` + +- **inputs:** `%input` is the source vector, `%scalar` is broadcast logically to + each lane, and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise sum. +- **constraints and limitations:** Input vector element type, scalar type, and + result vector element type MUST match. For integer vector forms, `%scalar` + may also use matching-signedness integer or signless `i` with the same + bit width as the vector element type, so it can be fed directly from `arith` + constants. + +--- + +##### `pto.vmuls` + +- **syntax:** `%result = pto.vmuls %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] * scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise product. +- **constraints and limitations:** Supported element types are hardware-family + specific; the current PTO micro Instruction documentation covers the common + numeric cases. For integer vector forms, `%scalar` may use matching-signedness + integer or signless `i` with the same bit width as the vector element + type. + +--- + +##### `pto.vmaxs` + +- **syntax:** `%result = pto.vmaxs %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] > scalar) ? src[i] : scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise maximum. +- **constraints and limitations:** Input and result types MUST match. For + integer vector forms, `%scalar` may use matching-signedness integer or + signless `i` with the same bit width as the vector element type. + +--- + +##### `pto.vmins` + +- **syntax:** `%result = pto.vmins %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] < scalar) ? src[i] : scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise minimum. +- **constraints and limitations:** Input and result types MUST match. For + integer vector forms, `%scalar` may use matching-signedness integer or + signless `i` with the same bit width as the vector element type. + +--- + +#### Shift + +##### `pto.vshls` + +- **syntax:** `%result = pto.vshls %input, %scalar, %mask : !pto.vreg, i16, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] << scalar; +``` + +- **inputs:** `%input` is the value vector, `%scalar` is the uniform `i16` shift + amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. The shift amount + SHOULD stay within the source element width. + +--- + +##### `pto.vshrs` + +- **syntax:** `%result = pto.vshrs %input, %scalar, %mask : !pto.vreg, i16, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] >> scalar; +``` + +- **inputs:** `%input` is the value vector, `%scalar` is the uniform `i16` shift + amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vlrelu` + +- **syntax:** `%result = pto.vlrelu %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : scalar * src[i]; +``` + +- **inputs:** `%input` is the activation vector, `%scalar` is the leaky slope, + and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise leaky-ReLU result. +- **constraints and limitations:** Only `f16` and `f32` forms are currently + documented for `pto.vlrelu`. + +--- + +#### Carry Operations + +##### `pto.vaddcs` + +- **syntax:** `%result, %carry = pto.vaddcs %lhs, %rhs, %carry_in, %mask : !pto.vreg, !pto.vreg, !pto.mask, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Add with carry-in and carry-out. + +```c +for (int i = 0; i < N; i++) { + uint64_t r = (uint64_t)src0[i] + src1[i] + carry_in[i]; + dst[i] = (T)r; + carry_out[i] = (r >> bitwidth); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the value vectors, `%carry_in` is the + incoming carry predicate, and `%mask` selects active lanes. +- **outputs:** `%result` is the arithmetic result and `%carry` is the carry-out + predicate. +- **A5 types:** `i32`, `si32`, `ui32` +- **constraints and limitations:** This is the scalar-extended carry-chain + family. On the current A5 surface, only 32-bit integer element types are + supported. `%carry_in`, `%mask`, and `%carry` therefore all use the same + typed-mask granularity as the data vector family, which on the current + documented A5 surface means `!pto.mask`. + +--- + +##### `pto.vsubcs` + +- **syntax:** `%result, %carry = pto.vsubcs %lhs, %rhs, %carry_in, %mask : !pto.vreg, !pto.vreg, !pto.mask, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Subtract with carry input and output. + +```c +for (int i = 0; i < N; i++) { + dst[i] = src0[i] - src1[i] - (1 - carry_in[i]); + carry_out[i] = (src0[i] >= src1[i] + (1 - carry_in[i])); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the value vectors, `%carry_in` is the + incoming carry predicate, and `%mask` selects active lanes. +- **outputs:** `%result` is the arithmetic result and `%carry` is the + carry predicate after the lane-wise subtraction. For this subtraction family, + active lanes set `%carry[i] = 1` when the subtraction completes without + borrow, and `%carry[i] = 0` when a borrow occurs. +- **A5 types:** `i32`, `si32`, `ui32` +- **constraints and limitations:** This is the scalar-extended borrow-chain + family and is currently restricted to 32-bit integer element types. + `%carry_in`, `%mask`, and `%carry` therefore all use the same typed-mask + granularity as the data vector family, which on the current documented A5 + surface means `!pto.mask`. + +--- + +#### Typical Usage + +```mlir +// Add bias to all elements +%biased = pto.vadds %activation, %bias_scalar, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Scale by constant +%scaled = pto.vmuls %input, %scale, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Clamp to [0, 255] for uint8 quantization +%clamped_low = pto.vmaxs %input, %c0, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> +%clamped = pto.vmins %clamped_low, %c255, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Shift right by fixed amount +%shifted = pto.vshrs %data, %c4, %mask : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> +``` + + + +### 9. Conversion Ops + +> **Category:** Type conversion operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that convert between data types (float/int, narrowing/widening). + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%mask` is the predicate mask that selects active conversion lanes. +- `%result` is the destination vector register value. +- `rnd`, `sat`, and `part` are optional attributes that refine + conversion behavior when the selected source/destination type pair needs + rounding, saturation, or lane placement control. +- The single `pto.vcvt` surface covers float-int, float-float, int-float, and + int-int conversion families. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). Only representative traces below; other `pto.vcvt` conversion pairs depend on the RV lowering in the trace. + +| PTO op | RV (CA) | Note | Latency | +|--------|---------|------|---------| +| `pto.vcvt` | `RV_VCVT_F2F` | f32→f16 | **7** | +| `pto.vci` | — | no vector `RV_*` in sampled `veccore0` trace | — | + +--- + +#### `pto.vci` + +- **syntax:** `%result = pto.vci %index {order = "ASC|DESC"} : T -> !pto.vreg` +- **semantics:** Generate a lane-index vector from a scalar base value. +- **inputs:** + `%index` is the scalar base value. Supported scalar types are `i8/i16/i32`, + `f16`, and `f32`. +- **outputs:** + `%result` is the generated index vector. +- **constraints and limitations:** + This is an index-generation family, not a numeric conversion. `order` and + the result element type together determine whether lanes are generated as + `base + lane_id` or `base - lane_id`. Supported result types are + `!pto.vreg<256xsi8>`, `!pto.vreg<128xsi16>`, `!pto.vreg<64xsi32>`, + `!pto.vreg<128xf16>`, and `!pto.vreg<64xf32>`. `%index` must use the + matching scalar type for `f16`/`f32`; for integer results, `%index` must use + the same bit width and may be signless or signed. + +--- + +#### `pto.vcvt` + +- **syntax:** `%result = pto.vcvt %input, %mask {rnd = "RND", sat = "SAT", part = "PART"} : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Type conversion between float/int types with rounding control. + +```c +for (int i = 0; i < min(N, M); i++) + if (mask[i]) + dst[i] = convert(src[i], T0, T1, rnd); +``` + +- **inputs:** + `%input` is the source vector, `%mask` selects active lanes, and attributes + select rounding, saturation, and output placement when the conversion changes + width or packs into sub-lane positions. +- **outputs:** + `%result` is the converted vector. +- **constraints and limitations:** + Only documented source/destination type pairs are legal. All three + attributes are optional at the surface level, but only the subset meaningful + to the selected conversion kind should be provided. The execution mask must + use the typed-mask granularity that matches the source vector family on the + current surface; there is no `!pto.mask` form in VPTO. + +--- + +##### Rounding Modes + +| Mode | Description | +|------|-------------| +| `R` | Round to nearest, ties to even (default) | +| `A` | Round away from zero | +| `F` | Round toward negative infinity (floor) | +| `C` | Round toward positive infinity (ceil) | +| `Z` | Round toward zero (truncate) | +| `O` | Round to odd | + +--- + +##### Saturation Modes + +| Mode | Description | +|------|-------------| +| `SAT` | Saturate on overflow | +| `NOSAT` | No saturation (wrap/undefined on overflow) | + +--- + +##### Part Modes + +Use `part` when a width-changing conversion writes only one half of each wider +destination lane group. + +- `Part` (`PART_EVEN`, `PART_ODD`) + - Used by ordinary width-changing conversions. + - Typical cases include `32 -> 16`, `16 -> 32`, and other even/odd packing + or unpacking forms. +- `Part_T` (`PART_P0`, `PART_P1`, `PART_P2`, `PART_P3`) + - Used by lower-level packed placement forms. + - Typical cases include `32 -> 8`, packed fp8/fp4 conversion paths, and + other flows where the result is written into one of four sub-parts before a + later merge or compact step. + +| Mode | Description | +|------|-------------| +| `EVEN` | Output to even-indexed lanes | +| `ODD` | Output to odd-indexed lanes | +| `P0` | Output to sub-part 0 in 4-way packed placement forms | +| `P1` | Output to sub-part 1 in 4-way packed placement forms | +| `P2` | Output to sub-part 2 in 4-way packed placement forms | +| `P3` | Output to sub-part 3 in 4-way packed placement forms | + +--- + +##### Attribute Guidance + +- `rnd` + - Use when the conversion needs an explicit rounding rule, especially for + float-to-int, float-to-float narrowing, or integer-to-float forms that do + not map exactly. +- `mask` + - Use to select which source lanes participate in the conversion. In + width-changing conversions, `mask` works together with `part` / `pp` to + determine which logical lane positions are produced. +- `sat` + - Use when the conversion may overflow the destination range and hardware + exposes a saturating form. +- `part` + - Use for width-changing conversions that select the even or odd half of the + destination packing layout. + +###### Float To Int + +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<32xsi64>` +- `%dst = pto.vcvt %src, %mask {rnd, sat} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xsi32>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {rnd, part} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xsi32>` +- `%dst = pto.vcvt %src, %mask {rnd, sat} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<256xsi8>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<64xsi32>` + +###### Float To Float + +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xbf16>` +- `%dst = pto.vcvt %src, %mask {rnd, sat} : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<128xf16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<64xf32>` + +###### Int To Float + +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<128xf16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xsi8>, !pto.mask -> !pto.vreg<128xf16>` +- `%dst = pto.vcvt %src, %mask {rnd} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<128xf16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<64xf32>` +- `%dst = pto.vcvt %src, %mask {rnd} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<64xf32>` + +###### Int To Int + +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<128xui16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<64xui32>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xsi8>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xsi8>, !pto.mask -> !pto.vreg<64xsi32>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<64xui32>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<64xui32>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<64xsi32>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<128xui16>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<128xui16>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<32xsi64>` + +##### A5 Supported Type Matrix + +The table below is only a summary. For exact attribute combinations, use the +per-form entries above as the source of truth. + +| `src \ dst` | `ui8` | `si8` | `ui16` | `si16` | `ui32` | `si32` | `si64` | `f16` | `f32` | `bf16` | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| `ui8` | | | Y | | Y | | | Y | | | +| `si8` | | | | Y | | Y | | Y | | | +| `ui16` | Y | | | | Y | | | | | | +| `si16` | Y | | | | Y | Y | | Y | Y | | +| `ui32` | Y | | Y | Y | | | | | | | +| `si32` | Y | | Y | Y | | | Y | | Y | | +| `si64` | | | | | | | | | | | +| `f16` | Y | Y | | Y | | Y | | | Y | | +| `f32` | | | | Y | | Y | Y | Y | | Y | +| `bf16` | | | | | | Y | | Y | Y | | + +--- + +##### Width-Changing Conversion Pattern + +For conversions that change width (e.g., f32→f16), use even/odd parts and combine: + +```mlir +// Convert two f32 vectors to one f16 vector +%even = pto.vcvt %in0, %mask {rnd = "R", sat = "SAT", part = "EVEN"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +%odd = pto.vcvt %in1, %mask {rnd = "R", sat = "SAT", part = "ODD"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +%result = pto.vor %even, %odd, %mask : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> +``` + +--- + +#### `pto.vtrc` + +- **syntax:** `%result = pto.vtrc %input, %mask, "RND" : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Truncate/round float to integer-valued float (stays in float type). + +```c +for (int i = 0; i < N; i++) + dst[i] = round_to_int_valued_float(src[i], rnd); +``` + +- **inputs:** + `%input` is the floating-point source vector, `%mask` selects active lanes, + and `RND` selects the truncation/rounding rule. +- **outputs:** + `%result` is still a floating-point vector, but each active lane now carries + an integer-valued floating-point result. +- **constraints and limitations:** + This op does not change the element type. `T` must be `f16`, `f32`, or + `bf16`. `RND` must be one of `R`, `A`, `F`, `C`, or `Z`. `BW` must match the + element width: `b16` for `f16`/`bf16`, `b32` for `f32`. + +**Example:** +```mlir +// Round to nearest integer, keep as float +%rounded = pto.vtrc %input, %mask, "R" : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// input: [1.4, 2.6, -1.5, 3.0] +// output: [1.0, 3.0, -2.0, 3.0] +``` + +--- + +#### Typical Usage + +```mlir +// Quantization: f32 → i8 with saturation +%scaled = pto.vmuls %input, %scale, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> +%quantized = pto.vcvt %scaled, %mask {rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xi32> +// Then narrow i32 → i8 via pack ops + +// Mixed precision: bf16 → f32 for accumulation +%f32_vec = pto.vcvt %bf16_input, %mask {part = "EVEN"} + : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<64xf32> + +// Floor for integer division +%floored = pto.vtrc %ratio, %mask, "F" : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%int_div = pto.vcvt %floored, %mask {rnd = "Z"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xi32> +``` + +--- + +#### `pto.vbitcast` + +- **syntax:** `%result = pto.vbitcast %input : !pto.vreg -> !pto.vreg` +- **semantics:** Bitwise reinterpretation of a vreg vector without changing the underlying bit pattern. This operation performs a pure type cast that preserves the exact bits of each element, changing only their interpretation (e.g., from floating-point to integer). + +- **inputs:** + `%input` is the source vector register value. +- **outputs:** + `%result` is the reinterpreted vector register value. +- **constraints and limitations:** + 1. Both source and result must be `!pto.vreg<...>` types. + 2. Source and result vectors must have the same total bit width (currently 2048 bits). + 3. Only integer and floating-point element types are supported. + +**Element bit-width equality examples:** +- `f32<64>` → `i32<64>` (both 32-bit elements, total 2048 bits) +- `f16<128>` → `i16<128>` (both 16-bit elements, total 2048 bits) +- `bf16<128>` → `ui16<128>` (both 16-bit elements, total 2048 bits) +- `si32<64>` → `ui32<64>` (both 32-bit elements, total 2048 bits) +- `f32<64>` → `i16<128>` (32-bit/16-bit elements, total 2048 bits) + +**Verification:** The operation verifies that: +1. Both input and result are `!pto.vreg<...>` types. +2. Total bit width equals 2048 (the fixed vreg size). + +**Comparison with `pto.vcvt`:** +- `pto.vcvt` performs value conversion with rounding, saturation, and lane placement control. +- `pto.vbitcast` performs bitwise reinterpretation without changing the underlying bit pattern. + +**Example: Reinterpreting float as integer for bit manipulation** +```mlir +// Prepare a vector of float values +%fvec = pto.vlds %ub[%lane] : !pto.ptr -> !pto.vreg<64xf32> + +// Reinterpret as integer for bitwise operations +%ivec = pto.vbitcast %fvec : !pto.vreg<64xf32> -> !pto.vreg<64xi32> + +// Extract sign bit (bit 31) +%sign_bits = pto.vand %ivec, %sign_mask, %mask : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> + +// Reinterpret back to float +%fvec_without_sign = pto.vbitcast %sign_bits : !pto.vreg<64xi32> -> !pto.vreg<64xf32> +``` + +**Example: Type punning between signed and unsigned integer** +```mlir +// Convert signed to unsigned without changing bits +%signed = pto.vlds %ub[%lane] : !pto.ptr -> !pto.vreg<64xsi32> +%unsigned = pto.vbitcast %signed : !pto.vreg<64xsi32> -> !pto.vreg<64xui32> +// Bits are identical; interpretation changes from signed to unsigned +``` + +#### `pto.pbitcast` + +- **syntax:** `%result = pto.pbitcast %input : !pto.mask -> !pto.mask` +- **semantics:** Bitwise reinterpretation of a predicate register without + changing the underlying predicate-register image. This op makes mask-family + reinterpretation explicit in VPTO IR when a producer and consumer expect + different `!pto.mask<...>` views of the same hardware predicate state. + +- **inputs:** + `%input` is the source predicate register value. +- **outputs:** + `%result` is the reinterpreted predicate register value. +- **constraints and limitations:** + 1. Both source and result must be `!pto.mask<...>` types. + 2. `pto.pbitcast` does not materialize or normalize predicate contents; it + only changes which mask granularity the surrounding VPTO IR uses to + interpret the same predicate bits. + +**Example: Reinterpret a b16 predicate as b32 before a consumer** +```mlir +%m16 = pto.pintlv_b16 %lhs, %rhs : !pto.mask, !pto.mask -> !pto.mask, !pto.mask +%m32 = pto.pbitcast %m16#0 : !pto.mask -> !pto.mask +%result = pto.vsel %a, %b, %m32 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 10. Reduction Ops + +> **Category:** Vector reduction operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that reduce a vector to a scalar or per-group result. + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%mask` is the predicate operand `Pg`; inactive lanes do not participate. +- `%result` is the destination vector register value. +- Reduction results are written into the low-significance portion of the + destination vector and the remaining destination bits are zero-filled. + +--- + +#### Full Vector Reductions + +##### `pto.vcadd` + +- **syntax:** `%result = pto.vcadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, f32 +- **semantics:** Sum all elements. Result in lane 0, others zeroed. + +```c +T sum = 0; +for (int i = 0; i < N; i++) + sum += src[i]; +dst[0] = sum; +for (int i = 1; i < N; i++) + dst[i] = 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains the reduction result in its low element(s). +- **constraints and limitations:** On A5, `i8/u8` inputs produce widened + `i16/u16` results with half as many lanes (`M = N / 2`), and `i16/u16` inputs + produce widened `i32/u32` results with half as many lanes. For + `i32/u32/f16/f32` inputs, `U = T` and `M = N`. If all predicate bits are + zero, the result is zero. + +--- + +##### `pto.vcmax` + +- **syntax:** `%result = pto.vcmax %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Find max element with argmax. The lowest destination element + stores the maximum value, the second-lowest destination element stores the + index of the first maximum, and all remaining elements are zero-filled. + +```c +T mx = -INF; int idx = 0; +for (int i = 0; i < N; i++) + if (src[i] > mx) { mx = src[i]; idx = i; } +dst[0] = mx; +dst[1] = idx; +for (int i = 2; i < N; i++) + dst[i] = 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result[0]` holds the extremum value and `%result[1]` holds the + index. Other destination elements are zero-filled. +- **constraints and limitations:** If there are multiple maxima, the minimum + index is written. For floating-point types, inactive lanes are treated as + `-INF`; if all lanes are inactive, `%result[0]` becomes `-INF`. For integer + types, inactive lanes are treated as the literal minimum value; if all lanes + are inactive, `%result[0]` becomes that literal minimum value. The index is + written into the second destination element slot of the same destination + vector register. + +--- + +##### `pto.vcmin` + +- **syntax:** `%result = pto.vcmin %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Find min element with argmin. The lowest destination element + stores the minimum value, the second-lowest destination element stores the + index of the first minimum, and all remaining elements are zero-filled. + +```c +T mn = INF; int idx = 0; +for (int i = 0; i < N; i++) + if (src[i] < mn) { mn = src[i]; idx = i; } +dst[0] = mn; +dst[1] = idx; +for (int i = 2; i < N; i++) + dst[i] = 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result[0]` holds the extremum value and `%result[1]` holds the + index. Other destination elements are zero-filled. +- **constraints and limitations:** If there are multiple minima, the minimum + index is written. For floating-point types, inactive lanes are treated as + `+INF`; if all lanes are inactive, `%result[0]` becomes `+INF`. For integer + types, inactive lanes are treated as the literal maximum value; if all lanes + are inactive, `%result[0]` becomes that literal maximum value. The index is + written into the second destination element slot of the same destination + vector register. + +--- + +#### Per-VLane (Group) Reductions + +The vector register is organized as **8 VLanes** of 32 bytes each. Group reductions operate within each VLane independently. + +``` +vreg layout (f32 example, 64 elements total): +VLane 0: [0..7] VLane 1: [8..15] VLane 2: [16..23] VLane 3: [24..31] +VLane 4: [32..39] VLane 5: [40..47] VLane 6: [48..55] VLane 7: [56..63] +``` + +##### `pto.vcgadd` + +- **syntax:** `%result = pto.vcgadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Sum within each VLane. 8 results at indices 0, 8, 16, 24, 32, 40, 48, 56 (for f32). + +```c +int K = N / 8; // elements per VLane +for (int g = 0; g < 8; g++) { + T sum = 0; + for (int i = 0; i < K; i++) + sum += src[g*K + i]; + dst[g*K] = sum; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +// For f32: results at dst[0], dst[8], dst[16], dst[24], dst[32], dst[40], dst[48], dst[56] +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one sum per 32-byte VLane group, written + contiguously into the low slot of each group. +- **constraints and limitations:** This is a per-32-byte VLane-group reduction. + Inactive lanes are treated as zero. + +--- + +##### `pto.vcgmax` + +- **syntax:** `%result = pto.vcgmax %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Max within each VLane. + +```c +int K = N / 8; +for (int g = 0; g < 8; g++) { + T mx = -INF; + for (int i = 0; i < K; i++) + if (src[g*K + i] > mx) mx = src[g*K + i]; + dst[g*K] = mx; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one maximum per 32-byte VLane group. +- **constraints and limitations:** Grouping is by hardware 32-byte VLane, not by + arbitrary software subvector. + +--- + +##### `pto.vcgmin` + +- **syntax:** `%result = pto.vcgmin %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Min within each VLane. + +```c +int K = N / 8; +for (int g = 0; g < 8; g++) { + T mn = INF; + for (int i = 0; i < K; i++) + if (src[g*K + i] < mn) mn = src[g*K + i]; + dst[g*K] = mn; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one minimum per 32-byte VLane group. +- **constraints and limitations:** Grouping is by hardware 32-byte VLane, not by + arbitrary software subvector. + +--- + +#### Prefix Operations + +##### `pto.vcpadd` + +- **syntax:** `%result = pto.vcpadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Inclusive prefix sum (scan). + +```c +dst[0] = src[0]; +for (int i = 1; i < N; i++) + dst[i] = dst[i-1] + src[i]; +``` + +**Example:** +```c +// input: [1, 2, 3, 4, 5, ...] +// output: [1, 3, 6, 10, 15, ...] +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` is the inclusive prefix-sum vector. +- **constraints and limitations:** Only floating-point element types are + documented on the current A5 surface here. + +--- + +#### Typical Usage + +```mlir +// Softmax: find max for numerical stability +%max_vec = pto.vcmax %logits, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// max is in lane 0, broadcast it +%max_broadcast = pto.vlds %ub_tmp[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> + +// Row-wise sum using vcgadd (for 8-row tile) +%row_sums = pto.vcgadd %tile, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// Results at indices 0, 8, 16, 24, 32, 40, 48, 56 + +// Full vector sum for normalization +%total = pto.vcadd %values, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// total[0] contains the sum + +// Prefix sum for cumulative distribution +%cdf = pto.vcpadd %pdf, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 11. Compare & Select + +> **Category:** Comparison and conditional selection operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that compare vectors and conditionally select elements. + +#### Common Operand Model + +- `%src0` and `%src1` are source vector operands. +- `%scalar` is the scalar operand for scalar-comparison families. +- `%seed` is the incoming predicate that limits which lanes participate in the + compare. +- `%result` is either a predicate mask (`vcmp`, `vcmps`) or a vector register + (`vsel`, `vselr`, `vselrv2`). + +--- + +#### Comparison Operations + +##### `pto.vcmp` + +- **syntax:** `%result = pto.vcmp %src0, %src1, %seed, "CMP_MODE" : !pto.vreg, !pto.vreg, !pto.mask -> !pto.mask` +- **semantics:** Element-wise comparison, output predicate mask. + +```c +for (int i = 0; i < N; i++) + if (seed[i]) + dst[i] = (src0[i] CMP src1[i]) ? 1 : 0; +``` + +**Compare modes:** + +| Mode | Operation | +|------|-----------| +| `eq` | Equal (==) | +| `ne` | Not equal (!=) | +| `lt` | Less than (<) | +| `le` | Less than or equal (<=) | +| `gt` | Greater than (>) | +| `ge` | Greater than or equal (>=) | + +**Example:** +```mlir +%all_active = pto.pset_b32 "PAT_ALL" : !pto.mask +%lt_mask = pto.vcmp %a, %b, %all_active, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +// lt_mask[i] = 1 if a[i] < b[i] +``` + +- **inputs:** `%src0`, `%src1`, and `%seed`; `CMP_MODE` selects the comparison + predicate. +- **outputs:** `%result` is the generated predicate mask. +- **constraints and limitations:** Only lanes enabled by `%seed` participate. + Integer and floating-point comparisons follow their own element-type-specific + comparison rules. `%seed` and `%result` keep the typed-mask granularity that + matches `%src0` / `%src1`. + +--- + +##### `pto.vcmps` + +- **syntax:** `%result = pto.vcmps %src, %scalar, %seed, "CMP_MODE" : !pto.vreg, T, !pto.mask -> !pto.mask` +- **semantics:** Compare vector against scalar. + +```c +for (int i = 0; i < N; i++) + if (seed[i]) + dst[i] = (src[i] CMP scalar) ? 1 : 0; +``` + +**Example:** +```mlir +%positive_mask = pto.vcmps %values, %c0_f32, %all_active, "gt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +// positive_mask[i] = 1 if values[i] > 0 +``` + +- **inputs:** `%src` is the vector source, `%scalar` is the scalar comparison + value, and `%seed` is the incoming predicate. +- **outputs:** `%result` is the generated predicate mask. +- **constraints and limitations:** For 32-bit scalar forms, the scalar source + MUST satisfy the backend's legal scalar-source constraints for this family. + `%seed` and `%result` keep the typed-mask granularity that matches `%src`. + +--- + +#### Selection Operations + +##### `pto.vsel` + +- **syntax:** `%result = pto.vsel %src0, %src1, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Per-lane select based on mask. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? src0[i] : src1[i]; +``` + +**Example — Conditional assignment:** +```mlir +// dst = mask ? true_vals : false_vals +%result = pto.vsel %true_vals, %false_vals, %condition + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +- **inputs:** `%src0` is the true-path vector, `%src1` is the false-path vector, + and `%mask` selects between them. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** Source vectors and result MUST have matching + vector shapes and element types. `%mask` keeps the typed-mask granularity + that matches the selected vector family. + +--- + +##### `pto.vselr` + +- **syntax:** `%result = pto.vselr %src, %idx : !pto.vreg, !pto.vreg> -> !pto.vreg` +- **semantics:** Lane-select by index vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = src[idx[i]]; +``` + +- **inputs:** `%src` is the source vector. `%idx` is the lane-index vector. +- **outputs:** `%result` is the reordered vector. +- **constraints and limitations:** `%idx` must use integer elements. `%idx` + must have the same lane count as `%src`, and its integer element width must + match the bit width of `%src` element type. + +--- + +##### `pto.vselrv2` + +- **syntax:** `%result = pto.vselrv2 %src0, %src1 : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Variant select form with the same current two-vector operand shape. +- **inputs:** `%src0` and `%src1` are the source vectors. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** This page records the surface shape only. + Lowering MUST preserve the exact A5 variant semantics selected for this form. + +--- + +#### Typical Usage + +```mlir +// Clamp negative values to zero (manual ReLU) +%all = pto.pset_b32 "PAT_ALL" : !pto.mask +%zero = pto.vbr %c0_f32 : f32 -> !pto.vreg<64xf32> +%neg_mask = pto.vcmps %input, %c0_f32, %all, "lt" : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%clamped = pto.vsel %zero, %input, %neg_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Element-wise max via compare+select +%gt_mask = pto.vcmp %a, %b, %all, "gt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +%max_ab = pto.vsel %a, %b, %gt_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Threshold filter +%above_thresh = pto.vcmps %scores, %threshold, %all, "ge" : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%filtered = pto.vsel %scores, %zero, %above_thresh : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +--- + +#### Compare + Select Pattern + +```mlir +// Softmax safe exp: exp(x - max) where x < max returns exp of negative +// but we want to clamp to avoid underflow + +%all = pto.pset_b32 "PAT_ALL" : !pto.mask + +// 1. Compare against threshold +%too_small = pto.vcmps %x_minus_max, %min_exp_arg, %all, "lt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask + +// 2. Clamp values below threshold +%clamped = pto.vsel %min_exp_arg_vec, %x_minus_max, %too_small + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// 3. Safe exp +%exp_result = pto.vexp %clamped, %all : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 12. Data Rearrangement + +> **Category:** In-register data movement and permutation +> **Pipeline:** PIPE_V (Vector Core) + +Operations that rearrange data within or between vector registers without memory access. + +#### Common Operand Model + +- `%lhs` / `%rhs` are source vector register values. +- `%src` is a single source vector register value. +- `%result` is the destination vector register value unless an op explicitly + returns multiple vectors. +- These families do not access UB directly; they only rearrange register + contents. + +--- + +#### Interleave / Deinterleave + +##### `pto.vintlv` + +- **syntax:** `%low, %high = pto.vintlv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg, !pto.vreg` +- **semantics:** Interleave elements from two sources. + +```c +// Interleave: merge even/odd elements from two sources +// low = {src0[0], src1[0], src0[1], src1[1], ...} +// high = {src0[N/2], src1[N/2], src0[N/2+1], src1[N/2+1], ...} +``` + +- **inputs:** `%lhs` and `%rhs` are the two source vectors. +- **outputs:** `%low` and `%high` are the two destination vectors. +- **constraints and limitations:** The two outputs form a paired interleave + result. The PTO micro Instruction representation exposes that pair as two SSA results, and the pair ordering MUST + be preserved. + +--- + +##### `pto.vdintlv` + +- **syntax:** `%low, %high = pto.vdintlv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg, !pto.vreg` +- **semantics:** Deinterleave elements into even/odd. + +```c +// Deinterleave: separate even/odd elements +// low = {src0[0], src0[2], src0[4], ...} // even +// high = {src0[1], src0[3], src0[5], ...} // odd +``` + +- **inputs:** `%lhs` and `%rhs` represent the interleaved source stream in the + current PTO micro Instruction representation. +- **outputs:** `%low` and `%high` are the separated destination vectors. +- **constraints and limitations:** The two outputs form the even/odd + deinterleave result pair, and their ordering MUST be preserved. + +--- + +#### Compress / Expand + +##### `pto.vsqz` + +- **syntax:** `%result = pto.vsqz %src, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Compress — pack active lanes to front. + +```c +int j = 0; +for (int i = 0; i < N; i++) + if (mask[i]) dst[j++] = src[i]; +while (j < N) dst[j++] = 0; +``` + +**Use case:** Sparse data compaction, filtering. + +- **inputs:** `%src` is the source vector and `%mask` selects which elements are + kept. +- **outputs:** `%result` is the compacted vector. +- **constraints and limitations:** This is a reduction-style compaction family. + Preserved element order MUST match source lane order. + +--- + +##### `pto.vusqz` + +- **syntax:** `%result = pto.vusqz %src, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Generate per-lane prefix counts from the governing predicate. + +```c +dst[0] = 0; +for (int i = 1; i < N; i++) + dst[i] = mask[i - 1] ? (dst[i - 1] + 1) : dst[i - 1]; +``` + +- **inputs:** `%mask` is the governing predicate. The current PTO surface keeps + `%src` in the operand list for interface compatibility, but the observable + result semantics are determined by `%mask`. +- **outputs:** `%result[i]` equals the number of active lanes in `%mask[0:i)`, + with `%result[0] = 0`. +- **constraints and limitations:** `T` is currently limited to `si8`, `si16`, + or `si32`. This operation is a predicate-derived counting/rearrangement + primitive rather than a value-placement primitive. The final predicate lane + does not contribute to a later output lane because there is no `dst[N]`. + +--- + +--- + +##### `pto.vselr` + +- **syntax:** `%result = pto.vselr %src, %idx : !pto.vreg, !pto.vreg> -> !pto.vreg` +- **semantics:** Register lane-select with an explicit index vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = src[idx[i]]; +``` + +- **inputs:** `%src` is the source vector. `%idx` is the lane-index vector. +- **outputs:** `%result` is the reordered vector. +- **constraints and limitations:** This page records the rearrangement use of + the family; the compare/select page documents the same name from the predicate + selection perspective. + +--- + +#### Pack / Unpack + +##### `pto.vpack` + +- **syntax:** `%result = pto.vpack %src, "PART" : !pto.vreg -> !pto.vreg<2NxT_narrow>` +- **semantics:** Narrow one wide vector and place the narrowed payload into the + selected half of the result. The other half is filled with zero. + +```c +// e.g., vreg<64xi32> → vreg<128xui16> +for (int i = 0; i < N; i++) + dst[i] = 0; + +if (part == LOWER) { + for (int i = 0; i < N; i++) + dst[i] = truncate(src[i]); +} else { // HIGHER + for (int i = 0; i < N; i++) + dst[N + i] = truncate(src[i]); +} +``` + +- **inputs:** `%src` is the wide source vector. `"LOWER"` and `"HIGHER"` + select whether the narrowed payload lands in the lower or upper half. +- **outputs:** `%result` is the packed narrow vector. +- **constraints and limitations:** Packing is a narrowing conversion with + truncation semantics. Current VPTO surface supports `i32/ui32 -> ui16` and + `i16/ui16 -> ui8`. + +--- + +##### `pto.vsunpack` + +- **syntax:** `%result = pto.vsunpack %src, %part : !pto.vreg, index -> !pto.vreg` +- **semantics:** Sign-extending unpack — narrow to wide (half). + +```c +// e.g., vreg<128xi16> → vreg<64xi32> (one half) +for (int i = 0; i < N/2; i++) + dst[i] = sign_extend(src[part_offset + i]); +``` + +- **inputs:** `%src` is the packed narrow vector and `%part` selects which half + is unpacked. +- **outputs:** `%result` is the widened vector. +- **constraints and limitations:** This is the sign-extending unpack family. + +--- + +##### `pto.vzunpack` + +- **syntax:** `%result = pto.vzunpack %src, %part : !pto.vreg, index -> !pto.vreg` +- **semantics:** Zero-extending unpack — narrow to wide (half). + +```c +for (int i = 0; i < N/2; i++) + dst[i] = zero_extend(src[part_offset + i]); +``` + +- **inputs:** `%src` is the packed narrow vector and `%part` selects which half + is unpacked. +- **outputs:** `%result` is the widened vector. +- **constraints and limitations:** This is the zero-extending unpack family. + +--- + +#### Typical Usage + +```mlir +// AoS → SoA conversion using deinterleave +%even, %odd = pto.vdintlv %interleaved0, %interleaved1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +// Filter: keep only elements passing condition +%pass_mask = pto.vcmps %values, %threshold, %all, "gt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%compacted = pto.vsqz %values, %pass_mask + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Type narrowing via pack +%packed_i16 = pto.vpack %wide_i32, "LOWER" + : !pto.vreg<64xi32> -> !pto.vreg<128xui16> +``` + +--- + +#### V2 Interleave Forms + +##### `pto.vintlvv2` + +- **syntax:** `%result = pto.vintlvv2 %lhs, %rhs, "PART" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **inputs:** `%lhs` and `%rhs` are source vectors and `PART` selects the + returned half of the V2 interleave result. +- **outputs:** `%result` is the selected interleave half. +- **constraints and limitations:** This op exposes only one half of the V2 + result in SSA form. + +##### `pto.vdintlvv2` + +- **syntax:** `%result = pto.vdintlvv2 %lhs, %rhs, "PART" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **inputs:** `%lhs` and `%rhs` are source vectors and `PART` selects the + returned half of the V2 deinterleave result. +- **outputs:** `%result` is the selected deinterleave half. +- **constraints and limitations:** This op exposes only one half of the V2 + result in SSA form. + + + +### 13. DSA/SFU Ops + +> **Category:** Domain-specific accelerator and special function unit operations +> **Pipeline:** PIPE_V (Vector Core) / SFU + +Fused operations, special functions, and UB-to-UB operations that leverage hardware acceleration. + +#### Common Operand Model + +- `%input`, `%lhs`, `%rhs`, `%acc`, and `%alpha` are source SSA values whose + roles are called out per instruction. +- `%mask` is the predicate operand `Pg` when present. +- `%result` is the destination SSA value. +- This page mixes three different backend shapes: pure `vreg -> vreg` ops, + conversion/fusion ops, and UB-to-UB helpers. Each instruction section calls + out which storage model it uses. + +--- + +#### Fused Activation Ops (vreg→vreg) + +##### `pto.vlrelu` + +- **syntax:** `%result = pto.vlrelu %input, %alpha, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Leaky ReLU with scalar alpha. + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : alpha * src[i]; +``` + +- **inputs:** `%input` is the activation vector, `%alpha` is the scalar slope, + and `%mask` selects active lanes. +- **outputs:** `%result` is the leaky-ReLU vector. +- **constraints and limitations:** Only `f16` and `f32` forms are currently + documented for `pto.vlrelu`. + +--- + +##### `pto.vprelu` + +- **syntax:** `%result = pto.vprelu %input, %alpha, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Parametric ReLU with per-element alpha vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : alpha[i] * src[i]; +``` + +- **inputs:** `%input` is the activation vector, `%alpha` is the per-element + slope vector, and `%mask` selects active lanes. +- **outputs:** `%result` is the parametric-ReLU vector. +- **constraints and limitations:** Floating-point element types only on the + current A5 surface. + +--- + +##### `pto.vexpdif` + +- **syntax:** `%result = pto.vexpdif %input, %max, %mask, "EVEN|ODD" : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** input `f16` or `f32`, output `f32` +- **semantics:** Fused exp(x - max) for numerically stable softmax. + +```c +for (int i = 0; i < N; i++) + dst[i] = expf(src[i] - max[i]); +``` + +**Use case:** Softmax numerator computation with numerical stability. + +- **inputs:** `%input` is the source vector, `%max` is the broadcasted + subtraction term, `%mask` selects active source lanes, and `%part` selects + `EVEN` or `ODD` for the underlying hardware contract. +- **outputs:** `%result` is the fused `exp(input - max)` vector with `f32` + elements. +- **constraints and limitations:** Source vectors must be `f16` or `f32`, the + result vector must be `f32`, the mask granularity must match the input + vector element width, and source/result storage width must match. + +--- + +#### Fused Compute+Convert Ops + +##### `pto.vaxpy` + +- **syntax:** `%result = pto.vaxpy %src0, %src1, %alpha, %mask : !pto.vreg, !pto.vreg, T, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** AXPY — scalar-vector multiply-add. + +```c +for (int i = 0; i < N; i++) + dst[i] = alpha * src0[i] + src1[i]; +``` + +- **inputs:** `%src0` is the scaled vector, `%src1` is the addend vector, + `%alpha` is the scalar multiplier, and `%mask` selects active lanes. +- **outputs:** `%result` is the fused AXPY result. +- **constraints and limitations:** Floating-point element types only on the + current documented surface. + +--- + + +#### Extended Arithmetic + +##### `pto.vmull` + +- **syntax:** `%low, %high = pto.vmull %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.vreg` +- **A5 types:** i32/ui32 (native 32×32→64 widening multiply) +- **semantics:** Widening multiply with high/low results. + +```c +for (int i = 0; i < 64; i++) { + int64_t r = (int64_t)src0_i32[i] * (int64_t)src1_i32[i]; + dst_lo[i] = (int32_t)(r & 0xFFFFFFFF); + dst_hi[i] = (int32_t)(r >> 32); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the source vectors and `%mask` selects + active lanes. +- **outputs:** `%low` and `%high` expose the widened-product low/high parts. +- **constraints and limitations:** The current documented A5 form is the native + widening 32x32->64 integer multiply family. + +--- + +##### `pto.vmula` + +- **syntax:** `%result = pto.vmula %acc, %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Multiply-accumulate. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) + dst[i] = acc[i] + lhs[i] * rhs[i]; +``` + +- **inputs:** `%acc` is the accumulator input, `%lhs` and `%rhs` are the + multiplicands, and `%mask` selects active lanes. +- **outputs:** `%result` is the multiply-accumulate result. +- **constraints and limitations:** `pto.vmula` is a fused multiply-accumulate + operation and is not always interchangeable with separate `vmul` plus `vadd`. + +--- + +#### Index Generation + +##### `pto.vci` + +- **syntax:** `%result = pto.vci %index {order = "ASC|DESC"} : T -> !pto.vreg` +- **semantics:** Generate a lane index vector from a scalar base value. + +```c +for (int i = 0; i < N; i++) + dst[i] = (order == ASC) ? (base_index + i) : (base_index - i); +``` + +**Use case:** Generate indices for gather/scatter, argsort, etc. + +- **inputs:** `%index` is the scalar base value. Supported scalar types are + `i8/i16/i32`, `f16`, and `f32`. +- **outputs:** `%result` is the generated index vector. +- **constraints and limitations:** `%result` element type determines both the + generated element type and the lane count. Supported result types are + `!pto.vreg<256xsi8>`, `!pto.vreg<128xsi16>`, `!pto.vreg<64xsi32>`, + `!pto.vreg<128xf16>`, and `!pto.vreg<64xf32>`. `%index` must use the + matching scalar type for `f16`/`f32`; for integer results, `%index` must use + the same bit width and may be signless or signed. + +--- + +#### Sorting Operations + +##### `pto.vbitsort` + +- **syntax:** `pto.vbitsort %dest, %src, %indices, %repeat_times : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, index` +- **semantics:** Sort 32 region proposals by score and materialize sorted + proposal records into `%dest`. +- **inputs:** `%dest` is the UB destination buffer. `%src` is the UB score + buffer. `%indices` is the UB index buffer. `%repeat_times` is the repeat + count; each repeat processes the next adjacent group of 32 scores and 32 + indices. +- **outputs:** This op writes UB memory and returns no SSA value. Each output + record occupies 8 bytes: the upper 4 bytes hold the index and the lower + 4 bytes hold the score. For `f16` score forms, the score uses the lower + 2 bytes of that 4-byte score field and the upper 2 bytes are reserved. +- **constraints and limitations:** `%dest`, `%src`, and `%indices` MUST be + UB-backed pointers and SHOULD satisfy the backend alignment contract expected + by the A5 `VBS32` instruction. Scores are sorted in descending order, so the + highest score is written to the lowest destination address. Equal-score ties + preserve the earlier input proposal first. This is a UB helper, not a pure + `vreg -> vreg` op. + +--- + +##### `pto.vmrgsort4` + +- **syntax:** `pto.vmrgsort4 %dest, %src0, %src1, %src2, %src3, %count, %config : !pto.ptr, !pto.ptr, !pto.ptr, !pto.ptr, !pto.ptr, i64, i64` +- **semantics:** Merge-sort 4 pre-sorted input vectors. +- **inputs:** `%dest` is the UB destination, `%src0..%src3` are the four + pre-sorted UB inputs, `%count` is the number of valid elements, and `%config` + is the operation control word. +- **outputs:** This op writes UB memory and returns no SSA value. +- **constraints and limitations:** Inputs MUST already be sorted according to + the sort order encoded by `%config`. + +--- + +#### Current Implementation Surface Summary + +- `pto.vmull %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.vreg` +- `pto.vmula %acc, %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- `pto.vci %index {order = "ASC|DESC"} : T -> !pto.vreg` +- `pto.vbitsort %dest, %src, %indices, %repeat_times : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, index` +- `pto.vmrgsort4 %dest, %src0, %src1, %src2, %src3, %count, %config : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, i64, i64` + +--- + +#### Typical Usage + +```mlir +// Softmax with fused expdiff +%max_broadcast = pto.vlds %ub_max[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> +%exp_stable = pto.vexpdif %logits, %max_broadcast, %mask, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Leaky ReLU activation +%activated = pto.vlrelu %linear_out, %alpha_scalar, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Generate ascending si32 indices for argsort +%indices = pto.vci %c0 {order = "ASC"} : i32 -> !pto.vreg<64xsi32> +``` + + + +### 14. Arith (Shared MLIR Dialect) + +> **Category:** Shared full scalar `arith` surface used around PTO ops +> **Dialect:** `arith` +> **Upstream Reference:** https://mlir.llvm.org/docs/Dialects/ArithOps/ + +The upstream MLIR `arith` dialect defines primitive arithmetic, comparison, select, and cast operations over signless integer, index, floating-point, and boolean-compatible scalar values. Within PTO micro Instruction code, the full scalar operation surface of `arith` is supported. These ops are used around PTO instructions to build constants, compute offsets and loop bounds, perform general scalar math, derive valid-shape metadata, and form predicates for `scf` control flow. + +These ops are part of the documented PTO micro Instruction surface, but they are not PTO ISA instructions. + +--- + +#### Role in PTO micro Instruction Code + +- materialize scalar constants used by PTO scalar operands and loop bounds +- compute scalar/index offsets for tensor views, partitioning, and dynamic shapes +- perform general scalar integer and floating-point math outside PTO vector/tile payload operations +- derive scalar predicates that guard `scf.if` or `scf.while` +- apply scalar casts, width changes, bitwise ops, and selects without introducing PTO-specific control ops + +Prefer PTO ops for vector or tile payload math. Use `arith` for scalar computation and bookkeeping that surrounds PTO regions. + +--- + +#### Supported Surface + +The documented PTO micro Instruction surface supports the full scalar operation surface of upstream `arith`. The upstream `arith` dialect reference remains authoritative for the exhaustive op-by-op syntax and semantics. The categories below summarize how that support is used in PTO micro Instruction code. + +| Category | Representative Ops | Typical Use in PTO micro Instruction Code | +|----------|--------------------|------------------| +| Constants | `arith.constant` | integer, floating-point, boolean, and `index` constants | +| Integer / Index Arithmetic | `arith.addi`, `arith.subi`, `arith.muli`, `arith.divsi`, `arith.divui`, `arith.ceildivsi`, `arith.ceildivui`, `arith.floordivsi`, `arith.remsi`, `arith.remui` | offsets, bounds, chunk sizes, scalar math | +| Floating-Point Arithmetic | `arith.addf`, `arith.subf`, `arith.mulf`, `arith.divf`, `arith.negf`, `arith.maximumf`, `arith.minimumf`, `arith.maxnumf`, `arith.minnumf` | scalar math around PTO regions | +| Bitwise / Shift Ops | `arith.andi`, `arith.ori`, `arith.xori`, `arith.shli`, `arith.shrsi`, `arith.shrui` | flags, masks, packed scalar fields | +| Comparisons / Select | `arith.cmpi`, `arith.cmpf`, `arith.select`, `arith.maxsi`, `arith.minui` | predicates, clamps, scalar muxes | +| Casts / Width Changes | `arith.index_cast`, `arith.index_castui`, `arith.extsi`, `arith.extui`, `arith.trunci`, `arith.sitofp`, `arith.uitofp`, `arith.fptosi`, `arith.fptoui`, `arith.extf`, `arith.truncf`, `arith.bitcast` | ABI glue, dynamic-shape plumbing, scalar type adaptation | + +--- + +#### Current PTOAS Coverage + +- the current repository examples are still dominated by constants, casts, integer/index arithmetic, compares, and selects because those are the most common surrounding-scalar patterns in existing kernels +- backend-specific tests such as the PTO shared-dialect fixture visibly exercise only a representative subset of `arith` ops in a single path +- the documented PTO micro Instruction source-level contract is nevertheless the full scalar `arith` surface, not just the index-heavy subset that appears most often in current samples + +This section therefore uses representative categories and examples instead of pretending that the supported `arith` surface is limited to the currently most common sample patterns. + +--- + +#### Typical Patterns + +##### Scalar Setup + +```mlir +%c0 = arith.constant 0 : index +%c1 = arith.constant 1 : index +%scale = arith.constant 2.0 : f32 +``` + +##### Dynamic Offset Computation + +```mlir +%vrow = arith.index_cast %valid_row : i32 to index +%chunk = arith.muli %row, %c32 : index +%tail = arith.subi %limit, %chunk : index +``` + +##### General Scalar Arithmetic + +```mlir +%sum_i = arith.addi %lhs_i, %rhs_i : i32 +%sum_f = arith.addf %lhs_f, %rhs_f : f32 +%prod_f = arith.mulf %sum_f, %scale : f32 +``` + +##### Scalar Predicate and Selection + +```mlir +%is_first = arith.cmpi eq, %i, %c0 : index +%active = arith.select %is_first, %first_count, %steady_count : index +``` + +##### Bitwise / Width Adaptation + +```mlir +%flags = arith.andi %flags0, %flags1 : i32 +%wide = arith.extui %flags : i32 to i64 +%shrunk = arith.trunci %wide : i64 to i16 +``` + +--- + +#### Authoring Guidance + +- treat upstream `arith` scalar semantics as the source of truth for supported scalar ops +- keep `arith` values scalar or `index` typed; do not use `arith` as a substitute for PTO vector/tile compute +- use `arith` for general scalar math, scalar comparisons, bitwise operations, and casts around PTO regions, not just for `index` arithmetic +- use `arith.cmpi` / `arith.cmpf` plus `scf.if` / `scf.while` for control flow, not ad hoc control intrinsics +- prefer `arith.index_cast` / `arith.index_castui` at ABI or shape boundaries where `index` is required, but do not read that as a restriction on the rest of scalar `arith` + + + +### 15. SCF (Shared MLIR Dialect) + +> **Category:** Shared structured control flow around PTO regions +> **Dialect:** `scf` +> **Upstream Reference:** https://mlir.llvm.org/docs/Dialects/SCFDialect/ + +The upstream MLIR `scf` dialect defines structured control flow operations with regions, including counted loops, conditional regions, and while-style loops. In PTO micro Instruction code, `scf` is the control shell around PTO ops: it sequences DMA, vector, and tile operations; carries scalar or tile state across iterations; and preserves analyzable control flow for PTO-specific analyses and lowerings. + +These ops are part of the documented PTO micro Instruction surface, but they are shared MLIR control-flow constructs rather than PTO ISA instructions. + +--- + +#### Supported Ops + +| Op | Role in PTO micro Instruction Code | Notes | +|----|------------------------|-------| +| `scf.for` | counted loops and loop-carried values | common structured counted loop form | +| `scf.if` | structured conditional execution | may yield values or act as side-effect-only branch | +| `scf.yield` | region terminator for `for` / `if` / `while` bodies | carries loop or branch results | +| `scf.while` | break-like or stateful loops | useful for source-level structured control | +| `scf.condition` | loop-continue / loop-exit decision for `scf.while` | placed in the "before" region | + +Ops such as `scf.execute_region`, `scf.forall`, or `scf.index_switch` are not part of the documented shared-dialect portion of the PTO micro Instruction surface here. + +--- + +#### Current PTOAS Coverage + +- `scf.for`, `scf.if`, and `scf.yield` are directly exercised in the shared-dialect PTO fixture and appear widely across PTO samples +- PTO synchronization and memory analyses explicitly reason about `scf.for`, `scf.if`, `scf.yield`, and `scf.while` +- `scf.while` and `scf.condition` appear in control-flow samples and are handled in PTO-to-EmitC control-flow lowering, but they are less broadly exercised than `for` / `if` on all backend paths + +--- + +#### Typical Patterns + +##### Counted Loop + +```mlir +scf.for %i = %c0 to %c4 step %c1 { + %offset = arith.muli %i, %c32 : index + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} +``` + +##### Counted Loop with Loop-Carried State + +```mlir +%final_alive = scf.for %i = %c0 to %c4 step %c1 + iter_args(%alive = %true) -> (i1) { + %break_now = arith.cmpi eq, %i, %c2 : index + %next_alive = scf.if %break_now -> (i1) { + scf.yield %false : i1 + } else { + scf.yield %alive : i1 + } + scf.yield %next_alive : i1 +} +``` + +##### Structured Conditional Region + +```mlir +%is_mode_a = arith.cmpi eq, %mode, %c0_i32 : i32 +scf.if %is_mode_a { + pto.tmuls ins(%data, %scale_a : !pto.tile_buf<...>, f32) outs(%data : !pto.tile_buf<...>) +} else { + pto.tadds ins(%data, %bias_b : !pto.tile_buf<...>, f32) outs(%data : !pto.tile_buf<...>) +} +``` + +##### While-Style Break Loop + +```mlir +%final:2 = scf.while (%i = %c0, %alive = %true) : (index, i1) -> (index, i1) { + %lt = arith.cmpi slt, %i, %c4 : index + %go = arith.andi %lt, %alive : i1 + scf.condition(%go) %i, %alive : index, i1 +} do { +^bb0(%i2: index, %alive2: i1): + %next_i = arith.addi %i2, %c1 : index + scf.yield %next_i, %alive2 : index, i1 +} +``` + +--- + +#### Authoring Guidance + +- use `scf.for` for regular counted loops and loop-carried scalar/tile state +- use `scf.if` for structured branching around PTO regions instead of inventing PTO-specific branch ops +- keep region results explicit with `scf.yield`; this is important for PTO analyses that track carried buffers and aliasing +- use `scf.while` only when a counted loop cannot express the control cleanly; `scf.for` remains the more common and better-exercised form in the current repository +- build branch predicates and loop conditions with `arith` ops, not PTO vector masks, unless the control decision truly comes from a scalarized value + +## Supported Data Types + +| Type | Bits | vreg Lanes | Description | +|------|------|-----------|-------------| +| `i8` / `si8` / `ui8` | 8 | 256 | Signless/signed/unsigned 8-bit integer | +| `i16` / `si16` / `ui16` | 16 | 128 | Signless/signed/unsigned 16-bit integer | +| `f16` | 16 | 128 | IEEE 754 half precision | +| `bf16` | 16 | 128 | Brain floating point | +| `i32` / `si32` / `ui32` | 32 | 64 | Signless/signed/unsigned 32-bit integer | +| `f32` | 32 | 64 | IEEE 754 single precision | +| `i64` / `si64` / `ui64` | 64 | 32 | Signless/signed/unsigned 64-bit integer | + +## Common Patterns + +### Softmax (Numerically Stable) + +```mlir +// 1. Find max +%max_vec = pto.vcmax %logits, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +pto.vsts %max_vec, %ub_tmp[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +%max_bc = pto.vlds %ub_tmp[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> + +// 2. exp(x - max) using fused op +%exp = pto.vexpdif %logits, %max_bc, %mask, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// 3. Sum +%sum = pto.vcadd %exp, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +pto.vsts %sum, %ub_tmp[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +%sum_bc = pto.vlds %ub_tmp[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> + +// 4. Divide +%softmax = pto.vdiv %exp, %sum_bc, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +### ReLU Variants + +```mlir +// Standard ReLU +%relu = pto.vrelu %input, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Leaky ReLU (scalar alpha) +%lrelu = pto.vlrelu %input, %alpha, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Parametric ReLU (per-element alpha) +%prelu = pto.vprelu %input, %alpha_vec, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +``` + +### Data Layout Conversion + +```mlir +// AoS → SoA (deinterleave) +%x, %y = pto.vldsx2 %ub_xy[%offset], "DINTLV_B32" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +// SoA → AoS (interleave) +pto.vstsx2 %x, %y, %ub_xy[%offset], "INTLV_B32", %all_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.ptr, index, !pto.mask +``` + +--- + +## Quick Reference by Category + +### Memory Operations + +| Operation | Group | Description | +|-----------|-------|-------------| +| GM→UB DMA | 2 | `pto.dma_load` | +| UB→GM DMA | 2 | `pto.dma_store` | +| Contiguous Load | 3 | `pto.vlds` with `NORM` dist | +| Broadcast Load | 3 | `pto.vlds` with `BRC` family dist | +| Gather | 3 | `pto.vgather2`, `pto.vgatherb` | +| Contiguous Store | 3 | `pto.vsts` with `NORM_B8` / `NORM_B16` / `NORM_B32` dist | +| Scatter | 3 | `pto.vscatter` | + +### Compute Operations + +| Operation | Group | Description | +|-----------|-------|-------------| +| Element-wise Arithmetic | 6, 7 | `pto.vadd`, `pto.vmul`, `pto.vabs`, etc. | +| Scalar Operations | 8 | `pto.vadds`, `pto.vmuls`, etc. | +| Transcendental | 6 | `pto.vexp`, `pto.vln`, `pto.vsqrt`, etc. | +| Reduction | 10 | `pto.vcadd`, `pto.vcmax`, `pto.vcmin` | +| Comparison | 11 | `pto.vcmp`, `pto.vcmps` | +| Selection | 11 | `pto.vsel`, `pto.vselr` | + +### Type & Data Manipulation + +| Operation | Group | Description | +|-----------|-------|-------------| +| Type Conversion | 9 | `pto.vcvt`, `pto.vbitcast`, `pto.pbitcast` | +| Interleave/Deinterleave | 12 | `pto.vintlv`, `pto.vdintlv` | +| Interleave/Deinterleave (not A5) | 12 | `pto.vintlvv2`, `pto.vdintlvv2` | + +### Synchronization + +| Operation | Group | Description | +|-----------|-------|-------------| +| Intra-core Sync | 1 | `pto.set_flag`, `pto.wait_flag` | +| Pipeline Buffer Sync | 1 | `pto.get_buf`, `pto.rls_buf` | + +### Scalar & Control Operations + +Group 14 covers the full scalar `arith` surface. The rows below list common PTO micro Instruction patterns rather than an exhaustive partition of `arith` ops. + +| Operation | Group | Description | +|-----------|-------|-------------| +| Scalar Constants | 14 | `arith.constant` | +| Scalar Integer / Index Arithmetic | 14 | `arith.addi`, `arith.subi`, `arith.muli`, `arith.divsi`, `arith.remui`, `arith.ceildivsi`, etc. | +| Scalar Floating-Point Arithmetic | 14 | `arith.addf`, `arith.subf`, `arith.mulf`, `arith.divf`, `arith.maximumf`, etc. | +| Scalar Compare & Select | 14 | `arith.cmpi`, `arith.cmpf`, `arith.select` | +| Scalar Casts / Width Changes | 14 | `arith.index_cast`, `arith.index_castui`, `arith.extsi`, `arith.extui`, `arith.trunci`, `arith.sitofp`, etc. | +| Scalar Bitwise / Shift Ops | 14 | `arith.andi`, `arith.ori`, `arith.xori`, `arith.shli`, `arith.shrsi`, `arith.shrui`, etc. | +| Counted Loops | 15 | `scf.for` | +| Conditional Regions | 15 | `scf.if`, `scf.yield` | +| Break-like Structured Loops | 15 | `scf.while`, `scf.condition`, `scf.yield` | + +## Part IV: PTO Tile Instruction + +PTO Tile Instruction is a high-performance instruction surface built on top of PTO micro Instruction. Each tile instruction encapsulates a tile-granular pattern — DMA between GM and on-chip buffers, vector arithmetic over a whole tile, reductions, broadcast / expansion, selection, padding — and internally expands to a sequence of micro-instruction primitives (`pto.vlds`, `pto.vsts`, `pto.vadd`, mask ops, sync flags, …). + +The full PTO Tile Instruction reference starts from [Tile and PTO Tile Instruction overview](PTO-tile-Instruction-SPEC-v0.4.md#tile-01-tile-overview). It covers: + +- [Tile and PTO Tile Instruction overview](PTO-tile-Instruction-SPEC-v0.4.md#tile-01-tile-overview) — tile concept, on-chip placement, physical shape vs valid region, conventions +- [Types & Attributes](PTO-tile-Instruction-SPEC-v0.4.md#tile-02-types-and-attributes) — `!pto.tile_buf`, `!pto.tensor_view`, address spaces, layout, pad +- [Pointer & View](PTO-tile-Instruction-SPEC-v0.4.md#tile-03-pointer-and-view) — tensor views, partitions, tile allocation, valid-shape updates +- [DMA Data Movement](PTO-tile-Instruction-SPEC-v0.4.md#tile-04-dma-data-movement) — `pto.tload` / `pto.tstore` +- [Vector Arithmetic](PTO-tile-Instruction-SPEC-v0.4.md#tile-05-vector-arithmetic) — `pto.tadd / tsub / tmul / tdiv / tmax / tmin`, tile-scalar forms, unary math, activations +- [Reductions](PTO-tile-Instruction-SPEC-v0.4.md#tile-06-reduction-ops), [Partial Elementwise](PTO-tile-Instruction-SPEC-v0.4.md#tile-07-partial-elementwise), [Bitwise & Shift](PTO-tile-Instruction-SPEC-v0.4.md#tile-08-bitwise-shift-ops), [Type Conversion](PTO-tile-Instruction-SPEC-v0.4.md#tile-09-type-conversion), [Broadcast & Expansion](PTO-tile-Instruction-SPEC-v0.4.md#tile-10-broadcast-and-expansion-ops), [Selection](PTO-tile-Instruction-SPEC-v0.4.md#tile-11-selection-ops), [Fill & Padding](PTO-tile-Instruction-SPEC-v0.4.md#tile-12-fill-and-padding-ops) + +For the boundary between Tile Instruction and the micro instruction surface (when to drop into `pto.vecscope` and how `pto.tile_buf_addr` bridges the two), see [Tile and PTO Tile Instruction overview §1.10](PTO-tile-Instruction-SPEC-v0.4.md#110-mixing-pto-tile-instruction-and-pto-micro-instruction). diff --git a/docs/release/PTO-tile-Instruction-SPEC-v0.4.md b/docs/release/PTO-tile-Instruction-SPEC-v0.4.md new file mode 100644 index 000000000..7bae9ddb7 --- /dev/null +++ b/docs/release/PTO-tile-Instruction-SPEC-v0.4.md @@ -0,0 +1,1815 @@ +# PTO Tile Instruction SPEC (A5) + +- v0.4: Initial PTO Tile Instruction SPEC covering core TileOps + +[toc] + +--- + + + +## 1. Tile and PTO Tile Instruction Overview + +> **Category:** Foundational concepts + +This chapter introduces both the tile data model and the **Tile Instruction** surface that operates on it. Read this before any of the per-group Tile Instruction references. + +--- + +### 1.1 What is PTO Tile Instruction + +**PTO Tile Instruction** is a high-performance instruction library built on top of [PTO micro Instruction](PTO-micro-Instruction-SPEC-v0.4.md#micro-01-pipeline-sync). Each tile instruction encapsulates a tile-granular pattern — DMA between GM and on-chip buffers, vector arithmetic over a whole tile, reductions, broadcast / expansion, selection, padding — that internally expands to a sequence of micro-instruction primitives (`pto.vlds`, `pto.vsts`, `pto.vadd`, mask ops, sync flags, …). + +For the kernel author this means: + +- **Author at the tile level.** Use `pto.tload`, `pto.tadd`, `pto.trowsum`, etc., to express tile-granular DMA and compute without writing the underlying vector loop. +- **Drop down to micro instruction when needed.** Inside `pto.vecscope`, `pto.tile_buf_addr` lowers a tile handle to a UB pointer, so handwritten micro-instruction code can read and write the same on-chip data. The mixing pattern is documented in [§1.10](#110-mixing-pto-tile-instruction-and-pto-micro-instruction). +- **Predictable lowering.** Because every Tile Instruction is templated against micro instruction, a kernel that mixes Tile and micro can share scratch tiles, masks, and pipeline events with no representation gap. + +The remaining chapters in this document cover the tile data types, pointer / view ops, DMA, compute families, and op-by-op syntax. The semantics below define the storage contract those ops share. + +### 1.2 Tile Buffer Model + +A **tile** is a bounded, rectangular 2-D sub-region of data that lives in **local on-chip memory** (UB, L0A, L0B, L0C, bias, or scaling buffer) and is consumed or produced by tile-level instructions. A tile is a storage object with an explicit lifetime and an explicit on-chip placement. + +Tile Instruction models tiles as **tile buffers** of type `!pto.tile_buf<...>`. A tile buffer records: + +- the **memory domain** (`loc`) — where the tile lives on chip; +- the **element type** (`dtype`) — how bits are interpreted; +- the **physical shape** (`rows`, `cols`) — how much storage the tile occupies; +- the **valid region** (`v_row`, `v_col`) — the populated sub-rectangle within the physical tile (may be `?` for runtime-dynamic); +- **layout and fractal** metadata (`blayout`, `slayout`, `fractal`, `pad`) — how elements are arranged in storage. + +This differs from a global tensor: + +- A `!pto.tensor_view` is a logical descriptor over **global memory (GM)** — shape information, no on-chip residency. +- A `!pto.partition_tensor_view` is a logical sub-window of a tensor view, still in GM. +- A `!pto.tile_buf` is the **local, on-chip** materialization of a partition — data placed in UB / L0 / bias / scaling buffers. + +Data flow between these is explicit: + +``` +!pto.tensor_view --partition_view--> !pto.partition_tensor_view --tload--> !pto.tile_buf + (GM) (GM slice) (on-chip tile) +``` + +Placement, lifetime, and reuse affect both correctness and performance. `pto.alloc_tile` makes allocation explicit, and pipeline ordering is expressed through the synchronization primitives described in [`01-pipeline-sync.md`](PTO-micro-Instruction-SPEC-v0.4.md#micro-01-pipeline-sync). + +**Explicit buffer lifetime example:** + +```mlir +%a0 = pto.alloc_tile : !pto.tile_buf +%a1 = pto.alloc_tile : !pto.tile_buf + +pto.tload ins(%pv0 : !pto.partition_tensor_view<16x16xf16>) + outs(%a0 : !pto.tile_buf) +pto.tload ins(%pv1 : !pto.partition_tensor_view<16x16xf16>) + outs(%a1 : !pto.tile_buf) +``` + +### 1.3 Hardware Memory Hierarchy + +The Ascend NPU on-chip memory layout that tile buffers map onto: + +``` +GM (Global Memory) +|- MAT (L1 Cache) +| |- LEFT (L0A — left matrix buffer) +| |- RIGHT (L0B — right matrix buffer) +| |- ACC (L0C — accumulator) +| `- BIAS (bias buffer) +`- VEC (UB — unified buffer) +``` + +`loc` on a tile buffer selects one of these domains. The full enum (with mnemonics) is defined in [§2.6 AddressSpace](#26-addressspace); each tile ISA chapter calls out which `loc` domains are legal for the ops it covers. + +### 1.4 Instruction Form + +Most Tile Instruction ops use an explicit source/destination form. The destination tile buffer is named in `outs(...)` and is updated in place: + +```mlir +pto. ins(, , ... : , , ...) + outs( : ) + [ {optional-attrs} ] +``` + +- Inputs appear inside `ins(...)` with their types. +- The output tile buffer appears inside `outs(...)`. +- Scalar operands (where applicable) are listed inside `ins(...)` alongside tile operands. +- Optional attributes follow as a trailing `{ ... }` block. + +Synchronization, sub-view, and allocation ops may diverge from this pattern (for example `pto.alloc_tile` yields a tile-buffer handle, and `pto.subset` returns a view). Each chapter states the assembly format for its ops. + +```mlir +pto.tadd ins(%a, %b : !pto.tile_buf, !pto.tile_buf) + outs(%c : !pto.tile_buf) +``` + +### 1.5 Physical Shape vs Valid Region + +Every tile buffer has two shape concepts: + +- **Physical shape** `(rows, cols)` — the extent of backing storage; static and known when the tile buffer type is declared. +- **Valid region** `(v_row, v_col)` — the populated sub-rectangle; either static or dynamic (`?`). + +The physical shape drives layout, fractal alignment, and buffer-size accounting. The valid region drives the iteration domain of compute and DMA ops. **Undefined behavior:** elements outside the valid region are padding — their contents must not be read. + +When the valid region is dynamic (`v_row = ?` or `v_col = ?`), it is provided at `pto.alloc_tile` time (or updated later with `pto.set_validshape`). Most Tile Instruction ops use the destination valid region as the iteration domain; a few ops require all operands to share the same valid region. + +### 1.6 Pipeline Association + +Every Tile Instruction op is associated with a hardware pipeline in the Decoupled Access-Execute architecture: + +| Pipeline | Symbol | Typical ops | +|----------|--------|------------| +| DMA inbound | `PIPE_MTE2` | `pto.tload` | +| DMA outbound | `PIPE_MTE3` | `pto.tstore` | +| Vector | `PIPE_V` | `pto.tadd`, `pto.tadds`, `pto.texp`, `pto.tcvt`, and the rest of the vector arithmetic set | +| Scalar | `PIPE_S` | scalar `arith`/`scf` ops interleaved with tile code | + +Cross-pipeline data dependencies are ordered explicitly, either via the **Flag/Event** mechanism (`pto.set_flag`/`pto.wait_flag`) or the **Buffer-ID** mechanism (`pto.get_buf`/`pto.rls_buf`). See [`01-pipeline-sync.md`](PTO-micro-Instruction-SPEC-v0.4.md#micro-01-pipeline-sync) for the full semantics. + +### 1.7 Scratch Operands and A2/A3 Compatibility + +Some Tile Instruction ops carry an extra `%tmp` tile operand whose only purpose is to keep the operand list aligned with the corresponding A2/A3 PTO instruction interface. Examples include `pto.txor` / `pto.txors` ([Chapter 8](#tile-08-bitwise-shift-ops)) and `pto.tsel` / `pto.tsels` ([Chapter 11](#tile-11-selection-ops)). + +`%tmp` exists for cross-arch interface compatibility — A5 templates may not materially use it, but it remains in the public op signature so the same Tile IR can be reused across A2/A3 and A5. Treat it as a required operand whose dtype/shape constraints are stated by the individual op page. + +### 1.8 Conventions for Chapters 5–12 + +Unless an op page states otherwise, the chapters that follow assume: + +- tile operands use `loc=vec`; +- tile layouts use `blayout=row_major` and `slayout=none_box`; +- valid bounds satisfy `v_row <= rows` and `v_col <= cols`; +- examples use the compact `!pto.tile_buf` form. Omitted attributes carry their default values: `valid` = physical shape, `blayout=row_major`, `slayout=none_box`, `fractal=512`, `pad=0`. + +The op pages call out any deviation from these conventions explicitly. + +### 1.9 Minimal End-to-End Example + +A minimal tile-level "load, add, store" kernel: + +```mlir +// Build the GM view and partition it +%tv = pto.make_tensor_view %gm_ptr, shape = [%m, %n], strides = [%s0, %s1] + : !pto.tensor_view +%pv = pto.partition_view %tv, offsets = [%c0, %c0], sizes = [%c16, %c16] + : !pto.tensor_view -> !pto.partition_tensor_view<16x16xf16> + +// Allocate on-chip tile buffers +%a = pto.alloc_tile : !pto.tile_buf +%b = pto.alloc_tile : !pto.tile_buf +%c = pto.alloc_tile : !pto.tile_buf + +// DMA-in, compute, DMA-out +pto.tload ins(%pv : !pto.partition_tensor_view<16x16xf16>) outs(%a : !pto.tile_buf) +pto.tload ins(%pv2 : !pto.partition_tensor_view<16x16xf16>) outs(%b : !pto.tile_buf) +pto.tadd ins(%a, %b : !pto.tile_buf, !pto.tile_buf) + outs(%c : !pto.tile_buf) +pto.tstore ins(%c : !pto.tile_buf) + outs(%pv_out : !pto.partition_tensor_view<16x16xf16>) +``` + +Synchronization is omitted for clarity; for the real ordering contracts (`pto.set_flag`/`pto.wait_flag`, `pto.get_buf`/`pto.rls_buf`, `pto.pipe_barrier`) see [`01-pipeline-sync.md`](PTO-micro-Instruction-SPEC-v0.4.md#micro-01-pipeline-sync). + + + +### 1.10 Mixing PTO Tile Instruction and PTO micro Instruction + +PTO Tile Instruction and PTO micro Instruction can be authored side-by-side in the same kernel. The Tile Instruction surface owns tile placement and GM ↔ on-chip DMA; the micro surface owns vector-register compute inside `pto.vecscope`. The two surfaces meet through `pto.tile_buf_addr`, which converts a tile handle into a UB pointer that vector ops can consume. + +This section presents a softmax kernel that uses both surfaces together, then walks through it. + +#### Kernel Structure + +The kernel follows a fixed shape that all mixed Tile + micro programs share: + +1. Build `tensor_view` / `partition_view` descriptors for each GM operand. +2. Use `pto.alloc_tile` to allocate UB tiles with explicit static **size** and **address**. +3. Use `pto.tload` to move data from GM partitions into tiles. +4. Cross the **MTE2 → V** synchronization edge with `pto.set_flag` / `pto.wait_flag`. +5. Open a `pto.vecscope` region. Inside the scope: + - Use `pto.tile_buf_addr` to lower each tile handle into a `!pto.ptr<..., ub>`. + - Use `pto.vlds` / `pto.vsts` and the rest of the micro vector ops to read, compute, and write UB. +6. Cross the **V → MTE3** synchronization edge with `pto.set_flag` / `pto.wait_flag`. +7. Use `pto.tstore` to move tiles back to GM. + +Two boundary rules govern this layout: + +- Tile-domain ops (`pto.tload`, `pto.tstore`, `pto.tadd`, …) **must not appear inside** `pto.vecscope`. +- `pto.tile_buf_addr` is **only legal inside** `pto.vecscope` / `pto.strict_vecscope`. + +The kernel also manually drives address allocation (`alloc_tile addr = ...`) and pipeline synchronization. Lowering with `--enable-insert-sync` is therefore disabled, and `--pto-level=level3` is used so that `alloc_tile` accepts an explicit address operand. + +#### Kernel Listing + +The listing below is an online softmax-update kernel reduced to the structurally interesting parts. Repeated descriptors and the deep online-softmax math are abbreviated with `// ...` so that the Tile / micro / sync boundaries stay visible. + +```mlir +module attributes {pto.target_arch = "a5"} { + func.func @online_softmax_update_kernel_2d( + %arg0: !pto.ptr, // oldmax (rows x 1) + %arg1: !pto.ptr, // oldsum (rows x 1) + %arg2: !pto.ptr, // qk (rows x 128) + %arg3: !pto.ptr, // newmax (rows x 1) + %arg4: !pto.ptr, // newsum (rows x 1) + %arg5: !pto.ptr, // expmax (rows x 1) + %arg6: !pto.ptr, // out (rows x 128) + %arg7: i32, %arg8: i32) { // %arg7 = seq_len, %arg8 = total_rows + // -------- (1) GM views and partitions -------- + // Eight rows of the qk and out tensors are processed per block. + %qk_view = pto.make_tensor_view %arg2, + shape = [%c1, %c1, %c1, %rows, %c128], + strides = [%rows_x_128, %rows_x_128, %rows_x_128, %c128, %c1] + : !pto.tensor_view + %qk_part = pto.partition_view %qk_view, + offsets = [%c0, %c0, %c0, %row_base, %c0], + sizes = [%c1, %c1, %c1, %row_count, %seq] + : !pto.tensor_view + -> !pto.partition_tensor_view + // ... oldmax/oldsum/newmax/newsum/expmax/out views/partitions analogous ... + + // -------- (2) Tile allocation with static size and explicit UB address -------- + %qk_tile = pto.alloc_tile addr = %c256_i64 + valid_row = %row_count valid_col = %seq + : !pto.tile_buf + %out_tile = pto.alloc_tile addr = %c8448_i64 + valid_row = %row_count valid_col = %seq + : !pto.tile_buf + %oldmax_tile = pto.alloc_tile addr = %c0_i64 valid_row = %row_count + : !pto.tile_buf + // ... oldsum/newmax/newsum/expmax tiles analogous (each at its own UB addr) ... + + // -------- (3) GM → tile DMA -------- + pto.tload ins(%qk_part : !pto.partition_tensor_view) + outs(%qk_tile : !pto.tile_buf) + pto.tload ins(%oldmax_part : !pto.partition_tensor_view) + outs(%oldmax_tile: !pto.tile_buf) + // ... oldsum tload analogous ... + + // -------- (4) MTE2 → V synchronization -------- + pto.set_flag ["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + // -------- (5) Vector region: tile_buf_addr + micro compute -------- + pto.vecscope { + // Lower tile handles to UB pointers. + %ub_qk = pto.tile_buf_addr %qk_tile + : !pto.tile_buf + -> !pto.ptr + %ub_out = pto.tile_buf_addr %out_tile + : !pto.tile_buf + -> !pto.ptr + %ub_newmax = pto.tile_buf_addr %newmax_tile + : !pto.tile_buf + -> !pto.ptr + // ... ub_oldmax / ub_oldsum / ub_newsum / ub_expmax analogous ... + + %active = pto.pset_b32 "PAT_ALL" : !pto.mask + %one_mask, %_ = pto.plt_b32 %c1_i32 : i32 -> !pto.mask, i32 + + scf.for %row = %c0 to %row_count step %c1 { + // Online-softmax max/sum reduction (one row at a time). + %row_qk = arith.muli %row, %c128 : index + %oldmax_bc = pto.vlds %ub_oldmax[%row] + {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> + %final_max, %final_sum = scf.for %chunk = %c0 to %c128 step %c64 + iter_args(%running_max = %oldmax_bc, %running_sum = %oldsum_bc) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %base = arith.addi %row_qk, %chunk : index + %vec = pto.vlds %ub_qk[%base] + : !pto.ptr -> !pto.vreg<64xf32> + // ... running_max / running_sum update via vcmax / vexpdif / vmul / vadd ... + scf.yield %merged_max, %merged_sum : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + // Persist the row-local results back to UB. + pto.vsts %final_max, %ub_newmax[%row], %one_mask + {dist = "1PT_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + + // Second pass: write softmax output back into the qk tile's UB region. + scf.for %chunk = %c0 to %c128 step %c64 { + %base = arith.addi %row_qk, %chunk : index + %vec = pto.vlds %ub_qk[%base] + : !pto.ptr -> !pto.vreg<64xf32> + %exp = pto.vexpdif %vec, %final_max, %chunk_mask, "ODD" + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask + -> !pto.vreg<64xf32> + %out = pto.vdiv %exp, %final_sum, %chunk_mask + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask + -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%base], %chunk_mask + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + } + } + + // -------- (6) V → MTE3 synchronization -------- + pto.set_flag ["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + // -------- (7) Tile → GM DMA -------- + pto.tstore ins(%out_tile : !pto.tile_buf) + outs(%out_part : !pto.partition_tensor_view) + pto.tstore ins(%newmax_tile : !pto.tile_buf) + outs(%newmax_part: !pto.partition_tensor_view) + // ... newsum/expmax tstore analogous ... + + pto.barrier #pto.pipe + return + } +} +``` + +#### Code Walkthrough + +The seven numbered comments in the listing above mark the seven steps from §Kernel Structure. The notes below highlight what each step contributes to the Tile / micro split. + +**(1) GM views and partitions** — pure metadata. `pto.make_tensor_view` records the GM tensor's shape and strides; `pto.partition_view` carves out the per-block sub-window. Neither op moves data, and both stay outside `pto.vecscope`. The 5-D shape is a quirk of this kernel's layout convention; the boundary rules don't depend on rank. + +**(2) `pto.alloc_tile` with static size and address** — declares the UB tile handles. The result type fixes the static physical shape (e.g. `8x128xf32`); `addr = %c256_i64` pins the tile to a specific UB byte offset; `valid_row = ...` / `valid_col = ...` carry the runtime valid extents (the `?` markers in `valid=?x?`). Because addresses are hand-assigned, this kernel compiles with `--pto-level=level3` and disables `--enable-insert-sync`. + +**(3) `pto.tload`** — copies a GM partition into the UB tile. Runs on `PIPE_MTE2`. Stays in the Tile domain; it cannot appear inside `pto.vecscope`. + +**(4) MTE2 → V flag handshake** — DMA inbound and the vector pipeline run asynchronously. The producer/consumer edge between `tload` and the upcoming `vecscope` must be made explicit with `pto.set_flag` / `pto.wait_flag`. + +**(5) Vector region** — `pto.vecscope` opens a vector-execution region. The first thing inside is a series of `pto.tile_buf_addr` ops, each lowering a tile handle into a `!pto.ptr`. From that point on the body is pure micro: `pto.vlds` reads UB into vregs, vector arithmetic / SFU / mask ops compute on vregs, and `pto.vsts` writes vregs back to UB. Tile ops are forbidden inside this region; `pto.tile_buf_addr` is forbidden outside. + +**(6) V → MTE3 flag handshake** — mirror of step (4), this time gating the vector results visible to the outbound DMA. + +**(7) `pto.tstore`** — writes each UB tile back to its GM partition, completing the round trip. Same Tile-domain rules as `tload`. + +#### Where the Tile and Micro Boundaries Sit + +| Op | Where it must live | Why | +|----|-------------------|-----| +| `pto.alloc_tile`, `pto.tload`, `pto.tstore`, `pto.tadd`, … (Tile domain) | **Outside** `pto.vecscope` | Tile ops describe tile residency and tile-granular DMA / compute; they have no meaning inside a vector-register region. | +| `pto.vlds`, `pto.vsts`, `pto.vmax`, `pto.vexpdif`, … (micro domain) | **Inside** `pto.vecscope` | These ops produce/consume `!pto.vreg` and `!pto.mask` values that only exist inside a vector region. | +| `pto.tile_buf_addr` | **Inside** `pto.vecscope` only | This is the single sanctioned bridge from a tile handle to a UB pointer; outside vecscope, tile handles must be consumed by Tile ops, not by address extraction. | +| `pto.set_flag` / `pto.wait_flag` (and other sync primitives) | Either side | Sync ops belong to whichever pipeline edge they coordinate; in this kernel they appear at the MTE2 → V and V → MTE3 boundaries. | + +In short: keep DMA and tile shape management in Tile-land, keep vreg/mask compute in vecscope, and use `pto.tile_buf_addr` exactly at the boundary. + + + +## 2. Types & Attributes + +> **Category:** Type system and attribute vocabulary + +This chapter defines the types and attributes used across the Tile Instruction chapters. + +--- + +### 2.1 Element Types + +Element types describe the primitive scalar values stored in tiles; by themselves they do not form a value. Common element categories: + +- **Integers:** signless — `i1`, `i8`, `i16`, `i32`, `i64`. Signedness is not encoded in the type; it is selected by operation semantics or attributes. +- **Floating-point:** `f16`, `bf16`, `f32`. +- **Index-like:** `index` values appear as scalar operands (offsets, sizes, scalar compares). + +Operation-specific constraints: + +- Elementwise ops typically require operand and result element types to match. +- Reductions, math ops, and division typically restrict to floating-point or a subset of integer types. +- Bitwise ops require integer element types. +- `pto.tcvt` defines explicit element-type changes under an explicit rounding mode. + +Memory layout and address space do not change element-type semantics; they only affect placement and access patterns. + +### 2.2 `!pto.ptr` + +A typed pointer. `memorySpace` is optional and defaults to `gm`. + +| Parameter | Type | Description | +|-----------|------|-------------| +| `elementType` | element type | Element type pointed to. | +| `memorySpace` | `gm` \| `vec` | Pointer address space (`gm` → global memory, `vec` → UB / vector memory). | + +**Syntax:** `!pto.ptr` or `!pto.ptr` + +Pointer conversions are modeled explicitly with `pto.castptr`. Between two `!pto.ptr` types, casts are only legal when both pointers stay in the same PTO memory space. + +### 2.3 `!pto.tensor_view` + +A descriptor for a global-memory tensor. Holds shape information; strides are supplied at `pto.make_tensor_view` construction time. Does not own data. + +| Parameter | Type | Description | +|-----------|------|-------------| +| `shape` | `ArrayRef` | Tensor shape `[d0, d1]` (each dim may be `?`). | +| `elementType` | element type | Element data type. | + +**Syntax:** `!pto.tensor_view<1024x512xf16>` + +### 2.4 `!pto.partition_tensor_view` + +A logical partition (slice) of a `tensor_view`. Holds shape information for a tile-sized region; strides are inherited from the parent `tensor_view`. Does not own data. + +| Parameter | Type | Description | +|-----------|------|-------------| +| `shape` | `ArrayRef` | Partition shape `[d0, d1]`. | +| `elementType` | element type | Element data type. | + +**Syntax:** `!pto.partition_tensor_view<16x16xf16>` + +### 2.5 `!pto.tile_buf` + +`pto.tile_buf` represents a local on-chip tile buffer with explicit placement, shape, valid region, and layout/fractal metadata. The textual form is **compact**: only the leading `` triple is mandatory; everything else is omitted when it equals its default. + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `loc` | keyword | — | Local memory domain (`vec` / `mat` / `left` / `right` / `acc` / `bias` / `scaling`). | +| `R` × `C` × `dtype` | shape × element type | — | Physical row/column count and element type. | +| `valid` | `v_row x v_col` (each `int64` or `?`) | `R x C` | Valid region. Omitted when equal to physical shape. | +| `blayout` | `BLayout` | `row_major` | Base layout. | +| `slayout` | `SLayout` | `none_box` | Secondary layout. | +| `fractal` | `int32` | `512` | Fractal size. | +| `pad` | `PadValue` enum int | `0` (`null`) | Padding policy/value selector. | + +**Examples:** + +```mlir +// Default config, valid == physical +!pto.tile_buf + +// Dynamic valid region +!pto.tile_buf + +// Non-default config +!pto.tile_buf +``` + +`?` denotes a dynamic symbol resolved at runtime (via `pto.alloc_tile` operands or `pto.set_validshape`). + +### 2.6 AddressSpace + +Defines the physical storage location of a buffer in the Ascend NPU memory hierarchy. + +| Value | Int | Mnemonic | Hardware Mapping | +|-------|-----|----------|------------------| +| `Zero` | 0 | `zero` | Default (unspecified). | +| `GM` | 1 | `gm` | Global Memory. | +| `MAT` | 2 | `mat` | L1 Cache. | +| `LEFT` | 3 | `left` | L0A (left matrix buffer). | +| `RIGHT` | 4 | `right` | L0B (right matrix buffer). | +| `ACC` | 5 | `acc` | L0C (accumulator). | +| `VEC` | 6 | `vec` | UB (unified buffer). | +| `BIAS` | 7 | `bias` | Bias buffer. | +| `SCALING` | 8 | `scaling` | Scaling buffer. | + +**Attribute syntax:** `loc=` (for example `loc=vec`). + +### 2.7 Tile Buf Config + +Composite attribute for tile-buffer layout/fractal/pad. + +| Parameter | Type | Description | +|-----------|------|-------------| +| `bLayout` | `BLayoutAttr` | Base layout (RowMajor / ColMajor). | +| `sLayout` | `SLayoutAttr` | Secondary layout (NoneBox / RowMajor / ColMajor). | +| `sFractalSize` | `IntegerAttr (i32)` | Secondary fractal size. | +| `pad` | `PadValueAttr` | Pad value policy. | + +**Syntax:** `#pto.tile_buf_config` + +**BLayout:** + +| Value | Int | Mnemonic | +|-------|-----|----------| +| `RowMajor` | 0 | `row_major` | +| `ColMajor` | 1 | `col_major` | + +**SLayout:** + +| Value | Int | Mnemonic | +|-------|-----|----------| +| `NoneBox` | 0 | `none_box` | +| `RowMajor` | 1 | `row_major` | +| `ColMajor` | 2 | `col_major` | + +**PadValue:** + +| Value | Int | Mnemonic | +|-------|-----|----------| +| `Null` | 0 | `null` | +| `Zero` | 1 | `zero` | +| `Max` | 2 | `max` | +| `Min` | 3 | `min` | + +### 2.8 Layout + +Global tensor layout attribute for `tensor_view` and `partition_tensor_view`. Tile buffers additionally use **Tile Buf Config** (§2.7) to describe physical/fractal layout. + +| Value | Int | Mnemonic | Description | +|-------|-----|----------|-------------| +| `ND` | 0 | `nd` | Row-major (Normal-Dimension). | +| `DN` | 1 | `dn` | Column-major (Dimension-Normal). | +| `NZ` | 2 | `nz` | Fractal / blocked layout. | + +**Attribute syntax:** `#pto.layout` + +### 2.9 PadMode (for loads) + +Padding mode for `pto.tload`. + +| Value | Int | Description | +|-------|-----|-------------| +| `PadNull` | 0 | No padding. | +| `PadFirstElem` | 1 | Pad using the first element. | +| `PadValue` | 2 | Pad using a specified value. | + +### 2.10 Shared Scalar and Control-Flow Ops + +Tile programs commonly interleave `pto` instructions with a small set of supporting ops: + +- **`func`** — `func.func`, `func.return`, `func.call`. +- **`arith`** — scalar constants/casts (`arith.constant`, `arith.index_cast`, `arith.bitcast`, `arith.extf`/`truncf`/…), integer/float arithmetic, bitwise/shift, compares/select, extended and min/max ops. +- **`scf`** — `scf.for`, `scf.if`, `scf.yield`; several other structured control-flow forms are lowered through `cf`. + +These supporting ops are included here only insofar as tile programs need function structure, scalar computation, and structured control flow; full coverage of those surfaces is out of scope for this reference. + + + +## 3. Pointer & View Operations + +> **Category:** Address arithmetic, tensor-view construction, tile-buffer allocation +> **Pipeline:** None (all ops are metadata / view construction; no HW side effect) + +These instructions build the address, view, and tile-buffer metadata that later DMA and compute instructions consume. None of them moves data. + +--- + +### `pto.addptr` + +- **syntax:** +```mlir +%result = pto.addptr %base, %offset : !pto.ptr -> !pto.ptr +``` +- **semantics:** `result = ptr + offset`, with `offset` counted in **elements** (not bytes). + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `%base` | `!pto.ptr` | Base pointer. | +| `%offset` | `index` | Element offset. | + +**Constraints:** + +- Result type must match the input pointer type. +- The op is pure (no side effects). + +**Example:** + +```mlir +%ptr_off = pto.addptr %base, %offset : !pto.ptr -> !pto.ptr +``` + +--- + +### `pto.castptr` + +- **syntax:** +```mlir +%p_ptr = pto.castptr %addr : i64 -> !pto.ptr +%p_ptr2 = pto.castptr %p_ptr : !pto.ptr -> !pto.ptr +%addr2 = pto.castptr %p_ptr : !pto.ptr -> i64 +``` +- **semantics:** Explicit cast between integer addresses and `!pto.ptr`, or between two `!pto.ptr` types. + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `input` | integer \| `!pto.ptr<...>` | Source value. | + +**Constraints:** + +- Integer-to-integer casts are rejected; use normal integer cast ops. +- Descriptor types (`!pto.tensor_view<...>`, `!pto.partition_tensor_view<...>`) are not legal direct inputs; extract an address first. +- Pointer-to-pointer casts are only legal when source and destination stay in the same PTO memory space (`gm` or `vec`). +- The op is pure. + +**Example:** + +```mlir +%p0 = pto.castptr %addr : i64 -> !pto.ptr +%p1 = pto.castptr %p0 : !pto.ptr -> !pto.ptr +%a2 = pto.castptr %p1 : !pto.ptr -> i64 +``` + +--- + +### `pto.make_tensor_view` + +- **syntax:** +```mlir +%tv = pto.make_tensor_view %ptr, shape = [%m, %n], strides = [%s0, %s1] + : !pto.tensor_view +``` +- **semantics:** Construct a global tensor view from a pointer, declaring the physical base and strides. No allocation, no data movement. + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `%ptr` | `AnyType` | Source pointer (must be `!pto.ptr` with element type matching the result). | +| `shape` | `Variadic` | Dynamic shape dimensions. | +| `strides` | `Variadic` | Dynamic strides. | +| `layout` | `LayoutAttr` (optional) | `nd` / `dn` / `nz` hint. | + +**Constraints:** + +- `ptr` element type must match the result element type. +- `shape` and `strides` operand counts must match the tensor_view rank. +- If `layout` is provided with static shapes/strides, it must be consistent with the inferred layout. + +**Example:** + +```mlir +%tv = pto.make_tensor_view %ptr, shape = [%m, %n], strides = [%s0, %s1] + : !pto.tensor_view +``` + +--- + +### `pto.get_tensor_view_dim` + +- **syntax:** +```mlir +%dim = pto.get_tensor_view_dim %tv, %idx : !pto.tensor_view<...> -> index +``` +- **semantics:** Return the runtime size of dimension `%idx` from a `tensor_view`. + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `%tv` | `!pto.tensor_view<...>` | Logical tensor view. | +| `%idx` | `index` | Dimension index (0-based). | + +**Example:** + +```mlir +%h = pto.get_tensor_view_dim %tv, %c0 : !pto.tensor_view -> index +``` + +--- + +### `pto.get_tensor_view_stride` + +- **syntax:** +```mlir +%stride = pto.get_tensor_view_stride %tv, %idx : !pto.tensor_view<...> -> index +``` +- **semantics:** Return the logical stride of dimension `%idx`, measured in **elements** (not bytes). + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `%tv` | `!pto.tensor_view<...>` or memref form | Tensor view or its lowered memory-reference form. | +| `%idx` | `index` | Dimension index (0-based). | + +**Example:** + +```mlir +%s0 = pto.get_tensor_view_stride %tv, %c0 : !pto.tensor_view -> index +``` + +--- + +### `pto.tensor_view_addr` + +- **syntax:** +```mlir +%result = pto.tensor_view_addr %src : !pto.tensor_view<...> -> memref<...> +%result = pto.tensor_view_addr %src : !pto.tensor_view<...> -> !pto.ptr +``` +- **semantics:** Extract the underlying address view from a `tensor_view` or `partition_tensor_view`. + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `%src` | `!pto.tensor_view<...>` or `!pto.partition_tensor_view<...>` | Source view descriptor. | + +**Constraints:** + +- The result type must be either the lowered memref view or a GM pointer `!pto.ptr` to the same underlying storage. +- The op is pure and does not move data. + +**Example:** + +```mlir +%base = pto.tensor_view_addr %tv : !pto.tensor_view -> !pto.ptr +``` + +`pto.tensor_view_addr` exposes the underlying address represented by the view descriptor. When the result type is a memref, it exposes the lowered view directly. When the result type is `!pto.ptr<..., gm>`, it exposes the same address in pointer form. During compiler-internal lowering, the operand may already be rewritten to a memref form; in that case this op is folded away or rewritten to an equivalent memref-to-ptr cast. + +--- + +### `pto.partition_view` + +- **syntax:** +```mlir +%pv = pto.partition_view %tv, offsets = [%o0, %o1], sizes = [%s0, %s1] + : !pto.tensor_view<...> -> !pto.partition_tensor_view<...> +``` +- **semantics:** `result = source[offsets, sizes]` — a logical window on a `tensor_view`. Captures both static and dynamic shapes; does not move data. + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `%tv` | `TensorViewType` | Input tensor view. | +| `offsets` | `Variadic` | Dynamic offsets. | +| `sizes` | `Variadic` | Dynamic sizes. | + +**Constraints:** + +- `offsets`/`sizes` counts must match the rank of `source`. + +**Example:** + +```mlir +%pv = pto.partition_view %tv, offsets = [%off0, %off1], sizes = [%s0, %s1] + : !pto.tensor_view<1024x512xf16> -> !pto.partition_tensor_view<16x16xf16> +``` + +--- + +### `pto.alloc_tile` + +- **syntax:** +```mlir +%tb = pto.alloc_tile : !pto.tile_buf<...> +%tb2 = pto.alloc_tile valid_row = %vr valid_col = %vc : !pto.tile_buf +%tb3 = pto.alloc_tile addr = %ad : !pto.tile_buf<...> +``` +- **semantics:** Declare the lifetime of a tile buffer. Each call produces an **independent** tile-buffer instance. + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `addr` | `Optional` | Optional start address. If omitted, assigned by the implementation. | +| `valid_row` | `Optional` | Dynamic valid-row count (required when result `v_row = ?`). | +| `valid_col` | `Optional` | Dynamic valid-col count (required when result `v_col = ?`). | + +**Constraints:** + +- If result `v_row`/`v_col` are dynamic (`?`), the corresponding operands must be present. +- If result `v_row`/`v_col` are static, the corresponding operands must be absent. + +**Example:** + +```mlir +%tb = pto.alloc_tile : !pto.tile_buf +``` + +--- + +### `pto.subset` + +- **syntax:** +```mlir +%sub = pto.subset %src[%i, %j] sizes [rows, cols] : !pto.tile_buf<...> +``` +- **semantics:** `result = source[offsets]` with static `sizes`. Creates a strided view of a parent tile. + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `%src` | `pto.tile_buf` | Parent tile buffer. | +| `offsets` | `Variadic` | Runtime offsets `[i, j]`. | +| `sizes` | `I64ArrayAttr` | Static shape `[rows, cols]`. | + +**Constraints:** + +- Boxed-vs-non-boxed behavior is derived from the source's tile config (`blayout`, `slayout`, `fractal`) and element type. +- For non-boxed layouts (`slayout=none_box`), no additional subset-specific structural checks are enforced. +- For boxed layouts: + - `sizes` must have length 2 and both subset sizes must be positive. + - Subset sizes must be multiples of the inferred inner boxed shape. + - `offsets` must have length 2; constant offsets must be non-negative and multiples of the inferred inner boxed shape. + - Source tile shape must be statically known. + - For boxed row-major tiles: subset must keep the full source column extent, and the column offset must be the constant `0`. + - For boxed col-major tiles: subset must keep the full source row extent, and the row offset must be the constant `0`. +- The inferred result reuses the source's element type, address space, and tile config. `valid_shape` is derived from the parent valid shape and constant offsets, or dynamic when offsets are dynamic. + +**Example:** + +```mlir +%sub = pto.subset %src[%i, %j] sizes [32, 32] + : !pto.tile_buf +``` + +--- + +### `pto.set_validshape` + +- **syntax:** +```mlir +pto.set_validshape %src, %valid_row, %valid_col : !pto.tile_buf +``` +- **semantics:** Update the runtime `v_row`/`v_col` metadata on an existing **dynamic** tile buffer. + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `%src` | `pto.tile_buf` | Dynamic rank-2 tile buffer. | +| `%valid_row` | `index` | Runtime valid row count. | +| `%valid_col` | `index` | Runtime valid column count. | + +**Constraints:** + +- `%src` must be rank-2 and use `v_row = ?` and `v_col = ?` on both dimensions. +- Tile programs use `pto.tile_buf`; memref forms are a lowering artifact and are not part of this surface. +- Constant `valid_row`/`valid_col` must be non-negative and `<=` the tile's static shape bounds. + +**Example:** + +```mlir +%src = pto.alloc_tile : !pto.tile_buf +pto.set_validshape %src, %vr, %vc : !pto.tile_buf +``` + +--- + +### `pto.tile_buf_addr` + +- **syntax:** +```mlir +%ub_ptr = pto.tile_buf_addr %tile : !pto.tile_buf<...> -> !pto.ptr +%ub_ref = pto.tile_buf_addr %tile : !pto.tile_buf<...> -> memref<...> +``` +- **semantics:** Extract the address of a `pto.tile_buf`'s data region. Returns either a typed PTO pointer (`!pto.ptr`) or a memref view, depending on the requested result type. Pure op: no data movement, no pipeline activity. + +This op is the **boundary between tile-buffer instructions and pointer-based vector instructions**. Inside a `pto.vecscope` body, use `pto.tile_buf_addr` to materialize a vec-space pointer from a tile handle allocated outside the scope; vector load/store ops such as `pto.vlds` and `pto.vsts` then consume that pointer. + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `%tile` | `pto.tile_buf` or tile-bound memref | Tile handle whose data-region address is taken. | + +**Results:** `!pto.ptr` or `memref<...>`. Memref results use the tile's static shape and address space; pointer results use the tile's element type and memory space (e.g. `vec`). + +**Constraints:** + +- Result must be either a typed PTO pointer or a memref view; no other result types are accepted. +- When a memref result is requested, the lowered form uses the tile's static shape and address space. +- `pto.tile_buf_addr` is **only legal inside `pto.vecscope` / `pto.strict_vecscope`**. Outside a vector scope, tile handles must be consumed by tile-level ops (`pto.tload`, `pto.tstore`, `pto.tadd`, …) rather than by address extraction. Conversely, tile-level ops must **not** appear inside `pto.vecscope`. + +**Example (inside `pto.vecscope`):** + +```mlir +%tile = pto.alloc_tile addr = %c0_i64 valid_row = %r + : !pto.tile_buf + +pto.vecscope { + %ub = pto.tile_buf_addr %tile + : !pto.tile_buf -> !pto.ptr + // ... vector-scope loads/stores on %ub ... +} +``` + +See [`03-vector-load-store.md`](PTO-micro-Instruction-SPEC-v0.4.md#micro-03-vector-load-store) for the pointer-based +vector load/store side of this handoff. + + + +## 4. DMA Data Movement + +> **Category:** GM↔on-chip DMA for tile buffers +> **Pipelines:** PIPE_MTE2 (GM→UB), PIPE_MTE3 (UB→GM), PIPE_FIX (when source is `loc=acc`) + +This chapter documents the public tile DMA instructions `pto.tload` and `pto.tstore`. Other raw scalar load/store helpers are outside the current tile-instruction subset and are not covered here. + +--- + +### `pto.tload` + +- **syntax:** +```mlir +pto.tload ins(%src : !pto.partition_tensor_view<...>) + outs(%dst : !pto.tile_buf<...>) +``` +- **semantics:** Physical DMA transfer from a global partition view into a local tile buffer. For each element `(i, j)` in the destination valid region: `dst[i, j] = src[i, j]`. + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `PartitionTensorViewType` | Source partition view. | +| `dst` | `pto.tile_buf` | Destination tile buffer. | + +**Constraints:** + +- Tile element type ∈ `{i8, i16, i32, i64, f16, bf16, f32}`. +- Destination tile must use `loc=vec`. +- Destination tile element type and source partition element type must have the same bitwidth. +- Runtime: source partition extents and destination valid region must be positive. + +**Pipeline:** `PIPE_MTE2`. + +**Example:** + +```mlir +pto.tload ins(%pv : !pto.partition_tensor_view<16x16xf16>) + outs(%tb : !pto.tile_buf) +``` + +--- + +### `pto.tstore` + +- **syntax:** +```mlir +pto.tstore ins(%src : !pto.tile_buf<...>) + outs(%dst : !pto.partition_tensor_view<...>) +``` +- **semantics:** Store a 2-D tile buffer back to a 2-D partition view. For each element `(i, j)` in the source valid region: `dst[i, j] = src[i, j]`. + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `pto.tile_buf` | Source tile buffer. | +| `dst` | `PartitionTensorViewType` | Destination partition view. | + +**Constraints:** + +- `src` must be `!pto.tile_buf`, `dst` must be `!pto.partition_tensor_view`. +- Static dst shape dims and static src valid-shape dims must be positive. +- `src.loc ∈ {vec, mat, acc}`. +- For `loc=vec` / `loc=mat`: src element type ∈ `{i8, i16, i32, i64, f16, bf16, f32}`; src/dst element bitwidth must match. +- For `loc=acc`: + - src element type must be `i32` or `f32`. + - dst element type ∈ `{i32, f32, f16, bf16}`. + +**Pipeline:** + +- `src.loc=acc` uses **PIPE_FIX**. +- `src.loc=vec` / `src.loc=mat` uses **PIPE_MTE3**. + +**Example:** + +```mlir +pto.tstore ins(%tb : !pto.tile_buf) + outs(%pv : !pto.partition_tensor_view<16x16xf16>) +``` + + + +## 5. Vector Arithmetic and Activation Operations + +> **Category:** Base tile-local VEC arithmetic +> **Pipeline:** PIPE_V + +This chapter documents the TileLib arithmetic families that keep the same output tile shape as their source tiles. These instructions operate on `!pto.tile_buf` values in `loc=vec` and cover tile-tile arithmetic, tile-scalar arithmetic, unary math, and activation ops. + +Reduction, partial, bitwise, conversion, broadcast / expansion, selection, and fill / padding families are documented in Chapters 6-12. + +--- + +### 5.1 Binary Tile-Tile Arithmetic + +Tile-tile arithmetic families: + +| Op | Semantics | +|----|-----------| +| `pto.tadd` | `dst[i, j] = src0[i, j] + src1[i, j]` | +| `pto.tsub` | `dst[i, j] = src0[i, j] - src1[i, j]` | +| `pto.tmul` | `dst[i, j] = src0[i, j] * src1[i, j]` | +| `pto.tdiv` | `dst[i, j] = src0[i, j] / src1[i, j]` | +| `pto.tmax` | `dst[i, j] = max(src0[i, j], src1[i, j])` | +| `pto.tmin` | `dst[i, j] = min(src0[i, j], src1[i, j])` | + +#### Common Syntax + +```mlir +pto. ins(%src0, %src1 : !pto.tile_buf<...>, !pto.tile_buf<...>) + outs(%dst : !pto.tile_buf<...>) +``` + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src0` | `pto.tile_buf` | First source tile buffer. | +| `src1` | `pto.tile_buf` | Second source tile buffer. | +| `dst` | `pto.tile_buf` | Destination tile buffer. | + +**Constraints:** + +- `src0`, `src1`, and `dst` must be shape-compatible tile buffers on `loc=vec`. +- The valid region must match across all three tiles. +- Element type legality is target-defined; ops specialize over the tile dtype selected at expansion time. +- `pto.tdiv` uses element-wise division; **undefined behavior** on divide-by-zero. + +**Example:** + +```mlir +pto.tadd ins(%a, %b : !pto.tile_buf, !pto.tile_buf) + outs(%c : !pto.tile_buf) +``` + +--- + +### 5.2 Tile-Scalar Arithmetic + +Tile-scalar families: + +| Op | Supported operand form(s) | Semantics | +|----|---------------------------|-----------| +| `pto.tadds` | `tile, scalar` | `dst[i, j] = src[i, j] + scalar` | +| `pto.tsubs` | `tile, scalar` | `dst[i, j] = src[i, j] - scalar` | +| `pto.tmuls` | `tile, scalar` | `dst[i, j] = src[i, j] * scalar` | +| `pto.tdivs` | `tile, scalar` and `scalar, tile` | `dst = src / scalar` or `dst = scalar / src` | +| `pto.tmaxs` | `tile, scalar` | `dst[i, j] = max(src[i, j], scalar)` | +| `pto.tmins` | `tile, scalar` | `dst[i, j] = min(src[i, j], scalar)` | + +#### Common Syntax + +For `pto.tadds`, `pto.tsubs`, `pto.tmuls`, `pto.tmaxs`, and `pto.tmins`: + +```mlir +pto. ins(%src, %scalar : !pto.tile_buf<...>, ) + outs(%dst : !pto.tile_buf<...>) +``` + +For `pto.tdivs`: + +```mlir +pto.tdivs ins(%src, %scalar : !pto.tile_buf<...>, ) + outs(%dst : !pto.tile_buf<...>) + +pto.tdivs ins(%scalar, %src : , !pto.tile_buf<...>) + outs(%dst : !pto.tile_buf<...>) +``` + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `pto.tile_buf` | Source tile buffer. | +| `scalar` | signless integer / floating-point scalar | Scalar broadcast across the tile. | +| `dst` | `pto.tile_buf` | Destination tile buffer. | + +**Constraints:** + +- `src` and `dst` must be shape-compatible `loc=vec` tile buffers. +- The scalar element type must be compatible with the tile element type. +- `pto.tdivs` is the only scalar family with two public operand orders. **Undefined behavior** on divide-by-zero (either `scalar==0` or any `src[i,j]==0` in the `scalar/src` form). + +**Example:** + +```mlir +pto.tadds ins(%a, %s : !pto.tile_buf, f32) + outs(%c : !pto.tile_buf) +``` + +```mlir +pto.tdivs ins(%s, %a : f32, !pto.tile_buf) + outs(%c : !pto.tile_buf) +``` + +--- + +### 5.3 Unary Math + +All ops below share the common form: + +```mlir +pto. ins(%src : !pto.tile_buf<...>) + outs(%dst : !pto.tile_buf<...>) +``` + +| Op | Semantics | +|----|-----------| +| `pto.tabs` | `dst = abs(src)` | +| `pto.tneg` | `dst = -src` | +| `pto.texp` | `dst = exp(src)` | +| `pto.tlog` | `dst = ln(src)` | +| `pto.tsqrt` | `dst = sqrt(src)` | +| `pto.trsqrt` | `dst = 1 / sqrt(src)` | +| `pto.trecip` | `dst = 1 / src` | + +**Constraints:** + +- `src` and `dst` must have the same valid region. +- These ops are numeric Tile Instruction ops on `loc=vec`. +- **Undefined behavior** on out-of-domain inputs: `tlog(<=0)`, `tsqrt(<0)`, `trsqrt(<=0)`, `trecip(0)`. + +**Example:** + +```mlir +pto.tabs ins(%a : !pto.tile_buf) + outs(%c : !pto.tile_buf) +``` + +--- + +### 5.4 Activation Operations + +Activation family: + +| Op | Semantics | +|----|-----------| +| `pto.trelu` | `dst[i, j] = max(0, src[i, j])` | +| `pto.tlrelu` | `dst[i, j] = src[i, j] > 0 ? src[i, j] : slope * src[i, j]` | + +#### Common Forms + +ReLU: + +```mlir +pto.trelu ins(%src : !pto.tile_buf<...>) + outs(%dst : !pto.tile_buf<...>) +``` + +Leaky ReLU: + +```mlir +pto.tlrelu ins(%src, %slope : !pto.tile_buf<...>, f32) + outs(%dst : !pto.tile_buf<...>) +``` + +**Constraints:** + +- `src` and `dst` must have the same valid region. +- `pto.trelu` supports `f16`, `f32`, and `i32`. +- `pto.tlrelu` supports `f16` and `f32`, with the slope passed as an `f32` scalar operand. +- Both ops execute on `loc=vec` tiles via the vector pipeline. + +**Example:** + +```mlir +pto.trelu ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) +``` + + + +## 6. Reduction Operations + +> **Category:** Tile-local VEC reductions +> **Pipeline:** PIPE_V + +This chapter documents the TileLib reduction families. These ops reduce one or more source dimensions into smaller destination tiles and are organized into row-reduction and column-reduction groups. + +--- + +### 6.1 Row Reductions + +Row reductions reduce each row of `%src` into one element stored at `%dst[row, 0]`. The op shape carries a scratch tile operand `%tmp` to keep the operand list aligned with the A2/A3 PTO instruction interface (see [§1.7](#17-scratch-operands-and-a2a3-compatibility)). + +#### Common Syntax + +```mlir +pto. ins(%src, %tmp : !pto.tile_buf<...>, !pto.tile_buf<...>) + outs(%dst : !pto.tile_buf<...>) +``` + +| Op | Semantics | +|----|-----------| +| `pto.trowsum` | `dst[i, 0] = sum_j src[i, j]` | +| `pto.trowprod` | `dst[i, 0] = prod_j src[i, j]` | +| `pto.trowmax` | `dst[i, 0] = max_j src[i, j]` | +| `pto.trowmin` | `dst[i, 0] = min_j src[i, j]` | +| `pto.trowargmax` | `dst[i, 0] = argmax_j src[i, j]` | +| `pto.trowargmin` | `dst[i, 0] = argmin_j src[i, j]` | + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `pto.tile_buf` | Source tile buffer. | +| `tmp` | `pto.tile_buf` | Scratch tile for A2/A3 interface compatibility. | +| `dst` | `pto.tile_buf` | Destination tile storing one result per source row. | + +**Constraints:** + +- `dst.v_row` should match `src.v_row`. +- `dst.v_col` should be `1`. +- `pto.trowargmax` and `pto.trowargmin` require an integer destination element type for the row-local index result. +- Numeric widening / narrowing inside the reduction is target-defined by the selected template (e.g. `trowsum` may widen `i16` accumulation internally before storing to `dst`). + +**Example:** + +```mlir +pto.trowsum ins(%src, %tmp : !pto.tile_buf, !pto.tile_buf) + outs(%dst : !pto.tile_buf) +``` + +--- + +### 6.2 Column Reductions + +Column reductions reduce each column of `%src` into one element stored at `%dst[0, col]`. + +#### Common Syntax + +```mlir +pto. ins(%src : !pto.tile_buf<...>) + outs(%dst : !pto.tile_buf<...>) +``` + +| Op | Semantics | +|----|-----------| +| `pto.tcolsum` | `dst[0, j] = sum_i src[i, j]` | +| `pto.tcolprod` | `dst[0, j] = prod_i src[i, j]` | +| `pto.tcolmax` | `dst[0, j] = max_i src[i, j]` | +| `pto.tcolmin` | `dst[0, j] = min_i src[i, j]` | + +**Constraints:** + +- `dst.v_row` should be `1`. +- `dst.v_col` should match `src.v_col`. +- Templates assume prefix-aligned valid regions and row-major VEC tiles. + +**Example:** + +```mlir +pto.tcolsum ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) +``` + + + +## 7. Partial Elementwise Operations + +> **Category:** Tile-local VEC partial-shape compute +> **Pipeline:** PIPE_V + +This chapter documents the TileLib partial elementwise families. These ops combine two tiles whose valid regions may differ, but whose overlap starts at `[0, 0]`. + +--- + +### Common Syntax + +```mlir +pto. ins(%src0, %src1 : !pto.tile_buf<...>, !pto.tile_buf<...>) + outs(%dst : !pto.tile_buf<...>) +``` + +| Op | Semantics on the overlap region | +|----|----------------------------------| +| `pto.tpartadd` | `dst = src0 + src1` | +| `pto.tpartmul` | `dst = src0 * src1` | +| `pto.tpartmax` | `dst = max(src0, src1)` | +| `pto.tpartmin` | `dst = min(src0, src1)` | + +**Constraints:** + +- Let `big` ∈ {`src0`, `src1`} be the operand whose valid shape equals `dst.valid_shape`, and `small` be the other operand. Exactly one operand plays each role. +- `small.valid_shape` must be a prefix-aligned sub-rectangle of `dst.valid_shape` (i.e. starting at `[0, 0]`). +- For `pto.tpartadd` and `pto.tpartmul`: outside the overlap (where only `big` covers `dst`), `dst` takes `big`'s value. +- For `pto.tpartmax` and `pto.tpartmin`: A5 templates initialize `dst` with the dtype extremum before merging the operands, so uncovered regions follow the template's pad-extremum behavior. + +**Example:** + +```mlir +pto.tpartadd ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) +``` + + + +## 8. Bitwise and Shift Operations + +> **Category:** Tile-local integer VEC compute +> **Pipeline:** PIPE_V + +This chapter documents the integer-only TileLib bitwise and shift families. + +--- + +### 8.1 Unary Bitwise NOT: `pto.tnot` + +- **syntax:** +```mlir +pto.tnot ins(%src : !pto.tile_buf<...>) + outs(%dst : !pto.tile_buf<...>) +``` +- **semantics:** `dst = ~src`. + +**Constraints:** + +- Tile element types must be integer types. +- `src` and `dst` must have the same valid region. + +**Example:** + +```mlir +pto.tnot ins(%a : !pto.tile_buf) + outs(%c : !pto.tile_buf) +``` + +--- + +### 8.2 Binary Tile-Tile Bitwise and Shift Families + +Tile-tile bitwise and shift families: + +| Op | Semantics | +|----|-----------| +| `pto.tand` | `dst = src0 & src1` | +| `pto.tor` | `dst = src0 \| src1` | +| `pto.txor` | `dst = src0 ^ src1` | +| `pto.tshl` | `dst = src0 << src1` | +| `pto.tshr` | `dst = src0 >> src1` | + +#### Common Forms + +For `pto.tand`, `pto.tor`, `pto.tshl`, and `pto.tshr`: + +```mlir +pto. ins(%src0, %src1 : !pto.tile_buf<...>, !pto.tile_buf<...>) + outs(%dst : !pto.tile_buf<...>) +``` + +`pto.txor` carries an extra scratch tile `%tmp` for A2/A3 interface compatibility (see [§1.7](#17-scratch-operands-and-a2a3-compatibility)): + +```mlir +pto.txor ins(%src0, %src1, %tmp : !pto.tile_buf<...>, !pto.tile_buf<...>, + !pto.tile_buf<...>) + outs(%dst : !pto.tile_buf<...>) +``` + +**Constraints:** + +- Tile element types must be integer types. +- `src0`, `src1`, and `dst` must have the same valid region. + +**Example:** + +```mlir +pto.tand ins(%a, %b : !pto.tile_buf, !pto.tile_buf) + outs(%c : !pto.tile_buf) +``` + +--- + +### 8.3 Tile-Scalar Bitwise and Shift Families + +Tile-scalar bitwise and shift families: + +| Op | Semantics | +|----|-----------| +| `pto.tands` | `dst = src & scalar` | +| `pto.tors` | `dst = src \| scalar` | +| `pto.txors` | `dst = src ^ scalar` | +| `pto.tshls` | `dst = src << scalar` | +| `pto.tshrs` | `dst = src >> scalar` | + +#### Common Forms + +For `pto.tands`, `pto.tors`, `pto.tshls`, and `pto.tshrs`: + +```mlir +pto. ins(%src, %scalar : !pto.tile_buf<...>, ) + outs(%dst : !pto.tile_buf<...>) +``` + +`pto.txors` carries an extra scratch tile `%tmp` for A2/A3 interface compatibility: + +```mlir +pto.txors ins(%src, %scalar, %tmp : !pto.tile_buf<...>, , + !pto.tile_buf<...>) + outs(%dst : !pto.tile_buf<...>) +``` + +**Constraints:** + +- Tile element types must be integer types. +- `src` and `dst` must have the same valid region. +- The scalar operand must be an integer-compatible shift / bitwise scalar. + +**Example:** + +```mlir +pto.tands ins(%a, %s : !pto.tile_buf, i32) + outs(%dst : !pto.tile_buf) +``` + + + +## 9. Type Conversion + +> **Category:** Element-wise type conversion +> **Pipeline:** PIPE_V + +This chapter documents the element-wise tile conversion instruction `pto.tcvt` and the rounding modes it uses. + +--- + +### `RoundMode` + +Rounding modes for `pto.tcvt`. + +| Value | Int | Description | +|-------|-----|-------------| +| `NONE` | 0 | No rounding. | +| `RINT` | 1 | Round to nearest integer. | +| `ROUND` | 2 | Round `f16` away from zero. | +| `FLOOR` | 3 | Round toward negative infinity. | +| `CEIL` | 4 | Round toward positive infinity. | +| `TRUNC` | 5 | Truncate toward zero. | +| `ODD` | 6 | Round to odd. | +| `CAST_RINT` | 7 | Cast with round-to-nearest (default). | + +**Attribute syntax:** `#pto` + +--- + +### `pto.tcvt` + +- **syntax:** +```mlir +pto.tcvt ins(%src : !pto.tile_buf<...>) + outs(%dst : !pto.tile_buf<...>) + {rmode = #pto} +``` +- **semantics:** `dst[i, j] = cast(src[i, j], rmode)` element-wise. + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `pto.tile_buf` | Source tile. | +| `dst` | `pto.tile_buf` | Destination tile (different element type). | +| `rmode` | `RoundModeAttr` | Default `CAST_RINT`. | + +**Constraints:** + +- `src`/`dst` must be shape/valid-region compatible. +- This reference does not define extra legality rules for the `(src, dst)` type pair. **Undefined behavior** on conversion pairs not supported by the target hardware; consult the A2/A3 and A5 hardware specs for legal pairs. + +**Example:** + +```mlir +pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + {rmode = #pto} +``` + + + +## 10. Broadcast and Expansion Operations + +> **Category:** Tile-local VEC broadcast and expansion compute +> **Pipeline:** PIPE_V + +This chapter documents the TileLib broadcast, row-expansion, and column-expansion families. These ops populate destination tiles by broadcasting one logical scalar across a larger region — either from a standalone scalar operand, one source value per destination row, or one source value per destination column. + +--- + +### 10.1 Scalar Broadcast: `pto.texpands` + +- **syntax:** +```mlir +pto.texpands ins(%scalar : ) + outs(%dst : !pto.tile_buf<...>) +``` +- **semantics:** `dst[i, j] = scalar` for every element inside `dst`'s valid region. + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `scalar` | signless integer / floating-point scalar | Scalar value broadcast into the destination tile. | +| `dst` | `pto.tile_buf` | Destination tile buffer. | + +**Constraints:** + +- The TileLib template is VEC-oriented and fills `dst.valid_shape`. +- The scalar type must be compatible with `dst.dtype`. + +**Example:** + +```mlir +pto.texpands ins(%scalar : f32) + outs(%dst : !pto.tile_buf) +``` + +--- + +### 10.2 Row-Wise Broadcast: `pto.trowexpand` + +- **syntax:** +```mlir +pto.trowexpand ins(%src : !pto.tile_buf<...>) + outs(%dst : !pto.tile_buf<...>) +``` +- **semantics:** `dst[row, col] = src[row, 0]`. + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `pto.tile_buf` | Source tile carrying one logical scalar per destination row. | +| `dst` | `pto.tile_buf` | Destination tile buffer. | + +**Constraints:** + +- `src` and `dst` must have the same number of valid rows. +- `src` must encode exactly one logical source value per destination row. +- Templates target row-major VEC layouts. + +**Example:** + +```mlir +pto.trowexpand ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) +``` + +--- + +### 10.3 Row-Wise Broadcast Arithmetic and Transform Families + +The row-expansion family combines a full tile `%src0` with a per-row scalar carrier `%src1`: + +| Op | Semantics | +|----|-----------| +| `pto.trowexpandadd` | `dst[row, col] = src0[row, col] + src1[row, 0]` | +| `pto.trowexpandsub` | `dst[row, col] = src0[row, col] - src1[row, 0]` | +| `pto.trowexpandmul` | `dst[row, col] = src0[row, col] * src1[row, 0]` | +| `pto.trowexpanddiv` | `dst[row, col] = src0[row, col] / src1[row, 0]` | +| `pto.trowexpandmax` | `dst[row, col] = max(src0[row, col], src1[row, 0])` | +| `pto.trowexpandmin` | `dst[row, col] = min(src0[row, col], src1[row, 0])` | +| `pto.trowexpandexpdif` | `dst[row, col] = exp(src0[row, col] - src1[row, 0])` | + +#### Common Syntax + +```mlir +pto. ins(%src0, %src1 : !pto.tile_buf<...>, !pto.tile_buf<...>) + outs(%dst : !pto.tile_buf<...>) +``` + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src0` | `pto.tile_buf` | Main source tile. | +| `src1` | `pto.tile_buf` | Tile carrying one logical scalar per destination row. | +| `dst` | `pto.tile_buf` | Destination tile buffer. | + +**Constraints:** + +- `src0` and `dst` must be shape/valid-region compatible. +- `src1` must provide one logical scalar per destination row. +- Templates target row-major VEC layouts. +- `pto.trowexpanddiv` and `pto.trowexpandexpdif` are floating-point-only. + +**Example:** + +```mlir +pto.trowexpandadd ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) +``` + +--- + +### 10.4 Column-Wise Broadcast: `pto.tcolexpand` + +- **syntax:** +```mlir +pto.tcolexpand ins(%src : !pto.tile_buf<...>) + outs(%dst : !pto.tile_buf<...>) +``` +- **semantics:** `dst[row, col] = src[0, col]`. + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `pto.tile_buf` | Source tile carrying one logical scalar per destination column. | +| `dst` | `pto.tile_buf` | Destination tile buffer. | + +**Constraints:** + +- `src` and `dst` must have the same number of valid columns. +- `src` must encode exactly one logical source value per destination column. +- Templates target row-major VEC layouts. + +**Example:** + +```mlir +pto.tcolexpand ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) +``` + +--- + +### 10.5 Column-Wise Broadcast Arithmetic and Transform Families + +The column-expansion family combines a full tile `%src0` with a per-column scalar carrier `%src1`: + +| Op | Semantics | +|----|-----------| +| `pto.tcolexpandadd` | `dst[row, col] = src0[row, col] + src1[0, col]` | +| `pto.tcolexpandsub` | `dst[row, col] = src0[row, col] - src1[0, col]` | +| `pto.tcolexpandmul` | `dst[row, col] = src0[row, col] * src1[0, col]` | +| `pto.tcolexpanddiv` | `dst[row, col] = src0[row, col] / src1[0, col]` | +| `pto.tcolexpandmax` | `dst[row, col] = max(src0[row, col], src1[0, col])` | +| `pto.tcolexpandmin` | `dst[row, col] = min(src0[row, col], src1[0, col])` | +| `pto.tcolexpandexpdif` | `dst[row, col] = exp(src0[row, col] - src1[0, col])` | + +#### Common Syntax + +```mlir +pto. ins(%src0, %src1 : !pto.tile_buf<...>, !pto.tile_buf<...>) + outs(%dst : !pto.tile_buf<...>) +``` + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src0` | `pto.tile_buf` | Main source tile. | +| `src1` | `pto.tile_buf` | Tile carrying one logical scalar per destination column. | +| `dst` | `pto.tile_buf` | Destination tile buffer. | + +**Constraints:** + +- `src0` and `dst` must be shape/valid-region compatible. +- `src1` must provide one logical scalar per destination column. +- Templates target row-major VEC layouts. +- `pto.tcolexpanddiv` and `pto.tcolexpandexpdif` are floating-point-only. + +**Example:** + +```mlir +pto.tcolexpandadd ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) +``` + + + +## 11. Selection Operations + +> **Category:** Tile-local VEC selection compute +> **Pipeline:** PIPE_V + +This chapter documents the TileLib selection families. These ops select between data sources under control of a packed predicate-mask tile. + +The mask tile carries packed predicate bytes in UB. Templates load predicate bits directly with predicate-load helpers such as `plds`, then use `vsel` to choose the data path. + +`pto.tsel` and `pto.tsels` carry an extra `%tmp` operand for A2/A3 interface compatibility (see [§1.7](#17-scratch-operands-and-a2a3-compatibility)). + +--- + +### 11.1 `pto.tsel` + +- **syntax:** +```mlir +pto.tsel ins(%mask, %src0, %src1, %tmp : + !pto.tile_buf<...>, !pto.tile_buf<...>, + !pto.tile_buf<...>, !pto.tile_buf<...>) + outs(%dst : !pto.tile_buf<...>) +``` +- **semantics:** `dst[i, j] = mask[i, j] ? src0[i, j] : src1[i, j]`. + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `pto.tile_buf` | Packed predicate-mask carrier. | +| `src0` | `pto.tile_buf` | Value selected when the predicate bit is true. | +| `src1` | `pto.tile_buf` | Value selected when the predicate bit is false. | +| `tmp` | `pto.tile_buf` | Scratch tile for A2/A3 interface compatibility. | +| `dst` | `pto.tile_buf` | Destination tile buffer. | + +**Constraints:** + +- `src0`, `src1`, and `dst` must have the same shape and valid region. +- The `tsel` template specializes the mask carrier as an `i8` tile with packed predicate bytes. + +**Example:** + +```mlir +pto.tsel ins(%mask, %a, %b, %tmp : + !pto.tile_buf, !pto.tile_buf, + !pto.tile_buf, !pto.tile_buf) + outs(%dst : !pto.tile_buf) +``` + +--- + +### 11.2 `pto.tsels` + +- **syntax:** +```mlir +pto.tsels ins(%mask, %src, %tmp, %scalar : + !pto.tile_buf<...>, !pto.tile_buf<...>, + !pto.tile_buf<...>, ) + outs(%dst : !pto.tile_buf<...>) +``` +- **semantics:** `dst[i, j] = mask[i, j] ? src[i, j] : scalar`. + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `pto.tile_buf` | Packed predicate-mask carrier. | +| `src` | `pto.tile_buf` | Source tile selected when the predicate bit is true. | +| `tmp` | `pto.tile_buf` | Scratch tile for A2/A3 interface compatibility. | +| `scalar` | signless integer / floating-point scalar | Scalar selected when the predicate bit is false. | +| `dst` | `pto.tile_buf` | Destination tile buffer. | + +**Constraints:** + +- `src` and `dst` must have the same shape and valid region. +- `tsels` accepts packed-mask carrier tiles with `i8`, `i16`, or `i32` element types. + +**Example:** + +```mlir +pto.tsels ins(%mask, %src, %tmp, %scalar : + !pto.tile_buf, !pto.tile_buf, + !pto.tile_buf, f16) + outs(%dst : !pto.tile_buf) +``` + + + +## 12. Fill and Padding Operations + +> **Category:** Tile-local fill, pad, and expansion materialization +> **Pipeline:** PIPE_V + +This chapter documents the TileLib fill / padding families. These ops preserve or materialize valid data and then synthesize the remaining destination region from the destination tile's padding policy. + +The destination tile's `pad` / `pad_value` configuration determines which value is written into the synthesized padding or expansion region. + +--- + +### 12.1 `pto.tfillpad` + +- **syntax:** +```mlir +pto.tfillpad ins(%src : !pto.tile_buf<...>) + outs(%dst : !pto.tile_buf<...>) +``` +- **semantics:** copy valid data from `src` into `dst`, then fill the remaining destination region according to `dst`'s pad policy. + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `pto.tile_buf` | Source tile. | +| `dst` | `pto.tile_buf` | Destination tile carrying the pad configuration. | + +**Constraints:** + +- Source and destination element types must be compatible. +- The destination tile must carry a meaningful pad configuration. +- This family is VEC-oriented. + +**Example:** + +```mlir +pto.tfillpad ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) +``` + +--- + +### 12.2 `pto.tfillpad_expand` + +- **syntax:** +```mlir +pto.tfillpad_expand ins(%src : !pto.tile_buf<...>) + outs(%dst : !pto.tile_buf<...>) +``` +- **semantics:** copy valid data from `src` into `dst`, then fill row/column expansion according to `dst`'s pad policy when the destination valid region or backing shape is larger than the source. + +**Parameter Table:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `pto.tile_buf` | Source tile. | +| `dst` | `pto.tile_buf` | Larger destination tile carrying the pad configuration. | + +**Constraints:** + +- `dst` may be larger than `src` in valid region or physical shape. +- The fill value is derived from `dst.pad_value`. +- A unified VEC-oriented template handles the supported element families. + +**Example:** + +```mlir +pto.tfillpad_expand ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) +``` diff --git a/docs/release/vpto-spec-v0.1.md b/docs/release/vpto-spec-v0.1.md new file mode 100644 index 000000000..a2949a485 --- /dev/null +++ b/docs/release/vpto-spec-v0.1.md @@ -0,0 +1,4885 @@ +# PTO micro Instruction Spec — Draft (A5) + +- v0.1: Doc Init + +[toc] + +--- + +## Part I: Architecture Overview + +### Overview + +This document defines the PTO micro Instruction, a compiler-internal and externally facing specification designed to represent vector compute kernels within the PTO architecture. Much like NVVM provides a robust IR for GPU architectures, the PTO micro Instruction serves as the direct bridge between high-level programming models and the underlying hardware ISA, providing a precise, low-level representation of vector workloads explicitly designed for the Ascend 950 architecture. + +#### Position in the Stack and Layer Modeled + +The PTO micro Instruction operates as a very low-level intermediate representation within the PTO compiler stack. It is uniquely designed to accurately and comprehensively express all architectural information of the Ascend 950 hardware. It specifically models the bare-metal vector execution layer, making hardware-specific capabilities and constraints, such as exact vector lane configurations, memory space hierarchies, and hardware-specific fusion semantics, fully transparent and controllable. + +#### Why External Developers Read or Author PTO micro Instruction + +While the majority of users will interact with the PTO architecture via higher-level frameworks, external developers may need to read or author PTO micro Instruction directly for several key reasons: + +- Custom Toolchain Development: build custom compiler frontends or domain-specific languages (DSLs) that target the Ascend 950 architecture with maximum hardware utilization. +- Performance Engineering: inspect the output of high-level compiler passes, verify fine-grained optimization behaviors, and pinpoint performance bottlenecks at the architectural level. +- Micro-Optimization: hand-author highly optimized, critical mathematical kernels using a stable, precise IR when higher-level abstractions cannot achieve the theoretical peak performance of the hardware. + +#### Relationship to CCE + +The PTO micro Instruction is designed to express the full semantic capabilities of the Compute Cube Engine (CCE), but with significant structural and pipeline advantages for compiler development. + +- Bypassing the C/Clang Pipeline: while CCE heavily relies on C/C++ extensions parsed by Clang, the PTO micro Instruction operates entirely independently of the C language frontend. By bypassing Clang AST generation and frontend processing, utilizing the PTO micro Instruction significantly reduces overall compilation time and memory overhead. +- Enhanced IR Verification: because the PTO micro Instruction is a strongly typed, SSA-based (Static Single Assignment) compiler IR rather than a C-wrapper API, it provides a much more rigorous and detailed IR verification process. Structural inconsistencies, invalid memory access patterns, and operand type mismatches are caught immediately with precise, explicit diagnostic feedback, providing developers with much higher visibility into kernel correctness than traditional CCE error reporting. + +#### Intended Audience + +This document is written for compiler engineers, library writers, and advanced performance architects. We expect the reader to have a working understanding of modern compiler infrastructure, specifically MLIR, the principles of Static Single Assignment (SSA) form, and a deep understanding of the vector-processing capabilities of the Ascend 950 architecture. + +### Getting Started + +The PTO micro Instruction is architected as a performance-critical layer within the compiler stack, specifically designed to exploit the **Decoupled Access-Execute** (DAE) nature of the Ascend 950 hardware. + +#### Hardware Pipeline Modeling + +The IR is structured to mirror the three primary hardware pipelines of the Ascend 950 architecture. Correct PTO micro Instruction authoring requires managing the interaction between these asynchronous units: + +**MTE2** (Memory Transfer Engine - Inbound): Responsible for moving data from Global Memory (GM) to the Unified Buffer (UB). + +**Vector Core** (Computation): The primary engine for executing SIMD operations on data stored in UB. + +**MTE3** (Memory Transfer Engine - Outbound): Responsible for moving processed data from UB back to GM. + +#### Architecture Detail: Vector Lane (VLane) + +The vector register is organized as **8 VLanes** of 32 bytes each. A VLane is the atomic unit for group reduction operations. + +``` +vreg (256 bytes total): +┌─────────┬─────────┬─────────┬─────┬─────────┬─────────┐ +│ VLane 0 │ VLane 1 │ VLane 2 │ ... │ VLane 6 │ VLane 7 │ +│ 32B │ 32B │ 32B │ │ 32B │ 32B │ +└─────────┴─────────┴─────────┴─────┴─────────┴─────────┘ +``` + +Elements per VLane by data type: + +| Data Type | Elements/VLane | Total Elements/vreg | +|-----------|---------------|-------------------| +| i8/u8 | 32 | 256 | +| i16/u16/f16/bf16 | 16 | 128 | +| i32/u32/f32 | 8 | 64 | +| i64/u64 | 4 | 32 | + +#### Memory and Synchronization Model + +The PTO micro Instruction enforces a strict memory hierarchy. The Unified Buffer (UB) is the only valid operand source for vector compute instructions. Consequently, the architecture of a PTO micro Instruction program is defined by the explicit management of data movement: + +**Address Space Isolation**: The IR uses `!pto.ptr` to distinguish between GM (`!pto.ptr`) and UB (`!pto.ptr`). The verifier ensures that vector compute operations do not access GM directly; data must first be moved into UB. + +**UB Capacity**: The Unified Buffer provides 256KB of on-chip SRAM (also referred to as "vecTile"). + +**Data Flow**: + +``` +┌─────────────────────────────────────────────┐ +│ Global Memory (GM) │ +│ (Off-chip HBM/DDR) │ +└─────────────────────┬───────────────────────┘ + │ DMA (MTE2 inbound / MTE3 outbound) +┌─────────────────────▼───────────────────────┐ +│ Unified Buffer (UB) │ +│ (On-chip SRAM, 256KB) │ +└─────────────────────┬───────────────────────┘ + │ Vector Load/Store (PIPE_V) +┌─────────────────────▼───────────────────────┐ +│ Vector Register File (VRF) │ +│ vreg (256B each) + mask (256-bit each) │ +└─────────────────────────────────────────────┘ +``` + +1. **GM → UB**: DMA transfer via MTE2 (`pto.copy_gm_to_ubuf`) +2. **UB → vreg**: Vector Load instructions (`pto.vlds`, `pto.vldx2`, etc.) +3. **vreg → vreg**: Compute instructions (`pto.vadd`, `pto.vmul`, etc.) +4. **vreg → UB**: Vector Store instructions (`pto.vsts`, `pto.vstx2`, etc.) +5. **UB → GM**: DMA transfer via MTE3 (`pto.copy_ubuf_to_gm`) + +**Load/Store Access Patterns**: + +For UB↔vreg data movement, besides contiguous load/store, the architecture provides rich access pattern support including strided access, pack/unpack, interleave/deinterleave, broadcast, upsample/downsample, channel split/merge, gather/scatter, and squeeze/expand operations. For detailed instruction syntax and distribution modes, refer to the [Vector Load/Store](#isa-03-vector-load-store) group in the ISA specification. + +#### Synchronization Model + +The Ascend 950 architecture employs a cluster-based design with a 1:2 ratio of Cube cores to Vector cores. The PTO micro Instruction provides multiple levels of synchronization to manage concurrent execution across pipelines and cores: + +**Inter-Core Synchronization (within a cluster):** + +Synchronization between cores within the same cluster is achieved via the core sync mechanism using `pto.set_intra_core` and `pto.wait_intra_core` operations. This enables coordination between Cube and Vector cores sharing the same cluster resources. + +**Vector Core Pipeline Synchronization:** + +Within a single core, multiple pipelines operate asynchronously: + +- **MTE2 (PIPE_MTE2)**: DMA copy-in from GM to UB +- **MTE3 (PIPE_MTE3)**: DMA copy-out from UB to GM +- **Vector Compute (PIPE_V)**: Vector ALU operations +- **Scalar (PIPE_S)**: Scalar unit running the kernel program + +Pipeline synchronization can be achieved through two mechanisms: + +1. **Flag/Event mechanism**: `pto.set_flag` and `pto.wait_flag` operations resolve Read-After-Write (RAW) and Write-After-Read (WAR) hazards between pipelines. + +2. **Buffer-ID mechanism**: `pto.get_buf` and `pto.rls_buf` provide finer-grained synchronization through buffer acquisition and release semantics for producer-consumer coordination. + +**Intra-Pipeline Memory Barriers (within `__VEC_SCOPE__`):** + +Within the vector execution scope, the hardware does not track UB address aliasing between reg↔UB accesses. When UB addresses overlap or alias between vector load/store operations, explicit memory barriers are required: + +```c +pto.mem_bar "VV_ALL" // All prior vector ops complete before subsequent +pto.mem_bar "VST_VLD" // All prior vector stores visible before subsequent loads +pto.mem_bar "VLD_VST" // All prior vector loads complete before subsequent stores +``` + +Without proper barriers, loads may see stale data or stores may be reordered incorrectly. + +#### Execution Scopes (__VEC_SCOPE__) + +`__VEC_SCOPE__` is the IR-level representation of a Vector Function (VF) launch. In the PTO architecture, it defines the hardware interface between the Scalar Unit and the Vector Thread. + +It is not a dedicated `pto` op. In the PTO micro Instruction, this scope is modeled as a specialized `scf.for` loop annotated with `llvm.loop.aivector_scope`. This gives the compiler a natural structural boundary for identifying the code block that must be lowered into a discrete VF hardware instruction sequence. + +**Scalar-Vector Interface:** + +The execution model follows non-blocking fork semantics: + +- Scalar invocation: the scalar processor invokes a vector thread by calling a VF. Once the launch command is issued, the scalar unit does not stall and continues executing subsequent instructions in the pipeline. +- Vector execution: after invocation, the vector thread independently fetches and executes the instructions defined within the VF scope. +- Parallelism: this decoupled execution allows the scalar and vector units to run in parallel, so the scalar unit can prepare addresses or manage control flow while the vector unit performs heavy SIMD computation. + +**Launch Mechanism And Constraints:** + +- Parameter buffering: all arguments required by the VF must be staged in hardware-specific buffers. +- Launch overhead: launching a VF incurs a latency of a few cycles. Very small VFs should account for this overhead because launch cost can rival useful computation time. + +**MLIR Representation:** + +```mlir +scf.for %dummy = %c0 to %c1 step %c1 { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} {llvm.loop.aivector_scope} +``` + +### Example: Abs + +```mlir +pto.set_loop2_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.copy_gm_to_ubuf %7, %2, %3, %3, %c0_i64, %c32_i64, %4, %c0_i64, %c0_i64, + %false, %c0_i64, %c128_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i64, i1, i64, i64, i64 + +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +scf.for %dummy = %c0 to %c1 step %c1 { + scf.for %lane = %c0 to %9 step %c64 { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %2[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %8[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } +} {llvm.loop.aivector_scope} + +pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] +pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_ubtoout %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop2_stride_ubtoout %c4096_i64, %c4096_i64 : i64, i64 +pto.copy_ubuf_to_gm %8, %14, %3, %3, %c0_i64, %c32_i64, %4, %c0_i64, %c128_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i64, i64 +``` + +### Scope + +This document is the interface specification centered on the `mlir::pto` dialect and the shared MLIR surface used alongside it in PTO micro Instruction programs. + +It only describes: + +- operation names +- operand and result lists +- operand and result types +- important attributes +- C-style semantics for each operation + +It does not describe lowering strategy. + +PTO micro Instruction source programs are not restricted to `pto` operations alone. In practice they also use shared MLIR dialect ops, most notably the full scalar operation surface of `arith` together with structured control-flow ops from `scf`, to express scalar constants, scalar arithmetic, type conversion, comparisons, and structured control flow around PTO vector or tile regions. These shared-dialect ops are part of the supported PTO micro Instruction source surface and should be regarded as part of PTO-ISA alongside `pto` dialect operations. + +### Shared MLIR Dialects + +- `arith`: the full scalar `arith` surface is supported in PTO micro Instruction programs, covering scalar integer, floating-point, boolean, and `index` operations. In current samples the most common uses are still constants, offset/bounds arithmetic, casts, compares, and selects. +- `scf`: structured control flow used to model counted loops, conditional regions, loop-carried state, and break-like control around PTO compute and data-movement ops. +- Shared dialect ops remain in standard MLIR form so that PTO analyses and backend passes can reason about control flow and scalar state without re-encoding them as PTO-specific instructions. + +### Core Types + +#### Element Types +`vreg`: `!pto.vreg` Fixed-width PTO micro Instruction vector type with total width exactly 256 bytes (2048 bits). `N` is the lane count, `T` is the element type, and `N * bitwidth(T) = 2048`. + +| Type | Bits | Description | +|------|------|-------------| +| `i8` / `s8` / `u8` | 8 | Signless/signed/unsigned 8-bit integer | +| `i16` / `s16` / `u16` | 16 | Signless/signed/unsigned 16-bit integer | +| `i32` / `s32` / `u32` | 32 | Signless/signed/unsigned 32-bit integer | +| `i64` / `s64` / `u64` | 64 | Signless/signed/unsigned 64-bit integer | +| `f16` | 16 | IEEE 754 half precision | +| `bf16` | 16 | Brain floating point | +| `f32` | 32 | IEEE 754 single precision | +| `f8e4m3` | 8 | FP8 (4-bit exponent, 3-bit mantissa) | +| `f8e5m2` | 8 | FP8 (5-bit exponent, 2-bit mantissa) | + +#### Address Space Conventions + +PTO micro Instruction memory operands use `!pto.ptr`. This specification models the following memory-space attributes: + +| Space | Interpretation | +|-------|----------------| +| `gm` | Global Memory (GM), off-chip HBM/DDR storage | +| `ub` | Unified Buffer (UB), on-chip vector buffer | + +Typical pointer construction and pointer arithmetic follow the same `!pto.ptr<..., space>` form: + +```mlir +%0 = pto.castptr %c0 : i64 -> !pto.ptr +%1 = pto.addptr %0, %c1024 : !pto.ptr -> !pto.ptr +``` + +#### `!pto.ptr` + +`!pto.ptr` is the typed pointer form used for explicit memory operands in PTO micro Instruction. + +- `T` is the element type associated with the pointed-to storage. +- `space` is the memory domain, typically `gm` or `ub` in this specification. +- A `pto.ptr` value carries an address plus its element-type / memory-space interpretation, but it does not carry tensor shape or stride metadata by itself. +- Tensor semantics are introduced separately through view-building operations such as `pto.make_tensor_view`. +- Pointer arithmetic is element-based rather than byte-based. + +Typical examples: + +- `!pto.ptr` +- `!pto.ptr` +- `!pto.ptr` + +#### Pointer Operations + +#### `pto.castptr` + +- **syntax:** `%result = pto.castptr %addr : i64 -> !pto.ptr` +- **semantics:** Reinterpret a scalar address value as a typed PTO pointer in the target memory space. + +```c +result = (ptr)addr; +``` + +`pto.castptr` is a pointer-construction operation. It does not perform data movement and does not by itself imply any load/store side effect. + +#### `pto.addptr` + +- **syntax:** `%result = pto.addptr %ptr, %offset : !pto.ptr -> !pto.ptr` +- **semantics:** Compute a new pointer by advancing the base pointer by an element offset. + +```c +result = ptr + offset; // offset counted in elements, not bytes +``` + +`pto.addptr` preserves both the element type `T` and the memory-space tag `space`. + +#### Pointer-Based Vector Access Example + +The following lowered-style fragment shows how typed PTO pointers flow through pointer construction, pointer arithmetic, structured control flow, and PTO memory ops: + +```mlir +%0 = pto.castptr %c0 : i64 -> !pto.ptr +%1 = pto.addptr %0, %c1024 : !pto.ptr -> !pto.ptr +scf.for %arg2 = %c0 to %c1 step %c1 { + %16 = scf.for %arg3 = %c0 to %11 step %c64 iter_args(%arg4 = %12) -> (i32) { + %mask, %scalar_out = pto.plt_b32 %arg4 : i32 -> !pto.mask, i32 + %17 = pto.vlds %1[%arg3] : !pto.ptr -> !pto.vreg<64xf32> + %18 = pto.vabs %17, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %18, %10[%arg3], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %scalar_out : i32 + } +} {llvm.loop.aivector_scope} +``` + +In this pattern, `pto.castptr` materializes a typed UB pointer, `pto.addptr` shifts the base by 1024 `f32` elements, and the subsequent `[%arg3]` indexing on `pto.vlds` / `pto.vsts` applies an additional element offset relative to that base. + +#### Special Types + +#### `!pto.mask` + +`!pto.mask` models an A5 predicate register (256-bit), not an integer vector. + +**Mask Granularity:** + +The mask is 256 bits in length, where each bit controls 1 byte of data. This means mask granularity varies by element type: + +| Element Type | Bits/Element | Mask Bits per Element | +|--------------|--------------|----------------------| +| `f32`/`i32` | 32 | 4 bits | +| `f16`/`bf16`/`i16` | 16 | 2 bits | +| `f8`/`i8` | 8 | 1 bit | + +**Predication Behavior (Zero-Merge):** + +The native hardware predication mode is **ZEROING** — inactive lanes produce zero: + +```c +dst[i] = mask[i] ? op(src0[i], src1[i]) : 0 // ZEROING mode +``` + +```mlir +// Predicated add: inactive lanes produce zero +%mask = pto.pset_b32 "PAT_VL32" : !pto.mask // first 32 lanes active +%result = pto.vcmp %a, %b, %mask, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +``` + +```mlir +// Compare and select: generate mask from comparison, use for conditional select +%mask = pto.vcmp %lhs, %rhs, %seed, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +%out = pto.vsel %x, %y, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +#### `!pto.align` + +`!pto.align` models the A5 vector-align carrier state. It is not payload data. + +```mlir +%align = pto.vldas %ub : !pto.ptr -> !pto.align +%vec, %align_out, %base_out = pto.vldus %ub, %align : !pto.ptr, !pto.align -> !pto.vreg<64xf32>, !pto.align, !pto.ptr +``` + +--- + +## Part II: Notation Convention + +This section defines the MLIR syntax patterns and C-style semantic notation used throughout the ISA reference (Part III). + +### MLIR Op Syntax Patterns + +All PTO micro Instruction operations follow standard MLIR syntax. The common patterns are: + +**Unary (one vector in, one vector out):** + +```mlir +%result = pto. %input : !pto.vreg -> !pto.vreg +``` + +**Binary (two vectors in, one vector out):** + +```mlir +%result = pto. %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg +``` + +**Vec-Scalar (one vector + one scalar in, one vector out):** + +```mlir +%result = pto. %input, %scalar : !pto.vreg, T -> !pto.vreg +``` + +**Load (memory to register):** + +```mlir +%result = pto.vlds %source[%offset] {dist = "DIST"} : !pto.ptr -> !pto.vreg +``` + +**Store (register to memory):** + +```mlir +pto.vsts %value, %destination[%offset] {dist = "DIST"} : !pto.vreg, !pto.ptr +``` + +**Dual Load (one load, two results — deinterleave):** + +```mlir +%low, %high = pto.vldx2 %source[%offset], "DIST" : !pto.ptr, index -> !pto.vreg, !pto.vreg +``` + +**Dual Store (two inputs, one interleaved store):** + +```mlir +pto.vstx2 %low, %high, %dest[%offset], "DIST", %mask : !pto.vreg, !pto.vreg, !pto.ptr, index, !pto.mask +``` + +**Compare (two vectors + seed mask in, mask out):** + +```mlir +%mask = pto.vcmp %src0, %src1, %seed, "CMP_MODE" : !pto.vreg, !pto.vreg, !pto.mask -> !pto.mask +``` + +**Conversion (one vector in, different-typed vector out):** + +```mlir +%result = pto.vcvt %input {round_mode = "ROUND_R", sat = "RS_ENABLE", part = "PART_EVEN"} : !pto.vreg -> !pto.vreg +``` + +**Predicate construction:** + +```mlir +%mask = pto.pset_b32 "PAT_ALL" : !pto.mask +%tail = pto.pge_b32 "PAT_VL16" : !pto.mask +``` + +**Sync operations:** + +```mlir +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +**Pointer construction and arithmetic:** + +```mlir +%ptr = pto.castptr %addr : i64 -> !pto.ptr +%ptr2 = pto.addptr %ptr, %offset : !pto.ptr -> !pto.ptr +``` + +### Shared Dialect Syntax Patterns + +PTO micro Instruction programs may interleave PTO ops with standard MLIR `arith` and `scf` ops. +The examples below emphasize common index-heavy patterns, but `arith` support is not limited to index arithmetic. + +**Scalar / index constant:** + +```mlir +%c0 = arith.constant 0 : index +%zero = arith.constant 0.0 : f32 +``` + +**Scalar arithmetic (integer / float / boolean-style bitwise):** + +```mlir +%sum_i = arith.addi %lhs_i, %rhs_i : i32 +%sum_f = arith.addf %lhs_f, %rhs_f : f32 +%bits = arith.andi %flags0, %flags1 : i32 +``` + +**Scalar compare and select:** + +```mlir +%cond = arith.cmpi eq, %lhs, %rhs : index +%bound = arith.select %cond, %a, %b : index +``` + +**Counted loop with loop-carried values:** + +```mlir +%result = scf.for %iv = %lb to %ub step %step + iter_args(%acc = %init) -> (index) { + %next = arith.addi %acc, %iv : index + scf.yield %next : index +} +``` + +**Structured conditional region:** + +```mlir +%selected = scf.if %cond -> (index) { + scf.yield %then_value : index +} else { + scf.yield %else_value : index +} +``` + +**Structured while loop:** + +```mlir +%state:2 = scf.while (%iv = %c0, %alive = %true) : (index, i1) -> (index, i1) { + %keep_going = arith.cmpi slt, %iv, %limit : index + scf.condition(%keep_going) %iv, %alive : index, i1 +} do { +^bb0(%iv_in: index, %alive_in: i1): + %iv_next = arith.addi %iv_in, %c1 : index + scf.yield %iv_next, %alive_in : index, i1 +} +``` + +### C-Style Semantics Convention + +For each ISA operation in Part III, semantics are expressed as C code. The convention: + +```c +// Vector register contents as arrays: +T dst[N]; // destination +T src0[N]; // first source +T src1[N]; // second source (binary ops) +T scalar; // scalar operand (vec-scalar ops) +int mask[N]; // per-lane predicate (0 or 1) + +// N = lane count determined by type: +// N = 256 for i8/u8 +// N = 128 for i16/u16/f16/bf16 +// N = 64 for i32/u32/f32 +// N = 32 for i64/u64 +``` + +**Example — pto.vadd semantics:** + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] + src1[i]; +``` + +**Example — pto.vcgadd (group reduction per VLane) semantics:** + +```c +int K = N / 8; // elements per VLane +for (int g = 0; g < 8; g++) { + T sum = 0; + for (int i = 0; i < K; i++) + sum += src[g*K + i]; + dst[g*K] = sum; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +### Template Placeholder Conventions + +| Placeholder | Meaning | +|-------------|---------| +| `"SRC_PIPE"`, `"DST_PIPE"` | Pipeline identifiers: `"PIPE_MTE2"`, `"PIPE_V"`, `"PIPE_MTE3"` | +| `"EVENT_ID"` | Event identifier: `"EVENT_ID0"` etc. | +| `"DIST"` | Distribution mode string (see the relevant load/store ISA group in Part III) | +| `"CMP_MODE"` | Compare predicate: `eq \| ne \| lt \| le \| gt \| ge` | +| `"ROUND_MODE"` | Rounding mode: `ROUND_R \| ROUND_A \| ROUND_F \| ROUND_C \| ROUND_Z` | +| `"SAT_MODE"` | Saturation: `RS_ENABLE \| RS_DISABLE` | +| `"PART_MODE"` | Half selector: `PART_EVEN \| PART_ODD` | +| `"PAT_*"` | Predicate pattern literal | +| `T` | Element type (f32, f16, bf16, i32, i16, i8, etc.) | +| `N` | Lane count (`N * bitwidth(T) = 2048`) | + +--- + +## Part III: ISA Instruction Reference — Summary + +This section provides a categorized overview of all PTO micro Instruction operations plus the shared MLIR `arith` and `scf` ops that may appear in PTO micro Instruction programs. Detailed documentation for each group is included later in this merged document. + +--- + +## Instruction Groups + +| # | Group | Description | Count | Details | +|---|-------|-------------|-------|---------| +| 1 | [Pipeline Sync](#isa-01-pipeline-sync) | Intra-core pipeline synchronization | 5 | `pto.set_flag`, `pto.wait_flag`, `pto.pipe_barrier`, `pto.get_buf`, `pto.rls_buf` | +| 2 | [DMA Copy Programming](#isa-02-dma-copy) | DMA configuration and transfer between GM↔UB | 9 | `pto.set_loop*_stride_*`, `pto.set_loop_size_*`, `pto.copy_gm_to_ubuf`, `pto.copy_ubuf_to_ubuf`, `pto.copy_ubuf_to_gm` | +| 3 | [Vector Load/Store](#isa-03-vector-load-store) | UB↔vreg data movement with various access patterns | ~20 | `pto.vlds`, `pto.vldx2`, `pto.vgather2`, `pto.vsts`, `pto.vstx2`, `pto.vscatter`, etc. | +| 4 | [Predicate Load/Store](#isa-04-predicate-load-store) | UB↔mask register movement | 7 | `pto.plds`, `pto.pld`, `pto.pldi`, `pto.psts`, `pto.pst`, `pto.psti`, `pto.pstu` | +| 5 | [Materialization & Predicate Ops](#isa-05-materialization-predicate) | Scalar broadcast, predicate generation and manipulation | ~17 | `pto.vbr`, `pto.vdup`, `pto.pset_b*`, `pto.pge_b*`, `pto.plt_b*`, `pto.ppack`, `pto.punpack`, `pto.pnot`, `pto.psel`, etc. | +| 6 | [Unary Vector Ops](#isa-06-unary-vector-ops) | Single-input element-wise operations | 9 | `pto.vabs`, `pto.vexp`, `pto.vln`, `pto.vsqrt`, `pto.vrec`, `pto.vrelu`, `pto.vnot`, `pto.vbcnt`, `pto.vcls` | +| 7 | [Binary Vector Ops](#isa-07-binary-vector-ops) | Two-input element-wise operations | 13 | `pto.vadd`, `pto.vsub`, `pto.vmul`, `pto.vdiv`, `pto.vmax`, `pto.vmin`, `pto.vand`, `pto.vor`, `pto.vxor`, `pto.vshl`, `pto.vshr`, `pto.vaddc`, `pto.vsubc` | +| 8 | [Vec-Scalar Ops](#isa-08-vec-scalar-ops) | Vector-scalar operations | 8 | `pto.vadds`, `pto.vmuls`, `pto.vmaxs`, `pto.vmins`, `pto.vlrelu`, `pto.vshls`, `pto.vshrs`, `pto.vaddcs`, `pto.vsubcs` | +| 9 | [Conversion Ops](#isa-09-conversion-ops) | Type conversion with rounding/saturation control | 2 | `pto.vcvt`, `pto.vtrc` | +| 10 | [Reduction Ops](#isa-10-reduction-ops) | Vector reductions | 3 | `pto.vcadd`, `pto.vcmax`, `pto.vcmin` | +| 11 | [Compare & Select](#isa-11-compare-select) | Comparison and conditional selection | 5 | `pto.vcmp`, `pto.vcmps`, `pto.vsel`, `pto.vselr`, `pto.vselrv2` | +| 12 | [Data Rearrangement](#isa-12-data-rearrangement) | In-register data movement and permutation | 4 | `pto.vintlv`, `pto.vdintlv`, `pto.vintlvv2`, `pto.vdintlvv2` | +| 13 | [DSA/SFU Ops](#isa-13-dsa-sfu-ops) | Specialized ops, index generation, and sorting helpers | 5 | `pto.vmull`, `pto.vmula`, `pto.vci`, `pto.vbitsort`, `pto.vmrgsort4` | +| 14 | [Arith (Shared MLIR Dialect)](#isa-14-shared-arith) | Full scalar `arith` surface used around PTO ops; the companion page lists categories and representative examples | all scalar ops | `arith.constant`, `arith.addi`, `arith.addf`, `arith.cmpi`, `arith.cmpf`, `arith.select`, `arith.index_cast`, `arith.extsi`, `arith.trunci`, `arith.andi`, `arith.shli`, etc. | +| 15 | [SCF (Shared MLIR Dialect)](#isa-15-shared-scf) | Structured loops, branches, and loop-carried state around PTO regions | 5 | `scf.for`, `scf.if`, `scf.while`, `scf.condition`, `scf.yield` | + +--- + +## Detailed ISA Group Reference + +This section inlines the 15 ISA group documents so the architectural overview, notation, summary table, and per-group semantics can be read in a single file. + + + +### 1. Pipeline Synchronization + +> **Category:** Synchronization primitives for coordinating pipeline execution +> **Pipelines:** MTE2 (GM→UB), PIPE_V (Vector), MTE3 (UB→GM) + +The PTO micro Instruction model operates on the Ascend 950's **Decoupled Access-Execute** architecture. The MTE and Vector pipelines run asynchronously, requiring explicit synchronization to prevent data hazards. + +--- + +#### Intra-Core Pipeline Sync + +These ops coordinate data flow between pipelines within a single vector core. + +##### `pto.set_flag` + +- **syntax:** `pto.set_flag["SRC_PIPE", "DST_PIPE", "EVENT_ID"]` +- **semantics:** Signal event from source pipe to destination pipe. + +```c +set_flag(src_pipe, dst_pipe, event_id); +``` + +**Example:** After MTE2 completes GM→UB transfer, signal Vector pipe: +```mlir +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +--- + +##### `pto.wait_flag` + +- **syntax:** `pto.wait_flag["SRC_PIPE", "DST_PIPE", "EVENT_ID"]` +- **semantics:** Block destination pipe until source pipe signals event. + +```c +wait_flag(src_pipe, dst_pipe, event_id); +``` + +**Example:** Vector pipe waits for MTE2 data to arrive: +```mlir +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +--- + +##### `pto.pipe_barrier` + +- **syntax:** `pto.pipe_barrier "PIPE_*"` +- **semantics:** Drain all pending ops in the specified pipe. All previously issued operations on that pipe complete before any subsequent operation begins. + +```c +pipe_barrier(pipe); +``` + +**Pipe identifiers:** `PIPE_MTE2`, `PIPE_V`, `PIPE_MTE3` + +**Example:** Two back-to-back `copy_ubuf_to_gm` calls writing to the same GM address. Without a barrier, MTE3 may reorder them and the final GM value is non-deterministic: + +```mlir +// Both stores target the same GM address — order matters! +pto.copy_ubuf_to_gm %ub_partial_0, %gm_result, ... +// Without pipe_barrier, MTE3 could execute the second copy before the first +// completes, producing a non-deterministic result at %gm_result. +pto.pipe_barrier "PIPE_MTE3" +// After barrier: first copy is guaranteed complete. Second copy overwrites deterministically. +pto.copy_ubuf_to_gm %ub_partial_1, %gm_result, ... +``` + +--- + +##### `pto.get_buf` + +- **syntax:** `pto.get_buf "PIPE_*", %buf_id, %mode : i64, i64` +- **semantics:** Acquire buffer slot for inter-pipeline double-buffering coordination. + +```c +get_buf(pipe, buf_id, mode); +``` + +--- + +##### `pto.rls_buf` + +- **syntax:** `pto.rls_buf "PIPE_*", %buf_id, %mode : i64, i64` +- **semantics:** Release buffer slot to allow other pipeline to proceed. + +```c +rls_buf(pipe, buf_id, mode); +``` + +--- + +##### `pto.mem_bar` + +- **syntax:** `pto.mem_bar "BARRIER_TYPE"` +- **semantics:** Intra-vector-pipe memory fence within `__VEC_SCOPE__`. Required when UB addresses alias between vector load/store operations. + +```c +mem_bar(barrier_type); +``` + +**Barrier types:** + +| Type | Semantics | +|------|-----------| +| `VV_ALL` | All prior vector ops complete before subsequent | +| `VST_VLD` | All prior vector stores visible before subsequent loads | +| `VLD_VST` | All prior vector loads complete before subsequent stores | + +**Example:** Ensure stores are visible before loads to same UB region: +```mlir +pto.vsts %v0, %ub[%c0] : !pto.vreg<64xf32>, !pto.ptr +pto.mem_bar "VST_VLD" +%v1 = pto.vlds %ub[%c0] : !pto.ptr -> !pto.vreg<64xf32> +``` + +--- + +#### Intra-Core Sync Patterns & Examples + +##### Example 1: `set_flag` / `wait_flag` (Explicit Events) + +Each cross-pipeline data dependency requires an explicit signal/wait pair. The programmer must manually insert `set_flag` after the producer and `wait_flag` before the consumer. + +```mlir +// ─── Stage 1: MTE2 loads data from GM into UB ─── +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, ... + +// MTE2 signals: "UB data is ready for Vector pipe" +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +// ─── Stage 2: Vector pipe consumes UB data ─── +// Vector waits until MTE2's signal arrives +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +scf.for %dummy = %c0 to %c1 step %c1 { + %v = pto.vlds %ub_ptr[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} {llvm.loop.aivector_scope} + +// Vector signals: "UB output is ready for MTE3" +pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + +// ─── Stage 3: MTE3 stores result from UB back to GM ─── +// MTE3 waits until Vector's signal arrives +pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + +pto.copy_ubuf_to_gm %ub_out, %gm_out, ... +``` + +**Key property:** Every cross-pipeline edge is an explicit `(set_flag, wait_flag)` pair. Simple for straight-line code, but gets verbose in loops (see Example 3). + +--- + +##### Example 2: `get_buf` / `rls_buf` (Resource-Based) + +Instead of naming events, each pipeline declares when it **acquires** (`get_buf`) and **releases** (`rls_buf`) a shared UB buffer. Cross-pipeline RAW/WAR dependencies are resolved implicitly by program order — if MTE2 releases `buf_A` and Vector later acquires `buf_A`, the hardware ensures the acquire cannot proceed until the release completes. + +```mlir +// ─── Stage 1: MTE2 loads data into UB ─── +// MTE2 acquires ub_ptr — blocks if Vector hasn't released it from a prior iteration +pto.get_buf "PIPE_MTE2", %bufid_ub_ptr, %mode : i64, i64 +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, ... +// MTE2 done writing ub_ptr — release it so Vector can consume +pto.rls_buf "PIPE_MTE2", %bufid_ub_ptr, %mode : i64, i64 + +// ─── Stage 2: Vector computation ─── +// Vector acquires ub_ptr (input) — blocks until MTE2 releases it (RAW: MTE2 write → V read) +pto.get_buf "PIPE_V", %bufid_ub_ptr, %mode : i64, i64 +// Vector acquires ub_out (output) — blocks until MTE3 releases it from a prior iteration (WAR: MTE3 read → V write) +pto.get_buf "PIPE_V", %bufid_ub_out, %mode : i64, i64 + +scf.for %dummy = %c0 to %c1 step %c1 { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub_ptr[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} {llvm.loop.aivector_scope} + +// Vector done reading ub_ptr — release so MTE2 can reuse it in next iteration +pto.rls_buf "PIPE_V", %bufid_ub_ptr, %mode : i64, i64 +// Vector done writing ub_out — release so MTE3 can consume +pto.rls_buf "PIPE_V", %bufid_ub_out, %mode : i64, i64 + +// ─── Stage 3: MTE3 stores result to GM ─── +// MTE3 acquires ub_out — blocks until Vector releases it (RAW: V write → MTE3 read) +pto.get_buf "PIPE_MTE3", %bufid_ub_out, %mode : i64, i64 +pto.copy_ubuf_to_gm %ub_out, %gm_out, ... +// MTE3 done reading ub_out — release so Vector can reuse it in next iteration +pto.rls_buf "PIPE_MTE3", %bufid_ub_out, %mode : i64, i64 +``` + +**Key property:** No event IDs needed. Dependencies are implicit from program order of `get_buf`/`rls_buf` on the same buffer ID. This becomes much more convenient in multi-iteration loops (see Example 3). + +--- + +##### Example 3: Ping/Pong Double-Buffering Loop + +Double-buffering overlaps DMA and compute by using two UB buffers alternately. All three stages (MTE2, Vector, MTE3) appear in the **same iteration** — the hardware pipelines them across iterations because different iterations operate on different buffers (`buf[i%2]`). + +###### Event ID scheme (`set_flag` / `wait_flag`) + +With 2 ping/pong buffers and 2 pipeline pairs (MTE2↔V, V↔MTE3), `set_flag`/`wait_flag` needs **8 event IDs** = 2 pipe-pairs × 2 buffers × (forward + reverse): + +**MTE2 ↔ Vector (input buffers):** + +| Event ID | Direction | Purpose | +|----------|-----------|---------| +| `EVT_IN_FWD_0` | MTE2 → V | RAW: buf_in[0] data ready | +| `EVT_IN_FWD_1` | MTE2 → V | RAW: buf_in[1] data ready | +| `EVT_IN_REV_0` | V → MTE2 | WAR: Vector done reading buf_in[0] | +| `EVT_IN_REV_1` | V → MTE2 | WAR: Vector done reading buf_in[1] | + +**Vector ↔ MTE3 (output buffers):** + +| Event ID | Direction | Purpose | +|----------|-----------|---------| +| `EVT_OUT_FWD_0` | V → MTE3 | RAW: buf_out[0] result ready | +| `EVT_OUT_FWD_1` | V → MTE3 | RAW: buf_out[1] result ready | +| `EVT_OUT_REV_0` | MTE3 → V | WAR: MTE3 done reading buf_out[0] | +| `EVT_OUT_REV_1` | MTE3 → V | WAR: MTE3 done reading buf_out[1] | + +###### 3a. `set_flag` / `wait_flag` version + +```mlir +// ═══ Pre-loop: prime ALL reverse-dependency signals ═══ +// Both input and output buffers start unused. We must pre-send +// reverse-dep signals so the first iteration's wait_flags don't deadlock. +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_0"] // ◀ PRIME: buf_in[0] "free" +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_1"] // ◀ PRIME: buf_in[1] "free" +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_0"] // ◀ PRIME: buf_out[0] "free" +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_1"] // ◀ PRIME: buf_out[1] "free" + +scf.for %i = %c0 to %N step %c1 { + // ── All 3 stages in same iteration, indexed by i%2 ── + // %pp = i % 2 (ping/pong selector for buffer & event IDs) + + // ── MTE2: load tile[i] into buf_in[i%2] ── + // WAR: wait until Vector has released buf_in[i%2] from iteration i-2 + pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{pp}"] + pto.copy_gm_to_ubuf %gm_ptr[%i], %ub_in[%pp], ... + // RAW: signal Vector that buf_in[i%2] data is ready + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVT_IN_FWD_{pp}"] + + // ── Vector: compute buf_in[i%2] → buf_out[i%2] ── + // RAW: wait for MTE2 to finish loading buf_in[i%2] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVT_IN_FWD_{pp}"] + // WAR: wait for MTE3 to finish reading buf_out[i%2] from iteration i-2 + pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{pp}"] + scf.for %dummy = %c0 to %c1 step %c1 { + %v = pto.vlds %ub_in[%pp][%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%pp][%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } {llvm.loop.aivector_scope} + // WAR: tell MTE2 "done reading buf_in[i%2]" + pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{pp}"] + // RAW: tell MTE3 "buf_out[i%2] result ready" + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVT_OUT_FWD_{pp}"] + + // ── MTE3: store result from buf_out[i%2] to GM ── + // RAW: wait for Vector to finish writing buf_out[i%2] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVT_OUT_FWD_{pp}"] + pto.copy_ubuf_to_gm %ub_out[%pp], %gm_out[%i], ... + // WAR: tell Vector "done reading buf_out[i%2]" + pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{pp}"] +} + +// ═══ Post-loop: drain — match every pre-loop prime with a wait ═══ +// Each priming set_flag must be paired. The last loop iteration's +// set_flags are consumed by wait_flags that will never fire inside the +// loop (there is no iteration i+2). Drain them here. +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{(N-1)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{(N-2)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{(N-1)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{(N-2)%2}"] // ◀ DRAIN +``` + +**What `set_flag`/`wait_flag` requires outside the loop:** +- **Before the loop (4 × `set_flag`):** Prime every reverse-dependency event ID — one per buffer per pipe-pair. Without this, the first iteration's `wait_flag` for reverse deps would deadlock (no signal was ever sent). +- **After the loop (4 × `wait_flag`):** Drain the matching reverse-dep signals from the last iterations. Every `set_flag` must be paired with a `wait_flag` — the last loop iterations produce signals that no subsequent iteration consumes, so they must be drained explicitly. + +###### 3b. `get_buf` / `rls_buf` version + +Same ping/pong double-buffering, but **no pre-loop priming or post-loop draining needed.** Buffer acquire/release semantics handle everything. + +```mlir +scf.for %i = %c0 to %N step %c1 { + // %pp = i % 2 (ping/pong selector) + + // ── MTE2: load tile[i] into buf[i%2] ── + // Acquires buf[i%2] — on first iteration, buffer is free so proceeds immediately. + // On later iterations, blocks until Vector releases buf[i%2] (WAR: automatic). + pto.get_buf %bufid_buf[%pp], "PIPE_MTE2" + pto.copy_gm_to_ubuf %gm_ptr[%i], %ub_buf[%pp], ... + pto.rls_buf %bufid_buf[%pp], "PIPE_MTE2" + + // ── Vector: compute on buf[i%2] ── + // Acquires buf[i%2] — blocks until MTE2 releases it (RAW: automatic) + pto.get_buf %bufid_buf[%pp], "PIPE_V" + pto.get_buf %bufid_out[%pp], "PIPE_V" + scf.for %dummy = %c0 to %c1 step %c1 { + %v = pto.vlds %ub_buf[%pp][%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%pp][%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } {llvm.loop.aivector_scope} + // Release buf[i%2] — MTE2 can reuse in iteration i+2 (WAR resolved) + pto.rls_buf %bufid_buf[%pp], "PIPE_V" + pto.rls_buf %bufid_out[%pp], "PIPE_V" + + // ── MTE3: store result ── + // Acquires out[i%2] — blocks until Vector releases it (RAW: automatic) + pto.get_buf %bufid_out[%pp], "PIPE_MTE3" + pto.copy_ubuf_to_gm %ub_out[%pp], %gm_out[%i], ... + pto.rls_buf %bufid_out[%pp], "PIPE_MTE3" +} +// No post-loop drain needed — last rls_buf completes the pipeline. +``` + +**No priming, no draining, no event IDs.** The acquire/release protocol on buffer IDs indexed by `i%2` implicitly resolves all cross-pipeline dependencies: +- **RAW** (MTE2→V): Vector's `get_buf` blocks until MTE2's `rls_buf` on `buf[i%2]` +- **WAR** (V→MTE2): MTE2's `get_buf` in iteration `i+2` blocks until Vector's `rls_buf` in iteration `i` (same buffer) +- **First iteration:** Buffer is initially free, so `get_buf` proceeds without blocking — no priming needed + +--- + +#### Comparison Summary + +| Aspect | `set_flag` / `wait_flag` | `get_buf` / `rls_buf` | +|--------|--------------------------|------------------------| +| Dependency model | Explicit event signals | Implicit via buffer acquire/release | +| IDs per pipe-pair | **8** = 2 buffers × 2 dirs × 2 (fwd+rev) | 1 fwd + 1 rev per buffer (shared global pool) | +| Total HW IDs | 8 per pipe-pair, grows with buffers | **32 global** across all pipes | +| Reverse (WAR) deps | Extra `set_flag`/`wait_flag` pair per buffer | Handled automatically | +| Pre-loop setup | `set_flag` to prime each reverse dep | None | +| Post-loop teardown | `wait_flag` to drain all primed signals | None | +| Straight-line code | Simple, clear | Slightly more verbose (bracket each stage) | +| Ping/pong loops | 8 event IDs + 4 prime + 4 drain | Same pattern, no overhead | +| Best used for | Simple pipelines, fine-grained control | Double/multi-buffering, complex loops | + +--- + +#### Inter-Core Sync + +> **Note:** Inter-core sync is only needed for **mixed Cube+Vector tasks** where Cube produces data that Vector consumes (or vice versa). **Vec-only tasks can ignore this section entirely.** + +These ops coordinate execution across the Cube block and Vector subblocks within a cluster. Each core cluster consists of **1 Cube block : 2 Vector subblocks**, each with its own **SU (Sequencer Unit)** running independent instruction streams. + +``` +Core Cluster (1:2 ratio) +┌─────────────────────────────────────────────┐ +│ ┌──────────────┐ ┌──────────────┐ │ +│ │ AIC (Cube) │ │ AIV0 (Vec) │ │ +│ │ ┌────────┐ │ │ ┌────────┐ │ │ +│ │ │ SU │──┼────┼──│ SU │ │ │ +│ │ └────────┘ │ │ └────────┘ │ │ +│ │ CUBE pipe │ │ MTE2/V/MTE3 │ │ +│ │ L0C buffer │ │ UB (256KB) │ │ +│ └──────────────┘ └──────────────┘ │ +│ ┌──────────────┐ │ +│ │ AIV1 (Vec) │ │ +│ │ ┌────────┐ │ │ +│ │ │ SU │ │ │ +│ │ └────────┘ │ │ +│ │ MTE2/V/MTE3 │ │ +│ │ UB (256KB) │ │ +│ └──────────────┘ │ +└─────────────────────────────────────────────┘ +``` + +##### Platform Comparison + +| Aspect | A2A3 (Ascend 910) | A5 (Ascend 950) | +|--------|-------------------|-----------------| +| **Signal op** | `set_cross_core` (mode2) | `set_intra_block` | +| **Wait op** | `wait_flag_dev` | `wait_intra_core` | +| **Wait behavior** | SU-level blocking (entire core stalls) | Per-pipeline (only named pipe stalls) | +| **Semaphore pool** | 16 IDs per cluster, 4-bit counter | 16 IDs, but 32-ID address space (see below) | +| **C→V** | **Broadcast**: one `set` reaches both AIV0+AIV1 | **1:1**: separate `set` per subblock required | +| **V→C** | **Reduce**: Cube waits for both subblocks in one `wait` | **1:1**: Cube needs separate `wait` per subblock | + +##### A2A3: `set_cross_core` / `wait_flag_dev` + +```c +// mode2 broadcast/reduce semantics for 1:2 cluster +set_cross_core(pipe, semaphore_id); // pipe: VEC/MTE2/CUBE/FIX +wait_flag_dev(semaphore_id); // SU-level blocking +``` + +``` +C→V Broadcast (one set reaches both): + AIC ──set_cross_core──┬──> AIV0 sema++ + └──> AIV1 sema++ + +V→C Reduce (one wait for both): + AIV0 ──set_cross_core──┐ + ├──> AIC wait_flag_dev (blocks until BOTH) + AIV1 ──set_cross_core──┘ +``` + +##### `pto.set_cross_core` + +- **syntax:** `pto.set_cross_core %core_id, %event_id : i64, i64` +- **semantics:** Signal event to another core. Uses **mode2** for 1:2 cluster on A2A3. + +##### `pto.wait_flag_dev` + +- **syntax:** `pto.wait_flag_dev %core_id, %event_id : i64, i64` +- **semantics:** Wait for event from another core. **SU-level blocking** — entire core stalls. + +##### A5: `set_intra_block` / `wait_intra_core` + +```c +set_intra_block(trigger_pipe, semaphore_id); +wait_intra_core(wait_pipe, semaphore_id); // only named pipe stalls +``` + +**A5 semaphore address space:** The hardware has **16 physical semaphore IDs** but exposes a **32-ID address space** to support 1:1 signaling to each subblock: + +| ID Range | Target | +|----------|--------| +| 0–15 | AIV0 (subblock 0) | +| 16–31 (+15 offset) | AIV1 (subblock 1) | + +This means C→V requires **separate `set_intra_block` calls** per subblock (no broadcast), and V→C requires **separate `wait_intra_core` calls** per subblock (no hardware reduce). + +``` +C→V on A5 (1:1, no broadcast — need two sets): + AIC ──set_intra_block(pipe, sema_id)────> AIV0 + AIC ──set_intra_block(pipe, sema_id+15)──> AIV1 + +V→C on A5 (1:1, no reduce — need two waits): + AIV0 ──set_intra_block──> AIC wait_intra_core(pipe, sema_id) + AIV1 ──set_intra_block──> AIC wait_intra_core(pipe, sema_id+15) // extra wait +``` + +##### `pto.set_intra_block` + +- **syntax:** `pto.set_intra_block %block_id, %event_id : i64, i64` +- **semantics:** Signal event within a block (A5). Specifies **trigger pipe**. 1:1 per subblock. + +##### `pto.wait_intra_core` + +- **syntax:** `pto.wait_intra_core %block_id, %event_id : i64, i64` +- **semantics:** Wait for event within block (A5). Specifies **which pipeline should wait** — only that pipe stalls, SU and other pipes continue. + +##### Wait Granularity Comparison + +``` +A2A3 wait_flag_dev (SU-level stall): + SU ──┬── PIPE_MTE2 ───╳ ALL STALLED + ├── PIPE_V ───╳ ALL STALLED + └── PIPE_MTE3 ───╳ ALL STALLED + +A5 wait_intra_core "PIPE_MTE2" (per-pipe stall): + SU ──┬── PIPE_MTE2 ───╳ STALLED (waiting for Cube) + ├── PIPE_V ─── ✓ RUNNING + └── PIPE_MTE3 ─── ✓ RUNNING +``` + +--- + + + +### 2. DMA Copy Programming + +> **Category:** DMA transfer configuration and execution +> **Pipelines:** MTE2 (GM→UB), MTE3 (UB→GM) + +DMA transfers move data between Global Memory (GM) and Unified Buffer (UB). The MTE engines operate asynchronously from the Vector core, requiring explicit sync (see [Pipeline Sync](#isa-01-pipeline-sync)). + +The MTE2/MTE3 DMA engine executes a **multi-level nested loop** transfer. Before issuing the copy instruction, stride and loop-size registers must be configured. + +--- + +#### Loop Stride Configuration (GM→UB) + +These ops configure the MTE2 DMA engine's hardware loops for GM→UB transfers. They must be set **before** calling `pto.copy_gm_to_ubuf`. + +##### `pto.set_loop_size_outtoub` + +- **syntax:** `pto.set_loop_size_outtoub %loop1_count, %loop2_count : i64, i64` +- **semantics:** Configure HW loop iteration counts for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%loop1_count` | 21 bits | Inner HW loop iteration count | +| `%loop2_count` | 21 bits | Outer HW loop iteration count | + +When not using multi-level looping, set both to 1. + +--- + +##### `pto.set_loop2_stride_outtoub` + +- **syntax:** `pto.set_loop2_stride_outtoub %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure outer loop (loop2) pointer advance for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 40 bits | GM source pointer advance per loop2 iteration (bytes) | +| `%dst_stride` | 21 bits | UB destination pointer advance per loop2 iteration (bytes) | + +After each loop2 iteration, the DMA engine advances the GM read pointer by `%src_stride` and UB write pointer by `%dst_stride`. + +--- + +##### `pto.set_loop1_stride_outtoub` + +- **syntax:** `pto.set_loop1_stride_outtoub %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure inner loop (loop1) pointer advance for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 40 bits | GM source pointer advance per loop1 iteration (bytes) | +| `%dst_stride` | 21 bits | UB destination pointer advance per loop1 iteration (bytes) | + +--- + +#### Loop Stride Configuration (UB→GM) + +These ops configure the MTE3 DMA engine's hardware loops for UB→GM transfers. They must be set **before** calling `pto.copy_ubuf_to_gm`. + +Note: UB stride fields are 21 bits (sufficient for 256KB UB address space), GM stride fields are 40 bits (full GM address range). + +##### `pto.set_loop_size_ubtoout` + +- **syntax:** `pto.set_loop_size_ubtoout %loop1_count, %loop2_count : i64, i64` +- **semantics:** Configure HW loop iteration counts for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%loop1_count` | 21 bits | Inner HW loop iteration count | +| `%loop2_count` | 21 bits | Outer HW loop iteration count | + +--- + +##### `pto.set_loop2_stride_ubtoout` + +- **syntax:** `pto.set_loop2_stride_ubtoout %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure outer loop (loop2) pointer advance for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 21 bits | UB source pointer advance per loop2 iteration (bytes) | +| `%dst_stride` | 40 bits | GM destination pointer advance per loop2 iteration (bytes) | + +--- + +##### `pto.set_loop1_stride_ubtoout` + +- **syntax:** `pto.set_loop1_stride_ubtoout %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure inner loop (loop1) pointer advance for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 21 bits | UB source pointer advance per loop1 iteration (bytes) | +| `%dst_stride` | 40 bits | GM destination pointer advance per loop1 iteration (bytes) | + +--- + +#### DMA Transfer Execution + +##### `pto.copy_gm_to_ubuf` + +- **syntax:** +```mlir +pto.copy_gm_to_ubuf %gm_src, %ub_dst, + %sid, %n_burst, %len_burst, %left_padding, %right_padding, + %data_select_bit, %l2_cache_ctl, %src_stride, %dst_stride + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` +- **semantics:** DMA transfer from Global Memory (`!pto.ptr`) to Unified Buffer (`!pto.ptr`). + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%gm_src` | GM source pointer (`!pto.ptr`) | +| `%ub_dst` | UB destination pointer (`!pto.ptr`, 32B-aligned) | +| `%sid` | Stream ID (usually 0) | +| `%n_burst` | Number of burst rows (innermost loop count) | +| `%len_burst` | Contiguous bytes transferred per burst row | +| `%left_padding` | Left padding count (bytes) | +| `%right_padding` | Right padding count (bytes) | +| `%data_select_bit` | Padding / data-select control bit (`i1`) | +| `%l2_cache_ctl` | L2 cache allocate control (TBD — controls whether DMA allocates in L2 cache) | +| `%src_stride` | GM source stride: start-to-start distance between consecutive burst rows (bytes) | +| `%dst_stride` | UB destination stride: start-to-start distance between consecutive burst rows (bytes, 32B-aligned) | + +--- + +##### `pto.copy_ubuf_to_gm` + +- **syntax:** +```mlir +pto.copy_ubuf_to_gm %ub_src, %gm_dst, + %sid, %n_burst, %len_burst, %reserved, %dst_stride, %src_stride + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` +- **semantics:** DMA transfer from Unified Buffer (`!pto.ptr`) to Global Memory (`!pto.ptr`). MTE3 reads only `len_burst` bytes from each UB row (de-padding). + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%ub_src` | UB source pointer (`!pto.ptr`, 32B-aligned) | +| `%gm_dst` | GM destination pointer (`!pto.ptr`) | +| `%sid` | Stream ID (usually 0) | +| `%n_burst` | Number of burst rows | +| `%len_burst` | Contiguous bytes transferred per burst row | +| `%reserved` | Reserved field (set to 0) | +| `%dst_stride` | GM destination stride: start-to-start distance between consecutive burst rows (bytes) | +| `%src_stride` | UB source stride: start-to-start distance between consecutive burst rows (bytes, 32B-aligned) | + +--- + +##### `pto.copy_ubuf_to_ubuf` + +- **syntax:** +```mlir +pto.copy_ubuf_to_ubuf %source, %dest, %sid, %n_burst, %len_burst, %src_stride, %dst_stride + : !pto.ptr, !pto.ptr, i64 x5 +``` +- **semantics:** Copy within Unified Buffer. + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%source` | UB source pointer | +| `%dest` | UB destination pointer | +| `%sid` | Stream ID | +| `%n_burst` | Number of bursts | +| `%len_burst` | Length per burst | +| `%src_stride` | Source stride | +| `%dst_stride` | Destination stride | + +--- + +#### Burst / Stride / Pad Model + +All A5 DMA addresses are **stride-based**: stride is the distance from the start of one row to the start of the next row (`stride >= lenBurst`). There is no separate "gap" parameter. + +##### Key Terms + +``` +burst = lenBurst contiguous bytes transferred per row +stride = distance (bytes) from start of row[r] to start of row[r+1] +pad = ub_stride - lenBurst, padded to the 32B alignment boundary +``` + +##### Alignment Constraints + +- **UB addresses** (both source and destination) must be **32-byte aligned**. +- **GM→UB padding**: When `data_select_bit = true`, each UB row is padded from `lenBurst` up to the **32B-aligned boundary** of `ub_stride` with `pad_val` (set via `set_mov_pad_val`). This ensures every UB row starts at a 32B-aligned offset. +- **UB→GM de-padding**: MTE3 reads `lenBurst` bytes from each 32B-aligned UB row (skipping any padding that was added during load), writing only valid data to GM. This effectively strips padding on store. + +##### 2D Diagram: GM→UB (pto.copy_gm_to_ubuf) + +``` +GM (source, `!pto.ptr`): + + |<--- src_stride (start-to-start) --->| + |<- len_burst ->| | +Row 0: [##DATA########]......................| +Row 1: [##DATA########]......................| +Row 2: [##DATA########]......................| + ... +Row N-1: [##DATA########] + +UB (destination, `!pto.ptr`, 32B-aligned): + + |<---------- dst_stride (32B-aligned) ---------->| + |<- len_burst ->|<- pad (to 32B boundary) ->| | +Row 0: [##DATA########][000000 PAD 000000000000000] +Row 1: [##DATA########][000000 PAD 000000000000000] +Row 2: [##DATA########][000000 PAD 000000000000000] + ... +Row N-1: [##DATA########][000000 PAD 000000000000000] + +N = n_burst +stride = start of row[r] to start of row[r+1] +pad = filled with pad_val to 32B boundary (data_select_bit=true) +[DATA] = valid data transferred by DMA +[PAD] = pad_val fill (set via set_mov_pad_val) +``` + +##### 2D Diagram: UB→GM (pto.copy_ubuf_to_gm) + +``` +UB (source, `!pto.ptr`, 32B-aligned start addr): + + |<---------- src_stride (32B-aligned) --------->| + |<- len_burst ->|<-- pad (ignored on read) -->| | +Row 0: [##DATA########][000 pad 000000000000000000] +Row 1: [##DATA########][000 pad 000000000000000000] +Row 2: [##DATA########][000 pad 000000000000000000] + ... +Row N-1: [##DATA########][000 pad 000000000000000000] + +GM (destination, `!pto.ptr`): + + |<--- dst_stride (start-to-start) --->| + |<- len_burst ->| | +Row 0: [##DATA########]......................| +Row 1: [##DATA########]......................| +Row 2: [##DATA########]......................| + ... +Row N-1: [##DATA########] + +N = n_burst +MTE3 reads only len_burst bytes from each UB row (de-padding). +Only len_burst bytes are written to each GM row. +``` + +--- + +#### Multi-Level Loop Semantics (C Code) + +The full DMA transfer is a nested loop. The HW loop registers (set before the copy) control the outer levels, and the copy instruction parameters control the innermost burst level. + +##### GM→UB Full Loop + +```c +// C equivalent of what the HW executes: +for (int j = 0; j < loop2_count; j++) { // HW outer loop + uint8_t *gm1 = gm_src + j * loop2_src_stride; + uint8_t *ub1 = ub_dst + j * loop2_dst_stride; + + for (int k = 0; k < loop1_count; k++) { // HW inner loop + uint8_t *gm2 = gm1 + k * loop1_src_stride; + uint8_t *ub2 = ub1 + k * loop1_dst_stride; + + for (int r = 0; r < n_burst; r++) { // burst engine + memcpy(ub2 + r * dst_stride, // UB dest row + gm2 + r * src_stride, // GM src row + len_burst); // contiguous bytes + if (data_select_bit) + memset(ub2 + r * dst_stride + len_burst, + pad_val, dst_stride - len_burst); + } + } +} +``` + +##### UB→GM Full Loop + +```c +// C equivalent: +for (int j = 0; j < loop2_count; j++) { + uint8_t *ub1 = ub_src + j * loop2_src_stride; + uint8_t *gm1 = gm_dst + j * loop2_dst_stride; + + for (int k = 0; k < loop1_count; k++) { + uint8_t *ub2 = ub1 + k * loop1_src_stride; + uint8_t *gm2 = gm1 + k * loop1_dst_stride; + + for (int r = 0; r < n_burst; r++) { + memcpy(gm2 + r * dst_stride, // GM dest row + ub2 + r * src_stride, // UB src row + len_burst); // contiguous bytes + } + } +} +``` + +--- + +#### Example 1: GM→UB — Load a 32×32 f32 Tile (Simple Case) + +Load a 32×32 f32 tile from GM into UB. This matches the `abs_kernel_2d` test case. + +``` +GM layout (32 × 32 f32, contiguous): + + |<- len_burst = 128B (32 × 4) ->| + |<- src_stride = 128B --------->| + +--[#######TILE#######]--+ row 0 + +--[#######TILE#######]--+ row 1 + ... + +--[#######TILE#######]--+ row 31 + +UB layout (32 × 32 f32, 32B-aligned, contiguous): + + |<- dst_stride = 128B (32B-aligned) ->| + +--[#######TILE#######]--+ row 0 + +--[#######TILE#######]--+ row 1 + ... + +--[#######TILE#######]--+ row 31 + + len_burst = 32 × 4 = 128 bytes + src_stride = 128 bytes (contiguous rows) + dst_stride = 128 bytes (already 32B-aligned, no padding) +``` + +```mlir +// Simple 2D load — no multi-level loops needed +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + +pto.copy_gm_to_ubuf %arg0, %ub_in, + %c0_i64, // sid = 0 + %c32_i64, // n_burst = 32 (32 rows) + %c128_i64, // len_burst = 128 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c128_i64, // src_stride = 128 bytes + %c128_i64 // dst_stride = 128 bytes + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +#### Example 2: GM→UB — Load a 2D Tile from a Larger Matrix + +Load a 64×128 tile (f16) from a 1024×512 matrix in GM into UB. + +``` +GM layout (1024 × 512 f16): + + col 0 col 128 col 512 + | | | + +--[###TILE###]+.....................+ row R + +--[###TILE###]+.....................+ row R+1 + ... + +--[###TILE###]+.....................+ row R+63 + + |<--------- src_stride = 1024B ----------->| + |<-len_burst=256B->| + + len_burst = 128 × 2 = 256 bytes (128 f16 elements) + src_stride = 512 × 2 = 1024 bytes (start-to-start, full GM row) + +UB layout (64 × 128 f16, 32B-aligned, contiguous): + + +--[###TILE###]--+ row 0 (256 bytes, 32B-aligned, no pad) + +--[###TILE###]--+ row 1 + ... + +--[###TILE###]--+ row 63 + + dst_stride = 256 bytes (= len_burst, already 32B-aligned, no padding) +``` + +```mlir +// Simple 2D load — no multi-level loops needed +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 (64 rows) + %c256_i64, // len_burst = 256 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c1024_i64, // src_stride = 1024 bytes (full matrix row) + %c256_i64 // dst_stride = 256 bytes (tile row) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +#### Example 3: GM→UB — Load with Padding + +Load 100 valid columns from GM into a 128-wide UB tile (f16). The remaining 28 columns are zero-padded. + +``` +GM (100 cols valid, contiguous): + + |<-len_burst=200B->| + |<- src_stride=200B (start-to-start) ->| + +--[####DATA####]-+ row 0 + +--[####DATA####]-+ row 1 + ... + +--[####DATA####]-+ row 63 + +UB (128 cols wide, 32B-aligned, padded): + + |<--------- dst_stride = 256B (32B-aligned) --------->| + |<-len_burst=200B->|<---- pad = 56B to 32B boundary ->| + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 0 + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 1 + ... + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 63 + + len_burst = 100 × 2 = 200 bytes + src_stride = 200 bytes (start-to-start, contiguous in GM) + dst_stride = 128 × 2 = 256 bytes (32B-aligned tile width in UB) + pad = 256 - 200 = 56 bytes (padded to 32B boundary with pad_val) +``` + +```mlir +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 + %c200_i64, // len_burst = 200 bytes + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %true, // data_select_bit = true (enable padding) + %c0_i64, // l2_cache_ctl = 0 + %c200_i64, // src_stride = 200 bytes + %c256_i64 // dst_stride = 256 bytes (32B-aligned) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +#### Example 4: UB→GM — Store a 32×32 f32 Tile (Simple Case) + +Store a 32×32 f32 tile from UB back to GM. This matches the `abs_kernel_2d` test case. + +``` +UB (source, 32B-aligned, 32 × 32 f32): + + |<- src_stride = 128B (32B-aligned) ->| + |<- len_burst = 128B ->| + +--[#######TILE#######]---+ row 0 + +--[#######TILE#######]---+ row 1 + ... + +--[#######TILE#######]---+ row 31 + + (no padding here — len_burst == src_stride) + +GM (dest, 32 × 32 f32): + + |<- dst_stride = 128B ->| + |<- len_burst = 128B -->| + +--[#######TILE#######]---+ row 0 + +--[#######TILE#######]---+ row 1 + ... + +--[#######TILE#######]---+ row 31 +``` + +```mlir +// Configure MTE3 strides +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + +pto.copy_ubuf_to_gm %ub_out, %arg1, + %c0_i64, // sid = 0 + %c32_i64, // n_burst = 32 + %c128_i64, // len_burst = 128 bytes + %c0_i64, // reserved = 0 + %c128_i64, // dst_stride = 128 bytes + %c128_i64 // src_stride = 128 bytes + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` + +--- + +#### Example 5: UB→GM — Store a 2D Tile Back to a Larger Matrix + +Store a 64×128 tile (f16) from UB back to a 1024×512 GM matrix at an offset. + +``` +UB (source, 32B-aligned, 64 × 128 f16): + + |<- src_stride = 256B (32B-aligned) ->| + |<- len_burst = 256B ->| + +--[#####TILE#####]---+ row 0 + +--[#####TILE#####]---+ row 1 + ... + +--[#####TILE#####]---+ row 63 + + (no padding here — len_burst == src_stride) + +GM (dest, into 1024 × 512 matrix): + + |<----------- dst_stride = 1024B (start-to-start) --------->| + |<- len_burst = 256B ->| | + col 0 col 128 col 512 + +--[#####TILE#####]---+.............................+ row R + +--[#####TILE#####]---+.............................+ row R+1 + ... + +--[#####TILE#####]---+.............................+ row R+63 + + MTE3 reads len_burst bytes from each 32B-aligned UB row, + writes only len_burst bytes per GM row (stride controls row spacing). +``` + +```mlir +// Configure MTE3 strides +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_ubtoout %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_ubtoout %c0_i64, %c0_i64 : i64, i64 + +pto.copy_ubuf_to_gm %ub_ptr, %gm_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 + %c256_i64, // len_burst = 256 bytes + %c0_i64, // reserved = 0 + %c1024_i64, // dst_stride = 1024 bytes (GM row) + %c256_i64 // src_stride = 256 bytes (UB row) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` + +--- + +#### Example 6: GM→UB with Multi-Level Loop (Batch of Tiles) + +Load 4 batches of 8×128 tiles from a [4, 8, 128] f16 tensor using loop1. + +``` +GM [4, 8, 128] f16 (contiguous): UB (4 tiles laid out sequentially): + + batch 0: 8 rows × 256 bytes [batch 0: 8×128][batch 1: 8×128] + batch 1: 8 rows × 256 bytes [batch 2: 8×128][batch 3: 8×128] + batch 2: 8 rows × 256 bytes + batch 3: 8 rows × 256 bytes loop1 src_stride = 2048 bytes (8 × 256) + loop1 dst_stride = 2048 bytes (8 × 256) + Each batch = 8 × 256 = 2048 bytes loop1_count = 4 (iterate over batches) +``` + +```mlir +// loop1_count = 4 batches, loop2_count = 1 (not used) +pto.set_loop_size_outtoub %c4_i64, %c1_i64 : i64, i64 + +// loop1 stride: advance by one batch (2048 bytes) in both GM and UB +pto.set_loop1_stride_outtoub %c2048_i64, %c2048_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c8_i64, // n_burst = 8 rows per batch + %c256_i64, // len_burst = 256 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c256_i64, // src_stride = 256 (contiguous rows) + %c256_i64 // dst_stride = 256 (contiguous rows) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +Execution trace: + +``` +loop1 iter 0: gm_ptr + 0×2048 → ub_ptr + 0×2048, DMA 8 rows × 256B +loop1 iter 1: gm_ptr + 1×2048 → ub_ptr + 1×2048, DMA 8 rows × 256B +loop1 iter 2: gm_ptr + 2×2048 → ub_ptr + 2×2048, DMA 8 rows × 256B +loop1 iter 3: gm_ptr + 3×2048 → ub_ptr + 3×2048, DMA 8 rows × 256B +``` + +--- + + + +### 3. Vector Load/Store + +> **Category:** UB ↔ Vector Register data movement +> **Pipeline:** PIPE_V (Vector Core) + +Vector loads move data from Unified Buffer (UB) to vector registers (`vreg`). Vector stores move data from `vreg` back to UB. All vector compute operates only on `vreg` — UB is the staging area between DMA and compute. + +#### Common Operand Model + +- `%source` / `%dest` is the base address operand in SSA form. The base pointer + MUST address the Vector tile buffer / UB space. +- `%offset` is the displacement operand in SSA form. The exact encoding is + instruction-specific, but the effective address and any post-update behavior + MUST match the selected instruction form. +- `%mask` is the predicate operand for predicated memory families. For memory + families, + inactive lanes or inactive blocks MUST NOT issue memory requests unless the + instruction explicitly documents a different behavior. +- `%result` is the destination vector register value in SSA form. +- `!pto.align` is the SSA carrier for alignment-buffer state used by unaligned + load/store families. The PTO micro Instruction representation makes that state explicit rather than implicit. + +--- + +#### Contiguous Loads + +##### `pto.vlds` + +- **syntax:** `%result = pto.vlds %source[%offset] {dist = "DIST"} : !pto.ptr -> !pto.vreg` +- **semantics:** Vector load with distribution mode. +- **inputs:** + `%source` is the UB base address, `%offset` is the load displacement, and + `DIST` selects the distribution mode. +- **outputs:** + `%result` is the loaded vector register value. +- **constraints and limitations:** + The effective address MUST satisfy the alignment rule of the selected + distribution mode. `NORM` reads one full vector footprint. Broadcast, + upsample, downsample, unpack, split-channel, and deinterleave modes change + how memory bytes are mapped into destination lanes, but they do not change the + fact that the source is UB memory. + +**Distribution modes:** + +| Mode | Description | C Semantics | +|------|-------------|-------------| +| `NORM` | Contiguous 256B load | `dst[i] = UB[base + i * sizeof(T)]` | +| `BRC_B8/B16/B32` | Broadcast single element | `dst[i] = UB[base]` for all i | +| `US_B8/B16` | Upsample (duplicate each element) | `dst[2*i] = dst[2*i+1] = UB[base + i]` | +| `DS_B8/B16` | Downsample (every 2nd element) | `dst[i] = UB[base + 2*i]` | +| `UNPK_B8/B16/B32` | Unpack (zero-extend to wider type) | `dst_i32[i] = (uint32_t)UB_i16[base + 2*i]` | +| `SPLT4CHN_B8` | Split 4-channel (RGBA → R plane) | Extract every 4th byte | +| `SPLT2CHN_B8/B16` | Split 2-channel | Extract every 2nd element | +| `DINTLV_B32` | Deinterleave 32-bit | Even elements only | +| `BLK` | Block load | Blocked access pattern | + +**Example — Contiguous load:** +```mlir +%v = pto.vlds %ub[%offset] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> +``` + +**Example — Broadcast scalar to all lanes:** +```mlir +%v = pto.vlds %ub[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> +``` + +--- + +##### `pto.vldas` + +- **syntax:** `%result = pto.vldas %source : !pto.ptr -> !pto.align` +- **semantics:** Prime alignment buffer for subsequent unaligned load. +- **inputs:** + `%source` is the UB address whose surrounding aligned block seeds the load + alignment state. +- **outputs:** + `%result` is the initialized load-alignment state. +- **constraints and limitations:** + This op is the required leading operation for a `pto.vldus` stream using the + same alignment state. The source address itself need not be 32-byte aligned; + hardware truncates it to the aligned block boundary for the priming load. + +--- + +##### `pto.vldus` + +- **syntax:** `%result, %align_out, %base_out = pto.vldus %source, %align : !pto.ptr, !pto.align -> !pto.vreg, !pto.align, !pto.ptr` +- **semantics:** Unaligned load using primed align state. +- **inputs:** + `%source` is the current UB address and `%align` is the incoming load + alignment state primed by `pto.vldas` or a prior `pto.vldus`. +- **outputs:** + `%result` is the assembled vector value, `%align_out` is the updated alignment + state, and `%base_out` is the post-update base pointer state exposed in SSA + form. +- **constraints and limitations:** + A matching `pto.vldas` MUST appear before the first dependent `pto.vldus` + stream in the same vector loop. Both the alignment state and the base address + advance across the stream, and the PTO micro Instruction representation exposes those updates as SSA results. + +**Unaligned load pattern:** +```mlir +%align = pto.vldas %ub : !pto.ptr -> !pto.align +%vec, %align2, %ub2 = pto.vldus %ub, %align : !pto.ptr, !pto.align -> !pto.vreg<64xf32>, !pto.align, !pto.ptr +``` + +--- + +#### Dual Loads (Deinterleave) + +##### `pto.vldx2` + +- **syntax:** `%low, %high = pto.vldx2 %source[%offset], "DIST" : !pto.ptr, index -> !pto.vreg, !pto.vreg` +- **semantics:** Dual load with deinterleave (AoS → SoA conversion). +- **inputs:** + `%source` is the UB base pointer, `%offset` is the displacement, and `DIST` + selects a dual-load/deinterleave layout. +- **outputs:** + `%low` and `%high` are the two destination vectors. +- **constraints and limitations:** + This family is only legal for interleave/deinterleave style distributions. + The two outputs form an ordered pair, and that pairing MUST be preserved. + +**Distribution modes:** `DINTLV_B8`, `DINTLV_B16`, `DINTLV_B32`, `BDINTLV` + +```c +// DINTLV_B32: deinterleave 32-bit elements +for (int i = 0; i < 64; i++) { + low[i] = UB[base + 8*i]; // even elements + high[i] = UB[base + 8*i + 4]; // odd elements +} +``` + +**Example — Load interleaved XY pairs into separate X/Y vectors:** +```mlir +%x, %y = pto.vldx2 %ub[%offset], "DINTLV_B32" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +``` + +--- + +#### Strided Loads + +##### `pto.vsld` + +- **syntax:** `%result = pto.vsld %source[%offset], "STRIDE" : !pto.ptr -> !pto.vreg` +- **semantics:** Strided load with fixed stride pattern. +- **inputs:** + `%source` is the UB base pointer and `%offset` is the displacement encoded + with the selected fixed stride mode. +- **outputs:** + `%result` is the loaded vector. +- **constraints and limitations:** + This is a deprecated compatibility family. The selected stride token + determines which sub-elements are read from each source block. + +**Stride modes:** `STRIDE_S3_B16`, `STRIDE_S4_B64`, `STRIDE_S8_B32`, `STRIDE_S2_B64` + +--- + +##### `pto.vsldb` + +- **syntax:** `%result = pto.vsldb %source, %offset, %mask : !pto.ptr, i32, !pto.mask -> !pto.vreg` +- **semantics:** Block-strided load for 2D tile access. +- **inputs:** + `%source` is the UB base pointer, `%offset` is the packed stride/control word, + and `%mask` controls which blocks participate. +- **outputs:** + `%result` is the loaded vector. +- **constraints and limitations:** + `%offset` is not a plain byte displacement; it encodes the block stride and + repeat pattern. If a block is masked off, the corresponding destination block + is zeroed and MUST NOT raise an address overflow exception for that block. + +--- + +#### Gather (Indexed) Loads + +##### `pto.vgather2` + +- **syntax:** `%result = pto.vgather2 %source, %offsets, %active_lanes : !pto.ptr, !pto.vreg, index -> !pto.vreg` +- **semantics:** Indexed gather from UB. +- **inputs:** + `%source` is the UB base pointer, `%offsets` provides per-lane element + offsets, and `%active_lanes` bounds how many lanes participate. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + Only the first `%active_lanes` indices participate. The index element width + and interpretation MUST match the selected gather form, and each effective + address must satisfy that form's alignment rules. + +```c +for (int i = 0; i < active_lanes; i++) + dst[i] = UB[base + offsets[i] * sizeof(T)]; +``` + +--- + +##### `pto.vgatherb` + +- **syntax:** `%result = pto.vgatherb %source, %offsets, %active_lanes : !pto.ptr, !pto.vreg, index -> !pto.vreg` +- **semantics:** Byte-granularity indexed gather from UB. +- **inputs:** + `%source` is the UB base pointer, `%offsets` contains per-block byte offsets, + and `%active_lanes` bounds the number of active gathered blocks. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + This is a block gather, not a byte-per-lane gather. `%source` MUST be 32-byte + aligned, each participating offset MUST describe a 32-byte-aligned block, and + inactive blocks are zero-filled. + +```c +for (int i = 0; i < active_lanes; i++) + dst[i] = UB[base + offsets[i]]; // byte-addressed +``` + +--- + +##### `pto.vgather2_bc` + +- **syntax:** `%result = pto.vgather2_bc %source, %offsets, %mask : !pto.ptr, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Gather with broadcast, conditioned by mask. +- **inputs:** + `%source` is the UB base pointer, `%offsets` contains gather indices, and + `%mask` gates which lanes participate. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + This is a backward-compatible family. Masked-off lanes do not participate in + address coalescing and do not trigger address overflow exceptions; their + destination lanes are zero-filled. + +--- + +#### Contiguous Stores + +##### `pto.vsts` + +- **syntax:** `pto.vsts %value, %dest[%offset], %mask {dist = "DIST"} : !pto.vreg, !pto.ptr, !pto.mask` +- **semantics:** Vector store with distribution mode. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, `%offset` is + the displacement, `%mask` selects the active lanes or sub-elements, and + `DIST` selects the store distribution. +- **outputs:** + This op has no SSA result; it writes to UB memory. +- **constraints and limitations:** + The effective destination address MUST satisfy the alignment rule of the + selected store mode. Narrowing/packing modes may only preserve a subset of the + source bits. Merge-channel modes reinterpret the source vector as channel + planes and interleave them on store. + +**Distribution modes:** + +| Mode | Description | C Semantics | +|------|-------------|-------------| +| `NORM_B8/B16/B32` | Contiguous store | `UB[base + i] = src[i]` | +| `PK_B16/B32` | Pack/narrowing store | `UB_i16[base + 2*i] = truncate_16(src_i32[i])` | +| `MRG4CHN_B8` | Merge 4 channels (R,G,B,A → RGBA) | Interleave 4 planes | +| `MRG2CHN_B8/B16` | Merge 2 channels | Interleave 2 planes | + +**Example — Contiguous store:** +```mlir +pto.vsts %v, %ub[%offset], %mask {dist = "NORM_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +--- + +#### Dual Stores (Interleave) + +##### `pto.vstx2` + +- **syntax:** `pto.vstx2 %low, %high, %dest[%offset], "DIST", %mask : !pto.vreg, !pto.vreg, !pto.ptr, index, !pto.mask` +- **semantics:** Dual interleaved store (SoA → AoS conversion). +- **inputs:** + `%low` and `%high` are the two source vectors, `%dest` is the UB base pointer, + `%offset` is the displacement, `DIST` selects the interleave layout, and + `%mask` gates the participating elements. +- **outputs:** + This op has no SSA result; it writes an interleaved stream to UB. +- **constraints and limitations:** + This family is only legal for interleave distributions. The two source + vectors form an ordered pair, and the interleave semantics of that pair MUST + be preserved. + +**Distribution modes:** `INTLV_B8`, `INTLV_B16`, `INTLV_B32` + +```c +// INTLV_B32: +for (int i = 0; i < 64; i++) { + UB[base + 8*i] = low[i]; + UB[base + 8*i + 4] = high[i]; +} +``` + +--- + +#### Strided Stores + +##### `pto.vsst` + +- **syntax:** `pto.vsst %value, %dest[%offset], "STRIDE" : !pto.vreg, !pto.ptr` +- **semantics:** Strided store with fixed stride pattern. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, and `%offset` + / `STRIDE` select the fixed strided layout. +- **outputs:** + This op writes UB memory and returns no SSA value. +- **constraints and limitations:** + This is a deprecated compatibility family. The stride token, not the vector + lane number alone, determines which destination elements are written. + +--- + +##### `pto.vsstb` + +- **syntax:** `pto.vsstb %value, %dest, %offset, %mask : !pto.vreg, !pto.ptr, i32, !pto.mask` +- **semantics:** Block-strided store for 2D tile access. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, `%offset` is + the packed stride/control word, and `%mask` controls block participation. +- **outputs:** + This op writes UB memory and returns no SSA value. +- **constraints and limitations:** + `%offset` is a control word, not a plain byte displacement. This is a + deprecated compatibility family kept for surface coverage. + +--- + +#### Scatter (Indexed) Stores + +##### `pto.vscatter` + +- **syntax:** `pto.vscatter %value, %dest, %offsets, %active_lanes : !pto.vreg, !pto.ptr, !pto.vreg, index` +- **semantics:** Indexed scatter to UB. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, `%offsets` + provides per-lane or per-block indices, and `%active_lanes` bounds the active + requests. +- **outputs:** + This op writes UB memory and returns no SSA value. +- **constraints and limitations:** + Only `b8`, `b16`, and `b32` element sizes are supported. The index vector + must use a supported integer element type and layout for this family. + Each computed address MUST be element-aligned. If two or more indices alias, + only one write is guaranteed and the winning lane is implementation-defined. + +```c +for (int i = 0; i < active_lanes; i++) + UB[base + offsets[i] * sizeof(T)] = src[i]; +``` + +--- + +#### Alignment State Stores + +##### `pto.vsta` + +- **syntax:** `pto.vsta %value, %dest[%offset] : !pto.align, !pto.ptr, index` +- **semantics:** Flush alignment state to memory. +- **inputs:** + `%value` is the pending store-alignment state, `%dest` is the UB base pointer, + and `%offset` is the flush displacement. +- **outputs:** + This op writes buffered tail bytes to UB and returns no SSA value. +- **constraints and limitations:** + The flush address MUST match the post-updated address expected by the + preceding unaligned-store stream. After the flush, the corresponding store + alignment state is consumed. + +--- + +##### `pto.vstas` +- **syntax:** `pto.vstas %value, %dest, %offset : !pto.align, !pto.ptr, i32` +- **semantics:** Scalar-register-offset form of alignment-state flush. +- **inputs:** + `%value` is the pending store-alignment state, `%dest` is the UB base + pointer, and `%offset` is the scalar-register style displacement. +- **outputs:** + This op writes buffered tail bytes to UB and returns no SSA value. +- **constraints and limitations:** + This family uses the same buffered-tail semantics as `pto.vsta` but keeps the + scalar-offset form explicit. + +--- + +##### `pto.vstar` +- **syntax:** `pto.vstar %value, %dest : !pto.align, !pto.ptr` +- **semantics:** Flush alignment state using the register-update form. +- **inputs:** + `%value` is the pending store-alignment state and `%dest` is the UB base + pointer. +- **outputs:** + This op writes buffered tail bytes to UB and returns no SSA value. +- **constraints and limitations:** + The implicit update state consumed by this flush MUST correspond to the same + store stream that produced `%value`. + +--- + +##### `pto.vstu` +- **syntax:** `%align_out, %base_out = pto.vstu %align_in, %base_in, %value, %dest, %mode : !pto.align, !pto.ptr, !pto.vreg, !pto.ptr, index -> !pto.align, !pto.ptr` +- **semantics:** Unaligned store with explicit threaded alignment/base state. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%base_in` is the current + stream base, `%value` is the vector to store, `%dest` is the UB base pointer, + and `%mode` selects the post-update behavior. +- **outputs:** + `%align_out` is the updated buffered-tail state and `%base_out` is the + post-update base pointer state. +- **constraints and limitations:** + This op models a stateful unaligned-store sequence in SSA form. A final + `pto.vsta` / `pto.vstas` / `pto.vstar` is still required to flush the trailing + buffered bytes. + +--- + +##### `pto.vstus` +- **syntax:** `%align_out, %base_out = pto.vstus %align_in, %base_in, %value, %dest, %offset : !pto.align, !pto.ptr, !pto.vreg, !pto.ptr, i32 -> !pto.align, !pto.ptr` +- **semantics:** Scalar-offset unaligned store with threaded state. +- **inputs:** + Same roles as `pto.vstu`, but `%offset` is provided explicitly as the scalar + displacement. +- **outputs:** + Updated alignment state and base state. +- **constraints and limitations:** + The same final flush requirement and state-threading constraints as + `pto.vstu` apply. + +--- + +##### `pto.vstur` +- **syntax:** `%align_out = pto.vstur %align_in, %value, %dest : !pto.align, !pto.vreg, !pto.ptr -> !pto.align` +- **semantics:** Register-update unaligned store form. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%value` is the vector to + store, and `%dest` is the UB base pointer. +- **outputs:** + `%align_out` is the updated buffered-tail state. +- **constraints and limitations:** + This op updates only the residual alignment state. A matching flush op is + still required to emit the trailing bytes. + +- **syntax:** `pto.vstas %value, %dest, %offset : !pto.align, !pto.ptr, i32` +- **semantics:** Flush alignment state with scalar offset. + +--- + +##### `pto.vstar` + +- **syntax:** `pto.vstar %value, %dest : !pto.align, !pto.ptr` +- **semantics:** Flush remaining alignment state. +- **inputs:** + `%value` is the pending alignment/buffer state that still needs to be emitted, + and `%dest` is the UB destination base pointer. +- **outputs:** + No SSA result. The effect is a memory-side flush that writes the remaining + buffered bytes to memory. +- **constraints and limitations:** + This op terminates an unaligned-store sequence. It MUST be paired with a + compatible prior state-producing store sequence so that the pending tail state + is well-defined. + +--- + +#### Stateful Store Ops + +These ops make reference-updated state explicit as SSA results. + +##### `pto.vstu` + +- **syntax:** `%align_out, %offset_out = pto.vstu %align_in, %offset_in, %value, %base, "MODE" : !pto.align, index, !pto.vreg, !pto.ptr -> !pto.align, index` +- **semantics:** Unaligned store with align + offset state update. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%offset_in` is the current + logical byte/element displacement, `%value` is the vector being stored, and + `%base` is the UB base pointer. +- **outputs:** + `%align_out` is the updated alignment/tail state and `%offset_out` is the + next offset after applying the selected post-update rule. +- **constraints and limitations:** + The alignment state MUST be threaded in program order. A terminating flush + form such as `pto.vstar`/`pto.vstas` is still required to commit the buffered + tail bytes. + +**Mode tokens:** `POST_UPDATE`, `NO_POST_UPDATE` + +--- + +##### `pto.vstus` + +- **syntax:** `%align_out, %base_out = pto.vstus %align_in, %offset, %value, %base, "MODE" : !pto.align, i32, !pto.vreg, !pto.ptr -> !pto.align, !pto.ptr` +- **semantics:** Unaligned store with scalar offset and state update. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%offset` is the scalar + displacement, `%value` is the vector being stored, and `%base` is the UB base + pointer. +- **outputs:** + `%align_out` is the updated buffered-tail state and `%base_out` is the next + base pointer when the lowering chooses a post-update form. +- **constraints and limitations:** + This is the scalar-offset stateful form of the unaligned store family. The + scalar offset width and update mode MUST match the selected form, and a later + flush op is still required. + +--- + +##### `pto.vstur` + +- **syntax:** `%align_out = pto.vstur %align_in, %value, %base, "MODE" : !pto.align, !pto.vreg, !pto.ptr -> !pto.align` +- **semantics:** Unaligned store with residual flush and state update. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%value` is the vector to + store, and `%base` is the UB base pointer. +- **outputs:** + `%align_out` is the updated residual state after the current partial store. +- **constraints and limitations:** + This form exposes only the evolving state; it does not by itself guarantee + that all buffered bytes have reached memory. A compatible final flush is still + required unless the surrounding sequence is known to be complete. + +--- + + + +### 4. Predicate Load/Store + +> **Category:** UB ↔ Predicate Register data movement +> **Pipeline:** PIPE_V (Vector Core) + +Predicate registers (`!pto.mask`) are 256-bit registers that enable per-lane conditional execution. These ops move predicate values between UB and predicate registers. + +--- + +#### Predicate Loads + +##### `pto.plds` + +- **syntax:** `%result = pto.plds %source[%offset] {dist = "DIST"} : !pto.ptr -> !pto.mask` +- **semantics:** Load predicate register with scalar offset. + +**Distribution modes:** `NORM`, `US`, `DS` + +**Example:** +```mlir +%mask = pto.plds %ub[%c0] {dist = "NORM"} : !pto.ptr -> !pto.mask +``` + +--- + +##### `pto.pld` + +- **syntax:** `%result = pto.pld %source[%offset], "DIST" : !pto.ptr, index -> !pto.mask` +- **semantics:** Load predicate register with areg offset. + +--- + +##### `pto.pldi` + +- **syntax:** `%result = pto.pldi %source, %offset, "DIST" : !pto.ptr, i32 -> !pto.mask` +- **semantics:** Load predicate register with immediate offset. + +--- + +#### Predicate Stores + +##### `pto.psts` + +- **syntax:** `pto.psts %value, %dest[%offset] : !pto.mask, !pto.ptr` +- **semantics:** Store predicate register with scalar offset. + +**Example:** +```mlir +pto.psts %mask, %ub[%c0] : !pto.mask, !pto.ptr +``` + +--- + +##### `pto.pst` + +- **syntax:** `pto.pst %value, %dest[%offset], "DIST" : !pto.mask, !pto.ptr, index` +- **semantics:** Store predicate register with areg offset. + +**Distribution modes:** `NORM`, `PK` + +--- + +##### `pto.psti` + +- **syntax:** `pto.psti %value, %dest, %offset, "DIST" : !pto.mask, !pto.ptr, i32` +- **semantics:** Store predicate register with immediate offset. + +--- + +##### `pto.pstu` + +- **syntax:** `%align_out, %base_out = pto.pstu %align_in, %value, %base : !pto.align, !pto.mask, !pto.ptr -> !pto.align, !pto.ptr` +- **semantics:** Predicate unaligned store with align state update. + +--- + +#### Typical Usage Pattern + +```mlir +// Generate comparison mask +%mask = pto.vcmp %v0, %v1, %seed, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + +// Store mask to UB for later use +pto.psts %mask, %ub_mask[%c0] : !pto.mask, !pto.ptr + +// ... later in another kernel ... + +// Load mask from UB +%saved_mask = pto.plds %ub_mask[%c0] {dist = "NORM"} : !pto.ptr -> !pto.mask + +// Use for predicated select +%result = pto.vsel %v_true, %v_false, %saved_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +--- + + + +### 5. Materialization & Predicate Ops + +> **Category:** Scalar broadcast, predicate generation and manipulation +> **Pipeline:** PIPE_V (Vector Core) + +These ops create vectors from scalar values and manipulate predicate registers. + +#### Common Operand Model + +- `%value` is the scalar source value in SSA form. +- `%input` is either a source scalar or a source vector depending on the op. +- `%result` is the destination vector register value. +- For 32-bit scalar inputs, the scalar source MUST satisfy the backend's legal + scalar-source constraints for this family. + +--- + +#### Scalar Materialization + +##### `pto.vbr` + +- **syntax:** `%result = pto.vbr %value : T -> !pto.vreg` +- **semantics:** Broadcast scalar to all vector lanes. +- **inputs:** + `%value` is the scalar source. +- **outputs:** + `%result` is a vector whose active lanes all carry `%value`. +- **constraints and limitations:** + Supported forms are `b8`, `b16`, and `b32`. For `b8`, only the low 8 bits of + the scalar source are consumed. + +```c +for (int i = 0; i < N; i++) + dst[i] = value; +``` + +**Example:** +```mlir +%one = pto.vbr %c1_f32 : f32 -> !pto.vreg<64xf32> +``` + +--- + +##### `pto.vdup` + +- **syntax:** `%result = pto.vdup %input {position = "POSITION"} : T|!pto.vreg -> !pto.vreg` +- **semantics:** Duplicate scalar or vector element to all lanes. +- **inputs:** + `%input` supplies the scalar or source-lane value selected by `position`. +- **outputs:** + `%result` is the duplicated vector. +- **constraints and limitations:** + `position` selects which source element or scalar position is duplicated. The + current PTO micro Instruction representation models that selector as an attribute rather than a + separate operand. + +```c +for (int i = 0; i < N; i++) + dst[i] = input_scalar_or_element; +``` + +--- + +#### Predicate Generation + +##### `pto.pset_b8` / `pto.pset_b16` / `pto.pset_b32` + +- **syntax:** `%result = pto.pset_b32 "PAT_*" : !pto.mask` +- **semantics:** Generate predicate from pattern. + +**Patterns:** + +| Pattern | Description | +|---------|-------------| +| `PAT_ALL` | All lanes active | +| `PAT_ALLF` | All lanes inactive | +| `PAT_H` | High half active | +| `PAT_Q` | Upper quarter active | +| `PAT_VL1`...`PAT_VL128` | First N lanes active | +| `PAT_M3`, `PAT_M4` | Modular patterns | + +**Example — All 64 f32 lanes active:** +```mlir +%all_active = pto.pset_b32 "PAT_ALL" : !pto.mask +``` + +**Example — First 16 lanes active:** +```mlir +%first_16 = pto.pset_b32 "PAT_VL16" : !pto.mask +``` + +--- + +##### `pto.pge_b8` / `pto.pge_b16` / `pto.pge_b32` + +- **syntax:** `%result = pto.pge_b32 "PAT_*" : !pto.mask` +- **semantics:** Generate tail mask — first N lanes active. + +```c +for (int i = 0; i < TOTAL_LANES; i++) + mask[i] = (i < len); +``` + +**Example — Tail mask for remainder loop:** +```mlir +%tail_mask = pto.pge_b32 "PAT_VL8" : !pto.mask + +--- + +##### `pto.plt_b8` / `pto.plt_b16` / `pto.plt_b32` + +- **syntax:** `%mask, %scalar_out = pto.plt_b32 %scalar : i32 -> !pto.mask, i32` +- **semantics:** Generate predicate state together with updated scalar state. +``` + +--- + +#### Predicate Pack/Unpack + +##### `pto.ppack` + +- **syntax:** `%result = pto.ppack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Narrowing pack of predicate register. + +**Part tokens:** `LOWER`, `HIGHER` + +--- + +##### `pto.punpack` + +- **syntax:** `%result = pto.punpack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Widening unpack of predicate register. + +--- + +#### Predicate Logical Ops + +##### `pto.pand` + +- **syntax:** `%result = pto.pand %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise AND. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) dst[i] = src0[i] & src1[i]; +``` + +--- + +##### `pto.por` + +- **syntax:** `%result = pto.por %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise OR. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) dst[i] = src0[i] | src1[i]; +``` + +--- + +##### `pto.pxor` + +- **syntax:** `%result = pto.pxor %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise XOR. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) dst[i] = src0[i] ^ src1[i]; +``` + +--- + +##### `pto.pnot` + +- **syntax:** `%result = pto.pnot %input, %mask : !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise NOT. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) dst[i] = ~src[i]; +``` + +--- + +##### `pto.psel` + +- **syntax:** `%result = pto.psel %src0, %src1, %sel : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate select (mux). + +```c +for (int i = 0; i < N; i++) + dst[i] = sel[i] ? src0[i] : src1[i]; +``` + +--- + +#### Predicate Movement + +##### `pto.ppack` + +- **syntax:** `%result = pto.ppack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Narrowing pack of predicate register. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) dst[i] = src[i]; +``` + +--- + +##### `pto.punpack` + +- **syntax:** `%result = pto.punpack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Widening unpack of predicate register. + +--- + +##### `pto.pdintlv_b8` + +- **syntax:** `%low, %high = pto.pdintlv_b8 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **semantics:** Predicate deinterleave. + +--- + +##### `pto.pintlv_b16` + +- **syntax:** `%low, %high = pto.pintlv_b16 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **semantics:** Predicate interleave. + +--- + +#### Typical Usage + +```mlir +// Generate all-active mask for f32 (64 lanes) +%all = pto.pset_b32 "PAT_ALL" : !pto.mask + +// Generate tail mask for remainder (last 12 elements) +%tail = pto.pge_b32 "PAT_VL12" : !pto.mask + +// Compare and generate mask +%cmp_mask = pto.vcmp %a, %b, %all, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + +// Combine masks: only process tail elements that passed comparison +%combined = pto.pand %cmp_mask, %tail, %all : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + +// Use for predicated operation +%result = pto.vsel %true_vals, %false_vals, %combined : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +--- + + + +### 6. Unary Vector Ops + +> **Category:** Single-input vector operations +> **Pipeline:** PIPE_V (Vector Core) + +Element-wise operations that take one vector input and produce one vector output. + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%mask` is the predicate operand. For this family, inactive lanes follow the + predication behavior of the selected instruction form: zeroing forms + zero-fill inactive lanes, while merging forms preserve the destination value. +- `%result` is the destination vector register value. Unless stated otherwise, + `%result` has the same lane count and element type as `%input`. + +--- + +#### Arithmetic + +##### `pto.vabs` + +- **syntax:** `%result = pto.vabs %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] < 0) ? -src[i] : src[i]; +``` + +- **inputs:** `%input` supplies the source lanes and `%mask` selects which lanes + participate. +- **outputs:** `%result` receives the lane-wise absolute values. +- **constraints and limitations:** Source and result types MUST match. Integer + overflow on the most-negative signed value follows the target-defined + behavior. + +--- + +##### `pto.vneg` + +- **syntax:** `%result = pto.vneg %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = -src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise arithmetic negation. +- **constraints and limitations:** Source and result types MUST match. + +--- + +#### Transcendental + +##### `pto.vexp` + +- **syntax:** `%result = pto.vexp %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = expf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds `exp(input[i])` per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + +--- + +##### `pto.vln` + +- **syntax:** `%result = pto.vln %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = logf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the natural logarithm per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + For real-number semantics, active inputs SHOULD be strictly positive; non- + positive inputs follow the target's exception/NaN rules. + +--- + +##### `pto.vsqrt` + +- **syntax:** `%result = pto.vsqrt %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = sqrtf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the square root per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + Negative active inputs follow the target's exception/NaN rules. + +--- + +##### `pto.vrsqrt` + +- **syntax:** `%result = pto.vrsqrt %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = 1.0f / sqrtf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds reciprocal-square-root values per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + Active inputs containing `+0` or `-0` follow the target's divide-style + exceptional behavior. + +--- + +##### `pto.vrec` + +- **syntax:** `%result = pto.vrec %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = 1.0f / src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the reciprocal per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + Active inputs containing `+0` or `-0` follow the target's divide-style + exceptional behavior. + +--- + +#### Activation + +##### `pto.vrelu` + +- **syntax:** `%result = pto.vrelu %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] > 0) ? src[i] : 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds `max(input[i], 0)` per active lane. +- **constraints and limitations:** Only floating-point element types are legal + on the current A5 surface described here. + +--- + +#### Bitwise + +##### `pto.vnot` + +- **syntax:** `%result = pto.vnot %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = ~src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the lane-wise bitwise inversion. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vbcnt` + +- **syntax:** `%result = pto.vbcnt %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = __builtin_popcount(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the population count for each active lane. +- **constraints and limitations:** Integer element types only. The count is + over the source element width, not over the full vector register. + +--- + +##### `pto.vcls` + +- **syntax:** `%result = pto.vcls %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = count_leading_sign_bits(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the leading-sign-bit count per active lane. +- **constraints and limitations:** Integer element types only. This operation is + sign-aware, so signed interpretation matters. + +--- + +#### Movement + +##### `pto.vmov` + +- **syntax:** `%result = pto.vmov %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Vector register copy. + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` is a copy of the source vector. +- **constraints and limitations:** Predicated `pto.vmov` behaves like a masked + copy, while the unpredicated form behaves like a full-register copy. + +--- + +#### Typical Usage + +```mlir +// Softmax numerator: exp(x - max) +%sub = pto.vsub %x, %max_broadcast, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%exp = pto.vexp %sub, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Reciprocal for division +%sum_rcp = pto.vrec %sum, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// ReLU activation +%activated = pto.vrelu %linear_out, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +--- + + + +### 7. Binary Vector Ops + +> **Category:** Two-input vector operations +> **Pipeline:** PIPE_V (Vector Core) + +Element-wise operations that take two vector inputs and produce one vector output. + +#### Common Operand Model + +- `%lhs` and `%rhs` are the two source vector register values. +- `%mask` is the predicate operand `Pg` that gates which lanes participate. +- `%result` is the destination vector register value. Unless explicitly noted, + it has the same lane count and element type as the inputs. +- Unless explicitly documented otherwise, `%lhs`, `%rhs`, and `%result` MUST + have matching vector shapes and element types. + +--- + +#### Arithmetic + +##### `pto.vadd` + +- **syntax:** `%result = pto.vadd %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i64, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] + src1[i]; +``` + +- **inputs:** `%lhs` and `%rhs` are added lane-wise; `%mask` selects active + lanes. +- **outputs:** `%result` is the lane-wise sum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vsub` + +- **syntax:** `%result = pto.vsub %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i64, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] - src1[i]; +``` + +- **inputs:** `%lhs` is the minuend, `%rhs` is the subtrahend, and `%mask` + selects active lanes. +- **outputs:** `%result` is the lane-wise difference. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vmul` + +- **syntax:** `%result = pto.vmul %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, bf16, f32 (**NOT** i8/u8) + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] * src1[i]; +``` + +- **inputs:** `%lhs` and `%rhs` are multiplied lane-wise; `%mask` selects + active lanes. +- **outputs:** `%result` is the lane-wise product. +- **constraints and limitations:** The current A5 profile excludes `i8/u8` + forms from this surface. + +--- + +##### `pto.vdiv` + +- **syntax:** `%result = pto.vdiv %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 only (no integer division) + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] / src1[i]; +``` + +- **inputs:** `%lhs` is the numerator, `%rhs` is the denominator, and `%mask` + selects active lanes. +- **outputs:** `%result` is the lane-wise quotient. +- **constraints and limitations:** Floating-point element types only. Active + denominators containing `+0` or `-0` follow the target's exceptional + behavior. + +--- + +##### `pto.vmax` + +- **syntax:** `%result = pto.vmax %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src0[i] > src1[i]) ? src0[i] : src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` holds the lane-wise maximum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vmin` + +- **syntax:** `%result = pto.vmin %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src0[i] < src1[i]) ? src0[i] : src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` holds the lane-wise minimum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +#### Bitwise + +##### `pto.vand` + +- **syntax:** `%result = pto.vand %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] & src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise AND. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vor` + +- **syntax:** `%result = pto.vor %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] | src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise OR. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vxor` + +- **syntax:** `%result = pto.vxor %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] ^ src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise XOR. +- **constraints and limitations:** Integer element types only. + +--- + +#### Shift + +##### `pto.vshl` + +- **syntax:** `%result = pto.vshl %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] << src1[i]; +``` + +- **inputs:** `%lhs` supplies the shifted value, `%rhs` supplies the per-lane + shift amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. Shift counts + SHOULD stay within `[0, bitwidth(T) - 1]`; out-of-range behavior is target- + defined unless the verifier narrows it further. + +--- + +##### `pto.vshr` + +- **syntax:** `%result = pto.vshr %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] >> src1[i]; // arithmetic for signed, logical for unsigned +``` + +- **inputs:** `%lhs` supplies the shifted value, `%rhs` supplies the per-lane + shift amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. Signedness of the + element type determines arithmetic vs logical behavior. + +--- + +#### Carry Operations + +##### `pto.vaddc` + +- **syntax:** `%result, %carry = pto.vaddc %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Add with carry output. + +```c +for (int i = 0; i < N; i++) { + uint64_t r = (uint64_t)src0[i] + src1[i]; + dst[i] = (T)r; + carry[i] = (r >> bitwidth); +} +``` + +- **inputs:** `%lhs` and `%rhs` are added lane-wise and `%mask` selects active + lanes. +- **outputs:** `%result` is the truncated arithmetic result and `%carry` is the + carry/overflow predicate per lane. +- **constraints and limitations:** This is a carry-chain integer add family. On + the current A5 surface, it SHOULD be treated as an unsigned integer + operation. + +--- + +##### `pto.vsubc` + +- **syntax:** `%result, %borrow = pto.vsubc %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Subtract with borrow output. + +```c +for (int i = 0; i < N; i++) { + dst[i] = src0[i] - src1[i]; + borrow[i] = (src0[i] < src1[i]); +} +``` + +- **inputs:** `%lhs` and `%rhs` are subtracted lane-wise and `%mask` selects + active lanes. +- **outputs:** `%result` is the arithmetic difference and `%borrow` marks lanes + that borrowed. +- **constraints and limitations:** This operation SHOULD be treated as an + unsigned 32-bit carry-chain family unless and until the verifier states + otherwise. + +--- + +#### Typical Usage + +```mlir +// Vector addition +%sum = pto.vadd %a, %b, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Element-wise multiply +%prod = pto.vmul %x, %y, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Clamp to range [min, max] +%clamped_low = pto.vmax %input, %min_vec, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%clamped = pto.vmin %clamped_low, %max_vec, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Bit manipulation +%masked = pto.vand %data, %bitmask, %mask : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +``` + +--- + + + +### 8. Vec-Scalar Ops + +> **Category:** Vector-scalar operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that combine a vector with a scalar value, applying the scalar to every lane. + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%scalar` is the scalar operand in SSA form. +- `%mask` is the predicate operand. +- `%result` is the destination vector register value. +- For 32-bit scalar forms, the scalar source MUST satisfy the backend's legal + scalar-source constraints for this family. + +--- + +#### Arithmetic + +##### `pto.vadds` + +- **syntax:** `%result = pto.vadds %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] + scalar; +``` + +- **inputs:** `%input` is the source vector, `%scalar` is broadcast logically to + each active lane, and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise sum. +- **constraints and limitations:** Inactive lanes follow the predication + behavior defined for this family. On the current surface, inactive lanes are + treated as zeroing lanes. + +--- + +##### `pto.vsubs` + +- **syntax:** `%result = pto.vsubs %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] - scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise difference. +- **constraints and limitations:** Integer or floating-point legality depends on + the selected type family in lowering. + +--- + +##### `pto.vmuls` + +- **syntax:** `%result = pto.vmuls %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] * scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise product. +- **constraints and limitations:** Supported element types are hardware-family + specific; the current PTO micro Instruction documentation covers the common numeric cases. + +--- + +##### `pto.vmaxs` + +- **syntax:** `%result = pto.vmaxs %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] > scalar) ? src[i] : scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise maximum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vmins` + +- **syntax:** `%result = pto.vmins %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] < scalar) ? src[i] : scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise minimum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +#### Bitwise + +##### `pto.vands` + +- **syntax:** `%result = pto.vands %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] & scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise AND. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vors` + +- **syntax:** `%result = pto.vors %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] | scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise OR. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vxors` + +- **syntax:** `%result = pto.vxors %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] ^ scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise XOR. +- **constraints and limitations:** Integer element types only. + +--- + +#### Shift + +##### `pto.vshls` + +- **syntax:** `%result = pto.vshls %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] << scalar; +``` + +- **inputs:** `%input` is the value vector, `%scalar` is the uniform shift + amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. The shift amount + SHOULD stay within the source element width. + +--- + +##### `pto.vshrs` + +- **syntax:** `%result = pto.vshrs %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] >> scalar; +``` + +- **inputs:** `%input` is the value vector, `%scalar` is the uniform shift + amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vlrelu` + +- **syntax:** `%result = pto.vlrelu %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : scalar * src[i]; +``` + +- **inputs:** `%input` is the activation vector, `%scalar` is the leaky slope, + and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise leaky-ReLU result. +- **constraints and limitations:** Only `f16` and `f32` forms are currently + documented for `pto.vlrelu`. + +--- + +#### Carry Operations + +##### `pto.vaddcs` + +- **syntax:** `%result, %carry = pto.vaddcs %lhs, %rhs, %carry_in, %mask : !pto.vreg, !pto.vreg, !pto.mask, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Add with carry-in and carry-out. + +```c +for (int i = 0; i < N; i++) { + uint64_t r = (uint64_t)src0[i] + src1[i] + carry_in[i]; + dst[i] = (T)r; + carry_out[i] = (r >> bitwidth); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the value vectors, `%carry_in` is the + incoming carry predicate, and `%mask` selects active lanes. +- **outputs:** `%result` is the arithmetic result and `%carry` is the carry-out + predicate. +- **constraints and limitations:** This is the scalar-extended carry-chain + family. Treat it as an unsigned integer operation unless the verifier states a + wider legal domain. + +--- + +##### `pto.vsubcs` + +- **syntax:** `%result, %borrow = pto.vsubcs %lhs, %rhs, %borrow_in, %mask : !pto.vreg, !pto.vreg, !pto.mask, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Subtract with borrow-in and borrow-out. + +```c +for (int i = 0; i < N; i++) { + dst[i] = src0[i] - src1[i] - borrow_in[i]; + borrow_out[i] = (src0[i] < src1[i] + borrow_in[i]); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the value vectors, `%borrow_in` is the + incoming borrow predicate, and `%mask` selects active lanes. +- **outputs:** `%result` is the arithmetic result and `%borrow` is the + borrow-out predicate. +- **constraints and limitations:** This is the scalar-extended borrow-chain + family and SHOULD be treated as an unsigned integer operation. + +--- + +#### Typical Usage + +```mlir +// Add bias to all elements +%biased = pto.vadds %activation, %bias_scalar, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Scale by constant +%scaled = pto.vmuls %input, %scale, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Clamp to [0, 255] for uint8 quantization +%clamped_low = pto.vmaxs %input, %c0, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> +%clamped = pto.vmins %clamped_low, %c255, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Shift right by fixed amount +%shifted = pto.vshrs %data, %c4, %mask : !pto.vreg<64xi32>, i32, !pto.mask -> !pto.vreg<64xi32> +``` + +--- + + + +### 9. Conversion Ops + +> **Category:** Type conversion operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that convert between data types (float/int, narrowing/widening). + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%result` is the destination vector register value. +- `round_mode`, `sat`, and `part` control rounding, saturation, and lane-part + selection in attribute form. +- The single `pto.vcvt` surface covers float-int, float-float, int-float, and + int-int conversion families. + +--- + +#### `pto.vci` + +- **syntax:** `%result = pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- **semantics:** Generate a lane-index vector from a scalar seed/index value. +- **inputs:** + `%index` is the scalar seed or base index. +- **outputs:** + `%result` is the generated index vector. +- **constraints and limitations:** + This is an index-generation family, not a numeric conversion. `ORDER` and the + result element type together determine how indices are generated. + +--- + +#### `pto.vcvt` + +- **syntax:** `%result = pto.vcvt %input {round_mode = "ROUND_MODE", sat = "SAT_MODE", part = "PART_MODE"} : !pto.vreg -> !pto.vreg` +- **semantics:** Type conversion between float/int types with rounding control. + +```c +for (int i = 0; i < min(N, M); i++) + dst[i] = convert(src[i], T0, T1, round_mode); +``` + +- **inputs:** + `%input` is the source vector; attributes select rounding, saturation, and + even/odd placement when the conversion changes width. +- **outputs:** + `%result` is the converted vector. +- **constraints and limitations:** + Only documented source/destination type pairs are legal. `PART_EVEN` / + `PART_ODD` is only meaningful for width-changing forms that pack two source + streams into one destination register. + +--- + +##### Rounding Modes + +| Mode | Description | +|------|-------------| +| `ROUND_R` | Round to nearest, ties to even (default) | +| `ROUND_A` | Round away from zero | +| `ROUND_F` | Round toward negative infinity (floor) | +| `ROUND_C` | Round toward positive infinity (ceil) | +| `ROUND_Z` | Round toward zero (truncate) | +| `ROUND_O` | Round to odd | + +--- + +##### Saturation Modes + +| Mode | Description | +|------|-------------| +| `RS_ENABLE` | Saturate on overflow | +| `RS_DISABLE` | No saturation (wrap/undefined on overflow) | + +--- + +##### Part Modes (for width-changing conversions) + +| Mode | Description | +|------|-------------| +| `PART_EVEN` | Output to even-indexed lanes | +| `PART_ODD` | Output to odd-indexed lanes | + +--- + +##### A5 Supported Conversions + +**Float-Float (vcvtff):** +- f32 ↔ f16 +- f32 ↔ bf16 +- f16 ↔ bf16 + +**Float-Int (vcvtfi):** +- f16 → i16, f16 → i32 +- f32 → i16, f32 → i32 +- bf16 → i32 + +**Int-Float (vcvtif):** +- i16 → f16 +- i32 → f32 + +--- + +##### Width-Changing Conversion Pattern + +For conversions that change width (e.g., f32→f16), use even/odd parts and combine: + +```mlir +// Convert two f32 vectors to one f16 vector +%even = pto.vcvt %in0 {round_mode = "ROUND_R", sat = "RS_ENABLE", part = "PART_EVEN"} + : !pto.vreg<64xf32> -> !pto.vreg<128xf16> +%odd = pto.vcvt %in1 {round_mode = "ROUND_R", sat = "RS_ENABLE", part = "PART_ODD"} + : !pto.vreg<64xf32> -> !pto.vreg<128xf16> +%result = pto.vor %even, %odd, %mask : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> +``` + +--- + +#### `pto.vtrc` + +- **syntax:** `%result = pto.vtrc %input, %mask, "ROUND_MODE" : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Truncate/round float to integer-valued float (stays in float type). + +```c +for (int i = 0; i < N; i++) + dst[i] = round_to_int_valued_float(src[i], round_mode); +``` + +- **inputs:** + `%input` is the floating-point source vector, `%mask` selects active lanes, + and `ROUND_MODE` selects the truncation/rounding rule. +- **outputs:** + `%result` is still a floating-point vector, but each active lane now carries + an integer-valued floating-point result. +- **constraints and limitations:** + This op does not change the element type. `T` must be `f16`, `f32`, or + `bf16`. `ROUND_MODE` must be one of `ROUND_R`, `ROUND_A`, `ROUND_F`, + `ROUND_C`, or `ROUND_Z`. `BW` must match the element width: `b16` for + `f16`/`bf16`, `b32` for `f32`. + +**Example:** +```mlir +// Round to nearest integer, keep as float +%rounded = pto.vtrc %input, %mask, "ROUND_R" : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// input: [1.4, 2.6, -1.5, 3.0] +// output: [1.0, 3.0, -2.0, 3.0] +``` + +--- + +#### Typical Usage + +```mlir +// Quantization: f32 → i8 with saturation +%scaled = pto.vmuls %input, %scale, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> +%quantized = pto.vcvt %scaled {round_mode = "ROUND_R", sat = "RS_ENABLE"} + : !pto.vreg<64xf32> -> !pto.vreg<64xi32> +// Then narrow i32 → i8 via pack ops + +// Mixed precision: bf16 → f32 for accumulation +%f32_vec = pto.vcvt %bf16_input {round_mode = "ROUND_R"} + : !pto.vreg<128xbf16> -> !pto.vreg<64xf32> + +// Floor for integer division +%floored = pto.vtrc %ratio, %mask, "ROUND_F" : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%int_div = pto.vcvt %floored {round_mode = "ROUND_Z"} + : !pto.vreg<64xf32> -> !pto.vreg<64xi32> +``` + +--- + + + +### 10. Reduction Ops + +> **Category:** Vector reduction operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that reduce a vector to a scalar or per-group result. + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%mask` is the predicate operand `Pg`; inactive lanes do not participate. +- `%result` is the destination vector register value. +- Reduction results are written into the low-significance portion of the + destination vector and the remaining destination bits are zero-filled. + +--- + +#### Full Vector Reductions + +##### `pto.vcadd` + +- **syntax:** `%result = pto.vcadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i64, f16, f32 +- **semantics:** Sum all elements. Result in lane 0, others zeroed. + +```c +T sum = 0; +for (int i = 0; i < N; i++) + sum += src[i]; +dst[0] = sum; +for (int i = 1; i < N; i++) + dst[i] = 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains the reduction result in its low element(s). +- **constraints and limitations:** Some narrow integer forms may widen the + internal accumulation or result placement. If all predicate bits are zero, the + result is zero. + +--- + +##### `pto.vcmax` + +- **syntax:** `%result = pto.vcmax %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Find max element with argmax. Result value + index in lane 0. + +```c +T mx = -INF; int idx = 0; +for (int i = 0; i < N; i++) + if (src[i] > mx) { mx = src[i]; idx = i; } +dst_val[0] = mx; +dst_idx[0] = idx; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` carries the reduction result in the low destination + positions. +- **constraints and limitations:** This family computes both the extremum and + location information, but the exact packing of that information into the + destination vector depends on the chosen form. If all predicate bits are zero, + the result follows the zero-filled convention. + +--- + +##### `pto.vcmin` + +- **syntax:** `%result = pto.vcmin %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Find min element with argmin. Result value + index in lane 0. + +```c +T mn = INF; int idx = 0; +for (int i = 0; i < N; i++) + if (src[i] < mn) { mn = src[i]; idx = i; } +dst_val[0] = mn; +dst_idx[0] = idx; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` carries the reduction result in the low destination + positions. +- **constraints and limitations:** As with `pto.vcmax`, the exact value/index + packing depends on the chosen form and MUST be preserved consistently. + +--- + +#### Per-VLane (Group) Reductions + +The vector register is organized as **8 VLanes** of 32 bytes each. Group reductions operate within each VLane independently. + +``` +vreg layout (f32 example, 64 elements total): +VLane 0: [0..7] VLane 1: [8..15] VLane 2: [16..23] VLane 3: [24..31] +VLane 4: [32..39] VLane 5: [40..47] VLane 6: [48..55] VLane 7: [56..63] +``` + +##### `pto.vcgadd` + +- **syntax:** `%result = pto.vcgadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Sum within each VLane. 8 results at indices 0, 8, 16, 24, 32, 40, 48, 56 (for f32). + +```c +int K = N / 8; // elements per VLane +for (int g = 0; g < 8; g++) { + T sum = 0; + for (int i = 0; i < K; i++) + sum += src[g*K + i]; + dst[g*K] = sum; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +// For f32: results at dst[0], dst[8], dst[16], dst[24], dst[32], dst[40], dst[48], dst[56] +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one sum per 32-byte VLane group, written + contiguously into the low slot of each group. +- **constraints and limitations:** This is a per-32-byte VLane-group reduction. + Inactive lanes are treated as zero. + +--- + +##### `pto.vcgmax` + +- **syntax:** `%result = pto.vcgmax %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Max within each VLane. + +```c +int K = N / 8; +for (int g = 0; g < 8; g++) { + T mx = -INF; + for (int i = 0; i < K; i++) + if (src[g*K + i] > mx) mx = src[g*K + i]; + dst[g*K] = mx; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one maximum per 32-byte VLane group. +- **constraints and limitations:** Grouping is by hardware 32-byte VLane, not by + arbitrary software subvector. + +--- + +##### `pto.vcgmin` + +- **syntax:** `%result = pto.vcgmin %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Min within each VLane. + +```c +int K = N / 8; +for (int g = 0; g < 8; g++) { + T mn = INF; + for (int i = 0; i < K; i++) + if (src[g*K + i] < mn) mn = src[g*K + i]; + dst[g*K] = mn; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one minimum per 32-byte VLane group. +- **constraints and limitations:** Grouping is by hardware 32-byte VLane, not by + arbitrary software subvector. + +--- + +#### Prefix Operations + +##### `pto.vcpadd` + +- **syntax:** `%result = pto.vcpadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Inclusive prefix sum (scan). + +```c +dst[0] = src[0]; +for (int i = 1; i < N; i++) + dst[i] = dst[i-1] + src[i]; +``` + +**Example:** +```c +// input: [1, 2, 3, 4, 5, ...] +// output: [1, 3, 6, 10, 15, ...] +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` is the inclusive prefix-sum vector. +- **constraints and limitations:** Only floating-point element types are + documented on the current A5 surface here. + +--- + +#### Typical Usage + +```mlir +// Softmax: find max for numerical stability +%max_vec = pto.vcmax %logits, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// max is in lane 0, broadcast it +%max_broadcast = pto.vlds %ub_tmp[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> + +// Row-wise sum using vcgadd (for 8-row tile) +%row_sums = pto.vcgadd %tile, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// Results at indices 0, 8, 16, 24, 32, 40, 48, 56 + +// Full vector sum for normalization +%total = pto.vcadd %values, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// total[0] contains the sum + +// Prefix sum for cumulative distribution +%cdf = pto.vcpadd %pdf, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +--- + + + +### 11. Compare & Select + +> **Category:** Comparison and conditional selection operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that compare vectors and conditionally select elements. + +#### Common Operand Model + +- `%src0` and `%src1` are source vector operands. +- `%scalar` is the scalar operand for scalar-comparison families. +- `%seed` is the incoming predicate that limits which lanes participate in the + compare. +- `%result` is either a predicate mask (`vcmp`, `vcmps`) or a vector register + (`vsel`, `vselr`, `vselrv2`). + +--- + +#### Comparison Operations + +##### `pto.vcmp` + +- **syntax:** `%result = pto.vcmp %src0, %src1, %seed, "CMP_MODE" : !pto.vreg, !pto.vreg, !pto.mask -> !pto.mask` +- **semantics:** Element-wise comparison, output predicate mask. + +```c +for (int i = 0; i < N; i++) + if (seed[i]) + dst[i] = (src0[i] CMP src1[i]) ? 1 : 0; +``` + +**Compare modes:** + +| Mode | Operation | +|------|-----------| +| `eq` | Equal (==) | +| `ne` | Not equal (!=) | +| `lt` | Less than (<) | +| `le` | Less than or equal (<=) | +| `gt` | Greater than (>) | +| `ge` | Greater than or equal (>=) | + +**Example:** +```mlir +%all_active = pto.pset_b32 "PAT_ALL" : !pto.mask +%lt_mask = pto.vcmp %a, %b, %all_active, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +// lt_mask[i] = 1 if a[i] < b[i] +``` + +- **inputs:** `%src0`, `%src1`, and `%seed`; `CMP_MODE` selects the comparison + predicate. +- **outputs:** `%result` is the generated predicate mask. +- **constraints and limitations:** Only lanes enabled by `%seed` participate. + Integer and floating-point comparisons follow their own element-type-specific + comparison rules. + +--- + +##### `pto.vcmps` + +- **syntax:** `%result = pto.vcmps %src, %scalar, %seed, "CMP_MODE" : !pto.vreg, T, !pto.mask -> !pto.mask` +- **semantics:** Compare vector against scalar. + +```c +for (int i = 0; i < N; i++) + if (seed[i]) + dst[i] = (src[i] CMP scalar) ? 1 : 0; +``` + +**Example:** +```mlir +%positive_mask = pto.vcmps %values, %c0_f32, %all_active, "gt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +// positive_mask[i] = 1 if values[i] > 0 +``` + +- **inputs:** `%src` is the vector source, `%scalar` is the scalar comparison + value, and `%seed` is the incoming predicate. +- **outputs:** `%result` is the generated predicate mask. +- **constraints and limitations:** For 32-bit scalar forms, the scalar source + MUST satisfy the backend's legal scalar-source constraints for this family. + +--- + +#### Selection Operations + +##### `pto.vsel` + +- **syntax:** `%result = pto.vsel %src0, %src1, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Per-lane select based on mask. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? src0[i] : src1[i]; +``` + +**Example — Conditional assignment:** +```mlir +// dst = mask ? true_vals : false_vals +%result = pto.vsel %true_vals, %false_vals, %condition + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +- **inputs:** `%src0` is the true-path vector, `%src1` is the false-path vector, + and `%mask` selects between them. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** Source vectors and result MUST have matching + vector shapes and element types. + +--- + +##### `pto.vselr` + +- **syntax:** `%result = pto.vselr %src0, %src1 : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Select with reversed mask semantics. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? src1[i] : src0[i]; // reversed from vsel +``` + +- **inputs:** `%src0` and `%src1` are the source vectors. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** This family preserves reversed-select + semantics. If the concrete lowering uses an implicit predicate source, that + predicate source MUST be documented by the surrounding IR pattern. + +--- + +##### `pto.vselrv2` + +- **syntax:** `%result = pto.vselrv2 %src0, %src1 : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Variant select form with the same current two-vector operand shape. +- **inputs:** `%src0` and `%src1` are the source vectors. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** This page records the surface shape only. + Lowering MUST preserve the exact A5 variant semantics selected for this form. + +--- + +#### Typical Usage + +```mlir +// Clamp negative values to zero (manual ReLU) +%all = pto.pset_b32 "PAT_ALL" : !pto.mask +%zero = pto.vbr %c0_f32 : f32 -> !pto.vreg<64xf32> +%neg_mask = pto.vcmps %input, %c0_f32, %all, "lt" : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%clamped = pto.vsel %zero, %input, %neg_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Element-wise max via compare+select +%gt_mask = pto.vcmp %a, %b, %all, "gt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +%max_ab = pto.vsel %a, %b, %gt_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Threshold filter +%above_thresh = pto.vcmps %scores, %threshold, %all, "ge" : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%filtered = pto.vsel %scores, %zero, %above_thresh : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +--- + +#### Compare + Select Pattern + +```mlir +// Softmax safe exp: exp(x - max) where x < max returns exp of negative +// but we want to clamp to avoid underflow + +%all = pto.pset_b32 "PAT_ALL" : !pto.mask + +// 1. Compare against threshold +%too_small = pto.vcmps %x_minus_max, %min_exp_arg, %all, "lt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask + +// 2. Clamp values below threshold +%clamped = pto.vsel %min_exp_arg_vec, %x_minus_max, %too_small + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// 3. Safe exp +%exp_result = pto.vexp %clamped, %all : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +--- + + + +### 12. Data Rearrangement + +> **Category:** In-register data movement and permutation +> **Pipeline:** PIPE_V (Vector Core) + +Operations that rearrange data within or between vector registers without memory access. + +#### Common Operand Model + +- `%lhs` / `%rhs` are source vector register values. +- `%src` is a single source vector register value. +- `%result` is the destination vector register value unless an op explicitly + returns multiple vectors. +- These families do not access UB directly; they only rearrange register + contents. + +--- + +#### Interleave / Deinterleave + +##### `pto.vintlv` + +- **syntax:** `%low, %high = pto.vintlv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg, !pto.vreg` +- **semantics:** Interleave elements from two sources. + +```c +// Interleave: merge even/odd elements from two sources +// low = {src0[0], src1[0], src0[1], src1[1], ...} +// high = {src0[N/2], src1[N/2], src0[N/2+1], src1[N/2+1], ...} +``` + +- **inputs:** `%lhs` and `%rhs` are the two source vectors. +- **outputs:** `%low` and `%high` are the two destination vectors. +- **constraints and limitations:** The two outputs form a paired interleave + result. The PTO micro Instruction representation exposes that pair as two SSA results, and the pair ordering MUST + be preserved. + +--- + +##### `pto.vdintlv` + +- **syntax:** `%low, %high = pto.vdintlv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg, !pto.vreg` +- **semantics:** Deinterleave elements into even/odd. + +```c +// Deinterleave: separate even/odd elements +// low = {src0[0], src0[2], src0[4], ...} // even +// high = {src0[1], src0[3], src0[5], ...} // odd +``` + +- **inputs:** `%lhs` and `%rhs` represent the interleaved source stream in the + current PTO micro Instruction representation. +- **outputs:** `%low` and `%high` are the separated destination vectors. +- **constraints and limitations:** The two outputs form the even/odd + deinterleave result pair, and their ordering MUST be preserved. + +--- + +#### Slide / Shift + +##### `pto.vslide` + +- **syntax:** `%result = pto.vslide %src0, %src1, %amt : !pto.vreg, !pto.vreg, i16 -> !pto.vreg` +- **semantics:** Concatenate two vectors and extract N-element window at offset. + +```c +// Conceptually: tmp[0..2N-1] = {src1, src0} +// dst[i] = tmp[amt + i] +if (amt >= 0) + for (int i = 0; i < N; i++) + dst[i] = (i >= amt) ? src0[i - amt] : src1[N - amt + i]; +``` + +**Use case:** Sliding window operations, shift register patterns. + +- **inputs:** `%src0` and `%src1` provide the concatenated source window and + `%amt` selects the extraction offset. +- **outputs:** `%result` is the extracted destination window. +- **constraints and limitations:** `pto.vslide` operates on the logical + concatenation of `%src1` and `%src0`. The source order and extraction offset + MUST be preserved exactly. + +--- + +##### `pto.vshift` + +- **syntax:** `%result = pto.vshift %src, %amt : !pto.vreg, i16 -> !pto.vreg` +- **semantics:** Single-source slide (shift with zero fill). + +```c +for (int i = 0; i < N; i++) + dst[i] = (i >= amt) ? src[i - amt] : 0; +``` + +- **inputs:** `%src` is the source vector and `%amt` is the slide amount. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** This surface represents the single-source + slide/shift family. Zero-fill versus other fill behavior MUST match the + selected form. + +--- + +#### Compress / Expand + +##### `pto.vsqz` + +- **syntax:** `%result = pto.vsqz %src, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Compress — pack active lanes to front. + +```c +int j = 0; +for (int i = 0; i < N; i++) + if (mask[i]) dst[j++] = src[i]; +while (j < N) dst[j++] = 0; +``` + +**Use case:** Sparse data compaction, filtering. + +- **inputs:** `%src` is the source vector and `%mask` selects which elements are + kept. +- **outputs:** `%result` is the compacted vector. +- **constraints and limitations:** This is a reduction-style compaction family. + Preserved element order MUST match source lane order. + +--- + +##### `pto.vusqz` + +- **syntax:** `%result = pto.vusqz %mask : !pto.mask -> !pto.vreg` +- **semantics:** Expand — scatter front elements to active positions. + +```c +int j = 0; +for (int i = 0; i < N; i++) + if (mask[i]) dst[i] = src_front[j++]; + else dst[i] = 0; +``` + +- **inputs:** `%mask` is the expansion/placement predicate. +- **outputs:** `%result` is the expanded vector image. +- **constraints and limitations:** The source-front stream is implicit in the + current surface. Lane placement for active and inactive positions MUST be + preserved exactly. + +--- + +#### Permutation + +##### `pto.vperm` + +- **syntax:** `%result = pto.vperm %src, %index : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** In-register permute (table lookup). **Not** memory gather. + +```c +for (int i = 0; i < N; i++) + dst[i] = src[index[i] % N]; +``` + +**Note:** This operates on register contents, unlike `pto.vgather2` which reads from UB memory. + +- **inputs:** `%src` is the source vector and `%index` supplies per-lane source + indices. +- **outputs:** `%result` is the permuted vector. +- **constraints and limitations:** This is an in-register permutation family. + `%index` values outside the legal range follow the wrap/clamp behavior of the + selected form. + +--- + +##### `pto.vselr` + +- **syntax:** `%result = pto.vselr %src0, %src1 : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Register select with reversed mask semantics. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? src1[i] : src0[i]; +``` + +- **inputs:** `%src0` and `%src1` are source vectors. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** This page records the rearrangement use of + the family; the compare/select page documents the same name from the predicate + selection perspective. + +--- + +#### Pack / Unpack + +##### `pto.vpack` + +- **syntax:** `%result = pto.vpack %src0, %src1, %part : !pto.vreg, !pto.vreg, index -> !pto.vreg<2NxT_narrow>` +- **semantics:** Narrowing pack — two wide vectors to one narrow vector. + +```c +// e.g., two vreg<64xi32> → one vreg<128xi16> +for (int i = 0; i < N; i++) { + dst[i] = truncate(src0[i]); + dst[N + i] = truncate(src1[i]); +} +``` + +- **inputs:** `%src0` and `%src1` are wide source vectors and `%part` selects + the packing submode. +- **outputs:** `%result` is the packed narrow vector. +- **constraints and limitations:** Packing is a narrowing conversion. Source + values that do not fit the destination width follow the truncation semantics + of the selected packing mode. + +--- + +##### `pto.vsunpack` + +- **syntax:** `%result = pto.vsunpack %src, %part : !pto.vreg, index -> !pto.vreg` +- **semantics:** Sign-extending unpack — narrow to wide (half). + +```c +// e.g., vreg<128xi16> → vreg<64xi32> (one half) +for (int i = 0; i < N/2; i++) + dst[i] = sign_extend(src[part_offset + i]); +``` + +- **inputs:** `%src` is the packed narrow vector and `%part` selects which half + is unpacked. +- **outputs:** `%result` is the widened vector. +- **constraints and limitations:** This is the sign-extending unpack family. + +--- + +##### `pto.vzunpack` + +- **syntax:** `%result = pto.vzunpack %src, %part : !pto.vreg, index -> !pto.vreg` +- **semantics:** Zero-extending unpack — narrow to wide (half). + +```c +for (int i = 0; i < N/2; i++) + dst[i] = zero_extend(src[part_offset + i]); +``` + +- **inputs:** `%src` is the packed narrow vector and `%part` selects which half + is unpacked. +- **outputs:** `%result` is the widened vector. +- **constraints and limitations:** This is the zero-extending unpack family. + +--- + +#### Typical Usage + +```mlir +// AoS → SoA conversion using deinterleave +%even, %odd = pto.vdintlv %interleaved0, %interleaved1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +// Filter: keep only elements passing condition +%pass_mask = pto.vcmps %values, %threshold, %all, "gt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%compacted = pto.vsqz %values, %pass_mask + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Sliding window sum +%prev_window = pto.vslide %curr, %prev, %c1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, i16 -> !pto.vreg<64xf32> +%window_sum = pto.vadd %curr, %prev_window, %all + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Type narrowing via pack +%packed_i16 = pto.vpack %wide0_i32, %wide1_i32, %c0 + : !pto.vreg<64xi32>, !pto.vreg<64xi32>, index -> !pto.vreg<128xi16> +``` + +--- + +#### V2 Interleave Forms + +##### `pto.vintlvv2` + +- **syntax:** `%result = pto.vintlvv2 %lhs, %rhs, "PART" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **inputs:** `%lhs` and `%rhs` are source vectors and `PART` selects the + returned half of the V2 interleave result. +- **outputs:** `%result` is the selected interleave half. +- **constraints and limitations:** This op exposes only one half of the V2 + result in SSA form. + +##### `pto.vdintlvv2` + +- **syntax:** `%result = pto.vdintlvv2 %lhs, %rhs, "PART" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **inputs:** `%lhs` and `%rhs` are source vectors and `PART` selects the + returned half of the V2 deinterleave result. +- **outputs:** `%result` is the selected deinterleave half. +- **constraints and limitations:** This op exposes only one half of the V2 + result in SSA form. + +--- + + + +### 13. DSA/SFU Ops + +> **Category:** Domain-specific accelerator and special function unit operations +> **Pipeline:** PIPE_V (Vector Core) / SFU + +Fused operations, special functions, and UB-to-UB operations that leverage hardware acceleration. + +#### Common Operand Model + +- `%input`, `%lhs`, `%rhs`, `%acc`, and `%alpha` are source SSA values whose + roles are called out per instruction. +- `%mask` is the predicate operand `Pg` when present. +- `%result` is the destination SSA value. +- This page mixes three different backend shapes: pure `vreg -> vreg` ops, + conversion/fusion ops, and UB-to-UB helpers. Each instruction section calls + out which storage model it uses. + +--- + +#### Fused Activation Ops (vreg→vreg) + +##### `pto.vlrelu` + +- **syntax:** `%result = pto.vlrelu %input, %alpha, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Leaky ReLU with scalar alpha. + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : alpha * src[i]; +``` + +- **inputs:** `%input` is the activation vector, `%alpha` is the scalar slope, + and `%mask` selects active lanes. +- **outputs:** `%result` is the leaky-ReLU vector. +- **constraints and limitations:** Only `f16` and `f32` forms are currently + documented for `pto.vlrelu`. + +--- + +##### `pto.vprelu` + +- **syntax:** `%result = pto.vprelu %input, %alpha : !pto.vreg, !pto.vreg -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Parametric ReLU with per-element alpha vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : alpha[i] * src[i]; +``` + +- **inputs:** `%input` is the activation vector and `%alpha` is the per-element + slope vector. +- **outputs:** `%result` is the parametric-ReLU vector. +- **constraints and limitations:** Floating-point element types only on the + current A5 surface. + +--- + +##### `pto.vexpdiff` + +- **syntax:** `%result = pto.vexpdiff %input, %max : !pto.vreg, !pto.vreg -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Fused exp(x - max) for numerically stable softmax. + +```c +for (int i = 0; i < N; i++) + dst[i] = expf(src[i] - max[i]); +``` + +**Use case:** Softmax numerator computation with numerical stability. + +- **inputs:** `%input` is the source vector and `%max` is the broadcasted + subtraction term. +- **outputs:** `%result` is the fused `exp(input - max)` vector. +- **constraints and limitations:** Floating-point element types only. + +--- + +#### Fused Compute+Convert Ops + +##### `pto.vaddrelu` + +- **syntax:** `%result = pto.vaddrelu %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Fused add + ReLU. + +```c +for (int i = 0; i < N; i++) + dst[i] = max(src0[i] + src1[i], 0); +``` + +- **inputs:** `%lhs` and `%rhs` are the two addends. +- **outputs:** `%result` is the fused add-then-ReLU result. +- **constraints and limitations:** Floating-point element types only on the + current documented surface. + +--- + +##### `pto.vsubrelu` + +- **syntax:** `%result = pto.vsubrelu %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Fused sub + ReLU. + +```c +for (int i = 0; i < N; i++) + dst[i] = max(src0[i] - src1[i], 0); +``` + +- **inputs:** `%lhs` is the minuend and `%rhs` is the subtrahend. +- **outputs:** `%result` is the fused sub-then-ReLU result. +- **constraints and limitations:** Floating-point element types only on the + current documented surface. + +--- + +##### `pto.vaxpy` + +- **syntax:** `%result = pto.vaxpy %src0, %src1, %alpha : !pto.vreg, !pto.vreg, T -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** AXPY — scalar-vector multiply-add. + +```c +for (int i = 0; i < N; i++) + dst[i] = alpha * src0[i] + src1[i]; +``` + +- **inputs:** `%src0` is the scaled vector, `%src1` is the addend vector, and + `%alpha` is the scalar multiplier. +- **outputs:** `%result` is the fused AXPY result. +- **constraints and limitations:** Floating-point element types only on the + current documented surface. + +--- + +##### `pto.vaddreluconv` + +- **syntax:** `%result = pto.vaddreluconv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Fused add + ReLU + type conversion (HW fusion). + +```c +// f32→f16 variant: +for (int i = 0; i < 64; i++) + dst_f16[i] = f32_to_f16(max(src0_f32[i] + src1_f32[i], 0)); + +// f16→i8 variant: +for (int i = 0; i < 128; i++) + dst_i8[i] = f16_to_i8(max(src0_f16[i] + src1_f16[i], 0)); +``` + +- **inputs:** `%lhs` and `%rhs` are the source vectors. +- **outputs:** `%result` is the fused add/ReLU/convert result. +- **constraints and limitations:** Only backend-supported source/destination + type pairs are legal. Rounding, saturation, and packing rules follow the + semantics of this fused operation, not an arbitrary sequence of standalone + ops. + +--- + +##### `pto.vmulconv` + +- **syntax:** `%result = pto.vmulconv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Fused mul + type conversion (HW fusion). + +```c +// f16→i8 variant: +for (int i = 0; i < 128; i++) + dst_i8[i] = f16_to_i8(src0_f16[i] * src1_f16[i]); +``` + +- **inputs:** `%lhs` and `%rhs` are the source vectors. +- **outputs:** `%result` is the fused mul/convert result. +- **constraints and limitations:** Only backend-supported source/destination + type pairs are legal. + +--- + +#### Extended Arithmetic + +##### `pto.vmull` + +- **syntax:** `%low, %high = pto.vmull %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.vreg` +- **A5 types:** i32/u32 (native 32×32→64 widening multiply) +- **semantics:** Widening multiply with high/low results. + +```c +for (int i = 0; i < 64; i++) { + int64_t r = (int64_t)src0_i32[i] * (int64_t)src1_i32[i]; + dst_lo[i] = (int32_t)(r & 0xFFFFFFFF); + dst_hi[i] = (int32_t)(r >> 32); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the source vectors and `%mask` selects + active lanes. +- **outputs:** `%low` and `%high` expose the widened-product low/high parts. +- **constraints and limitations:** The current documented A5 form is the native + widening 32x32->64 integer multiply family. + +--- + +##### `pto.vmula` + +- **syntax:** `%result = pto.vmula %acc, %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Multiply-accumulate. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) + dst[i] = acc[i] + lhs[i] * rhs[i]; +``` + +- **inputs:** `%acc` is the accumulator input, `%lhs` and `%rhs` are the + multiplicands, and `%mask` selects active lanes. +- **outputs:** `%result` is the multiply-accumulate result. +- **constraints and limitations:** `pto.vmula` is a fused multiply-accumulate + operation and is not always interchangeable with separate `vmul` plus `vadd`. + +--- + +#### Index Generation + +##### `pto.vci` + +- **syntax:** `%result = pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- **semantics:** Generate lane index vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = base_index + i; +``` + +**Use case:** Generate indices for gather/scatter, argsort, etc. + +- **inputs:** `%index` is the scalar seed/base index. +- **outputs:** `%result` is the generated index vector. +- **constraints and limitations:** This page documents the arithmetic/indexing + use of the family; the conversion page also records the same opcode for + completeness. + +--- + +#### UB-to-UB Operations + +##### `pto.vtranspose` + +- **syntax:** `pto.vtranspose %dest, %src, %config : !pto.ptr, !pto.ptr, i64` +- **semantics:** UB-to-UB transpose operation (not vreg-to-vreg). + +**Note:** This operates on UB memory directly, not on vector registers. + +- **inputs:** `%dest` and `%src` are UB pointers and `%config` is the ISA + control/config word. +- **outputs:** This op writes UB memory and returns no SSA value. +- **constraints and limitations:** This is not a `vreg -> vreg` op even though + it lives in the `pto.v*` namespace. Its correctness depends on the control + word and UB layout contract. + +--- + +#### Sorting Operations + +##### `pto.vsort32` + +- **syntax:** `pto.vsort32 %dest, %src, %config : !pto.ptr, !pto.ptr, i64` +- **semantics:** Sort 32 elements in UB. +- **inputs:** `%dest` and `%src` are UB pointers and `%config` is the ISA + control/config word. +- **outputs:** This op writes UB memory and returns no SSA value. +- **constraints and limitations:** This is a UB-to-UB accelerator helper, not a + pure vector-register op. + +--- + +##### `pto.vmrgsort` + +- **syntax:** `pto.vmrgsort4 %dest, %src0, %src1, %src2, %src3, %count, %config : !pto.ptr, !pto.ptr x4, i64, i64` +- **semantics:** Merge-sort 4 pre-sorted input vectors. +- **inputs:** `%dest` is the UB destination, `%src0..%src3` are the four + pre-sorted UB inputs, `%count` is the number of valid elements, and `%config` + is the operation control word. +- **outputs:** This op writes UB memory and returns no SSA value. +- **constraints and limitations:** Inputs MUST already be sorted according to + the sort order encoded by `%config`. This page uses the shorter mnemonic + `pto.vmrgsort`, while the current implementation summary still refers to + `pto.vmrgsort4`. + +--- + +#### Current Implementation Surface Summary + +- `pto.vmull %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.vreg` +- `pto.vmula %acc, %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- `pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- `pto.vbitsort %dest, %src, %indices, %repeat_times : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, index` +- `pto.vmrgsort4 %dest, %src0, %src1, %src2, %src3, %count, %config : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, i64, i64` + +--- + +#### Typical Usage + +```mlir +// Softmax with fused expdiff +%max_broadcast = pto.vlds %ub_max[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> +%exp_stable = pto.vexpdiff %logits, %max_broadcast : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +// Leaky ReLU activation +%activated = pto.vlrelu %linear_out, %alpha_scalar, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Fused residual add + ReLU +%residual = pto.vaddrelu %conv_out, %skip_connection : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +// Generate indices for argsort +%indices = pto.vci %c0 {order = "ASC"} : i32 -> !pto.vreg<64xi32> +``` + +--- + + + +### 14. Arith (Shared MLIR Dialect) + +> **Category:** Shared full scalar `arith` surface used around PTO ops +> **Dialect:** `arith` +> **Upstream Reference:** https://mlir.llvm.org/docs/Dialects/ArithOps/ + +The upstream MLIR `arith` dialect defines primitive arithmetic, comparison, select, and cast operations over signless integer, index, floating-point, and boolean-compatible scalar values. Within PTO micro Instruction code, the full scalar operation surface of `arith` is supported. These ops are used around PTO instructions to build constants, compute offsets and loop bounds, perform general scalar math, derive valid-shape metadata, and form predicates for `scf` control flow. + +These ops are part of the documented PTO micro Instruction surface, but they are not PTO ISA instructions. + +--- + +#### Role in PTO micro Instruction Code + +- materialize scalar constants used by PTO scalar operands and loop bounds +- compute scalar/index offsets for tensor views, partitioning, and dynamic shapes +- perform general scalar integer and floating-point math outside PTO vector/tile payload operations +- derive scalar predicates that guard `scf.if` or `scf.while` +- apply scalar casts, width changes, bitwise ops, and selects without introducing PTO-specific control ops + +Prefer PTO ops for vector or tile payload math. Use `arith` for scalar computation and bookkeeping that surrounds PTO regions. + +--- + +#### Supported Surface + +The documented PTO micro Instruction surface supports the full scalar operation surface of upstream `arith`. The upstream `arith` dialect reference remains authoritative for the exhaustive op-by-op syntax and semantics. The categories below summarize how that support is used in PTO micro Instruction code. + +| Category | Representative Ops | Typical Use in PTO micro Instruction Code | +|----------|--------------------|------------------| +| Constants | `arith.constant` | integer, floating-point, boolean, and `index` constants | +| Integer / Index Arithmetic | `arith.addi`, `arith.subi`, `arith.muli`, `arith.divsi`, `arith.divui`, `arith.ceildivsi`, `arith.ceildivui`, `arith.floordivsi`, `arith.remsi`, `arith.remui` | offsets, bounds, chunk sizes, scalar math | +| Floating-Point Arithmetic | `arith.addf`, `arith.subf`, `arith.mulf`, `arith.divf`, `arith.negf`, `arith.maximumf`, `arith.minimumf`, `arith.maxnumf`, `arith.minnumf` | scalar math around PTO regions | +| Bitwise / Shift Ops | `arith.andi`, `arith.ori`, `arith.xori`, `arith.shli`, `arith.shrsi`, `arith.shrui` | flags, masks, packed scalar fields | +| Comparisons / Select | `arith.cmpi`, `arith.cmpf`, `arith.select`, `arith.maxsi`, `arith.minui` | predicates, clamps, scalar muxes | +| Casts / Width Changes | `arith.index_cast`, `arith.index_castui`, `arith.extsi`, `arith.extui`, `arith.trunci`, `arith.sitofp`, `arith.uitofp`, `arith.fptosi`, `arith.fptoui`, `arith.extf`, `arith.truncf`, `arith.bitcast` | ABI glue, dynamic-shape plumbing, scalar type adaptation | + +--- + +#### Current PTOAS Coverage + +- the current repository examples are still dominated by constants, casts, integer/index arithmetic, compares, and selects because those are the most common surrounding-scalar patterns in existing kernels +- backend-specific tests such as the PTO shared-dialect fixture visibly exercise only a representative subset of `arith` ops in a single path +- the documented PTO micro Instruction source-level contract is nevertheless the full scalar `arith` surface, not just the index-heavy subset that appears most often in current samples + +This section therefore uses representative categories and examples instead of pretending that the supported `arith` surface is limited to the currently most common sample patterns. + +--- + +#### Typical Patterns + +##### Scalar Setup + +```mlir +%c0 = arith.constant 0 : index +%c1 = arith.constant 1 : index +%scale = arith.constant 2.0 : f32 +``` + +##### Dynamic Offset Computation + +```mlir +%vrow = arith.index_cast %valid_row : i32 to index +%chunk = arith.muli %row, %c32 : index +%tail = arith.subi %limit, %chunk : index +``` + +##### General Scalar Arithmetic + +```mlir +%sum_i = arith.addi %lhs_i, %rhs_i : i32 +%sum_f = arith.addf %lhs_f, %rhs_f : f32 +%prod_f = arith.mulf %sum_f, %scale : f32 +``` + +##### Scalar Predicate and Selection + +```mlir +%is_first = arith.cmpi eq, %i, %c0 : index +%active = arith.select %is_first, %first_count, %steady_count : index +``` + +##### Bitwise / Width Adaptation + +```mlir +%flags = arith.andi %flags0, %flags1 : i32 +%wide = arith.extui %flags : i32 to i64 +%shrunk = arith.trunci %wide : i64 to i16 +``` + +--- + +#### Authoring Guidance + +- treat upstream `arith` scalar semantics as the source of truth for supported scalar ops +- keep `arith` values scalar or `index` typed; do not use `arith` as a substitute for PTO vector/tile compute +- use `arith` for general scalar math, scalar comparisons, bitwise operations, and casts around PTO regions, not just for `index` arithmetic +- use `arith.cmpi` / `arith.cmpf` plus `scf.if` / `scf.while` for control flow, not ad hoc control intrinsics +- prefer `arith.index_cast` / `arith.index_castui` at ABI or shape boundaries where `index` is required, but do not read that as a restriction on the rest of scalar `arith` + +--- + + + +### 15. SCF (Shared MLIR Dialect) + +> **Category:** Shared structured control flow around PTO regions +> **Dialect:** `scf` +> **Upstream Reference:** https://mlir.llvm.org/docs/Dialects/SCFDialect/ + +The upstream MLIR `scf` dialect defines structured control flow operations with regions, including counted loops, conditional regions, and while-style loops. In PTO micro Instruction code, `scf` is the control shell around PTO ops: it sequences DMA, vector, and tile operations; carries scalar or tile state across iterations; and preserves analyzable control flow for PTO-specific analyses and lowerings. + +These ops are part of the documented PTO micro Instruction surface, but they are shared MLIR control-flow constructs rather than PTO ISA instructions. + +--- + +#### Supported Ops + +| Op | Role in PTO micro Instruction Code | Notes | +|----|------------------------|-------| +| `scf.for` | counted loops and loop-carried values | also used for `__VEC_SCOPE__` dummy-loop form | +| `scf.if` | structured conditional execution | may yield values or act as side-effect-only branch | +| `scf.yield` | region terminator for `for` / `if` / `while` bodies | carries loop or branch results | +| `scf.while` | break-like or stateful loops | useful for source-level structured control | +| `scf.condition` | loop-continue / loop-exit decision for `scf.while` | placed in the "before" region | + +Ops such as `scf.execute_region`, `scf.forall`, or `scf.index_switch` are not part of the documented shared-dialect portion of the PTO micro Instruction surface here. + +--- + +#### Current PTOAS Coverage + +- `scf.for`, `scf.if`, and `scf.yield` are directly exercised in the shared-dialect PTO fixture and appear widely across PTO samples +- the `__VEC_SCOPE__` contract in PTO micro Instruction is modeled as a specialized `scf.for` annotated with `llvm.loop.aivector_scope` +- PTO synchronization and memory analyses explicitly reason about `scf.for`, `scf.if`, `scf.yield`, and `scf.while` +- `scf.while` and `scf.condition` appear in control-flow samples and are handled in PTO-to-EmitC control-flow lowering, but they are less broadly exercised than `for` / `if` on all backend paths + +--- + +#### Typical Patterns + +##### Counted Loop + +```mlir +scf.for %i = %c0 to %c4 step %c1 { + %offset = arith.muli %i, %c32 : index + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} +``` + +##### Counted Loop with Loop-Carried State + +```mlir +%final_alive = scf.for %i = %c0 to %c4 step %c1 + iter_args(%alive = %true) -> (i1) { + %break_now = arith.cmpi eq, %i, %c2 : index + %next_alive = scf.if %break_now -> (i1) { + scf.yield %false : i1 + } else { + scf.yield %alive : i1 + } + scf.yield %next_alive : i1 +} +``` + +##### Structured Conditional Region + +```mlir +%is_mode_a = arith.cmpi eq, %mode, %c0_i32 : i32 +scf.if %is_mode_a { + pto.tmuls ins(%data, %scale_a : !pto.tile_buf<...>, f32) outs(%data : !pto.tile_buf<...>) +} else { + pto.tadds ins(%data, %bias_b : !pto.tile_buf<...>, f32) outs(%data : !pto.tile_buf<...>) +} +``` + +##### While-Style Break Loop + +```mlir +%final:2 = scf.while (%i = %c0, %alive = %true) : (index, i1) -> (index, i1) { + %lt = arith.cmpi slt, %i, %c4 : index + %go = arith.andi %lt, %alive : i1 + scf.condition(%go) %i, %alive : index, i1 +} do { +^bb0(%i2: index, %alive2: i1): + %next_i = arith.addi %i2, %c1 : index + scf.yield %next_i, %alive2 : index, i1 +} +``` + +--- + +#### Authoring Guidance + +- use `scf.for` for regular counted loops and loop-carried scalar/tile state +- use `scf.if` for structured branching around PTO regions instead of inventing PTO-specific branch ops +- keep region results explicit with `scf.yield`; this is important for PTO analyses that track carried buffers and aliasing +- use `scf.while` only when a counted loop cannot express the control cleanly; `scf.for` remains the more common and better-exercised form in the current repository +- build branch predicates and loop conditions with `arith` ops, not PTO vector masks, unless the control decision truly comes from a scalarized value + +--- + +## Supported Data Types + +| Type | Bits | vreg Lanes | Description | +|------|------|-----------|-------------| +| `i8` / `u8` | 8 | 256 | Signed/unsigned 8-bit integer | +| `i16` / `u16` | 16 | 128 | Signed/unsigned 16-bit integer | +| `f16` | 16 | 128 | IEEE 754 half precision | +| `bf16` | 16 | 128 | Brain floating point | +| `i32` / `u32` | 32 | 64 | Signed/unsigned 32-bit integer | +| `f32` | 32 | 64 | IEEE 754 single precision | +| `i64` / `u64` | 64 | 32 | Signed/unsigned 64-bit integer | + +--- + +## Common Patterns + +### Softmax (Numerically Stable) + +```mlir +// 1. Find max +%max_vec = pto.vcmax %logits, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +pto.vsts %max_vec, %ub_tmp[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +%max_bc = pto.vlds %ub_tmp[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> + +// 2. exp(x - max) using fused op +%exp = pto.vexpdiff %logits, %max_bc : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +// 3. Sum +%sum = pto.vcadd %exp, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +pto.vsts %sum, %ub_tmp[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +%sum_bc = pto.vlds %ub_tmp[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> + +// 4. Divide +%softmax = pto.vdiv %exp, %sum_bc, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +### ReLU Variants + +```mlir +// Standard ReLU +%relu = pto.vrelu %input, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Leaky ReLU (scalar alpha) +%lrelu = pto.vlrelu %input, %alpha, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Parametric ReLU (per-element alpha) +%prelu = pto.vprelu %input, %alpha_vec : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +// Fused add + ReLU +%fused = pto.vaddrelu %a, %b : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> +``` + +### Data Layout Conversion + +```mlir +// AoS → SoA (deinterleave) +%x, %y = pto.vldx2 %ub_xy[%offset], "DINTLV_B32" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +// SoA → AoS (interleave) +pto.vstx2 %x, %y, %ub_xy[%offset], "INTLV_B32", %all_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.ptr, index, !pto.mask +``` + + +--- + +## Quick Reference by Category + +### Memory Operations + +| Operation | Group | Description | +|-----------|-------|-------------| +| GM→UB DMA | 2 | `pto.copy_gm_to_ubuf` | +| UB→GM DMA | 2 | `pto.copy_ubuf_to_gm` | +| UB→UB Copy | 2 | `pto.copy_ubuf_to_ubuf` | +| Contiguous Load | 3 | `pto.vlds` with `NORM` dist | +| Broadcast Load | 3 | `pto.vlds` with `BRC_*` dist | +| Gather | 3 | `pto.vgather2`, `pto.vgatherb` | +| Contiguous Store | 3 | `pto.vsts` with `NORM_*` dist | +| Scatter | 3 | `pto.vscatter` | + +### Compute Operations + +| Operation | Group | Description | +|-----------|-------|-------------| +| Element-wise Arithmetic | 6, 7 | `pto.vadd`, `pto.vmul`, `pto.vabs`, etc. | +| Scalar Operations | 8 | `pto.vadds`, `pto.vmuls`, etc. | +| Transcendental | 6 | `pto.vexp`, `pto.vln`, `pto.vsqrt`, etc. | +| Reduction | 10 | `pto.vcadd`, `pto.vcmax`, `pto.vcmin` | +| Comparison | 11 | `pto.vcmp`, `pto.vcmps` | +| Selection | 11 | `pto.vsel`, `pto.vselr` | + +### Type & Data Manipulation + +| Operation | Group | Description | +|-----------|-------|-------------| +| Type Conversion | 9 | `pto.vcvt` | +| Interleave/Deinterleave | 12 | `pto.vintlv`, `pto.vdintlv` | +| Interleave/Deinterleave | 12 | `pto.vintlv`, `pto.vdintlv`, `pto.vintlvv2`, `pto.vdintlvv2` | + +### Synchronization + +| Operation | Group | Description | +|-----------|-------|-------------| +| Intra-core Sync | 1 | `pto.set_flag`, `pto.wait_flag` | +| Pipeline Buffer Sync | 1 | `pto.get_buf`, `pto.rls_buf` | + +### Scalar & Control Operations + +Group 14 covers the full scalar `arith` surface. The rows below list common PTO micro Instruction patterns rather than an exhaustive partition of `arith` ops. + +| Operation | Group | Description | +|-----------|-------|-------------| +| Scalar Constants | 14 | `arith.constant` | +| Scalar Integer / Index Arithmetic | 14 | `arith.addi`, `arith.subi`, `arith.muli`, `arith.divsi`, `arith.remui`, `arith.ceildivsi`, etc. | +| Scalar Floating-Point Arithmetic | 14 | `arith.addf`, `arith.subf`, `arith.mulf`, `arith.divf`, `arith.maximumf`, etc. | +| Scalar Compare & Select | 14 | `arith.cmpi`, `arith.cmpf`, `arith.select` | +| Scalar Casts / Width Changes | 14 | `arith.index_cast`, `arith.index_castui`, `arith.extsi`, `arith.extui`, `arith.trunci`, `arith.sitofp`, etc. | +| Scalar Bitwise / Shift Ops | 14 | `arith.andi`, `arith.ori`, `arith.xori`, `arith.shli`, `arith.shrsi`, `arith.shrui`, etc. | +| Counted Loops | 15 | `scf.for` | +| Conditional Regions | 15 | `scf.if`, `scf.yield` | +| Break-like Structured Loops | 15 | `scf.while`, `scf.condition`, `scf.yield` | diff --git a/docs/release/vpto-spec-v0.2.md b/docs/release/vpto-spec-v0.2.md new file mode 100644 index 000000000..90b632a14 --- /dev/null +++ b/docs/release/vpto-spec-v0.2.md @@ -0,0 +1,5074 @@ +# PTO micro Instruction Spec — Draft (A5) + +- v0.2: Update micro Instruction latency and throughput +- v0.1: Doc Init + +[toc] + +--- + +## Part I: Architecture Overview + +### Overview + +This document defines the PTO micro Instruction, a compiler-internal and externally facing specification designed to represent vector compute kernels within the PTO architecture. Much like NVVM provides a robust IR for GPU architectures, the PTO micro Instruction serves as the direct bridge between high-level programming models and the underlying hardware ISA, providing a precise, low-level representation of vector workloads explicitly designed for the Ascend 950 architecture. + +#### Position in the Stack and Layer Modeled + +The PTO micro Instruction operates as a very low-level intermediate representation within the PTO compiler stack. It is uniquely designed to accurately and comprehensively express all architectural information of the Ascend 950 hardware. It specifically models the bare-metal vector execution layer, making hardware-specific capabilities and constraints, such as exact vector lane configurations, memory space hierarchies, and hardware-specific fusion semantics, fully transparent and controllable. + +#### PTO Instruction Modes and Compilation Flows + +Within the end-to-end PTO software stack, PTO instructions may appear in three closely related authoring or lowering modes: + +- **PTO Tile Instruction**: tile-oriented PTO code that serves as a nano-kernel encapsulation of Tile operations, primarily expressing computation and data movement in terms of tile buffers, tile shapes, and tile-local layout. +- **PTO micro Instruction**: vector-execution-oriented PTO code that makes DMA setup, vector registers, masks, synchronization, and `__VEC_SCOPE__` boundaries explicit. This document is centered on this mode. +- **PTO Tile+micro Instruction**: a hybrid PTO form that keeps tile-level orchestration while embedding explicit micro-instruction regions where direct vector-pipeline control is required. + +From these PTO instruction forms, the stack can proceed along two main compilation flows: + +- **CCE generation flow**: PTO ISA is lowered into a CCE-oriented representation, which is then compiled by the BiSheng toolchain into Ascend device binaries. +- **Bytecode generation flow**: PTO ISA is emitted as bytecode, which is then compiled by the BiSheng toolchain into Ascend device binaries. + +```text +High-level frameworks / DSLs / library kernels + | + v + +----------------------------------+ + | PTO ISA layer | + | | + | (1) PTO Tile Instruction | + | (2) PTO micro Instruction | + | (3) PTO Tile+micro Instruction | + +----------------+-----------------+ + | + +------------+------------+ + | | + v v + +-------------------------+ +-------------------------+ + | Path A: generate CCE | | Path B: generate | + | (CCE-oriented form) | | bytecode | + +------------+------------+ +------------+------------+ + | | + v v + +-----------------------------------------------+ + | BiSheng compiler | + +---------------------------+-------------------+ + | + v + +-----------------------------+ + | Ascend device binaries | + +-----------------------------+ +``` + +#### Why External Developers Read or Author PTO micro Instruction + +While the majority of users will interact with the PTO architecture via higher-level frameworks, external developers may need to read or author PTO micro Instruction directly for several key reasons: + +- Custom Toolchain Development: build custom compiler frontends or domain-specific languages (DSLs) that target the Ascend 950 architecture with maximum hardware utilization. +- Performance Engineering: inspect the output of high-level compiler passes, verify fine-grained optimization behaviors, and pinpoint performance bottlenecks at the architectural level. +- Micro-Optimization: hand-author highly optimized, critical mathematical kernels using a stable, precise IR when higher-level abstractions cannot achieve the theoretical peak performance of the hardware. + +#### Relationship to CCE + +The PTO micro Instruction is designed to express the full semantic capabilities of the Compute Cube Engine (CCE), but with significant structural and pipeline advantages for compiler development. + +- Bypassing the C/Clang Pipeline: while CCE heavily relies on C/C++ extensions parsed by Clang, the PTO micro Instruction operates entirely independently of the C language frontend. By bypassing Clang AST generation and frontend processing, utilizing the PTO micro Instruction significantly reduces overall compilation time and memory overhead. +- Enhanced IR Verification: because the PTO micro Instruction is a strongly typed, SSA-based (Static Single Assignment) compiler IR rather than a C-wrapper API, it provides a much more rigorous and detailed IR verification process. Structural inconsistencies, invalid memory access patterns, and operand type mismatches are caught immediately with precise, explicit diagnostic feedback, providing developers with much higher visibility into kernel correctness than traditional CCE error reporting. + +#### Intended Audience + +This document is written for compiler engineers, library writers, and advanced performance architects. We expect the reader to have a working understanding of modern compiler infrastructure, specifically MLIR, the principles of Static Single Assignment (SSA) form, and a deep understanding of the vector-processing capabilities of the Ascend 950 architecture. + +### Getting Started + +The PTO micro Instruction is architected as a performance-critical layer within the compiler stack, specifically designed to exploit the **Decoupled Access-Execute** (DAE) nature of the Ascend 950 hardware. + +#### Hardware Pipeline Modeling + +The IR is structured to mirror the three primary hardware pipelines of the Ascend 950 architecture. Correct PTO micro Instruction authoring requires managing the interaction between these asynchronous units: + +**MTE2** (Memory Transfer Engine - Inbound): Responsible for moving data from Global Memory (GM) to the Unified Buffer (UB). + +**Vector Core** (Computation): The primary engine for executing SIMD operations on data stored in UB. + +**MTE3** (Memory Transfer Engine - Outbound): Responsible for moving processed data from UB back to GM. + +#### Architecture Detail: Vector Lane (VLane) + +The vector register is organized as **8 VLanes** of 32 bytes each. A VLane is the atomic unit for group reduction operations. + +``` +vreg (256 bytes total): +┌─────────┬─────────┬─────────┬─────┬─────────┬─────────┐ +│ VLane 0 │ VLane 1 │ VLane 2 │ ... │ VLane 6 │ VLane 7 │ +│ 32B │ 32B │ 32B │ │ 32B │ 32B │ +└─────────┴─────────┴─────────┴─────┴─────────┴─────────┘ +``` + +Elements per VLane by data type: + +| Data Type | Elements/VLane | Total Elements/vreg | +|-----------|---------------|-------------------| +| i8/u8 | 32 | 256 | +| i16/u16/f16/bf16 | 16 | 128 | +| i32/u32/f32 | 8 | 64 | +| i64/u64 | 4 | 32 | + +#### Memory and Synchronization Model + +The PTO micro Instruction enforces a strict memory hierarchy. The Unified Buffer (UB) is the only valid operand source for vector compute instructions. Consequently, the architecture of a PTO micro Instruction program is defined by the explicit management of data movement: + +**Address Space Isolation**: The IR uses `!pto.ptr` to distinguish between GM (`!pto.ptr`) and UB (`!pto.ptr`). The verifier ensures that vector compute operations do not access GM directly; data must first be moved into UB. + +**UB Capacity**: The Unified Buffer provides 256KB of on-chip SRAM (also referred to as "vecTile"). + +**Data Flow**: + +``` +┌─────────────────────────────────────────────┐ +│ Global Memory (GM) │ +│ (Off-chip HBM/DDR) │ +└─────────────────────┬───────────────────────┘ + │ DMA (MTE2 inbound / MTE3 outbound) +┌─────────────────────▼───────────────────────┐ +│ Unified Buffer (UB) │ +│ (On-chip SRAM, 256KB) │ +└─────────────────────┬───────────────────────┘ + │ Vector Load/Store (PIPE_V) +┌─────────────────────▼───────────────────────┐ +│ Vector Register File (VRF) │ +│ vreg (256B each) + mask (256-bit each) │ +└─────────────────────────────────────────────┘ +``` + +1. **GM → UB**: DMA transfer via MTE2 (`pto.copy_gm_to_ubuf`) +2. **UB → vreg**: Vector Load instructions (`pto.vlds`, `pto.vldx2`, etc.) +3. **vreg → vreg**: Compute instructions (`pto.vadd`, `pto.vmul`, etc.) +4. **vreg → UB**: Vector Store instructions (`pto.vsts`, `pto.vstx2`, etc.) +5. **UB → GM**: DMA transfer via MTE3 (`pto.copy_ubuf_to_gm`) + +**Load/Store Access Patterns**: + +For UB↔vreg data movement, besides contiguous load/store, the architecture provides rich access pattern support including strided access, pack/unpack, interleave/deinterleave, broadcast, upsample/downsample, channel split/merge, gather/scatter, and squeeze/expand operations. For detailed instruction syntax and distribution modes, refer to the [Vector Load/Store](#isa-03-vector-load-store) group in the ISA specification. + +#### Synchronization Model + +The Ascend 950 architecture employs a cluster-based design with a 1:2 ratio of Cube cores to Vector cores. The PTO micro Instruction provides multiple levels of synchronization to manage concurrent execution across pipelines and cores: + +**Inter-Core Synchronization (within a cluster):** + +Synchronization between cores within the same cluster is achieved via the core sync mechanism using `pto.set_intra_core` and `pto.wait_intra_core` operations. This enables coordination between Cube and Vector cores sharing the same cluster resources. + +**Vector Core Pipeline Synchronization:** + +Within a single core, multiple pipelines operate asynchronously: + +- **MTE2 (PIPE_MTE2)**: DMA copy-in from GM to UB +- **MTE3 (PIPE_MTE3)**: DMA copy-out from UB to GM +- **Vector Compute (PIPE_V)**: Vector ALU operations +- **Scalar (PIPE_S)**: Scalar unit running the kernel program + +Pipeline synchronization can be achieved through two mechanisms: + +1. **Flag/Event mechanism**: `pto.set_flag` and `pto.wait_flag` operations resolve Read-After-Write (RAW) and Write-After-Read (WAR) hazards between pipelines. + +2. **Buffer-ID mechanism**: `pto.get_buf` and `pto.rls_buf` provide finer-grained synchronization through buffer acquisition and release semantics for producer-consumer coordination. + +**Intra-Pipeline Memory Barriers (within `__VEC_SCOPE__`):** + +Within the vector execution scope, the hardware does not track UB address aliasing between reg↔UB accesses. When UB addresses overlap or alias between vector load/store operations, explicit memory barriers are required: + +```c +pto.mem_bar "VV_ALL" // All prior vector ops complete before subsequent +pto.mem_bar "VST_VLD" // All prior vector stores visible before subsequent loads +pto.mem_bar "VLD_VST" // All prior vector loads complete before subsequent stores +``` + +Without proper barriers, loads may see stale data or stores may be reordered incorrectly. + +#### Execution Scopes (__VEC_SCOPE__) + +`__VEC_SCOPE__` is the IR-level representation of a Vector Function (VF) launch. In the PTO architecture, it defines the hardware interface between the Scalar Unit and the Vector Thread. + +In PTO micro Instruction source IR, vector execution scopes are modeled as dedicated region ops. The default form is `pto.vecscope`; when the scope body must reject implicit capture and require explicit region arguments, use `pto.strict_vecscope`. + +**Scalar-Vector Interface:** + +The execution model follows non-blocking fork semantics: + +- Scalar invocation: the scalar processor invokes a vector thread by calling a VF. Once the launch command is issued, the scalar unit does not stall and continues executing subsequent instructions in the pipeline. +- Vector execution: after invocation, the vector thread independently fetches and executes the instructions defined within the VF scope. +- Parallelism: this decoupled execution allows the scalar and vector units to run in parallel, so the scalar unit can prepare addresses or manage control flow while the vector unit performs heavy SIMD computation. + +**Launch Mechanism And Constraints:** + +- Parameter buffering: all arguments required by the VF must be staged in hardware-specific buffers. +- Launch overhead: launching a VF incurs a latency of a few cycles. Very small VFs should account for this overhead because launch cost can rival useful computation time. + +**MLIR Representation:** + +```mlir +pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} +``` + +**Strict MLIR Representation:** + +```mlir +pto.strict_vecscope(%ub, %ub_out, %lane) { +^bb0(%in: !pto.ptr, %out: !pto.ptr, %iv: index): + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %in[%iv] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %out[%iv], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} : (!pto.ptr, !pto.ptr, index) -> () +``` + +`pto.strict_vecscope` is the strict form of `pto.vecscope`. + +- `pto.vecscope` allows the body to use surrounding SSA values directly. +- `pto.strict_vecscope` requires every external value used by the body to be passed through the op operand list and received as a body block argument. +- `pto.strict_vecscope` rejects implicit capture from the surrounding scope. +- both ops still represent one explicit VPTO vector interval. + +### Example: VecScope + +```mlir +pto.set_loop2_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.copy_gm_to_ubuf %7, %2, %3, %3, %c0_i64, %c32_i64, %4, %c0_i64, %c0_i64, + %false, %c0_i64, %c128_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i64, i1, i64, i64, i64 + +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +pto.vecscope { + scf.for %lane = %c0 to %9 step %c64 { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %2[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %8[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } +} + +pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] +pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_ubtoout %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop2_stride_ubtoout %c4096_i64, %c4096_i64 : i64, i64 +pto.copy_ubuf_to_gm %8, %14, %3, %3, %c0_i64, %c32_i64, %4, %c0_i64, %c128_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i64, i64 +``` + +### Example: Strict VecScope + +```mlir +pto.strict_vecscope(%ub_in, %ub_out, %lane, %remaining) { +^bb0(%in: !pto.ptr, %out: !pto.ptr, %iv: index, %rem: i32): + %mask, %next_remaining = pto.plt_b32 %rem : i32 -> !pto.mask, i32 + %v = pto.vlds %in[%iv] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %out[%iv], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} : (!pto.ptr, !pto.ptr, index, i32) -> () +``` + +Use `pto.strict_vecscope` when the source form should make all vector-scope inputs explicit in the region signature instead of relying on surrounding SSA visibility. The scope op itself only defines the vector-interval boundary and region argument contract. + +### Scope + +This document is the interface specification centered on the `mlir::pto` dialect and the shared MLIR surface used alongside it in PTO micro Instruction programs. + +It only describes: + +- operation names +- operand and result lists +- operand and result types +- important attributes +- C-style semantics for each operation + +It does not describe lowering strategy. + +PTO micro Instruction source programs are not restricted to `pto` operations alone. In practice they also use shared MLIR dialect ops, most notably the full scalar operation surface of `arith` together with structured control-flow ops from `scf`, to express scalar constants, scalar arithmetic, type conversion, comparisons, and structured control flow around PTO vector or tile regions. These shared-dialect ops are part of the supported PTO micro Instruction source surface and should be regarded as part of PTO-ISA alongside `pto` dialect operations. + +- `vreg`: `!pto.vreg` + Fixed-width VPTO vector type with total width exactly 256 bytes. +- `mask`: `!pto.mask` + Typed predicate-register view. `G` is one of `b8`, `b16`, `b32` and records the byte-granularity interpretation used by VPTO ops and verifiers. +- `align`: `!pto.align` +- `buf`: buffer-like LLVM pointer type accepted by the dialect +- `buf_like`: `memref<...>` or `!llvm.ptr` for stateless/predicate + `vld*/vst*` families +- `idx`: `index` +- `i32`: `i32` +- `i64`: `i64` + +### Shared MLIR Dialects + +- `arith`: the full scalar `arith` surface is supported in PTO micro Instruction programs, covering scalar integer, floating-point, boolean, and `index` operations. In current samples the most common uses are still constants, offset/bounds arithmetic, casts, compares, and selects. +- `scf`: structured control flow used to model counted loops, conditional regions, loop-carried state, and break-like control around PTO compute and data-movement ops. +- Shared dialect ops remain in standard MLIR form so that PTO analyses and backend passes can reason about control flow and scalar state without re-encoding them as PTO-specific instructions. + +### Core Types + +### Element Types +`vreg`: `!pto.vreg` Fixed-width PTO micro Instruction vector type with total width exactly 256 bytes (2048 bits). `N` is the lane count, `T` is the element type, and `N * bitwidth(T) = 2048`. + +| Type | Bits | Description | +|------|------|-------------| +| `i8` / `s8` / `u8` | 8 | Signless/signed/unsigned 8-bit integer | +| `i16` / `s16` / `u16` | 16 | Signless/signed/unsigned 16-bit integer | +| `i32` / `s32` / `u32` | 32 | Signless/signed/unsigned 32-bit integer | +| `i64` / `s64` / `u64` | 64 | Signless/signed/unsigned 64-bit integer | +| `f16` | 16 | IEEE 754 half precision | +| `bf16` | 16 | Brain floating point | +| `f32` | 32 | IEEE 754 single precision | +| `f8e4m3` | 8 | FP8 (4-bit exponent, 3-bit mantissa) | +| `f8e5m2` | 8 | FP8 (5-bit exponent, 2-bit mantissa) | + +### Address Space Conventions + +PTO micro Instruction memory operands use `!pto.ptr`. This specification models the following memory-space attributes: + +| Space | Interpretation | +|-------|----------------| +| `gm` | Global Memory (GM), off-chip HBM/DDR storage | +| `ub` | Unified Buffer (UB), on-chip vector buffer | + +Typical pointer construction and pointer arithmetic follow the same `!pto.ptr<..., space>` form: + +```mlir +%0 = pto.castptr %c0 : i64 -> !pto.ptr +%1 = pto.addptr %0, %c1024 : !pto.ptr -> !pto.ptr +``` + +### `!pto.ptr` + +`!pto.ptr` is the typed pointer form used for explicit memory operands in PTO micro Instruction. + +- `T` is the element type associated with the pointed-to storage. +- `space` is the memory domain, typically `gm` or `ub` in this specification. +- A `pto.ptr` value carries an address plus its element-type / memory-space interpretation, but it does not carry tensor shape or stride metadata by itself. +- Tensor semantics are introduced separately through view-building operations such as `pto.make_tensor_view`. +- Pointer arithmetic is element-based rather than byte-based. + +Typical examples: + +- `!pto.ptr` +- `!pto.ptr` +- `!pto.ptr` + +### Pointer Operations + +#### `pto.castptr` + +- **syntax:** `%result = pto.castptr %addr : i64 -> !pto.ptr` +- **semantics:** Reinterpret a scalar address value as a typed PTO pointer in the target memory space. + +```c +result = (ptr)addr; +``` + +`pto.castptr` is a pointer-construction operation. It does not perform data movement and does not by itself imply any load/store side effect. + +#### `pto.addptr` + +- **syntax:** `%result = pto.addptr %ptr, %offset : !pto.ptr -> !pto.ptr` +- **semantics:** Compute a new pointer by advancing the base pointer by an element offset. + +```c +result = ptr + offset; // offset counted in elements, not bytes +``` + +`pto.addptr` preserves both the element type `T` and the memory-space tag `space`. + +#### Pointer-Based Vector Access Example + +The following lowered-style fragment shows how typed PTO pointers flow through pointer construction, pointer arithmetic, structured control flow, and PTO memory ops: + +```mlir +%0 = pto.castptr %c0 : i64 -> !pto.ptr +%1 = pto.addptr %0, %c1024 : !pto.ptr -> !pto.ptr +pto.vecscope { + %16 = scf.for %arg3 = %c0 to %11 step %c64 iter_args(%arg4 = %12) -> (i32) { + %mask, %scalar_out = pto.plt_b32 %arg4 : i32 -> !pto.mask, i32 + %17 = pto.vlds %1[%arg3] : !pto.ptr -> !pto.vreg<64xf32> + %18 = pto.vabs %17, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %18, %10[%arg3], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %scalar_out : i32 + } +} +``` + +In this pattern, `pto.castptr` materializes a typed UB pointer, `pto.addptr` shifts the base by 1024 `f32` elements, and the subsequent `[%arg3]` indexing on `pto.vlds` / `pto.vsts` applies an additional element offset relative to that base. + +### Special Types + +#### `!pto.mask` + +`!pto.mask` models an A5 predicate register (256-bit) under a typed granularity view, not an integer vector. + +`G` is part of the type and MUST be one of: + +- `b32` +- `b16` +- `b8` + +All three forms describe the same physical 256-bit predicate-register class. The type parameter does not encode how many lanes are currently active. Instead, it records how VPTO interprets the register when matching mask-producing ops, mask-consuming ops, and verifier legality rules. + +In the ISA chapters below, this document uses `!pto.mask` as shorthand when a +family is generic over granularity. For op families whose names already encode +the granularity, such as `pset_b32`, `pge_b16`, `plt_b8`, +`pdintlv_b8`, and `pintlv_b16`, examples use the corresponding concrete typed +mask. + +**Mask Granularity:** + +The predicate register is 256 bits in length, where each bit controls 1 byte of data. `G` therefore describes how many bytes form one logical element slot: + +| Mask Type | Bytes / Element Slot | Typical Element Family | Derived Logical Lanes | +|-----------|----------------------|------------------------|-----------------------| +| `!pto.mask` | 4 | `f32` / `i32` | 64 | +| `!pto.mask` | 2 | `f16` / `bf16` / `i16` | 128 | +| `!pto.mask` | 1 | 8-bit element family | 256 | + +This is intentionally different from a lane-vector model such as `mask<64xi1>`: + +- `!pto.mask` still denotes a 256-bit predicate register; +- `64` is only the derived logical lane count for the `b32` view; +- value-level patterns such as `PAT_VL32` describe which lanes are active, not a different type. + +**Predication Behavior (Zero-Merge):** + +The native hardware predication mode is **ZEROING** — inactive lanes produce zero: + +```c +dst[i] = mask[i] ? op(src0[i], src1[i]) : 0 // ZEROING mode +``` + +```mlir +// Predicated add: inactive lanes produce zero +%mask = pto.pset_b32 "PAT_VL32" : !pto.mask // first 32 logical b32 lanes active +%result = pto.vcmp %a, %b, %mask, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +``` + +```mlir +// Compare and select: generate mask from comparison, use for conditional select +%mask = pto.vcmp %lhs, %rhs, %seed, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +%out = pto.vsel %x, %y, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +#### `!pto.align` + +`!pto.align` models the A5 vector-align carrier state. It is not payload data. + +```mlir +%align = pto.vldas %ub : !pto.ptr -> !pto.align +%vec, %align_out, %base_out = pto.vldus %ub, %align : !pto.ptr, !pto.align -> !pto.vreg<64xf32>, !pto.align, !pto.ptr +``` + +--- + +## Part II: Notation Convention + +This section defines the MLIR syntax patterns and C-style semantic notation used throughout the ISA reference (Part III). + +### MLIR Op Syntax Patterns + +All PTO micro Instruction operations follow standard MLIR syntax. The common patterns are: + +**Unary (one vector in, one vector out):** + +```mlir +%result = pto. %input : !pto.vreg -> !pto.vreg +``` + +**Binary (two vectors in, one vector out):** + +```mlir +%result = pto. %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg +``` + +**Vec-Scalar (one vector + one scalar in, one vector out):** + +```mlir +%result = pto. %input, %scalar : !pto.vreg, T -> !pto.vreg +``` + +**Load (memory to register):** + +```mlir +%result = pto.vlds %source[%offset] {dist = "DIST"} : !pto.ptr -> !pto.vreg +``` + +**Store (register to memory):** + +```mlir +pto.vsts %value, %destination[%offset] {dist = "DIST"} : !pto.vreg, !pto.ptr +``` + +**Dual Load (one load, two results — deinterleave):** + +```mlir +%low, %high = pto.vldx2 %source[%offset], "DIST" : !pto.ptr, index -> !pto.vreg, !pto.vreg +``` + +**Dual Store (two inputs, one interleaved store):** + +```mlir +pto.vstx2 %low, %high, %dest[%offset], "DIST", %mask : !pto.vreg, !pto.vreg, !pto.ptr, index, !pto.mask +``` + +**Compare (two vectors + seed mask in, mask out):** + +```mlir +%mask = pto.vcmp %src0, %src1, %seed, "CMP_MODE" : !pto.vreg, !pto.vreg, !pto.mask -> !pto.mask +``` + +**Conversion (one vector in, different-typed vector out):** + +```mlir +%result = pto.vcvt %input {round_mode = "ROUND_R", sat = "RS_ENABLE", part = "PART_EVEN"} : !pto.vreg -> !pto.vreg +``` + +**Predicate construction:** + +```mlir +%mask = pto.pset_b32 "PAT_ALL" : !pto.mask +%tail = pto.pge_b32 "PAT_VL16" : !pto.mask +``` + +**Sync operations:** + +```mlir +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +**Pointer construction and arithmetic:** + +```mlir +%ptr = pto.castptr %addr : i64 -> !pto.ptr +%ptr2 = pto.addptr %ptr, %offset : !pto.ptr -> !pto.ptr +``` + +### Shared Dialect Syntax Patterns + +PTO micro Instruction programs may interleave PTO ops with standard MLIR `arith` and `scf` ops. +The examples below emphasize common index-heavy patterns, but `arith` support is not limited to index arithmetic. + +**Scalar / index constant:** + +```mlir +%c0 = arith.constant 0 : index +%zero = arith.constant 0.0 : f32 +``` + +### C-Style Semantics Convention + +For each ISA operation in Part III, semantics are expressed as C code. The convention: + +```c +// Vector register contents as arrays: +T dst[N]; // destination +T src0[N]; // first source +T src1[N]; // second source (binary ops) +T scalar; // scalar operand (vec-scalar ops) +int mask[N]; // per-lane predicate (0 or 1) + +// N = lane count determined by type: +// N = 256 for i8/u8 +// N = 128 for i16/u16/f16/bf16 +// N = 64 for i32/u32/f32 +// N = 32 for i64/u64 +``` + +**Example — pto.vadd semantics:** + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] + src1[i]; +``` + +**Example — pto.vcgadd (group reduction per VLane) semantics:** + +```c +int K = N / 8; // elements per VLane +for (int g = 0; g < 8; g++) { + T sum = 0; + for (int i = 0; i < K; i++) + sum += src[g*K + i]; + dst[g*K] = sum; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +### Template Placeholder Conventions + +| Placeholder | Meaning | +|-------------|---------| +| `"SRC_PIPE"`, `"DST_PIPE"` | Pipeline identifiers: `"PIPE_MTE2"`, `"PIPE_V"`, `"PIPE_MTE3"` | +| `"EVENT_ID"` | Event identifier: `"EVENT_ID0"` etc. | +| `"DIST"` | Distribution mode string (see the relevant load/store ISA group in Part III) | +| `"CMP_MODE"` | Compare predicate: `eq \| ne \| lt \| le \| gt \| ge` | +| `"ROUND_MODE"` | Rounding mode: `ROUND_R \| ROUND_A \| ROUND_F \| ROUND_C \| ROUND_Z` | +| `"SAT_MODE"` | Saturation: `RS_ENABLE \| RS_DISABLE` | +| `"PART_MODE"` | Half selector: `PART_EVEN \| PART_ODD` | +| `"PAT_*"` | Predicate pattern literal | +| `T` | Element type (f32, f16, bf16, i32, i16, i8, etc.) | +| `N` | Lane count (`N * bitwidth(T) = 2048`) | + +--- + +## Part III: ISA Instruction Reference — Summary + +This section provides a categorized overview of all PTO micro Instruction operations plus the shared MLIR `arith` and `scf` ops that may appear in PTO micro Instruction programs. Detailed documentation for each group is included later in this merged document. + +--- + +## Instruction Groups + +| # | Group | Description | Count | Details | +|---|-------|-------------|-------|---------| +| 1 | [Pipeline Sync](#isa-01-pipeline-sync) | Intra-core pipeline synchronization | 5 | `pto.set_flag`, `pto.wait_flag`, `pto.pipe_barrier`, `pto.get_buf`, `pto.rls_buf` | +| 2 | [DMA Copy Programming](#isa-02-dma-copy) | DMA configuration and transfer between GM↔UB | 9 | `pto.set_loop*_stride_*`, `pto.set_loop_size_*`, `pto.copy_gm_to_ubuf`, `pto.copy_ubuf_to_ubuf`, `pto.copy_ubuf_to_gm` | +| 3 | [Vector Load/Store](#isa-03-vector-load-store) | UB↔vreg data movement with various access patterns | ~20 | `pto.vlds`, `pto.vldx2`, `pto.vgather2`, `pto.vsts`, `pto.vstx2`, `pto.vscatter`, etc. | +| 4 | [Predicate Load/Store](#isa-04-predicate-load-store) | UB↔mask register movement | 7 | `pto.plds`, `pto.pld`, `pto.pldi`, `pto.psts`, `pto.pst`, `pto.psti`, `pto.pstu` | +| 5 | [Materialization & Predicate Ops](#isa-05-materialization-predicate) | Scalar broadcast, predicate generation and manipulation | ~17 | `pto.vbr`, `pto.vdup`, `pto.pset_b*`, `pto.pge_b*`, `pto.plt_b*`, `pto.ppack`, `pto.punpack`, `pto.pnot`, `pto.psel`, etc. | +| 6 | [Unary Vector Ops](#isa-06-unary-vector-ops) | Single-input element-wise operations | 9 | `pto.vabs`, `pto.vexp`, `pto.vln`, `pto.vsqrt`, `pto.vrec`, `pto.vrelu`, `pto.vnot`, `pto.vbcnt`, `pto.vcls` | +| 7 | [Binary Vector Ops](#isa-07-binary-vector-ops) | Two-input element-wise operations | 13 | `pto.vadd`, `pto.vsub`, `pto.vmul`, `pto.vdiv`, `pto.vmax`, `pto.vmin`, `pto.vand`, `pto.vor`, `pto.vxor`, `pto.vshl`, `pto.vshr`, `pto.vaddc`, `pto.vsubc` | +| 8 | [Vec-Scalar Ops](#isa-08-vec-scalar-ops) | Vector-scalar operations | 8 | `pto.vadds`, `pto.vmuls`, `pto.vmaxs`, `pto.vmins`, `pto.vlrelu`, `pto.vshls`, `pto.vshrs`, `pto.vaddcs`, `pto.vsubcs` | +| 9 | [Conversion Ops](#isa-09-conversion-ops) | Type conversion with rounding/saturation control | 2 | `pto.vcvt`, `pto.vtrc` | +| 10 | [Reduction Ops](#isa-10-reduction-ops) | Vector reductions | 3 | `pto.vcadd`, `pto.vcmax`, `pto.vcmin` | +| 11 | [Compare & Select](#isa-11-compare-select) | Comparison and conditional selection | 5 | `pto.vcmp`, `pto.vcmps`, `pto.vsel`, `pto.vselr`, `pto.vselrv2` | +| 12 | [Data Rearrangement](#isa-12-data-rearrangement) | In-register data movement and permutation | 4 | `pto.vintlv`, `pto.vdintlv`, `pto.vintlvv2`, `pto.vdintlvv2` | +| 13 | [DSA/SFU Ops](#isa-13-dsa-sfu-ops) | Specialized ops, index generation, and sorting helpers | 5 | `pto.vmull`, `pto.vmula`, `pto.vci`, `pto.vbitsort`, `pto.vmrgsort4` | +| 14 | [Arith (Shared MLIR Dialect)](#isa-14-shared-arith) | Full scalar `arith` surface used around PTO ops; the companion page lists categories and representative examples | all scalar ops | `arith.constant`, `arith.addi`, `arith.addf`, `arith.cmpi`, `arith.cmpf`, `arith.select`, `arith.index_cast`, `arith.extsi`, `arith.trunci`, `arith.andi`, `arith.shli`, etc. | +| 15 | [SCF (Shared MLIR Dialect)](#isa-15-shared-scf) | Structured loops, branches, and loop-carried state around PTO regions | 5 | `scf.for`, `scf.if`, `scf.while`, `scf.condition`, `scf.yield` | + +--- + +## Detailed ISA Group Reference + +This section inlines the 15 ISA group documents so the architectural overview, notation, summary table, and per-group semantics can be read in a single file. + + + +### 1. Pipeline Synchronization + +> **Category:** Synchronization primitives for coordinating pipeline execution +> **Pipelines:** MTE2 (GM→UB), PIPE_V (Vector), MTE3 (UB→GM) + +The PTO micro Instruction model operates on the Ascend 950's **Decoupled Access-Execute** architecture. The MTE and Vector pipelines run asynchronously, requiring explicit synchronization to prevent data hazards. + +--- + +#### Intra-Core Pipeline Sync + +These ops coordinate data flow between pipelines within a single vector core. + +##### `pto.set_flag` + +- **syntax:** `pto.set_flag["SRC_PIPE", "DST_PIPE", "EVENT_ID"]` +- **semantics:** Signal event from source pipe to destination pipe. + +```c +set_flag(src_pipe, dst_pipe, event_id); +``` + +**Example:** After MTE2 completes GM→UB transfer, signal Vector pipe: +```mlir +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +--- + +##### `pto.wait_flag` + +- **syntax:** `pto.wait_flag["SRC_PIPE", "DST_PIPE", "EVENT_ID"]` +- **semantics:** Block destination pipe until source pipe signals event. + +```c +wait_flag(src_pipe, dst_pipe, event_id); +``` + +**Example:** Vector pipe waits for MTE2 data to arrive: +```mlir +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +--- + +##### `pto.pipe_barrier` + +- **syntax:** `pto.pipe_barrier "PIPE_*"` +- **semantics:** Drain all pending ops in the specified pipe. All previously issued operations on that pipe complete before any subsequent operation begins. + +```c +pipe_barrier(pipe); +``` + +**Pipe identifiers:** `PIPE_MTE2`, `PIPE_V`, `PIPE_MTE3` + +**Example:** Two back-to-back `copy_ubuf_to_gm` calls writing to the same GM address. Without a barrier, MTE3 may reorder them and the final GM value is non-deterministic: + +```mlir +// Both stores target the same GM address — order matters! +pto.copy_ubuf_to_gm %ub_partial_0, %gm_result, ... +// Without pipe_barrier, MTE3 could execute the second copy before the first +// completes, producing a non-deterministic result at %gm_result. +pto.pipe_barrier "PIPE_MTE3" +// After barrier: first copy is guaranteed complete. Second copy overwrites deterministically. +pto.copy_ubuf_to_gm %ub_partial_1, %gm_result, ... +``` + +--- + +##### `pto.get_buf` + +- **syntax:** `pto.get_buf "PIPE_*", %buf_id, %mode : i64, i64` +- **semantics:** Acquire buffer slot for inter-pipeline double-buffering coordination. + +```c +get_buf(pipe, buf_id, mode); +``` + +--- + +##### `pto.rls_buf` + +- **syntax:** `pto.rls_buf "PIPE_*", %buf_id, %mode : i64, i64` +- **semantics:** Release buffer slot to allow other pipeline to proceed. + +```c +rls_buf(pipe, buf_id, mode); +``` + +--- + +##### `pto.mem_bar` + +- **syntax:** `pto.mem_bar "BARRIER_TYPE"` +- **semantics:** Intra-vector-pipe memory fence within `__VEC_SCOPE__`. Required when UB addresses alias between vector load/store operations. + +```c +mem_bar(barrier_type); +``` + +**Barrier types:** + +| Type | Semantics | +|------|-----------| +| `VV_ALL` | All prior vector ops complete before subsequent | +| `VST_VLD` | All prior vector stores visible before subsequent loads | +| `VLD_VST` | All prior vector loads complete before subsequent stores | + +**Example:** Ensure stores are visible before loads to same UB region: +```mlir +pto.vsts %v0, %ub[%c0] : !pto.vreg<64xf32>, !pto.ptr +pto.mem_bar "VST_VLD" +%v1 = pto.vlds %ub[%c0] : !pto.ptr -> !pto.vreg<64xf32> +``` + +--- + +#### Intra-Core Sync Patterns & Examples + +##### Example 1: `set_flag` / `wait_flag` (Explicit Events) + +Each cross-pipeline data dependency requires an explicit signal/wait pair. The programmer must manually insert `set_flag` after the producer and `wait_flag` before the consumer. + +```mlir +// ─── Stage 1: MTE2 loads data from GM into UB ─── +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, ... + +// MTE2 signals: "UB data is ready for Vector pipe" +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +// ─── Stage 2: Vector pipe consumes UB data ─── +// Vector waits until MTE2's signal arrives +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +pto.vecscope { + %v = pto.vlds %ub_ptr[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} + +// Vector signals: "UB output is ready for MTE3" +pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + +// ─── Stage 3: MTE3 stores result from UB back to GM ─── +// MTE3 waits until Vector's signal arrives +pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + +pto.copy_ubuf_to_gm %ub_out, %gm_out, ... +``` + +**Key property:** Every cross-pipeline edge is an explicit `(set_flag, wait_flag)` pair. Simple for straight-line code, but gets verbose in loops (see Example 3). + +--- + +##### Example 2: `get_buf` / `rls_buf` (Resource-Based) + +Instead of naming events, each pipeline declares when it **acquires** (`get_buf`) and **releases** (`rls_buf`) a shared UB buffer. Cross-pipeline RAW/WAR dependencies are resolved implicitly by program order — if MTE2 releases `buf_A` and Vector later acquires `buf_A`, the hardware ensures the acquire cannot proceed until the release completes. + +```mlir +// ─── Stage 1: MTE2 loads data into UB ─── +// MTE2 acquires ub_ptr — blocks if Vector hasn't released it from a prior iteration +pto.get_buf "PIPE_MTE2", %bufid_ub_ptr, %mode : i64, i64 +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, ... +// MTE2 done writing ub_ptr — release it so Vector can consume +pto.rls_buf "PIPE_MTE2", %bufid_ub_ptr, %mode : i64, i64 + +// ─── Stage 2: Vector computation ─── +// Vector acquires ub_ptr (input) — blocks until MTE2 releases it (RAW: MTE2 write → V read) +pto.get_buf "PIPE_V", %bufid_ub_ptr, %mode : i64, i64 +// Vector acquires ub_out (output) — blocks until MTE3 releases it from a prior iteration (WAR: MTE3 read → V write) +pto.get_buf "PIPE_V", %bufid_ub_out, %mode : i64, i64 + +pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub_ptr[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} + +// Vector done reading ub_ptr — release so MTE2 can reuse it in next iteration +pto.rls_buf "PIPE_V", %bufid_ub_ptr, %mode : i64, i64 +// Vector done writing ub_out — release so MTE3 can consume +pto.rls_buf "PIPE_V", %bufid_ub_out, %mode : i64, i64 + +// ─── Stage 3: MTE3 stores result to GM ─── +// MTE3 acquires ub_out — blocks until Vector releases it (RAW: V write → MTE3 read) +pto.get_buf "PIPE_MTE3", %bufid_ub_out, %mode : i64, i64 +pto.copy_ubuf_to_gm %ub_out, %gm_out, ... +// MTE3 done reading ub_out — release so Vector can reuse it in next iteration +pto.rls_buf "PIPE_MTE3", %bufid_ub_out, %mode : i64, i64 +``` + +**Key property:** No event IDs needed. Dependencies are implicit from program order of `get_buf`/`rls_buf` on the same buffer ID. This becomes much more convenient in multi-iteration loops (see Example 3). + +--- + +##### Example 3: Ping/Pong Double-Buffering Loop + +Double-buffering overlaps DMA and compute by using two UB buffers alternately. All three stages (MTE2, Vector, MTE3) appear in the **same iteration** — the hardware pipelines them across iterations because different iterations operate on different buffers (`buf[i%2]`). + +###### Event ID scheme (`set_flag` / `wait_flag`) + +With 2 ping/pong buffers and 2 pipeline pairs (MTE2↔V, V↔MTE3), `set_flag`/`wait_flag` needs **8 event IDs** = 2 pipe-pairs × 2 buffers × (forward + reverse): + +**MTE2 ↔ Vector (input buffers):** + +| Event ID | Direction | Purpose | +|----------|-----------|---------| +| `EVT_IN_FWD_0` | MTE2 → V | RAW: buf_in[0] data ready | +| `EVT_IN_FWD_1` | MTE2 → V | RAW: buf_in[1] data ready | +| `EVT_IN_REV_0` | V → MTE2 | WAR: Vector done reading buf_in[0] | +| `EVT_IN_REV_1` | V → MTE2 | WAR: Vector done reading buf_in[1] | + +**Vector ↔ MTE3 (output buffers):** + +| Event ID | Direction | Purpose | +|----------|-----------|---------| +| `EVT_OUT_FWD_0` | V → MTE3 | RAW: buf_out[0] result ready | +| `EVT_OUT_FWD_1` | V → MTE3 | RAW: buf_out[1] result ready | +| `EVT_OUT_REV_0` | MTE3 → V | WAR: MTE3 done reading buf_out[0] | +| `EVT_OUT_REV_1` | MTE3 → V | WAR: MTE3 done reading buf_out[1] | + +###### 3a. `set_flag` / `wait_flag` version + +```mlir +// ═══ Pre-loop: prime ALL reverse-dependency signals ═══ +// Both input and output buffers start unused. We must pre-send +// reverse-dep signals so the first iteration's wait_flags don't deadlock. +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_0"] // ◀ PRIME: buf_in[0] "free" +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_1"] // ◀ PRIME: buf_in[1] "free" +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_0"] // ◀ PRIME: buf_out[0] "free" +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_1"] // ◀ PRIME: buf_out[1] "free" + +scf.for %i = %c0 to %N step %c1 { + // ── All 3 stages in same iteration, indexed by i%2 ── + // %pp = i % 2 (ping/pong selector for buffer & event IDs) + + // ── MTE2: load tile[i] into buf_in[i%2] ── + // WAR: wait until Vector has released buf_in[i%2] from iteration i-2 + pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{pp}"] + pto.copy_gm_to_ubuf %gm_ptr[%i], %ub_in[%pp], ... + // RAW: signal Vector that buf_in[i%2] data is ready + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVT_IN_FWD_{pp}"] + + // ── Vector: compute buf_in[i%2] → buf_out[i%2] ── + // RAW: wait for MTE2 to finish loading buf_in[i%2] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVT_IN_FWD_{pp}"] + // WAR: wait for MTE3 to finish reading buf_out[i%2] from iteration i-2 + pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{pp}"] + scf.for %dummy = %c0 to %c1 step %c1 { + %v = pto.vlds %ub_in[%pp][%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%pp][%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } {llvm.loop.aivector_scope} + // WAR: tell MTE2 "done reading buf_in[i%2]" + pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{pp}"] + // RAW: tell MTE3 "buf_out[i%2] result ready" + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVT_OUT_FWD_{pp}"] + + // ── MTE3: store result from buf_out[i%2] to GM ── + // RAW: wait for Vector to finish writing buf_out[i%2] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVT_OUT_FWD_{pp}"] + pto.copy_ubuf_to_gm %ub_out[%pp], %gm_out[%i], ... + // WAR: tell Vector "done reading buf_out[i%2]" + pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{pp}"] +} + +// ═══ Post-loop: drain — match every pre-loop prime with a wait ═══ +// Each priming set_flag must be paired. The last loop iteration's +// set_flags are consumed by wait_flags that will never fire inside the +// loop (there is no iteration i+2). Drain them here. +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{(N-1)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{(N-2)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{(N-1)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{(N-2)%2}"] // ◀ DRAIN +``` + +**What `set_flag`/`wait_flag` requires outside the loop:** +- **Before the loop (4 × `set_flag`):** Prime every reverse-dependency event ID — one per buffer per pipe-pair. Without this, the first iteration's `wait_flag` for reverse deps would deadlock (no signal was ever sent). +- **After the loop (4 × `wait_flag`):** Drain the matching reverse-dep signals from the last iterations. Every `set_flag` must be paired with a `wait_flag` — the last loop iterations produce signals that no subsequent iteration consumes, so they must be drained explicitly. + +###### 3b. `get_buf` / `rls_buf` version + +Same ping/pong double-buffering, but **no pre-loop priming or post-loop draining needed.** Buffer acquire/release semantics handle everything. + +```mlir +scf.for %i = %c0 to %N step %c1 { + // %pp = i % 2 (ping/pong selector) + + // ── MTE2: load tile[i] into buf[i%2] ── + // Acquires buf[i%2] — on first iteration, buffer is free so proceeds immediately. + // On later iterations, blocks until Vector releases buf[i%2] (WAR: automatic). + pto.get_buf %bufid_buf[%pp], "PIPE_MTE2" + pto.copy_gm_to_ubuf %gm_ptr[%i], %ub_buf[%pp], ... + pto.rls_buf %bufid_buf[%pp], "PIPE_MTE2" + + // ── Vector: compute on buf[i%2] ── + // Acquires buf[i%2] — blocks until MTE2 releases it (RAW: automatic) + pto.get_buf %bufid_buf[%pp], "PIPE_V" + pto.get_buf %bufid_out[%pp], "PIPE_V" + scf.for %dummy = %c0 to %c1 step %c1 { + %v = pto.vlds %ub_buf[%pp][%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%pp][%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } {llvm.loop.aivector_scope} + // Release buf[i%2] — MTE2 can reuse in iteration i+2 (WAR resolved) + pto.rls_buf %bufid_buf[%pp], "PIPE_V" + pto.rls_buf %bufid_out[%pp], "PIPE_V" + + // ── MTE3: store result ── + // Acquires out[i%2] — blocks until Vector releases it (RAW: automatic) + pto.get_buf %bufid_out[%pp], "PIPE_MTE3" + pto.copy_ubuf_to_gm %ub_out[%pp], %gm_out[%i], ... + pto.rls_buf %bufid_out[%pp], "PIPE_MTE3" +} +// No post-loop drain needed — last rls_buf completes the pipeline. +``` + +**No priming, no draining, no event IDs.** The acquire/release protocol on buffer IDs indexed by `i%2` implicitly resolves all cross-pipeline dependencies: +- **RAW** (MTE2→V): Vector's `get_buf` blocks until MTE2's `rls_buf` on `buf[i%2]` +- **WAR** (V→MTE2): MTE2's `get_buf` in iteration `i+2` blocks until Vector's `rls_buf` in iteration `i` (same buffer) +- **First iteration:** Buffer is initially free, so `get_buf` proceeds without blocking — no priming needed + +--- + +#### Comparison Summary + +| Aspect | `set_flag` / `wait_flag` | `get_buf` / `rls_buf` | +|--------|--------------------------|------------------------| +| Dependency model | Explicit event signals | Implicit via buffer acquire/release | +| IDs per pipe-pair | **8** = 2 buffers × 2 dirs × 2 (fwd+rev) | 1 fwd + 1 rev per buffer (shared global pool) | +| Total HW IDs | 8 per pipe-pair, grows with buffers | **32 global** across all pipes | +| Reverse (WAR) deps | Extra `set_flag`/`wait_flag` pair per buffer | Handled automatically | +| Pre-loop setup | `set_flag` to prime each reverse dep | None | +| Post-loop teardown | `wait_flag` to drain all primed signals | None | +| Straight-line code | Simple, clear | Slightly more verbose (bracket each stage) | +| Ping/pong loops | 8 event IDs + 4 prime + 4 drain | Same pattern, no overhead | +| Best used for | Simple pipelines, fine-grained control | Double/multi-buffering, complex loops | + +--- + +#### Inter-Core Sync + +> **Note:** Inter-core sync is only needed for **mixed Cube+Vector tasks** where Cube produces data that Vector consumes (or vice versa). **Vec-only tasks can ignore this section entirely.** + +These ops coordinate execution across the Cube block and Vector subblocks within a cluster. Each core cluster consists of **1 Cube block : 2 Vector subblocks**, each with its own **SU (Sequencer Unit)** running independent instruction streams. + +``` +Core Cluster (1:2 ratio) +┌─────────────────────────────────────────────┐ +│ ┌──────────────┐ ┌──────────────┐ │ +│ │ AIC (Cube) │ │ AIV0 (Vec) │ │ +│ │ ┌────────┐ │ │ ┌────────┐ │ │ +│ │ │ SU │──┼────┼──│ SU │ │ │ +│ │ └────────┘ │ │ └────────┘ │ │ +│ │ CUBE pipe │ │ MTE2/V/MTE3 │ │ +│ │ L0C buffer │ │ UB (256KB) │ │ +│ └──────────────┘ └──────────────┘ │ +│ ┌──────────────┐ │ +│ │ AIV1 (Vec) │ │ +│ │ ┌────────┐ │ │ +│ │ │ SU │ │ │ +│ │ └────────┘ │ │ +│ │ MTE2/V/MTE3 │ │ +│ │ UB (256KB) │ │ +│ └──────────────┘ │ +└─────────────────────────────────────────────┘ +``` + +##### Platform Comparison + +| Aspect | A2A3 (Ascend 910) | A5 (Ascend 950) | +|--------|-------------------|-----------------| +| **Signal op** | `set_cross_core` (mode2) | `set_intra_block` | +| **Wait op** | `wait_flag_dev` | `wait_intra_core` | +| **Wait behavior** | SU-level blocking (entire core stalls) | Per-pipeline (only named pipe stalls) | +| **Semaphore pool** | 16 IDs per cluster, 4-bit counter | 16 IDs, but 32-ID address space (see below) | +| **C→V** | **Broadcast**: one `set` reaches both AIV0+AIV1 | **1:1**: separate `set` per subblock required | +| **V→C** | **Reduce**: Cube waits for both subblocks in one `wait` | **1:1**: Cube needs separate `wait` per subblock | + +##### A2A3: `set_cross_core` / `wait_flag_dev` + +```c +// mode2 broadcast/reduce semantics for 1:2 cluster +set_cross_core(pipe, semaphore_id); // pipe: VEC/MTE2/CUBE/FIX +wait_flag_dev(semaphore_id); // SU-level blocking +``` + +``` +C→V Broadcast (one set reaches both): + AIC ──set_cross_core──┬──> AIV0 sema++ + └──> AIV1 sema++ + +V→C Reduce (one wait for both): + AIV0 ──set_cross_core──┐ + ├──> AIC wait_flag_dev (blocks until BOTH) + AIV1 ──set_cross_core──┘ +``` + +##### `pto.set_cross_core` + +- **syntax:** `pto.set_cross_core %core_id, %event_id : i64, i64` +- **semantics:** Signal event to another core. Uses **mode2** for 1:2 cluster on A2A3. + +##### `pto.wait_flag_dev` + +- **syntax:** `pto.wait_flag_dev %core_id, %event_id : i64, i64` +- **semantics:** Wait for event from another core. **SU-level blocking** — entire core stalls. + +##### A5: `set_intra_block` / `wait_intra_core` + +```c +set_intra_block(trigger_pipe, semaphore_id); +wait_intra_core(wait_pipe, semaphore_id); // only named pipe stalls +``` + +**A5 semaphore address space:** The hardware has **16 physical semaphore IDs** but exposes a **32-ID address space** to support 1:1 signaling to each subblock: + +| ID Range | Target | +|----------|--------| +| 0–15 | AIV0 (subblock 0) | +| 16–31 (+15 offset) | AIV1 (subblock 1) | + +This means C→V requires **separate `set_intra_block` calls** per subblock (no broadcast), and V→C requires **separate `wait_intra_core` calls** per subblock (no hardware reduce). + +``` +C→V on A5 (1:1, no broadcast — need two sets): + AIC ──set_intra_block(pipe, sema_id)────> AIV0 + AIC ──set_intra_block(pipe, sema_id+15)──> AIV1 + +V→C on A5 (1:1, no reduce — need two waits): + AIV0 ──set_intra_block──> AIC wait_intra_core(pipe, sema_id) + AIV1 ──set_intra_block──> AIC wait_intra_core(pipe, sema_id+15) // extra wait +``` + +##### `pto.set_intra_block` + +- **syntax:** `pto.set_intra_block %block_id, %event_id : i64, i64` +- **semantics:** Signal event within a block (A5). Specifies **trigger pipe**. 1:1 per subblock. + +##### `pto.wait_intra_core` + +- **syntax:** `pto.wait_intra_core %block_id, %event_id : i64, i64` +- **semantics:** Wait for event within block (A5). Specifies **which pipeline should wait** — only that pipe stalls, SU and other pipes continue. + +##### Wait Granularity Comparison + +``` +A2A3 wait_flag_dev (SU-level stall): + SU ──┬── PIPE_MTE2 ───╳ ALL STALLED + ├── PIPE_V ───╳ ALL STALLED + └── PIPE_MTE3 ───╳ ALL STALLED + +A5 wait_intra_core "PIPE_MTE2" (per-pipe stall): + SU ──┬── PIPE_MTE2 ───╳ STALLED (waiting for Cube) + ├── PIPE_V ─── ✓ RUNNING + └── PIPE_MTE3 ─── ✓ RUNNING +``` + + + +### 2. DMA Copy Programming + +> **Category:** DMA transfer configuration and execution +> **Pipelines:** MTE2 (GM→UB), MTE3 (UB→GM) + +DMA transfers move data between Global Memory (GM) and Unified Buffer (UB). The MTE engines operate asynchronously from the Vector core, requiring explicit sync (see [Pipeline Sync](#isa-01-pipeline-sync)). + +The MTE2/MTE3 DMA engine executes a **multi-level nested loop** transfer. Before issuing the copy instruction, stride and loop-size registers must be configured. + +--- + +#### Loop Stride Configuration (GM→UB) + +These ops configure the MTE2 DMA engine's hardware loops for GM→UB transfers. They must be set **before** calling `pto.copy_gm_to_ubuf`. + +##### `pto.set_loop_size_outtoub` + +- **syntax:** `pto.set_loop_size_outtoub %loop1_count, %loop2_count : i64, i64` +- **semantics:** Configure HW loop iteration counts for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%loop1_count` | 21 bits | Inner HW loop iteration count | +| `%loop2_count` | 21 bits | Outer HW loop iteration count | + +When not using multi-level looping, set both to 1. + +--- + +##### `pto.set_loop2_stride_outtoub` + +- **syntax:** `pto.set_loop2_stride_outtoub %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure outer loop (loop2) pointer advance for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 40 bits | GM source pointer advance per loop2 iteration (bytes) | +| `%dst_stride` | 21 bits | UB destination pointer advance per loop2 iteration (bytes) | + +After each loop2 iteration, the DMA engine advances the GM read pointer by `%src_stride` and UB write pointer by `%dst_stride`. + +--- + +##### `pto.set_loop1_stride_outtoub` + +- **syntax:** `pto.set_loop1_stride_outtoub %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure inner loop (loop1) pointer advance for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 40 bits | GM source pointer advance per loop1 iteration (bytes) | +| `%dst_stride` | 21 bits | UB destination pointer advance per loop1 iteration (bytes) | + +--- + +#### Loop Stride Configuration (UB→GM) + +These ops configure the MTE3 DMA engine's hardware loops for UB→GM transfers. They must be set **before** calling `pto.copy_ubuf_to_gm`. + +Note: UB stride fields are 21 bits (sufficient for 256KB UB address space), GM stride fields are 40 bits (full GM address range). + +##### `pto.set_loop_size_ubtoout` + +- **syntax:** `pto.set_loop_size_ubtoout %loop1_count, %loop2_count : i64, i64` +- **semantics:** Configure HW loop iteration counts for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%loop1_count` | 21 bits | Inner HW loop iteration count | +| `%loop2_count` | 21 bits | Outer HW loop iteration count | + +--- + +##### `pto.set_loop2_stride_ubtoout` + +- **syntax:** `pto.set_loop2_stride_ubtoout %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure outer loop (loop2) pointer advance for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 21 bits | UB source pointer advance per loop2 iteration (bytes) | +| `%dst_stride` | 40 bits | GM destination pointer advance per loop2 iteration (bytes) | + +--- + +##### `pto.set_loop1_stride_ubtoout` + +- **syntax:** `pto.set_loop1_stride_ubtoout %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure inner loop (loop1) pointer advance for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 21 bits | UB source pointer advance per loop1 iteration (bytes) | +| `%dst_stride` | 40 bits | GM destination pointer advance per loop1 iteration (bytes) | + +--- + +#### DMA Transfer Execution + +##### `pto.copy_gm_to_ubuf` + +- **syntax:** +```mlir +pto.copy_gm_to_ubuf %gm_src, %ub_dst, + %sid, %n_burst, %len_burst, %left_padding, %right_padding, + %data_select_bit, %l2_cache_ctl, %src_stride, %dst_stride + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` +- **semantics:** DMA transfer from Global Memory (`!pto.ptr`) to Unified Buffer (`!pto.ptr`). + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%gm_src` | GM source pointer (`!pto.ptr`) | +| `%ub_dst` | UB destination pointer (`!pto.ptr`, 32B-aligned) | +| `%sid` | Stream ID (usually 0) | +| `%n_burst` | Number of burst rows (innermost loop count) | +| `%len_burst` | Contiguous bytes transferred per burst row | +| `%left_padding` | Left padding count (bytes) | +| `%right_padding` | Right padding count (bytes) | +| `%data_select_bit` | Padding / data-select control bit (`i1`) | +| `%l2_cache_ctl` | L2 cache allocate control (TBD — controls whether DMA allocates in L2 cache) | +| `%src_stride` | GM source stride: start-to-start distance between consecutive burst rows (bytes) | +| `%dst_stride` | UB destination stride: start-to-start distance between consecutive burst rows (bytes, 32B-aligned) | + +--- + +##### `pto.copy_ubuf_to_gm` + +- **syntax:** +```mlir +pto.copy_ubuf_to_gm %ub_src, %gm_dst, + %sid, %n_burst, %len_burst, %reserved, %dst_stride, %src_stride + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` +- **semantics:** DMA transfer from Unified Buffer (`!pto.ptr`) to Global Memory (`!pto.ptr`). MTE3 reads only `len_burst` bytes from each UB row (de-padding). + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%ub_src` | UB source pointer (`!pto.ptr`, 32B-aligned) | +| `%gm_dst` | GM destination pointer (`!pto.ptr`) | +| `%sid` | Stream ID (usually 0) | +| `%n_burst` | Number of burst rows | +| `%len_burst` | Contiguous bytes transferred per burst row | +| `%reserved` | Reserved field (set to 0) | +| `%dst_stride` | GM destination stride: start-to-start distance between consecutive burst rows (bytes) | +| `%src_stride` | UB source stride: start-to-start distance between consecutive burst rows (bytes, 32B-aligned) | + +--- + +##### `pto.copy_ubuf_to_ubuf` + +- **syntax:** +```mlir +pto.copy_ubuf_to_ubuf %source, %dest, %sid, %n_burst, %len_burst, %src_stride, %dst_stride + : !pto.ptr, !pto.ptr, i64 x5 +``` +- **semantics:** Copy within Unified Buffer. + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%source` | UB source pointer | +| `%dest` | UB destination pointer | +| `%sid` | Stream ID | +| `%n_burst` | Number of bursts | +| `%len_burst` | Length per burst | +| `%src_stride` | Source stride | +| `%dst_stride` | Destination stride | + +--- + +#### Burst / Stride / Pad Model + +All A5 DMA addresses are **stride-based**: stride is the distance from the start of one row to the start of the next row (`stride >= lenBurst`). There is no separate "gap" parameter. + +##### Key Terms + +``` +burst = lenBurst contiguous bytes transferred per row +stride = distance (bytes) from start of row[r] to start of row[r+1] +pad = ub_stride - lenBurst, padded to the 32B alignment boundary +``` + +##### Alignment Constraints + +- **UB addresses** (both source and destination) must be **32-byte aligned**. +- **GM→UB padding**: When `data_select_bit = true`, each UB row is padded from `lenBurst` up to the **32B-aligned boundary** of `ub_stride` with `pad_val` (set via `set_mov_pad_val`). This ensures every UB row starts at a 32B-aligned offset. +- **UB→GM de-padding**: MTE3 reads `lenBurst` bytes from each 32B-aligned UB row (skipping any padding that was added during load), writing only valid data to GM. This effectively strips padding on store. + +##### 2D Diagram: GM→UB (pto.copy_gm_to_ubuf) + +``` +GM (source, `!pto.ptr`): + + |<--- src_stride (start-to-start) --->| + |<- len_burst ->| | +Row 0: [##DATA########]......................| +Row 1: [##DATA########]......................| +Row 2: [##DATA########]......................| + ... +Row N-1: [##DATA########] + +UB (destination, `!pto.ptr`, 32B-aligned): + + |<---------- dst_stride (32B-aligned) ---------->| + |<- len_burst ->|<- pad (to 32B boundary) ->| | +Row 0: [##DATA########][000000 PAD 000000000000000] +Row 1: [##DATA########][000000 PAD 000000000000000] +Row 2: [##DATA########][000000 PAD 000000000000000] + ... +Row N-1: [##DATA########][000000 PAD 000000000000000] + +N = n_burst +stride = start of row[r] to start of row[r+1] +pad = filled with pad_val to 32B boundary (data_select_bit=true) +[DATA] = valid data transferred by DMA +[PAD] = pad_val fill (set via set_mov_pad_val) +``` + +##### 2D Diagram: UB→GM (pto.copy_ubuf_to_gm) + +``` +UB (source, `!pto.ptr`, 32B-aligned start addr): + + |<---------- src_stride (32B-aligned) --------->| + |<- len_burst ->|<-- pad (ignored on read) -->| | +Row 0: [##DATA########][000 pad 000000000000000000] +Row 1: [##DATA########][000 pad 000000000000000000] +Row 2: [##DATA########][000 pad 000000000000000000] + ... +Row N-1: [##DATA########][000 pad 000000000000000000] + +GM (destination, `!pto.ptr`): + + |<--- dst_stride (start-to-start) --->| + |<- len_burst ->| | +Row 0: [##DATA########]......................| +Row 1: [##DATA########]......................| +Row 2: [##DATA########]......................| + ... +Row N-1: [##DATA########] + +N = n_burst +MTE3 reads only len_burst bytes from each UB row (de-padding). +Only len_burst bytes are written to each GM row. +``` + +--- + +#### Multi-Level Loop Semantics (C Code) + +The full DMA transfer is a nested loop. The HW loop registers (set before the copy) control the outer levels, and the copy instruction parameters control the innermost burst level. + +##### GM→UB Full Loop + +```c +// C equivalent of what the HW executes: +for (int j = 0; j < loop2_count; j++) { // HW outer loop + uint8_t *gm1 = gm_src + j * loop2_src_stride; + uint8_t *ub1 = ub_dst + j * loop2_dst_stride; + + for (int k = 0; k < loop1_count; k++) { // HW inner loop + uint8_t *gm2 = gm1 + k * loop1_src_stride; + uint8_t *ub2 = ub1 + k * loop1_dst_stride; + + for (int r = 0; r < n_burst; r++) { // burst engine + memcpy(ub2 + r * dst_stride, // UB dest row + gm2 + r * src_stride, // GM src row + len_burst); // contiguous bytes + if (data_select_bit) + memset(ub2 + r * dst_stride + len_burst, + pad_val, dst_stride - len_burst); + } + } +} +``` + +##### UB→GM Full Loop + +```c +// C equivalent: +for (int j = 0; j < loop2_count; j++) { + uint8_t *ub1 = ub_src + j * loop2_src_stride; + uint8_t *gm1 = gm_dst + j * loop2_dst_stride; + + for (int k = 0; k < loop1_count; k++) { + uint8_t *ub2 = ub1 + k * loop1_src_stride; + uint8_t *gm2 = gm1 + k * loop1_dst_stride; + + for (int r = 0; r < n_burst; r++) { + memcpy(gm2 + r * dst_stride, // GM dest row + ub2 + r * src_stride, // UB src row + len_burst); // contiguous bytes + } + } +} +``` + +--- + +#### Example 1: GM→UB — Load a 32×32 f32 Tile (Simple Case) + +Load a 32×32 f32 tile from GM into UB. This matches the `abs_kernel_2d` test case. + +``` +GM layout (32 × 32 f32, contiguous): + + |<- len_burst = 128B (32 × 4) ->| + |<- src_stride = 128B --------->| + +--[#######TILE#######]--+ row 0 + +--[#######TILE#######]--+ row 1 + ... + +--[#######TILE#######]--+ row 31 + +UB layout (32 × 32 f32, 32B-aligned, contiguous): + + |<- dst_stride = 128B (32B-aligned) ->| + +--[#######TILE#######]--+ row 0 + +--[#######TILE#######]--+ row 1 + ... + +--[#######TILE#######]--+ row 31 + + len_burst = 32 × 4 = 128 bytes + src_stride = 128 bytes (contiguous rows) + dst_stride = 128 bytes (already 32B-aligned, no padding) +``` + +```mlir +// Simple 2D load — no multi-level loops needed +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + +pto.copy_gm_to_ubuf %arg0, %ub_in, + %c0_i64, // sid = 0 + %c32_i64, // n_burst = 32 (32 rows) + %c128_i64, // len_burst = 128 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c128_i64, // src_stride = 128 bytes + %c128_i64 // dst_stride = 128 bytes + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +#### Example 2: GM→UB — Load a 2D Tile from a Larger Matrix + +Load a 64×128 tile (f16) from a 1024×512 matrix in GM into UB. + +``` +GM layout (1024 × 512 f16): + + col 0 col 128 col 512 + | | | + +--[###TILE###]+.....................+ row R + +--[###TILE###]+.....................+ row R+1 + ... + +--[###TILE###]+.....................+ row R+63 + + |<--------- src_stride = 1024B ----------->| + |<-len_burst=256B->| + + len_burst = 128 × 2 = 256 bytes (128 f16 elements) + src_stride = 512 × 2 = 1024 bytes (start-to-start, full GM row) + +UB layout (64 × 128 f16, 32B-aligned, contiguous): + + +--[###TILE###]--+ row 0 (256 bytes, 32B-aligned, no pad) + +--[###TILE###]--+ row 1 + ... + +--[###TILE###]--+ row 63 + + dst_stride = 256 bytes (= len_burst, already 32B-aligned, no padding) +``` + +```mlir +// Simple 2D load — no multi-level loops needed +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 (64 rows) + %c256_i64, // len_burst = 256 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c1024_i64, // src_stride = 1024 bytes (full matrix row) + %c256_i64 // dst_stride = 256 bytes (tile row) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +#### Example 3: GM→UB — Load with Padding + +Load 100 valid columns from GM into a 128-wide UB tile (f16). The remaining 28 columns are zero-padded. + +``` +GM (100 cols valid, contiguous): + + |<-len_burst=200B->| + |<- src_stride=200B (start-to-start) ->| + +--[####DATA####]-+ row 0 + +--[####DATA####]-+ row 1 + ... + +--[####DATA####]-+ row 63 + +UB (128 cols wide, 32B-aligned, padded): + + |<--------- dst_stride = 256B (32B-aligned) --------->| + |<-len_burst=200B->|<---- pad = 56B to 32B boundary ->| + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 0 + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 1 + ... + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 63 + + len_burst = 100 × 2 = 200 bytes + src_stride = 200 bytes (start-to-start, contiguous in GM) + dst_stride = 128 × 2 = 256 bytes (32B-aligned tile width in UB) + pad = 256 - 200 = 56 bytes (padded to 32B boundary with pad_val) +``` + +```mlir +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 + %c200_i64, // len_burst = 200 bytes + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %true, // data_select_bit = true (enable padding) + %c0_i64, // l2_cache_ctl = 0 + %c200_i64, // src_stride = 200 bytes + %c256_i64 // dst_stride = 256 bytes (32B-aligned) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +#### Example 4: UB→GM — Store a 32×32 f32 Tile (Simple Case) + +Store a 32×32 f32 tile from UB back to GM. This matches the `abs_kernel_2d` test case. + +``` +UB (source, 32B-aligned, 32 × 32 f32): + + |<- src_stride = 128B (32B-aligned) ->| + |<- len_burst = 128B ->| + +--[#######TILE#######]---+ row 0 + +--[#######TILE#######]---+ row 1 + ... + +--[#######TILE#######]---+ row 31 + + (no padding here — len_burst == src_stride) + +GM (dest, 32 × 32 f32): + + |<- dst_stride = 128B ->| + |<- len_burst = 128B -->| + +--[#######TILE#######]---+ row 0 + +--[#######TILE#######]---+ row 1 + ... + +--[#######TILE#######]---+ row 31 +``` + +```mlir +// Configure MTE3 strides +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + +pto.copy_ubuf_to_gm %ub_out, %arg1, + %c0_i64, // sid = 0 + %c32_i64, // n_burst = 32 + %c128_i64, // len_burst = 128 bytes + %c0_i64, // reserved = 0 + %c128_i64, // dst_stride = 128 bytes + %c128_i64 // src_stride = 128 bytes + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` + +--- + +#### Example 5: UB→GM — Store a 2D Tile Back to a Larger Matrix + +Store a 64×128 tile (f16) from UB back to a 1024×512 GM matrix at an offset. + +``` +UB (source, 32B-aligned, 64 × 128 f16): + + |<- src_stride = 256B (32B-aligned) ->| + |<- len_burst = 256B ->| + +--[#####TILE#####]---+ row 0 + +--[#####TILE#####]---+ row 1 + ... + +--[#####TILE#####]---+ row 63 + + (no padding here — len_burst == src_stride) + +GM (dest, into 1024 × 512 matrix): + + |<----------- dst_stride = 1024B (start-to-start) --------->| + |<- len_burst = 256B ->| | + col 0 col 128 col 512 + +--[#####TILE#####]---+.............................+ row R + +--[#####TILE#####]---+.............................+ row R+1 + ... + +--[#####TILE#####]---+.............................+ row R+63 + + MTE3 reads len_burst bytes from each 32B-aligned UB row, + writes only len_burst bytes per GM row (stride controls row spacing). +``` + +```mlir +// Configure MTE3 strides +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_ubtoout %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_ubtoout %c0_i64, %c0_i64 : i64, i64 + +pto.copy_ubuf_to_gm %ub_ptr, %gm_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 + %c256_i64, // len_burst = 256 bytes + %c0_i64, // reserved = 0 + %c1024_i64, // dst_stride = 1024 bytes (GM row) + %c256_i64 // src_stride = 256 bytes (UB row) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` + +--- + +#### Example 6: GM→UB with Multi-Level Loop (Batch of Tiles) + +Load 4 batches of 8×128 tiles from a [4, 8, 128] f16 tensor using loop1. + +``` +GM [4, 8, 128] f16 (contiguous): UB (4 tiles laid out sequentially): + + batch 0: 8 rows × 256 bytes [batch 0: 8×128][batch 1: 8×128] + batch 1: 8 rows × 256 bytes [batch 2: 8×128][batch 3: 8×128] + batch 2: 8 rows × 256 bytes + batch 3: 8 rows × 256 bytes loop1 src_stride = 2048 bytes (8 × 256) + loop1 dst_stride = 2048 bytes (8 × 256) + Each batch = 8 × 256 = 2048 bytes loop1_count = 4 (iterate over batches) +``` + +```mlir +// loop1_count = 4 batches, loop2_count = 1 (not used) +pto.set_loop_size_outtoub %c4_i64, %c1_i64 : i64, i64 + +// loop1 stride: advance by one batch (2048 bytes) in both GM and UB +pto.set_loop1_stride_outtoub %c2048_i64, %c2048_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c8_i64, // n_burst = 8 rows per batch + %c256_i64, // len_burst = 256 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c256_i64, // src_stride = 256 (contiguous rows) + %c256_i64 // dst_stride = 256 (contiguous rows) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +Execution trace: + +``` +loop1 iter 0: gm_ptr + 0×2048 → ub_ptr + 0×2048, DMA 8 rows × 256B +loop1 iter 1: gm_ptr + 1×2048 → ub_ptr + 1×2048, DMA 8 rows × 256B +loop1 iter 2: gm_ptr + 2×2048 → ub_ptr + 2×2048, DMA 8 rows × 256B +loop1 iter 3: gm_ptr + 3×2048 → ub_ptr + 3×2048, DMA 8 rows × 256B +``` + + + +### 3. Vector Load/Store + +> **Category:** UB ↔ Vector Register data movement +> **Pipeline:** PIPE_V (Vector Core) + +Vector loads move data from Unified Buffer (UB) to vector registers (`vreg`). Vector stores move data from `vreg` back to UB. All vector compute operates only on `vreg` — UB is the staging area between DMA and compute. + +#### Common Operand Model + +- `%source` / `%dest` is the base address operand in SSA form. The base pointer + MUST address the Vector tile buffer / UB space. +- `%offset` is the displacement operand in SSA form. The exact encoding is + instruction-specific, but the effective address and any post-update behavior + MUST match the selected instruction form. +- `%mask` is the predicate operand for predicated memory families. For memory + families, + inactive lanes or inactive blocks MUST NOT issue memory requests unless the + instruction explicitly documents a different behavior. +- `%result` is the destination vector register value in SSA form. +- `!pto.align` is the SSA carrier for alignment-buffer state used by unaligned + load/store families. The PTO micro Instruction representation makes that state explicit rather than implicit. + +--- + +#### Latency and throughput (A5) + +**Cycle-accurate simulator (CA model)** issue→retire timings for vector-side instructions behind this chapter. Values are **simulator** results, **not** guaranteed for silicon. + +**SOC:** Tables below are from **Ascend910_9599** CA sim (the pto-isa ST default when **Ascend950PR_9599** is not selected). + +**Log `dist:` tokens:** PTO load/store modes lower to **`RV_VLD` / `RV_VLDI` / `RV_VST` / `RV_VSTI`** with a **`dist:`** field on the vector pipes (`RVECLD` / `RVECST`). Some simulator logs typo contiguous load as `dist:NORAML`; treat as **`NORMAL`**. + +##### Reference op latencies (A5 mnemonics) + +| A5 mnemonic | Mode / note | Typical issue→retire (cycles) | +|-------------|-------------|------------------------------| +| `RV_VLD` | `dist:NORMAL` / `NORAML` | **9** | +| `RV_VLDI` | `dist:DINTLV_B32` (dual vreg) | **9** | +| `RV_VST` / `RV_VSTI` | `dist:NORM_B32` | **9** | +| `RV_VGATHER2` | `Dtype: B32` | **27–28** | +| `RV_VGATHERB` | indexed byte gather | **~21** | +| `RV_VSCATTER` | `Dtype: B16` | **~17** | +| `RV_VADD` | F32 between UB-backed ops | **7** | + +##### `dist:` tokens (issue→retire) + +Most **`dist:`** tokens are **9** issue→retire cycles. **`INTLV_*`** on **`RV_VSTI`** are **12** cycles. + +| `dist:` (as in log) | RV op | issue→retire (cycles) | +|---------------------|-------|----------------------| +| `DINTLV_B32` | `RV_VLDI` | **9** | +| `DINTLV_B16` | `RV_VLDI` | **9** | +| `DINTLV_B8` | `RV_VLDI` | **9** | +| `BRC_B32` | `RV_VLD` | **9** | +| `BRC_B8` | `RV_VLD` | **9** | +| `BRC_B16` | `RV_VLD` | **9** | +| `BRC_BLK` | `RV_VLD` | **9** | +| `INTLV_B32` | `RV_VSTI` | **12** | +| `INTLV_B16` | `RV_VSTI` | **12** | +| `INTLV_B8` | `RV_VSTI` | **12** | +| `UNPK_B8` | `RV_VLD` | **9** | +| `UNPK_B16` | `RV_VLD` | **9** | +| `UNPK_B32` | `RV_VLD` | **9** | +| `NORM_B32` | `RV_VSTI` | **9** | +| `NORM_B16` | `RV_VSTI` | **9** | +| `NORM_B8` | `RV_VSTI` | **9** | +| `PK_B32` | `RV_VSTI` | **9** | +| `PK_B16` | `RV_VSTI` | **9** | +| `NORMAL` / `NORAML` | `RV_VLD` | **9** | + +**Note:** PTO intrinsic **`BLK`** matches the **`BRC_BLK`** `dist:` string on **`RV_VLD`** in simulator logs (block-replicate path; not a plain contiguous copy in the usual tiling use). + +**Issue (vector load/store):** `pto.vlds` (**`RV_VLD`**) is **dual-issue capable**: two independent `pto.vlds` can issue **in the same cycle**. **Alternatively**, the hardware can issue **one** `pto.vlds` **and** **one** `pto.vsts` together (**1+1**) in the same cycle. Each cycle is **either** dual **`vlds`** **or** **`vlds` + `vsts` (1+1)**—those two issue modes are mutually exclusive. Sustained throughput still depends on RAW hazards and loop structure. + +**Throughput (simulator, pattern-dependent):** + +- **`RV_VLD` / `pto.vlds`:** Dual-issue **or** half of a **1+1** with `vsts`, per the rule above. +- **`RV_VST` / `pto.vsts`:** In a **1+1** cycle, pairs with one `vlds`; otherwise typically **one** store per cycle in tight loops. +- **`RV_VGATHER2`:** Much lower than contiguous `RV_VLD` (on the order of **~0.1** ops/cycle in steady-state alongside 27–28-cycle latency). + +##### PTO `dist` summary (loads) + +| PTO `dist` (load) | Latency | +|-------------------|-------------------| +| `NORM` | **9** cycles | +| `UNPK_B8`, `UNPK_B16`, `UNPK_B32` | **9** cycles | +| `DINTLV_B32` | **9** cycles (`RV_VLDI`) | +| `DINTLV_B16`, `DINTLV_B8` | **9** cycles (same `RV_VLDI` + `dist:DINTLV_*` path as `DINTLV_B32`) | +| `BRC_B32` | **9** cycles | +| `BRC_B8`, `BRC_B16` | **9** cycles (`RV_VLD`) | +| `BLK` | **9** cycles as **`dist:BRC_BLK`** on `RV_VLD` | +| `BDINTLV` | **9** cycles | +| `US_*`, `DS_*`, `SPLT*` | **9** cycles | + +##### PTO `dist` summary (stores) + +| PTO `dist` (store) | Latency | +|--------------------|-------------------| +| `NORM_B8`, `NORM_B16`, `NORM_B32` | **9** cycles (`RV_VSTI`) | +| `PK_B16`, `PK_B32` | **9** cycles | +| `INTLV_B32` (`pto.vstx2`) | **12** cycles | +| `INTLV_B16`, `INTLV_B8` | **12** cycles (same interleave store path as `INTLV_B32`) | +| `MRG4CHN_B8`, `MRG2CHN_*` | **9** cycles | + +##### Gather, scatter, and special addressing + +| PTO op | A5-level | Latency | +|--------|----------|-------------------| +| `pto.vgather2` | `RV_VGATHER2` | **27–28** cycles (pattern-dependent) | +| `pto.vgatherb` | `RV_VGATHERB` | **~21** cycles issue→retire | +| `pto.vgather2_bc` | (broadcast gather) | **27–28** cycles (same as **`pto.vgather2`**) | +| `pto.vscatter` | `RV_VSCATTER` | **~17** cycles for **`Dtype: B16`** | + +##### Strided loads/stores, unaligned ops, alignment state + +Ops such as **`pto.vldas`**, **`pto.vldus`**, **`pto.vsld`**, **`pto.vsldb`**, **`pto.vsst`**, **`pto.vsstb`**, **`pto.vsta`**, **`pto.vstas`**, **`pto.vstar`**, **`pto.vstu`**, **`pto.vstus`**, **`pto.vstur`**: **9** cycles (same vector load/store pipe family as contiguous `RV_VLD` / `RV_VST` unless listed otherwise above). + +##### Dual-issue vs DMA + +DMA **`TLOAD` / `TSTORE`** (global memory ↔ UB) use **MTE** pipes, not `RV_VLD`/`RV_VST`. **MTE2** `MOV_*` latency is not the same as vector `RV_VLD` latency (see `02-dma-copy.md` for GM↔UB movement). + +--- + +#### Contiguous Loads + +##### `pto.vlds` + +- **syntax:** `%result = pto.vlds %source[%offset] {dist = "DIST"} : !pto.ptr -> !pto.vreg` +- **semantics:** Vector load with distribution mode. +- **inputs:** + `%source` is the UB base address, `%offset` is the load displacement, and + `DIST` selects the distribution mode. +- **outputs:** + `%result` is the loaded vector register value. +- **constraints and limitations:** + The effective address MUST satisfy the alignment rule of the selected + distribution mode. `NORM` reads one full vector footprint. Broadcast, + upsample, downsample, unpack, split-channel, and deinterleave modes change + how memory bytes are mapped into destination lanes, but they do not change the + fact that the source is UB memory. + +**Distribution modes:** + +| Mode | Description | C Semantics | Latency | +|------|-------------|-------------|---------------------| +| `NORM` | Contiguous 256B load | `dst[i] = UB[base + i * sizeof(T)]` | **9** cycles | +| `BRC_B32` | Broadcast single element | `dst[i] = UB[base]` for all i | **9** cycles | +| `BRC_B8`, `BRC_B16` | Broadcast first lane element | Same idea at B8/B16 width | **9** cycles | +| `US_B8/B16` | Upsample (duplicate each element) | `dst[2*i] = dst[2*i+1] = UB[base + i]` | **9** cycles | +| `DS_B8/B16` | Downsample (every 2nd element) | `dst[i] = UB[base + 2*i]` | **9** cycles | +| `UNPK_B8/B16/B32` | Unpack (zero-extend to wider type) | `dst_i32[i] = (uint32_t)UB_i16[base + 2*i]` | **9** cycles | +| `SPLT4CHN_B8` | Split 4-channel (RGBA → R plane) | Extract every 4th byte | **9** cycles | +| `SPLT2CHN_B8/B16` | Split 2-channel | Extract every 2nd element | **9** cycles | +| `DINTLV_B32` | Deinterleave 32-bit | Even elements only | **9** cycles | +| `DINTLV_B16`, `DINTLV_B8` | Deinterleave 16-bit / 8-bit | Pair lanes from interleaved UB | **9** cycles | +| `BDINTLV` | Block deinterleave | (see PTO headers for exact tiling) | **9** cycles | +| `BLK` | Block load | Blocked / tiled access pattern (see PTO headers) | **9** cycles (`dist:BRC_BLK` on `RV_VLD`) | + +**Example — Contiguous load:** +```mlir +%v = pto.vlds %ub[%offset] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> +``` + +**Example — Broadcast scalar to all lanes:** +```mlir +%v = pto.vlds %ub[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> +``` + +--- + +##### `pto.vldas` + +- **syntax:** `%result = pto.vldas %source : !pto.ptr -> !pto.align` +- **semantics:** Prime alignment buffer for subsequent unaligned load. +- **inputs:** + `%source` is the UB address whose surrounding aligned block seeds the load + alignment state. +- **outputs:** + `%result` is the initialized load-alignment state. +- **constraints and limitations:** + This op is the required leading operation for a `pto.vldus` stream using the + same alignment state. The source address itself need not be 32-byte aligned; + hardware truncates it to the aligned block boundary for the priming load. +- **Latency:** **9** cycles. + +--- + +##### `pto.vldus` + +- **syntax:** `%result, %align_out, %base_out = pto.vldus %source, %align : !pto.ptr, !pto.align -> !pto.vreg, !pto.align, !pto.ptr` +- **semantics:** Unaligned load using primed align state. +- **inputs:** + `%source` is the current UB address and `%align` is the incoming load + alignment state primed by `pto.vldas` or a prior `pto.vldus`. +- **outputs:** + `%result` is the assembled vector value, `%align_out` is the updated alignment + state, and `%base_out` is the post-update base pointer state exposed in SSA + form. +- **constraints and limitations:** + A matching `pto.vldas` MUST appear before the first dependent `pto.vldus` + stream in the same vector loop. Both the alignment state and the base address + advance across the stream, and the PTO micro Instruction representation exposes those updates as SSA results. +- **Latency:** **9** cycles. + +**Unaligned load pattern:** +```mlir +%align = pto.vldas %ub : !pto.ptr -> !pto.align +%vec, %align2, %ub2 = pto.vldus %ub, %align : !pto.ptr, !pto.align -> !pto.vreg<64xf32>, !pto.align, !pto.ptr +``` + +--- + +#### Dual Loads (Deinterleave) + +##### `pto.vldx2` + +- **syntax:** `%low, %high = pto.vldx2 %source[%offset], "DIST" : !pto.ptr, index -> !pto.vreg, !pto.vreg` +- **semantics:** Dual load with deinterleave (AoS → SoA conversion). +- **inputs:** + `%source` is the UB base pointer, `%offset` is the displacement, and `DIST` + selects a dual-load/deinterleave layout. +- **outputs:** + `%low` and `%high` are the two destination vectors. +- **constraints and limitations:** + This family is only legal for interleave/deinterleave style distributions. + The two outputs form an ordered pair, and that pairing MUST be preserved. +- **Latency:** **`DINTLV_B32` → 9** cycles on `RV_VLDI`. **`DINTLV_B16` / `DINTLV_B8` → 9** cycles on `RV_VLDI`. **`BDINTLV` → 9** cycles on `RV_VLDI`. + +**Distribution modes:** `DINTLV_B8`, `DINTLV_B16`, `DINTLV_B32`, `BDINTLV` + +```c +// DINTLV_B32: deinterleave 32-bit elements +for (int i = 0; i < 64; i++) { + low[i] = UB[base + 8*i]; // even elements + high[i] = UB[base + 8*i + 4]; // odd elements +} +``` + +**Example — Load interleaved XY pairs into separate X/Y vectors:** +```mlir +%x, %y = pto.vldx2 %ub[%offset], "DINTLV_B32" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +``` + +--- + +#### Strided Loads + +##### `pto.vsld` + +- **syntax:** `%result = pto.vsld %source[%offset], "STRIDE" : !pto.ptr -> !pto.vreg` +- **semantics:** Strided load with fixed stride pattern. +- **inputs:** + `%source` is the UB base pointer and `%offset` is the displacement encoded + with the selected fixed stride mode. +- **outputs:** + `%result` is the loaded vector. +- **constraints and limitations:** + This is a deprecated compatibility family. The selected stride token + determines which sub-elements are read from each source block. +- **Latency:** **9** cycles. + +**Stride modes:** `STRIDE_S3_B16`, `STRIDE_S4_B64`, `STRIDE_S8_B32`, `STRIDE_S2_B64` + +--- + +##### `pto.vsldb` + +- **syntax:** `%result = pto.vsldb %source, %offset, %mask : !pto.ptr, i32, !pto.mask -> !pto.vreg` +- **semantics:** Block-strided load for 2D tile access. +- **inputs:** + `%source` is the UB base pointer, `%offset` is the packed stride/control word, + and `%mask` controls which blocks participate. +- **outputs:** + `%result` is the loaded vector. +- **constraints and limitations:** + `%offset` is not a plain byte displacement; it encodes the block stride and + repeat pattern. If a block is masked off, the corresponding destination block + is zeroed and MUST NOT raise an address overflow exception for that block. +- **Latency:** **9** cycles. + +--- + +#### Gather (Indexed) Loads + +##### `pto.vgather2` + +- **syntax:** `%result = pto.vgather2 %source, %offsets, %active_lanes : !pto.ptr, !pto.vreg, index -> !pto.vreg` +- **semantics:** Indexed gather from UB. +- **inputs:** + `%source` is the UB base pointer, `%offsets` provides per-lane element + offsets, and `%active_lanes` bounds how many lanes participate. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + Only the first `%active_lanes` indices participate. The index element width + and interpretation MUST match the selected gather form, and each effective + address must satisfy that form's alignment rules. +- **Latency:** **27–28** cycles per `RV_VGATHER2`; throughput much lower than contiguous `RV_VLD` (see **Latency and throughput (A5)** at the start of this chapter). + +```c +for (int i = 0; i < active_lanes; i++) + dst[i] = UB[base + offsets[i] * sizeof(T)]; +``` + +--- + +##### `pto.vgatherb` + +- **syntax:** `%result = pto.vgatherb %source, %offsets, %active_lanes : !pto.ptr, !pto.vreg, index -> !pto.vreg` +- **semantics:** Byte-granularity indexed gather from UB. +- **inputs:** + `%source` is the UB base pointer, `%offsets` contains per-block byte offsets, + and `%active_lanes` bounds the number of active gathered blocks. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + This is a block gather, not a byte-per-lane gather. `%source` MUST be 32-byte + aligned, each participating offset MUST describe a 32-byte-aligned block, and + inactive blocks are zero-filled. +- **Latency:** **~21** cycles issue→retire. + +```c +for (int i = 0; i < active_lanes; i++) + dst[i] = UB[base + offsets[i]]; // byte-addressed +``` + +--- + +##### `pto.vgather2_bc` + +- **syntax:** `%result = pto.vgather2_bc %source, %offsets, %mask : !pto.ptr, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Gather with broadcast, conditioned by mask. +- **inputs:** + `%source` is the UB base pointer, `%offsets` contains gather indices, and + `%mask` gates which lanes participate. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + This is a backward-compatible family. Masked-off lanes do not participate in + address coalescing and do not trigger address overflow exceptions; their + destination lanes are zero-filled. +- **Latency:** **27–28** cycles (same as **`pto.vgather2`**). + +--- + +#### Contiguous Stores + +##### `pto.vsts` + +- **syntax:** `pto.vsts %value, %dest[%offset], %mask {dist = "DIST"} : !pto.vreg, !pto.ptr, !pto.mask` +- **semantics:** Vector store with distribution mode. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, `%offset` is + the displacement, `%mask` selects the active lanes or sub-elements, and + `DIST` selects the store distribution. +- **outputs:** + This op has no SSA result; it writes to UB memory. +- **constraints and limitations:** + The effective destination address MUST satisfy the alignment rule of the + selected store mode. Narrowing/packing modes may only preserve a subset of the + source bits. Merge-channel modes reinterpret the source vector as channel + planes and interleave them on store. + +**Distribution modes:** + +| Mode | Description | C Semantics | Latency | +|------|-------------|-------------|---------------------| +| `NORM_B8/B16/B32` | Contiguous store | `UB[base + i] = src[i]` | **9** cycles | +| `PK_B16/B32` | Pack/narrowing store | `UB_i16[base + 2*i] = truncate_16(src_i32[i])` | **9** cycles | +| `MRG4CHN_B8` | Merge 4 channels (R,G,B,A → RGBA) | Interleave 4 planes | **9** cycles | +| `MRG2CHN_B8/B16` | Merge 2 channels | Interleave 2 planes | **9** cycles | + +**Example — Contiguous store:** +```mlir +pto.vsts %v, %ub[%offset], %mask {dist = "NORM_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +--- + +#### Dual Stores (Interleave) + +##### `pto.vstx2` + +- **syntax:** `pto.vstx2 %low, %high, %dest[%offset], "DIST", %mask : !pto.vreg, !pto.vreg, !pto.ptr, index, !pto.mask` +- **semantics:** Dual interleaved store (SoA → AoS conversion). +- **inputs:** + `%low` and `%high` are the two source vectors, `%dest` is the UB base pointer, + `%offset` is the displacement, `DIST` selects the interleave layout, and + `%mask` gates the participating elements. +- **outputs:** + This op has no SSA result; it writes an interleaved stream to UB. +- **constraints and limitations:** + This family is only legal for interleave distributions. The two source + vectors form an ordered pair, and the interleave semantics of that pair MUST + be preserved. +- **Latency:** **`INTLV_B32` / `INTLV_B16` / `INTLV_B8` → 12** cycles on `RV_VSTI`. + +**Distribution modes:** `INTLV_B8`, `INTLV_B16`, `INTLV_B32` + +```c +// INTLV_B32: +for (int i = 0; i < 64; i++) { + UB[base + 8*i] = low[i]; + UB[base + 8*i + 4] = high[i]; +} +``` + +--- + +#### Strided Stores + +##### `pto.vsst` + +- **syntax:** `pto.vsst %value, %dest[%offset], "STRIDE" : !pto.vreg, !pto.ptr` +- **semantics:** Strided store with fixed stride pattern. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, and `%offset` + / `STRIDE` select the fixed strided layout. +- **outputs:** + This op writes UB memory and returns no SSA value. +- **constraints and limitations:** + This is a deprecated compatibility family. The stride token, not the vector + lane number alone, determines which destination elements are written. +- **Latency:** **9** cycles. + +--- + +##### `pto.vsstb` + +- **syntax:** `pto.vsstb %value, %dest, %offset, %mask : !pto.vreg, !pto.ptr, i32, !pto.mask` +- **semantics:** Block-strided store for 2D tile access. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, `%offset` is + the packed stride/control word, and `%mask` controls block participation. +- **outputs:** + This op writes UB memory and returns no SSA value. +- **constraints and limitations:** + `%offset` is a control word, not a plain byte displacement. This is a + deprecated compatibility family kept for surface coverage. +- **Latency:** **9** cycles. + +--- + +#### Scatter (Indexed) Stores + +##### `pto.vscatter` + +- **syntax:** `pto.vscatter %value, %dest, %offsets, %active_lanes : !pto.vreg, !pto.ptr, !pto.vreg, index` +- **semantics:** Indexed scatter to UB. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, `%offsets` + provides per-lane or per-block indices, and `%active_lanes` bounds the active + requests. +- **outputs:** + This op writes UB memory and returns no SSA value. +- **constraints and limitations:** + Only `b8`, `b16`, and `b32` element sizes are supported. The index vector + must use a supported integer element type and layout for this family. + Each computed address MUST be element-aligned. If two or more indices alias, + only one write is guaranteed and the winning lane is implementation-defined. +- **Latency:** **~17** cycles for **`Dtype: B16`**. + +```c +for (int i = 0; i < active_lanes; i++) + UB[base + offsets[i] * sizeof(T)] = src[i]; +``` + +--- + +#### Alignment State Stores + +##### `pto.vsta` + +- **syntax:** `pto.vsta %value, %dest[%offset] : !pto.align, !pto.ptr, index` +- **semantics:** Flush alignment state to memory. +- **inputs:** + `%value` is the pending store-alignment state, `%dest` is the UB base pointer, + and `%offset` is the flush displacement. +- **outputs:** + This op writes buffered tail bytes to UB and returns no SSA value. +- **constraints and limitations:** + The flush address MUST match the post-updated address expected by the + preceding unaligned-store stream. After the flush, the corresponding store + alignment state is consumed. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstas` +- **syntax:** `pto.vstas %value, %dest, %offset : !pto.align, !pto.ptr, i32` +- **semantics:** Scalar-register-offset form of alignment-state flush. +- **inputs:** + `%value` is the pending store-alignment state, `%dest` is the UB base + pointer, and `%offset` is the scalar-register style displacement. +- **outputs:** + This op writes buffered tail bytes to UB and returns no SSA value. +- **constraints and limitations:** + This family uses the same buffered-tail semantics as `pto.vsta` but keeps the + scalar-offset form explicit. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstar` +- **syntax:** `pto.vstar %value, %dest : !pto.align, !pto.ptr` +- **semantics:** Flush alignment state using the register-update form. +- **inputs:** + `%value` is the pending store-alignment state and `%dest` is the UB base + pointer. +- **outputs:** + This op writes buffered tail bytes to UB and returns no SSA value. +- **constraints and limitations:** + The implicit update state consumed by this flush MUST correspond to the same + store stream that produced `%value`. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstu` +- **syntax:** `%align_out, %base_out = pto.vstu %align_in, %base_in, %value, %dest, %mode : !pto.align, !pto.ptr, !pto.vreg, !pto.ptr, index -> !pto.align, !pto.ptr` +- **semantics:** Unaligned store with explicit threaded alignment/base state. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%base_in` is the current + stream base, `%value` is the vector to store, `%dest` is the UB base pointer, + and `%mode` selects the post-update behavior. +- **outputs:** + `%align_out` is the updated buffered-tail state and `%base_out` is the + post-update base pointer state. +- **constraints and limitations:** + This op models a stateful unaligned-store sequence in SSA form. A final + `pto.vsta` / `pto.vstas` / `pto.vstar` is still required to flush the trailing + buffered bytes. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstus` +- **syntax:** `%align_out, %base_out = pto.vstus %align_in, %base_in, %value, %dest, %offset : !pto.align, !pto.ptr, !pto.vreg, !pto.ptr, i32 -> !pto.align, !pto.ptr` +- **semantics:** Scalar-offset unaligned store with threaded state. +- **inputs:** + Same roles as `pto.vstu`, but `%offset` is provided explicitly as the scalar + displacement. +- **outputs:** + Updated alignment state and base state. +- **constraints and limitations:** + The same final flush requirement and state-threading constraints as + `pto.vstu` apply. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstur` +- **syntax:** `%align_out = pto.vstur %align_in, %value, %dest : !pto.align, !pto.vreg, !pto.ptr -> !pto.align` +- **semantics:** Register-update unaligned store form. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%value` is the vector to + store, and `%dest` is the UB base pointer. +- **outputs:** + `%align_out` is the updated buffered-tail state. +- **constraints and limitations:** + This op updates only the residual alignment state. A matching flush op is + still required to emit the trailing bytes. +- **Latency:** **9** cycles. + +--- + +#### Stateful Store Ops + +These ops make reference-updated state explicit as SSA results. + +##### `pto.vstu` + +- **syntax:** `%align_out, %offset_out = pto.vstu %align_in, %offset_in, %value, %base, "MODE" : !pto.align, index, !pto.vreg, !pto.ptr -> !pto.align, index` +- **semantics:** Unaligned store with align + offset state update. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%offset_in` is the current + logical byte/element displacement, `%value` is the vector being stored, and + `%base` is the UB base pointer. +- **outputs:** + `%align_out` is the updated alignment/tail state and `%offset_out` is the + next offset after applying the selected post-update rule. +- **constraints and limitations:** + The alignment state MUST be threaded in program order. A terminating flush + form such as `pto.vstar`/`pto.vstas` is still required to commit the buffered + tail bytes. +- **Latency:** **9** cycles. + +**Mode tokens:** `POST_UPDATE`, `NO_POST_UPDATE` + +--- + +##### `pto.vstus` + +- **syntax:** `%align_out, %base_out = pto.vstus %align_in, %offset, %value, %base, "MODE" : !pto.align, i32, !pto.vreg, !pto.ptr -> !pto.align, !pto.ptr` +- **semantics:** Unaligned store with scalar offset and state update. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%offset` is the scalar + displacement, `%value` is the vector being stored, and `%base` is the UB base + pointer. +- **outputs:** + `%align_out` is the updated buffered-tail state and `%base_out` is the next + base pointer when the lowering chooses a post-update form. +- **constraints and limitations:** + This is the scalar-offset stateful form of the unaligned store family. The + scalar offset width and update mode MUST match the selected form, and a later + flush op is still required. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstur` + +- **syntax:** `%align_out = pto.vstur %align_in, %value, %base, "MODE" : !pto.align, !pto.vreg, !pto.ptr -> !pto.align` +- **semantics:** Unaligned store with residual flush and state update. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%value` is the vector to + store, and `%base` is the UB base pointer. +- **outputs:** + `%align_out` is the updated residual state after the current partial store. +- **constraints and limitations:** + This form exposes only the evolving state; it does not by itself guarantee + that all buffered bytes have reached memory. A compatible final flush is still + required unless the surrounding sequence is known to be complete. +- **Latency:** **9** cycles. + + + +### 4. Predicate Load/Store + +> **Category:** UB ↔ Predicate Register data movement +> **Pipeline:** PIPE_V (Vector Core) + +Predicate registers (`!pto.mask`) are 256-bit registers that enable per-lane conditional execution. These ops move predicate values between UB and predicate registers. + +In concrete examples, `G` should be chosen to match the consumer family. The +examples below use `b32` when the loaded/stored mask is paired with `f32` +vector compares or selects. + +--- + +#### Predicate Loads + +##### `pto.plds` + +- **syntax:** `%result = pto.plds %source[%offset] {dist = "DIST"} : !pto.ptr -> !pto.mask` +- **semantics:** Load predicate register with scalar offset. + +**Distribution modes:** `NORM`, `US`, `DS` + +**Example:** +```mlir +%mask = pto.plds %ub[%c0] {dist = "NORM"} : !pto.ptr -> !pto.mask +``` + +--- + +##### `pto.pld` + +- **syntax:** `%result = pto.pld %source[%offset], "DIST" : !pto.ptr, index -> !pto.mask` +- **semantics:** Load predicate register with areg offset. + +--- + +##### `pto.pldi` + +- **syntax:** `%result = pto.pldi %source, %offset, "DIST" : !pto.ptr, i32 -> !pto.mask` +- **semantics:** Load predicate register with immediate offset. + +--- + +#### Predicate Stores + +##### `pto.psts` + +- **syntax:** `pto.psts %value, %dest[%offset] : !pto.mask, !pto.ptr` +- **semantics:** Store predicate register with scalar offset. + +**Example:** +```mlir +pto.psts %mask, %ub[%c0] : !pto.mask, !pto.ptr +``` + +--- + +##### `pto.pst` + +- **syntax:** `pto.pst %value, %dest[%offset], "DIST" : !pto.mask, !pto.ptr, index` +- **semantics:** Store predicate register with areg offset. + +**Distribution modes:** `NORM`, `PK` + +--- + +##### `pto.psti` + +- **syntax:** `pto.psti %value, %dest, %offset, "DIST" : !pto.mask, !pto.ptr, i32` +- **semantics:** Store predicate register with immediate offset. + +--- + +##### `pto.pstu` + +- **syntax:** `%align_out, %base_out = pto.pstu %align_in, %value, %base : !pto.align, !pto.mask, !pto.ptr -> !pto.align, !pto.ptr` +- **semantics:** Predicate unaligned store with align state update. + +--- + +#### Typical Usage Pattern + +```mlir +// Generate comparison mask +%mask = pto.vcmp %v0, %v1, %seed, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + +// Store mask to UB for later use +pto.psts %mask, %ub_mask[%c0] : !pto.mask, !pto.ptr + +// ... later in another kernel ... + +// Load mask from UB +%saved_mask = pto.plds %ub_mask[%c0] {dist = "NORM"} : !pto.ptr -> !pto.mask + +// Use for predicated select +%result = pto.vsel %v_true, %v_false, %saved_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 5. Materialization & Predicate Ops + +> **Category:** Scalar broadcast, predicate generation and manipulation +> **Pipeline:** PIPE_V (Vector Core) + +These ops create vectors from scalar values and manipulate predicate registers. + +#### Common Operand Model + +- `%value` is the scalar source value in SSA form. +- `%input` is either a source scalar or a source vector depending on the op. +- `%result` is the destination vector register value. +- For 32-bit scalar inputs, the scalar source MUST satisfy the backend's legal + scalar-source constraints for this family. + +--- + +#### Scalar Materialization + +##### `pto.vbr` + +- **syntax:** `%result = pto.vbr %value : T -> !pto.vreg` +- **semantics:** Broadcast scalar to all vector lanes. +- **inputs:** + `%value` is the scalar source. +- **outputs:** + `%result` is a vector whose active lanes all carry `%value`. +- **constraints and limitations:** + Supported forms are `b8`, `b16`, and `b32`. For `b8`, only the low 8 bits of + the scalar source are consumed. + +```c +for (int i = 0; i < N; i++) + dst[i] = value; +``` + +**Example:** +```mlir +%one = pto.vbr %c1_f32 : f32 -> !pto.vreg<64xf32> +``` + +--- + +##### `pto.vdup` + +- **syntax:** `%result = pto.vdup %input {position = "POSITION"} : T|!pto.vreg -> !pto.vreg` +- **semantics:** Duplicate scalar or vector element to all lanes. +- **inputs:** + `%input` supplies the scalar or source-lane value selected by `position`. +- **outputs:** + `%result` is the duplicated vector. +- **constraints and limitations:** + `position` selects which source element or scalar position is duplicated. The + current PTO micro Instruction representation models that selector as an attribute rather than a + separate operand. + +```c +for (int i = 0; i < N; i++) + dst[i] = input_scalar_or_element; +``` + +--- + +#### Predicate Generation + +##### `pto.pset_b8` / `pto.pset_b16` / `pto.pset_b32` + +- **syntax:** `%result = pto.pset_b8 "PAT_*" : !pto.mask` +- **syntax:** `%result = pto.pset_b16 "PAT_*" : !pto.mask` +- **syntax:** `%result = pto.pset_b32 "PAT_*" : !pto.mask` +- **semantics:** Generate predicate from pattern. + +**Patterns:** + +| Pattern | Description | +|---------|-------------| +| `PAT_ALL` | All lanes active | +| `PAT_ALLF` | All lanes inactive | +| `PAT_H` | High half active | +| `PAT_Q` | Upper quarter active | +| `PAT_VL1`...`PAT_VL128` | First N lanes active | +| `PAT_M3`, `PAT_M4` | Modular patterns | + +**Example — All 64 f32 lanes active:** +```mlir +%all_active = pto.pset_b32 "PAT_ALL" : !pto.mask +``` + +**Example — First 16 lanes active:** +```mlir +%first_16 = pto.pset_b32 "PAT_VL16" : !pto.mask +``` + +--- + +##### `pto.pge_b8` / `pto.pge_b16` / `pto.pge_b32` + +- **syntax:** `%result = pto.pge_b8 "PAT_*" : !pto.mask` +- **syntax:** `%result = pto.pge_b16 "PAT_*" : !pto.mask` +- **syntax:** `%result = pto.pge_b32 "PAT_*" : !pto.mask` +- **semantics:** Generate tail mask — first N lanes active. + +```c +for (int i = 0; i < TOTAL_LANES; i++) + mask[i] = (i < len); +``` + +**Example — Tail mask for remainder loop:** +```mlir +%tail_mask = pto.pge_b32 "PAT_VL8" : !pto.mask +``` + +--- + +##### `pto.plt_b8` / `pto.plt_b16` / `pto.plt_b32` + +- **syntax:** `%mask, %scalar_out = pto.plt_b8 %scalar : i32 -> !pto.mask, i32` +- **syntax:** `%mask, %scalar_out = pto.plt_b16 %scalar : i32 -> !pto.mask, i32` +- **syntax:** `%mask, %scalar_out = pto.plt_b32 %scalar : i32 -> !pto.mask, i32` +- **semantics:** Generate predicate state together with updated scalar state. + +--- + +#### Predicate Pack/Unpack + +##### `pto.ppack` + +- **syntax:** `%result = pto.ppack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Narrowing pack of predicate register. + +**Part tokens:** `LOWER`, `HIGHER` + +--- + +##### `pto.punpack` + +- **syntax:** `%result = pto.punpack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Widening unpack of predicate register. + +--- + +#### Predicate Logical Ops + +##### `pto.pand` + +- **syntax:** `%result = pto.pand %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise AND. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) dst[i] = src0[i] & src1[i]; +``` + +--- + +##### `pto.por` + +- **syntax:** `%result = pto.por %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise OR. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) dst[i] = src0[i] | src1[i]; +``` + +--- + +##### `pto.pxor` + +- **syntax:** `%result = pto.pxor %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise XOR. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) dst[i] = src0[i] ^ src1[i]; +``` + +--- + +##### `pto.pnot` + +- **syntax:** `%result = pto.pnot %input, %mask : !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise NOT. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) dst[i] = ~src[i]; +``` + +--- + +##### `pto.psel` + +- **syntax:** `%result = pto.psel %src0, %src1, %sel : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate select (mux). + +```c +for (int i = 0; i < N; i++) + dst[i] = sel[i] ? src0[i] : src1[i]; +``` + +--- + +#### Predicate Movement + +##### `pto.ppack` + +- **syntax:** `%result = pto.ppack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Narrowing pack of predicate register. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) dst[i] = src[i]; +``` + +--- + +##### `pto.punpack` + +- **syntax:** `%result = pto.punpack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Widening unpack of predicate register. + +--- + +##### `pto.pdintlv_b8` + +- **syntax:** `%low, %high = pto.pdintlv_b8 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **semantics:** Predicate deinterleave. + +--- + +##### `pto.pintlv_b16` + +- **syntax:** `%low, %high = pto.pintlv_b16 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **semantics:** Predicate interleave. + +--- + +#### Typical Usage + +```mlir +// Generate all-active mask for f32 (64 lanes) +%all = pto.pset_b32 "PAT_ALL" : !pto.mask + +// Generate tail mask for remainder (last 12 elements) +%tail = pto.pge_b32 "PAT_VL12" : !pto.mask + +// Compare and generate mask +%cmp_mask = pto.vcmp %a, %b, %all, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + +// Combine masks: only process tail elements that passed comparison +%combined = pto.pand %cmp_mask, %tail, %all : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + +// Use for predicated operation +%result = pto.vsel %true_vals, %false_vals, %combined : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 6. Unary Vector Ops + +> **Category:** Single-input vector operations +> **Pipeline:** PIPE_V (Vector Core) + +Element-wise operations that take one vector input and produce one vector output. + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%mask` is the predicate operand. For this family, inactive lanes follow the + predication behavior of the selected instruction form: zeroing forms + zero-fill inactive lanes, while merging forms preserve the destination value. +- `%result` is the destination vector register value. Unless stated otherwise, + `%result` has the same lane count and element type as `%input`. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). **fp16** values use **aclFloat16** in traces where measured. **bf16:** no simple-tile ST coverage on this surface; treat as **—**. + +| PTO op | RV (CA) | fp32 | fp16 | bf16 | +|--------|---------|------|------|------| +| `pto.vabs` | `RV_VABS_FP` | **5** | **5** | — | +| `pto.vneg` | `RV_VMULS` | **8** | **8** | — | +| `pto.vexp` | `RV_VEXP` | **16** | **21** | — | +| `pto.vln` | `RV_VLN` | **18** | **23** | — | +| `pto.vsqrt` | `RV_VSQRT` | **17** | **22** | — | +| `pto.vrsqrt` | `RV_VSQRT` / `RV_VDIV` | **17** / **17** | **22** / **22** | — | +| `pto.vrec` | `RV_VDIV` | **17** | **22** | — | +| `pto.vrelu` | `RV_VRELU` | **5** | **5** | — | +| `pto.vnot` | `RV_VNOT` | — | int-only paths | — | +| `pto.vmov` | `RV_VLD` proxy | **9** | **9** | — | + +--- + +#### Arithmetic + +##### `pto.vabs` + +- **syntax:** `%result = pto.vabs %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] < 0) ? -src[i] : src[i]; +``` + +- **inputs:** `%input` supplies the source lanes and `%mask` selects which lanes + participate. +- **outputs:** `%result` receives the lane-wise absolute values. +- **constraints and limitations:** Source and result types MUST match. Integer + overflow on the most-negative signed value follows the target-defined + behavior. + +--- + +##### `pto.vneg` + +- **syntax:** `%result = pto.vneg %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = -src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise arithmetic negation. +- **constraints and limitations:** Source and result types MUST match. + +--- + +#### Transcendental + +##### `pto.vexp` + +- **syntax:** `%result = pto.vexp %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = expf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds `exp(input[i])` per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + +--- + +##### `pto.vln` + +- **syntax:** `%result = pto.vln %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = logf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the natural logarithm per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + For real-number semantics, active inputs SHOULD be strictly positive; non- + positive inputs follow the target's exception/NaN rules. + +--- + +##### `pto.vsqrt` + +- **syntax:** `%result = pto.vsqrt %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = sqrtf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the square root per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + Negative active inputs follow the target's exception/NaN rules. + +--- + +##### `pto.vrsqrt` + +- **syntax:** `%result = pto.vrsqrt %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = 1.0f / sqrtf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds reciprocal-square-root values per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + Active inputs containing `+0` or `-0` follow the target's divide-style + exceptional behavior. + +--- + +##### `pto.vrec` + +- **syntax:** `%result = pto.vrec %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = 1.0f / src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the reciprocal per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + Active inputs containing `+0` or `-0` follow the target's divide-style + exceptional behavior. + +--- + +#### Activation + +##### `pto.vrelu` + +- **syntax:** `%result = pto.vrelu %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] > 0) ? src[i] : 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds `max(input[i], 0)` per active lane. +- **constraints and limitations:** Only floating-point element types are legal + on the current A5 surface described here. + +--- + +#### Bitwise + +##### `pto.vnot` + +- **syntax:** `%result = pto.vnot %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = ~src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the lane-wise bitwise inversion. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vbcnt` + +- **syntax:** `%result = pto.vbcnt %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = __builtin_popcount(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the population count for each active lane. +- **constraints and limitations:** Integer element types only. The count is + over the source element width, not over the full vector register. + +--- + +##### `pto.vcls` + +- **syntax:** `%result = pto.vcls %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = count_leading_sign_bits(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the leading-sign-bit count per active lane. +- **constraints and limitations:** Integer element types only. This operation is + sign-aware, so signed interpretation matters. + +--- + +#### Movement + +##### `pto.vmov` + +- **syntax:** `%result = pto.vmov %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Vector register copy. + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` is a copy of the source vector. +- **constraints and limitations:** Predicated `pto.vmov` behaves like a masked + copy, while the unpredicated form behaves like a full-register copy. + +--- + +#### Typical Usage + +```mlir +// Softmax numerator: exp(x - max) +%sub = pto.vsub %x, %max_broadcast, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%exp = pto.vexp %sub, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Reciprocal for division +%sum_rcp = pto.vrec %sum, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// ReLU activation +%activated = pto.vrelu %linear_out, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 7. Binary Vector Ops + +> **Category:** Two-input vector operations +> **Pipeline:** PIPE_V (Vector Core) + +Element-wise operations that take two vector inputs and produce one vector output. + +#### Common Operand Model + +- `%lhs` and `%rhs` are the two source vector register values. +- `%mask` is the predicate operand `Pg` that gates which lanes participate. +- `%result` is the destination vector register value. Unless explicitly noted, + it has the same lane count and element type as the inputs. +- Unless explicitly documented otherwise, `%lhs`, `%rhs`, and `%result` MUST + have matching vector shapes and element types. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). **fp16** uses **aclFloat16** in measured traces. **bf16:** — (no dedicated vec tile ST on this surface). + +| PTO op | RV (CA) | fp32 | fp16 | bf16 | +|--------|---------|------|------|------| +| `pto.vadd` | `RV_VADD` | **7** | **7** | — | +| `pto.vsub` | `RV_VSUB` | **7** | **7** | — | +| `pto.vmul` | `RV_VMUL` | **8** | **8** | — | +| `pto.vdiv` | `RV_VDIV` | **17** | **22** | — | + +--- + +#### Arithmetic + +##### `pto.vadd` + +- **syntax:** `%result = pto.vadd %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i64, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] + src1[i]; +``` + +- **inputs:** `%lhs` and `%rhs` are added lane-wise; `%mask` selects active + lanes. +- **outputs:** `%result` is the lane-wise sum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vsub` + +- **syntax:** `%result = pto.vsub %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i64, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] - src1[i]; +``` + +- **inputs:** `%lhs` is the minuend, `%rhs` is the subtrahend, and `%mask` + selects active lanes. +- **outputs:** `%result` is the lane-wise difference. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vmul` + +- **syntax:** `%result = pto.vmul %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, bf16, f32 (**NOT** i8/u8) + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] * src1[i]; +``` + +- **inputs:** `%lhs` and `%rhs` are multiplied lane-wise; `%mask` selects + active lanes. +- **outputs:** `%result` is the lane-wise product. +- **constraints and limitations:** The current A5 profile excludes `i8/u8` + forms from this surface. + +--- + +##### `pto.vdiv` + +- **syntax:** `%result = pto.vdiv %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 only (no integer division) + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] / src1[i]; +``` + +- **inputs:** `%lhs` is the numerator, `%rhs` is the denominator, and `%mask` + selects active lanes. +- **outputs:** `%result` is the lane-wise quotient. +- **constraints and limitations:** Floating-point element types only. Active + denominators containing `+0` or `-0` follow the target's exceptional + behavior. + +--- + +##### `pto.vmax` + +- **syntax:** `%result = pto.vmax %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src0[i] > src1[i]) ? src0[i] : src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` holds the lane-wise maximum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vmin` + +- **syntax:** `%result = pto.vmin %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src0[i] < src1[i]) ? src0[i] : src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` holds the lane-wise minimum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +#### Bitwise + +##### `pto.vand` + +- **syntax:** `%result = pto.vand %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] & src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise AND. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vor` + +- **syntax:** `%result = pto.vor %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] | src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise OR. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vxor` + +- **syntax:** `%result = pto.vxor %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] ^ src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise XOR. +- **constraints and limitations:** Integer element types only. + +--- + +#### Shift + +##### `pto.vshl` + +- **syntax:** `%result = pto.vshl %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] << src1[i]; +``` + +- **inputs:** `%lhs` supplies the shifted value, `%rhs` supplies the per-lane + shift amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. Shift counts + SHOULD stay within `[0, bitwidth(T) - 1]`; out-of-range behavior is target- + defined unless the verifier narrows it further. + +--- + +##### `pto.vshr` + +- **syntax:** `%result = pto.vshr %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] >> src1[i]; // arithmetic for signed, logical for unsigned +``` + +- **inputs:** `%lhs` supplies the shifted value, `%rhs` supplies the per-lane + shift amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. Signedness of the + element type determines arithmetic vs logical behavior. + +--- + +#### Carry Operations + +##### `pto.vaddc` + +- **syntax:** `%result, %carry = pto.vaddc %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Add with carry output. + +```c +for (int i = 0; i < N; i++) { + uint64_t r = (uint64_t)src0[i] + src1[i]; + dst[i] = (T)r; + carry[i] = (r >> bitwidth); +} +``` + +- **inputs:** `%lhs` and `%rhs` are added lane-wise and `%mask` selects active + lanes. +- **outputs:** `%result` is the truncated arithmetic result and `%carry` is the + carry/overflow predicate per lane. +- **constraints and limitations:** This is a carry-chain integer add family. On + the current A5 surface, it SHOULD be treated as an unsigned integer + operation. + +--- + +##### `pto.vsubc` + +- **syntax:** `%result, %borrow = pto.vsubc %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Subtract with borrow output. + +```c +for (int i = 0; i < N; i++) { + dst[i] = src0[i] - src1[i]; + borrow[i] = (src0[i] < src1[i]); +} +``` + +- **inputs:** `%lhs` and `%rhs` are subtracted lane-wise and `%mask` selects + active lanes. +- **outputs:** `%result` is the arithmetic difference and `%borrow` marks lanes + that borrowed. +- **constraints and limitations:** This operation SHOULD be treated as an + unsigned 32-bit carry-chain family unless and until the verifier states + otherwise. + +--- + +#### Typical Usage + +```mlir +// Vector addition +%sum = pto.vadd %a, %b, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Element-wise multiply +%prod = pto.vmul %x, %y, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Clamp to range [min, max] +%clamped_low = pto.vmax %input, %min_vec, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%clamped = pto.vmin %clamped_low, %max_vec, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Bit manipulation +%masked = pto.vand %data, %bitmask, %mask : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +``` + + + +### 8. Vec-Scalar Ops + +> **Category:** Vector-scalar operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that combine a vector with a scalar value, applying the scalar to every lane. + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%scalar` is the scalar operand in SSA form. +- `%mask` is the predicate operand. +- `%result` is the destination vector register value. +- For 32-bit scalar forms, the scalar source MUST satisfy the backend's legal + scalar-source constraints for this family. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). **fp16** uses **aclFloat16** in measured traces. **bf16:** —. + +| PTO op | RV (CA) | fp32 | fp16 | bf16 | +|--------|---------|------|------|------| +| `pto.vadds` | `RV_VADDS` | **7** | **7** | — | +| `pto.vmuls` | `RV_VMULS` | **8** | **8** | — | + +--- + +#### Arithmetic + +##### `pto.vadds` + +- **syntax:** `%result = pto.vadds %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] + scalar; +``` + +- **inputs:** `%input` is the source vector, `%scalar` is broadcast logically to + each active lane, and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise sum. +- **constraints and limitations:** Inactive lanes follow the predication + behavior defined for this family. On the current surface, inactive lanes are + treated as zeroing lanes. + +--- + +##### `pto.vsubs` + +- **syntax:** `%result = pto.vsubs %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] - scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise difference. +- **constraints and limitations:** Integer or floating-point legality depends on + the selected type family in lowering. + +--- + +##### `pto.vmuls` + +- **syntax:** `%result = pto.vmuls %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] * scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise product. +- **constraints and limitations:** Supported element types are hardware-family + specific; the current PTO micro Instruction documentation covers the common numeric cases. + +--- + +##### `pto.vmaxs` + +- **syntax:** `%result = pto.vmaxs %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] > scalar) ? src[i] : scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise maximum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vmins` + +- **syntax:** `%result = pto.vmins %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] < scalar) ? src[i] : scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise minimum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +#### Bitwise + +##### `pto.vands` + +- **syntax:** `%result = pto.vands %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] & scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise AND. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vors` + +- **syntax:** `%result = pto.vors %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] | scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise OR. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vxors` + +- **syntax:** `%result = pto.vxors %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] ^ scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise XOR. +- **constraints and limitations:** Integer element types only. + +--- + +#### Shift + +##### `pto.vshls` + +- **syntax:** `%result = pto.vshls %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] << scalar; +``` + +- **inputs:** `%input` is the value vector, `%scalar` is the uniform shift + amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. The shift amount + SHOULD stay within the source element width. + +--- + +##### `pto.vshrs` + +- **syntax:** `%result = pto.vshrs %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] >> scalar; +``` + +- **inputs:** `%input` is the value vector, `%scalar` is the uniform shift + amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vlrelu` + +- **syntax:** `%result = pto.vlrelu %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : scalar * src[i]; +``` + +- **inputs:** `%input` is the activation vector, `%scalar` is the leaky slope, + and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise leaky-ReLU result. +- **constraints and limitations:** Only `f16` and `f32` forms are currently + documented for `pto.vlrelu`. + +--- + +#### Carry Operations + +##### `pto.vaddcs` + +- **syntax:** `%result, %carry = pto.vaddcs %lhs, %rhs, %carry_in, %mask : !pto.vreg, !pto.vreg, !pto.mask, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Add with carry-in and carry-out. + +```c +for (int i = 0; i < N; i++) { + uint64_t r = (uint64_t)src0[i] + src1[i] + carry_in[i]; + dst[i] = (T)r; + carry_out[i] = (r >> bitwidth); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the value vectors, `%carry_in` is the + incoming carry predicate, and `%mask` selects active lanes. +- **outputs:** `%result` is the arithmetic result and `%carry` is the carry-out + predicate. +- **constraints and limitations:** This is the scalar-extended carry-chain + family. Treat it as an unsigned integer operation unless the verifier states a + wider legal domain. + +--- + +##### `pto.vsubcs` + +- **syntax:** `%result, %borrow = pto.vsubcs %lhs, %rhs, %borrow_in, %mask : !pto.vreg, !pto.vreg, !pto.mask, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Subtract with borrow-in and borrow-out. + +```c +for (int i = 0; i < N; i++) { + dst[i] = src0[i] - src1[i] - borrow_in[i]; + borrow_out[i] = (src0[i] < src1[i] + borrow_in[i]); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the value vectors, `%borrow_in` is the + incoming borrow predicate, and `%mask` selects active lanes. +- **outputs:** `%result` is the arithmetic result and `%borrow` is the + borrow-out predicate. +- **constraints and limitations:** This is the scalar-extended borrow-chain + family and SHOULD be treated as an unsigned integer operation. + +--- + +#### Typical Usage + +```mlir +// Add bias to all elements +%biased = pto.vadds %activation, %bias_scalar, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Scale by constant +%scaled = pto.vmuls %input, %scale, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Clamp to [0, 255] for uint8 quantization +%clamped_low = pto.vmaxs %input, %c0, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> +%clamped = pto.vmins %clamped_low, %c255, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Shift right by fixed amount +%shifted = pto.vshrs %data, %c4, %mask : !pto.vreg<64xi32>, i32, !pto.mask -> !pto.vreg<64xi32> +``` + + + +### 9. Conversion Ops + +> **Category:** Type conversion operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that convert between data types (float/int, narrowing/widening). + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%result` is the destination vector register value. +- `round_mode`, `sat`, and `part` control rounding, saturation, and lane-part + selection in attribute form. +- The single `pto.vcvt` surface covers float-int, float-float, int-float, and + int-int conversion families. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). Only representative traces below; other `pto.vcvt` conversion pairs depend on the RV lowering in the trace. + +| PTO op | RV (CA) | Note | Latency | +|--------|---------|------|---------| +| `pto.vcvt` | `RV_VCVT_F2F` | f32→f16 | **7** | +| `pto.vci` | — | no vector `RV_*` in sampled `veccore0` trace | — | + +--- + +#### `pto.vci` + +- **syntax:** `%result = pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- **semantics:** Generate a lane-index vector from a scalar seed/index value. +- **inputs:** + `%index` is the scalar seed or base index. +- **outputs:** + `%result` is the generated index vector. +- **constraints and limitations:** + This is an index-generation family, not a numeric conversion. `ORDER` and the + result element type together determine how indices are generated. + +--- + +#### `pto.vcvt` + +- **syntax:** `%result = pto.vcvt %input {round_mode = "ROUND_MODE", sat = "SAT_MODE", part = "PART_MODE"} : !pto.vreg -> !pto.vreg` +- **semantics:** Type conversion between float/int types with rounding control. + +```c +for (int i = 0; i < min(N, M); i++) + dst[i] = convert(src[i], T0, T1, round_mode); +``` + +- **inputs:** + `%input` is the source vector; attributes select rounding, saturation, and + even/odd placement when the conversion changes width. +- **outputs:** + `%result` is the converted vector. +- **constraints and limitations:** + Only documented source/destination type pairs are legal. `PART_EVEN` / + `PART_ODD` is only meaningful for width-changing forms that pack two source + streams into one destination register. + +--- + +##### Rounding Modes + +| Mode | Description | +|------|-------------| +| `ROUND_R` | Round to nearest, ties to even (default) | +| `ROUND_A` | Round away from zero | +| `ROUND_F` | Round toward negative infinity (floor) | +| `ROUND_C` | Round toward positive infinity (ceil) | +| `ROUND_Z` | Round toward zero (truncate) | +| `ROUND_O` | Round to odd | + +--- + +##### Saturation Modes + +| Mode | Description | +|------|-------------| +| `RS_ENABLE` | Saturate on overflow | +| `RS_DISABLE` | No saturation (wrap/undefined on overflow) | + +--- + +##### Part Modes (for width-changing conversions) + +| Mode | Description | +|------|-------------| +| `PART_EVEN` | Output to even-indexed lanes | +| `PART_ODD` | Output to odd-indexed lanes | + +--- + +##### A5 Supported Conversions + +**Float-Float (vcvtff):** +- f32 ↔ f16 +- f32 ↔ bf16 +- f16 ↔ bf16 + +**Float-Int (vcvtfi):** +- f16 → i16, f16 → i32 +- f32 → i16, f32 → i32 +- bf16 → i32 + +**Int-Float (vcvtif):** +- i16 → f16 +- i32 → f32 + +--- + +##### Width-Changing Conversion Pattern + +For conversions that change width (e.g., f32→f16), use even/odd parts and combine: + +```mlir +// Convert two f32 vectors to one f16 vector +%even = pto.vcvt %in0 {round_mode = "ROUND_R", sat = "RS_ENABLE", part = "PART_EVEN"} + : !pto.vreg<64xf32> -> !pto.vreg<128xf16> +%odd = pto.vcvt %in1 {round_mode = "ROUND_R", sat = "RS_ENABLE", part = "PART_ODD"} + : !pto.vreg<64xf32> -> !pto.vreg<128xf16> +%result = pto.vor %even, %odd, %mask : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> +``` + +--- + +#### `pto.vtrc` + +- **syntax:** `%result = pto.vtrc %input, %mask, "ROUND_MODE" : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Truncate/round float to integer-valued float (stays in float type). + +```c +for (int i = 0; i < N; i++) + dst[i] = round_to_int_valued_float(src[i], round_mode); +``` + +- **inputs:** + `%input` is the floating-point source vector, `%mask` selects active lanes, + and `ROUND_MODE` selects the truncation/rounding rule. +- **outputs:** + `%result` is still a floating-point vector, but each active lane now carries + an integer-valued floating-point result. +- **constraints and limitations:** + This op does not change the element type. `T` must be `f16`, `f32`, or + `bf16`. `ROUND_MODE` must be one of `ROUND_R`, `ROUND_A`, `ROUND_F`, + `ROUND_C`, or `ROUND_Z`. `BW` must match the element width: `b16` for + `f16`/`bf16`, `b32` for `f32`. + +**Example:** +```mlir +// Round to nearest integer, keep as float +%rounded = pto.vtrc %input, %mask, "ROUND_R" : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// input: [1.4, 2.6, -1.5, 3.0] +// output: [1.0, 3.0, -2.0, 3.0] +``` + +--- + +#### Typical Usage + +```mlir +// Quantization: f32 → i8 with saturation +%scaled = pto.vmuls %input, %scale, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> +%quantized = pto.vcvt %scaled {round_mode = "ROUND_R", sat = "RS_ENABLE"} + : !pto.vreg<64xf32> -> !pto.vreg<64xi32> +// Then narrow i32 → i8 via pack ops + +// Mixed precision: bf16 → f32 for accumulation +%f32_vec = pto.vcvt %bf16_input {round_mode = "ROUND_R"} + : !pto.vreg<128xbf16> -> !pto.vreg<64xf32> + +// Floor for integer division +%floored = pto.vtrc %ratio, %mask, "ROUND_F" : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%int_div = pto.vcvt %floored {round_mode = "ROUND_Z"} + : !pto.vreg<64xf32> -> !pto.vreg<64xi32> +``` + + + +### 10. Reduction Ops + +> **Category:** Vector reduction operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that reduce a vector to a scalar or per-group result. + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%mask` is the predicate operand `Pg`; inactive lanes do not participate. +- `%result` is the destination vector register value. +- Reduction results are written into the low-significance portion of the + destination vector and the remaining destination bits are zero-filled. + +--- + +#### Full Vector Reductions + +##### `pto.vcadd` + +- **syntax:** `%result = pto.vcadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i64, f16, f32 +- **semantics:** Sum all elements. Result in lane 0, others zeroed. + +```c +T sum = 0; +for (int i = 0; i < N; i++) + sum += src[i]; +dst[0] = sum; +for (int i = 1; i < N; i++) + dst[i] = 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains the reduction result in its low element(s). +- **constraints and limitations:** Some narrow integer forms may widen the + internal accumulation or result placement. If all predicate bits are zero, the + result is zero. + +--- + +##### `pto.vcmax` + +- **syntax:** `%result = pto.vcmax %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Find max element with argmax. Result value + index in lane 0. + +```c +T mx = -INF; int idx = 0; +for (int i = 0; i < N; i++) + if (src[i] > mx) { mx = src[i]; idx = i; } +dst_val[0] = mx; +dst_idx[0] = idx; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` carries the reduction result in the low destination + positions. +- **constraints and limitations:** This family computes both the extremum and + location information, but the exact packing of that information into the + destination vector depends on the chosen form. If all predicate bits are zero, + the result follows the zero-filled convention. + +--- + +##### `pto.vcmin` + +- **syntax:** `%result = pto.vcmin %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Find min element with argmin. Result value + index in lane 0. + +```c +T mn = INF; int idx = 0; +for (int i = 0; i < N; i++) + if (src[i] < mn) { mn = src[i]; idx = i; } +dst_val[0] = mn; +dst_idx[0] = idx; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` carries the reduction result in the low destination + positions. +- **constraints and limitations:** As with `pto.vcmax`, the exact value/index + packing depends on the chosen form and MUST be preserved consistently. + +--- + +#### Per-VLane (Group) Reductions + +The vector register is organized as **8 VLanes** of 32 bytes each. Group reductions operate within each VLane independently. + +``` +vreg layout (f32 example, 64 elements total): +VLane 0: [0..7] VLane 1: [8..15] VLane 2: [16..23] VLane 3: [24..31] +VLane 4: [32..39] VLane 5: [40..47] VLane 6: [48..55] VLane 7: [56..63] +``` + +##### `pto.vcgadd` + +- **syntax:** `%result = pto.vcgadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Sum within each VLane. 8 results at indices 0, 8, 16, 24, 32, 40, 48, 56 (for f32). + +```c +int K = N / 8; // elements per VLane +for (int g = 0; g < 8; g++) { + T sum = 0; + for (int i = 0; i < K; i++) + sum += src[g*K + i]; + dst[g*K] = sum; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +// For f32: results at dst[0], dst[8], dst[16], dst[24], dst[32], dst[40], dst[48], dst[56] +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one sum per 32-byte VLane group, written + contiguously into the low slot of each group. +- **constraints and limitations:** This is a per-32-byte VLane-group reduction. + Inactive lanes are treated as zero. + +--- + +##### `pto.vcgmax` + +- **syntax:** `%result = pto.vcgmax %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Max within each VLane. + +```c +int K = N / 8; +for (int g = 0; g < 8; g++) { + T mx = -INF; + for (int i = 0; i < K; i++) + if (src[g*K + i] > mx) mx = src[g*K + i]; + dst[g*K] = mx; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one maximum per 32-byte VLane group. +- **constraints and limitations:** Grouping is by hardware 32-byte VLane, not by + arbitrary software subvector. + +--- + +##### `pto.vcgmin` + +- **syntax:** `%result = pto.vcgmin %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Min within each VLane. + +```c +int K = N / 8; +for (int g = 0; g < 8; g++) { + T mn = INF; + for (int i = 0; i < K; i++) + if (src[g*K + i] < mn) mn = src[g*K + i]; + dst[g*K] = mn; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one minimum per 32-byte VLane group. +- **constraints and limitations:** Grouping is by hardware 32-byte VLane, not by + arbitrary software subvector. + +--- + +#### Prefix Operations + +##### `pto.vcpadd` + +- **syntax:** `%result = pto.vcpadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Inclusive prefix sum (scan). + +```c +dst[0] = src[0]; +for (int i = 1; i < N; i++) + dst[i] = dst[i-1] + src[i]; +``` + +**Example:** +```c +// input: [1, 2, 3, 4, 5, ...] +// output: [1, 3, 6, 10, 15, ...] +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` is the inclusive prefix-sum vector. +- **constraints and limitations:** Only floating-point element types are + documented on the current A5 surface here. + +--- + +#### Typical Usage + +```mlir +// Softmax: find max for numerical stability +%max_vec = pto.vcmax %logits, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// max is in lane 0, broadcast it +%max_broadcast = pto.vlds %ub_tmp[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> + +// Row-wise sum using vcgadd (for 8-row tile) +%row_sums = pto.vcgadd %tile, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// Results at indices 0, 8, 16, 24, 32, 40, 48, 56 + +// Full vector sum for normalization +%total = pto.vcadd %values, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// total[0] contains the sum + +// Prefix sum for cumulative distribution +%cdf = pto.vcpadd %pdf, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 11. Compare & Select + +> **Category:** Comparison and conditional selection operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that compare vectors and conditionally select elements. + +#### Common Operand Model + +- `%src0` and `%src1` are source vector operands. +- `%scalar` is the scalar operand for scalar-comparison families. +- `%seed` is the incoming predicate that limits which lanes participate in the + compare. +- `%result` is either a predicate mask (`vcmp`, `vcmps`) or a vector register + (`vsel`, `vselr`, `vselrv2`). + +--- + +#### Comparison Operations + +##### `pto.vcmp` + +- **syntax:** `%result = pto.vcmp %src0, %src1, %seed, "CMP_MODE" : !pto.vreg, !pto.vreg, !pto.mask -> !pto.mask` +- **semantics:** Element-wise comparison, output predicate mask. + +```c +for (int i = 0; i < N; i++) + if (seed[i]) + dst[i] = (src0[i] CMP src1[i]) ? 1 : 0; +``` + +**Compare modes:** + +| Mode | Operation | +|------|-----------| +| `eq` | Equal (==) | +| `ne` | Not equal (!=) | +| `lt` | Less than (<) | +| `le` | Less than or equal (<=) | +| `gt` | Greater than (>) | +| `ge` | Greater than or equal (>=) | + +**Example:** +```mlir +%all_active = pto.pset_b32 "PAT_ALL" : !pto.mask +%lt_mask = pto.vcmp %a, %b, %all_active, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +// lt_mask[i] = 1 if a[i] < b[i] +``` + +- **inputs:** `%src0`, `%src1`, and `%seed`; `CMP_MODE` selects the comparison + predicate. +- **outputs:** `%result` is the generated predicate mask. +- **constraints and limitations:** Only lanes enabled by `%seed` participate. + Integer and floating-point comparisons follow their own element-type-specific + comparison rules. + +--- + +##### `pto.vcmps` + +- **syntax:** `%result = pto.vcmps %src, %scalar, %seed, "CMP_MODE" : !pto.vreg, T, !pto.mask -> !pto.mask` +- **semantics:** Compare vector against scalar. + +```c +for (int i = 0; i < N; i++) + if (seed[i]) + dst[i] = (src[i] CMP scalar) ? 1 : 0; +``` + +**Example:** +```mlir +%positive_mask = pto.vcmps %values, %c0_f32, %all_active, "gt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +// positive_mask[i] = 1 if values[i] > 0 +``` + +- **inputs:** `%src` is the vector source, `%scalar` is the scalar comparison + value, and `%seed` is the incoming predicate. +- **outputs:** `%result` is the generated predicate mask. +- **constraints and limitations:** For 32-bit scalar forms, the scalar source + MUST satisfy the backend's legal scalar-source constraints for this family. + +--- + +#### Selection Operations + +##### `pto.vsel` + +- **syntax:** `%result = pto.vsel %src0, %src1, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Per-lane select based on mask. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? src0[i] : src1[i]; +``` + +**Example — Conditional assignment:** +```mlir +// dst = mask ? true_vals : false_vals +%result = pto.vsel %true_vals, %false_vals, %condition + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +- **inputs:** `%src0` is the true-path vector, `%src1` is the false-path vector, + and `%mask` selects between them. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** Source vectors and result MUST have matching + vector shapes and element types. + +--- + +##### `pto.vselr` + +- **syntax:** `%result = pto.vselr %src0, %src1 : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Select with reversed mask semantics. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? src1[i] : src0[i]; // reversed from vsel +``` + +- **inputs:** `%src0` and `%src1` are the source vectors. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** This family preserves reversed-select + semantics. If the concrete lowering uses an implicit predicate source, that + predicate source MUST be documented by the surrounding IR pattern. + +--- + +##### `pto.vselrv2` + +- **syntax:** `%result = pto.vselrv2 %src0, %src1 : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Variant select form with the same current two-vector operand shape. +- **inputs:** `%src0` and `%src1` are the source vectors. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** This page records the surface shape only. + Lowering MUST preserve the exact A5 variant semantics selected for this form. + +--- + +#### Typical Usage + +```mlir +// Clamp negative values to zero (manual ReLU) +%all = pto.pset_b32 "PAT_ALL" : !pto.mask +%zero = pto.vbr %c0_f32 : f32 -> !pto.vreg<64xf32> +%neg_mask = pto.vcmps %input, %c0_f32, %all, "lt" : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%clamped = pto.vsel %zero, %input, %neg_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Element-wise max via compare+select +%gt_mask = pto.vcmp %a, %b, %all, "gt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +%max_ab = pto.vsel %a, %b, %gt_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Threshold filter +%above_thresh = pto.vcmps %scores, %threshold, %all, "ge" : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%filtered = pto.vsel %scores, %zero, %above_thresh : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +--- + +#### Compare + Select Pattern + +```mlir +// Softmax safe exp: exp(x - max) where x < max returns exp of negative +// but we want to clamp to avoid underflow + +%all = pto.pset_b32 "PAT_ALL" : !pto.mask + +// 1. Compare against threshold +%too_small = pto.vcmps %x_minus_max, %min_exp_arg, %all, "lt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask + +// 2. Clamp values below threshold +%clamped = pto.vsel %min_exp_arg_vec, %x_minus_max, %too_small + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// 3. Safe exp +%exp_result = pto.vexp %clamped, %all : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 12. Data Rearrangement + +> **Category:** In-register data movement and permutation +> **Pipeline:** PIPE_V (Vector Core) + +Operations that rearrange data within or between vector registers without memory access. + +#### Common Operand Model + +- `%lhs` / `%rhs` are source vector register values. +- `%src` is a single source vector register value. +- `%result` is the destination vector register value unless an op explicitly + returns multiple vectors. +- These families do not access UB directly; they only rearrange register + contents. + +--- + +#### Interleave / Deinterleave + +##### `pto.vintlv` + +- **syntax:** `%low, %high = pto.vintlv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg, !pto.vreg` +- **semantics:** Interleave elements from two sources. + +```c +// Interleave: merge even/odd elements from two sources +// low = {src0[0], src1[0], src0[1], src1[1], ...} +// high = {src0[N/2], src1[N/2], src0[N/2+1], src1[N/2+1], ...} +``` + +- **inputs:** `%lhs` and `%rhs` are the two source vectors. +- **outputs:** `%low` and `%high` are the two destination vectors. +- **constraints and limitations:** The two outputs form a paired interleave + result. The PTO micro Instruction representation exposes that pair as two SSA results, and the pair ordering MUST + be preserved. + +--- + +##### `pto.vdintlv` + +- **syntax:** `%low, %high = pto.vdintlv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg, !pto.vreg` +- **semantics:** Deinterleave elements into even/odd. + +```c +// Deinterleave: separate even/odd elements +// low = {src0[0], src0[2], src0[4], ...} // even +// high = {src0[1], src0[3], src0[5], ...} // odd +``` + +- **inputs:** `%lhs` and `%rhs` represent the interleaved source stream in the + current PTO micro Instruction representation. +- **outputs:** `%low` and `%high` are the separated destination vectors. +- **constraints and limitations:** The two outputs form the even/odd + deinterleave result pair, and their ordering MUST be preserved. + +--- + +#### Slide / Shift + +##### `pto.vslide` + +- **syntax:** `%result = pto.vslide %src0, %src1, %amt : !pto.vreg, !pto.vreg, i16 -> !pto.vreg` +- **semantics:** Concatenate two vectors and extract N-element window at offset. + +```c +// Conceptually: tmp[0..2N-1] = {src1, src0} +// dst[i] = tmp[amt + i] +if (amt >= 0) + for (int i = 0; i < N; i++) + dst[i] = (i >= amt) ? src0[i - amt] : src1[N - amt + i]; +``` + +**Use case:** Sliding window operations, shift register patterns. + +- **inputs:** `%src0` and `%src1` provide the concatenated source window and + `%amt` selects the extraction offset. +- **outputs:** `%result` is the extracted destination window. +- **constraints and limitations:** `pto.vslide` operates on the logical + concatenation of `%src1` and `%src0`. The source order and extraction offset + MUST be preserved exactly. + +--- + +##### `pto.vshift` + +- **syntax:** `%result = pto.vshift %src, %amt : !pto.vreg, i16 -> !pto.vreg` +- **semantics:** Single-source slide (shift with zero fill). + +```c +for (int i = 0; i < N; i++) + dst[i] = (i >= amt) ? src[i - amt] : 0; +``` + +- **inputs:** `%src` is the source vector and `%amt` is the slide amount. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** This surface represents the single-source + slide/shift family. Zero-fill versus other fill behavior MUST match the + selected form. + +--- + +#### Compress / Expand + +##### `pto.vsqz` + +- **syntax:** `%result = pto.vsqz %src, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Compress — pack active lanes to front. + +```c +int j = 0; +for (int i = 0; i < N; i++) + if (mask[i]) dst[j++] = src[i]; +while (j < N) dst[j++] = 0; +``` + +**Use case:** Sparse data compaction, filtering. + +- **inputs:** `%src` is the source vector and `%mask` selects which elements are + kept. +- **outputs:** `%result` is the compacted vector. +- **constraints and limitations:** This is a reduction-style compaction family. + Preserved element order MUST match source lane order. + +--- + +##### `pto.vusqz` + +- **syntax:** `%result = pto.vusqz %mask : !pto.mask -> !pto.vreg` +- **semantics:** Expand — scatter front elements to active positions. + +```c +int j = 0; +for (int i = 0; i < N; i++) + if (mask[i]) dst[i] = src_front[j++]; + else dst[i] = 0; +``` + +- **inputs:** `%mask` is the expansion/placement predicate. +- **outputs:** `%result` is the expanded vector image. +- **constraints and limitations:** The source-front stream is implicit in the + current surface. Lane placement for active and inactive positions MUST be + preserved exactly. + +--- + +#### Permutation + +##### `pto.vperm` + +- **syntax:** `%result = pto.vperm %src, %index : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** In-register permute (table lookup). **Not** memory gather. + +```c +for (int i = 0; i < N; i++) + dst[i] = src[index[i] % N]; +``` + +**Note:** This operates on register contents, unlike `pto.vgather2` which reads from UB memory. + +- **inputs:** `%src` is the source vector and `%index` supplies per-lane source + indices. +- **outputs:** `%result` is the permuted vector. +- **constraints and limitations:** This is an in-register permutation family. + `%index` values outside the legal range follow the wrap/clamp behavior of the + selected form. + +--- + +##### `pto.vselr` + +- **syntax:** `%result = pto.vselr %src0, %src1 : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Register select with reversed mask semantics. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? src1[i] : src0[i]; +``` + +- **inputs:** `%src0` and `%src1` are source vectors. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** This page records the rearrangement use of + the family; the compare/select page documents the same name from the predicate + selection perspective. + +--- + +#### Pack / Unpack + +##### `pto.vpack` + +- **syntax:** `%result = pto.vpack %src0, %src1, %part : !pto.vreg, !pto.vreg, index -> !pto.vreg<2NxT_narrow>` +- **semantics:** Narrowing pack — two wide vectors to one narrow vector. + +```c +// e.g., two vreg<64xi32> → one vreg<128xi16> +for (int i = 0; i < N; i++) { + dst[i] = truncate(src0[i]); + dst[N + i] = truncate(src1[i]); +} +``` + +- **inputs:** `%src0` and `%src1` are wide source vectors and `%part` selects + the packing submode. +- **outputs:** `%result` is the packed narrow vector. +- **constraints and limitations:** Packing is a narrowing conversion. Source + values that do not fit the destination width follow the truncation semantics + of the selected packing mode. + +--- + +##### `pto.vsunpack` + +- **syntax:** `%result = pto.vsunpack %src, %part : !pto.vreg, index -> !pto.vreg` +- **semantics:** Sign-extending unpack — narrow to wide (half). + +```c +// e.g., vreg<128xi16> → vreg<64xi32> (one half) +for (int i = 0; i < N/2; i++) + dst[i] = sign_extend(src[part_offset + i]); +``` + +- **inputs:** `%src` is the packed narrow vector and `%part` selects which half + is unpacked. +- **outputs:** `%result` is the widened vector. +- **constraints and limitations:** This is the sign-extending unpack family. + +--- + +##### `pto.vzunpack` + +- **syntax:** `%result = pto.vzunpack %src, %part : !pto.vreg, index -> !pto.vreg` +- **semantics:** Zero-extending unpack — narrow to wide (half). + +```c +for (int i = 0; i < N/2; i++) + dst[i] = zero_extend(src[part_offset + i]); +``` + +- **inputs:** `%src` is the packed narrow vector and `%part` selects which half + is unpacked. +- **outputs:** `%result` is the widened vector. +- **constraints and limitations:** This is the zero-extending unpack family. + +--- + +#### Typical Usage + +```mlir +// AoS → SoA conversion using deinterleave +%even, %odd = pto.vdintlv %interleaved0, %interleaved1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +// Filter: keep only elements passing condition +%pass_mask = pto.vcmps %values, %threshold, %all, "gt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%compacted = pto.vsqz %values, %pass_mask + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Sliding window sum +%prev_window = pto.vslide %curr, %prev, %c1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, i16 -> !pto.vreg<64xf32> +%window_sum = pto.vadd %curr, %prev_window, %all + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Type narrowing via pack +%packed_i16 = pto.vpack %wide0_i32, %wide1_i32, %c0 + : !pto.vreg<64xi32>, !pto.vreg<64xi32>, index -> !pto.vreg<128xi16> +``` + +--- + +#### V2 Interleave Forms + +##### `pto.vintlvv2` + +- **syntax:** `%result = pto.vintlvv2 %lhs, %rhs, "PART" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **inputs:** `%lhs` and `%rhs` are source vectors and `PART` selects the + returned half of the V2 interleave result. +- **outputs:** `%result` is the selected interleave half. +- **constraints and limitations:** This op exposes only one half of the V2 + result in SSA form. + +##### `pto.vdintlvv2` + +- **syntax:** `%result = pto.vdintlvv2 %lhs, %rhs, "PART" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **inputs:** `%lhs` and `%rhs` are source vectors and `PART` selects the + returned half of the V2 deinterleave result. +- **outputs:** `%result` is the selected deinterleave half. +- **constraints and limitations:** This op exposes only one half of the V2 + result in SSA form. + + + +### 13. DSA/SFU Ops + +> **Category:** Domain-specific accelerator and special function unit operations +> **Pipeline:** PIPE_V (Vector Core) / SFU + +Fused operations, special functions, and UB-to-UB operations that leverage hardware acceleration. + +#### Common Operand Model + +- `%input`, `%lhs`, `%rhs`, `%acc`, and `%alpha` are source SSA values whose + roles are called out per instruction. +- `%mask` is the predicate operand `Pg` when present. +- `%result` is the destination SSA value. +- This page mixes three different backend shapes: pure `vreg -> vreg` ops, + conversion/fusion ops, and UB-to-UB helpers. Each instruction section calls + out which storage model it uses. + +--- + +#### Fused Activation Ops (vreg→vreg) + +##### `pto.vlrelu` + +- **syntax:** `%result = pto.vlrelu %input, %alpha, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Leaky ReLU with scalar alpha. + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : alpha * src[i]; +``` + +- **inputs:** `%input` is the activation vector, `%alpha` is the scalar slope, + and `%mask` selects active lanes. +- **outputs:** `%result` is the leaky-ReLU vector. +- **constraints and limitations:** Only `f16` and `f32` forms are currently + documented for `pto.vlrelu`. + +--- + +##### `pto.vprelu` + +- **syntax:** `%result = pto.vprelu %input, %alpha : !pto.vreg, !pto.vreg -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Parametric ReLU with per-element alpha vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : alpha[i] * src[i]; +``` + +- **inputs:** `%input` is the activation vector and `%alpha` is the per-element + slope vector. +- **outputs:** `%result` is the parametric-ReLU vector. +- **constraints and limitations:** Floating-point element types only on the + current A5 surface. + +--- + +##### `pto.vexpdiff` + +- **syntax:** `%result = pto.vexpdiff %input, %max : !pto.vreg, !pto.vreg -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Fused exp(x - max) for numerically stable softmax. + +```c +for (int i = 0; i < N; i++) + dst[i] = expf(src[i] - max[i]); +``` + +**Use case:** Softmax numerator computation with numerical stability. + +- **inputs:** `%input` is the source vector and `%max` is the broadcasted + subtraction term. +- **outputs:** `%result` is the fused `exp(input - max)` vector. +- **constraints and limitations:** Floating-point element types only. + +--- + +#### Fused Compute+Convert Ops + +##### `pto.vaddrelu` + +- **syntax:** `%result = pto.vaddrelu %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Fused add + ReLU. + +```c +for (int i = 0; i < N; i++) + dst[i] = max(src0[i] + src1[i], 0); +``` + +- **inputs:** `%lhs` and `%rhs` are the two addends. +- **outputs:** `%result` is the fused add-then-ReLU result. +- **constraints and limitations:** Floating-point element types only on the + current documented surface. + +--- + +##### `pto.vsubrelu` + +- **syntax:** `%result = pto.vsubrelu %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Fused sub + ReLU. + +```c +for (int i = 0; i < N; i++) + dst[i] = max(src0[i] - src1[i], 0); +``` + +- **inputs:** `%lhs` is the minuend and `%rhs` is the subtrahend. +- **outputs:** `%result` is the fused sub-then-ReLU result. +- **constraints and limitations:** Floating-point element types only on the + current documented surface. + +--- + +##### `pto.vaxpy` + +- **syntax:** `%result = pto.vaxpy %src0, %src1, %alpha : !pto.vreg, !pto.vreg, T -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** AXPY — scalar-vector multiply-add. + +```c +for (int i = 0; i < N; i++) + dst[i] = alpha * src0[i] + src1[i]; +``` + +- **inputs:** `%src0` is the scaled vector, `%src1` is the addend vector, and + `%alpha` is the scalar multiplier. +- **outputs:** `%result` is the fused AXPY result. +- **constraints and limitations:** Floating-point element types only on the + current documented surface. + +--- + +##### `pto.vaddreluconv` + +- **syntax:** `%result = pto.vaddreluconv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Fused add + ReLU + type conversion (HW fusion). + +```c +// f32→f16 variant: +for (int i = 0; i < 64; i++) + dst_f16[i] = f32_to_f16(max(src0_f32[i] + src1_f32[i], 0)); + +// f16→i8 variant: +for (int i = 0; i < 128; i++) + dst_i8[i] = f16_to_i8(max(src0_f16[i] + src1_f16[i], 0)); +``` + +- **inputs:** `%lhs` and `%rhs` are the source vectors. +- **outputs:** `%result` is the fused add/ReLU/convert result. +- **constraints and limitations:** Only backend-supported source/destination + type pairs are legal. Rounding, saturation, and packing rules follow the + semantics of this fused operation, not an arbitrary sequence of standalone + ops. + +--- + +##### `pto.vmulconv` + +- **syntax:** `%result = pto.vmulconv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Fused mul + type conversion (HW fusion). + +```c +// f16→i8 variant: +for (int i = 0; i < 128; i++) + dst_i8[i] = f16_to_i8(src0_f16[i] * src1_f16[i]); +``` + +- **inputs:** `%lhs` and `%rhs` are the source vectors. +- **outputs:** `%result` is the fused mul/convert result. +- **constraints and limitations:** Only backend-supported source/destination + type pairs are legal. + +--- + +#### Extended Arithmetic + +##### `pto.vmull` + +- **syntax:** `%low, %high = pto.vmull %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.vreg` +- **A5 types:** i32/u32 (native 32×32→64 widening multiply) +- **semantics:** Widening multiply with high/low results. + +```c +for (int i = 0; i < 64; i++) { + int64_t r = (int64_t)src0_i32[i] * (int64_t)src1_i32[i]; + dst_lo[i] = (int32_t)(r & 0xFFFFFFFF); + dst_hi[i] = (int32_t)(r >> 32); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the source vectors and `%mask` selects + active lanes. +- **outputs:** `%low` and `%high` expose the widened-product low/high parts. +- **constraints and limitations:** The current documented A5 form is the native + widening 32x32->64 integer multiply family. + +--- + +##### `pto.vmula` + +- **syntax:** `%result = pto.vmula %acc, %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Multiply-accumulate. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) + dst[i] = acc[i] + lhs[i] * rhs[i]; +``` + +- **inputs:** `%acc` is the accumulator input, `%lhs` and `%rhs` are the + multiplicands, and `%mask` selects active lanes. +- **outputs:** `%result` is the multiply-accumulate result. +- **constraints and limitations:** `pto.vmula` is a fused multiply-accumulate + operation and is not always interchangeable with separate `vmul` plus `vadd`. + +--- + +#### Index Generation + +##### `pto.vci` + +- **syntax:** `%result = pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- **semantics:** Generate lane index vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = base_index + i; +``` + +**Use case:** Generate indices for gather/scatter, argsort, etc. + +- **inputs:** `%index` is the scalar seed/base index. +- **outputs:** `%result` is the generated index vector. +- **constraints and limitations:** This page documents the arithmetic/indexing + use of the family; the conversion page also records the same opcode for + completeness. + +--- + +#### UB-to-UB Operations + +##### `pto.vtranspose` + +- **syntax:** `pto.vtranspose %dest, %src, %config : !pto.ptr, !pto.ptr, i64` +- **semantics:** UB-to-UB transpose operation (not vreg-to-vreg). + +**Note:** This operates on UB memory directly, not on vector registers. + +- **inputs:** `%dest` and `%src` are UB pointers and `%config` is the ISA + control/config word. +- **outputs:** This op writes UB memory and returns no SSA value. +- **constraints and limitations:** This is not a `vreg -> vreg` op even though + it lives in the `pto.v*` namespace. Its correctness depends on the control + word and UB layout contract. + +--- + +#### Sorting Operations + +##### `pto.vsort32` + +- **syntax:** `pto.vsort32 %dest, %src, %config : !pto.ptr, !pto.ptr, i64` +- **semantics:** Sort 32 elements in UB. +- **inputs:** `%dest` and `%src` are UB pointers and `%config` is the ISA + control/config word. +- **outputs:** This op writes UB memory and returns no SSA value. +- **constraints and limitations:** This is a UB-to-UB accelerator helper, not a + pure vector-register op. + +--- + +##### `pto.vmrgsort` + +- **syntax:** `pto.vmrgsort4 %dest, %src0, %src1, %src2, %src3, %count, %config : !pto.ptr, !pto.ptr x4, i64, i64` +- **semantics:** Merge-sort 4 pre-sorted input vectors. +- **inputs:** `%dest` is the UB destination, `%src0..%src3` are the four + pre-sorted UB inputs, `%count` is the number of valid elements, and `%config` + is the operation control word. +- **outputs:** This op writes UB memory and returns no SSA value. +- **constraints and limitations:** Inputs MUST already be sorted according to + the sort order encoded by `%config`. This page uses the shorter mnemonic + `pto.vmrgsort`, while the current implementation summary still refers to + `pto.vmrgsort4`. + +--- + +#### Current Implementation Surface Summary + +- `pto.vmull %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.vreg` +- `pto.vmula %acc, %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- `pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- `pto.vbitsort %dest, %src, %indices, %repeat_times : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, index` +- `pto.vmrgsort4 %dest, %src0, %src1, %src2, %src3, %count, %config : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, i64, i64` + +--- + +#### Typical Usage + +```mlir +// Softmax with fused expdiff +%max_broadcast = pto.vlds %ub_max[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> +%exp_stable = pto.vexpdiff %logits, %max_broadcast : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +// Leaky ReLU activation +%activated = pto.vlrelu %linear_out, %alpha_scalar, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Fused residual add + ReLU +%residual = pto.vaddrelu %conv_out, %skip_connection : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +// Generate indices for argsort +%indices = pto.vci %c0 {order = "ASC"} : i32 -> !pto.vreg<64xi32> +``` + + + +### 14. Arith (Shared MLIR Dialect) + +> **Category:** Shared full scalar `arith` surface used around PTO ops +> **Dialect:** `arith` +> **Upstream Reference:** https://mlir.llvm.org/docs/Dialects/ArithOps/ + +The upstream MLIR `arith` dialect defines primitive arithmetic, comparison, select, and cast operations over signless integer, index, floating-point, and boolean-compatible scalar values. Within PTO micro Instruction code, the full scalar operation surface of `arith` is supported. These ops are used around PTO instructions to build constants, compute offsets and loop bounds, perform general scalar math, derive valid-shape metadata, and form predicates for `scf` control flow. + +These ops are part of the documented PTO micro Instruction surface, but they are not PTO ISA instructions. + +--- + +#### Role in PTO micro Instruction Code + +- materialize scalar constants used by PTO scalar operands and loop bounds +- compute scalar/index offsets for tensor views, partitioning, and dynamic shapes +- perform general scalar integer and floating-point math outside PTO vector/tile payload operations +- derive scalar predicates that guard `scf.if` or `scf.while` +- apply scalar casts, width changes, bitwise ops, and selects without introducing PTO-specific control ops + +Prefer PTO ops for vector or tile payload math. Use `arith` for scalar computation and bookkeeping that surrounds PTO regions. + +--- + +#### Supported Surface + +The documented PTO micro Instruction surface supports the full scalar operation surface of upstream `arith`. The upstream `arith` dialect reference remains authoritative for the exhaustive op-by-op syntax and semantics. The categories below summarize how that support is used in PTO micro Instruction code. + +| Category | Representative Ops | Typical Use in PTO micro Instruction Code | +|----------|--------------------|------------------| +| Constants | `arith.constant` | integer, floating-point, boolean, and `index` constants | +| Integer / Index Arithmetic | `arith.addi`, `arith.subi`, `arith.muli`, `arith.divsi`, `arith.divui`, `arith.ceildivsi`, `arith.ceildivui`, `arith.floordivsi`, `arith.remsi`, `arith.remui` | offsets, bounds, chunk sizes, scalar math | +| Floating-Point Arithmetic | `arith.addf`, `arith.subf`, `arith.mulf`, `arith.divf`, `arith.negf`, `arith.maximumf`, `arith.minimumf`, `arith.maxnumf`, `arith.minnumf` | scalar math around PTO regions | +| Bitwise / Shift Ops | `arith.andi`, `arith.ori`, `arith.xori`, `arith.shli`, `arith.shrsi`, `arith.shrui` | flags, masks, packed scalar fields | +| Comparisons / Select | `arith.cmpi`, `arith.cmpf`, `arith.select`, `arith.maxsi`, `arith.minui` | predicates, clamps, scalar muxes | +| Casts / Width Changes | `arith.index_cast`, `arith.index_castui`, `arith.extsi`, `arith.extui`, `arith.trunci`, `arith.sitofp`, `arith.uitofp`, `arith.fptosi`, `arith.fptoui`, `arith.extf`, `arith.truncf`, `arith.bitcast` | ABI glue, dynamic-shape plumbing, scalar type adaptation | + +--- + +#### Current PTOAS Coverage + +- the current repository examples are still dominated by constants, casts, integer/index arithmetic, compares, and selects because those are the most common surrounding-scalar patterns in existing kernels +- backend-specific tests such as the PTO shared-dialect fixture visibly exercise only a representative subset of `arith` ops in a single path +- the documented PTO micro Instruction source-level contract is nevertheless the full scalar `arith` surface, not just the index-heavy subset that appears most often in current samples + +This section therefore uses representative categories and examples instead of pretending that the supported `arith` surface is limited to the currently most common sample patterns. + +--- + +#### Typical Patterns + +##### Scalar Setup + +```mlir +%c0 = arith.constant 0 : index +%c1 = arith.constant 1 : index +%scale = arith.constant 2.0 : f32 +``` + +##### Dynamic Offset Computation + +```mlir +%vrow = arith.index_cast %valid_row : i32 to index +%chunk = arith.muli %row, %c32 : index +%tail = arith.subi %limit, %chunk : index +``` + +##### General Scalar Arithmetic + +```mlir +%sum_i = arith.addi %lhs_i, %rhs_i : i32 +%sum_f = arith.addf %lhs_f, %rhs_f : f32 +%prod_f = arith.mulf %sum_f, %scale : f32 +``` + +##### Scalar Predicate and Selection + +```mlir +%is_first = arith.cmpi eq, %i, %c0 : index +%active = arith.select %is_first, %first_count, %steady_count : index +``` + +##### Bitwise / Width Adaptation + +```mlir +%flags = arith.andi %flags0, %flags1 : i32 +%wide = arith.extui %flags : i32 to i64 +%shrunk = arith.trunci %wide : i64 to i16 +``` + +--- + +#### Authoring Guidance + +- treat upstream `arith` scalar semantics as the source of truth for supported scalar ops +- keep `arith` values scalar or `index` typed; do not use `arith` as a substitute for PTO vector/tile compute +- use `arith` for general scalar math, scalar comparisons, bitwise operations, and casts around PTO regions, not just for `index` arithmetic +- use `arith.cmpi` / `arith.cmpf` plus `scf.if` / `scf.while` for control flow, not ad hoc control intrinsics +- prefer `arith.index_cast` / `arith.index_castui` at ABI or shape boundaries where `index` is required, but do not read that as a restriction on the rest of scalar `arith` + + + +### 15. SCF (Shared MLIR Dialect) + +> **Category:** Shared structured control flow around PTO regions +> **Dialect:** `scf` +> **Upstream Reference:** https://mlir.llvm.org/docs/Dialects/SCFDialect/ + +The upstream MLIR `scf` dialect defines structured control flow operations with regions, including counted loops, conditional regions, and while-style loops. In PTO micro Instruction code, `scf` is the control shell around PTO ops: it sequences DMA, vector, and tile operations; carries scalar or tile state across iterations; and preserves analyzable control flow for PTO-specific analyses and lowerings. + +These ops are part of the documented PTO micro Instruction surface, but they are shared MLIR control-flow constructs rather than PTO ISA instructions. + +--- + +#### Supported Ops + +| Op | Role in PTO micro Instruction Code | Notes | +|----|------------------------|-------| +| `scf.for` | counted loops and loop-carried values | common structured counted loop form | +| `scf.if` | structured conditional execution | may yield values or act as side-effect-only branch | +| `scf.yield` | region terminator for `for` / `if` / `while` bodies | carries loop or branch results | +| `scf.while` | break-like or stateful loops | useful for source-level structured control | +| `scf.condition` | loop-continue / loop-exit decision for `scf.while` | placed in the "before" region | + +Ops such as `scf.execute_region`, `scf.forall`, or `scf.index_switch` are not part of the documented shared-dialect portion of the PTO micro Instruction surface here. + +--- + +#### Current PTOAS Coverage + +- `scf.for`, `scf.if`, and `scf.yield` are directly exercised in the shared-dialect PTO fixture and appear widely across PTO samples +- PTO synchronization and memory analyses explicitly reason about `scf.for`, `scf.if`, `scf.yield`, and `scf.while` +- `scf.while` and `scf.condition` appear in control-flow samples and are handled in PTO-to-EmitC control-flow lowering, but they are less broadly exercised than `for` / `if` on all backend paths + +--- + +#### Typical Patterns + +##### Counted Loop + +```mlir +scf.for %i = %c0 to %c4 step %c1 { + %offset = arith.muli %i, %c32 : index + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} +``` + +##### Counted Loop with Loop-Carried State + +```mlir +%final_alive = scf.for %i = %c0 to %c4 step %c1 + iter_args(%alive = %true) -> (i1) { + %break_now = arith.cmpi eq, %i, %c2 : index + %next_alive = scf.if %break_now -> (i1) { + scf.yield %false : i1 + } else { + scf.yield %alive : i1 + } + scf.yield %next_alive : i1 +} +``` + +##### Structured Conditional Region + +```mlir +%is_mode_a = arith.cmpi eq, %mode, %c0_i32 : i32 +scf.if %is_mode_a { + pto.tmuls ins(%data, %scale_a : !pto.tile_buf<...>, f32) outs(%data : !pto.tile_buf<...>) +} else { + pto.tadds ins(%data, %bias_b : !pto.tile_buf<...>, f32) outs(%data : !pto.tile_buf<...>) +} +``` + +##### While-Style Break Loop + +```mlir +%final:2 = scf.while (%i = %c0, %alive = %true) : (index, i1) -> (index, i1) { + %lt = arith.cmpi slt, %i, %c4 : index + %go = arith.andi %lt, %alive : i1 + scf.condition(%go) %i, %alive : index, i1 +} do { +^bb0(%i2: index, %alive2: i1): + %next_i = arith.addi %i2, %c1 : index + scf.yield %next_i, %alive2 : index, i1 +} +``` + +--- + +#### Authoring Guidance + +- use `scf.for` for regular counted loops and loop-carried scalar/tile state +- use `scf.if` for structured branching around PTO regions instead of inventing PTO-specific branch ops +- keep region results explicit with `scf.yield`; this is important for PTO analyses that track carried buffers and aliasing +- use `scf.while` only when a counted loop cannot express the control cleanly; `scf.for` remains the more common and better-exercised form in the current repository +- build branch predicates and loop conditions with `arith` ops, not PTO vector masks, unless the control decision truly comes from a scalarized value + +## Supported Data Types + +| Type | Bits | vreg Lanes | Description | +|------|------|-----------|-------------| +| `i8` / `u8` | 8 | 256 | Signed/unsigned 8-bit integer | +| `i16` / `u16` | 16 | 128 | Signed/unsigned 16-bit integer | +| `f16` | 16 | 128 | IEEE 754 half precision | +| `bf16` | 16 | 128 | Brain floating point | +| `i32` / `u32` | 32 | 64 | Signed/unsigned 32-bit integer | +| `f32` | 32 | 64 | IEEE 754 single precision | +| `i64` / `u64` | 64 | 32 | Signed/unsigned 64-bit integer | + +--- + +## Common Patterns + +### Softmax (Numerically Stable) + +```mlir +// 1. Find max +%max_vec = pto.vcmax %logits, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +pto.vsts %max_vec, %ub_tmp[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +%max_bc = pto.vlds %ub_tmp[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> + +// 2. exp(x - max) using fused op +%exp = pto.vexpdiff %logits, %max_bc : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +// 3. Sum +%sum = pto.vcadd %exp, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +pto.vsts %sum, %ub_tmp[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +%sum_bc = pto.vlds %ub_tmp[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> + +// 4. Divide +%softmax = pto.vdiv %exp, %sum_bc, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +### ReLU Variants + +```mlir +// Standard ReLU +%relu = pto.vrelu %input, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Leaky ReLU (scalar alpha) +%lrelu = pto.vlrelu %input, %alpha, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Parametric ReLU (per-element alpha) +%prelu = pto.vprelu %input, %alpha_vec : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +// Fused add + ReLU +%fused = pto.vaddrelu %a, %b : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> +``` + +### Data Layout Conversion + +```mlir +// AoS → SoA (deinterleave) +%x, %y = pto.vldx2 %ub_xy[%offset], "DINTLV_B32" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +// SoA → AoS (interleave) +pto.vstx2 %x, %y, %ub_xy[%offset], "INTLV_B32", %all_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.ptr, index, !pto.mask +``` + +--- + +## Quick Reference by Category + +### Memory Operations + +| Operation | Group | Description | +|-----------|-------|-------------| +| GM→UB DMA | 2 | `pto.copy_gm_to_ubuf` | +| UB→GM DMA | 2 | `pto.copy_ubuf_to_gm` | +| UB→UB Copy | 2 | `pto.copy_ubuf_to_ubuf` | +| Contiguous Load | 3 | `pto.vlds` with `NORM` dist | +| Broadcast Load | 3 | `pto.vlds` with `BRC_*` dist | +| Gather | 3 | `pto.vgather2`, `pto.vgatherb` | +| Contiguous Store | 3 | `pto.vsts` with `NORM_*` dist | +| Scatter | 3 | `pto.vscatter` | + +### Compute Operations + +| Operation | Group | Description | +|-----------|-------|-------------| +| Element-wise Arithmetic | 6, 7 | `pto.vadd`, `pto.vmul`, `pto.vabs`, etc. | +| Scalar Operations | 8 | `pto.vadds`, `pto.vmuls`, etc. | +| Transcendental | 6 | `pto.vexp`, `pto.vln`, `pto.vsqrt`, etc. | +| Reduction | 10 | `pto.vcadd`, `pto.vcmax`, `pto.vcmin` | +| Comparison | 11 | `pto.vcmp`, `pto.vcmps` | +| Selection | 11 | `pto.vsel`, `pto.vselr` | + +### Type & Data Manipulation + +| Operation | Group | Description | +|-----------|-------|-------------| +| Type Conversion | 9 | `pto.vcvt` | +| Interleave/Deinterleave | 12 | `pto.vintlv`, `pto.vdintlv` | +| Interleave/Deinterleave | 12 | `pto.vintlv`, `pto.vdintlv`, `pto.vintlvv2`, `pto.vdintlvv2` | + +### Synchronization + +| Operation | Group | Description | +|-----------|-------|-------------| +| Intra-core Sync | 1 | `pto.set_flag`, `pto.wait_flag` | +| Pipeline Buffer Sync | 1 | `pto.get_buf`, `pto.rls_buf` | + +### Scalar & Control Operations + +Group 14 covers the full scalar `arith` surface. The rows below list common PTO micro Instruction patterns rather than an exhaustive partition of `arith` ops. + +| Operation | Group | Description | +|-----------|-------|-------------| +| Scalar Constants | 14 | `arith.constant` | +| Scalar Integer / Index Arithmetic | 14 | `arith.addi`, `arith.subi`, `arith.muli`, `arith.divsi`, `arith.remui`, `arith.ceildivsi`, etc. | +| Scalar Floating-Point Arithmetic | 14 | `arith.addf`, `arith.subf`, `arith.mulf`, `arith.divf`, `arith.maximumf`, etc. | +| Scalar Compare & Select | 14 | `arith.cmpi`, `arith.cmpf`, `arith.select` | +| Scalar Casts / Width Changes | 14 | `arith.index_cast`, `arith.index_castui`, `arith.extsi`, `arith.extui`, `arith.trunci`, `arith.sitofp`, etc. | +| Scalar Bitwise / Shift Ops | 14 | `arith.andi`, `arith.ori`, `arith.xori`, `arith.shli`, `arith.shrsi`, `arith.shrui`, etc. | +| Counted Loops | 15 | `scf.for` | +| Conditional Regions | 15 | `scf.if`, `scf.yield` | +| Break-like Structured Loops | 15 | `scf.while`, `scf.condition`, `scf.yield` | diff --git a/docs/release/vpto-spec-v0.3.md b/docs/release/vpto-spec-v0.3.md new file mode 100644 index 000000000..8de281795 --- /dev/null +++ b/docs/release/vpto-spec-v0.3.md @@ -0,0 +1,5349 @@ +# PTO micro Instruction Spec — Draft (A5) + +- v0.3: Add runtime block query and vector-interval legality notes; Normalize load/store distribution families; Update get_buf/rls_buf details +- v0.2: Update micro Instruction latency and throughput +- v0.1: Doc Init + +[toc] + +--- + +## Part I: Architecture Overview + +### Overview + +This document defines the PTO micro Instruction, a compiler-internal and externally facing specification designed to represent vector compute kernels within the PTO architecture. Much like NVVM provides a robust IR for GPU architectures, the PTO micro Instruction serves as the direct bridge between high-level programming models and the underlying hardware ISA, providing a precise, low-level representation of vector workloads explicitly designed for the Ascend 950 architecture. + +#### Position in the Stack and Layer Modeled + +The PTO micro Instruction operates as a very low-level intermediate representation within the PTO compiler stack. It is uniquely designed to accurately and comprehensively express all architectural information of the Ascend 950 hardware. It specifically models the bare-metal vector execution layer, making hardware-specific capabilities and constraints, such as exact vector lane configurations, memory space hierarchies, and hardware-specific fusion semantics, fully transparent and controllable. + +#### PTO Instruction Modes and Compilation Flows + +Within the end-to-end PTO software stack, PTO instructions may appear in three closely related authoring or lowering modes: + +- **PTO Tile Instruction**: tile-oriented PTO code that serves as a nano-kernel encapsulation of Tile operations, primarily expressing computation and data movement in terms of tile buffers, tile shapes, and tile-local layout. +- **PTO micro Instruction**: vector-execution-oriented PTO code that makes DMA setup, vector registers, masks, synchronization, and `__VEC_SCOPE__` boundaries explicit. This document is centered on this mode. +- **PTO Tile+micro Instruction**: a hybrid PTO form that keeps tile-level orchestration while embedding explicit micro-instruction regions where direct vector-pipeline control is required. + +From these PTO instruction forms, the stack can proceed along two main compilation flows: + +- **CCE generation flow**: PTO ISA is lowered into a CCE-oriented representation, which is then compiled by the BiSheng toolchain into Ascend device binaries. +- **Bytecode generation flow**: PTO ISA is emitted as bytecode, which is then compiled by the BiSheng toolchain into Ascend device binaries. + +```text +High-level frameworks / DSLs / library kernels + | + v + +----------------------------------+ + | PTO ISA layer | + | | + | (1) PTO Tile Instruction | + | (2) PTO micro Instruction | + | (3) PTO Tile+micro Instruction | + +----------------+-----------------+ + | + +------------+------------+ + | | + v v + +-------------------------+ +-------------------------+ + | Path A: generate CCE | | Path B: generate | + | (CCE-oriented form) | | bytecode | + +------------+------------+ +------------+------------+ + | | + v v + +-----------------------------------------------+ + | BiSheng compiler | + +---------------------------+-------------------+ + | + v + +-----------------------------+ + | Ascend device binaries | + +-----------------------------+ +``` + +#### Why External Developers Read or Author PTO micro Instruction + +While the majority of users will interact with the PTO architecture via higher-level frameworks, external developers may need to read or author PTO micro Instruction directly for several key reasons: + +- Custom Toolchain Development: build custom compiler frontends or domain-specific languages (DSLs) that target the Ascend 950 architecture with maximum hardware utilization. +- Performance Engineering: inspect the output of high-level compiler passes, verify fine-grained optimization behaviors, and pinpoint performance bottlenecks at the architectural level. +- Micro-Optimization: hand-author highly optimized, critical mathematical kernels using a stable, precise IR when higher-level abstractions cannot achieve the theoretical peak performance of the hardware. + +#### Relationship to CCE + +The PTO micro Instruction is designed to express the full semantic capabilities of the Compute Cube Engine (CCE), but with significant structural and pipeline advantages for compiler development. + +- Bypassing the C/Clang Pipeline: while CCE heavily relies on C/C++ extensions parsed by Clang, the PTO micro Instruction operates entirely independently of the C language frontend. By bypassing Clang AST generation and frontend processing, utilizing the PTO micro Instruction significantly reduces overall compilation time and memory overhead. +- Enhanced IR Verification: because the PTO micro Instruction is a strongly typed, SSA-based (Static Single Assignment) compiler IR rather than a C-wrapper API, it provides a much more rigorous and detailed IR verification process. Structural inconsistencies, invalid memory access patterns, and operand type mismatches are caught immediately with precise, explicit diagnostic feedback, providing developers with much higher visibility into kernel correctness than traditional CCE error reporting. + +#### Intended Audience + +This document is written for compiler engineers, library writers, and advanced performance architects. We expect the reader to have a working understanding of modern compiler infrastructure, specifically MLIR, the principles of Static Single Assignment (SSA) form, and a deep understanding of the vector-processing capabilities of the Ascend 950 architecture. + +### Getting Started + +The PTO micro Instruction is architected as a performance-critical layer within the compiler stack, specifically designed to exploit the **Decoupled Access-Execute** (DAE) nature of the Ascend 950 hardware. + +#### Hardware Pipeline Modeling + +The IR is structured to mirror the three primary hardware pipelines of the Ascend 950 architecture. Correct PTO micro Instruction authoring requires managing the interaction between these asynchronous units: + +**MTE2** (Memory Transfer Engine - Inbound): Responsible for moving data from Global Memory (GM) to the Unified Buffer (UB). + +**Vector Core** (Computation): The primary engine for executing SIMD operations on data stored in UB. + +**MTE3** (Memory Transfer Engine - Outbound): Responsible for moving processed data from UB back to GM. + +#### Architecture Detail: Vector Lane (VLane) + +The vector register is organized as **8 VLanes** of 32 bytes each. A VLane is the atomic unit for group reduction operations. + +``` +vreg (256 bytes total): +┌─────────┬─────────┬─────────┬─────┬─────────┬─────────┐ +│ VLane 0 │ VLane 1 │ VLane 2 │ ... │ VLane 6 │ VLane 7 │ +│ 32B │ 32B │ 32B │ │ 32B │ 32B │ +└─────────┴─────────┴─────────┴─────┴─────────┴─────────┘ +``` + +Elements per VLane by data type: + +| Data Type | Elements/VLane | Total Elements/vreg | +|-----------|---------------|-------------------| +| i8/si8/ui8 | 32 | 256 | +| i16/si16/ui16/f16/bf16 | 16 | 128 | +| i32/si32/ui32/f32 | 8 | 64 | +| i64/si64/ui64 | 4 | 32 | + +#### Memory and Synchronization Model + +The PTO micro Instruction enforces a strict memory hierarchy. The Unified Buffer (UB) is the only valid operand source for vector compute instructions. Consequently, the architecture of a PTO micro Instruction program is defined by the explicit management of data movement: + +**Address Space Isolation**: The IR uses `!pto.ptr` to distinguish between GM (`!pto.ptr`) and UB (`!pto.ptr`). The verifier ensures that vector compute operations do not access GM directly; data must first be moved into UB. + +**UB Capacity**: The Unified Buffer provides 256KB of on-chip SRAM (also referred to as "vecTile"). + +**Data Flow**: + +``` +┌─────────────────────────────────────────────┐ +│ Global Memory (GM) │ +│ (Off-chip HBM/DDR) │ +└─────────────────────┬───────────────────────┘ + │ DMA (MTE2 inbound / MTE3 outbound) +┌─────────────────────▼───────────────────────┐ +│ Unified Buffer (UB) │ +│ (On-chip SRAM, 256KB) │ +└─────────────────────┬───────────────────────┘ + │ Vector Load/Store (PIPE_V) +┌─────────────────────▼───────────────────────┐ +│ Vector Register File (VRF) │ +│ vreg (256B each) + mask (256-bit each) │ +└─────────────────────────────────────────────┘ +``` + +1. **GM → UB**: DMA transfer via MTE2 (`pto.copy_gm_to_ubuf`) +2. **UB → vreg**: Vector Load instructions (`pto.vlds`, `pto.vldsx2`, etc.) +3. **vreg → vreg**: Compute instructions (`pto.vadd`, `pto.vmul`, etc.) +4. **vreg → UB**: Vector Store instructions (`pto.vsts`, `pto.vstsx2`, etc.) +5. **UB → GM**: DMA transfer via MTE3 (`pto.copy_ubuf_to_gm`) + +**Load/Store Access Patterns**: + +For UB↔vreg data movement, besides contiguous load/store, the architecture provides rich access pattern support including strided access, pack/unpack, interleave/deinterleave, broadcast, upsample/downsample, channel split/merge, gather/scatter, and squeeze/expand operations. For detailed instruction syntax and distribution modes, refer to the [Vector Load/Store](#isa-03-vector-load-store) group in the ISA specification. + +#### Synchronization Model + +The Ascend 950 architecture employs a cluster-based design with a 1:2 ratio of Cube cores to Vector cores. The PTO micro Instruction provides multiple levels of synchronization to manage concurrent execution across pipelines and cores: + +**Inter-Core Synchronization (within a cluster):** + +Synchronization between cores within the same cluster is achieved via the core sync mechanism using `pto.set_intra_core` and `pto.wait_intra_core` operations. This enables coordination between Cube and Vector cores sharing the same cluster resources. + +**Vector Core Pipeline Synchronization:** + +Within a single core, multiple pipelines operate asynchronously: + +- **MTE2 (PIPE_MTE2)**: DMA copy-in from GM to UB +- **MTE3 (PIPE_MTE3)**: DMA copy-out from UB to GM +- **Vector Compute (PIPE_V)**: Vector ALU operations +- **Scalar (PIPE_S)**: Scalar unit running the kernel program + +Pipeline synchronization can be achieved through two mechanisms: + +1. **Flag/Event mechanism**: `pto.set_flag` and `pto.wait_flag` operations resolve Read-After-Write (RAW) and Write-After-Read (WAR) hazards between pipelines. + +2. **Buffer-ID mechanism**: `pto.get_buf` and `pto.rls_buf` provide finer-grained synchronization through buffer acquisition and release semantics for producer-consumer coordination. + +**Intra-Pipeline Memory Barriers (within `__VEC_SCOPE__`):** + +Within the vector execution scope, the hardware does not track UB address aliasing between reg↔UB accesses. When UB addresses overlap or alias between vector load/store operations, explicit memory barriers are required: + +```c +pto.mem_bar "VV_ALL" // All prior vector ops complete before subsequent +pto.mem_bar "VST_VLD" // All prior vector stores visible before subsequent loads +pto.mem_bar "VLD_VST" // All prior vector loads complete before subsequent stores +``` + +Without proper barriers, loads may see stale data or stores may be reordered incorrectly. + +#### Execution Scopes (__VEC_SCOPE__) + +`__VEC_SCOPE__` is the IR-level representation of a Vector Function (VF) launch. In the PTO architecture, it defines the hardware interface between the Scalar Unit and the Vector Thread. + +In PTO micro Instruction source IR, vector execution scopes are modeled as dedicated region ops. The default form is `pto.vecscope`; when the scope body must reject implicit capture and require explicit region arguments, use `pto.strict_vecscope`. + +**Scalar-Vector Interface:** + +The execution model follows non-blocking fork semantics: + +- Scalar invocation: the scalar processor invokes a vector thread by calling a VF. Once the launch command is issued, the scalar unit does not stall and continues executing subsequent instructions in the pipeline. +- Vector execution: after invocation, the vector thread independently fetches and executes the instructions defined within the VF scope. +- Parallelism: this decoupled execution allows the scalar and vector units to run in parallel, so the scalar unit can prepare addresses or manage control flow while the vector unit performs heavy SIMD computation. + +**Launch Mechanism And Constraints:** + +- Parameter buffering: all arguments required by the VF must be staged in hardware-specific buffers. +- Launch overhead: launching a VF incurs a latency of a few cycles. Very small VFs should account for this overhead because launch cost can rival useful computation time. + +**MLIR Representation:** + +```mlir +pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} +``` + +**Strict MLIR Representation:** + +```mlir +pto.strict_vecscope(%ub, %ub_out, %lane) { +^bb0(%in: !pto.ptr, %out: !pto.ptr, %iv: index): + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %in[%iv] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %out[%iv], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} : (!pto.ptr, !pto.ptr, index) -> () +``` + +`pto.strict_vecscope` is the strict form of `pto.vecscope`. + +- `pto.vecscope` allows the body to use surrounding SSA values directly. +- `pto.strict_vecscope` requires every external value used by the body to be passed through the op operand list and received as a body block argument. +- `pto.strict_vecscope` rejects implicit capture from the surrounding scope. +- both ops still represent one explicit VPTO vector interval. +- regardless of whether the source form uses `pto.vecscope`, + `pto.strict_vecscope`, or a lowered carrier loop with + `llvm.loop.aivector_scope`, every op that produces or consumes `!pto.vreg`, + `!pto.mask<...>`, or `!pto.align` must be enclosed by exactly one vector + interval +- nested vector intervals are not part of the legal VPTO surface; ordinary + nested `scf.for` structure is fine, but one vector interval may not contain + another vector interval + +### Example: VecScope + +```mlir +pto.set_loop2_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.copy_gm_to_ubuf %7, %2, %3, %3, %c0_i64, %c32_i64, %4, %c0_i64, %c0_i64, + %false, %c0_i64, %c128_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i64, i1, i64, i64, i64 + +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +pto.vecscope { + scf.for %lane = %c0 to %9 step %c64 { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %2[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %8[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } +} + +pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] +pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_ubtoout %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop2_stride_ubtoout %c4096_i64, %c4096_i64 : i64, i64 +pto.copy_ubuf_to_gm %8, %14, %3, %3, %c0_i64, %c32_i64, %4, %c0_i64, %c128_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i64, i64 +``` + +### Example: Strict VecScope + +```mlir +pto.strict_vecscope(%ub_in, %ub_out, %lane, %remaining) { +^bb0(%in: !pto.ptr, %out: !pto.ptr, %iv: index, %rem: i32): + %mask, %next_remaining = pto.plt_b32 %rem : i32 -> !pto.mask, i32 + %v = pto.vlds %in[%iv] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %out[%iv], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} : (!pto.ptr, !pto.ptr, index, i32) -> () +``` + +Use `pto.strict_vecscope` when the source form should make all vector-scope inputs explicit in the region signature instead of relying on surrounding SSA visibility. The scope op itself only defines the vector-interval boundary and region argument contract. + +### Scope + +This document is the interface specification centered on the `mlir::pto` dialect and the shared MLIR surface used alongside it in PTO micro Instruction programs. + +It only describes: + +- operation names +- operand and result lists +- operand and result types +- important attributes +- C-style semantics for each operation + +It does not describe lowering strategy. + +PTO micro Instruction source programs are not restricted to `pto` operations alone. In practice they also use shared MLIR dialect ops, most notably the full scalar operation surface of `arith` together with structured control-flow ops from `scf`, to express scalar constants, scalar arithmetic, type conversion, comparisons, and structured control flow around PTO vector or tile regions. These shared-dialect ops are part of the supported PTO micro Instruction source surface and should be regarded as part of PTO-ISA alongside `pto` dialect operations. + +### Shared MLIR Dialects + +- `arith`: the full scalar `arith` surface is supported in PTO micro Instruction programs, covering scalar integer, floating-point, boolean, and `index` operations. In current samples the most common uses are still constants, offset/bounds arithmetic, casts, compares, and selects. +- `scf`: structured control flow used to model counted loops, conditional regions, loop-carried state, and break-like control around PTO compute and data-movement ops. +- Shared dialect ops remain in standard MLIR form so that PTO analyses and backend passes can reason about control flow and scalar state without re-encoding them as PTO-specific instructions. + +### BlockDim Query Operations + +These ops expose the current kernel instance's execution coordinates to scalar code. They are the PTO-level equivalent of runtime queries such as `GetBlockIdx()` and `GetBlockNum()` in kernel programming models. + +Use them when the same kernel body is launched across multiple blocks or subblocks and each execution instance must figure out which slice of the global workload it owns. + +A common pattern is: + +- split the full input/output tensor into `block_num` disjoint block-sized regions +- let each block compute its own starting offset from `block_idx` +- within one block, further tile the local region and drive the tile loop with ordinary scalar `arith` / `scf` ops + +For example, if a tensor is split evenly across 8 blocks and each block handles `block_length = 2048` elements, then block `b` owns the global range `[b * block_length, (b + 1) * block_length)`. The per-block GM base pointer can be formed by adding `block_idx * block_length` elements to the original base pointer. + +At the PTO micro Instruction level, these runtime-query ops are pure scalar producers. They do not perform data movement, do not allocate memory, and do not by themselves create tiling or double buffering. Instead, they provide the scalar values used by surrounding address computation and structured control flow. + +#### Example: block-level data partitioning + +```mlir +%block = pto.get_block_idx +%block_num = pto.get_block_num +%block_len = arith.constant 2048 : index +%base = arith.index_cast %block : i64 to index +%offset = arith.muli %base, %block_len : index +%block_in = pto.addptr %gm_in, %offset : !pto.ptr -> !pto.ptr +%block_out = pto.addptr %gm_out, %offset : !pto.ptr -> !pto.ptr +``` + +In this pattern, all blocks execute the same kernel body, but each block sees a different `%block` value and therefore computes a different GM window. + +#### `pto.get_block_idx` + +- **syntax:** `%block = pto.get_block_idx` +- **result:** `i64` +- **semantics:** Return the current block ID in the range `[0, pto.get_block_num())`. + +```c +block = block_idx(); +``` + +#### `pto.get_subblock_idx` + +- **syntax:** `%subblock = pto.get_subblock_idx` +- **result:** `i64` +- **semantics:** Return the current subblock ID in the range `[0, pto.get_subblock_num())`. + +```c +subblock = subblock_idx(); +``` + +#### `pto.get_block_num` + +- **syntax:** `%block_num = pto.get_block_num` +- **result:** `i64` +- **semantics:** Return the total number of launched blocks visible to the current kernel instance. + +```c +block_num = block_num(); +``` + +#### `pto.get_subblock_num` + +- **syntax:** `%subblock_num = pto.get_subblock_num` +- **result:** `i64` +- **semantics:** Return the total number of visible subblocks for the current execution instance. + +```c +subblock_num = subblock_num(); +``` + +Typical usage: + +```mlir +%block = pto.get_block_idx +%subblock = pto.get_subblock_idx +%block_num = pto.get_block_num +%subblock_num = pto.get_subblock_num +``` + +### Core Types + +### Element Types +`vreg`: `!pto.vreg` Fixed-width PTO micro Instruction vector type with total width exactly 256 bytes (2048 bits). `N` is the lane count, `T` is the element type, and `N * bitwidth(T) = 2048`. + +| Type | Bits | Description | +|------|------|-------------| +| `i8` / `si8` / `ui8` | 8 | Signless/signed/unsigned 8-bit integer | +| `i16` / `si16` / `ui16` | 16 | Signless/signed/unsigned 16-bit integer | +| `i32` / `si32` / `ui32` | 32 | Signless/signed/unsigned 32-bit integer | +| `i64` / `si64` / `ui64` | 64 | Signless/signed/unsigned 64-bit integer | +| `f16` | 16 | IEEE 754 half precision | +| `bf16` | 16 | Brain floating point | +| `f32` | 32 | IEEE 754 single precision | + +### Mask Types + +`mask`: `!pto.mask` Typed predicate-register view. `G` is one of `b8`, `b16`, `b32` and records the byte-granularity interpretation used by VPTO ops and verifiers. + +Typed masks are also the primary legality contract for predicated VPTO code: + +- vector ops over `f32`, `i32`, `si32`, and `ui32` consume `!pto.mask` +- vector ops over `f16`, `bf16`, `i16`, `si16`, and `ui16` consume + `!pto.mask` +- vector ops over 8-bit element families consume `!pto.mask` +- compare families keep seed-mask and result-mask granularity aligned with the + compared vector family +- carry families keep carry-in, carry-out, and execution-mask granularity + aligned with the data-vector family +- mask-only ops that do not explicitly change granularity preserve the same `G` + +### Address Space Conventions + +PTO micro Instruction memory operands use `!pto.ptr`. This specification models the following memory-space attributes: + +| Space | Interpretation | +|-------|----------------| +| `gm` | Global Memory (GM), off-chip HBM/DDR storage | +| `ub` | Unified Buffer (UB), on-chip vector buffer | + +Typical pointer construction and pointer arithmetic follow the same `!pto.ptr<..., space>` form: + +```mlir +%0 = pto.castptr %c0 : i64 -> !pto.ptr +%1 = pto.addptr %0, %c1024 : !pto.ptr -> !pto.ptr +``` + +### `!pto.ptr` + +`!pto.ptr` is the typed pointer form used for explicit memory operands in PTO micro Instruction. + +- `T` is the element type associated with the pointed-to storage. +- `space` is the memory domain, typically `gm` or `ub` in this specification. +- A `pto.ptr` value carries an address plus its element-type / memory-space interpretation, but it does not carry tensor shape or stride metadata by itself. +- Tensor semantics are introduced separately through view-building operations such as `pto.make_tensor_view`. +- Pointer arithmetic is element-based rather than byte-based. + +Typical examples: + +- `!pto.ptr` +- `!pto.ptr` +- `!pto.ptr` + +### Pointer Operations + +#### `pto.castptr` + +- **syntax:** `%result = pto.castptr %addr : i64 -> !pto.ptr` +- **semantics:** Reinterpret a scalar address value as a typed PTO pointer in the target memory space. + +```c +result = (ptr)addr; +``` + +`pto.castptr` is a pointer-construction operation. It does not perform data movement and does not by itself imply any load/store side effect. + +#### `pto.addptr` + +- **syntax:** `%result = pto.addptr %ptr, %offset : !pto.ptr -> !pto.ptr` +- **semantics:** Compute a new pointer by advancing the base pointer by an element offset. + +```c +result = ptr + offset; // offset counted in elements, not bytes +``` + +`pto.addptr` preserves both the element type `T` and the memory-space tag `space`. + +#### `pto.load_scalar` + +- **syntax:** `%value = pto.load_scalar %ptr[%offset] : !pto.ptr -> T` +- **semantics:** Load one scalar element from a pointer-like operand. + +```c +value = ptr[offset]; +``` + +- **inputs:** + `%ptr` is a typed PTO pointer `!pto.ptr`, and `%offset` is an + `index` displacement counted in elements. +- **outputs:** + `%value` is the loaded scalar element. +- **constraints and limitations:** + The result type MUST match the element type of `%ptr`. This op is a scalar + memory helper; unlike `pto.vlds`, it does not produce a `vreg` result and + does not participate in vector load `dist` families. + +#### `pto.store_scalar` + +- **syntax:** `pto.store_scalar %value, %ptr[%offset] : !pto.ptr, T` +- **semantics:** Store one scalar element to a pointer-like operand. + +```c +ptr[offset] = value; +``` + +- **inputs:** + `%value` is the scalar value to store. `%ptr` is a typed PTO pointer + `!pto.ptr`, and `%offset` is an `index` displacement counted in + elements. +- **constraints and limitations:** + The stored value type MUST match the element type of `%ptr`. This op is a + scalar memory helper; unlike `pto.vsts`, it does not consume a mask and does + not target vector-store `dist` families. + +#### Pointer-Based Vector Access Example + +The following lowered-style fragment shows how typed PTO pointers flow through +pointer construction, pointer arithmetic, structured control flow, and PTO +memory ops. Scalar memory access is expressed on `!pto.ptr` in +general, but the common VPTO pattern here is UB-local scalar access alongside +UB vector loads/stores: + +```mlir +%0 = pto.castptr %c0 : i64 -> !pto.ptr +%1 = pto.addptr %0, %c1024 : !pto.ptr -> !pto.ptr +pto.vecscope { + %16 = scf.for %arg3 = %c0 to %11 step %c64 iter_args(%arg4 = %12) -> (i32) { + %mask, %scalar_out = pto.plt_b32 %arg4 : i32 -> !pto.mask, i32 + %s = pto.load_scalar %1[%c4] : !pto.ptr -> f32 + pto.store_scalar %s, %1[%c8] : !pto.ptr, f32 + %17 = pto.vlds %1[%arg3] : !pto.ptr -> !pto.vreg<64xf32> + %18 = pto.vabs %17, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %18, %10[%arg3], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %scalar_out : i32 + } +} +``` + +In this pattern, `pto.castptr` materializes a typed UB pointer, `pto.addptr` shifts the base by 1024 `f32` elements, and the subsequent `[%arg3]` indexing on `pto.vlds` / `pto.vsts` applies an additional element offset relative to that base. + +### Special Types + +#### `!pto.mask` + +`!pto.mask` models an A5 predicate register (256-bit) under a typed granularity view, not an integer vector. + +`G` is part of the type and MUST be one of: + +- `b32` +- `b16` +- `b8` + +All three forms describe the same physical 256-bit predicate-register class. The type parameter does not encode how many lanes are currently active. Instead, it records how VPTO interprets the register when matching mask-producing ops, mask-consuming ops, and verifier legality rules. + +In the ISA chapters below, this document uses `!pto.mask` as shorthand when a +family is generic over granularity. For op families whose names already encode +the granularity, such as `pset_b32`, `pge_b16`, `plt_b8`, +`pdintlv_b8`, and `pintlv_b16`, examples use the corresponding concrete typed +mask. + +**Mask Granularity:** + +The predicate register is 256 bits in length, where each bit controls 1 byte of data. `G` therefore describes how many bytes form one logical element slot: + +| Mask Type | Bytes / Element Slot | Typical Element Family | Derived Logical Lanes | +|-----------|----------------------|------------------------|-----------------------| +| `!pto.mask` | 4 | `f32` / `i32` | 64 | +| `!pto.mask` | 2 | `f16` / `bf16` / `i16` | 128 | +| `!pto.mask` | 1 | 8-bit element family | 256 | + +This is intentionally different from a lane-vector model such as `mask<64xi1>`: + +- `!pto.mask` still denotes a 256-bit predicate register; +- `64` is only the derived logical lane count for the `b32` view; +- value-level patterns such as `PAT_VL32` describe which lanes are active, not a different type. +- `pto.vaddc`, `pto.vsubc`, `pto.vaddcs`, and `pto.vsubcs` use `!pto.mask` + to carry their per-lane carry results, interpreted with this same + granularity. + +**Predication Behavior (Zero-Merge):** + +The native hardware predication mode is **ZEROING** — inactive lanes produce zero: + +```c +dst[i] = mask[i] ? op(src0[i], src1[i]) : 0 // ZEROING mode +``` + +```mlir +// Predicated add: inactive lanes produce zero +%mask = pto.pset_b32 "PAT_VL32" : !pto.mask // first 32 logical b32 lanes active +%result = pto.vcmp %a, %b, %mask, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +``` + +```mlir +// Compare and select: generate mask from comparison, use for conditional select +%mask = pto.vcmp %lhs, %rhs, %seed, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +%out = pto.vsel %x, %y, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +#### `!pto.align` + +`!pto.align` models the A5 vector-align carrier state. It is not payload data. + +```mlir +%align = pto.vldas %ub : !pto.ptr -> !pto.align +%vec, %align_out = pto.vldus %ub, %align : !pto.ptr, !pto.align -> !pto.vreg<64xf32>, !pto.align + +%store_align = pto.init_align : !pto.align +%next_align = pto.vstus %store_align, %offset, %vec, %ub + : !pto.align, i32, !pto.vreg<64xf32>, !pto.ptr -> !pto.align +``` + +--- + +## Part II: Notation Convention + +This section defines the MLIR syntax patterns and C-style semantic notation used throughout the ISA reference (Part III). + +### MLIR Op Syntax Patterns + +All PTO micro Instruction operations follow standard MLIR syntax. The common patterns are: + +**Unary (one vector in, one vector out):** + +```mlir +%result = pto. %input : !pto.vreg -> !pto.vreg +``` + +**Binary (two vectors in, one vector out):** + +```mlir +%result = pto. %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg +``` + +**Vec-Scalar (one vector + one scalar in, one vector out):** + +```mlir +%result = pto. %input, %scalar : !pto.vreg, T -> !pto.vreg +``` + +**Load (memory to register):** + +```mlir +%result = pto.vlds %source[%offset] {dist = "DIST"} : !pto.ptr -> !pto.vreg +``` + +**Store (register to memory):** + +```mlir +pto.vsts %value, %destination[%offset] {dist = "DIST"} : !pto.vreg, !pto.ptr +``` + +**Dual Load (one load, two results — deinterleave):** + +```mlir +%low, %high = pto.vldsx2 %source[%offset], "DIST" : !pto.ptr, index -> !pto.vreg, !pto.vreg +``` + +**Dual Store (two inputs, one interleaved store):** + +```mlir +pto.vstsx2 %low, %high, %dest[%offset], "DIST", %mask : !pto.vreg, !pto.vreg, !pto.ptr, index, !pto.mask +``` + +**Compare (two vectors + seed mask in, mask out):** + +```mlir +%mask = pto.vcmp %src0, %src1, %seed, "CMP_MODE" : !pto.vreg, !pto.vreg, !pto.mask -> !pto.mask +``` + +**Conversion (one vector in, different-typed vector out):** + +```mlir +%result = pto.vcvt %input {rnd = "R", sat = "SAT", part = "EVEN"} : !pto.vreg -> !pto.vreg +``` + +**Predicate construction:** + +```mlir +%mask = pto.pset_b32 "PAT_ALL" : !pto.mask +%tail = pto.pge_b32 "PAT_VL16" : !pto.mask +``` + +**Sync operations:** + +```mlir +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +**Pointer construction and arithmetic:** + +```mlir +%ptr = pto.castptr %addr : i64 -> !pto.ptr +%ptr2 = pto.addptr %ptr, %offset : !pto.ptr -> !pto.ptr +``` + +### Shared Dialect Syntax Patterns + +PTO micro Instruction programs may interleave PTO ops with standard MLIR `arith` and `scf` ops. +The examples below emphasize common index-heavy patterns, but `arith` support is not limited to index arithmetic. + +**Scalar / index constant:** + +```mlir +%c0 = arith.constant 0 : index +%zero = arith.constant 0.0 : f32 +``` + +**Scalar arithmetic (integer / float / boolean-style bitwise):** + +```mlir +%sum_i = arith.addi %lhs_i, %rhs_i : i32 +%sum_f = arith.addf %lhs_f, %rhs_f : f32 +%bits = arith.andi %flags0, %flags1 : i32 +``` + +**Scalar compare and select:** + +```mlir +%cond = arith.cmpi eq, %lhs, %rhs : index +%bound = arith.select %cond, %a, %b : index +``` + +**Counted loop with loop-carried values:** + +```mlir +%result = scf.for %iv = %lb to %ub step %step + iter_args(%acc = %init) -> (index) { + %next = arith.addi %acc, %iv : index + scf.yield %next : index +} +``` + +**Structured conditional region:** + +```mlir +%selected = scf.if %cond -> (index) { + scf.yield %then_value : index +} else { + scf.yield %else_value : index +} +``` + +**Structured while loop:** + +```mlir +%state:2 = scf.while (%iv = %c0, %alive = %true) : (index, i1) -> (index, i1) { + %keep_going = arith.cmpi slt, %iv, %limit : index + scf.condition(%keep_going) %iv, %alive : index, i1 +} do { +^bb0(%iv_in: index, %alive_in: i1): + %iv_next = arith.addi %iv_in, %c1 : index + scf.yield %iv_next, %alive_in : index, i1 +} +``` + +### C-Style Semantics Convention + +For each ISA operation in Part III, semantics are expressed as C code. The convention: + +```c +// Vector register contents as arrays: +T dst[N]; // destination +T src0[N]; // first source +T src1[N]; // second source (binary ops) +T scalar; // scalar operand (vec-scalar ops) +int mask[N]; // per-lane predicate (0 or 1) + +// N = lane count determined by type: +// N = 256 for i8/si8/ui8 +// N = 128 for i16/si16/ui16/f16/bf16 +// N = 64 for i32/si32/ui32/f32 +// N = 32 for i64/si64/ui64 +``` + +**Example — pto.vadd semantics:** + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] + src1[i]; +``` + +**Example — pto.vcgadd (group reduction per VLane) semantics:** + +```c +int K = N / 8; // elements per VLane +for (int g = 0; g < 8; g++) { + T sum = 0; + for (int i = 0; i < K; i++) + sum += src[g*K + i]; + dst[g*K] = sum; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +### Template Placeholder Conventions + +| Placeholder | Meaning | +|-------------|---------| +| `"SRC_PIPE"`, `"DST_PIPE"` | Pipeline identifiers: `"PIPE_MTE2"`, `"PIPE_V"`, `"PIPE_MTE3"` | +| `"EVENT_ID"` | Event identifier: `"EVENT_ID0"` etc. | +| `"DIST"` | Distribution mode string (see the relevant load/store ISA group in Part III) | +| `"CMP_MODE"` | Compare predicate: `eq \| ne \| lt \| le \| gt \| ge` | +| `"RND"` | Rounding mode: `R \| A \| F \| C \| Z \| O` | +| `"SAT"` | Saturation: `SAT \| NOSAT` | +| `"PART"` | Half selector: `EVEN \| ODD` | +| `"PAT_*"` | Predicate pattern literal | +| `T` | Element type (f32, f16, bf16, i32, i16, i8, etc.) | +| `N` | Lane count (`N * bitwidth(T) = 2048`) | + +--- + +### C-Style Semantics Convention + +For each ISA operation in Part III, semantics are expressed as C code. The convention: + +```c +// Vector register contents as arrays: +T dst[N]; // destination +T src0[N]; // first source +T src1[N]; // second source (binary ops) +T scalar; // scalar operand (vec-scalar ops) +int mask[N]; // per-lane predicate (0 or 1) + +// N = lane count determined by type: +// N = 256 for i8/si8/ui8 +// N = 128 for i16/si16/ui16/f16/bf16 +// N = 64 for i32/si32/ui32/f32 +// N = 32 for i64/si64/ui64 +``` + +**Example — pto.vadd semantics:** + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] + src1[i]; +``` + +**Example — pto.vcgadd (group reduction per VLane) semantics:** + +```c +int K = N / 8; // elements per VLane +for (int g = 0; g < 8; g++) { + T sum = 0; + for (int i = 0; i < K; i++) + sum += src[g*K + i]; + dst[g*K] = sum; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +### Template Placeholder Conventions + +| Placeholder | Meaning | +|-------------|---------| +| `"SRC_PIPE"`, `"DST_PIPE"` | Pipeline identifiers: `"PIPE_MTE2"`, `"PIPE_V"`, `"PIPE_MTE3"` | +| `"EVENT_ID"` | Event identifier: `"EVENT_ID0"` etc. | +| `"DIST"` | Distribution mode string (see the relevant load/store ISA group in Part III) | +| `"CMP_MODE"` | Compare predicate: `eq \| ne \| lt \| le \| gt \| ge` | +| `"RND"` | Rounding mode: `R \| A \| F \| C \| Z \| O` | +| `"SAT"` | Saturation: `SAT \| NOSAT` | +| `"PART"` | Half selector: `EVEN \| ODD` | +| `"PAT_*"` | Predicate pattern literal | +| `T` | Element type (f32, f16, bf16, i32, i16, i8, etc.) | +| `N` | Lane count (`N * bitwidth(T) = 2048`) | + +--- + +## Part III: ISA Instruction Reference — Summary + +This section provides a categorized overview of all PTO micro Instruction operations plus the shared MLIR `arith` and `scf` ops that may appear in PTO micro Instruction programs. Detailed documentation for each group is included later in this merged document. + +--- + +## Instruction Groups + +| # | Group | Description | Count | Details | +|---|-------|-------------|-------|---------| +| 1 | [Pipeline Sync](#isa-01-pipeline-sync) | Intra-core pipeline synchronization | 5 | `pto.set_flag`, `pto.wait_flag`, `pto.pipe_barrier`, `pto.get_buf`, `pto.rls_buf` | +| 2 | [DMA Copy Programming](#isa-02-dma-copy) | DMA configuration and transfer between GM↔UB | 9 | `pto.set_loop*_stride_*`, `pto.set_loop_size_*`, `pto.copy_gm_to_ubuf`, `pto.copy_ubuf_to_ubuf`, `pto.copy_ubuf_to_gm` | +| 3 | [Vector Load/Store](#isa-03-vector-load-store) | UB↔vreg data movement with various access patterns | ~20 | `pto.vlds`, `pto.vldsx2`, `pto.vgather2`, `pto.vsts`, `pto.vstsx2`, `pto.vscatter`, etc. | +| 4 | [Predicate Load/Store](#isa-04-predicate-load-store) | UB↔mask register movement | 5 | `pto.plds`, `pto.pldi`, `pto.psts`, `pto.psti`, `pto.pstu` | +| 5 | [Materialization & Predicate Ops](#isa-05-materialization-predicate) | Scalar broadcast, predicate generation and manipulation | ~17 | `pto.vbr`, `pto.vdup`, `pto.pset_b*`, `pto.pge_b*`, `pto.plt_b*`, `pto.ppack`, `pto.punpack`, `pto.pnot`, `pto.psel`, etc. | +| 6 | [Unary Vector Ops](#isa-06-unary-vector-ops) | Single-input element-wise operations | 6 | `pto.vabs`, `pto.vexp`, `pto.vln`, `pto.vsqrt`, `pto.vrelu`, `pto.vnot` | +| 7 | [Binary Vector Ops](#isa-07-binary-vector-ops) | Two-input element-wise operations | 13 | `pto.vadd`, `pto.vsub`, `pto.vmul`, `pto.vdiv`, `pto.vmax`, `pto.vmin`, `pto.vand`, `pto.vor`, `pto.vxor`, `pto.vshl`, `pto.vshr`, `pto.vaddc`, `pto.vsubc` | +| 8 | [Vec-Scalar Ops](#isa-08-vec-scalar-ops) | Vector-scalar operations | 9 | `pto.vadds`, `pto.vmuls`, `pto.vmaxs`, `pto.vmins`, `pto.vlrelu`, `pto.vshls`, `pto.vshrs`, `pto.vaddcs`, `pto.vsubcs` | +| 9 | [Conversion Ops](#isa-09-conversion-ops) | Type conversion with rounding/saturation control | 2 | `pto.vcvt`, `pto.vtrc` | +| 10 | [Reduction Ops](#isa-10-reduction-ops) | Vector reductions | 7 | `pto.vcadd`, `pto.vcmax`, `pto.vcmin`, `pto.vcgadd`, `pto.vcgmax`, `pto.vcgmin`, `pto.vcpadd` | +| 11 | [Compare & Select](#isa-11-compare-select) | Comparison and conditional selection | 4 (+1 not A5) | `pto.vcmp`, `pto.vcmps`, `pto.vsel`, `pto.vselr` (`pto.vselrv2` removed: not A5) | +| 12 | [Data Rearrangement](#isa-12-data-rearrangement) | In-register data movement and permutation | 2 (+2 not A5) | `pto.vintlv`, `pto.vdintlv` (`pto.vintlvv2`, `pto.vdintlvv2` removed: not A5) | +| 13 | [DSA/SFU Ops](#isa-13-dsa-sfu-ops) | Specialized ops, index generation, and sorting helpers | 9 | `pto.vlrelu`, `pto.vprelu`, `pto.vexpdiff`, `pto.vaxpy`, `pto.vmull`, `pto.vmula`, `pto.vci`, `pto.vbitsort`, `pto.vmrgsort4` | +| 14 | [Arith (Shared MLIR Dialect)](#isa-14-shared-arith) | Full scalar `arith` surface used around PTO ops; the companion page lists categories and representative examples | all scalar ops | `arith.constant`, `arith.addi`, `arith.addf`, `arith.cmpi`, `arith.cmpf`, `arith.select`, `arith.index_cast`, `arith.extsi`, `arith.trunci`, `arith.andi`, `arith.shli`, etc. | +| 15 | [SCF (Shared MLIR Dialect)](#isa-15-shared-scf) | Structured loops, branches, and loop-carried state around PTO regions | 5 | `scf.for`, `scf.if`, `scf.while`, `scf.condition`, `scf.yield` | + +--- + +## Detailed ISA Group Reference + +This section inlines the 15 ISA group documents so the architectural overview, notation, summary table, and per-group semantics can be read in a single file. + + + +### 1. Pipeline Synchronization + +> **Category:** Synchronization primitives for coordinating pipeline execution +> **Pipelines:** MTE2 (GM→UB), PIPE_V (Vector), MTE3 (UB→GM) + +The PTO micro Instruction model operates on the Ascend 950's **Decoupled Access-Execute** architecture. The MTE and Vector pipelines run asynchronously, requiring explicit synchronization to prevent data hazards. + +--- + +#### Intra-Core Pipeline Sync + +These ops coordinate data flow between pipelines within a single vector core. + +##### `pto.set_flag` + +- **syntax:** `pto.set_flag["SRC_PIPE", "DST_PIPE", "EVENT_ID"]` +- **semantics:** Signal event from source pipe to destination pipe. + +```c +set_flag(src_pipe, dst_pipe, event_id); +``` + +**Example:** After MTE2 completes GM→UB transfer, signal Vector pipe: +```mlir +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +--- + +##### `pto.wait_flag` + +- **syntax:** `pto.wait_flag["SRC_PIPE", "DST_PIPE", "EVENT_ID"]` +- **semantics:** Block destination pipe until source pipe signals event. + +```c +wait_flag(src_pipe, dst_pipe, event_id); +``` + +**Example:** Vector pipe waits for MTE2 data to arrive: +```mlir +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +--- + +##### `pto.pipe_barrier` + +- **syntax:** `pto.pipe_barrier "PIPE_*"` +- **semantics:** Drain all pending ops in the specified pipe. All previously issued operations on that pipe complete before any subsequent operation begins. + +```c +pipe_barrier(pipe); +``` + +**Pipe identifiers:** `PIPE_MTE2`, `PIPE_V`, `PIPE_MTE3` + +**Example:** Two back-to-back `copy_ubuf_to_gm` calls writing to the same GM address. Without a barrier, MTE3 may reorder them and the final GM value is non-deterministic: + +```mlir +// Both stores target the same GM address — order matters! +pto.copy_ubuf_to_gm %ub_partial_0, %gm_result, ... +// Without pipe_barrier, MTE3 could execute the second copy before the first +// completes, producing a non-deterministic result at %gm_result. +pto.pipe_barrier "PIPE_MTE3" +// After barrier: first copy is guaranteed complete. Second copy overwrites deterministically. +pto.copy_ubuf_to_gm %ub_partial_1, %gm_result, ... +``` + +--- + +##### `pto.get_buf` + +- **syntax:** `pto.get_buf "PIPE_*", %buf_id, %mode : i64, i64` +- **semantics:** Acquire buffer slot for inter-pipeline double-buffering coordination. + +```c +get_buf(pipe, buf_id, mode); +``` + +--- + +##### `pto.rls_buf` + +- **syntax:** `pto.rls_buf "PIPE_*", %buf_id, %mode : i64, i64` +- **semantics:** Release buffer slot to allow other pipeline to proceed. + +```c +rls_buf(pipe, buf_id, mode); +``` + +--- + +##### Mode Parameter for `get_buf` / `rls_buf` + +The `mode` parameter controls how `get_buf` and `rls_buf` interact with pipeline execution and dependency tracking: + +| Mode | `get_buf` Behavior | `rls_buf` Behavior | Use Case | +|------|-------------------|-------------------|----------| +| **0** (default) | **Blocking acquire**: waits for all previous `rls_buf` with same `buf_id` from all pipelines (in program order) before the specified pipe can proceed | **Immediate release**: signals completion for only the instructions related to the specified pipe | **Automatic ping/pong dependency** — recommended for double/multi-buffering | +| **1** | **Non-blocking acquire**: does not wait; pipe execution proceeds immediately | **Deferred release**: waits for all instructions across all pipelines with same `buf_id` to retire before signaling | **Backward compatibility** with `set_flag`/`wait_flag` semantics | + +**Mode 0 (Default — Recommended):** +- `get_buf`: The specified pipeline blocks until all previous `rls_buf` operations for the same buffer ID (from any pipeline) have completed, respecting program order. +- `rls_buf`: Immediately signals that the specified pipeline has finished using the buffer — only waits for that pipe's related instructions. +- This mode provides **automatic RAW/WAR/WAW dependency resolution** based on buffer ID and program order, making it ideal for ping/pong and N-buffer patterns. + +**Mode 1 (Legacy Compatibility):** +- `get_buf`: Does not block — the pipeline proceeds immediately without waiting. +- `rls_buf`: Waits for **all** previous instructions across **all** pipelines with the same buffer ID to retire before signaling release. +- This mode emulates `set_flag`/`wait_flag` behavior and is provided for backward compatibility with existing code patterns. + +> **Note:** A5 supports both `set_flag`/`wait_flag` and `get_buf`/`rls_buf` mechanisms. Mode 1 is rarely needed since mode 0 provides a more programmer-friendly approach for buffer-based synchronization. + +--- + +##### `pto.mem_bar` + +- **syntax:** `pto.mem_bar "BARRIER_TYPE"` +- **semantics:** Intra-vector-pipe memory fence within `__VEC_SCOPE__`. Required when UB addresses alias between vector load/store operations. + +```c +mem_bar(barrier_type); +``` + +**Barrier types:** + +| Type | Semantics | +|------|-----------| +| `VV_ALL` | All prior vector ops complete before subsequent | +| `VST_VLD` | All prior vector stores visible before subsequent loads | +| `VLD_VST` | All prior vector loads complete before subsequent stores | + +**Example:** Ensure stores are visible before loads to same UB region: +```mlir +pto.vsts %v0, %ub[%c0] : !pto.vreg<64xf32>, !pto.ptr +pto.mem_bar "VST_VLD" +%v1 = pto.vlds %ub[%c0] : !pto.ptr -> !pto.vreg<64xf32> +``` + +--- + +#### Why `get_buf` / `rls_buf` is More Programmer-Friendly + +The buffer-based synchronization (`get_buf`/`rls_buf`) provides the **same functional capability** as `set_flag`/`wait_flag` for maintaining correct ordering of RAW/WAR/WAW dependencies across pipelines, but with significant usability advantages: + +##### 1. No Manual Priming or Draining + +With `set_flag`/`wait_flag`, ping/pong loops require: +- **Pre-loop priming**: 4× `set_flag` to initialize reverse-dependency signals (otherwise first iteration deadlocks) +- **Post-loop draining**: 4× `wait_flag` to consume leftover signals from final iterations + +With `get_buf`/`rls_buf`: +- **First iteration**: Buffer is initially free, so `get_buf` proceeds immediately — no priming needed +- **Final iteration**: Last `rls_buf` simply completes — no draining required + +##### 2. No Loop Peeling for Complex Dependencies + +For non-1:1 producer-consumer ratios (e.g., 1 MTE2 load : N Vector compute slices), `set_flag`/`wait_flag` requires **peeling the set_flag outside the loop**: + +```mlir +// set_flag/wait_flag: 1 MTE2 load, 8 Vector computes on slices +// MTE2 loads large tile once +pto.copy_gm_to_ubuf %gm_ptr, %ub_tile, ... +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVT_TILE_READY"] // ◀ MUST be outside loop + +// Vector consumes in 8 slices — but wait_flag can only fire ONCE +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVT_TILE_READY"] // ◀ MUST peel before loop +scf.for %slice = %c0 to %c8 step %c1 { + // compute on %ub_tile[%slice] + // Cannot put wait_flag here — would deadlock on iteration 1+ +} +``` + +With `get_buf`/`rls_buf`, acquire/release can be **inside the loop** — no peeling needed: + +```mlir +// get_buf/rls_buf: same 1:8 pattern, acquire/release inside loop works fine +// MTE2 loads large tile +pto.get_buf "PIPE_MTE2", %bufid_tile, %c0 : i64, i64 +pto.copy_gm_to_ubuf %gm_ptr, %ub_tile, ... +pto.rls_buf "PIPE_MTE2", %bufid_tile, %c0 : i64, i64 + +// Vector acquires/releases per slice — all 8 iterations work correctly +scf.for %slice = %c0 to %c8 step %c1 { + pto.get_buf "PIPE_V", %bufid_tile, %c0 : i64, i64 // iteration 0: blocks until MTE2 done + // iteration 1-7: proceeds immediately (already acquired) + // compute on %ub_tile[%slice] + pto.rls_buf "PIPE_V", %bufid_tile, %c0 : i64, i64 +} +// No peeling required — get_buf handles the MTE2→V dependency automatically +``` + +##### 3. Simpler Mental Model + +| Aspect | `set_flag`/`wait_flag` | `get_buf`/`rls_buf` | +|--------|------------------------|---------------------| +| **Dependency tracking** | Manual: track event IDs, signal directions, pair every set with wait | Automatic: buffer ID + program order | +| **Event ID management** | **8 IDs per pipe-pair direction** (HW limit); each buffer occupies 1 ID per direction | **1 buffer ID per shared resource** (HW limit: 32 global); same ID used across all pipelines | +| **Error-prone areas** | Forgetting prime/drain, mismatched IDs, wrong direction | Forgetting release (but compile-time checkable) | + +##### Quick Example Comparison + +**Problem:** MTE2 loads into `buf[i%2]`, Vector processes, MTE3 stores — standard ping/pong. + +**set_flag/wait_flag approach:** +```mlir +// BEFORE loop: prime 4 reverse-dep signals +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_0"] +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_1"] +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_0"] +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_1"] + +scf.for %i = ... { + // 4 set_flag + 4 wait_flag inside loop + // Must track 4 IDs: 2 pipe-pair directions × 2 ping/pong buffers +} + +// AFTER loop: drain 4 signals +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_0"] +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_1"] +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_0"] +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_1"] +``` + +**get_buf/rls_buf approach:** +```mlir +scf.for %i = ... { + pto.get_buf %bufid_in[%pp], "PIPE_MTE2" + // ... MTE2 work ... + pto.rls_buf %bufid_in[%pp], "PIPE_MTE2" + + pto.get_buf %bufid_in[%pp], "PIPE_V" + pto.get_buf %bufid_out[%pp], "PIPE_V" + // ... Vector work ... + pto.rls_buf %bufid_in[%pp], "PIPE_V" + pto.rls_buf %bufid_out[%pp], "PIPE_V" + + pto.get_buf %bufid_out[%pp], "PIPE_MTE3" + // ... MTE3 work ... + pto.rls_buf %bufid_out[%pp], "PIPE_MTE3" +} +// Done. No prime. No drain. Dependencies resolved by buffer ID + program order. +``` + +--- + +#### Intra-Core Sync Patterns & Examples + +##### Example 1: `set_flag` / `wait_flag` (Explicit Events) + +Each cross-pipeline data dependency requires an explicit signal/wait pair. The programmer must manually insert `set_flag` after the producer and `wait_flag` before the consumer. + +```mlir +// ─── Stage 1: MTE2 loads data from GM into UB ─── +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, ... + +// MTE2 signals: "UB data is ready for Vector pipe" +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +// ─── Stage 2: Vector pipe consumes UB data ─── +// Vector waits until MTE2's signal arrives +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +pto.vecscope { + %v = pto.vlds %ub_ptr[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} + +// Vector signals: "UB output is ready for MTE3" +pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + +// ─── Stage 3: MTE3 stores result from UB back to GM ─── +// MTE3 waits until Vector's signal arrives +pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + +pto.copy_ubuf_to_gm %ub_out, %gm_out, ... +``` + +**Key property:** Every cross-pipeline edge is an explicit `(set_flag, wait_flag)` pair. Simple for straight-line code, but gets verbose in loops (see Example 3). + +--- + +##### Example 2: `get_buf` / `rls_buf` (Resource-Based) + +Instead of naming events, each pipeline declares when it **acquires** (`get_buf`) and **releases** (`rls_buf`) a shared UB buffer. Cross-pipeline RAW/WAR dependencies are resolved implicitly by program order — if MTE2 releases `buf_A` and Vector later acquires `buf_A`, the hardware ensures the acquire cannot proceed until the release completes. + +```mlir +// ─── Stage 1: MTE2 loads data into UB ─── +// MTE2 acquires ub_ptr — blocks if Vector hasn't released it from a prior iteration +pto.get_buf "PIPE_MTE2", %bufid_ub_ptr, %c0 : i64, i64 // mode=0 (default) +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, ... +// MTE2 done writing ub_ptr — release it so Vector can consume +pto.rls_buf "PIPE_MTE2", %bufid_ub_ptr, %c0 : i64, i64 + +// ─── Stage 2: Vector computation ─── +// Vector acquires ub_ptr (input) — blocks until MTE2 releases it (RAW: MTE2 write → V read) +pto.get_buf "PIPE_V", %bufid_ub_ptr, %c0 : i64, i64 +// Vector acquires ub_out (output) — blocks until MTE3 releases it from a prior iteration (WAR: MTE3 read → V write) +pto.get_buf "PIPE_V", %bufid_ub_out, %c0 : i64, i64 + +pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub_ptr[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} + +// Vector done reading ub_ptr — release so MTE2 can reuse it in next iteration +pto.rls_buf "PIPE_V", %bufid_ub_ptr, %c0 : i64, i64 +// Vector done writing ub_out — release so MTE3 can consume +pto.rls_buf "PIPE_V", %bufid_ub_out, %c0 : i64, i64 + +// ─── Stage 3: MTE3 stores result to GM ─── +// MTE3 acquires ub_out — blocks until Vector releases it (RAW: V write → MTE3 read) +pto.get_buf "PIPE_MTE3", %bufid_ub_out, %c0 : i64, i64 +pto.copy_ubuf_to_gm %ub_out, %gm_out, ... +// MTE3 done reading ub_out — release so Vector can reuse it in next iteration +pto.rls_buf "PIPE_MTE3", %bufid_ub_out, %c0 : i64, i64 +``` + +**Key property:** No event IDs needed. Dependencies are implicit from program order of `get_buf`/`rls_buf` on the same buffer ID. This becomes much more convenient in multi-iteration loops (see Example 3). + +--- + +##### Example 3: Ping/Pong Double-Buffering Loop + +Double-buffering overlaps DMA and compute by using two UB buffers alternately. All three stages (MTE2, Vector, MTE3) appear in the **same iteration** — the hardware pipelines them across iterations because different iterations operate on different buffers (`buf[i%2]`). + +###### Event ID scheme (`set_flag` / `wait_flag`) + +With 2 ping/pong buffers and 2 pipeline pairs (MTE2↔V, V↔MTE3), `set_flag`/`wait_flag` needs **8 event IDs** = 2 pipe-pairs × 2 buffers × (forward + reverse): + +**MTE2 ↔ Vector (input buffers):** + +| Event ID | Direction | Purpose | +|----------|-----------|---------| +| `EVT_IN_FWD_0` | MTE2 → V | RAW: buf_in[0] data ready | +| `EVT_IN_FWD_1` | MTE2 → V | RAW: buf_in[1] data ready | +| `EVT_IN_REV_0` | V → MTE2 | WAR: Vector done reading buf_in[0] | +| `EVT_IN_REV_1` | V → MTE2 | WAR: Vector done reading buf_in[1] | + +**Vector ↔ MTE3 (output buffers):** + +| Event ID | Direction | Purpose | +|----------|-----------|---------| +| `EVT_OUT_FWD_0` | V → MTE3 | RAW: buf_out[0] result ready | +| `EVT_OUT_FWD_1` | V → MTE3 | RAW: buf_out[1] result ready | +| `EVT_OUT_REV_0` | MTE3 → V | WAR: MTE3 done reading buf_out[0] | +| `EVT_OUT_REV_1` | MTE3 → V | WAR: MTE3 done reading buf_out[1] | + +###### 3a. `set_flag` / `wait_flag` version + +```mlir +// ═══ Pre-loop: prime ALL reverse-dependency signals ═══ +// Both input and output buffers start unused. We must pre-send +// reverse-dep signals so the first iteration's wait_flags don't deadlock. +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_0"] // ◀ PRIME: buf_in[0] "free" +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_1"] // ◀ PRIME: buf_in[1] "free" +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_0"] // ◀ PRIME: buf_out[0] "free" +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_1"] // ◀ PRIME: buf_out[1] "free" + +scf.for %i = %c0 to %N step %c1 { + // ── All 3 stages in same iteration, indexed by i%2 ── + // %pp = i % 2 (ping/pong selector for buffer & event IDs) + + // ── MTE2: load tile[i] into buf_in[i%2] ── + // WAR: wait until Vector has released buf_in[i%2] from iteration i-2 + pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{pp}"] + pto.copy_gm_to_ubuf %gm_ptr[%i], %ub_in[%pp], ... + // RAW: signal Vector that buf_in[i%2] data is ready + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVT_IN_FWD_{pp}"] + + // ── Vector: compute buf_in[i%2] → buf_out[i%2] ── + // RAW: wait for MTE2 to finish loading buf_in[i%2] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVT_IN_FWD_{pp}"] + // WAR: wait for MTE3 to finish reading buf_out[i%2] from iteration i-2 + pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{pp}"] + scf.for %dummy = %c0 to %c1 step %c1 { + %v = pto.vlds %ub_in[%pp][%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%pp][%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } {llvm.loop.aivector_scope} + // WAR: tell MTE2 "done reading buf_in[i%2]" + pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{pp}"] + // RAW: tell MTE3 "buf_out[i%2] result ready" + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVT_OUT_FWD_{pp}"] + + // ── MTE3: store result from buf_out[i%2] to GM ── + // RAW: wait for Vector to finish writing buf_out[i%2] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVT_OUT_FWD_{pp}"] + pto.copy_ubuf_to_gm %ub_out[%pp], %gm_out[%i], ... + // WAR: tell Vector "done reading buf_out[i%2]" + pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{pp}"] +} + +// ═══ Post-loop: drain — match every pre-loop prime with a wait ═══ +// Each priming set_flag must be paired. The last loop iteration's +// set_flags are consumed by wait_flags that will never fire inside the +// loop (there is no iteration i+2). Drain them here. +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{(N-1)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{(N-2)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{(N-1)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{(N-2)%2}"] // ◀ DRAIN +``` + +**What `set_flag`/`wait_flag` requires outside the loop:** +- **Before the loop (4 × `set_flag`):** Prime every reverse-dependency event ID — one per buffer per pipe-pair. Without this, the first iteration's `wait_flag` for reverse deps would deadlock (no signal was ever sent). +- **After the loop (4 × `wait_flag`):** Drain the matching reverse-dep signals from the last iterations. Every `set_flag` must be paired with a `wait_flag` — the last loop iterations produce signals that no subsequent iteration consumes, so they must be drained explicitly. + +###### 3b. `get_buf` / `rls_buf` version + +Same ping/pong double-buffering, but **no pre-loop priming or post-loop draining needed.** Buffer acquire/release semantics handle everything. + +```mlir +scf.for %i = %c0 to %N step %c1 { + // %pp = i % 2 (ping/pong selector) + + // ── MTE2: load tile[i] into buf[i%2] ── + // Acquires buf[i%2] — on first iteration, buffer is free so proceeds immediately. + // On later iterations, blocks until Vector releases buf[i%2] (WAR: automatic). + pto.get_buf "PIPE_MTE2", %bufid_buf[%pp], %c0 : i64, i64 // mode=0 + pto.copy_gm_to_ubuf %gm_ptr[%i], %ub_buf[%pp], ... + pto.rls_buf "PIPE_MTE2", %bufid_buf[%pp], %c0 : i64, i64 + + // ── Vector: compute on buf[i%2] ── + // Acquires buf[i%2] — blocks until MTE2 releases it (RAW: automatic) + pto.get_buf "PIPE_V", %bufid_buf[%pp], %c0 : i64, i64 + pto.get_buf "PIPE_V", %bufid_out[%pp], %c0 : i64, i64 + scf.for %dummy = %c0 to %c1 step %c1 { + %v = pto.vlds %ub_buf[%pp][%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%pp][%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } {llvm.loop.aivector_scope} + // Release buf[i%2] — MTE2 can reuse in iteration i+2 (WAR resolved) + pto.rls_buf "PIPE_V", %bufid_buf[%pp], %c0 : i64, i64 + pto.rls_buf "PIPE_V", %bufid_out[%pp], %c0 : i64, i64 + + // ── MTE3: store result ── + // Acquires out[i%2] — blocks until Vector releases it (RAW: automatic) + pto.get_buf "PIPE_MTE3", %bufid_out[%pp], %c0 : i64, i64 + pto.copy_ubuf_to_gm %ub_out[%pp], %gm_out[%i], ... + pto.rls_buf "PIPE_MTE3", %bufid_out[%pp], %c0 : i64, i64 +} +// No post-loop drain needed — last rls_buf completes the pipeline. +``` + +**No priming, no draining, no event IDs.** The acquire/release protocol on buffer IDs indexed by `i%2` implicitly resolves all cross-pipeline dependencies: +- **RAW** (MTE2→V): Vector's `get_buf` blocks until MTE2's `rls_buf` on `buf[i%2]` +- **WAR** (V→MTE2): MTE2's `get_buf` in iteration `i+2` blocks until Vector's `rls_buf` in iteration `i` (same buffer) +- **First iteration:** Buffer is initially free, so `get_buf` proceeds without blocking — no priming needed + +--- + +#### Comparison Summary + +| Aspect | `set_flag` / `wait_flag` | `get_buf` / `rls_buf` | +|--------|--------------------------|------------------------| +| Dependency model | Explicit event signals | Implicit via buffer acquire/release | +| IDs per pipe-pair | 2 IDs per buffer: 1 for forward (e.g., MTE2→V) + 1 for reverse (V→MTE2) | **1 ID per buffer** (handles both directions automatically) | +| Total HW IDs | **8 per pipe-pair** (hardware limit) | **32 global** across all pipes | +| Reverse (WAR) deps | Extra `set_flag`/`wait_flag` pair per buffer | Handled automatically | +| Pre-loop setup | `set_flag` to prime each reverse dep | **None** | +| Post-loop teardown | `wait_flag` to drain all primed signals | **None** | +| Loop peeling for complex deps | Required for non-1:1 or nested loops | **Not required** | +| Straight-line code | Simple, clear | Slightly more verbose (bracket each stage) | +| Ping/pong loops | 8 event IDs + 4 prime + 4 drain | Same pattern, **no overhead** | +| Best used for | Simple pipelines, fine-grained control | Double/multi-buffering, complex loops | + +--- + +#### Inter-Core Sync + +> **Note:** Inter-core sync is only needed for **mixed Cube+Vector tasks** where Cube produces data that Vector consumes (or vice versa). **Vec-only tasks can ignore this section entirely.** + +These ops coordinate execution across the Cube block and Vector subblocks within a cluster. Each core cluster consists of **1 Cube block : 2 Vector subblocks**, each with its own **SU (Sequencer Unit)** running independent instruction streams. + +``` +Core Cluster (1:2 ratio) +┌─────────────────────────────────────────────┐ +│ ┌──────────────┐ ┌──────────────┐ │ +│ │ AIC (Cube) │ │ AIV0 (Vec) │ │ +│ │ ┌────────┐ │ │ ┌────────┐ │ │ +│ │ │ SU │──┼────┼──│ SU │ │ │ +│ │ └────────┘ │ │ └────────┘ │ │ +│ │ CUBE pipe │ │ MTE2/V/MTE3 │ │ +│ │ L0C buffer │ │ UB (256KB) │ │ +│ └──────────────┘ └──────────────┘ │ +│ ┌──────────────┐ │ +│ │ AIV1 (Vec) │ │ +│ │ ┌────────┐ │ │ +│ │ │ SU │ │ │ +│ │ └────────┘ │ │ +│ │ MTE2/V/MTE3 │ │ +│ │ UB (256KB) │ │ +│ └──────────────┘ │ +└─────────────────────────────────────────────┘ +``` + +##### Platform Comparison + +| Aspect | A2A3 (Ascend 910) | A5 (Ascend 950) | +|--------|-------------------|-----------------| +| **Signal op** | `set_cross_core` (mode2) | `set_intra_block` | +| **Wait op** | `wait_flag_dev` | `wait_intra_core` | +| **Wait behavior** | SU-level blocking (entire core stalls) | Per-pipeline (only named pipe stalls) | +| **Semaphore pool** | 16 IDs per cluster, 4-bit counter | 16 IDs, but 32-ID address space (see below) | +| **C→V** | **Broadcast**: one `set` reaches both AIV0+AIV1 | **1:1**: separate `set` per subblock required | +| **V→C** | **Reduce**: Cube waits for both subblocks in one `wait` | **1:1**: Cube needs separate `wait` per subblock | + +##### A2A3: `set_cross_core` / `wait_flag_dev` + +```c +// mode2 broadcast/reduce semantics for 1:2 cluster +set_cross_core(pipe, semaphore_id); // pipe: VEC/MTE2/CUBE/FIX +wait_flag_dev(semaphore_id); // SU-level blocking +``` + +``` +C→V Broadcast (one set reaches both): + AIC ──set_cross_core──┬──> AIV0 sema++ + └──> AIV1 sema++ + +V→C Reduce (one wait for both): + AIV0 ──set_cross_core──┐ + ├──> AIC wait_flag_dev (blocks until BOTH) + AIV1 ──set_cross_core──┘ +``` + +##### `pto.set_cross_core` + +- **syntax:** `pto.set_cross_core %core_id, %event_id : i64, i64` +- **semantics:** Signal event to another core. Uses **mode2** for 1:2 cluster on A2A3. + +##### `pto.wait_flag_dev` + +- **syntax:** `pto.wait_flag_dev %core_id, %event_id : i64, i64` +- **semantics:** Wait for event from another core. **SU-level blocking** — entire core stalls. + +##### A5: `set_intra_block` / `wait_intra_core` + +```c +set_intra_block(trigger_pipe, semaphore_id); +wait_intra_core(wait_pipe, semaphore_id); // only named pipe stalls +``` + +**A5 semaphore address space:** The hardware has **16 physical semaphore IDs** but exposes a **32-ID address space** to support 1:1 signaling to each subblock: + +| ID Range | Target | +|----------|--------| +| 0–15 | AIV0 (subblock 0) | +| 16–31 (+15 offset) | AIV1 (subblock 1) | + +This means C→V requires **separate `set_intra_block` calls** per subblock (no broadcast), and V→C requires **separate `wait_intra_core` calls** per subblock (no hardware reduce). + +``` +C→V on A5 (1:1, no broadcast — need two sets): + AIC ──set_intra_block(pipe, sema_id)────> AIV0 + AIC ──set_intra_block(pipe, sema_id+15)──> AIV1 + +V→C on A5 (1:1, no reduce — need two waits): + AIV0 ──set_intra_block──> AIC wait_intra_core(pipe, sema_id) + AIV1 ──set_intra_block──> AIC wait_intra_core(pipe, sema_id+15) // extra wait +``` + +##### `pto.set_intra_block` + +- **syntax:** `pto.set_intra_block %block_id, %event_id : i64, i64` +- **semantics:** Signal event within a block (A5). Specifies **trigger pipe**. 1:1 per subblock. + +##### `pto.wait_intra_core` + +- **syntax:** `pto.wait_intra_core %block_id, %event_id : i64, i64` +- **semantics:** Wait for event within block (A5). Specifies **which pipeline should wait** — only that pipe stalls, SU and other pipes continue. + +##### Wait Granularity Comparison + +``` +A2A3 wait_flag_dev (SU-level stall): + SU ──┬── PIPE_MTE2 ───╳ ALL STALLED + ├── PIPE_V ───╳ ALL STALLED + └── PIPE_MTE3 ───╳ ALL STALLED + +A5 wait_intra_core "PIPE_MTE2" (per-pipe stall): + SU ──┬── PIPE_MTE2 ───╳ STALLED (waiting for Cube) + ├── PIPE_V ─── ✓ RUNNING + └── PIPE_MTE3 ─── ✓ RUNNING +``` + + + +### 2. DMA Copy Programming + +> **Category:** DMA transfer configuration and execution +> **Pipelines:** MTE2 (GM→UB), MTE3 (UB→GM) + +DMA transfers move data between Global Memory (GM) and Unified Buffer (UB). The MTE engines operate asynchronously from the Vector core, requiring explicit sync (see [Pipeline Sync](#isa-01-pipeline-sync)). + +The MTE2/MTE3 DMA engine executes a **multi-level nested loop** transfer. Before issuing the copy instruction, stride and loop-size registers must be configured. + +--- + +#### Loop Stride Configuration (GM→UB) + +These ops configure the MTE2 DMA engine's hardware loops for GM→UB transfers. They must be set **before** calling `pto.copy_gm_to_ubuf`. + +##### `pto.set_loop_size_outtoub` + +- **syntax:** `pto.set_loop_size_outtoub %loop1_count, %loop2_count : i64, i64` +- **semantics:** Configure HW loop iteration counts for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%loop1_count` | 21 bits | Inner HW loop iteration count | +| `%loop2_count` | 21 bits | Outer HW loop iteration count | + +When not using multi-level looping, set both to 1. + +--- + +##### `pto.set_loop2_stride_outtoub` + +- **syntax:** `pto.set_loop2_stride_outtoub %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure outer loop (loop2) pointer advance for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 40 bits | GM source pointer advance per loop2 iteration (bytes) | +| `%dst_stride` | 21 bits | UB destination pointer advance per loop2 iteration (bytes) | + +After each loop2 iteration, the DMA engine advances the GM read pointer by `%src_stride` and UB write pointer by `%dst_stride`. + +--- + +##### `pto.set_loop1_stride_outtoub` + +- **syntax:** `pto.set_loop1_stride_outtoub %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure inner loop (loop1) pointer advance for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 40 bits | GM source pointer advance per loop1 iteration (bytes) | +| `%dst_stride` | 21 bits | UB destination pointer advance per loop1 iteration (bytes) | + +--- + +#### Loop Stride Configuration (UB→GM) + +These ops configure the MTE3 DMA engine's hardware loops for UB→GM transfers. They must be set **before** calling `pto.copy_ubuf_to_gm`. + +Note: UB stride fields are 21 bits (sufficient for 256KB UB address space), GM stride fields are 40 bits (full GM address range). + +##### `pto.set_loop_size_ubtoout` + +- **syntax:** `pto.set_loop_size_ubtoout %loop1_count, %loop2_count : i64, i64` +- **semantics:** Configure HW loop iteration counts for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%loop1_count` | 21 bits | Inner HW loop iteration count | +| `%loop2_count` | 21 bits | Outer HW loop iteration count | + +--- + +##### `pto.set_loop2_stride_ubtoout` + +- **syntax:** `pto.set_loop2_stride_ubtoout %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure outer loop (loop2) pointer advance for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 21 bits | UB source pointer advance per loop2 iteration (bytes) | +| `%dst_stride` | 40 bits | GM destination pointer advance per loop2 iteration (bytes) | + +--- + +##### `pto.set_loop1_stride_ubtoout` + +- **syntax:** `pto.set_loop1_stride_ubtoout %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure inner loop (loop1) pointer advance for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 21 bits | UB source pointer advance per loop1 iteration (bytes) | +| `%dst_stride` | 40 bits | GM destination pointer advance per loop1 iteration (bytes) | + +--- + +#### DMA Transfer Execution + +##### `pto.copy_gm_to_ubuf` + +- **syntax:** +```mlir +pto.copy_gm_to_ubuf %gm_src, %ub_dst, + %sid, %n_burst, %len_burst, %left_padding, %right_padding, + %data_select_bit, %l2_cache_ctl, %src_stride, %dst_stride + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` +- **semantics:** DMA transfer from Global Memory (`!pto.ptr`) to Unified Buffer (`!pto.ptr`). + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%gm_src` | GM source pointer (`!pto.ptr`) | +| `%ub_dst` | UB destination pointer (`!pto.ptr`, 32B-aligned) | +| `%sid` | Stream ID (usually 0) | +| `%n_burst` | Number of burst rows (innermost loop count) | +| `%len_burst` | Contiguous bytes transferred per burst row | +| `%left_padding` | Left padding count (bytes) | +| `%right_padding` | Right padding count (bytes) | +| `%data_select_bit` | Padding / data-select control bit (`i1`) | +| `%l2_cache_ctl` | L2 cache allocate control (TBD — controls whether DMA allocates in L2 cache) | +| `%src_stride` | GM source stride: start-to-start distance between consecutive burst rows (bytes) | +| `%dst_stride` | UB destination stride: start-to-start distance between consecutive burst rows (bytes, 32B-aligned) | + +--- + +##### `pto.copy_ubuf_to_gm` + +- **syntax:** +```mlir +pto.copy_ubuf_to_gm %ub_src, %gm_dst, + %sid, %n_burst, %len_burst, %reserved, %dst_stride, %src_stride + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` +- **semantics:** DMA transfer from Unified Buffer (`!pto.ptr`) to Global Memory (`!pto.ptr`). MTE3 reads only `len_burst` bytes from each UB row (de-padding). + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%ub_src` | UB source pointer (`!pto.ptr`, 32B-aligned) | +| `%gm_dst` | GM destination pointer (`!pto.ptr`) | +| `%sid` | Stream ID (usually 0) | +| `%n_burst` | Number of burst rows | +| `%len_burst` | Contiguous bytes transferred per burst row | +| `%reserved` | Reserved field (set to 0) | +| `%dst_stride` | GM destination stride: start-to-start distance between consecutive burst rows (bytes) | +| `%src_stride` | UB source stride: start-to-start distance between consecutive burst rows (bytes, 32B-aligned) | + +--- + +##### `pto.copy_ubuf_to_ubuf` + +- **syntax:** +```mlir +pto.copy_ubuf_to_ubuf %source, %dest, %sid, %n_burst, %len_burst, %src_stride, %dst_stride + : !pto.ptr, !pto.ptr, i64 x5 +``` +- **semantics:** Copy within Unified Buffer. + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%source` | UB source pointer | +| `%dest` | UB destination pointer | +| `%sid` | Stream ID | +| `%n_burst` | Number of bursts | +| `%len_burst` | Length per burst | +| `%src_stride` | Source stride | +| `%dst_stride` | Destination stride | + +--- + +#### Burst / Stride / Pad Model + +All A5 DMA addresses are **stride-based**: stride is the distance from the start of one row to the start of the next row (`stride >= lenBurst`). There is no separate "gap" parameter. + +##### Key Terms + +``` +burst = lenBurst contiguous bytes transferred per row +stride = distance (bytes) from start of row[r] to start of row[r+1] +pad = ub_stride - lenBurst, padded to the 32B alignment boundary +``` + +##### Alignment Constraints + +- **UB addresses** (both source and destination) must be **32-byte aligned**. +- **GM→UB padding**: When `data_select_bit = true`, each UB row is padded from `lenBurst` up to the **32B-aligned boundary** of `ub_stride` with `pad_val` (set via `set_mov_pad_val`). This ensures every UB row starts at a 32B-aligned offset. +- **UB→GM de-padding**: MTE3 reads `lenBurst` bytes from each 32B-aligned UB row (skipping any padding that was added during load), writing only valid data to GM. This effectively strips padding on store. + +##### 2D Diagram: GM→UB (pto.copy_gm_to_ubuf) + +``` +GM (source, `!pto.ptr`): + + |<--- src_stride (start-to-start) --->| + |<- len_burst ->| | +Row 0: [##DATA########]......................| +Row 1: [##DATA########]......................| +Row 2: [##DATA########]......................| + ... +Row N-1: [##DATA########] + +UB (destination, `!pto.ptr`, 32B-aligned): + + |<---------- dst_stride (32B-aligned) ---------->| + |<- len_burst ->|<- pad (to 32B boundary) ->| | +Row 0: [##DATA########][000000 PAD 000000000000000] +Row 1: [##DATA########][000000 PAD 000000000000000] +Row 2: [##DATA########][000000 PAD 000000000000000] + ... +Row N-1: [##DATA########][000000 PAD 000000000000000] + +N = n_burst +stride = start of row[r] to start of row[r+1] +pad = filled with pad_val to 32B boundary (data_select_bit=true) +[DATA] = valid data transferred by DMA +[PAD] = pad_val fill (set via set_mov_pad_val) +``` + +##### 2D Diagram: UB→GM (pto.copy_ubuf_to_gm) + +``` +UB (source, `!pto.ptr`, 32B-aligned start addr): + + |<---------- src_stride (32B-aligned) --------->| + |<- len_burst ->|<-- pad (ignored on read) -->| | +Row 0: [##DATA########][000 pad 000000000000000000] +Row 1: [##DATA########][000 pad 000000000000000000] +Row 2: [##DATA########][000 pad 000000000000000000] + ... +Row N-1: [##DATA########][000 pad 000000000000000000] + +GM (destination, `!pto.ptr`): + + |<--- dst_stride (start-to-start) --->| + |<- len_burst ->| | +Row 0: [##DATA########]......................| +Row 1: [##DATA########]......................| +Row 2: [##DATA########]......................| + ... +Row N-1: [##DATA########] + +N = n_burst +MTE3 reads only len_burst bytes from each UB row (de-padding). +Only len_burst bytes are written to each GM row. +``` + +--- + +#### Multi-Level Loop Semantics (C Code) + +The full DMA transfer is a nested loop. The HW loop registers (set before the copy) control the outer levels, and the copy instruction parameters control the innermost burst level. + +##### GM→UB Full Loop + +```c +// C equivalent of what the HW executes: +for (int j = 0; j < loop2_count; j++) { // HW outer loop + uint8_t *gm1 = gm_src + j * loop2_src_stride; + uint8_t *ub1 = ub_dst + j * loop2_dst_stride; + + for (int k = 0; k < loop1_count; k++) { // HW inner loop + uint8_t *gm2 = gm1 + k * loop1_src_stride; + uint8_t *ub2 = ub1 + k * loop1_dst_stride; + + for (int r = 0; r < n_burst; r++) { // burst engine + memcpy(ub2 + r * dst_stride, // UB dest row + gm2 + r * src_stride, // GM src row + len_burst); // contiguous bytes + if (data_select_bit) + memset(ub2 + r * dst_stride + len_burst, + pad_val, dst_stride - len_burst); + } + } +} +``` + +##### UB→GM Full Loop + +```c +// C equivalent: +for (int j = 0; j < loop2_count; j++) { + uint8_t *ub1 = ub_src + j * loop2_src_stride; + uint8_t *gm1 = gm_dst + j * loop2_dst_stride; + + for (int k = 0; k < loop1_count; k++) { + uint8_t *ub2 = ub1 + k * loop1_src_stride; + uint8_t *gm2 = gm1 + k * loop1_dst_stride; + + for (int r = 0; r < n_burst; r++) { + memcpy(gm2 + r * dst_stride, // GM dest row + ub2 + r * src_stride, // UB src row + len_burst); // contiguous bytes + } + } +} +``` + +--- + +#### Example 1: GM→UB — Load a 32×32 f32 Tile (Simple Case) + +Load a 32×32 f32 tile from GM into UB. This matches the `abs_kernel_2d` test case. + +``` +GM layout (32 × 32 f32, contiguous): + + |<- len_burst = 128B (32 × 4) ->| + |<- src_stride = 128B --------->| + +--[#######TILE#######]--+ row 0 + +--[#######TILE#######]--+ row 1 + ... + +--[#######TILE#######]--+ row 31 + +UB layout (32 × 32 f32, 32B-aligned, contiguous): + + |<- dst_stride = 128B (32B-aligned) ->| + +--[#######TILE#######]--+ row 0 + +--[#######TILE#######]--+ row 1 + ... + +--[#######TILE#######]--+ row 31 + + len_burst = 32 × 4 = 128 bytes + src_stride = 128 bytes (contiguous rows) + dst_stride = 128 bytes (already 32B-aligned, no padding) +``` + +```mlir +// Simple 2D load — no multi-level loops needed +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + +pto.copy_gm_to_ubuf %arg0, %ub_in, + %c0_i64, // sid = 0 + %c32_i64, // n_burst = 32 (32 rows) + %c128_i64, // len_burst = 128 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c128_i64, // src_stride = 128 bytes + %c128_i64 // dst_stride = 128 bytes + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +#### Example 2: GM→UB — Load a 2D Tile from a Larger Matrix + +Load a 64×128 tile (f16) from a 1024×512 matrix in GM into UB. + +``` +GM layout (1024 × 512 f16): + + col 0 col 128 col 512 + | | | + +--[###TILE###]+.....................+ row R + +--[###TILE###]+.....................+ row R+1 + ... + +--[###TILE###]+.....................+ row R+63 + + |<--------- src_stride = 1024B ----------->| + |<-len_burst=256B->| + + len_burst = 128 × 2 = 256 bytes (128 f16 elements) + src_stride = 512 × 2 = 1024 bytes (start-to-start, full GM row) + +UB layout (64 × 128 f16, 32B-aligned, contiguous): + + +--[###TILE###]--+ row 0 (256 bytes, 32B-aligned, no pad) + +--[###TILE###]--+ row 1 + ... + +--[###TILE###]--+ row 63 + + dst_stride = 256 bytes (= len_burst, already 32B-aligned, no padding) +``` + +```mlir +// Simple 2D load — no multi-level loops needed +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 (64 rows) + %c256_i64, // len_burst = 256 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c1024_i64, // src_stride = 1024 bytes (full matrix row) + %c256_i64 // dst_stride = 256 bytes (tile row) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +#### Example 3: GM→UB — Load with Padding + +Load 100 valid columns from GM into a 128-wide UB tile (f16). The remaining 28 columns are zero-padded. + +``` +GM (100 cols valid, contiguous): + + |<-len_burst=200B->| + |<- src_stride=200B (start-to-start) ->| + +--[####DATA####]-+ row 0 + +--[####DATA####]-+ row 1 + ... + +--[####DATA####]-+ row 63 + +UB (128 cols wide, 32B-aligned, padded): + + |<--------- dst_stride = 256B (32B-aligned) --------->| + |<-len_burst=200B->|<---- pad = 56B to 32B boundary ->| + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 0 + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 1 + ... + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 63 + + len_burst = 100 × 2 = 200 bytes + src_stride = 200 bytes (start-to-start, contiguous in GM) + dst_stride = 128 × 2 = 256 bytes (32B-aligned tile width in UB) + pad = 256 - 200 = 56 bytes (padded to 32B boundary with pad_val) +``` + +```mlir +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 + %c200_i64, // len_burst = 200 bytes + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %true, // data_select_bit = true (enable padding) + %c0_i64, // l2_cache_ctl = 0 + %c200_i64, // src_stride = 200 bytes + %c256_i64 // dst_stride = 256 bytes (32B-aligned) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +#### Example 4: UB→GM — Store a 32×32 f32 Tile (Simple Case) + +Store a 32×32 f32 tile from UB back to GM. This matches the `abs_kernel_2d` test case. + +``` +UB (source, 32B-aligned, 32 × 32 f32): + + |<- src_stride = 128B (32B-aligned) ->| + |<- len_burst = 128B ->| + +--[#######TILE#######]---+ row 0 + +--[#######TILE#######]---+ row 1 + ... + +--[#######TILE#######]---+ row 31 + + (no padding here — len_burst == src_stride) + +GM (dest, 32 × 32 f32): + + |<- dst_stride = 128B ->| + |<- len_burst = 128B -->| + +--[#######TILE#######]---+ row 0 + +--[#######TILE#######]---+ row 1 + ... + +--[#######TILE#######]---+ row 31 +``` + +```mlir +// Configure MTE3 strides +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + +pto.copy_ubuf_to_gm %ub_out, %arg1, + %c0_i64, // sid = 0 + %c32_i64, // n_burst = 32 + %c128_i64, // len_burst = 128 bytes + %c0_i64, // reserved = 0 + %c128_i64, // dst_stride = 128 bytes + %c128_i64 // src_stride = 128 bytes + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` + +--- + +#### Example 5: UB→GM — Store a 2D Tile Back to a Larger Matrix + +Store a 64×128 tile (f16) from UB back to a 1024×512 GM matrix at an offset. + +``` +UB (source, 32B-aligned, 64 × 128 f16): + + |<- src_stride = 256B (32B-aligned) ->| + |<- len_burst = 256B ->| + +--[#####TILE#####]---+ row 0 + +--[#####TILE#####]---+ row 1 + ... + +--[#####TILE#####]---+ row 63 + + (no padding here — len_burst == src_stride) + +GM (dest, into 1024 × 512 matrix): + + |<----------- dst_stride = 1024B (start-to-start) --------->| + |<- len_burst = 256B ->| | + col 0 col 128 col 512 + +--[#####TILE#####]---+.............................+ row R + +--[#####TILE#####]---+.............................+ row R+1 + ... + +--[#####TILE#####]---+.............................+ row R+63 + + MTE3 reads len_burst bytes from each 32B-aligned UB row, + writes only len_burst bytes per GM row (stride controls row spacing). +``` + +```mlir +// Configure MTE3 strides +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_ubtoout %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_ubtoout %c0_i64, %c0_i64 : i64, i64 + +pto.copy_ubuf_to_gm %ub_ptr, %gm_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 + %c256_i64, // len_burst = 256 bytes + %c0_i64, // reserved = 0 + %c1024_i64, // dst_stride = 1024 bytes (GM row) + %c256_i64 // src_stride = 256 bytes (UB row) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` + +--- + +#### Example 6: GM→UB with Multi-Level Loop (Batch of Tiles) + +Load 4 batches of 8×128 tiles from a [4, 8, 128] f16 tensor using loop1. + +``` +GM [4, 8, 128] f16 (contiguous): UB (4 tiles laid out sequentially): + + batch 0: 8 rows × 256 bytes [batch 0: 8×128][batch 1: 8×128] + batch 1: 8 rows × 256 bytes [batch 2: 8×128][batch 3: 8×128] + batch 2: 8 rows × 256 bytes + batch 3: 8 rows × 256 bytes loop1 src_stride = 2048 bytes (8 × 256) + loop1 dst_stride = 2048 bytes (8 × 256) + Each batch = 8 × 256 = 2048 bytes loop1_count = 4 (iterate over batches) +``` + +```mlir +// loop1_count = 4 batches, loop2_count = 1 (not used) +pto.set_loop_size_outtoub %c4_i64, %c1_i64 : i64, i64 + +// loop1 stride: advance by one batch (2048 bytes) in both GM and UB +pto.set_loop1_stride_outtoub %c2048_i64, %c2048_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c8_i64, // n_burst = 8 rows per batch + %c256_i64, // len_burst = 256 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c256_i64, // src_stride = 256 (contiguous rows) + %c256_i64 // dst_stride = 256 (contiguous rows) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +Execution trace: + +``` +loop1 iter 0: gm_ptr + 0×2048 → ub_ptr + 0×2048, DMA 8 rows × 256B +loop1 iter 1: gm_ptr + 1×2048 → ub_ptr + 1×2048, DMA 8 rows × 256B +loop1 iter 2: gm_ptr + 2×2048 → ub_ptr + 2×2048, DMA 8 rows × 256B +loop1 iter 3: gm_ptr + 3×2048 → ub_ptr + 3×2048, DMA 8 rows × 256B +``` + + + +### 3. Vector Load/Store + +> **Category:** UB ↔ Vector Register data movement +> **Pipeline:** PIPE_V (Vector Core) + +Vector loads move data from Unified Buffer (UB) to vector registers (`vreg`). Vector stores move data from `vreg` back to UB. All vector compute operates only on `vreg` — UB is the staging area between DMA and compute. + +#### Common Operand Model + +- `%source` / `%dest` is the base address operand in SSA form. The base pointer + MUST address the Vector tile buffer / UB space. +- `%offset` is the displacement operand in SSA form. The exact encoding is + instruction-specific, but the effective address and any post-update behavior + MUST match the selected instruction form. +- `%mask` is the predicate operand for predicated memory families. For memory + families, + inactive lanes or inactive blocks MUST NOT issue memory requests unless the + instruction explicitly documents a different behavior. +- `%result` is the destination vector register value in SSA form. +- `!pto.align` is the SSA carrier for alignment-buffer state used by unaligned + load/store families. The PTO micro Instruction representation makes that state explicit rather than implicit. + +--- + +#### Latency and throughput (A5) + +**Cycle-accurate simulator (CA model)** issue→retire timings for vector-side instructions behind this chapter. Values are **simulator** results, **not** guaranteed for silicon. + +**SOC:** Tables below are from **Ascend910_9599** CA sim (the pto-isa ST default when **Ascend950PR_9599** is not selected). + +**Log `dist:` tokens:** PTO load/store modes lower to **`RV_VLD` / `RV_VLDI` / `RV_VST` / `RV_VSTI`** with a **`dist:`** field on the vector pipes (`RVECLD` / `RVECST`). Some simulator logs typo contiguous load as `dist:NORAML`; treat as **`NORMAL`**. + +##### Reference op latencies (A5 mnemonics) + +| A5 mnemonic | Mode / note | Typical issue→retire (cycles) | +|-------------|-------------|------------------------------| +| `RV_VLD` | `dist:NORMAL` / `NORAML` | **9** | +| `RV_VLDI` | `dist:DINTLV` (dual vreg) | **9** | +| `RV_VST` / `RV_VSTI` | `dist:NORM` | **9** | +| `RV_VGATHER2` | `Dtype: B32` | **27–28** | +| `RV_VGATHERB` | indexed byte gather | **~21** | +| `RV_VSCATTER` | `Dtype: B16` | **~17** | +| `RV_VADD` | F32 between UB-backed ops | **7** | + +##### `dist:` tokens (issue→retire) + +Most **`dist:`** tokens are **9** issue→retire cycles. **`INTLV`** on **`RV_VSTI`** is **12** cycles. + +| `dist:` (as in log) | RV op | issue→retire (cycles) | +|---------------------|-------|----------------------| +| `DINTLV` | `RV_VLDI` | **9** | +| `BRC` | `RV_VLD` | **9** | +| `BRC_BLK` | `RV_VLD` | **9** | +| `INTLV` | `RV_VSTI` | **12** | +| `UNPK` | `RV_VLD` | **9** | +| `NORM` | `RV_VSTI` | **9** | +| `PK` | `RV_VSTI` | **9** | +| `NORMAL` / `NORAML` | `RV_VLD` | **9** | + +**Note:** PTO intrinsic **`BRC_BLK`** matches the **`BRC_BLK`** `dist:` string on **`RV_VLD`** in simulator logs (block-replicate path; not a plain contiguous copy in the usual tiling use). + +**Issue (vector load/store):** `pto.vlds` (**`RV_VLD`**) is **dual-issue capable**: two independent `pto.vlds` can issue **in the same cycle**. **Alternatively**, the hardware can issue **one** `pto.vlds` **and** **one** `pto.vsts` together (**1+1**) in the same cycle. Each cycle is **either** dual **`vlds`** **or** **`vlds` + `vsts` (1+1)**—those two issue modes are mutually exclusive. Sustained throughput still depends on RAW hazards and loop structure. + +**Throughput (simulator, pattern-dependent):** + +- **`RV_VLD` / `pto.vlds`:** Dual-issue **or** half of a **1+1** with `vsts`, per the rule above. +- **`RV_VST` / `pto.vsts`:** In a **1+1** cycle, pairs with one `vlds`; otherwise typically **one** store per cycle in tight loops. +- **`RV_VGATHER2`:** Much lower than contiguous `RV_VLD` (on the order of **~0.1** ops/cycle in steady-state alongside 27–28-cycle latency). + +##### PTO `dist` summary (loads) + +| PTO `dist` (load) | Latency | +|-------------------|-------------------| +| `NORM` | **9** cycles | +| `UNPK` | **9** cycles | +| `DINTLV` | **9** cycles (`RV_VLDI`) | +| `BRC` | **9** cycles (`RV_VLD`) | +| `BRC_BLK` | **9** cycles as **`dist:BRC_BLK`** on `RV_VLD` | +| `BDINTLV` | **9** cycles | +| `US`, `DS`, `SPLT4CHN`, `SPLT2CHN` | **9** cycles | + +##### PTO `dist` summary (stores) + +| PTO `dist` (store) | Latency | +|--------------------|-------------------| +| `NORM` | **9** cycles (`RV_VSTI`) | +| `PK` | **9** cycles | +| `INTLV` (`pto.vstx2`) | **12** cycles | +| `MRG4CHN`, `MRG2CHN` | **9** cycles | + +##### Gather, scatter, and special addressing + +| PTO op | A5-level | Latency | +|--------|----------|-------------------| +| `pto.vgather2` | `RV_VGATHER2` | **27–28** cycles (pattern-dependent) | +| `pto.vgatherb` | `RV_VGATHERB` | **~21** cycles issue→retire | +| `pto.vgather2_bc` | (broadcast gather) | **27–28** cycles (same as **`pto.vgather2`**) | +| `pto.vscatter` | `RV_VSCATTER` | **~17** cycles for **`Dtype: B16`** | + +##### Strided loads/stores, unaligned ops, alignment state + +Ops such as **`pto.vldas`**, **`pto.vldus`**, **`pto.vsld`**, **`pto.vsldb`**, **`pto.vsst`**, **`pto.vsstb`**, **`pto.vsta`**, **`pto.vstas`**, **`pto.vstar`**, **`pto.vstu`**, **`pto.vstus`**, **`pto.vstur`**: **9** cycles (same vector load/store pipe family as contiguous `RV_VLD` / `RV_VST` unless listed otherwise above). + +##### Dual-issue vs DMA + +DMA **`TLOAD` / `TSTORE`** (global memory ↔ UB) use **MTE** pipes, not `RV_VLD`/`RV_VST`. **MTE2** `MOV_*` latency is not the same as vector `RV_VLD` latency (see `02-dma-copy.md` for GM↔UB movement). + +--- + +#### Contiguous Loads + +##### `pto.vlds` + +- **syntax:** `%result = pto.vlds %source[%offset] {dist = "DIST"} : !pto.ptr -> !pto.vreg` +- **semantics:** Vector load with distribution mode. +- **inputs:** + `%source` is the UB base address, `%offset` is the load displacement, and + `DIST` selects the distribution mode. +- **outputs:** + `%result` is the loaded vector register value. +- **constraints and limitations:** + The effective address MUST satisfy the alignment rule of the selected + distribution mode. `NORM` reads one full vector footprint. Broadcast, + upsample, downsample, unpack, split-channel, and deinterleave modes change + how memory bytes are mapped into destination lanes, but they do not change the + fact that the source is UB memory. PTO surface exposes load `dist` as family + tokens, and each family only supports the element widths listed below. + +**Distribution families:** + +| Family | Allowed element widths | C semantics | Latency | +|------|-------------|-------------|-------------| +| `NORM` | width-agnostic | `dst[i] = UB[base + i * sizeof(T)]` | **9** cycles | +| `BRC` | `b8`, `b16`, `b32` | `dst[i] = UB[base]` for all `i` | **9** cycles | +| `US` | `b8`, `b16` | `dst[2*i] = dst[2*i+1] = UB[base + i]` | **9** cycles | +| `DS` | `b8`, `b16` | `dst[i] = UB[base + 2*i]` | **9** cycles | +| `UNPK` | `b8`, `b16`, `b32` | Expand packed source data into wider lanes | **9** cycles | +| `BRC_BLK` | width-agnostic | Block-replicate load path; simulator logs may print `dist:BRC_BLK` | **9** cycles | +| `E2B` | `b16`, `b32` | Load element groups and expand them into byte-oriented lane layout | **9** cycles | +| `UNPK4` | `b8` | Unpack 4-way packed `b8` source groups into destination lanes | **9** cycles | +| `SPLT4CHN` | `b8` | Split 4-channel interleaved source into one channel plane | **9** cycles | +| `SPLT2CHN` | `b8`, `b16` | Split 2-channel interleaved source into one channel plane | **9** cycles | + +`pto.vlds` currently covers only single-result load families. Dual-result +deinterleave forms are modeled separately in PTO surface as +[`pto.vldsx2`](#ptovldsx2): `BDINTLV` is the block-deinterleave family, while +`DINTLV` is the element-width-sensitive deinterleave family. + +**Example — Contiguous load:** +```mlir +%v = pto.vlds %ub[%offset] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> +``` + +**Example — Broadcast scalar to all lanes:** +```mlir +%v = pto.vlds %ub[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> +``` + +--- + +##### `pto.vldas` + +- **syntax:** `%result = pto.vldas %source : !pto.ptr -> !pto.align` +- **semantics:** Prime alignment buffer for subsequent unaligned load. +- **inputs:** + `%source` is the UB address whose surrounding aligned block seeds the load + alignment state. +- **outputs:** + `%result` is the initialized load-alignment state. +- **constraints and limitations:** + This op is the required leading operation for a `pto.vldus` stream using the + same alignment state. The source address itself need not be 32-byte aligned; + hardware truncates it to the aligned block boundary for the priming load. +- **Latency:** **9** cycles. + +--- + +##### `pto.vldus` + +- **syntax:** `%result, %align_out = pto.vldus %source, %align : !pto.ptr, !pto.align -> !pto.vreg, !pto.align` +- **semantics:** Unaligned load using primed align state. +- **inputs:** + `%source` is the current UB address and `%align` is the incoming load + alignment state primed by `pto.vldas` or a prior `pto.vldus`. +- **outputs:** + `%result` is the assembled vector value and `%align_out` is the updated + alignment state. +- **constraints and limitations:** + A matching `pto.vldas` MUST appear before the first dependent `pto.vldus` + stream in the same vector loop. The installed no-post A5 interface keeps a + struct-shaped internal return for lowering convenience, but its no-post + `base` field is not meaningful user-visible state. VPTO therefore hides that + value and only exposes the updated align carrier. Reusing the original + `%source` starts a new explicit access point; if the caller wants another + no-post access, it should compute the next source pointer explicitly and pair + it with the required align setup. +- **Latency:** **9** cycles. + +**Unaligned load pattern:** +```mlir +%align = pto.vldas %ub : !pto.ptr -> !pto.align +%vec, %align2 = pto.vldus %ub, %align : !pto.ptr, !pto.align -> !pto.vreg<64xf32>, !pto.align +``` + +--- + +##### `pto.init_align` + +- **syntax:** `%result = pto.init_align : !pto.align` +- **semantics:** Initialize store-side align carrier state. +- **outputs:** + `%result` is a fresh zero-initialized align carrier for store-side unaligned + streams such as `pto.vstus`, `pto.vstur`, `pto.vstar`, `pto.vstas`, and + `pto.pstu`. +- **constraints and limitations:** + This op is for store-family initialization only. Unaligned load streams still + start from `pto.vldas`. + +--- + +#### Dual Loads (Deinterleave) + +##### `pto.vldsx2` + +- **syntax:** `%low, %high = pto.vldsx2 %source[%offset], "DIST" : !pto.ptr, index -> !pto.vreg, !pto.vreg` +- **semantics:** Dual load with deinterleave (AoS → SoA conversion). +- **inputs:** + `%source` is the UB base pointer, `%offset` is the displacement, and `DIST` + selects a dual-load/deinterleave layout. +- **outputs:** + `%low` and `%high` are the two destination vectors. +- **constraints and limitations:** + This family is only legal for interleave/deinterleave style distributions. + The two outputs form an ordered pair, and that pairing MUST be preserved. + PTO surface accepts deinterleave families. `BDINTLV` is element-width + agnostic, while `DINTLV` supports only the element widths listed in the + table. +- **latency:** `BDINTLV` / `DINTLV` are both **9** cycles. + +**Distribution families:** + +| Family | Allowed element widths | C semantics | Latency | +|------|-------------|-------------|-------------| +| `BDINTLV` | width-agnostic | Block deinterleave into two destination vectors | **9** cycles | +| `DINTLV` | `b8`, `b16`, `b32` | Deinterleave alternating elements into `%low` / `%high` | **9** cycles | + +```c +// DINTLV family on 32-bit elements: deinterleave 32-bit elements +for (int i = 0; i < 64; i++) { + low[i] = UB[base + 8*i]; // even elements + high[i] = UB[base + 8*i + 4]; // odd elements +} +``` + +**Example — Load interleaved XY pairs into separate X/Y vectors:** +```mlir +%x, %y = pto.vldsx2 %ub[%offset], "DINTLV" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +``` + +##### `pto.vsldb` + +- **syntax:** `%result = pto.vsldb %source, %block_stride, %repeat_stride, %mask : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg` +- **semantics:** Block-strided load for 2D tile access. +- **inputs:** + `%source` is the UB base pointer. `%block_stride` and `%repeat_stride` are + the two 16-bit fields of the hardware control word, and `%mask` controls + which blocks participate. +- **outputs:** + `%result` is the loaded vector. +- **constraints and limitations:** + PTO surface does not expose the packed control word directly. If a block is + masked off, the corresponding destination block is zeroed and MUST NOT raise + an address overflow exception for that block. +- **Latency:** **9** cycles. + +```c +// Block-strided load on 32-bit elements: one 32B block = 8 lanes. +for (int blk = 0; blk < 8; ++blk) { + if (pg_b32[blk]) + dst_block[blk] = UB_block[base + repeat_stride + blk * block_stride]; + else + dst_block[blk] = 0; +} +``` + +--- + +#### Gather (Indexed) Loads + +##### `pto.vgather2` + +- **syntax:** `%result = pto.vgather2 %source, %offsets, %active_lanes : !pto.ptr, !pto.vreg, index -> !pto.vreg` +- **semantics:** Indexed gather from UB. +- **inputs:** + `%source` is the UB base pointer, `%offsets` provides per-lane element + offsets, and `%active_lanes` bounds how many lanes participate. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + Only the first `%active_lanes` indices participate. The index element width + and interpretation MUST match the selected gather form, and each effective + address must satisfy that form's alignment rules. +- **Latency:** **27–28** cycles per `RV_VGATHER2`; throughput much lower than contiguous `RV_VLD` (see **Latency and throughput (A5)** at the start of this chapter). + +```c +for (int i = 0; i < active_lanes; i++) + dst[i] = UB[base + offsets[i] * sizeof(T)]; +``` + +--- + +##### `pto.vgatherb` + +- **syntax:** `%result = pto.vgatherb %source, %offsets, %mask : !pto.ptr, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Block gather load from UB. +- **inputs:** + `%source` is the UB base pointer, `%offsets` is a `ui32` offset vector, and + `%mask` is a `b32` predicate over the block-index lanes. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + This is a 32-byte block gather, not an element gather. `%source` MUST be + 32-byte aligned. Each participating `offsets[i]` is interpreted as a byte + offset and MUST itself be 32-byte aligned. Only the low `VL/8` bytes of the + offset vector are semantically valid; the effective block address is + `block_addr[i] = offsets_u32[i] + base`. If a `b32` predicate position is + false, the corresponding block does not participate in address coalescing, + does not raise overflow on that block address, and the destination block is + zero-filled. +- **Latency:** **~21** cycles issue→retire. + +```c +for (int blk = 0; blk < VL / 32; ++blk) { + if (pg_b32[blk]) + dst_block[blk] = UB_block[base + offsets_u32[blk]]; + else + dst_block[blk] = 0; +} +``` + +--- + +##### `pto.vgather2_bc` + +- **syntax:** `%result = pto.vgather2_bc %source, %offsets, %mask : !pto.ptr, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Gather with broadcast, conditioned by mask. +- **inputs:** + `%source` is the UB base pointer, `%offsets` contains gather indices, and + `%mask` gates which lanes participate. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + This is a backward-compatible family. Masked-off lanes do not participate in + address coalescing and do not trigger address overflow exceptions; their + destination lanes are zero-filled. On the current PTO surface, `%offsets` + uses 32-bit integer elements. +- **Latency:** **27–28** cycles (same as **`pto.vgather2`**). + +--- + +#### Contiguous Stores + +##### `pto.vsts` + +- **syntax:** `pto.vsts %value, %dest[%offset], %mask {dist = "DIST"} : !pto.vreg, !pto.ptr, !pto.mask` +- **semantics:** Vector store with distribution mode. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, `%offset` is + the displacement, `%mask` selects the active lanes or sub-elements, and + `DIST` selects the store distribution. +- **outputs:** + This op has no SSA result; it writes to UB memory. +- **constraints and limitations:** + The effective destination address MUST satisfy the alignment rule of the + selected store mode. The single-input `pto.vsts` family covers contiguous + store, first-element-only store, packed store, and channel-merge store. + Dual-input interleave store remains in `pto.vstsx2`. PTO surface exposes + store `dist` as family tokens, and each family only supports the element + widths listed below. + +**Distribution families:** + +| Family | Allowed element widths | C semantics | Latency | +|------|-------------|-------------|-------------| +| `NORM` | `b8`, `b16`, `b32` | `UB[base + i] = src[i]` | **9** cycles | +| `1PT` | `b8`, `b16`, `b32` | Only element 0 is written to the destination footprint | **9** cycles | +| `PK` | `b16`, `b32`, `b64` | Pack low half bits of each source element before store | **9** cycles | +| `PK4` | `b32` | Pack low 8 bits of each `b32` element before store | **9** cycles | +| `MRG4CHN` | `b8` | Merge 4 channel planes into an interleaved 4-channel layout | **9** cycles | +| `MRG2CHN` | `b8`, `b16` | Merge 2 channel planes into an interleaved 2-channel layout | **9** cycles | + +**Example — Contiguous store:** +```mlir +pto.vsts %v, %ub[%offset], %mask {dist = "NORM"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +--- + +#### Dual Stores (Interleave) + +##### `pto.vstsx2` + +- **syntax:** `pto.vstsx2 %low, %high, %dest[%offset], "DIST", %mask : !pto.vreg, !pto.vreg, !pto.ptr, index, !pto.mask` +- **semantics:** Dual interleaved store (SoA → AoS conversion). +- **inputs:** + `%low` and `%high` are the two source vectors, `%dest` is the UB base pointer, + `%offset` is the displacement, `DIST` selects the interleave layout, and + `%mask` gates the participating elements. +- **outputs:** + This op has no SSA result; it writes an interleaved stream to UB. +- **constraints and limitations:** + This family is only legal for interleave distributions. The two source + vectors form an ordered pair, and the interleave semantics of that pair MUST + be preserved. PTO surface accepts the `INTLV` family, which only supports the + element widths listed below. + be preserved. PTO surface accepts the `INTLV` family, which only supports the + element widths listed below. +- **latency:** `INTLV` is **12** cycles。 + +**Distribution families:** + +| Family | Allowed element widths | C semantics | Latency | +|------|-------------|-------------|-------------| +| `INTLV` | `b8`, `b16`, `b32` | Interleave `%low` / `%high` into one destination stream | **12** cycles | +| `INTLV` | `b8`, `b16`, `b32` | + +```c +// INTLV family on 32-bit elements: +for (int i = 0; i < 64; i++) { + UB[base + 8*i] = low[i]; + UB[base + 8*i + 4] = high[i]; +} +``` + +##### `pto.vsstb` + +- **syntax:** `pto.vsstb %value, %dest, %block_stride, %repeat_stride, %mask : !pto.vreg, !pto.ptr, i16, i16, !pto.mask` +- **semantics:** Block-strided store for 2D tile access. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, + `%block_stride` and `%repeat_stride` are the two 16-bit fields of the + hardware control word, and `%mask` controls block participation. +- **outputs:** + This op writes UB memory and returns no SSA value. +- **constraints and limitations:** + PTO surface does not expose the packed control word directly. Masked-off + blocks MUST NOT issue memory writes. +- **Latency:** **9** cycles. + +```c +// Block-strided store on 32-bit elements: one 32B block = 8 lanes. +for (int blk = 0; blk < 8; ++blk) { + if (pg_b32[blk]) + UB_block[base + repeat_stride + blk * block_stride] = src_block[blk]; +} +``` + +--- + +#### Scatter (Indexed) Stores + +##### `pto.vscatter` + +- **syntax:** `pto.vscatter %value, %dest, %offsets, %active_lanes : !pto.vreg, !pto.ptr, !pto.vreg, index` +- **semantics:** Indexed scatter to UB. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, `%offsets` + provides per-lane or per-block indices, and `%active_lanes` bounds the active + requests. +- **outputs:** + This op writes UB memory and returns no SSA value. +- **constraints and limitations:** + Only `b8`, `b16`, and `b32` element sizes are supported. The index vector + must use a supported integer element type and layout for this family. + Each computed address MUST be element-aligned. If two or more indices alias, + only one write is guaranteed and the winning lane is implementation-defined. +- **Latency:** **~17** cycles for **`Dtype: B16`**. + +```c +for (int i = 0; i < active_lanes; i++) + UB[base + offsets[i] * sizeof(T)] = src[i]; +``` + +--- + +#### Alignment State Stores + +##### `pto.vstas` +- **syntax:** `pto.vstas %value, %dest, %offset : !pto.align, !pto.ptr, i32` +- **semantics:** Scalar-register-offset form of alignment-state flush. +- **inputs:** + `%value` is the pending store-alignment state, `%dest` is the UB base + pointer, and `%offset` is the scalar-register style displacement. +- **outputs:** + This op writes buffered tail bytes to UB and returns no SSA value. +- **constraints and limitations:** + This family flushes pending store-alignment state using an explicit scalar + offset and keeps the scalar-offset form explicit. The incoming `%value` + should come from `pto.init_align` or from a prior state-producing unaligned + store op in the same stream. `%dest` and `%offset` together must identify the + same logical flush point produced by the immediately preceding stateful + unaligned-store step on that stream; using an unrelated base/offset pair is + invalid even if `%value` itself came from the same stream. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstar` +- **syntax:** `pto.vstar %value, %dest : !pto.align, !pto.ptr` +- **semantics:** Flush alignment state using the register-update form. +- **inputs:** + `%value` is the pending store-alignment state and `%dest` is the UB base + pointer. +- **outputs:** + This op writes buffered tail bytes to UB and returns no SSA value. +- **constraints and limitations:** + The implicit update state consumed by this flush MUST correspond to the same + store stream that produced `%value`. The first store-side state in a stream + should be created by `pto.init_align`. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstar` + +- **syntax:** `pto.vstar %value, %dest : !pto.align, !pto.ptr` +- **semantics:** Flush remaining alignment state. +- **inputs:** + `%value` is the pending alignment/buffer state that still needs to be emitted, + and `%dest` is the UB destination base pointer. +- **outputs:** + No SSA result. The effect is a memory-side flush that writes the remaining + buffered bytes to memory. +- **constraints and limitations:** + This op terminates an unaligned-store sequence. It MUST be paired with a + compatible prior state-producing store sequence so that the pending tail state + is well-defined. +- **Latency:** **9** cycles. + +--- + +#### Stateful Store Ops + +These ops make reference-updated state explicit as SSA results. + +##### `pto.vstus` + +- **syntax:** `%align_out = pto.vstus %align_in, %offset, %value, %base : !pto.align, i32, !pto.vreg, !pto.ptr -> !pto.align` +- **semantics:** No-post unaligned store with scalar offset. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%offset` is the scalar + displacement, `%value` is the vector being stored, and `%base` is the UB base + pointer. +- **outputs:** + `%align_out` is the updated buffered-tail state. +- **constraints and limitations:** + This is the scalar-offset stateful form of the unaligned store family. The + scalar offset width MUST match the selected form, and a later flush op is + still required. The first `%align_in` in the stream should come from + `pto.init_align`. This op does not mean "store a full vector starting at + `%base + %offset`". Instead, `%offset` describes how far the store stream + advances at this step, and `%align_out` carries any residual tail that could + not be committed yet. The no-post surface does not expose an updated base + pointer. A later flush op must therefore use an explicit destination/offset + pair that identifies the same logical flush point as this `pto.vstus`. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstur` + +- **syntax:** `%align_out = pto.vstur %align_in, %value, %base, "MODE" : !pto.align, !pto.vreg, !pto.ptr -> !pto.align` +- **semantics:** Unaligned store with residual flush and SPR-AR-driven state update. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%value` is the vector to + store, `%base` is the UB base pointer, and `MODE` selects whether the + hardware updates `SPR AR` after the store. +- **outputs:** + `%align_out` is the updated residual state after the current partial store. +- **constraints and limitations:** + The effective address is `base + AR`, where `AR` is the hardware SPR state + carried outside SSA. `POST_UPDATE` means hardware may advance `SPR AR` + according to the fixed `SPR SQZN` configuration; `NO_POST_UPDATE` preserves + the current `SPR AR` value. This form exposes only the evolving residual + align-state in SSA; it does not by itself guarantee that all buffered bytes + have reached memory. A compatible final flush is still required unless the + surrounding sequence is known to be complete. Independent sequences typically + begin from `AR = 0`; if the surrounding program does not already guarantee + that, the hardware sequence should clear `SPR AR` before the first dependent + `pto.vstur`. The first `%align_in` in the stream should come from + `pto.init_align`. `pto.vstur` also consumes the fixed `SPR SQZN` state, so a + preceding squeeze producer such as `pto.vsqz` / `pto.vusqz` MUST establish + the byte count before the store. `MODE` MUST be one of `POST_UPDATE` or + `NO_POST_UPDATE`. +- **Latency:** **9** cycles. + + + +### 4. Predicate Load/Store + +> **Category:** UB ↔ Predicate Register data movement +> **Pipeline:** PIPE_V (Vector Core) + +Predicate registers (`!pto.mask`) are 256-bit registers that enable per-lane conditional execution. These ops move predicate values between UB and predicate registers. + +In concrete examples, `G` should be chosen to match the consumer family. The +examples below use `b32` when the loaded/stored mask is used with `f32` +vector compares or selects. + +The predicate load/store ops documented on this page always use explicit +`base[offset]` addressing. The immediate forms (`pldi`, `psti`) and dynamic +forms (`plds`, `psts`) differ only in how `%offset` is supplied. + +--- + +#### Predicate Loads + +##### `pto.plds` + +- **syntax:** `%result = pto.plds %source[%offset], "DIST" : !pto.ptr, index -> !pto.mask` +- **semantics:** Load predicate register with runtime offset. This is the + dynamic-offset form of `pto.pldi`: the predicate payload interpretation is + the same, but `%offset` is supplied as an SSA `index` instead of a constant + `index` immediate. +- **DIST:** mandatory string token, one of `NORM`, `US`, `DS`. + - `NORM`: load a normal packed predicate payload of size `VL/8`. + - `US`: load a packed predicate payload of size `VL/16`, then duplicate each + loaded bit once. + - `DS`: load a packed predicate payload of size `2 * VL/8`, then keep one + bit out of every two bits. + +The loaded payload is a packed predicate image in UB. Consumer ops interpret +the resulting `!pto.mask` according to the mask granularity `G`. +`pto.plds` only +models the explicit `base[offset]` form. + +**Example:** +```mlir +%mask = pto.plds %ub[%c0], "NORM" : !pto.ptr, index -> !pto.mask +``` + +--- + +##### `pto.pldi` + +- **syntax:** `%result = pto.pldi %source[%offset], "DIST" : !pto.ptr, index -> !pto.mask` +- **offset:** must be a constant `index` immediate in PTO surface form. +- **semantics:** Load predicate register with immediate offset. +- **DIST:** mandatory string token, one of `NORM`, `US`, `DS`. + - `NORM`: load a normal packed predicate payload of size `VL/8`. + - `US`: load a packed predicate payload of size `VL/16`, then duplicate each + loaded bit once. + - `DS`: load a packed predicate payload of size `2 * VL/8`, then keep one + bit out of every two bits. + +Like `pto.plds`, this op reads a packed predicate payload from UB and +materializes it as `!pto.mask`. + +--- + +#### Predicate Stores + +##### `pto.psts` + +- **syntax:** `pto.psts %value, %dest[%offset], "DIST" : !pto.mask, !pto.ptr, index` +- **semantics:** Store predicate register with runtime offset. This is the + dynamic-offset form of `pto.psti`: the predicate payload interpretation is + the same, but `%offset` is supplied as an SSA `index` instead of a constant + `index` immediate. +- **DIST:** mandatory string token, one of `NORM`, `PK`. + - `NORM`: store the packed predicate payload into a normal destination space + of size `VL/8`. + - `PK`: store the packed predicate payload into a destination space of size + `VL/16`, keeping one bit out of every two bits. + +`pto.psts` stores the packed predicate payload represented by `!pto.mask`. +It only models the explicit `base[offset]` form. + +**Example:** +```mlir +pto.psts %mask, %ub[%c0], "NORM" : !pto.mask, !pto.ptr, index +``` + +--- + +##### `pto.psti` + +- **syntax:** `pto.psti %value, %dest[%offset], "DIST" : !pto.mask, !pto.ptr, index` +- **offset:** must be a constant `index` immediate in PTO surface form. +- **semantics:** Store predicate register with immediate offset. +- **DIST:** mandatory string token, one of `NORM`, `PK`. + - `NORM`: store the packed predicate payload into a normal destination space + of size `VL/8`. + - `PK`: store the packed predicate payload into a destination space of size + `VL/16`, keeping one bit out of every two bits. + +`pto.psti` and `pto.psts` store the packed predicate payload represented by +`!pto.mask`. The surface distinction is only immediate-offset versus +dynamic-offset. + +--- + +##### `pto.pstu` + +- **syntax:** `%align_out, %base_out = pto.pstu %align_in, %value, %base : !pto.align, !pto.mask, !pto.ptr -> !pto.align, !pto.ptr` +- **syntax:** `%align_out, %base_out = pto.pstu %align_in, %value, %base : !pto.align, !pto.mask, !pto.ptr -> !pto.align, !pto.ptr` +- **semantics:** Predicate unaligned store with align/base state update. The base type is fixed by mask granularity: `b16 <-> ui16`, `b32 <-> ui32`. +- **outputs:** + `%align_out` and `%base_out` are the updated unaligned-store state and are + intended to be used by a later `pto.pstu` call. +- **constraints and limitations:** + The first `%align_in` in a predicate unaligned-store stream should come from + `pto.init_align`. + +--- + +#### Typical Usage Pattern + +```mlir +// Generate comparison mask +%mask = pto.vcmp %v0, %v1, %seed, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + +// Store mask to UB for later use +pto.psts %mask, %ub_mask[%c0], "NORM" : !pto.mask, !pto.ptr, index + +// ... later in another kernel ... + +// Load mask from UB +%saved_mask = pto.plds %ub_mask[%c0], "NORM" : !pto.ptr, index -> !pto.mask + +// Use for predicated select +%result = pto.vsel %v_true, %v_false, %saved_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 5. Materialization & Predicate Ops + +> **Category:** Scalar broadcast, predicate generation and manipulation +> **Pipeline:** PIPE_V (Vector Core) + +These ops create vectors from scalar values and manipulate predicate registers. + +#### Common Operand Model + +- `%value` is the scalar source value in SSA form. +- `%input` is either a source scalar or a source vector depending on the op. +- `%result` is the destination vector register value. +- For 32-bit scalar inputs, the scalar source MUST satisfy the backend's legal + scalar-source constraints for this family. + +--- + +#### Scalar Materialization + +##### `pto.vbr` + +- **syntax:** `%result = pto.vbr %value : T -> !pto.vreg` +- **semantics:** Broadcast scalar to all vector lanes. +- **inputs:** + `%value` is the scalar source. +- **outputs:** + `%result` is a vector whose active lanes all carry `%value`. +- **constraints and limitations:** + Supported forms are `b8`, `b16`, and `b32`. For `b8`, only the low 8 bits of + the scalar source are consumed. + +```c +for (int i = 0; i < N; i++) + dst[i] = value; +``` + +**Example:** +```mlir +%one = pto.vbr %c1_f32 : f32 -> !pto.vreg<64xf32> +``` + +--- + +##### `pto.vdup` + +- **syntax:** `%result = pto.vdup %input, %mask {position = "LOWEST|HIGHEST"} : T|!pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Duplicate scalar or vector element to all lanes. +- **inputs:** + `%input` supplies the scalar or source-lane value selected by `position`, + and `%mask` controls the active lanes. +- **outputs:** + `%result` is the duplicated vector. +- **constraints and limitations:** + `position` selects which source vector element is duplicated and is only valid + for vector input. `position` defaults to `LOWEST`. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? input_scalar_or_element : 0; +``` + +--- + +#### Predicate Generation + +##### `pto.pset_b8` / `pto.pset_b16` / `pto.pset_b32` + +- **syntax:** `%result = pto.pset_b8 "PATTERN" : !pto.mask` +- **syntax:** `%result = pto.pset_b16 "PATTERN" : !pto.mask` +- **syntax:** `%result = pto.pset_b32 "PATTERN" : !pto.mask` +- **semantics:** Materialize a predicate register from a named pattern token. + +**Supported pattern tokens:** + +| Pattern | Description | +|---------|-------------| +| `PAT_ALL` | All lanes active | +| `PAT_ALLF` | All lanes inactive | +| `PAT_H` | High half active | +| `PAT_Q` | Upper quarter active | +| `PAT_VL1`...`PAT_VL128` | First N logical lanes active | +| `PAT_M3`, `PAT_M4` | Modular patterns | + +`PAT_ALL` is the PTO spelling of the VISA-style all-true predicate pattern. +The other tokens listed above are also concrete installed-toolchain pattern +objects, not PTO-only aliases. + +**Example — All 64 f32 lanes active:** +```mlir +%all_active = pto.pset_b32 "PAT_ALL" : !pto.mask +``` + +**Example — First 16 lanes active:** +```mlir +%first_16 = pto.pset_b32 "PAT_VL16" : !pto.mask +``` + +--- + +##### `pto.pge_b8` / `pto.pge_b16` / `pto.pge_b32` + +- **syntax:** `%result = pto.pge_b8 "PATTERN" : !pto.mask` +- **syntax:** `%result = pto.pge_b16 "PATTERN" : !pto.mask` +- **syntax:** `%result = pto.pge_b32 "PATTERN" : !pto.mask` +- **semantics:** Generate a predicate from a lane-count pattern token. In the + common tail-mask form, `PAT_VL` marks the first `N` logical lanes active. +- **supported pattern tokens:** `PAT_ALL`, `PAT_ALLF`, `PAT_H`, `PAT_Q`, + `PAT_VL1`, `PAT_VL2`, `PAT_VL3`, `PAT_VL4`, `PAT_VL8`, `PAT_VL16`, + `PAT_VL32`, `PAT_VL64`, `PAT_VL128`, `PAT_M3`, `PAT_M4` + +```c +for (int i = 0; i < TOTAL_LANES; i++) + mask[i] = (i < len); +``` + +**Example — Tail mask for remainder loop:** +```mlir +%tail_mask = pto.pge_b32 "PAT_VL8" : !pto.mask +``` + +--- + +##### `pto.plt_b8` / `pto.plt_b16` / `pto.plt_b32` + +- **syntax:** `%mask, %scalar_out = pto.plt_b8 %scalar : i32 -> !pto.mask, i32` +- **syntax:** `%mask, %scalar_out = pto.plt_b16 %scalar : i32 -> !pto.mask, i32` +- **syntax:** `%mask, %scalar_out = pto.plt_b32 %scalar : i32 -> !pto.mask, i32` +- **semantics:** Generate a tail-style predicate from an SSA lane-count value. + On A5/V300-style toolchains, this family is exposed as a post-update wrapper: + the predicate result becomes `%mask`, and the wrapper's carry-out scalar state + is surfaced as `%scalar_out`. +- **inputs:** + `%scalar` is the incoming lane-count / remaining-count state. +- **outputs:** + `%mask` is the generated predicate. + `%scalar_out` is the post-update scalar carry-out from the same `plt` call + and can be threaded into a subsequent `pto.plt_b*` call in the same chain. + +```c +for (int i = 0; i < VL_t; ++i) + mask[i] = (i < scalar_in); + +scalar_out = (scalar_in < VL_t) ? 0 : (scalar_in - VL_t); +``` + +Where `VL_t` is the logical lane count of the concrete op variant: + +- `pto.plt_b8`: `VL_t = 256` +- `pto.plt_b16`: `VL_t = 128` +- `pto.plt_b32`: `VL_t = 64` + +--- + +#### Predicate Pack/Unpack + +##### `pto.ppack` + +- **syntax:** `%result = pto.ppack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Narrowing pack of predicate register. +- **part tokens:** + - `LOWER`: pack into the lower half of `%result`; the upper half is zeroed. + - `HIGHER`: pack into the higher half of `%result`; the lower half is zeroed. + +Conceptually, `pto.ppack` keeps one bit out of each adjacent 2-bit group from +`%input`, packs those kept bits into the selected half of `%result`, and fills +the other half with zeros. + +```c +// Let VL be the logical lane count of the destination predicate. +// LOWER +for (int i = 0; i < VL / 2; ++i) + result[i] = input[2 * i]; +for (int i = VL / 2; i < VL; ++i) + result[i] = 0; + +// HIGHER +for (int i = 0; i < VL / 2; ++i) + result[VL / 2 + i] = input[2 * i]; +for (int i = 0; i < VL / 2; ++i) + result[i] = 0; +``` + +--- + +##### `pto.punpack` + +- **syntax:** `%result = pto.punpack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Widening unpack of predicate register. +- **part tokens:** + - `LOWER`: unpack from the lower half of `%input`. + - `HIGHER`: unpack from the higher half of `%input`. + +Conceptually, `pto.punpack` reads the selected half of `%input`, zero-extends +each 1-bit predicate element into a 2-bit group in `%result`, and leaves the +expanded image in the full destination predicate register. + +```c +// Let VL be the logical lane count of the destination predicate. +// LOWER +for (int i = 0; i < VL / 2; ++i) { + result[2 * i] = input[i]; + result[2 * i + 1] = 0; +} + +// HIGHER +for (int i = 0; i < VL / 2; ++i) { + result[2 * i] = input[VL / 2 + i]; + result[2 * i + 1] = 0; +} +``` + +--- + +#### Predicate Logical Ops + +##### `pto.pand` + +- **syntax:** `%result = pto.pand %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise AND gated by a governing predicate. + +Inactive lanes selected out by `%mask` are zeroed. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? (src0[i] & src1[i]) : 0; +``` + +--- + +##### `pto.por` + +- **syntax:** `%result = pto.por %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise OR gated by a governing predicate. + +Inactive lanes selected out by `%mask` are zeroed. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? (src0[i] | src1[i]) : 0; +``` + +--- + +##### `pto.pxor` + +- **syntax:** `%result = pto.pxor %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise XOR gated by a governing predicate. + +Inactive lanes selected by `%mask` are zeroed. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? (src0[i] ^ src1[i]) : 0; +``` + +--- + +##### `pto.pnot` + +- **syntax:** `%result = pto.pnot %input, %mask : !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise NOT gated by a governing predicate. + +Inactive lanes selected by `%mask` are zeroed. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? (~src[i]) : 0; +``` + +--- + +##### `pto.psel` + +- **syntax:** `%result = pto.psel %src0, %src1, %sel : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate select (mux). `%sel` is the governing predicate that + chooses lanes from `%src0` or `%src1`. + +```c +for (int i = 0; i < N; i++) + dst[i] = sel[i] ? src0[i] : src1[i]; +``` + +--- + +##### `pto.pdintlv_b8` / `pto.pdintlv_b16` / `pto.pdintlv_b32` + +- **syntax:** `%low, %high = pto.pdintlv_b8 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **syntax:** `%low, %high = pto.pdintlv_b16 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **syntax:** `%low, %high = pto.pdintlv_b32 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **semantics:** De-interleave two predicate sources and return the two + de-interleaved predicate images in the same predicate element family. + +--- + +##### `pto.pintlv_b8` / `pto.pintlv_b16` / `pto.pintlv_b32` + +- **syntax:** `%low, %high = pto.pintlv_b8 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **syntax:** `%low, %high = pto.pintlv_b16 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **syntax:** `%low, %high = pto.pintlv_b32 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **semantics:** Interleave two predicate sources and return the two + resulting predicate images in the same predicate element family. + +--- + +#### Typical Usage + +```mlir +// Generate all-active mask for f32 (64 lanes) +%all = pto.pset_b32 "PAT_ALL" : !pto.mask + +// Generate tail mask for remainder (last 12 elements) +%tail = pto.pge_b32 "PAT_VL12" : !pto.mask + +// Compare and generate mask +%cmp_mask = pto.vcmp %a, %b, %all, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + +// Combine masks: only process tail elements that passed comparison +%combined = pto.pand %cmp_mask, %tail, %all : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + +// Use for predicated operation +%result = pto.vsel %true_vals, %false_vals, %combined : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 6. Unary Vector Ops + +> **Category:** Single-input vector operations +> **Pipeline:** PIPE_V (Vector Core) + +Element-wise operations that take one vector input and produce one vector output. + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%mask` is the predicate operand. For this family, inactive lanes follow the + predication behavior of the selected instruction form: zeroing forms + zero-fill inactive lanes, while merging forms preserve the destination value. +- `%result` is the destination vector register value. Unless stated otherwise, + `%result` has the same lane count and element type as `%input`. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). **fp16** values use **aclFloat16** in traces where measured. **bf16:** no simple-tile ST coverage on this surface; treat as **—**. + +| PTO op | RV (CA) | fp32 | fp16 | bf16 | +|--------|---------|------|------|------| +| `pto.vabs` | `RV_VABS_FP` | **5** | **5** | — | +| `pto.vneg` | `RV_VMULS` | **8** | **8** | — | +| `pto.vexp` | `RV_VEXP` | **16** | **21** | — | +| `pto.vln` | `RV_VLN` | **18** | **23** | — | +| `pto.vsqrt` | `RV_VSQRT` | **17** | **22** | — | +| `pto.vrelu` | `RV_VRELU` | **5** | **5** | — | +| `pto.vnot` | `RV_VNOT` | — | int-only paths | — | +| `pto.vmov` | `RV_VLD` proxy | **9** | **9** | — | + +--- + +#### Arithmetic + +##### `pto.vabs` + +- **syntax:** `%result = pto.vabs %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] < 0) ? -src[i] : src[i]; +``` + +- **inputs:** `%input` supplies the source lanes and `%mask` selects which lanes + participate. +- **outputs:** `%result` receives the lane-wise absolute values. +- **constraints and limitations:** Source and result types MUST match. On A5, + integer overflow follows the ISA default truncation behavior for this family; + `pto.vabs` is not an explicit saturating op. + +--- + +##### `pto.vneg` + +- **syntax:** `%result = pto.vneg %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = -src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise arithmetic negation. +- **constraints and limitations:** Source and result types MUST match. + +--- + +#### Transcendental + +##### `pto.vexp` + +- **syntax:** `%result = pto.vexp %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = expf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds `exp(input[i])` per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + +--- + +##### `pto.vln` + +- **syntax:** `%result = pto.vln %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = logf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the natural logarithm per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + For real-number semantics, active inputs SHOULD be strictly positive; non- + positive inputs follow the target's exception/NaN rules. + +--- + +##### `pto.vsqrt` + +- **syntax:** `%result = pto.vsqrt %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = sqrtf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the square root per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + Negative active inputs follow the target's exception/NaN rules. + +--- + +#### Activation + +##### `pto.vrelu` + +- **syntax:** `%result = pto.vrelu %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] > 0) ? src[i] : 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds `max(input[i], 0)` per active lane. +- **constraints and limitations:** Only floating-point element types are legal + on the current A5 surface described here. + +--- + +#### Bitwise + +##### `pto.vnot` + +- **syntax:** `%result = pto.vnot %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = ~src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the lane-wise bitwise inversion. +- **constraints and limitations:** Integer element types only. + +--- + +#### Movement + +#### Typical Usage + +```mlir +// Softmax numerator: exp(x - max) +%sub = pto.vsub %x, %max_broadcast, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%exp = pto.vexp %sub, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// ReLU activation +%activated = pto.vrelu %linear_out, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 7. Binary Vector Ops + +> **Category:** Two-input vector operations +> **Pipeline:** PIPE_V (Vector Core) + +Element-wise operations that take two vector inputs and produce one vector output. + +#### Common Operand Model + +- `%lhs` and `%rhs` are the two source vector register values. +- `%mask` is the predicate operand `Pg` that gates which lanes participate. +- `%result` is the destination vector register value. Unless explicitly noted, + it has the same lane count and element type as the inputs. +- Unless explicitly documented otherwise, `%lhs`, `%rhs`, and `%result` MUST + have matching vector shapes and element types. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). **fp16** uses **aclFloat16** in measured traces. **bf16:** — (no dedicated vec tile ST on this surface). + +| PTO op | RV (CA) | fp32 | fp16 | bf16 | +|--------|---------|------|------|------| +| `pto.vadd` | `RV_VADD` | **7** | **7** | — | +| `pto.vsub` | `RV_VSUB` | **7** | **7** | — | +| `pto.vmul` | `RV_VMUL` | **8** | **8** | — | +| `pto.vdiv` | `RV_VDIV` | **17** | **22** | — | + +--- + +#### Arithmetic + +##### `pto.vadd` + +- **syntax:** `%result = pto.vadd %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i64, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] + src1[i]; +``` + +- **inputs:** `%lhs` and `%rhs` are added lane-wise; `%mask` selects active + lanes. +- **outputs:** `%result` is the lane-wise sum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vsub` + +- **syntax:** `%result = pto.vsub %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i64, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] - src1[i]; +``` + +- **inputs:** `%lhs` is the minuend, `%rhs` is the subtrahend, and `%mask` + selects active lanes. +- **outputs:** `%result` is the lane-wise difference. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vmul` + +- **syntax:** `%result = pto.vmul %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, bf16, f32 (**NOT** i8/ui8) + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] * src1[i]; +``` + +- **inputs:** `%lhs` and `%rhs` are multiplied lane-wise; `%mask` selects + active lanes. +- **outputs:** `%result` is the lane-wise product. +- **constraints and limitations:** The current A5 profile excludes `i8/ui8` + forms from this surface. + +--- + +##### `pto.vdiv` + +- **syntax:** `%result = pto.vdiv %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 only (no integer division) + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] / src1[i]; +``` + +- **inputs:** `%lhs` is the numerator, `%rhs` is the denominator, and `%mask` + selects active lanes. +- **outputs:** `%result` is the lane-wise quotient. +- **constraints and limitations:** Floating-point element types only. Active + denominators containing `+0` or `-0` follow the target's exceptional + behavior. + +--- + +##### `pto.vmax` + +- **syntax:** `%result = pto.vmax %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src0[i] > src1[i]) ? src0[i] : src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` holds the lane-wise maximum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vmin` + +- **syntax:** `%result = pto.vmin %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src0[i] < src1[i]) ? src0[i] : src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` holds the lane-wise minimum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +#### Bitwise + +##### `pto.vand` + +- **syntax:** `%result = pto.vand %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] & src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise AND. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vor` + +- **syntax:** `%result = pto.vor %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] | src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise OR. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vxor` + +- **syntax:** `%result = pto.vxor %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] ^ src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise XOR. +- **constraints and limitations:** Integer element types only. + +--- + +#### Shift + +##### `pto.vshl` + +- **syntax:** `%result = pto.vshl %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] << src1[i]; +``` + +- **inputs:** `%lhs` supplies the shifted value, `%rhs` supplies the per-lane + shift amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. Shift counts + SHOULD stay within `[0, bitwidth(T) - 1]`; out-of-range behavior is target- + defined unless the verifier narrows it further. + +--- + +##### `pto.vshr` + +- **syntax:** `%result = pto.vshr %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] >> src1[i]; // arithmetic for signed, logical for unsigned +``` + +- **inputs:** `%lhs` supplies the shifted value, `%rhs` supplies the per-lane + shift amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. Signedness of the + element type determines arithmetic vs logical behavior. + +--- + +#### Carry Operations + +##### `pto.vaddc` + +- **syntax:** `%result, %carry = pto.vaddc %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Add with carry output. + +```c +for (int i = 0; i < N; i++) { + uint64_t r = (uint64_t)src0[i] + src1[i]; + dst[i] = (T)r; + carry[i] = (r >> bitwidth); +} +``` + +- **inputs:** `%lhs` and `%rhs` are added lane-wise and `%mask` selects active + lanes. +- **outputs:** `%result` is the truncated arithmetic result and `%carry` is the + carry/overflow predicate per lane. +- **A5 types:** `i32`, `si32`, `ui32` +- **constraints and limitations:** This is a carry-chain integer add family. On + the current A5 surface, only 32-bit integer element types are supported. + `%mask` and `%carry` therefore use the same typed-mask granularity as the + data vector family, which on the current documented A5 surface means + `!pto.mask`. + +--- + +##### `pto.vsubc` + +- **syntax:** `%result, %carry = pto.vsubc %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Subtract with per-lane carry output. + +```c +for (int i = 0; i < N; i++) { + dst[i] = src0[i] - src1[i]; + carry[i] = (src0[i] >= src1[i]); +} +``` + +- **inputs:** `%lhs` and `%rhs` are subtracted lane-wise and `%mask` selects + active lanes. +- **outputs:** `%result` is the arithmetic difference and `%carry` is the + per-lane carry predicate. For this subtraction family, active lanes set + `%carry[i] = 1` when the subtraction completes without borrow, and + `%carry[i] = 0` when a borrow occurs. +- **A5 types:** `i32`, `si32`, `ui32` +- **constraints and limitations:** This operation is currently restricted to + the 32-bit integer carry/borrow-chain family. `%mask` and `%carry` + therefore use the same typed-mask granularity as the data vector family, + which on the current documented A5 surface means `!pto.mask`. + +--- + +#### Typical Usage + +```mlir +// Vector addition +%sum = pto.vadd %a, %b, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Element-wise multiply +%prod = pto.vmul %x, %y, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Clamp to range [min, max] +%clamped_low = pto.vmax %input, %min_vec, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%clamped = pto.vmin %clamped_low, %max_vec, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Bit manipulation +%masked = pto.vand %data, %bitmask, %mask : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +``` + + + +### 8. Vec-Scalar Ops + +> **Category:** Vector-scalar operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that combine a vector with a scalar value, applying the scalar to every lane. + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%scalar` is the scalar operand in SSA form. +- `%mask` is the predicate operand. +- `%result` is the destination vector register value. +- For 32-bit scalar forms, the scalar source MUST satisfy the backend's legal + scalar-source constraints for this family. +- For elementwise vec-scalar families whose scalar conceptually matches the + vector element type (`pto.vadds`, `pto.vmuls`, `pto.vmaxs`, + `pto.vmins`, `pto.vlrelu`): + - signed integer vectors accept signed integer scalars with the same width, + and also accept signless `i` + - unsigned integer vectors accept unsigned integer scalars with the same + width, and also accept signless `i` + - signless integer vectors accept signless `i` +- `pto.vshls` and `pto.vshrs` are not part of that rule; their scalar operand + is the shift amount and remains fixed to `i16`. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). **fp16** uses **aclFloat16** in measured traces. **bf16:** —. + +| PTO op | RV (CA) | fp32 | fp16 | bf16 | +|--------|---------|------|------|------| +| `pto.vadds` | `RV_VADDS` | **7** | **7** | — | +| `pto.vmuls` | `RV_VMULS` | **8** | **8** | — | + +--- + +#### Arithmetic + +##### `pto.vadds` + +- **syntax:** `%result = pto.vadds %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` +- **A5 types:** `si8`, `si16`, `si32`, `ui8`, `ui16`, `ui32`, `f16`, `bf16`, `f32` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] + scalar; +``` + +- **inputs:** `%input` is the source vector, `%scalar` is broadcast logically to + each lane, and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise sum. +- **constraints and limitations:** Input vector element type, scalar type, and + result vector element type MUST match. For integer vector forms, `%scalar` + may also use matching-signedness integer or signless `i` with the same + bit width as the vector element type, so it can be fed directly from `arith` + constants. + +--- + +##### `pto.vmuls` + +- **syntax:** `%result = pto.vmuls %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] * scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise product. +- **constraints and limitations:** Supported element types are hardware-family + specific; the current PTO micro Instruction documentation covers the common + numeric cases. For integer vector forms, `%scalar` may use matching-signedness + integer or signless `i` with the same bit width as the vector element + type. + +--- + +##### `pto.vmaxs` + +- **syntax:** `%result = pto.vmaxs %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] > scalar) ? src[i] : scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise maximum. +- **constraints and limitations:** Input and result types MUST match. For + integer vector forms, `%scalar` may use matching-signedness integer or + signless `i` with the same bit width as the vector element type. + +--- + +##### `pto.vmins` + +- **syntax:** `%result = pto.vmins %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] < scalar) ? src[i] : scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise minimum. +- **constraints and limitations:** Input and result types MUST match. For + integer vector forms, `%scalar` may use matching-signedness integer or + signless `i` with the same bit width as the vector element type. + +--- + +#### Shift + +##### `pto.vshls` + +- **syntax:** `%result = pto.vshls %input, %scalar, %mask : !pto.vreg, i16, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] << scalar; +``` + +- **inputs:** `%input` is the value vector, `%scalar` is the uniform `i16` shift + amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. The shift amount + SHOULD stay within the source element width. + +--- + +##### `pto.vshrs` + +- **syntax:** `%result = pto.vshrs %input, %scalar, %mask : !pto.vreg, i16, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] >> scalar; +``` + +- **inputs:** `%input` is the value vector, `%scalar` is the uniform `i16` shift + amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vlrelu` + +- **syntax:** `%result = pto.vlrelu %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : scalar * src[i]; +``` + +- **inputs:** `%input` is the activation vector, `%scalar` is the leaky slope, + and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise leaky-ReLU result. +- **constraints and limitations:** Only `f16` and `f32` forms are currently + documented for `pto.vlrelu`. + +--- + +#### Carry Operations + +##### `pto.vaddcs` + +- **syntax:** `%result, %carry = pto.vaddcs %lhs, %rhs, %carry_in, %mask : !pto.vreg, !pto.vreg, !pto.mask, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Add with carry-in and carry-out. + +```c +for (int i = 0; i < N; i++) { + uint64_t r = (uint64_t)src0[i] + src1[i] + carry_in[i]; + dst[i] = (T)r; + carry_out[i] = (r >> bitwidth); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the value vectors, `%carry_in` is the + incoming carry predicate, and `%mask` selects active lanes. +- **outputs:** `%result` is the arithmetic result and `%carry` is the carry-out + predicate. +- **A5 types:** `i32`, `si32`, `ui32` +- **constraints and limitations:** This is the scalar-extended carry-chain + family. On the current A5 surface, only 32-bit integer element types are + supported. `%carry_in`, `%mask`, and `%carry` therefore all use the same + typed-mask granularity as the data vector family, which on the current + documented A5 surface means `!pto.mask`. + +--- + +##### `pto.vsubcs` + +- **syntax:** `%result, %carry = pto.vsubcs %lhs, %rhs, %carry_in, %mask : !pto.vreg, !pto.vreg, !pto.mask, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Subtract with carry input and output. + +```c +for (int i = 0; i < N; i++) { + dst[i] = src0[i] - src1[i] - (1 - carry_in[i]); + carry_out[i] = (src0[i] >= src1[i] + (1 - carry_in[i])); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the value vectors, `%carry_in` is the + incoming carry predicate, and `%mask` selects active lanes. +- **outputs:** `%result` is the arithmetic result and `%carry` is the + carry predicate after the lane-wise subtraction. For this subtraction family, + active lanes set `%carry[i] = 1` when the subtraction completes without + borrow, and `%carry[i] = 0` when a borrow occurs. +- **A5 types:** `i32`, `si32`, `ui32` +- **constraints and limitations:** This is the scalar-extended borrow-chain + family and is currently restricted to 32-bit integer element types. + `%carry_in`, `%mask`, and `%carry` therefore all use the same typed-mask + granularity as the data vector family, which on the current documented A5 + surface means `!pto.mask`. + +--- + +#### Typical Usage + +```mlir +// Add bias to all elements +%biased = pto.vadds %activation, %bias_scalar, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Scale by constant +%scaled = pto.vmuls %input, %scale, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Clamp to [0, 255] for uint8 quantization +%clamped_low = pto.vmaxs %input, %c0, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> +%clamped = pto.vmins %clamped_low, %c255, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Shift right by fixed amount +%shifted = pto.vshrs %data, %c4, %mask : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> +``` + + + +### 9. Conversion Ops + +> **Category:** Type conversion operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that convert between data types (float/int, narrowing/widening). + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%mask` is the predicate mask that selects active conversion lanes. +- `%result` is the destination vector register value. +- `rnd`, `sat`, and `part` are optional attributes that refine + conversion behavior when the selected source/destination type pair needs + rounding, saturation, or lane placement control. +- The single `pto.vcvt` surface covers float-int, float-float, int-float, and + int-int conversion families. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). Only representative traces below; other `pto.vcvt` conversion pairs depend on the RV lowering in the trace. + +| PTO op | RV (CA) | Note | Latency | +|--------|---------|------|---------| +| `pto.vcvt` | `RV_VCVT_F2F` | f32→f16 | **7** | +| `pto.vci` | — | no vector `RV_*` in sampled `veccore0` trace | — | + +--- + +#### `pto.vci` + +- **syntax:** `%result = pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- **semantics:** Generate a lane-index vector from a scalar seed/index value. +- **inputs:** + `%index` is the scalar seed or base index. +- **outputs:** + `%result` is the generated index vector. +- **constraints and limitations:** + This is an index-generation family, not a numeric conversion. `ORDER` and the + result element type together determine how indices are generated. `%result` + uses an integer element type, and the scalar `%index` type matches that + result element type. + +--- + +#### `pto.vcvt` + +- **syntax:** `%result = pto.vcvt %input, %mask {rnd = "RND", sat = "SAT", part = "PART"} : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Type conversion between float/int types with rounding control. + +```c +for (int i = 0; i < min(N, M); i++) + if (mask[i]) + dst[i] = convert(src[i], T0, T1, rnd); +``` + +- **inputs:** + `%input` is the source vector, `%mask` selects active lanes, and attributes + select rounding, saturation, and output placement when the conversion changes + width or packs into sub-lane positions. +- **outputs:** + `%result` is the converted vector. +- **constraints and limitations:** + Only documented source/destination type pairs are legal. All three + attributes are optional at the surface level, but only the subset meaningful + to the selected conversion kind should be provided. The execution mask must + use the typed-mask granularity that matches the source vector family on the + current surface; there is no `!pto.mask` form in VPTO. + +--- + +##### Rounding Modes + +| Mode | Description | +|------|-------------| +| `R` | Round to nearest, ties to even (default) | +| `A` | Round away from zero | +| `F` | Round toward negative infinity (floor) | +| `C` | Round toward positive infinity (ceil) | +| `Z` | Round toward zero (truncate) | +| `O` | Round to odd | + +--- + +##### Saturation Modes + +| Mode | Description | +|------|-------------| +| `SAT` | Saturate on overflow | +| `NOSAT` | No saturation (wrap/undefined on overflow) | + +--- + +##### Part Modes + +Use `part` when a width-changing conversion writes only one half of each wider +destination lane group. This is typically used in even/odd placement forms such +as `32 -> 16` or `16 -> 32` style conversions. + +| Mode | Description | +|------|-------------| +| `EVEN` | Output to even-indexed lanes | +| `ODD` | Output to odd-indexed lanes | + +--- + +##### Attribute Guidance + +- `rnd` + - Use when the conversion needs an explicit rounding rule, especially for + float-to-int, float-to-float narrowing, or integer-to-float forms that do + not map exactly. +- `mask` + - Use to select which source lanes participate in the conversion. In + width-changing conversions, `mask` works together with `part` / `pp` to + determine which logical lane positions are produced. +- `sat` + - Use when the conversion may overflow the destination range and hardware + exposes a saturating form. +- `part` + - Use for width-changing conversions that select the even or odd half of the + destination packing layout. + +###### Float To Int + +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<32xsi64>` +- `%dst = pto.vcvt %src, %mask {rnd, sat} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xsi32>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {rnd, part} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xsi32>` +- `%dst = pto.vcvt %src, %mask {rnd, sat} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<256xsi8>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<64xsi32>` + +###### Float To Float + +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xbf16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<64xf32>` + +###### Int To Float + +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<128xf16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xsi8>, !pto.mask -> !pto.vreg<128xf16>` +- `%dst = pto.vcvt %src, %mask {rnd} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<128xf16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<64xf32>` +- `%dst = pto.vcvt %src, %mask {rnd} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<64xf32>` + +###### Int To Int + +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<128xui16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<64xui32>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xsi8>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xsi8>, !pto.mask -> !pto.vreg<64xsi32>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<64xui32>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<64xui32>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<64xsi32>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<128xui16>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<128xui16>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<32xsi64>` + +##### A5 Supported Type Matrix + +The table below is only a summary. For exact attribute combinations, use the +per-form entries above as the source of truth. + +| `src \ dst` | `ui8` | `si8` | `ui16` | `si16` | `ui32` | `si32` | `si64` | `f16` | `f32` | `bf16` | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| `ui8` | | | Y | | Y | | | Y | | | +| `si8` | | | | Y | | Y | | Y | | | +| `ui16` | Y | | | | Y | | | | | | +| `si16` | Y | | | | Y | Y | | Y | Y | | +| `ui32` | Y | | Y | Y | | | | | | | +| `si32` | Y | | Y | Y | | | Y | | Y | | +| `si64` | | | | | | | | | | | +| `f16` | Y | Y | | Y | | Y | | | Y | | +| `f32` | | | | Y | | Y | Y | Y | | Y | +| `bf16` | | | | | | Y | | | Y | | + +--- + +##### Width-Changing Conversion Pattern + +For conversions that change width (e.g., f32→f16), use even/odd parts and combine: + +```mlir +// Convert two f32 vectors to one f16 vector +%even = pto.vcvt %in0, %mask {rnd = "R", sat = "SAT", part = "EVEN"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +%odd = pto.vcvt %in1, %mask {rnd = "R", sat = "SAT", part = "ODD"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +%result = pto.vor %even, %odd, %mask : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> +``` + +--- + +#### `pto.vtrc` + +- **syntax:** `%result = pto.vtrc %input, %mask, "RND" : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Truncate/round float to integer-valued float (stays in float type). + +```c +for (int i = 0; i < N; i++) + dst[i] = round_to_int_valued_float(src[i], rnd); +``` + +- **inputs:** + `%input` is the floating-point source vector, `%mask` selects active lanes, + and `RND` selects the truncation/rounding rule. +- **outputs:** + `%result` is still a floating-point vector, but each active lane now carries + an integer-valued floating-point result. +- **constraints and limitations:** + This op does not change the element type. `T` must be `f16`, `f32`, or + `bf16`. `RND` must be one of `R`, `A`, `F`, `C`, or `Z`. `BW` must match the + element width: `b16` for `f16`/`bf16`, `b32` for `f32`. + +**Example:** +```mlir +// Round to nearest integer, keep as float +%rounded = pto.vtrc %input, %mask, "R" : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// input: [1.4, 2.6, -1.5, 3.0] +// output: [1.0, 3.0, -2.0, 3.0] +``` + +--- + +#### Typical Usage + +```mlir +// Quantization: f32 → i8 with saturation +%scaled = pto.vmuls %input, %scale, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> +%quantized = pto.vcvt %scaled, %mask {rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xi32> +// Then narrow i32 → i8 via pack ops + +// Mixed precision: bf16 → f32 for accumulation +%f32_vec = pto.vcvt %bf16_input, %mask {part = "EVEN"} + : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<64xf32> + +// Floor for integer division +%floored = pto.vtrc %ratio, %mask, "F" : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%int_div = pto.vcvt %floored, %mask {rnd = "Z"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xi32> +``` + + + +### 10. Reduction Ops + +> **Category:** Vector reduction operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that reduce a vector to a scalar or per-group result. + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%mask` is the predicate operand `Pg`; inactive lanes do not participate. +- `%result` is the destination vector register value. +- Reduction results are written into the low-significance portion of the + destination vector and the remaining destination bits are zero-filled. + +--- + +#### Full Vector Reductions + +##### `pto.vcadd` + +- **syntax:** `%result = pto.vcadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i64, f16, f32 +- **semantics:** Sum all elements. Result in lane 0, others zeroed. + +```c +T sum = 0; +for (int i = 0; i < N; i++) + sum += src[i]; +dst[0] = sum; +for (int i = 1; i < N; i++) + dst[i] = 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains the reduction result in its low element(s). +- **constraints and limitations:** Some narrow integer forms may widen the + internal accumulation or result placement. If all predicate bits are zero, the + result is zero. + +--- + +##### `pto.vcmax` + +- **syntax:** `%result = pto.vcmax %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Find max element with argmax. The lowest destination element + stores the maximum value, the second-lowest destination element stores the + index of the first maximum, and all remaining elements are zero-filled. + +```c +T mx = -INF; int idx = 0; +for (int i = 0; i < N; i++) + if (src[i] > mx) { mx = src[i]; idx = i; } +dst[0] = mx; +dst[1] = idx; +for (int i = 2; i < N; i++) + dst[i] = 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result[0]` holds the extremum value and `%result[1]` holds the + index. Other destination elements are zero-filled. +- **constraints and limitations:** If there are multiple maxima, the minimum + index is written. For floating-point types, inactive lanes are treated as + `-INF`; if all lanes are inactive, `%result[0]` becomes `-INF`. For integer + types, inactive lanes are treated as the literal minimum value; if all lanes + are inactive, `%result[0]` becomes that literal minimum value. The index is + written into the second destination element slot of the same destination + vector register. + +--- + +##### `pto.vcmin` + +- **syntax:** `%result = pto.vcmin %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Find min element with argmin. The lowest destination element + stores the minimum value, the second-lowest destination element stores the + index of the first minimum, and all remaining elements are zero-filled. + +```c +T mn = INF; int idx = 0; +for (int i = 0; i < N; i++) + if (src[i] < mn) { mn = src[i]; idx = i; } +dst[0] = mn; +dst[1] = idx; +for (int i = 2; i < N; i++) + dst[i] = 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result[0]` holds the extremum value and `%result[1]` holds the + index. Other destination elements are zero-filled. +- **constraints and limitations:** If there are multiple minima, the minimum + index is written. For floating-point types, inactive lanes are treated as + `+INF`; if all lanes are inactive, `%result[0]` becomes `+INF`. For integer + types, inactive lanes are treated as the literal maximum value; if all lanes + are inactive, `%result[0]` becomes that literal maximum value. The index is + written into the second destination element slot of the same destination + vector register. + +--- + +#### Per-VLane (Group) Reductions + +The vector register is organized as **8 VLanes** of 32 bytes each. Group reductions operate within each VLane independently. + +``` +vreg layout (f32 example, 64 elements total): +VLane 0: [0..7] VLane 1: [8..15] VLane 2: [16..23] VLane 3: [24..31] +VLane 4: [32..39] VLane 5: [40..47] VLane 6: [48..55] VLane 7: [56..63] +``` + +##### `pto.vcgadd` + +- **syntax:** `%result = pto.vcgadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Sum within each VLane. 8 results at indices 0, 8, 16, 24, 32, 40, 48, 56 (for f32). + +```c +int K = N / 8; // elements per VLane +for (int g = 0; g < 8; g++) { + T sum = 0; + for (int i = 0; i < K; i++) + sum += src[g*K + i]; + dst[g*K] = sum; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +// For f32: results at dst[0], dst[8], dst[16], dst[24], dst[32], dst[40], dst[48], dst[56] +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one sum per 32-byte VLane group, written + contiguously into the low slot of each group. +- **constraints and limitations:** This is a per-32-byte VLane-group reduction. + Inactive lanes are treated as zero. + +--- + +##### `pto.vcgmax` + +- **syntax:** `%result = pto.vcgmax %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Max within each VLane. + +```c +int K = N / 8; +for (int g = 0; g < 8; g++) { + T mx = -INF; + for (int i = 0; i < K; i++) + if (src[g*K + i] > mx) mx = src[g*K + i]; + dst[g*K] = mx; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one maximum per 32-byte VLane group. +- **constraints and limitations:** Grouping is by hardware 32-byte VLane, not by + arbitrary software subvector. + +--- + +##### `pto.vcgmin` + +- **syntax:** `%result = pto.vcgmin %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Min within each VLane. + +```c +int K = N / 8; +for (int g = 0; g < 8; g++) { + T mn = INF; + for (int i = 0; i < K; i++) + if (src[g*K + i] < mn) mn = src[g*K + i]; + dst[g*K] = mn; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one minimum per 32-byte VLane group. +- **constraints and limitations:** Grouping is by hardware 32-byte VLane, not by + arbitrary software subvector. + +--- + +#### Prefix Operations + +##### `pto.vcpadd` + +- **syntax:** `%result = pto.vcpadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Inclusive prefix sum (scan). + +```c +dst[0] = src[0]; +for (int i = 1; i < N; i++) + dst[i] = dst[i-1] + src[i]; +``` + +**Example:** +```c +// input: [1, 2, 3, 4, 5, ...] +// output: [1, 3, 6, 10, 15, ...] +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` is the inclusive prefix-sum vector. +- **constraints and limitations:** Only floating-point element types are + documented on the current A5 surface here. + +--- + +#### Typical Usage + +```mlir +// Softmax: find max for numerical stability +%max_vec = pto.vcmax %logits, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// max is in lane 0, broadcast it +%max_broadcast = pto.vlds %ub_tmp[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> + +// Row-wise sum using vcgadd (for 8-row tile) +%row_sums = pto.vcgadd %tile, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// Results at indices 0, 8, 16, 24, 32, 40, 48, 56 + +// Full vector sum for normalization +%total = pto.vcadd %values, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// total[0] contains the sum + +// Prefix sum for cumulative distribution +%cdf = pto.vcpadd %pdf, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 11. Compare & Select + +> **Category:** Comparison and conditional selection operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that compare vectors and conditionally select elements. + +#### Common Operand Model + +- `%src0` and `%src1` are source vector operands. +- `%scalar` is the scalar operand for scalar-comparison families. +- `%seed` is the incoming predicate that limits which lanes participate in the + compare. +- `%result` is either a predicate mask (`vcmp`, `vcmps`) or a vector register + (`vsel`, `vselr`, `vselrv2`). + +--- + +#### Comparison Operations + +##### `pto.vcmp` + +- **syntax:** `%result = pto.vcmp %src0, %src1, %seed, "CMP_MODE" : !pto.vreg, !pto.vreg, !pto.mask -> !pto.mask` +- **semantics:** Element-wise comparison, output predicate mask. + +```c +for (int i = 0; i < N; i++) + if (seed[i]) + dst[i] = (src0[i] CMP src1[i]) ? 1 : 0; +``` + +**Compare modes:** + +| Mode | Operation | +|------|-----------| +| `eq` | Equal (==) | +| `ne` | Not equal (!=) | +| `lt` | Less than (<) | +| `le` | Less than or equal (<=) | +| `gt` | Greater than (>) | +| `ge` | Greater than or equal (>=) | + +**Example:** +```mlir +%all_active = pto.pset_b32 "PAT_ALL" : !pto.mask +%lt_mask = pto.vcmp %a, %b, %all_active, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +// lt_mask[i] = 1 if a[i] < b[i] +``` + +- **inputs:** `%src0`, `%src1`, and `%seed`; `CMP_MODE` selects the comparison + predicate. +- **outputs:** `%result` is the generated predicate mask. +- **constraints and limitations:** Only lanes enabled by `%seed` participate. + Integer and floating-point comparisons follow their own element-type-specific + comparison rules. `%seed` and `%result` keep the typed-mask granularity that + matches `%src0` / `%src1`. + +--- + +##### `pto.vcmps` + +- **syntax:** `%result = pto.vcmps %src, %scalar, %seed, "CMP_MODE" : !pto.vreg, T, !pto.mask -> !pto.mask` +- **semantics:** Compare vector against scalar. + +```c +for (int i = 0; i < N; i++) + if (seed[i]) + dst[i] = (src[i] CMP scalar) ? 1 : 0; +``` + +**Example:** +```mlir +%positive_mask = pto.vcmps %values, %c0_f32, %all_active, "gt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +// positive_mask[i] = 1 if values[i] > 0 +``` + +- **inputs:** `%src` is the vector source, `%scalar` is the scalar comparison + value, and `%seed` is the incoming predicate. +- **outputs:** `%result` is the generated predicate mask. +- **constraints and limitations:** For 32-bit scalar forms, the scalar source + MUST satisfy the backend's legal scalar-source constraints for this family. + `%seed` and `%result` keep the typed-mask granularity that matches `%src`. + +--- + +#### Selection Operations + +##### `pto.vsel` + +- **syntax:** `%result = pto.vsel %src0, %src1, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Per-lane select based on mask. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? src0[i] : src1[i]; +``` + +**Example — Conditional assignment:** +```mlir +// dst = mask ? true_vals : false_vals +%result = pto.vsel %true_vals, %false_vals, %condition + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +- **inputs:** `%src0` is the true-path vector, `%src1` is the false-path vector, + and `%mask` selects between them. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** Source vectors and result MUST have matching + vector shapes and element types. `%mask` keeps the typed-mask granularity + that matches the selected vector family. + +--- + +##### `pto.vselr` + +- **syntax:** `%result = pto.vselr %src, %idx : !pto.vreg, !pto.vreg> -> !pto.vreg` +- **semantics:** Lane-select by index vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = src[idx[i]]; +``` + +- **inputs:** `%src` is the source vector. `%idx` is the lane-index vector. +- **outputs:** `%result` is the reordered vector. +- **constraints and limitations:** `%idx` must use integer elements. `%idx` + must have the same lane count as `%src`, and its integer element width must + match the bit width of `%src` element type. + +--- + +##### `pto.vselrv2` + +- **syntax:** `%result = pto.vselrv2 %src0, %src1 : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Variant select form with the same current two-vector operand shape. +- **inputs:** `%src0` and `%src1` are the source vectors. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** This page records the surface shape only. + Lowering MUST preserve the exact A5 variant semantics selected for this form. + +--- + +#### Typical Usage + +```mlir +// Clamp negative values to zero (manual ReLU) +%all = pto.pset_b32 "PAT_ALL" : !pto.mask +%zero = pto.vbr %c0_f32 : f32 -> !pto.vreg<64xf32> +%neg_mask = pto.vcmps %input, %c0_f32, %all, "lt" : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%clamped = pto.vsel %zero, %input, %neg_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Element-wise max via compare+select +%gt_mask = pto.vcmp %a, %b, %all, "gt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +%max_ab = pto.vsel %a, %b, %gt_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Threshold filter +%above_thresh = pto.vcmps %scores, %threshold, %all, "ge" : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%filtered = pto.vsel %scores, %zero, %above_thresh : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +--- + +#### Compare + Select Pattern + +```mlir +// Softmax safe exp: exp(x - max) where x < max returns exp of negative +// but we want to clamp to avoid underflow + +%all = pto.pset_b32 "PAT_ALL" : !pto.mask + +// 1. Compare against threshold +%too_small = pto.vcmps %x_minus_max, %min_exp_arg, %all, "lt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask + +// 2. Clamp values below threshold +%clamped = pto.vsel %min_exp_arg_vec, %x_minus_max, %too_small + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// 3. Safe exp +%exp_result = pto.vexp %clamped, %all : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 12. Data Rearrangement + +> **Category:** In-register data movement and permutation +> **Pipeline:** PIPE_V (Vector Core) + +Operations that rearrange data within or between vector registers without memory access. + +#### Common Operand Model + +- `%lhs` / `%rhs` are source vector register values. +- `%src` is a single source vector register value. +- `%result` is the destination vector register value unless an op explicitly + returns multiple vectors. +- These families do not access UB directly; they only rearrange register + contents. + +--- + +#### Interleave / Deinterleave + +##### `pto.vintlv` + +- **syntax:** `%low, %high = pto.vintlv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg, !pto.vreg` +- **semantics:** Interleave elements from two sources. + +```c +// Interleave: merge even/odd elements from two sources +// low = {src0[0], src1[0], src0[1], src1[1], ...} +// high = {src0[N/2], src1[N/2], src0[N/2+1], src1[N/2+1], ...} +``` + +- **inputs:** `%lhs` and `%rhs` are the two source vectors. +- **outputs:** `%low` and `%high` are the two destination vectors. +- **constraints and limitations:** The two outputs form a paired interleave + result. The PTO micro Instruction representation exposes that pair as two SSA results, and the pair ordering MUST + be preserved. + +--- + +##### `pto.vdintlv` + +- **syntax:** `%low, %high = pto.vdintlv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg, !pto.vreg` +- **semantics:** Deinterleave elements into even/odd. + +```c +// Deinterleave: separate even/odd elements +// low = {src0[0], src0[2], src0[4], ...} // even +// high = {src0[1], src0[3], src0[5], ...} // odd +``` + +- **inputs:** `%lhs` and `%rhs` represent the interleaved source stream in the + current PTO micro Instruction representation. +- **outputs:** `%low` and `%high` are the separated destination vectors. +- **constraints and limitations:** The two outputs form the even/odd + deinterleave result pair, and their ordering MUST be preserved. + +--- + +#### Compress / Expand + +##### `pto.vsqz` + +- **syntax:** `%result = pto.vsqz %src, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Compress — pack active lanes to front. + +```c +int j = 0; +for (int i = 0; i < N; i++) + if (mask[i]) dst[j++] = src[i]; +while (j < N) dst[j++] = 0; +``` + +**Use case:** Sparse data compaction, filtering. + +- **inputs:** `%src` is the source vector and `%mask` selects which elements are + kept. +- **outputs:** `%result` is the compacted vector. +- **constraints and limitations:** This is a reduction-style compaction family. + Preserved element order MUST match source lane order. + +--- + +##### `pto.vusqz` + +- **syntax:** `%result = pto.vusqz %src, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Generate per-lane prefix counts from the governing predicate. + +```c +dst[0] = 0; +for (int i = 1; i < N; i++) + dst[i] = mask[i - 1] ? (dst[i - 1] + 1) : dst[i - 1]; +``` + +- **inputs:** `%mask` is the governing predicate. The current PTO surface keeps + `%src` in the operand list for interface compatibility, but the observable + result semantics are determined by `%mask`. +- **outputs:** `%result[i]` equals the number of active lanes in `%mask[0:i)`, + with `%result[0] = 0`. +- **constraints and limitations:** `T` is currently limited to `si8`, `si16`, + or `si32`. This operation is a predicate-derived counting/rearrangement + primitive rather than a value-placement primitive. The final predicate lane + does not contribute to a later output lane because there is no `dst[N]`. + +--- + +--- + +##### `pto.vselr` + +- **syntax:** `%result = pto.vselr %src, %idx : !pto.vreg, !pto.vreg> -> !pto.vreg` +- **semantics:** Register lane-select with an explicit index vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = src[idx[i]]; +``` + +- **inputs:** `%src` is the source vector. `%idx` is the lane-index vector. +- **outputs:** `%result` is the reordered vector. +- **constraints and limitations:** This page records the rearrangement use of + the family; the compare/select page documents the same name from the predicate + selection perspective. + +--- + +#### Pack / Unpack + +##### `pto.vpack` + +- **syntax:** `%result = pto.vpack %src, "PART" : !pto.vreg -> !pto.vreg<2NxT_narrow>` +- **semantics:** Narrow one wide vector and place the narrowed payload into the + selected half of the result. The other half is filled with zero. + +```c +// e.g., vreg<64xi32> → vreg<128xui16> +for (int i = 0; i < N; i++) + dst[i] = 0; + +if (part == LOWER) { + for (int i = 0; i < N; i++) + dst[i] = truncate(src[i]); +} else { // HIGHER + for (int i = 0; i < N; i++) + dst[N + i] = truncate(src[i]); +} +``` + +- **inputs:** `%src` is the wide source vector. `"LOWER"` and `"HIGHER"` + select whether the narrowed payload lands in the lower or upper half. +- **outputs:** `%result` is the packed narrow vector. +- **constraints and limitations:** Packing is a narrowing conversion with + truncation semantics. Current VPTO surface supports `i32/ui32 -> ui16` and + `i16/ui16 -> ui8`. + +--- + +##### `pto.vsunpack` + +- **syntax:** `%result = pto.vsunpack %src, %part : !pto.vreg, index -> !pto.vreg` +- **semantics:** Sign-extending unpack — narrow to wide (half). + +```c +// e.g., vreg<128xi16> → vreg<64xi32> (one half) +for (int i = 0; i < N/2; i++) + dst[i] = sign_extend(src[part_offset + i]); +``` + +- **inputs:** `%src` is the packed narrow vector and `%part` selects which half + is unpacked. +- **outputs:** `%result` is the widened vector. +- **constraints and limitations:** This is the sign-extending unpack family. + +--- + +##### `pto.vzunpack` + +- **syntax:** `%result = pto.vzunpack %src, %part : !pto.vreg, index -> !pto.vreg` +- **semantics:** Zero-extending unpack — narrow to wide (half). + +```c +for (int i = 0; i < N/2; i++) + dst[i] = zero_extend(src[part_offset + i]); +``` + +- **inputs:** `%src` is the packed narrow vector and `%part` selects which half + is unpacked. +- **outputs:** `%result` is the widened vector. +- **constraints and limitations:** This is the zero-extending unpack family. + +--- + +#### Typical Usage + +```mlir +// AoS → SoA conversion using deinterleave +%even, %odd = pto.vdintlv %interleaved0, %interleaved1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +// Filter: keep only elements passing condition +%pass_mask = pto.vcmps %values, %threshold, %all, "gt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%compacted = pto.vsqz %values, %pass_mask + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Type narrowing via pack +%packed_i16 = pto.vpack %wide_i32, "LOWER" + : !pto.vreg<64xi32> -> !pto.vreg<128xui16> +``` + +--- + +#### V2 Interleave Forms + +##### `pto.vintlvv2` + +- **syntax:** `%result = pto.vintlvv2 %lhs, %rhs, "PART" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **inputs:** `%lhs` and `%rhs` are source vectors and `PART` selects the + returned half of the V2 interleave result. +- **outputs:** `%result` is the selected interleave half. +- **constraints and limitations:** This op exposes only one half of the V2 + result in SSA form. + +##### `pto.vdintlvv2` + +- **syntax:** `%result = pto.vdintlvv2 %lhs, %rhs, "PART" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **inputs:** `%lhs` and `%rhs` are source vectors and `PART` selects the + returned half of the V2 deinterleave result. +- **outputs:** `%result` is the selected deinterleave half. +- **constraints and limitations:** This op exposes only one half of the V2 + result in SSA form. + + + +### 13. DSA/SFU Ops + +> **Category:** Domain-specific accelerator and special function unit operations +> **Pipeline:** PIPE_V (Vector Core) / SFU + +Fused operations, special functions, and UB-to-UB operations that leverage hardware acceleration. + +#### Common Operand Model + +- `%input`, `%lhs`, `%rhs`, `%acc`, and `%alpha` are source SSA values whose + roles are called out per instruction. +- `%mask` is the predicate operand `Pg` when present. +- `%result` is the destination SSA value. +- This page mixes three different backend shapes: pure `vreg -> vreg` ops, + conversion/fusion ops, and UB-to-UB helpers. Each instruction section calls + out which storage model it uses. + +--- + +#### Fused Activation Ops (vreg→vreg) + +##### `pto.vlrelu` + +- **syntax:** `%result = pto.vlrelu %input, %alpha, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Leaky ReLU with scalar alpha. + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : alpha * src[i]; +``` + +- **inputs:** `%input` is the activation vector, `%alpha` is the scalar slope, + and `%mask` selects active lanes. +- **outputs:** `%result` is the leaky-ReLU vector. +- **constraints and limitations:** Only `f16` and `f32` forms are currently + documented for `pto.vlrelu`. + +--- + +##### `pto.vprelu` + +- **syntax:** `%result = pto.vprelu %input, %alpha : !pto.vreg, !pto.vreg -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Parametric ReLU with per-element alpha vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : alpha[i] * src[i]; +``` + +- **inputs:** `%input` is the activation vector and `%alpha` is the per-element + slope vector. +- **outputs:** `%result` is the parametric-ReLU vector. +- **constraints and limitations:** Floating-point element types only on the + current A5 surface. + +--- + +##### `pto.vexpdiff` + +- **syntax:** `%result = pto.vexpdiff %input, %max, "EVEN|ODD" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **A5 types:** input `f16` or `f32`, output `f32` +- **semantics:** Fused exp(x - max) for numerically stable softmax. + +```c +for (int i = 0; i < N; i++) + dst[i] = expf(src[i] - max[i]); +``` + +**Use case:** Softmax numerator computation with numerical stability. + +- **inputs:** `%input` is the source vector and `%max` is the broadcasted + subtraction term. `%part` selects `EVEN` or `ODD` for the + underlying hardware contract. +- **outputs:** `%result` is the fused `exp(input - max)` vector with `f32` + elements. +- **constraints and limitations:** Source vectors must be `f16` or `f32`, the + result vector must be `f32`, and source/result storage width must match. + +--- + +#### Fused Compute+Convert Ops + +##### `pto.vaxpy` + +- **syntax:** `%result = pto.vaxpy %src0, %src1, %alpha : !pto.vreg, !pto.vreg, T -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** AXPY — scalar-vector multiply-add. + +```c +for (int i = 0; i < N; i++) + dst[i] = alpha * src0[i] + src1[i]; +``` + +- **inputs:** `%src0` is the scaled vector, `%src1` is the addend vector, and + `%alpha` is the scalar multiplier. +- **outputs:** `%result` is the fused AXPY result. +- **constraints and limitations:** Floating-point element types only on the + current documented surface. + +--- + +#### Extended Arithmetic + +##### `pto.vmull` + +- **syntax:** `%low, %high = pto.vmull %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.vreg` +- **A5 types:** i32/ui32 (native 32×32→64 widening multiply) +- **semantics:** Widening multiply with high/low results. + +```c +for (int i = 0; i < 64; i++) { + int64_t r = (int64_t)src0_i32[i] * (int64_t)src1_i32[i]; + dst_lo[i] = (int32_t)(r & 0xFFFFFFFF); + dst_hi[i] = (int32_t)(r >> 32); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the source vectors and `%mask` selects + active lanes. +- **outputs:** `%low` and `%high` expose the widened-product low/high parts. +- **constraints and limitations:** The current documented A5 form is the native + widening 32x32->64 integer multiply family. + +--- + +##### `pto.vmula` + +- **syntax:** `%result = pto.vmula %acc, %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Multiply-accumulate. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) + dst[i] = acc[i] + lhs[i] * rhs[i]; +``` + +- **inputs:** `%acc` is the accumulator input, `%lhs` and `%rhs` are the + multiplicands, and `%mask` selects active lanes. +- **outputs:** `%result` is the multiply-accumulate result. +- **constraints and limitations:** `pto.vmula` is a fused multiply-accumulate + operation and is not always interchangeable with separate `vmul` plus `vadd`. + +--- + +#### Index Generation + +##### `pto.vci` + +- **syntax:** `%result = pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- **semantics:** Generate lane index vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = base_index + i; +``` + +**Use case:** Generate indices for gather/scatter, argsort, etc. + +- **inputs:** `%index` is the scalar seed/base index. +- **outputs:** `%result` is the generated index vector. +- **constraints and limitations:** This page documents the arithmetic/indexing + use of the family; the conversion page also records the same opcode for + completeness. + +--- + +#### Sorting Operations + +##### `pto.vbitsort` + +- **syntax:** `pto.vbitsort %dest, %src, %indices, %repeat_times : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, index` +- **semantics:** Sort 32 region proposals by score and materialize sorted + proposal records into `%dest`. +- **inputs:** `%dest` is the UB destination buffer. `%src` is the UB score + buffer. `%indices` is the UB index buffer. `%repeat_times` is the repeat + count; each repeat processes the next adjacent group of 32 scores and 32 + indices. +- **outputs:** This op writes UB memory and returns no SSA value. Each output + record occupies 8 bytes: the upper 4 bytes hold the index and the lower + 4 bytes hold the score. For `f16` score forms, the score uses the lower + 2 bytes of that 4-byte score field and the upper 2 bytes are reserved. +- **constraints and limitations:** `%dest`, `%src`, and `%indices` MUST be + UB-backed pointers and SHOULD satisfy the backend alignment contract expected + by the A5 `VBS32` instruction. Scores are sorted in descending order, so the + highest score is written to the lowest destination address. Equal-score ties + preserve the earlier input proposal first. This is a UB helper, not a pure + `vreg -> vreg` op. + +--- + +##### `pto.vmrgsort4` + +- **syntax:** `pto.vmrgsort4 %dest, %src0, %src1, %src2, %src3, %count, %config : !pto.ptr, !pto.ptr, !pto.ptr, !pto.ptr, !pto.ptr, i64, i64` +- **semantics:** Merge-sort 4 pre-sorted input vectors. +- **inputs:** `%dest` is the UB destination, `%src0..%src3` are the four + pre-sorted UB inputs, `%count` is the number of valid elements, and `%config` + is the operation control word. +- **outputs:** This op writes UB memory and returns no SSA value. +- **constraints and limitations:** Inputs MUST already be sorted according to + the sort order encoded by `%config`. + +--- + +#### Current Implementation Surface Summary + +- `pto.vmull %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.vreg` +- `pto.vmula %acc, %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- `pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- `pto.vbitsort %dest, %src, %indices, %repeat_times : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, index` +- `pto.vmrgsort4 %dest, %src0, %src1, %src2, %src3, %count, %config : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, i64, i64` + +--- + +#### Typical Usage + +```mlir +// Softmax with fused expdiff +%max_broadcast = pto.vlds %ub_max[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> +%exp_stable = pto.vexpdiff %logits, %max_broadcast : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +// Leaky ReLU activation +%activated = pto.vlrelu %linear_out, %alpha_scalar, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Generate indices for argsort +%indices = pto.vci %c0 {order = "ASC"} : i32 -> !pto.vreg<64xi32> +``` + + + +### 14. Arith (Shared MLIR Dialect) + +> **Category:** Shared full scalar `arith` surface used around PTO ops +> **Dialect:** `arith` +> **Upstream Reference:** https://mlir.llvm.org/docs/Dialects/ArithOps/ + +The upstream MLIR `arith` dialect defines primitive arithmetic, comparison, select, and cast operations over signless integer, index, floating-point, and boolean-compatible scalar values. Within PTO micro Instruction code, the full scalar operation surface of `arith` is supported. These ops are used around PTO instructions to build constants, compute offsets and loop bounds, perform general scalar math, derive valid-shape metadata, and form predicates for `scf` control flow. + +These ops are part of the documented PTO micro Instruction surface, but they are not PTO ISA instructions. + +--- + +#### Role in PTO micro Instruction Code + +- materialize scalar constants used by PTO scalar operands and loop bounds +- compute scalar/index offsets for tensor views, partitioning, and dynamic shapes +- perform general scalar integer and floating-point math outside PTO vector/tile payload operations +- derive scalar predicates that guard `scf.if` or `scf.while` +- apply scalar casts, width changes, bitwise ops, and selects without introducing PTO-specific control ops + +Prefer PTO ops for vector or tile payload math. Use `arith` for scalar computation and bookkeeping that surrounds PTO regions. + +--- + +#### Supported Surface + +The documented PTO micro Instruction surface supports the full scalar operation surface of upstream `arith`. The upstream `arith` dialect reference remains authoritative for the exhaustive op-by-op syntax and semantics. The categories below summarize how that support is used in PTO micro Instruction code. + +| Category | Representative Ops | Typical Use in PTO micro Instruction Code | +|----------|--------------------|------------------| +| Constants | `arith.constant` | integer, floating-point, boolean, and `index` constants | +| Integer / Index Arithmetic | `arith.addi`, `arith.subi`, `arith.muli`, `arith.divsi`, `arith.divui`, `arith.ceildivsi`, `arith.ceildivui`, `arith.floordivsi`, `arith.remsi`, `arith.remui` | offsets, bounds, chunk sizes, scalar math | +| Floating-Point Arithmetic | `arith.addf`, `arith.subf`, `arith.mulf`, `arith.divf`, `arith.negf`, `arith.maximumf`, `arith.minimumf`, `arith.maxnumf`, `arith.minnumf` | scalar math around PTO regions | +| Bitwise / Shift Ops | `arith.andi`, `arith.ori`, `arith.xori`, `arith.shli`, `arith.shrsi`, `arith.shrui` | flags, masks, packed scalar fields | +| Comparisons / Select | `arith.cmpi`, `arith.cmpf`, `arith.select`, `arith.maxsi`, `arith.minui` | predicates, clamps, scalar muxes | +| Casts / Width Changes | `arith.index_cast`, `arith.index_castui`, `arith.extsi`, `arith.extui`, `arith.trunci`, `arith.sitofp`, `arith.uitofp`, `arith.fptosi`, `arith.fptoui`, `arith.extf`, `arith.truncf`, `arith.bitcast` | ABI glue, dynamic-shape plumbing, scalar type adaptation | + +--- + +#### Current PTOAS Coverage + +- the current repository examples are still dominated by constants, casts, integer/index arithmetic, compares, and selects because those are the most common surrounding-scalar patterns in existing kernels +- backend-specific tests such as the PTO shared-dialect fixture visibly exercise only a representative subset of `arith` ops in a single path +- the documented PTO micro Instruction source-level contract is nevertheless the full scalar `arith` surface, not just the index-heavy subset that appears most often in current samples + +This section therefore uses representative categories and examples instead of pretending that the supported `arith` surface is limited to the currently most common sample patterns. + +--- + +#### Typical Patterns + +##### Scalar Setup + +```mlir +%c0 = arith.constant 0 : index +%c1 = arith.constant 1 : index +%scale = arith.constant 2.0 : f32 +``` + +##### Dynamic Offset Computation + +```mlir +%vrow = arith.index_cast %valid_row : i32 to index +%chunk = arith.muli %row, %c32 : index +%tail = arith.subi %limit, %chunk : index +``` + +##### General Scalar Arithmetic + +```mlir +%sum_i = arith.addi %lhs_i, %rhs_i : i32 +%sum_f = arith.addf %lhs_f, %rhs_f : f32 +%prod_f = arith.mulf %sum_f, %scale : f32 +``` + +##### Scalar Predicate and Selection + +```mlir +%is_first = arith.cmpi eq, %i, %c0 : index +%active = arith.select %is_first, %first_count, %steady_count : index +``` + +##### Bitwise / Width Adaptation + +```mlir +%flags = arith.andi %flags0, %flags1 : i32 +%wide = arith.extui %flags : i32 to i64 +%shrunk = arith.trunci %wide : i64 to i16 +``` + +--- + +#### Authoring Guidance + +- treat upstream `arith` scalar semantics as the source of truth for supported scalar ops +- keep `arith` values scalar or `index` typed; do not use `arith` as a substitute for PTO vector/tile compute +- use `arith` for general scalar math, scalar comparisons, bitwise operations, and casts around PTO regions, not just for `index` arithmetic +- use `arith.cmpi` / `arith.cmpf` plus `scf.if` / `scf.while` for control flow, not ad hoc control intrinsics +- prefer `arith.index_cast` / `arith.index_castui` at ABI or shape boundaries where `index` is required, but do not read that as a restriction on the rest of scalar `arith` + + + +### 15. SCF (Shared MLIR Dialect) + +> **Category:** Shared structured control flow around PTO regions +> **Dialect:** `scf` +> **Upstream Reference:** https://mlir.llvm.org/docs/Dialects/SCFDialect/ + +The upstream MLIR `scf` dialect defines structured control flow operations with regions, including counted loops, conditional regions, and while-style loops. In PTO micro Instruction code, `scf` is the control shell around PTO ops: it sequences DMA, vector, and tile operations; carries scalar or tile state across iterations; and preserves analyzable control flow for PTO-specific analyses and lowerings. + +These ops are part of the documented PTO micro Instruction surface, but they are shared MLIR control-flow constructs rather than PTO ISA instructions. + +--- + +#### Supported Ops + +| Op | Role in PTO micro Instruction Code | Notes | +|----|------------------------|-------| +| `scf.for` | counted loops and loop-carried values | common structured counted loop form | +| `scf.if` | structured conditional execution | may yield values or act as side-effect-only branch | +| `scf.yield` | region terminator for `for` / `if` / `while` bodies | carries loop or branch results | +| `scf.while` | break-like or stateful loops | useful for source-level structured control | +| `scf.condition` | loop-continue / loop-exit decision for `scf.while` | placed in the "before" region | + +Ops such as `scf.execute_region`, `scf.forall`, or `scf.index_switch` are not part of the documented shared-dialect portion of the PTO micro Instruction surface here. + +--- + +#### Current PTOAS Coverage + +- `scf.for`, `scf.if`, and `scf.yield` are directly exercised in the shared-dialect PTO fixture and appear widely across PTO samples +- PTO synchronization and memory analyses explicitly reason about `scf.for`, `scf.if`, `scf.yield`, and `scf.while` +- `scf.while` and `scf.condition` appear in control-flow samples and are handled in PTO-to-EmitC control-flow lowering, but they are less broadly exercised than `for` / `if` on all backend paths + +--- + +#### Typical Patterns + +##### Counted Loop + +```mlir +scf.for %i = %c0 to %c4 step %c1 { + %offset = arith.muli %i, %c32 : index + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} +``` + +##### Counted Loop with Loop-Carried State + +```mlir +%final_alive = scf.for %i = %c0 to %c4 step %c1 + iter_args(%alive = %true) -> (i1) { + %break_now = arith.cmpi eq, %i, %c2 : index + %next_alive = scf.if %break_now -> (i1) { + scf.yield %false : i1 + } else { + scf.yield %alive : i1 + } + scf.yield %next_alive : i1 +} +``` + +##### Structured Conditional Region + +```mlir +%is_mode_a = arith.cmpi eq, %mode, %c0_i32 : i32 +scf.if %is_mode_a { + pto.tmuls ins(%data, %scale_a : !pto.tile_buf<...>, f32) outs(%data : !pto.tile_buf<...>) +} else { + pto.tadds ins(%data, %bias_b : !pto.tile_buf<...>, f32) outs(%data : !pto.tile_buf<...>) +} +``` + +##### While-Style Break Loop + +```mlir +%final:2 = scf.while (%i = %c0, %alive = %true) : (index, i1) -> (index, i1) { + %lt = arith.cmpi slt, %i, %c4 : index + %go = arith.andi %lt, %alive : i1 + scf.condition(%go) %i, %alive : index, i1 +} do { +^bb0(%i2: index, %alive2: i1): + %next_i = arith.addi %i2, %c1 : index + scf.yield %next_i, %alive2 : index, i1 +} +``` + +--- + +#### Authoring Guidance + +- use `scf.for` for regular counted loops and loop-carried scalar/tile state +- use `scf.if` for structured branching around PTO regions instead of inventing PTO-specific branch ops +- keep region results explicit with `scf.yield`; this is important for PTO analyses that track carried buffers and aliasing +- use `scf.while` only when a counted loop cannot express the control cleanly; `scf.for` remains the more common and better-exercised form in the current repository +- build branch predicates and loop conditions with `arith` ops, not PTO vector masks, unless the control decision truly comes from a scalarized value + +## Supported Data Types + +| Type | Bits | vreg Lanes | Description | +|------|------|-----------|-------------| +| `i8` / `si8` / `ui8` | 8 | 256 | Signless/signed/unsigned 8-bit integer | +| `i16` / `si16` / `ui16` | 16 | 128 | Signless/signed/unsigned 16-bit integer | +| `f16` | 16 | 128 | IEEE 754 half precision | +| `bf16` | 16 | 128 | Brain floating point | +| `i32` / `si32` / `ui32` | 32 | 64 | Signless/signed/unsigned 32-bit integer | +| `f32` | 32 | 64 | IEEE 754 single precision | +| `i64` / `si64` / `ui64` | 64 | 32 | Signless/signed/unsigned 64-bit integer | + +--- + +## Common Patterns + +### Softmax (Numerically Stable) + +```mlir +// 1. Find max +%max_vec = pto.vcmax %logits, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +pto.vsts %max_vec, %ub_tmp[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +%max_bc = pto.vlds %ub_tmp[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> + +// 2. exp(x - max) using fused op +%exp = pto.vexpdiff %logits, %max_bc, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +// 3. Sum +%sum = pto.vcadd %exp, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +pto.vsts %sum, %ub_tmp[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +%sum_bc = pto.vlds %ub_tmp[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> + +// 4. Divide +%softmax = pto.vdiv %exp, %sum_bc, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +### ReLU Variants + +```mlir +// Standard ReLU +%relu = pto.vrelu %input, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Leaky ReLU (scalar alpha) +%lrelu = pto.vlrelu %input, %alpha, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Parametric ReLU (per-element alpha) +%prelu = pto.vprelu %input, %alpha_vec : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +``` + +### Data Layout Conversion + +```mlir +// AoS → SoA (deinterleave) +%x, %y = pto.vldsx2 %ub_xy[%offset], "DINTLV" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +// SoA → AoS (interleave) +pto.vstsx2 %x, %y, %ub_xy[%offset], "INTLV", %all_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.ptr, index, !pto.mask +``` + +--- + +--- + +## Quick Reference by Category + +### Memory Operations + +| Operation | Group | Description | +|-----------|-------|-------------| +| GM→UB DMA | 2 | `pto.copy_gm_to_ubuf` | +| UB→GM DMA | 2 | `pto.copy_ubuf_to_gm` | +| UB→UB Copy | 2 | `pto.copy_ubuf_to_ubuf` | +| Contiguous Load | 3 | `pto.vlds` with `NORM` dist | +| Broadcast Load | 3 | `pto.vlds` with `BRC` family dist | +| Gather | 3 | `pto.vgather2`, `pto.vgatherb` | +| Contiguous Store | 3 | `pto.vsts` with `NORM` dist | +| Scatter | 3 | `pto.vscatter` | + +### Compute Operations + +| Operation | Group | Description | +|-----------|-------|-------------| +| Element-wise Arithmetic | 6, 7 | `pto.vadd`, `pto.vmul`, `pto.vabs`, etc. | +| Scalar Operations | 8 | `pto.vadds`, `pto.vmuls`, etc. | +| Transcendental | 6 | `pto.vexp`, `pto.vln`, `pto.vsqrt`, etc. | +| Reduction | 10 | `pto.vcadd`, `pto.vcmax`, `pto.vcmin` | +| Comparison | 11 | `pto.vcmp`, `pto.vcmps` | +| Selection | 11 | `pto.vsel`, `pto.vselr` | + +### Type & Data Manipulation + +| Operation | Group | Description | +|-----------|-------|-------------| +| Type Conversion | 9 | `pto.vcvt` | +| Interleave/Deinterleave | 12 | `pto.vintlv`, `pto.vdintlv` | +| Interleave/Deinterleave (not A5) | 12 | `pto.vintlvv2`, `pto.vdintlvv2` | + +### Synchronization + +| Operation | Group | Description | +|-----------|-------|-------------| +| Intra-core Sync | 1 | `pto.set_flag`, `pto.wait_flag` | +| Pipeline Buffer Sync | 1 | `pto.get_buf`, `pto.rls_buf` | + +### Scalar & Control Operations + +Group 14 covers the full scalar `arith` surface. The rows below list common PTO micro Instruction patterns rather than an exhaustive partition of `arith` ops. + +| Operation | Group | Description | +|-----------|-------|-------------| +| Scalar Constants | 14 | `arith.constant` | +| Scalar Integer / Index Arithmetic | 14 | `arith.addi`, `arith.subi`, `arith.muli`, `arith.divsi`, `arith.remui`, `arith.ceildivsi`, etc. | +| Scalar Floating-Point Arithmetic | 14 | `arith.addf`, `arith.subf`, `arith.mulf`, `arith.divf`, `arith.maximumf`, etc. | +| Scalar Compare & Select | 14 | `arith.cmpi`, `arith.cmpf`, `arith.select` | +| Scalar Casts / Width Changes | 14 | `arith.index_cast`, `arith.index_castui`, `arith.extsi`, `arith.extui`, `arith.trunci`, `arith.sitofp`, etc. | +| Scalar Bitwise / Shift Ops | 14 | `arith.andi`, `arith.ori`, `arith.xori`, `arith.shli`, `arith.shrsi`, `arith.shrui`, etc. | +| Counted Loops | 15 | `scf.for` | +| Conditional Regions | 15 | `scf.if`, `scf.yield` | +| Break-like Structured Loops | 15 | `scf.while`, `scf.condition`, `scf.yield` | diff --git a/docs/sample.pto b/docs/sample.pto new file mode 100644 index 000000000..956b7ba4c --- /dev/null +++ b/docs/sample.pto @@ -0,0 +1,57 @@ +module attributes {pto.target_arch = "a5"} { + func.func @abs_kernel_2d(%arg0: memref, %arg1: memref) { + %c4096_i64 = arith.constant 4096 : i64 + %c0_i64 = arith.constant 0 : i64 + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [0], sizes: [%c32, %c32], strides: [%c32, %c1] {layout = #pto.layout} : memref to memref> + %reinterpret_cast_0 = memref.reinterpret_cast %arg1 to offset: [0], sizes: [%c32, %c32], strides: [%c32, %c1] {layout = #pto.layout} : memref to memref> + %memspacecast = memref.memory_space_cast %arg0 : memref to memref + %0 = builtin.unrealized_conversion_cast %memspacecast : memref to !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)> + %1 = llvm.extractvalue %0[1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)> + %2 = llvm.inttoptr %c0_i64 : i64 to !llvm.ptr<6> + %3 = arith.index_castui %c32 : index to i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c4_i64 = arith.constant 4 : i64 + %4 = arith.muli %3, %c32_i64 : i64 + %5 = arith.muli %c1_i64, %4 : i64 + %6 = arith.muli %5, %c4_i64 : i64 + %7 = arith.muli %4, %c4_i64 : i64 + %8 = arith.muli %3, %c4_i64 : i64 + %c128_i64 = arith.constant 128 : i64 + %9 = llvm.getelementptr %1[%c0_i64] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, i8 + a5vm.set_loop2_stride_outtoub %6, %c4096_i64 : i64, i64 + a5vm.set_loop1_stride_outtoub %7, %c4096_i64 : i64, i64 + a5vm.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + a5vm.copy_gm_to_ubuf %9, %2, %3, %3, %c0_i64, %3, %8, %c0_i64, %c0_i64, %c0_i64, %c128_i64, %c128_i64 {a5vm.element_type = "u32", data_select_bit = false, layout = "nd", ub_pad = false} : !llvm.ptr<1>, !llvm.ptr<6>, i64, i64, i64, i64, i64, i64, i64, i64, i64, i64 + a5vm.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + a5vm.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + %10 = llvm.inttoptr %c4096_i64 : i64 to !llvm.ptr<6> + %c0 = arith.constant 0 : index + %11 = arith.muli %c32, %c32 : index + %c64 = arith.constant 64 : index + %12 = arith.index_castui %11 : index to i32 + pto.vecscope { + %16 = scf.for %arg3 = %c0 to %11 step %c64 iter_args(%arg4 = %12) -> (i32) { + %mask, %scalar_out = a5vm.plt_b32 %arg4 : i32 -> !a5vm.mask, i32 + %17 = a5vm.vlds %2[%arg3] : !llvm.ptr<6> -> !a5vm.vreg<64xf32> + %18 = a5vm.vabs %17, %mask {mode = "MODE_ZEROING"} : !a5vm.vreg<64xf32>, !a5vm.mask -> !a5vm.vreg<64xf32> + a5vm.vsts %18, %10[%arg3], %mask : !a5vm.vreg<64xf32>, !llvm.ptr<6>, !a5vm.mask + scf.yield %scalar_out : i32 + } + } + a5vm.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + a5vm.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + %memspacecast_1 = memref.memory_space_cast %arg1 : memref to memref + %13 = builtin.unrealized_conversion_cast %memspacecast_1 : memref to !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)> + %14 = llvm.extractvalue %13[1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)> + %15 = llvm.getelementptr %14[%c0_i64] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, i8 + a5vm.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + a5vm.set_loop1_stride_ubtoout %c4096_i64, %c4096_i64 : i64, i64 + a5vm.set_loop2_stride_ubtoout %c4096_i64, %c4096_i64 : i64, i64 + a5vm.copy_ubuf_to_gm %10, %15, %3, %3, %c0_i64, %c32_i64, %8, %c0_i64, %c128_i64, %c128_i64 {a5vm.element_type = "u32", layout = "nd"} : !llvm.ptr<6>, !llvm.ptr<1>, i64, i64, i64, i64, i64, i64, i64, i64 + a5vm.pipe_barrier "PIPE_ALL" + return + } +} diff --git a/docs/tilelang-dsl-syntax-sugar-proposals.md b/docs/tilelang-dsl-syntax-sugar-proposals.md new file mode 100644 index 000000000..16661a43c --- /dev/null +++ b/docs/tilelang-dsl-syntax-sugar-proposals.md @@ -0,0 +1,404 @@ +# TileLang DSL Syntax Sugar Proposals + +## Overview + +This document proposes syntax sugar enhancements for the TileLang Python DSL to improve programming ergonomics while maintaining close correspondence with the underlying VPTO IR. The current DSL design closely mirrors VPTO instructions, which can lead to verbose and error-prone code. These proposals aim to provide higher-level abstractions that compile down to the existing VPTO operations. + +## Current Usability Challenges + +### 1. **Low-Level Pointer Operations** +```python +# Current: manual byte offset management +ub_in = pto.castptr(0, pto.ptr(pto.f32, MemorySpace.UB)) +ub_out = pto.castptr(4096, pto.ptr(pto.f32, MemorySpace.UB)) +next_ptr = pto.addptr(ub_ptr, 4096) +``` +**Problem**: Users must manage byte offsets and memory spaces manually. + +### 2. **Verbose Copy Operations** +The `pto.copy_ubuf_to_ubuf` / `pto.mte_ub_ub` operand contract is low-level: +- source pointer, destination pointer, `sid` +- `n_burst`, `len_burst`, `src_gap`, `dst_gap` + +**Problem**: Correctly setting burst and gap parameters is error-prone, especially for multi-dimensional data. + +### 3. **Precise Mask Type Matching** +```python +# Must ensure mask granularity matches element type +mask32 = pto.pset_b32("PAT_ALL") # f32 requires b32 mask +mask16 = pto.pset_b16("PAT_ALL") # f16 requires b16 mask +``` +**Problem**: Type error messages are not intuitive and easy to confuse. + +### 4. **Strict Vector Scope Requirements** +```python +# strict_vecscope requires explicit capture of all variables +with pto.strict_vecscope(src_ptr, dst_ptr, start, end) as (s, d, lb, ub): + # Can only use captured variables +``` +**Problem**: Increases boilerplate code, especially when multiple variables need capture. + +### 5. **Manual Synchronization Management** +```python +pto.set_flag(PIPE.MTE2, PIPE.V, EVENT.ID0) +pto.wait_flag(PIPE.MTE2, PIPE.V, EVENT.ID0) +``` +**Problem**: Easy to forget synchronization or use wrong event IDs. + +### 6. **Byte Offsets vs. Element Indices** +```python +# Need to calculate byte offsets +vec = pto.vlds(ub_ptr, lane * 256) # Assuming f32, 4 bytes per element +``` +**Problem**: Users must understand underlying memory layout. + +## Proposed Syntax Sugar Enhancements + +### 1. **Array View Abstraction** + +#### Current API +```python +# Low-level pointer operations +ub_ptr = pto.castptr(0, pto.ptr(pto.f32, MemorySpace.UB)) +vec = pto.vlds(ub_ptr, 64 * 4) # Load 64th f32 element +``` + +#### Proposed Syntax Sugar +```python +# Create array views +ub_array = pto.ub_array(256, pto.f32, base_offset=0) # 256-element f32 UB array +gm_array = pto.gm_array(1024, pto.f32, src) # GM pointer array view + +# Element access with automatic offset calculation +element = ub_array[64] # Get 64th element (auto-calculates byte offset) +slice = ub_array[128:256] # Slice operation + +# Array assignment (compiles to appropriate copy operations) +ub_array[0:64] = gm_array[0:64] # Compiles to copy_gm_to_ubuf + +# Multi-dimensional arrays +ub_2d = pto.ub_array((256, 128), pto.f32) # 2D array +row = ub_2d[32, :] # Row slice +col = ub_2d[:, 64] # Column slice +``` + +#### Implementation Notes +- `ub_array[64]` → `pto.vlds(ub_ptr, 64 * sizeof(f32))` +- `ub_array[0:64] = gm_array[0:64]` → Appropriate `copy_gm_to_ubuf` call with stride calculations +- Array views are compile-time constructs with no runtime overhead + +### 2. **Simplified Copy Operations** + +#### Current API +```python +pto.copy_gm_to_ubuf(src, dst, 0, 32, 128, 0, 0, False, 0, 128, 128) +``` + +#### Proposed Syntax Sugar +```python +# Full array copy +pto.copy_gm_to_ub(gm_array, ub_array) + +# Slice copy with automatic stride calculation +pto.copy_gm_to_ub(gm_array[0:64], ub_array[128:192]) + +# Copy with element count +pto.copy_gm_to_ub(gm_array, ub_array, count=64) + +# Transpose copy +pto.copy_gm_to_ub(gm_array, ub_array, transpose=True) + +# Multi-dimensional copy with automatic stride inference +pto.copy_gm_to_ub(gm_2d[0:32, :], ub_2d[:, 0:64]) + +# Chained operations +(pto.copy_gm_to_ub(gm_array, ub_array) + .then(pto.copy_ub_to_ub(ub_array, ub_temp)) + .then(pto.copy_ub_to_gm(ub_temp, dst_array))) +``` + +### 3. **Automatic Mask Inference** + +#### Current API +```python +# Must specify mask type explicitly +mask32 = pto.pset_b32("PAT_ALL") +vec_f32 = pto.vlds(ptr, offset) +out = pto.vabs(vec_f32, mask32) +``` + +#### Proposed Syntax Sugar +```python +# Automatic mask type inference +mask = pto.pset("PAT_ALL") # Inferred as mask_b32 for f32 vectors +out = pto.vabs(vec_f32, mask) # Type-safe, auto-matched + +# Vector method syntax (more Pythonic) +out = vec_f32.abs(mask="PAT_ALL") +out = vec_f32.add(other_vec, mask=pto.pset("PAT_EVEN")) +out = vec_f32.max(scalar, mask="PAT_ALL") + +# Mask creation from comparison +mask = vec_f32 >= pto.f32(0.0) # Creates appropriate mask_b32 +mask = vec_f32 < threshold # Auto-infers mask type + +# Mask operations with auto-typing +combined = mask1 & mask2 # Bitwise AND with type preservation +inverted = ~mask # Logical NOT +``` + +### 4. **Simplified Synchronization Primitives** + +#### Current API +```python +pto.set_flag(PIPE.MTE2, PIPE.V, EVENT.ID0) +# ... computation ... +pto.wait_flag(PIPE.MTE2, PIPE.V, EVENT.ID0) +``` + +#### Proposed Syntax Sugar +```python +# Context manager for automatic synchronization +with pto.sync_between(PIPE.MTE2, PIPE.V, event=EVENT.ID0): + # set_flag called on entry, wait_flag on exit + pto.copy_gm_to_ub(src, dst) + compute_block() + +# Decorator for function-level synchronization +@pto.synchronized(from_pipe=PIPE.MTE2, to_pipe=PIPE.V) +def compute_block(): + # Automatic synchronization before and after + pass + +# Pipeline synchronization chain +with pto.pipeline([ + (PIPE.MTE2, PIPE.V, EVENT.ID0), + (PIPE.V, PIPE.MTE3, EVENT.ID1), + (PIPE.MTE3, PIPE.S, EVENT.ID2) +]): + # Multi-stage synchronization + stage1() + stage2() + stage3() +``` + +### 5. **Element-Level Indexing Operations** + +#### Current API +```python +# Byte offset calculation required +vec = pto.vlds(ub_ptr, lane * 256) # Need to know f32 is 4 bytes +``` + +#### Proposed Syntax Sugar +```python +# Element-level indexing +vec = pto.vlde(ub_array, lane) # Automatic byte offset calculation +pto.vste(vec, ub_array, lane) # Element-level store + +# Array view methods +vec = ub_array.load_element(lane) +ub_array.store_element(lane, vec) + +# Batch operations +vectors = ub_array.load_elements([0, 64, 128, 192]) +ub_array.store_elements([256, 320, 384], vectors) + +# Strided access +stride = ub_array.load_stride(start=0, end=1024, step=64) +``` + +### 6. **Type Inference Simplification** + +#### Current API +```python +# Explicit type annotations required +remaining: pto.i32 = 1024 +# or +remaining = pto.i32(1024) +``` + +#### Proposed Syntax Sugar +```python +# Automatic type inference for constants +remaining = pto.constant(1024) # Inferred as i32 or i64 from context +step = pto.constant(64, type=pto.i32) # Explicit type specification + +# Typed range with automatic inference +for i in pto.range(0, 1024, 64): # i automatically gets correct machine type + # i is pto.i32 + +# Function argument type inference +@pto.vkernel +def kernel(x): # Type inferred from usage + return x * pto.constant(2) # x type inferred from multiplication + +# Variable type inference from operations +result = pto.constant(10) + pto.constant(20) # result is pto.i32 +``` + +### 7. **More Flexible Vector Scopes** + +#### Current API +```python +# Explicit capture required +with pto.strict_vecscope(src_ptr, dst_ptr, start, end) as (s, d, lb, ub): + for i in range(lb, ub, step): + vec = pto.vlds(s, i) + pto.vsts(vec, d, i, mask) +``` + +#### Proposed Syntax Sugar +```python +# Automatic variable capture +with pto.vector_scope(): + # Variables used in scope are automatically captured + for i in pto.range(start, end, step): + vec = src_array.load_element(i) + dst_array.store_element(i, vec.abs()) + +# Decorator for vectorized functions +@pto.vectorize +def compute_element(src, dst, index): + vec = src.load_element(index) + dst.store_element(index, vec.abs()) + +# Apply vectorized function across range +pto.vector_map(compute_element, src_array, dst_array, range(0, 1024, 64)) + +# Lambda support +pto.vector_map(lambda x: x.abs(), src_array, dst_array) +``` + +### 8. **Built-in Utility Functions** + +#### Common Pattern Encapsulation +```python +# Vector map/reduce operations +result = pto.vector_map(abs, src_array, dst_array) # Element-wise mapping +sum = pto.vector_reduce(add, array) # Reduction +max_val = pto.vector_reduce(max, array) # Maximum reduction + +# Vector zip/unzip +zipped = pto.vector_zip(src1, src2, dst) # Interleave +unzipped1, unzipped2 = pto.vector_unzip(src, dst1, dst2) # Deinterleave + +# Mathematical functions +result = pto.vector_sin(array) +result = pto.vector_exp(array) +result = pto.vector_relu(array) +result = pto.vector_sigmoid(array) + +# Statistical operations +mean = pto.vector_mean(array) +variance = pto.vector_variance(array) +min_val, max_val = pto.vector_minmax(array) + +# Linear algebra (small-scale) +dot_product = pto.vector_dot(vec1, vec2) +norm = pto.vector_norm(array) +``` + +## Implementation Strategy + +These syntax sugar enhancements can be implemented through: + +1. **Python Decorators and Context Managers**: For synchronization and vector scopes +2. **Wrapper Classes**: `UBArray`, `GMArray`, `Vector` classes that encapsulate low-level operations +3. **Operator Overloading**: Support for `[]`, `:`, arithmetic operators on wrapper classes +4. **Type Inference System**: Context-based machine type inference +5. **Compile-time Transformation**: Conversion of high-level syntax to low-level VPTO operations before IR generation + +## Compatibility with VPTO IR + +**Key Principle**: All syntax sugar must ultimately lower to existing VPTO operations. + +### Lowering Examples + +| Syntax Sugar | VPTO IR Equivalent | +|--------------|-------------------| +| `ub_array[64]` | `pto.vlds(ub_ptr, 64 * sizeof(f32))` | +| `pto.copy_gm_to_ub(src_array, dst_array)` | Appropriate `copy_gm_to_ubuf` call with calculated strides | +| `with pto.sync_between(...):` | `set_flag` + `wait_flag` pair | +| `mask = vec_f32 >= pto.f32(0.0)` | `pto.pge_b32(vec_f32, pto.f32(0.0))` | +| `vec_f32.abs(mask="PAT_ALL")` | `pto.vabs(vec_f32, pto.pset_b32("PAT_ALL"))` | + +## Prioritization + +### High Priority (Immediate Value) +1. Array view abstraction +2. Simplified copy operations +3. Automatic mask inference + +### Medium Priority (Significant Ergonomics Improvement) +4. Element-level indexing +5. Type inference simplification +6. Flexible vector scopes + +### Low Priority (Advanced Features) +7. Enhanced synchronization primitives +8. Built-in utility functions + +## Migration Path + +The existing low-level API will remain available for performance-critical code or direct VPTO IR correspondence. Syntax sugar will be provided as an optional layer that can be mixed with low-level operations. + +```python +# Mixed usage example +@pto.vkernel +def mixed_kernel(src: pto.ptr(pto.f32, MemorySpace.GM), + dst: pto.ptr(pto.f32, MemorySpace.GM)): + # Low-level: manual pointer setup + ub_in = pto.castptr(0, pto.ptr(pto.f32, MemorySpace.UB)) + + # High-level: array view for computation + ub_array = pto.ub_array(256, pto.f32, base_ptr=ub_in) + + # Mixed: low-level copy, high-level computation + pto.copy_gm_to_ubuf(src, ub_in, 0, 32, 128, 0, 0, False, 0, 128, 128) + + with pto.vector_scope(): + for i in pto.range(0, 256, 64): + vec = ub_array.load_element(i) + result = vec.abs(mask="PAT_ALL") + ub_array.store_element(i, result) + + # Low-level: copy back + pto.copy_ubuf_to_gm(ub_in, dst, 0, 32, 128, 0, 128, 128) +``` + +## Next Steps + +1. **Prototype Implementation**: Start with array view abstraction and simplified copy operations +2. **User Feedback**: Gather feedback from performance engineers on the proposed syntax +3. **Gradual Rollout**: Implement enhancements in phases, starting with high-priority items +4. **Documentation**: Update DSL guide with syntax sugar examples and migration guides +5. **Testing**: Ensure all syntax sugar correctly lowers to VPTO IR and maintains performance + +These enhancements will significantly improve the TileLang DSL's usability while maintaining the close correspondence with underlying VPTO IR that performance engineers require. + +1. 软件流水线(Software Pipelining)的表达成本 +在 NPU 上写 Vector 级算子,最难的往往不是数值计算,而是利用 UB (Unified Buffer) 进行 Double/Multi-Buffering(乒乓缓存),并手动排布内存搬运与计算的流水线。 + +现状挑战:如果开发者全靠手写 set_flag、wait_flag,以及手动维护 Ping-Pong 缓冲的偏移量,代码会迅速膨胀且极易死锁或读写冲突。 + +优化建议:DSL 在保留底层原语的同时,可以提供稍微高级一点的流水线抽象。例如,引入 pto.CircularBuffer(tile, num_stages=2) 的概念,让开发者可以专注于“当前 stage 的计算”,而由底层生成器自动完成不同 stage 的指针轮转和 Flag 同步。 + +2. Python 宿主变量 vs MLIR SSA 变量的心智模型边界 +因为 DSL 的本质是用 Python 元编程来生成 MLIR(静态图),开发者在写代码时很容易混淆“Python 运行期的值”和“NPU 运行期的值”。 + +现状挑战:手册中提到“变量的自动合并”(比如 if 分支产生合并),这涉及到复杂的 SSA 转换。特别是在 for 循环中,**循环携带状态(Loop-carried state)**的处理往往是个痛点。如果开发者在循环外定义了一个 Python 列表或字典,在循环内去修改它,这在生成 MLIR 的 scf.for 时是无法正确映射的。 + +优化建议:需要有极其明确的类型系统提示或语法边界,强制区分编译期求值的变量(Meta-variables)和生成的 MLIR Value。可以考虑借鉴 Triton 的方式,提供类似 tl.constexpr 的装饰或类型,让开发者清楚哪些分支在生成 MLIR 时会被静态展开,哪些会真正生成 scf.if。 + +3. 地址计算(Address Generation)的易错性 +即使是对底层开发者,手动计算字节偏移也是痛苦且容易出 Bug 的。 + +现状挑战:i * cols * 4 这种强依赖 f32 占用 4 字节的硬编码,在泛型算子开发中会带来负担(比如想写一个同时兼容 f16 和 f32 的模板算子)。 + +优化建议:提供基于语义的视图(View)操作。保留控制力不代表必须算字节。可以提供类似 tile.get_vector_slice(row_idx, vec_idx) 的接口,它在内部自动 Emit(发射)对应的 MLIR 乘法和加法指令来计算 offset。这不仅防呆,还能让生成的 MLIR 结构更规范。 + +4. Mask 的隐式推导(针对边界处理) +NPU 算子经常要处理尾部不对齐的数据(Tail processing)。 + +优化建议:虽然底层需要具体的 Mask 寄存器配置(如 PAT_ALL),但在 for 循环的最后一步边界处理时,能否提供一个类似 pto.make_mask(remaining_elements) 的宏/内联函数?让它在生成 MLIR 时,自动展开为对应的硬件 plt_b32 等指令,这样可以大幅减少手写冗长边界判断的样板代码。 diff --git a/docs/vpto-spec.md b/docs/vpto-spec.md new file mode 100644 index 000000000..e2dabbeb6 --- /dev/null +++ b/docs/vpto-spec.md @@ -0,0 +1,1478 @@ +# PTO micro Instruction Spec — Merged Draft (A5) + +> **Status:** DRAFT for review +> **Base:** [vpto-spec.md](https://github.com/mouliangyu/PTOAS/blob/feature-vpto-backend/docs/vpto-spec.md) (2026-03-20) +> **Updated:** 2026-04-30 + +--- + +## Part I: Architecture Overview + +### Overview + +This document defines the PTO micro Instruction, a compiler-internal and externally facing specification designed to represent vector compute kernels within the PTO architecture. Much like NVVM provides a robust IR for GPU architectures, the PTO micro Instruction serves as the direct bridge between high-level programming models and the underlying hardware ISA, providing a precise, architecture-aware representation of vector workloads explicitly designed for the Ascend 950 architecture. + +#### Position in the Stack and Layer Modeled + +The PTO micro Instruction operates as an explicit intermediate representation within the PTO compiler stack. It is designed to accurately express the user-visible architectural information needed for Ascend 950 kernels, including vector lane organization, memory space hierarchy, synchronization, and hardware-specific fusion semantics. + +#### PTO Instruction Modes and Compilation Flows + +Within the end-to-end PTO software stack, PTO instructions may appear in three closely related authoring or lowering modes: + +- **PTO Tile Instruction**: tile-oriented PTO code that serves as a nano-kernel encapsulation of tile instructions, primarily expressing computation and data movement in terms of tile buffers, tile shapes, and tile-local layout. +- **PTO micro Instruction**: vector-execution-oriented PTO code that makes DMA setup, vector registers, masks, synchronization, and `__VEC_SCOPE__` boundaries explicit. This document is centered on this mode. +- **PTO Tile+micro Instruction**: a hybrid PTO form that keeps tile-level orchestration while embedding explicit micro-instruction regions where direct vector-pipeline control is required. + +From these PTO instruction forms, the stack can proceed along two main compilation flows: + +- **CCE generation flow**: PTO ISA is lowered into a CCE-oriented representation, which is then compiled by the BiSheng toolchain into Ascend device binaries. +- **Bytecode generation flow**: PTO ISA is emitted as bytecode, which is then compiled by the BiSheng toolchain into Ascend device binaries. + +```text +High-level frameworks / DSLs / library kernels + | + v + +----------------------------------+ + | PTO ISA layer | + | | + | (1) PTO Tile Instruction | + | (2) PTO micro Instruction | + | (3) PTO Tile+micro Instruction | + +----------------+-----------------+ + | + +------------+------------+ + | | + v v + +-------------------------+ +-------------------------+ + | Path A: generate CCE | | Path B: generate | + | (CCE-oriented form) | | bytecode | + +------------+------------+ +------------+------------+ + | | + v v + +-----------------------------------------------+ + | BiSheng compiler | + +---------------------------+-------------------+ + | + v + +-----------------------------+ + | Ascend device binaries | + +-----------------------------+ +``` + +#### Why External Developers Read or Author PTO micro Instruction + +While the majority of users will interact with the PTO architecture via higher-level frameworks, external developers may need to read or author PTO micro Instruction directly for several key reasons: + +- Custom Toolchain Development: build custom compiler frontends or domain-specific languages (DSLs) that target the Ascend 950 architecture with maximum hardware utilization. +- Performance Engineering: inspect the output of high-level compiler passes, verify fine-grained optimization behaviors, and pinpoint performance bottlenecks at the architectural level. +- Micro-Optimization: hand-author highly optimized, critical mathematical kernels using a stable, precise IR when higher-level abstractions cannot achieve the theoretical peak performance of the hardware. + +#### Relationship to CCE + +The PTO micro Instruction is designed to express the full semantic capabilities of the Compute Cube Engine (CCE), but with significant structural and pipeline advantages for compiler development. + +- Bypassing the C/Clang Pipeline: while CCE heavily relies on C/C++ extensions parsed by Clang, the PTO micro Instruction operates entirely independently of the C language frontend. By bypassing Clang AST generation and frontend processing, utilizing the PTO micro Instruction significantly reduces overall compilation time and memory overhead. +- Enhanced IR Verification: because the PTO micro Instruction is a strongly typed, SSA-based (Static Single Assignment) compiler IR rather than a C-wrapper API, it provides a much more rigorous and detailed IR verification process. Structural inconsistencies, invalid memory access patterns, and operand type mismatches are caught immediately with precise, explicit diagnostic feedback, providing developers with much higher visibility into kernel correctness than traditional CCE error reporting. + +#### Intended Audience + +This document is written for compiler engineers, library writers, and advanced performance architects. We expect the reader to have a working understanding of modern compiler infrastructure, specifically MLIR, the principles of Static Single Assignment (SSA) form, and a deep understanding of the vector-processing capabilities of the Ascend 950 architecture. + +### Getting Started + +The PTO micro Instruction is architected as a performance-critical layer within the compiler stack, specifically designed to exploit the **Decoupled Access-Execute** (DAE) nature of the Ascend 950 hardware. + +#### Hardware Pipeline Modeling + +The IR is structured to mirror the three primary hardware pipelines of the Ascend 950 architecture. Correct PTO micro Instruction authoring requires managing the interaction between these asynchronous units: + +**MTE2** (Memory Transfer Engine - Inbound): Responsible for moving data from Global Memory (GM) to the Unified Buffer (UB). + +**Vector Core** (Computation): The primary engine for executing SIMD operations on data stored in UB. + +**MTE3** (Memory Transfer Engine - Outbound): Responsible for moving processed data from UB back to GM. + +#### Architecture Detail: Vector Lane (VLane) + +The vector register is organized as **8 VLanes** of 32 bytes each. A VLane is the atomic unit for group reduction operations. + +``` +vreg (256 bytes total): +┌─────────┬─────────┬─────────┬─────┬─────────┬─────────┐ +│ VLane 0 │ VLane 1 │ VLane 2 │ ... │ VLane 6 │ VLane 7 │ +│ 32B │ 32B │ 32B │ │ 32B │ 32B │ +└─────────┴─────────┴─────────┴─────┴─────────┴─────────┘ +``` + +Elements per VLane by data type: + +| Data Type | Elements/VLane | Total Elements/vreg | +|-----------|---------------|-------------------| +| i8/si8/ui8 | 32 | 256 | +| i16/si16/ui16/f16/bf16 | 16 | 128 | +| i32/si32/ui32/f32 | 8 | 64 | +| i64/si64/ui64 | 4 | 32 | + +#### Memory and Synchronization Model + +The PTO micro Instruction enforces a strict memory hierarchy. The Unified Buffer (UB) is the only valid operand source for vector compute instructions. Consequently, the architecture of a PTO micro Instruction program is defined by the explicit management of data movement: + +**Address Space Isolation**: The IR uses `!pto.ptr` to distinguish between GM (`!pto.ptr`) and UB (`!pto.ptr`). The verifier ensures that vector compute operations do not access GM directly; data must first be moved into UB. + +**UB Capacity**: The Unified Buffer provides 256KB of on-chip SRAM (also referred to as "vecTile"). + +**Data Flow**: + +``` +┌─────────────────────────────────────────────┐ +│ Global Memory (GM) │ +│ (Off-chip HBM/DDR) │ +└─────────────────────┬───────────────────────┘ + │ DMA (MTE2 inbound / MTE3 outbound) +┌─────────────────────▼───────────────────────┐ +│ Unified Buffer (UB) │ +│ (On-chip SRAM, 256KB) │ +└─────────────────────┬───────────────────────┘ + │ Vector Load/Store (PIPE_V) +┌─────────────────────▼───────────────────────┐ +│ Vector Register File (VRF) │ +│ vreg (256B each) + mask (256-bit each) │ +└─────────────────────────────────────────────┘ +``` + +1. **GM → UB**: DMA transfer via MTE2 (`pto.mte_gm_ub`) +2. **UB → vreg**: Vector Load instructions (`pto.vlds`, `pto.vldsx2`, etc.) +3. **vreg → vreg**: Compute instructions (`pto.vadd`, `pto.vmul`, etc.) +4. **vreg → UB**: Vector Store instructions (`pto.vsts`, `pto.vstsx2`, etc.) +5. **UB → GM**: DMA transfer via MTE3 (`pto.mte_ub_gm`) + +The grouped DMA surface in this specification covers `pto.mte_gm_ub` +(GM→UB), `pto.mte_ub_gm` (UB→GM), and `pto.mte_ub_ub` / `pto.mte_ub_l1` +(UB→UB or UB→CBUF). + +**Load/Store Access Patterns**: + +For UB↔vreg data movement, besides contiguous load/store, the architecture provides rich access pattern support including strided access, pack/unpack, interleave/deinterleave, broadcast, upsample/downsample, channel split/merge, gather/scatter, and squeeze/expand operations. For detailed instruction syntax and distribution modes, refer to the [Vector Load/Store](isa/micro-isa/03-vector-load-store.md) group in the ISA specification. + +#### Synchronization Model + +The Ascend 950 architecture employs a cluster-based design with a 1:2 ratio of Cube cores to Vector cores. The PTO micro Instruction provides multiple levels of synchronization to manage concurrent execution across pipelines and cores: + +**Inter-Core Synchronization (within a cluster):** + +Synchronization between cores within the same cluster is achieved via the core sync mechanism using `pto.set_intra_core` and `pto.wait_intra_core` operations. This enables coordination between Cube and Vector cores sharing the same cluster resources. + +**Vector Core Pipeline Synchronization:** + +Within a single core, multiple pipelines operate asynchronously: + +- **MTE2 (PIPE_MTE2)**: DMA copy-in from GM to UB +- **MTE3 (PIPE_MTE3)**: DMA copy-out from UB to GM +- **Vector Compute (PIPE_V)**: Vector ALU operations +- **Scalar (PIPE_S)**: Scalar unit running the kernel program + +Pipeline synchronization can be achieved through two mechanisms: + +1. **Flag/Event mechanism**: `pto.set_flag` and `pto.wait_flag` operations resolve Read-After-Write (RAW) and Write-After-Read (WAR) hazards between pipelines. + +2. **Buffer-ID mechanism**: `pto.get_buf` and `pto.rls_buf` provide finer-grained synchronization through buffer acquisition and release semantics for producer-consumer coordination. + +**Intra-Pipeline Memory Barriers (within `__VEC_SCOPE__`):** + +Within the vector execution scope, the hardware does not track UB address aliasing between reg↔UB accesses. When UB addresses overlap or alias between vector load/store operations, explicit memory barriers are required: + +```c +pto.mem_bar "VV_ALL" // All prior vector ops complete before subsequent +pto.mem_bar "VST_VLD" // All prior vector stores visible before subsequent loads +pto.mem_bar "VLD_VST" // All prior vector loads complete before subsequent stores +``` + +Without proper barriers, loads may see stale data or stores may be reordered incorrectly. + +#### Execution Scopes (__VEC_SCOPE__) + +`__VEC_SCOPE__` is the IR-level representation of a Vector Function (VF) launch. In the PTO architecture, it defines the hardware interface between the Scalar Unit and the Vector Thread. + +In PTO micro Instruction source IR, vector execution scopes are modeled as dedicated region ops. The default form is `pto.vecscope`; when the scope body must reject implicit capture and require explicit region arguments, use `pto.strict_vecscope`. + +**Scalar-Vector Interface:** + +The execution model follows non-blocking fork semantics: + +- Scalar invocation: the scalar processor invokes a vector thread by calling a VF. Once the launch command is issued, the scalar unit does not stall and continues executing subsequent instructions in the pipeline. +- Vector execution: after invocation, the vector thread independently fetches and executes the instructions defined within the VF scope. +- Parallelism: this decoupled execution allows the scalar and vector units to run in parallel, so the scalar unit can prepare addresses or manage control flow while the vector unit performs heavy SIMD computation. + +**Launch Mechanism And Constraints:** + +- Parameter buffering: all arguments required by the VF must be staged in hardware-specific buffers. +- Launch overhead: launching a VF incurs a latency of a few cycles. Very small VFs should account for this overhead because launch cost can rival useful computation time. + +**MLIR Representation:** + +```mlir +pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} +``` + +**Strict MLIR Representation:** + +```mlir +pto.strict_vecscope(%ub, %ub_out, %lane) { +^bb0(%in: !pto.ptr, %out: !pto.ptr, %iv: index): + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %in[%iv] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %out[%iv], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} : (!pto.ptr, !pto.ptr, index) -> () +``` + +`pto.strict_vecscope` is the strict form of `pto.vecscope`. + +- `pto.vecscope` allows the body to use surrounding SSA values directly. +- `pto.strict_vecscope` requires every external value used by the body to be passed through the op operand list and received as a body block argument. +- `pto.strict_vecscope` rejects implicit capture from the surrounding scope. +- both ops still represent one explicit VPTO vector interval. +- regardless of whether the source form uses `pto.vecscope`, + `pto.strict_vecscope`, or a lowered carrier loop with + `llvm.loop.aivector_scope`, every op that produces or consumes `!pto.vreg`, + `!pto.mask<...>`, or `!pto.align` must be enclosed by exactly one vector + interval +- nested vector intervals are not part of the legal VPTO surface; ordinary + nested `scf.for` structure is fine, but one vector interval may not contain + another vector interval + +### Example: VecScope + +```mlir +pto.mte_gm_ub %7, %2, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64 + +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +pto.vecscope { + scf.for %lane = %c0 to %9 step %c64 { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %2[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %8[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } +} + +pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] +pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] +pto.mte_ub_gm %8, %14, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 +``` + +### Example: Strict VecScope + +```mlir +pto.strict_vecscope(%ub_in, %ub_out, %lane, %remaining) { +^bb0(%in: !pto.ptr, %out: !pto.ptr, %iv: index, %rem: i32): + %mask, %next_remaining = pto.plt_b32 %rem : i32 -> !pto.mask, i32 + %v = pto.vlds %in[%iv] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %out[%iv], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} : (!pto.ptr, !pto.ptr, index, i32) -> () +``` + +Use `pto.strict_vecscope` when the source form should make all vector-scope inputs explicit in the region signature instead of relying on surrounding SSA visibility. The scope op itself only defines the vector-interval boundary and region argument contract. + +### Cluster Programming Model + +#### Overview + +An A5 cluster contains one **Cube block** (AIC) and two **Vector blocks** (AIV0, AIV1). Each +block runs an **independent program** under its own Scalar Unit (SU), with its own issue queues: + +| Block | Issue Queues | +|---|---| +| Cube (AIC) | MTE2, MTE1, CUBE, FIXP | +| Vector (AIV) | MTE2, VEC, MTE3 | + +There is no implicit synchronization between blocks. All coordination between the Cube and Vector +programs is **explicit**, via the primitives described below. + + +``` +┌─────────────────────────────────────── A5 CLUSTER ───────────────────────────────────────┐ +│ │ +│ ┌─────────────────────┐ ┌─────────────────────┐ ┌─────────────────────┐ │ +│ │ CUBE CORE (AIC) │ │ VECTOR 0 (AIV0) │ │ VECTOR 1 (AIV1) │ │ +│ │ │ │ subblock_id = 0 │ │ subblock_id = 1 │ │ +│ │ ┌───────────────┐ │ │ ┌───────────────┐ │ │ ┌───────────────┐ │ │ +│ │ │ Scalar Unit │ │ │ │ Scalar Unit │ │ │ │ Scalar Unit │ │ │ +│ │ │ (SU) │ │ │ │ (SU) │ │ │ │ (SU) │ │ │ +│ │ │ runs cube │ │ │ │ runs vec │ │ │ │ runs vec │ │ │ +│ │ │ program │ │ │ │ program │ │ │ │ program │ │ │ +│ │ └───────────────┘ │ │ └───────────────┘ │ │ └───────────────┘ │ │ +│ │ ── Issue Queues ─ │ │ ── Issue Queues ─ │ │ ── Issue Queues ─ │ │ +│ │ ┌───────────────┐ │ │ ┌───────────────┐ │ │ ┌───────────────┐ │ │ +│ │ │ MTE2 │ │ │ │ MTE2 │ │ │ │ MTE2 │ │ │ +│ │ │ GM → L1 │ │ │ │ GM → UB │ │ │ │ GM → UB │ │ │ +│ │ ├───────────────┤ │ │ ├───────────────┤ │ │ ├───────────────┤ │ │ +│ │ │ MTE1 │ │ │ │ VEC │ │ │ │ VEC │ │ │ +│ │ │ L1 → L0A/B │ │ │ │ SIMD compute │ │ │ │ SIMD compute │ │ │ +│ │ ├───────────────┤ │ │ ├───────────────┤ │ │ ├───────────────┤ │ │ +│ │ │ CUBE │ │ │ │ MTE3 │ │ │ │ MTE3 │ │ │ +│ │ │ MMAD (L0C) │ │ │ │ UB → GM │ │ │ │ UB → GM │ │ │ +│ │ ├───────────────┤ │ │ └───────────────┘ │ │ └───────────────┘ │ │ +│ │ │ FIXP │ │ │ │ │ │ │ +│ │ │ L0C → UB │ │ │ │ │ │ │ +│ │ │ (fixpipe) │ │ │ │ │ │ │ +│ │ └───────────────┘ │ │ │ │ │ │ +│ └─────────────────────┘ └─────────────────────┘ └─────────────────────┘ │ +│ │ +│ ┌────────────────────── SC (System Controller) ──────────────────────────────────────┐ │ +│ │ │ │ +│ │ 32 semaphores · 4-bit counter each · shared for C→V and V→C directions │ │ +│ │ │ │ +│ │ ┌──────────────────────────────────────────────────────────────────────────────┐ │ │ +│ │ │ sema_id 0 –15 │ [ 0][ 1][ 2][ 3][ 4][ 5][ 6][ 7][ 8][ 9][10][11][12][13][14][15] │ │ │ +│ │ │ │ ↕ C→V / V→C ↕ │ │ │ +│ │ │ │ communicate with AIV0 (subblock_id=0) │ │ │ +│ │ ├──────────────────────────────────────────────────────────────────────────────┤ │ │ +│ │ │ sema_id 16–31 │ [16][17][18][19][20][21][22][23][24][25][26][27][28][29][30][31] │ │ │ +│ │ │ │ ↕ C→V / V→C ↕ │ │ │ +│ │ │ │ communicate with AIV1 (subblock_id=1) │ │ │ +│ │ └──────────────────────────────────────────────────────────────────────────────┘ │ │ +│ │ │ │ +│ │ → 16 sema_id pairs (0–15) available for 1:2 C:V sync per slot │ │ +│ │ │ │ +│ │ set_intra_block(trigger_pipe, sema_id) ──► increments semaphore │ │ +│ │ wait_intra_core(wait_pipe, sema_id) ──► stalls pipe until semaphore > 0 │ │ +│ │ │ │ +│ └─────────────────────────────────────────────────────────────────────────────────────┘ │ +└───────────────────────────────────────────────────────────────────────────────────────────┘ +``` + +#### Intra-Cluster Synchronization + +Within a cluster, the PTO micro ISA provides two levels of synchronization: + +**Intra-core pipeline sync** (`pto.set_flag` / `pto.wait_flag`): coordinates the asynchronous +pipelines *within a single block* — for example, ensuring MTE2 completes a GM→UB load before +the VEC pipeline begins computation. This does not cross block boundaries. + +**Inter-block sync** (`pto.set_intra_block` / `pto.wait_intra_core`): coordinates between the +Cube block and a Vector block within the same cluster. The sender specifies which **local +pipeline** commits the signal, ensuring the preceding operation on that pipeline has completed +before the signal is issued. The receiver specifies which **local pipeline** should stall until +the signal arrives. This is the fundamental IPC primitive for Cube–Vector cooperation on A5. + +> **Note:** `pto.set_cross_core` / `pto.wait_cross_core` operate at **multi-cluster** scope and +> are not used for intra-cluster communication. + +#### Intra-Cluster Data Paths + +A5 provides dedicated on-chip data paths between the Cube and Vector blocks, bypassing Global +Memory entirely. These are the **recommended high-performance paths** for intra-cluster tile +exchange. + +##### C→V: Cube L0C → Vector UB (fixpipe) + +The **fixpipe** instruction transfers data directly from Cube's L0C buffer to a Vector block's UB. +Because Cube natively produces results in **NZ fractal layout** and Vector operates on **ND +(row-major) layout**, fixpipe performs the layout conversion in hardware: + +``` +Cube L0C (NZ layout) ──[fixpipe, NZ2ND]──▶ Vector UB (ND layout) +``` + +Fixpipe supports a **dual-destination mode**: a single transfer can write to *both* AIV0's UB and +AIV1's UB simultaneously, with the tile split in hardware along either the row axis +(`DualModeSplitM`) or the column axis (`DualModeSplitN`): + +| Split | AIV0 receives | AIV1 receives | +|---|---|---| +| Split-M (rows) | Upper `[M/2, N]` in ND | Lower `[M/2, N]` in ND | +| Split-N (cols) | Left `[M, N/2]` in ND | Right `[M, N/2]` in ND | + +This 1→2 broadcast with in-hardware tile split is the architectural basis for 1:2 +Cube-to-Vector tile distribution. + +##### V→C: Vector UB → Cube L1 + +The reverse path transfers data from a Vector block's UB into Cube's L1 buffer. +A key architectural constraint: Cube's L1 stores tiles in **NZ fractal layout** (e.g. +`K1M1M0K0` — for fp16: `K0=16`, `M0=16`) so they can be loaded into L0A/L0B for MMAD +computation. Since Vector produces tiles in **ND layout**, the layout conversion from ND to NZ +must be applied as part of the V→C transfer: + +``` +Vector UB (ND layout) ──[ND→NZ movement]──▶ Cube L1 (NZ K1M1M0K0) +``` + +For 1:2 mode, both AIV0 and AIV1 each transfer a sub-tile into Cube's L1. The two sub-tiles are +assembled into a single contiguous NZ Mat tile in L1, ready for use as a LeftTile or RightTile +input to MMAD: + +| Split | AIV0 writes to L1 | AIV1 writes to L1 | Assembled in L1 | +|---|---|---|---| +| Split-M (rows) | `[K/2, N]` NZ at base | `[K/2, N]` NZ at offset | Full `[K, N]` NZ Mat tile | +| Split-N (cols) | `[K, N/2]` NZ at base | `[K, N/2]` NZ at offset | Full `[K, N]` NZ Mat tile | + +##### Fallback: GM-Staged Transfer + +When the local data path is not applicable, data can be exchanged via a **Global Memory staging +buffer**: the producer DMAs data to GM, and the consumer DMAs from GM. This path incurs off-chip +bandwidth cost and higher latency, but serves as a general fallback. + +#### Cube Internal Buffer Layout: NZ Fractal Format + +All cube unit internal buffers (L1/cbuf, L0A, L0B, L0C) use a **fractal NZ layout** rather than +row-major ND. Understanding this layout is essential when authoring cube data-movement ops. + +##### Definition + +Given hardware constant `C0 = 32 bytes`, for element type with byte width `E = sizeof(T)`: + +- Inner tile width: `K0 = N0 = C0 / E` (e.g. `K0 = 16` for fp16/bf16) +- Inner tile height: `M0 = 16` + +NZ re-indexing for a logical `[M, K]` tensor: + +``` +NZ index: (k1, m1, m0, k0) + where k1 = k / K0, k0 = k % K0 + m1 = m / M0, m0 = m % M0 +Physical layout: K1 x M1 x M0 x K0 (last dimension contiguous) +``` + +##### Per-buffer NZ Layouts + +| Buffer | Logical shape | Physical NZ layout | Notes | +|--------|--------------|-------------------|-------| +| L1 (cbuf) - Tensor A | `[M, K]` | `K1 M1 M0 K0` | Row-major A staged into NZ layout | +| L1 (cbuf) - Tensor B | `[K, N]` | `K1 N1 K0 N0` | Row-major B staged into NZ layout | +| L0A (left operand) | - | `K1 M1 M0 K0` | FRACTAL_NZ (A5) / FRACTAL_ZZ (A3): same NZ order as L1 cbuf | +| L0B (right operand) | - | `K1 N1 N0 K0` | FRACTAL_ZN: row-major outer, col-major inner (K0 innermost) | +| L0C (accumulator) | `[M, N]` | `N1 M1 M0 N0` | output of MMAD (FRACTAL_NZ: col-major outer, row-major inner) | + +##### Data Flow: GM -> L1 -> L0A/B -> L0C + +``` ++------------------------------------------------------------------------------+ +| GEMM Data Layout: GM -> L1 (NZ) -> L0A/B -> L0C | ++------------------------------------------------------------------------------+ + +STEP 1 - Global Memory (ND, row-major) +-------------------------------------- + Tensor A [M, K] Tensor B [K, N] + (K is the contiguous axis) (N is the contiguous axis) + + col-> k0 k1 ... kK-1 col-> n0 n1 ... nN-1 +row| +--------------------+ row| +--------------------+ + m0 | a00 a01 ... | k0 | b00 b01 ... | + m1 | a10 a11 ... | k1 | b10 b11 ... | + ... | | ... | | + mM-1| | kK-1 | | + +--------------------+ +--------------------+ + Physical: A[m*K + k] Physical: B[k*N + n] + + +STEP 2 - GM -> L1 (cbuf): NDtoNZ fractal repack +------------------------------------------------- + Use the structured cube load surface to stage row-major A and B into L1 NZ layout. + + A in L1: K1 x M1 x M0 x K0 B in L1: K1 x N1 x K0 x N0 + For each outer block (k1, m1): For each outer block (k1, n1): + +----------------------------+ +----------------------------+ + | M0 rows x K0 cols | | K0 rows x N0 cols | + | (16x16 elems contiguous) | | (16x16 elems contiguous) | + | m0| k0-> [0 .. K0-1] | | k0| n0-> [0 .. N0-1] | + | 0 [a a a a ...] | | 0 [b b b b ...] | + | 1 [a a a a ...] | | 1 [b b b b ...] | + | ... | | ... | + | M0-1 [a a a a ...] | | K0-1 [b b b b ...] | + +----------------------------+ +----------------------------+ + Physical: A_nz[k1][m1][m0][k0] Physical: B_nz[k1][n1][k0][n0] + + + + NOTE: For GEMM with row-major A/B, stage both operands from GM to L1 as + logical ND-to-NZ movement. If the source is already in a transposed logical + layout, express that at the structured load level instead of relying on a + later interpretation of the same bytes. + + +STEP 3 - L1 -> L0A / L0B +-------------------------- + L0A: cbuf K1 M1 M0 K0 --mte_l1_l0a--> L0A K1 M1 M0 K0 (FRACTAL_NZ on A5) + L0B: cbuf K1 N1 K0 N0 --mte_l1_l0b--> L0B K1 N1 N0 K0 (FRACTAL_ZN, K0 innermost) + + Why transpose at L1->L0B and not at GM->L1? + -------------------------------------------- + The cube reduction axis is K. L0B requires K innermost (N1 K1 K0 N0) + so the cube hardware reads all K0 elements per cycle without striding. + The inner-box transpose is performed as part of the structured right-load + movement itself; no separate user-visible pass is required. + Each 512B fractal z-block is permuted as it moves from L1 to L0B. + + L0A tile (cube LEFT port): L0B tile (cube RIGHT port): + +---------------------+ +---------------------+ + | shape: [M0, K0] | x | shape: [K0, N0] | + | M0 rows, K0 cols | | K0 rows, N0 cols | + | K innermost (fast) | | K innermost (fast) | + +---------------------+ +---------------------+ + | | + +-----------------+--------------------+ + | pto.mad (MMAD) + v + +STEP 4 - L0C output layout: N1 M1 M0 N0 +----------------------------------------- + For each outer block (n1, m1): + +------------------------------+ + | M0 rows x N0 cols | + | = result sub-tile of C[M,N] | + | n0-> [0 .. N0-1] | + | m0| [c c c c ...] | + | [c c c c ...] | + +------------------------------+ + Physical: C_nz[n1][m1][m0][n0] -> C_nd[m1*M0+m0][n1*N0+n0] + + Writeback: FIXPIPE MTE ops convert the L0C NZ result to the requested + destination layout and memory space. + + +Full pipeline summary +---------------------- + GM (ND) L1/cbuf (NZ) L0A/B (NZ) L0C (NZ) GM (ND) + + A[M,K] --mte_gm_l1_frac/mte_gm_l1--> K1 M1 M0 K0 --mte_l1_l0a--> K1 M1 M0 K0 -+ + +-MAD-> N1 M1 M0 N0 --> C[M,N] + B[K,N] --mte_gm_l1_frac/mte_gm_l1--> K1 N1 K0 N0 --mte_l1_l0b--> K1 N1 N0 K0 -+ + ^ + transpose as part of mte_l1_l0b when requested + NOT at GM->L1 +``` + + + +#### Programming Model + +The common pattern for Cube–Vector co-programming is a **software pipeline**: the Cube and Vector +programs run a coordinated loop where each iteration the Cube produces a tile and the Vector +consumes it (or vice versa), with explicit `pto.set_intra_block` / `pto.wait_intra_core` +handshakes at each step to maintain correct data ordering. + +The PTO micro ISA exposes all the hardware primitives above directly. Higher-level constructs +that simplify this pattern (such as in-order FIFO abstractions) can be implemented as software +libraries on top of these primitives; they are not part of the ISA itself. + + +### Scope + +This document is the interface specification centered on the `mlir::pto` dialect and the shared MLIR surface used alongside it in PTO micro Instruction programs. + +It only describes: + +- operation names +- operand and result lists +- operand and result types +- important attributes +- C-style semantics for each operation + +It does not describe lowering strategy. + +PTO micro Instruction source programs are not restricted to `pto` operations alone. In practice they also use shared MLIR dialect ops, most notably the full scalar operation surface of `arith` together with structured control-flow ops from `scf`, to express scalar constants, scalar arithmetic, type conversion, comparisons, and structured control flow around PTO vector or tile regions. These shared-dialect ops are part of the supported PTO micro Instruction source surface and should be regarded as part of PTO-ISA alongside `pto` dialect operations. + +### Shared MLIR Dialects + +- `arith`: the full scalar `arith` surface is supported in PTO micro Instruction programs, covering scalar integer, floating-point, boolean, and `index` operations. In current samples the most common uses are still constants, offset/bounds arithmetic, casts, compares, and selects. +- `scf`: structured control flow used to model counted loops, conditional regions, loop-carried state, and break-like control around PTO compute and data-movement ops. +- Shared dialect ops remain in standard MLIR form so that PTO analyses and backend passes can reason about control flow and scalar state without re-encoding them as PTO-specific instructions. + +### BlockDim Query Operations + +These ops expose the current kernel instance's execution coordinates to scalar code. They are the PTO-level equivalent of runtime queries such as `GetBlockIdx()` and `GetBlockNum()` in kernel programming models. + +Use them when the same kernel body is launched across multiple blocks or subblocks and each execution instance must figure out which slice of the global workload it owns. + +A common pattern is: + +- split the full input/output tensor into `block_num` disjoint block-sized regions +- let each block compute its own starting offset from `block_idx` +- within one block, further tile the local region and drive the tile loop with ordinary scalar `arith` / `scf` ops + +For example, if a tensor is split evenly across 8 blocks and each block handles `block_length = 2048` elements, then block `b` owns the global range `[b * block_length, (b + 1) * block_length)`. The per-block GM base pointer can be formed by adding `block_idx * block_length` elements to the original base pointer. + +At the PTO micro Instruction level, these runtime-query ops are pure scalar producers. They do not perform data movement, do not allocate memory, and do not by themselves create tiling or double buffering. Instead, they provide the scalar values used by surrounding address computation and structured control flow. + +#### Example: block-level data partitioning + +```mlir +%block = pto.get_block_idx +%block_num = pto.get_block_num +%block_len = arith.constant 2048 : index +%base = arith.index_cast %block : i64 to index +%offset = arith.muli %base, %block_len : index +%block_in = pto.addptr %gm_in, %offset : !pto.ptr -> !pto.ptr +%block_out = pto.addptr %gm_out, %offset : !pto.ptr -> !pto.ptr +``` + +In this pattern, all blocks execute the same kernel body, but each block sees a different `%block` value and therefore computes a different GM window. + +#### `pto.get_block_idx` + +- **syntax:** `%block = pto.get_block_idx` +- **result:** `i64` +- **semantics:** Return the current block ID in the range `[0, pto.get_block_num())`. + +```c +block = block_idx(); +``` + +#### `pto.get_subblock_idx` + +- **syntax:** `%subblock = pto.get_subblock_idx` +- **result:** `i64` +- **semantics:** Return the current subblock ID in the range `[0, pto.get_subblock_num())`. + +```c +subblock = subblock_idx(); +``` + +#### `pto.get_block_num` + +- **syntax:** `%block_num = pto.get_block_num` +- **result:** `i64` +- **semantics:** Return the total number of launched blocks visible to the current kernel instance. + +```c +block_num = block_num(); +``` + +#### `pto.get_subblock_num` + +- **syntax:** `%subblock_num = pto.get_subblock_num` +- **result:** `i64` +- **semantics:** Return the total number of visible subblocks for the current execution instance. + +```c +subblock_num = subblock_num(); +``` + +#### `pto.store_vfsimt_info` + +- **syntax:** `pto.store_vfsimt_info %dim_z, %dim_y, %dim_x : i32, i32, i32` +- **operands:** `i32, i32, i32` +- **semantics:** Configure the SIMT VF launch descriptor consumed by a subsequent SIMT entry invocation. The three operands are the launch dimensions in `z, y, x` order. Lowering packs them into the `llvm.hivm.store.vfsimt.info` intrinsic payload. +- **placement:** This op must appear in the outer non-SIMT caller. It must not appear inside a function marked with `pto.simt_entry`. + +```c +store_vfsimt_info(dim_z, dim_y, dim_x); +``` + +#### `pto.get_tid_x` + +- **syntax:** `%tx = pto.get_tid_x : i32` +- **result:** `i32` +- **semantics:** Return the current SIMT lane X coordinate inside the active VF launch. + +```c +tx = get_tid_x(); +``` + +#### `pto.get_tid_y` + +- **syntax:** `%ty = pto.get_tid_y : i32` +- **result:** `i32` +- **semantics:** Return the current SIMT lane Y coordinate inside the active VF launch. + +```c +ty = get_tid_y(); +``` + +#### `pto.get_tid_z` + +- **syntax:** `%tz = pto.get_tid_z : i32` +- **result:** `i32` +- **semantics:** Return the current SIMT lane Z coordinate inside the active VF launch. + +```c +tz = get_tid_z(); +``` + +Example: + +```mlir +module attributes {pto.target_arch = "a5"} { + func.func @simt_store_tid_kernel(%out: !pto.ptr) { + %c0_i64 = arith.constant 0 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %dim_z = arith.constant 1 : i32 + %dim_y = arith.constant 32 : i32 + %dim_x = arith.constant 32 : i32 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + pto.store_vfsimt_info %dim_z, %dim_y, %dim_x : i32, i32, i32 + func.call @simt_write(%ub_out) : (!pto.ptr) -> () + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.dma_store %ub_out, %out, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + return + } + + func.func @simt_write(%dst: !pto.ptr) attributes {pto.simt_entry} { + %tx = pto.get_tid_x : i32 + %ty = pto.get_tid_y : i32 + %tz = pto.get_tid_z : i32 + %c8_i32 = arith.constant 8 : i32 + %c16_i32 = arith.constant 16 : i32 + %c32_i32 = arith.constant 32 : i32 + %ty_shift = arith.shli %ty, %c8_i32 : i32 + %tz_shift = arith.shli %tz, %c16_i32 : i32 + %xy = arith.ori %tx, %ty_shift : i32 + %xyz = arith.ori %xy, %tz_shift : i32 + %lane_base = arith.muli %ty, %c32_i32 : i32 + %tid = arith.addi %lane_base, %tx : i32 + %tid_idx = arith.index_castui %tid : i32 to index + pto.store %xyz, %dst[%tid_idx] : !pto.ptr, i32 + return + } +} +``` + +Typical usage: + +```mlir +%block = pto.get_block_idx +%subblock = pto.get_subblock_idx +%block_num = pto.get_block_num +%subblock_num = pto.get_subblock_num +``` + +### VMS4 Status Query + +#### `pto.get_vms4_sr` + +- **syntax:** `%list0, %list1, %list2, %list3 = pto.get_vms4_sr : i16, i16, i16, i16` +- **results:** four `i16` values +- **semantics:** Read `VMS4_SR` and return the finished element counts for + source lists 0, 1, 2, and 3. After an exhausted `pto.vmrgsort4`, these are + the per-source-list executed counts. + +| Bits | Meaning | +|------|---------| +| `[15:0]` | finished count for source list 0 | +| `[31:16]` | finished count for source list 1 | +| `[47:32]` | finished count for source list 2 | +| `[63:48]` | finished count for source list 3 | + +```c +status = VMS4_SR; +list0 = (uint16_t)(status & 0xffff); +list1 = (uint16_t)((status >> 16) & 0xffff); +list2 = (uint16_t)((status >> 32) & 0xffff); +list3 = (uint16_t)((status >> 48) & 0xffff); +``` + +### Core Types + +### Element Types +`vreg`: `!pto.vreg` Fixed-width PTO micro Instruction vector type with total width exactly 256 bytes (2048 bits). `N` is the lane count, `T` is the element type, and `N * bitwidth(T) = 2048`. + +| Type | Bits | Description | +|------|------|-------------| +| `i8` / `si8` / `ui8` | 8 | Signless/signed/unsigned 8-bit integer | +| `i16` / `si16` / `ui16` | 16 | Signless/signed/unsigned 16-bit integer | +| `i32` / `si32` / `ui32` | 32 | Signless/signed/unsigned 32-bit integer | +| `i64` / `si64` / `ui64` | 64 | Signless/signed/unsigned 64-bit integer | +| `f16` | 16 | IEEE 754 half precision | +| `bf16` | 16 | Brain floating point | +| `f32` | 32 | IEEE 754 single precision | + +### Mask Types + +`mask`: `!pto.mask` Typed predicate-register view. `G` is one of `b8`, `b16`, `b32` and records the byte-granularity interpretation used by VPTO ops and verifiers. + +Typed masks are also the primary legality contract for predicated VPTO code: + +- vector ops over `f32`, `i32`, `si32`, and `ui32` consume `!pto.mask` +- vector ops over `f16`, `bf16`, `i16`, `si16`, and `ui16` consume + `!pto.mask` +- vector ops over 8-bit element families consume `!pto.mask` +- compare families keep seed-mask and result-mask granularity aligned with the + compared vector family +- carry families keep carry-in, carry-out, and execution-mask granularity + aligned with the data-vector family +- mask-only ops that do not explicitly change granularity preserve the same `G` + +### Address Space Conventions + +PTO micro Instruction memory operands use `!pto.ptr`. This specification models the following memory-space attributes: + +| Space | Interpretation | +|-------|----------------| +| `gm` | Global Memory (GM), off-chip HBM/DDR storage | +| `ub` | Unified Buffer (UB), on-chip vector buffer | + +Typical pointer construction and pointer arithmetic follow the same `!pto.ptr<..., space>` form: + +```mlir +%0 = pto.castptr %c0 : i64 -> !pto.ptr +%1 = pto.addptr %0, %c1024 : !pto.ptr -> !pto.ptr +``` + +### `!pto.ptr` + +`!pto.ptr` is the typed pointer form used for explicit memory operands in PTO micro Instruction. + +- `T` is the element type associated with the pointed-to storage. +- `space` is the memory domain, typically `gm` or `ub` in this specification. +- A `pto.ptr` value carries an address plus its element-type / memory-space interpretation, but it does not carry tensor shape or stride metadata by itself. +- Tensor semantics are introduced separately through view-building operations such as `pto.make_tensor_view`. +- Pointer arithmetic is element-based rather than byte-based. + +Typical examples: + +- `!pto.ptr` +- `!pto.ptr` +- `!pto.ptr` + +### Pointer Operations + +#### `pto.castptr` + +- **syntax:** `%result = pto.castptr %addr : i64 -> !pto.ptr` +- **semantics:** Reinterpret a scalar address value as a typed PTO pointer in the target memory space. + +```c +result = (ptr)addr; +``` + +`pto.castptr` is a pointer-construction operation. It does not perform data movement and does not by itself imply any load/store side effect. + +#### `pto.addptr` + +- **syntax:** `%result = pto.addptr %ptr, %offset : !pto.ptr -> !pto.ptr` +- **semantics:** Compute a new pointer by advancing the base pointer by an element offset. + +```c +result = ptr + offset; // offset counted in elements, not bytes +``` + +`pto.addptr` preserves both the element type `T` and the memory-space tag `space`. + +#### `pto.load_scalar` + +- **syntax:** `%value = pto.load_scalar %ptr[%offset] : !pto.ptr -> T` +- **semantics:** Load one scalar element from a pointer-like operand. + +```c +value = ptr[offset]; +``` + +- **inputs:** + `%ptr` is a typed PTO pointer `!pto.ptr`, and `%offset` is an + `index` displacement counted in elements. +- **outputs:** + `%value` is the loaded scalar element. +- **constraints and limitations:** + The result type MUST match the element type of `%ptr`. This op is a scalar + memory helper; unlike `pto.vlds`, it does not produce a `vreg` result and + does not participate in vector load `dist` families. + +#### `pto.store_scalar` + +- **syntax:** `pto.store_scalar %value, %ptr[%offset] : !pto.ptr, T` +- **semantics:** Store one scalar element to a pointer-like operand. + +```c +ptr[offset] = value; +``` + +- **inputs:** + `%value` is the scalar value to store. `%ptr` is a typed PTO pointer + `!pto.ptr`, and `%offset` is an `index` displacement counted in + elements. +- **constraints and limitations:** + The stored value type MUST match the element type of `%ptr`. This op is a + scalar memory helper; unlike `pto.vsts`, it does not consume a mask and does + not target vector-store `dist` families. + +#### `pto.load` + +- **syntax:** `%value = pto.load %ptr[%offset] : !pto.ptr -> T` +- **semantics:** Load one scalar element from a VPTO pointer-like operand. + +```c +value = ptr[offset]; +``` + +- **inputs:** + `%ptr` is a typed PTO pointer `!pto.ptr` or a memref operand that + will be normalized to a PTO pointer before LLVM emission. `%offset` is an + `index` displacement counted in elements. +- **outputs:** + `%value` is the loaded scalar element. +- **constraints and limitations:** + The result type MUST match the element type of `%ptr`. This is the preferred + scalar memory op for VPTO/SIMT authoring. + +#### `pto.store` + +- **syntax:** `pto.store %value, %ptr[%offset] : !pto.ptr, T` +- **semantics:** Store one scalar element to a VPTO pointer-like operand. + +```c +ptr[offset] = value; +``` + +- **inputs:** + `%value` is the scalar value to store. `%ptr` is a typed PTO pointer + `!pto.ptr` or a memref operand that will be normalized to a PTO + pointer before LLVM emission. `%offset` is an `index` displacement counted in + elements. +- **constraints and limitations:** + The stored value type MUST match the element type of `%ptr`. This is the + preferred scalar memory op for VPTO/SIMT authoring. + +#### Pointer-Based Vector Access Example + +The following lowered-style fragment shows how typed PTO pointers flow through +pointer construction, pointer arithmetic, structured control flow, and PTO +memory ops. Scalar memory access is expressed on `!pto.ptr` in +general, but the common VPTO pattern here is UB-local scalar access alongside +UB vector loads/stores: + +```mlir +%0 = pto.castptr %c0 : i64 -> !pto.ptr +%1 = pto.addptr %0, %c1024 : !pto.ptr -> !pto.ptr +pto.vecscope { + %16 = scf.for %arg3 = %c0 to %11 step %c64 iter_args(%arg4 = %12) -> (i32) { + %mask, %scalar_out = pto.plt_b32 %arg4 : i32 -> !pto.mask, i32 + %s = pto.load %1[%c4] : !pto.ptr -> f32 + pto.store %s, %1[%c8] : !pto.ptr, f32 + %17 = pto.vlds %1[%arg3] : !pto.ptr -> !pto.vreg<64xf32> + %18 = pto.vabs %17, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %18, %10[%arg3], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %scalar_out : i32 + } +} +``` + +In this pattern, `pto.castptr` materializes a typed UB pointer, `pto.addptr` shifts the base by 1024 `f32` elements, and the subsequent `[%arg3]` indexing on `pto.vlds` / `pto.vsts` applies an additional element offset relative to that base. + +### Special Types + +#### `!pto.mask` + +`!pto.mask` models an A5 predicate register (256-bit) under a typed granularity view, not an integer vector. + +`G` is part of the type and MUST be one of: + +- `b32` +- `b16` +- `b8` + +All three forms describe the same physical 256-bit predicate-register class. The type parameter does not encode how many lanes are currently active. Instead, it records how VPTO interprets the register when matching mask-producing ops, mask-consuming ops, and verifier legality rules. + +In the ISA chapters below, this document uses `!pto.mask` as shorthand when a +family is generic over granularity. For op families whose names already encode +the granularity, such as `pset_b32`, `pge_b16`, `plt_b8`, +`pdintlv_b8`, and `pintlv_b16`, examples use the corresponding concrete typed +mask. + +**Mask Granularity:** + +The predicate register is 256 bits in length, where each bit controls 1 byte of data. `G` therefore describes how many bytes form one logical element slot: + +| Mask Type | Bytes / Element Slot | Typical Element Family | Derived Logical Lanes | +|-----------|----------------------|------------------------|-----------------------| +| `!pto.mask` | 4 | `f32` / `i32` | 64 | +| `!pto.mask` | 2 | `f16` / `bf16` / `i16` | 128 | +| `!pto.mask` | 1 | 8-bit element family | 256 | + +This is intentionally different from a lane-vector model such as `mask<64xi1>`: + +- `!pto.mask` still denotes a 256-bit predicate register; +- `64` is only the derived logical lane count for the `b32` view; +- value-level patterns such as `PAT_VL32` describe which lanes are active, not a different type. +- `pto.vaddc`, `pto.vsubc`, `pto.vaddcs`, and `pto.vsubcs` use `!pto.mask` + to carry their per-lane carry results, interpreted with this same + granularity. + +**Predication Behavior (Zero-Merge):** + +The native hardware predication mode is **ZEROING** — inactive lanes produce zero: + +```c +dst[i] = mask[i] ? op(src0[i], src1[i]) : 0 // ZEROING mode +``` + +```mlir +// Predicated add: inactive lanes produce zero +%mask = pto.pset_b32 "PAT_VL32" : !pto.mask // first 32 logical b32 lanes active +%result = pto.vcmp %a, %b, %mask, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +``` + +```mlir +// Compare and select: generate mask from comparison, use for conditional select +%mask = pto.vcmp %lhs, %rhs, %seed, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +%out = pto.vsel %x, %y, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +#### `!pto.align` + +`!pto.align` models the A5 vector-align carrier state. It is not payload data. + +```mlir +%align = pto.vldas %ub : !pto.ptr -> !pto.align +%vec, %align_out = pto.vldus %ub, %align : !pto.ptr, !pto.align -> !pto.vreg<64xf32>, !pto.align + +%store_align = pto.init_align : !pto.align +%next_align = pto.vstus %store_align, %offset, %vec, %ub + : !pto.align, i32, !pto.vreg<64xf32>, !pto.ptr -> !pto.align +``` + +--- + +## Part II: Notation Convention + +This section defines the MLIR syntax patterns and C-style semantic notation used throughout the ISA reference (Part III). + +### MLIR Op Syntax Patterns + +All PTO micro Instruction operations follow standard MLIR syntax. The common patterns are: + +**Unary (one vector in, one vector out):** + +```mlir +%result = pto. %input : !pto.vreg -> !pto.vreg +``` + +**Binary (two vectors in, one vector out):** + +```mlir +%result = pto. %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg +``` + +**Vec-Scalar (one vector + one scalar in, one vector out):** + +```mlir +%result = pto. %input, %scalar : !pto.vreg, T -> !pto.vreg +``` + +**Load (memory to register):** + +```mlir +%result = pto.vlds %source[%offset] {dist = "DIST"} : !pto.ptr -> !pto.vreg +``` + +**Store (register to memory):** + +```mlir +pto.vsts %value, %destination[%offset] {dist = "DIST"} : !pto.vreg, !pto.ptr +``` + +**Dual Load (one load, two results — deinterleave):** + +```mlir +%low, %high = pto.vldsx2 %source[%offset], "DIST" : !pto.ptr, index -> !pto.vreg, !pto.vreg +``` + +**Dual Store (two inputs, one interleaved store):** + +```mlir +pto.vstsx2 %low, %high, %dest[%offset], "DIST", %mask : !pto.vreg, !pto.vreg, !pto.ptr, index, !pto.mask +``` + +**Compare (two vectors + seed mask in, mask out):** + +```mlir +%mask = pto.vcmp %src0, %src1, %seed, "CMP_MODE" : !pto.vreg, !pto.vreg, !pto.mask -> !pto.mask +``` + +**Conversion (one vector in, different-typed vector out):** + +```mlir +%result = pto.vcvt %input, %mask {rnd = "R", sat = "SAT", part = "EVEN"} : !pto.vreg, !pto.mask -> !pto.vreg +``` + +**Predicate construction:** + +```mlir +%mask = pto.pset_b32 "PAT_ALL" : !pto.mask +%tail = pto.pge_b32 "PAT_VL16" : !pto.mask +``` + +**Sync operations:** + +```mlir +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +**Pointer construction and arithmetic:** + +```mlir +%ptr = pto.castptr %addr : i64 -> !pto.ptr +%ptr2 = pto.addptr %ptr, %offset : !pto.ptr -> !pto.ptr +``` + +### Shared Dialect Syntax Patterns + +PTO micro Instruction programs may interleave PTO ops with standard MLIR `arith` and `scf` ops. +The examples below emphasize common index-heavy patterns, but `arith` support is not limited to index arithmetic. + +**Scalar / index constant:** + +```mlir +%c0 = arith.constant 0 : index +%zero = arith.constant 0.0 : f32 +``` + +**Scalar arithmetic (integer / float / boolean-style bitwise):** + +```mlir +%sum_i = arith.addi %lhs_i, %rhs_i : i32 +%sum_f = arith.addf %lhs_f, %rhs_f : f32 +%bits = arith.andi %flags0, %flags1 : i32 +``` + +**Scalar compare and select:** + +```mlir +%cond = arith.cmpi eq, %lhs, %rhs : index +%bound = arith.select %cond, %a, %b : index +``` + +**Counted loop with loop-carried values:** + +```mlir +%result = scf.for %iv = %lb to %ub step %step + iter_args(%acc = %init) -> (index) { + %next = arith.addi %acc, %iv : index + scf.yield %next : index +} +``` + +**Structured conditional region:** + +```mlir +%selected = scf.if %cond -> (index) { + scf.yield %then_value : index +} else { + scf.yield %else_value : index +} +``` + +**Structured while loop:** + +```mlir +%state:2 = scf.while (%iv = %c0, %alive = %true) : (index, i1) -> (index, i1) { + %keep_going = arith.cmpi slt, %iv, %limit : index + scf.condition(%keep_going) %iv, %alive : index, i1 +} do { +^bb0(%iv_in: index, %alive_in: i1): + %iv_next = arith.addi %iv_in, %c1 : index + scf.yield %iv_next, %alive_in : index, i1 +} +``` + +### C-Style Semantics Convention + +For each ISA operation in Part III, semantics are expressed as C code. The convention: + +```c +// Vector register contents as arrays: +T dst[N]; // destination +T src0[N]; // first source +T src1[N]; // second source (binary ops) +T scalar; // scalar operand (vec-scalar ops) +int mask[N]; // per-lane predicate (0 or 1) + +// N = lane count determined by type: +// N = 256 for i8/si8/ui8 +// N = 128 for i16/si16/ui16/f16/bf16 +// N = 64 for i32/si32/ui32/f32 +// N = 32 for i64/si64/ui64 +``` + +**Example — pto.vadd semantics:** + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] + src1[i]; +``` + +**Example — pto.vcgadd (group reduction per VLane) semantics:** + +```c +int K = N / 8; // elements per VLane +for (int g = 0; g < 8; g++) { + T sum = 0; + for (int i = 0; i < K; i++) + sum += src[g*K + i]; + dst[g*K] = sum; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +For A5 reduction result types: + +- `pto.vcadd` widens `i8 -> i16`, `u8 -> u16`, `i16 -> i32`, and `u16 -> u32`, + with the lane count halved in each widening case. +- `pto.vcadd` keeps the same result type for `f16`, `f32`, `i32`, and `u32`. + +### Template Placeholder Conventions + +| Placeholder | Meaning | +|-------------|---------| +| `"SRC_PIPE"`, `"DST_PIPE"` | Pipeline identifiers: `"PIPE_MTE2"`, `"PIPE_V"`, `"PIPE_MTE3"` | +| `"EVENT_ID"` | Event identifier: `"EVENT_ID0"` etc. | +| `"DIST"` | Distribution mode string (see the relevant load/store ISA group in Part III) | +| `"CMP_MODE"` | Compare predicate: `eq \| ne \| lt \| le \| gt \| ge` | +| `"RND"` | Rounding mode: `R \| A \| F \| C \| Z \| O` | +| `"SAT"` | Saturation: `SAT \| NOSAT` | +| `"PART"` | Half selector: `EVEN \| ODD` | +| `"PAT_*"` | Predicate pattern literal | +| `T` | Element type (f32, f16, bf16, i32, i16, i8, etc.) | +| `N` | Lane count (`N * bitwidth(T) = 2048`) | + +--- + +## Part III: ISA Instruction Reference +# Part III: ISA Instruction Reference — Summary + +This section provides a categorized overview of all PTO micro Instruction operations plus the shared MLIR `arith` and `scf` ops that may appear in PTO micro Instruction programs. Detailed documentation for each group is available in the linked files. + +--- + +## Instruction Groups + +| # | Group | Description | Count | Details | +|---|-------|-------------|-------|---------| +| 1 | [Pipeline Sync](isa/micro-isa/01-pipeline-sync.md) | Intra-core pipeline synchronization | 5 | `pto.set_flag`, `pto.wait_flag`, `pto.pipe_barrier`, `pto.get_buf`, `pto.rls_buf` | +| 2 | [DMA Copy Programming](isa/micro-isa/02-dma-copy.md) | Public DMA transfer interface between GM↔UB, UB→UB, and UB→L1 | 4 | `pto.mte_gm_ub`, `pto.mte_ub_gm`, `pto.mte_ub_ub`, `pto.mte_ub_l1` | +| 3 | [Vector Load/Store](isa/micro-isa/03-vector-load-store.md) | UB↔vreg data movement with various access patterns | ~20 | `pto.vlds`, `pto.vldsx2`, `pto.vgather2`, `pto.vsts`, `pto.vstsx2`, `pto.vscatter`, etc. | +| 4 | [Predicate Load/Store](isa/micro-isa/04-predicate-load-store.md) | UB↔mask register movement | 5 | `pto.plds`, `pto.pldi`, `pto.psts`, `pto.psti`, `pto.pstu` | +| 5 | [Materialization & Predicate Ops](isa/micro-isa/05-materialization-predicate.md) | Scalar broadcast, predicate generation and manipulation | ~17 | `pto.vbr`, `pto.vdup`, `pto.pset_b*`, `pto.pge_b*`, `pto.plt_b*`, `pto.ppack`, `pto.punpack`, `pto.pnot`, `pto.psel`, etc. | +| 6 | [Unary Vector Ops](isa/micro-isa/06-unary-vector-ops.md) | Single-input element-wise operations | 6 | `pto.vabs`, `pto.vexp`, `pto.vln`, `pto.vsqrt`, `pto.vrelu`, `pto.vnot` | +| 7 | [Binary Vector Ops](isa/micro-isa/07-binary-vector-ops.md) | Two-input element-wise operations | 13 | `pto.vadd`, `pto.vsub`, `pto.vmul`, `pto.vdiv`, `pto.vmax`, `pto.vmin`, `pto.vand`, `pto.vor`, `pto.vxor`, `pto.vshl`, `pto.vshr`, `pto.vaddc`, `pto.vsubc` | +| 8 | [Vec-Scalar Ops](isa/micro-isa/08-vec-scalar-ops.md) | Vector-scalar operations | 9 | `pto.vadds`, `pto.vmuls`, `pto.vmaxs`, `pto.vmins`, `pto.vlrelu`, `pto.vshls`, `pto.vshrs`, `pto.vaddcs`, `pto.vsubcs` | +| 9 | [Conversion Ops](isa/micro-isa/09-conversion-ops.md) | Type conversion with rounding/saturation control | 4 | `pto.vcvt`, `pto.vtrc`, `pto.vbitcast`, `pto.pbitcast` | +| 10 | [Reduction Ops](isa/micro-isa/10-reduction-ops.md) | Vector reductions | 7 | `pto.vcadd`, `pto.vcmax`, `pto.vcmin`, `pto.vcgadd`, `pto.vcgmax`, `pto.vcgmin`, `pto.vcpadd` | +| 11 | [Compare & Select](isa/micro-isa/11-compare-select.md) | Comparison and conditional selection | 4 (+1 not A5) | `pto.vcmp`, `pto.vcmps`, `pto.vsel`, `pto.vselr` (`pto.vselrv2` removed: not A5) | +| 12 | [Data Rearrangement](isa/micro-isa/12-data-rearrangement.md) | In-register data movement and permutation | 2 (+2 not A5) | `pto.vintlv`, `pto.vdintlv` (`pto.vintlvv2`, `pto.vdintlvv2` removed: not A5) | +| 13 | [DSA/SFU Ops](isa/micro-isa/13-dsa-sfu-ops.md) | Specialized ops, index generation, and sorting helpers | 10 | `pto.vlrelu`, `pto.vprelu`, `pto.vexpdif`, `pto.vaxpy`, `pto.vmull`, `pto.vmula`, `pto.vci`, `pto.vbitsort`, `pto.vmrgsort4`, `pto.get_vms4_sr` | +| 14 | [Arith (Shared MLIR Dialect)](isa/micro-isa/14-shared-arith.md) | Full scalar `arith` surface used around PTO ops; the companion page lists categories and representative examples | all scalar ops | `arith.constant`, `arith.addi`, `arith.addf`, `arith.cmpi`, `arith.cmpf`, `arith.select`, `arith.index_cast`, `arith.extsi`, `arith.trunci`, `arith.andi`, `arith.shli`, etc. | +| 15 | [SCF (Shared MLIR Dialect)](isa/micro-isa/15-shared-scf.md) | Structured loops, branches, and loop-carried state around PTO regions | 5 | `scf.for`, `scf.if`, `scf.while`, `scf.condition`, `scf.yield` | +| 16 | [Cube Matrix Multiply](isa/micro-isa/16-cube-matmul.md) | GM↔L1 (`l1`/cbuf) staging, L1 (`l1`)↔UB/BT/FB side moves, L1→L0A/L0B loads, L0C (`l0c`) matmul, and FIXPIPE MTE writeback | 19 | `pto.mte_gm_l1`, `pto.mte_l1_ub`, `pto.mte_gm_l1_frac`, `pto.mte_l1_bt`, `pto.mte_l1_fb`, `pto.mte_l1_l0a`, `pto.mte_l1_l0b`, `pto.mte_l1_l0a_mx`, `pto.mte_l1_l0b_mx`, `pto.mad`, `pto.mad_acc`, `pto.mad_bias`, `pto.mad_mx`, `pto.mad_mx_acc`, `pto.mad_mx_bias`, `pto.mte_l0c_l1`, `pto.mte_l0c_gm`, `pto.mte_l0c_ub` | + +--- + +## Quick Reference by Category + +### Memory Operations + +| Operation | Group | Description | +|-----------|-------|-------------| +| GM→UB DMA | 2 | `pto.mte_gm_ub` | +| UB→GM DMA | 2 | `pto.mte_ub_gm` | +| UB→UB / UB→L1 copy | 2 | `pto.mte_ub_ub`, `pto.mte_ub_l1` | +| GM→L1 | 16 | `pto.mte_gm_l1`, `pto.mte_gm_l1_frac` | +| L1→UB | 16 | `pto.mte_l1_ub` | +| L1→BT | 16 | `pto.mte_l1_bt` | +| L1→FB | 16 | `pto.mte_l1_fb` | +| L1→L0A / L1→L0B | 16 | `pto.mte_l1_l0a`, `pto.mte_l1_l0b`, `pto.mte_l1_l0a_mx`, `pto.mte_l1_l0b_mx` | +| L0C→L1 / GM / UB (FIXPIPE MTE) | 16 | `pto.mte_l0c_l1`, `pto.mte_l0c_gm`, `pto.mte_l0c_ub` | +| Contiguous Load | 3 | `pto.vlds` with `NORM` dist | +| Broadcast Load | 3 | `pto.vlds` with `BRC` family dist | +| Gather | 3 | `pto.vgather2`, `pto.vgatherb` | +| Contiguous Store | 3 | `pto.vsts` with `NORM_B8` / `NORM_B16` / `NORM_B32` dist | +| Scatter | 3 | `pto.vscatter` | + +### Compute Operations + +| Operation | Group | Description | +|-----------|-------|-------------| +| Element-wise Arithmetic | 6, 7 | `pto.vadd`, `pto.vmul`, `pto.vabs`, etc. | +| Scalar Operations | 8 | `pto.vadds`, `pto.vmuls`, etc. | +| Transcendental | 6 | `pto.vexp`, `pto.vln`, `pto.vsqrt`, etc. | +| Reduction | 10 | `pto.vcadd`, `pto.vcmax`, `pto.vcmin` | +| Cube matmul family (zero-init / accumulate / bias-init; shared clauses `unit_flag`, `disable_gemv`, `sat`, `tf32_mode`, `n_dir`) | 16 | `pto.mad`, `pto.mad_acc`, `pto.mad_bias`, `pto.mad_mx`, `pto.mad_mx_acc`, `pto.mad_mx_bias` | +| Comparison | 11 | `pto.vcmp`, `pto.vcmps` | +| Selection | 11 | `pto.vsel`, `pto.vselr` | + +### Type & Data Manipulation + +| Operation | Group | Description | +|-----------|-------|-------------| +| Type Conversion | 9 | `pto.vcvt`, `pto.vbitcast`, `pto.pbitcast` | +| Interleave/Deinterleave | 12 | `pto.vintlv`, `pto.vdintlv` | +| Interleave/Deinterleave (not A5) | 12 | `pto.vintlvv2`, `pto.vdintlvv2` | + +### Synchronization + +| Operation | Group | Description | +|-----------|-------|-------------| +| Intra-core Sync | 1 | `pto.set_flag`, `pto.wait_flag` | +| Pipeline Buffer Sync | 1 | `pto.get_buf`, `pto.rls_buf` | + +### Scalar & Control Operations + +Group 14 covers the full scalar `arith` surface. The rows below list common PTO micro Instruction patterns rather than an exhaustive partition of `arith` ops. + +| Operation | Group | Description | +|-----------|-------|-------------| +| Scalar Constants | 14 | `arith.constant` | +| Scalar Integer / Index Arithmetic | 14 | `arith.addi`, `arith.subi`, `arith.muli`, `arith.divsi`, `arith.remui`, `arith.ceildivsi`, etc. | +| Scalar Floating-Point Arithmetic | 14 | `arith.addf`, `arith.subf`, `arith.mulf`, `arith.divf`, `arith.maximumf`, etc. | +| Scalar Compare & Select | 14 | `arith.cmpi`, `arith.cmpf`, `arith.select` | +| Scalar Casts / Width Changes | 14 | `arith.index_cast`, `arith.index_castui`, `arith.extsi`, `arith.extui`, `arith.trunci`, `arith.sitofp`, etc. | +| Scalar Bitwise / Shift Ops | 14 | `arith.andi`, `arith.ori`, `arith.xori`, `arith.shli`, `arith.shrsi`, `arith.shrui`, etc. | +| Counted Loops | 15 | `scf.for` | +| Conditional Regions | 15 | `scf.if`, `scf.yield` | +| Break-like Structured Loops | 15 | `scf.while`, `scf.condition`, `scf.yield` | + +### Cube Operation Surface + +- `pto.mte_l1_bt` +- `pto.mte_l1_fb` +- `pto.mte_gm_l1` +- `pto.mte_gm_l1_frac` +- `pto.mte_l1_ub` +- `pto.mte_l1_l0a` +- `pto.mte_l1_l0b` +- `pto.mte_l1_l0a_mx` +- `pto.mte_l1_l0b_mx` +- `pto.mad` +- `pto.mad_acc` +- `pto.mad_bias` +- `pto.mad_mx` +- `pto.mad_mx_acc` +- `pto.mad_mx_bias` +- `pto.mte_l0c_l1` +- `pto.mte_l0c_gm` +- `pto.mte_l0c_ub` + +--- + +## Supported Data Types + +| Type | Bits | vreg Lanes | Description | +|------|------|-----------|-------------| +| `i8` / `si8` / `ui8` | 8 | 256 | Signless/signed/unsigned 8-bit integer | +| `i16` / `si16` / `ui16` | 16 | 128 | Signless/signed/unsigned 16-bit integer | +| `f16` | 16 | 128 | IEEE 754 half precision | +| `bf16` | 16 | 128 | Brain floating point | +| `i32` / `si32` / `ui32` | 32 | 64 | Signless/signed/unsigned 32-bit integer | +| `f32` | 32 | 64 | IEEE 754 single precision | +| `i64` / `si64` / `ui64` | 64 | 32 | Signless/signed/unsigned 64-bit integer | + +--- + +## Common Patterns + +### Softmax (Numerically Stable) + +```mlir +// 1. Find max +%max_vec = pto.vcmax %logits, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +pto.vsts %max_vec, %ub_tmp[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +%max_bc = pto.vlds %ub_tmp[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> + +// 2. exp(x - max) using fused op +%exp = pto.vexpdif %logits, %max_bc, %mask, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// 3. Sum +%sum = pto.vcadd %exp, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +pto.vsts %sum, %ub_tmp[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +%sum_bc = pto.vlds %ub_tmp[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> + +// 4. Divide +%softmax = pto.vdiv %exp, %sum_bc, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +### ReLU Variants + +```mlir +// Standard ReLU +%relu = pto.vrelu %input, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Leaky ReLU (scalar alpha) +%lrelu = pto.vlrelu %input, %alpha, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Parametric ReLU (per-element alpha) +%prelu = pto.vprelu %input, %alpha_vec, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +``` + +### Data Layout Conversion + +```mlir +// AoS → SoA (deinterleave) +%x, %y = pto.vldsx2 %ub_xy[%offset], "DINTLV_B32" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +// SoA → AoS (interleave) +pto.vstsx2 %x, %y, %ub_xy[%offset], "INTLV_B32", %all_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.ptr, index, !pto.mask +``` + +--- + +*For detailed semantics, C-style pseudocode, and CCE mappings, see the individual group documentation files.* + +--- + +## Part IV: PTO Tile Instruction + +PTO Tile Instruction is a high-performance instruction surface built on top of PTO micro Instruction. Each tile instruction encapsulates a tile-granular pattern — DMA between GM and on-chip buffers, vector arithmetic over a whole tile, reductions, broadcast / expansion, selection, padding — and internally expands to a sequence of micro-instruction primitives (`pto.vlds`, `pto.vsts`, `pto.vadd`, mask ops, sync flags, …). + +The full PTO Tile Instruction reference starts from [Tile and PTO Tile Instruction overview](isa/tile-op/01-tile-overview.md). It covers: + +- [Tile and PTO Tile Instruction overview](isa/tile-op/01-tile-overview.md) — tile concept, on-chip placement, physical shape vs valid region, conventions +- [Types & Attributes](isa/tile-op/02-types-and-attributes.md) — `!pto.tile_buf`, `!pto.tensor_view`, address spaces, layout, pad +- [Pointer & View](isa/tile-op/03-pointer-and-view.md) — tensor views, partitions, tile allocation, valid-shape updates +- [DMA Data Movement](isa/tile-op/04-dma-data-movement.md) — `pto.tload` / `pto.tstore` +- [Vector Arithmetic](isa/tile-op/05-vector-arithmetic.md) — `pto.tadd / tsub / tmul / tdiv / tmax / tmin`, tile-scalar forms, unary math, activations +- [Reductions](isa/tile-op/06-reduction-ops.md), [Partial Elementwise](isa/tile-op/07-partial-elementwise.md), [Bitwise & Shift](isa/tile-op/08-bitwise-shift-ops.md), [Type Conversion](isa/tile-op/09-type-conversion.md), [Broadcast & Expansion](isa/tile-op/10-broadcast-and-expansion-ops.md), [Selection](isa/tile-op/11-selection-ops.md), [Fill & Padding](isa/tile-op/12-fill-and-padding-ops.md) + +For the boundary between Tile Instruction and the micro instruction surface (when to drop into `pto.vecscope` and how `pto.tile_buf_addr` bridges the two), see [Tile and PTO Tile Instruction overview §1.10](isa/tile-op/01-tile-overview.md#110-mixing-pto-tile-instruction-and-pto-micro-instruction). + +--- + +## Appendix: Discussion Points + +### Part I + +1. **mem_bar as pto op:** Should `pto.mem_bar` be a formal pto dialect op, or is there an existing mechanism? +2. **UB size parameterization:** Is 256KB always fixed, or should spec allow for architecture variants? +3. **MERGING predication:** Intentionally omitted (SW-emulated, perf overhead). Revisit if needed later. + +### Part II + +1. **Predication in C semantics:** Should every op's C code explicitly show the `if (mask[i])` guard, or assume all-active and note predication separately? +2. **VLane terminology:** Using "VLane" instead of "DataBlock" — confirm this naming is preferred. + +### Part 3A + +1. **pto.vdupi:** Is this distinct from `pto.vdup` with an immediate operand, or can `pto.vdup` handle both? +2. **Predicate ops (pand/por/pxor and predicate movement forms):** These need MLIR op definitions and verifier rules. Confirm priority. + +### Part 3B + +1. **Section 10 removals:** 4 interleave ops removed (not on A5). If multi-arch support is needed later, these would need conditional inclusion. + +### Part 3C + +2. **Store dist family completeness:** `vsts` currently covers `NORM_B8`, `NORM_B16`, `NORM_B32`, `1PT_B8`, `1PT_B16`, `1PT_B32`, `PK_B16`, `PK_B32`, `PK_B64`, `PK4_B32`, `MRG4CHN_B8`, `MRG2CHN_B8`, and `MRG2CHN_B16`, while `vstsx2` covers `INTLV_B8` / `INTLV_B16` / `INTLV_B32`. `MRG4CHN_B8` / `MRG2CHN_B8` / `MRG2CHN_B16` are preserved in the VPTO surface, but the current hardware still reports them as unsupported via verifier warning and they are not expected to validate at runtime on A5 today. +3. **vcvt width-changing pattern:** The even/odd + `vor` pattern for forms such as `f32 -> f16` is the standard compiler lowering. Confirm this is the intended representation in the spec. +4. **Stateful store ops (Section 14):** These are complex with SSA state threading. Are they all needed for A5, or can some be simplified? diff --git a/include/PTO/IR/CMakeLists.txt b/include/PTO/IR/CMakeLists.txt index 89e858f55..c0dcef288 100644 --- a/include/PTO/IR/CMakeLists.txt +++ b/include/PTO/IR/CMakeLists.txt @@ -41,6 +41,12 @@ if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/PTOInterfaces.td) mlir_tablegen(PTOInterfaces.cpp.inc -gen-op-interface-defs) endif() +if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/VPTOInterfaces.td) + set(LLVM_TARGET_DEFINITIONS VPTOInterfaces.td) + mlir_tablegen(VPTOInterfaces.h.inc -gen-op-interface-decls) + mlir_tablegen(VPTOInterfaces.cpp.inc -gen-op-interface-defs) +endif() + # ============================================================ # 3. 注册聚合目标 # ============================================================ diff --git a/include/PTO/IR/PTO.h b/include/PTO/IR/PTO.h index b8c6f90b2..5117df15c 100644 --- a/include/PTO/IR/PTO.h +++ b/include/PTO/IR/PTO.h @@ -47,6 +47,7 @@ //===----------------------------------------------------------------------===// #include "PTO/IR/PTOInterfaces.h.inc" +#include "PTO/IR/VPTOInterfaces.h.inc" //===----------------------------------------------------------------------===// // PTO Attributes @@ -66,6 +67,65 @@ // PTO Dialect Operations //===----------------------------------------------------------------------===// +namespace mlir { +namespace pto { + +//===----------------------------------------------------------------------===// +// S Fractal Size Constants +//===----------------------------------------------------------------------===// + +/// Fractal size for mxBox layout (16x2 inner block, 32 bytes total). +inline constexpr int32_t kFractalMxSize = 32; + +/// Fractal size for AB matrices in matmul (16xN inner block, 512 bytes). +inline constexpr int32_t kFractalABSize = 512; + +/// Fractal size for C matrix in matmul (16x16 inner block, 1024 bytes). +inline constexpr int32_t kFractalCSize = 1024; + +struct DmaLoopConfig { + Value count; + Value srcStride; + Value dstStride; +}; + +struct DmaPadConfig { + Value value; + Value leftCount; + Value rightCount; +}; + +struct AccStoreModeConfig { + AccStoreMode mode; + std::optional split; + std::optional loop0SrcStride; +}; + +struct CubeLoadFracShapeConfig { + Value nValue; + Value dValue; +}; + +struct CubeLoadFracSrcLayoutConfig { + Value srcInnerStride; + std::optional srcOuterStride; +}; + +struct CubeLoadFracDstGroupConfig { + Value groupCount; + Value dstLoop2Stride; + Value dstLoop3Stride; + Value dstLoop4Stride; +}; + +struct CubeLoadFracCtrlConfig { + Value l2CacheCtrl; + Value smallc0En; +}; + +} // namespace pto +} // namespace mlir + #define GET_OP_CLASSES #include "PTO/IR/PTOOps.h.inc" @@ -120,6 +180,7 @@ class ScopedPTOParserTargetArch { /// Function attribute that marks an explicit PTO kernel entry. inline constexpr llvm::StringLiteral kPTOEntryAttrName = "pto.entry"; inline constexpr llvm::StringLiteral kLegacyHACCEntryAttrName = "hacc.entry"; +inline constexpr llvm::StringLiteral kPTOSimtEntryAttrName = "pto.simt_entry"; /// Return true if the function carries an explicit entry marker. bool hasExplicitPTOEntryAttr(func::FuncOp func); diff --git a/include/PTO/IR/PTOAttrs.td b/include/PTO/IR/PTOAttrs.td index fc1e0ca2f..d063408ff 100644 --- a/include/PTO/IR/PTOAttrs.td +++ b/include/PTO/IR/PTOAttrs.td @@ -120,6 +120,55 @@ def PTO_PipeAttr : PTO_Attr<"Pipe", "pipe"> { }]; } +//===----------------------------------------------------------------------===// +// MemBar +//===----------------------------------------------------------------------===// + +def PTO_MEMBAR_VV_ALL : I32EnumAttrCase<"VV_ALL", 0>; +def PTO_MEMBAR_VST_VLD : I32EnumAttrCase<"VST_VLD", 1>; +def PTO_MEMBAR_VLD_VST : I32EnumAttrCase<"VLD_VST", 2>; +def PTO_MEMBAR_VST_VST : I32EnumAttrCase<"VST_VST", 3>; +def PTO_MEMBAR_VS_ALL : I32EnumAttrCase<"VS_ALL", 4>; +def PTO_MEMBAR_VST_LD : I32EnumAttrCase<"VST_LD", 5>; +def PTO_MEMBAR_VLD_ST : I32EnumAttrCase<"VLD_ST", 6>; +def PTO_MEMBAR_VST_ST : I32EnumAttrCase<"VST_ST", 7>; +def PTO_MEMBAR_SV_ALL : I32EnumAttrCase<"SV_ALL", 8>; +def PTO_MEMBAR_ST_VLD : I32EnumAttrCase<"ST_VLD", 9>; +def PTO_MEMBAR_LD_VST : I32EnumAttrCase<"LD_VST", 10>; +def PTO_MEMBAR_ST_VST : I32EnumAttrCase<"ST_VST", 11>; +def PTO_MEMBAR_SS_ALL : I32EnumAttrCase<"SS_ALL", 12>; +def PTO_MEMBAR_ST_LD : I32EnumAttrCase<"ST_LD", 13>; +def PTO_MEMBAR_LD_ST : I32EnumAttrCase<"LD_ST", 14>; +def PTO_MEMBAR_ST_ST : I32EnumAttrCase<"ST_ST", 15>; + +def PTO_MemBarEnum : PTO_I32Enum< + "MemBarKind", "PTO low-level memory barrier kind", [ + PTO_MEMBAR_VV_ALL, + PTO_MEMBAR_VST_VLD, + PTO_MEMBAR_VLD_VST, + PTO_MEMBAR_VST_VST, + PTO_MEMBAR_VS_ALL, + PTO_MEMBAR_VST_LD, + PTO_MEMBAR_VLD_ST, + PTO_MEMBAR_VST_ST, + PTO_MEMBAR_SV_ALL, + PTO_MEMBAR_ST_VLD, + PTO_MEMBAR_LD_VST, + PTO_MEMBAR_ST_VST, + PTO_MEMBAR_SS_ALL, + PTO_MEMBAR_ST_LD, + PTO_MEMBAR_LD_ST, + PTO_MEMBAR_ST_ST + ]>; + +def PTO_MemBarAttr : PTO_Attr<"MemBar", "membar"> { + let parameters = (ins EnumParameter:$kind); + let assemblyFormat = "`<` params `>`"; + let description = [{ + Low-level memory barrier kind for VPTO `pto.mem_bar`. + }]; +} + //===----------------------------------------------------------------------===// // Sync Op Type (High Level Abstraction) //===----------------------------------------------------------------------===// @@ -410,6 +459,57 @@ def PTO_RoundModeAttr : EnumAttr { let summary = "rounding mode attribute"; } +//===----------------------------------------------------------------------===// +// PrecisionMode +//===----------------------------------------------------------------------===// + +def PTO_PrecisionModeEnum : PTO_I32Enum< + "PrecisionMode", "PTO precision mode for elementwise math ops", [ + I32EnumAttrCase<"DEFAULT", 0>, + I32EnumAttrCase<"HIGH_PRECISION", 1> + ]>; + +def PTO_PrecisionModeAttr + : EnumAttr { + let summary = "precision mode attribute for elementwise math ops"; +} + +//===----------------------------------------------------------------------===// +// MAD semantic controls +//===----------------------------------------------------------------------===// + +def PTO_MadUnitFlagModeEnum : PTO_I32Enum< + "MadUnitFlagMode", "PTO MAD producer unit-flag mode", [ + I32EnumAttrCase<"CheckOnly", 0, "check_only">, + I32EnumAttrCase<"CheckAndSet", 1, "check_and_set"> + ]>; + +def PTO_MadUnitFlagModeAttr + : EnumAttr { + let summary = "MAD producer unit-flag mode attribute"; +} + +def PTO_Tf32ModeEnum : PTO_I32Enum< + "Tf32Mode", "PTO MAD TF32 rounding mode", [ + I32EnumAttrCase<"RoundEven", 0, "round_even">, + I32EnumAttrCase<"RoundAway", 1, "round_away"> + ]>; + +def PTO_Tf32ModeAttr : EnumAttr { + let summary = "MAD TF32 rounding mode attribute"; +} + +def PTO_MadSatModeEnum : PTO_I32Enum< + "MadSatMode", "PTO MAD floating exceptional-value saturation mode", [ + I32EnumAttrCase<"Sat", 0, "sat">, + I32EnumAttrCase<"NoSat", 1, "nosat"> + ]>; + +def PTO_MadSatModeAttr + : EnumAttr { + let summary = "MAD floating exceptional-value saturation mode attribute"; +} + //===----------------------------------------------------------------------===// // SaturationMode //===----------------------------------------------------------------------===// @@ -508,13 +608,142 @@ def PTO_ReduceOpAttr : EnumAttr { def PTO_ReluPreModeEnum : PTO_I32Enum< "ReluPreMode", "PTO TSTORE relu pre mode", [ I32EnumAttrCase<"NoRelu", 0, "no_relu">, - I32EnumAttrCase<"NormalRelu", 1, "normal_relu"> + I32EnumAttrCase<"NormalRelu", 1, "normal_relu">, + I32EnumAttrCase<"ScalarRelu", 2, "scalar_relu">, + I32EnumAttrCase<"VectorRelu", 3, "vector_relu">, + I32EnumAttrCase<"Pwl", 4, "pwl"> ]>; def PTO_ReluPreModeAttr : EnumAttr { let summary = "TSTORE relu pre mode attribute"; } +def PTO_UnitFlagCtrlEnum : PTO_I32Enum< + "AccStoreUnitFlagCtrl", "PTO accumulator store unit-flag mode", [ + I32EnumAttrCase<"Off", 0, "off">, + I32EnumAttrCase<"CheckOnly", 2, "check_only">, + I32EnumAttrCase<"CheckAndClear", 3, "check_and_clear"> + ]>; + +def PTO_UnitFlagCtrlAttr : EnumAttr { + let summary = "acc_store unit-flag control attribute"; +} + +def PTO_AccStoreQuantPreModeEnum : PTO_I32Enum< + "AccStoreQuantPreMode", "PTO accumulator store pre-quantization mode", [ + I32EnumAttrCase<"NoConvert", 0, "no_convert">, + I32EnumAttrCase<"F32F16", 1, "f32_f16">, + I32EnumAttrCase<"QF322HIF8PreVec", 2, "qf322hif8_pre_vec">, + I32EnumAttrCase<"QF322HIF8PreScalar", 3, "qf322hif8_pre_scalar">, + I32EnumAttrCase<"QF322HIF8PreHybridVec", 4, "qf322hif8_pre_hybrid_vec">, + I32EnumAttrCase<"QF322HIF8PreHybridScalar", 5, "qf322hif8_pre_hybrid_scalar">, + I32EnumAttrCase<"DEQS32IntVec", 6, "deqs32_int_vec">, + I32EnumAttrCase<"DEQS32IntScalar", 7, "deqs32_int_scalar">, + I32EnumAttrCase<"REQ8Vec", 8, "req8_vec">, + I32EnumAttrCase<"REQ8Scalar", 9, "req8_scalar">, + I32EnumAttrCase<"DEQF16Vec", 10, "deqf16_vec">, + I32EnumAttrCase<"DEQF16Scalar", 11, "deqf16_scalar">, + I32EnumAttrCase<"QF322FP8PreVec", 12, "qf322fp8_pre_vec">, + I32EnumAttrCase<"QF322FP8PreScalar", 13, "qf322fp8_pre_scalar">, + I32EnumAttrCase<"QF322F32PreVec", 14, "qf322f32_pre_vec">, + I32EnumAttrCase<"QF322F32PreScalar", 15, "qf322f32_pre_scalar">, + I32EnumAttrCase<"F32BF16", 16, "f32_bf16">, + I32EnumAttrCase<"QF162B8PreVec", 17, "qf162b8_pre_vec">, + I32EnumAttrCase<"QF162B8PreScalar", 18, "qf162b8_pre_scalar">, + I32EnumAttrCase<"QF162S4PreVec", 19, "qf162s4_pre_vec">, + I32EnumAttrCase<"QF162S4PreScalar", 20, "qf162s4_pre_scalar">, + I32EnumAttrCase<"REQ4Vec", 21, "req4_vec">, + I32EnumAttrCase<"REQ4Scalar", 22, "req4_scalar">, + I32EnumAttrCase<"QF322B8PreVec", 23, "qf322b8_pre_vec">, + I32EnumAttrCase<"QF322B8PreScalar", 24, "qf322b8_pre_scalar">, + I32EnumAttrCase<"QF322S4PreVec", 25, "qf322s4_pre_vec">, + I32EnumAttrCase<"QF322S4PreScalar", 26, "qf322s4_pre_scalar">, + I32EnumAttrCase<"DEQS16Vec", 27, "deqs16_vec">, + I32EnumAttrCase<"DEQS16Scalar", 28, "deqs16_scalar">, + I32EnumAttrCase<"QF162S16PreVec", 29, "qf162s16_pre_vec">, + I32EnumAttrCase<"QF162S16PreScalar", 30, "qf162s16_pre_scalar">, + I32EnumAttrCase<"QF322F16PreVec", 31, "qf322f16_pre_vec">, + I32EnumAttrCase<"QF322F16PreScalar", 32, "qf322f16_pre_scalar">, + I32EnumAttrCase<"QF322BF16PreVec", 33, "qf322bf16_pre_vec">, + I32EnumAttrCase<"QF322BF16PreScalar", 34, "qf322bf16_pre_scalar">, + I32EnumAttrCase<"QS322BF16PreVec", 35, "qs322bf16_pre_vec">, + I32EnumAttrCase<"QS322BF16PreScalar", 36, "qs322bf16_pre_scalar"> + ]>; + +def PTO_AccStoreQuantPreModeAttr : EnumAttr { + let summary = "acc_store pre-quantization mode attribute"; +} + +def PTO_AccStoreAtomicTypeEnum : PTO_I32Enum< + "AccStoreAtomicType", "PTO accumulator store atomic type", [ + I32EnumAttrCase<"F32", 1, "f32">, + I32EnumAttrCase<"F16", 2, "f16">, + I32EnumAttrCase<"S16", 3, "s16">, + I32EnumAttrCase<"S32", 4, "s32">, + I32EnumAttrCase<"S8", 5, "s8">, + I32EnumAttrCase<"BF16", 6, "bf16"> + ]>; + +def PTO_AccStoreAtomicTypeAttr : EnumAttr { + let summary = "acc_store atomic type attribute"; +} + +def PTO_AccStoreAtomicOpEnum : PTO_I32Enum< + "AccStoreAtomicOp", "PTO accumulator store atomic op", [ + I32EnumAttrCase<"Add", 0, "add">, + I32EnumAttrCase<"Max", 1, "max">, + I32EnumAttrCase<"Min", 2, "min"> + ]>; + +def PTO_AccStoreAtomicOpAttr : EnumAttr { + let summary = "acc_store atomic op attribute"; +} + +def PTO_AccStoreModeEnum : PTO_I32Enum< + "AccStoreMode", "PTO accumulator store layout conversion mode", [ + I32EnumAttrCase<"Nz2nd", 0, "nz2nd">, + I32EnumAttrCase<"Nz2dn", 1, "nz2dn">, + I32EnumAttrCase<"Nz2nz", 2, "nz2nz"> + ]>; + +def PTO_AccStoreModeAttr : EnumAttr { + let summary = "acc_store layout conversion mode"; +} + +def PTO_AccStoreUbDstModeEnum : PTO_I32Enum< + "AccStoreUbDstMode", "PTO accumulator store UB destination mode", [ + I32EnumAttrCase<"Single", 0, "single">, + I32EnumAttrCase<"SplitM", 1, "split_m">, + I32EnumAttrCase<"SplitN", 2, "split_n"> + ]>; + +def PTO_AccStoreUbDstModeAttr + : EnumAttr { + let summary = "acc_store_ub destination mode attribute"; +} + +def PTO_AccStoreSatModeEnum : PTO_I32Enum< + "AccStoreSatMode", "PTO accumulator store saturation mode", [ + I32EnumAttrCase<"Sat", 0, "sat">, + I32EnumAttrCase<"NoSat", 1, "nosat">, + I32EnumAttrCase<"SatPreserveNan", 2, "sat_preserve_nan"> + ]>; + +def PTO_AccStoreSatModeAttr + : EnumAttr { + let summary = "acc_store saturation mode attribute"; +} + +def PTO_CubeLoadFracModeEnum : PTO_I32Enum< + "CubeLoadFracMode", "PTO cube fractal load source layout mode", [ + I32EnumAttrCase<"Nd2nz", 0, "nd2nz">, + I32EnumAttrCase<"Dn2nz", 1, "dn2nz"> + ]>; + +def PTO_CubeLoadFracModeAttr : EnumAttr { + let summary = "cube_load_frac source layout conversion mode"; +} + def PTO_GatherOOBEnum : PTO_I32Enum< "GatherOOB", "PTO MGATHER out-of-bounds mode", [ I32EnumAttrCase<"Undefined", 0, "undefined">, diff --git a/include/PTO/IR/PTOOps.td b/include/PTO/IR/PTOOps.td index c502bb5d2..3c01aea8b 100644 --- a/include/PTO/IR/PTOOps.td +++ b/include/PTO/IR/PTOOps.td @@ -53,7 +53,7 @@ def TileBufOrMemRef : def ScalarPtrOrMemRef : TypeConstraint< CPred<"::mlir::pto::isScalarPtrOrMemRef($_self)">, - "Ptr or MemRef in GM">; + "Ptr or GM MemRef">; def ScalarType : AnyTypeOf<[AnySignlessInteger, AnyFloat], "numeric (integer/float)">; @@ -76,6 +76,8 @@ class PTO_DpsOp traits = []> class PTO_Op traits = []> : Op; +include "PTO/IR/VPTOOps.td" + //===----------------------------------------------------------------------===// // Pointer/View Ops (for your front-end IR) //===----------------------------------------------------------------------===// @@ -153,6 +155,31 @@ def IntToPtrOp : PTO_Op<"inttoptr", [Pure]> { }]; } +def CastPtrOp : PTO_Op<"castptr", [Pure]> { + let summary = "Cast between integer and !pto.ptr, or between !pto.ptr types"; + let description = [{ + Performs an explicit pointer-domain cast. + + Supported cases: + - integer -> !pto.ptr + - !pto.ptr -> integer + - !pto.ptr -> !pto.ptr + - memref<..., space> -> !pto.ptr (extract the aligned base ptr) + + Pointer-to-pointer casts must stay within the same PTO memory space. Cross + space casts such as gm <-> ub are rejected by the verifier. + }]; + + let arguments = (ins AnyType:$input); + let results = (outs AnyType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $input attr-dict `:` type($input) `->` type($result) + }]; +} + //===----------------------------------------------------------------------===// // Scalar pointer load/store //===----------------------------------------------------------------------===// @@ -239,8 +266,9 @@ def PartitionViewOp : PTO_Op<"partition_view", [AttrSizedOperandSegments]> { } // Helper: tensor_view or memref (after lowering tensor_view to memref). -def TensorViewOrMemRef : - AnyTypeOf<[TensorViewType, AnyMemRef], "TensorView or MemRef">; +def TensorViewLikeOrMemRef : + AnyTypeOf<[TensorViewType, PartitionTensorViewType, AnyMemRef], + "TensorView, PartitionTensorView, or MemRef">; // Get the size of a dimension of a tensor_view or its lowered memref view. // Result type: Index (use arith.index_cast if i32 is needed). @@ -257,7 +285,27 @@ def GetTensorViewDimOp : PTO_Op<"get_tensor_view_dim", [Pure]> { : memref<...>, index -> index }]; let arguments = (ins - TensorViewOrMemRef:$tensor_view, + TensorViewLikeOrMemRef:$tensor_view, + Index:$dim_index + ); + let results = (outs Index:$result); + let assemblyFormat = [{ + $tensor_view `,` $dim_index `:` qualified(type($tensor_view)) `->` qualified(type($result)) + attr-dict + }]; +} + +// Get the logical stride of a tensor_view dimension in elements. +// Result type: Index (use arith.index_cast if i32 is needed). +def GetTensorViewStrideOp : PTO_Op<"get_tensor_view_stride", [Pure]> { + let summary = "Get the stride of a dimension of a tensor_view."; + let description = [{ + Returns the stride, measured in elements, of the given dimension of a + logical tensor view. This op accepts either !pto.tensor_view or the memref + it is lowered to. + }]; + let arguments = (ins + TensorViewLikeOrMemRef:$tensor_view, Index:$dim_index ); let results = (outs Index:$result); @@ -2413,8 +2461,10 @@ def SetFlagOp : PTO_Op<"set_flag"> { PTO_EventAttr:$event_id ); let results = (outs); - let assemblyFormat = [{ - `[` $src_pipe `,` $dst_pipe `,` $event_id `]` attr-dict + let extraClassDeclaration = [{ + void print(::mlir::OpAsmPrinter &p); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, + ::mlir::OperationState &result); }]; } @@ -2426,8 +2476,10 @@ def WaitFlagOp : PTO_Op<"wait_flag"> { PTO_EventAttr:$event_id ); let results = (outs); - let assemblyFormat = [{ - `[` $src_pipe `,` $dst_pipe `,` $event_id `]` attr-dict + let extraClassDeclaration = [{ + void print(::mlir::OpAsmPrinter &p); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, + ::mlir::OperationState &result); }]; } @@ -2461,6 +2513,9 @@ def WaitFlagDynOp : PTO_Op<"wait_flag_dyn"> { // Buffer-ID Synchronization (A5) //===----------------------------------------------------------------------===// +def PTO_PipeLikeAttr + : AnyAttrOf<[PTO_PipeEventTypeAttr, PTO_SyncOpTypeAttr, PTO_PipeAttr]>; + def GetBufOp : PTO_Op<"get_buf"> { let summary = "Acquire a buffer-id token for a sync op type (A5)"; let description = [{ @@ -2477,7 +2532,7 @@ def GetBufOp : PTO_Op<"get_buf"> { }]; let arguments = (ins - PTO_PipeEventTypeLikeAttr:$op_type, + PTO_PipeLikeAttr:$op_type, I32Attr:$buf_id, DefaultValuedAttr:$mode ); @@ -2486,8 +2541,10 @@ def GetBufOp : PTO_Op<"get_buf"> { let hasVerifier = 1; - let assemblyFormat = [{ - `[` $op_type `,` $buf_id `]` attr-dict + let extraClassDeclaration = [{ + void print(::mlir::OpAsmPrinter &p); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, + ::mlir::OperationState &result); }]; } @@ -2501,7 +2558,7 @@ def RlsBufOp : PTO_Op<"rls_buf"> { }]; let arguments = (ins - PTO_PipeEventTypeLikeAttr:$op_type, + PTO_PipeLikeAttr:$op_type, I32Attr:$buf_id, DefaultValuedAttr:$mode ); @@ -2510,8 +2567,10 @@ def RlsBufOp : PTO_Op<"rls_buf"> { let hasVerifier = 1; - let assemblyFormat = [{ - `[` $op_type `,` $buf_id `]` attr-dict + let extraClassDeclaration = [{ + void print(::mlir::OpAsmPrinter &p); + static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, + ::mlir::OperationState &result); }]; } @@ -2604,6 +2663,71 @@ def SyncAllOp : PTO_Op<"syncall", [AttrSizedOperandSegments]> { let hasCustomAssemblyFormat = 1; } +//===----------------------------------------------------------------------===// +// SIMD Bridge Ops +//===----------------------------------------------------------------------===// + +def SimdTileToMemrefOp : PTO_Op<"simd.tile_to_memref", [Pure]> { + let summary = "Bridge cast from tile_buf to memref in OP-Lib bodies."; + let description = [{ + This op is the canonical bridge marker for OP-Lib templates to expose a + memref view from tile-like values while keeping external ABI as + !pto.tile_buf. + In tile_buf world, src is !pto.tile_buf and dst is the corresponding + memref bridge type. + After memref-world lowering, src may already be memref and this op remains + as a marker for backend lowering (EmitC) to materialize tile data access. + }]; + + let arguments = (ins TileBufOrMemRef:$src); + let results = (outs AnyMemRef:$dst); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $src attr-dict `:` qualified(type($src)) `to` type($dst) + }]; +} + +// --------------------------------------------------------------------------- +// TileBuf intrinsics — used in TileLang DSL-generated template functions. +// These ops extract memref address and valid shape from tile_buf parameters. +// After inline, FoldTileBufIntrinsics resolves them to concrete values. +// --------------------------------------------------------------------------- + +def TileValidRowsOp : PTO_Op<"tile_valid_rows", [Pure]> { + let summary = "Extract valid row count from a tile_buf."; + let description = [{ + Returns the valid row count (v_row) of a `tile_buf` as an `index` value. + When the tile_buf has a static v_row, FoldTileBufIntrinsics folds this + into `arith.constant`. When v_row is dynamic (`?`), the fold resolves + it to the runtime index value carried by the tile_buf. + }]; + + let arguments = (ins TileBufType:$src); + let results = (outs Index:$result); + + let assemblyFormat = [{ + $src attr-dict `:` qualified(type($src)) `->` type($result) + }]; +} + +def TileValidColsOp : PTO_Op<"tile_valid_cols", [Pure]> { + let summary = "Extract valid column count from a tile_buf."; + let description = [{ + Returns the valid column count (v_col) of a `tile_buf` as an `index` value. + When the tile_buf has a static v_col, FoldTileBufIntrinsics folds this + into `arith.constant`. When v_col is dynamic (`?`), the fold resolves + it to the runtime index value carried by the tile_buf. + }]; + + let arguments = (ins TileBufType:$src); + let results = (outs Index:$result); + + let assemblyFormat = [{ + $src attr-dict `:` qualified(type($src)) `->` type($result) + }]; +} //===----------------------------------------------------------------------===// // FFT Configuration Operation @@ -3425,7 +3549,9 @@ def TColExpandDivOp : PTO_TOp<"tcolexpanddiv", [ let arguments = (ins PTODpsType:$src0, PTODpsType:$src1, - PTODpsType:$dst + PTODpsType:$dst, + DefaultValuedAttr:$precision_mode ); let results = (outs); @@ -3804,7 +3930,9 @@ def TDivOp : PTO_TOp<"tdiv", [ let arguments = (ins PTODpsType:$src0, PTODpsType:$src1, - PTODpsType:$dst + PTODpsType:$dst, + DefaultValuedAttr:$precision_mode ); let results = (outs); @@ -3834,7 +3962,9 @@ def TDivSOp : PTO_TOp<"tdivs", [ let arguments = (ins AnyType:$src, AnyType:$scalar, - PTODpsType:$dst + PTODpsType:$dst, + DefaultValuedAttr:$precision_mode ); let results = (outs); @@ -3917,7 +4047,9 @@ def TExpOp : PTO_TOp<"texp", [ let arguments = (ins PTODpsType:$src, - PTODpsType:$dst + PTODpsType:$dst, + DefaultValuedAttr:$precision_mode ); let results = (outs); @@ -4355,7 +4487,9 @@ def TLogOp : PTO_TOp<"tlog", [ let arguments = (ins PTODpsType:$src, - PTODpsType:$dst + PTODpsType:$dst, + DefaultValuedAttr:$precision_mode ); let results = (outs); @@ -5118,7 +5252,9 @@ def TRecipOp: PTO_TOp<"trecip", [ let arguments = (ins PTODpsType:$src, - PTODpsType:$dst + PTODpsType:$dst, + DefaultValuedAttr:$precision_mode ); let results = (outs); @@ -5311,7 +5447,9 @@ def TRowExpandDivOp: PTO_TOp<"trowexpanddiv", [ PTODpsType:$src0, PTODpsType:$src1, Optional:$tmp, - PTODpsType:$dst + PTODpsType:$dst, + DefaultValuedAttr:$precision_mode ); let results = (outs); @@ -5702,7 +5840,9 @@ def TRsqrtOp: PTO_TOp<"trsqrt", [ let arguments = (ins PTODpsType:$src, PTODpsType:$dst, - Optional:$tmp + Optional:$tmp, + DefaultValuedAttr:$precision_mode ); let results = (outs); @@ -6012,7 +6152,9 @@ def TSqrtOp: PTO_TOp<"tsqrt", [ let arguments = (ins PTODpsType:$src, - PTODpsType:$dst + PTODpsType:$dst, + DefaultValuedAttr:$precision_mode ); let results = (outs); diff --git a/include/PTO/IR/PTOTypeDefs.td b/include/PTO/IR/PTOTypeDefs.td index 665014fc3..cc0c871e9 100644 --- a/include/PTO/IR/PTOTypeDefs.td +++ b/include/PTO/IR/PTOTypeDefs.td @@ -18,18 +18,25 @@ include "mlir/Interfaces/DataLayoutInterfaces.td" include "PTO/IR/PTODialect.td" include "PTO/IR/PTOAttrs.td" -// ---- !pto.ptr ---- +// ---- !pto.ptr ---- def PtrType : TypeDef { let mnemonic = "ptr"; let parameters = (ins - "mlir::Type":$elementType + "mlir::Type":$elementType, + "mlir::pto::AddressSpaceAttr":$memorySpace ); - let assemblyFormat = "`<` $elementType `>`"; + let hasCustomAssemblyFormat = 1; let skipDefaultBuilders = 1; let builders = [ TypeBuilder<(ins "Type":$elementType), [{ - return Base::get($_ctxt, elementType); + return Base::get($_ctxt, elementType, + mlir::pto::AddressSpaceAttr::get($_ctxt, + mlir::pto::AddressSpace::GM)); + }]>, + TypeBuilder<(ins "Type":$elementType, + "mlir::pto::AddressSpaceAttr":$memorySpace), [{ + return Base::get($_ctxt, elementType, memorySpace); }]> ]; } @@ -311,3 +318,5 @@ def F4E2M1x2Type : TypeDef { + let description = [{ + Interface for semantic MAD-family ops. The interface is only a uniform view + over the existing IR operands and attributes; it does not duplicate operand + state into a separate lowering descriptor. + }]; + let cppNamespace = "::mlir::pto"; + let methods = [ + InterfaceMethod<"Return the left matrix operand.", + "::mlir::Value", "getLhs", (ins), [{}], + /*defaultImplementation=*/[{ return $_op.getLhs(); }]>, + InterfaceMethod<"Return the right matrix operand.", + "::mlir::Value", "getRhs", (ins), [{}], + /*defaultImplementation=*/[{ return $_op.getRhs(); }]>, + InterfaceMethod<"Return the accumulator/destination operand.", + "::mlir::Value", "getDst", (ins), [{}], + /*defaultImplementation=*/[{ return $_op.getDst(); }]>, + InterfaceMethod<"Return M.", + "::mlir::Value", "getM", (ins), [{}], + /*defaultImplementation=*/[{ return $_op.getM(); }]>, + InterfaceMethod<"Return N.", + "::mlir::Value", "getN", (ins), [{}], + /*defaultImplementation=*/[{ return $_op.getN(); }]>, + InterfaceMethod<"Return K.", + "::mlir::Value", "getK", (ins), [{}], + /*defaultImplementation=*/[{ return $_op.getK(); }]>, + InterfaceMethod<"Return true for MX MAD forms.", + "bool", "isMadMxFamily">, + InterfaceMethod<"Return true if the op carries a bias operand.", + "bool", "hasBiasOperand">, + InterfaceMethod<"Return true if the op reads the existing accumulator.", + "bool", "readsAccumulator">, + InterfaceMethod<"Return true if TF32 mode is a legal semantic attribute.", + "bool", "supportsTf32Mode">, + InterfaceMethod<"Return the bias operand, or null for non-bias forms.", + "::mlir::Value", "getBiasOrNull">, + InterfaceMethod<"Return true if the accumulator source is bias.", + "bool", "initializesAccumulatorWithBias", (ins), [{}], + /*defaultImplementation=*/[{ return $_op.hasBiasOperand(); }]>, + InterfaceMethod<"Return true if the accumulator is initialized to zero.", + "bool", "initializesAccumulatorWithZero", (ins), [{}], + /*defaultImplementation=*/[{ + return !$_op.readsAccumulator() && !$_op.hasBiasOperand(); + }]>, + InterfaceMethod<"Return unit_flag_mode.", + "::mlir::Attribute", "getUnitFlagModeAttr", (ins), [{}], + /*defaultImplementation=*/[{ return $_op.getUnitFlagModeAttr(); }]>, + InterfaceMethod<"Return true if disable_gemv is set.", + "bool", "getDisableGemv", (ins), [{}], + /*defaultImplementation=*/[{ return $_op.getDisableGemv(); }]>, + InterfaceMethod<"Return sat_mode.", + "::mlir::Attribute", "getSatModeAttr", (ins), [{}], + /*defaultImplementation=*/[{ return $_op.getSatModeAttr(); }]>, + InterfaceMethod<"Return tf32_mode.", + "::mlir::Attribute", "getTf32ModeAttr">, + InterfaceMethod<"Return true if n_dir is set.", + "bool", "getNDir", (ins), [{}], + /*defaultImplementation=*/[{ return $_op.getNDir(); }]> + ]; +} + +def PTO_MadRawOpInterface : OpInterface<"MadRawOpInterface"> { + let description = [{ + Interface for raw MAD-family ops. Raw ops carry already-materialized X_t + and optionally a bias operand; intrinsic selection is derived from this + interface and the pointer element types. + }]; + let cppNamespace = "::mlir::pto"; + let methods = [ + InterfaceMethod<"Return the left matrix operand.", + "::mlir::Value", "getLhs", (ins), [{}], + /*defaultImplementation=*/[{ return $_op.getLhs(); }]>, + InterfaceMethod<"Return the right matrix operand.", + "::mlir::Value", "getRhs", (ins), [{}], + /*defaultImplementation=*/[{ return $_op.getRhs(); }]>, + InterfaceMethod<"Return the accumulator/destination operand.", + "::mlir::Value", "getDst", (ins), [{}], + /*defaultImplementation=*/[{ return $_op.getDst(); }]>, + InterfaceMethod<"Return X_t.", + "::mlir::Value", "getXt", (ins), [{}], + /*defaultImplementation=*/[{ return $_op.getXt(); }]>, + InterfaceMethod<"Return true for MX MAD forms.", + "bool", "isMadMxFamily">, + InterfaceMethod<"Return true if the op carries a bias operand.", + "bool", "hasBiasOperand">, + InterfaceMethod<"Return the bias operand, or null for non-bias forms.", + "::mlir::Value", "getBiasOrNull"> + ]; +} + +#endif // MLIR_DIALECT_PTO_IR_VPTOINTERFACES diff --git a/include/PTO/IR/VPTOOps.td b/include/PTO/IR/VPTOOps.td new file mode 100644 index 000000000..61e00e3ad --- /dev/null +++ b/include/PTO/IR/VPTOOps.td @@ -0,0 +1,2771 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- VPTOOps.td - PTO low-level operations ----------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_PTO_IR_VPTOOPS +#define MLIR_DIALECT_PTO_IR_VPTOOPS + +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "PTO/IR/VPTOInterfaces.td" + +def PTO_VectorType : Type($_self)">, + "PTO low-level vector type">; +def PTO_MaskTypeConstraint : Type($_self)">, + "PTO low-level mask type">; +def PTO_B8MaskTypeConstraint : Type< + CPred<"::llvm::isa<::mlir::pto::MaskType>($_self) && ::llvm::cast<::mlir::pto::MaskType>($_self).isB8()">, + "PTO low-level b8 mask type">; +def PTO_B16MaskTypeConstraint : Type< + CPred<"::llvm::isa<::mlir::pto::MaskType>($_self) && ::llvm::cast<::mlir::pto::MaskType>($_self).isB16()">, + "PTO low-level b16 mask type">; +def PTO_B32MaskTypeConstraint : Type< + CPred<"::llvm::isa<::mlir::pto::MaskType>($_self) && ::llvm::cast<::mlir::pto::MaskType>($_self).isB32()">, + "PTO low-level b32 mask type">; +def PTO_AlignTypeConstraint : Type($_self)">, + "PTO low-level align type">; + +def PTO_BufferType : Type< + CPred<"::llvm::isa<::mlir::pto::PtrType>($_self)">, + "pointer-like buffer type">; +def PTO_ScalingPtrType : Type< + CPred<"::llvm::isa<::mlir::pto::PtrType>($_self) && " + "::llvm::cast<::mlir::pto::PtrType>($_self).getMemorySpace().getAddressSpace() == ::mlir::pto::AddressSpace::SCALING">, + "scaling pointer type">; +def PTO_BufferLikeType : AnyTypeOf<[AnyMemRef, PTO_BufferType], + "memref or pointer-like buffer type">; +def PTO_AccStorePayloadType : AnyTypeOf<[AnySignlessInteger, AnyFloat, PTO_ScalingPtrType], + "signless integer, float scalar, or scaling pointer">; +def PTO_AccStoreClipPayloadType : AnyTypeOf<[AnyInteger, AnyFloat], + "integer or float clip scalar">; + +def PTOLoadOp : PTO_Op<"load", [ + DeclareOpInterfaceMethods + ]> { + let summary = "Load one scalar element from a VPTO pointer-like operand."; + + let arguments = (ins + PTO_BufferLikeType:$ptr, + Index:$offset + ); + + let results = (outs AnyType:$value); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $ptr `[` $offset `]` attr-dict `:` type($ptr) `->` type($value) + }]; +} + +def PTOStoreOp : PTO_Op<"store", [ + DeclareOpInterfaceMethods + ]> { + let summary = "Store one scalar element to a VPTO pointer-like operand."; + + let arguments = (ins + PTO_BufferLikeType:$ptr, + Index:$offset, + AnyType:$value + ); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $value `,` $ptr `[` $offset `]` attr-dict `:` type($ptr) `,` type($value) + }]; +} + +def TensorViewAddrOp : PTO_Op<"tensor_view_addr", [Pure]> { + let summary = "Extract address from a tensor view."; + let description = [{ + Returns the address view carried by a `tensor_view` or + `partition_tensor_view` value. The result may be either a memref view or a + typed PTO pointer, depending on the requested destination type. + + In authoring-form IR this op preserves the descriptor-style surface; + during view-to-memref lowering it collapses to the underlying memref value + or to a memref-derived pointer. + + This op may also accept a memref operand after earlier view lowering, in + which case it behaves as an identity marker and is removed by lowering. + }]; + + let arguments = (ins AnyTypeOf<[TensorViewType, PartitionTensorViewType, AnyMemRef], + "TensorViewLike or MemRef">:$src); + let results = (outs PtrOrMemRef:$dst); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $src attr-dict `:` qualified(type($src)) `->` type($dst) + }]; +} + +def TileBufAddrOp : PTO_Op<"tile_buf_addr", [Pure]> { + let summary = "Extract address from a tile_buf."; + let description = [{ + Returns the address view of the data region of a `tile_buf` value. + The result may be either a memref view or a typed PTO pointer, depending + on the requested destination type. Memref results use the tile's static + shape and address space. + + This op is emitted by TileLang DSL templates and resolved by the + FoldTileBufIntrinsics pass after inlining. Hand-written `.pto` may also + use it directly on the memref result of `pto.bind_tile` / lowered + `pto.alloc_tile`. + }]; + + let arguments = (ins AnyTypeOf<[TileBufType, AnyMemRef], + "tile_buf or tile-bound memref">:$src); + let results = (outs PtrOrMemRef:$dst); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $src attr-dict `:` qualified(type($src)) `->` type($dst) + }]; +} + +def VecScopeOp : PTO_Op<"vecscope", [SingleBlock, NoTerminator]> { + let summary = "Structured region container for one VPTO vector scope"; + let description = [{ + `pto.vecscope` marks a structured vector-scope interval without overloading + a dummy carrier loop with scope metadata. Lowering and emission passes may + use the region boundary to preserve loop shape while treating the enclosed + body as one VPTO vector interval. + }]; + + let regions = (region SizedRegion<1>:$body); + let hasVerifier = 1; + let assemblyFormat = "$body attr-dict"; +} + +def StrictVecScopeOp : PTO_Op<"strict_vecscope", [SingleBlock, NoTerminator, + IsolatedFromAbove]> { + let summary = "Structured VPTO vector scope with explicit captures only"; + let description = [{ + `pto.strict_vecscope` is the strict form of `pto.vecscope`. Values used by + the body must be passed explicitly through op operands and corresponding + block arguments; implicit SSA capture from the surrounding scope is + rejected. + }]; + + let arguments = (ins Variadic:$captures); + let regions = (region SizedRegion<1>:$body); + let hasVerifier = 1; + let assemblyFormat = [{ + `(` $captures `)` $body attr-dict `:` functional-type($captures, results) + }]; +} + +def PTO_MemBarOp : PTO_Op<"mem_bar"> { + let summary = "Low-level VPTO memory barrier"; + let description = [{ + Low-level memory ordering barrier that lowers to one of the + `llvm.hivm.mem.bar.*` intrinsics exposed by Bisheng. + }]; + + let arguments = (ins PTO_MemBarAttr:$kind); + let results = (outs); + + let hasCustomAssemblyFormat = 1; +} + + +class PTO_BinaryI64ConfigOp : PTO_Op { + let arguments = (ins + I64:$first, + I64:$second + ); + + let results = (outs); + + let assemblyFormat = [{ + $first `,` $second attr-dict `:` type($first) `,` type($second) + }]; +} + +class PTO_BinaryI64PureOp : PTO_Op { + let arguments = (ins + I64:$first, + I64:$second + ); + + let results = (outs I64:$result); + + let assemblyFormat = [{ + $first `,` $second attr-dict `:` type($first) `,` type($second) `->` type($result) + }]; +} + +class PTO_UnaryI64ConfigOp : PTO_Op { + let arguments = (ins I64:$value); + let results = (outs); + + let assemblyFormat = [{ + $value attr-dict `:` type($value) + }]; +} + +class PTO_NullaryI64PureOp : PTO_Op { + let arguments = (ins); + let results = (outs I64:$result); + + let assemblyFormat = [{ + attr-dict `:` type($result) + }]; +} + +class PTO_NullaryI32PureOp : PTO_Op { + let arguments = (ins); + let results = (outs I32:$result); + + let assemblyFormat = [{ + attr-dict `:` type($result) + }]; +} + +class PTO_NullaryConfigOp : PTO_Op { + let arguments = (ins); + let results = (outs); + let assemblyFormat = [{ attr-dict }]; +} + +def PTO_SetMovPadValOp : PTO_Op<"set_mov_pad_val"> { + let arguments = (ins AnyTypeOf<[AnyInteger, AnyFloat], + "integer/float scalar">:$value); + let results = (outs); + let hasVerifier = 1; + + let assemblyFormat = [{ + $value attr-dict `:` type($value) + }]; +} + +def PTO_SetLoop2StrideOutToUbOp : PTO_BinaryI64ConfigOp<"set_loop2_stride_outtoub">; +def PTO_SetLoop1StrideOutToUbOp : PTO_BinaryI64ConfigOp<"set_loop1_stride_outtoub">; +def PTO_SetLoopSizeOutToUbOp : PTO_BinaryI64ConfigOp<"set_loop_size_outtoub">; +def PTO_SetLoop2StrideOutToL1Op : PTO_UnaryI64ConfigOp<"set_loop2_stride_outtol1">; +def PTO_SetLoop1StrideOutToL1Op : PTO_UnaryI64ConfigOp<"set_loop1_stride_outtol1">; +def PTO_SetLoopSizeOutToL1Op : PTO_UnaryI64ConfigOp<"set_loop_size_outtol1">; +def PTO_SetLoop2StrideUbToOutOp : PTO_BinaryI64ConfigOp<"set_loop2_stride_ubtoout">; +def PTO_SetLoop1StrideUbToOutOp : PTO_BinaryI64ConfigOp<"set_loop1_stride_ubtoout">; +def PTO_SetLoopSizeUbToOutOp : PTO_BinaryI64ConfigOp<"set_loop_size_ubtoout">; +def PTO_SetLoop3ParaOp : PTO_BinaryI64ConfigOp<"set_loop3_para">; +def PTO_SetChannelParaOp : PTO_BinaryI64ConfigOp<"set_channel_para">; +def PTO_SetMte2NzParaOp : PTO_UnaryI64ConfigOp<"set_mte2_nz_para">; +def PTO_SetPadValOutToL1Op : PTO_UnaryI64ConfigOp<"set_pad_val_outtol1">; +def PTO_GetCtrlOp : PTO_NullaryI64PureOp<"get_ctrl">; +def PTO_GetVms4SrOp : PTO_Op<"get_vms4_sr", [Pure]> { + let arguments = (ins); + let results = (outs I16:$list0, I16:$list1, I16:$list2, I16:$list3); + + let assemblyFormat = [{ + attr-dict `:` type($list0) `,` type($list1) `,` type($list2) `,` type($list3) + }]; +} +def PTO_SetCtrlOp : PTO_UnaryI64ConfigOp<"set_ctrl">; +def PTO_StoreVfSimtInfoOp : PTO_Op<"store_vfsimt_info"> { + let summary = "Configure SIMT VF launch dimensions"; + let arguments = (ins I32:$dimZ, I32:$dimY, I32:$dimX); + let results = (outs); + let assemblyFormat = [{ + $dimZ `,` $dimY `,` $dimX attr-dict `:` type($dimZ) `,` type($dimY) `,` type($dimX) + }]; +} +def PTO_Sbitset0Op : PTO_BinaryI64PureOp<"sbitset0">; +def PTO_Sbitset1Op : PTO_BinaryI64PureOp<"sbitset1">; +def PTO_SetQuantPreOp : PTO_UnaryI64ConfigOp<"set_quant_pre">; +def PTO_SetReluAlphaOp : PTO_UnaryI64ConfigOp<"set_relu_alpha">; +def PTO_SetFixClipReluOp : PTO_UnaryI64ConfigOp<"set_fix_clip_relu">; +def PTO_SetFpcOp : PTO_UnaryI64ConfigOp<"set_fpc">; +def PTO_SetAtomicS32Op : PTO_NullaryConfigOp<"set_atomic_s32">; +def PTO_SetAtomicS8Op : PTO_NullaryConfigOp<"set_atomic_s8">; +def PTO_GetTidXOp : PTO_NullaryI32PureOp<"get_tid_x">; +def PTO_GetTidYOp : PTO_NullaryI32PureOp<"get_tid_y">; +def PTO_GetTidZOp : PTO_NullaryI32PureOp<"get_tid_z">; + +def PTO_CopyGmToUbufOp : PTO_Op<"copy_gm_to_ubuf", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferLikeType:$source, + PTO_BufferLikeType:$destination, + I64:$sid, + I64:$n_burst, + I64:$len_burst, + I64:$left_padding_count, + I64:$right_padding_count, + I1:$data_select_bit, + I64:$l2_cache_ctl, + I64:$gm_stride, + I64:$ub_stride + ); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `,` $destination `,` $sid `,` $n_burst `,` $len_burst `,` + $left_padding_count `,` $right_padding_count `,` $data_select_bit `,` $l2_cache_ctl `,` $gm_stride `,` $ub_stride + attr-dict `:` type($source) `,` type($destination) `,` type($sid) `,` type($n_burst) `,` + type($len_burst) `,` type($left_padding_count) `,` + type($right_padding_count) `,` type($data_select_bit) `,` type($l2_cache_ctl) `,` type($gm_stride) `,` type($ub_stride) + }]; +} + +def PTO_MteGmUbOp : PTO_Op<"mte_gm_ub", [ + AttrSizedOperandSegments, + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferLikeType:$source, + PTO_BufferLikeType:$destination, + I64:$l2_cache_ctl, + I64:$len_burst, + I64:$n_burst, + I64:$nburst_src_stride, + I64:$nburst_dst_stride, + Variadic:$loop_counts, + Variadic:$loop_src_strides, + Variadic:$loop_dst_strides, + Optional>:$pad_value, + Optional:$left_padding_count, + Optional:$right_padding_count + ); + + let results = (outs); + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; + + let builders = [ + OpBuilder<(ins + "::mlir::Value":$source, + "::mlir::Value":$destination, + "::mlir::Value":$l2CacheCtl, + "::mlir::Value":$lenBurst, + "::mlir::pto::DmaLoopConfig":$nburst, + "::llvm::ArrayRef<::mlir::pto::DmaLoopConfig>":$loops, + "::std::optional<::mlir::pto::DmaPadConfig>":$pad + )>, + OpBuilder<(ins + "::mlir::Value":$source, + "::mlir::Value":$destination, + "::mlir::Value":$l2CacheCtl, + "::mlir::Value":$lenBurst, + "::mlir::pto::DmaLoopConfig":$nburst, + "::std::optional<::mlir::pto::DmaLoopConfig>":$loop1, + "::std::optional<::mlir::pto::DmaLoopConfig>":$loop2, + "::std::optional<::mlir::pto::DmaPadConfig>":$pad + )> + ]; + +} + +def PTO_CopyUbufToUbufOp : PTO_Op<"copy_ubuf_to_ubuf"> { + let arguments = (ins + PTO_BufferType:$source, + PTO_BufferType:$destination, + I64:$sid, + I64:$n_burst, + I64:$len_burst, + I64:$src_stride, + I64:$dst_stride + ); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `,` $destination `,` $sid `,` $n_burst `,` $len_burst `,` $src_stride `,` $dst_stride + attr-dict `:` type($source) `,` type($destination) `,` type($sid) `,` type($n_burst) `,` + type($len_burst) `,` type($src_stride) `,` type($dst_stride) + }]; +} + +def PTO_CopyCbufToUbufOp : PTO_Op<"copy_cbuf_to_ubuf", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferType:$source, + PTO_BufferType:$destination, + I64:$sid, + I64:$n_burst, + I64:$len_burst, + I64:$src_stride, + I64:$dst_stride + ); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `,` $destination `,` $sid `,` $n_burst `,` $len_burst `,` $src_stride `,` $dst_stride + attr-dict `:` type($source) `,` type($destination) `,` type($sid) `,` type($n_burst) `,` + type($len_burst) `,` type($src_stride) `,` type($dst_stride) + }]; +} + +def PTO_CopyUbufToCbufOp : PTO_Op<"copy_ubuf_to_cbuf", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferType:$source, + PTO_BufferType:$destination, + I64:$sid, + I64:$n_burst, + I64:$len_burst, + I64:$src_stride, + I64:$dst_stride + ); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `,` $destination `,` $sid `,` $n_burst `,` $len_burst `,` $src_stride `,` $dst_stride + attr-dict `:` type($source) `,` type($destination) `,` type($sid) `,` type($n_burst) `,` + type($len_burst) `,` type($src_stride) `,` type($dst_stride) + }]; +} + +def PTO_MteUbUbOp : PTO_Op<"mte_ub_ub", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferType:$source, + PTO_BufferType:$destination, + I64:$n_burst, + I64:$len_burst, + I64:$src_stride, + I64:$dst_stride + ); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `,` $destination `,` $len_burst + `nburst` `(` $n_burst `,` $src_stride `,` $dst_stride `)` + attr-dict `:` type($source) `,` type($destination) `,` type($n_burst) `,` + type($len_burst) `,` type($src_stride) `,` + type($dst_stride) + }]; +} + +def PTO_MteUbL1Op : PTO_Op<"mte_ub_l1", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferType:$source, + PTO_BufferType:$destination, + I64:$n_burst, + I64:$len_burst, + I64:$src_stride, + I64:$dst_stride + ); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `,` $destination `,` $len_burst + `nburst` `(` $n_burst `,` $src_stride `,` $dst_stride `)` + attr-dict `:` type($source) `,` type($destination) `,` type($n_burst) `,` + type($len_burst) `,` type($src_stride) `,` + type($dst_stride) + }]; +} + +def PTO_CopyGmToCbufOp : PTO_Op<"copy_gm_to_cbuf"> { + let arguments = (ins + PTO_BufferType:$source, + PTO_BufferType:$destination, + I64:$n_burst, + I64:$len_burst, + I64:$src_stride, + I64:$dst_stride + ); + + let results = (outs); + + let assemblyFormat = [{ + $source `,` $destination `,` $n_burst `,` $len_burst `,` $src_stride `,` $dst_stride + attr-dict `:` type($source) `,` type($destination) `,` type($n_burst) `,` type($len_burst) `,` type($src_stride) `,` type($dst_stride) + }]; +} + +def PTO_CopyGmToCbufMultiNd2NzOp : PTO_Op<"copy_gm_to_cbuf_multi_nd2nz"> { + let arguments = (ins + PTO_BufferType:$source, + PTO_BufferType:$destination, + I64:$sid, + I64:$loop1_src_stride, + I64:$l2_cache_ctrl, + I64:$n_value, + I64:$d_value, + I64:$loop4_src_stride, + I1:$smallc0_en + ); + + let results = (outs); + + let assemblyFormat = [{ + $source `,` $destination `,` $sid `,` $loop1_src_stride `,` $l2_cache_ctrl `,` + $n_value `,` $d_value `,` $loop4_src_stride `,` $smallc0_en + attr-dict `:` type($source) `,` type($destination) `,` type($sid) `,` + type($loop1_src_stride) `,` type($l2_cache_ctrl) `,` type($n_value) `,` + type($d_value) `,` type($loop4_src_stride) `,` type($smallc0_en) + }]; +} + +def PTO_CopyGmToCbufMultiDn2NzOp : PTO_Op<"copy_gm_to_cbuf_multi_dn2nz"> { + let arguments = (ins + PTO_BufferType:$source, + PTO_BufferType:$destination, + I64:$sid, + I64:$loop1_src_stride, + I64:$l2_cache_ctrl, + I64:$n_value, + I64:$d_value, + I64:$loop4_src_stride, + I1:$smallc0_en + ); + + let results = (outs); + + let assemblyFormat = [{ + $source `,` $destination `,` $sid `,` $loop1_src_stride `,` $l2_cache_ctrl `,` + $n_value `,` $d_value `,` $loop4_src_stride `,` $smallc0_en + attr-dict `:` type($source) `,` type($destination) `,` type($sid) `,` + type($loop1_src_stride) `,` type($l2_cache_ctrl) `,` type($n_value) `,` + type($d_value) `,` type($loop4_src_stride) `,` type($smallc0_en) + }]; +} + +def PTO_CopyCbufToBtOp : PTO_Op<"copy_cbuf_to_bt"> { + let arguments = (ins + PTO_BufferType:$source, + PTO_BufferType:$destination, + I1:$conv_control, + I64:$n_burst, + I64:$len_burst, + I64:$source_gap, + I64:$dst_gap + ); + let results = (outs); + + let assemblyFormat = [{ + $source `,` $destination `,` $conv_control `,` $n_burst `,` $len_burst `,` + $source_gap `,` $dst_gap + attr-dict `:` type($source) `,` type($destination) `,` type($conv_control) `,` + type($n_burst) `,` type($len_burst) `,` type($source_gap) `,` type($dst_gap) + }]; +} + +def PTO_CopyCbufToFbufOp : PTO_Op<"copy_cbuf_to_fbuf"> { + let arguments = (ins + PTO_BufferType:$source, + PTO_BufferType:$destination, + I64:$n_burst, + I64:$len_burst, + I64:$source_gap, + I64:$dst_gap + ); + let results = (outs); + + let assemblyFormat = [{ + $source `,` $destination `,` $n_burst `,` $len_burst `,` + $source_gap `,` $dst_gap + attr-dict `:` type($source) `,` type($destination) `,` type($n_burst) `,` + type($len_burst) `,` type($source_gap) `,` type($dst_gap) + }]; +} + +def PTO_LoadCbufToCaOp : PTO_Op<"load_cbuf_to_ca"> { + let arguments = (ins + PTO_BufferType:$source, + PTO_BufferType:$destination, + I64:$m_start, + I64:$k_start, + I64:$m_step, + I64:$k_step, + I64:$src_stride, + I64:$dst_stride, + DefaultValuedAttr:$transpose + ); + + let results = (outs); + + let assemblyFormat = [{ + $source `,` $destination `,` $m_start `,` $k_start `,` $m_step `,` + $k_step `,` $src_stride `,` $dst_stride attr-dict `:` type($source) `,` + type($destination) `,` type($m_start) `,` type($k_start) `,` + type($m_step) `,` type($k_step) `,` type($src_stride) `,` + type($dst_stride) + }]; +} + +def PTO_LoadCbufToCaS4Op : PTO_Op<"load_cbuf_to_ca_s4"> { + let arguments = (ins + PTO_BufferType:$source, + PTO_BufferType:$destination, + I64:$m_start, + I64:$k_start, + I64:$m_step, + I64:$k_step, + I64:$src_stride, + I64:$dst_stride, + I64:$transpose + ); + + let results = (outs); + + let assemblyFormat = [{ + $source `,` $destination `,` $m_start `,` $k_start `,` $m_step `,` + $k_step `,` $src_stride `,` $dst_stride `,` $transpose + attr-dict `:` type($source) `,` type($destination) `,` type($m_start) `,` + type($k_start) `,` type($m_step) `,` type($k_step) `,` type($src_stride) `,` + type($dst_stride) `,` type($transpose) + }]; +} + +def PTO_LoadCbufToCbOp : PTO_Op<"load_cbuf_to_cb"> { + let arguments = (ins + PTO_BufferType:$source, + PTO_BufferType:$destination, + I64:$m_start, + I64:$k_start, + I64:$m_step, + I64:$k_step, + I64:$src_stride, + I64:$dst_stride, + DefaultValuedAttr:$transpose + ); + + let results = (outs); + + let assemblyFormat = [{ + $source `,` $destination `,` $m_start `,` $k_start `,` $m_step `,` + $k_step `,` $src_stride `,` $dst_stride attr-dict `:` type($source) `,` + type($destination) `,` type($m_start) `,` type($k_start) `,` + type($m_step) `,` type($k_step) `,` type($src_stride) `,` + type($dst_stride) + }]; +} + +def PTO_LoadCbufToCbS4Op : PTO_Op<"load_cbuf_to_cb_s4"> { + let arguments = (ins + PTO_BufferType:$source, + PTO_BufferType:$destination, + I64:$m_start, + I64:$k_start, + I64:$m_step, + I64:$k_step, + I64:$src_stride, + I64:$dst_stride, + I64:$transpose + ); + + let results = (outs); + + let assemblyFormat = [{ + $source `,` $destination `,` $m_start `,` $k_start `,` $m_step `,` + $k_step `,` $src_stride `,` $dst_stride `,` $transpose + attr-dict `:` type($source) `,` type($destination) `,` type($m_start) `,` + type($k_start) `,` type($m_step) `,` type($k_step) `,` type($src_stride) `,` + type($dst_stride) `,` type($transpose) + }]; +} + +def PTO_LoadCbufToCaMxOp : PTO_Op<"load_cbuf_to_ca_mx"> { + let arguments = (ins + PTO_BufferType:$source, + PTO_BufferType:$destination, + I64:$m, + I64:$k + ); + + let results = (outs); + + let assemblyFormat = [{ + $source `,` $destination `,` $m `,` $k attr-dict `:` type($source) `,` type($destination) `,` type($m) `,` type($k) + }]; +} + +def PTO_LoadCbufToCbMxOp : PTO_Op<"load_cbuf_to_cb_mx"> { + let arguments = (ins + PTO_BufferType:$source, + PTO_BufferType:$destination, + I64:$x_start_position, + I64:$y_start_position, + I64:$x_step, + I64:$y_step, + I64:$src_stride, + I64:$dst_stride + ); + + let results = (outs); + + let assemblyFormat = [{ + $source `,` $destination `,` $x_start_position `,` $y_start_position `,` + $x_step `,` $y_step `,` $src_stride `,` $dst_stride attr-dict `:` + type($source) `,` type($destination) `,` type($x_start_position) `,` + type($y_start_position) `,` type($x_step) `,` type($y_step) `,` + type($src_stride) `,` type($dst_stride) + }]; +} + +def PTO_CopyMatrixCcToGmOp : PTO_Op<"copy_matrix_cc_to_gm"> { + let arguments = (ins + PTO_BufferType:$source, + PTO_BufferType:$destination, + I64:$xm, + I64:$xt + ); + + let results = (outs); + + let assemblyFormat = [{ + $source `,` $destination `,` $xm `,` $xt attr-dict `:` type($source) `,` type($destination) `,` type($xm) `,` type($xt) + }]; +} + +def PTO_CopyMatrixCcToCbufOp : PTO_Op<"copy_matrix_cc_to_cbuf"> { + let arguments = (ins + PTO_BufferType:$source, + PTO_BufferType:$destination, + I64:$config0, + I64:$config1 + ); + + let results = (outs); + + let assemblyFormat = [{ + $source `,` $destination `,` $config0 `,` $config1 attr-dict `:` type($source) `,` + type($destination) `,` type($config0) `,` type($config1) + }]; +} + +def PTO_CopyMatrixCcToUbOp : PTO_Op<"copy_matrix_cc_to_ub"> { + let arguments = (ins + PTO_BufferType:$source, + PTO_BufferType:$destination, + I64:$config0, + I64:$config1 + ); + + let results = (outs); + + let assemblyFormat = [{ + $source `,` $destination `,` $config0 `,` $config1 attr-dict `:` type($source) `,` + type($destination) `,` type($config0) `,` type($config1) + }]; +} + +def PTO_VldsOp : PTO_Op<"vlds", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferLikeType:$source, + Index:$offset, + OptionalAttr:$dist + ); + + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `[` $offset `]` attr-dict `:` type($source) `->` type($result) + }]; +} + +def PTO_VldsPostOp : PTO_Op<"vlds_post", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferLikeType:$source, + Index:$offset, + OptionalAttr:$dist + ); + + let results = (outs PTO_VectorType:$result, + PTO_BufferLikeType:$updated_source); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `[` $offset `]` attr-dict `:` type($source) `->` type($result) `,` type($updated_source) + }]; +} + +def PTO_Vldsx2Op : PTO_Op<"vldsx2", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferLikeType:$source, + Index:$offset, + StrAttr:$dist + ); + let results = (outs PTO_VectorType:$low, PTO_VectorType:$high); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `[` $offset `]` `,` $dist attr-dict `:` type($source) `,` type($offset) `->` type($low) `,` type($high) + }]; +} + +def PTO_VldasOp : PTO_Op<"vldas", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferType:$source + ); + + let results = (outs PTO_AlignTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source attr-dict `:` type($source) `->` type($result) + }]; +} + +def PTO_InitAlignOp : PTO_Op<"init_align", []> { + let arguments = (ins); + + let results = (outs PTO_AlignTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + attr-dict `:` type($result) + }]; +} + +def PTO_SprclrOp : PTO_Op<"sprclr", []> { + let arguments = (ins StrAttr:$spr); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $spr attr-dict + }]; +} + +def PTO_VldusOp : PTO_Op<"vldus", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferType:$source, + PTO_AlignTypeConstraint:$align + ); + + let results = (outs + PTO_VectorType:$result, + PTO_AlignTypeConstraint:$updated_align + ); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `,` $align attr-dict `:` type($source) `,` type($align) `->` type($result) `,` type($updated_align) + }]; +} + +def PTO_UvldOp : PTO_Op<"uvld", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferLikeType:$source, + Index:$offset + ); + + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `[` $offset `]` attr-dict `:` type($source) `->` type($result) + }]; +} + +def PTO_VbrOp : PTO_Op<"vbr", [Pure]> { + let arguments = (ins AnyType:$value); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $value attr-dict `:` type($value) `->` type($result) + }]; +} + +def PTO_VdupOp : PTO_Op<"vdup", [Pure]> { + let arguments = (ins + AnyType:$input, + PTO_MaskTypeConstraint:$mask, + OptionalAttr:$position + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $input `,` $mask attr-dict `:` type($input) `,` type($mask) `->` type($result) + }]; +} + +def PTO_PsetB8Op : PTO_Op<"pset_b8", [Pure]> { + let arguments = (ins StrAttr:$pattern); + let results = (outs PTO_B8MaskTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $pattern attr-dict `:` type($result) + }]; +} + +def PTO_PsetB16Op : PTO_Op<"pset_b16", [Pure]> { + let arguments = (ins StrAttr:$pattern); + let results = (outs PTO_B16MaskTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $pattern attr-dict `:` type($result) + }]; +} + +// NOTE: The op families introduced below are intentionally marked as +// unvalidated scaffolding. They are added to preserve missing CCE builtin +// semantics at the dialect layer, but they have not yet been validated through +// PTO lowering or end-to-end sample execution. +def PTO_PsetB32Op : PTO_Op<"pset_b32", [Pure]> { + let arguments = (ins StrAttr:$pattern); + let results = (outs PTO_B32MaskTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $pattern attr-dict `:` type($result) + }]; +} + +def PTO_PgeB8Op : PTO_Op<"pge_b8", [Pure]> { + let arguments = (ins StrAttr:$pattern); + let results = (outs PTO_B8MaskTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $pattern attr-dict `:` type($result) + }]; +} + +def PTO_PgeB16Op : PTO_Op<"pge_b16", [Pure]> { + let arguments = (ins StrAttr:$pattern); + let results = (outs PTO_B16MaskTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $pattern attr-dict `:` type($result) + }]; +} + +def PTO_PgeB32Op : PTO_Op<"pge_b32", [Pure]> { + let arguments = (ins StrAttr:$pattern); + let results = (outs PTO_B32MaskTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $pattern attr-dict `:` type($result) + }]; +} + +def PTO_PltB8Op : PTO_Op<"plt_b8", [Pure]> { + let arguments = (ins I32:$scalar); + let results = (outs PTO_B8MaskTypeConstraint:$mask, I32:$scalar_out); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $scalar attr-dict `:` type($scalar) `->` type($mask) `,` type($scalar_out) + }]; +} + +def PTO_PltB16Op : PTO_Op<"plt_b16", [Pure]> { + let arguments = (ins I32:$scalar); + let results = (outs PTO_B16MaskTypeConstraint:$mask, I32:$scalar_out); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $scalar attr-dict `:` type($scalar) `->` type($mask) `,` type($scalar_out) + }]; +} + +def PTO_PltB32Op : PTO_Op<"plt_b32", [Pure]> { + let arguments = (ins I32:$scalar); + let results = (outs PTO_B32MaskTypeConstraint:$mask, I32:$scalar_out); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $scalar attr-dict `:` type($scalar) `->` type($mask) `,` type($scalar_out) + }]; +} + +class PTO_MaskUnaryOp : PTO_Op { + let arguments = (ins PTO_MaskTypeConstraint:$input); + let results = (outs PTO_MaskTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $input attr-dict `:` type($input) `->` type($result) + }]; +} + +def PTO_PpackOp : PTO_MaskUnaryOp<"ppack"> { + let arguments = (ins PTO_MaskTypeConstraint:$input, StrAttr:$part); + let assemblyFormat = [{ + $input `,` $part attr-dict `:` type($input) `->` type($result) + }]; +} + +def PTO_PunpackOp : PTO_MaskUnaryOp<"punpack"> { + let arguments = (ins PTO_MaskTypeConstraint:$input, StrAttr:$part); + let assemblyFormat = [{ + $input `,` $part attr-dict `:` type($input) `->` type($result) + }]; +} + +def PTO_PbitcastOp : PTO_MaskUnaryOp<"pbitcast">; + +def PTO_PnotOp : PTO_Op<"pnot", [Pure]> { + let arguments = (ins + PTO_MaskTypeConstraint:$input, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_MaskTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $input `,` $mask attr-dict `:` type($input) `,` type($mask) `->` type($result) + }]; +} + +def PTO_PselOp : PTO_Op<"psel", [Pure]> { + let arguments = (ins + PTO_MaskTypeConstraint:$src0, + PTO_MaskTypeConstraint:$src1, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_MaskTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $src0 `,` $src1 `,` $mask attr-dict `:` type($src0) `,` type($src1) `,` type($mask) `->` type($result) + }]; +} + +def PTO_PandOp : PTO_Op<"pand", [Pure]> { + let arguments = (ins + PTO_MaskTypeConstraint:$src0, + PTO_MaskTypeConstraint:$src1, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_MaskTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $src0 `,` $src1 `,` $mask attr-dict `:` type($src0) `,` type($src1) `,` type($mask) `->` type($result) + }]; +} + +def PTO_PorOp : PTO_Op<"por", [Pure]> { + let arguments = (ins + PTO_MaskTypeConstraint:$src0, + PTO_MaskTypeConstraint:$src1, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_MaskTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $src0 `,` $src1 `,` $mask attr-dict `:` type($src0) `,` type($src1) `,` type($mask) `->` type($result) + }]; +} + +def PTO_PxorOp : PTO_Op<"pxor", [Pure]> { + let arguments = (ins + PTO_MaskTypeConstraint:$src0, + PTO_MaskTypeConstraint:$src1, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_MaskTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $src0 `,` $src1 `,` $mask attr-dict `:` type($src0) `,` type($src1) `,` type($mask) `->` type($result) + }]; +} + +def PTO_PldsOp : PTO_Op<"plds", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferLikeType:$source, + Index:$offset, + StrAttr:$dist + ); + let results = (outs PTO_MaskTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `[` $offset `]` `,` $dist attr-dict `:` type($source) `,` type($offset) `->` type($result) + }]; +} + +def PTO_PldiOp : PTO_Op<"pldi", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferLikeType:$source, + Index:$offset, + StrAttr:$dist + ); + let results = (outs PTO_MaskTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `[` $offset `]` `,` $dist attr-dict `:` type($source) `,` type($offset) `->` type($result) + }]; +} + +def PTO_PstiOp : PTO_Op<"psti", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_MaskTypeConstraint:$value, + PTO_BufferLikeType:$destination, + Index:$offset, + StrAttr:$dist + ); + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $value `,` $destination `[` $offset `]` `,` $dist attr-dict `:` type($value) `,` type($destination) `,` type($offset) + }]; +} + +def PTO_VabsOp : PTO_Op<"vabs", [Pure]> { + let arguments = (ins + PTO_VectorType:$input, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $input `,` $mask attr-dict `:` type($input) `,` type($mask) `->` type($result) + }]; +} + +class PTO_UnaryVecOp : PTO_Op { + let arguments = (ins + PTO_VectorType:$input, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $input `,` $mask attr-dict `:` type($input) `,` type($mask) `->` type($result) + }]; +} + +def PTO_VexpOp : PTO_UnaryVecOp<"vexp">; +def PTO_VlnOp : PTO_UnaryVecOp<"vln">; +def PTO_VsqrtOp : PTO_UnaryVecOp<"vsqrt">; +def PTO_VnegOp : PTO_UnaryVecOp<"vneg">; +def PTO_VreluOp : PTO_UnaryVecOp<"vrelu">; +def PTO_VnotOp : PTO_UnaryVecOp<"vnot">; +def PTO_VcaddOp : PTO_UnaryVecOp<"vcadd">; + +def PTO_MadOp : PTO_Op<"mad", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferType:$lhs, + PTO_BufferType:$rhs, + PTO_BufferType:$dst, + I64:$m, + I64:$n, + I64:$k, + OptionalAttr:$unit_flag_mode, + UnitAttr:$disable_gemv, + OptionalAttr:$sat_mode, + OptionalAttr:$tf32_mode, + UnitAttr:$n_dir + ); + let results = (outs); + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; +} + +def PTO_MadAccOp : PTO_Op<"mad_acc", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferType:$lhs, + PTO_BufferType:$rhs, + PTO_BufferType:$dst, + I64:$m, + I64:$n, + I64:$k, + OptionalAttr:$unit_flag_mode, + UnitAttr:$disable_gemv, + OptionalAttr:$sat_mode, + OptionalAttr:$tf32_mode, + UnitAttr:$n_dir + ); + let results = (outs); + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; +} + +def PTO_MadBiasOp : PTO_Op<"mad_bias", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferType:$lhs, + PTO_BufferType:$rhs, + PTO_BufferType:$dst, + PTO_BufferType:$bias, + I64:$m, + I64:$n, + I64:$k, + OptionalAttr:$unit_flag_mode, + UnitAttr:$disable_gemv, + OptionalAttr:$sat_mode, + OptionalAttr:$tf32_mode, + UnitAttr:$n_dir + ); + let results = (outs); + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; +} + +def PTO_MadMxOp : PTO_Op<"mad_mx", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferType:$lhs, + PTO_BufferType:$rhs, + PTO_BufferType:$dst, + I64:$m, + I64:$n, + I64:$k, + OptionalAttr:$unit_flag_mode, + UnitAttr:$disable_gemv, + OptionalAttr:$sat_mode, + UnitAttr:$n_dir + ); + let results = (outs); + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; +} + +def PTO_MadMxAccOp : PTO_Op<"mad_mx_acc", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferType:$lhs, + PTO_BufferType:$rhs, + PTO_BufferType:$dst, + I64:$m, + I64:$n, + I64:$k, + OptionalAttr:$unit_flag_mode, + UnitAttr:$disable_gemv, + OptionalAttr:$sat_mode, + UnitAttr:$n_dir + ); + let results = (outs); + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; +} + +def PTO_MadMxBiasOp : PTO_Op<"mad_mx_bias", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferType:$lhs, + PTO_BufferType:$rhs, + PTO_BufferType:$dst, + PTO_BufferType:$bias, + I64:$m, + I64:$n, + I64:$k, + OptionalAttr:$unit_flag_mode, + UnitAttr:$disable_gemv, + OptionalAttr:$sat_mode, + UnitAttr:$n_dir + ); + let results = (outs); + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; +} + +def PTO_MadRawOp : PTO_Op<"mad_raw", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferType:$lhs, + PTO_BufferType:$rhs, + PTO_BufferType:$dst, + I64:$xt + ); + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs `,` $dst `,` $xt attr-dict `:` type($lhs) `,` type($rhs) `,` type($dst) `,` type($xt) + }]; +} + +def PTO_MadBiasRawOp : PTO_Op<"mad_bias_raw", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferType:$lhs, + PTO_BufferType:$rhs, + PTO_BufferType:$dst, + PTO_BufferType:$bias, + I64:$xt + ); + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs `,` $dst `,` $bias `,` $xt attr-dict `:` type($lhs) `,` type($rhs) `,` type($dst) `,` type($bias) `,` type($xt) + }]; +} + +def PTO_MadMxRawOp : PTO_Op<"mad_mx_raw", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferType:$lhs, + PTO_BufferType:$rhs, + PTO_BufferType:$dst, + I64:$xt + ); + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs `,` $dst `,` $xt attr-dict `:` type($lhs) `,` type($rhs) `,` type($dst) `,` type($xt) + }]; +} + +def PTO_MadMxBiasRawOp : PTO_Op<"mad_mx_bias_raw", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferType:$lhs, + PTO_BufferType:$rhs, + PTO_BufferType:$dst, + PTO_BufferType:$bias, + I64:$xt + ); + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs `,` $dst `,` $bias `,` $xt attr-dict `:` type($lhs) `,` type($rhs) `,` type($dst) `,` type($bias) `,` type($xt) + }]; +} + +def PTO_VcmaxOp : PTO_UnaryVecOp<"vcmax">; +def PTO_VcminOp : PTO_UnaryVecOp<"vcmin">; +def PTO_VcgaddOp : PTO_UnaryVecOp<"vcgadd">; +def PTO_VcgmaxOp : PTO_UnaryVecOp<"vcgmax">; +def PTO_VcgminOp : PTO_UnaryVecOp<"vcgmin">; +def PTO_VcpaddOp : PTO_UnaryVecOp<"vcpadd">; + +class PTO_BinaryVecOp : PTO_Op { + let arguments = (ins + PTO_VectorType:$lhs, + PTO_VectorType:$rhs, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs `,` $mask attr-dict `:` type($lhs) `,` type($rhs) `,` type($mask) `->` type($result) + }]; +} + +def PTO_VaddOp : PTO_BinaryVecOp<"vadd">; +def PTO_VsubOp : PTO_BinaryVecOp<"vsub">; +def PTO_VmulOp : PTO_BinaryVecOp<"vmul">; +def PTO_VdivOp : PTO_BinaryVecOp<"vdiv">; +def PTO_VmaxOp : PTO_BinaryVecOp<"vmax">; +def PTO_VminOp : PTO_BinaryVecOp<"vmin">; +def PTO_VandOp : PTO_BinaryVecOp<"vand">; +def PTO_VorOp : PTO_BinaryVecOp<"vor">; +def PTO_VxorOp : PTO_BinaryVecOp<"vxor">; + +def PTO_VaddcOp : PTO_Op<"vaddc", [Pure]> { + let arguments = (ins + PTO_VectorType:$lhs, + PTO_VectorType:$rhs, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs + PTO_VectorType:$result, + PTO_MaskTypeConstraint:$carry + ); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs `,` $mask attr-dict `:` type($lhs) `,` type($rhs) `,` type($mask) `->` type($result) `,` type($carry) + }]; +} + +def PTO_VsubcOp : PTO_Op<"vsubc", [Pure]> { + let arguments = (ins + PTO_VectorType:$lhs, + PTO_VectorType:$rhs, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs + PTO_VectorType:$result, + PTO_MaskTypeConstraint:$carry + ); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs `,` $mask attr-dict `:` type($lhs) `,` type($rhs) `,` type($mask) `->` type($result) `,` type($carry) + }]; +} + +def PTO_VaddcsOp : PTO_Op<"vaddcs", [Pure]> { + let arguments = (ins + PTO_VectorType:$lhs, + PTO_VectorType:$rhs, + PTO_MaskTypeConstraint:$carry_in, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs + PTO_VectorType:$result, + PTO_MaskTypeConstraint:$carry + ); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs `,` $carry_in `,` $mask attr-dict `:` type($lhs) `,` type($rhs) `,` type($carry_in) `,` type($mask) `->` type($result) `,` type($carry) + }]; +} + +def PTO_VsubcsOp : PTO_Op<"vsubcs", [Pure]> { + let arguments = (ins + PTO_VectorType:$lhs, + PTO_VectorType:$rhs, + PTO_MaskTypeConstraint:$carry_in, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs + PTO_VectorType:$result, + PTO_MaskTypeConstraint:$carry + ); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs `,` $carry_in `,` $mask attr-dict `:` type($lhs) `,` type($rhs) `,` type($carry_in) `,` type($mask) `->` type($result) `,` type($carry) + }]; +} + +def PTO_VshlOp : PTO_BinaryVecOp<"vshl">; +def PTO_VshrOp : PTO_BinaryVecOp<"vshr">; + +def PTO_VselOp : PTO_Op<"vsel", [Pure]> { + let arguments = (ins + PTO_VectorType:$src0, + PTO_VectorType:$src1, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $src0 `,` $src1 `,` $mask attr-dict `:` type($src0) `,` type($src1) `,` type($mask) `->` type($result) + }]; +} + +def PTO_VcmpOp : PTO_Op<"vcmp", [Pure]> { + let arguments = (ins + PTO_VectorType:$src0, + PTO_VectorType:$src1, + PTO_MaskTypeConstraint:$mask, + StrAttr:$cmp_mode + ); + let results = (outs PTO_MaskTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $src0 `,` $src1 `,` $mask `,` $cmp_mode attr-dict `:` type($src0) `,` type($src1) `,` type($mask) `->` type($result) + }]; +} + +def PTO_VcmpsOp : PTO_Op<"vcmps", [Pure]> { + let arguments = (ins + PTO_VectorType:$src, + AnyType:$scalar, + PTO_MaskTypeConstraint:$mask, + StrAttr:$cmp_mode + ); + let results = (outs PTO_MaskTypeConstraint:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $src `,` $scalar `,` $mask `,` $cmp_mode attr-dict `:` type($src) `,` type($scalar) `,` type($mask) `->` type($result) + }]; +} + +class PTO_PredicatePairReorderOp + : PTO_Op { + let arguments = (ins + operandTy:$lhs, + operandTy:$rhs + ); + let results = (outs + operandTy:$low, + operandTy:$high + ); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($low) `,` type($high) + }]; +} + +def PTO_PdintlvB8Op : PTO_PredicatePairReorderOp<"pdintlv_b8", + PTO_B8MaskTypeConstraint>; +def PTO_PdintlvB16Op : PTO_PredicatePairReorderOp<"pdintlv_b16", + PTO_B16MaskTypeConstraint>; +def PTO_PdintlvB32Op : PTO_PredicatePairReorderOp<"pdintlv_b32", + PTO_B32MaskTypeConstraint>; + +def PTO_PintlvB8Op : PTO_PredicatePairReorderOp<"pintlv_b8", + PTO_B8MaskTypeConstraint>; +def PTO_PintlvB16Op : PTO_PredicatePairReorderOp<"pintlv_b16", + PTO_B16MaskTypeConstraint>; +def PTO_PintlvB32Op : PTO_PredicatePairReorderOp<"pintlv_b32", + PTO_B32MaskTypeConstraint>; + +class PTO_VecScalarOp : PTO_Op { + let arguments = (ins + PTO_VectorType:$input, + AnyType:$scalar, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $input `,` $scalar `,` $mask attr-dict `:` type($input) `,` type($scalar) `,` type($mask) `->` type($result) + }]; +} + +class PTO_VecScalarMaskedOp : PTO_Op { + let arguments = (ins + PTO_VectorType:$input, + AnyType:$scalar, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $input `,` $scalar `,` $mask attr-dict `:` type($input) `,` type($scalar) `,` type($mask) `->` type($result) + }]; +} + +def PTO_VtrcOp : PTO_Op<"vtrc", [Pure]> { + let arguments = (ins + PTO_VectorType:$input, + PTO_MaskTypeConstraint:$mask, + StrAttr:$round_mode + ); + let results = (outs PTO_VectorType:$result); + + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; +} + +def PTO_VcvtOp : PTO_Op<"vcvt", [Pure]> { + let arguments = (ins + PTO_VectorType:$input, + PTO_MaskTypeConstraint:$mask, + OptionalAttr:$rnd, + OptionalAttr:$sat, + OptionalAttr:$part + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; +} + +def PTO_VbitcastOp : PTO_Op<"vbitcast", [Pure]> { + let arguments = (ins + PTO_VectorType:$input + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $input attr-dict `:` type($input) `->` type($result) + }]; +} + +def PTO_VciOp : PTO_Op<"vci", [Pure]> { + let arguments = (ins + AnyTypeOf<[AnyInteger, AnyFloat], "integer/float scalar">:$index, + OptionalAttr:$order + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $index attr-dict `:` type($index) `->` type($result) + }]; +} + +def PTO_VbitsortOp : PTO_Op<"vbitsort", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferType:$destination, + PTO_BufferType:$source, + PTO_BufferType:$indices, + Index:$repeat_times + ); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $destination `,` $source `,` $indices `,` $repeat_times attr-dict `:` type($destination) `,` + type($source) `,` type($indices) `,` type($repeat_times) + }]; +} + +def PTO_Vmrgsort4Op : PTO_Op<"vmrgsort4"> { + let arguments = (ins + PTO_BufferType:$destination, + PTO_BufferType:$source0, + PTO_BufferType:$source1, + PTO_BufferType:$source2, + PTO_BufferType:$source3, + I64:$count, + I64:$config + ); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $destination `,` $source0 `,` $source1 `,` $source2 `,` $source3 `,` $count `,` $config + attr-dict `:` type($destination) `,` type($source0) `,` type($source1) `,` type($source2) `,` + type($source3) `,` type($count) `,` type($config) + }]; +} + +def PTO_Vgather2Op : PTO_Op<"vgather2", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferType:$source, + PTO_VectorType:$offsets, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `,` $offsets `,` $mask attr-dict `:` type($source) `,` type($offsets) `,` type($mask) `->` type($result) + }]; +} + +def PTO_VgatherbOp : PTO_Op<"vgatherb", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferType:$source, + PTO_VectorType:$offsets, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `,` $offsets `,` $mask attr-dict `:` type($source) `,` type($offsets) `,` type($mask) `->` type($result) + }]; +} + +// NOTE: Unvalidated new gather/select/interleave-family abstractions. Added to +// cover CCE builtin families not yet exercised through end-to-end PTO seams. +def PTO_Vgather2BcOp : PTO_Op<"vgather2_bc", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferType:$source, + PTO_VectorType:$offsets, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `,` $offsets `,` $mask attr-dict `:` type($source) `,` type($offsets) `,` type($mask) `->` type($result) + }]; +} + +def PTO_VmulsOp : PTO_VecScalarMaskedOp<"vmuls">; +def PTO_VaddsOp : PTO_VecScalarMaskedOp<"vadds">; +def PTO_VmaxsOp : PTO_VecScalarMaskedOp<"vmaxs">; +def PTO_VminsOp : PTO_VecScalarMaskedOp<"vmins">; +def PTO_VlreluOp : PTO_VecScalarMaskedOp<"vlrelu">; +def PTO_VshlsOp : PTO_VecScalarMaskedOp<"vshls">; +def PTO_VshrsOp : PTO_VecScalarMaskedOp<"vshrs">; + +def PTO_VstsOp : PTO_Op<"vsts", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_VectorType:$value, + PTO_BufferLikeType:$destination, + Index:$offset, + OptionalAttr:$dist, + PTO_MaskTypeConstraint:$mask + ); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $value `,` $destination `[` $offset `]` `,` $mask attr-dict `:` type($value) `,` type($destination) `,` type($mask) + }]; +} + +def PTO_VstsPostOp : PTO_Op<"vsts_post", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_VectorType:$value, + PTO_BufferLikeType:$destination, + Index:$offset, + OptionalAttr:$dist, + PTO_MaskTypeConstraint:$mask + ); + + let results = (outs PTO_BufferLikeType:$updated_destination); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $value `,` $destination `[` $offset `]` `,` $mask attr-dict `:` type($value) `,` type($destination) `,` type($mask) `->` type($updated_destination) + }]; +} + +def PTO_VscatterOp : PTO_Op<"vscatter", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_VectorType:$value, + PTO_BufferType:$destination, + PTO_VectorType:$offsets, + PTO_MaskTypeConstraint:$mask + ); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $value `,` $destination `,` $offsets `,` $mask attr-dict `:` type($value) `,` type($destination) `,` type($offsets) `,` type($mask) + }]; +} + +def PTO_PstsOp : PTO_Op<"psts", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_MaskTypeConstraint:$value, + PTO_BufferLikeType:$destination, + Index:$offset, + StrAttr:$dist + ); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $value `,` $destination `[` $offset `]` `,` $dist attr-dict `:` type($value) `,` type($destination) `,` type($offset) + }]; +} + +def PTO_CopyUbufToGmOp : PTO_Op<"copy_ubuf_to_gm", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferLikeType:$source, + PTO_BufferLikeType:$destination, + I64:$sid, + I64:$n_burst, + I64:$len_burst, + I64:$reserved, + I64:$burst_dst_stride, + I64:$burst_src_stride + ); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `,` $destination `,` $sid `,` $n_burst `,` $len_burst `,` + $reserved `,` $burst_dst_stride `,` $burst_src_stride + attr-dict `:` type($source) `,` type($destination) `,` type($sid) `,` type($n_burst) `,` + type($len_burst) `,` type($reserved) `,` + type($burst_dst_stride) `,` type($burst_src_stride) + }]; +} + +def PTO_MteUbGmOp : PTO_Op<"mte_ub_gm", [ + AttrSizedOperandSegments, + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferLikeType:$source, + PTO_BufferLikeType:$destination, + I64:$len_burst, + I64:$n_burst, + I64:$nburst_src_stride, + I64:$nburst_dst_stride, + Variadic:$loop_counts, + Variadic:$loop_src_strides, + Variadic:$loop_dst_strides + ); + + let results = (outs); + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; + + let builders = [ + OpBuilder<(ins + "::mlir::Value":$source, + "::mlir::Value":$destination, + "::mlir::Value":$lenBurst, + "::mlir::pto::DmaLoopConfig":$nburst, + "::llvm::ArrayRef<::mlir::pto::DmaLoopConfig>":$loops + )>, + OpBuilder<(ins + "::mlir::Value":$source, + "::mlir::Value":$destination, + "::mlir::Value":$lenBurst, + "::mlir::pto::DmaLoopConfig":$nburst, + "::std::optional<::mlir::pto::DmaLoopConfig>":$loop1, + "::std::optional<::mlir::pto::DmaLoopConfig>":$loop2 + )> + ]; +} + +// Cube bridge wrappers: fuse common register config sequences with the core +// cube transfer/load/store ops, similar to dma_load/dma_store wrappers. +def PTO_MteGmL1Op : PTO_Op<"mte_gm_l1", [ + AttrSizedOperandSegments, + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferLikeType:$source, + PTO_BufferLikeType:$destination, + I64:$len_burst, + I64:$n_burst, + I64:$nburst_src_stride, + I64:$nburst_dst_stride, + Variadic:$loop_counts, + Variadic:$loop_src_strides, + Variadic:$loop_dst_strides + ); + + let results = (outs); + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; + + let builders = [ + OpBuilder<(ins + "::mlir::Value":$source, + "::mlir::Value":$destination, + "::mlir::Value":$lenBurst, + "::mlir::pto::DmaLoopConfig":$nburst, + "::llvm::ArrayRef<::mlir::pto::DmaLoopConfig>":$loops + )>, + OpBuilder<(ins + "::mlir::Value":$source, + "::mlir::Value":$destination, + "::mlir::Value":$lenBurst, + "::mlir::pto::DmaLoopConfig":$nburst, + "::std::optional<::mlir::pto::DmaLoopConfig>":$loop1, + "::std::optional<::mlir::pto::DmaLoopConfig>":$loop2 + )> + ]; +} + +def PTO_MteL1UbOp : PTO_Op<"mte_l1_ub", [ + AttrSizedOperandSegments, + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferLikeType:$source, + PTO_BufferLikeType:$destination, + I64:$len_burst, + I64:$n_burst, + I64:$nburst_src_stride, + I64:$nburst_dst_stride, + Variadic:$loop_counts, + Variadic:$loop_src_strides, + Variadic:$loop_dst_strides + ); + + let results = (outs); + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; + + let builders = [ + OpBuilder<(ins + "::mlir::Value":$source, + "::mlir::Value":$destination, + "::mlir::Value":$lenBurst, + "::mlir::pto::DmaLoopConfig":$nburst, + "::llvm::ArrayRef<::mlir::pto::DmaLoopConfig>":$loops + )>, + OpBuilder<(ins + "::mlir::Value":$source, + "::mlir::Value":$destination, + "::mlir::Value":$lenBurst, + "::mlir::pto::DmaLoopConfig":$nburst, + "::std::optional<::mlir::pto::DmaLoopConfig>":$loop1, + "::std::optional<::mlir::pto::DmaLoopConfig>":$loop2 + )> + ]; +} + +def PTO_MteL1BtOp : PTO_Op<"mte_l1_bt", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferLikeType:$source, + PTO_BufferLikeType:$destination, + I64:$len_burst, + I64:$n_burst, + I64:$nburst_src_gap, + I64:$nburst_dst_gap + ); + + let results = (outs); + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; + + let builders = [ + OpBuilder<(ins + "::mlir::Value":$source, + "::mlir::Value":$destination, + "::mlir::Value":$lenBurst, + "::mlir::pto::DmaLoopConfig":$nburst + )> + ]; +} + +def PTO_MteL1FbOp : PTO_Op<"mte_l1_fb", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferLikeType:$source, + PTO_BufferLikeType:$destination, + I64:$len_burst, + I64:$n_burst, + I64:$nburst_src_gap, + I64:$nburst_dst_gap + ); + + let results = (outs); + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; + + let builders = [ + OpBuilder<(ins + "::mlir::Value":$source, + "::mlir::Value":$destination, + "::mlir::Value":$lenBurst, + "::mlir::pto::DmaLoopConfig":$nburst + )> + ]; +} + +def PTO_MteGmL1FracOp : PTO_Op<"mte_gm_l1_frac", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferLikeType:$source, + PTO_BufferLikeType:$destination, + I64:$n_value, + I64:$d_value, + I64:$src_inner_stride, + I64:$group_count, + I64:$dst_loop2_stride, + I64:$dst_loop3_stride, + I64:$dst_loop4_stride, + I64:$l2_cache_ctrl, + I1:$smallc0_en, + Optional:$src_outer_stride, + PTO_CubeLoadFracModeAttr:$mode + ); + + let results = (outs); + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; + + let builders = [ + OpBuilder<(ins + "::mlir::Value":$source, + "::mlir::Value":$destination, + "::mlir::pto::CubeLoadFracMode":$mode, + "::mlir::pto::CubeLoadFracShapeConfig":$shape, + "::mlir::pto::CubeLoadFracSrcLayoutConfig":$srcLayout, + "::mlir::pto::CubeLoadFracDstGroupConfig":$dstGroup, + "::mlir::pto::CubeLoadFracCtrlConfig":$ctrl + )> + ]; +} + +def PTO_MteL1L0aOp : PTO_Op<"mte_l1_l0a", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferLikeType:$source, + PTO_BufferLikeType:$destination, + I64:$m, + I64:$k, + DefaultValuedAttr:$transpose + ); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `,` $destination `,` $m `,` $k attr-dict `:` type($source) `,` + type($destination) `,` type($m) `,` type($k) + }]; +} + +def PTO_MteL1L0bOp : PTO_Op<"mte_l1_l0b", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferLikeType:$source, + PTO_BufferLikeType:$destination, + I64:$k, + I64:$n, + DefaultValuedAttr:$transpose + ); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `,` $destination `,` $k `,` $n attr-dict `:` type($source) `,` + type($destination) `,` type($k) `,` type($n) + }]; +} + +def PTO_MteL1L0aMxOp : PTO_Op<"mte_l1_l0a_mx", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferLikeType:$source, + PTO_BufferLikeType:$destination, + I64:$m, + I64:$k + ); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `,` $destination `,` $m `,` $k attr-dict `:` type($source) `,` + type($destination) `,` type($m) `,` type($k) + }]; +} + +def PTO_MteL1L0bMxOp : PTO_Op<"mte_l1_l0b_mx", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferLikeType:$source, + PTO_BufferLikeType:$destination, + I64:$k, + I64:$n + ); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `,` $destination `,` $k `,` $n attr-dict `:` type($source) `,` + type($destination) `,` type($k) `,` type($n) + }]; +} + +def PTO_MteL0cL1Op : PTO_Op<"mte_l0c_l1", [ + AttrSizedOperandSegments, + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferLikeType:$source, + PTO_BufferLikeType:$destination, + I64:$m, + I64:$n, + I64:$src_stride, + I64:$dst_stride, + Optional:$pre_quant, + Optional:$pre_relu, + Optional:$clip_value, + Optional:$split, + Optional:$loop0_src_stride, + Optional:$loop3_count, + Optional:$loop3_src_stride, + Optional:$loop3_dst_stride, + OptionalAttr:$mode, + OptionalAttr:$unit_flag, + OptionalAttr:$pre_quant_mode, + OptionalAttr:$pre_relu_mode, + OptionalAttr:$atomic_type, + OptionalAttr:$atomic_op, + OptionalAttr:$sat_mode + ); + + let results = (outs); + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; +} + +def PTO_MteL0cGmOp : PTO_Op<"mte_l0c_gm", [ + AttrSizedOperandSegments, + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferLikeType:$source, + PTO_BufferLikeType:$destination, + I64:$m, + I64:$n, + I64:$src_stride, + I64:$dst_stride, + Optional:$pre_quant, + Optional:$pre_relu, + Optional:$clip_value, + I64:$sid, + I64:$l2_cache_ctrl, + Optional:$split, + Optional:$loop0_src_stride, + Optional:$loop3_count, + Optional:$loop3_src_stride, + Optional:$loop3_dst_stride, + OptionalAttr:$mode, + OptionalAttr:$unit_flag, + OptionalAttr:$pre_quant_mode, + OptionalAttr:$pre_relu_mode, + OptionalAttr:$atomic_type, + OptionalAttr:$atomic_op, + OptionalAttr:$sat_mode + ); + + let results = (outs); + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; +} + +def PTO_MteL0cUbOp : PTO_Op<"mte_l0c_ub", [ + AttrSizedOperandSegments, + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferLikeType:$source, + PTO_BufferLikeType:$destination, + I64:$m, + I64:$n, + I64:$src_stride, + I64:$dst_stride, + Optional:$pre_quant, + Optional:$pre_relu, + Optional:$clip_value, + Optional:$sub_blockid, + Optional:$split, + Optional:$loop0_src_stride, + Optional:$loop3_count, + Optional:$loop3_src_stride, + Optional:$loop3_dst_stride, + PTO_AccStoreUbDstModeAttr:$dst_mode, + OptionalAttr:$mode, + OptionalAttr:$unit_flag, + OptionalAttr:$pre_quant_mode, + OptionalAttr:$pre_relu_mode, + OptionalAttr:$sat_mode + ); + + let results = (outs); + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; +} + +// NOTE: Unvalidated new x2 / pair / align-store-family abstractions. Added to +// reflect CCE builtin families but not yet end-to-end validated. +def PTO_VselrOp : PTO_Op<"vselr", [Pure]> { + let arguments = (ins + PTO_VectorType:$src0, + PTO_VectorType:$src1 + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $src0 `,` $src1 attr-dict `:` type($src0) `,` type($src1) `->` type($result) + }]; +} + +def PTO_VsqzOp : PTO_UnaryVecOp<"vsqz">; + +def PTO_VusqzOp : PTO_Op<"vusqz", [Pure]> { + let arguments = (ins + PTO_VectorType:$src, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $src `,` $mask attr-dict `:` type($src) `,` type($mask) `->` type($result) + }]; +} + +def PTO_VpackOp : PTO_Op<"vpack", [Pure]> { + let arguments = (ins + PTO_VectorType:$src, + StrAttr:$part + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $src `,` $part attr-dict `:` type($src) `->` type($result) + }]; +} + +def PTO_VsunpackOp : PTO_Op<"vsunpack", [Pure]> { + let arguments = (ins + PTO_VectorType:$src, + Index:$part + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $src `,` $part attr-dict `:` type($src) `->` type($result) + }]; +} + +def PTO_VzunpackOp : PTO_Op<"vzunpack", [Pure]> { + let arguments = (ins + PTO_VectorType:$src, + Index:$part + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $src `,` $part attr-dict `:` type($src) `->` type($result) + }]; +} + +def PTO_Vselrv2Op : PTO_Op<"vselrv2", [Pure]> { + let arguments = (ins + PTO_VectorType:$src0, + PTO_VectorType:$src1 + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $src0 `,` $src1 attr-dict `:` type($src0) `,` type($src1) `->` type($result) + }]; +} + +def PTO_VintlvOp : PTO_Op<"vintlv", [Pure]> { + let arguments = (ins + PTO_VectorType:$lhs, + PTO_VectorType:$rhs + ); + let results = (outs PTO_VectorType:$low, PTO_VectorType:$high); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($low) `,` type($high) + }]; +} + +def PTO_VdintlvOp : PTO_Op<"vdintlv", [Pure]> { + let arguments = (ins + PTO_VectorType:$lhs, + PTO_VectorType:$rhs + ); + let results = (outs PTO_VectorType:$low, PTO_VectorType:$high); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($low) `,` type($high) + }]; +} + +def PTO_Vintlvv2Op : PTO_Op<"vintlvv2", [Pure]> { + let arguments = (ins + PTO_VectorType:$lhs, + PTO_VectorType:$rhs, + StrAttr:$part + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs `,` $part attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) + }]; +} + +def PTO_Vdintlvv2Op : PTO_Op<"vdintlvv2", [Pure]> { + let arguments = (ins + PTO_VectorType:$lhs, + PTO_VectorType:$rhs, + StrAttr:$part + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs `,` $part attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) + }]; +} + +def PTO_VmullOp : PTO_Op<"vmull", [Pure]> { + let arguments = (ins + PTO_VectorType:$lhs, + PTO_VectorType:$rhs, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_VectorType:$low, PTO_VectorType:$high); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs `,` $mask attr-dict `:` type($lhs) `,` type($rhs) `,` type($mask) `->` type($low) `,` type($high) + }]; +} + +def PTO_VmulaOp : PTO_Op<"vmula", [Pure]> { + let arguments = (ins + PTO_VectorType:$acc, + PTO_VectorType:$lhs, + PTO_VectorType:$rhs, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $acc `,` $lhs `,` $rhs `,` $mask attr-dict `:` type($acc) `,` type($lhs) `,` type($rhs) `,` type($mask) `->` type($result) + }]; +} + +class PTO_UnmaskedBinaryVecOp : PTO_Op { + let arguments = (ins + PTO_VectorType:$lhs, + PTO_VectorType:$rhs + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) + }]; +} + +class PTO_BinaryVecMaskedOp : PTO_Op { + let arguments = (ins + PTO_VectorType:$lhs, + PTO_VectorType:$rhs, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs `,` $mask attr-dict `:` type($lhs) `,` type($rhs) `,` type($mask) `->` type($result) + }]; +} + +def PTO_VpreluOp : PTO_BinaryVecMaskedOp<"vprelu">; +def PTO_VexpdifOp : PTO_Op<"vexpdif", [Pure]> { + let arguments = (ins + PTO_VectorType:$input, + PTO_VectorType:$max, + PTO_MaskTypeConstraint:$mask, + StrAttr:$part + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $input `,` $max `,` $mask `,` $part attr-dict `:` type($input) `,` type($max) `,` type($mask) `->` type($result) + }]; +} + +def PTO_VaxpyOp : PTO_Op<"vaxpy", [Pure]> { + let arguments = (ins + PTO_VectorType:$src0, + PTO_VectorType:$src1, + AnyType:$alpha, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $src0 `,` $src1 `,` $alpha `,` $mask attr-dict `:` type($src0) `,` type($src1) `,` type($alpha) `,` type($mask) `->` type($result) + }]; +} + +def PTO_VaddreluconvOp : PTO_Op<"vaddreluconv", [Pure]> { + let arguments = (ins + PTO_VectorType:$lhs, + PTO_VectorType:$rhs + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) + }]; +} + +def PTO_VmulconvOp : PTO_Op<"vmulconv", [Pure]> { + let arguments = (ins + PTO_VectorType:$lhs, + PTO_VectorType:$rhs + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) + }]; +} + +def PTO_Vstsx2Op : PTO_Op<"vstsx2", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_VectorType:$low, + PTO_VectorType:$high, + PTO_BufferLikeType:$destination, + Index:$offset, + StrAttr:$dist, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $low `,` $high `,` $destination `[` $offset `]` `,` $dist `,` $mask attr-dict `:` type($low) `,` type($high) `,` type($destination) `,` type($offset) `,` type($mask) + }]; +} + +def PTO_VsldbOp : PTO_Op<"vsldb", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_BufferLikeType:$source, + I16:$block_stride, + I16:$repeat_stride, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs PTO_VectorType:$result); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $source `,` $block_stride `,` $repeat_stride `,` $mask attr-dict `:` type($source) `,` type($block_stride) `,` type($repeat_stride) `,` type($mask) `->` type($result) + }]; +} + +def PTO_VsstbOp : PTO_Op<"vsstb", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_VectorType:$value, + PTO_BufferLikeType:$destination, + I16:$block_stride, + I16:$repeat_stride, + PTO_MaskTypeConstraint:$mask + ); + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $value `,` $destination `,` $block_stride `,` $repeat_stride `,` $mask attr-dict `:` type($value) `,` type($destination) `,` type($block_stride) `,` type($repeat_stride) `,` type($mask) + }]; +} + +def PTO_VstasOp : PTO_Op<"vstas", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_AlignTypeConstraint:$value, + PTO_BufferLikeType:$destination, + I32:$offset + ); + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $value `,` $destination `,` $offset attr-dict `:` type($value) `,` type($destination) `,` type($offset) + }]; +} + +def PTO_VstarOp : PTO_Op<"vstar", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_AlignTypeConstraint:$value, + PTO_BufferLikeType:$destination + ); + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $value `,` $destination attr-dict `:` type($value) `,` type($destination) + }]; +} + +// NOTE: Unvalidated stateful store-family abstractions. These preserve +// align/base/offset update results explicitly in SSA form instead of relying on +// implicit CCE reference updates. +// Keep `base/base_out` pointer-only (`PTO_BufferType`): memref semantics for +// stateful post-update addresses are intentionally out of scope in this change. +def PTO_PstuOp : PTO_Op<"pstu", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_AlignTypeConstraint:$align_in, + PTO_MaskTypeConstraint:$value, + PTO_BufferType:$base + ); + let results = (outs PTO_AlignTypeConstraint:$align_out, PTO_BufferType:$base_out); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $align_in `,` $value `,` $base attr-dict `:` type($align_in) `,` type($value) `,` type($base) `->` type($align_out) `,` type($base_out) + }]; +} + +def PTO_VstusOp : PTO_Op<"vstus", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_AlignTypeConstraint:$align_in, + I32:$offset, + PTO_VectorType:$value, + PTO_BufferType:$base + ); + let results = (outs PTO_AlignTypeConstraint:$align_out); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $align_in `,` $offset `,` $value `,` $base attr-dict `:` type($align_in) `,` type($offset) `,` type($value) `,` type($base) `->` type($align_out) + }]; +} + +def PTO_VsturOp : PTO_Op<"vstur", [ + DeclareOpInterfaceMethods + ]> { + let arguments = (ins + PTO_AlignTypeConstraint:$align_in, + PTO_VectorType:$value, + PTO_BufferType:$base, + StrAttr:$mode + ); + let results = (outs PTO_AlignTypeConstraint:$align_out); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $align_in `,` $value `,` $base `,` $mode attr-dict `:` type($align_in) `,` type($value) `,` type($base) `->` type($align_out) + }]; +} + +#endif // MLIR_DIALECT_PTO_IR_VPTOOPS diff --git a/include/PTO/IR/VPTOTypeDefs.td b/include/PTO/IR/VPTOTypeDefs.td new file mode 100644 index 000000000..ed62e7655 --- /dev/null +++ b/include/PTO/IR/VPTOTypeDefs.td @@ -0,0 +1,61 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- VPTOTypeDefs.td ---------------------------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_PTO_IR_VPTOTYPEDEFS +#define MLIR_DIALECT_PTO_IR_VPTOTYPEDEFS + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinTypeInterfaces.td" + +def VRegType : TypeDef { + let mnemonic = "vreg"; + let summary = "A 256-byte PTO low-level vector"; + + let parameters = (ins + "int64_t":$elementCount, + "Type":$elementType + ); + + let hasCustomAssemblyFormat = 1; + let genVerifyDecl = 1; +} + +def MaskType : TypeDef { + let mnemonic = "mask"; + let summary = "A PTO low-level predicate/mask register"; + + let parameters = (ins + StringRefParameter<"mask granularity view">:$granularity + ); + + let hasCustomAssemblyFormat = 1; + let genVerifyDecl = 1; + + let extraClassDeclaration = [{ + static bool isSupportedGranularity(::llvm::StringRef granularity); + + bool isB8() const { return getGranularity() == "b8"; } + bool isB16() const { return getGranularity() == "b16"; } + bool isB32() const { return getGranularity() == "b32"; } + }]; +} + +def AlignType : TypeDef { + let mnemonic = "align"; + let summary = "A PTO low-level vector_align carrier"; +} + +#endif // MLIR_DIALECT_PTO_IR_VPTOTYPEDEFS diff --git a/include/PTO/Transforms/Passes.h b/include/PTO/Transforms/Passes.h index 068071193..9f8bec808 100644 --- a/include/PTO/Transforms/Passes.h +++ b/include/PTO/Transforms/Passes.h @@ -39,6 +39,8 @@ std::unique_ptr createPTOLowerFrontendPipeOpsPass(); std::unique_ptr createPTOInferValidatePipeInitPass(); std::unique_ptr createPTOResolveReservedBuffersPass(); std::unique_ptr createPTOWrapFunctionsInSectionsPass(); +std::unique_ptr createVPTOSplitCVModulePass(); +std::unique_ptr createVPTONormalizeContainerPass(); std::unique_ptr createPTOVerifyTFreePass(); // Creates a pass for ... @@ -71,9 +73,26 @@ std::unique_ptr createPTOValidateIntToPtrUsesPass(); std::unique_ptr createPTOMaterializeTileHandlesPass(); std::unique_ptr createInferPTOLayoutPass(); std::unique_ptr createPTOA5NormalizeTMovPass(); - LogicalResult validateIntToPtrUses(func::FuncOp func); +std::unique_ptr createPTOInferVPTOVecScopePass(); +std::unique_ptr createVPTOExpandWrapperOpsPass(); +std::unique_ptr createPTOVPTOPtrBoundaryPass(); +std::unique_ptr createVPTOPtrNormalizePass(); +std::unique_ptr createVPTOPtrCastCleanupPass(); +LogicalResult validateVPTOAuthoringIR(ModuleOp module, + llvm::raw_ostream *diagOS = nullptr); +LogicalResult validateVPTOEmissionIR(ModuleOp module, + llvm::raw_ostream *diagOS = nullptr); +std::unique_ptr createPTOValidateVPTOIRPass(); +std::unique_ptr createPTOValidateVPTOEmissionIRPass(); +std::unique_ptr createExpandTileOpPass(); +std::unique_ptr createExpandTileOpPass(const ExpandTileOpOptions &options); +std::unique_ptr createFoldTileBufIntrinsicsPass(); +std::unique_ptr +createPTOInlineLibCallPass(const PTOInlineLibCallOptions &options = {}); +void registerPTOViewToMemrefPass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/include/PTO/Transforms/Passes.td b/include/PTO/Transforms/Passes.td index 19d4bd7fd..cc6bcf3b1 100644 --- a/include/PTO/Transforms/Passes.td +++ b/include/PTO/Transforms/Passes.td @@ -15,6 +15,9 @@ // //===----------------------------------------------------------------------===// +// The VPTO backend is emitted from tools/ptoas rather than a TableGen pass; +// these registrations continue to describe the shared pre-backend pipeline. + #ifndef MLIR_DIALECT_PTO_PASSES #define MLIR_DIALECT_PTO_PASSES @@ -185,6 +188,43 @@ def PTOWrapFunctionsInSections : Pass<"pto-wrap-functions-in-sections", "func::F ]; } +def VPTOSplitCVModule : Pass<"vpto-split-cv-module", "ModuleOp"> { + let summary = "Split a VPTO module with cube/vector sections into kernel modules"; + let description = [{ + Rewrites a single-module VPTO input where one `pto.kernel` function + contains `pto.section.cube` / `pto.section.vector` regions into the + normalized fatobj input form with one child module per kernel kind. + + The pass clones the original module for each target kind, removes the + opposite section kind, and inlines the target section body. It does not + analyze cross-section SSA dependencies; normal MLIR verification catches + illegal uses after the non-target section is removed. + }]; + + let constructor = "mlir::pto::createVPTOSplitCVModulePass()"; + + let dependentDialects = [ + "mlir::pto::PTODialect", + "mlir::func::FuncDialect" + ]; +} + +def VPTONormalizeContainer : Pass<"vpto-normalize-container", "ModuleOp"> { + let summary = "Normalize and verify VPTO kernel container shape"; + let description = [{ + Normalizes a VPTO kernel module carrying `pto.kernel_kind` into the canonical + outer container form. Existing containers are left in place. The pass + verifies that the resulting top-level module contains only child modules and + that every child module carries `pto.kernel_kind`. + }]; + + let constructor = "mlir::pto::createVPTONormalizeContainerPass()"; + + let dependentDialects = [ + "mlir::pto::PTODialect" + ]; +} + def PTOLowerFrontendPipeOps : Pass<"pto-lower-frontend-pipe-ops", "func::FuncOp"> { let summary = "Lower frontend TPUSH/TPOP pipe ops to internal pipe ops"; let description = [{ @@ -265,6 +305,97 @@ def PTOResolveReservedBuffers : Pass<"pto-resolve-reserved-buffers", "ModuleOp"> ]; } +def ExpandTileOp : Pass<"pto-expand-tile-op", "ModuleOp"> { + let summary = "Expand tile ops into calls to TileLang DSL template functions"; + let description = [{ + Expands tile-level operations (pto.tadd, pto.tsub, etc.) by invoking the + TileLang Python DSL to instantiate template libraries. The generated + template functions use tile_buf parameters and contain vector-level + implementations (pto.vecscope, pto.vlds, pto.vadd, pto.vsts, etc.). + + Each tile op is replaced by a func.call to the generated template function, + with tile_buf operands passed directly (no type bridging). + + After this pass, the Inline pass inlines template bodies, and + FoldTileBufIntrinsics resolves tile_buf_addr / tile_valid_rows / + tile_valid_cols. + }]; + let constructor = "mlir::pto::createExpandTileOpPass()"; + let dependentDialects = [ + "mlir::pto::PTODialect", + "mlir::memref::MemRefDialect", + "mlir::arith::ArithDialect", + "mlir::func::FuncDialect", + "mlir::scf::SCFDialect", + "mlir::vector::VectorDialect" + ]; + let options = [ + Option<"tilelangPath", "tilelang-path", "std::string", + /*default=*/"\"\"", + "Path to directory of .py tilelang DSL template files">, + Option<"tilelangPkgPath", "tilelang-pkg-path", "std::string", + /*default=*/"\"\"", + "PYTHONPATH for tilelang_dsl package (added to env)">, + Option<"pythonExe", "python-exe", "std::string", + /*default=*/"\"python3\"", + "Python executable for tilelang DSL invocation"> + ]; +} + +def FoldTileBufIntrinsics : Pass<"pto-fold-tile-buf-intrinsics", "mlir::func::FuncOp"> { + let summary = "Fold structured-view intrinsics after template inlining"; + let description = [{ + After TileLang DSL template functions are inlined, the IR contains + structured-view intrinsics whose operands are now bound to concrete values. + + This pass resolves them: + - pto.tile_buf_addr → replaced by the memref recovered from the active + materialized tile handle, or by pto.castptr when the requested result + type is !pto.ptr + - pto.tile_valid_rows → folded to arith.constant if v_row is static, + or replaced with the dynamic valid_row operand carried by the + materialized tile handle + - pto.tile_valid_cols → same as above for v_col + + tensor_view family: + - pto.tensor_view_addr → traces through unrealized_conversion_cast → + subview → reinterpret_cast, then folds to the base memref or to + pto.castptr/pto.addptr on the base memref + - pto.get_tensor_view_dim → folded to arith.constant for static subview + sizes, or to the subview size SSA operand for dynamic dims + - pto.get_tensor_view_stride → folded to the reinterpret_cast stride + operand, multiplied by the subview stride when needed + + Dead unrealized_conversion_cast, memref.subview, and + memref.reinterpret_cast ops exposed by folding are cleaned up after the + rewrite. + }]; + let constructor = "mlir::pto::createFoldTileBufIntrinsicsPass()"; + let dependentDialects = [ + "mlir::pto::PTODialect", + "mlir::memref::MemRefDialect", + "mlir::arith::ArithDialect" + ]; +} + +def PTOInlineLibCall : Pass<"pto-inline-libcall", "ModuleOp"> { + let summary = "Materialize OP-Lib instance bodies and inline OP-Lib calls"; + let description = [{ + Resolves OP-Lib instance declarations generated by OP-Lib lowering, + materializes instance bodies, and inlines OP-Lib calls into caller/fused + helper functions. Function signatures stay in !pto.tile_buf form. + }]; + let constructor = "mlir::pto::createPTOInlineLibCallPass()"; + let dependentDialects = ["mlir::func::FuncDialect", "mlir::pto::PTODialect", + "mlir::memref::MemRefDialect", + "mlir::arith::ArithDialect", + "mlir::scf::SCFDialect"]; + let options = [Option< + "debug", "debug", "bool", + /*default=*/"false", + "Enable verbose debug logging for OP-Lib instantiation/inlining">]; +} + def PTOVerifyTFree : Pass<"pto-verify-tfree", "func::FuncOp"> { let summary = "Verify explicit matching pto.tfree placement for pto.tpop"; let description = [{ @@ -332,4 +463,130 @@ def PTOMaterializeTileHandles : Pass<"pto-materialize-tile-handles", "ModuleOp"> ]; } +def PTOInferVPTOVecScope + : Pass<"pto-infer-vpto-vecscope", "func::FuncOp"> { + let summary = + "Infer missing pto.vecscope regions for VPTO vector operation clusters"; + let description = [{ + Runs near the VPTO emission boundary after inlining, canonicalization, + CSE, pointer normalization, and wrapper-op expansion have exposed the final + SSA shape. The pass greedily clusters contiguous VPTO vector operations + into `pto.vecscope` regions while preserving explicit vector-scope carriers + and treating DMA/copy/sync, unresolved calls, terminators, and forbidden + operations as boundaries. + + The inferred `pto.vecscope` form remains resultless. Values whose type is + `!pto.vreg`, `!pto.mask`, or `!pto.align` must not escape the inferred + scope. + }]; + let constructor = "mlir::pto::createPTOInferVPTOVecScopePass()"; + let dependentDialects = ["mlir::func::FuncDialect", + "mlir::pto::PTODialect", + "mlir::arith::ArithDialect", + "mlir::memref::MemRefDialect", + "mlir::scf::SCFDialect"]; +} + +def PTOValidateVPTOIR : Pass<"pto-validate-vpto-ir", "ModuleOp"> { + let summary = + "Validate authoring-stage VPTO legality before ptr-boundary canonicalization"; + let description = [{ + Runs the authoring-stage VPTO legality verifier on post-mainline VPTO IR. + This stage keeps the memref-first authoring surface legal, while checking + the shared structural contracts that must hold before emission-boundary + canonicalization. + }]; + let constructor = "mlir::pto::createPTOValidateVPTOIRPass()"; + let dependentDialects = ["mlir::func::FuncDialect", + "mlir::pto::PTODialect", + "mlir::memref::MemRefDialect", + "mlir::scf::SCFDialect"]; +} + +def PTOValidateVPTOEmissionIR + : Pass<"pto-validate-vpto-emission-ir", "ModuleOp"> { + let summary = + "Validate emission-stage VPTO legality after ptr-boundary canonicalization"; + let description = [{ + Runs the emission-stage VPTO legality verifier on ptr-form VPTO IR after + `PTOVPTOPtrBoundary`. This stage re-checks the shared authoring contracts + and confirms the final emission surface no longer carries memref boundary + state or residual wrapper scaffold. + }]; + let constructor = "mlir::pto::createPTOValidateVPTOEmissionIRPass()"; + let dependentDialects = ["mlir::func::FuncDialect", + "mlir::pto::PTODialect", + "mlir::memref::MemRefDialect", + "mlir::scf::SCFDialect"]; +} + +def VPTOExpandWrapperOps + : Pass<"vpto-expand-wrapper-ops", "func::FuncOp"> { + let summary = + "Expand VPTO wrapper ops to emission-ready low-level VPTO IR"; + let description = [{ + Expand higher-level VPTO wrapper operations, including bridge, DMA, cube, + accumulator-store, and MAD semantic forms, to the low-level pointer/raw VPTO + operations consumed by backend emission. + }]; + let constructor = "mlir::pto::createVPTOExpandWrapperOpsPass()"; + let dependentDialects = ["mlir::func::FuncDialect", + "mlir::arith::ArithDialect", + "mlir::memref::MemRefDialect", + "mlir::LLVM::LLVMDialect", + "mlir::pto::PTODialect"]; +} + +def PTOVPTOPtrBoundary + : Pass<"pto-vpto-ptr-boundary", "ModuleOp"> { + let summary = + "Canonicalize the final VPTO emission boundary from memref-first IR to ptr ABI"; + let description = [{ + Runs the final emission-boundary ptr canonicalization after the backend + mainline has finished its memref-first optimization pipeline. This pass + rewrites eligible memref function arguments to same-space `!pto.ptr`, + rejects memref function results, canonicalizes supported body-level VPTO + buffer-like ops to ptr-form, and drops dead boundary/view scaffold such as + trivial `pto.castptr`, `pto.bind_tile`, `memref.subview`, + `memref.reinterpret_cast`, and `memref.memory_space_cast` once they become + unused. + }]; + let constructor = "mlir::pto::createPTOVPTOPtrBoundaryPass()"; + let dependentDialects = ["mlir::func::FuncDialect", + "mlir::pto::PTODialect", + "mlir::memref::MemRefDialect"]; +} + +def VPTOPtrNormalize + : Pass<"vpto-ptr-normalize", "ModuleOp"> { + let summary = + "Normalize VPTO ptr-like values and users into a uniform !pto.ptr form"; + let description = [{ + Uses MLIR's conversion framework to normalize VPTO ptr-related forms before + the existing VPTO ptr-boundary canonicalization runs. This pass rewrites + supported tile-buffer and memref view producers such as `pto.tile_buf_addr` + and `memref.subview`, and updates VPTO memory ops to consume the normalized + ptr-form consistently. + }]; + let constructor = "mlir::pto::createVPTOPtrNormalizePass()"; + let dependentDialects = ["mlir::func::FuncDialect", + "mlir::pto::PTODialect", + "mlir::memref::MemRefDialect", + "mlir::arith::ArithDialect"]; +} + +def VPTOPtrCastCleanup + : Pass<"vpto-ptr-cast-cleanup", "ModuleOp"> { + let summary = "Collapse transient ptr-memref-ptr bridge casts after VPTO ptr normalization"; + let description = [{ + Eliminates bridge chains such as + `!pto.ptr -> builtin.unrealized_conversion_cast -> memref.cast -> + builtin.unrealized_conversion_cast -> !pto.ptr` + when the outer ptr types already match. + }]; + let constructor = "mlir::pto::createVPTOPtrCastCleanupPass()"; + let dependentDialects = ["mlir::pto::PTODialect", + "mlir::memref::MemRefDialect"]; +} + #endif // MLIR_DIALECT_PTO_PASSES diff --git a/include/PTO/Transforms/VPTOLLVMEmitter.h b/include/PTO/Transforms/VPTOLLVMEmitter.h new file mode 100644 index 000000000..b831ac33a --- /dev/null +++ b/include/PTO/Transforms/VPTOLLVMEmitter.h @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef MLIR_DIALECT_PTO_TRANSFORMS_VPTOLLVMEMITTER_H +#define MLIR_DIALECT_PTO_TRANSFORMS_VPTOLLVMEMITTER_H + +#include +#include + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Support/LLVM.h" + +namespace mlir { +class ModuleOp; +} + +namespace llvm { +class LLVMContext; +class Module; +class raw_ostream; +} + +namespace mlir::pto { + +struct VPTOEmissionOptions { + bool dumpVPTOIR = false; + std::string targetTriple; + std::string march; + std::string aicoreArch; + std::string defaultTargetCPU; + std::string defaultTargetFeatures; +}; + +struct EmittedLLVMModule { + std::unique_ptr context; + std::unique_ptr module; +}; + +LogicalResult lowerVPTOModuleToLLVMModules( + ModuleOp module, const VPTOEmissionOptions &options, + EmittedLLVMModule &cubeModule, EmittedLLVMModule &vectorModule, + llvm::raw_ostream &diagOS); + +} // namespace mlir::pto + +#endif // MLIR_DIALECT_PTO_TRANSFORMS_VPTOLLVMEMITTER_H diff --git a/include/PTO/Transforms/VPTOLLVMEmitterHelper.h b/include/PTO/Transforms/VPTOLLVMEmitterHelper.h new file mode 100644 index 000000000..0db138273 --- /dev/null +++ b/include/PTO/Transforms/VPTOLLVMEmitterHelper.h @@ -0,0 +1,14 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef MLIR_DIALECT_PTO_TRANSFORMS_VPTOLLVMEMITTERHELPER_H +#define MLIR_DIALECT_PTO_TRANSFORMS_VPTOLLVMEMITTERHELPER_H + +#include "PTO/Transforms/VPTOLLVMEmitter.h" + +#endif // MLIR_DIALECT_PTO_TRANSFORMS_VPTOLLVMEMITTERHELPER_H diff --git a/include/PTO/Transforms/VPTOLowering.h b/include/PTO/Transforms/VPTOLowering.h new file mode 100644 index 000000000..2f7a8332e --- /dev/null +++ b/include/PTO/Transforms/VPTOLowering.h @@ -0,0 +1,37 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- VPTOLowering.h - VPTO buffer materialization contracts ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_PTO_TRANSFORMS_VPTOLOWERING_H_ +#define MLIR_DIALECT_PTO_TRANSFORMS_VPTOLOWERING_H_ + +#include "PTO/IR/PTO.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Support/LLVM.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir { +namespace pto { + +Value materializeBufferPointer(Value value, Type elementType, + Attribute memorySpace, + PatternRewriter &rewriter, Location loc); +LogicalResult convertVPTOEmissionBoundaryToPtr( + ModuleOp module, llvm::raw_ostream *diagOS = nullptr); + +} // namespace pto +} // namespace mlir + +#endif // MLIR_DIALECT_PTO_TRANSFORMS_VPTOLOWERING_H_ diff --git a/include/pto-c/Dialect/PTO.h b/include/pto-c/Dialect/PTO.h index fac8f2f57..3e1fe4169 100644 --- a/include/pto-c/Dialect/PTO.h +++ b/include/pto-c/Dialect/PTO.h @@ -26,7 +26,10 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(PTO, pto); // ---- !pto.ptr ---- bool mlirPTOTypeIsAPtrType(MlirType type); MlirType mlirPTOPtrTypeGet(MlirContext ctx, MlirType elementType); +MlirType mlirPTOPtrTypeGetWithMemorySpace(MlirContext ctx, MlirType elementType, + MlirAttribute memorySpace); MlirType mlirPTOPtrTypeGetElementType(MlirType type); +MlirAttribute mlirPTOPtrTypeGetMemorySpace(MlirType type); // ---- !pto.async_session / !pto.async_event ---- bool mlirPTOTypeIsAAsyncSessionType(MlirType type); @@ -123,6 +126,9 @@ MLIR_CAPI_EXPORTED int32_t mlirPTOReduceOpAttrGetValue(MlirAttribute attr); MLIR_CAPI_EXPORTED MlirAttribute mlirPTORoundModeAttrGet(MlirContext ctx, int32_t value); MLIR_CAPI_EXPORTED bool mlirPTOAttrIsARoundModeAttr(MlirAttribute attr); MLIR_CAPI_EXPORTED int32_t mlirPTORoundModeAttrGetValue(MlirAttribute attr); +MLIR_CAPI_EXPORTED MlirAttribute mlirPTOPrecisionModeAttrGet(MlirContext ctx, int32_t value); +MLIR_CAPI_EXPORTED bool mlirPTOAttrIsAPrecisionModeAttr(MlirAttribute attr); +MLIR_CAPI_EXPORTED int32_t mlirPTOPrecisionModeAttrGetValue(MlirAttribute attr); MLIR_CAPI_EXPORTED MlirAttribute mlirPTOSaturationModeAttrGet(MlirContext ctx, int32_t value); MLIR_CAPI_EXPORTED bool mlirPTOAttrIsASaturationModeAttr(MlirAttribute attr); MLIR_CAPI_EXPORTED int32_t mlirPTOSaturationModeAttrGetValue(MlirAttribute attr); diff --git a/lib/Bindings/Python/CMakeLists.txt b/lib/Bindings/Python/CMakeLists.txt index e9e32ba98..b7628dd86 100644 --- a/lib/Bindings/Python/CMakeLists.txt +++ b/lib/Bindings/Python/CMakeLists.txt @@ -39,6 +39,7 @@ target_link_libraries(_pto PRIVATE MLIRSupport MLIRArithDialect MLIRMemRefDialect + MLIRSCFDialect MLIRDestinationStyleOpInterface MLIRInferTypeOpInterface MLIRSideEffectInterfaces @@ -47,6 +48,7 @@ target_link_libraries(_pto PRIVATE MLIRLoopLikeInterface MLIRViewLikeInterface MLIRFunctionInterfaces + MLIRLLVMDialect ) # 关键:放到 mlir/_mlir_libs 下(匹配 MLIR dialect python 的 import 习惯) @@ -60,6 +62,10 @@ if(APPLE) target_link_options(_pto PRIVATE "LINKER:-undefined,dynamic_lookup") endif() +if(NOT MLIR_PYTHON_PACKAGE_DIR) + message(FATAL_ERROR "MLIR_PYTHON_PACKAGE_DIR must be set when PTO_ENABLE_PYTHON_BINDING=ON") +endif() + install(TARGETS _pto LIBRARY DESTINATION "${MLIR_PYTHON_PACKAGE_DIR}/mlir/_mlir_libs" ) @@ -101,4 +107,3 @@ add_custom_command(TARGET _pto POST_BUILD "${CMAKE_BINARY_DIR}/python/mlir/_mlir_libs" VERBATIM ) - diff --git a/lib/Bindings/Python/PTOModule.cpp b/lib/Bindings/Python/PTOModule.cpp index e9b7e2535..a68d3fd14 100644 --- a/lib/Bindings/Python/PTOModule.cpp +++ b/lib/Bindings/Python/PTOModule.cpp @@ -126,6 +126,10 @@ static void bindPTOModule(pybind11::module &m) { .value("ODD", mlir::pto::RoundMode::ODD) .value("CAST_RINT", mlir::pto::RoundMode::CAST_RINT); + py::enum_(m, "PrecisionMode") + .value("DEFAULT", mlir::pto::PrecisionMode::DEFAULT) + .value("HIGH_PRECISION", mlir::pto::PrecisionMode::HIGH_PRECISION); + py::enum_(m, "SaturationMode") .value("ON", mlir::pto::SaturationMode::ON) .value("OFF", mlir::pto::SaturationMode::OFF); @@ -416,6 +420,33 @@ static void bindPTOModule(pybind11::module &m) { return mlirPTORoundModeAttrGetValue(self); }); + mlir_attribute_subclass( + m, "PrecisionModeAttr", + [](MlirAttribute a) { return mlirPTOAttrIsAPrecisionModeAttr(a); }) + .def_classmethod( + "get", + [](py::object cls, py::object value, MlirContext ctx) -> py::object { + int32_t v = 0; + if (py::isinstance(value)) { + v = value.cast(); + } else if (py::hasattr(value, "value")) { + v = value.attr("value").cast(); + } else { + throw std::runtime_error("PrecisionModeAttr.get expects int or PrecisionMode enum"); + } + + MlirAttribute a = mlirPTOPrecisionModeAttrGet(ctx, v); + if (mlirAttributeIsNull(a)) return py::none(); + return cls.attr("__call__")(a); + }, + py::arg("cls"), py::arg("value"), py::arg("context") = py::none()) + + .def_property_readonly( + "value", + [](MlirAttribute self) -> int32_t { + return mlirPTOPrecisionModeAttrGetValue(self); + }); + mlir_attribute_subclass( m, "SaturationModeAttr", [](MlirAttribute a) { return mlirPTOAttrIsASaturationModeAttr(a); }) @@ -644,20 +675,34 @@ static void bindPTOModule(pybind11::module &m) { [](MlirType type) -> bool { return mlirPTOTypeIsAPtrType(type); }) .def_classmethod( "get", - [](py::object cls, MlirType elementType, + [](py::object cls, MlirType elementType, py::object memorySpace, MlirContext context) -> py::object { MlirContext ctx = context; if (!ctx.ptr) ctx = mlirTypeGetContext(elementType); - MlirType t = mlirPTOPtrTypeGet(ctx, elementType); + MlirType t = {nullptr}; + if (memorySpace.is_none()) { + t = mlirPTOPtrTypeGet(ctx, elementType); + } else { + MlirAttribute memorySpaceAttr = + py::cast(memorySpace); + t = mlirPTOPtrTypeGetWithMemorySpace(ctx, elementType, + memorySpaceAttr); + } return cls.attr("__call__")(t); }, py::arg("cls"), py::arg("element_type"), + py::arg("memory_space") = py::none(), py::arg("context") = py::none()) .def_property_readonly( "element_type", [](MlirType self) -> MlirType { return mlirPTOPtrTypeGetElementType(self); + }) + .def_property_readonly( + "memory_space", + [](MlirType self) -> MlirAttribute { + return mlirPTOPtrTypeGetMemorySpace(self); }); mlir_type_subclass( diff --git a/lib/CAPI/Dialect/PTO.cpp b/lib/CAPI/Dialect/PTO.cpp index 2ca0fef71..412f76f4e 100644 --- a/lib/CAPI/Dialect/PTO.cpp +++ b/lib/CAPI/Dialect/PTO.cpp @@ -76,6 +76,14 @@ MlirType mlirPTOPtrTypeGet(MlirContext ctx, MlirType elementType) { return wrap(mlir::pto::PtrType::get(c, elem)); } +MlirType mlirPTOPtrTypeGetWithMemorySpace(MlirContext ctx, MlirType elementType, + MlirAttribute memorySpace) { + auto c = unwrap(ctx); + auto elem = unwrap(elementType); + auto space = mlir::cast(unwrap(memorySpace)); + return wrap(mlir::pto::PtrType::get(c, elem, space)); +} + MlirType mlirPTOPtrTypeGetElementType(MlirType type) { auto t = cast(unwrap(type));; return wrap(t.getElementType()); @@ -129,6 +137,11 @@ MlirType mlirPTOF4E2M1x2TypeGet(MlirContext ctx) { return wrap(mlir::pto::F4E2M1x2Type::get(unwrap(ctx))); } +MlirAttribute mlirPTOPtrTypeGetMemorySpace(MlirType type) { + auto t = cast(unwrap(type)); + return wrap(t.getMemorySpace()); +} + bool mlirPTOAttrIsAAddressSpaceAttr(MlirAttribute attr) { return mlir::isa(unwrap(attr)); } @@ -371,6 +384,21 @@ int32_t mlirPTORoundModeAttrGetValue(MlirAttribute attr) { return static_cast(a.getValue()); } +MlirAttribute mlirPTOPrecisionModeAttrGet(MlirContext ctx, int32_t value) { + auto *c = unwrap(ctx); + auto mode = static_cast(value); + return wrap(mlir::pto::PrecisionModeAttr::get(c, mode)); +} + +bool mlirPTOAttrIsAPrecisionModeAttr(MlirAttribute attr) { + return mlir::isa(unwrap(attr)); +} + +int32_t mlirPTOPrecisionModeAttrGetValue(MlirAttribute attr) { + auto a = mlir::cast(unwrap(attr)); + return static_cast(a.getValue()); +} + MlirAttribute mlirPTOSaturationModeAttrGet(MlirContext ctx, int32_t value) { auto *c = unwrap(ctx); auto mode = static_cast(value); diff --git a/lib/PTO/IR/CMakeLists.txt b/lib/PTO/IR/CMakeLists.txt index b055d8290..74b9e0bd6 100644 --- a/lib/PTO/IR/CMakeLists.txt +++ b/lib/PTO/IR/CMakeLists.txt @@ -14,6 +14,7 @@ # [关键] 库名重命名为 PTOIR,避免与 LLVM 里的 PTODialect/MLIRPTODialect 冲突 add_mlir_dialect_library(PTOIR PTO.cpp + VPTO.cpp PTOAttrs.cpp PTOSyncUtils.cpp PTOTypeDefs.cpp @@ -29,6 +30,7 @@ add_mlir_dialect_library(PTOIR MLIRIR MLIRFuncDialect MLIRMemRefDialect + MLIRSCFDialect MLIRControlFlowInterfaces MLIRInferTypeOpInterface MLIRSideEffectInterfaces diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index 376b9c017..c009e6dee 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -184,6 +184,9 @@ static LogicalResult verifyArithmeticElemTypeForArch( Operation *op, Type elemTy, PTOArch targetArch, bool allowInt8OnA5, bool allowBf16OnA5, StringRef a2a3Error, StringRef a5Error); static bool isRowMajorTileBuf(Type ty); +static ParseResult parseLegacyOrAttrPipe(OpAsmParser &parser, PipeAttr &attr); +static ParseResult parseLegacyOrAttrEvent(OpAsmParser &parser, EventAttr &attr); +static ParseResult parseI32LiteralAttr(OpAsmParser &parser, IntegerAttr &attr); #define GET_ENUM_CLASSES #include "PTO/IR/PTOEnums.cpp.inc" @@ -389,6 +392,49 @@ static LogicalResult dispatchVerifierByArch(Operation *op, FnA2A3 &&verifyA2A3, } return failure(); } +static std::optional parsePtrAddressSpaceKeyword(StringRef keyword) { + return llvm::StringSwitch>(keyword) + .Case("gm", pto::AddressSpace::GM) + .Case("mat", pto::AddressSpace::MAT) + .Case("l1", pto::AddressSpace::MAT) + .Case("left", pto::AddressSpace::LEFT) + .Case("l0a", pto::AddressSpace::LEFT) + .Case("right", pto::AddressSpace::RIGHT) + .Case("l0b", pto::AddressSpace::RIGHT) + .Case("acc", pto::AddressSpace::ACC) + .Case("l0c", pto::AddressSpace::ACC) + .Case("vec", pto::AddressSpace::VEC) + .Case("ub", pto::AddressSpace::VEC) + .Case("bias", pto::AddressSpace::BIAS) + .Case("bt", pto::AddressSpace::BIAS) + .Case("scaling", pto::AddressSpace::SCALING) + .Case("fb", pto::AddressSpace::SCALING) + .Default(std::nullopt); +} + +static StringRef printPtrAddressSpaceKeyword(pto::AddressSpace space) { + switch (space) { + case pto::AddressSpace::GM: + case pto::AddressSpace::Zero: + return "gm"; + case pto::AddressSpace::MAT: + return "l1"; + case pto::AddressSpace::LEFT: + return "l0a"; + case pto::AddressSpace::RIGHT: + return "l0b"; + case pto::AddressSpace::ACC: + return "l0c"; + case pto::AddressSpace::VEC: + return "ub"; + case pto::AddressSpace::BIAS: + return "bt"; + case pto::AddressSpace::SCALING: + return "fb"; + default: + return {}; + } +} static ParseResult parseSyncEventOpCommon(OpAsmParser &parser, OperationState &result, @@ -498,17 +544,23 @@ static void printSyncEventOpCommon(OpAsmPrinter &p, Operation *op, mlir::Type elem; if (failed(parser.parseType(elem))) return mlir::Type(); + auto memorySpace = pto::AddressSpaceAttr::get(ctx, pto::AddressSpace::GM); if (succeeded(parser.parseOptionalComma())) { - // ptr no longer accepts an address space; consume the attr for recovery. - mlir::Attribute memorySpace; - (void)parser.parseAttribute(memorySpace); - parser.emitError(parser.getCurrentLocation(), - "!pto.ptr no longer accepts address space; use !pto.ptr"); - return mlir::Type(); + StringRef memorySpaceKeyword; + if (failed(parser.parseKeyword(&memorySpaceKeyword))) + return mlir::Type(); + auto parsed = parsePtrAddressSpaceKeyword(memorySpaceKeyword); + if (!parsed) { + parser.emitError(parser.getCurrentLocation(), + "!pto.ptr address space must be one of " + "`gm|ub|mat|l1|left|l0a|right|l0b|acc|l0c|vec|bias|bt|scaling|fb`"); + return mlir::Type(); + } + memorySpace = pto::AddressSpaceAttr::get(ctx, *parsed); } if (failed(parser.parseGreater())) return mlir::Type(); - return mlir::pto::PtrType::get(ctx, elem); + return mlir::pto::PtrType::get(ctx, elem, memorySpace); } if (head == "pto.tensor_view") { @@ -534,6 +586,41 @@ void TensorViewType::print(::mlir::AsmPrinter &printer) const { printShapeAndElem(printer, getShape(), getElementType()); } +mlir::Type PtrType::parse(::mlir::AsmParser &parser) { + Type elementType; + if (failed(parser.parseLess()) || failed(parser.parseType(elementType))) + return {}; + + auto memorySpace = + pto::AddressSpaceAttr::get(parser.getContext(), pto::AddressSpace::GM); + if (succeeded(parser.parseOptionalComma())) { + StringRef memorySpaceKeyword; + if (failed(parser.parseKeyword(&memorySpaceKeyword))) + return {}; + auto parsed = parsePtrAddressSpaceKeyword(memorySpaceKeyword); + if (!parsed) { + parser.emitError(parser.getCurrentLocation(), + "!pto.ptr address space must be one of " + "`gm|ub|mat|l1|left|l0a|right|l0b|acc|l0c|vec|bias|bt|scaling|fb`"); + return {}; + } + memorySpace = pto::AddressSpaceAttr::get(parser.getContext(), *parsed); + } + + if (failed(parser.parseGreater())) + return {}; + return PtrType::get(parser.getContext(), elementType, memorySpace); +} + +void PtrType::print(::mlir::AsmPrinter &printer) const { + printer << "<" << getElementType(); + StringRef memorySpaceKeyword = + printPtrAddressSpaceKeyword(getMemorySpace().getAddressSpace()); + if (!memorySpaceKeyword.empty()) + printer << ", " << memorySpaceKeyword; + printer << ">"; +} + //===----------------------------------------------------------------------===// // pto.tdivs custom asm to support both: // pto.tdivs ins(%src, %scalar : !pto.tile_buf<...>, f32) outs(%dst : !pto.tile_buf<...>) @@ -2104,6 +2191,43 @@ LogicalResult mlir::pto::LocalArraySetOp::verify() { return success(); } +LogicalResult mlir::pto::CastPtrOp::verify() { + Type inputType = getInput().getType(); + Type resultType = getResult().getType(); + + auto inputPtrType = dyn_cast(inputType); + auto resultPtrType = dyn_cast(resultType); + auto inputMemRefType = dyn_cast(inputType); + bool inputIsInteger = isa(inputType); + bool resultIsInteger = isa(resultType); + + if (!inputPtrType && !inputMemRefType && !inputIsInteger) + return emitOpError("input must be an integer, memref, or !pto.ptr<...>"); + if (!resultPtrType && !resultIsInteger) + return emitOpError("result must be an integer or !pto.ptr<...>"); + + if (inputIsInteger && resultIsInteger) + return emitOpError("integer-to-integer cast is not a ptr cast"); + + if (inputMemRefType && resultIsInteger) + return emitOpError("memref-to-integer cast is unsupported"); + + if (inputMemRefType && resultPtrType) { + auto memrefSpace = dyn_cast_or_null( + inputMemRefType.getMemorySpace()); + auto resultSpace = resultPtrType.getMemorySpace(); + if (memrefSpace && memrefSpace != resultSpace) + return emitOpError("memref-to-ptr cast must stay within the same PTO memory space"); + } + + if (inputPtrType && resultPtrType && + inputPtrType.getMemorySpace() != resultPtrType.getMemorySpace()) { + return emitOpError("ptr-to-ptr cast must stay within the same PTO memory space"); + } + + return success(); +} + @@ -2126,6 +2250,8 @@ void PTODialect::initialize() { AddressSpaceAttr mlir::pto::getPTOAddressSpaceAttr(Type type) { + if (auto ptrType = dyn_cast(type)) + return ptrType.getMemorySpace(); auto memRefType = dyn_cast(type); if (!memRefType) return {}; @@ -2137,7 +2263,7 @@ AddressSpaceAttr mlir::pto::getPTOAddressSpaceAttr(Type type) { bool mlir::pto::isScalarPtrOrMemRef(Type type) { if (auto pty = dyn_cast(type)) - return true; + return static_cast(pty); if (auto memTy = dyn_cast(type)) return isGmAddressSpaceAttr(memTy.getMemorySpace()); return false; @@ -2456,6 +2582,21 @@ LogicalResult TLoadOp::verify() { pad != static_cast(pto::PadValue::Zero)) return emitOpError("expects A5 i64/u64 tload dst pad to be null or zero"); } + + auto dstSpace = getPTOMemorySpaceEnum(dstTile); + if (dstSpace && *dstSpace == pto::AddressSpace::VEC) { + int32_t bl = dstTile.getBLayoutValueI32(); + int32_t sl = dstTile.getSLayoutValueI32(); + bool isND = (bl == static_cast(pto::BLayout::RowMajor) && + sl == static_cast(pto::SLayout::NoneBox)); + bool isDN = (bl == static_cast(pto::BLayout::ColMajor) && + sl == static_cast(pto::SLayout::NoneBox)); + bool isNZ = (bl == static_cast(pto::BLayout::ColMajor) && + sl == static_cast(pto::SLayout::RowMajor)); + if (!isND && !isDN && !isNZ) + return emitOpError("expects A5 tload vec dst layout to be ND, DN, or NZ"); + } + return success(); }; @@ -2927,6 +3068,19 @@ LogicalResult TStoreOp::verify() { return emitOpError("expects A5 vec tstore src element type to be i8/i16/i32/i64/f16/bf16/f32/f8/hif8/fp4"); if (getElemByteSize(srcElem) != getElemByteSize(dstElem)) return emitOpError("expects A5 vec tstore src and dst element types to have the same bitwidth"); + + int32_t bl = srcTile.getBLayoutValueI32(); + int32_t sl = srcTile.getSLayoutValueI32(); + bool isND = (bl == static_cast(pto::BLayout::RowMajor) && + sl == static_cast(pto::SLayout::NoneBox)); + bool isDN = (bl == static_cast(pto::BLayout::ColMajor) && + sl == static_cast(pto::SLayout::NoneBox)); + bool isNZ = (bl == static_cast(pto::BLayout::ColMajor) && + sl == static_cast(pto::SLayout::RowMajor)); + auto srcShape = srcTile.getShape(); + bool isSpecialCase = (srcShape.size() == 2 && (srcShape[0] == 1 || srcShape[1] == 1)); + if (!isSpecialCase && !isND && !isDN && !isNZ) + return emitOpError("expects A5 vec tstore src layout to be ND, DN, or NZ (or special case with 1 row/col)"); return success(); } @@ -2989,6 +3143,17 @@ static Type getElemTy(Type ty) { return Type(); } +static LogicalResult verifyPrecisionModeFloatOnly(Operation *op, + pto::PrecisionMode mode, + Type elem) { + if (mode != pto::PrecisionMode::HIGH_PRECISION) + return success(); + if (elem.isF16() || elem.isF32()) + return success(); + return op->emitOpError() + << "precision_mode = HIGH_PRECISION requires element type to be f16 or f32"; +} + static SmallVector getShapeVec(Type ty) { SmallVector s; if (auto mr = mlir::dyn_cast(ty)) @@ -3652,6 +3817,27 @@ static LogicalResult verifyPartialValidPattern(Operation *op, Type src0Ty, return success(); } +static LogicalResult verifyPartialValidPatternLoose(Operation *op, Type src0Ty, + Type src1Ty, Type dstTy) { + auto src0Valid = getValidShapeVec(src0Ty); + auto src1Valid = getValidShapeVec(src1Ty); + auto dstValid = getValidShapeVec(dstTy); + if (src0Valid.size() != 2 || src1Valid.size() != 2 || dstValid.size() != 2) + return op->emitOpError("expects src0, src1, and dst to have rank-2 valid_shape"); + + auto lessEqualKnown = [](int64_t lhs, int64_t rhs) { + return lhs == ShapedType::kDynamic || rhs == ShapedType::kDynamic || lhs <= rhs; + }; + + for (unsigned i = 0; i < 2; ++i) { + if (!lessEqualKnown(src0Valid[i], dstValid[i]) || + !lessEqualKnown(src1Valid[i], dstValid[i])) + return op->emitOpError( + "expects src0/src1 valid_shape to be less than or equal to dst valid_shape"); + } + return success(); +} + [[maybe_unused]] static bool hasKnownZeroValidRegion(Type ty) { auto valid = getValidShapeVec(ty); if (valid.size() != 2) @@ -3734,19 +3920,20 @@ verifyNumericScalarTileOpCommon(Operation *op, Type srcTy, Type dstTy, static FailureOr verifyShiftLikeBinaryTileOpCommon(Operation *op, Type src0Ty, Type src1Ty, - Type dstTy) { + Type dstTy) { if (failed(verifyTileBufCommon(op, src0Ty, "src0")) || failed(verifyTileBufCommon(op, src1Ty, "src1")) || failed(verifyTileBufCommon(op, dstTy, "dst"))) return failure(); Type e0 = getElemTy(src0Ty); Type e1 = getElemTy(src1Ty); - if (!e0 || !e1) { + Type ed = getElemTy(dstTy); + if (!e0 || !e1 || !ed) { op->emitOpError("failed to get element type for operands"); return failure(); } - if (e0 != e1) { - op->emitOpError("expects src0 and src1 to have the same element type"); + if (e0 != e1 || e0 != ed) { + op->emitOpError("expects src0, src1, and dst to have the same element type"); return failure(); } if (!isRowMajorTileBuf(src0Ty) || !isRowMajorTileBuf(src1Ty) || @@ -3925,7 +4112,6 @@ static LogicalResult verifyVecTileStorage(Operation *op, Type ty, StringRef name return op->emitOpError() << "expects " << name << " to be in the vec address space"; return success(); } - static LogicalResult verifyVecTileCommonA2A3(Operation *op, Type ty, StringRef name) { if (failed(verifyTileBufCommon(op, ty, name))) @@ -4226,7 +4412,7 @@ LogicalResult pto::TAddSOp::verify() { "expects A2/A3 tadds element type to be i32/i16/f16/f32", "expects A5 tadds element type to be i32/i16/i8/f16/bf16/f32", /*requireValidRowsEqualOnA2A3=*/true, - /*requireValidRowsEqualOnA5=*/false); + /*requireValidRowsEqualOnA5=*/true); } LogicalResult pto::TAxpyOp::verify() { @@ -4882,10 +5068,13 @@ LogicalResult pto::TColExpandAddOp::verify() { LogicalResult pto::TColExpandDivOp::verify() { auto verifyByArch = [&](PTOArch targetArch) -> LogicalResult { bool allowIntegerTypes = (targetArch == PTOArch::A5); - return verifyTColExpandBinaryLikeOp(getOperation(), getSrc0().getType(), - getSrc1().getType(), getDst().getType(), - targetArch, "tcolexpanddiv", - /*allowIntegerTypes=*/allowIntegerTypes); + if (failed(verifyTColExpandBinaryLikeOp( + getOperation(), getSrc0().getType(), getSrc1().getType(), + getDst().getType(), targetArch, "tcolexpanddiv", + /*allowIntegerTypes=*/allowIntegerTypes))) + return failure(); + return verifyPrecisionModeFloatOnly(getOperation(), getPrecisionMode(), + getElemTy(getSrc0().getType())); }; auto verifyA2A3 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A3); }; auto verifyA5 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A5); }; @@ -5253,7 +5442,7 @@ mlir::LogicalResult mlir::pto::TDivSOp::verify() { !(elem.isInteger(32) || elem.isInteger(16) || elem.isInteger(8) || elem.isF16() || elem.isF32())) return emitOpError("expects A5 tdivs element type to be i32/i16/i8/f16/f32"); - return success(); + return verifyPrecisionModeFloatOnly(getOperation(), getPrecisionMode(), elem); }; auto verifyA2A3 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A3); }; auto verifyA5 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A5); }; @@ -6210,7 +6399,7 @@ mlir::LogicalResult mlir::pto::TLogOp::verify() { auto elemTy = getElemTy(srcTy); if (!(elemTy.isF16() || elemTy.isF32())) return emitOpError() << "expects element type to be f16 or f32"; - return mlir::success(); + return verifyPrecisionModeFloatOnly(getOperation(), getPrecisionMode(), elemTy); } mlir::LogicalResult mlir::pto::TLReluOp::verify() { @@ -6285,7 +6474,7 @@ mlir::LogicalResult mlir::pto::TMinSOp::verify() { "expects A2/A3 tmins element type to be i32/i16/f16/f32", "expects A5 tmins element type to be i32/i16/i8/f16/bf16/f32", /*requireValidRowsEqualOnA2A3=*/true, - /*requireValidRowsEqualOnA5=*/false); + /*requireValidRowsEqualOnA5=*/true); } mlir::LogicalResult mlir::pto::TMovOp::verify() { @@ -6547,14 +6736,19 @@ static LogicalResult verifyBufSyncOp(Operation *op, Attribute opTypeAttr, if (!opTypeAttr) return op->emitOpError("expects 'op_type' attribute"); - auto opTypeOr = parseSyncOpTypeLikeAttr(opTypeAttr); - if (failed(opTypeOr)) { - auto diag = - op->emitOpError("expects 'op_type' to be pipe_event_type/sync_op_type, got "); - diag << opTypeAttr; - return failure(); + pto::PIPE pipe = pto::PIPE::PIPE_UNASSIGNED; + if (auto pipeAttr = dyn_cast(opTypeAttr)) { + pipe = pipeAttr.getPipe(); + } else { + auto opTypeOr = parseSyncOpTypeLikeAttr(opTypeAttr); + if (failed(opTypeOr)) { + auto diag = op->emitOpError( + "expects 'op_type' to be pipe_event_type/sync_op_type/pipe, got "); + diag << opTypeAttr; + return failure(); + } + pipe = mapSyncOpTypeToPipe(*opTypeOr); } - pto::PIPE pipe = mapSyncOpTypeToPipe(*opTypeOr); if (!isConcreteSyncPipe(pipe)) return op->emitOpError("expects 'op_type' to map to a concrete pipe, not PIPE_ALL/PIPE_UNASSIGNED"); @@ -6582,6 +6776,261 @@ LogicalResult RlsBufOp::verify() { return verifyBufSyncOp(getOperation(), getOpTypeAttr(), getBufIdAttr(), getModeAttr()); } + +static ParseResult parseLegacyOrAttrMemBar(OpAsmParser &parser, + MemBarAttr &attr) { + auto loc = parser.getCurrentLocation(); + std::string token; + if (succeeded(parser.parseOptionalString(&token))) { + auto kind = symbolizeMemBarKind(token); + if (!kind) + return parser.emitError(loc) << "invalid membar token: " << token; + attr = MemBarAttr::get(parser.getContext(), *kind); + return success(); + } + + Attribute parsed; + if (failed(parser.parseAttribute(parsed))) + return failure(); + auto memBarAttr = dyn_cast(parsed); + if (!memBarAttr) + return parser.emitError(loc, "expected membar attribute"); + attr = memBarAttr; + return success(); +} + +static void printLegacyOrAttrMemBar(OpAsmPrinter &p, MemBarAttr kind, + ArrayRef attrs) { + p << ' ' << '"' << stringifyMemBarKind(kind.getKind()) << '"'; + p.printOptionalAttrDict(attrs, {"kind"}); +} + +static ParseResult parseLegacyOrAttrPipe(OpAsmParser &parser, PipeAttr &attr) { + auto loc = parser.getCurrentLocation(); + std::string token; + if (succeeded(parser.parseOptionalString(&token))) { + auto pipe = symbolizePIPE(token); + if (!pipe) + return parser.emitError(loc) << "invalid pipe token: " << token; + attr = PipeAttr::get(parser.getContext(), *pipe); + return success(); + } + + if (succeeded(parser.parseOptionalLess())) { + StringRef keyword; + if (parser.parseKeyword(&keyword) || parser.parseGreater()) + return failure(); + auto pipe = symbolizePIPE(keyword); + if (!pipe) + return parser.emitError(loc) << "invalid pipe token: " << keyword; + attr = PipeAttr::get(parser.getContext(), *pipe); + return success(); + } + + Attribute parsed; + if (failed(parser.parseAttribute(parsed))) + return failure(); + auto pipeAttr = dyn_cast(parsed); + if (!pipeAttr) + return parser.emitError(loc, "expected pipe attribute"); + attr = pipeAttr; + return success(); +} + +static ParseResult parseLegacyOrAttrEvent(OpAsmParser &parser, EventAttr &attr) { + auto loc = parser.getCurrentLocation(); + std::string token; + if (succeeded(parser.parseOptionalString(&token))) { + auto event = symbolizeEVENT(token); + if (!event) + return parser.emitError(loc) << "invalid event token: " << token; + attr = EventAttr::get(parser.getContext(), *event); + return success(); + } + + if (succeeded(parser.parseOptionalLess())) { + StringRef keyword; + if (parser.parseKeyword(&keyword) || parser.parseGreater()) + return failure(); + auto event = symbolizeEVENT(keyword); + if (!event) + return parser.emitError(loc) << "invalid event token: " << keyword; + attr = EventAttr::get(parser.getContext(), *event); + return success(); + } + + Attribute parsed; + if (failed(parser.parseAttribute(parsed))) + return failure(); + auto eventAttr = dyn_cast(parsed); + if (!eventAttr) + return parser.emitError(loc, "expected event attribute"); + attr = eventAttr; + return success(); +} + +static ParseResult parseI32LiteralAttr(OpAsmParser &parser, IntegerAttr &attr) { + auto loc = parser.getCurrentLocation(); + int64_t value = 0; + if (failed(parser.parseInteger(value))) + return failure(); + if (value < std::numeric_limits::min() || + value > std::numeric_limits::max()) + return parser.emitError(loc, "expected 32-bit integer literal"); + attr = IntegerAttr::get(IntegerType::get(parser.getContext(), 32), value); + return success(); +} + +static void printLegacySyncTriplet(OpAsmPrinter &p, PipeAttr srcPipe, + PipeAttr dstPipe, EventAttr eventId, + ArrayRef attrs) { + p << "[<" << stringifyPIPE(srcPipe.getPipe()) << ">, <" + << stringifyPIPE(dstPipe.getPipe()) << ">, <" + << stringifyEVENT(eventId.getEvent()) << ">]"; + p.printOptionalAttrDict(attrs, {"src_pipe", "dst_pipe", "event_id"}); +} + +ParseResult SetFlagOp::parse(OpAsmParser &parser, OperationState &result) { + PipeAttr srcPipe; + PipeAttr dstPipe; + EventAttr eventId; + if (parser.parseLSquare() || parseLegacyOrAttrPipe(parser, srcPipe) || + parser.parseComma() || parseLegacyOrAttrPipe(parser, dstPipe) || + parser.parseComma() || parseLegacyOrAttrEvent(parser, eventId) || + parser.parseRSquare()) + return failure(); + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + result.addAttribute("src_pipe", srcPipe); + result.addAttribute("dst_pipe", dstPipe); + result.addAttribute("event_id", eventId); + return success(); +} + +void SetFlagOp::print(OpAsmPrinter &p) { + printLegacySyncTriplet(p, getSrcPipe(), getDstPipe(), getEventId(), + (*this)->getAttrs()); +} + +ParseResult WaitFlagOp::parse(OpAsmParser &parser, OperationState &result) { + PipeAttr srcPipe; + PipeAttr dstPipe; + EventAttr eventId; + if (parser.parseLSquare() || parseLegacyOrAttrPipe(parser, srcPipe) || + parser.parseComma() || parseLegacyOrAttrPipe(parser, dstPipe) || + parser.parseComma() || parseLegacyOrAttrEvent(parser, eventId) || + parser.parseRSquare()) + return failure(); + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + result.addAttribute("src_pipe", srcPipe); + result.addAttribute("dst_pipe", dstPipe); + result.addAttribute("event_id", eventId); + return success(); +} + +void WaitFlagOp::print(OpAsmPrinter &p) { + printLegacySyncTriplet(p, getSrcPipe(), getDstPipe(), getEventId(), + (*this)->getAttrs()); +} + +ParseResult MemBarOp::parse(OpAsmParser &parser, OperationState &result) { + MemBarAttr kind; + if (parseLegacyOrAttrMemBar(parser, kind)) + return failure(); + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + result.addAttribute("kind", kind); + return success(); +} + +void MemBarOp::print(OpAsmPrinter &p) { + printLegacyOrAttrMemBar(p, getKind(), (*this)->getAttrs()); +} + +static ParseResult parseBufSyncOp(OpAsmParser &parser, OperationState &result) { + Attribute opTypeAttr; + IntegerAttr bufIdAttr; + IntegerAttr modeAttr; + + auto loc = parser.getCurrentLocation(); + std::string token; + if (succeeded(parser.parseOptionalString(&token))) { + if (auto pipe = symbolizePIPE(token)) + opTypeAttr = PipeAttr::get(parser.getContext(), *pipe); + else if (auto opType = symbolizeSyncOpType(token)) + opTypeAttr = PipeEventTypeAttr::get(parser.getContext(), *opType); + else + return parser.emitError(loc) << "invalid get_buf/rls_buf token: " << token; + + if (parser.parseComma() || parseI32LiteralAttr(parser, bufIdAttr)) + return failure(); + if (succeeded(parser.parseOptionalComma())) { + if (parseI32LiteralAttr(parser, modeAttr)) + return failure(); + } else { + modeAttr = IntegerAttr::get(IntegerType::get(parser.getContext(), 32), 0); + } + } else if (succeeded(parser.parseOptionalLSquare())) { + if (parser.parseAttribute(opTypeAttr) || parser.parseComma() || + parseI32LiteralAttr(parser, bufIdAttr)) + return failure(); + if (succeeded(parser.parseOptionalComma())) { + if (parseI32LiteralAttr(parser, modeAttr)) + return failure(); + } else { + modeAttr = IntegerAttr::get(IntegerType::get(parser.getContext(), 32), 0); + } + if (parser.parseRSquare()) + return failure(); + } else { + return parser.emitError(loc, "expected string pipe/op_type or '['"); + } + + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + result.addAttribute("op_type", opTypeAttr); + result.addAttribute("buf_id", bufIdAttr); + result.addAttribute("mode", modeAttr); + return success(); +} + +static void printBufSyncOp(OpAsmPrinter &p, Attribute opTypeAttr, + IntegerAttr bufIdAttr, IntegerAttr modeAttr, + ArrayRef attrs) { + if (auto pipeAttr = dyn_cast(opTypeAttr)) { + p << " \"" << stringifyPIPE(pipeAttr.getPipe()) << "\", " + << bufIdAttr.getInt() << ", " << modeAttr.getInt(); + } else if (auto pipeEventType = dyn_cast(opTypeAttr)) { + p << "[" << opTypeAttr << ", " << bufIdAttr.getInt() << ", " + << modeAttr.getInt() << "]"; + } else if (auto syncOpType = dyn_cast(opTypeAttr)) { + p << "[" << opTypeAttr << ", " << bufIdAttr.getInt() << ", " + << modeAttr.getInt() << "]"; + } else { + p << "[" << opTypeAttr << ", " << bufIdAttr.getInt() << ", " + << modeAttr.getInt() << "]"; + } + p.printOptionalAttrDict(attrs, {"op_type", "buf_id", "mode"}); +} + +ParseResult GetBufOp::parse(OpAsmParser &parser, OperationState &result) { + return parseBufSyncOp(parser, result); +} + +void GetBufOp::print(OpAsmPrinter &p) { + printBufSyncOp(p, getOpTypeAttr(), getBufIdAttr(), getModeAttr(), + (*this)->getAttrs()); +} + +ParseResult RlsBufOp::parse(OpAsmParser &parser, OperationState &result) { + return parseBufSyncOp(parser, result); +} + +void RlsBufOp::print(OpAsmPrinter &p) { + printBufSyncOp(p, getOpTypeAttr(), getBufIdAttr(), getModeAttr(), + (*this)->getAttrs()); +} // ---- TOp ---- LogicalResult TGemvBiasOp::verify() { auto verifyA2A3 = [&]() -> LogicalResult { @@ -7243,7 +7692,7 @@ mlir::LogicalResult mlir::pto::TMulSOp::verify() { "expects A2/A3 tmuls element type to be i32/i16/f16/f32", "expects A5 tmuls element type to be i32/i16/i8/f16/bf16/f32", /*requireValidRowsEqualOnA2A3=*/true, - /*requireValidRowsEqualOnA5=*/false); + /*requireValidRowsEqualOnA5=*/true); } mlir::LogicalResult mlir::pto::TShlSOp::verify() { @@ -7529,6 +7978,8 @@ mlir::LogicalResult mlir::pto::TPartAddOp::verify() { auto d = getShapeVec(dstTy); if (s0.size() != 2 || s1.size() != 2 || d.size() != 2) return emitOpError() << "expects src0/src1/dst to be rank-2 (tile-shaped)"; + if (failed(verifyPartialValidPatternLoose(*this, src0Ty, src1Ty, dstTy))) + return failure(); return mlir::success(); }; return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); @@ -7562,6 +8013,8 @@ mlir::LogicalResult mlir::pto::TPartMaxOp::verify() { if (!(e0.isInteger(32) || e0.isInteger(16) || e0.isInteger(8) || e0.isF16() || e0.isBF16() || e0.isF32())) return emitOpError("expects A5 tpartmax element type to be i32/i16/i8/f16/bf16/f32"); + if (failed(verifyPartialValidPatternLoose(*this, t0, t1, td))) + return failure(); return mlir::success(); }; return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); @@ -7595,6 +8048,8 @@ mlir::LogicalResult mlir::pto::TPartMinOp::verify() { if (!(e0.isInteger(32) || e0.isInteger(16) || e0.isInteger(8) || e0.isF16() || e0.isBF16() || e0.isF32())) return emitOpError("expects A5 tpartmin element type to be i32/i16/i8/f16/bf16/f32"); + if (failed(verifyPartialValidPatternLoose(*this, t0, t1, td))) + return failure(); return mlir::success(); }; return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); @@ -7721,6 +8176,8 @@ mlir::LogicalResult mlir::pto::TPartMulOp::verify() { if (s0.size() != 2 || s1.size() != 2 || d.size() != 2) return emitOpError() << "expects src0/src1/dst to be rank-2 (tile-shaped)"; + if (failed(verifyPartialValidPatternLoose(*this, src0Ty, src1Ty, dstTy))) + return failure(); return mlir::success(); }; return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); @@ -8650,7 +9107,7 @@ mlir::LogicalResult mlir::pto::TRowExpandDivOp::verify() { "expects A5 trowexpanddiv element type to be i8/i16/i32/f16/f32"); return emitOpError("expects element type to be f16 or f32"); } - return mlir::success(); + return verifyPrecisionModeFloatOnly(getOperation(), getPrecisionMode(), elem); }; auto verifyA2A3 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A3); }; auto verifyA5 = [&]() -> LogicalResult { return verifyByArch(PTOArch::A5); }; @@ -9284,10 +9741,11 @@ mlir::LogicalResult mlir::pto::TSelSOp::verify() { FailureOr elemOr = verifyCommon(); if (failed(elemOr)) return failure(); + Type tMask = getMask().getType(); Type tSrc = getSrc().getType(); Type tDst = getDst().getType(); - if (!isRowMajorTileBuf(tSrc) || !isRowMajorTileBuf(tDst)) - return emitOpError("expects src and dst to use row-major layout"); + if (!isRowMajorTileBuf(tMask) || !isRowMajorTileBuf(tSrc) || !isRowMajorTileBuf(tDst)) + return emitOpError("expects mask, src, and dst to use row-major layout"); Type elem = *elemOr; bool ok = elem.isF16() || elem.isF32(); if (auto it = mlir::dyn_cast(elem)) @@ -9509,7 +9967,7 @@ mlir::LogicalResult mlir::pto::TSubSOp::verify() { "expects A2/A3 tsubs element type to be i32/i16/f16/f32", "expects A5 tsubs element type to be i32/i16/i8/f16/bf16/f32", /*requireValidRowsEqualOnA2A3=*/true, - /*requireValidRowsEqualOnA5=*/false); + /*requireValidRowsEqualOnA5=*/true); } @@ -10481,6 +10939,151 @@ static LogicalResult computeInnerShape(TileBufConfigAttr cfg, Type elemTy, return failure(); } +static LogicalResult +computeExpectedTileBufMemrefStrides(TileBufType tileTy, + SmallVectorImpl &expectedStrides) { + if (tileTy.getRank() != 2) + return failure(); + + ArrayRef shape = tileTy.getShape(); + if (shape.size() != 2) + return failure(); + if (shape[0] == ShapedType::kDynamic || shape[1] == ShapedType::kDynamic) + return failure(); + + auto cfg = tileTy.getConfigAttr(); + if (!cfg) + cfg = TileBufConfigAttr::getDefault(tileTy.getContext()); + + int64_t innerRows = 1, innerCols = 1; + bool boxed = false; + int32_t bl = 0, sl = 0; + if (failed(computeInnerShape(cfg, tileTy.getElementType(), innerRows, innerCols, + boxed, bl, sl))) + return failure(); + + expectedStrides.clear(); + if (!boxed) { + if (bl == 1) { + expectedStrides.push_back(1); + expectedStrides.push_back(shape[0]); + } else { + expectedStrides.push_back(shape[1]); + expectedStrides.push_back(1); + } + return success(); + } + + if (bl == 1) { + if (sl != 1) + return failure(); + expectedStrides.push_back(innerCols); + expectedStrides.push_back(shape[0]); + return success(); + } + + expectedStrides.push_back(shape[1]); + expectedStrides.push_back(innerRows); + return success(); +} + +mlir::LogicalResult mlir::pto::SimdTileToMemrefOp::verify() { + auto memTy = dyn_cast(getDst().getType()); + if (!memTy) + return emitOpError("expects result to be memref"); + + Type srcTy = getSrc().getType(); + if (auto tileTy = dyn_cast(srcTy)) { + if (memTy.getElementType() != tileTy.getElementType()) + return emitOpError( + "expects memref element type to match tile_buf element type"); + + if (memTy.getMemorySpace() != tileTy.getMemorySpace()) + return emitOpError( + "expects memref memory space to match tile_buf memory space"); + + if (memTy.getRank() != tileTy.getRank()) + return emitOpError("expects memref rank to match tile_buf rank"); + + ArrayRef tileShape = tileTy.getShape(); + ArrayRef validShape = tileTy.getValidShape(); + ArrayRef memShape = memTy.getShape(); + if (tileShape.size() != memShape.size()) + return emitOpError( + "expects memref shape rank to match tile_buf shape rank"); + + if (validShape.size() != memShape.size()) + return emitOpError( + "expects tile_buf valid shape rank to match memref shape rank"); + + for (unsigned i = 0; i < validShape.size(); ++i) { + int64_t expect = validShape[i]; + if (expect < 0) { + if (memShape[i] >= 0 && memShape[i] != tileShape[i]) { + return emitOpError() + << "expects memref dim " << i + << " to be dynamic or match physical tile dim " << tileShape[i] + << " because tile_buf valid dim is ?"; + } + continue; + } + + if (memShape[i] != expect) { + return emitOpError() << "expects memref dim " << i + << " to match tile_buf valid dim; got " + << memShape[i] << ", expected " << expect; + } + } + + SmallVector expectedStrides; + if (failed(computeExpectedTileBufMemrefStrides(tileTy, expectedStrides))) + return emitOpError("cannot infer expected strides from tile_buf layout"); + + SmallVector memStrides; + int64_t memOffset = ShapedType::kDynamic; + if (failed(getStridesAndOffset(memTy, memStrides, memOffset))) + return emitOpError("expects memref to use strided layout"); + if (memOffset != 0) + return emitOpError("expects memref offset to be 0"); + if (memStrides.size() != expectedStrides.size()) + return emitOpError("expects memref stride rank to match tile_buf rank"); + for (unsigned i = 0; i < expectedStrides.size(); ++i) { + if (memStrides[i] != expectedStrides[i]) { + return emitOpError() + << "expects memref strides to match tile_buf layout; got " + << memStrides[i] << " at dim " << i << ", expected " + << expectedStrides[i]; + } + } + return success(); + } + + auto srcMemTy = dyn_cast(srcTy); + if (!srcMemTy) + return emitOpError("expects src to be !pto.tile_buf or memref"); + + if (srcMemTy.getElementType() != memTy.getElementType()) + return emitOpError("expects src/result memref element types to match"); + + if (srcMemTy.getMemorySpace() != memTy.getMemorySpace()) + return emitOpError("expects src/result memref memory spaces to match"); + + if (srcMemTy.getRank() != memTy.getRank()) + return emitOpError("expects src/result memref ranks to match"); + + ArrayRef srcShape = srcMemTy.getShape(); + ArrayRef dstShape = memTy.getShape(); + for (unsigned i = 0; i < srcShape.size(); ++i) { + if (srcShape[i] >= 0 && dstShape[i] >= 0 && srcShape[i] != dstShape[i]) { + return emitOpError() + << "expects compatible src/result memref shapes; dim " << i + << " mismatches (" << srcShape[i] << " vs " << dstShape[i] << ")"; + } + } + + return success(); +} + mlir::LogicalResult mlir::pto::SubViewOp::verify() { if (shouldBypassDecodedMemrefVerifier(getOperation())) return success(); @@ -12929,5 +13532,6 @@ void TFreeOp::getEffects( // [Include 必须放在最后] #include "PTO/IR/PTOInterfaces.cpp.inc" +#include "PTO/IR/VPTOInterfaces.cpp.inc" #define GET_OP_CLASSES #include "PTO/IR/PTOOps.cpp.inc" diff --git a/lib/PTO/IR/PTOAttrs.cpp b/lib/PTO/IR/PTOAttrs.cpp index cc4915f83..fcf8be62a 100644 --- a/lib/PTO/IR/PTOAttrs.cpp +++ b/lib/PTO/IR/PTOAttrs.cpp @@ -87,9 +87,12 @@ LogicalResult TileBufConfigAttr::verify(function_ref emitE return emitError() << "s_fractal_size must be i32", failure(); int32_t s = (int32_t)sFractalSize.getInt(); - if (s != kFractalSize32 && s != kFractalSize16 && - s != kFractalSize512 && s != kFractalSize1024) - return emitError() << "unsupported s_fractal_size: " << s, failure(); + if (s != kFractalMxSize && s != kFractalABSize && s != kFractalCSize) + return emitError() << "unsupported s_fractal_size: " << s + << ", must be one of {" + << kFractalMxSize << ", " + << kFractalABSize << ", " + << kFractalCSize << "}", failure(); int32_t blv = getLayoutInt(bLayout, -1); if (blv != kBLayoutRowMajor && blv != kBLayoutColMajor) diff --git a/lib/PTO/IR/PTOTypeDefs.cpp b/lib/PTO/IR/PTOTypeDefs.cpp index a3d5ab596..1e1fe272c 100644 --- a/lib/PTO/IR/PTOTypeDefs.cpp +++ b/lib/PTO/IR/PTOTypeDefs.cpp @@ -442,15 +442,31 @@ static LogicalResult parseCompactTileBufFields(AsmParser &parser, static Type buildTileBufType(AsmParser &parser, const ParsedTileBufFields &fields) { MLIRContext *ctx = parser.getContext(); + auto emitError = [&]() -> InFlightDiagnostic { + return parser.emitError(parser.getNameLoc()); + }; - if (fields.rows < 0 || fields.cols < 0) { - parser.emitError(parser.getNameLoc(), "rows/cols must be non-negative"); + // 1. Shape positivity check + if (fields.rows <= 0 || fields.cols <= 0) { + emitError() << "tile_buf rows/cols must be positive"; + return Type(); + } + + // 2. ValidShape bounds check + int64_t vrow = fields.vrow < 0 ? ShapedType::kDynamic : fields.vrow; + int64_t vcol = fields.vcol < 0 ? ShapedType::kDynamic : fields.vcol; + if (vrow != ShapedType::kDynamic && vrow > fields.rows) { + emitError() << "tile_buf valid_row (" << vrow << ") exceeds row (" << fields.rows << ")"; + return Type(); + } + if (vcol != ShapedType::kDynamic && vcol > fields.cols) { + emitError() << "tile_buf valid_col (" << vcol << ") exceeds col (" << fields.cols << ")"; return Type(); } auto memorySpace = resolveTileBufMemorySpace(fields.locStr); if (!memorySpace.has_value()) { - parser.emitError(parser.getNameLoc(), "unknown loc: ") << fields.locStr; + emitError() << "unknown loc: " << fields.locStr; return Type(); } @@ -459,22 +475,29 @@ static Type buildTileBufType(AsmParser &parser, auto pv = symbolizePadValue(fields.padInt); auto compact = symbolizeCompactMode(fields.compactInt); if (!bl.has_value()) { - parser.emitError(parser.getNameLoc(), "unknown blayout: ") - << fields.blayoutStr; + emitError() << "unknown blayout: " << fields.blayoutStr; return Type(); } if (!sl.has_value()) { - parser.emitError(parser.getNameLoc(), "unknown slayout: ") - << fields.slayoutStr; + emitError() << "unknown slayout: " << fields.slayoutStr; return Type(); } if (!pv.has_value()) { - parser.emitError(parser.getNameLoc(), "unknown pad: ") << fields.padInt; + emitError() << "unknown pad: " << fields.padInt; return Type(); } if (!compact.has_value()) { - parser.emitError(parser.getNameLoc(), "unknown compact: ") - << fields.compactInt; + emitError() << "unknown compact: " << fields.compactInt; + return Type(); + } + + // 3. Fractal value check (only Mx/AB/C sizes allowed) + if (fields.fractal != kFractalMxSize && fields.fractal != kFractalABSize && fields.fractal != kFractalCSize) { + emitError() << "unsupported s_fractal_size: " << fields.fractal + << ", must be one of {" + << kFractalMxSize << ", " + << kFractalABSize << ", " + << kFractalCSize << "}"; return Type(); } @@ -482,6 +505,9 @@ static Type buildTileBufType(AsmParser &parser, resolveTileBufBLayout(parser.getContext(), memorySpace.value(), bl.value()); + // (32-byte alignment and boxed layout divisibility checks removed + // - not general hardware requirements; validation handled elsewhere) + auto blAttr = BLayoutAttr::get(ctx, effectiveBLayout); auto slAttr = SLayoutAttr::get(ctx, sl.value()); auto fractalAttr = diff --git a/lib/PTO/IR/VPTO.cpp b/lib/PTO/IR/VPTO.cpp new file mode 100644 index 000000000..1d52a15b4 --- /dev/null +++ b/lib/PTO/IR/VPTO.cpp @@ -0,0 +1,6906 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- VPTO.cpp - VPTO dialect -------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PTO/IR/PTO.h" +#include "PTO/IR/PTOTypeUtils.h" + +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/LoopLikeInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" + +#include +#include + +using namespace mlir; +using namespace mlir::pto; + +static llvm::cl::opt disableVPTOAlignChainVerification( + "vpto-disable-align-chain-verification", + llvm::cl::desc("Disable !pto.align linear-chain verifier checks"), + llvm::cl::init(false), llvm::cl::Hidden); + +static std::string formatVRegType(int64_t elementCount, Type elementType) { + std::string storage; + llvm::raw_string_ostream os(storage); + os << "!pto.vreg<" << elementCount << "x" << elementType << ">"; + return storage; +} + +static std::string formatMaskType(StringRef granularity) { + std::string storage; + llvm::raw_string_ostream os(storage); + os << "!pto.mask<" << granularity << ">"; + return storage; +} + +static LogicalResult verifyVRegTypeLike(Operation *op, Type type, + StringRef roleDescription) { + auto vecType = dyn_cast(type); + if (!vecType) + return op->emitOpError() << roleDescription << " must be !pto.vreg<...>"; + + return VRegType::verify( + [&]() { return op->emitOpError() << roleDescription << " "; }, + vecType.getElementCount(), vecType.getElementType()); +} + +static LogicalResult verifyMaskTypeLike(Operation *op, Type type, + StringRef roleDescription) { + if (!isa(type)) + return op->emitOpError() << roleDescription << " must be !pto.mask<...>"; + return success(); +} + +static LogicalResult verifyMaskTypeWithGranularityLike(Operation *op, Type type, + StringRef roleDescription, + StringRef granularity) { + auto maskType = dyn_cast(type); + if (!maskType) + return op->emitOpError() << roleDescription << " must be !pto.mask<...>"; + if (maskType.getGranularity() != granularity) { + return op->emitOpError() + << roleDescription << " must be " << formatMaskType(granularity); + } + return success(); +} + +static LogicalResult verifyVPTOScalarAccessTypes(Operation *op, Type ptrTy, + Type valueTy, + StringRef opNameForDiag) { + Type elemTy; + if (auto pty = dyn_cast(ptrTy)) { + elemTy = pty.getElementType(); + } else if (auto memTy = dyn_cast(ptrTy)) { + elemTy = memTy.getElementType(); + } else { + return op->emitOpError() << "expects " << opNameForDiag + << " pointer operand to be !pto.ptr or memref"; + } + + if (valueTy != elemTy) { + return op->emitOpError() << "expects " << opNameForDiag + << " value type to match pointer element type"; + } + return success(); +} + +static bool isMaskGranularityAdjacentWidening(StringRef inputGranularity, + StringRef resultGranularity) { + return (inputGranularity == "b8" && resultGranularity == "b16") || + (inputGranularity == "b16" && resultGranularity == "b32"); +} + +LogicalResult PTOLoadOp::verify() { + return verifyVPTOScalarAccessTypes(getOperation(), getPtr().getType(), + getValue().getType(), "load"); +} + +LogicalResult PTOStoreOp::verify() { + return verifyVPTOScalarAccessTypes(getOperation(), getPtr().getType(), + getValue().getType(), "store"); +} + +void PTOLoadOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getPtrMutable()); +} + +void PTOStoreOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getPtrMutable()); +} + +static LogicalResult verifyNotNestedInVecScope(Operation *op, + StringRef opNameForDiag) { + if (op->getParentOfType() || + op->getParentOfType()) { + return op->emitOpError() + << "must not be nested under pto.vecscope/pto.strict_vecscope; " + << opNameForDiag << " is a UB helper op rather than a vecscope op"; + } + return success(); +} + +static LogicalResult verifyNestedInVecScope(Operation *op, + StringRef opNameForDiag) { + if (op->getParentOfType() || op->getParentOfType()) + return success(); + return op->emitOpError() + << "must be nested under pto.vecscope/pto.strict_vecscope; " + << opNameForDiag << " is part of the vecscope control sequence"; +} + +static LogicalResult verifyAlignTypeLike(Operation *op, Type type, + StringRef roleDescription) { + if (!isa(type)) + return op->emitOpError() << roleDescription << " must be !pto.align"; + return success(); +} + +static bool isSupportedVdupPosition(std::optional position) { + return !position || *position == "LOWEST" || *position == "HIGHEST"; +} + +static bool isSupportedMovPadScalarType(Type type) { + if (auto intType = dyn_cast(type)) + return intType.isSignless() && + (intType.getWidth() == 8 || intType.getWidth() == 16 || + intType.getWidth() == 32); + if (auto floatType = dyn_cast(type)) + return floatType.isF16() || floatType.isBF16() || floatType.isF32(); + return false; +} + +static bool isMxElementType(Type type) { return isa(type); } + +static std::optional getVdupMaskGranularity(Type elementType) { + if (auto intType = dyn_cast(elementType)) { + switch (intType.getWidth()) { + case 8: + return StringRef("b8"); + case 16: + return StringRef("b16"); + case 32: + return StringRef("b32"); + default: + return std::nullopt; + } + } + if (elementType.isF16() || elementType.isBF16()) + return StringRef("b16"); + if (elementType.isF32()) + return StringRef("b32"); + return std::nullopt; +} + +static bool isSupportedVtrcRoundMode(StringRef mode) { + return mode == "R" || mode == "A" || mode == "F" || mode == "C" || + mode == "Z"; +} + +static bool isStoreAlignProducer(Operation *op) { + return isa(op); +} + +static bool isStoreAlignSink(Operation *op) { + return isa(op); +} + +static bool isLoadAlignProducer(Operation *op) { + return isa(op); +} + +static scf::IfOp getEnclosingBranchIf(Operation *op) { + for (Operation *cursor = op; cursor; cursor = cursor->getParentOp()) { + auto ifOp = dyn_cast(cursor); + if (!ifOp) + continue; + Region *parentRegion = op->getParentRegion(); + if (parentRegion == &ifOp.getThenRegion() || parentRegion == &ifOp.getElseRegion()) + return ifOp; + } + return nullptr; +} + +static bool isValueOwnedByRegion(Value value, Region *region) { + if (auto blockArg = dyn_cast(value)) + return blockArg.getParentRegion() == region; + if (Operation *def = value.getDefiningOp()) + return def->getParentRegion() == region; + return false; +} + +static FailureOr resolveStoreAlignRoot(Value value, Operation *user); +static FailureOr resolveLoadAlignRoot(Value value, Operation *user); + +static FailureOr resolveStoreAlignRootImpl( + Value current, llvm::SmallPtrSet visited) { + + while (true) { + if (!visited.insert(current.getAsOpaquePointer()).second) { + return failure(); + } + + if (auto blockArg = dyn_cast(current)) { + auto *owner = blockArg.getOwner(); + auto forOp = dyn_cast(owner->getParentOp()); + if (!forOp) + return failure(); + unsigned argNumber = blockArg.getArgNumber(); + unsigned ivCount = forOp.getNumInductionVars(); + if (argNumber < ivCount) + return failure(); + unsigned iterIdx = argNumber - ivCount; + if (iterIdx >= forOp.getInitArgs().size()) + return failure(); + current = forOp.getInitArgs()[iterIdx]; + continue; + } + + if (Operation *def = current.getDefiningOp()) { + if (isa(def)) + return current; + if (auto stateOp = dyn_cast(def)) { + current = stateOp.getAlignIn(); + continue; + } + if (auto stateOp = dyn_cast(def)) { + current = stateOp.getAlignIn(); + continue; + } + if (auto stateOp = dyn_cast(def)) { + current = stateOp.getAlignIn(); + continue; + } + if (auto forOp = dyn_cast(def)) { + auto result = dyn_cast(current); + if (!result) + return failure(); + unsigned resultIdx = result.getResultNumber(); + if (resultIdx >= forOp.getYieldedValues().size()) + return failure(); + current = forOp.getYieldedValues()[resultIdx]; + continue; + } + if (auto ifOp = dyn_cast(def)) { + auto result = dyn_cast(current); + if (!result || !ifOp.elseBlock()) + return failure(); + unsigned resultIdx = result.getResultNumber(); + auto thenYield = dyn_cast(ifOp.thenBlock()->getTerminator()); + auto elseYield = dyn_cast(ifOp.elseBlock()->getTerminator()); + if (!thenYield || !elseYield || resultIdx >= thenYield.getNumOperands() || + resultIdx >= elseYield.getNumOperands()) { + return failure(); + } + FailureOr thenRoot = + resolveStoreAlignRootImpl(thenYield.getOperand(resultIdx), visited); + FailureOr elseRoot = + resolveStoreAlignRootImpl(elseYield.getOperand(resultIdx), visited); + if (failed(thenRoot) || failed(elseRoot) || *thenRoot != *elseRoot) + return failure(); + return *thenRoot; + } + } + + return failure(); + } +} + +static FailureOr resolveStoreAlignRoot(Value value, Operation *user) { + (void)user; + return resolveStoreAlignRootImpl(value, {}); +} + +static LogicalResult verifyStoreAlignLoopThreading(Value align, Operation *user, + StringRef roleDescription) { + Operation *cursor = user; + while (auto forOp = cursor->getParentOfType()) { + Region *body = &forOp.getRegion(); + if (isValueOwnedByRegion(align, body)) + return success(); + if (!isValueOwnedByRegion(align, body)) { + return user->emitOpError() + << roleDescription + << " must be threaded through scf.for iter_args when used inside a " + "loop"; + } + cursor = forOp; + } + return success(); +} + +static FailureOr resolveSingleAlignIfResult(scf::IfOp ifOp) { + SmallVector alignResultIndices; + for (auto [index, type] : llvm::enumerate(ifOp.getResultTypes())) { + if (isa(type)) + alignResultIndices.push_back(index); + } + if (alignResultIndices.size() != 1) + return failure(); + return ifOp.getResult(alignResultIndices.front()); +} + +static LogicalResult verifyStoreAlignLinearUses(Value value, Operation *user) { + llvm::SmallPtrSet visited; + Value current = value; + + while (visited.insert(current.getAsOpaquePointer()).second) { + SmallVector nextValues; + SmallVector terminalUsers; + SmallVector branchUsers; + + for (OpOperand &use : current.getUses()) { + Operation *owner = use.getOwner(); + if (isStoreAlignSink(owner)) { + terminalUsers.push_back(owner); + branchUsers.push_back(owner); + continue; + } + if (auto stateOp = dyn_cast(owner)) { + nextValues.push_back(stateOp.getAlignOut()); + branchUsers.push_back(owner); + continue; + } + if (auto stateOp = dyn_cast(owner)) { + nextValues.push_back(stateOp.getAlignOut()); + branchUsers.push_back(owner); + continue; + } + if (auto stateOp = dyn_cast(owner)) { + nextValues.push_back(stateOp.getAlignOut()); + branchUsers.push_back(owner); + continue; + } + if (auto forOp = dyn_cast(owner)) { + unsigned firstInitArg = forOp.getNumControlOperands(); + if (use.getOperandNumber() < firstInitArg) + return user->emitOpError() + << "found unexpected scf.for control operand use for !pto.align"; + unsigned iterIdx = use.getOperandNumber() - firstInitArg; + if (iterIdx >= forOp.getRegionIterArgs().size()) + return user->emitOpError() + << "found invalid scf.for iter_args use for !pto.align"; + nextValues.push_back(forOp.getRegionIterArgs()[iterIdx]); + continue; + } + if (auto yieldOp = dyn_cast(owner)) { + auto forOp = dyn_cast(yieldOp->getParentOp()); + if (!forOp) + return user->emitOpError() + << "found !pto.align yielded from non-scf.for loop"; + unsigned resultIdx = use.getOperandNumber(); + if (resultIdx >= forOp.getNumResults()) + return user->emitOpError() + << "found invalid scf.yield result mapping for !pto.align"; + nextValues.push_back(forOp.getResult(resultIdx)); + continue; + } + return user->emitOpError() + << "found unsupported !pto.align consumer " << owner->getName(); + } + + if (nextValues.size() + terminalUsers.size() > 1) { + scf::IfOp commonIf; + for (Operation *branchUser : branchUsers) { + scf::IfOp enclosingIf = getEnclosingBranchIf(branchUser); + if (!enclosingIf) { + commonIf = nullptr; + break; + } + if (!commonIf) + commonIf = enclosingIf; + else if (commonIf != enclosingIf) { + commonIf = nullptr; + break; + } + } + if (commonIf) { + FailureOr mergedValue = resolveSingleAlignIfResult(commonIf); + if (succeeded(mergedValue)) { + current = *mergedValue; + continue; + } + } + return user->emitOpError() + << "!pto.align value must form a single linear store-state chain"; + } + if (nextValues.empty()) + return success(); + current = nextValues.front(); + } + + return success(); +} + +static LogicalResult verifyStoreAlignChain(Value align, Operation *user, + StringRef roleDescription) { + if (disableVPTOAlignChainVerification) + return success(); + + if (failed(verifyAlignTypeLike(user, align.getType(), roleDescription))) + return failure(); + + if (failed(verifyStoreAlignLoopThreading(align, user, roleDescription))) + return failure(); + + FailureOr root = resolveStoreAlignRoot(align, user); + if (failed(root)) { + if (Operation *def = align.getDefiningOp()) { + if (!isa(def)) { + return user->emitOpError() + << roleDescription + << " must be produced by pto.init_align or a prior store-state op, got " + << def->getName(); + } + } + return user->emitOpError() + << roleDescription + << " must be produced by pto.init_align or a prior store-state op"; + } + + Operation *def = (*root).getDefiningOp(); + if (!isStoreAlignProducer(def)) { + return user->emitOpError() + << roleDescription + << " must be produced by pto.init_align or a prior store-state op, got " + << def->getName(); + } + + return verifyStoreAlignLinearUses(*root, user); +} + +static FailureOr resolveLoadAlignRootImpl( + Value current, llvm::SmallPtrSet visited) { + + while (true) { + if (!visited.insert(current.getAsOpaquePointer()).second) + return failure(); + + if (auto blockArg = dyn_cast(current)) { + auto *owner = blockArg.getOwner(); + auto forOp = dyn_cast(owner->getParentOp()); + if (!forOp) + return failure(); + unsigned argNumber = blockArg.getArgNumber(); + unsigned ivCount = forOp.getNumInductionVars(); + if (argNumber < ivCount) + return failure(); + unsigned iterIdx = argNumber - ivCount; + if (iterIdx >= forOp.getInitArgs().size()) + return failure(); + current = forOp.getInitArgs()[iterIdx]; + continue; + } + + if (Operation *def = current.getDefiningOp()) { + if (isa(def)) + return current; + if (auto stateOp = dyn_cast(def)) { + current = stateOp.getAlign(); + continue; + } + if (auto forOp = dyn_cast(def)) { + auto result = dyn_cast(current); + if (!result) + return failure(); + unsigned resultIdx = result.getResultNumber(); + if (resultIdx >= forOp.getYieldedValues().size()) + return failure(); + current = forOp.getYieldedValues()[resultIdx]; + continue; + } + if (auto ifOp = dyn_cast(def)) { + auto result = dyn_cast(current); + if (!result || !ifOp.elseBlock()) + return failure(); + unsigned resultIdx = result.getResultNumber(); + auto thenYield = dyn_cast(ifOp.thenBlock()->getTerminator()); + auto elseYield = dyn_cast(ifOp.elseBlock()->getTerminator()); + if (!thenYield || !elseYield || resultIdx >= thenYield.getNumOperands() || + resultIdx >= elseYield.getNumOperands()) { + return failure(); + } + FailureOr thenRoot = + resolveLoadAlignRootImpl(thenYield.getOperand(resultIdx), visited); + FailureOr elseRoot = + resolveLoadAlignRootImpl(elseYield.getOperand(resultIdx), visited); + if (failed(thenRoot) || failed(elseRoot) || *thenRoot != *elseRoot) + return failure(); + return *thenRoot; + } + } + + return failure(); + } +} + +static FailureOr resolveLoadAlignRoot(Value value, Operation *user) { + (void)user; + return resolveLoadAlignRootImpl(value, {}); +} + +static LogicalResult verifyLoadAlignLoopThreading(Value align, Operation *user, + StringRef roleDescription) { + Operation *cursor = user; + while (auto forOp = cursor->getParentOfType()) { + Region *body = &forOp.getRegion(); + if (isValueOwnedByRegion(align, body)) + return success(); + if (!isValueOwnedByRegion(align, body)) { + return user->emitOpError() + << roleDescription + << " must be threaded through scf.for iter_args when used inside a " + "loop"; + } + cursor = forOp; + } + return success(); +} + +static LogicalResult verifyLoadAlignLinearUses(Value value, Operation *user) { + llvm::SmallPtrSet visited; + Value current = value; + + while (visited.insert(current.getAsOpaquePointer()).second) { + SmallVector nextValues; + SmallVector branchUsers; + + for (OpOperand &use : current.getUses()) { + Operation *owner = use.getOwner(); + if (auto stateOp = dyn_cast(owner)) { + nextValues.push_back(stateOp.getUpdatedAlign()); + branchUsers.push_back(owner); + continue; + } + if (auto forOp = dyn_cast(owner)) { + unsigned firstInitArg = forOp.getNumControlOperands(); + if (use.getOperandNumber() < firstInitArg) { + return user->emitOpError() + << "found unexpected scf.for control operand use for !pto.align"; + } + unsigned iterIdx = use.getOperandNumber() - firstInitArg; + if (iterIdx >= forOp.getRegionIterArgs().size()) { + return user->emitOpError() + << "found invalid scf.for iter_args use for !pto.align"; + } + nextValues.push_back(forOp.getRegionIterArgs()[iterIdx]); + continue; + } + if (auto yieldOp = dyn_cast(owner)) { + auto forOp = dyn_cast(yieldOp->getParentOp()); + if (!forOp) { + return user->emitOpError() + << "found !pto.align yielded from non-scf.for loop"; + } + unsigned resultIdx = use.getOperandNumber(); + if (resultIdx >= forOp.getNumResults()) { + return user->emitOpError() + << "found invalid scf.yield result mapping for !pto.align"; + } + nextValues.push_back(forOp.getResult(resultIdx)); + continue; + } + return user->emitOpError() + << "found unsupported !pto.align consumer " << owner->getName(); + } + + if (nextValues.size() > 1) { + scf::IfOp commonIf; + for (Operation *branchUser : branchUsers) { + scf::IfOp enclosingIf = getEnclosingBranchIf(branchUser); + if (!enclosingIf) { + commonIf = nullptr; + break; + } + if (!commonIf) + commonIf = enclosingIf; + else if (commonIf != enclosingIf) { + commonIf = nullptr; + break; + } + } + if (commonIf) { + FailureOr mergedValue = resolveSingleAlignIfResult(commonIf); + if (succeeded(mergedValue)) { + current = *mergedValue; + continue; + } + } + return user->emitOpError() + << "!pto.align value must form a single linear load-state chain"; + } + if (nextValues.empty()) + return success(); + current = nextValues.front(); + } + + return success(); +} + +static LogicalResult verifyLoadAlignChain(Value align, Operation *user, + StringRef roleDescription) { + if (disableVPTOAlignChainVerification) + return success(); + + if (failed(verifyAlignTypeLike(user, align.getType(), roleDescription))) + return failure(); + + if (failed(verifyLoadAlignLoopThreading(align, user, roleDescription))) + return failure(); + + FailureOr root = resolveLoadAlignRoot(align, user); + if (failed(root)) { + if (Operation *def = align.getDefiningOp()) { + if (!isa(def)) { + return user->emitOpError() + << roleDescription + << " must be produced by pto.vldas or a prior load-state op, got " + << def->getName(); + } + } + return user->emitOpError() + << roleDescription + << " must be produced by pto.vldas or a prior load-state op"; + } + + Operation *def = (*root).getDefiningOp(); + if (!isLoadAlignProducer(def)) { + return user->emitOpError() + << roleDescription + << " must be produced by pto.vldas or a prior load-state op, got " + << def->getName(); + } + + return verifyLoadAlignLinearUses(*root, user); +} + +static bool isSupportedPredicatePattern(StringRef pattern) { + return pattern == "PAT_ALL" || pattern == "PAT_VL1" || pattern == "PAT_VL2" || + pattern == "PAT_VL3" || pattern == "PAT_VL4" || pattern == "PAT_VL8" || + pattern == "PAT_VL16" || pattern == "PAT_VL32" || + pattern == "PAT_VL64" || pattern == "PAT_VL128" || + pattern == "PAT_M3" || pattern == "PAT_M4" || pattern == "PAT_H" || + pattern == "PAT_Q" || pattern == "PAT_ALLF"; +} + +static bool isSupportedPredicateLoadDist(StringRef dist) { + return dist == "NORM" || dist == "US" || dist == "DS"; +} + +static bool isSupportedPredicateStoreDist(StringRef dist) { + return dist == "NORM" || dist == "PK"; +} + +static bool isSupportedPartToken(StringRef part) { + return part == "LOWER" || part == "HIGHER"; +} + +static bool isSupportedSprToken(StringRef spr) { return spr == "AR"; } + +static std::optional normalizeRoundModeToken(StringRef token) { + if (token == "R" || token == "ROUND_R") + return StringRef("R"); + if (token == "A" || token == "ROUND_A") + return StringRef("A"); + if (token == "F" || token == "ROUND_F") + return StringRef("F"); + if (token == "C" || token == "ROUND_C") + return StringRef("C"); + if (token == "Z" || token == "ROUND_Z") + return StringRef("Z"); + if (token == "O" || token == "ROUND_O") + return StringRef("O"); + return std::nullopt; +} + +static std::optional normalizeSaturationToken(StringRef token) { + if (token == "SAT" || token == "RS_ENABLE") + return StringRef("SAT"); + if (token == "NOSAT" || token == "RS_DISABLE") + return StringRef("NOSAT"); + return std::nullopt; +} + +static std::optional normalizeEvenOddPartToken(StringRef token) { + if (token == "EVEN" || token == "PART_EVEN") + return StringRef("EVEN"); + if (token == "ODD" || token == "PART_ODD") + return StringRef("ODD"); + return std::nullopt; +} + +static std::optional normalizePacked4PartToken(StringRef token) { + if (token == "P0" || token == "PART_P0") + return StringRef("P0"); + if (token == "P1" || token == "PART_P1") + return StringRef("P1"); + if (token == "P2" || token == "PART_P2") + return StringRef("P2"); + if (token == "P3" || token == "PART_P3") + return StringRef("P3"); + return std::nullopt; +} + +static std::optional normalizeVcvtPartToken(StringRef token) { + if (auto normalized = normalizeEvenOddPartToken(token)) + return normalized; + return normalizePacked4PartToken(token); +} + +namespace { + +enum class VcvtElemKind { + Invalid, + F16, + BF16, + F32, + S8, + U8, + S16, + U16, + S32, + U32, + S64, +}; + +struct VcvtContract { + bool requiresRnd; + bool requiresSat; + bool requiresPart; +}; + +enum class VcvtPartFamily { + EvenOdd, + Packed4, +}; + +static VcvtElemKind classifyVcvtElemType(Type type) { + if (type.isF16()) + return VcvtElemKind::F16; + if (type.isBF16()) + return VcvtElemKind::BF16; + if (type.isF32()) + return VcvtElemKind::F32; + if (auto intType = dyn_cast(type)) { + switch (intType.getWidth()) { + case 8: + return intType.isUnsigned() ? VcvtElemKind::U8 : VcvtElemKind::S8; + case 16: + return intType.isUnsigned() ? VcvtElemKind::U16 : VcvtElemKind::S16; + case 32: + return intType.isUnsigned() ? VcvtElemKind::U32 : VcvtElemKind::S32; + case 64: + return intType.isUnsigned() ? VcvtElemKind::Invalid : VcvtElemKind::S64; + default: + return VcvtElemKind::Invalid; + } + } + return VcvtElemKind::Invalid; +} + +static std::optional getVcvtElemBitWidth(VcvtElemKind kind) { + switch (kind) { + case VcvtElemKind::F16: + case VcvtElemKind::BF16: + case VcvtElemKind::S16: + case VcvtElemKind::U16: + return 16; + case VcvtElemKind::F32: + case VcvtElemKind::S32: + case VcvtElemKind::U32: + return 32; + case VcvtElemKind::S8: + case VcvtElemKind::U8: + return 8; + case VcvtElemKind::S64: + return 64; + case VcvtElemKind::Invalid: + return std::nullopt; + } + return std::nullopt; +} + +static std::optional classifyVcvtPartFamily(unsigned srcBits, + unsigned dstBits) { + unsigned largerBits = std::max(srcBits, dstBits); + unsigned smallerBits = std::min(srcBits, dstBits); + if (largerBits == smallerBits * 2) + return VcvtPartFamily::EvenOdd; + if (largerBits == smallerBits * 4) + return VcvtPartFamily::Packed4; + return std::nullopt; +} + +static bool isValidVcvtPartForFamily(StringRef part, VcvtPartFamily family) { + switch (family) { + case VcvtPartFamily::EvenOdd: + return part == "EVEN" || part == "ODD"; + case VcvtPartFamily::Packed4: + return part == "P0" || part == "P1" || part == "P2" || part == "P3"; + } + return false; +} + +static std::optional lookupVcvtContract(VcvtElemKind src, + VcvtElemKind dst) { + switch (src) { + case VcvtElemKind::F32: + switch (dst) { + case VcvtElemKind::F16: + case VcvtElemKind::BF16: + case VcvtElemKind::S16: + case VcvtElemKind::S64: + return VcvtContract{/*requiresRnd=*/true, /*requiresSat=*/true, + /*requiresPart=*/true}; + case VcvtElemKind::S32: + return VcvtContract{/*requiresRnd=*/true, /*requiresSat=*/true, + /*requiresPart=*/false}; + default: + return std::nullopt; + } + case VcvtElemKind::F16: + switch (dst) { + case VcvtElemKind::F32: + return VcvtContract{/*requiresRnd=*/false, /*requiresSat=*/false, + /*requiresPart=*/true}; + case VcvtElemKind::S32: + return VcvtContract{/*requiresRnd=*/true, /*requiresSat=*/false, + /*requiresPart=*/true}; + case VcvtElemKind::S16: + return VcvtContract{/*requiresRnd=*/true, /*requiresSat=*/true, + /*requiresPart=*/false}; + case VcvtElemKind::S8: + case VcvtElemKind::U8: + return VcvtContract{/*requiresRnd=*/true, /*requiresSat=*/true, + /*requiresPart=*/true}; + default: + return std::nullopt; + } + case VcvtElemKind::BF16: + switch (dst) { + case VcvtElemKind::F16: + return VcvtContract{/*requiresRnd=*/true, /*requiresSat=*/true, + /*requiresPart=*/false}; + case VcvtElemKind::F32: + return VcvtContract{/*requiresRnd=*/false, /*requiresSat=*/false, + /*requiresPart=*/true}; + case VcvtElemKind::S32: + return VcvtContract{/*requiresRnd=*/true, /*requiresSat=*/true, + /*requiresPart=*/true}; + default: + return std::nullopt; + } + case VcvtElemKind::U8: + switch (dst) { + case VcvtElemKind::F16: + case VcvtElemKind::U16: + case VcvtElemKind::U32: + return VcvtContract{/*requiresRnd=*/false, /*requiresSat=*/false, + /*requiresPart=*/true}; + default: + return std::nullopt; + } + case VcvtElemKind::S8: + switch (dst) { + case VcvtElemKind::F16: + case VcvtElemKind::S16: + case VcvtElemKind::S32: + return VcvtContract{/*requiresRnd=*/false, /*requiresSat=*/false, + /*requiresPart=*/true}; + default: + return std::nullopt; + } + case VcvtElemKind::U16: + switch (dst) { + case VcvtElemKind::U8: + return VcvtContract{/*requiresRnd=*/false, /*requiresSat=*/true, + /*requiresPart=*/true}; + case VcvtElemKind::U32: + return VcvtContract{/*requiresRnd=*/false, /*requiresSat=*/false, + /*requiresPart=*/true}; + default: + return std::nullopt; + } + case VcvtElemKind::S16: + switch (dst) { + case VcvtElemKind::F16: + return VcvtContract{/*requiresRnd=*/true, /*requiresSat=*/false, + /*requiresPart=*/false}; + case VcvtElemKind::F32: + case VcvtElemKind::U32: + case VcvtElemKind::S32: + return VcvtContract{/*requiresRnd=*/false, /*requiresSat=*/false, + /*requiresPart=*/true}; + case VcvtElemKind::U8: + return VcvtContract{/*requiresRnd=*/false, /*requiresSat=*/true, + /*requiresPart=*/true}; + default: + return std::nullopt; + } + case VcvtElemKind::U32: + switch (dst) { + case VcvtElemKind::U8: + case VcvtElemKind::U16: + case VcvtElemKind::S16: + return VcvtContract{/*requiresRnd=*/false, /*requiresSat=*/true, + /*requiresPart=*/true}; + default: + return std::nullopt; + } + case VcvtElemKind::S32: + switch (dst) { + case VcvtElemKind::F32: + return VcvtContract{/*requiresRnd=*/true, /*requiresSat=*/false, + /*requiresPart=*/false}; + case VcvtElemKind::U8: + case VcvtElemKind::U16: + case VcvtElemKind::S16: + return VcvtContract{/*requiresRnd=*/false, /*requiresSat=*/true, + /*requiresPart=*/true}; + case VcvtElemKind::S64: + return VcvtContract{/*requiresRnd=*/false, /*requiresSat=*/false, + /*requiresPart=*/true}; + default: + return std::nullopt; + } + case VcvtElemKind::S64: + switch (dst) { + case VcvtElemKind::F32: + return VcvtContract{/*requiresRnd=*/true, /*requiresSat=*/false, + /*requiresPart=*/true}; + case VcvtElemKind::S32: + return VcvtContract{/*requiresRnd=*/false, /*requiresSat=*/true, + /*requiresPart=*/true}; + default: + return std::nullopt; + } + case VcvtElemKind::Invalid: + return std::nullopt; + } + return std::nullopt; +} + +} // namespace + +static std::optional getDistElementWidth(Type type) { + if (auto intType = dyn_cast(type)) + return intType.getWidth(); + if (type.isF16() || type.isBF16()) + return 16; + if (type.isF32()) + return 32; + if (type.isF64()) + return 64; + return std::nullopt; +} + +static bool isSupportedVldx2DistToken(StringRef dist) { + return dist == "BDINTLV" || dist == "DINTLV_B8" || dist == "DINTLV_B16" || + dist == "DINTLV_B32"; +} + +static bool isSupportedVldsDistToken(StringRef dist) { + return dist == "NORM" || dist == "BRC_B8" || dist == "BRC_B16" || + dist == "BRC_B32" || dist == "US_B8" || dist == "US_B16" || + dist == "DS_B8" || dist == "DS_B16" || dist == "UNPK_B8" || + dist == "UNPK_B16" || dist == "UNPK_B32" || dist == "BRC_BLK" || + dist == "E2B_B16" || dist == "E2B_B32" || dist == "UNPK4" || + dist == "SPLT4CHN" || dist == "SPLT2CHN_B8" || dist == "SPLT2CHN_B16"; +} + +static bool isSupportedVstsDistToken(StringRef dist) { + return dist == "NORM_B8" || dist == "NORM_B16" || dist == "NORM_B32" || + dist == "1PT_B8" || dist == "1PT_B16" || dist == "1PT_B32" || + dist == "PK_B16" || dist == "PK_B32" || dist == "PK_B64" || + dist == "PK4_B32" || dist == "MRG4CHN_B8" || dist == "MRG2CHN_B8" || + dist == "MRG2CHN_B16"; +} + +static bool isSupportedVstsx2DistToken(StringRef dist) { + return dist == "INTLV_B8" || dist == "INTLV_B16" || dist == "INTLV_B32"; +} + +static std::optional +getVstsMaskGranularityOverride(StringRef dist, Type elementType) { + auto width = getDistElementWidth(elementType); + if (!width) + return std::nullopt; + + if (dist == "MRG4CHN_B8") + return StringRef("b32"); + if (dist == "MRG2CHN_B8") + return StringRef("b16"); + if (dist == "MRG2CHN_B16") + return StringRef("b32"); + if (dist == "PK_B16") + return StringRef("b16"); + if (dist == "PK_B32") + return StringRef("b32"); + + return std::nullopt; +} + +static bool isSupportedPostMode(StringRef mode) { + return mode == "NO_POST_UPDATE" || mode == "POST_UPDATE"; +} + +static std::optional getOptionalPostModeAttr(Operation *op) { + if (auto mode = op->getAttrOfType("mode")) + return mode.getValue(); + return std::nullopt; +} + +static unsigned getIntOrFloatBitWidth(Type type) { + if (auto intType = dyn_cast(type)) + return intType.getWidth(); + if (auto floatType = dyn_cast(type)) + return floatType.getWidth(); + return 0; +} + +static bool isIntegerOrFloatLike(Type type) { + return isa(type) || isa(type); +} + +static std::optional getVRegStorageBitWidth(Type type) { + auto vecType = dyn_cast(type); + if (!vecType) + return std::nullopt; + unsigned elemWidth = getIntOrFloatBitWidth(vecType.getElementType()); + if (!elemWidth) + return std::nullopt; + return vecType.getElementCount() * static_cast(elemWidth); +} + +static LogicalResult verifyIntegerVRegTypeLike(Operation *op, Type type, + StringRef roleDescription) { + if (failed(verifyVRegTypeLike(op, type, roleDescription))) + return failure(); + auto vecType = cast(type); + if (!isa(vecType.getElementType())) + return op->emitOpError() + << roleDescription << " must use integer vector element type"; + return success(); +} + +enum class MemoryRole { + Unknown, + GM, + UB, + Other, +}; + +static MemoryRole classifyMemoryRole(Type type) { + auto memrefType = dyn_cast(type); + if (!memrefType) { + if (auto ptrType = dyn_cast(type)) { + switch (ptrType.getMemorySpace().getAddressSpace()) { + case pto::AddressSpace::GM: + case pto::AddressSpace::Zero: + return MemoryRole::GM; + case pto::AddressSpace::VEC: + return MemoryRole::UB; + default: + return MemoryRole::Other; + } + } + return MemoryRole::Other; + } + + Attribute memorySpace = memrefType.getMemorySpace(); + if (!memorySpace) + return MemoryRole::Unknown; + + if (auto addrSpace = dyn_cast(memorySpace)) { + switch (addrSpace.getAddressSpace()) { + case pto::AddressSpace::GM: + case pto::AddressSpace::Zero: + return MemoryRole::GM; + case pto::AddressSpace::VEC: + return MemoryRole::UB; + default: + return MemoryRole::Other; + } + } + + if (auto intAttr = dyn_cast(memorySpace)) { + switch (intAttr.getInt()) { + case static_cast(pto::AddressSpace::GM): + case static_cast(pto::AddressSpace::Zero): + return MemoryRole::GM; + case static_cast(pto::AddressSpace::VEC): + return MemoryRole::UB; + default: + return MemoryRole::Other; + } + } + + return MemoryRole::Other; +} + +static bool isBufferLike(Type type) { + return isa(type); +} + +static int64_t getBufferElementByteSize(Type type) { + Type elementType; + if (auto ptrType = dyn_cast(type)) { + elementType = ptrType.getElementType(); + } else if (auto memrefType = dyn_cast(type)) { + elementType = memrefType.getElementType(); + } else { + return 0; + } + + return getPTOStorageElemByteSize(elementType); +} + +static std::optional getBufferAddressSpace(Type type) { + if (auto ptrType = dyn_cast(type)) + return ptrType.getMemorySpace().getAddressSpace(); + if (auto memrefType = dyn_cast(type)) { + if (auto space = + dyn_cast_or_null(memrefType.getMemorySpace())) + return space.getAddressSpace(); + if (auto intSpace = dyn_cast_or_null(memrefType.getMemorySpace())) + return static_cast(intSpace.getInt()); + } + return std::nullopt; +} + +template +static LogicalResult verifyCubeBridgeLoadLikeOp(BridgeLoadOp op, + AddressSpace expectedDstSpace, + StringRef dstName) { + if (!isBufferLike(op.getSource().getType()) || + !isBufferLike(op.getDestination().getType())) + return op.emitOpError("requires buffer-like source and destination"); + + if (getBufferAddressSpace(op.getSource().getType()) != AddressSpace::MAT) + return op.emitOpError("requires MAT source"); + if (getBufferAddressSpace(op.getDestination().getType()) != expectedDstSpace) { + return op.emitOpError() + << "requires " << dstName << " destination"; + } + + int64_t sourceElemBytes = getBufferElementByteSize(op.getSource().getType()); + int64_t destinationElemBytes = + getBufferElementByteSize(op.getDestination().getType()); + if (sourceElemBytes <= 0 || destinationElemBytes <= 0) { + return op.emitOpError( + "requires source and destination element types with known byte width"); + } + if (sourceElemBytes != destinationElemBytes) { + return op.emitOpError( + "requires source and destination element byte widths to match"); + } + + return success(); +} + +static bool hasAll(Value first, Value second, Value third) { + return static_cast(first) && static_cast(second) && + static_cast(third); +} + +static bool hasAny(Value first, Value second, Value third) { + return static_cast(first) || static_cast(second) || + static_cast(third); +} + +static ParseResult parseRequiredOperandWithComma( + OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand) { + if (parser.parseOperand(operand)) + return failure(); + return parser.parseComma(); +} + +static ParseResult parseDmaTripleGroup( + OpAsmParser &parser, StringRef keyword, + SmallVectorImpl &operands) { + if (parser.parseKeyword(keyword) || parser.parseLParen()) + return failure(); + for (int i = 0; i < 3; ++i) { + OpAsmParser::UnresolvedOperand operand; + if (parser.parseOperand(operand)) + return failure(); + operands.push_back(operand); + if (i != 2 && parser.parseComma()) + return failure(); + } + return parser.parseRParen(); +} + +static ParseResult parseOptionalDmaTripleGroupAlias( + OpAsmParser &parser, ArrayRef keywords, + StringRef &parsedKeyword, + SmallVectorImpl &operands) { + parsedKeyword = {}; + for (StringRef keyword : keywords) { + if (failed(parser.parseOptionalKeyword(keyword))) + continue; + parsedKeyword = keyword; + if (parser.parseLParen()) + return failure(); + for (int i = 0; i < 3; ++i) { + OpAsmParser::UnresolvedOperand operand; + if (parser.parseOperand(operand)) + return failure(); + operands.push_back(operand); + if (i != 2 && parser.parseComma()) + return failure(); + } + return parser.parseRParen(); + } + return success(); +} + +static bool isDmaLoopKeyword(StringRef keyword) { + if (keyword == "loop") + return true; + if (!keyword.consume_front("loop")) + return false; + if (keyword.empty()) + return false; + return llvm::all_of(keyword, llvm::isDigit); +} + +static ParseResult parseDmaTripleTypes(OpAsmParser &parser, + SmallVectorImpl &types) { + for (int i = 0; i < 3; ++i) { + Type type; + if (parser.parseType(type)) + return failure(); + types.push_back(type); + if (i != 2 && parser.parseComma()) + return failure(); + } + return success(); +} + +static ParseResult parseDmaPadTypes(OpAsmParser &parser, + SmallVectorImpl &types) { + Type valueType; + if (parser.parseType(valueType)) + return failure(); + types.push_back(valueType); + if (succeeded(parser.parseOptionalComma())) { + Type leftType; + Type rightType; + if (parser.parseType(leftType) || parser.parseComma() || + parser.parseType(rightType)) + return failure(); + types.push_back(leftType); + types.push_back(rightType); + } + return success(); +} + +static void printDmaTripleGroup(OpAsmPrinter &printer, StringRef keyword, + Value first, Value second, Value third) { + printer << " " << keyword << "(" << first << ", " << second << ", " << third + << ")"; +} + +static void printDmaTripleTypes(OpAsmPrinter &printer, StringRef keyword, + Type first, Type second, Type third) { + printer << ", " << keyword << " " << first << ", " << second << ", " << third; +} + +static void printDmaPadGroup(OpAsmPrinter &printer, Value value, Value left, + Value right) { + printer << " pad(" << value; + if (left || right) + printer << ", " << left << ", " << right; + printer << ")"; +} + +static void printDmaPadTypes(OpAsmPrinter &printer, Type valueType, + Type leftType, Type rightType) { + printer << ", pad " << valueType; + if (leftType || rightType) + printer << ", " << leftType << ", " << rightType; +} + +static FailureOr +parseCubeLoadFracModeKeyword(StringRef keyword) { + if (std::optional mode = symbolizeCubeLoadFracMode(keyword)) + return *mode; + return failure(); +} + +static ParseResult parseFixedKeywordOperandGroup( + OpAsmParser &parser, StringRef keyword, int operandCount, + SmallVectorImpl &operands) { + if (parser.parseKeyword(keyword) || parser.parseLParen()) + return failure(); + for (int i = 0; i < operandCount; ++i) { + OpAsmParser::UnresolvedOperand operand; + if (parser.parseOperand(operand)) + return failure(); + operands.push_back(operand); + if (i + 1 != operandCount && parser.parseComma()) + return failure(); + } + return parser.parseRParen(); +} + +static ParseResult parseFixedKeywordTypes(OpAsmParser &parser, StringRef keyword, + int typeCount, + SmallVectorImpl &types) { + if (parser.parseKeyword(keyword)) + return failure(); + for (int i = 0; i < typeCount; ++i) { + Type type; + if (parser.parseType(type)) + return failure(); + types.push_back(type); + if (i + 1 != typeCount && parser.parseComma()) + return failure(); + } + return success(); +} + +static ParseResult parseCubeLoadFracSrcLayoutGroup( + OpAsmParser &parser, + SmallVectorImpl &operands) { + if (parser.parseKeyword("src_layout") || parser.parseLParen()) + return failure(); + OpAsmParser::UnresolvedOperand innerStride; + if (parser.parseOperand(innerStride)) + return failure(); + operands.push_back(innerStride); + if (succeeded(parser.parseOptionalComma())) { + OpAsmParser::UnresolvedOperand outerStride; + if (parser.parseOperand(outerStride)) + return failure(); + operands.push_back(outerStride); + } + return parser.parseRParen(); +} + +static ParseResult parseCubeLoadFracSrcLayoutTypes(OpAsmParser &parser, + SmallVectorImpl &types) { + if (parser.parseKeyword("src_layout") || parser.parseLParen()) + return failure(); + Type innerStrideType; + if (parser.parseType(innerStrideType)) + return failure(); + types.push_back(innerStrideType); + if (succeeded(parser.parseOptionalComma())) { + Type outerStrideType; + if (parser.parseType(outerStrideType)) + return failure(); + types.push_back(outerStrideType); + } + return parser.parseRParen(); +} + +static void printCubeLoadFracSrcLayoutGroup(OpAsmPrinter &printer, + Value srcInnerStride, + Value srcOuterStride) { + printer << ", src_layout(" << srcInnerStride; + if (srcOuterStride) + printer << ", " << srcOuterStride; + printer << ")"; +} + +static void printCubeLoadFracSrcLayoutTypes(OpAsmPrinter &printer, + Type srcInnerStrideType, + Type srcOuterStrideType) { + printer << ", src_layout(" << srcInnerStrideType; + if (srcOuterStrideType) + printer << ", " << srcOuterStrideType; + printer << ")"; +} + +static FailureOr parseAccStoreModeKeyword(StringRef keyword) { + if (std::optional mode = symbolizeAccStoreMode(keyword)) + return *mode; + return failure(); +} + +[[maybe_unused]] static ParseResult parseAccStoreModeGroup( + OpAsmParser &parser, StringRef &modeKeyword, + SmallVectorImpl &modeOperands) { + if (parser.parseKeyword(&modeKeyword)) + return failure(); + if (failed(parseAccStoreModeKeyword(modeKeyword))) + return parser.emitError(parser.getCurrentLocation(), + "expected one of 'nz2nd', 'nz2dn', or 'nz2nz'"); + auto parseModeOperandWithParens = [&]() -> ParseResult { + OpAsmParser::UnresolvedOperand operand; + if (parser.parseLParen() || parser.parseOperand(operand) || parser.parseRParen()) + return failure(); + modeOperands.push_back(operand); + return success(); + }; + auto parseModeOperandAfterLParen = [&]() -> ParseResult { + OpAsmParser::UnresolvedOperand operand; + if (parser.parseOperand(operand) || parser.parseRParen()) + return failure(); + modeOperands.push_back(operand); + return success(); + }; + + switch (*parseAccStoreModeKeyword(modeKeyword)) { + case AccStoreMode::Nz2nd: + return success(); + case AccStoreMode::Nz2dn: + (void)parser.parseOptionalComma(); + if (succeeded(parser.parseOptionalKeyword("loop0_src_stride"))) + return parseModeOperandWithParens(); + if (failed(parser.parseOptionalLParen())) + return success(); + return parseModeOperandAfterLParen(); + case AccStoreMode::Nz2nz: + (void)parser.parseOptionalComma(); + if (succeeded(parser.parseOptionalKeyword("split"))) + return parseModeOperandWithParens(); + if (failed(parser.parseOptionalLParen())) + return success(); + return parseModeOperandAfterLParen(); + } + return success(); +} + +[[maybe_unused]] static ParseResult +parseAccStoreModeTypes(OpAsmParser &parser, StringRef modeKeyword, + SmallVectorImpl &modeTypes) { + if (parser.parseKeyword(modeKeyword)) + return failure(); + auto parseModeTypeWithParens = [&]() -> ParseResult { + Type modeType; + if (parser.parseLParen() || parser.parseType(modeType) || parser.parseRParen()) + return failure(); + modeTypes.push_back(modeType); + return success(); + }; + auto parseModeTypeAfterLParen = [&]() -> ParseResult { + Type modeType; + if (parser.parseType(modeType) || parser.parseRParen()) + return failure(); + modeTypes.push_back(modeType); + return success(); + }; + + switch (*parseAccStoreModeKeyword(modeKeyword)) { + case AccStoreMode::Nz2nd: + return success(); + case AccStoreMode::Nz2dn: + (void)parser.parseOptionalComma(); + if (succeeded(parser.parseOptionalKeyword("loop0_src_stride"))) + return parseModeTypeWithParens(); + if (failed(parser.parseOptionalLParen())) + return success(); + return parseModeTypeAfterLParen(); + case AccStoreMode::Nz2nz: + (void)parser.parseOptionalComma(); + if (succeeded(parser.parseOptionalKeyword("split"))) + return parseModeTypeWithParens(); + if (failed(parser.parseOptionalLParen())) + return success(); + return parseModeTypeAfterLParen(); + } + return success(); +} + +[[maybe_unused]] static void printAccStoreModeGroup(OpAsmPrinter &printer, + AccStoreMode mode, + Value split, + Value loop0SrcStride) { + printer << ", " << pto::stringifyAccStoreMode(mode); + switch (mode) { + case AccStoreMode::Nz2nd: + return; + case AccStoreMode::Nz2dn: + if (loop0SrcStride) + printer << ", loop0_src_stride(" << loop0SrcStride << ")"; + return; + case AccStoreMode::Nz2nz: + if (split) + printer << ", split(" << split << ")"; + return; + } + llvm_unreachable("unexpected mte_l0c mode"); +} + +[[maybe_unused]] static void printAccStoreModeTypes(OpAsmPrinter &printer, + AccStoreMode mode, + Type splitType, + Type loop0SrcStrideType) { + printer << ", " << pto::stringifyAccStoreMode(mode); + switch (mode) { + case AccStoreMode::Nz2nd: + return; + case AccStoreMode::Nz2dn: + if (loop0SrcStrideType) + printer << ", loop0_src_stride(" << loop0SrcStrideType << ")"; + return; + case AccStoreMode::Nz2nz: + if (splitType) + printer << ", split(" << splitType << ")"; + return; + } + llvm_unreachable("unexpected mte_l0c mode"); +} + +[[maybe_unused]] static ParseResult parseMteL0cL1OptionalLoop3( + OpAsmParser &parser, + SmallVectorImpl &loop3CountOperands, + SmallVectorImpl &loop3SrcStrideOperands, + SmallVectorImpl &loop3DstStrideOperands) { + StringRef parsedKeyword; + SmallVector loop3Operands; + if (parseOptionalDmaTripleGroupAlias(parser, {"loop3"}, parsedKeyword, + loop3Operands)) + return failure(); + if (!parsedKeyword.empty()) { + loop3CountOperands.push_back(loop3Operands[0]); + loop3SrcStrideOperands.push_back(loop3Operands[1]); + loop3DstStrideOperands.push_back(loop3Operands[2]); + } + return success(); +} + +[[maybe_unused]] static ParseResult parseMteL0cL1OptionalFpc( + OpAsmParser &parser, + SmallVectorImpl &fpcOperands) { + if (failed(parser.parseOptionalKeyword("fpc"))) + return success(); + if (parser.parseLParen()) + return failure(); + OpAsmParser::UnresolvedOperand operand; + if (parser.parseOperand(operand) || parser.parseRParen()) + return failure(); + fpcOperands.push_back(operand); + return success(); +} + +[[maybe_unused]] static void printMteL0cL1OptionalFpc(OpAsmPrinter &printer, + Value fpc) { + if (fpc) + printer << ", fpc(" << fpc << ")"; +} + +[[maybe_unused]] static void +printMteL0cL1OptionalFpcType(OpAsmPrinter &printer, Type fpcType) { + if (fpcType) + printer << ", fpc(" << fpcType << ")"; +} + +[[maybe_unused]] static ParseResult parseMteL0cL1OptionalLoop3Types( + OpAsmParser &parser, SmallVectorImpl &loop3CountTypes, + SmallVectorImpl &loop3SrcStrideTypes, + SmallVectorImpl &loop3DstStrideTypes, StringRef opName) { + if (succeeded(parser.parseOptionalComma())) { + StringRef keyword; + if (parser.parseKeyword(&keyword)) + return failure(); + if (keyword != "loop3") + return parser.emitError(parser.getCurrentLocation(), "expected 'loop3'"); + SmallVector loop3GroupTypes; + if (parseDmaTripleTypes(parser, loop3GroupTypes)) + return failure(); + loop3CountTypes.push_back(loop3GroupTypes[0]); + loop3SrcStrideTypes.push_back(loop3GroupTypes[1]); + loop3DstStrideTypes.push_back(loop3GroupTypes[2]); + if (succeeded(parser.parseOptionalComma())) + return parser.emitError(parser.getCurrentLocation(), + (Twine(opName) + + " accepts at most one loop3 group") + .str()); + } + return success(); +} + +[[maybe_unused]] static LogicalResult verifyAccStoreLikeModeOperands( + Operation *op, AccStoreMode mode, Value split, Value loop0SrcStride, + Value loop3Count, Value loop3SrcStride, Value loop3DstStride, + StringRef nz2ndSplitError, StringRef nz2ndLoop0Error, + StringRef nz2dnSplitError, StringRef nz2nzLoop0Error, + StringRef nz2nzLoop3Error) { + bool hasLoop3Count = static_cast(loop3Count); + bool hasLoop3SrcStride = static_cast(loop3SrcStride); + bool hasLoop3DstStride = static_cast(loop3DstStride); + if ((hasLoop3Count != hasLoop3SrcStride) || + (hasLoop3Count != hasLoop3DstStride)) { + return op->emitOpError( + "requires loop3 count, src stride, and dst stride to appear together"); + } + + switch (mode) { + case AccStoreMode::Nz2nd: + if (split) + return op->emitOpError(nz2ndSplitError); + if (loop0SrcStride) + return op->emitOpError(nz2ndLoop0Error); + return success(); + case AccStoreMode::Nz2dn: + if (split) + return op->emitOpError(nz2dnSplitError); + return success(); + case AccStoreMode::Nz2nz: + if (loop0SrcStride) + return op->emitOpError(nz2nzLoop0Error); + if (loop3Count) + return op->emitOpError(nz2nzLoop3Error); + return success(); + } + llvm_unreachable("unexpected mte_l0c mode"); +} + +struct StructuredAccStoreAsmState { + std::optional unitFlag; + std::optional preQuantMode; + std::optional preReluMode; + std::optional mode; + std::optional atomicType; + std::optional atomicOp; + std::optional satMode; + + SmallVector preQuantOperands; + SmallVector preReluOperands; + SmallVector clipValueOperands; + SmallVector splitOperands; + SmallVector loop0SrcStrideOperands; + SmallVector loop3CountOperands; + SmallVector loop3SrcStrideOperands; + SmallVector loop3DstStrideOperands; + + SmallVector preQuantTypes; + SmallVector preReluTypes; + SmallVector clipValueTypes; + SmallVector splitTypes; + SmallVector loop0SrcStrideTypes; + SmallVector loop3CountTypes; + SmallVector loop3SrcStrideTypes; + SmallVector loop3DstStrideTypes; +}; + +enum class StructuredAccStoreClauseKind { + UnitFlag = 0, + PreQuant = 1, + PreRelu = 2, + Layout = 3, + Loop3 = 4, + Sat = 5, + Atomic = 6 +}; + +static bool isStructuredAccStoreVectorQuantMode(AccStoreQuantPreMode mode) { + switch (mode) { + case AccStoreQuantPreMode::QF322HIF8PreVec: + case AccStoreQuantPreMode::QF322HIF8PreHybridVec: + case AccStoreQuantPreMode::DEQS32IntVec: + case AccStoreQuantPreMode::REQ8Vec: + case AccStoreQuantPreMode::DEQF16Vec: + case AccStoreQuantPreMode::QF322FP8PreVec: + case AccStoreQuantPreMode::QF322F32PreVec: + case AccStoreQuantPreMode::QF162B8PreVec: + case AccStoreQuantPreMode::QF162S4PreVec: + case AccStoreQuantPreMode::REQ4Vec: + case AccStoreQuantPreMode::QF322B8PreVec: + case AccStoreQuantPreMode::QF322S4PreVec: + case AccStoreQuantPreMode::DEQS16Vec: + case AccStoreQuantPreMode::QF162S16PreVec: + case AccStoreQuantPreMode::QF322F16PreVec: + case AccStoreQuantPreMode::QF322BF16PreVec: + case AccStoreQuantPreMode::QS322BF16PreVec: + return true; + default: + return false; + } +} + +static bool isStructuredAccStoreScalingPayload(Value value) { + auto ptrType = dyn_cast_or_null(value.getType()); + return ptrType && + ptrType.getMemorySpace().getAddressSpace() == AddressSpace::SCALING; +} + +[[maybe_unused]] static bool isStructuredAccStoreScalingPayloadType(Type type) { + auto ptrType = dyn_cast_or_null(type); + return ptrType && + ptrType.getMemorySpace().getAddressSpace() == AddressSpace::SCALING; +} + +static Type getStructuredAccStoreScalingElementType(Value value) { + auto ptrType = dyn_cast_or_null(value.getType()); + if (!ptrType || + ptrType.getMemorySpace().getAddressSpace() != AddressSpace::SCALING) + return {}; + return ptrType.getElementType(); +} + +[[maybe_unused]] static bool isStructuredAccStoreIntegerPayload(Value value) { + return value.getType().isSignlessInteger(); +} + +static bool isStructuredAccStoreClipPayloadForUInt8(Type type) { + auto intType = dyn_cast(type); + if (!intType || intType.getWidth() != 16) + return false; + return intType.isUnsigned() || intType.isSignless(); +} + +static bool isStructuredAccStoreClipPayloadForSignedInt(Type type) { + auto intType = dyn_cast(type); + if (!intType) + return false; + unsigned width = intType.getWidth(); + if (width != 4 && width != 8 && width != 16) + return false; + return intType.isSigned() || intType.isSignless(); +} + +static bool isStructuredAccStoreFloatScalarPayloadType(Type type) { + return type.isF16() || type.isF32() || type.isBF16(); +} + +static bool isStructuredAccStoreFloatScalarPayload(Value value) { + return isStructuredAccStoreFloatScalarPayloadType(value.getType()); +} + +[[maybe_unused]] static bool isStructuredAccStoreIntegerPayloadType(Type type) { + return type.isSignlessInteger(); +} + +static bool isStructuredAccStoreClipSupportedElementType(Type type) { + if (auto floatType = dyn_cast(type)) + return floatType.isF16(); + auto intType = dyn_cast(type); + if (!intType) + return false; + if (intType.isUnsignedInteger(8)) + return true; + if (intType.isSignlessInteger(4) || intType.isSignlessInteger(8) || + intType.isSignlessInteger(16)) + return true; + if (intType.isSignedInteger(4) || intType.isSignedInteger(8) || + intType.isSignedInteger(16)) + return true; + return false; +} + +static LogicalResult verifyStructuredAccStoreClipPayload(Operation *op, + Type destinationElementType, + Value clipValue) { + if (!clipValue) + return success(); + + Type clipType = clipValue.getType(); + if (destinationElementType.isF16()) { + if (!clipType.isF16()) + return op->emitOpError("clip for f16 destination requires f16 payload"); + return success(); + } + + auto intType = dyn_cast(destinationElementType); + if (!intType) + return op->emitOpError() + << "clip requires destination element type to be f16, ui8, or signed 4/8/16-bit integer, got " + << destinationElementType; + + if (intType.isUnsignedInteger(8)) { + if (!isStructuredAccStoreClipPayloadForUInt8(clipType)) + return op->emitOpError("clip for ui8 destination requires ui16/signless i16 payload"); + return success(); + } + + if (intType.isSignlessInteger(4) || intType.isSignlessInteger(8) || + intType.isSignlessInteger(16) || intType.isSignedInteger(4) || + intType.isSignedInteger(8) || intType.isSignedInteger(16)) { + if (!isStructuredAccStoreClipPayloadForSignedInt(clipType)) + return op->emitOpError("clip for signed 4/8/16-bit destination requires signed/signless i4/i8/i16 payload"); + return success(); + } + + return op->emitOpError() + << "clip requires destination element type to be f16, ui8, or signed 4/8/16-bit integer, got " + << destinationElementType; +} + +static bool isStructuredAccStoreFloatPreQuantMode(AccStoreQuantPreMode mode) { + switch (mode) { + case AccStoreQuantPreMode::F32F16: + case AccStoreQuantPreMode::QF322HIF8PreVec: + case AccStoreQuantPreMode::QF322HIF8PreScalar: + case AccStoreQuantPreMode::QF322HIF8PreHybridVec: + case AccStoreQuantPreMode::QF322HIF8PreHybridScalar: + case AccStoreQuantPreMode::QF322FP8PreVec: + case AccStoreQuantPreMode::QF322FP8PreScalar: + case AccStoreQuantPreMode::QF322F32PreVec: + case AccStoreQuantPreMode::QF322F32PreScalar: + case AccStoreQuantPreMode::F32BF16: + case AccStoreQuantPreMode::QF162B8PreVec: + case AccStoreQuantPreMode::QF162B8PreScalar: + case AccStoreQuantPreMode::QF162S4PreVec: + case AccStoreQuantPreMode::QF162S4PreScalar: + case AccStoreQuantPreMode::QF322B8PreVec: + case AccStoreQuantPreMode::QF322B8PreScalar: + case AccStoreQuantPreMode::QF322S4PreVec: + case AccStoreQuantPreMode::QF322S4PreScalar: + case AccStoreQuantPreMode::QF322F16PreVec: + case AccStoreQuantPreMode::QF322F16PreScalar: + case AccStoreQuantPreMode::QF322BF16PreVec: + case AccStoreQuantPreMode::QF322BF16PreScalar: + return true; + default: + return false; + } +} + +static bool isStructuredAccStoreInt32PreQuantMode(AccStoreQuantPreMode mode) { + switch (mode) { + case AccStoreQuantPreMode::DEQS32IntVec: + case AccStoreQuantPreMode::DEQS32IntScalar: + case AccStoreQuantPreMode::REQ8Vec: + case AccStoreQuantPreMode::REQ8Scalar: + case AccStoreQuantPreMode::DEQF16Vec: + case AccStoreQuantPreMode::DEQF16Scalar: + case AccStoreQuantPreMode::DEQS16Vec: + case AccStoreQuantPreMode::DEQS16Scalar: + case AccStoreQuantPreMode::QF162S16PreVec: + case AccStoreQuantPreMode::QF162S16PreScalar: + case AccStoreQuantPreMode::QS322BF16PreVec: + case AccStoreQuantPreMode::QS322BF16PreScalar: + return true; + default: + return false; + } +} + +static ParseResult parseStructuredAccStoreUnitFlag(OpAsmParser &parser, + StructuredAccStoreAsmState &state) { + if (state.unitFlag) + return parser.emitError(parser.getCurrentLocation(), "duplicate unit_flag clause"); + StringRef keyword; + if (parser.parseLParen() || parser.parseKeyword(&keyword) || parser.parseRParen()) + return failure(); + if (keyword == "check_only") + state.unitFlag = AccStoreUnitFlagCtrl::CheckOnly; + else if (keyword == "check_and_clear") + state.unitFlag = AccStoreUnitFlagCtrl::CheckAndClear; + else + return parser.emitError(parser.getCurrentLocation(), + "expected 'check_only' or 'check_and_clear'"); + return success(); +} + +static ParseResult parseStructuredAccStorePreQuant( + OpAsmParser &parser, StructuredAccStoreAsmState &state) { + if (state.preQuantMode) + return parser.emitError(parser.getCurrentLocation(), "duplicate pre_quant clause"); + OpAsmParser::UnresolvedOperand payload; + StringRef modeKeyword; + if (parser.parseLParen() || parser.parseOperand(payload) || parser.parseComma() || + parser.parseKeyword("mode") || parser.parseEqual() || + parser.parseKeyword(&modeKeyword) || parser.parseRParen()) + return failure(); + auto mode = symbolizeAccStoreQuantPreMode(modeKeyword); + if (!mode) + return parser.emitError(parser.getCurrentLocation(), "invalid pre_quant mode"); + state.preQuantOperands.push_back(payload); + state.preQuantMode = *mode; + return success(); +} + +static ParseResult parseStructuredAccStorePreRelu( + OpAsmParser &parser, StructuredAccStoreAsmState &state) { + if (state.preReluMode) + return parser.emitError(parser.getCurrentLocation(), "duplicate pre_relu clause"); + StringRef modeKeyword; + bool hasPayload = false; + OpAsmParser::UnresolvedOperand payload; + if (parser.parseLParen()) + return failure(); + if (failed(parser.parseOptionalKeyword("mode"))) { + hasPayload = true; + if (parser.parseOperand(payload) || parser.parseComma() || + parser.parseKeyword("mode")) + return failure(); + } + if (parser.parseEqual() || parser.parseKeyword(&modeKeyword)) + return failure(); + auto mode = symbolizeReluPreMode(modeKeyword); + if (!mode) + return parser.emitError(parser.getCurrentLocation(), "invalid pre_relu mode"); + if (succeeded(parser.parseOptionalComma())) { + if (parser.parseKeyword("clip") || parser.parseEqual()) + return failure(); + if (!state.clipValueOperands.empty()) + return parser.emitError(parser.getCurrentLocation(), + "duplicate clip payload in pre_relu clause"); + OpAsmParser::UnresolvedOperand clipValue; + if (parser.parseOperand(clipValue)) + return failure(); + state.clipValueOperands.push_back(clipValue); + } + if (parser.parseRParen()) + return failure(); + + if (hasPayload) + state.preReluOperands.push_back(payload); + state.preReluMode = *mode; + return success(); +} + +static ParseResult parseStructuredAccStoreLayout( + OpAsmParser &parser, StructuredAccStoreAsmState &state, StringRef keyword) { + auto mode = parseAccStoreModeKeyword(keyword); + if (failed(mode)) + return parser.emitError(parser.getCurrentLocation(), + "expected one of 'nz2nd', 'nz2dn', or 'nz2nz'"); + if (state.mode) + return parser.emitError(parser.getCurrentLocation(), "duplicate layout clause"); + state.mode = *mode; + if (*mode == AccStoreMode::Nz2dn) { + if (succeeded(parser.parseOptionalLParen())) { + OpAsmParser::UnresolvedOperand operand; + if (parser.parseOperand(operand) || parser.parseRParen()) + return failure(); + state.loop0SrcStrideOperands.push_back(operand); + } + } else if (*mode == AccStoreMode::Nz2nz) { + if (succeeded(parser.parseOptionalLParen())) { + OpAsmParser::UnresolvedOperand operand; + if (parser.parseOperand(operand) || parser.parseRParen()) + return failure(); + state.splitOperands.push_back(operand); + } + } + return success(); +} + +static ParseResult parseStructuredAccStoreLoop3( + OpAsmParser &parser, StructuredAccStoreAsmState &state) { + if (!state.loop3CountOperands.empty()) + return parser.emitError(parser.getCurrentLocation(), "duplicate loop3 clause"); + OpAsmParser::UnresolvedOperand count; + OpAsmParser::UnresolvedOperand srcStride; + OpAsmParser::UnresolvedOperand dstStride; + if (parser.parseLParen() || parser.parseOperand(count) || parser.parseComma() || + parser.parseOperand(srcStride) || parser.parseComma() || + parser.parseOperand(dstStride) || parser.parseRParen()) + return failure(); + state.loop3CountOperands.push_back(count); + state.loop3SrcStrideOperands.push_back(srcStride); + state.loop3DstStrideOperands.push_back(dstStride); + return success(); +} + +static ParseResult parseStructuredAccStoreAtomic( + OpAsmParser &parser, StructuredAccStoreAsmState &state) { + if (state.atomicType || state.atomicOp) + return parser.emitError(parser.getCurrentLocation(), "duplicate atomic clause"); + StringRef typeKeyword; + StringRef opKeyword; + if (parser.parseLParen() || parser.parseKeyword("type") || parser.parseEqual() || + parser.parseKeyword(&typeKeyword) || parser.parseComma() || + parser.parseKeyword("op") || parser.parseEqual() || + parser.parseKeyword(&opKeyword) || parser.parseRParen()) + return failure(); + auto type = symbolizeAccStoreAtomicType(typeKeyword); + auto op = symbolizeAccStoreAtomicOp(opKeyword); + if (!type) + return parser.emitError(parser.getCurrentLocation(), "invalid atomic type"); + if (!op) + return parser.emitError(parser.getCurrentLocation(), "invalid atomic op"); + state.atomicType = *type; + state.atomicOp = *op; + return success(); +} + +static ParseResult parseStructuredAccStoreClauses( + OpAsmParser &parser, StructuredAccStoreAsmState &state) { + int lastClause = -1; + bool seenClause = false; + while (true) { + if (seenClause) { + if (failed(parser.parseOptionalComma())) + return success(); + } + StringRef keyword; + if (parser.parseKeyword(&keyword)) { + if (!seenClause) + return success(); + return failure(); + } + seenClause = true; + + StructuredAccStoreClauseKind kind; + if (keyword == "unit_flag") + kind = StructuredAccStoreClauseKind::UnitFlag; + else if (keyword == "pre_quant") + kind = StructuredAccStoreClauseKind::PreQuant; + else if (keyword == "pre_relu") + kind = StructuredAccStoreClauseKind::PreRelu; + else if (keyword == "nz2nd" || keyword == "nz2dn" || keyword == "nz2nz") + kind = StructuredAccStoreClauseKind::Layout; + else if (keyword == "loop3") + kind = StructuredAccStoreClauseKind::Loop3; + else if (keyword == "sat" || keyword == "nosat") + kind = StructuredAccStoreClauseKind::Sat; + else if (keyword == "atomic") + kind = StructuredAccStoreClauseKind::Atomic; + else + return parser.emitError(parser.getCurrentLocation(), "unknown mte_l0c clause"); + + if (static_cast(kind) < lastClause) { + return parser.emitError(parser.getCurrentLocation(), + "mte_l0c clauses must follow canonical order"); + } + lastClause = static_cast(kind); + + ParseResult parseResult = success(); + switch (kind) { + case StructuredAccStoreClauseKind::UnitFlag: + parseResult = parseStructuredAccStoreUnitFlag(parser, state); + break; + case StructuredAccStoreClauseKind::PreQuant: + parseResult = parseStructuredAccStorePreQuant(parser, state); + break; + case StructuredAccStoreClauseKind::PreRelu: + parseResult = parseStructuredAccStorePreRelu(parser, state); + break; + case StructuredAccStoreClauseKind::Layout: + parseResult = parseStructuredAccStoreLayout(parser, state, keyword); + break; + case StructuredAccStoreClauseKind::Loop3: + parseResult = parseStructuredAccStoreLoop3(parser, state); + break; + case StructuredAccStoreClauseKind::Sat: + if (state.satMode) + return parser.emitError(parser.getCurrentLocation(), "duplicate sat/nosat clause"); + if (keyword == "nosat") { + state.satMode = AccStoreSatMode::NoSat; + break; + } + if (succeeded(parser.parseOptionalLParen())) { + StringRef satOption; + if (parser.parseKeyword(&satOption) || satOption != "preserve_nan") + return parser.emitError(parser.getCurrentLocation(), + "expected preserve_nan"); + if (parser.parseRParen()) + return failure(); + state.satMode = AccStoreSatMode::SatPreserveNan; + } else { + state.satMode = AccStoreSatMode::Sat; + } + break; + case StructuredAccStoreClauseKind::Atomic: + parseResult = parseStructuredAccStoreAtomic(parser, state); + break; + } + if (failed(parseResult)) + return failure(); + } +} + +static ParseResult parseStructuredOptionalType(OpAsmParser &parser, + SmallVectorImpl &types) { + Type type; + if (parser.parseType(type)) + return failure(); + types.push_back(type); + return success(); +} + +static LogicalResult verifyStructuredAccStoreLike( + Operation *op, Type srcType, Type dstType, Value preQuant, Value preRelu, + Value clipValue, + Value split, Value loop0SrcStride, Value loop3Count, Value loop3SrcStride, + Value loop3DstStride, + std::optional unitFlag, + std::optional preQuantMode, + std::optional preReluMode, std::optional mode, + std::optional atomicType, + std::optional atomicOp, bool allowAtomic) { + auto getBufferElementType = [](Type type) -> Type { + if (auto ptrType = dyn_cast(type)) + return ptrType.getElementType(); + if (auto memrefType = dyn_cast(type)) + return memrefType.getElementType(); + return {}; + }; + Type sourceElementType = getBufferElementType(srcType); + Type destinationElementType = getBufferElementType(dstType); + + if (static_cast(preQuant) != static_cast(preQuantMode)) + return op->emitOpError("pre_quant requires payload and mode together"); + if (preQuantMode) { + if (isStructuredAccStoreVectorQuantMode(*preQuantMode)) { + if (!isStructuredAccStoreScalingPayload(preQuant)) + return op->emitOpError("vector pre_quant mode requires scaling pointer payload"); + if (!isStructuredAccStoreFloatScalarPayloadType( + getStructuredAccStoreScalingElementType(preQuant))) + return op->emitOpError( + "vector pre_quant mode requires scaling pointer element type to be f16, bf16, or f32"); + } else if (!isStructuredAccStoreFloatScalarPayload(preQuant)) { + return op->emitOpError( + "scalar pre_quant mode requires f16/bf16/f32 payload"); + } + + auto emitIncompatibleQuantModeError = [&]() -> LogicalResult { + return op->emitOpError() + << "pre_quant mode " << stringifyAccStoreQuantPreMode(*preQuantMode) + << " is incompatible with source element type " << sourceElementType + << " and destination element type " << destinationElementType; + }; + + if (isa(sourceElementType)) { + if (!isStructuredAccStoreFloatPreQuantMode(*preQuantMode)) + return emitIncompatibleQuantModeError(); + } else if (sourceElementType.isSignlessInteger(32)) { + if (!isStructuredAccStoreInt32PreQuantMode(*preQuantMode)) + return emitIncompatibleQuantModeError(); + } else { + return op->emitOpError() + << "pre_quant requires source element type to be f32 or i32, got " + << sourceElementType; + } + } + + if (clipValue && !isStructuredAccStoreClipSupportedElementType(destinationElementType)) + return op->emitOpError() + << "clip requires destination element type to be f16, ui8, or signed 4/8/16-bit integer, got " + << destinationElementType; + if (failed(verifyStructuredAccStoreClipPayload(op, destinationElementType, + clipValue))) + return failure(); + + if (!preReluMode) { + if (preRelu) + return op->emitOpError("pre_relu payload requires pre_relu mode"); + if (clipValue) + return op->emitOpError("clip requires pre_relu clause"); + } else { + switch (*preReluMode) { + case ReluPreMode::NoRelu: + if (preRelu) + return op->emitOpError("mode does not accept pre_relu payload"); + break; + case ReluPreMode::NormalRelu: + if (preRelu) + return op->emitOpError("mode does not accept pre_relu payload"); + break; + case ReluPreMode::ScalarRelu: + if (!preRelu) + return op->emitOpError("scalar_relu requires payload"); + if (!isStructuredAccStoreFloatScalarPayload(preRelu)) + return op->emitOpError("scalar_relu requires f16/bf16/f32 payload"); + break; + case ReluPreMode::VectorRelu: + if (!preRelu) + return op->emitOpError("vector_relu requires payload"); + if (!isStructuredAccStoreScalingPayload(preRelu)) + return op->emitOpError("vector_relu requires scaling pointer payload"); + if (!isStructuredAccStoreFloatScalarPayloadType( + getStructuredAccStoreScalingElementType(preRelu))) + return op->emitOpError( + "vector_relu requires scaling pointer element type to be f16, bf16, or f32"); + break; + case ReluPreMode::Pwl: + return op->emitOpError("pwl is not supported for target_profile mte_l0c_l1"); + } + } + + bool hasLoop3 = static_cast(loop3Count) || static_cast(loop3SrcStride) || + static_cast(loop3DstStride); + if (hasLoop3 && !(loop3Count && loop3SrcStride && loop3DstStride)) + return op->emitOpError("loop3 requires count, src stride, and dst stride together"); + + if (!mode) { + if (split) + return op->emitOpError("split requires nz2nz"); + if (loop0SrcStride) + return op->emitOpError("loop0_src_stride requires nz2dn"); + if (loop3Count) + return op->emitOpError("loop3 requires nz2nd or nz2dn"); + } else { + switch (*mode) { + case AccStoreMode::Nz2nd: + if (split) + return op->emitOpError("nz2nd does not accept split"); + if (loop0SrcStride) + return op->emitOpError("nz2nd does not accept loop0_src_stride"); + break; + case AccStoreMode::Nz2dn: { + if (!loop0SrcStride) + return op->emitOpError("nz2dn requires loop0_src_stride"); + if (split) + return op->emitOpError("nz2dn does not accept split"); + APInt loop0Value; + if (unitFlag && *unitFlag != AccStoreUnitFlagCtrl::Off && + (!matchPattern(loop0SrcStride, m_ConstantInt(&loop0Value)) || + !loop0Value.isOne())) { + return op->emitOpError( + "unit_flag must be off when nz2dn loop0_src_stride is not 1"); + } + break; + } + case AccStoreMode::Nz2nz: + if (loop0SrcStride) + return op->emitOpError("nz2nz does not accept loop0_src_stride"); + if (loop3Count) + return op->emitOpError("loop3 requires nz2nd or nz2dn"); + if (!isa(destinationElementType) || + !cast(destinationElementType).isF32()) + return op->emitOpError("nz2nz requires destination element type to be f32"); + break; + } + } + + if (static_cast(atomicType) != static_cast(atomicOp)) + return op->emitOpError("atomic requires type and op together"); + if ((atomicType || atomicOp) && !allowAtomic) + return op->emitOpError("atomic is only supported for mte_l0c_gm"); + + return success(); +} + +static void printStructuredAccStoreClauses( + OpAsmPrinter &printer, std::optional unitFlag, + Value preQuant, + std::optional preQuantMode, Value preRelu, + std::optional preReluMode, Value clipValue, + std::optional mode, Value split, Value loop0SrcStride, + Value loop3Count, Value loop3SrcStride, Value loop3DstStride, + std::optional satMode, + std::optional atomicType, + std::optional atomicOp) { + if (unitFlag && *unitFlag != AccStoreUnitFlagCtrl::Off) { + printer << ", unit_flag(" + << (*unitFlag == AccStoreUnitFlagCtrl::CheckOnly ? "check_only" + : "check_and_clear") + << ")"; + } + if (preQuantMode) { + printer << ", pre_quant(" << preQuant << ", mode = " + << stringifyAccStoreQuantPreMode(*preQuantMode) << ")"; + } + if (preReluMode) { + printer << ", pre_relu("; + if (preRelu) + printer << preRelu << ", "; + printer << "mode = " << stringifyReluPreMode(*preReluMode); + if (clipValue) + printer << ", clip = " << clipValue; + printer << ")"; + } + if (mode) { + switch (*mode) { + case AccStoreMode::Nz2nd: + printer << ", nz2nd"; + break; + case AccStoreMode::Nz2dn: + printer << ", nz2dn"; + if (loop0SrcStride) + printer << "(" << loop0SrcStride << ")"; + break; + case AccStoreMode::Nz2nz: + printer << ", nz2nz"; + if (split) + printer << "(" << split << ")"; + break; + } + } + if (loop3Count) { + printer << ", loop3(" << loop3Count << ", " << loop3SrcStride << ", " + << loop3DstStride << ")"; + } + if (satMode) { + switch (*satMode) { + case AccStoreSatMode::Sat: + printer << ", sat"; + break; + case AccStoreSatMode::NoSat: + printer << ", nosat"; + break; + case AccStoreSatMode::SatPreserveNan: + printer << ", sat(preserve_nan)"; + break; + } + } + if (atomicType && atomicOp) { + printer << ", atomic(type = " << stringifyAccStoreAtomicType(*atomicType) + << ", op = " << stringifyAccStoreAtomicOp(*atomicOp) << ")"; + } +} + +static void printStructuredAccStoreOptionalTypes( + OpAsmPrinter &printer, Value preQuant, Value preRelu, Value clipValue, + Value split, Value loop0SrcStride, Value loop3Count, Value loop3SrcStride, + Value loop3DstStride) { + if (preQuant) + printer << ", " << preQuant.getType(); + if (preRelu) + printer << ", " << preRelu.getType(); + if (clipValue) + printer << ", " << clipValue.getType(); + if (split) + printer << ", " << split.getType(); + if (loop0SrcStride) + printer << ", " << loop0SrcStride.getType(); + if (loop3Count) + printer << ", " << loop3Count.getType() << ", " << loop3SrcStride.getType() + << ", " << loop3DstStride.getType(); +} + +static ParseResult parseStructuredAccStoreTailTypes( + OpAsmParser &parser, StructuredAccStoreAsmState &state) { + if (!state.preQuantOperands.empty() && + (parser.parseComma() || + parseStructuredOptionalType(parser, state.preQuantTypes))) + return failure(); + if (!state.preReluOperands.empty() && + (parser.parseComma() || + parseStructuredOptionalType(parser, state.preReluTypes))) + return failure(); + if (!state.clipValueOperands.empty() && + (parser.parseComma() || + parseStructuredOptionalType(parser, state.clipValueTypes))) + return failure(); + if (!state.splitOperands.empty() && + (parser.parseComma() || + parseStructuredOptionalType(parser, state.splitTypes))) + return failure(); + if (!state.loop0SrcStrideOperands.empty() && + (parser.parseComma() || + parseStructuredOptionalType(parser, state.loop0SrcStrideTypes))) + return failure(); + if (!state.loop3CountOperands.empty() && + (parser.parseComma() || + parseStructuredOptionalType(parser, state.loop3CountTypes) || + parser.parseComma() || + parseStructuredOptionalType(parser, state.loop3SrcStrideTypes) || + parser.parseComma() || + parseStructuredOptionalType(parser, state.loop3DstStrideTypes))) + return failure(); + return success(); +} + +template +static void setStructuredAccStoreSegmentSizes(OperationState &result, + ArrayRef segmentSizes) { + auto &segments = result.getOrAddProperties() + .operandSegmentSizes; + llvm::copy(segmentSizes, segments.begin()); +} + +template +static void addStructuredAccStoreAttrs(OperationState &result, + Builder &builder, + const StructuredAccStoreAsmState &state) { + if (state.mode) + result.addAttribute("mode", AccStoreModeAttr::get(builder.getContext(), + *state.mode)); + if (state.unitFlag) + result.addAttribute("unit_flag", + AccStoreUnitFlagCtrlAttr::get(builder.getContext(), + *state.unitFlag)); + if (state.preQuantMode) + result.addAttribute("pre_quant_mode", + AccStoreQuantPreModeAttr::get(builder.getContext(), + *state.preQuantMode)); + if (state.preReluMode) + result.addAttribute("pre_relu_mode", + ReluPreModeAttr::get(builder.getContext(), + *state.preReluMode)); + if (state.atomicType) + result.addAttribute("atomic_type", + AccStoreAtomicTypeAttr::get(builder.getContext(), + *state.atomicType)); + if (state.atomicOp) + result.addAttribute("atomic_op", + AccStoreAtomicOpAttr::get(builder.getContext(), + *state.atomicOp)); + if (state.satMode) + result.addAttribute("sat_mode", + AccStoreSatModeAttr::get(builder.getContext(), + *state.satMode)); +} + +[[maybe_unused]] static ParseResult resolveStructuredMteL0cL1OptionalOperands( + OpAsmParser &parser, StructuredAccStoreAsmState &state, + SmallVectorImpl &resolvedOperands, OperationState &result) { + auto location = parser.getCurrentLocation(); + if (parser.resolveOperands(state.preQuantOperands, state.preQuantTypes, + location, result.operands) || + parser.resolveOperands(state.preReluOperands, state.preReluTypes, + location, result.operands) || + parser.resolveOperands(state.clipValueOperands, state.clipValueTypes, + location, result.operands) || + parser.resolveOperands(state.splitOperands, state.splitTypes, location, + result.operands) || + parser.resolveOperands(state.loop0SrcStrideOperands, + state.loop0SrcStrideTypes, location, + result.operands) || + parser.resolveOperands(state.loop3CountOperands, state.loop3CountTypes, + location, result.operands) || + parser.resolveOperands(state.loop3SrcStrideOperands, + state.loop3SrcStrideTypes, location, + result.operands) || + parser.resolveOperands(state.loop3DstStrideOperands, + state.loop3DstStrideTypes, location, + result.operands)) + return failure(); + + auto extractResolved = [&](SmallVectorImpl &ops, + SmallVectorImpl &types) -> Value { + if (ops.empty()) + return {}; + return result.operands[resolvedOperands.size()]; + }; + resolvedOperands.push_back(extractResolved(state.preQuantOperands, + state.preQuantTypes)); + resolvedOperands.push_back(extractResolved(state.preReluOperands, + state.preReluTypes)); + resolvedOperands.push_back(extractResolved(state.clipValueOperands, + state.clipValueTypes)); + resolvedOperands.push_back(extractResolved(state.splitOperands, + state.splitTypes)); + resolvedOperands.push_back(extractResolved(state.loop0SrcStrideOperands, + state.loop0SrcStrideTypes)); + resolvedOperands.push_back(extractResolved(state.loop3CountOperands, + state.loop3CountTypes)); + resolvedOperands.push_back(extractResolved(state.loop3SrcStrideOperands, + state.loop3SrcStrideTypes)); + resolvedOperands.push_back(extractResolved(state.loop3DstStrideOperands, + state.loop3DstStrideTypes)); + return success(); +} + +template +static LogicalResult verifyCopyGmToUbufOp(CopyOp op, bool expectSourceGM) { + if (!isBufferLike(op.getSource().getType()) || + !isBufferLike(op.getDestination().getType())) + return op.emitOpError( + "requires typed !pto.ptr or memref source and destination"); + + MemoryRole sourceRole = classifyMemoryRole(op.getSource().getType()); + MemoryRole destinationRole = classifyMemoryRole(op.getDestination().getType()); + bool directionMatches = true; + if (expectSourceGM) { + directionMatches &= sourceRole != MemoryRole::UB; + directionMatches &= destinationRole != MemoryRole::GM; + } else { + directionMatches &= sourceRole != MemoryRole::GM; + directionMatches &= destinationRole != MemoryRole::UB; + } + + if (!directionMatches) { + return op.emitOpError() + << "requires " + << (expectSourceGM ? "GM source and UB destination" + : "UB source and GM destination"); + } + + int64_t sourceElemBytes = getBufferElementByteSize(op.getSource().getType()); + int64_t destinationElemBytes = + getBufferElementByteSize(op.getDestination().getType()); + if (sourceElemBytes <= 0 || destinationElemBytes <= 0) + return op.emitOpError("requires copy source and destination element types with known byte width"); + if (sourceElemBytes != destinationElemBytes) + return op.emitOpError("requires source and destination element byte widths to match"); + + return success(); +} + +template +static LogicalResult verifyOptionalDmaLoopGroup(DmaOp op, Value count, + Value srcStride, + Value dstStride, + StringRef name) { + if (hasAny(count, srcStride, dstStride) && !hasAll(count, srcStride, dstStride)) + return op.emitOpError() << "requires " << name + << " group to provide count, src stride, and dst stride together"; + return success(); +} + +static LogicalResult verifyDmaLoadStoreLoopGroups(Operation *op, + ValueRange loopCounts, + ValueRange loopSrcStrides, + ValueRange loopDstStrides) { + if (loopCounts.size() != loopSrcStrides.size() || + loopCounts.size() != loopDstStrides.size()) + return op->emitOpError() + << "requires each loop group to provide count, src stride, and dst stride together"; + return success(); +} + +template +static LogicalResult verifyCopyUbufToGmOp(CopyOp op, bool expectSourceGM) { + if (!isBufferLike(op.getSource().getType()) || + !isBufferLike(op.getDestination().getType())) + return op.emitOpError( + "requires typed !pto.ptr or memref source and destination"); + + MemoryRole sourceRole = classifyMemoryRole(op.getSource().getType()); + MemoryRole destinationRole = classifyMemoryRole(op.getDestination().getType()); + bool directionMatches = true; + if (expectSourceGM) { + directionMatches &= sourceRole != MemoryRole::UB; + directionMatches &= destinationRole != MemoryRole::GM; + } else { + directionMatches &= sourceRole != MemoryRole::GM; + directionMatches &= destinationRole != MemoryRole::UB; + } + + if (!directionMatches) { + return op.emitOpError() + << "requires " + << (expectSourceGM ? "GM source and UB destination" + : "UB source and GM destination"); + } + + int64_t sourceElemBytes = getBufferElementByteSize(op.getSource().getType()); + int64_t destinationElemBytes = + getBufferElementByteSize(op.getDestination().getType()); + if (sourceElemBytes <= 0 || destinationElemBytes <= 0) + return op.emitOpError("requires copy source and destination element types with known byte width"); + if (sourceElemBytes != destinationElemBytes) + return op.emitOpError("requires source and destination element byte widths to match"); + + return success(); +} + +template +static LogicalResult verifyCopyCbufToUbufLikeOp(CopyOp op) { + if (!isBufferLike(op.getSource().getType()) || + !isBufferLike(op.getDestination().getType())) + return op.emitOpError( + "requires typed !pto.ptr or memref source and destination"); + + if (classifyMemoryRole(op.getSource().getType()) != MemoryRole::Other || + classifyMemoryRole(op.getDestination().getType()) != MemoryRole::UB) + return op.emitOpError("requires CBUF source and UB destination"); + + int64_t sourceElemBytes = getBufferElementByteSize(op.getSource().getType()); + int64_t destinationElemBytes = + getBufferElementByteSize(op.getDestination().getType()); + if (sourceElemBytes <= 0 || destinationElemBytes <= 0) + return op.emitOpError("requires copy source and destination element types with known byte width"); + if (sourceElemBytes != destinationElemBytes) + return op.emitOpError("requires source and destination element byte widths to match"); + + return success(); +} + +Type VRegType::parse(AsmParser &parser) { + SmallVector shape; + Type elementType; + SMLoc loc = parser.getCurrentLocation(); + + if (failed(parser.parseLess()) || + failed(parser.parseDimensionList(shape, /*allowDynamic=*/false, + /*withTrailingX=*/true)) || + shape.size() != 1 || failed(parser.parseType(elementType)) || + failed(parser.parseGreater())) + return {}; + + return parser.getChecked(loc, parser.getContext(), shape.front(), + elementType); +} + +void VRegType::print(AsmPrinter &printer) const { + printer << "<" << getElementCount() << "x"; + printer.printType(getElementType()); + printer << ">"; +} + +LogicalResult VRegType::verify(function_ref emitError, + int64_t elementCount, Type elementType) { + if (elementCount <= 0) + return emitError() << "'" << formatVRegType(elementCount, elementType) + << "' expected a positive element count"; + + auto intOrFloat = mlir::dyn_cast(elementType); + unsigned elementBitWidth = 0; + if (intOrFloat) { + elementBitWidth = intOrFloat.getWidth(); + } else if (auto floatType = mlir::dyn_cast(elementType)) { + elementBitWidth = floatType.getWidth(); + } else { + return emitError() << "'" << formatVRegType(elementCount, elementType) + << "' expected an integer or floating-point element type"; + } + + if (elementCount * static_cast(elementBitWidth) != 2048) + return emitError() << "'" << formatVRegType(elementCount, elementType) + << "' expected exactly 256 bytes"; + + return success(); +} + +LogicalResult VecScopeOp::verify() { + Region &bodyRegion = getBody(); + if (bodyRegion.empty()) + return emitOpError("expects a non-empty body region"); + + Block &body = bodyRegion.front(); + if (body.getNumArguments() != 0) + return emitOpError() << "expects body block to have no arguments, got " + << body.getNumArguments(); + + return success(); +} + +LogicalResult StrictVecScopeOp::verify() { + Region &bodyRegion = getBody(); + if (bodyRegion.empty()) + return emitOpError("expects a non-empty body region"); + + Block &body = bodyRegion.front(); + if (body.getNumArguments() != getCaptures().size()) + return emitOpError() << "expects body block to have " + << getCaptures().size() + << " arguments to match explicit captures, got " + << body.getNumArguments(); + + for (auto [idx, pair] : + llvm::enumerate(llvm::zip(body.getArguments(), getCaptures()))) { + BlockArgument blockArg = std::get<0>(pair); + Value capture = std::get<1>(pair); + if (blockArg.getType() != capture.getType()) + return emitOpError() << "expects body block argument #" << idx + << " to have type " << capture.getType() + << ", got " << blockArg.getType(); + } + return success(); +} + +bool MaskType::isSupportedGranularity(StringRef granularity) { + return granularity == "b8" || granularity == "b16" || + granularity == "b32"; +} + +Type MaskType::parse(AsmParser &parser) { + auto loc = parser.getCurrentLocation(); + StringRef granularity; + if (failed(parser.parseLess()) || failed(parser.parseKeyword(&granularity)) || + failed(parser.parseGreater())) + return {}; + + return parser.getChecked(loc, parser.getContext(), granularity); +} + +void MaskType::print(AsmPrinter &printer) const { + printer << "<" << getGranularity() << ">"; +} + +LogicalResult +MaskType::verify(function_ref emitError, + StringRef granularity) { + if (!isSupportedGranularity(granularity)) + return emitError() << "'" << formatMaskType(granularity) + << "' expected granularity to be one of b8, b16, b32"; + return success(); +} + +void CopyGmToUbufOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult CopyGmToUbufOp::verify() { + return verifyCopyGmToUbufOp(*this, true); +} + +void MteGmUbOp::build(OpBuilder &builder, OperationState &state, Value source, + Value destination, Value l2CacheCtl, Value lenBurst, + pto::DmaLoopConfig nburst, + llvm::ArrayRef loops, + std::optional pad) { + state.addOperands({source, destination, l2CacheCtl, lenBurst, nburst.count, + nburst.srcStride, nburst.dstStride}); + for (const pto::DmaLoopConfig &loop : loops) + state.addOperands(loop.count); + for (const pto::DmaLoopConfig &loop : loops) + state.addOperands(loop.srcStride); + for (const pto::DmaLoopConfig &loop : loops) + state.addOperands(loop.dstStride); + bool hasPadCounts = pad && pad->leftCount && pad->rightCount; + assert((!pad || static_cast(pad->leftCount) == + static_cast(pad->rightCount)) && + "mte_gm_ub pad config must provide both left and right counts, or omit both"); + if (pad) { + state.addOperands(pad->value); + if (hasPadCounts) + state.addOperands({pad->leftCount, pad->rightCount}); + } + + state.addAttribute( + getOperandSegmentSizeAttr(), + builder.getDenseI32ArrayAttr( + {1, 1, 1, 1, 1, 1, 1, + static_cast(loops.size()), + static_cast(loops.size()), + static_cast(loops.size()), + pad ? 1 : 0, hasPadCounts ? 1 : 0, hasPadCounts ? 1 : 0})); +} + +void MteGmUbOp::build(OpBuilder &builder, OperationState &state, Value source, + Value destination, Value l2CacheCtl, Value lenBurst, + pto::DmaLoopConfig nburst, + std::optional loop1, + std::optional loop2, + std::optional pad) { + SmallVector loops; + if (loop1) + loops.push_back(*loop1); + if (loop2) + loops.push_back(*loop2); + build(builder, state, source, destination, l2CacheCtl, lenBurst, nburst, + loops, pad); +} + +ParseResult MteGmUbOp::parse(OpAsmParser &parser, OperationState &result) { + OpAsmParser::UnresolvedOperand source, destination, l2CacheCtl, lenBurst; + SmallVector nburstOperands; + SmallVector loopCountOperands; + SmallVector loopSrcStrideOperands; + SmallVector loopDstStrideOperands; + SmallVector padOperands; + if (parseRequiredOperandWithComma(parser, source) || + parseRequiredOperandWithComma(parser, destination) || + parseRequiredOperandWithComma(parser, l2CacheCtl) || + parser.parseOperand(lenBurst) || + parseDmaTripleGroup(parser, "nburst", nburstOperands)) + return failure(); + while (true) { + if (succeeded(parser.parseOptionalKeyword("pad"))) { + if (parser.parseLParen()) + return failure(); + OpAsmParser::UnresolvedOperand value; + if (parser.parseOperand(value)) + return failure(); + padOperands.push_back(value); + if (succeeded(parser.parseOptionalComma())) { + OpAsmParser::UnresolvedOperand left; + OpAsmParser::UnresolvedOperand right; + if (parser.parseOperand(left) || parser.parseComma() || + parser.parseOperand(right)) + return failure(); + padOperands.push_back(left); + padOperands.push_back(right); + } + if (parser.parseRParen()) + return failure(); + break; + } + + StringRef parsedKeyword; + SmallVector loopGroupOperands; + if (parseOptionalDmaTripleGroupAlias(parser, {"loop", "loop1", "loop2"}, + parsedKeyword, loopGroupOperands)) + return failure(); + if (parsedKeyword.empty()) + break; + loopCountOperands.push_back(loopGroupOperands[0]); + loopSrcStrideOperands.push_back(loopGroupOperands[1]); + loopDstStrideOperands.push_back(loopGroupOperands[2]); + } + + if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()) + return failure(); + + Type sourceType, destinationType, l2CacheCtlType, lenBurstType; + SmallVector nburstTypes, loopCountTypes, loopSrcStrideTypes, + loopDstStrideTypes, padTypes; + if (parser.parseType(sourceType) || parser.parseComma() || + parser.parseType(destinationType) || parser.parseComma() || + parser.parseType(l2CacheCtlType) || parser.parseComma() || + parser.parseType(lenBurstType) || parser.parseComma() || + parseDmaTripleTypes(parser, nburstTypes)) + return failure(); + while (succeeded(parser.parseOptionalComma())) { + StringRef keyword; + if (parser.parseKeyword(&keyword)) + return failure(); + if (isDmaLoopKeyword(keyword)) { + SmallVector loopGroupTypes; + if (parseDmaTripleTypes(parser, loopGroupTypes)) + return failure(); + loopCountTypes.push_back(loopGroupTypes[0]); + loopSrcStrideTypes.push_back(loopGroupTypes[1]); + loopDstStrideTypes.push_back(loopGroupTypes[2]); + continue; + } + if (keyword == "pad") { + if (!padTypes.empty() || parseDmaPadTypes(parser, padTypes)) + return failure(); + continue; + } + return parser.emitError(parser.getCurrentLocation(), + "expected one of 'loop' or 'pad'"); + } + + int32_t loopGroupCount = static_cast(loopCountOperands.size()); + if (loopCountOperands.size() != loopSrcStrideOperands.size() || + loopCountOperands.size() != loopDstStrideOperands.size() || + loopCountTypes.size() != loopSrcStrideTypes.size() || + loopCountTypes.size() != loopDstStrideTypes.size()) + return parser.emitError(parser.getCurrentLocation(), + "requires each loop group to provide count, src stride, and dst stride"); + if (loopCountOperands.size() != loopCountTypes.size()) + return parser.emitError(parser.getCurrentLocation(), + "requires loop operand and type groups to match"); + + auto &segments = + result.getOrAddProperties().operandSegmentSizes; + llvm::copy(ArrayRef{1, 1, 1, 1, 1, 1, 1, + loopGroupCount, loopGroupCount, loopGroupCount, + static_cast(padOperands.size() ? 1 : 0), + static_cast(padOperands.size() == 3 ? 1 : 0), + static_cast(padOperands.size() == 3 ? 1 : 0)}, + segments.begin()); + + if (parser.resolveOperand(source, sourceType, result.operands) || + parser.resolveOperand(destination, destinationType, result.operands) || + parser.resolveOperand(l2CacheCtl, l2CacheCtlType, result.operands) || + parser.resolveOperand(lenBurst, lenBurstType, result.operands) || + parser.resolveOperands(nburstOperands, nburstTypes, parser.getCurrentLocation(), + result.operands) || + parser.resolveOperands(loopCountOperands, loopCountTypes, + parser.getCurrentLocation(), result.operands) || + parser.resolveOperands(loopSrcStrideOperands, loopSrcStrideTypes, + parser.getCurrentLocation(), result.operands) || + parser.resolveOperands(loopDstStrideOperands, loopDstStrideTypes, + parser.getCurrentLocation(), + result.operands) || + parser.resolveOperands(padOperands, padTypes, parser.getCurrentLocation(), + result.operands)) + return failure(); + return success(); +} + +void MteGmUbOp::print(OpAsmPrinter &printer) { + printer << " " << getSource() << ", " << getDestination() << ", " + << getL2CacheCtl() << ", " << getLenBurst(); + printDmaTripleGroup(printer, "nburst", getNBurst(), getNburstSrcStride(), + getNburstDstStride()); + for (auto [count, srcStride, dstStride] : + llvm::zip(getLoopCounts(), getLoopSrcStrides(), getLoopDstStrides())) + printDmaTripleGroup(printer, "loop", count, srcStride, dstStride); + if (getPadValue()) + printDmaPadGroup(printer, getPadValue(), getLeftPaddingCount(), + getRightPaddingCount()); + printer.printOptionalAttrDict((*this)->getAttrs()); + printer << " : " << getSource().getType() << ", " << getDestination().getType() + << ", " << getL2CacheCtl().getType() << ", " << getLenBurst().getType() + << ", " << getNBurst().getType() << ", " << getNburstSrcStride().getType() + << ", " + << getNburstDstStride().getType(); + for (auto [count, srcStride, dstStride] : + llvm::zip(getLoopCounts(), getLoopSrcStrides(), getLoopDstStrides())) + printDmaTripleTypes(printer, "loop", count.getType(), srcStride.getType(), + dstStride.getType()); + if (getPadValue()) + printDmaPadTypes(printer, getPadValue().getType(), + getLeftPaddingCount() ? getLeftPaddingCount().getType() : Type{}, + getRightPaddingCount() ? getRightPaddingCount().getType() : Type{}); +} + +void MteGmUbOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult MteGmUbOp::verify() { + if (failed(verifyCopyGmToUbufOp(*this, true))) + return failure(); + if (failed(verifyDmaLoadStoreLoopGroups( + getOperation(), getLoopCounts(), getLoopSrcStrides(), + getLoopDstStrides()))) + return failure(); + if (!getPadValue() && (getLeftPaddingCount() || getRightPaddingCount())) + return emitOpError() << "requires pad group to provide a pad value"; + if (getPadValue() && static_cast(getLeftPaddingCount()) != + static_cast(getRightPaddingCount())) + return emitOpError() + << "requires pad group to provide both left and right counts, or omit both"; + if (Value padValue = getPadValue()) { + Type valueType = padValue.getType(); + if (!isSupportedMovPadScalarType(valueType)) + return emitOpError() + << "expects pad value to be i8/i16/i32 or f16/bf16/f32 scalar, but got " + << valueType; + } + return success(); +} + +LogicalResult SetMovPadValOp::verify() { + Type valueType = getValue().getType(); + if (isSupportedMovPadScalarType(valueType)) + return success(); + return emitOpError() + << "expects i8/i16/i32 or f16/bf16/f32 scalar operand, but got " + << valueType; +} +void MadOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getLhsMutable()); + effects.emplace_back(MemoryEffects::Read::get(), &getRhsMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDstMutable()); +} + +static LogicalResult verifyMadPointerKinds(Operation *op, Type lhsTy, Type rhsTy, + Type dstTy, + std::optional biasTy = + std::nullopt) { + auto lhsType = dyn_cast(lhsTy); + auto rhsType = dyn_cast(rhsTy); + auto dstType = dyn_cast(dstTy); + if (!lhsType || !rhsType || !dstType) + return op->emitOpError("requires typed !pto.ptr lhs/rhs/dst operands"); + + const auto lhsAS = lhsType.getMemorySpace().getAddressSpace(); + const auto rhsAS = rhsType.getMemorySpace().getAddressSpace(); + const auto dstAS = dstType.getMemorySpace().getAddressSpace(); + + const bool isStrongCube = + lhsAS == pto::AddressSpace::LEFT && rhsAS == pto::AddressSpace::RIGHT && + dstAS == pto::AddressSpace::ACC; + if (!isStrongCube) + return op->emitOpError("requires l0a/l0b/l0c-typed lhs/rhs/dst pointers"); + + if (!biasTy) + return success(); + + auto biasType = dyn_cast(*biasTy); + if (!biasType) + return op->emitOpError("requires typed !pto.ptr bias operand"); + if (biasType.getMemorySpace().getAddressSpace() != pto::AddressSpace::BIAS) { + return op->emitOpError("requires bias pointer in !pto.ptr<..., bt>"); + } + if (biasType.getElementType() != dstType.getElementType()) { + return op->emitOpError("requires bias element type to match dst element type"); + } + return success(); +} + +void MadAccOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getLhsMutable()); + effects.emplace_back(MemoryEffects::Read::get(), &getRhsMutable()); + effects.emplace_back(MemoryEffects::Read::get(), &getDstMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDstMutable()); +} + + +void MadBiasOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getLhsMutable()); + effects.emplace_back(MemoryEffects::Read::get(), &getRhsMutable()); + effects.emplace_back(MemoryEffects::Read::get(), &getBiasMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDstMutable()); +} + +static LogicalResult verifyMadMxCommon(Operation *op, Type lhsTy, Type rhsTy, + Type dstTy, + std::optional biasTy = + std::nullopt) { + if (failed(verifyMadPointerKinds(op, lhsTy, rhsTy, dstTy, biasTy))) + return failure(); + + auto lhsType = cast(lhsTy); + auto rhsType = cast(rhsTy); + auto dstType = cast(dstTy); + const auto lhsAS = lhsType.getMemorySpace().getAddressSpace(); + const auto rhsAS = rhsType.getMemorySpace().getAddressSpace(); + const auto dstAS = dstType.getMemorySpace().getAddressSpace(); + const bool isStrongCube = + lhsAS == pto::AddressSpace::LEFT && rhsAS == pto::AddressSpace::RIGHT && + dstAS == pto::AddressSpace::ACC; + if (!isStrongCube) + return op->emitOpError("requires l0a/l0b/l0c-typed lhs/rhs/dst pointers"); + + if (!isMxElementType(lhsType.getElementType()) || + !isMxElementType(rhsType.getElementType())) { + return op->emitOpError( + "requires MX lhs/rhs element types (currently f8E4M3FN)"); + } + return success(); +} + +void MadMxOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getLhsMutable()); + effects.emplace_back(MemoryEffects::Read::get(), &getRhsMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDstMutable()); +} + + +void MadMxAccOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getLhsMutable()); + effects.emplace_back(MemoryEffects::Read::get(), &getRhsMutable()); + effects.emplace_back(MemoryEffects::Read::get(), &getDstMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDstMutable()); +} + + +void MadMxBiasOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getLhsMutable()); + effects.emplace_back(MemoryEffects::Read::get(), &getRhsMutable()); + effects.emplace_back(MemoryEffects::Read::get(), &getBiasMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDstMutable()); +} + +static std::optional +parseMadUnitFlagModeToken(StringRef token) { + if (token == "check_only") + return pto::MadUnitFlagMode::CheckOnly; + if (token == "check_and_set") + return pto::MadUnitFlagMode::CheckAndSet; + return std::nullopt; +} + +static StringRef stringifyMadUnitFlagModeToken(pto::MadUnitFlagMode mode) { + switch (mode) { + case pto::MadUnitFlagMode::CheckOnly: + return "check_only"; + case pto::MadUnitFlagMode::CheckAndSet: + return "check_and_set"; + } + llvm_unreachable("unexpected mad unit flag mode"); +} + +static std::optional parseTf32ModeToken(StringRef token) { + if (token == "round_even") + return pto::Tf32Mode::RoundEven; + if (token == "round_away") + return pto::Tf32Mode::RoundAway; + return std::nullopt; +} + +static StringRef stringifyTf32ModeToken(pto::Tf32Mode mode) { + switch (mode) { + case pto::Tf32Mode::RoundEven: + return "round_even"; + case pto::Tf32Mode::RoundAway: + return "round_away"; + } + llvm_unreachable("unexpected tf32 mode"); +} + +static StringRef stringifyMadSatModeToken(pto::MadSatMode mode) { + switch (mode) { + case pto::MadSatMode::Sat: + return "sat"; + case pto::MadSatMode::NoSat: + return "nosat"; + } + llvm_unreachable("unexpected mad sat mode"); +} + +static LogicalResult verifyMadSemanticClauses(Operation *op, Type lhsTy, + Type rhsTy, Type dstTy, + std::optional biasTy, + std::optional tf32Mode, + std::optional satMode, + bool hasNDir) { + if (failed(verifyMadPointerKinds(op, lhsTy, rhsTy, dstTy, biasTy))) + return failure(); + + auto lhsType = dyn_cast(lhsTy); + auto rhsType = dyn_cast(rhsTy); + auto dstType = dyn_cast(dstTy); + if (!lhsType || !rhsType || !dstType) + return op->emitOpError("requires typed !pto.ptr lhs/rhs/dst operands"); + + if (tf32Mode) { + if (!(lhsType.getElementType().isF32() && rhsType.getElementType().isF32() && + dstType.getElementType().isF32())) { + return op->emitOpError( + "requires tf32_mode only for f32 lhs/rhs/dst element types"); + } + } + if (pto::isPTOHiFloat8Type(lhsType.getElementType()) != + pto::isPTOHiFloat8Type(rhsType.getElementType())) { + return op->emitOpError( + "requires lhs/rhs to both use hif8 or both use non-hif8 element types"); + } + if (satMode) { + auto isFloatLike = [](Type type) { + if (isa(type)) + return true; + return pto::isPTOLowPrecisionType(type); + }; + if (!(isFloatLike(lhsType.getElementType()) && + isFloatLike(rhsType.getElementType()) && + isFloatLike(dstType.getElementType()))) { + return op->emitOpError( + "requires sat/nosat only for floating lhs/rhs/dst element types"); + } + } + (void)hasNDir; + return success(); +} + +template +static ParseResult parseMadSemanticOpCommon(OpAsmParser &parser, + OperationState &result, + bool hasBias, + bool parseTf32ModeClause) { + OpAsmParser::UnresolvedOperand lhs, rhs, dst, bias; + OpAsmParser::UnresolvedOperand m, n, k; + StringRef unitFlagKeyword; + StringRef tf32Keyword; + NamedAttrList attrs; + + if (parseRequiredOperandWithComma(parser, lhs) || + parseRequiredOperandWithComma(parser, rhs) || + parseRequiredOperandWithComma(parser, dst) || + (hasBias && parseRequiredOperandWithComma(parser, bias)) || + parseRequiredOperandWithComma(parser, m) || + parseRequiredOperandWithComma(parser, n) || + parser.parseOperand(k)) + return failure(); + + auto parseUnitFlagClause = [&]() -> ParseResult { + if (failed(parser.parseOptionalKeyword("unit_flag"))) + return success(); + if (parser.parseLParen() || parser.parseKeyword(&unitFlagKeyword) || + parser.parseRParen()) + return failure(); + auto mode = parseMadUnitFlagModeToken(unitFlagKeyword); + if (!mode) + return parser.emitError(parser.getCurrentLocation()) + << "expected unit_flag(check_only|check_and_set)"; + attrs.set("unit_flag_mode", + pto::MadUnitFlagModeAttr::get(parser.getContext(), *mode)); + return success(); + }; + auto parseDisableGemvClause = [&]() -> ParseResult { + if (succeeded(parser.parseOptionalKeyword("disable_gemv"))) { + attrs.set("disable_gemv", UnitAttr::get(parser.getContext())); + } + return success(); + }; + auto parseSatClause = [&]() -> ParseResult { + if (succeeded(parser.parseOptionalKeyword("sat"))) { + attrs.set("sat_mode", + pto::MadSatModeAttr::get(parser.getContext(), + pto::MadSatMode::Sat)); + return success(); + } + if (succeeded(parser.parseOptionalKeyword("nosat"))) { + attrs.set("sat_mode", + pto::MadSatModeAttr::get(parser.getContext(), + pto::MadSatMode::NoSat)); + } + return success(); + }; + auto parseTf32Clause = [&]() -> ParseResult { + if (!parseTf32ModeClause) + return success(); + if (failed(parser.parseOptionalKeyword("tf32_mode"))) + return success(); + if (parser.parseLParen() || parser.parseKeyword(&tf32Keyword) || + parser.parseRParen()) + return failure(); + auto mode = parseTf32ModeToken(tf32Keyword); + if (!mode) + return parser.emitError(parser.getCurrentLocation()) + << "expected tf32_mode(round_even|round_away)"; + attrs.set("tf32_mode", pto::Tf32ModeAttr::get(parser.getContext(), *mode)); + return success(); + }; + auto parseNDirClause = [&]() -> ParseResult { + if (succeeded(parser.parseOptionalKeyword("n_dir"))) { + attrs.set("n_dir", UnitAttr::get(parser.getContext())); + } + return success(); + }; + + if (failed(parseUnitFlagClause()) || failed(parseDisableGemvClause()) || + failed(parseSatClause()) || failed(parseTf32Clause()) || + failed(parseNDirClause())) + return failure(); + + if (parser.parseOptionalAttrDict(attrs) || parser.parseColon()) + return failure(); + + Type lhsType, rhsType, dstType, mType, nType, kType, biasType; + if (parser.parseType(lhsType) || parser.parseComma() || + parser.parseType(rhsType) || parser.parseComma() || + parser.parseType(dstType) || parser.parseComma()) + return failure(); + if (hasBias) { + if (parser.parseType(biasType) || parser.parseComma()) + return failure(); + } + if (parser.parseType(mType) || parser.parseComma() || parser.parseType(nType) || + parser.parseComma() || parser.parseType(kType)) + return failure(); + + result.addAttributes(attrs); + if (hasBias) { + if (parser.resolveOperand(lhs, lhsType, result.operands) || + parser.resolveOperand(rhs, rhsType, result.operands) || + parser.resolveOperand(dst, dstType, result.operands) || + parser.resolveOperand(bias, biasType, result.operands) || + parser.resolveOperand(m, mType, result.operands) || + parser.resolveOperand(n, nType, result.operands) || + parser.resolveOperand(k, kType, result.operands)) + return failure(); + } else { + if (parser.resolveOperand(lhs, lhsType, result.operands) || + parser.resolveOperand(rhs, rhsType, result.operands) || + parser.resolveOperand(dst, dstType, result.operands) || + parser.resolveOperand(m, mType, result.operands) || + parser.resolveOperand(n, nType, result.operands) || + parser.resolveOperand(k, kType, result.operands)) + return failure(); + } + return success(); +} + +static void printMadSemanticClauses(OpAsmPrinter &printer, Operation *op, + bool allowTf32Mode) { + if (auto unitFlagMode = op->getAttrOfType( + "unit_flag_mode")) { + printer << " unit_flag(" + << stringifyMadUnitFlagModeToken(unitFlagMode.getValue()) << ")"; + } + if (op->hasAttr("disable_gemv")) + printer << " disable_gemv"; + if (auto satMode = op->getAttrOfType("sat_mode")) + printer << ' ' << stringifyMadSatModeToken(satMode.getValue()); + if (allowTf32Mode) { + if (auto tf32Mode = op->getAttrOfType("tf32_mode")) { + printer << " tf32_mode(" << stringifyTf32ModeToken(tf32Mode.getValue()) + << ")"; + } + } + if (op->hasAttr("n_dir")) + printer << " n_dir"; +} + +static ArrayRef getMadSemanticElidedAttrs(bool allowTf32Mode) { + static constexpr StringRef kWithTf32[] = {"unit_flag_mode", "disable_gemv", + "sat_mode", "tf32_mode", "n_dir"}; + static constexpr StringRef kWithoutTf32[] = {"unit_flag_mode", + "disable_gemv", "sat_mode", + "n_dir"}; + return allowTf32Mode ? ArrayRef(kWithTf32) + : ArrayRef(kWithoutTf32); +} + +template +static void printMadSemanticOpNoBias(OpAsmPrinter &printer, OpT op, + bool allowTf32Mode) { + printer << ' ' << op.getLhs() << ", " << op.getRhs() << ", " << op.getDst() + << ", " << op.getM() << ", " << op.getN() << ", " << op.getK(); + printMadSemanticClauses(printer, op, allowTf32Mode); + printer.printOptionalAttrDict(op->getAttrs(), + getMadSemanticElidedAttrs(allowTf32Mode)); + printer << " : " << op.getLhs().getType() << ", " << op.getRhs().getType() + << ", " << op.getDst().getType() << ", " << op.getM().getType() + << ", " << op.getN().getType() << ", " << op.getK().getType(); +} + +template +static void printMadSemanticOpWithBias(OpAsmPrinter &printer, OpT op, + bool allowTf32Mode) { + printer << ' ' << op.getLhs() << ", " << op.getRhs() << ", " << op.getDst() + << ", " << op.getBias() << ", " << op.getM() << ", " << op.getN() + << ", " << op.getK(); + printMadSemanticClauses(printer, op, allowTf32Mode); + printer.printOptionalAttrDict(op->getAttrs(), + getMadSemanticElidedAttrs(allowTf32Mode)); + printer << " : " << op.getLhs().getType() << ", " << op.getRhs().getType() + << ", " << op.getDst().getType() << ", " << op.getBias().getType() + << ", " << op.getM().getType() << ", " << op.getN().getType() + << ", " << op.getK().getType(); +} + +LogicalResult MadOp::verify() { + std::optional tf32Mode; + if (auto tf32ModeAttr = + (*this)->getAttrOfType("tf32_mode")) + tf32Mode = tf32ModeAttr.getValue(); + return verifyMadSemanticClauses(*this, getLhs().getType(), getRhs().getType(), + getDst().getType(), std::nullopt, tf32Mode, + getSatMode(), + (*this)->hasAttr("n_dir")); +} + +ParseResult MadOp::parse(OpAsmParser &parser, OperationState &result) { + return parseMadSemanticOpCommon(parser, result, /*hasBias=*/false, + /*parseTf32ModeClause=*/true); +} + +void MadOp::print(OpAsmPrinter &printer) { + printMadSemanticOpNoBias(printer, *this, /*allowTf32Mode=*/true); +} + +bool MadOp::isMadMxFamily() { return false; } +bool MadOp::hasBiasOperand() { return false; } +bool MadOp::readsAccumulator() { return false; } +bool MadOp::supportsTf32Mode() { return true; } +Value MadOp::getBiasOrNull() { return {}; } + +LogicalResult MadAccOp::verify() { + std::optional tf32Mode; + if (auto tf32ModeAttr = + (*this)->getAttrOfType("tf32_mode")) + tf32Mode = tf32ModeAttr.getValue(); + return verifyMadSemanticClauses(*this, getLhs().getType(), getRhs().getType(), + getDst().getType(), std::nullopt, tf32Mode, + getSatMode(), + (*this)->hasAttr("n_dir")); +} + +ParseResult MadAccOp::parse(OpAsmParser &parser, OperationState &result) { + return parseMadSemanticOpCommon(parser, result, /*hasBias=*/false, + /*parseTf32ModeClause=*/true); +} + +void MadAccOp::print(OpAsmPrinter &printer) { + printMadSemanticOpNoBias(printer, *this, /*allowTf32Mode=*/true); +} + +bool MadAccOp::isMadMxFamily() { return false; } +bool MadAccOp::hasBiasOperand() { return false; } +bool MadAccOp::readsAccumulator() { return true; } +bool MadAccOp::supportsTf32Mode() { return true; } +Value MadAccOp::getBiasOrNull() { return {}; } + +LogicalResult MadBiasOp::verify() { + std::optional tf32Mode; + if (auto tf32ModeAttr = + (*this)->getAttrOfType("tf32_mode")) + tf32Mode = tf32ModeAttr.getValue(); + return verifyMadSemanticClauses(*this, getLhs().getType(), getRhs().getType(), + getDst().getType(), getBias().getType(), + tf32Mode, getSatMode(), + (*this)->hasAttr("n_dir")); +} + +ParseResult MadBiasOp::parse(OpAsmParser &parser, OperationState &result) { + return parseMadSemanticOpCommon(parser, result, /*hasBias=*/true, + /*parseTf32ModeClause=*/true); +} + +void MadBiasOp::print(OpAsmPrinter &printer) { + printMadSemanticOpWithBias(printer, *this, /*allowTf32Mode=*/true); +} + +bool MadBiasOp::isMadMxFamily() { return false; } +bool MadBiasOp::hasBiasOperand() { return true; } +bool MadBiasOp::readsAccumulator() { return false; } +bool MadBiasOp::supportsTf32Mode() { return true; } +Value MadBiasOp::getBiasOrNull() { return getBias(); } + +LogicalResult MadMxOp::verify() { + if (failed(verifyMadMxCommon(*this, getLhs().getType(), getRhs().getType(), + getDst().getType()))) + return failure(); + return verifyMadSemanticClauses(*this, getLhs().getType(), getRhs().getType(), + getDst().getType(), std::nullopt, std::nullopt, + getSatMode(), + (*this)->hasAttr("n_dir")); +} + +ParseResult MadMxOp::parse(OpAsmParser &parser, OperationState &result) { + return parseMadSemanticOpCommon(parser, result, /*hasBias=*/false, + /*parseTf32ModeClause=*/false); +} + +void MadMxOp::print(OpAsmPrinter &printer) { + printMadSemanticOpNoBias(printer, *this, /*allowTf32Mode=*/false); +} + +bool MadMxOp::isMadMxFamily() { return true; } +bool MadMxOp::hasBiasOperand() { return false; } +bool MadMxOp::readsAccumulator() { return false; } +bool MadMxOp::supportsTf32Mode() { return false; } +Value MadMxOp::getBiasOrNull() { return {}; } +Attribute MadMxOp::getTf32ModeAttr() { return {}; } + +LogicalResult MadMxAccOp::verify() { + if (failed(verifyMadMxCommon(*this, getLhs().getType(), getRhs().getType(), + getDst().getType()))) + return failure(); + return verifyMadSemanticClauses(*this, getLhs().getType(), getRhs().getType(), + getDst().getType(), std::nullopt, std::nullopt, + getSatMode(), + (*this)->hasAttr("n_dir")); +} + +ParseResult MadMxAccOp::parse(OpAsmParser &parser, OperationState &result) { + return parseMadSemanticOpCommon(parser, result, /*hasBias=*/false, + /*parseTf32ModeClause=*/false); +} + +void MadMxAccOp::print(OpAsmPrinter &printer) { + printMadSemanticOpNoBias(printer, *this, /*allowTf32Mode=*/false); +} + +bool MadMxAccOp::isMadMxFamily() { return true; } +bool MadMxAccOp::hasBiasOperand() { return false; } +bool MadMxAccOp::readsAccumulator() { return true; } +bool MadMxAccOp::supportsTf32Mode() { return false; } +Value MadMxAccOp::getBiasOrNull() { return {}; } +Attribute MadMxAccOp::getTf32ModeAttr() { return {}; } + +LogicalResult MadMxBiasOp::verify() { + if (failed(verifyMadMxCommon(*this, getLhs().getType(), getRhs().getType(), + getDst().getType(), getBias().getType()))) + return failure(); + return verifyMadSemanticClauses(*this, getLhs().getType(), getRhs().getType(), + getDst().getType(), getBias().getType(), + std::nullopt, getSatMode(), + (*this)->hasAttr("n_dir")); +} + +ParseResult MadMxBiasOp::parse(OpAsmParser &parser, OperationState &result) { + return parseMadSemanticOpCommon(parser, result, /*hasBias=*/true, + /*parseTf32ModeClause=*/false); +} + +void MadMxBiasOp::print(OpAsmPrinter &printer) { + printMadSemanticOpWithBias(printer, *this, /*allowTf32Mode=*/false); +} + +bool MadMxBiasOp::isMadMxFamily() { return true; } +bool MadMxBiasOp::hasBiasOperand() { return true; } +bool MadMxBiasOp::readsAccumulator() { return false; } +bool MadMxBiasOp::supportsTf32Mode() { return false; } +Value MadMxBiasOp::getBiasOrNull() { return getBias(); } +Attribute MadMxBiasOp::getTf32ModeAttr() { return {}; } + +void MadRawOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getLhsMutable()); + effects.emplace_back(MemoryEffects::Read::get(), &getRhsMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDstMutable()); +} + +LogicalResult MadRawOp::verify() { + return verifyMadPointerKinds(*this, getLhs().getType(), getRhs().getType(), + getDst().getType()); +} + +bool MadRawOp::isMadMxFamily() { return false; } +bool MadRawOp::hasBiasOperand() { return false; } +Value MadRawOp::getBiasOrNull() { return {}; } + +void MadBiasRawOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getLhsMutable()); + effects.emplace_back(MemoryEffects::Read::get(), &getRhsMutable()); + effects.emplace_back(MemoryEffects::Read::get(), &getBiasMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDstMutable()); +} + +LogicalResult MadBiasRawOp::verify() { + return verifyMadPointerKinds(*this, getLhs().getType(), getRhs().getType(), + getDst().getType(), getBias().getType()); +} + +bool MadBiasRawOp::isMadMxFamily() { return false; } +bool MadBiasRawOp::hasBiasOperand() { return true; } +Value MadBiasRawOp::getBiasOrNull() { return getBias(); } + +void MadMxRawOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getLhsMutable()); + effects.emplace_back(MemoryEffects::Read::get(), &getRhsMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDstMutable()); +} + +LogicalResult MadMxRawOp::verify() { + return verifyMadMxCommon(*this, getLhs().getType(), getRhs().getType(), + getDst().getType()); +} + +bool MadMxRawOp::isMadMxFamily() { return true; } +bool MadMxRawOp::hasBiasOperand() { return false; } +Value MadMxRawOp::getBiasOrNull() { return {}; } + +void MadMxBiasRawOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getLhsMutable()); + effects.emplace_back(MemoryEffects::Read::get(), &getRhsMutable()); + effects.emplace_back(MemoryEffects::Read::get(), &getBiasMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDstMutable()); +} + +LogicalResult MadMxBiasRawOp::verify() { + return verifyMadMxCommon(*this, getLhs().getType(), getRhs().getType(), + getDst().getType(), getBias().getType()); +} + +bool MadMxBiasRawOp::isMadMxFamily() { return true; } +bool MadMxBiasRawOp::hasBiasOperand() { return true; } +Value MadMxBiasRawOp::getBiasOrNull() { return getBias(); } + +static bool isCompatibleScalarForSemanticType(Type semanticType, + Type scalarType) { + if (semanticType == scalarType) + return true; + + auto semanticInt = dyn_cast(semanticType); + auto scalarInt = dyn_cast(scalarType); + if (!semanticInt || !scalarInt || semanticInt.getWidth() != scalarInt.getWidth()) + return false; + + if (semanticInt.isSigned()) + return scalarInt.isSigned() || scalarInt.isSignless(); + if (semanticInt.isUnsigned()) + return scalarInt.isUnsigned() || scalarInt.isSignless(); + return scalarInt.isSignless(); +} + +LogicalResult VbrOp::verify() { + if (failed(verifyVRegTypeLike(*this, getResult().getType(), "result"))) + return failure(); + + auto resultVecType = cast(getResult().getType()); + Type elementType = getValue().getType(); + if (isa(elementType)) + return emitOpError("value must be a scalar matching the result element type"); + Type resultElementType = resultVecType.getElementType(); + if (!isCompatibleScalarForSemanticType(resultElementType, elementType)) + return emitOpError("value type must match result element type"); + return success(); +} + +template +static LogicalResult verifyWideningReductionVecOp(ReductionOp op, + StringRef opName) { + if (failed(verifyVRegTypeLike(op, op.getInput().getType(), "input")) || + failed(verifyVRegTypeLike(op, op.getResult().getType(), "result"))) + return failure(); + + auto inputType = dyn_cast(op.getInput().getType()); + auto resultType = dyn_cast(op.getResult().getType()); + if (!inputType || !resultType) + return failure(); + + Type inputElemType = inputType.getElementType(); + Type expectedResultElemType = inputElemType; + int64_t expectedResultLanes = inputType.getElementCount(); + if (auto inputInt = dyn_cast(inputElemType)) { + if (inputInt.getWidth() < 8 || inputInt.getWidth() > 32) + return op.emitOpError( + "requires 8-bit, 16-bit, or 32-bit integer vector element type"); + if (inputInt.getWidth() == 8) { + expectedResultElemType = + IntegerType::get(op.getContext(), 16, inputInt.getSignedness()); + expectedResultLanes = inputType.getElementCount() / 2; + } + if (inputInt.getWidth() == 16) { + expectedResultElemType = + IntegerType::get(op.getContext(), 32, inputInt.getSignedness()); + expectedResultLanes = inputType.getElementCount() / 2; + } + } else if (!inputElemType.isF16() && !inputElemType.isF32()) { + return op.emitOpError("requires i16/i32/f16/f32 vector element type"); + } + + if (resultType.getElementCount() == expectedResultLanes && + resultType.getElementType() == expectedResultElemType) + return success(); + + return op.emitOpError() << opName << " expects result type !pto.vreg<" + << expectedResultLanes << "x" + << expectedResultElemType + << " for input element type " << inputElemType; +} + +LogicalResult VcaddOp::verify() { + return verifyWideningReductionVecOp(*this, "vcadd"); +} + +LogicalResult VcmaxOp::verify() { + if (failed(verifyVRegTypeLike(*this, getInput().getType(), "input")) || + failed(verifyVRegTypeLike(*this, getResult().getType(), "result"))) + return failure(); + if (getInput().getType() != getResult().getType()) + return emitOpError("input and result must have the same vector type"); + return success(); +} + +LogicalResult VcminOp::verify() { + if (failed(verifyVRegTypeLike(*this, getInput().getType(), "input")) || + failed(verifyVRegTypeLike(*this, getResult().getType(), "result"))) + return failure(); + if (getInput().getType() != getResult().getType()) + return emitOpError("input and result must have the same vector type"); + return success(); +} + +LogicalResult VciOp::verify() { + auto resultType = dyn_cast(getResult().getType()); + if (!resultType) + return emitOpError("result must be !pto.vreg<...>"); + Type resultElemType = resultType.getElementType(); + bool supportedInteger = false; + if (auto intType = dyn_cast(resultElemType)) + supportedInteger = intType.getWidth() == 8 || intType.getWidth() == 16 || + intType.getWidth() == 32; + bool supportedFloat = resultElemType.isF16() || resultElemType.isF32(); + if (!supportedInteger && !supportedFloat) + return emitOpError("result element type must be integer or f16/f32"); + if (!isCompatibleScalarForSemanticType(resultElemType, getIndex().getType())) + return emitOpError("index type must match result element type"); + return success(); +} + +void Vgather2Op::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult Vgather2Op::verify() { + if (!isBufferLike(getSource().getType())) + return emitOpError("requires a pointer-like source"); + MemoryRole sourceRole = classifyMemoryRole(getSource().getType()); + if (sourceRole == MemoryRole::GM) + return emitOpError("requires a UB-backed source"); + + auto offsetsType = dyn_cast(getOffsets().getType()); + auto resultType = dyn_cast(getResult().getType()); + if (!offsetsType || !resultType) + return emitOpError("offsets and result must be !pto.vreg<...>"); + if (!isa(offsetsType.getElementType())) + return emitOpError("offset vector must use integer element type"); + if (offsetsType.getElementCount() != resultType.getElementCount()) + return emitOpError("offset and result vectors must have the same element count"); + if (failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type"))) + return failure(); + return success(); +} + +LogicalResult CopyUbufToUbufOp::verify() { + if (!isBufferLike(getSource().getType()) || !isBufferLike(getDestination().getType())) + return emitOpError("requires pointer-like source and destination"); + if (classifyMemoryRole(getSource().getType()) != MemoryRole::UB || + classifyMemoryRole(getDestination().getType()) != MemoryRole::UB) + return emitOpError("requires UB-backed source and destination"); + return success(); +} + +void CopyCbufToUbufOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult CopyCbufToUbufOp::verify() { + return verifyCopyCbufToUbufLikeOp(*this); +} + +void CopyUbufToCbufOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult CopyUbufToCbufOp::verify() { + if (!isBufferLike(getSource().getType()) || !isBufferLike(getDestination().getType())) + return emitOpError("requires pointer-like source and destination"); + if (classifyMemoryRole(getSource().getType()) != MemoryRole::UB || + classifyMemoryRole(getDestination().getType()) != MemoryRole::Other) + return emitOpError("requires UB-backed source and CBUF-backed destination"); + return success(); +} + +void MteUbUbOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult MteUbUbOp::verify() { + if (!isBufferLike(getSource().getType()) || !isBufferLike(getDestination().getType())) + return emitOpError("requires pointer-like source and destination"); + if (classifyMemoryRole(getSource().getType()) != MemoryRole::UB || + classifyMemoryRole(getDestination().getType()) != MemoryRole::UB) + return emitOpError("requires UB-backed source and destination"); + return success(); +} + +void MteUbL1Op::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult MteUbL1Op::verify() { + if (!isBufferLike(getSource().getType()) || !isBufferLike(getDestination().getType())) + return emitOpError("requires pointer-like source and destination"); + if (classifyMemoryRole(getSource().getType()) != MemoryRole::UB || + classifyMemoryRole(getDestination().getType()) != MemoryRole::Other) + return emitOpError("requires UB-backed source and CBUF-backed destination"); + return success(); +} + +void VgatherbOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult VgatherbOp::verify() { + if (!isBufferLike(getSource().getType())) + return emitOpError("requires a pointer-like source"); + MemoryRole sourceRole = classifyMemoryRole(getSource().getType()); + if (sourceRole == MemoryRole::GM) + return emitOpError("requires a UB-backed source"); + + if (failed(verifyMaskTypeWithGranularityLike(getOperation(), getMask().getType(), + "mask type", "b32"))) + return failure(); + + auto offsetsType = dyn_cast(getOffsets().getType()); + auto resultType = dyn_cast(getResult().getType()); + if (!offsetsType || !resultType) + return emitOpError("offsets and result must be !pto.vreg<...>"); + auto offsetsElemType = dyn_cast(offsetsType.getElementType()); + if (!offsetsElemType) + return emitOpError("offset vector must use integer element type"); + if (offsetsElemType.getWidth() != 32) + return emitOpError("currently requires 32-bit offset vector elements"); + if (offsetsType.getElementCount() != resultType.getElementCount()) + return emitOpError("offset and result vectors must have the same element count"); + return success(); +} + +void Vgather2BcOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult Vgather2BcOp::verify() { + if (!isBufferLike(getSource().getType())) + return emitOpError("requires a pointer-like source"); + if (classifyMemoryRole(getSource().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed source"); + if (failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type"))) + return failure(); + + auto offsetsType = dyn_cast(getOffsets().getType()); + auto resultType = dyn_cast(getResult().getType()); + if (!offsetsType || !resultType) + return emitOpError("offsets and result must be !pto.vreg<...>"); + auto offsetsElemType = dyn_cast(offsetsType.getElementType()); + if (!offsetsElemType) + return emitOpError("offset vector must use integer element type"); + if (offsetsElemType.getWidth() != 32) + return emitOpError("currently requires 32-bit offset vector elements"); + if (offsetsType.getElementCount() != resultType.getElementCount()) + return emitOpError("offset and result vectors must have the same element count"); + return success(); +} + +LogicalResult VbitsortOp::verify() { + if (!isBufferLike(getDestination().getType()) || !isBufferLike(getSource().getType()) || + !isBufferLike(getIndices().getType())) + return emitOpError("requires pointer-like destination/source/indices"); + if (classifyMemoryRole(getDestination().getType()) != MemoryRole::UB || + classifyMemoryRole(getSource().getType()) != MemoryRole::UB || + classifyMemoryRole(getIndices().getType()) != MemoryRole::UB) + return emitOpError("requires UB-backed destination/source/indices"); + if (!getRepeatTimes().getType().isIndex()) + return emitOpError("repeat_times must be index"); + if (failed(verifyNotNestedInVecScope(*this, "pto.vbitsort"))) + return failure(); + return success(); +} + +void VbitsortOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); + effects.emplace_back(MemoryEffects::Read::get(), &getIndicesMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult Vmrgsort4Op::verify() { + if (!isBufferLike(getDestination().getType()) || !isBufferLike(getSource0().getType()) || + !isBufferLike(getSource1().getType()) || !isBufferLike(getSource2().getType()) || + !isBufferLike(getSource3().getType())) + return emitOpError("requires pointer-like destination and sources"); + if (classifyMemoryRole(getDestination().getType()) != MemoryRole::UB || + classifyMemoryRole(getSource0().getType()) != MemoryRole::UB || + classifyMemoryRole(getSource1().getType()) != MemoryRole::UB || + classifyMemoryRole(getSource2().getType()) != MemoryRole::UB || + classifyMemoryRole(getSource3().getType()) != MemoryRole::UB) + return emitOpError("requires UB-backed destination and sources"); + auto dstPtrType = dyn_cast(getDestination().getType()); + auto src0PtrType = dyn_cast(getSource0().getType()); + auto src1PtrType = dyn_cast(getSource1().getType()); + auto src2PtrType = dyn_cast(getSource2().getType()); + auto src3PtrType = dyn_cast(getSource3().getType()); + if (!dstPtrType || !src0PtrType || !src1PtrType || !src2PtrType || + !src3PtrType) + return emitOpError("requires ptr-backed destination and sources"); + + Type elemType = dstPtrType.getElementType(); + if (src0PtrType.getElementType() != elemType || + src1PtrType.getElementType() != elemType || + src2PtrType.getElementType() != elemType || + src3PtrType.getElementType() != elemType) + return emitOpError( + "requires destination and all sources to have the same element type"); + if (!elemType.isF16() && !elemType.isF32()) + return emitOpError("requires f16 or f32 element type"); + if (failed(verifyNotNestedInVecScope(*this, "pto.vmrgsort4"))) + return failure(); + return success(); +} + +LogicalResult VmaxOp::verify() { + if (failed(verifyVRegTypeLike(*this, getLhs().getType(), "lhs")) || + failed(verifyVRegTypeLike(*this, getRhs().getType(), "rhs")) || + failed(verifyVRegTypeLike(*this, getResult().getType(), "result"))) + return failure(); + if (getLhs().getType() != getRhs().getType() || + getLhs().getType() != getResult().getType()) + return emitOpError("lhs, rhs, and result must have the same vector type"); + return success(); +} + +LogicalResult VminOp::verify() { + if (failed(verifyVRegTypeLike(*this, getLhs().getType(), "lhs")) || + failed(verifyVRegTypeLike(*this, getRhs().getType(), "rhs")) || + failed(verifyVRegTypeLike(*this, getResult().getType(), "result"))) + return failure(); + if (getLhs().getType() != getRhs().getType() || + getLhs().getType() != getResult().getType()) + return emitOpError("lhs, rhs, and result must have the same vector type"); + return success(); +} + +void VldsOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +template +static LogicalResult verifyVldsCommon(LoadOp op) { + if (!isBufferLike(op.getSource().getType())) + return op.emitOpError("requires a pointer-like source"); + + if (failed(verifyVRegTypeLike(op, op.getResult().getType(), "result type"))) + return failure(); + + MemoryRole sourceRole = classifyMemoryRole(op.getSource().getType()); + if (sourceRole == MemoryRole::GM) + return op.emitOpError("requires a UB-backed source"); + + if (op.getDistAttr()) { + StringRef dist = *op.getDist(); + if (!isSupportedVldsDistToken(dist)) + return op.emitOpError( + "supports only NORM, BRC_B8/B16/B32, US_B8/B16, DS_B8/B16, " + "UNPK_B8/B16/B32, BRC_BLK, E2B_B16/B32, UNPK4, SPLT4CHN, and " + "SPLT2CHN_B8/B16 load distributions"); + } + + return success(); +} + +LogicalResult VldsOp::verify() { + if (failed(verifyVldsCommon(*this))) + return failure(); + if (std::optional mode = getOptionalPostModeAttr(getOperation()); + mode && !isSupportedPostMode(*mode)) + return emitOpError("requires mode to be POST_UPDATE or NO_POST_UPDATE"); + return success(); +} +void VldsPostOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult VldsPostOp::verify() { + if (failed(verifyVldsCommon(*this))) + return failure(); + if (getUpdatedSource().getType() != getSource().getType()) + return emitOpError("requires updated source result to match source type"); + return success(); +} + +void VldasOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult VldasOp::verify() { + if (!isBufferLike(getSource().getType())) + return emitOpError("requires a pointer-like source"); + if (failed(verifyAlignTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + if (classifyMemoryRole(getSource().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed source"); + return success(); +} + +LogicalResult InitAlignOp::verify() { + return verifyAlignTypeLike(*this, getResult().getType(), "result type"); +} + +LogicalResult SprclrOp::verify() { + if (!isSupportedSprToken(getSpr())) + return emitOpError("requires spr to be \"AR\""); + if (failed(verifyNestedInVecScope(*this, "pto.sprclr"))) + return failure(); + return success(); +} + +void VldusOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult VldusOp::verify() { + if (failed(verifyLoadAlignChain(getAlign(), *this, "align type")) || + failed(verifyVRegTypeLike(*this, getResult().getType(), "result type")) || + failed(verifyAlignTypeLike(*this, getUpdatedAlign().getType(), + "updated align type"))) + return failure(); + if (!isBufferLike(getSource().getType())) + return emitOpError("requires a pointer-like source"); + if (classifyMemoryRole(getSource().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed source"); + return success(); +} + +void UvldOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult UvldOp::verify() { + if (failed(verifyVRegTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + if (!isBufferLike(getSource().getType())) + return emitOpError("requires a buffer-like source"); + if (classifyMemoryRole(getSource().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed source"); + + auto sourceMemRef = dyn_cast(getSource().getType()); + if (!sourceMemRef) + return success(); + + Type sourceElementType = sourceMemRef.getElementType(); + Type vectorElementType = cast(getResult().getType()).getElementType(); + if (sourceElementType != vectorElementType) + return emitOpError( + "requires source element type to match vector element type"); + return success(); +} + +LogicalResult VdupOp::verify() { + auto resultType = dyn_cast(getResult().getType()); + if (!resultType) + return emitOpError("result must be !pto.vreg<...>"); + + std::optional granularity = + getVdupMaskGranularity(resultType.getElementType()); + if (!granularity) + return emitOpError("result element type must use b8, b16, or b32 mask granularity"); + if (failed(verifyMaskTypeWithGranularityLike( + getOperation(), getMask().getType(), "mask type", *granularity))) + return failure(); + + if (!isSupportedVdupPosition(getPosition())) + return emitOpError("position must be LOWEST or HIGHEST"); + + Type inputType = getInput().getType(); + if (auto inputVecType = dyn_cast(inputType)) { + if (inputVecType != resultType) + return emitOpError("vector input must match result vector type"); + return success(); + } + + if (getPosition()) + return emitOpError("position is only supported for vector input"); + + Type resultElementType = resultType.getElementType(); + if (!isCompatibleScalarForSemanticType(resultElementType, inputType)) + return emitOpError("scalar input must match result element type"); + + return success(); +} + +LogicalResult TensorViewAddrOp::verify() { + Type srcType = getSrc().getType(); + Type dstType = getDst().getType(); + + Type elementType; + int64_t expectedRank = -1; + auto gmSpace = pto::AddressSpaceAttr::get(getContext(), pto::AddressSpace::GM); + + if (auto tvType = dyn_cast(srcType)) { + elementType = tvType.getElementType(); + expectedRank = tvType.getRank(); + } else if (auto partType = dyn_cast(srcType)) { + elementType = partType.getElementType(); + expectedRank = partType.getRank(); + } else if (auto memrefType = dyn_cast(srcType)) { + elementType = memrefType.getElementType(); + expectedRank = memrefType.getRank(); + auto srcSpace = dyn_cast_or_null(memrefType.getMemorySpace()); + if (srcSpace && srcSpace != gmSpace) + return emitOpError("memref source must stay in gm memory space"); + } else { + return emitOpError( + "source must be a tensor_view, partition_tensor_view, or memref"); + } + + if (auto dstMemRefType = dyn_cast(dstType)) { + if (dstMemRefType.getElementType() != elementType) + return emitOpError( + "memref result element type must match source element type"); + if (dstMemRefType.getRank() != expectedRank) + return emitOpError("memref result rank must match source rank"); + auto dstSpace = + dyn_cast_or_null(dstMemRefType.getMemorySpace()); + if (dstSpace && dstSpace != gmSpace) + return emitOpError("memref result must stay in gm memory space"); + return success(); + } + + auto dstPtrType = dyn_cast(dstType); + if (!dstPtrType) + return emitOpError("result must be a memref or !pto.ptr<...>"); + if (dstPtrType.getElementType() != elementType) + return emitOpError( + "pointer result element type must match source element type"); + if (dstPtrType.getMemorySpace() != gmSpace) + return emitOpError("pointer result must stay in gm memory space"); + return success(); +} + +LogicalResult TileBufAddrOp::verify() { + Type dstType = getDst().getType(); + Type elementType; + Attribute srcMemorySpace; + int64_t srcRank = 0; + + if (auto srcTileType = dyn_cast(getSrc().getType())) { + elementType = srcTileType.getElementType(); + srcMemorySpace = srcTileType.getMemorySpace(); + srcRank = static_cast(srcTileType.getShape().size()); + } else if (auto srcMemRefType = dyn_cast(getSrc().getType())) { + // PTOViewToMemref may lower tile_buf producers to memref + pto.bind_tile + // before the shared materialization bridge restores tile handles. + // Hand-written pto.tile_buf_addr may therefore temporarily see a tile-bound + // memref operand in that intermediate form. + elementType = srcMemRefType.getElementType(); + srcMemorySpace = srcMemRefType.getMemorySpace(); + srcRank = srcMemRefType.getRank(); + } else { + return emitOpError("source must be a !pto.tile_buf<...> or memref"); + } + + auto srcSpace = dyn_cast_or_null(srcMemorySpace); + + if (auto dstMemRefType = dyn_cast(dstType)) { + if (dstMemRefType.getElementType() != elementType) + return emitOpError( + "memref result element type must match tile element type"); + if (dstMemRefType.getRank() != srcRank) + return emitOpError("memref result rank must match tile rank"); + auto dstSpace = + dyn_cast_or_null(dstMemRefType.getMemorySpace()); + if (srcSpace && dstSpace && srcSpace != dstSpace) + return emitOpError("memref result must stay within the tile memory space"); + return success(); + } + + auto dstPtrType = dyn_cast(dstType); + if (!dstPtrType) + return emitOpError("result must be a memref or !pto.ptr<...>"); + if (dstPtrType.getElementType() != elementType) + return emitOpError( + "pointer result element type must match tile element type"); + if (srcSpace && dstPtrType.getMemorySpace() != srcSpace) + return emitOpError("pointer result must stay within the tile memory space"); + return success(); +} + +LogicalResult PsetB8Op::verify() { + if (failed(verifyMaskTypeWithGranularityLike(*this, getResult().getType(), + "result type", "b8"))) + return failure(); + + if (!isSupportedPredicatePattern(getPattern())) + return emitOpError("requires a supported PAT_* predicate pattern"); + return success(); +} + +LogicalResult PsetB16Op::verify() { + if (failed(verifyMaskTypeWithGranularityLike(*this, getResult().getType(), + "result type", "b16"))) + return failure(); + + if (!isSupportedPredicatePattern(getPattern())) + return emitOpError("requires a supported PAT_* predicate pattern"); + return success(); +} + +LogicalResult PsetB32Op::verify() { + if (failed(verifyMaskTypeWithGranularityLike(*this, getResult().getType(), + "result type", "b32"))) + return failure(); + if (!isSupportedPredicatePattern(getPattern())) + return emitOpError("requires a supported PAT_* predicate pattern"); + return success(); +} + +LogicalResult PgeB8Op::verify() { + if (failed(verifyMaskTypeWithGranularityLike(*this, getResult().getType(), + "result type", "b8"))) + return failure(); + if (!isSupportedPredicatePattern(getPattern())) + return emitOpError("requires a supported PAT_* predicate pattern"); + return success(); +} + +LogicalResult PgeB16Op::verify() { + if (failed(verifyMaskTypeWithGranularityLike(*this, getResult().getType(), + "result type", "b16"))) + return failure(); + if (!isSupportedPredicatePattern(getPattern())) + return emitOpError("requires a supported PAT_* predicate pattern"); + return success(); +} + +LogicalResult PgeB32Op::verify() { + if (failed(verifyMaskTypeWithGranularityLike(*this, getResult().getType(), + "result type", "b32"))) + return failure(); + if (!isSupportedPredicatePattern(getPattern())) + return emitOpError("requires a supported PAT_* predicate pattern"); + return success(); +} + +template +static LogicalResult verifyPredicateLaneCountOp(PltOp op, + StringRef granularity) { + if (failed(verifyMaskTypeWithGranularityLike(op, op.getMask().getType(), + "mask type", granularity))) + return failure(); + Type scalarType = op.getScalar().getType(); + auto scalarIntType = dyn_cast(scalarType); + if (!scalarIntType || scalarIntType.getWidth() != 32) + return op.emitOpError("requires scalar to be i32"); + if (op.getScalarOut().getType() != scalarType) + return op.emitOpError("requires scalar_out to match scalar type"); + return success(); +} + +LogicalResult PltB8Op::verify() { return verifyPredicateLaneCountOp(*this, "b8"); } +LogicalResult PltB16Op::verify() { + return verifyPredicateLaneCountOp(*this, "b16"); +} +LogicalResult PltB32Op::verify() { + return verifyPredicateLaneCountOp(*this, "b32"); +} + +LogicalResult PpackOp::verify() { + if (failed(verifyMaskTypeLike(*this, getInput().getType(), "input type")) || + failed(verifyMaskTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + if (getPart() != "LOWER") + return emitOpError("currently supports only LOWER part"); + return success(); +} + +LogicalResult PunpackOp::verify() { + if (failed(verifyMaskTypeLike(*this, getInput().getType(), "input type")) || + failed(verifyMaskTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + if (getPart() != "LOWER") + return emitOpError("currently supports only LOWER part"); + auto inputMaskType = cast(getInput().getType()); + auto resultMaskType = cast(getResult().getType()); + StringRef inputGranularity = inputMaskType.getGranularity(); + StringRef resultGranularity = resultMaskType.getGranularity(); + if (inputGranularity != resultGranularity && + !isMaskGranularityAdjacentWidening(inputGranularity, resultGranularity)) { + return emitOpError( + "requires result mask granularity to match the input or widen by one step"); + } + return success(); +} + +LogicalResult PbitcastOp::verify() { + if (failed(verifyMaskTypeLike(*this, getInput().getType(), "input type")) || + failed(verifyMaskTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + return success(); +} + +LogicalResult PnotOp::verify() { + if (failed(verifyMaskTypeLike(*this, getInput().getType(), "input type")) || + failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type")) || + failed(verifyMaskTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + return success(); +} + +LogicalResult PselOp::verify() { + if (failed(verifyMaskTypeLike(*this, getSrc0().getType(), "src0 type")) || + failed(verifyMaskTypeLike(*this, getSrc1().getType(), "src1 type")) || + failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type")) || + failed(verifyMaskTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + return success(); +} + +template +static LogicalResult verifyBinaryMaskOp(BinaryMaskOp op) { + if (failed(verifyMaskTypeLike(op, op.getSrc0().getType(), "src0 type")) || + failed(verifyMaskTypeLike(op, op.getSrc1().getType(), "src1 type")) || + failed(verifyMaskTypeLike(op, op.getMask().getType(), "mask type")) || + failed(verifyMaskTypeLike(op, op.getResult().getType(), "result type"))) + return failure(); + return success(); +} + +LogicalResult PandOp::verify() { return verifyBinaryMaskOp(*this); } +LogicalResult PorOp::verify() { return verifyBinaryMaskOp(*this); } +LogicalResult PxorOp::verify() { return verifyBinaryMaskOp(*this); } + +void PldsOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult PldsOp::verify() { + if (!isBufferLike(getSource().getType())) + return emitOpError("requires a pointer-like source"); + if (failed(verifyMaskTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + MemoryRole sourceRole = classifyMemoryRole(getSource().getType()); + if (sourceRole == MemoryRole::GM) + return emitOpError("requires a UB-backed source"); + if (!getOffset().getType().isIndex()) + return emitOpError("requires index offset"); + if (!isSupportedPredicateLoadDist(getDist())) + return emitOpError("requires predicate load dist to be NORM, US, or DS"); + return success(); +} + +void PldiOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult PldiOp::verify() { + if (!isBufferLike(getSource().getType())) + return emitOpError("requires a pointer-like source"); + if (failed(verifyMaskTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + if (classifyMemoryRole(getSource().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed source"); + if (!matchPattern(getOffset(), m_Constant())) + return emitOpError("requires offset to be a constant index immediate"); + if (!isSupportedPredicateLoadDist(getDist())) + return emitOpError("requires predicate load dist to be NORM, US, or DS"); + return success(); +} + +template +static LogicalResult verifyElementwiseVecScalarOpLike(OpTy op) { + auto inputType = dyn_cast(op.getInput().getType()); + auto resultType = dyn_cast(op.getResult().getType()); + if (!inputType || !resultType) + return op.emitOpError("input and result must be !pto.vreg<...>"); + if (inputType != resultType) + return op.emitOpError("input and result vector types must match"); + + Type elemType = inputType.getElementType(); + Type scalarType = op.getScalar().getType(); + if (scalarType == elemType) + return success(); + + auto elemInt = dyn_cast(elemType); + auto scalarInt = dyn_cast(scalarType); + if (!elemInt || !scalarInt || elemInt.getWidth() != scalarInt.getWidth()) + return op.emitOpError("scalar type must match vector element type"); + + if (elemInt.isSigned() && (scalarInt.isSigned() || scalarInt.isSignless())) + return success(); + if (elemInt.isUnsigned() && + (scalarInt.isUnsigned() || scalarInt.isSignless())) + return success(); + if (elemInt.isSignless() && scalarInt.isSignless()) + return success(); + + return op.emitOpError( + "integer scalar type must match vector element width and use matching signedness or signless i"); +} + +template +static LogicalResult verifyVecScalarOpLike(OpTy op) { + if (failed(verifyElementwiseVecScalarOpLike(op))) + return failure(); + return success(); +} + +template +static LogicalResult verifyVecScalarMaskedOpLike(OpTy op) { + if (failed(verifyElementwiseVecScalarOpLike(op))) + return failure(); + if (failed(verifyMaskTypeLike(op, op.getMask().getType(), "mask type"))) + return failure(); + return success(); +} + +template +static LogicalResult verifyCarryVecOp(CarryOp op) { + if (failed(verifyIntegerVRegTypeLike(op, op.getLhs().getType(), "lhs type")) || + failed(verifyIntegerVRegTypeLike(op, op.getRhs().getType(), "rhs type")) || + failed(verifyMaskTypeLike(op, op.getMask().getType(), "mask type")) || + failed(verifyIntegerVRegTypeLike(op, op.getResult().getType(), + "result type")) || + failed(verifyMaskTypeLike(op, op.getCarry().getType(), "carry type"))) + return failure(); + + auto lhsType = cast(op.getLhs().getType()); + auto rhsType = cast(op.getRhs().getType()); + auto resultType = cast(op.getResult().getType()); + auto lhsElemType = cast(lhsType.getElementType()); + if (lhsType != rhsType || lhsType != resultType) + return op.emitOpError("requires lhs, rhs, and result to have matching vector types"); + if (lhsElemType.getWidth() != 32) + return op.emitOpError("currently requires 32-bit integer vector elements"); + return success(); +} + +template +static LogicalResult verifyCarryVecOpWithInput(CarryWithInputOp op) { + if (failed(verifyCarryVecOp(op)) || + failed(verifyMaskTypeLike(op, op.getCarryIn().getType(), + "carry_in type"))) + return failure(); + return success(); +} + +LogicalResult VmulsOp::verify() { return verifyVecScalarMaskedOpLike(*this); } +LogicalResult VaddsOp::verify() { return verifyVecScalarMaskedOpLike(*this); } +LogicalResult VmaxsOp::verify() { return verifyVecScalarMaskedOpLike(*this); } +LogicalResult VminsOp::verify() { return verifyVecScalarMaskedOpLike(*this); } +LogicalResult VlreluOp::verify() { return verifyVecScalarMaskedOpLike(*this); } +LogicalResult VshlsOp::verify() { + auto inputType = dyn_cast(getInput().getType()); + auto resultType = dyn_cast(getResult().getType()); + if (!inputType || !resultType) + return emitOpError("input and result must be !pto.vreg<...>"); + if (failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type"))) + return failure(); + if (inputType != resultType) + return emitOpError("input and result vector types must match"); + if (!isa(inputType.getElementType())) + return emitOpError("requires integer vector and integer scalar"); + auto scalarType = dyn_cast(getScalar().getType()); + if (!scalarType || !scalarType.isSignlessInteger(16)) + return emitOpError("requires signless i16 scalar"); + return success(); +} +LogicalResult VshrsOp::verify() { + auto inputType = dyn_cast(getInput().getType()); + auto resultType = dyn_cast(getResult().getType()); + if (!inputType || !resultType) + return emitOpError("input and result must be !pto.vreg<...>"); + if (failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type"))) + return failure(); + if (inputType != resultType) + return emitOpError("input and result vector types must match"); + if (!isa(inputType.getElementType())) + return emitOpError("requires integer vector and integer scalar"); + auto scalarType = dyn_cast(getScalar().getType()); + if (!scalarType || !scalarType.isSignlessInteger(16)) + return emitOpError("requires signless i16 scalar"); + return success(); +} + +LogicalResult VabsOp::verify() { + if (failed(verifyVRegTypeLike(*this, getInput().getType(), "operand type"))) + return failure(); + if (failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type"))) + return failure(); + if (failed(verifyVRegTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + if (getInput().getType() != getResult().getType()) + return emitOpError("requires matching register vector shape"); + return success(); +} + +template +static LogicalResult verifyUnaryVecOp(UnaryOp op) { + if (failed(verifyVRegTypeLike(op, op.getInput().getType(), "operand type"))) + return failure(); + if (failed(verifyMaskTypeLike(op, op.getMask().getType(), "mask type"))) + return failure(); + if (failed(verifyVRegTypeLike(op, op.getResult().getType(), "result type"))) + return failure(); + if (op.getInput().getType() != op.getResult().getType()) + return op.emitOpError("requires matching register vector shape"); + return success(); +} + +LogicalResult VexpOp::verify() { return verifyUnaryVecOp(*this); } +LogicalResult VlnOp::verify() { return verifyUnaryVecOp(*this); } +LogicalResult VsqrtOp::verify() { return verifyUnaryVecOp(*this); } +LogicalResult VnegOp::verify() { return verifyUnaryVecOp(*this); } +LogicalResult VreluOp::verify() { + if (failed(verifyUnaryVecOp(*this))) + return failure(); + auto inputType = cast(getInput().getType()); + Type elemType = inputType.getElementType(); + if (auto intType = dyn_cast(elemType)) { + if (intType.getWidth() != 32 || intType.isUnsigned()) + return emitOpError("requires si32/i32/f16/f32 vector element type"); + return success(); + } + if (!elemType.isF16() && !elemType.isF32()) + return emitOpError("requires si32/i32/f16/f32 vector element type"); + return success(); +} +LogicalResult VnotOp::verify() { return verifyUnaryVecOp(*this); } + +template +static LogicalResult verifyBinaryVecOp(BinaryOp op) { + if (failed(verifyVRegTypeLike(op, op.getLhs().getType(), "lhs type"))) + return failure(); + if (failed(verifyVRegTypeLike(op, op.getRhs().getType(), "rhs type"))) + return failure(); + if (failed(verifyMaskTypeLike(op, op.getMask().getType(), "mask type"))) + return failure(); + if (failed(verifyVRegTypeLike(op, op.getResult().getType(), "result type"))) + return failure(); + if (op.getLhs().getType() != op.getRhs().getType() || + op.getLhs().getType() != op.getResult().getType()) + return op.emitOpError("requires matching register vector shapes"); + return success(); +} + +LogicalResult VaddOp::verify() { return verifyBinaryVecOp(*this); } +LogicalResult VsubOp::verify() { return verifyBinaryVecOp(*this); } +LogicalResult VmulOp::verify() { return verifyBinaryVecOp(*this); } +LogicalResult VdivOp::verify() { return verifyBinaryVecOp(*this); } +LogicalResult VandOp::verify() { return verifyBinaryVecOp(*this); } +LogicalResult VorOp::verify() { return verifyBinaryVecOp(*this); } +LogicalResult VxorOp::verify() { return verifyBinaryVecOp(*this); } +LogicalResult VshlOp::verify() { + if (failed(verifyBinaryVecOp(*this))) + return failure(); + auto lhsType = cast(getLhs().getType()); + if (!isa(lhsType.getElementType())) + return emitOpError("requires integer vector element type"); + return success(); +} +LogicalResult VshrOp::verify() { + if (failed(verifyBinaryVecOp(*this))) + return failure(); + auto lhsType = cast(getLhs().getType()); + if (!isa(lhsType.getElementType())) + return emitOpError("requires integer vector element type"); + return success(); +} +LogicalResult VaddcOp::verify() { return verifyCarryVecOp(*this); } +LogicalResult VsubcOp::verify() { return verifyCarryVecOp(*this); } +LogicalResult VaddcsOp::verify() { return verifyCarryVecOpWithInput(*this); } +LogicalResult VsubcsOp::verify() { return verifyCarryVecOpWithInput(*this); } + +template +static LogicalResult verifyReductionVecOp(ReductionOp op) { + return verifyUnaryVecOp(op); +} + +template +static LogicalResult verifyGroupReductionVecOp(ReductionOp op) { + if (failed(verifyReductionVecOp(op))) + return failure(); + auto inputType = cast(op.getInput().getType()); + Type elemType = inputType.getElementType(); + if (auto intType = dyn_cast(elemType)) { + if (intType.getWidth() < 16 || intType.getWidth() > 32) + return op.emitOpError( + "requires 16-bit or 32-bit integer vector element type"); + return success(); + } + if (!elemType.isF16() && !elemType.isF32()) + return op.emitOpError("requires i16/i32/f16/f32 vector element type"); + return success(); +} + +LogicalResult VcgaddOp::verify() { return verifyGroupReductionVecOp(*this); } +LogicalResult VcgmaxOp::verify() { return verifyGroupReductionVecOp(*this); } +LogicalResult VcgminOp::verify() { return verifyGroupReductionVecOp(*this); } +LogicalResult VcpaddOp::verify() { + if (failed(verifyReductionVecOp(*this))) + return failure(); + auto inputType = cast(getInput().getType()); + Type elemType = inputType.getElementType(); + if (!elemType.isF16() && !elemType.isF32()) + return emitOpError("requires f16 or f32 vector element type"); + return success(); +} + +template +static LogicalResult verifyLaneSelectOp(SelectOp op) { + if (failed(verifyVRegTypeLike(op, op.getSrc0().getType(), "src0 type")) || + failed(verifyVRegTypeLike(op, op.getSrc1().getType(), "src1 type")) || + failed(verifyVRegTypeLike(op, op.getResult().getType(), "result type"))) + return failure(); + + auto src0Type = cast(op.getSrc0().getType()); + auto src1Type = cast(op.getSrc1().getType()); + auto resultType = cast(op.getResult().getType()); + if (src0Type != resultType) + return op.emitOpError("requires src0 and result to have identical vector types"); + if (src1Type.getElementCount() != src0Type.getElementCount()) + return op.emitOpError("requires src0/src1 to have identical element counts"); + auto src1ElemType = dyn_cast(src1Type.getElementType()); + if (!src1ElemType) + return op.emitOpError("requires src1 to use integer vector elements"); + if (src1ElemType.getWidth() != getIntOrFloatBitWidth(src0Type.getElementType())) + return op.emitOpError("requires src1 integer element width to match src0 element width"); + return success(); +} + +template +static LogicalResult verifyPairVecResults(PairOp op) { + if (failed(verifyVRegTypeLike(op, op.getLhs().getType(), "lhs type")) || + failed(verifyVRegTypeLike(op, op.getRhs().getType(), "rhs type")) || + failed(verifyVRegTypeLike(op, op.getLow().getType(), "low result type")) || + failed(verifyVRegTypeLike(op, op.getHigh().getType(), "high result type"))) + return failure(); + if (op.getLhs().getType() != op.getRhs().getType() || + op.getLhs().getType() != op.getLow().getType() || + op.getLhs().getType() != op.getHigh().getType()) + return op.emitOpError("requires operands and results to share one vector type"); + return success(); +} + +template +static LogicalResult verifyPartVecOp(PartOp op) { + if (failed(verifyVRegTypeLike(op, op.getLhs().getType(), "lhs type")) || + failed(verifyVRegTypeLike(op, op.getRhs().getType(), "rhs type")) || + failed(verifyVRegTypeLike(op, op.getResult().getType(), "result type"))) + return failure(); + if (op.getLhs().getType() != op.getRhs().getType() || + op.getLhs().getType() != op.getResult().getType()) + return op.emitOpError("requires operands and result to share one vector type"); + if (!isSupportedPartToken(op.getPart())) + return op.emitOpError("requires part to be LOWER or HIGHER"); + return success(); +} + +LogicalResult VselOp::verify() { + if (failed(verifyVRegTypeLike(*this, getSrc0().getType(), "src0 type")) || + failed(verifyVRegTypeLike(*this, getSrc1().getType(), "src1 type")) || + failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type")) || + failed(verifyVRegTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + if (getSrc0().getType() != getSrc1().getType() || + getSrc0().getType() != getResult().getType()) + return emitOpError("requires src0, src1, and result to have identical vector types"); + return success(); +} + +LogicalResult VselrOp::verify() { return verifyLaneSelectOp(*this); } +LogicalResult Vselrv2Op::verify() { return verifyLaneSelectOp(*this); } + +LogicalResult VsqzOp::verify() { return verifyUnaryVecOp(*this); } + +LogicalResult VusqzOp::verify() { + if (failed(verifyVRegTypeLike(*this, getSrc().getType(), "src type")) || + failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type")) || + failed(verifyVRegTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + if (getSrc().getType() != getResult().getType()) + return emitOpError("requires src and result to share one vector type"); + auto srcType = cast(getSrc().getType()); + auto elemType = dyn_cast(srcType.getElementType()); + if (!elemType) + return emitOpError("requires signed integer vector element type"); + if (elemType.isUnsigned()) + return emitOpError("requires signed integer vector element type"); + unsigned width = elemType.getWidth(); + if (width != 8 && width != 16 && width != 32) + return emitOpError("requires s8/s16/s32 vector element type"); + return success(); +} + +LogicalResult VpackOp::verify() { + if (failed(verifyVRegTypeLike(*this, getSrc().getType(), "src type")) || + failed(verifyVRegTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + if (!isSupportedPartToken(getPart())) + return emitOpError("requires part to be LOWER or HIGHER"); + auto srcType = cast(getSrc().getType()); + auto resultType = cast(getResult().getType()); + Type srcElemType = srcType.getElementType(); + Type resultElemType = resultType.getElementType(); + if (!isa(srcElemType) || !isa(resultElemType)) + return emitOpError("currently requires integer source and result element types"); + if (resultType.getElementCount() != srcType.getElementCount() * 2) + return emitOpError( + "requires result element count to be twice the source element count"); + unsigned srcWidth = getIntOrFloatBitWidth(srcElemType); + unsigned resultWidth = getIntOrFloatBitWidth(resultElemType); + if (!srcWidth || resultWidth * 2 != srcWidth) + return emitOpError( + "requires result element width to be half the source element width"); + auto srcIntType = cast(srcElemType); + auto resultIntType = cast(resultElemType); + if (!resultIntType.isUnsigned()) + return emitOpError("requires unsigned result element type"); + if (!((srcIntType.getWidth() == 32 && resultIntType.getWidth() == 16) || + (srcIntType.getWidth() == 16 && resultIntType.getWidth() == 8))) + return emitOpError( + "currently supports only s32/u32 -> u16 and s16/u16 -> u8"); + return success(); +} + +template +static LogicalResult verifyUnpackVecOp(UnpackOp op) { + if (failed(verifyVRegTypeLike(op, op.getSrc().getType(), "src type")) || + failed(verifyVRegTypeLike(op, op.getResult().getType(), "result type"))) + return failure(); + auto srcType = cast(op.getSrc().getType()); + auto resultType = cast(op.getResult().getType()); + Type srcElemType = srcType.getElementType(); + Type resultElemType = resultType.getElementType(); + if (!isa(srcElemType) || !isa(resultElemType)) + return op.emitOpError( + "currently requires integer source and result element types"); + if (srcType.getElementCount() != resultType.getElementCount() * 2) + return op.emitOpError( + "requires source element count to be twice the result element count"); + unsigned srcWidth = getIntOrFloatBitWidth(srcElemType); + unsigned resultWidth = getIntOrFloatBitWidth(resultElemType); + if (!srcWidth || srcWidth * 2 != resultWidth) + return op.emitOpError( + "requires result element width to be twice the source element width"); + return success(); +} + +LogicalResult VsunpackOp::verify() { return verifyUnpackVecOp(*this); } +LogicalResult VzunpackOp::verify() { return verifyUnpackVecOp(*this); } + +static bool isSupportedCmpMode(StringRef mode) { + return mode == "eq" || mode == "ne" || mode == "lt" || mode == "le" || + mode == "gt" || mode == "ge"; +} + +LogicalResult VcmpOp::verify() { + if (failed(verifyVRegTypeLike(*this, getSrc0().getType(), "src0 type")) || + failed(verifyVRegTypeLike(*this, getSrc1().getType(), "src1 type")) || + failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type")) || + failed(verifyMaskTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + if (getSrc0().getType() != getSrc1().getType()) + return emitOpError("requires src0 and src1 to have identical vector types"); + if (!isSupportedCmpMode(getCmpMode())) + return emitOpError("requires cmp_mode to be one of eq/ne/lt/le/gt/ge"); + return success(); +} + +LogicalResult VcmpsOp::verify() { + if (failed(verifyVRegTypeLike(*this, getSrc().getType(), "src type")) || + failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type")) || + failed(verifyMaskTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + auto srcType = cast(getSrc().getType()); + Type srcElementType = srcType.getElementType(); + Type scalarType = getScalar().getType(); + if (!isCompatibleScalarForSemanticType(srcElementType, scalarType)) + return emitOpError("requires scalar type to match source element type"); + if (!isSupportedCmpMode(getCmpMode())) + return emitOpError("requires cmp_mode to be one of eq/ne/lt/le/gt/ge"); + return success(); +} + +ParseResult VtrcOp::parse(OpAsmParser &parser, OperationState &result) { + OpAsmParser::UnresolvedOperand input; + OpAsmParser::UnresolvedOperand mask; + std::string roundModeToken; + NamedAttrList attrs; + Type inputType, maskType, resultType; + + if (parser.parseOperand(input) || parser.parseComma() || + parser.parseOperand(mask) || parser.parseComma() || + parser.parseKeywordOrString(&roundModeToken) || + parser.parseOptionalAttrDict(attrs) || + parser.parseColonType(inputType) || parser.parseComma() || + parser.parseType(maskType) || parser.parseArrow() || + parser.parseType(resultType)) + return failure(); + + auto normalized = normalizeRoundModeToken(roundModeToken); + if (!normalized || !isSupportedVtrcRoundMode(*normalized)) + return parser.emitError(parser.getCurrentLocation()) + << "round mode must be one of R/A/F/C/Z or " + "ROUND_R/ROUND_A/ROUND_F/ROUND_C/ROUND_Z"; + + attrs.set("round_mode", parser.getBuilder().getStringAttr(*normalized)); + result.addAttributes(attrs); + if (parser.resolveOperand(input, inputType, result.operands) || + parser.resolveOperand(mask, maskType, result.operands)) + return failure(); + result.addTypes(resultType); + return success(); +} + +void VtrcOp::print(OpAsmPrinter &printer) { + printer << ' ' << getInput() << ", " << getMask() << ", "; + Builder builder(getContext()); + auto normalized = normalizeRoundModeToken(getRoundMode()); + printer.printAttributeWithoutType( + builder.getStringAttr(normalized.value_or(getRoundMode()))); + printer.printOptionalAttrDict((*this)->getAttrs(), {"round_mode"}); + printer << " : " << getInput().getType() << ", " << getMask().getType() + << " -> " << getResult().getType(); +} + +LogicalResult VtrcOp::verify() { + auto inputType = dyn_cast(getInput().getType()); + auto resultType = dyn_cast(getResult().getType()); + if (!inputType || !resultType) + return emitOpError("input and result must be !pto.vreg<...>"); + if (failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type"))) + return failure(); + if (inputType != resultType) + return emitOpError("requires input and result to have identical vreg type"); + auto elemType = inputType.getElementType(); + if (!(elemType.isF16() || elemType.isF32() || elemType.isBF16())) + return emitOpError("requires f16/f32/bf16 vector element type"); + auto expectedGranularity = getVdupMaskGranularity(elemType); + if (!expectedGranularity) + return emitOpError("requires element type with supported predicate granularity"); + if (failed(verifyMaskTypeWithGranularityLike(*this, getMask().getType(), + "mask type", + *expectedGranularity))) + return failure(); + auto normalized = normalizeRoundModeToken(getRoundMode()); + if (!normalized || !isSupportedVtrcRoundMode(*normalized)) + return emitOpError("round mode must be one of R/A/F/C/Z"); + return success(); +} + +ParseResult VcvtOp::parse(OpAsmParser &parser, OperationState &result) { + OpAsmParser::UnresolvedOperand input; + OpAsmParser::UnresolvedOperand mask; + NamedAttrList attrs; + Type inputType, maskType, resultType; + + if (parser.parseOperand(input) || parser.parseComma() || + parser.parseOperand(mask) || parser.parseOptionalAttrDict(attrs) || + parser.parseColonType(inputType) || parser.parseComma() || + parser.parseType(maskType) || parser.parseArrow() || + parser.parseType(resultType)) + return failure(); + + Attribute legacyRndAttr = attrs.get("round_mode"); + Attribute rndAttr = attrs.get("rnd"); + if (legacyRndAttr && rndAttr) + return parser.emitError(parser.getCurrentLocation()) + << "rnd and round_mode cannot be specified together"; + + auto normalizeNamedStringAttr = + [&](StringRef sourceName, StringRef canonicalName, + auto normalizeFn) -> ParseResult { + Attribute rawAttr = attrs.get(sourceName); + if (!rawAttr) + return success(); + auto strAttr = dyn_cast(rawAttr); + if (!strAttr) + return parser.emitError(parser.getCurrentLocation()) + << sourceName << " must be a string literal"; + auto normalized = normalizeFn(strAttr.getValue()); + if (!normalized) + return parser.emitError(parser.getCurrentLocation()) + << sourceName << " has unsupported value '" << strAttr.getValue() + << "'"; + attrs.erase(sourceName); + attrs.set(canonicalName, parser.getBuilder().getStringAttr(*normalized)); + return success(); + }; + + if (failed(normalizeNamedStringAttr("round_mode", "rnd", + normalizeRoundModeToken)) || + failed(normalizeNamedStringAttr("rnd", "rnd", normalizeRoundModeToken)) || + failed(normalizeNamedStringAttr("sat", "sat", normalizeSaturationToken)) || + failed(normalizeNamedStringAttr("part", "part", normalizeVcvtPartToken))) + return failure(); + + result.addAttributes(attrs); + if (parser.resolveOperand(input, inputType, result.operands) || + parser.resolveOperand(mask, maskType, result.operands)) + return failure(); + result.addTypes(resultType); + return success(); +} + +void VcvtOp::print(OpAsmPrinter &printer) { + printer << ' ' << getInput() << ", " << getMask(); + printer.printOptionalAttrDict((*this)->getAttrs()); + printer << " : " << getInput().getType() << ", " << getMask().getType() + << " -> " << getResult().getType(); +} + +LogicalResult VcvtOp::verify() { + auto inputType = dyn_cast(getInput().getType()); + auto resultType = dyn_cast(getResult().getType()); + if (!inputType || !resultType) + return emitOpError("input and result must be !pto.vreg<...>"); + if (failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type"))) + return failure(); + + VcvtElemKind inputElemKind = classifyVcvtElemType(inputType.getElementType()); + VcvtElemKind resultElemKind = classifyVcvtElemType(resultType.getElementType()); + auto contract = lookupVcvtContract(inputElemKind, resultElemKind); + if (!contract) + return emitOpError("unsupported vcvt source/result element type pair"); + + auto inputElemBits = getVcvtElemBitWidth(inputElemKind); + auto resultElemBits = getVcvtElemBitWidth(resultElemKind); + if (!inputElemBits || !resultElemBits) + return emitOpError("could not determine vcvt element bit width"); + unsigned maskBitWidth = std::min(*inputElemBits, 32u); + StringRef expectedMaskGranularity = maskBitWidth == 8 ? "b8" + : maskBitWidth == 16 ? "b16" + : maskBitWidth == 32 ? "b32" + : ""; + if (expectedMaskGranularity.empty()) + return emitOpError("could not determine vcvt mask granularity"); + if (failed(verifyMaskTypeWithGranularityLike( + *this, getMask().getType(), "mask type", expectedMaskGranularity))) + return failure(); + if (inputType.getElementCount() * static_cast(*inputElemBits) != + resultType.getElementCount() * static_cast(*resultElemBits)) { + return emitOpError("requires source and result vectors to carry the same " + "total number of bits"); + } + + if (getRndAttr()) { + StringRef roundMode = *getRnd(); + if (!normalizeRoundModeToken(roundMode)) + return emitOpError("rnd must be one of R/A/F/C/Z/O"); + } + if (static_cast(getRndAttr()) != contract->requiresRnd) { + return contract->requiresRnd ? emitOpError("requires rnd attr for this vcvt type pair") + : emitOpError("rnd attr is not valid for this vcvt type pair"); + } + + if (getSatAttr()) { + StringRef sat = *getSat(); + if (!normalizeSaturationToken(sat)) + return emitOpError("sat must be SAT or NOSAT"); + } + if (static_cast(getSatAttr()) != contract->requiresSat) { + return contract->requiresSat ? emitOpError("requires sat attr for this vcvt type pair") + : emitOpError("sat attr is not valid for this vcvt type pair"); + } + + if (getPartAttr()) { + StringRef part = *getPart(); + auto normalizedPart = normalizeVcvtPartToken(part); + if (!normalizedPart) + return emitOpError("part must be one of EVEN/ODD/P0/P1/P2/P3"); + auto partFamily = classifyVcvtPartFamily(*inputElemBits, *resultElemBits); + if (!partFamily) + return emitOpError("part attr is not supported for this vcvt width relation"); + if (!isValidVcvtPartForFamily(*normalizedPart, *partFamily)) { + switch (*partFamily) { + case VcvtPartFamily::EvenOdd: + return emitOpError("part must be EVEN or ODD for 8/16 and 16/32 vcvt forms"); + case VcvtPartFamily::Packed4: + return emitOpError("part must be P0, P1, P2, or P3 for 8/32 vcvt forms"); + } + } + } + if (static_cast(getPartAttr()) != contract->requiresPart) { + return contract->requiresPart ? emitOpError("requires part attr for this vcvt type pair") + : emitOpError("part attr is not valid for this vcvt type pair"); + } + + return success(); +} + +LogicalResult VbitcastOp::verify() { + auto inputType = dyn_cast(getInput().getType()); + auto resultType = dyn_cast(getResult().getType()); + if (!inputType || !resultType) + return emitOpError("input and result must be !pto.vreg<...>"); + + auto getStorageBits = [](VRegType type) -> std::optional { + Type elementType = type.getElementType(); + if (auto intType = dyn_cast(elementType)) + return type.getElementCount() * static_cast(intType.getWidth()); + if (auto floatType = dyn_cast(elementType)) + return type.getElementCount() * + static_cast(floatType.getWidth()); + return std::nullopt; + }; + + auto inputBits = getStorageBits(inputType); + auto resultBits = getStorageBits(resultType); + if (!inputBits || !resultBits) + return emitOpError("requires integer or floating-point vreg element type"); + if (*inputBits != *resultBits) { + return emitOpError("requires source and result vectors to carry the same " + "total number of bits"); + } + + return success(); +} + +LogicalResult PdintlvB8Op::verify() { + if (failed(verifyMaskTypeWithGranularityLike(*this, getLhs().getType(), + "lhs type", "b8")) || + failed(verifyMaskTypeWithGranularityLike(*this, getRhs().getType(), + "rhs type", "b8")) || + failed(verifyMaskTypeWithGranularityLike(*this, getLow().getType(), + "low type", "b8")) || + failed(verifyMaskTypeWithGranularityLike(*this, getHigh().getType(), + "high type", "b8"))) + return failure(); + return success(); +} + +LogicalResult PdintlvB16Op::verify() { + if (failed(verifyMaskTypeWithGranularityLike(*this, getLhs().getType(), + "lhs type", "b16")) || + failed(verifyMaskTypeWithGranularityLike(*this, getRhs().getType(), + "rhs type", "b16")) || + failed(verifyMaskTypeWithGranularityLike(*this, getLow().getType(), + "low type", "b16")) || + failed(verifyMaskTypeWithGranularityLike(*this, getHigh().getType(), + "high type", "b16"))) + return failure(); + return success(); +} + +LogicalResult PdintlvB32Op::verify() { + if (failed(verifyMaskTypeWithGranularityLike(*this, getLhs().getType(), + "lhs type", "b32")) || + failed(verifyMaskTypeWithGranularityLike(*this, getRhs().getType(), + "rhs type", "b32")) || + failed(verifyMaskTypeWithGranularityLike(*this, getLow().getType(), + "low type", "b32")) || + failed(verifyMaskTypeWithGranularityLike(*this, getHigh().getType(), + "high type", "b32"))) + return failure(); + return success(); +} + +LogicalResult PintlvB8Op::verify() { + if (failed(verifyMaskTypeWithGranularityLike(*this, getLhs().getType(), + "lhs type", "b8")) || + failed(verifyMaskTypeWithGranularityLike(*this, getRhs().getType(), + "rhs type", "b8")) || + failed(verifyMaskTypeWithGranularityLike(*this, getLow().getType(), + "low type", "b8")) || + failed(verifyMaskTypeWithGranularityLike(*this, getHigh().getType(), + "high type", "b8"))) + return failure(); + return success(); +} + +LogicalResult PintlvB16Op::verify() { + if (failed(verifyMaskTypeWithGranularityLike(*this, getLhs().getType(), + "lhs type", "b16")) || + failed(verifyMaskTypeWithGranularityLike(*this, getRhs().getType(), + "rhs type", "b16")) || + failed(verifyMaskTypeWithGranularityLike(*this, getLow().getType(), + "low type", "b16")) || + failed(verifyMaskTypeWithGranularityLike(*this, getHigh().getType(), + "high type", "b16"))) + return failure(); + return success(); +} + +LogicalResult PintlvB32Op::verify() { + if (failed(verifyMaskTypeWithGranularityLike(*this, getLhs().getType(), + "lhs type", "b32")) || + failed(verifyMaskTypeWithGranularityLike(*this, getRhs().getType(), + "rhs type", "b32")) || + failed(verifyMaskTypeWithGranularityLike(*this, getLow().getType(), + "low type", "b32")) || + failed(verifyMaskTypeWithGranularityLike(*this, getHigh().getType(), + "high type", "b32"))) + return failure(); + return success(); +} + +LogicalResult VintlvOp::verify() { return verifyPairVecResults(*this); } +LogicalResult VdintlvOp::verify() { return verifyPairVecResults(*this); } +LogicalResult Vintlvv2Op::verify() { return verifyPartVecOp(*this); } +LogicalResult Vdintlvv2Op::verify() { return verifyPartVecOp(*this); } + +LogicalResult VmullOp::verify() { + if (failed(verifyPairVecResults(*this)) || + failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type"))) + return failure(); + auto lhsType = cast(getLhs().getType()); + auto lhsElemType = dyn_cast(lhsType.getElementType()); + if (!lhsElemType) + return emitOpError("requires integer vector element type"); + if (lhsElemType.getWidth() != 32) + return emitOpError("currently requires 32-bit integer vector elements"); + return success(); +} + +LogicalResult VmulaOp::verify() { + if (failed(verifyVRegTypeLike(*this, getAcc().getType(), "acc type")) || + failed(verifyVRegTypeLike(*this, getLhs().getType(), "lhs type")) || + failed(verifyVRegTypeLike(*this, getRhs().getType(), "rhs type")) || + failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type")) || + failed(verifyVRegTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + if (getAcc().getType() != getLhs().getType() || + getAcc().getType() != getRhs().getType() || + getAcc().getType() != getResult().getType()) + return emitOpError("requires acc, lhs, rhs, and result to share one vector type"); + return success(); +} + +template +static LogicalResult verifyBinaryVecNoMaskOp(BinaryVecNoMaskOp op) { + if (failed(verifyVRegTypeLike(op, op.getLhs().getType(), "lhs type")) || + failed(verifyVRegTypeLike(op, op.getRhs().getType(), "rhs type")) || + failed(verifyVRegTypeLike(op, op.getResult().getType(), "result type"))) + return failure(); + if (op.getLhs().getType() != op.getRhs().getType() || + op.getLhs().getType() != op.getResult().getType()) + return op.emitOpError("requires lhs, rhs, and result to share one vector type"); + return success(); +} + +template +static LogicalResult verifyFloatBinaryVecNoMaskOp(BinaryVecNoMaskOp op) { + if (failed(verifyBinaryVecNoMaskOp(op))) + return failure(); + auto lhsType = cast(op.getLhs().getType()); + Type elemType = lhsType.getElementType(); + if (!elemType.isF16() && !elemType.isF32()) + return op.emitOpError("requires f16 or f32 vector element type"); + return success(); +} + +template +static LogicalResult verifyFloatBinaryVecMaskOp(BinaryVecMaskOp op) { + if (failed(verifyVRegTypeLike(op, op.getLhs().getType(), "lhs type")) || + failed(verifyVRegTypeLike(op, op.getRhs().getType(), "rhs type")) || + failed(verifyMaskTypeLike(op, op.getMask().getType(), "mask type")) || + failed(verifyVRegTypeLike(op, op.getResult().getType(), "result type"))) + return failure(); + if (op.getLhs().getType() != op.getRhs().getType() || + op.getLhs().getType() != op.getResult().getType()) + return op.emitOpError("requires lhs, rhs, and result to share one vector type"); + auto lhsType = cast(op.getLhs().getType()); + Type elemType = lhsType.getElementType(); + if (!elemType.isF16() && !elemType.isF32()) + return op.emitOpError("requires f16 or f32 vector element type"); + return success(); +} + +LogicalResult VpreluOp::verify() { return verifyFloatBinaryVecMaskOp(*this); } +LogicalResult VexpdifOp::verify() { + if (failed(verifyVRegTypeLike(*this, getInput().getType(), "input type")) || + failed(verifyVRegTypeLike(*this, getMax().getType(), "max type")) || + failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type")) || + failed(verifyVRegTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + + auto inputType = cast(getInput().getType()); + auto maxType = cast(getMax().getType()); + auto resultType = cast(getResult().getType()); + if (inputType != maxType) + return emitOpError("requires input and max to share one vector type"); + + Type inputElemType = inputType.getElementType(); + if (!inputElemType.isF16() && !inputElemType.isF32()) + return emitOpError("requires f16 or f32 input vector element type"); + auto expectedGranularity = getVdupMaskGranularity(inputElemType); + if (!expectedGranularity) + return emitOpError("requires input element type with supported predicate granularity"); + if (failed(verifyMaskTypeWithGranularityLike(*this, getMask().getType(), + "mask type", + *expectedGranularity))) + return failure(); + if (!resultType.getElementType().isF32()) + return emitOpError("requires f32 result vector element type"); + + auto inputBits = getVRegStorageBitWidth(inputType); + auto resultBits = getVRegStorageBitWidth(resultType); + if (!inputBits || !resultBits || *inputBits != *resultBits) + return emitOpError( + "requires source and result to preserve total vector storage width"); + + StringRef part = getPart(); + if (part != "EVEN" && part != "ODD") + return emitOpError("part must be EVEN or ODD"); + return success(); +} + +LogicalResult VaxpyOp::verify() { + if (failed(verifyVRegTypeLike(*this, getSrc0().getType(), "src0 type")) || + failed(verifyVRegTypeLike(*this, getSrc1().getType(), "src1 type")) || + failed(verifyVRegTypeLike(*this, getResult().getType(), "result type")) || + failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type"))) + return failure(); + auto src0Type = cast(getSrc0().getType()); + auto src1Type = cast(getSrc1().getType()); + auto resultType = cast(getResult().getType()); + if (src0Type != src1Type || src0Type != resultType) + return emitOpError("requires src0, src1, and result to share one vector type"); + Type elemType = src0Type.getElementType(); + if (!elemType.isF16() && !elemType.isF32()) + return emitOpError("requires f16 or f32 vector element type"); + auto expectedGranularity = getVdupMaskGranularity(elemType); + if (!expectedGranularity) + return emitOpError("requires element type with supported predicate granularity"); + if (failed(verifyMaskTypeWithGranularityLike(*this, getMask().getType(), + "mask type", + *expectedGranularity))) + return failure(); + if (getAlpha().getType() != elemType) + return emitOpError("requires alpha type to match vector element type"); + return success(); +} + +template +static LogicalResult verifyFusedConvVecOp(ConvOp op) { + if (failed(verifyVRegTypeLike(op, op.getLhs().getType(), "lhs type")) || + failed(verifyVRegTypeLike(op, op.getRhs().getType(), "rhs type")) || + failed(verifyVRegTypeLike(op, op.getResult().getType(), "result type"))) + return failure(); + auto lhsType = cast(op.getLhs().getType()); + auto rhsType = cast(op.getRhs().getType()); + auto resultType = cast(op.getResult().getType()); + if (lhsType != rhsType) + return op.emitOpError("requires lhs and rhs to share one vector type"); + if (!isIntegerOrFloatLike(lhsType.getElementType()) || + !isIntegerOrFloatLike(resultType.getElementType())) + return op.emitOpError( + "requires integer or floating-point vector element types"); + auto lhsBits = getVRegStorageBitWidth(lhsType); + auto resultBits = getVRegStorageBitWidth(resultType); + if (!lhsBits || !resultBits || *lhsBits != *resultBits) + return op.emitOpError( + "requires source and result to preserve total vector storage width"); + return success(); +} + +LogicalResult VaddreluconvOp::verify() { + return verifyFusedConvVecOp(*this); +} +LogicalResult VmulconvOp::verify() { return verifyFusedConvVecOp(*this); } + +void Vldsx2Op::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult Vldsx2Op::verify() { + if (!isBufferLike(getSource().getType())) + return emitOpError("requires a pointer-like source"); + if (classifyMemoryRole(getSource().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed source"); + if (!getOffset().getType().isIndex()) + return emitOpError("requires index offset"); + if (failed(verifyVRegTypeLike(*this, getLow().getType(), "low result type")) || + failed(verifyVRegTypeLike(*this, getHigh().getType(), "high result type"))) + return failure(); + if (getLow().getType() != getHigh().getType()) + return emitOpError("requires low/high results to share one vector type"); + if (!isSupportedVldx2DistToken(getDist())) + return emitOpError("requires a supported x2 load distribution token"); + return success(); +} + +void VstsOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getValueMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +template +static LogicalResult verifyVstsCommon(StoreOp op) { + if (failed(verifyVRegTypeLike(op, op.getValue().getType(), "value type"))) + return failure(); + + if (!isBufferLike(op.getDestination().getType())) + return op.emitOpError("requires a pointer-like destination"); + + MemoryRole destinationRole = classifyMemoryRole(op.getDestination().getType()); + if (destinationRole == MemoryRole::GM) + return op.emitOpError("requires a UB-backed destination"); + + if (std::optional dist = op.getDist(); + dist && !isSupportedVstsDistToken(*dist)) { + return op.emitOpError("requires a supported store distribution token"); + } + if (std::optional dist = op.getDist()) { + if (std::optional granularity = getVstsMaskGranularityOverride( + *dist, cast(op.getValue().getType()).getElementType())) { + if (failed(verifyMaskTypeWithGranularityLike(op, op.getMask().getType(), + "mask type", *granularity))) + return failure(); + } else if (failed(verifyMaskTypeLike(op, op.getMask().getType(), + "mask type"))) { + return failure(); + } + } else if (failed(verifyMaskTypeLike(op, op.getMask().getType(), + "mask type"))) { + return failure(); + } + + return success(); +} + +LogicalResult VstsOp::verify() { + if (failed(verifyVstsCommon(*this))) + return failure(); + if (std::optional mode = getOptionalPostModeAttr(getOperation()); + mode && !isSupportedPostMode(*mode)) + return emitOpError("requires mode to be POST_UPDATE or NO_POST_UPDATE"); + return success(); +} +void VstsPostOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getValueMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult VstsPostOp::verify() { + if (failed(verifyVstsCommon(*this))) + return failure(); + if (getUpdatedDestination().getType() != getDestination().getType()) + return emitOpError( + "requires updated destination result to match destination type"); + return success(); +} + +void Vstsx2Op::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getLowMutable()); + effects.emplace_back(MemoryEffects::Read::get(), &getHighMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult Vstsx2Op::verify() { + if (failed(verifyVRegTypeLike(*this, getLow().getType(), "low value type")) || + failed(verifyVRegTypeLike(*this, getHigh().getType(), "high value type")) || + failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type"))) + return failure(); + if (getLow().getType() != getHigh().getType()) + return emitOpError("requires low/high values to share one vector type"); + if (!isBufferLike(getDestination().getType())) + return emitOpError("requires a pointer-like destination"); + if (classifyMemoryRole(getDestination().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed destination"); + if (!getOffset().getType().isIndex()) + return emitOpError("requires index offset"); + if (!isSupportedVstsx2DistToken(getDist())) + return emitOpError("requires a supported x2 store distribution token"); + return success(); +} + +void VscatterOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getValueMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult VscatterOp::verify() { + if (failed(verifyVRegTypeLike(*this, getValue().getType(), "value type"))) + return failure(); + if (!isBufferLike(getDestination().getType())) + return emitOpError("requires a pointer-like destination"); + auto offsetsType = dyn_cast(getOffsets().getType()); + auto valueType = dyn_cast(getValue().getType()); + if (!offsetsType || !valueType) + return emitOpError("value and offsets must be !pto.vreg<...>"); + auto offsetsElemType = dyn_cast(offsetsType.getElementType()); + if (!offsetsElemType) + return emitOpError("offset vector must use integer element type"); + if (offsetsElemType.getWidth() != 32) + return emitOpError("currently requires 32-bit offset vector elements"); + if (offsetsType.getElementCount() != valueType.getElementCount()) + return emitOpError("offset and value vectors must have the same element count"); + if (failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type"))) + return failure(); + MemoryRole destinationRole = classifyMemoryRole(getDestination().getType()); + if (destinationRole == MemoryRole::GM) + return emitOpError("requires a UB-backed destination"); + return success(); +} + +void VsldbOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult VsldbOp::verify() { + if (!isBufferLike(getSource().getType())) + return emitOpError("requires a pointer-like source"); + if (classifyMemoryRole(getSource().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed source"); + if (failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type")) || + failed(verifyVRegTypeLike(*this, getResult().getType(), "result type"))) + return failure(); + if (!getBlockStride().getType().isSignlessInteger(16)) + return emitOpError("requires block_stride to be i16"); + if (!getRepeatStride().getType().isSignlessInteger(16)) + return emitOpError("requires repeat_stride to be i16"); + return success(); +} + +void PstsOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getValueMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +void PstiOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getValueMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult PstiOp::verify() { + if (failed(verifyMaskTypeLike(*this, getValue().getType(), "value type"))) + return failure(); + if (!isBufferLike(getDestination().getType())) + return emitOpError("requires a pointer-like destination"); + if (classifyMemoryRole(getDestination().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed destination"); + if (!matchPattern(getOffset(), m_Constant())) + return emitOpError("requires offset to be a constant index immediate"); + if (!isSupportedPredicateStoreDist(getDist())) + return emitOpError("requires predicate store dist to be NORM or PK"); + return success(); +} + +LogicalResult PstsOp::verify() { + if (failed(verifyMaskTypeLike(*this, getValue().getType(), "value type"))) + return failure(); + if (!isBufferLike(getDestination().getType())) + return emitOpError("requires a pointer-like destination"); + MemoryRole destinationRole = classifyMemoryRole(getDestination().getType()); + if (destinationRole == MemoryRole::GM) + return emitOpError("requires a UB-backed destination"); + if (!getOffset().getType().isIndex()) + return emitOpError("requires index offset"); + if (!isSupportedPredicateStoreDist(getDist())) + return emitOpError("requires predicate store dist to be NORM or PK"); + return success(); +} + +void VsstbOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getValueMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult VsstbOp::verify() { + if (failed(verifyVRegTypeLike(*this, getValue().getType(), "value type")) || + failed(verifyMaskTypeLike(*this, getMask().getType(), "mask type"))) + return failure(); + if (!isBufferLike(getDestination().getType())) + return emitOpError("requires a pointer-like destination"); + if (classifyMemoryRole(getDestination().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed destination"); + if (!getBlockStride().getType().isSignlessInteger(16)) + return emitOpError("requires block_stride to be i16"); + if (!getRepeatStride().getType().isSignlessInteger(16)) + return emitOpError("requires repeat_stride to be i16"); + return success(); +} + +void VstasOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getValueMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult VstasOp::verify() { + if (failed(verifyStoreAlignChain(getValue(), *this, "value type"))) + return failure(); + if (!isBufferLike(getDestination().getType())) + return emitOpError("requires a pointer-like destination"); + if (classifyMemoryRole(getDestination().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed destination"); + return success(); +} + +void VstarOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getValueMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult VstarOp::verify() { + if (failed(verifyStoreAlignChain(getValue(), *this, "value type"))) + return failure(); + if (!isBufferLike(getDestination().getType())) + return emitOpError("requires a pointer-like destination"); + if (classifyMemoryRole(getDestination().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed destination"); + return success(); +} + +void PstuOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getAlignInMutable()); + effects.emplace_back(MemoryEffects::Read::get(), &getValueMutable()); + effects.emplace_back(MemoryEffects::Read::get(), &getBaseMutable()); +} + +LogicalResult PstuOp::verify() { + if (failed(verifyStoreAlignChain(getAlignIn(), *this, "align_in type")) || + failed(verifyMaskTypeLike(*this, getValue().getType(), "value type")) || + failed(verifyAlignTypeLike(*this, getAlignOut().getType(), "align_out type"))) + return failure(); + if (!isBufferLike(getBase().getType()) || !isBufferLike(getBaseOut().getType())) + return emitOpError("requires pointer-like base and base_out"); + if (getBase().getType() != getBaseOut().getType()) + return emitOpError("requires base and base_out to have identical types"); + if (classifyMemoryRole(getBase().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed base"); + auto baseType = cast(getBase().getType()); + auto maskType = cast(getValue().getType()); + auto elemType = dyn_cast(baseType.getElementType()); + if (!elemType || elemType.isSigned() || (elemType.getWidth() != 16 && elemType.getWidth() != 32)) + return emitOpError("requires ui16/ui32 UB base type"); + if (maskType.isB16() && elemType.getWidth() != 16) + return emitOpError("requires !pto.mask to pair with !pto.ptr"); + if (maskType.isB32() && elemType.getWidth() != 32) + return emitOpError("requires !pto.mask to pair with !pto.ptr"); + return success(); +} + +void VstusOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getAlignInMutable()); + effects.emplace_back(MemoryEffects::Read::get(), &getValueMutable()); + effects.emplace_back(MemoryEffects::Read::get(), &getBaseMutable()); +} + +LogicalResult VstusOp::verify() { + if (failed(verifyStoreAlignChain(getAlignIn(), *this, "align_in type")) || + failed(verifyVRegTypeLike(*this, getValue().getType(), "value type")) || + failed(verifyAlignTypeLike(*this, getAlignOut().getType(), "align_out type"))) + return failure(); + if (!isBufferLike(getBase().getType())) + return emitOpError("requires a pointer-like base"); + if (classifyMemoryRole(getBase().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed base"); + return success(); +} + +void VsturOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getAlignInMutable()); + effects.emplace_back(MemoryEffects::Read::get(), &getValueMutable()); + effects.emplace_back(MemoryEffects::Read::get(), &getBaseMutable()); +} + +LogicalResult VsturOp::verify() { + if (failed(verifyStoreAlignChain(getAlignIn(), *this, "align_in type")) || + failed(verifyVRegTypeLike(*this, getValue().getType(), "value type")) || + failed(verifyAlignTypeLike(*this, getAlignOut().getType(), "align_out type"))) + return failure(); + if (!isBufferLike(getBase().getType())) + return emitOpError("requires a pointer-like base"); + if (classifyMemoryRole(getBase().getType()) == MemoryRole::GM) + return emitOpError("requires a UB-backed base"); + if (!isSupportedPostMode(getMode())) + return emitOpError("requires mode to be POST_UPDATE or NO_POST_UPDATE"); + return success(); +} + +void CopyUbufToGmOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult CopyUbufToGmOp::verify() { + return verifyCopyUbufToGmOp(*this, false); +} + +void MteUbGmOp::build(OpBuilder &builder, OperationState &state, Value source, + Value destination, Value lenBurst, pto::DmaLoopConfig nburst, + llvm::ArrayRef loops) { + state.addOperands({source, destination, lenBurst, nburst.count, + nburst.srcStride, nburst.dstStride}); + for (const pto::DmaLoopConfig &loop : loops) + state.addOperands(loop.count); + for (const pto::DmaLoopConfig &loop : loops) + state.addOperands(loop.srcStride); + for (const pto::DmaLoopConfig &loop : loops) + state.addOperands(loop.dstStride); + + state.addAttribute( + getOperandSegmentSizeAttr(), + builder.getDenseI32ArrayAttr( + {1, 1, 1, 1, 1, 1, + static_cast(loops.size()), + static_cast(loops.size()), + static_cast(loops.size())})); +} + +void MteUbGmOp::build(OpBuilder &builder, OperationState &state, Value source, + Value destination, Value lenBurst, pto::DmaLoopConfig nburst, + std::optional loop1, + std::optional loop2) { + SmallVector loops; + if (loop1) + loops.push_back(*loop1); + if (loop2) + loops.push_back(*loop2); + build(builder, state, source, destination, lenBurst, nburst, loops); +} + +ParseResult MteUbGmOp::parse(OpAsmParser &parser, OperationState &result) { + OpAsmParser::UnresolvedOperand source, destination, lenBurst; + SmallVector nburstOperands; + SmallVector loopCountOperands; + SmallVector loopSrcStrideOperands; + SmallVector loopDstStrideOperands; + if (parseRequiredOperandWithComma(parser, source) || + parseRequiredOperandWithComma(parser, destination) || + parser.parseOperand(lenBurst) || + parseDmaTripleGroup(parser, "nburst", nburstOperands)) + return failure(); + while (true) { + StringRef parsedKeyword; + SmallVector loopGroupOperands; + if (parseOptionalDmaTripleGroupAlias(parser, {"loop", "loop1", "loop2"}, + parsedKeyword, loopGroupOperands)) + return failure(); + if (parsedKeyword.empty()) + break; + loopCountOperands.push_back(loopGroupOperands[0]); + loopSrcStrideOperands.push_back(loopGroupOperands[1]); + loopDstStrideOperands.push_back(loopGroupOperands[2]); + } + + if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()) + return failure(); + + Type sourceType, destinationType, lenBurstType; + SmallVector nburstTypes, loopCountTypes, loopSrcStrideTypes, + loopDstStrideTypes; + if (parser.parseType(sourceType) || parser.parseComma() || + parser.parseType(destinationType) || parser.parseComma() || + parser.parseType(lenBurstType) || parser.parseComma() || + parseDmaTripleTypes(parser, nburstTypes)) + return failure(); + while (succeeded(parser.parseOptionalComma())) { + StringRef keyword; + if (parser.parseKeyword(&keyword)) + return failure(); + if (isDmaLoopKeyword(keyword)) { + SmallVector loopGroupTypes; + if (parseDmaTripleTypes(parser, loopGroupTypes)) + return failure(); + loopCountTypes.push_back(loopGroupTypes[0]); + loopSrcStrideTypes.push_back(loopGroupTypes[1]); + loopDstStrideTypes.push_back(loopGroupTypes[2]); + continue; + } + return parser.emitError(parser.getCurrentLocation(), + "expected 'loop'"); + } + + int32_t loopGroupCount = static_cast(loopCountOperands.size()); + if (loopCountOperands.size() != loopSrcStrideOperands.size() || + loopCountOperands.size() != loopDstStrideOperands.size() || + loopCountTypes.size() != loopSrcStrideTypes.size() || + loopCountTypes.size() != loopDstStrideTypes.size()) + return parser.emitError(parser.getCurrentLocation(), + "requires each loop group to provide count, src stride, and dst stride"); + if (loopCountOperands.size() != loopCountTypes.size()) + return parser.emitError(parser.getCurrentLocation(), + "requires loop operand and type groups to match"); + + auto &segments = + result.getOrAddProperties().operandSegmentSizes; + llvm::copy(ArrayRef{1, 1, 1, 1, 1, 1, + loopGroupCount, loopGroupCount, loopGroupCount}, + segments.begin()); + + if (parser.resolveOperand(source, sourceType, result.operands) || + parser.resolveOperand(destination, destinationType, result.operands) || + parser.resolveOperand(lenBurst, lenBurstType, result.operands) || + parser.resolveOperands(nburstOperands, nburstTypes, parser.getCurrentLocation(), + result.operands) || + parser.resolveOperands(loopCountOperands, loopCountTypes, + parser.getCurrentLocation(), result.operands) || + parser.resolveOperands(loopSrcStrideOperands, loopSrcStrideTypes, + parser.getCurrentLocation(), result.operands) || + parser.resolveOperands(loopDstStrideOperands, loopDstStrideTypes, + parser.getCurrentLocation(), + result.operands)) + return failure(); + return success(); +} + +void MteUbGmOp::print(OpAsmPrinter &printer) { + printer << " " << getSource() << ", " << getDestination() << ", " + << getLenBurst(); + printDmaTripleGroup(printer, "nburst", getNBurst(), getNburstSrcStride(), + getNburstDstStride()); + for (auto [count, srcStride, dstStride] : + llvm::zip(getLoopCounts(), getLoopSrcStrides(), getLoopDstStrides())) + printDmaTripleGroup(printer, "loop", count, srcStride, dstStride); + printer.printOptionalAttrDict((*this)->getAttrs()); + printer << " : " << getSource().getType() << ", " << getDestination().getType() + << ", " << getLenBurst().getType() << ", " << getNBurst().getType() + << ", " << getNburstSrcStride().getType() + << ", " + << getNburstDstStride().getType(); + for (auto [count, srcStride, dstStride] : + llvm::zip(getLoopCounts(), getLoopSrcStrides(), getLoopDstStrides())) + printDmaTripleTypes(printer, "loop", count.getType(), srcStride.getType(), + dstStride.getType()); +} + +void MteUbGmOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult MteUbGmOp::verify() { + if (!isBufferLike(getSource().getType()) || + !isBufferLike(getDestination().getType())) + return emitOpError( + "requires typed !pto.ptr or memref source and destination"); + if (classifyMemoryRole(getSource().getType()) != MemoryRole::UB || + classifyMemoryRole(getDestination().getType()) != MemoryRole::GM) + return emitOpError("requires UB source and GM destination"); + int64_t sourceElemBytes = getBufferElementByteSize(getSource().getType()); + int64_t destinationElemBytes = + getBufferElementByteSize(getDestination().getType()); + if (sourceElemBytes <= 0 || destinationElemBytes <= 0) + return emitOpError( + "requires copy source and destination element types with known byte width"); + if (sourceElemBytes != destinationElemBytes) + return emitOpError( + "requires source and destination element byte widths to match"); + return verifyDmaLoadStoreLoopGroups( + getOperation(), getLoopCounts(), getLoopSrcStrides(), + getLoopDstStrides()); +} + +void MteGmL1Op::build(OpBuilder &builder, OperationState &state, Value source, + Value destination, Value lenBurst, + pto::DmaLoopConfig nburst, + llvm::ArrayRef loops) { + state.addOperands( + {source, destination, lenBurst, nburst.count, nburst.srcStride, + nburst.dstStride}); + for (const pto::DmaLoopConfig &loop : loops) + state.addOperands(loop.count); + for (const pto::DmaLoopConfig &loop : loops) + state.addOperands(loop.srcStride); + for (const pto::DmaLoopConfig &loop : loops) + state.addOperands(loop.dstStride); + + state.addAttribute( + getOperandSegmentSizeAttr(), + builder.getDenseI32ArrayAttr( + {1, 1, 1, 1, 1, 1, + static_cast(loops.size()), + static_cast(loops.size()), + static_cast(loops.size())})); +} + +void MteGmL1Op::build(OpBuilder &builder, OperationState &state, Value source, + Value destination, Value lenBurst, + pto::DmaLoopConfig nburst, + std::optional loop1, + std::optional loop2) { + SmallVector loops; + if (loop1) + loops.push_back(*loop1); + if (loop2) + loops.push_back(*loop2); + build(builder, state, source, destination, lenBurst, nburst, loops); +} + +void MteL1UbOp::build(OpBuilder &builder, OperationState &state, Value source, + Value destination, Value lenBurst, + pto::DmaLoopConfig nburst, + llvm::ArrayRef loops) { + state.addOperands( + {source, destination, lenBurst, nburst.count, nburst.srcStride, + nburst.dstStride}); + for (const pto::DmaLoopConfig &loop : loops) + state.addOperands(loop.count); + for (const pto::DmaLoopConfig &loop : loops) + state.addOperands(loop.srcStride); + for (const pto::DmaLoopConfig &loop : loops) + state.addOperands(loop.dstStride); + + state.addAttribute( + getOperandSegmentSizeAttr(), + builder.getDenseI32ArrayAttr( + {1, 1, 1, 1, 1, 1, + static_cast(loops.size()), + static_cast(loops.size()), + static_cast(loops.size())})); +} + +void MteL1UbOp::build(OpBuilder &builder, OperationState &state, Value source, + Value destination, Value lenBurst, + pto::DmaLoopConfig nburst, + std::optional loop1, + std::optional loop2) { + SmallVector loops; + if (loop1) + loops.push_back(*loop1); + if (loop2) + loops.push_back(*loop2); + build(builder, state, source, destination, lenBurst, nburst, loops); +} + +void MteGmL1FracOp::build(OpBuilder &builder, OperationState &state, + Value source, Value destination, + pto::CubeLoadFracMode mode, + pto::CubeLoadFracShapeConfig shape, + pto::CubeLoadFracSrcLayoutConfig srcLayout, + pto::CubeLoadFracDstGroupConfig dstGroup, + pto::CubeLoadFracCtrlConfig ctrl) { + state.addOperands({source, destination, shape.nValue, shape.dValue, + srcLayout.srcInnerStride}); + state.addOperands({dstGroup.groupCount, dstGroup.dstLoop2Stride, + dstGroup.dstLoop3Stride, dstGroup.dstLoop4Stride, + ctrl.l2CacheCtrl, ctrl.smallc0En}); + bool hasSrcOuterStride = srcLayout.srcOuterStride.has_value(); + if (hasSrcOuterStride) + state.addOperands(*srcLayout.srcOuterStride); + + state.addAttribute(getModeAttrName(state.name), + CubeLoadFracModeAttr::get(builder.getContext(), mode)); +} + +ParseResult MteGmL1Op::parse(OpAsmParser &parser, OperationState &result) { + OpAsmParser::UnresolvedOperand source, destination, lenBurst; + SmallVector nburstOperands; + SmallVector loopCountOperands; + SmallVector loopSrcStrideOperands; + SmallVector loopDstStrideOperands; + if (parseRequiredOperandWithComma(parser, source) || + parseRequiredOperandWithComma(parser, destination) || + parser.parseOperand(lenBurst) || + parseDmaTripleGroup(parser, "nburst", nburstOperands)) + return failure(); + while (true) { + StringRef parsedKeyword; + SmallVector loopGroupOperands; + if (parseOptionalDmaTripleGroupAlias(parser, {"loop", "loop1", "loop2"}, + parsedKeyword, loopGroupOperands)) + return failure(); + if (parsedKeyword.empty()) + break; + loopCountOperands.push_back(loopGroupOperands[0]); + loopSrcStrideOperands.push_back(loopGroupOperands[1]); + loopDstStrideOperands.push_back(loopGroupOperands[2]); + } + + if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()) + return failure(); + + Type sourceType, destinationType, lenBurstType; + SmallVector nburstTypes, loopCountTypes, loopSrcStrideTypes, + loopDstStrideTypes; + if (parser.parseType(sourceType) || parser.parseComma() || + parser.parseType(destinationType) || parser.parseComma() || + parser.parseType(lenBurstType) || parser.parseComma() || + parseDmaTripleTypes(parser, nburstTypes)) + return failure(); + while (succeeded(parser.parseOptionalComma())) { + StringRef keyword; + if (parser.parseKeyword(&keyword)) + return failure(); + if (!isDmaLoopKeyword(keyword)) + return parser.emitError(parser.getCurrentLocation(), "expected 'loop'"); + SmallVector loopGroupTypes; + if (parseDmaTripleTypes(parser, loopGroupTypes)) + return failure(); + loopCountTypes.push_back(loopGroupTypes[0]); + loopSrcStrideTypes.push_back(loopGroupTypes[1]); + loopDstStrideTypes.push_back(loopGroupTypes[2]); + } + + int32_t loopGroupCount = static_cast(loopCountOperands.size()); + if (loopCountOperands.size() != loopSrcStrideOperands.size() || + loopCountOperands.size() != loopDstStrideOperands.size() || + loopCountTypes.size() != loopSrcStrideTypes.size() || + loopCountTypes.size() != loopDstStrideTypes.size()) + return parser.emitError(parser.getCurrentLocation(), + "requires each loop group to provide count, src stride, and dst stride"); + if (loopCountOperands.size() != loopCountTypes.size()) + return parser.emitError(parser.getCurrentLocation(), + "requires loop operand and type groups to match"); + + auto &segments = + result.getOrAddProperties().operandSegmentSizes; + llvm::copy(ArrayRef{1, 1, 1, 1, 1, 1, + loopGroupCount, loopGroupCount, loopGroupCount}, + segments.begin()); + + if (parser.resolveOperand(source, sourceType, result.operands) || + parser.resolveOperand(destination, destinationType, result.operands) || + parser.resolveOperand(lenBurst, lenBurstType, result.operands) || + parser.resolveOperands(nburstOperands, nburstTypes, + parser.getCurrentLocation(), result.operands) || + parser.resolveOperands(loopCountOperands, loopCountTypes, + parser.getCurrentLocation(), result.operands) || + parser.resolveOperands(loopSrcStrideOperands, loopSrcStrideTypes, + parser.getCurrentLocation(), result.operands) || + parser.resolveOperands(loopDstStrideOperands, loopDstStrideTypes, + parser.getCurrentLocation(), result.operands)) + return failure(); + return success(); +} + +ParseResult MteL1UbOp::parse(OpAsmParser &parser, OperationState &result) { + OpAsmParser::UnresolvedOperand source, destination, lenBurst; + SmallVector nburstOperands; + SmallVector loopCountOperands; + SmallVector loopSrcStrideOperands; + SmallVector loopDstStrideOperands; + if (parseRequiredOperandWithComma(parser, source) || + parseRequiredOperandWithComma(parser, destination) || + parser.parseOperand(lenBurst) || + parseDmaTripleGroup(parser, "nburst", nburstOperands)) + return failure(); + while (true) { + StringRef parsedKeyword; + SmallVector loopGroupOperands; + if (parseOptionalDmaTripleGroupAlias(parser, {"loop", "loop1", "loop2"}, + parsedKeyword, loopGroupOperands)) + return failure(); + if (parsedKeyword.empty()) + break; + loopCountOperands.push_back(loopGroupOperands[0]); + loopSrcStrideOperands.push_back(loopGroupOperands[1]); + loopDstStrideOperands.push_back(loopGroupOperands[2]); + } + + if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()) + return failure(); + + Type sourceType, destinationType, lenBurstType; + SmallVector nburstTypes, loopCountTypes, loopSrcStrideTypes, + loopDstStrideTypes; + if (parser.parseType(sourceType) || parser.parseComma() || + parser.parseType(destinationType) || parser.parseComma() || + parser.parseType(lenBurstType) || parser.parseComma() || + parseDmaTripleTypes(parser, nburstTypes)) + return failure(); + while (succeeded(parser.parseOptionalComma())) { + StringRef keyword; + if (parser.parseKeyword(&keyword)) + return failure(); + if (!isDmaLoopKeyword(keyword)) + return parser.emitError(parser.getCurrentLocation(), "expected 'loop'"); + SmallVector loopGroupTypes; + if (parseDmaTripleTypes(parser, loopGroupTypes)) + return failure(); + loopCountTypes.push_back(loopGroupTypes[0]); + loopSrcStrideTypes.push_back(loopGroupTypes[1]); + loopDstStrideTypes.push_back(loopGroupTypes[2]); + } + + int32_t loopGroupCount = static_cast(loopCountOperands.size()); + if (loopCountOperands.size() != loopSrcStrideOperands.size() || + loopCountOperands.size() != loopDstStrideOperands.size() || + loopCountTypes.size() != loopSrcStrideTypes.size() || + loopCountTypes.size() != loopDstStrideTypes.size()) + return parser.emitError(parser.getCurrentLocation(), + "requires each loop group to provide count, src stride, and dst stride"); + if (loopCountOperands.size() != loopCountTypes.size()) + return parser.emitError(parser.getCurrentLocation(), + "requires loop operand and type groups to match"); + + auto &segments = + result.getOrAddProperties().operandSegmentSizes; + llvm::copy(ArrayRef{1, 1, 1, 1, 1, 1, + loopGroupCount, loopGroupCount, loopGroupCount}, + segments.begin()); + + if (parser.resolveOperand(source, sourceType, result.operands) || + parser.resolveOperand(destination, destinationType, result.operands) || + parser.resolveOperand(lenBurst, lenBurstType, result.operands) || + parser.resolveOperands(nburstOperands, nburstTypes, + parser.getCurrentLocation(), result.operands) || + parser.resolveOperands(loopCountOperands, loopCountTypes, + parser.getCurrentLocation(), result.operands) || + parser.resolveOperands(loopSrcStrideOperands, loopSrcStrideTypes, + parser.getCurrentLocation(), result.operands) || + parser.resolveOperands(loopDstStrideOperands, loopDstStrideTypes, + parser.getCurrentLocation(), result.operands)) + return failure(); + return success(); +} + +ParseResult MteGmL1FracOp::parse(OpAsmParser &parser, OperationState &result) { + OpAsmParser::UnresolvedOperand source, destination; + StringRef modeKeyword; + SmallVector shapeOperands; + SmallVector srcLayoutOperands; + SmallVector dstGroupOperands; + SmallVector ctrlOperands; + + if (parseRequiredOperandWithComma(parser, source) || + parseRequiredOperandWithComma(parser, destination) || + parser.parseKeyword(&modeKeyword) || + failed(parseCubeLoadFracModeKeyword(modeKeyword)) || parser.parseComma() || + parseFixedKeywordOperandGroup(parser, "shape", 2, shapeOperands) || + parser.parseComma() || + parseCubeLoadFracSrcLayoutGroup(parser, srcLayoutOperands) || + parser.parseComma() || + parseFixedKeywordOperandGroup(parser, "dst_group", 4, dstGroupOperands) || + parser.parseComma() || + parseFixedKeywordOperandGroup(parser, "ctrl", 2, ctrlOperands)) + return failure(); + + if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()) + return failure(); + + Type sourceType, destinationType; + SmallVector shapeTypes; + SmallVector srcLayoutTypes; + SmallVector dstGroupTypes; + SmallVector ctrlTypes; + + if (parser.parseType(sourceType) || parser.parseComma() || + parser.parseType(destinationType) || parser.parseComma() || + parser.parseKeyword(modeKeyword) || parser.parseComma() || + parseFixedKeywordTypes(parser, "shape", 2, shapeTypes) || + parser.parseComma() || + parseCubeLoadFracSrcLayoutTypes(parser, srcLayoutTypes) || + parser.parseComma() || + parseFixedKeywordTypes(parser, "dst_group", 4, dstGroupTypes) || + parser.parseComma() || + parseFixedKeywordTypes(parser, "ctrl", 2, ctrlTypes)) + return failure(); + + auto modeOr = parseCubeLoadFracModeKeyword(modeKeyword); + if (failed(modeOr)) + return parser.emitError(parser.getCurrentLocation(), + "expected one of 'nd2nz' or 'dn2nz'"); + if (shapeOperands.size() != 2 || shapeTypes.size() != 2) + return parser.emitError(parser.getCurrentLocation(), + "shape requires exactly two operands and types"); + if (srcLayoutOperands.empty() || srcLayoutOperands.size() > 2 || + srcLayoutTypes.empty() || srcLayoutTypes.size() > 2) + return parser.emitError(parser.getCurrentLocation(), + "src_layout requires one or two operands and types"); + if (dstGroupOperands.size() != 4 || dstGroupTypes.size() != 4) + return parser.emitError(parser.getCurrentLocation(), + "dst_group requires exactly four operands and types"); + if (ctrlOperands.size() != 2 || ctrlTypes.size() != 2) + return parser.emitError(parser.getCurrentLocation(), + "ctrl requires exactly two operands and types"); + if (srcLayoutOperands.size() != srcLayoutTypes.size()) + return parser.emitError(parser.getCurrentLocation(), + "src_layout operand and type groups must match"); + + bool hasSrcOuterStride = srcLayoutOperands.size() == 2; + result.addAttribute(getModeAttrName(result.name), + CubeLoadFracModeAttr::get(parser.getContext(), *modeOr)); + + SmallVector flatTypes; + SmallVector flatOperands; + flatOperands.append({shapeOperands[0], shapeOperands[1], srcLayoutOperands[0]}); + flatTypes.append({shapeTypes[0], shapeTypes[1], srcLayoutTypes[0]}); + flatOperands.append(dstGroupOperands.begin(), dstGroupOperands.end()); + flatTypes.append(dstGroupTypes.begin(), dstGroupTypes.end()); + flatOperands.append(ctrlOperands.begin(), ctrlOperands.end()); + flatTypes.append(ctrlTypes.begin(), ctrlTypes.end()); + if (hasSrcOuterStride) { + flatOperands.push_back(srcLayoutOperands[1]); + flatTypes.push_back(srcLayoutTypes[1]); + } + + if (parser.resolveOperand(source, sourceType, result.operands) || + parser.resolveOperand(destination, destinationType, result.operands) || + parser.resolveOperands(flatOperands, flatTypes, parser.getCurrentLocation(), + result.operands)) + return failure(); + return success(); +} + +void MteGmL1Op::print(OpAsmPrinter &printer) { + printer << " " << getSource() << ", " << getDestination() << ", " + << getLenBurst(); + printDmaTripleGroup(printer, "nburst", getNBurst(), getNburstSrcStride(), + getNburstDstStride()); + for (auto [count, srcStride, dstStride] : + llvm::zip(getLoopCounts(), getLoopSrcStrides(), getLoopDstStrides())) + printDmaTripleGroup(printer, "loop", count, srcStride, dstStride); + printer.printOptionalAttrDict((*this)->getAttrs()); + printer << " : " << getSource().getType() << ", " << getDestination().getType() + << ", " << getLenBurst().getType() << ", " << getNBurst().getType() + << ", " << getNburstSrcStride().getType() << ", " + << getNburstDstStride().getType(); + for (auto [count, srcStride, dstStride] : + llvm::zip(getLoopCounts(), getLoopSrcStrides(), getLoopDstStrides())) + printDmaTripleTypes(printer, "loop", count.getType(), srcStride.getType(), + dstStride.getType()); +} + +void MteL1UbOp::print(OpAsmPrinter &printer) { + printer << " " << getSource() << ", " << getDestination() << ", " + << getLenBurst(); + printDmaTripleGroup(printer, "nburst", getNBurst(), getNburstSrcStride(), + getNburstDstStride()); + for (auto [count, srcStride, dstStride] : + llvm::zip(getLoopCounts(), getLoopSrcStrides(), getLoopDstStrides())) + printDmaTripleGroup(printer, "loop", count, srcStride, dstStride); + printer.printOptionalAttrDict((*this)->getAttrs()); + printer << " : " << getSource().getType() << ", " << getDestination().getType() + << ", " << getLenBurst().getType() << ", " << getNBurst().getType() + << ", " << getNburstSrcStride().getType() << ", " + << getNburstDstStride().getType(); + for (auto [count, srcStride, dstStride] : + llvm::zip(getLoopCounts(), getLoopSrcStrides(), getLoopDstStrides())) + printDmaTripleTypes(printer, "loop", count.getType(), srcStride.getType(), + dstStride.getType()); +} + +void MteL1BtOp::build(OpBuilder &builder, OperationState &state, Value source, + Value destination, Value lenBurst, + pto::DmaLoopConfig nburst) { + state.addOperands({source, destination, lenBurst, nburst.count, + nburst.srcStride, nburst.dstStride}); +} + +void MteL1FbOp::build(OpBuilder &builder, OperationState &state, Value source, + Value destination, Value lenBurst, + pto::DmaLoopConfig nburst) { + state.addOperands({source, destination, lenBurst, nburst.count, + nburst.srcStride, nburst.dstStride}); +} + +ParseResult MteL1BtOp::parse(OpAsmParser &parser, OperationState &result) { + OpAsmParser::UnresolvedOperand source, destination, lenBurst; + SmallVector nburstOperands; + if (parseRequiredOperandWithComma(parser, source) || + parseRequiredOperandWithComma(parser, destination) || + parser.parseOperand(lenBurst) || + parseDmaTripleGroup(parser, "nburst", nburstOperands) || + parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()) + return failure(); + + Type sourceType, destinationType, lenBurstType; + SmallVector nburstTypes; + if (parser.parseType(sourceType) || parser.parseComma() || + parser.parseType(destinationType) || parser.parseComma() || + parser.parseType(lenBurstType) || parser.parseComma() || + parseDmaTripleTypes(parser, nburstTypes)) + return failure(); + + if (parser.resolveOperand(source, sourceType, result.operands) || + parser.resolveOperand(destination, destinationType, result.operands) || + parser.resolveOperand(lenBurst, lenBurstType, result.operands) || + parser.resolveOperands(nburstOperands, nburstTypes, + parser.getCurrentLocation(), result.operands)) + return failure(); + return success(); +} + +ParseResult MteL1FbOp::parse(OpAsmParser &parser, OperationState &result) { + OpAsmParser::UnresolvedOperand source, destination, lenBurst; + SmallVector nburstOperands; + if (parseRequiredOperandWithComma(parser, source) || + parseRequiredOperandWithComma(parser, destination) || + parser.parseOperand(lenBurst) || + parseDmaTripleGroup(parser, "nburst", nburstOperands) || + parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()) + return failure(); + + Type sourceType, destinationType, lenBurstType; + SmallVector nburstTypes; + if (parser.parseType(sourceType) || parser.parseComma() || + parser.parseType(destinationType) || parser.parseComma() || + parser.parseType(lenBurstType) || parser.parseComma() || + parseDmaTripleTypes(parser, nburstTypes)) + return failure(); + + if (parser.resolveOperand(source, sourceType, result.operands) || + parser.resolveOperand(destination, destinationType, result.operands) || + parser.resolveOperand(lenBurst, lenBurstType, result.operands) || + parser.resolveOperands(nburstOperands, nburstTypes, + parser.getCurrentLocation(), result.operands)) + return failure(); + return success(); +} + +void MteL1BtOp::print(OpAsmPrinter &printer) { + printer << " " << getSource() << ", " << getDestination() << ", " + << getLenBurst(); + printDmaTripleGroup(printer, "nburst", getNBurst(), getNburstSrcGap(), + getNburstDstGap()); + printer.printOptionalAttrDict((*this)->getAttrs()); + printer << " : " << getSource().getType() << ", " << getDestination().getType() + << ", " << getLenBurst().getType() << ", " << getNBurst().getType() + << ", " << getNburstSrcGap().getType() << ", " + << getNburstDstGap().getType(); +} + +void MteL1FbOp::print(OpAsmPrinter &printer) { + printer << " " << getSource() << ", " << getDestination() << ", " + << getLenBurst(); + printDmaTripleGroup(printer, "nburst", getNBurst(), getNburstSrcGap(), + getNburstDstGap()); + printer.printOptionalAttrDict((*this)->getAttrs()); + printer << " : " << getSource().getType() << ", " << getDestination().getType() + << ", " << getLenBurst().getType() << ", " << getNBurst().getType() + << ", " << getNburstSrcGap().getType() << ", " + << getNburstDstGap().getType(); +} + +void MteGmL1FracOp::print(OpAsmPrinter &printer) { + printer << " " << getSource() << ", " << getDestination() << ", " + << pto::stringifyCubeLoadFracMode(getMode()); + printer << ", shape(" << getNValue() << ", " << getDValue() << ")"; + printCubeLoadFracSrcLayoutGroup(printer, getSrcInnerStride(), + getSrcOuterStride()); + printer << ", dst_group(" << getGroupCount() << ", " << getDstLoop2Stride() + << ", " << getDstLoop3Stride() << ", " << getDstLoop4Stride() + << ")"; + printer << ", ctrl(" << getL2CacheCtrl() << ", " << getSmallc0En() << ")"; + printer.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{"operandSegmentSizes", + "mode"}); + printer << " : " << getSource().getType() << ", " << getDestination().getType() + << ", " << pto::stringifyCubeLoadFracMode(getMode()) + << ", shape " << getNValue().getType() << ", " << getDValue().getType(); + printCubeLoadFracSrcLayoutTypes( + printer, getSrcInnerStride().getType(), + getSrcOuterStride() ? getSrcOuterStride().getType() : Type()); + printer << ", dst_group " << getGroupCount().getType() << ", " + << getDstLoop2Stride().getType() << ", " + << getDstLoop3Stride().getType() << ", " + << getDstLoop4Stride().getType() << ", ctrl " + << getL2CacheCtrl().getType() << ", " << getSmallc0En().getType(); +} + +LogicalResult MteGmL1Op::verify() { + if (failed(verifyCopyGmToUbufOp(*this, true))) + return failure(); + return verifyDmaLoadStoreLoopGroups( + getOperation(), getLoopCounts(), getLoopSrcStrides(), + getLoopDstStrides()); +} + +LogicalResult MteL1UbOp::verify() { + if (failed(verifyCopyCbufToUbufLikeOp(*this))) + return failure(); + return verifyDmaLoadStoreLoopGroups( + getOperation(), getLoopCounts(), getLoopSrcStrides(), + getLoopDstStrides()); +} + +LogicalResult MteL1BtOp::verify() { + auto getBufferElementType = [](Type type) -> Type { + if (auto ptrType = dyn_cast(type)) + return ptrType.getElementType(); + if (auto memrefType = dyn_cast(type)) + return memrefType.getElementType(); + return {}; + }; + + if (!isBufferLike(getSource().getType()) || + !isBufferLike(getDestination().getType())) + return emitOpError("requires buffer-like source and destination"); + if (getBufferAddressSpace(getSource().getType()) != pto::AddressSpace::MAT) + return emitOpError("requires MAT source"); + if (getBufferAddressSpace(getDestination().getType()) != pto::AddressSpace::BIAS) + return emitOpError("requires BIAS destination"); + + Type srcElem = getBufferElementType(getSource().getType()); + Type dstElem = getBufferElementType(getDestination().getType()); + const bool isF32 = srcElem.isF32() && dstElem.isF32(); + const bool isI32 = isa(srcElem) && isa(dstElem) && + cast(srcElem).getWidth() == 32 && + cast(dstElem).getWidth() == 32; + const bool isF16ToF32 = srcElem.isF16() && dstElem.isF32(); + const bool isBF16ToF32 = srcElem.isBF16() && dstElem.isF32(); + if (!isF32 && !isI32 && !isF16ToF32 && !isBF16ToF32) { + return emitOpError( + "expects one of f32->f32, i32->i32, f16->f32, or bf16->f32"); + } + return success(); +} + +LogicalResult MteL1FbOp::verify() { + if (!isBufferLike(getSource().getType()) || !isBufferLike(getDestination().getType())) + return emitOpError( + "requires typed !pto.ptr or memref source and destination"); + + auto getAddressSpace = [](Type type) -> std::optional { + if (auto ptrType = dyn_cast(type)) + return ptrType.getMemorySpace().getAddressSpace(); + if (auto memrefType = dyn_cast(type)) { + Attribute memorySpace = memrefType.getMemorySpace(); + if (auto addrSpace = dyn_cast_or_null(memorySpace)) + return addrSpace.getAddressSpace(); + if (auto intAttr = dyn_cast_or_null(memorySpace)) + return static_cast(intAttr.getInt()); + } + return std::nullopt; + }; + + std::optional sourceAS = getAddressSpace(getSource().getType()); + std::optional destinationAS = + getAddressSpace(getDestination().getType()); + if (!sourceAS || !destinationAS) + return emitOpError("requires source and destination with PTO address spaces"); + if (*sourceAS != pto::AddressSpace::MAT) + return emitOpError("requires source in mat address space"); + if (*destinationAS != pto::AddressSpace::SCALING) + return emitOpError("requires destination in scaling address space"); + return success(); +} + +LogicalResult MteGmL1FracOp::verify() { + if (failed(verifyCopyGmToUbufOp(*this, true))) + return failure(); + + auto checkNonNegativeConst = [&](Value value, StringRef name) -> LogicalResult { + APInt intValue; + if (matchPattern(value, m_ConstantInt(&intValue)) && intValue.isNegative()) + return emitOpError() << name << " must be non-negative"; + return success(); + }; + if (failed(checkNonNegativeConst(getGroupCount(), "group_count")) || + failed(checkNonNegativeConst(getSrcInnerStride(), "src_inner_stride")) || + failed(checkNonNegativeConst(getDstLoop2Stride(), "dst_loop2_stride")) || + failed(checkNonNegativeConst(getDstLoop3Stride(), "dst_loop3_stride")) || + failed(checkNonNegativeConst(getDstLoop4Stride(), "dst_loop4_stride")) || + (getSrcOuterStride() && + failed(checkNonNegativeConst(getSrcOuterStride(), "src_outer_stride")))) + return failure(); + + APInt groupCount; + if (matchPattern(getGroupCount(), m_ConstantInt(&groupCount)) && + groupCount.isZero()) + return emitOpError("group_count must be greater than zero"); + + APInt smallc0En; + APInt dValue; + if (matchPattern(getSmallc0En(), m_ConstantInt(&smallc0En)) && + smallc0En.getBoolValue() && matchPattern(getDValue(), m_ConstantInt(&dValue)) && + dValue.ugt(4)) + return emitOpError("smallc0_en requires d_value <= 4"); + + return success(); +} + +void MteGmL1Op::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +void MteL1UbOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +void MteL1BtOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +void MteL1FbOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +void MteGmL1FracOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +ParseResult MteL0cL1Op::parse(OpAsmParser &parser, OperationState &result) { + Builder builder(parser.getContext()); + StructuredAccStoreAsmState state; + OpAsmParser::UnresolvedOperand source, destination, m, n, srcStride, + dstStride; + if (parseRequiredOperandWithComma(parser, source) || + parseRequiredOperandWithComma(parser, destination) || + parseRequiredOperandWithComma(parser, m) || + parseRequiredOperandWithComma(parser, n) || + parseRequiredOperandWithComma(parser, srcStride) || + parseRequiredOperandWithComma(parser, dstStride) || + parseStructuredAccStoreClauses(parser, state) || + parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()) + return failure(); + + Type sourceType, destinationType, mType, nType, srcStrideType, dstStrideType; + if (parser.parseType(sourceType) || parser.parseComma() || + parser.parseType(destinationType) || parser.parseComma() || + parser.parseType(mType) || parser.parseComma() || parser.parseType(nType) || + parser.parseComma() || parser.parseType(srcStrideType) || + parser.parseComma() || parser.parseType(dstStrideType) || + parseStructuredAccStoreTailTypes(parser, state)) + return failure(); + + setStructuredAccStoreSegmentSizes( + result, {1, 1, 1, 1, 1, 1, !state.preQuantOperands.empty() ? 1 : 0, + !state.preReluOperands.empty() ? 1 : 0, + !state.clipValueOperands.empty() ? 1 : 0, + !state.splitOperands.empty() ? 1 : 0, + !state.loop0SrcStrideOperands.empty() ? 1 : 0, + !state.loop3CountOperands.empty() ? 1 : 0, + !state.loop3SrcStrideOperands.empty() ? 1 : 0, + !state.loop3DstStrideOperands.empty() ? 1 : 0}); + if (state.atomicType || state.atomicOp) { + return parser.emitError(parser.getCurrentLocation(), + "atomic is only supported for mte_l0c_gm"); + } + addStructuredAccStoreAttrs(result, builder, state); + + if (parser.resolveOperand(source, sourceType, result.operands) || + parser.resolveOperand(destination, destinationType, result.operands) || + parser.resolveOperand(m, mType, result.operands) || + parser.resolveOperand(n, nType, result.operands) || + parser.resolveOperand(srcStride, srcStrideType, result.operands) || + parser.resolveOperand(dstStride, dstStrideType, result.operands) || + parser.resolveOperands(state.preQuantOperands, state.preQuantTypes, + parser.getCurrentLocation(), result.operands) || + parser.resolveOperands(state.preReluOperands, state.preReluTypes, + parser.getCurrentLocation(), result.operands) || + parser.resolveOperands(state.clipValueOperands, state.clipValueTypes, + parser.getCurrentLocation(), result.operands) || + parser.resolveOperands(state.splitOperands, state.splitTypes, + parser.getCurrentLocation(), result.operands) || + parser.resolveOperands(state.loop0SrcStrideOperands, + state.loop0SrcStrideTypes, + parser.getCurrentLocation(), result.operands) || + parser.resolveOperands(state.loop3CountOperands, state.loop3CountTypes, + parser.getCurrentLocation(), result.operands) || + parser.resolveOperands(state.loop3SrcStrideOperands, + state.loop3SrcStrideTypes, + parser.getCurrentLocation(), result.operands) || + parser.resolveOperands(state.loop3DstStrideOperands, + state.loop3DstStrideTypes, + parser.getCurrentLocation(), result.operands)) + return failure(); + return success(); +} + +void MteL0cL1Op::print(OpAsmPrinter &printer) { + printer << " " << getSource() << ", " << getDestination() << ", " << getM() + << ", " << getN() << ", " << getSrcStride() << ", " << getDstStride(); + printStructuredAccStoreClauses(printer, getUnitFlag(), getPreQuant(), + getPreQuantMode(), getPreRelu(), + getPreReluMode(), getClipValue(), getMode(), + getSplit(), getLoop0SrcStride(), + getLoop3Count(), getLoop3SrcStride(), + getLoop3DstStride(), getSatMode(), + getAtomicType(), getAtomicOp()); + printer.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{"operandSegmentSizes", + "mode", + "unit_flag", + "pre_quant_mode", + "pre_relu_mode", + "atomic_type", + "atomic_op", + "sat_mode"}); + printer << " : " << getSource().getType() << ", " << getDestination().getType() + << ", " << getM().getType() << ", " << getN().getType() << ", " + << getSrcStride().getType() << ", " << getDstStride().getType(); + printStructuredAccStoreOptionalTypes( + printer, getPreQuant(), getPreRelu(), getClipValue(), getSplit(), + getLoop0SrcStride(), getLoop3Count(), getLoop3SrcStride(), + getLoop3DstStride()); +} + +LogicalResult MteL0cL1Op::verify() { + if (!isBufferLike(getSource().getType()) || + !isBufferLike(getDestination().getType())) + return emitOpError("requires buffer-like source and destination"); + std::optional sourceSpace = + getBufferAddressSpace(getSource().getType()); + std::optional destinationSpace = + getBufferAddressSpace(getDestination().getType()); + if (sourceSpace != AddressSpace::ACC || destinationSpace != AddressSpace::MAT) { + return emitOpError("requires ACC source and MAT destination"); + } + return verifyStructuredAccStoreLike( + *this, getSource().getType(), getDestination().getType(), getPreQuant(), getPreRelu(), + getClipValue(), getSplit(), getLoop0SrcStride(), getLoop3Count(), + getLoop3SrcStride(), getLoop3DstStride(), getUnitFlag(), + getPreQuantMode(), getPreReluMode(), getMode(), std::nullopt, + std::nullopt, /*allowAtomic=*/false); +} + +LogicalResult MteL1L0aOp::verify() { + return verifyCubeBridgeLoadLikeOp(*this, AddressSpace::LEFT, "LEFT"); +} + +LogicalResult MteL1L0bOp::verify() { + return verifyCubeBridgeLoadLikeOp(*this, AddressSpace::RIGHT, "RIGHT"); +} + +LogicalResult MteL1L0aMxOp::verify() { + return verifyCubeBridgeLoadLikeOp(*this, AddressSpace::LEFT, "LEFT"); +} + +LogicalResult MteL1L0bMxOp::verify() { + return verifyCubeBridgeLoadLikeOp(*this, AddressSpace::RIGHT, "RIGHT"); +} + +void MteL1L0aOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +void MteL1L0bOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +void MteL1L0aMxOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +void MteL1L0bMxOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +void MteL0cL1Op::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +ParseResult MteL0cGmOp::parse(OpAsmParser &parser, OperationState &result) { + Builder builder(parser.getContext()); + StructuredAccStoreAsmState state; + OpAsmParser::UnresolvedOperand source, destination, m, n, srcStride, + dstStride, sid, l2CacheCtrl; + if (parseRequiredOperandWithComma(parser, source) || + parseRequiredOperandWithComma(parser, destination) || + parseRequiredOperandWithComma(parser, m) || + parseRequiredOperandWithComma(parser, n) || + parseRequiredOperandWithComma(parser, srcStride) || + parseRequiredOperandWithComma(parser, dstStride) || + parseRequiredOperandWithComma(parser, sid) || + parseRequiredOperandWithComma(parser, l2CacheCtrl) || + parseStructuredAccStoreClauses(parser, state) || + parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()) + return failure(); + + Type sourceType, destinationType, mType, nType, srcStrideType, dstStrideType, + sidType, l2CacheCtrlType; + if (parser.parseType(sourceType) || parser.parseComma() || + parser.parseType(destinationType) || parser.parseComma() || + parser.parseType(mType) || parser.parseComma() || parser.parseType(nType) || + parser.parseComma() || parser.parseType(srcStrideType) || + parser.parseComma() || parser.parseType(dstStrideType) || + parser.parseComma() || parser.parseType(sidType) || + parser.parseComma() || parser.parseType(l2CacheCtrlType) || + parseStructuredAccStoreTailTypes(parser, state)) + return failure(); + + setStructuredAccStoreSegmentSizes( + result, {1, 1, 1, 1, 1, 1, !state.preQuantOperands.empty() ? 1 : 0, + !state.preReluOperands.empty() ? 1 : 0, + !state.clipValueOperands.empty() ? 1 : 0, 1, 1, + !state.splitOperands.empty() ? 1 : 0, + !state.loop0SrcStrideOperands.empty() ? 1 : 0, + !state.loop3CountOperands.empty() ? 1 : 0, + !state.loop3SrcStrideOperands.empty() ? 1 : 0, + !state.loop3DstStrideOperands.empty() ? 1 : 0}); + addStructuredAccStoreAttrs(result, builder, state); + + if (parser.resolveOperand(source, sourceType, result.operands) || + parser.resolveOperand(destination, destinationType, result.operands) || + parser.resolveOperand(m, mType, result.operands) || + parser.resolveOperand(n, nType, result.operands) || + parser.resolveOperand(srcStride, srcStrideType, result.operands) || + parser.resolveOperand(dstStride, dstStrideType, result.operands) || + parser.resolveOperands(state.preQuantOperands, state.preQuantTypes, + parser.getCurrentLocation(), result.operands) || + parser.resolveOperands(state.preReluOperands, state.preReluTypes, + parser.getCurrentLocation(), result.operands) || + parser.resolveOperands(state.clipValueOperands, state.clipValueTypes, + parser.getCurrentLocation(), result.operands) || + parser.resolveOperand(sid, sidType, result.operands) || + parser.resolveOperand(l2CacheCtrl, l2CacheCtrlType, result.operands) || + parser.resolveOperands(state.splitOperands, state.splitTypes, + parser.getCurrentLocation(), result.operands) || + parser.resolveOperands(state.loop0SrcStrideOperands, + state.loop0SrcStrideTypes, + parser.getCurrentLocation(), result.operands) || + parser.resolveOperands(state.loop3CountOperands, state.loop3CountTypes, + parser.getCurrentLocation(), result.operands) || + parser.resolveOperands(state.loop3SrcStrideOperands, + state.loop3SrcStrideTypes, + parser.getCurrentLocation(), result.operands) || + parser.resolveOperands(state.loop3DstStrideOperands, + state.loop3DstStrideTypes, + parser.getCurrentLocation(), result.operands)) + return failure(); + return success(); +} + +void MteL0cGmOp::print(OpAsmPrinter &printer) { + printer << " " << getSource() << ", " << getDestination() << ", " << getM() + << ", " << getN() << ", " << getSrcStride() << ", " + << getDstStride() << ", " << getSid() << ", " << getL2CacheCtrl(); + printStructuredAccStoreClauses(printer, getUnitFlag(), getPreQuant(), + getPreQuantMode(), getPreRelu(), + getPreReluMode(), getClipValue(), getMode(), + getSplit(), getLoop0SrcStride(), + getLoop3Count(), getLoop3SrcStride(), + getLoop3DstStride(), getSatMode(), + getAtomicType(), getAtomicOp()); + printer.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{"operandSegmentSizes", + "mode", + "unit_flag", + "pre_quant_mode", + "pre_relu_mode", + "atomic_type", + "atomic_op", + "sat_mode"}); + printer << " : " << getSource().getType() << ", " << getDestination().getType() + << ", " << getM().getType() << ", " << getN().getType() << ", " + << getSrcStride().getType() << ", " << getDstStride().getType() + << ", " << getSid().getType() << ", " << getL2CacheCtrl().getType(); + printStructuredAccStoreOptionalTypes( + printer, getPreQuant(), getPreRelu(), getClipValue(), getSplit(), + getLoop0SrcStride(), getLoop3Count(), getLoop3SrcStride(), + getLoop3DstStride()); +} + +LogicalResult MteL0cGmOp::verify() { + if (!isBufferLike(getSource().getType()) || + !isBufferLike(getDestination().getType())) + return emitOpError("requires buffer-like source and destination"); + std::optional sourceSpace = + getBufferAddressSpace(getSource().getType()); + std::optional destinationSpace = + getBufferAddressSpace(getDestination().getType()); + if (sourceSpace != AddressSpace::ACC || destinationSpace != AddressSpace::GM) { + return emitOpError("requires ACC source and GM destination"); + } + return verifyStructuredAccStoreLike( + *this, getSource().getType(), getDestination().getType(), getPreQuant(), getPreRelu(), + getClipValue(), getSplit(), getLoop0SrcStride(), getLoop3Count(), + getLoop3SrcStride(), getLoop3DstStride(), getUnitFlag(), + getPreQuantMode(), getPreReluMode(), getMode(), getAtomicType(), + getAtomicOp(), /*allowAtomic=*/true); +} + +void MteL0cGmOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +ParseResult MteL0cUbOp::parse(OpAsmParser &parser, OperationState &result) { + Builder builder(parser.getContext()); + StructuredAccStoreAsmState state; + OpAsmParser::UnresolvedOperand source, destination, m, n, srcStride, + dstStride, subBlockId; + bool hasSubBlockId = false; + AccStoreUbDstMode dstMode = AccStoreUbDstMode::Single; + if (parseRequiredOperandWithComma(parser, source) || + parseRequiredOperandWithComma(parser, destination) || + parseRequiredOperandWithComma(parser, m) || + parseRequiredOperandWithComma(parser, n) || + parseRequiredOperandWithComma(parser, srcStride) || + parseRequiredOperandWithComma(parser, dstStride)) + return failure(); + if (parser.parseKeyword("dst_mode") || parser.parseLParen()) + return failure(); + OptionalParseResult subBlockIdParse = + parser.parseOptionalOperand(subBlockId); + if (subBlockIdParse.has_value()) { + if (failed(*subBlockIdParse)) + return failure(); + hasSubBlockId = true; + } else { + StringRef dstModeKeyword; + if (parser.parseKeyword(&dstModeKeyword)) + return failure(); + if (dstModeKeyword == "split_m") { + dstMode = AccStoreUbDstMode::SplitM; + } else if (dstModeKeyword == "split_n") { + dstMode = AccStoreUbDstMode::SplitN; + } else { + return parser.emitError( + parser.getCurrentLocation(), + "expected dst_mode(%sub_blockid), dst_mode(split_m), or " + "dst_mode(split_n)"); + } + } + if (parser.parseRParen()) + return failure(); + if (succeeded(parser.parseOptionalComma()) && + parseStructuredAccStoreClauses(parser, state)) + return failure(); + if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()) + return failure(); + + Type sourceType, destinationType, mType, nType, srcStrideType, dstStrideType, + subBlockIdType; + if (parser.parseType(sourceType) || parser.parseComma() || + parser.parseType(destinationType) || parser.parseComma() || + parser.parseType(mType) || parser.parseComma() || parser.parseType(nType) || + parser.parseComma() || parser.parseType(srcStrideType) || + parser.parseComma() || parser.parseType(dstStrideType)) + return failure(); + if (hasSubBlockId && + (parser.parseComma() || parser.parseType(subBlockIdType))) + return failure(); + if (parseStructuredAccStoreTailTypes(parser, state)) + return failure(); + + setStructuredAccStoreSegmentSizes( + result, {1, 1, 1, 1, 1, 1, !state.preQuantOperands.empty() ? 1 : 0, + !state.preReluOperands.empty() ? 1 : 0, + !state.clipValueOperands.empty() ? 1 : 0, + hasSubBlockId ? 1 : 0, + !state.splitOperands.empty() ? 1 : 0, + !state.loop0SrcStrideOperands.empty() ? 1 : 0, + !state.loop3CountOperands.empty() ? 1 : 0, + !state.loop3SrcStrideOperands.empty() ? 1 : 0, + !state.loop3DstStrideOperands.empty() ? 1 : 0}); + if (state.atomicType || state.atomicOp) { + return parser.emitError(parser.getCurrentLocation(), + "atomic is only supported for mte_l0c_gm"); + } + addStructuredAccStoreAttrs(result, builder, state); + result.addAttribute("dst_mode", + AccStoreUbDstModeAttr::get(builder.getContext(), dstMode)); + + if (parser.resolveOperand(source, sourceType, result.operands) || + parser.resolveOperand(destination, destinationType, result.operands) || + parser.resolveOperand(m, mType, result.operands) || + parser.resolveOperand(n, nType, result.operands) || + parser.resolveOperand(srcStride, srcStrideType, result.operands) || + parser.resolveOperand(dstStride, dstStrideType, result.operands) || + parser.resolveOperands(state.preQuantOperands, state.preQuantTypes, + parser.getCurrentLocation(), result.operands) || + parser.resolveOperands(state.preReluOperands, state.preReluTypes, + parser.getCurrentLocation(), result.operands) || + parser.resolveOperands(state.clipValueOperands, state.clipValueTypes, + parser.getCurrentLocation(), result.operands) || + (hasSubBlockId && + parser.resolveOperand(subBlockId, subBlockIdType, result.operands)) || + parser.resolveOperands(state.splitOperands, state.splitTypes, + parser.getCurrentLocation(), result.operands) || + parser.resolveOperands(state.loop0SrcStrideOperands, + state.loop0SrcStrideTypes, + parser.getCurrentLocation(), result.operands) || + parser.resolveOperands(state.loop3CountOperands, state.loop3CountTypes, + parser.getCurrentLocation(), result.operands) || + parser.resolveOperands(state.loop3SrcStrideOperands, + state.loop3SrcStrideTypes, + parser.getCurrentLocation(), result.operands) || + parser.resolveOperands(state.loop3DstStrideOperands, + state.loop3DstStrideTypes, + parser.getCurrentLocation(), result.operands)) + return failure(); + return success(); +} + +void MteL0cUbOp::print(OpAsmPrinter &printer) { + printer << " " << getSource() << ", " << getDestination() << ", " << getM() + << ", " << getN() << ", " << getSrcStride() << ", " + << getDstStride() << ", dst_mode("; + switch (getDstMode()) { + case AccStoreUbDstMode::Single: + printer << getSubBlockid(); + break; + case AccStoreUbDstMode::SplitM: + printer << "split_m"; + break; + case AccStoreUbDstMode::SplitN: + printer << "split_n"; + break; + } + printer << ")"; + printStructuredAccStoreClauses(printer, getUnitFlag(), getPreQuant(), + getPreQuantMode(), getPreRelu(), + getPreReluMode(), getClipValue(), getMode(), + getSplit(), getLoop0SrcStride(), + getLoop3Count(), getLoop3SrcStride(), + getLoop3DstStride(), getSatMode(), + std::nullopt, std::nullopt); + printer.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{"operandSegmentSizes", + "mode", + "unit_flag", + "pre_quant_mode", + "pre_relu_mode", + "dst_mode", + "sat_mode"}); + printer << " : " << getSource().getType() << ", " << getDestination().getType() + << ", " << getM().getType() << ", " << getN().getType() << ", " + << getSrcStride().getType() << ", " << getDstStride().getType(); + if (getSubBlockid()) + printer << ", " << getSubBlockid().getType(); + printStructuredAccStoreOptionalTypes( + printer, getPreQuant(), getPreRelu(), getClipValue(), getSplit(), + getLoop0SrcStride(), getLoop3Count(), getLoop3SrcStride(), + getLoop3DstStride()); +} + +LogicalResult MteL0cUbOp::verify() { + if (!isBufferLike(getSource().getType()) || + !isBufferLike(getDestination().getType())) + return emitOpError("requires buffer-like source and destination"); + std::optional sourceSpace = + getBufferAddressSpace(getSource().getType()); + std::optional destinationSpace = + getBufferAddressSpace(getDestination().getType()); + if (sourceSpace != AddressSpace::ACC || destinationSpace != AddressSpace::VEC) { + return emitOpError("requires ACC source and UB destination"); + } + if (failed(verifyStructuredAccStoreLike( + *this, getSource().getType(), getDestination().getType(), getPreQuant(), getPreRelu(), + getClipValue(), getSplit(), getLoop0SrcStride(), getLoop3Count(), + getLoop3SrcStride(), getLoop3DstStride(), getUnitFlag(), + getPreQuantMode(), getPreReluMode(), getMode(), std::nullopt, + std::nullopt, /*allowAtomic=*/false))) + return failure(); + + if (getDstMode() == AccStoreUbDstMode::Single) { + if (!getSubBlockid()) + return emitOpError("dst_mode(%sub_blockid) requires a sub_blockid operand"); + APInt subBlockId; + if (matchPattern(getSubBlockid(), m_ConstantInt(&subBlockId)) && + subBlockId.ugt(1)) + return emitOpError("sub_blockid must be 0 or 1"); + return success(); + } + if (getSubBlockid()) + return emitOpError("split destination modes do not accept sub_blockid"); + + if (getPreQuant() || getPreRelu() || getClipValue() || getPreQuantMode() || + getPreReluMode() || getSplit() || getLoop0SrcStride() || + getLoop3Count() || getLoop3SrcStride() || getLoop3DstStride()) { + return emitOpError("dual destination mode cannot be combined with " + "pre_quant, pre_relu, clip, nz2dn, nz2nz, or loop3"); + } + if (getMode() && *getMode() != AccStoreMode::Nz2nd) + return emitOpError("dual destination mode requires normal or nz2nd layout"); + + APInt mValue; + APInt nValue; + if (getDstMode() == AccStoreUbDstMode::SplitM && + matchPattern(getM(), m_ConstantInt(&mValue)) && + mValue.getZExtValue() % 2 != 0) + return emitOpError("split-M dual destination requires m to be even"); + if (getDstMode() == AccStoreUbDstMode::SplitN && + matchPattern(getN(), m_ConstantInt(&nValue)) && + nValue.getZExtValue() % 32 != 0) + return emitOpError("split-N dual destination requires n to be a multiple of 32"); + return success(); +} + +void MteL0cUbOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} diff --git a/lib/PTO/Transforms/CMakeLists.txt b/lib/PTO/Transforms/CMakeLists.txt index b8194d0a4..32d7cf16d 100644 --- a/lib/PTO/Transforms/CMakeLists.txt +++ b/lib/PTO/Transforms/CMakeLists.txt @@ -12,11 +12,25 @@ # See LICENSE in the root of the software repository for the full text of the License. add_mlir_dialect_library(PTOTransforms + VPTOLLVMEmitter.cpp + VPTOLLVMEmitterHelper.cpp + VPTOPtrNormalize.cpp + VPTOPtrCastCleanup.cpp + VPTOExpandWrapperOps.cpp + PTOVPTOPtrBoundary.cpp + VPTOBufferMaterialization.cpp + PTOValidateVPTOIR.cpp + PTOInferVPTOVecScope.cpp + InsertSync/PTOInsertSync.cpp PTOInjectBarrierAllSync.cpp InsertSync/InsertSyncDebug.cpp PTOViewToMemref.cpp PTOValidateIntToPtrUses.cpp + ExpandTileOp.cpp + FoldTileBufIntrinsics.cpp + PTOLowerToOpLibCalls.cpp + PTOInstantiateAndInlineOpLib.cpp PTOToEmitC.cpp Utils.cpp OptMemPlanForPipeline.cpp @@ -34,6 +48,8 @@ add_mlir_dialect_library(PTOTransforms PTOInferValidatePipeInitPass.cpp PTOResolveReservedBuffersPass.cpp PTOWrapFunctionsInSectionsPass.cpp + VPTONormalizeContainer.cpp + VPTOSplitCVModule.cpp InsertSync/PTOIRTranslator.cpp InsertSync/SyncCommon.cpp InsertSync/InsertSyncAnalysis.cpp @@ -61,6 +77,9 @@ add_mlir_dialect_library(PTOTransforms PTOPassesIncGen PTOOpsIncGen + LINK_COMPONENTS + Analysis + LINK_LIBS PUBLIC PTOIR MLIRIR @@ -74,8 +93,15 @@ add_mlir_dialect_library(PTOTransforms MLIRTransformUtils MLIRTransforms MLIRTensorDialect + MLIRSCFDialect + MLIRVectorDialect + MLIRParser MLIRSCFToEmitC MLIRSCFDialect + MLIRSCFToControlFlow + MLIRConvertToLLVMPass + MLIRTargetLLVMIRExport + MLIRToLLVMIRTranslationRegistration ) install(TARGETS PTOTransforms diff --git a/lib/PTO/Transforms/ExpandTileOp.cpp b/lib/PTO/Transforms/ExpandTileOp.cpp new file mode 100644 index 000000000..046a3faf9 --- /dev/null +++ b/lib/PTO/Transforms/ExpandTileOp.cpp @@ -0,0 +1,997 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- ExpandTileOp.cpp ---------------------------------------------------===// +//===----------------------------------------------------------------------===// +// +// Expand tile-level ops (pto.tadd, pto.tsub, ...) by invoking the TileLang +// Python DSL to instantiate template libraries. +// +// The generated template functions use tile_buf parameters. After this pass, +// the Inline pass inlines the template body, and FoldTileBufIntrinsics +// resolves tile_buf_addr / tile_valid_rows / tile_valid_cols. +// +// Workflow per tile op: +// 1. Extract SpecKey from ALL operands' tile_buf types. +// 2. Invoke Python DSL helper to generate a specialized MLIR function +// (with tile_buf parameters). +// 3. Parse the generated MLIR and clone the function into the module. +// 4. Replace the original tile op with func.call, passing tile_buf +// operands directly (no type bridging needed). +// + +#include "PTO/IR/PTO.h" +#include "PTO/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Parser/Parser.h" + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringSet.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/Path.h" +#include "llvm/Support/Program.h" +#include "llvm/Support/raw_ostream.h" + +#include +#include +#include + +extern "C" { +extern char **environ; +} + +using namespace mlir; + +namespace mlir { +namespace pto { + namespace func = ::mlir::func; + + #define GEN_PASS_DEF_EXPANDTILEOP + #include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +namespace { + +// ============================================================================ +// OperandTypeInfo: describes one operand for template specialization. +// +// Four kinds of operands: +// Tile — from TileBufType. dtype + shape + memorySpace + config +// all participate in the specialization key (SpecKey). +// View — from MemRefType (lowered PartitionTensorViewType). Only dtype +// participates in SpecKey — the template is fully dynamic so +// shape/strides/memorySpace don't affect code generation. They are +// carried here solely for JSON serialization to the Python DSL for +// constraint checking. +// Vector — from builtin VectorType. The element dtype and vector shape +// participate in SpecKey so helper-side schema filtering can +// distinguish auxiliary vector operands such as tmrgsort's +// `excuted : vector<4xi16>`. +// Scalar — from a scalar element type. Only dtype participates in SpecKey. +// ============================================================================ +enum class OperandKind { Tile, View, Vector, Scalar }; + +struct OperandTypeInfo { + OperandKind kind = OperandKind::Tile; + std::string dtype; // all kinds: element type string (e.g. "f32") + + // --- Tile-only (TileBufType) --- + SmallVector tileShape; + SmallVector tileValidShape; + std::string tileMemorySpace; // e.g. "ub", "gm", "mat", "left", "right", "acc", "bias" + int32_t blayout = 0; + int32_t slayout = 0; + int32_t fractal = 0; + uint64_t pad = 0; + + // --- View-only (MemRefType) — for JSON / constraint checking only --- + SmallVector viewShape; + SmallVector viewStrides; + std::string viewMemorySpace; // "gm" or "ub" + + // --- Vector-only (builtin VectorType) --- + SmallVector vectorShape; + + /// Equality for SpecKey caching — only compares fields relevant to each kind. + bool operator==(const OperandTypeInfo &rhs) const { + if (kind != rhs.kind || dtype != rhs.dtype) + return false; + if (kind == OperandKind::Tile) + return tileShape == rhs.tileShape && + tileValidShape == rhs.tileValidShape && + tileMemorySpace == rhs.tileMemorySpace && + blayout == rhs.blayout && slayout == rhs.slayout && + fractal == rhs.fractal && pad == rhs.pad; + if (kind == OperandKind::Vector) + return vectorShape == rhs.vectorShape; + // View and Scalar: dtype alone is sufficient for template caching. + return true; + } +}; + +// ============================================================================ +// SpecKey: identifies a specialized template instance using ALL operands. +// ============================================================================ +struct SpecKey { + std::string opName; + std::string targetArch; + SmallVector operands; + SmallVector, 4> contextAttrs; + + bool operator==(const SpecKey &rhs) const { + return opName == rhs.opName && targetArch == rhs.targetArch && + operands == rhs.operands && contextAttrs == rhs.contextAttrs; + } +}; + +struct SpecKeyInfo : public llvm::DenseMapInfo { + static inline SpecKey getEmptyKey() { return {"", "", {}}; } + static inline SpecKey getTombstoneKey() { return {"__tombstone__", "", {}}; } + static unsigned getHashValue(const SpecKey &key) { + unsigned h = llvm::hash_combine(key.opName, key.targetArch); + for (const auto &op : key.operands) { + h = llvm::hash_combine(h, static_cast(op.kind), op.dtype); + if (op.kind == OperandKind::Tile) { + h = llvm::hash_combine(h, op.tileMemorySpace, op.blayout, + op.slayout, op.fractal, op.pad); + for (int64_t d : op.tileShape) + h = llvm::hash_combine(h, d); + for (int64_t d : op.tileValidShape) + h = llvm::hash_combine(h, d); + } else if (op.kind == OperandKind::Vector) { + for (int64_t d : op.vectorShape) + h = llvm::hash_combine(h, d); + } + // View/Vector/Scalar: only kind + dtype contribute to hash. + } + for (const auto &[attrName, attrValue] : key.contextAttrs) + h = llvm::hash_combine(h, attrName, attrValue); + return h; + } + static bool isEqual(const SpecKey &lhs, const SpecKey &rhs) { + return lhs == rhs; + } +}; + +// ============================================================================ +// Helpers +// ============================================================================ +static std::string getDtypeString(Type elemTy) { + if (elemTy.isIndex()) return "i32"; + if (elemTy.isInteger(1)) return "i1"; + if (elemTy.isF32()) return "f32"; + if (elemTy.isF16()) return "f16"; + if (elemTy.isBF16()) return "bf16"; + if (elemTy.isUnsignedInteger(64)) return "ui64"; + if (elemTy.isUnsignedInteger(32)) return "ui32"; + if (elemTy.isUnsignedInteger(16)) return "ui16"; + if (elemTy.isUnsignedInteger(8)) return "ui8"; + if (elemTy.isSignedInteger(64)) return "si64"; + if (elemTy.isSignedInteger(32)) return "si32"; + if (elemTy.isSignedInteger(16)) return "si16"; + if (elemTy.isSignedInteger(8)) return "si8"; + if (elemTy.isSignlessInteger(64)) return "i64"; + if (elemTy.isSignlessInteger(32)) return "i32"; + if (elemTy.isSignlessInteger(16)) return "i16"; + if (elemTy.isSignlessInteger(8)) return "i8"; + return ""; +} + +// Cast `operand` to `dstTy`, preferring semantically precise ops over the +// generic unrealized cast so later lowering passes don't get stuck. +static Value bridgeOperandToType(OpBuilder &builder, Location loc, + Value operand, Type dstTy) { + Type srcTy = operand.getType(); + if (srcTy == dstTy) + return operand; + if (srcTy.isIndex() && isa(dstTy)) + return builder.create(loc, dstTy, operand); + return builder.create(loc, dstTy, operand) + .getResult(0); +} + +static StringRef getTileOpName(Operation *op) { + return op->getName().stripDialect(); +} + +static std::string getTargetArchString(ModuleOp mod) { + if (!mod) + return ""; + auto targetAttr = mod->getAttrOfType("pto.target_arch"); + if (!targetAttr) + return ""; + return targetAttr.getValue().str(); +} + +static std::string stringifyMemorySpace(pto::AddressSpace space) { + switch (space) { + case pto::AddressSpace::GM: + return "gm"; + case pto::AddressSpace::MAT: + return "mat"; + case pto::AddressSpace::LEFT: + return "left"; + case pto::AddressSpace::RIGHT: + return "right"; + case pto::AddressSpace::ACC: + return "acc"; + case pto::AddressSpace::BIAS: + return "bias"; + case pto::AddressSpace::VEC: + case pto::AddressSpace::SCALING: + case pto::AddressSpace::Zero: + return "ub"; + } + return "ub"; +} + +static std::string getMemorySpaceString(pto::TileBufType tbTy) { + auto msAttr = dyn_cast_or_null(tbTy.getMemorySpace()); + return msAttr ? stringifyMemorySpace(msAttr.getAddressSpace()) : "ub"; +} + +static std::string getMemorySpaceString(MemRefType mrTy) { + auto msAttr = dyn_cast_or_null(mrTy.getMemorySpace()); + return msAttr ? stringifyMemorySpace(msAttr.getAddressSpace()) : "gm"; +} + +static std::string getBLayoutString(int32_t blayout) { + if (blayout == static_cast(pto::BLayout::ColMajor)) + return "col_major"; + return "row_major"; +} + +static std::string getSLayoutString(int32_t slayout) { + if (slayout == static_cast(pto::SLayout::RowMajor)) + return "row_major"; + if (slayout == static_cast(pto::SLayout::ColMajor)) + return "col_major"; + return "none_box"; +} + +static std::optional getTCvtRoundModeString(pto::TCvtOp op) { + switch (op.getRmode()) { + case pto::RoundMode::NONE: + case pto::RoundMode::RINT: + case pto::RoundMode::CAST_RINT: + return "RINT"; + case pto::RoundMode::ROUND: + return "ROUND"; + case pto::RoundMode::FLOOR: + return "FLOOR"; + case pto::RoundMode::CEIL: + return "CEIL"; + case pto::RoundMode::TRUNC: + return "TRUNC"; + case pto::RoundMode::ODD: + return "ODD"; + } + return std::nullopt; +} + +static StringRef getPrecisionModeString(pto::PrecisionMode mode) { + switch (mode) { + case pto::PrecisionMode::DEFAULT: + return "DEFAULT"; + case pto::PrecisionMode::HIGH_PRECISION: + return "HIGH_PRECISION"; + } + llvm_unreachable("unknown PrecisionMode"); +} + +// MUST stay in sync with template behavior. Adding an op here without a real +// HIGH_PRECISION code path would silence the warning while preserving DEFAULT +// behavior. +static const llvm::StringSet<> &highPrecisionImplementedOps() { + static const llvm::StringSet<> kImplementedOps{ + "pto.tlog", + "pto.tdiv", + "pto.tdivs", + "pto.trecip", + "pto.trowexpanddiv", + "pto.tcolexpanddiv", + }; + return kImplementedOps; +} + +template +static bool tryAppendPrecisionMode( + Operation *op, + SmallVectorImpl> &attrs) { + auto typed = dyn_cast(op); + if (!typed) + return false; + + pto::PrecisionMode mode = typed.getPrecisionMode(); + attrs.emplace_back("precision_mode", getPrecisionModeString(mode).str()); + + if (mode == pto::PrecisionMode::HIGH_PRECISION && + !highPrecisionImplementedOps().contains(op->getName().getStringRef())) { + StringRef opName = op->getName().getStringRef(); + llvm::errs() << "warning: '" << opName << "' op " << opName + << ": precision_mode = HIGH_PRECISION requested but not yet " + "implemented; falling back to DEFAULT behavior\n"; + } + return true; +} + +static std::string getTRandomRoundsString(pto::TRandomOp op) { + return std::to_string(op.getRounds()); +} + +static void appendOpContextAttrs( + Operation *op, + SmallVectorImpl> &attrs) { + if (auto tcvt = dyn_cast(op)) { + std::optional roundMode = getTCvtRoundModeString(tcvt); + if (roundMode) + attrs.emplace_back("round_mode", *roundMode); + } + if (auto trandom = dyn_cast(op)) + attrs.emplace_back("rounds", getTRandomRoundsString(trandom)); + if (auto tcmp = dyn_cast(op)) { + if (auto cmpModeAttr = tcmp.getCmpModeAttr()) { + attrs.emplace_back("cmp_mode", + stringifyCmpMode(cmpModeAttr.getValue()).str()); + } + } + if (auto tcmps = dyn_cast(op)) { + if (auto cmpModeAttr = tcmps.getCmpModeAttr()) { + attrs.emplace_back("cmp_mode", + stringifyCmpMode(cmpModeAttr.getValue()).str()); + } + } + (void)(tryAppendPrecisionMode(op, attrs) || + tryAppendPrecisionMode(op, attrs) || + tryAppendPrecisionMode(op, attrs) || + tryAppendPrecisionMode(op, attrs) || + tryAppendPrecisionMode(op, attrs) || + tryAppendPrecisionMode(op, attrs) || + tryAppendPrecisionMode(op, attrs) || + tryAppendPrecisionMode(op, attrs) || + tryAppendPrecisionMode(op, attrs)); +} + +static bool getStaticIntFromValue(Value value, int64_t &out) { + if (auto cOp = value.getDefiningOp()) { + out = cOp.value(); + return true; + } + if (auto cInt = value.getDefiningOp()) { + out = cInt.value(); + return true; + } + return false; +} + +static int64_t getStaticIntOrDynamic(OpFoldResult ofr) { + if (auto attr = ofr.dyn_cast()) { + if (auto intAttr = dyn_cast(attr)) + return intAttr.getInt(); + return ShapedType::kDynamic; + } + auto value = llvm::cast(ofr); + int64_t result = ShapedType::kDynamic; + if (getStaticIntFromValue(value, result)) + return result; + return ShapedType::kDynamic; +} + +static void recordStaticSizes(ArrayRef inputs, + SmallVectorImpl &out) { + out.clear(); + out.reserve(inputs.size()); + for (OpFoldResult ofr : inputs) + out.push_back(getStaticIntOrDynamic(ofr)); +} + +static SmallVector combineSubviewStrides(ArrayRef baseStrides, + ArrayRef steps) { + SmallVector result; + result.reserve(baseStrides.size()); + for (auto [baseStride, step] : llvm::zip(baseStrides, steps)) { + int64_t stepValue = getStaticIntOrDynamic(step); + if (baseStride == ShapedType::kDynamic || + stepValue == ShapedType::kDynamic) { + result.push_back(ShapedType::kDynamic); + continue; + } + result.push_back(baseStride * stepValue); + } + return result; +} + +static void populateViewShapeAndStrides(Value value, + SmallVectorImpl &shape, + SmallVectorImpl &strides) { + if (!value) + return; + + if (auto subview = value.getDefiningOp()) { + populateViewShapeAndStrides(subview.getSource(), shape, strides); + SmallVector subviewShape; + recordStaticSizes(subview.getMixedSizes(), subviewShape); + if (!subviewShape.empty()) + shape = subviewShape; + if (!strides.empty()) + strides = combineSubviewStrides(strides, subview.getMixedStrides()); + return; + } + + if (auto reinterpret = value.getDefiningOp()) { + if (shape.empty()) { + SmallVector reinterpretShape; + recordStaticSizes(reinterpret.getMixedSizes(), reinterpretShape); + if (!reinterpretShape.empty()) + shape = reinterpretShape; + } + if (strides.empty()) + recordStaticSizes(reinterpret.getMixedStrides(), strides); + return; + } + + if (auto cast = value.getDefiningOp()) { + populateViewShapeAndStrides(cast.getSource(), shape, strides); + return; + } + + if (auto memrefTy = dyn_cast(value.getType())) { + if (shape.empty()) + shape.assign(memrefTy.getShape().begin(), memrefTy.getShape().end()); + if (strides.empty()) { + int64_t offset = ShapedType::kDynamic; + if (succeeded(getStridesAndOffset(memrefTy, strides, offset))) { + // strides populated — dynamic dims remain ShapedType::kDynamic. + } + } + } +} + +static std::optional buildOperandTypeInfo(Value value) { + Type ty = value.getType(); + // Tile operand — from TileBufType. + if (auto tbTy = dyn_cast(ty)) { + OperandTypeInfo info; + info.kind = OperandKind::Tile; + info.dtype = getDtypeString(tbTy.getElementType()); + if (info.dtype.empty()) + return std::nullopt; + info.tileShape.assign(tbTy.getShape().begin(), tbTy.getShape().end()); + auto validShape = tbTy.getValidShape(); + if (validShape.empty()) + info.tileValidShape.assign(tbTy.getShape().begin(), tbTy.getShape().end()); + else + info.tileValidShape.assign(validShape.begin(), validShape.end()); + info.tileMemorySpace = getMemorySpaceString(tbTy); + if (auto config = tbTy.getConfigAttr()) { + info.blayout = static_cast(config.getBLayout().getValue()); + info.slayout = static_cast(config.getSLayout().getValue()); + info.fractal = config.getSFractalSize() + ? static_cast(config.getSFractalSize().getInt()) + : 0; + info.pad = static_cast(config.getPad().getValue()); + } + return info; + } + + // View operand — from MemRefType (lowered PartitionTensorViewType). + if (auto mrTy = dyn_cast(ty)) { + OperandTypeInfo info; + info.kind = OperandKind::View; + info.dtype = getDtypeString(mrTy.getElementType()); + if (info.dtype.empty()) + return std::nullopt; + info.viewMemorySpace = getMemorySpaceString(mrTy); + populateViewShapeAndStrides(value, info.viewShape, info.viewStrides); + if (info.viewShape.empty()) + info.viewShape.assign(mrTy.getShape().begin(), mrTy.getShape().end()); + if (info.viewStrides.empty()) { + int64_t offset = ShapedType::kDynamic; + if (succeeded(getStridesAndOffset(mrTy, info.viewStrides, offset))) { + // strides populated — dynamic dims remain ShapedType::kDynamic. + } + } + return info; + } + + // Auxiliary vector operand — from builtin VectorType (e.g. vector<4xi16>). + if (auto vecTy = dyn_cast(ty)) { + OperandTypeInfo info; + info.kind = OperandKind::Vector; + info.dtype = getDtypeString(vecTy.getElementType()); + if (info.dtype.empty()) + return std::nullopt; + info.vectorShape.assign(vecTy.getShape().begin(), vecTy.getShape().end()); + return info; + } + + // Scalar operand — from a scalar element type. + OperandTypeInfo info; + info.kind = OperandKind::Scalar; + info.dtype = getDtypeString(ty); + if (info.dtype.empty()) + return std::nullopt; + return info; +} + +static std::optional buildSpecKey(Operation *op) { + SpecKey key; + key.opName = getTileOpName(op).str(); + key.targetArch = getTargetArchString(op->getParentOfType()); + + for (unsigned i = 0; i < op->getNumOperands(); ++i) { + auto info = buildOperandTypeInfo(op->getOperand(i)); + if (!info) + return std::nullopt; + key.operands.push_back(*info); + } + if (key.operands.empty()) + return std::nullopt; + + appendOpContextAttrs(op, key.contextAttrs); + return key; +} + +// ============================================================================ +// ExpandState: runtime state for a single pass invocation. +// ============================================================================ +struct ExpandState { + std::vector> parsedModules; + llvm::DenseMap specCache; + + std::string tilelangPath; + std::string tilelangPkgPath; + std::string pythonExe; + + func::FuncOp invokeTilelangDSL(const SpecKey &key, Operation *tileOp, + ModuleOp mod, MLIRContext *ctx); + + LogicalResult expandTileOpsInFunction(func::FuncOp func, ModuleOp mod, + MLIRContext *ctx); +}; + +// ============================================================================ +// The Pass +// ============================================================================ +struct ExpandTileOpPass + : public mlir::pto::impl::ExpandTileOpBase { + using ExpandTileOpBase::ExpandTileOpBase; + + void runOnOperation() override; +}; + +/// Serialize a JSON array of integers. +static void appendJsonIntArray(std::string &json, ArrayRef arr) { + json += "["; + for (size_t i = 0; i < arr.size(); ++i) { + if (i > 0) + json += ","; + json += std::to_string(arr[i]); + } + json += "]"; +} + +/// Serialize a JSON array where dynamic dimensions become `null`. +static void appendJsonDimArray(std::string &json, ArrayRef arr, + bool negativeIsDynamic = false) { + json += "["; + for (size_t i = 0; i < arr.size(); ++i) { + if (i > 0) + json += ","; + int64_t dim = arr[i]; + if (ShapedType::isDynamic(dim) || (negativeIsDynamic && dim < 0)) { + json += "null"; + continue; + } + json += std::to_string(dim); + } + json += "]"; +} + +static std::string buildOperandSpecsJson(const SpecKey &key) { + std::string json = "["; + for (size_t i = 0; i < key.operands.size(); ++i) { + const auto &op = key.operands[i]; + if (i > 0) + json += ","; + + if (op.kind == OperandKind::Tile) { + json += "{\"kind\":\"tile\",\"dtype\":\"" + op.dtype + "\",\"shape\":"; + appendJsonIntArray(json, op.tileShape); + json += ",\"valid_shape\":"; + appendJsonDimArray(json, op.tileValidShape, /*negativeIsDynamic=*/true); + json += ",\"memory_space\":\""; + json += op.tileMemorySpace; + json += "\",\"config\":{"; + json += "\"b_layout\":\""; + json += getBLayoutString(op.blayout); + json += "\",\"s_layout\":\""; + json += getSLayoutString(op.slayout); + json += "\",\"s_fractal_size\":"; + json += std::to_string(op.fractal); + json += ",\"pad_value\":\"0x"; + json += llvm::utohexstr(op.pad, /*LowerCase=*/false); + json += "\"}}"; + continue; + } + + if (op.kind == OperandKind::View) { + json += "{\"kind\":\"view\",\"dtype\":\"" + op.dtype + "\",\"shape\":"; + appendJsonDimArray(json, op.viewShape); + if (!op.viewStrides.empty()) { + json += ",\"strides\":["; + for (size_t dim = 0; dim < op.viewStrides.size(); ++dim) { + if (dim > 0) + json += ","; + if (ShapedType::isDynamic(op.viewStrides[dim])) + json += "null"; + else + json += std::to_string(op.viewStrides[dim]); + } + json += "]"; + } + json += ",\"memory_space\":\"" + op.viewMemorySpace + "\"}"; + continue; + } + + if (op.kind == OperandKind::Vector) { + json += "{\"kind\":\"vector\",\"dtype\":\"" + op.dtype + "\",\"shape\":"; + appendJsonIntArray(json, op.vectorShape); + json += "}"; + continue; + } + + // Scalar + json += "{\"kind\":\"scalar\",\"dtype\":\"" + op.dtype + "\"}"; + } + json += "]"; + return json; +} + +static std::string buildContextAttrsJson(const SpecKey &key) { + std::string json = "{"; + for (size_t i = 0; i < key.contextAttrs.size(); ++i) { + const auto &[attrName, attrValue] = key.contextAttrs[i]; + if (i > 0) + json += ","; + json += "\""; + json += attrName; + json += "\":\""; + json += attrValue; + json += "\""; + } + json += "}"; + return json; +} + +// ============================================================================ +// Invoke Python DSL helper to generate a specialized template function. +// ============================================================================ +func::FuncOp ExpandState::invokeTilelangDSL(const SpecKey &key, + Operation *tileOp, + ModuleOp mod, MLIRContext *ctx) { + // Check cache first. + auto cacheIt = specCache.find(key); + if (cacheIt != specCache.end()) + return cacheIt->second; + + // 1. Locate the Python executable. + auto pythonPath = llvm::sys::findProgramByName(pythonExe); + if (!pythonPath) { + llvm::errs() << "ExpandTileOp: cannot find '" << pythonExe << "'\n"; + return nullptr; + } + + // 2. Build operand schema JSON for mixed tile/scalar specialization. + std::string operandSpecsJson = buildOperandSpecsJson(key); + std::string contextAttrsJson = buildContextAttrsJson(key); + if (key.targetArch.empty()) { + llvm::errs() << "ExpandTileOp: missing pto.target_arch module attribute\n"; + return nullptr; + } + + // 3. Create temp file for stdout redirect. + SmallString<128> tmpPath; + int tmpFD; + if (auto ec = llvm::sys::fs::createTemporaryFile("tilelang_expand", "mlir", + tmpFD, tmpPath)) { + llvm::errs() << "ExpandTileOp: cannot create temp file: " + << ec.message() << "\n"; + return nullptr; + } + ::close(tmpFD); + + // 4. Build command args. + std::string opName = "pto." + key.opName; + SmallVector args = { + *pythonPath, "-m", "tilelang_dsl.expand_helper", + "--template-dir", tilelangPath, + "--target", key.targetArch, + "--op", opName, + "--operand-specs", operandSpecsJson, + }; + if (!key.contextAttrs.empty()) { + args.push_back("--context-attrs"); + args.push_back(contextAttrsJson); + } + + // 5. Set up environment with PYTHONPATH. + std::optional redirects[] = {std::nullopt, StringRef(tmpPath), + std::nullopt}; + + SmallVector envp; + std::string pythonPathEnv; + std::vector envStorage; + bool hasPythonPath = !tilelangPkgPath.empty(); + if (hasPythonPath) { + const char *existingPath = ::getenv("PYTHONPATH"); + pythonPathEnv = "PYTHONPATH=" + tilelangPkgPath; + if (existingPath && existingPath[0] != '\0') { + pythonPathEnv += ":"; + pythonPathEnv += existingPath; + } + for (char **e = environ; *e; ++e) { + StringRef entry(*e); + if (entry.starts_with("PYTHONPATH=")) + continue; + envStorage.push_back(std::string(entry)); + } + envStorage.push_back(pythonPathEnv); + for (auto &s : envStorage) + envp.push_back(s); + } + + // 6. Execute. + std::string errMsg; + int rc = llvm::sys::ExecuteAndWait( + *pythonPath, args, + hasPythonPath ? std::optional>(envp) : std::nullopt, + redirects, /*secondsToWait=*/30, /*memoryLimit=*/0, &errMsg); + + if (rc != 0) { + std::string cmd; + llvm::raw_string_ostream os(cmd); + bool first = true; + auto appendToken = [&](StringRef token) { + if (!first) + os << ' '; + first = false; + llvm::sys::printArg(os, token, /*Quote=*/true); + }; + if (hasPythonPath) { + appendToken("env"); + appendToken(pythonPathEnv); + } + for (StringRef arg : args) + appendToken(arg); + os.flush(); + + llvm::errs() << "ExpandTileOp: tilelang DSL helper failed (rc=" << rc + << "): " << errMsg << "\n"; + llvm::errs() << "ExpandTileOp: run: " << cmd << "\n"; + llvm::sys::fs::remove(tmpPath); + return nullptr; + } + + // 7. Read the generated MLIR. + auto bufOrErr = llvm::MemoryBuffer::getFile(tmpPath); + llvm::sys::fs::remove(tmpPath); + if (!bufOrErr) { + llvm::errs() << "ExpandTileOp: cannot read DSL output\n"; + return nullptr; + } + StringRef mlirText = (*bufOrErr)->getBuffer(); + if (mlirText.empty()) { + llvm::errs() << "ExpandTileOp: empty DSL output\n"; + return nullptr; + } + + // 8. Parse the MLIR text. + auto parsedMod = parseSourceString(mlirText, ctx); + if (!parsedMod) { + llvm::errs() << "ExpandTileOp: failed to parse DSL output\n"; + return nullptr; + } + + // 9. Clone the generated function set into the target module. The TileLang + // output may include private inline helper funcs referenced by the entry. + SmallVector parsedFuncs; + for (auto fn : parsedMod->getOps()) + parsedFuncs.push_back(fn); + if (parsedFuncs.empty()) { + llvm::errs() << "ExpandTileOp: no func.func in DSL output\n"; + return nullptr; + } + OpBuilder builder(ctx); + builder.setInsertionPointToEnd(mod.getBody()); + SmallVector clonedFuncs; + llvm::StringMap renamedSymbols; + + // Build a unique name from the spec-key-relevant operand fields. + std::string uniqueName = "__pto_tilelang_" + key.targetArch + "_" + key.opName; + for (const auto &op : key.operands) { + uniqueName += op.kind == OperandKind::Tile ? "_tile" + : op.kind == OperandKind::View ? "_view" + : op.kind == OperandKind::Vector ? "_vector" + : "_scalar"; + uniqueName += "_" + op.dtype; + if (op.kind == OperandKind::Tile) { + for (int64_t d : op.tileShape) + uniqueName += "_" + std::to_string(d); + for (int64_t d : op.tileValidShape) + uniqueName += "_v" + std::to_string(d); + uniqueName += "_bl" + std::to_string(op.blayout); + uniqueName += "_sl" + std::to_string(op.slayout); + uniqueName += "_fr" + std::to_string(op.fractal); + uniqueName += "_pd" + llvm::utohexstr(op.pad, /*LowerCase=*/false); + } else if (op.kind == OperandKind::Vector) { + for (int64_t d : op.vectorShape) + uniqueName += "_" + std::to_string(d); + } + } + for (const auto &[attrName, attrValue] : key.contextAttrs) + uniqueName += "_ctx_" + attrName + "_" + attrValue; + + for (auto [index, fn] : llvm::enumerate(parsedFuncs)) { + IRMapping mapping; + auto cloned = cast(builder.clone(*fn, mapping)); + std::string newName; + if (index == 0) { + newName = uniqueName; + cloned.setVisibility(SymbolTable::Visibility::Private); + } else { + newName = uniqueName + "__" + std::string(fn.getSymName()); + } + renamedSymbols[fn.getSymName()] = newName; + cloned.setName(newName); + clonedFuncs.push_back(cloned); + } + + for (func::FuncOp fn : clonedFuncs) { + fn.walk([&](func::CallOp call) { + StringRef callee = call.getCallee(); + if (callee.empty()) + return; + auto renameIt = renamedSymbols.find(callee); + if (renameIt == renamedSymbols.end()) + return; + call.setCallee(renameIt->second); + }); + } + + auto cloned = clonedFuncs.front(); + // The pto.tilelang.instance attribute should already be set by the + // TileLang DSL frontend in the generated MLIR. Verify it exists. + if (!cloned->hasAttr("pto.tilelang.instance")) { + llvm::errs() << "ExpandTileOp: warning: DSL output function @" + << cloned.getSymName() + << " missing pto.tilelang.instance attribute\n"; + } + + // Keep the parsed module alive. + parsedModules.push_back(std::move(parsedMod)); + + specCache[key] = cloned; + return cloned; +} + +// ============================================================================ +// Expand tile ops in a single function. +// ============================================================================ +LogicalResult ExpandState::expandTileOpsInFunction(func::FuncOp func, + ModuleOp mod, + MLIRContext *ctx) { + OpBuilder builder(ctx); + + // Collect tile ops first (avoid modifying while iterating). + SmallVector tileOps; + func.walk([&](Operation *op) { + if (isa(op)) + tileOps.push_back(op); + }); + + for (auto *op : tileOps) { + auto specKeyOpt = buildSpecKey(op); + if (!specKeyOpt) { + op->emitError( + "ExpandTileOp: cannot build specialization key for this operand schema"); + return failure(); + } + + // Invoke tilelang DSL (with caching). + func::FuncOp dslFn = invokeTilelangDSL(*specKeyOpt, op, mod, ctx); + if (!dslFn) { + StringRef opName = getTileOpName(op); + op->emitError("ExpandTileOp: failed to instantiate tilelang template for " + + opName); + return failure(); + } + + // Replace tile op with func.call. For view operands whose caller type + // (memref) differs from the template parameter type (tensor_view / + // partition_tensor_view), insert an unrealized_conversion_cast bridge. + // FoldTileBufIntrinsics will later resolve these casts. + builder.setInsertionPoint(op); + SmallVector operands; + auto fnArgTypes = dslFn.getArgumentTypes(); + for (unsigned i = 0; i < op->getNumOperands(); ++i) { + Value operand = op->getOperand(i); + if (i < fnArgTypes.size() && operand.getType() != fnArgTypes[i]) { + operand = bridgeOperandToType(builder, op->getLoc(), operand, + fnArgTypes[i]); + } + operands.push_back(operand); + } + builder.create(op->getLoc(), dslFn, operands); + op->erase(); + } + + return success(); +} + +// ============================================================================ +// Main entry point. +// ============================================================================ +void ExpandTileOpPass::runOnOperation() { + ModuleOp mod = getOperation(); + MLIRContext *ctx = &getContext(); + + if (tilelangPath.empty()) { + mod.emitError( + "ExpandTileOp requires a non-empty tilelang-path on the VPTO backend"); + signalPassFailure(); + return; + } + + ExpandState state; + state.tilelangPath = std::string(tilelangPath); + state.tilelangPkgPath = std::string(tilelangPkgPath); + state.pythonExe = std::string(pythonExe); + + for (auto func : mod.getOps()) { + if (func.isExternal()) + continue; + if (failed(state.expandTileOpsInFunction(func, mod, ctx))) + return signalPassFailure(); + } +} + +} // namespace + +namespace mlir { +namespace pto { + +std::unique_ptr createExpandTileOpPass() { + return std::make_unique(); +} + +std::unique_ptr +createExpandTileOpPass(const ExpandTileOpOptions &options) { + return std::make_unique(options); +} + +} // namespace pto +} // namespace mlir diff --git a/lib/PTO/Transforms/FoldTileBufIntrinsics.cpp b/lib/PTO/Transforms/FoldTileBufIntrinsics.cpp new file mode 100644 index 000000000..398b13df8 --- /dev/null +++ b/lib/PTO/Transforms/FoldTileBufIntrinsics.cpp @@ -0,0 +1,611 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- FoldTileBufIntrinsics.cpp ------------------------------------------===// +// +// After TileLang DSL template functions are inlined, the IR contains +// structured-view intrinsics that reference template parameters: +// +// tile_buf family: +// - pto.tile_buf_addr → extract memref address from tile_buf +// - pto.tile_valid_rows → extract valid row count +// - pto.tile_valid_cols → extract valid column count +// +// tensor_view family: +// - pto.tensor_view_addr → extract memref/ptr from tensor_view +// - pto.get_tensor_view_dim → extract dimension size +// - pto.get_tensor_view_stride → extract dimension stride +// +// This pass resolves them against the concrete values at the call site. +// For tile_buf intrinsics, the active VPTO path folds against materialized tile +// handles produced by the shared tile-handle bridge (`pto.alloc_tile` or +// `pto.materialize_tile`). +// For tensor_view intrinsics, the pass traces through the full +// unrealized_conversion_cast → memref.subview → memref.reinterpret_cast +// chain to fold directly to constants or SSA operands from the +// reinterpret_cast, without generating intermediate memref.dim / +// memref.extract_strided_metadata ops. +// +//===----------------------------------------------------------------------===// + +#include "PTO/IR/PTO.h" +#include "PTO/Transforms/Passes.h" + +#include + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace mlir { +namespace pto { + #define GEN_PASS_DEF_FOLDTILEBUFINTRINSICS + #include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +namespace { + +static void eraseDeadAllocTileOps(func::FuncOp func) { + SmallVector deadAllocs; + func.walk([&](pto::AllocTileOp alloc) { + if (alloc.getResult().use_empty()) + deadAllocs.push_back(alloc); + }); + + for (pto::AllocTileOp alloc : llvm::reverse(deadAllocs)) + alloc.erase(); +} + +struct TileHandleInfo { + Value sourceMemref; + Value addr; + Value validRow; + Value validCol; + pto::TileBufConfigAttr config; +}; + +static std::optional resolveTileHandle(Value tileBuf, + Operation *user) { + if (auto alloc = tileBuf.getDefiningOp()) { + auto tileTy = dyn_cast(alloc.getResult().getType()); + if (!tileTy) { + user->emitError( + "FoldTileBufIntrinsics: pto.alloc_tile must produce !pto.tile_buf"); + return std::nullopt; + } + return TileHandleInfo{Value(), alloc.getAddr(), alloc.getValidRow(), + alloc.getValidCol(), tileTy.getConfigAttr()}; + } + + if (auto materialize = tileBuf.getDefiningOp()) { + return TileHandleInfo{materialize.getSource(), Value(), + materialize.getValidRow(), materialize.getValidCol(), + materialize.getConfig()}; + } + + user->emitError("FoldTileBufIntrinsics: expected tile_buf to be defined by " + "the active materialized tile-handle bridge " + "(pto.alloc_tile or pto.materialize_tile)"); + return std::nullopt; +} + +static MemRefType getCanonicalMemRefTypeForTileBuf(pto::TileBufType tileTy) { + return MemRefType::get(tileTy.getShape(), tileTy.getElementType(), + AffineMap(), tileTy.getMemorySpace()); +} + +struct ViewChain { + UnrealizedConversionCastOp cast; + memref::SubViewOp subview; + memref::ReinterpretCastOp reinterpretCast; + Value baseMemref; +}; + +static std::optional traceViewChain(Value tensorView, + Operation *user) { + Value memrefVal; + UnrealizedConversionCastOp castOp; + + if (isa(tensorView.getType())) { + memrefVal = tensorView; + } else { + castOp = tensorView.getDefiningOp(); + if (!castOp || castOp.getNumOperands() != 1) { + user->emitError( + "FoldTileBufIntrinsics: expected tensor_view to be defined by a " + "single-operand builtin.unrealized_conversion_cast"); + return std::nullopt; + } + memrefVal = castOp.getOperand(0); + if (!isa(memrefVal.getType())) { + user->emitError( + "FoldTileBufIntrinsics: expected cast operand to be a memref, got ") + << memrefVal.getType(); + return std::nullopt; + } + } + + auto subviewOp = memrefVal.getDefiningOp(); + if (!subviewOp) { + user->emitError("FoldTileBufIntrinsics: expected memref to be defined by " + "memref.subview, got ") + << (memrefVal.getDefiningOp() + ? memrefVal.getDefiningOp()->getName().getStringRef() + : StringRef("block argument")); + return std::nullopt; + } + + auto rcOp = subviewOp.getSource().getDefiningOp(); + if (!rcOp) { + user->emitError( + "FoldTileBufIntrinsics: expected subview source to be defined by " + "memref.reinterpret_cast, got ") + << (subviewOp.getSource().getDefiningOp() + ? subviewOp.getSource().getDefiningOp()->getName().getStringRef() + : StringRef("block argument")); + return std::nullopt; + } + + return ViewChain{castOp, subviewOp, rcOp, rcOp.getSource()}; +} + +static bool getConstIndexValue(Value v, int64_t &out) { + if (auto cOp = v.getDefiningOp()) { + out = cOp.value(); + return true; + } + if (auto cInt = v.getDefiningOp()) { + out = cInt.value(); + return true; + } + if (auto cOp = v.getDefiningOp()) { + if (auto ia = dyn_cast(cOp.getValue())) { + out = ia.getInt(); + return true; + } + } + if (auto castOp = v.getDefiningOp()) + return getConstIndexValue(castOp.getIn(), out); + if (auto extOp = v.getDefiningOp()) + return getConstIndexValue(extOp.getIn(), out); + if (auto extOp = v.getDefiningOp()) + return getConstIndexValue(extOp.getIn(), out); + if (auto truncOp = v.getDefiningOp()) + return getConstIndexValue(truncOp.getIn(), out); + return false; +} + +static Value getValueOrCreateConstant(OpBuilder &builder, Location loc, + OpFoldResult ofr) { + if (auto val = dyn_cast(ofr)) + return val; + auto intAttr = dyn_cast(cast(ofr)); + assert(intAttr && "expected integer attribute in OpFoldResult"); + return builder.create(loc, intAttr.getInt()); +} + +static bool isAllStaticZero(ArrayRef ofrs) { + for (OpFoldResult ofr : ofrs) { + auto attr = dyn_cast(ofr); + if (!attr) + return false; + auto intAttr = dyn_cast(attr); + if (!intAttr || intAttr.getInt() != 0) + return false; + } + return true; +} + +static Value computeResultStride(OpBuilder &builder, Location loc, + OpFoldResult rcStride, + OpFoldResult svStride) { + if (auto attr = dyn_cast(svStride)) { + auto intAttr = dyn_cast(attr); + if (intAttr && intAttr.getInt() == 1) + return getValueOrCreateConstant(builder, loc, rcStride); + } + + Value lhs = getValueOrCreateConstant(builder, loc, rcStride); + Value rhs = getValueOrCreateConstant(builder, loc, svStride); + return builder.create(loc, lhs, rhs); +} + +static Value computeLinearOffset(OpBuilder &builder, Location loc, + ArrayRef rcOffsets, + ArrayRef svOffsets, + ArrayRef rcStrides) { + bool rcAllZero = isAllStaticZero(rcOffsets); + bool svAllZero = isAllStaticZero(svOffsets); + + if (rcAllZero && svAllZero) + return Value(); + + Value svPart; + if (!svAllZero) { + for (auto [svOffset, rcStride] : llvm::zip(svOffsets, rcStrides)) { + if (auto attr = dyn_cast(svOffset)) { + auto intAttr = dyn_cast(attr); + if (intAttr && intAttr.getInt() == 0) + continue; + } + + Value off = getValueOrCreateConstant(builder, loc, svOffset); + Value stride = getValueOrCreateConstant(builder, loc, rcStride); + Value term = builder.create(loc, off, stride); + svPart = svPart ? builder.create(loc, svPart, term) : term; + } + } + + Value rcPart; + if (!rcAllZero) { + if (rcOffsets.empty()) + return Value(); + rcPart = getValueOrCreateConstant(builder, loc, rcOffsets.front()); + } + + if (rcPart && svPart) + return builder.create(loc, rcPart, svPart); + return rcPart ? rcPart : svPart; +} + +struct FoldTileBufIntrinsicsPass + : public pto::impl::FoldTileBufIntrinsicsBase { + using FoldTileBufIntrinsicsBase::FoldTileBufIntrinsicsBase; + + void runOnOperation() override { + func::FuncOp func = getOperation(); + MLIRContext *ctx = &getContext(); + OpBuilder builder(ctx); + + // Leftover TileLang template instances (private, uncalled after + // PTOInlineLibCall) still contain pto.tile_buf_addr / tile_valid_* + // ops on tile_buf function arguments — they have no materialized tile + // handle anchor to fold against and will be removed by later DCE. Skip + // them. + if (func->hasAttr("pto.tilelang.instance")) + return; + + SmallVector addrOps; + SmallVector rowsOps; + SmallVector colsOps; + SmallVector tvAddrOps; + SmallVector tvDimOps; + SmallVector tvStrideOps; + + func.walk([&](Operation *op) { + if (auto addr = dyn_cast(op)) + addrOps.push_back(addr); + else if (auto rows = dyn_cast(op)) + rowsOps.push_back(rows); + else if (auto cols = dyn_cast(op)) + colsOps.push_back(cols); + else if (auto tvAddr = dyn_cast(op)) + tvAddrOps.push_back(tvAddr); + else if (auto tvDim = dyn_cast(op)) + tvDimOps.push_back(tvDim); + else if (auto tvStride = dyn_cast(op)) + tvStrideOps.push_back(tvStride); + }); + + // Fold pto.tile_buf_addr by recovering the active materialized tile + // handle contract: + // - pto.materialize_tile → use the source memref directly + // - pto.alloc_tile → rebuild a memref from the explicit addr + // When the requested result type is already !pto.ptr<...>, cast from the + // recovered memref instead of leaving tile_buf_addr in the IR. + for (auto addrOp : addrOps) { + auto handleInfo = resolveTileHandle(addrOp.getSrc(), addrOp); + if (!handleInfo) + return signalPassFailure(); + + auto tileTy = dyn_cast(addrOp.getSrc().getType()); + if (!tileTy) { + addrOp.emitError("FoldTileBufIntrinsics: tile_buf_addr source must be " + "!pto.tile_buf"); + return signalPassFailure(); + } + + if (auto resultMemrefType = dyn_cast(addrOp.getDst().getType())) { + if (handleInfo->sourceMemref) { + Value srcMemref = handleInfo->sourceMemref; + if (!isa(srcMemref.getType())) { + addrOp.emitError( + "FoldTileBufIntrinsics: pto.materialize_tile source is not a memref"); + return signalPassFailure(); + } + + // The declared tile_buf_addr result type may differ from the actual + // materialized source layout (e.g. plain shape vs. strided layout). + if (srcMemref.getType() != resultMemrefType) + addrOp.getDst().setType(cast(srcMemref.getType())); + addrOp.getDst().replaceAllUsesWith(srcMemref); + addrOp.erase(); + continue; + } + + if (!handleInfo->addr) { + addrOp.emitError("FoldTileBufIntrinsics: pto.alloc_tile used by " + "tile_buf_addr must carry an addr operand on the " + "VPTO path"); + return signalPassFailure(); + } + + builder.setInsertionPoint(addrOp); + Value replacement = builder.create( + addrOp.getLoc(), resultMemrefType, ValueRange{handleInfo->addr}, + handleInfo->validRow ? handleInfo->validRow : Value(), + handleInfo->validCol ? handleInfo->validCol : Value(), + handleInfo->config); + addrOp.getDst().replaceAllUsesWith(replacement); + addrOp.erase(); + continue; + } + + auto resultPtrType = dyn_cast(addrOp.getDst().getType()); + if (!resultPtrType) { + addrOp.emitError( + "FoldTileBufIntrinsics: tile_buf_addr result must be memref or !pto.ptr"); + return signalPassFailure(); + } + + Value memrefValue; + if (handleInfo->sourceMemref) { + memrefValue = handleInfo->sourceMemref; + if (!isa(memrefValue.getType())) { + addrOp.emitError( + "FoldTileBufIntrinsics: pto.materialize_tile source is not a memref"); + return signalPassFailure(); + } + } else { + if (!handleInfo->addr) { + addrOp.emitError("FoldTileBufIntrinsics: pto.alloc_tile used by " + "tile_buf_addr must carry an addr operand on the " + "VPTO path"); + return signalPassFailure(); + } + + builder.setInsertionPoint(addrOp); + auto canonicalMemrefType = getCanonicalMemRefTypeForTileBuf(tileTy); + memrefValue = builder.create( + addrOp.getLoc(), canonicalMemrefType, ValueRange{handleInfo->addr}, + handleInfo->validRow ? handleInfo->validRow : Value(), + handleInfo->validCol ? handleInfo->validCol : Value(), + handleInfo->config); + } + + builder.setInsertionPoint(addrOp); + Value replacement = + builder.create(addrOp.getLoc(), resultPtrType, + memrefValue); + addrOp.getDst().replaceAllUsesWith(replacement); + addrOp.erase(); + } + + // Fold pto.tile_valid_rows → arith.constant (static) or the dynamic + // valid_row operand carried by the new tile handle bridge. + for (auto rowsOp : rowsOps) { + builder.setInsertionPoint(rowsOp); + auto tbTy = dyn_cast(rowsOp.getSrc().getType()); + if (!tbTy || tbTy.getValidShape().empty()) { + rowsOp.emitError("tile_valid_rows: invalid tile_buf type"); + return signalPassFailure(); + } + + int64_t vRow = tbTy.getValidShape()[0]; + Value replacement; + if (vRow != ShapedType::kDynamic) { + replacement = + builder.create(rowsOp.getLoc(), vRow); + } else { + auto handleInfo = resolveTileHandle(rowsOp.getSrc(), rowsOp); + if (!handleInfo) + return signalPassFailure(); + replacement = handleInfo->validRow; + if (!replacement) { + rowsOp.emitError( + "tile_valid_rows: dynamic v_row but the materialized tile " + "handle has no valid_row operand"); + return signalPassFailure(); + } + assert(replacement.getType() == rowsOp.getResult().getType() && + "tile_valid_rows fold: type mismatch with handle valid_row"); + } + rowsOp.getResult().replaceAllUsesWith(replacement); + rowsOp.erase(); + } + + // Fold pto.tile_valid_cols → arith.constant (static) or the dynamic + // valid_col operand carried by the new tile handle bridge. + for (auto colsOp : colsOps) { + builder.setInsertionPoint(colsOp); + auto tbTy = dyn_cast(colsOp.getSrc().getType()); + if (!tbTy || tbTy.getValidShape().size() < 2) { + colsOp.emitError("tile_valid_cols: invalid tile_buf type"); + return signalPassFailure(); + } + + int64_t vCol = tbTy.getValidShape()[1]; + Value replacement; + if (vCol != ShapedType::kDynamic) { + replacement = + builder.create(colsOp.getLoc(), vCol); + } else { + auto handleInfo = resolveTileHandle(colsOp.getSrc(), colsOp); + if (!handleInfo) + return signalPassFailure(); + replacement = handleInfo->validCol; + if (!replacement) { + colsOp.emitError( + "tile_valid_cols: dynamic v_col but the materialized tile " + "handle has no valid_col operand"); + return signalPassFailure(); + } + assert(replacement.getType() == colsOp.getResult().getType() && + "tile_valid_cols fold: type mismatch with handle valid_col"); + } + colsOp.getResult().replaceAllUsesWith(replacement); + colsOp.erase(); + } + + for (auto addrOp : tvAddrOps) { + auto chain = traceViewChain(addrOp.getSrc(), addrOp); + if (!chain) + return signalPassFailure(); + + builder.setInsertionPoint(addrOp); + + auto resultPtrType = dyn_cast(addrOp.getDst().getType()); + if (!resultPtrType) { + if (auto resultMemrefType = + dyn_cast(addrOp.getDst().getType())) { + Value base = chain->baseMemref; + if (base.getType() != resultMemrefType) + addrOp.getDst().setType(cast(base.getType())); + addrOp.getDst().replaceAllUsesWith(base); + addrOp.erase(); + continue; + } + addrOp.emitError( + "FoldTileBufIntrinsics: tensor_view_addr result must be memref or " + "!pto.ptr"); + return signalPassFailure(); + } + + Value linearOffset = + computeLinearOffset(builder, addrOp.getLoc(), + chain->reinterpretCast.getMixedOffsets(), + chain->subview.getMixedOffsets(), + chain->reinterpretCast.getMixedStrides()); + + Value basePtr = builder.create( + addrOp.getLoc(), resultPtrType, chain->baseMemref); + Value replacement = + linearOffset + ? builder.create(addrOp.getLoc(), resultPtrType, + basePtr, linearOffset) + : basePtr; + + addrOp.getDst().replaceAllUsesWith(replacement); + addrOp.erase(); + } + + for (auto dimOp : tvDimOps) { + auto chain = traceViewChain(dimOp.getTensorView(), dimOp); + if (!chain) + return signalPassFailure(); + + int64_t dimIdx = 0; + if (!getConstIndexValue(dimOp.getDimIndex(), dimIdx)) { + dimOp.emitError( + "FoldTileBufIntrinsics: get_tensor_view_dim requires a constant " + "dim index"); + return signalPassFailure(); + } + + auto svTy = cast(chain->subview.getType()); + if (dimIdx < 0 || dimIdx >= svTy.getRank()) { + dimOp.emitError( + "FoldTileBufIntrinsics: get_tensor_view_dim dim index out of " + "bounds"); + return signalPassFailure(); + } + + builder.setInsertionPoint(dimOp); + Value replacement; + if (!svTy.isDynamicDim(dimIdx)) { + replacement = + builder.create(dimOp.getLoc(), + svTy.getDimSize(dimIdx)); + } else { + replacement = getValueOrCreateConstant( + builder, dimOp.getLoc(), chain->subview.getMixedSizes()[dimIdx]); + } + + dimOp.getResult().replaceAllUsesWith(replacement); + dimOp.erase(); + } + + for (auto strideOp : tvStrideOps) { + auto chain = traceViewChain(strideOp.getTensorView(), strideOp); + if (!chain) + return signalPassFailure(); + + int64_t dimIdx = 0; + if (!getConstIndexValue(strideOp.getDimIndex(), dimIdx)) { + strideOp.emitError( + "FoldTileBufIntrinsics: get_tensor_view_stride requires a " + "constant dim index"); + return signalPassFailure(); + } + + auto svTy = cast(chain->subview.getType()); + if (dimIdx < 0 || dimIdx >= svTy.getRank()) { + strideOp.emitError( + "FoldTileBufIntrinsics: get_tensor_view_stride dim index out of " + "bounds"); + return signalPassFailure(); + } + + builder.setInsertionPoint(strideOp); + Value replacement = computeResultStride( + builder, strideOp.getLoc(), + chain->reinterpretCast.getMixedStrides()[dimIdx], + chain->subview.getMixedStrides()[dimIdx]); + + strideOp.getResult().replaceAllUsesWith(replacement); + strideOp.erase(); + } + + // Clean up dead unrealized_conversion_cast ops that bridged + // memref -> partition_tensor_view / tile_buf and are now unused + // after folding. + SmallVector deadCasts; + func.walk([&](UnrealizedConversionCastOp castOp) { + if (castOp.use_empty() && castOp.getNumOperands() == 1 && + isa(castOp.getOperand(0).getType()) && + isa( + castOp.getResult(0).getType())) + deadCasts.push_back(castOp); + }); + for (auto castOp : llvm::reverse(deadCasts)) + castOp.erase(); + + while (true) { + SmallVector deadMemrefOps; + func.walk([&](Operation *op) { + if ((isa(op) || + isa(op)) && + op->use_empty()) + deadMemrefOps.push_back(op); + }); + if (deadMemrefOps.empty()) + break; + for (auto *op : llvm::reverse(deadMemrefOps)) + op->erase(); + } + + eraseDeadAllocTileOps(func); + } +}; + +} // namespace + +namespace mlir { +namespace pto { + +std::unique_ptr createFoldTileBufIntrinsicsPass() { + return std::make_unique(); +} + +} // namespace pto +} // namespace mlir diff --git a/lib/PTO/Transforms/InferPTOLayout.cpp b/lib/PTO/Transforms/InferPTOLayout.cpp index a0fb6d7c3..d508befe4 100644 --- a/lib/PTO/Transforms/InferPTOLayout.cpp +++ b/lib/PTO/Transforms/InferPTOLayout.cpp @@ -540,6 +540,12 @@ struct InferPTOLayoutPass : public mlir::pto::impl::InferPTOLayoutBase { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InferPTOLayoutPass) + StringRef getArgument() const final { return "pto-infer-layout"; } + + StringRef getDescription() const final { + return "Infer GlobalTensor layout (ND/DN/NZ) for make_tensor_view"; + } + void runOnOperation() override { func::FuncOp func = getOperation(); // ------------------------------------------------------------------ diff --git a/lib/PTO/Transforms/PTOInferVPTOVecScope.cpp b/lib/PTO/Transforms/PTOInferVPTOVecScope.cpp new file mode 100644 index 000000000..77e36fe25 --- /dev/null +++ b/lib/PTO/Transforms/PTOInferVPTOVecScope.cpp @@ -0,0 +1,374 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- PTOInferVPTOVecScope.cpp ------------------------------------------===// +// +// VPTO automatic vecscope inference. +// +//===----------------------------------------------------------------------===// + +#include "PTO/Transforms/Passes.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Pass/Pass.h" + +#include "llvm/ADT/SmallPtrSet.h" + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_PTOINFERVPTOVECSCOPE +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +using namespace mlir; + +namespace { + +enum class VPTOInferenceOpClass { + Vector, + SafeScalar, + Boundary, +}; + +struct NestedRegionSummary { + bool hasVectorOperation = false; + bool hasBoundaryOperation = false; +}; + +struct EscapingVectorScopeValue { + Value value; + Operation *producer = nullptr; + Operation *user = nullptr; +}; + +static VPTOInferenceOpClass classifyOperationForInference(Operation *op); +static LogicalResult inferVecScopesInRegion(Region ®ion, + MLIRContext *context); + +static bool isVecScopeType(Type type) { + return isa(type); +} + +static bool isPTOOperation(Operation *op) { + return op && op->getName().getStringRef().starts_with("pto."); +} + +static bool isExplicitVectorScopeCarrier(Operation *op) { + return isa(op); +} + +static bool isForbiddenInsideInferredVectorScope(Operation *op) { + return isa(op); +} + +static bool isVectorScopeBoundaryOperation(Operation *op) { + return isa(op); +} + +static bool hasVecScopeTypedOperandOrResult(Operation *op) { + for (Type type : op->getOperandTypes()) { + if (isVecScopeType(type)) + return true; + } + for (Type type : op->getResultTypes()) { + if (isVecScopeType(type)) + return true; + } + return false; +} + +static bool requiresVectorScope(Operation *op) { + if (!isPTOOperation(op)) + return false; + + return hasVecScopeTypedOperandOrResult(op) || + isa(op); +} + +static bool isAtomicControlFlowCandidate(Operation *op) { + return isa(op); +} + +static bool isSafeScalarOperation(Operation *op) { + if (op->getNumRegions() != 0) + return false; + if (op->hasTrait()) + return false; + if (isa(op)) + return false; + if (isPTOOperation(op) && !isMemoryEffectFree(op)) + return false; + return isMemoryEffectFree(op); +} + +static void summarizeNestedRegionForAtomicCluster( + Region ®ion, NestedRegionSummary &summary) { + for (Block &block : region) { + for (Operation &op : block) { + if (op.hasTrait()) + continue; + + switch (classifyOperationForInference(&op)) { + case VPTOInferenceOpClass::Vector: + summary.hasVectorOperation = true; + break; + case VPTOInferenceOpClass::SafeScalar: + break; + case VPTOInferenceOpClass::Boundary: + summary.hasBoundaryOperation = true; + return; + } + } + } +} + +static bool canTreatAsAtomicControlFlow(Operation *op) { + if (!isAtomicControlFlowCandidate(op)) + return false; + + NestedRegionSummary summary; + for (Region ®ion : op->getRegions()) { + summarizeNestedRegionForAtomicCluster(region, summary); + if (summary.hasBoundaryOperation) + return false; + } + return summary.hasVectorOperation; +} + +static VPTOInferenceOpClass classifyOperationForInference(Operation *op) { + if (!op) + return VPTOInferenceOpClass::Boundary; + + if (isExplicitVectorScopeCarrier(op)) + return VPTOInferenceOpClass::Boundary; + if (op->hasTrait()) + return VPTOInferenceOpClass::Boundary; + if (isa(op)) + return VPTOInferenceOpClass::Boundary; + if (isVectorScopeBoundaryOperation(op)) + return VPTOInferenceOpClass::Boundary; + if (isForbiddenInsideInferredVectorScope(op)) + return VPTOInferenceOpClass::Boundary; + + if (requiresVectorScope(op)) + return VPTOInferenceOpClass::Vector; + + if (canTreatAsAtomicControlFlow(op)) + return VPTOInferenceOpClass::Vector; + + if (isSafeScalarOperation(op)) + return VPTOInferenceOpClass::SafeScalar; + + return VPTOInferenceOpClass::Boundary; +} + +static bool hasVectorOperation(ArrayRef ops) { + return llvm::any_of(ops, [](Operation *op) { + return classifyOperationForInference(op) == VPTOInferenceOpClass::Vector; + }); +} + +static bool isUserInsideCluster(Operation *user, + const llvm::SmallPtrSetImpl &ops) { + for (Operation *cur = user; cur; cur = cur->getParentOp()) { + if (ops.contains(cur)) + return true; + } + return false; +} + +static bool canMoveIntoResultlessScope(ArrayRef ops) { + llvm::SmallPtrSet opSet; + for (Operation *op : ops) + opSet.insert(op); + + for (Operation *op : ops) { + for (Value result : op->getResults()) { + for (Operation *user : result.getUsers()) { + if (!isUserInsideCluster(user, opSet)) + return false; + } + } + } + return true; +} + +static bool +findEscapingVectorScopeResult(ArrayRef ops, + EscapingVectorScopeValue &escapingValue) { + llvm::SmallPtrSet opSet; + for (Operation *op : ops) + opSet.insert(op); + + for (Operation *op : ops) { + for (Value result : op->getResults()) { + if (!isVecScopeType(result.getType())) + continue; + + for (Operation *user : result.getUsers()) { + if (isUserInsideCluster(user, opSet)) + continue; + + escapingValue.value = result; + escapingValue.producer = op; + escapingValue.user = user; + return true; + } + } + } + return false; +} + +static LogicalResult +emitEscapingVectorScopeValueError(const EscapingVectorScopeValue &escapingValue) { + Operation *producer = escapingValue.producer; + if (!producer) + return failure(); + + InFlightDiagnostic diag = producer->emitOpError() + << "cannot infer resultless pto.vecscope because " + "VPTO vector-scope data cannot have external " + "users"; + if (escapingValue.value) + diag << "; escaping value type is " << escapingValue.value.getType(); + if (escapingValue.user) + diag.attachNote(escapingValue.user->getLoc()) + << "external user is here"; + return failure(); +} + +static void wrapCluster(ArrayRef ops, MLIRContext *context) { + if (ops.empty() || !hasVectorOperation(ops)) + return; + + Operation *first = ops.front(); + Operation *last = ops.back(); + Block *parentBlock = first->getBlock(); + + IRRewriter rewriter(context); + rewriter.setInsertionPoint(first); + auto scope = rewriter.create(first->getLoc()); + scope.getBody().push_back(new Block()); + + Block &scopeBody = scope.getBody().front(); + scopeBody.getOperations().splice(scopeBody.end(), parentBlock->getOperations(), + Block::iterator(first), + std::next(Block::iterator(last))); +} + +static LogicalResult wrapGreedySubclusters(ArrayRef ops, + MLIRContext *context) { + for (size_t begin = 0; begin < ops.size();) { + size_t bestEnd = begin; + EscapingVectorScopeValue escapingValue; + bool sawEscapingVectorScopeResult = false; + + for (size_t end = ops.size(); end > begin; --end) { + ArrayRef candidate = ops.slice(begin, end - begin); + if (!hasVectorOperation(candidate)) + continue; + + // Prefer the largest suffix-preserving candidate that actually needs a + // vecscope and can be moved into today's resultless pto.vecscope form. + if (canMoveIntoResultlessScope(candidate)) { + bestEnd = end; + break; + } + + if (!sawEscapingVectorScopeResult && + findEscapingVectorScopeResult(candidate, escapingValue)) + sawEscapingVectorScopeResult = true; + } + + if (bestEnd == begin) { + if (classifyOperationForInference(ops[begin]) == + VPTOInferenceOpClass::Vector && + sawEscapingVectorScopeResult) + return emitEscapingVectorScopeValueError(escapingValue); + ++begin; + continue; + } + + wrapCluster(ops.slice(begin, bestEnd - begin), context); + begin = bestEnd; + } + return success(); +} + +static LogicalResult inferVecScopesInBlock(Block &block, MLIRContext *context) { + SmallVector pending; + + auto flush = [&]() -> LogicalResult { + if (failed(wrapGreedySubclusters(pending, context))) + return failure(); + pending.clear(); + return success(); + }; + + SmallVector ops; + for (Operation &op : block) + ops.push_back(&op); + + for (Operation *op : ops) { + switch (classifyOperationForInference(op)) { + case VPTOInferenceOpClass::Vector: + case VPTOInferenceOpClass::SafeScalar: + pending.push_back(op); + continue; + case VPTOInferenceOpClass::Boundary: + if (failed(flush())) + return failure(); + continue; + } + } + if (failed(flush())) + return failure(); + + SmallVector remainingOps; + for (Operation &op : block) + remainingOps.push_back(&op); + + for (Operation *op : remainingOps) { + if (isExplicitVectorScopeCarrier(op)) + continue; + for (Region &nested : op->getRegions()) { + if (failed(inferVecScopesInRegion(nested, context))) + return failure(); + } + } + return success(); +} + +static LogicalResult inferVecScopesInRegion(Region ®ion, + MLIRContext *context) { + for (Block &block : region) { + if (failed(inferVecScopesInBlock(block, context))) + return failure(); + } + return success(); +} + +struct PTOInferVPTOVecScopePass + : public pto::impl::PTOInferVPTOVecScopeBase< + PTOInferVPTOVecScopePass> { + void runOnOperation() override { + func::FuncOp func = getOperation(); + if (failed(inferVecScopesInRegion(func.getBody(), &getContext()))) + signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr mlir::pto::createPTOInferVPTOVecScopePass() { + return std::make_unique(); +} diff --git a/lib/PTO/Transforms/PTOInstantiateAndInlineOpLib.cpp b/lib/PTO/Transforms/PTOInstantiateAndInlineOpLib.cpp new file mode 100644 index 000000000..1d77061c5 --- /dev/null +++ b/lib/PTO/Transforms/PTOInstantiateAndInlineOpLib.cpp @@ -0,0 +1,313 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "PTO/IR/PTO.h" +#include "PTO/Transforms/Passes.h" +#include "PTOLowerToOpLibCalls.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Pass/Pass.h" + +#include "llvm/ADT/StringRef.h" + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_PTOINLINELIBCALL +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +using namespace mlir; + +namespace { + +static constexpr llvm::StringLiteral kOpLibAttrInstVariantId = + "pto.oplib.instance.variant_id"; +static constexpr llvm::StringLiteral kOpLibAttrInstOp = "pto.oplib.instance.op"; +static constexpr llvm::StringLiteral kOpLibAttrInstDType = + "pto.oplib.instance.dtype"; +static constexpr llvm::StringLiteral kErrInstanceBodyMissing = + "E_OPLIB_INSTANCE_BODY_MISSING"; + +static bool isInstanceFunc(func::FuncOp fn) { + return fn->hasAttr(kOpLibAttrInstVariantId); +} + +static bool isTilelangInlineProcFunc(func::FuncOp fn) { + return fn->hasAttr("pto.tilelang.inline_proc"); +} + +static bool isTilelangTemplateFunc(func::FuncOp fn) { + return fn->hasAttr("pto.tilelang.instance") && fn.isPrivate(); +} + +static bool isInlineableLibFunc(func::FuncOp fn) { + // Keep OP-Lib behavior unchanged while force-inlining TileLang helpers + // (inline_proc + private template helper). + if (isInstanceFunc(fn) || isTilelangInlineProcFunc(fn)) + return true; + return isTilelangTemplateFunc(fn); +} + +static Value maybeUnwrapCastToExpected(Value operand, Type expectedType) { + if (operand.getType() == expectedType) + return operand; + + auto cast = operand.getDefiningOp(); + if (!cast || cast->getNumOperands() != 1 || cast->getNumResults() != 1) + return operand; + + if (cast.getOperand(0).getType() == expectedType) + return cast.getOperand(0); + return operand; +} + +static Operation *cloneOpForInlineWithFix(OpBuilder &builder, Operation &op, + IRMapping &mapping) { + if (auto alloc = dyn_cast(&op)) { + auto mapOperand = [&](Value operand, Type expectedType) -> Value { + if (!operand) + return Value(); + Value mapped = mapping.lookupOrNull(operand); + if (!mapped) + mapped = operand; + return maybeUnwrapCastToExpected(mapped, expectedType); + }; + + Value mappedAddr = mapOperand( + alloc.getAddr(), alloc.getAddr() ? alloc.getAddr().getType() : Type()); + Value mappedValidRow = mapOperand( + alloc.getValidRow(), + alloc.getValidRow() ? alloc.getValidRow().getType() : Type()); + Value mappedValidCol = mapOperand( + alloc.getValidCol(), + alloc.getValidCol() ? alloc.getValidCol().getType() : Type()); + + auto cloned = builder.create( + alloc.getLoc(), alloc.getType(), mappedAddr, mappedValidRow, + mappedValidCol); + cloned->setAttrs(alloc->getAttrs()); + return cloned.getOperation(); + } + + return builder.clone(op, mapping); +} + +static void eraseDeadBridgeCasts(func::FuncOp func) { + bool changed = true; + while (changed) { + changed = false; + + SmallVector deadUnrealized; + func.walk([&](UnrealizedConversionCastOp cast) { + if (cast->use_empty()) + deadUnrealized.push_back(cast); + }); + + SmallVector deadMemrefCasts; + func.walk([&](memref::CastOp cast) { + if (cast->use_empty()) + deadMemrefCasts.push_back(cast); + }); + + if (deadUnrealized.empty() && deadMemrefCasts.empty()) + break; + + for (UnrealizedConversionCastOp cast : llvm::reverse(deadUnrealized)) + cast.erase(); + for (memref::CastOp cast : llvm::reverse(deadMemrefCasts)) + cast.erase(); + changed = true; + } +} + +static LogicalResult inlineCall(func::CallOp call, func::FuncOp callee) { + if (callee.isExternal()) + return call.emitOpError("callee must have a body before inlining"); + + Block &entry = callee.getBody().front(); + if (entry.getNumArguments() != call.getNumOperands()) + return call.emitOpError("callee argument count mismatch during inlining"); + auto returnOp = dyn_cast(entry.getTerminator()); + if (!returnOp) + return call.emitOpError("callee must terminate with func.return"); + if (returnOp.getNumOperands() != call.getNumResults()) + return call.emitOpError("callee return/result arity mismatch during inlining"); + + OpBuilder builder(call); + IRMapping mapping; + for (auto [arg, operand] : + llvm::zip(entry.getArguments(), call.getOperands())) + mapping.map(arg, operand); + + for (Operation &op : entry.without_terminator()) { + FailureOr handledOr = + pto::tryCloneOpLibInlineBridgeOp(builder, op, mapping); + if (failed(handledOr)) + return call.emitOpError("failed to remap OP-Lib inline bridge op"); + if (*handledOr) + continue; + + Operation *newOp = cloneOpForInlineWithFix(builder, op, mapping); + for (auto [oldRes, newRes] : + llvm::zip(op.getResults(), newOp->getResults())) + mapping.map(oldRes, newRes); + } + + for (auto [callResult, returnOperand] : + llvm::zip(call.getResults(), returnOp.getOperands())) { + Value mapped = mapping.lookupOrNull(returnOperand); + if (!mapped) + mapped = returnOperand; + callResult.replaceAllUsesWith(mapped); + } + + call.erase(); + return success(); +} + +struct PTOInlineLibCallPass + : public pto::impl::PTOInlineLibCallBase { + using pto::impl::PTOInlineLibCallBase< + PTOInlineLibCallPass>::PTOInlineLibCallBase; + + void runOnOperation() override { + ModuleOp module = getOperation(); + + int inlinedCalls = 0; + int touchedFuncs = 0; + + for (func::FuncOp func : module.getOps()) { + if (func.isExternal()) + continue; + if (isInstanceFunc(func)) + continue; + if (func.empty()) + continue; + + bool changedThisFunc = false; + bool madeProgress = true; + while (madeProgress) { + madeProgress = false; + + SmallVector calls; + func.walk([&](func::CallOp call) { calls.push_back(call); }); + + for (func::CallOp oldCall : calls) { + if (!oldCall || !oldCall->getBlock()) + continue; + + auto calleeAttr = oldCall.getCalleeAttr(); + if (!calleeAttr) + continue; + + func::FuncOp callee = + module.lookupSymbol(calleeAttr.getValue()); + if (!callee || !isInlineableLibFunc(callee)) + continue; + + if (callee.isExternal()) { + oldCall.emitError() << kErrInstanceBodyMissing + << ": OP-Lib instance body is missing for @" + << callee.getSymName(); + if (auto variant = + callee->getAttrOfType(kOpLibAttrInstVariantId)) { + oldCall.emitRemark() << "variant_id=" << variant.getValue(); + } + if (auto op = callee->getAttrOfType(kOpLibAttrInstOp)) { + oldCall.emitRemark() << "op=" << op.getValue(); + } + if (auto dtype = + callee->getAttrOfType(kOpLibAttrInstDType)) { + oldCall.emitRemark() << "dtype=" << dtype.getValue(); + } + signalPassFailure(); + return; + } + + func::CallOp call = oldCall; + SmallVector concreteOperands; + concreteOperands.reserve(call.getNumOperands()); + for (auto [operand, expectedTy] : llvm::zip( + call.getOperands(), callee.getFunctionType().getInputs())) { + concreteOperands.push_back( + maybeUnwrapCastToExpected(operand, expectedTy)); + } + + OpBuilder builder(call); + auto newCall = builder.create(call.getLoc(), callee, + concreteOperands); + if (call.getNumResults() != newCall.getNumResults()) { + call.emitOpError("call result arity mismatch during inline staging"); + signalPassFailure(); + return; + } + for (auto [oldResult, newResult] : + llvm::zip(call.getResults(), newCall.getResults())) + oldResult.replaceAllUsesWith(newResult); + call.erase(); + + if (failed(inlineCall(newCall, callee))) { + signalPassFailure(); + return; + } + + ++inlinedCalls; + changedThisFunc = true; + madeProgress = true; + if (debug) { + llvm::errs() << "[op-fusion] inline-libcall: inlined @" + << callee.getSymName() << " into @" << func.getSymName() + << "\n"; + } + } + } + + if (changedThisFunc) { + eraseDeadBridgeCasts(func); + ++touchedFuncs; + } + } + + if (debug) { + llvm::errs() << "[op-fusion] inline-libcall touched " << touchedFuncs + << " function(s), inlined " << inlinedCalls << " call(s)\n"; + } + + // Drop now-dead inline-able callees (private + uncalled) so downstream + // backends never see leftover template/instance bodies. This is needed + // for TileLang templates whose tile_buf-typed parameters cannot be + // legalized once their callers have been inlined. + SymbolTable symbolTable(module); + SmallVector deadFuncs; + for (func::FuncOp func : module.getOps()) { + if (!isInlineableLibFunc(func)) + continue; + if (func.isPublic()) + continue; + auto uses = symbolTable.getSymbolUses(func, module); + if (uses && uses->empty()) + deadFuncs.push_back(func); + } + for (func::FuncOp func : deadFuncs) + func.erase(); + } +}; + +} // namespace + +std::unique_ptr +mlir::pto::createPTOInlineLibCallPass(const PTOInlineLibCallOptions &options) { + return std::make_unique(options); +} diff --git a/lib/PTO/Transforms/PTOLowerToOpLibCalls.cpp b/lib/PTO/Transforms/PTOLowerToOpLibCalls.cpp new file mode 100644 index 000000000..42f630902 --- /dev/null +++ b/lib/PTO/Transforms/PTOLowerToOpLibCalls.cpp @@ -0,0 +1,225 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "PTO/IR/PTO.h" + +#include "PTOLowerToOpLibCalls.h" + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" + +using namespace mlir; + +namespace { + +static int64_t getElemBytes(Type elemTy) { + if (auto intTy = dyn_cast(elemTy)) + return (intTy.getWidth() + 7) / 8; + if (auto floatTy = dyn_cast(elemTy)) + return (floatTy.getWidth() + 7) / 8; + return -1; +} + +static bool readBLayoutI32(Attribute attr, int32_t &out) { + if (auto intAttr = dyn_cast(attr)) { + out = static_cast(intAttr.getInt()); + return true; + } + return false; +} + +static bool readSLayoutI32(Attribute attr, int32_t &out) { + if (auto intAttr = dyn_cast(attr)) { + out = static_cast(intAttr.getInt()); + return true; + } + return false; +} + +static FailureOr inferSimdBridgeMemRefType(pto::TileBufType tileTy, + MLIRContext *ctx) { + if (tileTy.getRank() != 2) + return failure(); + + ArrayRef physicalShape = tileTy.getShape(); + if (physicalShape.size() != 2) + return failure(); + if (physicalShape[0] == ShapedType::kDynamic || + physicalShape[1] == ShapedType::kDynamic) + return failure(); + + SmallVector memShape(physicalShape.begin(), physicalShape.end()); + ArrayRef validShape = tileTy.getValidShape(); + if (validShape.size() == memShape.size()) { + for (unsigned i = 0; i < validShape.size(); ++i) + memShape[i] = validShape[i] < 0 ? physicalShape[i] : validShape[i]; + } + + auto cfg = tileTy.getConfigAttr(); + if (!cfg) + cfg = pto::TileBufConfigAttr::getDefault(ctx); + + int32_t bl = 0; + int32_t sl = 0; + int32_t fr = 512; + (void)readBLayoutI32(cfg.getBLayout(), bl); + (void)readSLayoutI32(cfg.getSLayout(), sl); + if (auto attr = dyn_cast(cfg.getSFractalSize())) + fr = static_cast(attr.getInt()); + + int64_t innerRows = 1; + int64_t innerCols = 1; + if (sl != 0) { + int64_t elemBytes = getElemBytes(tileTy.getElementType()); + if (elemBytes <= 0) + return failure(); + if (fr == 1024) { + innerRows = 16; + innerCols = 16; + } else if (fr == 32) { + innerRows = 16; + innerCols = 2; + } else if (fr == 512) { + if (sl == 1) { + innerRows = 16; + innerCols = 32 / elemBytes; + } else if (sl == 2) { + innerRows = 32 / elemBytes; + innerCols = 16; + } else { + return failure(); + } + } else { + return failure(); + } + } + + SmallVector strides; + if (sl == 0) { + if (bl == 1) { + strides.push_back(1); + strides.push_back(physicalShape[0]); + } else { + strides.push_back(physicalShape[1]); + strides.push_back(1); + } + } else if (bl == 1) { + if (sl != 1) + return failure(); + strides.push_back(innerCols); + strides.push_back(physicalShape[0]); + } else { + strides.push_back(physicalShape[1]); + strides.push_back(innerRows); + } + + auto layout = StridedLayoutAttr::get(ctx, /*offset=*/0, strides); + return MemRefType::get(memShape, tileTy.getElementType(), layout, + tileTy.getMemorySpace()); +} + +static bool areIntegerCarrierTypesCompatible(Type lhs, Type rhs) { + auto lhsInt = dyn_cast(lhs); + auto rhsInt = dyn_cast(rhs); + if (!lhsInt || !rhsInt) + return false; + return lhsInt.getWidth() == rhsInt.getWidth(); +} + +static bool canRemapSimdBridgeViaCarrierCast(MemRefType actualTy, + MemRefType templateTy) { + if (actualTy.getRank() != templateTy.getRank()) + return false; + if (actualTy.getMemorySpace() != templateTy.getMemorySpace()) + return false; + return areIntegerCarrierTypesCompatible(actualTy.getElementType(), + templateTy.getElementType()); +} + +static MemRefType remapMemRefToTemplateCarrier(MemRefType actualTy, + MemRefType templateTy) { + return MemRefType::get(actualTy.getShape(), templateTy.getElementType(), + actualTy.getLayout(), actualTy.getMemorySpace()); +} + +} // namespace + +FailureOr mlir::pto::tryCloneOpLibInlineBridgeOp(OpBuilder &builder, + Operation &op, + IRMapping &mapping) { + if (auto bridge = dyn_cast(&op)) { + Value mappedSrc = mapping.lookupOrNull(bridge.getSrc()); + if (!mappedSrc) + return failure(); + + auto templateMemTy = dyn_cast(bridge.getDst().getType()); + if (auto mappedTileTy = dyn_cast(mappedSrc.getType())) { + FailureOr inferredTyOr = + inferSimdBridgeMemRefType(mappedTileTy, builder.getContext()); + if (failed(inferredTyOr)) + return failure(); + + auto inferredTy = *inferredTyOr; + auto newBridge = builder.create( + bridge.getLoc(), inferredTy, mappedSrc); + if (templateMemTy && inferredTy != templateMemTy && + canRemapSimdBridgeViaCarrierCast(inferredTy, templateMemTy)) { + auto carrierTy = remapMemRefToTemplateCarrier(inferredTy, templateMemTy); + auto cast = builder.create( + bridge.getLoc(), TypeRange{carrierTy}, ValueRange{newBridge.getDst()}); + mapping.map(bridge.getDst(), cast.getResult(0)); + } else { + mapping.map(bridge.getDst(), newBridge.getDst()); + } + return true; + } + + auto mappedMemTy = dyn_cast(mappedSrc.getType()); + auto dstMemTy = templateMemTy; + if (!mappedMemTy || !dstMemTy) + return failure(); + if (mappedMemTy.getRank() != dstMemTy.getRank()) + return failure(); + + auto newBridge = builder.create( + bridge.getLoc(), mappedMemTy, mappedSrc); + if (mappedMemTy.getElementType() == dstMemTy.getElementType()) { + mapping.map(bridge.getDst(), newBridge.getDst()); + return true; + } + if (!canRemapSimdBridgeViaCarrierCast(mappedMemTy, dstMemTy)) + return failure(); + auto carrierTy = remapMemRefToTemplateCarrier(mappedMemTy, dstMemTy); + auto cast = builder.create( + bridge.getLoc(), TypeRange{carrierTy}, ValueRange{newBridge.getDst()}); + mapping.map(bridge.getDst(), cast.getResult(0)); + return true; + } + + if (auto cast = dyn_cast(&op)) { + if (cast->getNumOperands() != 1 || cast->getNumResults() != 1) + return failure(); + + Value mappedSrc = mapping.lookupOrNull(cast.getOperand(0)); + if (!mappedSrc) + return failure(); + + Type dstTy = cast.getResult(0).getType(); + if (mappedSrc.getType() == dstTy) { + mapping.map(cast.getResult(0), mappedSrc); + return true; + } + + auto clonedCast = builder.create( + cast.getLoc(), TypeRange{dstTy}, ValueRange{mappedSrc}); + mapping.map(cast.getResult(0), clonedCast.getResult(0)); + return true; + } + + return false; +} diff --git a/lib/PTO/Transforms/PTOLowerToOpLibCalls.h b/lib/PTO/Transforms/PTOLowerToOpLibCalls.h new file mode 100644 index 000000000..34bf7e2b9 --- /dev/null +++ b/lib/PTO/Transforms/PTOLowerToOpLibCalls.h @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef PTO_TRANSFORMS_PTOLOWERTOOPLIBCALLS_H +#define PTO_TRANSFORMS_PTOLOWERTOOPLIBCALLS_H + +#include "mlir/IR/Builders.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Support/LLVM.h" + +namespace mlir { +namespace pto { + +FailureOr tryCloneOpLibInlineBridgeOp(OpBuilder &builder, Operation &op, + IRMapping &mapping); + +} // namespace pto +} // namespace mlir + +#endif // PTO_TRANSFORMS_PTOLOWERTOOPLIBCALLS_H diff --git a/lib/PTO/Transforms/PTOPlanMemory.cpp b/lib/PTO/Transforms/PTOPlanMemory.cpp index 906e679f3..0d25b39d1 100644 --- a/lib/PTO/Transforms/PTOPlanMemory.cpp +++ b/lib/PTO/Transforms/PTOPlanMemory.cpp @@ -694,6 +694,9 @@ SetVector MemLivenessAnalysis::Union(SetVector set1, } SetVector MemLivenessAnalysis::GetAliasBuffers(Value aliasBuffer) { + if (!aliasBuffer) + return {}; + auto trueVar = buffer2AliasVec.find(aliasBuffer); if (trueVar != buffer2AliasVec.end()) { return trueVar->second; diff --git a/lib/PTO/Transforms/PTOToEmitC.cpp b/lib/PTO/Transforms/PTOToEmitC.cpp index 7d26239da..2fd87d7d7 100644 --- a/lib/PTO/Transforms/PTOToEmitC.cpp +++ b/lib/PTO/Transforms/PTOToEmitC.cpp @@ -4389,6 +4389,9 @@ struct PTOTStoreToTSTORE : public OpConversionPattern { switch (reluPreMode) { case pto::ReluPreMode::NoRelu: return "ReluPreMode::NoRelu"; case pto::ReluPreMode::NormalRelu: return "ReluPreMode::NormalRelu"; + case pto::ReluPreMode::ScalarRelu: return "ReluPreMode::ScalarRelu"; + case pto::ReluPreMode::VectorRelu: return "ReluPreMode::VectorRelu"; + case pto::ReluPreMode::Pwl: return "ReluPreMode::Pwl"; } return "ReluPreMode::NoRelu"; } @@ -8735,6 +8738,12 @@ struct PTOMovToEmitC : public OpConversionPattern { return "ReluPreMode::NoRelu"; case pto::ReluPreMode::NormalRelu: return "ReluPreMode::NormalRelu"; + case pto::ReluPreMode::ScalarRelu: + return "ReluPreMode::ScalarRelu"; + case pto::ReluPreMode::VectorRelu: + return "ReluPreMode::VectorRelu"; + case pto::ReluPreMode::Pwl: + return "ReluPreMode::Pwl"; } llvm_unreachable("unknown ReluPreMode"); }; diff --git a/lib/PTO/Transforms/PTOVPTOPtrBoundary.cpp b/lib/PTO/Transforms/PTOVPTOPtrBoundary.cpp new file mode 100644 index 000000000..f1e52424f --- /dev/null +++ b/lib/PTO/Transforms/PTOVPTOPtrBoundary.cpp @@ -0,0 +1,345 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "PTO/IR/PTO.h" +#include "PTO/Transforms/VPTOLowering.h" +#include "PTO/Transforms/Passes.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_PTOVPTOPTRBOUNDARY +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +using namespace mlir; + +namespace { + +static Type convertVPTOBoundaryMemRefType(Type type) { + auto memrefType = dyn_cast(type); + if (!memrefType) + return type; + auto memorySpace = + dyn_cast_or_null(memrefType.getMemorySpace()); + if (!memorySpace) + return {}; + return pto::PtrType::get(type.getContext(), memrefType.getElementType(), + memorySpace); +} + +static bool isTrivialVPTOBoundaryCastPtr(pto::CastPtrOp castOp) { + return castOp.getInput().getType() == castOp.getResult().getType(); +} + +static LogicalResult eraseDeadVPTOMemRefScaffold(ModuleOp module) { + bool erasedAny = true; + while (erasedAny) { + erasedAny = false; + SmallVector trivialCasts; + SmallVector deadOps; + module.walk([&](Operation *op) { + if (auto castOp = dyn_cast(op)) { + if (isTrivialVPTOBoundaryCastPtr(castOp)) { + trivialCasts.push_back(castOp); + return; + } + if (castOp->use_empty()) + deadOps.push_back(op); + return; + } + + if (!op->use_empty()) + return; + if (isa(op)) + deadOps.push_back(op); + }); + + for (pto::CastPtrOp castOp : trivialCasts) { + if (!castOp->getBlock()) + continue; + castOp.getResult().replaceAllUsesWith(castOp.getInput()); + castOp.erase(); + erasedAny = true; + } + + for (Operation *op : deadOps) { + if (!op->getBlock()) + continue; + op->erase(); + erasedAny = true; + } + } + return success(); +} + +static Type getVPTOBufferElementType(Value value) { + Type type = value.getType(); + if (auto tileType = dyn_cast(type)) + return tileType.getElementType(); + if (auto memrefType = dyn_cast(type)) + return memrefType.getElementType(); + if (auto ptrType = dyn_cast(type)) + return ptrType.getElementType(); + return {}; +} + +static Attribute getVPTOBufferMemorySpace(Value value) { + Type type = value.getType(); + if (auto tileType = dyn_cast(type)) + return tileType.getMemorySpace(); + if (auto memrefType = dyn_cast(type)) + return memrefType.getMemorySpace(); + if (auto ptrType = dyn_cast(type)) + return ptrType.getMemorySpace(); + return {}; +} + +static bool needsPtrCanonicalization(Value value) { + return isa(value.getType()); +} + +static bool isSupportedVPTOBufferLikeBoundaryOp(Operation *op) { + return isa(op); +} + +static LogicalResult canonicalizeBoundaryCastPtrOps(ModuleOp module, + llvm::raw_ostream *diagOS) { + SmallVector castsToRewrite; + module.walk([&](pto::CastPtrOp castOp) { + if (!isa(castOp.getInput().getType())) + return; + if (!isa(castOp.getResult().getType())) + return; + castsToRewrite.push_back(castOp); + }); + + PatternRewriter rewriter(module.getContext()); + for (pto::CastPtrOp castOp : castsToRewrite) { + if (!castOp->getBlock()) + continue; + + auto resultType = dyn_cast(castOp.getResult().getType()); + if (!resultType) + continue; + + rewriter.setInsertionPoint(castOp); + Value ptrValue = pto::materializeBufferPointer( + castOp.getInput(), resultType.getElementType(), + resultType.getMemorySpace(), rewriter, castOp.getLoc()); + if (!ptrValue) { + if (diagOS) { + *diagOS << "VPTO emission-boundary ptr rewrite failed: could not " + "canonicalize pto.castptr input for "; + castOp->print(*diagOS); + *diagOS << "\n"; + } + return failure(); + } + + castOp.getResult().replaceAllUsesWith(ptrValue); + rewriter.eraseOp(castOp); + } + + return success(); +} + +static LogicalResult canonicalizeSupportedVPTOBufferLikeOps( + ModuleOp module, llvm::raw_ostream *diagOS) { + SmallVector opsToRewrite; + module.walk([&](Operation *op) { + if (isSupportedVPTOBufferLikeBoundaryOp(op)) + opsToRewrite.push_back(op); + }); + + PatternRewriter rewriter(module.getContext()); + for (Operation *op : opsToRewrite) { + rewriter.setInsertionPoint(op); + + SmallVector newOperands; + newOperands.reserve(op->getNumOperands()); + bool changed = false; + + for (Value operand : op->getOperands()) { + if (!needsPtrCanonicalization(operand)) { + newOperands.push_back(operand); + continue; + } + + Type elementType = getVPTOBufferElementType(operand); + Attribute memorySpace = getVPTOBufferMemorySpace(operand); + if (!elementType || !memorySpace) { + if (diagOS) { + *diagOS << "VPTO emission-boundary ptr rewrite failed: could not " + "derive element type or memory space for operand of "; + op->print(*diagOS); + *diagOS << "\n"; + } + return failure(); + } + + Value ptrValue = pto::materializeBufferPointer(operand, elementType, + memorySpace, rewriter, + op->getLoc()); + if (!ptrValue) { + if (diagOS) { + *diagOS << "VPTO emission-boundary ptr rewrite failed: could not " + "materialize pointer operand for "; + op->print(*diagOS); + *diagOS << "\n"; + } + return failure(); + } + + changed = changed || (ptrValue != operand); + newOperands.push_back(ptrValue); + } + + if (!changed) + continue; + + OperationState state(op->getLoc(), op->getName().getStringRef()); + state.addOperands(newOperands); + state.addTypes(op->getResultTypes()); + state.addAttributes(op->getAttrs()); + + Operation *newOp = rewriter.create(state); + rewriter.replaceOp(op, newOp->getResults()); + } + + return success(); +} + +struct PTOVPTOPtrBoundaryPass + : public pto::impl::PTOVPTOPtrBoundaryBase { + using pto::impl::PTOVPTOPtrBoundaryBase< + PTOVPTOPtrBoundaryPass>::PTOVPTOPtrBoundaryBase; + + void runOnOperation() override { + ModuleOp module = getOperation(); + if (failed(pto::convertVPTOEmissionBoundaryToPtr(module, &llvm::errs()))) + signalPassFailure(); + } +}; + +} // namespace + +LogicalResult mlir::pto::convertVPTOEmissionBoundaryToPtr( + ModuleOp module, llvm::raw_ostream *diagOS) { + // VPTO kernels use ptr-only entry semantics at the emission boundary: the + // function ABI keeps only the same-space base pointer, while shape/stride + // state remains in SSA. Body-level op canonicalization is added on top of + // this entry rewrite in follow-up tasks. + if (failed(eraseDeadVPTOMemRefScaffold(module))) + return failure(); + + bool sawFailure = false; + for (func::FuncOp func : module.getOps()) { + if (func.isExternal()) + continue; + + FunctionType functionType = func.getFunctionType(); + SmallVector newInputs(functionType.getInputs().begin(), + functionType.getInputs().end()); + bool changed = false; + + for (auto [idx, inputType] : llvm::enumerate(functionType.getInputs())) { + auto memrefType = dyn_cast(inputType); + if (!memrefType) + continue; + + Type newType = convertVPTOBoundaryMemRefType(inputType); + if (!newType) { + if (diagOS) + *diagOS << "VPTO emission-boundary ptr rewrite failed: unsupported " + "memref argument type in " + << func.getName() << ": " << inputType << "\n"; + sawFailure = true; + continue; + } + + BlockArgument arg = func.getArgument(idx); + SmallVector users(arg.getUsers().begin(), arg.getUsers().end()); + arg.setType(newType); + newInputs[idx] = newType; + changed = true; + + for (Operation *user : users) { + if (auto cast = dyn_cast(user)) { + if (cast.getInput() != arg) + continue; + if (cast.getResult().getType() == newType) { + cast.getResult().replaceAllUsesWith(arg); + cast.erase(); + } + continue; + } + + if (isa(user) && + user->use_empty()) { + user->erase(); + continue; + } + + if (isSupportedVPTOBufferLikeBoundaryOp(user)) + continue; + + if (diagOS) { + *diagOS << "VPTO emission-boundary ptr rewrite failed: argument " + << idx << " of " << func.getName() + << " still feeds a memref-dependent user after ptr rewrite:\n"; + user->print(*diagOS); + *diagOS << "\n"; + } + sawFailure = true; + } + } + + for (Type resultType : functionType.getResults()) { + if (!isa(resultType)) + continue; + if (diagOS) + *diagOS << "VPTO emission-boundary ptr rewrite failed: memref result " + "is unsupported for " + << func.getName() << ": " << resultType << "\n"; + sawFailure = true; + } + + if (changed) { + func.setFunctionType( + FunctionType::get(module.getContext(), newInputs, functionType.getResults())); + } + } + + if (sawFailure) + return failure(); + + if (failed(canonicalizeBoundaryCastPtrOps(module, diagOS))) + return failure(); + + if (failed(canonicalizeSupportedVPTOBufferLikeOps(module, diagOS))) + return failure(); + + return eraseDeadVPTOMemRefScaffold(module); +} + +std::unique_ptr mlir::pto::createPTOVPTOPtrBoundaryPass() { + return std::make_unique(); +} diff --git a/lib/PTO/Transforms/PTOValidateVPTOIR.cpp b/lib/PTO/Transforms/PTOValidateVPTOIR.cpp new file mode 100644 index 000000000..0082f6d34 --- /dev/null +++ b/lib/PTO/Transforms/PTOValidateVPTOIR.cpp @@ -0,0 +1,926 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- PTOValidateVPTOIR.cpp - Shared VPTO legality helpers --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file owns the shared helper layer for the dual-stage VPTO legality +// verifier. Follow-up tasks add the public validation entrypoints and pass +// wrappers on top of this utility layer. +// +//===----------------------------------------------------------------------===// + +#include "PTO/IR/PTO.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" + +#include +#include +#include + +namespace mlir { +namespace pto { + +LogicalResult validateVPTOAuthoringIR(ModuleOp module, + llvm::raw_ostream *diagOS = nullptr); +LogicalResult validateVPTOEmissionIR(ModuleOp module, + llvm::raw_ostream *diagOS = nullptr); + +namespace detail { + +constexpr llvm::StringLiteral kAIVectorScopeAttrName = + "llvm.loop.aivector_scope"; + +enum class VPTOMaskGranularity { + B8, + B16, + B32, +}; + +enum class VPTOBufferAddressFamily { + None, + Copy, + BufferLike, + PtrOnly, +}; + +enum class VPTOLegalityStage { + Authoring, + Emission, +}; + +class VPTOLegalityHelper { +public: + explicit VPTOLegalityHelper(ModuleOp module) : module(module) {} + + ModuleOp getModule() const { return module; } + + SmallVector getFunctions() { + SmallVector funcs; + for (func::FuncOp func : module.getOps()) + funcs.push_back(func); + return funcs; + } + + static bool isLegalityTypedValue(Type type) { + return isa(type); + } + + static bool isBufferLikeValue(Type type) { + return isa(type); + } + + static bool requiresVecScope(Operation *op) { + if (!isPTOp(op)) + return false; + + return llvm::any_of(op->getOperandTypes(), isLegalityTypedValue) || + llvm::any_of(op->getResultTypes(), isLegalityTypedValue); + } + + static bool isAIVectorScopeCarrier(scf::ForOp loop) { + return loop && loop->hasAttr(kAIVectorScopeAttrName); + } + + static bool isDedicatedVecScopeCarrier(Operation *op) { + return isa_and_nonnull(op); + } + + static bool isAnyVectorScopeCarrier(Operation *op) { + if (auto loop = dyn_cast_or_null(op)) + return isAIVectorScopeCarrier(loop); + return isDedicatedVecScopeCarrier(op); + } + + static Operation *getEnclosingVectorScopeCarrier(Operation *op) { + for (Operation *parent = op ? op->getParentOp() : nullptr; parent; + parent = parent->getParentOp()) { + if (isAnyVectorScopeCarrier(parent)) + return parent; + } + return nullptr; + } + + static std::optional getMaskGranularity(Type type) { + auto maskType = dyn_cast(type); + if (!maskType) + return std::nullopt; + return getMaskGranularity(maskType); + } + + static std::optional getMaskGranularity(MaskType type) { + if (type.isB8()) + return VPTOMaskGranularity::B8; + if (type.isB16()) + return VPTOMaskGranularity::B16; + if (type.isB32()) + return VPTOMaskGranularity::B32; + return std::nullopt; + } + + static StringRef stringifyMaskGranularity(VPTOMaskGranularity granularity) { + switch (granularity) { + case VPTOMaskGranularity::B8: + return "b8"; + case VPTOMaskGranularity::B16: + return "b16"; + case VPTOMaskGranularity::B32: + return "b32"; + } + llvm_unreachable("unsupported VPTO mask granularity"); + } + + static std::optional + inferMaskGranularityFromType(Type type) { + if (auto vregType = dyn_cast(type)) + type = vregType.getElementType(); + + if (type.isF32()) + return VPTOMaskGranularity::B32; + if (type.isF16() || type.isBF16()) + return VPTOMaskGranularity::B16; + + auto intType = dyn_cast(type); + if (!intType) + return std::nullopt; + + switch (intType.getWidth()) { + case 8: + return VPTOMaskGranularity::B8; + case 16: + return VPTOMaskGranularity::B16; + case 32: + return VPTOMaskGranularity::B32; + default: + return std::nullopt; + } + } + + static std::optional + inferMaskGranularityFromFamily(Operation *op) { + StringRef mnemonic = getPTOpMnemonic(op); + if (mnemonic.empty()) + return std::nullopt; + + if (mnemonic.ends_with("_b8")) + return VPTOMaskGranularity::B8; + if (mnemonic.ends_with("_b16")) + return VPTOMaskGranularity::B16; + if (mnemonic.ends_with("_b32")) + return VPTOMaskGranularity::B32; + return std::nullopt; + } + + static VPTOBufferAddressFamily classifyBufferAddressFamily(Operation *op) { + if (!op) + return VPTOBufferAddressFamily::None; + + if (isa(op)) + return VPTOBufferAddressFamily::Copy; + + if (isa(op)) + return VPTOBufferAddressFamily::PtrOnly; + + if (isa(op)) + return VPTOBufferAddressFamily::BufferLike; + + return VPTOBufferAddressFamily::None; + } + + static bool isSupportedEmissionBufferLikeOp(Operation *op) { + return classifyBufferAddressFamily(op) == + VPTOBufferAddressFamily::BufferLike; + } + + static bool isResidualEmissionScaffold(Operation *op) { + return isa(op) || + isTrivialEmissionCastPtr(op); + } + + static SmallVector collectBufferOperands(Operation *op) { + SmallVector bufferOperands; + for (OpOperand &operand : op->getOpOperands()) { + if (isBufferLikeValue(operand.get().getType())) + bufferOperands.push_back(&operand); + } + return bufferOperands; + } + +private: + static bool isPTOp(Operation *op) { + return op && op->getName().getStringRef().starts_with("pto."); + } + + static StringRef getPTOpMnemonic(Operation *op) { + if (!isPTOp(op)) + return {}; + StringRef mnemonic = op->getName().getStringRef(); + (void)mnemonic.consume_front("pto."); + return mnemonic; + } + + static bool isTrivialEmissionCastPtr(Operation *op) { + auto castOp = dyn_cast_or_null(op); + return castOp && + castOp.getInput().getType() == castOp.getResult().getType(); + } + + ModuleOp module; +}; + +class VPTOLegalityValidator { +public: + VPTOLegalityValidator(ModuleOp module, VPTOLegalityStage stage, + llvm::raw_ostream *diagOS) + : helper(module), stage(stage), diagOS(diagOS) {} + + LogicalResult validate() { + if (!helper.getModule()) { + writeDiagnostic("VPTO legality validation requires a valid module\n"); + return failure(); + } + + if (failed(validateAuthoringRules())) + return failure(); + + if (stage == VPTOLegalityStage::Emission && + failed(validateEmissionRules())) + return failure(); + + return success(); + } + +private: + LogicalResult validateAuthoringRules() { + if (failed(validateAuthoringFunctionSurface())) + return failure(); + if (failed(validateAuthoringOperationSurface())) + return failure(); + return success(); + } + + LogicalResult validateEmissionRules() { + if (failed(validateEmissionFunctionSurface())) + return failure(); + if (failed(validateEmissionOperationSurface())) + return failure(); + return success(); + } + + static std::string formatExpectedMaskType(VPTOMaskGranularity granularity) { + std::string storage; + llvm::raw_string_ostream os(storage); + os << "!pto.mask<" + << VPTOLegalityHelper::stringifyMaskGranularity(granularity) << ">"; + return storage; + } + + static LogicalResult validateMaskMatchesVectorFamily(Operation *op, + Type maskType, + StringRef maskRole, + Type vectorType, + StringRef vectorRole) { + auto actual = VPTOLegalityHelper::getMaskGranularity(maskType); + auto expected = VPTOLegalityHelper::inferMaskGranularityFromType(vectorType); + if (!actual || !expected || *actual == *expected) + return success(); + + return op->emitOpError() + << maskRole << " " << maskType << " does not match " << vectorRole + << " " << vectorType << "; expected " + << formatExpectedMaskType(*expected); + } + + static std::optional + inferVstsMaskGranularityOverride(Operation *op) { + Value value; + if (auto vsts = dyn_cast(op)) + value = vsts.getValue(); + else if (auto vstsPost = dyn_cast(op)) + value = vstsPost.getValue(); + else + return std::nullopt; + + auto valueType = dyn_cast(value.getType()); + if (!valueType) + return std::nullopt; + + auto distAttr = op->getAttrOfType("dist"); + if (!distAttr) + return std::nullopt; + + StringRef dist = distAttr.getValue(); + auto elementType = valueType.getElementType(); + unsigned width = 0; + if (auto elementIntType = dyn_cast(elementType)) { + width = elementIntType.getWidth(); + } else if (elementType.isF16() || elementType.isBF16()) { + width = 16; + } else if (elementType.isF32()) { + width = 32; + } else if (elementType.isF64()) { + width = 64; + } else { + return std::nullopt; + } + + if (dist == "PK_B16") { + if (width == 8) + return VPTOMaskGranularity::B16; + return std::nullopt; + } + if (dist == "PK_B32") { + if (width == 16) + return VPTOMaskGranularity::B32; + return std::nullopt; + } + if (dist == "MRG4CHN_B8") { + if (width == 8) + return VPTOMaskGranularity::B32; + return std::nullopt; + } + if (dist == "MRG2CHN_B8") { + if (width == 8) + return VPTOMaskGranularity::B16; + return std::nullopt; + } + if (dist == "MRG2CHN_B16") { + if (width == 16) + return VPTOMaskGranularity::B32; + } + return std::nullopt; + } + + static LogicalResult validateSameMaskGranularity(Operation *op, Type lhsType, + StringRef lhsRole, + Type rhsType, + StringRef rhsRole) { + auto lhs = VPTOLegalityHelper::getMaskGranularity(lhsType); + auto rhs = VPTOLegalityHelper::getMaskGranularity(rhsType); + if (!lhs || !rhs || *lhs == *rhs) + return success(); + + return op->emitOpError() << lhsRole << " " << lhsType << " does not match " + << rhsRole << " " << rhsType; + } + + static bool isAdjacentMaskGranularityWidening(VPTOMaskGranularity input, + VPTOMaskGranularity result) { + return (input == VPTOMaskGranularity::B8 && + result == VPTOMaskGranularity::B16) || + (input == VPTOMaskGranularity::B16 && + result == VPTOMaskGranularity::B32); + } + + static LogicalResult validatePunpackMaskGranularity(PunpackOp op) { + auto input = VPTOLegalityHelper::getMaskGranularity(op.getInput().getType()); + auto result = VPTOLegalityHelper::getMaskGranularity(op.getResult().getType()); + if (!input || !result || *input == *result || + isAdjacentMaskGranularityWidening(*input, *result)) + return success(); + + return op.emitOpError() + << "input mask type " << op.getInput().getType() + << " does not match result mask type " << op.getResult().getType() + << " for pto.punpack"; + } + + template + static LogicalResult validateInputMaskVectorConsumer(OpTy op) { + return validateMaskMatchesVectorFamily(op, op.getMask().getType(), + "mask type", + op.getInput().getType(), + "input vector type"); + } + + template + static LogicalResult validateBinaryMaskVectorConsumer(OpTy op) { + return validateMaskMatchesVectorFamily(op, op.getMask().getType(), + "mask type", op.getLhs().getType(), + "lhs vector type"); + } + + template + static LogicalResult validateValueMaskVectorConsumer(OpTy op) { + if constexpr (std::is_same_v || + std::is_same_v) { + if (std::optional expected = + inferVstsMaskGranularityOverride(op.getOperation())) { + auto actual = + VPTOLegalityHelper::getMaskGranularity(op.getMask().getType()); + if (!actual || *actual == *expected) + return success(); + return op.emitOpError() + << "mask type " << op.getMask().getType() + << " does not match value vector type " + << op.getValue().getType() << "; expected " + << formatExpectedMaskType(*expected); + } + } + return validateMaskMatchesVectorFamily(op, op.getMask().getType(), + "mask type", op.getValue().getType(), + "value vector type"); + } + + void emitHardwareSupportWarnings(Operation *op) const { + auto emitForStore = [&](auto storeOp) { + Operation *store = storeOp.getOperation(); + auto distAttr = store->getAttrOfType("dist"); + if (!distAttr) + return; + + StringRef dist = distAttr.getValue(); + if (dist == "MRG4CHN_B8" || dist == "MRG2CHN_B8" || dist == "MRG2CHN_B16") + writeDiagnostic((Twine("warning: ") + store->getName().getStringRef() + + " dist " + dist + + " is not supported on the current hardware\n") + .str()); + }; + + if (auto vsts = dyn_cast(op)) { + emitForStore(vsts); + return; + } + if (auto vstsPost = dyn_cast(op)) { + emitForStore(vstsPost); + return; + } + } + + template + static LogicalResult validateResultMaskVectorConsumer(OpTy op) { + return validateMaskMatchesVectorFamily(op, op.getMask().getType(), + "mask type", + op.getResult().getType(), + "result vector type"); + } + + template + static LogicalResult validateCarryFamilyContract(CarryOp op) { + if (failed(validateMaskMatchesVectorFamily(op, op.getMask().getType(), + "mask type", + op.getLhs().getType(), + "lhs vector type")) || + failed(validateSameMaskGranularity(op, op.getMask().getType(), + "mask type", + op.getCarry().getType(), + "carry type"))) + return failure(); + + if constexpr (std::is_same_v || + std::is_same_v) { + if (failed(validateSameMaskGranularity(op, op.getCarryIn().getType(), + "carry_in type", + op.getMask().getType(), + "mask type")) || + failed(validateSameMaskGranularity(op, op.getCarryIn().getType(), + "carry_in type", + op.getCarry().getType(), + "carry type"))) + return failure(); + } + + return success(); + } + + template + static LogicalResult validateCompareFamilyContract(CompareOp op, Type vecType) { + if (failed(validateMaskMatchesVectorFamily(op, op.getMask().getType(), + "seed mask type", vecType, + "input vector type")) || + failed(validateMaskMatchesVectorFamily(op, op.getResult().getType(), + "result mask type", vecType, + "input vector type")) || + failed(validateSameMaskGranularity(op, op.getMask().getType(), + "seed mask type", + op.getResult().getType(), + "result mask type"))) + return failure(); + return success(); + } + + template + static LogicalResult validateMaskOnlyUnaryContract(MaskUnaryOp op) { + return validateSameMaskGranularity(op, op.getInput().getType(), + "input mask type", + op.getResult().getType(), + "result mask type"); + } + + static LogicalResult validateMaskOnlyPnotContract(PnotOp op) { + if (failed(validateSameMaskGranularity(op, op.getInput().getType(), + "input mask type", + op.getMask().getType(), + "mask type")) || + failed(validateSameMaskGranularity(op, op.getInput().getType(), + "input mask type", + op.getResult().getType(), + "result mask type"))) + return failure(); + return success(); + } + + static LogicalResult validateMaskOnlyPselContract(PselOp op) { + if (failed(validateSameMaskGranularity(op, op.getSrc0().getType(), + "src0 mask type", + op.getSrc1().getType(), + "src1 mask type")) || + failed(validateSameMaskGranularity(op, op.getSrc0().getType(), + "src0 mask type", + op.getMask().getType(), + "mask type")) || + failed(validateSameMaskGranularity(op, op.getSrc0().getType(), + "src0 mask type", + op.getResult().getType(), + "result mask type"))) + return failure(); + return success(); + } + + template + static LogicalResult validatePredicateMovementContract( + PredicateMovementOp op) { + auto expected = VPTOLegalityHelper::inferMaskGranularityFromFamily(op); + if (!expected) + return success(); + + if (failed(validateSameMaskGranularity(op, op.getLhs().getType(), + "lhs mask type", + op.getRhs().getType(), + "rhs mask type")) || + failed(validateSameMaskGranularity(op, op.getLhs().getType(), + "lhs mask type", + op.getLow().getType(), + "low mask type")) || + failed(validateSameMaskGranularity(op, op.getLhs().getType(), + "lhs mask type", + op.getHigh().getType(), + "high mask type"))) + return failure(); + + auto lhs = VPTOLegalityHelper::getMaskGranularity(op.getLhs().getType()); + if (!lhs || *lhs == *expected) + return success(); + + return op.emitOpError() + << "predicate movement family requires " + << formatExpectedMaskType(*expected) + << " but got lhs mask type " << op.getLhs().getType(); + } + + static LogicalResult validateFamilySuffixMaskResult(Operation *op, + Type resultType, + StringRef resultRole) { + auto expected = VPTOLegalityHelper::inferMaskGranularityFromFamily(op); + auto actual = VPTOLegalityHelper::getMaskGranularity(resultType); + if (!expected || !actual || *expected == *actual) + return success(); + + return op->emitOpError() + << "family suffix requires " << resultRole << " to be " + << formatExpectedMaskType(*expected) << ", but got " << resultType; + } + + static LogicalResult validateFamilySuffixMaskContracts(Operation *op) { + return llvm::TypeSwitch(op) + .Case( + [](auto concreteOp) { + return validateFamilySuffixMaskResult( + concreteOp, concreteOp.getResult().getType(), "result type"); + }) + .Case([](auto concreteOp) { + return validateFamilySuffixMaskResult(concreteOp, + concreteOp.getMask().getType(), + "mask result type"); + }) + .Default([](Operation *) { return success(); }); + } + + static LogicalResult validateUnaryElementTypeContracts(Operation *op) { + return llvm::TypeSwitch(op) + .Case([](VreluOp concreteOp) { + auto vecType = dyn_cast(concreteOp.getInput().getType()); + if (!vecType) + return success(); + + Type elemType = vecType.getElementType(); + if (auto intType = dyn_cast(elemType)) { + if (intType.getWidth() == 32 && !intType.isUnsigned()) + return success(); + } else if (elemType.isF16() || elemType.isF32()) { + return success(); + } + + concreteOp.emitOpError("requires si32/i32/f16/f32 vector element type"); + return failure(); + }) + .Default([](Operation *) { return success(); }); + } + + static LogicalResult validateMaskGranularityContracts(Operation *op) { + return llvm::TypeSwitch(op) + .Case( + [](auto concreteOp) { + return validateInputMaskVectorConsumer(concreteOp); + }) + .Case([](auto concreteOp) { + return validateBinaryMaskVectorConsumer(concreteOp); + }) + .Case([](auto concreteOp) { + return validateCarryFamilyContract(concreteOp); + }) + .Case([](VcmpOp concreteOp) { + return validateCompareFamilyContract(concreteOp, + concreteOp.getSrc0().getType()); + }) + .Case([](VcmpsOp concreteOp) { + return validateCompareFamilyContract(concreteOp, + concreteOp.getSrc().getType()); + }) + .Case([](PpackOp concreteOp) { + return validateMaskOnlyUnaryContract(concreteOp); + }) + .Case([](PunpackOp concreteOp) { + return validatePunpackMaskGranularity(concreteOp); + }) + .Case( + [](PnotOp concreteOp) { return validateMaskOnlyPnotContract(concreteOp); }) + .Case( + [](PselOp concreteOp) { return validateMaskOnlyPselContract(concreteOp); }) + .Case([](auto concreteOp) { + return validatePredicateMovementContract(concreteOp); + }) + .Case([](VselOp concreteOp) { + return validateMaskMatchesVectorFamily(concreteOp, + concreteOp.getMask().getType(), + "mask type", + concreteOp.getSrc0().getType(), + "src0 vector type"); + }) + .Case([](auto concreteOp) { + return validateResultMaskVectorConsumer(concreteOp); + }) + .Case([](auto concreteOp) { + return validateValueMaskVectorConsumer(concreteOp); + }) + .Case([](Vstsx2Op concreteOp) { + return validateMaskMatchesVectorFamily(concreteOp, + concreteOp.getMask().getType(), + "mask type", + concreteOp.getLow().getType(), + "low vector type"); + }) + .Case([](auto concreteOp) { + return validateMaskMatchesVectorFamily(concreteOp, + concreteOp.getMask().getType(), + "mask type", + concreteOp.getLhs().getType(), + "lhs vector type"); + }) + .Default([](Operation *) { return success(); }); + } + + LogicalResult validateAuthoringFunctionSurface() { + for (func::FuncOp func : helper.getFunctions()) { + if (!func->hasAttr(pto::kPTOSimtEntryAttrName)) + continue; + + WalkResult walkResult = func.walk([&](StoreVfSimtInfoOp op) { + op.emitOpError() + << "must not appear inside a function marked with '" + << pto::kPTOSimtEntryAttrName + << "'; configure SIMT launch info in the outer non-simt caller " + "instead"; + return WalkResult::interrupt(); + }); + if (walkResult.wasInterrupted()) + return failure(); + } + return success(); + } + + LogicalResult validateAuthoringOperationSurface() { + WalkResult loopWalkResult = helper.getModule().walk([&](scf::ForOp loop) { + if (!VPTOLegalityHelper::isAIVectorScopeCarrier(loop)) + return WalkResult::advance(); + + Operation *parentScope = + VPTOLegalityHelper::getEnclosingVectorScopeCarrier(loop); + if (!parentScope) + return WalkResult::advance(); + + if (isa(parentScope)) { + loop.emitOpError() << "does not allow nested scf.for with '" + << kAIVectorScopeAttrName << "'"; + return WalkResult::interrupt(); + } + + loop.emitOpError() + << "does not allow legacy scf.for carrier nested inside dedicated " + "pto.vecscope/pto.strict_vecscope"; + return WalkResult::interrupt(); + }); + if (loopWalkResult.wasInterrupted()) + return failure(); + + WalkResult vecScopeWalkResult = helper.getModule().walk([&](Operation *op) { + if (!VPTOLegalityHelper::isDedicatedVecScopeCarrier(op)) + return WalkResult::advance(); + + if (!VPTOLegalityHelper::getEnclosingVectorScopeCarrier(op)) + return WalkResult::advance(); + + op->emitOpError() + << "does not allow nested dedicated pto.vecscope/pto.strict_vecscope"; + return WalkResult::interrupt(); + }); + if (vecScopeWalkResult.wasInterrupted()) + return failure(); + + WalkResult opWalkResult = helper.getModule().walk([&](Operation *op) { + (void)VPTOLegalityHelper::inferMaskGranularityFromFamily(op); + (void)VPTOLegalityHelper::classifyBufferAddressFamily(op); + + if (!VPTOLegalityHelper::requiresVecScope(op)) + return WalkResult::advance(); + + if (VPTOLegalityHelper::getEnclosingVectorScopeCarrier(op)) { + if (failed(validateFamilySuffixMaskContracts(op)) || + failed(validateUnaryElementTypeContracts(op)) || + failed(validateMaskGranularityContracts(op))) + return WalkResult::interrupt(); + emitHardwareSupportWarnings(op); + return WalkResult::advance(); + } + + op->emitOpError() + << "requires enclosing scf.for with '" + << kAIVectorScopeAttrName + << "' or dedicated pto.vecscope/pto.strict_vecscope" + << "' because it consumes or produces !pto.vreg/!pto.mask/!pto.align"; + return WalkResult::interrupt(); + }); + return opWalkResult.wasInterrupted() ? failure() : success(); + } + + LogicalResult validateEmissionFunctionSurface() { + for (func::FuncOp func : helper.getFunctions()) { + FunctionType functionType = func.getFunctionType(); + + for (auto [idx, inputType] : llvm::enumerate(functionType.getInputs())) { + if (!isa(inputType)) + continue; + return func.emitError() + << "emission-stage VPTO legality rejects memref argument #" + << idx << ": " << inputType; + } + + for (auto [idx, resultType] : llvm::enumerate(functionType.getResults())) { + if (!isa(resultType)) + continue; + return func.emitError() + << "emission-stage VPTO legality rejects memref result #" + << idx << ": " << resultType; + } + } + return success(); + } + + LogicalResult validateEmissionOperationSurface() { + WalkResult walkResult = helper.getModule().walk([&](Operation *op) { + VPTOBufferAddressFamily family = + VPTOLegalityHelper::classifyBufferAddressFamily(op); + + if (family == VPTOBufferAddressFamily::BufferLike) { + for (OpOperand *operand : VPTOLegalityHelper::collectBufferOperands(op)) { + Type operandType = operand->get().getType(); + if (!isa(operandType)) + continue; + + op->emitOpError() + << "emission-stage VPTO legality rejects memref-form buffer " + "operand #" + << operand->getOperandNumber() << " of type " << operandType + << " for buffer-like family op"; + return WalkResult::interrupt(); + } + } + + if (VPTOLegalityHelper::isResidualEmissionScaffold(op)) { + op->emitOpError() + << "must be eliminated before emission-stage VPTO validation"; + return WalkResult::interrupt(); + } + + return WalkResult::advance(); + }); + return walkResult.wasInterrupted() ? failure() : success(); + } + + void writeDiagnostic(StringRef message) const { + if (diagOS) + *diagOS << message; + } + + VPTOLegalityHelper helper; + VPTOLegalityStage stage; + llvm::raw_ostream *diagOS; +}; + +} // namespace detail + +namespace { + +struct PTOValidateVPTOIRPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PTOValidateVPTOIRPass) + + StringRef getArgument() const final { return "pto-validate-vpto-ir"; } + + StringRef getDescription() const final { + return "Validate authoring-stage VPTO legality before emission-boundary canonicalization"; + } + + void runOnOperation() override { + ModuleOp module = getOperation(); + if (failed(validateVPTOAuthoringIR(module, &llvm::errs()))) + signalPassFailure(); + } +}; + +struct PTOValidateVPTOEmissionIRPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PTOValidateVPTOEmissionIRPass) + + StringRef getArgument() const final { + return "pto-validate-vpto-emission-ir"; + } + + StringRef getDescription() const final { + return "Validate emission-stage VPTO legality after ptr-boundary canonicalization"; + } + + void runOnOperation() override { + ModuleOp module = getOperation(); + if (failed(validateVPTOEmissionIR(module, &llvm::errs()))) + signalPassFailure(); + } +}; + +} // namespace + +LogicalResult validateVPTOAuthoringIR(ModuleOp module, + llvm::raw_ostream *diagOS) { + return detail::VPTOLegalityValidator( + module, detail::VPTOLegalityStage::Authoring, diagOS) + .validate(); +} + +LogicalResult validateVPTOEmissionIR(ModuleOp module, + llvm::raw_ostream *diagOS) { + return detail::VPTOLegalityValidator( + module, detail::VPTOLegalityStage::Emission, diagOS) + .validate(); +} + +std::unique_ptr createPTOValidateVPTOIRPass() { + return std::make_unique(); +} + +std::unique_ptr createPTOValidateVPTOEmissionIRPass() { + return std::make_unique(); +} + +} // namespace pto +} // namespace mlir diff --git a/lib/PTO/Transforms/PTOViewToMemref.cpp b/lib/PTO/Transforms/PTOViewToMemref.cpp index c21669b81..8e2b83a42 100644 --- a/lib/PTO/Transforms/PTOViewToMemref.cpp +++ b/lib/PTO/Transforms/PTOViewToMemref.cpp @@ -28,6 +28,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" namespace mlir { namespace pto { @@ -576,7 +577,7 @@ static Type convertPTOTypeToMemRef(Type t) { // 1. 处理 !pto.ptr if (auto pty = dyn_cast(t)) { return MemRefType::get({ShapedType::kDynamic}, pty.getElementType(), - MemRefLayoutAttrInterface(), Attribute()); + MemRefLayoutAttrInterface(), pty.getMemorySpace()); } // 2. 处理 !pto.tile_buf<...> @@ -1391,6 +1392,41 @@ struct PTOViewToMemrefPass return; } + // Stage 0.40 Insert pto.bind_tile for function args that were tile_buf. + // ------------------------------------------------------------------ + // Later materialization and intrinsic folding use BindTileOp as the + // anchor to recover tile metadata after the Stage-0 type rewrite. + { + IRRewriter rewriter(ctx); + // Insert after existing block args, before any existing ops. + rewriter.setInsertionPointToStart(&entry); + for (unsigned i = 0; i < entry.getNumArguments(); ++i) { + Type origTy = fnTy.getInputs()[i]; + auto tbTy = dyn_cast(origTy); + if (!tbTy) + continue; + + auto configAttr = tbTy.getConfigAttr(); + if (!configAttr) configAttr = pto::TileBufConfigAttr::getDefault(ctx); + + Value vRow, vCol; + auto vs = tbTy.getValidShape(); + if (vs.size() == 2) { + if (vs[0] != ShapedType::kDynamic) + vRow = rewriter.create(func.getLoc(), vs[0]); + if (vs[1] != ShapedType::kDynamic) + vCol = rewriter.create(func.getLoc(), vs[1]); + } + + auto bindOp = rewriter.create( + func.getLoc(), newInputs[i], entry.getArgument(i), + vRow ? vRow : Value(), vCol ? vCol : Value(), configAttr); + markForceDynamicValidShape(bindOp, tbTy.hasDynamicValid(), ctx); + + entry.getArgument(i).replaceAllUsesExcept(bindOp.getResult(), bindOp); + } + } + // ------------------------------------------------------------------ // Stage 0.5: lower pto.alloc_tile -> memref.alloc + pto.bind_tile // ------------------------------------------------------------------ @@ -1665,7 +1701,50 @@ struct PTOViewToMemrefPass } // ------------------------------------------------------------------ - // Stage 1.5: Fold pto.addptr chains into load/store_scalar. + // Stage 1.5: Lower pto.get_tensor_view_stride -> strided memref metadata + // ------------------------------------------------------------------ + SmallVector tvStrides; + func.walk([&](mlir::pto::GetTensorViewStrideOp op) { tvStrides.push_back(op); }); + + for (auto op : tvStrides) { + IRRewriter rewriter(ctx); + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); + + Value view = op.getTensorView(); + auto mrTy = dyn_cast(view.getType()); + if (!mrTy) + continue; // leave it to later passes if it hasn't been lowered yet + + int64_t dimIndex = 0; + if (!getConstIndexValue(op.getDimIndex(), dimIndex)) { + op.emitError("get_tensor_view_stride currently expects a constant dim index"); + signalPassFailure(); + return; + } + if (dimIndex < 0 || dimIndex >= mrTy.getRank()) { + op.emitError("get_tensor_view_stride dim index is out of bounds"); + signalPassFailure(); + return; + } + + SmallVector staticStrides; + int64_t offset = ShapedType::kDynamic; + if (succeeded(getStridesAndOffset(mrTy, staticStrides, offset)) && + dimIndex < (int64_t)staticStrides.size() && + staticStrides[dimIndex] != ShapedType::kDynamic) { + rewriter.replaceOpWithNewOp( + op, staticStrides[dimIndex]); + continue; + } + + auto metadata = + rewriter.create(loc, view); + rewriter.replaceOp(op, metadata.getStrides()[dimIndex]); + } + + // ------------------------------------------------------------------ + // Stage 1.6: Fold pto.addptr chains into load/store_scalar. // ------------------------------------------------------------------ DefaultInlineVector loadScalars; func.walk([&](mlir::pto::LoadScalarOp op) { loadScalars.push_back(op); }); @@ -1863,8 +1942,11 @@ struct PTOViewToMemrefPass for (auto op : exp) { IRRewriter rewriter(ctx); rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, op->getOperand(0), op->getOperand(1)); + auto attrs = op->getAttrs(); + auto newOp = rewriter.create( + op.getLoc(), TypeRange{}, op->getOperand(0), op->getOperand(1)); + newOp->setAttrs(attrs); + rewriter.replaceOp(op, newOp->getResults()); } // --- TMulOp [Src, Scalar, Dst] --- @@ -2689,12 +2771,11 @@ struct PTOViewToMemrefPass return; } - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src0, - src1, - dst); + auto attrs = op->getAttrs(); + auto newOp = rewriter.create( + op.getLoc(), TypeRange{}, src0, src1, dst); + newOp->setAttrs(attrs); + rewriter.replaceOp(op, newOp->getResults()); } DefaultInlineVector divsops; @@ -2744,12 +2825,15 @@ struct PTOViewToMemrefPass signalPassFailure(); return; } - rewriter.replaceOpWithNewOp( - op, + auto attrs = op->getAttrs(); + auto newOp = rewriter.create( + op.getLoc(), TypeRange{}, src, scale, dst); + newOp->setAttrs(attrs); + rewriter.replaceOp(op, newOp->getResults()); } DefaultInlineVector expandsops; @@ -3055,11 +3139,11 @@ struct PTOViewToMemrefPass return; } - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - dst); + auto attrs = op->getAttrs(); + auto newOp = rewriter.create( + op.getLoc(), TypeRange{}, src, dst); + newOp->setAttrs(attrs); + rewriter.replaceOp(op, newOp->getResults()); } DefaultInlineVector lreluops; diff --git a/lib/PTO/Transforms/VPTOBufferMaterialization.cpp b/lib/PTO/Transforms/VPTOBufferMaterialization.cpp new file mode 100644 index 000000000..a12fc8771 --- /dev/null +++ b/lib/PTO/Transforms/VPTOBufferMaterialization.cpp @@ -0,0 +1,89 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "PTO/IR/PTO.h" +#include "PTO/Transforms/VPTOLowering.h" + +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" + +namespace mlir::pto { +namespace { + +static AddressSpaceAttr getNormalizedPtrMemorySpace(Attribute memorySpace, + MLIRContext *context) { + if (auto addrSpace = dyn_cast_or_null(memorySpace)) + return addrSpace; + if (auto intAttr = dyn_cast_or_null(memorySpace)) + return AddressSpaceAttr::get(context, + static_cast(intAttr.getInt())); + return AddressSpaceAttr::get(context, AddressSpace::GM); +} + +static Value materializeMemRefView(Value value, ArrayRef shape, + Type elementType, Attribute memorySpace, + PatternRewriter &rewriter, Location loc) { + auto memrefType = + MemRefType::get(shape, elementType, AffineMap(), memorySpace); + if (value.getType() == memrefType) + return value; + return rewriter + .create( + loc, TypeRange(ArrayRef{memrefType}), value) + .getResult(0); +} + +static Value materializeTileBufferView(Value value, PatternRewriter &rewriter, + Location loc) { + if (isa(value.getType())) + return value; + + auto tileType = dyn_cast(value.getType()); + if (!tileType) + return {}; + + return materializeMemRefView(value, tileType.getShape(), + tileType.getElementType(), + tileType.getMemorySpace(), rewriter, loc); +} + +} // namespace + +Value materializeBufferPointer(Value value, Type elementType, + Attribute memorySpace, + PatternRewriter &rewriter, Location loc) { + if (!value) + return {}; + + auto ptrMemorySpace = + getNormalizedPtrMemorySpace(memorySpace, rewriter.getContext()); + auto ptrType = PtrType::get(rewriter.getContext(), elementType, ptrMemorySpace); + + if (value.getType() == ptrType) + return value; + + if (auto bind = value.getDefiningOp()) + return materializeBufferPointer(bind.getSource(), elementType, memorySpace, + rewriter, loc); + + if (auto cast = value.getDefiningOp()) { + if (cast.getAddrs().empty()) + return {}; + return rewriter.create(loc, ptrType, cast.getAddrs().front()) + .getResult(); + } + + Value memrefValue = materializeTileBufferView(value, rewriter, loc); + auto memrefType = dyn_cast_or_null(memrefValue.getType()); + if (!memrefValue || !memrefType) + return {}; + return rewriter.create(loc, ptrType, memrefValue).getResult(); +} + +} // namespace mlir::pto diff --git a/lib/PTO/Transforms/VPTOExpandWrapperOps.cpp b/lib/PTO/Transforms/VPTOExpandWrapperOps.cpp new file mode 100644 index 000000000..e4871a630 --- /dev/null +++ b/lib/PTO/Transforms/VPTOExpandWrapperOps.cpp @@ -0,0 +1,1674 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "PTO/IR/PTO.h" +#include "PTO/IR/PTOTypeUtils.h" +#include "PTO/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_VPTOEXPANDWRAPPEROPS +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +using namespace mlir; + +namespace { + +static pto::AddressSpaceAttr getPointerMemorySpace(Attribute memorySpace, + MLIRContext *ctx) { + if (auto addrSpace = dyn_cast_or_null(memorySpace)) + return addrSpace; + if (auto intAttr = dyn_cast_or_null(memorySpace)) + return pto::AddressSpaceAttr::get( + ctx, static_cast(intAttr.getInt())); + return pto::AddressSpaceAttr::get(ctx, pto::AddressSpace::GM); +} + +static Value materializeBufferPointer(Value value, PatternRewriter &rewriter, + Location loc) { + if (!value) + return {}; + + if (isa(value.getType())) + return value; + + auto memrefType = dyn_cast(value.getType()); + if (!memrefType) + return {}; + + auto ptrType = + pto::PtrType::get(rewriter.getContext(), memrefType.getElementType(), + getPointerMemorySpace(memrefType.getMemorySpace(), + rewriter.getContext())); + return rewriter.create(loc, ptrType, value).getResult(); +} + +static Type getBufferElementType(Type type) { + if (auto ptrType = dyn_cast(type)) + return ptrType.getElementType(); + if (auto memrefType = dyn_cast(type)) + return memrefType.getElementType(); + return {}; +} + +static Value offsetBufferPointer(Value basePtr, Type elementType, + Value elementOffset, + PatternRewriter &rewriter, Location loc) { + if (!basePtr) + return {}; + + Value offsetIndex = elementOffset; + if (!offsetIndex.getType().isIndex()) + offsetIndex = rewriter.create(loc, + rewriter.getIndexType(), + elementOffset); + return rewriter.create(loc, basePtr.getType(), basePtr, + offsetIndex); +} + +static bool isKnownOne(Value value) { + APInt intValue; + return value && matchPattern(value, m_ConstantInt(&intValue)) && + intValue.isOne(); +} + +static bool shouldRestoreDmaLoopSize(Value loop1Count, Value loop2Count) { + if (!loop1Count) + return false; + return !isKnownOne(loop1Count) || !isKnownOne(loop2Count); +} + +static SmallVector collectLoopConfigs(ValueRange counts, + ValueRange srcStrides, + ValueRange dstStrides) { + SmallVector loops; + loops.reserve(counts.size()); + for (auto [count, srcStride, dstStride] : + llvm::zip(counts, srcStrides, dstStrides)) + loops.push_back({count, srcStride, dstStride}); + return loops; +} + +static Value offsetPointerByBytes(Value basePtr, Value byteOffset, + PatternRewriter &rewriter, Location loc) { + if (!basePtr) + return {}; + + Value basePtrValue = materializeBufferPointer(basePtr, rewriter, loc); + auto ptrType = dyn_cast_or_null(basePtrValue.getType()); + if (!ptrType) + return {}; + + APInt constOffset; + if (matchPattern(byteOffset, m_ConstantInt(&constOffset)) && constOffset.isZero()) + return basePtrValue; + + auto bytePtrType = + pto::PtrType::get(rewriter.getContext(), rewriter.getI8Type(), + ptrType.getMemorySpace()); + Value bytePtr = + rewriter.create(loc, bytePtrType, basePtrValue); + Value offsetIndex = byteOffset; + if (!offsetIndex.getType().isIndex()) + offsetIndex = + rewriter.create(loc, rewriter.getIndexType(), + offsetIndex); + Value advanced = + rewriter.create(loc, bytePtrType, bytePtr, offsetIndex); + return rewriter.create(loc, ptrType, advanced); +} + +[[maybe_unused]] static Value materializeFpcValue(Value fpc, + PatternRewriter &rewriter, + Location loc) { + if (!fpc) + return {}; + if (fpc.getType().isInteger(64)) + return fpc; + if (isa(fpc.getType())) + return rewriter.create(loc, rewriter.getI64Type(), fpc); + return {}; +} + +static Value materializeI64Value(Value value, PatternRewriter &rewriter, + Location loc) { + if (!value) + return {}; + if (value.getType().isInteger(64)) + return value; + if (auto intType = dyn_cast(value.getType())) + return rewriter.create(loc, rewriter.getI64Type(), value); + if (isa(value.getType())) + return rewriter.create(loc, rewriter.getI64Type(), value); + return {}; +} + +static Value materializeAccStoreScalarPayload(Value value, + PatternRewriter &rewriter, + Location loc) { + if (!value) + return {}; + if (Value raw = materializeI64Value(value, rewriter, loc)) + return raw; + + Type type = value.getType(); + Value f32Value = value; + if (type.isF16() || type.isBF16()) { + f32Value = rewriter.create(loc, rewriter.getF32Type(), value); + } else if (!type.isF32()) { + return {}; + } + + Value bitsI32 = rewriter.create(loc, rewriter.getI32Type(), f32Value); + return rewriter.create(loc, rewriter.getI64Type(), bitsI32); +} + +static Value materializeAccStoreClipPayload(Value value, Type destinationElementType, + PatternRewriter &rewriter, + Location loc) { + if (!value) + return {}; + + if (value.getType().isF16()) { + Value bitsI16 = + rewriter.create(loc, rewriter.getI16Type(), value); + return rewriter.create(loc, rewriter.getI64Type(), bitsI16); + } + + auto intType = dyn_cast(value.getType()); + if (!intType) + return {}; + + Value widened; + if (auto dstIntType = dyn_cast(destinationElementType); + dstIntType && dstIntType.isUnsignedInteger(8)) { + widened = rewriter.create(loc, rewriter.getI64Type(), value); + } else { + widened = rewriter.create(loc, rewriter.getI64Type(), value); + } + + Value mask = rewriter.create(loc, 0xFFFF, 64); + return rewriter.create(loc, widened, mask); +} + +static Value getI64Constant(Location loc, PatternRewriter &rewriter, + uint64_t value) { + return rewriter.create(loc, value, 64); +} + +static Value buildAccStoreOptionalEnumValue(Location loc, + std::optional value, + PatternRewriter &rewriter) { + return getI64Constant(loc, rewriter, value.value_or(0)); +} + +static Value buildAccStoreFpcValue(Location loc, Value preQuant, + std::optional preQuantMode, + Value preRelu, + std::optional preReluMode, + PatternRewriter &rewriter) { + auto encodeFixpipeBufferAddr = [&](Value addr, uint64_t unitShift) -> Value { + Value segmentMask = getI64Constant(loc, rewriter, 0xffff); + Value fieldMask = getI64Constant(loc, rewriter, 0xff); + Value segmentOffset = rewriter.create(loc, addr, segmentMask); + Value scaledAddr = rewriter.create( + loc, segmentOffset, getI64Constant(loc, rewriter, unitShift)); + return rewriter.create(loc, scaledAddr, fieldMask); + }; + + Value quantAddr; + if (preQuantMode) { + switch (*preQuantMode) { + case pto::AccStoreQuantPreMode::QF322HIF8PreVec: + case pto::AccStoreQuantPreMode::QF322HIF8PreHybridVec: + case pto::AccStoreQuantPreMode::DEQS32IntVec: + case pto::AccStoreQuantPreMode::REQ8Vec: + case pto::AccStoreQuantPreMode::DEQF16Vec: + case pto::AccStoreQuantPreMode::QF322FP8PreVec: + case pto::AccStoreQuantPreMode::QF322F32PreVec: + case pto::AccStoreQuantPreMode::QF162B8PreVec: + case pto::AccStoreQuantPreMode::QF162S4PreVec: + case pto::AccStoreQuantPreMode::REQ4Vec: + case pto::AccStoreQuantPreMode::QF322B8PreVec: + case pto::AccStoreQuantPreMode::QF322S4PreVec: + case pto::AccStoreQuantPreMode::DEQS16Vec: + case pto::AccStoreQuantPreMode::QF162S16PreVec: + case pto::AccStoreQuantPreMode::QF322F16PreVec: + case pto::AccStoreQuantPreMode::QF322BF16PreVec: + case pto::AccStoreQuantPreMode::QS322BF16PreVec: + if (Value quantPtr = materializeI64Value(preQuant, rewriter, loc)) + quantAddr = encodeFixpipeBufferAddr(quantPtr, /*unitShift=*/7); + break; + default: + break; + } + } + + Value reluAddr; + if (preReluMode && *preReluMode == pto::ReluPreMode::VectorRelu) { + if (Value reluPtr = materializeI64Value(preRelu, rewriter, loc)) + reluAddr = encodeFixpipeBufferAddr(reluPtr, /*unitShift=*/6); + } + + if (!quantAddr && !reluAddr) + return {}; + + Value mask = getI64Constant(loc, rewriter, 0xff); + Value fpc = getI64Constant(loc, rewriter, 0); + if (quantAddr) { + Value quantShift = getI64Constant(loc, rewriter, 8); + Value quantBits = rewriter.create(loc, quantAddr, quantShift); + fpc = rewriter.create(loc, fpc, quantBits); + } + if (reluAddr) { + Value reluBits = rewriter.create(loc, reluAddr, mask); + fpc = rewriter.create(loc, fpc, reluBits); + } + return fpc; +} + +static void configureAccStoreScalarPreOps(Location loc, Value preQuant, + std::optional preQuantMode, + Value preRelu, + std::optional preReluMode, + Value clipValue, + Type destinationElementType, + PatternRewriter &rewriter) { + auto isVectorQuantMode = [](pto::AccStoreQuantPreMode mode) { + switch (mode) { + case pto::AccStoreQuantPreMode::QF322HIF8PreVec: + case pto::AccStoreQuantPreMode::QF322HIF8PreHybridVec: + case pto::AccStoreQuantPreMode::DEQS32IntVec: + case pto::AccStoreQuantPreMode::REQ8Vec: + case pto::AccStoreQuantPreMode::DEQF16Vec: + case pto::AccStoreQuantPreMode::QF322FP8PreVec: + case pto::AccStoreQuantPreMode::QF322F32PreVec: + case pto::AccStoreQuantPreMode::QF162B8PreVec: + case pto::AccStoreQuantPreMode::QF162S4PreVec: + case pto::AccStoreQuantPreMode::REQ4Vec: + case pto::AccStoreQuantPreMode::QF322B8PreVec: + case pto::AccStoreQuantPreMode::QF322S4PreVec: + case pto::AccStoreQuantPreMode::DEQS16Vec: + case pto::AccStoreQuantPreMode::QF162S16PreVec: + case pto::AccStoreQuantPreMode::QF322F16PreVec: + case pto::AccStoreQuantPreMode::QF322BF16PreVec: + case pto::AccStoreQuantPreMode::QS322BF16PreVec: + return true; + default: + return false; + } + }; + + if (preQuantMode && !isVectorQuantMode(*preQuantMode)) { + if (Value quantValue = materializeAccStoreScalarPayload(preQuant, rewriter, loc)) + rewriter.create(loc, quantValue); + } + if (preReluMode && *preReluMode == pto::ReluPreMode::ScalarRelu) { + if (Value reluAlpha = materializeAccStoreScalarPayload(preRelu, rewriter, loc)) + rewriter.create(loc, reluAlpha); + } + if (clipValue) { + if (Value clip = materializeAccStoreClipPayload(clipValue, + destinationElementType, + rewriter, loc)) + rewriter.create(loc, clip); + } +} + +static Value configureAccStoreCtrl(Location loc, bool allowAtomic, + std::optional atomicType, + std::optional atomicOp, + std::optional satMode, + PatternRewriter &rewriter) { + if ((!allowAtomic || !atomicType || !atomicOp) && !satMode) + return {}; + + Value originalCtrl = rewriter.create(loc); + Value ctrl = originalCtrl; + uint64_t clearMaskValue = 0; + if (allowAtomic && atomicType && atomicOp) + clearMaskValue |= (static_cast(0x7) << 6) | + (static_cast(0x3) << 9); + if (satMode) + clearMaskValue |= (static_cast(1) << 48) | + (static_cast(1) << 50); + Value clearMask = getI64Constant(loc, rewriter, clearMaskValue); + Value fullMask = getI64Constant(loc, rewriter, ~static_cast(0)); + Value keepMask = rewriter.create(loc, clearMask, fullMask); + ctrl = rewriter.create(loc, ctrl, keepMask); + + if (allowAtomic && atomicType && atomicOp) { + uint64_t atomicBits = (static_cast(static_cast(*atomicType)) << 6) | + (static_cast(static_cast(*atomicOp)) << 9); + ctrl = rewriter.create(loc, ctrl, + getI64Constant(loc, rewriter, atomicBits)); + } + if (satMode && *satMode == pto::AccStoreSatMode::NoSat) { + ctrl = rewriter.create( + loc, ctrl, getI64Constant(loc, rewriter, + static_cast(1) << 48)); + } + if (satMode && *satMode == pto::AccStoreSatMode::SatPreserveNan) { + ctrl = rewriter.create( + loc, ctrl, getI64Constant(loc, rewriter, + static_cast(1) << 50)); + } + rewriter.create(loc, ctrl); + return originalCtrl; +} + +static Value buildAccumulatedByteOffset(Location loc, Value baseOffset, + Value indexI64, Value stride, + PatternRewriter &rewriter) { + Value delta = rewriter.create(loc, indexI64, stride); + return rewriter.create(loc, baseOffset, delta); +} + +static Value packLoopPair(Location loc, Value low, Value high, + PatternRewriter &rewriter) { + Value shift = rewriter.create(loc, 40, 64); + Value highShifted = rewriter.create(loc, high, shift); + return rewriter.create(loc, highShifted, low); +} + +static Value packLoopSize(Location loc, Value loop2, Value loop1, + PatternRewriter &rewriter) { + Value shift = rewriter.create(loc, 21, 64); + Value loop2Shifted = rewriter.create(loc, loop2, shift); + return rewriter.create(loc, loop2Shifted, loop1); +} + +static Value castIntegerLikeTo(Location loc, Value value, Type targetType, + PatternRewriter &rewriter) { + if (value.getType() == targetType) + return value; + + auto targetInt = dyn_cast(targetType); + if (value.getType().isIndex() && targetInt) + return rewriter.create(loc, targetType, value); + if (auto sourceInt = dyn_cast(value.getType())) { + if (targetInt) { + if (sourceInt.getWidth() < targetInt.getWidth()) + return rewriter.create(loc, targetType, value); + if (sourceInt.getWidth() > targetInt.getWidth()) + return rewriter.create(loc, targetType, value); + return value; + } + if (targetType.isIndex()) + return rewriter.create(loc, targetType, value); + } + + return {}; +} + +static FailureOr packMadXt(Location loc, Value m, Value n, Value k, + std::optional unitFlagMode, + bool disableGemv, bool cmatrixSource, + bool cmatrixInit, + PatternRewriter &rewriter) { + Type i64Ty = rewriter.getI64Type(); + Value mI64 = castIntegerLikeTo(loc, m, i64Ty, rewriter); + Value nI64 = castIntegerLikeTo(loc, n, i64Ty, rewriter); + Value kI64 = castIntegerLikeTo(loc, k, i64Ty, rewriter); + if (!mI64 || !nI64 || !kI64) + return failure(); + + auto constant = [&](uint64_t value) -> Value { + return rewriter.create(loc, value, 64); + }; + auto shl = [&](Value value, uint64_t amount) -> Value { + return rewriter.create(loc, value, constant(amount)); + }; + auto bitOr = [&](Value lhs, Value rhs) -> Value { + return rewriter.create(loc, lhs, rhs); + }; + + Value xt = mI64; + xt = bitOr(xt, shl(kI64, 12)); + xt = bitOr(xt, shl(nI64, 24)); + if (unitFlagMode) { + uint64_t unitFlagCtrl = + *unitFlagMode == pto::MadUnitFlagMode::CheckOnly ? 2 : 3; + xt = bitOr(xt, shl(constant(unitFlagCtrl), 55)); + } + if (disableGemv) + xt = bitOr(xt, shl(constant(1), 61)); + if (cmatrixSource) + xt = bitOr(xt, shl(constant(1), 62)); + if (cmatrixInit) + xt = bitOr(xt, shl(constant(1), 63)); + return xt; +} + +static Value setCtrlBit(Location loc, Value ctrl, unsigned bitIndex, bool value, + PatternRewriter &rewriter) { + Value bit = rewriter.create(loc, bitIndex, 64); + if (value) + return rewriter.create(loc, ctrl, bit).getResult(); + return rewriter.create(loc, ctrl, bit).getResult(); +} + +static Value buildMadSemanticCtrl(Location loc, Value ctrl, + bool isHif8, + std::optional tf32Mode, + std::optional satMode, + bool hasNDir, + PatternRewriter &rewriter) { + ctrl = setCtrlBit(loc, ctrl, 45, isHif8, rewriter); + if (tf32Mode) { + ctrl = setCtrlBit(loc, ctrl, 46, true, rewriter); + ctrl = setCtrlBit(loc, ctrl, 47, + *tf32Mode == pto::Tf32Mode::RoundAway, rewriter); + } else { + ctrl = setCtrlBit(loc, ctrl, 46, false, rewriter); + ctrl = setCtrlBit(loc, ctrl, 47, false, rewriter); + } + if (satMode) + ctrl = setCtrlBit(loc, ctrl, 48, *satMode == pto::MadSatMode::NoSat, + rewriter); + ctrl = setCtrlBit(loc, ctrl, 51, hasNDir, rewriter); + return ctrl; +} + +static Value packMte2NzPara(Location loc, Value groupCount, Value dstLoop2Stride, + Value dstLoop3Stride, Value dstLoop4Stride, + PatternRewriter &rewriter) { + Value shift16 = rewriter.create(loc, 16, 64); + Value shift32 = rewriter.create(loc, 32, 64); + Value shift48 = rewriter.create(loc, 48, 64); + Value loop2Bits = + rewriter.create(loc, dstLoop2Stride, shift16); + Value loop3Bits = + rewriter.create(loc, dstLoop3Stride, shift32); + Value loop4Bits = + rewriter.create(loc, dstLoop4Stride, shift48); + Value low = rewriter.create(loc, groupCount, loop2Bits); + Value high = rewriter.create(loc, loop3Bits, loop4Bits); + return rewriter.create(loc, low, high); +} + +static Value packCopyMatrixCcToGmXm(Location loc, Value sid, Value nSize, + Value mSize, Value dstStride, + PatternRewriter &rewriter) { + Value nShift4 = rewriter.create(loc, 4, 64); + Value mShift16 = rewriter.create(loc, 16, 64); + Value dstShift32 = rewriter.create(loc, 32, 64); + Value nBits = rewriter.create(loc, nSize, nShift4); + Value mBits = rewriter.create(loc, mSize, mShift16); + Value dstStrideBits = rewriter.create(loc, dstStride, dstShift32); + Value sidMask = rewriter.create(loc, 0xf, 64); + Value sidBits = rewriter.create(loc, sid, sidMask); + Value xmLow = rewriter.create(loc, sidBits, nBits); + xmLow = rewriter.create(loc, xmLow, mBits); + return rewriter.create(loc, xmLow, dstStrideBits); +} + +static Value packCopyMatrixCcToGmXt(Location loc, Value srcStride, + Value clipReluPre, Value unitFlagCtrl, + Value quantPre, Value reluPreMode, + Value l2CacheCtrl, + Value nz2ndEn, Value channelSplitEn, + Value nz2dnEn, + PatternRewriter &rewriter) { + Value l2CacheShift16 = rewriter.create(loc, 16, 64); + Value clipReluShift30 = rewriter.create(loc, 30, 64); + Value unitFlagShift32 = rewriter.create(loc, 32, 64); + Value quantBlockBitShift29 = + rewriter.create(loc, 29, 64); + Value quantFieldShift34 = rewriter.create(loc, 34, 64); + Value reluShift39 = rewriter.create(loc, 39, 64); + Value channelSplitShift42 = + rewriter.create(loc, 42, 64); + Value nz2ndShift43 = rewriter.create(loc, 43, 64); + Value nz2dnShift62 = rewriter.create(loc, 62, 64); + + Value quantShift5 = rewriter.create(loc, 5, 64); + Value quantLowMask = rewriter.create(loc, 0x1f, 64); + Value quantBitMask = rewriter.create(loc, 0x1, 64); + Value clipReluMask = rewriter.create(loc, 0x3, 64); + Value l2CacheMask = rewriter.create(loc, 0xf, 64); + Value unitFlagMask = rewriter.create(loc, 0x3, 64); + Value reluMask = rewriter.create(loc, 0x7, 64); + + Value l2CacheBits = rewriter.create(loc, l2CacheCtrl, l2CacheMask); + l2CacheBits = + rewriter.create(loc, l2CacheBits, l2CacheShift16); + + Value clipReluBits = + rewriter.create(loc, clipReluPre, clipReluMask); + clipReluBits = + rewriter.create(loc, clipReluBits, clipReluShift30); + + Value unitFlagBits = rewriter.create(loc, unitFlagCtrl, unitFlagMask); + unitFlagBits = + rewriter.create(loc, unitFlagBits, unitFlagShift32); + + Value quantBlockBit = rewriter.create(loc, quantPre, quantShift5); + quantBlockBit = + rewriter.create(loc, quantBlockBit, quantBitMask); + quantBlockBit = rewriter.create(loc, quantBlockBit, + quantBlockBitShift29); + + Value quantField = rewriter.create(loc, quantPre, quantLowMask); + quantField = + rewriter.create(loc, quantField, quantFieldShift34); + + Value reluBits = rewriter.create(loc, reluPreMode, reluMask); + reluBits = rewriter.create(loc, reluBits, reluShift39); + + Value channelSplitBits = + rewriter.create(loc, channelSplitEn, quantBitMask); + channelSplitBits = rewriter.create(loc, channelSplitBits, + channelSplitShift42); + + Value nz2ndBits = rewriter.create(loc, nz2ndEn, quantBitMask); + nz2ndBits = + rewriter.create(loc, nz2ndBits, nz2ndShift43); + + Value nz2dnBits = rewriter.create(loc, nz2dnEn, quantBitMask); + nz2dnBits = + rewriter.create(loc, nz2dnBits, nz2dnShift62); + + Value xt = rewriter.create(loc, srcStride, l2CacheBits); + xt = rewriter.create(loc, xt, clipReluBits); + xt = rewriter.create(loc, xt, unitFlagBits); + xt = rewriter.create(loc, xt, quantBlockBit); + xt = rewriter.create(loc, xt, quantField); + xt = rewriter.create(loc, xt, reluBits); + xt = rewriter.create(loc, xt, channelSplitBits); + xt = rewriter.create(loc, xt, nz2ndBits); + return rewriter.create(loc, xt, nz2dnBits); +} + +static Value packCopyMatrixCcToUbConfig1(Location loc, Value srcStride, + Value dualDstMode, Value subBlockId, + Value clipReluPre, Value unitFlagCtrl, + Value quantPre, Value reluPreMode, + Value nz2ndEn, Value channelSplitEn, + Value nz2dnEn, + PatternRewriter &rewriter) { + Value dualDstShift16 = rewriter.create(loc, 16, 64); + Value subBlockShift18 = rewriter.create(loc, 18, 64); + Value clipReluShift30 = rewriter.create(loc, 30, 64); + Value unitFlagShift32 = rewriter.create(loc, 32, 64); + Value quantBlockBitShift29 = + rewriter.create(loc, 29, 64); + Value quantFieldShift34 = rewriter.create(loc, 34, 64); + Value reluShift39 = rewriter.create(loc, 39, 64); + Value channelSplitShift42 = + rewriter.create(loc, 42, 64); + Value nz2ndShift43 = rewriter.create(loc, 43, 64); + Value nz2dnShift62 = rewriter.create(loc, 62, 64); + + Value dualDstMask = rewriter.create(loc, 0x3, 64); + Value subBlockMask = rewriter.create(loc, 0x1, 64); + Value quantShift5 = rewriter.create(loc, 5, 64); + Value quantLowMask = rewriter.create(loc, 0x1f, 64); + Value quantBitMask = rewriter.create(loc, 0x1, 64); + Value clipReluMask = rewriter.create(loc, 0x3, 64); + Value unitFlagMask = rewriter.create(loc, 0x3, 64); + Value reluMask = rewriter.create(loc, 0x7, 64); + + Value dualDstBits = rewriter.create(loc, dualDstMode, dualDstMask); + dualDstBits = + rewriter.create(loc, dualDstBits, dualDstShift16); + + Value subBlockBits = rewriter.create(loc, subBlockId, subBlockMask); + subBlockBits = + rewriter.create(loc, subBlockBits, subBlockShift18); + + Value clipReluBits = + rewriter.create(loc, clipReluPre, clipReluMask); + clipReluBits = + rewriter.create(loc, clipReluBits, clipReluShift30); + + Value unitFlagBits = rewriter.create(loc, unitFlagCtrl, unitFlagMask); + unitFlagBits = + rewriter.create(loc, unitFlagBits, unitFlagShift32); + + Value quantBlockBit = rewriter.create(loc, quantPre, quantShift5); + quantBlockBit = + rewriter.create(loc, quantBlockBit, quantBitMask); + quantBlockBit = rewriter.create(loc, quantBlockBit, + quantBlockBitShift29); + + Value quantField = rewriter.create(loc, quantPre, quantLowMask); + quantField = + rewriter.create(loc, quantField, quantFieldShift34); + + Value reluBits = rewriter.create(loc, reluPreMode, reluMask); + reluBits = rewriter.create(loc, reluBits, reluShift39); + + Value channelSplitBits = + rewriter.create(loc, channelSplitEn, quantBitMask); + channelSplitBits = rewriter.create(loc, channelSplitBits, + channelSplitShift42); + + Value nz2ndBits = rewriter.create(loc, nz2ndEn, quantBitMask); + nz2ndBits = + rewriter.create(loc, nz2ndBits, nz2ndShift43); + + Value nz2dnBits = rewriter.create(loc, nz2dnEn, quantBitMask); + nz2dnBits = + rewriter.create(loc, nz2dnBits, nz2dnShift62); + + Value config1 = rewriter.create(loc, srcStride, dualDstBits); + config1 = rewriter.create(loc, config1, subBlockBits); + config1 = rewriter.create(loc, config1, clipReluBits); + config1 = rewriter.create(loc, config1, unitFlagBits); + config1 = rewriter.create(loc, config1, quantBlockBit); + config1 = rewriter.create(loc, config1, quantField); + config1 = rewriter.create(loc, config1, reluBits); + config1 = rewriter.create(loc, config1, channelSplitBits); + config1 = rewriter.create(loc, config1, nz2ndBits); + return rewriter.create(loc, config1, nz2dnBits); +} + +static Value packLoop3Config(Location loc, Value count, Value srcStride, + Value dstStride, PatternRewriter &rewriter) { + Value srcShift16 = rewriter.create(loc, 16, 64); + Value dstShift32 = rewriter.create(loc, 32, 64); + Value srcBits = rewriter.create(loc, srcStride, srcShift16); + Value dstBits = rewriter.create(loc, dstStride, dstShift32); + Value low = rewriter.create(loc, count, srcBits); + return rewriter.create(loc, low, dstBits); +} + +static Value packChannelConfig(Location loc, Value loop0SrcStride, + PatternRewriter &rewriter) { + Value shift48 = rewriter.create(loc, 48, 64); + return rewriter.create(loc, loop0SrcStride, shift48); +} + +struct LoadCbufToCbControl { + Value mStart; + Value kStart; + Value mStep; + Value kStep; + Value srcStride; + Value dstStride; +}; + +static FailureOr +deriveLoadCbufToCbControl(Location loc, Value k, Value n, Type elementType, + bool transpose, PatternRewriter &rewriter) { + unsigned elemBitWidth = elementType.getIntOrFloatBitWidth(); + if (elemBitWidth == 0 || (elemBitWidth % 8) != 0) + return failure(); + uint64_t elemBytes = elemBitWidth / 8; + + auto constant = [&](uint64_t value) -> Value { + return rewriter.create(loc, value, 64); + }; + auto ceilDivConst = [&](Value value, uint64_t divisor) -> Value { + Value bias = constant(divisor - 1); + Value sum = rewriter.create(loc, value, bias); + return rewriter.create(loc, sum, constant(divisor)); + }; + + Value zero = constant(0); + if (!transpose) { + Value mStep = ceilDivConst(n, 16); + Value kBytes = rewriter.create(loc, k, constant(elemBytes)); + Value kStep = ceilDivConst(kBytes, 32); + Value stride = ceilDivConst(n, 16); + return LoadCbufToCbControl{zero, zero, mStep, kStep, stride, stride}; + } + + uint64_t c0Size = std::max(16, 32 / elemBytes); + Value kAlign = ceilDivConst(k, c0Size); + kAlign = rewriter.create(loc, kAlign, constant(c0Size)); + Value nAlign = ceilDivConst(n, c0Size); + nAlign = rewriter.create(loc, nAlign, constant(c0Size)); + Value mStep = ceilDivConst(kAlign, 16); + Value nBytes = rewriter.create(loc, nAlign, constant(elemBytes)); + Value kStep = ceilDivConst(nBytes, 32); + Value srcStride = ceilDivConst(kAlign, 16); + Value dstStride = ceilDivConst(nAlign, 16); + return LoadCbufToCbControl{zero, zero, mStep, kStep, srcStride, dstStride}; +} + +static FailureOr +deriveLoadCbufToCaControl(Location loc, Value m, Value k, Type elementType, + bool transpose, PatternRewriter &rewriter) { + unsigned elemBitWidth = elementType.getIntOrFloatBitWidth(); + if (elemBitWidth == 0 || (elemBitWidth % 8) != 0) + return failure(); + uint64_t elemBytes = elemBitWidth / 8; + + auto constant = [&](uint64_t value) -> Value { + return rewriter.create(loc, value, 64); + }; + auto ceilDivConst = [&](Value value, uint64_t divisor) -> Value { + Value bias = constant(divisor - 1); + Value sum = rewriter.create(loc, value, bias); + return rewriter.create(loc, sum, constant(divisor)); + }; + + Value zero = constant(0); + if (!transpose) { + Value mStep = ceilDivConst(m, 16); + Value kBytes = rewriter.create(loc, k, constant(elemBytes)); + Value kStep = ceilDivConst(kBytes, 32); + Value stride = ceilDivConst(m, 16); + return LoadCbufToCbControl{zero, zero, mStep, kStep, stride, stride}; + } + + uint64_t c0Size = std::max(16, 32 / elemBytes); + Value mAlign = ceilDivConst(m, c0Size); + mAlign = rewriter.create(loc, mAlign, constant(c0Size)); + Value kAlign = ceilDivConst(k, c0Size); + kAlign = rewriter.create(loc, kAlign, constant(c0Size)); + Value mStep = ceilDivConst(kAlign, 16); + Value mBytes = rewriter.create(loc, mAlign, constant(elemBytes)); + Value kStep = ceilDivConst(mBytes, 32); + Value srcStride = ceilDivConst(kAlign, 16); + Value dstStride = ceilDivConst(mAlign, 16); + return LoadCbufToCbControl{zero, zero, mStep, kStep, srcStride, dstStride}; +} + +static Value extractConfigLow40(Location loc, Value packed, + PatternRewriter &rewriter) { + Value lowMask = + rewriter.create(loc, 0xffffffffffULL, 64); + return rewriter.create(loc, packed, lowMask); +} + +static Value extractConfigHigh24(Location loc, Value packed, + PatternRewriter &rewriter) { + Value shift40 = rewriter.create(loc, 40, 64); + return rewriter.create(loc, packed, shift40); +} + +template +static void buildSoftwareLoopNest(PatternRewriter &rewriter, Location loc, + ArrayRef loops, + Value srcOffset, Value dstOffset, + BodyBuilder &&buildLeaf) { + if (loops.empty()) { + buildLeaf(srcOffset, dstOffset); + return; + } + + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + Value count = rewriter.create(loc, rewriter.getIndexType(), + loops.front().count); + scf::ForOp forOp = rewriter.create(loc, c0, count, c1); + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(forOp.getBody()); + Value ivI64 = + rewriter.create(loc, rewriter.getI64Type(), + forOp.getInductionVar()); + Value nextSrcOffset = buildAccumulatedByteOffset( + loc, srcOffset, ivI64, loops.front().srcStride, rewriter); + Value nextDstOffset = buildAccumulatedByteOffset( + loc, dstOffset, ivI64, loops.front().dstStride, rewriter); + buildSoftwareLoopNest(rewriter, loc, loops.drop_front(), nextSrcOffset, + nextDstOffset, buildLeaf); + } +} + +struct ExpandUvldPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(pto::UvldOp op, + PatternRewriter &rewriter) const override { + auto vecType = dyn_cast(op.getResult().getType()); + if (!vecType) + return failure(); + + Value basePtr = materializeBufferPointer(op.getSource(), rewriter, op.getLoc()); + if (!basePtr) + return op.emitOpError( + "requires a recoverable pointer base for uvld expansion"); + + Value loadPtr = offsetBufferPointer(basePtr, vecType.getElementType(), + op.getOffset(), rewriter, op.getLoc()); + auto alignType = pto::AlignType::get(rewriter.getContext()); + Value align = + rewriter.create(op.getLoc(), alignType, loadPtr); + auto load = rewriter.create( + op.getLoc(), TypeRange{vecType, alignType}, + ValueRange{loadPtr, align}); + rewriter.replaceOp(op, load.getResult()); + return success(); + } +}; + +enum class MadRawKind { Ordinary, OrdinaryBias, Mx, MxBias }; + +static MadRawKind deriveMadRawKind(pto::MadSemanticOpInterface op) { + if (op.isMadMxFamily()) + return op.hasBiasOperand() ? MadRawKind::MxBias : MadRawKind::Mx; + return op.hasBiasOperand() ? MadRawKind::OrdinaryBias + : MadRawKind::Ordinary; +} + +static LogicalResult emitMadRawOp(pto::MadSemanticOpInterface op, + MadRawKind kind, Value xt, + PatternRewriter &rewriter) { + Location loc = op->getLoc(); + Value lhs = op.getLhs(); + Value rhs = op.getRhs(); + Value dst = op.getDst(); + switch (kind) { + case MadRawKind::Ordinary: + rewriter.create(loc, lhs, rhs, dst, xt); + return success(); + case MadRawKind::OrdinaryBias: + rewriter.create(loc, lhs, rhs, dst, op.getBiasOrNull(), + xt); + return success(); + case MadRawKind::Mx: + rewriter.create(loc, lhs, rhs, dst, xt); + return success(); + case MadRawKind::MxBias: + rewriter.create(loc, lhs, rhs, dst, + op.getBiasOrNull(), xt); + return success(); + } + return failure(); +} + +static LogicalResult lowerMadSemanticOp(pto::MadSemanticOpInterface op, + PatternRewriter &rewriter) { + std::optional unitFlagMode; + if (auto unitFlagModeAttr = + dyn_cast_or_null(op.getUnitFlagModeAttr())) + unitFlagMode = unitFlagModeAttr.getValue(); + + std::optional tf32Mode; + if (op.supportsTf32Mode()) { + if (auto tf32ModeAttr = + dyn_cast_or_null(op.getTf32ModeAttr())) + tf32Mode = tf32ModeAttr.getValue(); + } + + std::optional satMode; + if (auto satModeAttr = + dyn_cast_or_null(op.getSatModeAttr())) + satMode = satModeAttr.getValue(); + + bool isHif8 = false; + if (auto lhsPtr = dyn_cast(op.getLhs().getType())) + isHif8 = pto::isPTOHiFloat8Type(lhsPtr.getElementType()); + + Location loc = op->getLoc(); + Value ctrlSaved = rewriter.create(loc).getResult(); + Value ctrlForOp = buildMadSemanticCtrl(loc, ctrlSaved, isHif8, tf32Mode, + satMode, op.getNDir(), rewriter); + rewriter.create(loc, ctrlForOp); + + FailureOr xt = + packMadXt(loc, op.getM(), op.getN(), op.getK(), unitFlagMode, + op.getDisableGemv(), op.initializesAccumulatorWithBias(), + op.initializesAccumulatorWithZero(), rewriter); + if (failed(xt)) + return rewriter.notifyMatchFailure(op, "failed to pack mad xt"); + + if (failed(emitMadRawOp(op, deriveMadRawKind(op), *xt, rewriter))) + return rewriter.notifyMatchFailure(op, "failed to emit mad raw op"); + + rewriter.create(loc, ctrlSaved); + rewriter.eraseOp(op); + return success(); +} + +template +class ExpandMadSemanticPattern final : public OpRewritePattern { +public: + explicit ExpandMadSemanticPattern(MLIRContext *context) + : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(SemanticOp op, + PatternRewriter &rewriter) const override { + auto semantic = dyn_cast(op.getOperation()); + if (!semantic) + return failure(); + return lowerMadSemanticOp(semantic, rewriter); + } +}; + +struct ExpandDmaLoadPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(pto::MteGmUbOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value zero = rewriter.create(loc, 0, 64); + Value one = rewriter.create(loc, 1, 64); + SmallVector loops = + collectLoopConfigs(op.getLoopCounts(), op.getLoopSrcStrides(), + op.getLoopDstStrides()); + ArrayRef hwLoops = ArrayRef(loops).take_front(2); + ArrayRef swLoops = ArrayRef(loops).drop_front(hwLoops.size()); + + Value loop1Count; + Value loop2Size = one; + if (hwLoops.size() == 2) { + rewriter.create( + loc, hwLoops[0].srcStride, hwLoops[0].dstStride); + loop2Size = hwLoops[0].count; + loop1Count = hwLoops[1].count; + rewriter.create( + loc, hwLoops[1].srcStride, hwLoops[1].dstStride); + rewriter.create(loc, loop2Size, loop1Count); + } else if (hwLoops.size() == 1) { + loop1Count = hwLoops[0].count; + rewriter.create( + loc, hwLoops[0].srcStride, hwLoops[0].dstStride); + rewriter.create(loc, loop2Size, loop1Count); + } + + Value leftPadding = op.getLeftPaddingCount(); + if (!leftPadding) + leftPadding = rewriter.create(loc, 0, 64); + Value rightPadding = op.getRightPaddingCount(); + if (!rightPadding) + rightPadding = rewriter.create(loc, 0, 64); + Value dataSelect = rewriter.create( + loc, rewriter.getI1Type(), + rewriter.getBoolAttr(static_cast(op.getPadValue()))); + + if (Value padValue = op.getPadValue()) + rewriter.create(loc, padValue); + + buildSoftwareLoopNest( + rewriter, loc, swLoops, zero, zero, + [&](Value srcOffset, Value dstOffset) { + Value source = offsetPointerByBytes(op.getSource(), srcOffset, rewriter, loc); + Value destination = + offsetPointerByBytes(op.getDestination(), dstOffset, rewriter, loc); + rewriter.create( + loc, source, destination, zero, op.getNBurst(), op.getLenBurst(), + leftPadding, rightPadding, dataSelect, op.getL2CacheCtl(), + op.getNburstSrcStride(), op.getNburstDstStride()); + }); + if (shouldRestoreDmaLoopSize(loop1Count, loop2Size)) + rewriter.create(loc, one, one); + rewriter.eraseOp(op); + return success(); + } +}; + +struct ExpandDmaStorePattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(pto::MteUbGmOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value zero = rewriter.create(loc, 0, 64); + Value one = rewriter.create(loc, 1, 64); + SmallVector loops = + collectLoopConfigs(op.getLoopCounts(), op.getLoopSrcStrides(), + op.getLoopDstStrides()); + ArrayRef hwLoops = + ArrayRef(loops).take_front(2); + ArrayRef swLoops = + ArrayRef(loops).drop_front(hwLoops.size()); + + Value loop1Count; + Value loop2Size = one; + if (hwLoops.size() == 2) { + rewriter.create( + loc, hwLoops[0].srcStride, hwLoops[0].dstStride); + loop2Size = hwLoops[0].count; + loop1Count = hwLoops[1].count; + rewriter.create( + loc, hwLoops[1].srcStride, hwLoops[1].dstStride); + rewriter.create(loc, loop2Size, loop1Count); + } else if (hwLoops.size() == 1) { + loop1Count = hwLoops[0].count; + rewriter.create( + loc, hwLoops[0].srcStride, hwLoops[0].dstStride); + rewriter.create(loc, loop2Size, loop1Count); + } + + buildSoftwareLoopNest( + rewriter, loc, swLoops, zero, zero, + [&](Value srcOffset, Value dstOffset) { + Value source = offsetPointerByBytes(op.getSource(), srcOffset, rewriter, loc); + Value destination = + offsetPointerByBytes(op.getDestination(), dstOffset, rewriter, loc); + rewriter.create( + loc, source, destination, zero, op.getNBurst(), op.getLenBurst(), + zero, op.getNburstDstStride(), op.getNburstSrcStride()); + }); + if (shouldRestoreDmaLoopSize(loop1Count, loop2Size)) + rewriter.create(loc, one, one); + rewriter.eraseOp(op); + return success(); + } +}; + +struct ExpandMteUbUbPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(pto::MteUbUbOp op, + PatternRewriter &rewriter) const override { + Value zero = rewriter.create(op.getLoc(), 0, 64); + rewriter.replaceOpWithNewOp( + op, op.getSource(), op.getDestination(), zero, op.getNBurst(), + op.getLenBurst(), op.getSrcStride(), op.getDstStride()); + return success(); + } +}; + +struct ExpandMteUbL1Pattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(pto::MteUbL1Op op, + PatternRewriter &rewriter) const override { + Value zero = rewriter.create(op.getLoc(), 0, 64); + rewriter.replaceOpWithNewOp( + op, op.getSource(), op.getDestination(), zero, op.getNBurst(), + op.getLenBurst(), op.getSrcStride(), op.getDstStride()); + return success(); + } +}; + +struct ExpandCubeLoadPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(pto::MteGmL1Op op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value zero = rewriter.create(loc, 0, 64); + Value one = rewriter.create(loc, 1, 64); + SmallVector loops = + collectLoopConfigs(op.getLoopCounts(), op.getLoopSrcStrides(), + op.getLoopDstStrides()); + ArrayRef hwLoops = + ArrayRef(loops).take_front(2); + ArrayRef swLoops = + ArrayRef(loops).drop_front(hwLoops.size()); + + Value loop1Count; + Value loop2Count = one; + if (hwLoops.size() == 2) { + rewriter.create( + loc, + packLoopPair(loc, hwLoops[0].srcStride, hwLoops[0].dstStride, + rewriter)); + loop2Count = hwLoops[0].count; + loop1Count = hwLoops[1].count; + rewriter.create( + loc, + packLoopPair(loc, hwLoops[1].srcStride, hwLoops[1].dstStride, + rewriter)); + rewriter.create( + loc, packLoopSize(loc, loop2Count, loop1Count, rewriter)); + } else if (hwLoops.size() == 1) { + loop1Count = hwLoops[0].count; + rewriter.create( + loc, + packLoopPair(loc, hwLoops[0].srcStride, hwLoops[0].dstStride, + rewriter)); + rewriter.create( + loc, packLoopSize(loc, loop2Count, loop1Count, rewriter)); + } + + SmallVector swLoopNestOrder(swLoops.rbegin(), + swLoops.rend()); + buildSoftwareLoopNest( + rewriter, loc, swLoopNestOrder, zero, zero, + [&](Value srcOffset, Value dstOffset) { + Value source = + offsetPointerByBytes(op.getSource(), srcOffset, rewriter, loc); + Value destination = offsetPointerByBytes(op.getDestination(), dstOffset, + rewriter, loc); + rewriter.create( + loc, source, destination, op.getNBurst(), op.getLenBurst(), + op.getNburstSrcStride(), op.getNburstDstStride()); + }); + if (loop1Count && (!isKnownOne(loop1Count) || !isKnownOne(loop2Count))) + rewriter.create( + loc, packLoopSize(loc, one, one, rewriter)); + rewriter.eraseOp(op); + return success(); + } +}; + +struct ExpandCubeStorePattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(pto::MteL1UbOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value zero = rewriter.create(loc, 0, 64); + SmallVector loops = + collectLoopConfigs(op.getLoopCounts(), op.getLoopSrcStrides(), + op.getLoopDstStrides()); + SmallVector swLoopNestOrder(loops.rbegin(), + loops.rend()); + buildSoftwareLoopNest( + rewriter, loc, swLoopNestOrder, zero, zero, + [&](Value srcOffset, Value dstOffset) { + Value source = + offsetPointerByBytes(op.getSource(), srcOffset, rewriter, loc); + Value destination = + offsetPointerByBytes(op.getDestination(), dstOffset, rewriter, loc); + rewriter.create( + loc, source, destination, zero, op.getNBurst(), op.getLenBurst(), + op.getNburstSrcStride(), op.getNburstDstStride()); + }); + rewriter.eraseOp(op); + return success(); + } +}; + +struct ExpandBiasLoadPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(pto::MteL1BtOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto sourceType = dyn_cast( + materializeBufferPointer(op.getSource(), rewriter, loc).getType()); + if (!sourceType) + return rewriter.notifyMatchFailure(op, "expected pointer-like source"); + + Value convControl = rewriter.create( + loc, sourceType.getElementType().isF16() ? 1 : 0, 1); + rewriter.replaceOpWithNewOp( + op, op.getSource(), op.getDestination(), convControl, op.getNBurst(), + op.getLenBurst(), op.getNburstSrcGap(), op.getNburstDstGap()); + return success(); + } +}; + +struct ExpandFpLoadPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(pto::MteL1FbOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value source = materializeBufferPointer(op.getSource(), rewriter, loc); + Value destination = + materializeBufferPointer(op.getDestination(), rewriter, loc); + if (!source || !destination) + return rewriter.notifyMatchFailure(op, "expected pointer-like operands"); + + rewriter.replaceOpWithNewOp( + op, source, destination, op.getNBurst(), + op.getLenBurst(), op.getNburstSrcGap(), op.getNburstDstGap()); + return success(); + } +}; + +struct ExpandCubeLoadFracPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(pto::MteGmL1FracOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value zero = rewriter.create(loc, 0, 64); + Value mte2NzPara = packMte2NzPara( + loc, op.getGroupCount(), op.getDstLoop2Stride(), op.getDstLoop3Stride(), + op.getDstLoop4Stride(), rewriter); + rewriter.create(loc, mte2NzPara); + + Value srcOuterStride = op.getSrcOuterStride() ? op.getSrcOuterStride() : zero; + Value source = materializeBufferPointer(op.getSource(), rewriter, loc); + Value destination = + materializeBufferPointer(op.getDestination(), rewriter, loc); + switch (op.getMode()) { + case pto::CubeLoadFracMode::Nd2nz: + rewriter.create( + loc, source, destination, zero, op.getSrcInnerStride(), + op.getL2CacheCtrl(), op.getNValue(), op.getDValue(), srcOuterStride, + op.getSmallc0En()); + break; + case pto::CubeLoadFracMode::Dn2nz: + rewriter.create( + loc, source, destination, zero, op.getSrcInnerStride(), + op.getL2CacheCtrl(), op.getNValue(), op.getDValue(), srcOuterStride, + op.getSmallc0En()); + break; + } + rewriter.eraseOp(op); + return success(); + } +}; + +struct ExpandLeftLoadPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(pto::MteL1L0aOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto sourceType = dyn_cast(op.getSource().getType()); + if (!sourceType) + return rewriter.notifyMatchFailure(op, "expected typed L1 source"); + FailureOr control = deriveLoadCbufToCaControl( + loc, op.getM(), op.getK(), sourceType.getElementType(), + op.getTranspose(), rewriter); + if (failed(control)) + return rewriter.notifyMatchFailure(op, + "failed to derive load_cbuf_to_ca control"); + auto load = rewriter.create( + loc, op.getSource(), op.getDestination(), control->mStart, + control->kStart, control->mStep, control->kStep, control->srcStride, + control->dstStride); + load->setAttr("transpose", rewriter.getBoolAttr(op.getTranspose())); + rewriter.eraseOp(op); + return success(); + } +}; + +struct ExpandRightLoadPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(pto::MteL1L0bOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto sourceType = dyn_cast(op.getSource().getType()); + if (!sourceType) + return rewriter.notifyMatchFailure(op, "expected typed L1 source"); + FailureOr control = deriveLoadCbufToCbControl( + loc, op.getK(), op.getN(), sourceType.getElementType(), + op.getTranspose(), rewriter); + if (failed(control)) + return rewriter.notifyMatchFailure(op, + "failed to derive load_cbuf_to_cb control"); + auto load = rewriter.create( + loc, op.getSource(), op.getDestination(), control->mStart, + control->kStart, control->mStep, control->kStep, control->srcStride, + control->dstStride); + load->setAttr("transpose", rewriter.getBoolAttr(op.getTranspose())); + rewriter.eraseOp(op); + return success(); + } +}; + +struct ExpandLeftLoadMxPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(pto::MteL1L0aMxOp op, + PatternRewriter &rewriter) const override { + rewriter.create(op.getLoc(), op.getSource(), + op.getDestination(), op.getM(), + op.getK()); + rewriter.eraseOp(op); + return success(); + } +}; + +struct ExpandRightLoadMxPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(pto::MteL1L0bMxOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto sourceType = dyn_cast(op.getSource().getType()); + if (!sourceType) + return rewriter.notifyMatchFailure(op, "expected typed L1 source"); + + unsigned elemBitWidth = sourceType.getElementType().getIntOrFloatBitWidth(); + if (elemBitWidth == 0 || (elemBitWidth % 8) != 0) + return rewriter.notifyMatchFailure(op, "unsupported element type"); + uint64_t elemBytes = elemBitWidth / 8; + + auto constant = [&](uint64_t value) -> Value { + return rewriter.create(loc, value, 64); + }; + auto ceilDivConst = [&](Value value, uint64_t divisor) -> Value { + Value bias = constant(divisor - 1); + Value sum = rewriter.create(loc, value, bias); + return rewriter.create(loc, sum, constant(divisor)); + }; + + Value zero = constant(0); + Value one = constant(1); + Value yStep = ceilDivConst( + rewriter.create(loc, op.getK(), constant(elemBytes)), 32); + Value stride = ceilDivConst(op.getN(), 16); + + rewriter.create( + loc, op.getSource(), op.getDestination(), zero, zero, one, yStep, stride, + stride); + rewriter.eraseOp(op); + return success(); + } +}; + +struct ExpandAccStorePattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(pto::MteL0cL1Op op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value zero = getI64Constant(loc, rewriter, 0); + Value one = getI64Constant(loc, rewriter, 1); + configureAccStoreScalarPreOps(loc, op.getPreQuant(), op.getPreQuantMode(), + op.getPreRelu(), op.getPreReluMode(), + op.getClipValue(), + getBufferElementType(op.getDestination().getType()), + rewriter); + if (Value fpc = buildAccStoreFpcValue(loc, op.getPreQuant(), + op.getPreQuantMode(), + op.getPreRelu(), + op.getPreReluMode(), rewriter)) + rewriter.create(loc, fpc); + Value originalCtrl = + configureAccStoreCtrl(loc, /*allowAtomic=*/false, std::nullopt, + std::nullopt, op.getSatMode(), rewriter); + pto::DmaLoopConfig hwLoop{one, zero, zero}; + if (Value loop3Count = op.getLoop3Count()) { + hwLoop = {loop3Count, op.getLoop3SrcStride(), op.getLoop3DstStride()}; + } + + Value channelLoop0Stride = zero; + Value nz2ndEn = zero; + Value channelSplitEn = zero; + Value nz2dnEn = zero; + if (auto mode = op.getMode()) { + switch (*mode) { + case pto::AccStoreMode::Nz2nd: + nz2ndEn = one; + break; + case pto::AccStoreMode::Nz2dn: + nz2dnEn = one; + channelLoop0Stride = op.getLoop0SrcStride() ? op.getLoop0SrcStride() : one; + break; + case pto::AccStoreMode::Nz2nz: + channelSplitEn = op.getSplit() ? op.getSplit() : zero; + break; + } + } else { + nz2ndEn = one; + } + + Value loop3Config = packLoop3Config(loc, hwLoop.count, hwLoop.srcStride, + hwLoop.dstStride, rewriter); + Value channelConfig = + packChannelConfig(loc, channelLoop0Stride, rewriter); + rewriter.create( + loc, extractConfigLow40(loc, loop3Config, rewriter), + extractConfigHigh24(loc, loop3Config, rewriter)); + rewriter.create( + loc, extractConfigLow40(loc, channelConfig, rewriter), + extractConfigHigh24(loc, channelConfig, rewriter)); + Value clipReluPre = getI64Constant(loc, rewriter, op.getClipValue() ? 1 : 0); + Value unitFlagCtrl = buildAccStoreOptionalEnumValue( + loc, + op.getUnitFlag() + ? std::optional(static_cast(*op.getUnitFlag())) + : std::nullopt, + rewriter); + Value quantPreMode = buildAccStoreOptionalEnumValue( + loc, + op.getPreQuantMode() + ? std::optional(static_cast(*op.getPreQuantMode())) + : std::nullopt, + rewriter); + Value reluPreMode = buildAccStoreOptionalEnumValue( + loc, + op.getPreReluMode() + ? std::optional(static_cast(*op.getPreReluMode())) + : std::nullopt, + rewriter); + Value xm = + packCopyMatrixCcToGmXm(loc, zero, op.getN(), op.getM(), + op.getDstStride(), rewriter); + Value xt = packCopyMatrixCcToGmXt( + loc, op.getSrcStride(), clipReluPre, unitFlagCtrl, quantPreMode, + reluPreMode, zero, nz2ndEn, channelSplitEn, nz2dnEn, + rewriter); + rewriter.create(loc, op.getSource(), + op.getDestination(), xm, xt); + if (originalCtrl) + rewriter.create(loc, originalCtrl); + rewriter.eraseOp(op); + return success(); + } +}; + +struct ExpandAccStoreGmPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(pto::MteL0cGmOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value zero = getI64Constant(loc, rewriter, 0); + Value one = getI64Constant(loc, rewriter, 1); + configureAccStoreScalarPreOps(loc, op.getPreQuant(), op.getPreQuantMode(), + op.getPreRelu(), op.getPreReluMode(), + op.getClipValue(), + getBufferElementType(op.getDestination().getType()), + rewriter); + if (Value fpc = buildAccStoreFpcValue(loc, op.getPreQuant(), + op.getPreQuantMode(), + op.getPreRelu(), + op.getPreReluMode(), rewriter)) + rewriter.create(loc, fpc); + Value originalCtrl = + configureAccStoreCtrl(loc, /*allowAtomic=*/true, op.getAtomicType(), + op.getAtomicOp(), op.getSatMode(), rewriter); + pto::DmaLoopConfig hwLoop{one, zero, zero}; + if (Value loop3Count = op.getLoop3Count()) { + hwLoop = {loop3Count, op.getLoop3SrcStride(), op.getLoop3DstStride()}; + } + + Value channelLoop0Stride = zero; + Value nz2ndEn = zero; + Value channelSplitEn = zero; + Value nz2dnEn = zero; + if (auto mode = op.getMode()) { + switch (*mode) { + case pto::AccStoreMode::Nz2nd: + nz2ndEn = one; + break; + case pto::AccStoreMode::Nz2dn: + nz2dnEn = one; + channelLoop0Stride = + op.getLoop0SrcStride() ? op.getLoop0SrcStride() : one; + break; + case pto::AccStoreMode::Nz2nz: + channelSplitEn = op.getSplit() ? op.getSplit() : zero; + break; + } + } else { + nz2ndEn = one; + } + + Value loop3Config = packLoop3Config(loc, hwLoop.count, hwLoop.srcStride, + hwLoop.dstStride, rewriter); + Value channelConfig = packChannelConfig(loc, channelLoop0Stride, rewriter); + rewriter.create( + loc, extractConfigLow40(loc, loop3Config, rewriter), + extractConfigHigh24(loc, loop3Config, rewriter)); + rewriter.create( + loc, extractConfigLow40(loc, channelConfig, rewriter), + extractConfigHigh24(loc, channelConfig, rewriter)); + Value clipReluPre = getI64Constant(loc, rewriter, op.getClipValue() ? 1 : 0); + Value unitFlagCtrl = buildAccStoreOptionalEnumValue( + loc, + op.getUnitFlag() + ? std::optional(static_cast(*op.getUnitFlag())) + : std::nullopt, + rewriter); + Value quantPreMode = buildAccStoreOptionalEnumValue( + loc, + op.getPreQuantMode() + ? std::optional(static_cast(*op.getPreQuantMode())) + : std::nullopt, + rewriter); + Value reluPreMode = buildAccStoreOptionalEnumValue( + loc, + op.getPreReluMode() + ? std::optional(static_cast(*op.getPreReluMode())) + : std::nullopt, + rewriter); + Value xm = packCopyMatrixCcToGmXm(loc, op.getSid(), op.getN(), op.getM(), + op.getDstStride(), rewriter); + Value xt = packCopyMatrixCcToGmXt( + loc, op.getSrcStride(), clipReluPre, unitFlagCtrl, quantPreMode, + reluPreMode, op.getL2CacheCtrl(), nz2ndEn, channelSplitEn, + nz2dnEn, rewriter); + rewriter.create(loc, op.getSource(), + op.getDestination(), xm, xt); + if (originalCtrl) + rewriter.create(loc, originalCtrl); + rewriter.eraseOp(op); + return success(); + } +}; + +struct ExpandAccStoreUbPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(pto::MteL0cUbOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value zero = getI64Constant(loc, rewriter, 0); + Value one = getI64Constant(loc, rewriter, 1); + configureAccStoreScalarPreOps(loc, op.getPreQuant(), op.getPreQuantMode(), + op.getPreRelu(), op.getPreReluMode(), + op.getClipValue(), + getBufferElementType(op.getDestination().getType()), + rewriter); + if (Value fpc = buildAccStoreFpcValue(loc, op.getPreQuant(), + op.getPreQuantMode(), + op.getPreRelu(), + op.getPreReluMode(), rewriter)) + rewriter.create(loc, fpc); + Value originalCtrl = + configureAccStoreCtrl(loc, /*allowAtomic=*/false, std::nullopt, + std::nullopt, op.getSatMode(), rewriter); + pto::DmaLoopConfig hwLoop{one, zero, zero}; + if (Value loop3Count = op.getLoop3Count()) { + hwLoop = {loop3Count, op.getLoop3SrcStride(), op.getLoop3DstStride()}; + } + + Value channelLoop0Stride = zero; + Value nz2ndEn = zero; + Value channelSplitEn = zero; + Value nz2dnEn = zero; + if (auto mode = op.getMode()) { + switch (*mode) { + case pto::AccStoreMode::Nz2nd: + nz2ndEn = one; + break; + case pto::AccStoreMode::Nz2dn: + nz2dnEn = one; + channelLoop0Stride = op.getLoop0SrcStride() ? op.getLoop0SrcStride() : one; + break; + case pto::AccStoreMode::Nz2nz: + channelSplitEn = op.getSplit() ? op.getSplit() : zero; + break; + } + } else { + nz2ndEn = one; + } + + Value loop3Config = packLoop3Config(loc, hwLoop.count, hwLoop.srcStride, + hwLoop.dstStride, rewriter); + Value channelConfig = + packChannelConfig(loc, channelLoop0Stride, rewriter); + rewriter.create( + loc, extractConfigLow40(loc, loop3Config, rewriter), + extractConfigHigh24(loc, loop3Config, rewriter)); + rewriter.create( + loc, extractConfigLow40(loc, channelConfig, rewriter), + extractConfigHigh24(loc, channelConfig, rewriter)); + Value clipReluPre = getI64Constant(loc, rewriter, op.getClipValue() ? 1 : 0); + Value unitFlagCtrl = buildAccStoreOptionalEnumValue( + loc, + op.getUnitFlag() + ? std::optional(static_cast(*op.getUnitFlag())) + : std::nullopt, + rewriter); + Value quantPreMode = buildAccStoreOptionalEnumValue( + loc, + op.getPreQuantMode() + ? std::optional(static_cast(*op.getPreQuantMode())) + : std::nullopt, + rewriter); + Value reluPreMode = buildAccStoreOptionalEnumValue( + loc, + op.getPreReluMode() + ? std::optional(static_cast(*op.getPreReluMode())) + : std::nullopt, + rewriter); + + Value dualDstMode = + getI64Constant(loc, rewriter, static_cast(op.getDstMode())); + Value subBlockId = op.getSubBlockid() ? op.getSubBlockid() : zero; + Value config0 = packCopyMatrixCcToGmXm(loc, zero, op.getN(), op.getM(), + op.getDstStride(), rewriter); + Value config1 = packCopyMatrixCcToUbConfig1( + loc, op.getSrcStride(), dualDstMode, subBlockId, + clipReluPre, unitFlagCtrl, quantPreMode, reluPreMode, nz2ndEn, + channelSplitEn, nz2dnEn, rewriter); + rewriter.create(loc, op.getSource(), + op.getDestination(), config0, + config1); + if (originalCtrl) + rewriter.create(loc, originalCtrl); + rewriter.eraseOp(op); + return success(); + } +}; + +struct VPTOExpandWrapperOpsPass + : public pto::impl::VPTOExpandWrapperOpsBase { + using pto::impl::VPTOExpandWrapperOpsBase< + VPTOExpandWrapperOpsPass>::VPTOExpandWrapperOpsBase; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + func::FuncOp func = getOperation(); + if (func.isExternal()) + return; + + RewritePatternSet patterns(&getContext()); + patterns.add, + ExpandMadSemanticPattern, + ExpandMadSemanticPattern, + ExpandMadSemanticPattern, + ExpandMadSemanticPattern, + ExpandMadSemanticPattern>(&getContext()); + if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr mlir::pto::createVPTOExpandWrapperOpsPass() { + return std::make_unique(); +} diff --git a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp new file mode 100644 index 000000000..9fcc97c77 --- /dev/null +++ b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp @@ -0,0 +1,8040 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// https://discourse.llvm.org/t/matchandrewrite-hiding-virtual-functions/84933/8 +#pragma GCC diagnostic ignored "-Woverloaded-virtual" + +#include "PTO/Transforms/VPTOLLVMEmitter.h" + +#include "PTO/IR/PTO.h" +#include "PTO/IR/PTOTypeUtils.h" +#include "PTO/IR/PTOSyncUtils.h" +#include "PTO/Transforms/Passes.h" + +#include "mlir/Conversion/Passes.h" +#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Export.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringSet.h" +#include "llvm/Bitcode/BitcodeWriter.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir::pto { + +void materializeVecScopeCarrierLoops(ModuleOp module); +LogicalResult applyQueriedTargetAttrs(ModuleOp module, + const VPTOEmissionOptions &options, + llvm::raw_ostream &diagOS); +LogicalResult attachAIVectorScopeMetadata(llvm::Module &llvmModule, + llvm::raw_ostream &diagOS); +void attachHIVMKernelAnnotations(llvm::Module &llvmModule); + +namespace { + +constexpr llvm::StringLiteral kPTOKernelAttrName = "pto.kernel"; +constexpr llvm::StringLiteral kLegacyPTOAICoreAttrName = "pto.aicore"; +constexpr llvm::StringLiteral kVectorSuffix = "_mix_aiv"; +constexpr llvm::StringLiteral kCubeSuffix = "_mix_aic"; + +static bool hasVPTOKernelAttr(Operation *op) { + return op->hasAttr(kPTOKernelAttrName) || + op->hasAttr(kLegacyPTOAICoreAttrName); +} + +static std::string getElementTypeFragment(Type type); +static Type getElementTypeFromVectorLike(Type type); +static std::optional getElementCountFromVectorLike(Type type); + +static Type normalizeIntegerTypeForLLVMLowering(Type type, Builder &builder) { + if (auto intType = dyn_cast(type)) { + if (!intType.isSignless()) + return builder.getIntegerType(intType.getWidth()); + return type; + } + + if (auto vecType = dyn_cast(type)) { + Type normalizedElement = + normalizeIntegerTypeForLLVMLowering(vecType.getElementType(), builder); + if (normalizedElement == vecType.getElementType()) + return type; + return VectorType::get(vecType.getShape(), normalizedElement, + vecType.getScalableDims()); + } + + return type; +} + +static Type convertVPTOType(Type type, Builder &builder) { + if (auto vecType = dyn_cast(type)) { + Type elementType = + normalizeIntegerTypeForLLVMLowering(vecType.getElementType(), builder); + return VectorType::get({vecType.getElementCount()}, elementType); + } + if (isa(type)) + return VectorType::get({256}, builder.getI1Type()); + if (isa(type)) + return VectorType::get({32}, builder.getI8Type()); + if (auto ptrType = dyn_cast(type)) { + return LLVM::LLVMPointerType::get( + builder.getContext(), + static_cast(ptrType.getMemorySpace().getAddressSpace())); + } + return normalizeIntegerTypeForLLVMLowering(type, builder); +} + +static bool hasVPTOConvertibleType(Type type) { + return isa(type); +} + +static bool hasVPTOConvertibleType(TypeRange types) { + return llvm::any_of(types, [](Type type) { return hasVPTOConvertibleType(type); }); +} + +static Value materializeVPTOCast(OpBuilder &builder, Type resultType, + ValueRange inputs, Location loc) { + if (inputs.size() != 1) + return {}; + return builder + .create(loc, TypeRange{resultType}, inputs) + .getResult(0); +} + +class VPTOTypeConverter final : public TypeConverter { +public: + explicit VPTOTypeConverter(MLIRContext *context) { + addConversion([](Type type) { return type; }); + addConversion([](Type type) -> Type { + // The conversion callback outlives this constructor, so build on demand + // from the current type context instead of capturing a local Builder. + Builder builder(type.getContext()); + return convertVPTOType(type, builder); + }); + addSourceMaterialization(materializeVPTOCast); + addTargetMaterialization(materializeVPTOCast); + addArgumentMaterialization(materializeVPTOCast); + } +}; + +struct PlannedDecl { + std::string name; + FunctionType type; +}; + +struct LoweringState { + SmallVector plannedDecls; +}; + +enum class VcvtElemKind { + Invalid, + F16, + BF16, + F32, + S8, + U8, + S16, + U16, + S32, + U32, + S64, +}; + +struct VcvtContract { + const char *intrinsic; + bool requiresRnd; + bool requiresSat; + bool requiresPart; + unsigned maskBitWidth; + bool satBeforeRnd = false; +}; + +static Value getI64Constant(OpBuilder &builder, Location loc, uint64_t value) { + return builder.create(loc, builder.getI64IntegerAttr(value)) + .getResult(); +} + +static Value getI32Constant(OpBuilder &builder, Location loc, uint64_t value) { + return builder.create(loc, builder.getI32IntegerAttr(value)) + .getResult(); +} + +[[maybe_unused]] static Value getI1Constant(OpBuilder &builder, Location loc, + bool value) { + return builder + .create( + loc, builder.getIntegerAttr(builder.getI1Type(), value ? 1 : 0)) + .getResult(); +} + +static bool isMxElementType(Type ty) { + if (auto floatType = dyn_cast(ty)) + return floatType.getWidth() == 8; + std::string typeText; + llvm::raw_string_ostream os(typeText); + ty.print(os); + os.flush(); + return StringRef(typeText).starts_with("f8"); +} + +static std::string getMadMxElementFragment(Type type) { + if (type.isF16()) + return "f16"; + if (type.isBF16()) + return "bf16"; + + std::string typeText; + llvm::raw_string_ostream os(typeText); + type.print(os); + os.flush(); + + std::string lower = StringRef(typeText).lower(); + if (StringRef(lower).contains("e4m3")) + return "e4m3"; + if (StringRef(lower).contains("e5m2")) + return "e5m2"; + if (StringRef(lower).contains("hif4")) + return "hif4"; + if (StringRef(lower).contains("e2m1x2")) + return "e2m1x2"; + if (StringRef(lower).contains("e1m2x2")) + return "e1m2x2"; + return {}; +} + +static FailureOr buildMadMxCalleeName(MLIRContext *context, + Type lhsElem, Type rhsElem) { + std::string lhs = getMadMxElementFragment(lhsElem); + std::string rhs = getMadMxElementFragment(rhsElem); + if (lhs.empty() || rhs.empty()) + return failure(); + return StringAttr::get(context, "llvm.hivm.MMAD.MX." + lhs + rhs).getValue(); +} + +static std::string getMadRhsFragment(Type type) { + if (type.isF16()) + return "f16"; + if (type.isBF16()) + return "bf16"; + if (type.isF32()) + return "f32"; + if (auto intType = dyn_cast(type)) { + if (intType.isSigned() && intType.getWidth() == 4) + return "s4"; + if (intType.isSigned() && intType.getWidth() == 8) + return "s8"; + if (intType.isUnsigned() && intType.getWidth() == 2) + return "u2"; + } + + std::string typeText; + llvm::raw_string_ostream os(typeText); + type.print(os); + os.flush(); + std::string lower = StringRef(typeText).lower(); + if (StringRef(lower).contains("e8m0")) + return "e8m0"; + return {}; +} + +static bool isMadE4M3ElementType(Type type) { + return type.isFloat8E4M3() || type.isFloat8E4M3FN() || + type.isFloat8E4M3FNUZ() || type.isFloat8E4M3B11FNUZ(); +} + +static std::string getMadDstFragment(Type type) { + if (type.isF16()) + return "f16"; + if (type.isF32()) + return "f32"; + if (auto intType = dyn_cast(type)) { + if (intType.isSigned() && intType.getWidth() == 32) + return "s32"; + } + return {}; +} + +static FailureOr buildMadTypedCalleeName(MLIRContext *context, + Type lhsElem, Type rhsElem, + Type dstElem) { + std::string rhs = getMadRhsFragment(rhsElem); + std::string dst = getMadDstFragment(dstElem); + if (lhsElem.isF16() && rhs == "f16" && dst == "f32") + return StringAttr::get(context, "llvm.hivm.MAD.f162f32.c310").getValue(); + if (lhsElem.isF16() && rhs == "f16" && dst == "f16") + return StringAttr::get(context, "llvm.hivm.MAD.f162f16").getValue(); + if (lhsElem.isF16() && rhs == "f16" && dst == "s32") + return StringAttr::get(context, "llvm.hivm.MAD.f162s32.1952").getValue(); + if (lhsElem.isBF16() && rhs == "bf16" && dst == "f32") + return StringAttr::get(context, "llvm.hivm.MAD.bf162f32.c310").getValue(); + if (lhsElem.isF32() && rhs == "f32" && dst == "f32") + return StringAttr::get(context, "llvm.hivm.MAD.f322f32.c310").getValue(); + if (isMadE4M3ElementType(lhsElem) && isMadE4M3ElementType(rhsElem) && + dst == "f32") + return StringAttr::get(context, "llvm.hivm.MAD.e4m3e4m3.c310").getValue(); + if (pto::isPTOHiFloat8Type(lhsElem) && pto::isPTOHiFloat8Type(rhsElem) && + dst == "f32") + return StringAttr::get(context, "llvm.hivm.MAD.e4m3e4m3.c310").getValue(); + if (lhsElem.isF16() && rhs == "s4") + return StringAttr::get(context, "llvm.hivm.MAD.f16s4.c310").getValue(); + if (lhsElem.isF16() && rhs == "s8") + return StringAttr::get(context, "llvm.hivm.MAD.f16s8.c310").getValue(); + if (lhsElem.isF16() && rhs == "u2") + return StringAttr::get(context, "llvm.hivm.MAD.f16u2").getValue(); + if (lhsElem.isF16() && rhs == "e8m0") + return StringAttr::get(context, "llvm.hivm.MAD.f16e8m0.c310").getValue(); + return failure(); +} + +static FailureOr buildLaneTypedCallee(MLIRContext *context, + Type resultType, + StringRef stem, + StringRef suffix) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(resultType)); + auto lanes = getElementCountFromVectorLike(resultType); + if (vec.empty() || !lanes) + return failure(); + + return StringAttr::get(context, "llvm.hivm." + stem.str() + ".v" + + std::to_string(*lanes) + vec + + suffix.str()) + .getValue(); +} + +static FailureOr buildLaneTypedCalleeFromInput(MLIRContext *context, + Type inputType, + StringRef stem, + StringRef suffix) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(inputType)); + auto lanes = getElementCountFromVectorLike(inputType); + if (vec.empty() || !lanes) + return failure(); + + return StringAttr::get(context, "llvm.hivm." + stem.str() + ".v" + + std::to_string(*lanes) + vec + + suffix.str()) + .getValue(); +} + +static std::string getElementTypeFragment(Type type) { + if (type.isF16()) + return "f16"; + if (type.isBF16()) + return "bf16"; + if (type.isF32()) + return "f32"; + if (auto intType = dyn_cast(type)) + return (intType.isUnsigned() ? "u" : "s") + std::to_string(intType.getWidth()); + return {}; +} + +static std::string getL0LoadElementFragment(Type type) { + std::string elem = getElementTypeFragment(type); + if (!elem.empty()) + return elem; + + std::string typeText; + llvm::raw_string_ostream os(typeText); + type.print(os); + os.flush(); + std::string lower = StringRef(typeText).lower(); + if (StringRef(lower).contains("e4m3") || + StringRef(lower).contains("e5m2") || + StringRef(lower).contains("e8m0") || + StringRef(lower).contains("hif8")) + return "s8"; + return {}; +} + +static std::string getVbrScalarFragment(Type type) { + if (type.isF16()) + return "f16"; + if (type.isBF16()) + return "bf16"; + if (type.isF32()) + return "f32"; + if (auto intType = dyn_cast(type)) + return (intType.isUnsigned() ? "u" : "s") + std::to_string(intType.getWidth()); + return {}; +} + +static Type getElementTypeFromVectorLike(Type type) { + if (auto vecType = dyn_cast(type)) + return vecType.getElementType(); + if (auto vecType = dyn_cast(type)) + return vecType.getElementType(); + return {}; +} + +static std::optional getElementCountFromVectorLike(Type type) { + if (auto vecType = dyn_cast(type)) + return vecType.getElementCount(); + if (auto vecType = dyn_cast(type)) { + if (vecType.getRank() != 1) + return std::nullopt; + return vecType.getShape().front(); + } + return std::nullopt; +} + +static Value castIntegerLikeTo(Operation *anchor, Value value, Type targetType) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + + if (value.getType() == targetType) + return value; + + auto targetInt = dyn_cast(targetType); + if (value.getType().isIndex() && targetInt) + return builder.create(anchor->getLoc(), targetType, value); + if (auto sourceInt = dyn_cast(value.getType())) { + if (targetInt) { + if (sourceInt.getWidth() < targetInt.getWidth()) + return builder.create(anchor->getLoc(), targetType, value); + if (sourceInt.getWidth() > targetInt.getWidth()) + return builder.create(anchor->getLoc(), targetType, value); + return value; + } + if (targetType.isIndex()) + return builder.create(anchor->getLoc(), targetType, value); + } + + return {}; +} + +static FailureOr reinterpretPointerToAddrSpace(Operation *anchor, + Value value, + unsigned targetAddressSpace) { + auto sourcePtrType = dyn_cast(value.getType()); + if (!sourcePtrType) + return failure(); + if (sourcePtrType.getAddressSpace() == targetAddressSpace) + return value; + + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + Location loc = anchor->getLoc(); + Value asInt = builder.create(loc, builder.getI64Type(), value); + Type targetPtrType = + LLVM::LLVMPointerType::get(anchor->getContext(), targetAddressSpace); + return builder.create(loc, targetPtrType, asInt).getResult(); +} + +static FailureOr normalizeVdupScalarOperand(OpBuilder &builder, Location loc, + Value input, + Type resultType) { + auto intType = dyn_cast(input.getType()); + if (!intType || intType.getWidth() != 8) + return input; + + Type resultElemType = getElementTypeFromVectorLike(resultType); + std::string resultElemFragment = getElementTypeFragment(resultElemType); + if (resultElemFragment != "s8" && resultElemFragment != "u8") + return input; + + if (intType.isSignless()) + return input; + + Type signlessType = builder.getIntegerType(intType.getWidth()); + return builder + .create(loc, TypeRange{signlessType}, input) + .getResult(0); +} + +static Value normalizeByteScalarOperandForHivmCall(OpBuilder &builder, Location loc, + Value input, + Type semanticElementType) { + auto intType = dyn_cast(input.getType()); + if (!intType || intType.getWidth() != 8) + return input; + + Type i16Type = builder.getIntegerType(16); + auto semanticIntType = dyn_cast(semanticElementType); + if (semanticIntType && semanticIntType.isUnsigned()) + return builder.create(loc, i16Type, input).getResult(); + return builder.create(loc, i16Type, input).getResult(); +} + +static bool isCompatibleScalarForSemanticType(Type semanticType, + Type scalarType) { + if (semanticType == scalarType) + return true; + + auto semanticInt = dyn_cast(semanticType); + auto scalarInt = dyn_cast(scalarType); + if (!semanticInt || !scalarInt || semanticInt.getWidth() != scalarInt.getWidth()) + return false; + + if (semanticInt.isSigned()) + return scalarInt.isSigned() || scalarInt.isSignless(); + if (semanticInt.isUnsigned()) + return scalarInt.isUnsigned() || scalarInt.isSignless(); + return scalarInt.isSignless(); +} + +static std::string getCopyElementFragment(Type elementType) { + if (!elementType) + return {}; + if (elementType.isF16()) + return "f16"; + if (elementType.isBF16()) + return "bf16"; + if (elementType.isF32()) + return "f32"; + // Handle FP8 family (e4m3/e5m2/e8m0/hif8) used by cube-matmul/mad_mx. + std::string typeText; + llvm::raw_string_ostream os(typeText); + elementType.print(os); + os.flush(); + std::string lower = StringRef(typeText).lower(); + if (StringRef(lower).contains("e4m3")) + return "e4m3"; + if (StringRef(lower).contains("e5m2")) + return "e5m2"; + if (StringRef(lower).contains("e8m0")) + return "e8m0"; + if (StringRef(lower).contains("hif8")) + return "hif8"; + if (auto intType = dyn_cast(elementType)) { + switch (intType.getWidth()) { + case 8: + return intType.isUnsigned() ? "u8" : "s8"; + case 16: + return intType.isUnsigned() ? "u16" : "s16"; + case 32: + return intType.isUnsigned() ? "u32" : "s32"; + default: + return {}; + } + } + return {}; +} + +static std::string getNd2NzCopyElementFragment(Type elementType) { + if (!elementType) + return {}; + std::string typeText; + llvm::raw_string_ostream os(typeText); + elementType.print(os); + os.flush(); + std::string lower = StringRef(typeText).lower(); + if (StringRef(lower).contains("e4m3") || StringRef(lower).contains("e5m2") || + StringRef(lower).contains("e8m0") || StringRef(lower).contains("hif8")) + return "U8"; + + if (elementType.isF16() || elementType.isBF16()) + return "U16"; + if (elementType.isF32()) + return "U32"; + if (auto intType = dyn_cast(elementType)) { + switch (intType.getWidth()) { + case 8: + return "U8"; + case 16: + return "U16"; + case 32: + return "U32"; + default: + return {}; + } + } + return {}; +} + +static std::optional parsePredicatePatternImmediate(StringRef pattern) { + if (pattern == "PAT_ALL") + return 0; + if (pattern == "PAT_VL1") + return 1; + if (pattern == "PAT_VL2") + return 2; + if (pattern == "PAT_VL3") + return 3; + if (pattern == "PAT_VL4") + return 4; + if (pattern == "PAT_VL8") + return 5; + if (pattern == "PAT_VL16") + return 6; + if (pattern == "PAT_VL32") + return 7; + if (pattern == "PAT_VL64") + return 8; + if (pattern == "PAT_VL128") + return 9; + if (pattern == "PAT_M3") + return 10; + if (pattern == "PAT_M4") + return 11; + if (pattern == "PAT_H") + return 12; + if (pattern == "PAT_Q") + return 13; + if (pattern == "PAT_ALLF") + return 15; + return std::nullopt; +} + +static std::optional parseHiLoPartImmediate(StringRef part) { + if (part == "LOWER") + return 0; + if (part == "HIGHER") + return 1; + return std::nullopt; +} + +static std::optional parseRoundModeImmediate(StringRef roundMode) { + if (roundMode == "R" || roundMode == "ROUND_R") + return 0; + if (roundMode == "A" || roundMode == "ROUND_A") + return 1; + if (roundMode == "F" || roundMode == "ROUND_F") + return 2; + if (roundMode == "C" || roundMode == "ROUND_C") + return 3; + if (roundMode == "Z" || roundMode == "ROUND_Z") + return 4; + if (roundMode == "O" || roundMode == "ROUND_O") + return 5; + return std::nullopt; +} + +static std::optional parseSaturationImmediate(StringRef sat) { + if (sat == "SAT") + return 1; + if (sat == "NOSAT") + return 0; + return std::nullopt; +} + +static std::optional parsePartImmediate(StringRef part) { + if (part == "EVEN" || part == "PART_EVEN") + return 0; + if (part == "ODD" || part == "PART_ODD") + return 1; + return std::nullopt; +} + +static std::optional parseVcvtPartImmediate(StringRef part) { + if (part == "EVEN" || part == "PART_EVEN" || part == "P0" || + part == "PART_P0") + return 0; + if (part == "ODD" || part == "PART_ODD" || part == "P1" || + part == "PART_P1") + return 1; + if (part == "P2" || part == "PART_P2") + return 2; + if (part == "P3" || part == "PART_P3") + return 3; + return std::nullopt; +} + +static std::optional parsePredicateStoreDistImmediate(StringRef dist) { + if (dist == "NORM") + return 0; + if (dist == "PK") + return 1; + return std::nullopt; +} + +static std::optional parsePredicateLoadDistImmediate(StringRef dist) { + if (dist.empty() || dist == "NORM") + return 0; + if (dist == "US") + return 1; + if (dist == "DS") + return 2; + return std::nullopt; +} + +static std::optional parsePostModeImmediate(StringRef mode) { + if (mode == "NO_POST_UPDATE") + return 0; + if (mode == "POST_UPDATE") + return 1; + return std::nullopt; +} + +static std::optional parsePipeImmediate(StringRef pipe) { + if (pipe == "PIPE_S") + return 0; + if (pipe == "PIPE_V") + return 1; + if (pipe == "PIPE_M") + return 2; + if (pipe == "PIPE_MTE1") + return 3; + if (pipe == "PIPE_MTE2") + return 4; + if (pipe == "PIPE_MTE3") + return 5; + if (pipe == "PIPE_ALL") + return 6; + if (pipe == "PIPE_MTE4") + return 7; + if (pipe == "PIPE_MTE5") + return 8; + if (pipe == "PIPE_V2") + return 9; + if (pipe == "PIPE_FIX") + return 10; + if (pipe == "VIRTUAL_PIPE_MTE2_L1A") + return 11; + if (pipe == "VIRTUAL_PIPE_MTE2_L1B") + return 12; + return std::nullopt; +} + +static std::optional parseEventImmediate(StringRef event) { + if (!event.consume_front("EVENT_ID")) + return std::nullopt; + uint64_t value = 0; + if (event.getAsInteger(10, value)) + return std::nullopt; + return value; +} + +static std::optional parseSprImmediate(StringRef spr) { + if (spr == "AR") + return 74; + return std::nullopt; +} + +static std::optional getDistElementWidth(Type type) { + if (auto intType = dyn_cast(type)) + return intType.getWidth(); + if (type.isF16() || type.isBF16()) + return 16; + if (type.isF32()) + return 32; + if (type.isF64()) + return 64; + return std::nullopt; +} + +static VcvtElemKind classifyVcvtElemType(Type type) { + if (type.isF16()) + return VcvtElemKind::F16; + if (type.isBF16()) + return VcvtElemKind::BF16; + if (type.isF32()) + return VcvtElemKind::F32; + if (auto intType = dyn_cast(type)) { + switch (intType.getWidth()) { + case 8: + return intType.isUnsigned() ? VcvtElemKind::U8 : VcvtElemKind::S8; + case 16: + return intType.isUnsigned() ? VcvtElemKind::U16 : VcvtElemKind::S16; + case 32: + return intType.isUnsigned() ? VcvtElemKind::U32 : VcvtElemKind::S32; + case 64: + return intType.isUnsigned() ? VcvtElemKind::Invalid : VcvtElemKind::S64; + default: + return VcvtElemKind::Invalid; + } + } + return VcvtElemKind::Invalid; +} + +static std::optional lookupVcvtContract(VcvtElemKind src, + VcvtElemKind dst) { + switch (src) { + case VcvtElemKind::F32: + switch (dst) { + case VcvtElemKind::F16: + return VcvtContract{"llvm.hivm.vcvtff.f322f16.x", true, true, true, 32}; + case VcvtElemKind::BF16: + return VcvtContract{"llvm.hivm.vcvtff.f322bf16.x", true, true, true, 32}; + case VcvtElemKind::S16: + return VcvtContract{"llvm.hivm.vcvtfi.f322s16.x", true, true, true, 32}; + case VcvtElemKind::S32: + return VcvtContract{"llvm.hivm.vcvtfi.f322s32.x", true, true, false, 32}; + case VcvtElemKind::S64: + return VcvtContract{"llvm.hivm.vcvtfi.f322s64.x", true, true, true, 32}; + default: + return std::nullopt; + } + case VcvtElemKind::F16: + switch (dst) { + case VcvtElemKind::F32: + return VcvtContract{"llvm.hivm.vcvtff.f162f32.x", false, false, true, 16}; + case VcvtElemKind::S32: + return VcvtContract{"llvm.hivm.vcvtfi.f162s32.x", true, false, true, 16}; + case VcvtElemKind::S16: + return VcvtContract{"llvm.hivm.vcvtfi.f162s16.x", true, true, false, 16}; + case VcvtElemKind::S8: + return VcvtContract{"llvm.hivm.vcvtfi.f162s8.x", true, true, true, 16}; + case VcvtElemKind::U8: + return VcvtContract{"llvm.hivm.vcvtfi.f162u8.x", true, true, true, 16}; + default: + return std::nullopt; + } + case VcvtElemKind::BF16: + switch (dst) { + case VcvtElemKind::F16: + return VcvtContract{"llvm.hivm.vcvtff.bf162f16.x", true, true, false, 16, + true}; + case VcvtElemKind::F32: + return VcvtContract{"llvm.hivm.vcvtff.bf162f32.x", false, false, true, 16}; + case VcvtElemKind::S32: + return VcvtContract{"llvm.hivm.vcvtfi.bf162s32.x", true, true, true, 16}; + default: + return std::nullopt; + } + case VcvtElemKind::U8: + switch (dst) { + case VcvtElemKind::F16: + return VcvtContract{"llvm.hivm.vcvtif.u82f16.x", false, false, true, 8}; + case VcvtElemKind::U16: + return VcvtContract{"llvm.hivm.vcvtii.u82u16.x", false, false, true, 8}; + case VcvtElemKind::U32: + return VcvtContract{"llvm.hivm.vcvtii.u82u32.x", false, false, true, 8}; + default: + return std::nullopt; + } + case VcvtElemKind::S8: + switch (dst) { + case VcvtElemKind::F16: + return VcvtContract{"llvm.hivm.vcvtif.s82f16.x", false, false, true, 8}; + case VcvtElemKind::S16: + return VcvtContract{"llvm.hivm.vcvtii.s82s16.x", false, false, true, 8}; + case VcvtElemKind::S32: + return VcvtContract{"llvm.hivm.vcvtii.s82s32.x", false, false, true, 8}; + default: + return std::nullopt; + } + case VcvtElemKind::U16: + switch (dst) { + case VcvtElemKind::U8: + return VcvtContract{"llvm.hivm.vcvtii.u162u8.x", false, true, true, 16}; + case VcvtElemKind::U32: + return VcvtContract{"llvm.hivm.vcvtii.u162u32.x", false, false, true, 16}; + default: + return std::nullopt; + } + case VcvtElemKind::S16: + switch (dst) { + case VcvtElemKind::F16: + return VcvtContract{"llvm.hivm.vcvtif.s162f16.x", true, false, false, 16}; + case VcvtElemKind::F32: + return VcvtContract{"llvm.hivm.vcvtif.s162f32.x", false, false, true, 16}; + case VcvtElemKind::U8: + return VcvtContract{"llvm.hivm.vcvtii.s162u8.x", false, true, true, 16}; + case VcvtElemKind::U32: + return VcvtContract{"llvm.hivm.vcvtii.s162u32.x", false, false, true, 16}; + case VcvtElemKind::S32: + return VcvtContract{"llvm.hivm.vcvtii.s162s32.x", false, false, true, 16}; + default: + return std::nullopt; + } + case VcvtElemKind::U32: + switch (dst) { + case VcvtElemKind::U8: + return VcvtContract{"llvm.hivm.vcvtii.u322u8.x", false, true, true, 32}; + case VcvtElemKind::U16: + return VcvtContract{"llvm.hivm.vcvtii.u322u16.x", false, true, true, 32}; + case VcvtElemKind::S16: + return VcvtContract{"llvm.hivm.vcvtii.u322s16.x", false, true, true, 32}; + default: + return std::nullopt; + } + case VcvtElemKind::S32: + switch (dst) { + case VcvtElemKind::F32: + return VcvtContract{"llvm.hivm.vcvtif.s322f32.x", true, false, false, 32}; + case VcvtElemKind::U8: + return VcvtContract{"llvm.hivm.vcvtii.s322u8.x", false, true, true, 32}; + case VcvtElemKind::U16: + return VcvtContract{"llvm.hivm.vcvtii.s322u16.x", false, true, true, 32}; + case VcvtElemKind::S16: + return VcvtContract{"llvm.hivm.vcvtii.s322s16.x", false, true, true, 32}; + case VcvtElemKind::S64: + return VcvtContract{"llvm.hivm.vcvtii.s322s64.x", false, false, true, 32}; + default: + return std::nullopt; + } + case VcvtElemKind::S64: + switch (dst) { + case VcvtElemKind::F32: + return VcvtContract{"llvm.hivm.vcvtif.s642f32.x", true, false, true, 32}; + case VcvtElemKind::S32: + return VcvtContract{"llvm.hivm.vcvtii.s642s32.x", false, true, true, 32}; + default: + return std::nullopt; + } + case VcvtElemKind::Invalid: + return std::nullopt; + } + return std::nullopt; +} + +// VSQZ #st hint must only be set when the compacted vector feeds VSTUR. +// Emitting #st=1 without a matching VSTUR consumer can deadlock hardware queues. +static uint64_t determineVsqzStoreHint(pto::VsqzOp vsqz) { + Value result = vsqz.getResult(); + for (Operation *user : result.getUsers()) { + auto vstur = dyn_cast(user); + if (!vstur) + continue; + if (vstur.getValue() == result) + return 1; + } + return 0; +} + +static std::optional parseLoadDistImmediate(StringRef dist, + Type elementType) { + auto width = getDistElementWidth(elementType); + if (dist.empty() || dist == "NORM") + return 0; + if (!width) + return std::nullopt; + if (dist == "BRC_B8") + return std::optional(1); + if (dist == "BRC_B16") + return std::optional(2); + if (dist == "BRC_B32") + return std::optional(3); + if (dist == "US_B8") + return std::optional(6); + if (dist == "US_B16") + return std::optional(7); + if (dist == "DS_B8") + return std::optional(8); + if (dist == "DS_B16") + return std::optional(9); + if (dist == "UNPK_B8") + return std::optional(13); + if (dist == "UNPK_B16") + return std::optional(14); + if (dist == "UNPK_B32") + return std::optional(18); + if (dist == "BRC_BLK") + return 15; + if (dist == "E2B_B16") + return std::optional(16); + if (dist == "E2B_B32") + return std::optional(17); + if (dist == "UNPK4") + return *width == 8 ? std::optional(20) : std::nullopt; + if (dist == "SPLT4CHN") + return *width == 8 ? std::optional(21) : std::nullopt; + if (dist == "SPLT2CHN_B8") + return std::optional(22); + if (dist == "SPLT2CHN_B16") + return std::optional(23); + return std::nullopt; +} + +static std::optional parseLoadX2DistImmediate(StringRef dist, + Type elementType) { + auto width = getDistElementWidth(elementType); + if (dist == "BDINTLV") + return 10; + if (!width) + return std::nullopt; + if (dist == "DINTLV_B8") + return std::optional(11); + if (dist == "DINTLV_B16") + return std::optional(12); + if (dist == "DINTLV_B32") + return std::optional(19); + return std::nullopt; +} + +static std::optional parseStoreDistImmediate(StringRef dist, + Type elementType) { + auto width = getDistElementWidth(elementType); + if (dist.empty()) { + if (!width) + return std::nullopt; + if (*width == 8) + return 0; + if (*width == 16) + return 1; + if (*width == 32) + return 2; + return std::nullopt; + } + if (dist == "NORM_B8") + return std::optional(0); + if (dist == "NORM_B16") + return std::optional(1); + if (dist == "NORM_B32") + return std::optional(2); + if (dist == "1PT_B8") + return std::optional(3); + if (dist == "1PT_B16") + return std::optional(4); + if (dist == "1PT_B32") + return std::optional(5); + if (dist == "PK_B16") + return std::optional(6); + if (dist == "PK_B32") + return std::optional(7); + if (dist == "PK_B64") + return std::optional(10); + if (dist == "PK4_B32") + return std::optional(12); + if (dist == "MRG4CHN_B8") + return std::optional(13); + if (dist == "MRG2CHN_B8") + return std::optional(14); + if (dist == "MRG2CHN_B16") + return std::optional(15); + return std::nullopt; +} + +static std::optional parseStoreX2DistImmediate(StringRef dist, + Type elementType) { + auto width = getDistElementWidth(elementType); + if (!width) + return std::nullopt; + if (dist == "INTLV_B8") + return std::optional(8); + if (dist == "INTLV_B16") + return std::optional(9); + if (dist == "INTLV_B32") + return std::optional(11); + return std::nullopt; +} + +static Value packBlockRepeatStride(Operation *anchor, Value blockStride, + Value repeatStride) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + + Value blockI32 = castIntegerLikeTo(anchor, blockStride, builder.getI32Type()); + Value repeatI32 = + castIntegerLikeTo(anchor, repeatStride, builder.getI32Type()); + if (!blockI32 || !repeatI32) + return {}; + + auto c16 = builder.create(anchor->getLoc(), 16, 32); + auto blockShifted = + builder.create(anchor->getLoc(), blockI32, c16); + return builder + .create(anchor->getLoc(), blockShifted, repeatI32) + .getResult(); +} + +static std::optional parseOrderImmediate(StringRef order) { + if (order.empty() || order == "ASC") + return 0; + if (order == "DESC") + return 1; + return std::nullopt; +} + +static FailureOr packLoopPair(Operation *anchor, Value low, Value high) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + + Value lowI64 = castIntegerLikeTo(anchor, low, builder.getI64Type()); + Value highI64 = castIntegerLikeTo(anchor, high, builder.getI64Type()); + if (!lowI64 || !highI64) + return failure(); + + Value shift = getI64Constant(builder, anchor->getLoc(), 40); + Value highShifted = + builder.create(anchor->getLoc(), highI64, shift).getResult(); + return builder.create(anchor->getLoc(), highShifted, lowI64) + .getResult(); +} + +static FailureOr packLoopSize(Operation *anchor, Value loop2, Value loop1) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + + Value loop2I64 = castIntegerLikeTo(anchor, loop2, builder.getI64Type()); + Value loop1I64 = castIntegerLikeTo(anchor, loop1, builder.getI64Type()); + if (!loop2I64 || !loop1I64) + return failure(); + + Value shift = getI64Constant(builder, anchor->getLoc(), 21); + Value loop2Shifted = + builder.create(anchor->getLoc(), loop2I64, shift).getResult(); + return builder.create(anchor->getLoc(), loop2Shifted, loop1I64) + .getResult(); +} + +static FailureOr +packCopyGmToUbConfig0(Operation *anchor, ValueRange operands) { + if (operands.size() != 11) + return failure(); + + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + Location loc = anchor->getLoc(); + + auto getI64Operand = [&](unsigned idx) -> Value { + return castIntegerLikeTo(anchor, operands[idx], builder.getI64Type()); + }; + + Value sid = getI64Operand(2); + Value nBurst = getI64Operand(3); + Value lenBurst = getI64Operand(4); + Value leftPadding = getI64Operand(5); + Value rightPadding = getI64Operand(6); + Value dataSelect = castIntegerLikeTo(anchor, operands[7], builder.getI64Type()); + Value cacheCtl = getI64Operand(8); + if (!sid || !nBurst || !lenBurst || !leftPadding || !rightPadding || + !dataSelect || !cacheCtl) + return failure(); + + auto shl = [&](Value value, uint64_t amount) -> Value { + return builder.create(loc, value, + getI64Constant(builder, loc, amount)); + }; + auto bitOr = [&](Value lhs, Value rhs) -> Value { + return builder.create(loc, lhs, rhs); + }; + + Value config = sid; + config = bitOr(config, shl(nBurst, 4)); + config = bitOr(config, shl(lenBurst, 25)); + config = bitOr(config, shl(leftPadding, 46)); + config = bitOr(config, shl(rightPadding, 52)); + config = bitOr(config, shl(dataSelect, 58)); + config = bitOr(config, shl(cacheCtl, 60)); + return config; +} + +static FailureOr +packCopyGmToUbConfig1(Operation *anchor, ValueRange operands) { + if (operands.size() != 11) + return failure(); + return packLoopPair(anchor, operands[9], operands[10]); +} + +[[maybe_unused]] static FailureOr +packCopyGmToUbConfig0(Operation *anchor, Value sid, Value nBurst, + Value lenBurst, Value leftPadding, Value rightPadding, + Value dataSelect, Value cacheCtl) { + SmallVector operands(11); + operands[2] = sid; + operands[3] = nBurst; + operands[4] = lenBurst; + operands[5] = leftPadding; + operands[6] = rightPadding; + operands[7] = dataSelect; + operands[8] = cacheCtl; + return packCopyGmToUbConfig0(anchor, operands); +} + +static FailureOr +packCopyUbToGmConfig0(Operation *anchor, ValueRange operands) { + if (operands.size() != 8) + return failure(); + + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + Location loc = anchor->getLoc(); + + auto getI64Operand = [&](unsigned idx) -> Value { + return castIntegerLikeTo(anchor, operands[idx], builder.getI64Type()); + }; + + Value sid = getI64Operand(2); + Value nBurst = getI64Operand(3); + Value lenBurst = getI64Operand(4); + Value reserved = getI64Operand(5); + if (!sid || !nBurst || !lenBurst || !reserved) + return failure(); + + auto shl = [&](Value value, uint64_t amount) -> Value { + return builder.create(loc, value, + getI64Constant(builder, loc, amount)); + }; + auto bitOr = [&](Value lhs, Value rhs) -> Value { + return builder.create(loc, lhs, rhs); + }; + + Value config = sid; + config = bitOr(config, shl(nBurst, 4)); + config = bitOr(config, shl(lenBurst, 25)); + config = bitOr(config, shl(reserved, 60)); + return config; +} + +static FailureOr +packCopyUbToGmConfig1(Operation *anchor, ValueRange operands) { + if (operands.size() != 8) + return failure(); + return packLoopPair(anchor, operands[6], operands[7]); +} + +[[maybe_unused]] static FailureOr +packCopyUbToGmConfig0(Operation *anchor, Value sid, Value nBurst, + Value lenBurst, Value reserved) { + SmallVector operands(8); + operands[2] = sid; + operands[3] = nBurst; + operands[4] = lenBurst; + operands[5] = reserved; + return packCopyUbToGmConfig0(anchor, operands); +} + +static FailureOr +packCopyUbToUbConfig(Operation *anchor, ValueRange operands) { + if (operands.size() != 7) + return failure(); + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + Location loc = anchor->getLoc(); + + auto getI64Operand = [&](unsigned idx) -> Value { + return castIntegerLikeTo(anchor, operands[idx], builder.getI64Type()); + }; + + Value nBurst = getI64Operand(3); + Value lenBurst = getI64Operand(4); + Value srcStride = getI64Operand(5); + Value dstStride = getI64Operand(6); + if (!nBurst || !lenBurst || !srcStride || !dstStride) + return failure(); + + auto shl = [&](Value value, uint64_t amount) -> Value { + return builder.create(loc, value, + getI64Constant(builder, loc, amount)); + }; + auto bitOr = [&](Value lhs, Value rhs) -> Value { + return builder.create(loc, lhs, rhs); + }; + + Value config = nBurst; + config = bitOr(config, shl(lenBurst, 16)); + config = bitOr(config, shl(srcStride, 32)); + config = bitOr(config, shl(dstStride, 48)); + return config; +} + +static FailureOr +packCopyCbufToUbConfig(Operation *anchor, ValueRange operands) { + if (operands.size() != 7) + return failure(); + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + Location loc = anchor->getLoc(); + + auto getI64Operand = [&](unsigned idx) -> Value { + return castIntegerLikeTo(anchor, operands[idx], builder.getI64Type()); + }; + + Value sid = getI64Operand(2); + Value nBurst = getI64Operand(3); + Value lenBurst = getI64Operand(4); + Value srcStride = getI64Operand(5); + Value dstStride = getI64Operand(6); + if (!sid || !nBurst || !lenBurst || !srcStride || !dstStride) + return failure(); + + auto shl = [&](Value value, uint64_t amount) -> Value { + return builder.create(loc, value, + getI64Constant(builder, loc, amount)); + }; + auto bitOr = [&](Value lhs, Value rhs) -> Value { + return builder.create(loc, lhs, rhs); + }; + + Value config = sid; + config = bitOr(config, shl(nBurst, 4)); + config = bitOr(config, shl(lenBurst, 16)); + config = bitOr(config, shl(srcStride, 32)); + config = bitOr(config, shl(dstStride, 48)); + return config; +} + +static FailureOr +packCopyUbToCbufConfig(Operation *anchor, ValueRange operands) { + if (operands.size() != 7) + return failure(); + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + Location loc = anchor->getLoc(); + + auto getI64Operand = [&](unsigned idx) -> Value { + return castIntegerLikeTo(anchor, operands[idx], builder.getI64Type()); + }; + + Value sid = getI64Operand(2); + Value nBurst = getI64Operand(3); + Value lenBurst = getI64Operand(4); + Value srcStride = getI64Operand(5); + Value dstStride = getI64Operand(6); + if (!sid || !nBurst || !lenBurst || !srcStride || !dstStride) + return failure(); + + auto shl = [&](Value value, uint64_t amount) -> Value { + return builder.create(loc, value, + getI64Constant(builder, loc, amount)); + }; + auto bitOr = [&](Value lhs, Value rhs) -> Value { + return builder.create(loc, lhs, rhs); + }; + + Value config = sid; + config = bitOr(config, shl(nBurst, 4)); + config = bitOr(config, shl(lenBurst, 16)); + config = bitOr(config, shl(srcStride, 32)); + config = bitOr(config, shl(dstStride, 48)); + return config; +} + +static FailureOr +packCopyGmToCbufConfig0(Operation *anchor, Value nBurst, Value lenBurst) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + Location loc = anchor->getLoc(); + + Value nBurstI64 = castIntegerLikeTo(anchor, nBurst, builder.getI64Type()); + Value lenBurstI64 = castIntegerLikeTo(anchor, lenBurst, builder.getI64Type()); + if (!nBurstI64 || !lenBurstI64) + return failure(); + + auto shl = [&](Value value, uint64_t amount) -> Value { + return builder.create(loc, value, + getI64Constant(builder, loc, amount)); + }; + auto bitOr = [&](Value lhs, Value rhs) -> Value { + return builder.create(loc, lhs, rhs); + }; + + Value config0 = getI64Constant(builder, loc, 0); // sid + config0 = bitOr(config0, shl(nBurstI64, 4)); // burst_num[24:4] + config0 = bitOr(config0, shl(lenBurstI64, 25)); // burst_len[45:25] + return config0; +} + +static FailureOr +packCopyGmToCbufConfig1(Operation *anchor, Value srcStride, + Value dstStride) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + Location loc = anchor->getLoc(); + + Value srcStrideI64 = castIntegerLikeTo(anchor, srcStride, builder.getI64Type()); + Value dstStrideI64 = castIntegerLikeTo(anchor, dstStride, builder.getI64Type()); + if (!srcStrideI64 || !dstStrideI64) + return failure(); + + auto shl = [&](Value value, uint64_t amount) -> Value { + return builder.create(loc, value, + getI64Constant(builder, loc, amount)); + }; + auto bitOr = [&](Value lhs, Value rhs) -> Value { + return builder.create(loc, lhs, rhs); + }; + + // config1 packs burst_src_stride[39:0] and burst_dst_stride[60:40]. + return bitOr(srcStrideI64, shl(dstStrideI64, 40)); +} + +static FailureOr +packCopyGmToCbufMultiConfig0(Operation *anchor, Value sid, + Value loop1SrcStride, Value l2CacheCtl, + Value nValue) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + Location loc = anchor->getLoc(); + + Value sidI64 = castIntegerLikeTo(anchor, sid, builder.getI64Type()); + Value loop1SrcStrideI64 = + castIntegerLikeTo(anchor, loop1SrcStride, builder.getI64Type()); + Value l2CacheCtlI64 = castIntegerLikeTo(anchor, l2CacheCtl, builder.getI64Type()); + Value nValueI64 = castIntegerLikeTo(anchor, nValue, builder.getI64Type()); + if (!sidI64 || !loop1SrcStrideI64 || !l2CacheCtlI64 || !nValueI64) + return failure(); + + auto shl = [&](Value value, uint64_t amount) -> Value { + return builder.create(loc, value, + getI64Constant(builder, loc, amount)); + }; + auto bitOr = [&](Value lhs, Value rhs) -> Value { + return builder.create(loc, lhs, rhs); + }; + + Value config0 = sidI64; + config0 = bitOr(config0, shl(loop1SrcStrideI64, 4)); + config0 = bitOr(config0, shl(l2CacheCtlI64, 44)); + config0 = bitOr(config0, shl(nValueI64, 48)); + return config0; +} + +static FailureOr +packCopyGmToCbufMultiConfig1(Operation *anchor, Value dValue, + Value loop4SrcStride, Value smallC0En) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + Location loc = anchor->getLoc(); + + Value dValueI64 = castIntegerLikeTo(anchor, dValue, builder.getI64Type()); + Value loop4SrcStrideI64 = + castIntegerLikeTo(anchor, loop4SrcStride, builder.getI64Type()); + Value smallC0EnI64 = castIntegerLikeTo(anchor, smallC0En, builder.getI64Type()); + if (!dValueI64 || !loop4SrcStrideI64 || !smallC0EnI64) + return failure(); + + auto shl = [&](Value value, uint64_t amount) -> Value { + return builder.create(loc, value, + getI64Constant(builder, loc, amount)); + }; + auto bitOr = [&](Value lhs, Value rhs) -> Value { + return builder.create(loc, lhs, rhs); + }; + + Value config1 = dValueI64; + config1 = bitOr(config1, shl(loop4SrcStrideI64, 21)); + config1 = bitOr(config1, shl(smallC0EnI64, 61)); + return config1; +} + +static FailureOr packCopyCbufToBtConfig(Operation *anchor, + Value convControl, + Value nBurst, Value lenBurst, + Value sourceGap, + Value dstGap) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + Location loc = anchor->getLoc(); + + Value convControlI64 = + castIntegerLikeTo(anchor, convControl, builder.getI64Type()); + Value nBurstI64 = castIntegerLikeTo(anchor, nBurst, builder.getI64Type()); + Value lenBurstI64 = castIntegerLikeTo(anchor, lenBurst, builder.getI64Type()); + Value sourceGapI64 = castIntegerLikeTo(anchor, sourceGap, builder.getI64Type()); + Value dstGapI64 = castIntegerLikeTo(anchor, dstGap, builder.getI64Type()); + if (!convControlI64 || !nBurstI64 || !lenBurstI64 || !sourceGapI64 || + !dstGapI64) + return failure(); + + auto shl = [&](Value value, uint64_t amount) -> Value { + return builder.create(loc, value, + getI64Constant(builder, loc, amount)); + }; + auto bitOr = [&](Value lhs, Value rhs) -> Value { + return builder.create(loc, lhs, rhs); + }; + + Value config = shl(convControlI64, 3); + config = bitOr(config, shl(nBurstI64, 4)); + config = bitOr(config, shl(lenBurstI64, 16)); + config = bitOr(config, shl(sourceGapI64, 32)); + config = bitOr(config, shl(dstGapI64, 48)); + return config; +} + +static FailureOr packCopyCbufToFbufConfig(Operation *anchor, Value nBurst, + Value lenBurst, + Value sourceGap, + Value dstGap) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + Location loc = anchor->getLoc(); + + Value nBurstI64 = castIntegerLikeTo(anchor, nBurst, builder.getI64Type()); + Value lenBurstI64 = castIntegerLikeTo(anchor, lenBurst, builder.getI64Type()); + Value sourceGapI64 = castIntegerLikeTo(anchor, sourceGap, builder.getI64Type()); + Value dstGapI64 = castIntegerLikeTo(anchor, dstGap, builder.getI64Type()); + if (!nBurstI64 || !lenBurstI64 || !sourceGapI64 || !dstGapI64) + return failure(); + + auto shl = [&](Value value, uint64_t amount) -> Value { + return builder.create(loc, value, + getI64Constant(builder, loc, amount)); + }; + auto bitOr = [&](Value lhs, Value rhs) -> Value { + return builder.create(loc, lhs, rhs); + }; + + Value config = shl(nBurstI64, 4); + config = bitOr(config, shl(lenBurstI64, 16)); + config = bitOr(config, shl(sourceGapI64, 32)); + config = bitOr(config, shl(dstGapI64, 48)); + return config; +} + +static FailureOr +packLoadCbufToS4Config0(Operation *anchor, Value mStart, Value kStart, + Value mStep, Value kStep) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + Location loc = anchor->getLoc(); + + Value mStartI64 = castIntegerLikeTo(anchor, mStart, builder.getI64Type()); + Value kStartI64 = castIntegerLikeTo(anchor, kStart, builder.getI64Type()); + Value mStepI64 = castIntegerLikeTo(anchor, mStep, builder.getI64Type()); + Value kStepI64 = castIntegerLikeTo(anchor, kStep, builder.getI64Type()); + if (!mStartI64 || !kStartI64 || !mStepI64 || !kStepI64) + return failure(); + + auto shl = [&](Value value, uint64_t amount) -> Value { + return builder.create(loc, value, + getI64Constant(builder, loc, amount)); + }; + auto bitOr = [&](Value lhs, Value rhs) -> Value { + return builder.create(loc, lhs, rhs); + }; + + Value config0 = mStartI64; + config0 = bitOr(config0, shl(kStartI64, 16)); + config0 = bitOr(config0, shl(mStepI64, 32)); + config0 = bitOr(config0, shl(kStepI64, 40)); + return config0; +} + +static FailureOr +packLoadCbufToS4Config1(Operation *anchor, Value srcStride, Value dstStride) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + Location loc = anchor->getLoc(); + + Value srcStrideI64 = castIntegerLikeTo(anchor, srcStride, builder.getI64Type()); + Value dstStrideI64 = castIntegerLikeTo(anchor, dstStride, builder.getI64Type()); + if (!srcStrideI64 || !dstStrideI64) + return failure(); + + auto shl = [&](Value value, uint64_t amount) -> Value { + return builder.create(loc, value, + getI64Constant(builder, loc, amount)); + }; + return builder.create(loc, srcStrideI64, shl(dstStrideI64, 16)) + .getResult(); +} + +static FailureOr +packLoadCbufToCaConfig0(Operation *anchor, Value mStart, Value kStart, + Value mStep, Value kStep) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + Location loc = anchor->getLoc(); + + Value mStartI64 = castIntegerLikeTo(anchor, mStart, builder.getI64Type()); + Value kStartI64 = castIntegerLikeTo(anchor, kStart, builder.getI64Type()); + Value mStepI64 = castIntegerLikeTo(anchor, mStep, builder.getI64Type()); + Value kStepI64 = castIntegerLikeTo(anchor, kStep, builder.getI64Type()); + if (!mStartI64 || !kStartI64 || !mStepI64 || !kStepI64) + return failure(); + + auto shl = [&](Value value, uint64_t amount) -> Value { + return builder.create(loc, value, + getI64Constant(builder, loc, amount)); + }; + auto bitOr = [&](Value lhs, Value rhs) -> Value { + return builder.create(loc, lhs, rhs); + }; + + Value config0 = mStartI64; + config0 = bitOr(config0, shl(kStartI64, 16)); + config0 = bitOr(config0, shl(mStepI64, 32)); + config0 = bitOr(config0, shl(kStepI64, 40)); + return config0; +} + +static FailureOr +packLoadCbufToCaConfig1(Operation *anchor, Value srcStride, Value dstStride) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + Location loc = anchor->getLoc(); + + Value srcStrideI64 = + castIntegerLikeTo(anchor, srcStride, builder.getI64Type()); + Value dstStrideI64 = + castIntegerLikeTo(anchor, dstStride, builder.getI64Type()); + if (!srcStrideI64 || !dstStrideI64) + return failure(); + + auto shl = [&](Value value, uint64_t amount) -> Value { + return builder.create(loc, value, + getI64Constant(builder, loc, amount)); + }; + return builder.create(loc, srcStrideI64, shl(dstStrideI64, 16)) + .getResult(); +} + +static FailureOr +packLoadCbufToCbConfig0(Operation *anchor, Value mStart, Value kStart, + Value mStep, Value kStep) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + Location loc = anchor->getLoc(); + + Value mStartI64 = castIntegerLikeTo(anchor, mStart, builder.getI64Type()); + Value kStartI64 = castIntegerLikeTo(anchor, kStart, builder.getI64Type()); + Value mStepI64 = castIntegerLikeTo(anchor, mStep, builder.getI64Type()); + Value kStepI64 = castIntegerLikeTo(anchor, kStep, builder.getI64Type()); + if (!mStartI64 || !kStartI64 || !mStepI64 || !kStepI64) + return failure(); + + auto shl = [&](Value value, uint64_t amount) -> Value { + return builder.create(loc, value, + getI64Constant(builder, loc, amount)); + }; + auto bitOr = [&](Value lhs, Value rhs) -> Value { + return builder.create(loc, lhs, rhs); + }; + + Value config0 = mStartI64; + config0 = bitOr(config0, shl(kStartI64, 16)); + config0 = bitOr(config0, shl(mStepI64, 32)); + config0 = bitOr(config0, shl(kStepI64, 40)); + return config0; +} + +static FailureOr +packLoadCbufToCbConfig1(Operation *anchor, Value srcStride, Value dstStride) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + Location loc = anchor->getLoc(); + + Value srcStrideI64 = + castIntegerLikeTo(anchor, srcStride, builder.getI64Type()); + Value dstStrideI64 = + castIntegerLikeTo(anchor, dstStride, builder.getI64Type()); + if (!srcStrideI64 || !dstStrideI64) + return failure(); + + auto shl = [&](Value value, uint64_t amount) -> Value { + return builder.create(loc, value, + getI64Constant(builder, loc, amount)); + }; + return builder.create(loc, srcStrideI64, shl(dstStrideI64, 16)) + .getResult(); +} + +static Value buildMadBiasDestination(Operation *anchor, + ConversionPatternRewriter &rewriter, + Value dst, Value bias) { + Type i64Ty = rewriter.getI64Type(); + Value dstAddr = rewriter.create(anchor->getLoc(), i64Ty, dst); + Value biasAddr = + rewriter.create(anchor->getLoc(), i64Ty, bias); + Value lowMask = getI64Constant(rewriter, anchor->getLoc(), 0xffffffffULL); + Value dstLow = rewriter.create(anchor->getLoc(), dstAddr, lowMask); + Value biasLow = rewriter.create(anchor->getLoc(), biasAddr, lowMask); + Value biasHigh = rewriter.create( + anchor->getLoc(), biasLow, getI64Constant(rewriter, anchor->getLoc(), 32)); + Value packed = rewriter.create(anchor->getLoc(), dstLow, biasHigh); + return rewriter.create(anchor->getLoc(), dst.getType(), packed); +} + +static FailureOr packVbitsortConfig(Operation *anchor, Value repeatTimes) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + Location loc = anchor->getLoc(); + + Value repeatI64 = castIntegerLikeTo(anchor, repeatTimes, builder.getI64Type()); + if (!repeatI64) + return failure(); + return builder + .create(loc, repeatI64, getI64Constant(builder, loc, 56)) + .getResult(); +} + +static FailureOr convertElementOffsetToBytes(Operation *anchor, Value offset, + Type elementType) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + + Value offsetI32 = castIntegerLikeTo(anchor, offset, builder.getI32Type()); + if (!offsetI32) + return failure(); + + unsigned bitWidth = 0; + if (auto intType = dyn_cast(elementType)) + bitWidth = intType.getWidth(); + else if (auto floatType = dyn_cast(elementType)) + bitWidth = floatType.getWidth(); + if (bitWidth == 0 || bitWidth % 8 != 0) + return failure(); + + Value scale = builder.create( + anchor->getLoc(), builder.getI32IntegerAttr(bitWidth / 8)); + return builder.create(anchor->getLoc(), offsetI32, scale) + .getResult(); +} + +[[maybe_unused]] static FailureOr +materializeDynamicPltMask(ConversionPatternRewriter &rewriter, + LoweringState &state, Location loc, Value laneCount, + Type vectorElemType) { + Type i32Type = rewriter.getI32Type(); + Value laneCountI32 = laneCount; + if (laneCountI32.getType() != i32Type) { + laneCountI32 = castIntegerLikeTo(rewriter.getInsertionBlock()->getParentOp(), + laneCountI32, i32Type); + if (!laneCountI32) + return failure(); + } + + StringRef calleeName; + if (vectorElemType.isF32()) { + calleeName = StringRef("llvm.hivm.plt.b32.v300"); + } else if (vectorElemType.isF16() || vectorElemType.isBF16()) { + calleeName = StringRef("llvm.hivm.plt.b16.v300"); + } else if (auto intType = dyn_cast(vectorElemType)) { + if (intType.getWidth() == 32) + calleeName = StringRef("llvm.hivm.plt.b32.v300"); + else if (intType.getWidth() == 16) + calleeName = StringRef("llvm.hivm.plt.b16.v300"); + else if (intType.getWidth() == 8) + calleeName = StringRef("llvm.hivm.plt.b8.v300"); + } + if (calleeName.empty()) + return failure(); + + Type maskType = VectorType::get({256}, rewriter.getI1Type()); + auto funcType = + rewriter.getFunctionType(TypeRange{i32Type}, TypeRange{maskType, i32Type}); + auto call = rewriter.create(loc, calleeName, funcType.getResults(), + ValueRange{laneCountI32}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + return call.getResult(0); +} + +static FailureOr buildCarryBinaryCallee(MLIRContext *context, + Type resultType, + StringRef stem) { + std::string vec = + getElementTypeFragment(cast(resultType).getElementType()); + auto lanes = getElementCountFromVectorLike(resultType); + if (vec.empty() || !lanes) + return failure(); + return StringAttr::get(context, "llvm.hivm." + stem.str() + ".v" + + std::to_string(*lanes) + vec) + .getValue(); +} + +template +static StringRef getUnaryMaskedStem() { + if constexpr (std::is_same_v) + return "vabs"; + if constexpr (std::is_same_v) + return "vexp"; + if constexpr (std::is_same_v) + return "vln"; + if constexpr (std::is_same_v) + return "vneg"; + if constexpr (std::is_same_v) + return "vsqrt"; + if constexpr (std::is_same_v) + return "vrelu"; + if constexpr (std::is_same_v) + return "vnot"; + return {}; +} + +template +static StringRef getBinaryMaskedStem() { + if constexpr (std::is_same_v) + return "vadd"; + if constexpr (std::is_same_v) + return "vsub"; + if constexpr (std::is_same_v) + return "vmul"; + if constexpr (std::is_same_v) + return "vdiv"; + if constexpr (std::is_same_v) + return "vmax"; + if constexpr (std::is_same_v) + return "vmin"; + if constexpr (std::is_same_v) + return "vand"; + if constexpr (std::is_same_v) + return "vor"; + if constexpr (std::is_same_v) + return "vxor"; + if constexpr (std::is_same_v) + return "vshl"; + if constexpr (std::is_same_v) + return "vshr"; + if constexpr (std::is_same_v) + return "vprelu"; + return {}; +} + +template +static StringRef getCarryBinaryStem() { + if constexpr (std::is_same_v) + return "vaddc"; + if constexpr (std::is_same_v) + return "vsubc"; + if constexpr (std::is_same_v) + return "vaddcs"; + if constexpr (std::is_same_v) + return "vsubcs"; + return {}; +} + +template +static constexpr bool hasCarryInput() { + return std::is_same_v || + std::is_same_v; +} + +static FailureOr buildVselCallee(MLIRContext *context, + Type resultType) { + std::string vec = + getElementTypeFragment(cast(resultType).getElementType()); + auto lanes = getElementCountFromVectorLike(resultType); + if (vec.empty() || !lanes) + return failure(); + return StringAttr::get(context, "llvm.hivm.vsel.v" + std::to_string(*lanes) + + vec) + .getValue(); +} + +static FailureOr buildVselrCallee(MLIRContext *context, + Type resultType) { + Type elemType = getElementTypeFromVectorLike(resultType); + auto lanes = getElementCountFromVectorLike(resultType); + if (!elemType || !lanes) + return failure(); + + std::string vec = getElementTypeFragment(elemType); + if (auto floatType = dyn_cast(elemType); + floatType && floatType.isF32()) + vec = "u32"; + if (vec.empty()) + return failure(); + + return StringAttr::get(context, "llvm.hivm.vselr.v" + std::to_string(*lanes) + + vec) + .getValue(); +} + +static FailureOr buildVdupCallee(MLIRContext *context, pto::VdupOp op) { + Type inputType = op.getInput().getType(); + Type resultType = op.getResult().getType(); + std::string vec = getElementTypeFragment(getElementTypeFromVectorLike(resultType)); + auto lanes = getElementCountFromVectorLike(resultType); + if (vec.empty() || !lanes) + return failure(); + + if (isa(inputType)) { + StringRef position = op.getPosition().value_or("LOWEST"); + StringRef family = position == "HIGHEST" ? "vdupm" : "vdup"; + return StringAttr::get(context, "llvm.hivm." + family.str() + ".v" + + std::to_string(*lanes) + vec + ".z") + .getValue(); + } + + return StringAttr::get(context, "llvm.hivm.vdups.v" + std::to_string(*lanes) + + vec + ".z") + .getValue(); +} + +static FailureOr buildVbrCallee(MLIRContext *context, + Type semanticElementType) { + std::string scalar = getVbrScalarFragment(semanticElementType); + if (scalar.empty()) + return failure(); + return StringAttr::get(context, "llvm.hivm.vbr." + scalar + ".v300").getValue(); +} + +static FailureOr buildPstuCallee(MLIRContext *context, pto::PstuOp op) { + if (auto maskType = dyn_cast(op.getValue().getType())) { + if (maskType.isB16()) + return StringAttr::get(context, "llvm.hivm.pstu.b16").getValue(); + if (maskType.isB32()) + return StringAttr::get(context, "llvm.hivm.pstu.b32").getValue(); + } + return failure(); +} + +static StringRef buildVstusCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.vstus").getValue(); +} + +static StringRef buildVsturCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.vstur").getValue(); +} + +static StringRef buildInitAlignCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.init.vector.align.data").getValue(); +} + +template +static StringRef buildRuntimeQueryCallee(MLIRContext *context); + +template <> +StringRef buildRuntimeQueryCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.GET.CTRL").getValue(); +} + +template <> +StringRef buildRuntimeQueryCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.GET.VMS4.SR").getValue(); +} + +template <> +StringRef buildRuntimeQueryCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.get.TID.X").getValue(); +} + +template <> +StringRef buildRuntimeQueryCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.get.TID.Y").getValue(); +} + +template <> +StringRef buildRuntimeQueryCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.get.TID.Z").getValue(); +} + +static StringRef buildSprclrCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.sprclr").getValue(); +} + +template +static StringRef buildUnaryConfigCallee(MLIRContext *context); + +template <> +StringRef buildUnaryConfigCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.CTRL").getValue(); +} + +static StringRef buildStoreVfSimtInfoCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.store.vfsimt.info").getValue(); +} + +static StringRef buildVstarCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.vstar").getValue(); +} + +static StringRef buildVstasCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.vstas").getValue(); +} + +template +static StringRef buildBinaryI64PureCallee(MLIRContext *context); + +template <> +StringRef buildBinaryI64PureCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SBITSET0").getValue(); +} + +template <> +StringRef buildBinaryI64PureCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SBITSET1").getValue(); +} + +static FailureOr buildVldsPostCallee(MLIRContext *context, + Type resultType) { + std::string vec = getElementTypeFragment(getElementTypeFromVectorLike(resultType)); + auto lanes = getElementCountFromVectorLike(resultType); + if (vec.empty() || !lanes) + return failure(); + return StringAttr::get(context, "llvm.hivm.vldsx1.post.v" + + std::to_string(*lanes) + vec) + .getValue(); +} + +static FailureOr buildVstsPostCallee(MLIRContext *context, + Type valueType) { + std::string vec = getElementTypeFragment(getElementTypeFromVectorLike(valueType)); + auto lanes = getElementCountFromVectorLike(valueType); + if (vec.empty() || !lanes) + return failure(); + return StringAttr::get(context, "llvm.hivm.vstsx1.post.v" + + std::to_string(*lanes) + vec) + .getValue(); +} + +static StringRef buildVldasCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.vldas").getValue(); +} + +static FailureOr buildVldusCallee(MLIRContext *context, + Type resultType) { + std::string vec = getElementTypeFragment(getElementTypeFromVectorLike(resultType)); + auto lanes = getElementCountFromVectorLike(resultType); + if (vec.empty() || !lanes) + return failure(); + return StringAttr::get(context, "llvm.hivm.vldus.v" + + std::to_string(*lanes) + vec) + .getValue(); +} + +static FailureOr buildVcmpCallee(MLIRContext *context, Type inputType, + StringRef cmpMode, + bool isScalarCompare) { + std::string elem = getElementTypeFragment(getElementTypeFromVectorLike(inputType)); + if (elem.empty()) + return failure(); + StringRef stem = isScalarCompare ? "vcmps" : "vcmp"; + return StringAttr::get(context, "llvm.hivm." + stem.str() + "." + + cmpMode.str() + "." + elem + ".z") + .getValue(); +} + +template +static StringRef getVecScalarMaskedStem() { + if constexpr (std::is_same_v) + return "vmuls"; + if constexpr (std::is_same_v) + return "vadds"; + if constexpr (std::is_same_v) + return "vmaxs"; + if constexpr (std::is_same_v) + return "vmins"; + if constexpr (std::is_same_v) + return "vlrelu"; + if constexpr (std::is_same_v) + return "vshls"; + if constexpr (std::is_same_v) + return "vshrs"; + return {}; +} + +template +static StringRef getReductionUnaryStem() { + if constexpr (std::is_same_v) + return "vcadd"; + if constexpr (std::is_same_v) + return "vcmax"; + if constexpr (std::is_same_v) + return "vcmin"; + if constexpr (std::is_same_v) + return "vcgadd"; + if constexpr (std::is_same_v) + return "vcgmax"; + if constexpr (std::is_same_v) + return "vcgmin"; + if constexpr (std::is_same_v) + return "vcpadd"; + return {}; +} + +static FailureOr buildCopyGmToUbCallee(MLIRContext *context, + Type sourceType) { + auto ptrType = dyn_cast(sourceType); + if (!ptrType) + return failure(); + Type elementType = ptrType.getElementType(); + if ((isa(elementType) && + cast(elementType).getWidth() == 64) || + elementType.isF64()) { + return StringAttr::get(context, "llvm.hivm.MOV.OUT.TO.UB.ALIGN.V2.s32.DV") + .getValue(); + } + std::string elem = getCopyElementFragment(elementType); + if (elem.empty()) + return failure(); + return StringAttr::get(context, "llvm.hivm.MOV.OUT.TO.UB.ALIGN.V2." + elem + + ".DV") + .getValue(); +} + +static StringRef buildCopyUbToGmCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.MOV.UB.TO.OUT.ALIGN.V2.DV") + .getValue(); +} + +static StringRef buildCopyUbToUbCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.MOV.UB.TO.UB.v310").getValue(); +} + +static StringRef buildCopyCbufToUbCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.MOV.L1.TO.UB.v310").getValue(); +} + +static StringRef buildCopyUbToCbufCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.MOV.UB.TO.L1.v310").getValue(); +} + +static FailureOr buildOrdinaryMadCallee(MLIRContext *context, + pto::MadRawOpInterface op) { + auto lhsType = dyn_cast(op.getLhs().getType()); + auto rhsType = dyn_cast(op.getRhs().getType()); + auto dstType = dyn_cast(op.getDst().getType()); + if (!lhsType || !rhsType || !dstType) + return failure(); + + return buildMadTypedCalleeName(context, lhsType.getElementType(), + rhsType.getElementType(), + dstType.getElementType()); +} + +static FailureOr buildMxMadCallee(MLIRContext *context, + pto::MadRawOpInterface op) { + auto lhsType = dyn_cast(op.getLhs().getType()); + auto rhsType = dyn_cast(op.getRhs().getType()); + if (!lhsType || !rhsType) + return failure(); + if (isMxElementType(lhsType.getElementType()) && + isMxElementType(rhsType.getElementType())) { + return buildMadMxCalleeName(context, lhsType.getElementType(), + rhsType.getElementType()); + } + return failure(); +} + +static FailureOr buildCopyGmToCbufCallee(MLIRContext *context, + Type sourceType) { + auto ptrType = dyn_cast(sourceType); + if (!ptrType) + return failure(); + std::string elem = getCopyElementFragment(ptrType.getElementType()); + if (elem.empty()) + return failure(); + return StringAttr::get(context, "llvm.hivm.MOV.OUT.TO.L1.ALIGN.V2." + elem + + ".DV") + .getValue(); +} + +static FailureOr +buildCopyGmToCbufMultiNd2NzCallee(MLIRContext *context, Type sourceType) { + auto ptrType = dyn_cast(sourceType); + if (!ptrType) + return failure(); + std::string elem = getNd2NzCopyElementFragment(ptrType.getElementType()); + if (elem.empty()) + return failure(); + return StringAttr::get(context, "llvm.hivm.MOV.OUT.TO.L1.MULTI.ND2NZ." + + elem + ".V310") + .getValue(); +} + +static std::string getDn2NzCopyElementFragment(Type type) { + auto ptrType = dyn_cast(type); + if (!ptrType) + return {}; + + Type elementType = ptrType.getElementType(); + std::string typeText; + llvm::raw_string_ostream os(typeText); + elementType.print(os); + os.flush(); + std::string lower = StringRef(typeText).lower(); + if (StringRef(lower).contains("e4m3") || StringRef(lower).contains("e5m2") || + StringRef(lower).contains("e8m0") || StringRef(lower).contains("hif8")) + return "u8"; + + if (elementType.isF16() || elementType.isBF16()) + return "u16"; + if (elementType.isF32()) + return "u32"; + + if (auto intType = dyn_cast(elementType)) { + switch (intType.getWidth()) { + case 8: + return "u8"; + case 16: + return "u16"; + case 32: + return "u32"; + default: + return {}; + } + } + return {}; +} + +static FailureOr +buildCopyGmToCbufMultiDn2NzCallee(MLIRContext *context, Type sourceType) { + auto ptrType = dyn_cast(sourceType); + if (!ptrType) + return failure(); + std::string elem = getDn2NzCopyElementFragment(sourceType); + if (elem.empty()) + return failure(); + return StringAttr::get(context, + "llvm.hivm.MOV.OUT.TO.L1.MULTI.DN2NZ." + elem) + .getValue(); +} + +static FailureOr buildLoadCbufToCaCallee(MLIRContext *context, + Type sourceType) { + auto ptrType = dyn_cast(sourceType); + if (!ptrType) + return failure(); + std::string elem = getL0LoadElementFragment(ptrType.getElementType()); + if (elem.empty()) + return failure(); + return StringAttr::get(context, "llvm.hivm.LOAD.L1.TO.L0A.2Dv2." + elem) + .getValue(); +} + +static StringRef buildLoadCbufToCaS4Callee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.LOAD.L1.TO.L0A.2Dv2.s4") + .getValue(); +} + +static FailureOr buildLoadCbufToCbCallee(MLIRContext *context, + Type sourceType) { + auto ptrType = dyn_cast(sourceType); + if (!ptrType) + return failure(); + std::string elem = getL0LoadElementFragment(ptrType.getElementType()); + if (elem.empty()) + return failure(); + return StringAttr::get(context, "llvm.hivm.LOAD.L1.TO.L0B.2Dv2." + elem) + .getValue(); +} + +static StringRef buildLoadCbufToCbS4Callee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.LOAD.L1.TO.L0B.2Dv2.s4") + .getValue(); +} + +static StringRef buildLoadCbufToCaMxCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.LOAD.L1.TO.L0A.MX.2Dv2.v") + .getValue(); +} + +[[maybe_unused]] static StringRef buildLoadCbufToCbMxCallee( + MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.LOAD.L1.TO.L0B.MX.2Dv2.v") + .getValue(); +} + +static StringRef buildCopyMatrixCcToGmCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.FIX.L0C.TO.OUT.f32.EXT") + .getValue(); +} + +static StringRef buildCopyMatrixCcToCbufCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.FIX.L0C.TO.L1.f32.EXT") + .getValue(); +} + +static FailureOr buildCopyMatrixCcToUbCallee(MLIRContext *context, + Type destinationType) { + auto ptrType = dyn_cast(destinationType); + if (!ptrType) + return failure(); + Type dstElem = ptrType.getElementType(); + if (dstElem.isF16()) + return StringAttr::get(context, "llvm.hivm.FIX.L0C.TO.UB.f322f16.EXT") + .getValue(); + if (dstElem.isF32()) + return StringAttr::get(context, "llvm.hivm.FIX.L0C.TO.UB.f32.EXT") + .getValue(); + return failure(); +} + +static FailureOr buildCopyCbufToBtCallee(pto::CopyCbufToBtOp op) { + auto ptrType = dyn_cast(op.getSource().getType()); + if (!ptrType) + return failure(); + Type srcElem = ptrType.getElementType(); + if (srcElem.isF16()) + return StringAttr::get(op.getContext(), "llvm.hivm.MOV.L1.TO.BT.f16") + .getValue(); + if (srcElem.isBF16()) + return StringAttr::get(op.getContext(), "llvm.hivm.MOV.L1.TO.BT.bf16") + .getValue(); + if (srcElem.isF32()) + return StringAttr::get(op.getContext(), "llvm.hivm.MOV.L1.TO.BT.f32") + .getValue(); + if (auto intType = dyn_cast(srcElem); + intType && intType.getWidth() == 32) { + return StringAttr::get(op.getContext(), "llvm.hivm.MOV.L1.TO.BT.s32") + .getValue(); + } + return failure(); +} + +static StringRef buildCopyCbufToFbufCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.MOV.L1.TO.FB.v220").getValue(); +} + +static StringRef buildPstiCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.psti.b8").getValue(); +} + +static StringRef buildPstsCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.psts.b8").getValue(); +} + +static StringRef buildPldiCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pldi.b8").getValue(); +} + +static StringRef buildPldsCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.plds.b8").getValue(); +} + +static StringRef buildPnotCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pnot.z").getValue(); +} + +static StringRef buildPselCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.psel").getValue(); +} + +static StringRef buildPandCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pand.z").getValue(); +} + +static StringRef buildPorCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.por.z").getValue(); +} + +static StringRef buildPxorCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pxor.z").getValue(); +} + +static StringRef buildPpackCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.ppack.z").getValue(); +} + +static StringRef buildPunpackCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.punpack").getValue(); +} + +template +static StringRef buildPredicatePairReorderCallee(MLIRContext *context); + +template <> +StringRef buildPredicatePairReorderCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pdintlv.b8").getValue(); +} + +template <> +StringRef buildPredicatePairReorderCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pdintlv.b16").getValue(); +} + +template <> +StringRef buildPredicatePairReorderCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pdintlv.b32").getValue(); +} + +template <> +StringRef buildPredicatePairReorderCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pintlv.b8").getValue(); +} + +template <> +StringRef buildPredicatePairReorderCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pintlv.b16").getValue(); +} + +template <> +StringRef buildPredicatePairReorderCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pintlv.b32").getValue(); +} + +static FailureOr buildInterleaveCallee(MLIRContext *context, + Type resultType, + StringRef stem) { + return buildLaneTypedCallee(context, resultType, stem, ""); +} + +static FailureOr buildUnpackCallee(MLIRContext *context, + Type inputType, + Type resultType, + StringRef stem) { + std::string input = + getElementTypeFragment(getElementTypeFromVectorLike(inputType)); + std::string result = + getElementTypeFragment(getElementTypeFromVectorLike(resultType)); + if (input.empty() || result.empty()) + return failure(); + return StringAttr::get(context, + "llvm.hivm." + stem.str() + "." + input + "2" + result) + .getValue(); +} + +static FailureOr buildVpackCallee(MLIRContext *context, Type inputType, + Type resultType) { + std::string input = + getElementTypeFragment(getElementTypeFromVectorLike(inputType)); + std::string result = + getElementTypeFragment(getElementTypeFromVectorLike(resultType)); + if (input.empty() || result.empty()) + return failure(); + + return StringAttr::get(context, "llvm.hivm.vpack." + input + "2" + result + ".x") + .getValue(); +} + +static FailureOr buildVsqzCallee(MLIRContext *context, + Type resultType) { + return buildLaneTypedCallee(context, resultType, "vsqz", ".x.v300"); +} + +static FailureOr buildVusqzCallee(MLIRContext *context, + Type resultType) { + return buildLaneTypedCallee(context, resultType, "vusqz", ".m"); +} + +static FailureOr buildVmulaCallee(MLIRContext *context, + Type resultType) { + return buildLaneTypedCallee(context, resultType, "vmula", ".m"); +} + +static FailureOr buildVmullCallee(MLIRContext *context, + Type resultType) { + return buildLaneTypedCallee(context, resultType, "vmull", ""); +} + +template +static StringRef getPredicateStoreCallee(MLIRContext *context); + +template <> +StringRef getPredicateStoreCallee(MLIRContext *context) { + return buildPstiCallee(context); +} + +template <> +StringRef getPredicateStoreCallee(MLIRContext *context) { + return buildPstsCallee(context); +} + +template +static StringRef getPredicateLoadCallee(MLIRContext *context); + +template <> +StringRef getPredicateLoadCallee(MLIRContext *context) { + return buildPldiCallee(context); +} + +template <> +StringRef getPredicateLoadCallee(MLIRContext *context) { + return buildPldsCallee(context); +} + +template +static StringRef getPredicateMaskCallee(MLIRContext *context); + +template <> +StringRef getPredicateMaskCallee(MLIRContext *context) { + return buildPnotCallee(context); +} + +template <> +StringRef getPredicateMaskCallee(MLIRContext *context) { + return buildPselCallee(context); +} + +template <> +StringRef getPredicateMaskCallee(MLIRContext *context) { + return buildPandCallee(context); +} + +template <> +StringRef getPredicateMaskCallee(MLIRContext *context) { + return buildPorCallee(context); +} + +template <> +StringRef getPredicateMaskCallee(MLIRContext *context) { + return buildPxorCallee(context); +} + +template +static StringRef getPredicatePackCallee(MLIRContext *context); + +template <> +StringRef getPredicatePackCallee(MLIRContext *context) { + return buildPpackCallee(context); +} + +template <> +StringRef getPredicatePackCallee(MLIRContext *context) { + return buildPunpackCallee(context); +} + +template +static StringRef buildPltCallee(MLIRContext *context); + +template <> +StringRef buildPltCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.plt.b8.v300").getValue(); +} + +template <> +StringRef buildPltCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.plt.b16.v300").getValue(); +} + +template <> +StringRef buildPltCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.plt.b32.v300").getValue(); +} + +template +static StringRef buildPsetCallee(MLIRContext *context); + +template <> +StringRef buildPsetCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pset.b8").getValue(); +} + +template <> +StringRef buildPsetCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pset.b16").getValue(); +} + +template <> +StringRef buildPsetCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pset.b32").getValue(); +} + +template +static StringRef buildPgeCallee(MLIRContext *context); + +template <> +StringRef buildPgeCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pge.b8").getValue(); +} + +template <> +StringRef buildPgeCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pge.b16").getValue(); +} + +template <> +StringRef buildPgeCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.pge.b32").getValue(); +} + +static FailureOr buildVldsCallee(MLIRContext *context, Type resultType) { + std::string vec = getElementTypeFragment(getElementTypeFromVectorLike(resultType)); + auto lanes = getElementCountFromVectorLike(resultType); + if (vec.empty() || !lanes) + return failure(); + return StringAttr::get(context, "llvm.hivm.vldsx1.v" + std::to_string(*lanes) + + vec) + .getValue(); +} + +static FailureOr buildVldsx2Callee(MLIRContext *context, + Type resultType) { + return buildLaneTypedCallee(context, resultType, "vldsx2", ""); +} + +static StringRef buildVsldbCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.vsldb").getValue(); +} + +static FailureOr buildVstsCallee(MLIRContext *context, Type valueType) { + std::string vec = getElementTypeFragment(getElementTypeFromVectorLike(valueType)); + auto lanes = getElementCountFromVectorLike(valueType); + if (vec.empty() || !lanes) + return failure(); + return StringAttr::get(context, "llvm.hivm.vstsx1.v" + std::to_string(*lanes) + + vec) + .getValue(); +} + +static FailureOr buildVstsx2Callee(MLIRContext *context, Type valueType) { + return buildLaneTypedCallee(context, valueType, "vstsx2", ""); +} + +static StringRef buildVsstbCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.vsstb").getValue(); +} + +static FailureOr buildVgather2Callee(MLIRContext *context, + Type resultType) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(resultType)); + auto lanes = getElementCountFromVectorLike(resultType); + if (vec.empty() || !lanes) + return failure(); + return StringAttr::get(context, "llvm.hivm.vgather2.v300.v" + + std::to_string(*lanes) + vec) + .getValue(); +} + +static FailureOr buildVgather2BcCallee(MLIRContext *context, + Type resultType) { + return buildLaneTypedCallee(context, resultType, "vgather2.bc", ""); +} + +static FailureOr buildVgatherbCallee(MLIRContext *context, + Type resultType) { + return buildLaneTypedCallee(context, resultType, "vgatherb.v310", ""); +} + +static FailureOr buildVscatterCallee(MLIRContext *context, + Type valueType) { + return buildLaneTypedCallee(context, valueType, "vscatter", ".v300"); +} + +static FailureOr buildVaxpyCallee(MLIRContext *context, + Type resultType) { + return buildLaneTypedCallee(context, resultType, "vaxpy", ".m"); +} + +static FailureOr buildVciCallee(MLIRContext *context, Type resultType) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(resultType)); + auto lanes = getElementCountFromVectorLike(resultType); + if (vec.empty() || !lanes) + return failure(); + if (vec == "f16" || vec == "f32") + return StringAttr::get(context, "llvm.hivm.vci.v" + std::to_string(*lanes) + + vec + "." + vec) + .getValue(); + return StringAttr::get(context, + "llvm.hivm.vci.v" + std::to_string(*lanes) + vec) + .getValue(); +} + +static FailureOr buildVtrcCallee(MLIRContext *context, Type resultType) { + std::string vec = + getElementTypeFragment(getElementTypeFromVectorLike(resultType)); + auto lanes = getElementCountFromVectorLike(resultType); + if (vec.empty() || !lanes) + return failure(); + return StringAttr::get(context, "llvm.hivm.vtrc." + vec + ".x").getValue(); +} + +static FailureOr buildVexpdifCallee(MLIRContext *context, + Type inputType, + Type resultType) { + std::string srcVec = + getElementTypeFragment(getElementTypeFromVectorLike(inputType)); + auto srcLanes = getElementCountFromVectorLike(inputType); + std::string dstElem = + getElementTypeFragment(getElementTypeFromVectorLike(resultType)); + if (srcVec.empty() || dstElem.empty() || !srcLanes) + return failure(); + return StringAttr::get(context, "llvm.hivm.vexpdif.v" + + std::to_string(*srcLanes) + srcVec + + dstElem) + .getValue(); +} + +static FailureOr buildVbitsortCallee(MLIRContext *context, + pto::VbitsortOp op) { + Type sourceElemType = cast(op.getSource().getType()).getElementType(); + if (sourceElemType.isF16()) + return StringAttr::get(context, "llvm.hivm.VBS32.V300.f16").getValue(); + if (sourceElemType.isF32()) + return StringAttr::get(context, "llvm.hivm.VBS32.V300.f32").getValue(); + return failure(); +} + +static FailureOr buildVmrgsort4Callee(MLIRContext *context, + pto::Vmrgsort4Op op) { + Type elemType = + cast(op.getDestination().getType()).getElementType(); + if (elemType.isF16()) + return StringAttr::get(context, "llvm.hivm.VMRGSORT.f16.V300").getValue(); + if (elemType.isF32()) + return StringAttr::get(context, "llvm.hivm.VMRGSORT.f32.V300").getValue(); + return failure(); +} + +static FailureOr packVmrgsort4SourceAddr(Operation *anchor, Value source0, + Value source1, Value source2, + Value source3, Type elemType) { + OpBuilder builder(anchor); + builder.setInsertionPoint(anchor); + Location loc = anchor->getLoc(); + unsigned addrShift = 0; + if (elemType.isF16()) + addrShift = 3; + else if (elemType.isF32()) + addrShift = 3; + else + return failure(); + + auto packOne = [&](Value source, uint64_t laneShift) -> FailureOr { + FailureOr ubPtr = reinterpretPointerToAddrSpace(anchor, source, 6); + if (failed(ubPtr)) + return failure(); + Value asInt = + builder.create(loc, builder.getI64Type(), *ubPtr); + Value shifted = builder.create( + loc, asInt, getI64Constant(builder, loc, addrShift)); + Value masked = builder.create( + loc, shifted, getI64Constant(builder, loc, 0xFFFFULL)); + if (laneShift == 0) + return masked; + return builder + .create(loc, masked, + getI64Constant(builder, loc, laneShift)) + .getResult(); + }; + + FailureOr low0 = packOne(source0, 0); + FailureOr low1 = packOne(source1, 16); + FailureOr low2 = packOne(source2, 32); + FailureOr low3 = packOne(source3, 48); + if (failed(low0) || failed(low1) || failed(low2) || failed(low3)) + return failure(); + + Value packed01 = builder.create(loc, *low0, *low1); + Value packed23 = builder.create(loc, *low2, *low3); + Value packed = builder.create(loc, packed01, packed23); + Type ubPtrTy = LLVM::LLVMPointerType::get(anchor->getContext(), 6); + return builder.create(loc, ubPtrTy, packed).getResult(); +} + +static FailureOr buildVcvtContract(pto::VcvtOp op) { + Type inputElemType = getElementTypeFromVectorLike(op.getInput().getType()); + Type resultElemType = getElementTypeFromVectorLike(op.getResult().getType()); + if (!inputElemType || !resultElemType) + return failure(); + auto contract = lookupVcvtContract(classifyVcvtElemType(inputElemType), + classifyVcvtElemType(resultElemType)); + if (!contract) + return failure(); + return *contract; +} + +static bool needsV300CtrlModeForVPTOFunc(func::FuncOp funcOp) { + if (!pto::isPTOEntryFunction(funcOp) || funcOp.getBlocks().empty()) + return false; + + bool needsCtrlSetup = false; + funcOp.walk([&](pto::VcvtOp vcvtOp) { + FailureOr contract = buildVcvtContract(vcvtOp); + if (succeeded(contract) && (*contract).requiresSat) { + needsCtrlSetup = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return needsCtrlSetup; +} + +template +static StringRef buildSetLoopCallee(MLIRContext *context); + +template +static StringRef buildUnaryConfigCallee(MLIRContext *context); + +template +static StringRef buildNullaryConfigCallee(MLIRContext *context); + +template <> +StringRef buildSetLoopCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.LOOP2.STRIDE.OUTTOUB") + .getValue(); +} + +template <> +StringRef buildSetLoopCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.LOOP1.STRIDE.OUTTOUB") + .getValue(); +} + +template <> +StringRef buildSetLoopCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.LOOP.SIZE.OUTTOUB") + .getValue(); +} + +template <> +StringRef buildSetLoopCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.LOOP2.STRIDE.UBTOOUT") + .getValue(); +} + +template <> +StringRef buildSetLoopCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.LOOP1.STRIDE.UBTOOUT") + .getValue(); +} + +template <> +StringRef buildSetLoopCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.LOOP.SIZE.UBTOOUT") + .getValue(); +} + +template <> +StringRef buildSetLoopCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.LOOP3.PARA").getValue(); +} + +template <> +StringRef buildSetLoopCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.CHANNEL.PARA").getValue(); +} + +template <> +StringRef buildUnaryConfigCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.MOV.PAD.VAL").getValue(); +} + +template <> +StringRef buildUnaryConfigCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.QUANT.PRE.v300").getValue(); +} + +template <> +StringRef buildUnaryConfigCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.RELU.ALPHA").getValue(); +} + +template <> +StringRef buildUnaryConfigCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.FIX.CLIP.RELU").getValue(); +} + +template <> +StringRef buildUnaryConfigCallee( + MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.LOOP2.STRIDE.OUTTOL1") + .getValue(); +} + +template <> +StringRef buildUnaryConfigCallee( + MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.LOOP1.STRIDE.OUTTOL1") + .getValue(); +} + +template <> +StringRef buildUnaryConfigCallee( + MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.LOOP.SIZE.OUTTOL1") + .getValue(); +} + +template <> +StringRef buildUnaryConfigCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.MTE2.NZ.PARA").getValue(); +} + +template <> +StringRef buildUnaryConfigCallee( + MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.PAD.VAL.OUTTOL1") + .getValue(); +} + +template <> +StringRef buildUnaryConfigCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.FPC").getValue(); +} + +template <> +StringRef buildNullaryConfigCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.ATOMIC.S32").getValue(); +} + +template <> +StringRef buildNullaryConfigCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.ATOMIC.S8").getValue(); +} + +static FailureOr encodeMovPadValue(Location loc, Value value, + ConversionPatternRewriter &rewriter) { + Type type = value.getType(); + Value payload = value; + unsigned bitWidth = 0; + + if (auto intType = dyn_cast(type)) { + bitWidth = intType.getWidth(); + } else if (auto floatType = dyn_cast(type)) { + bitWidth = floatType.getWidth(); + auto intType = rewriter.getIntegerType(bitWidth); + payload = rewriter.create(loc, intType, value); + } else { + return failure(); + } + + if (bitWidth != 8 && bitWidth != 16 && bitWidth != 32) + return failure(); + + return rewriter.create(loc, rewriter.getI64Type(), payload) + .getResult(); +} + +template +static StringRef buildSyncCallee(MLIRContext *context); + +template <> +StringRef buildSyncCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.FLAG.IMM").getValue(); +} + +template <> +StringRef buildSyncCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.WAIT.FLAG.IMM").getValue(); +} + +template <> +StringRef buildSyncCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.FLAG.REG").getValue(); +} + +template <> +StringRef buildSyncCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.WAIT.FLAG.REG").getValue(); +} + +template <> +StringRef buildSyncCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.BARRIER").getValue(); +} + +template <> +StringRef buildSyncCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.SET.INTRA.BLOCK.mode").getValue(); +} + +template <> +StringRef buildSyncCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.WAIT.INTRA.BLOCK.mode").getValue(); +} + +static StringRef buildMemBarCallee(MemBarKind kind, MLIRContext *context) { + switch (kind) { + case MemBarKind::VV_ALL: + return StringAttr::get(context, "llvm.hivm.mem.bar.vv.all").getValue(); + case MemBarKind::VST_VLD: + return StringAttr::get(context, "llvm.hivm.mem.bar.vst.vld").getValue(); + case MemBarKind::VLD_VST: + return StringAttr::get(context, "llvm.hivm.mem.bar.vld.vst").getValue(); + case MemBarKind::VST_VST: + return StringAttr::get(context, "llvm.hivm.mem.bar.vst.vst").getValue(); + case MemBarKind::VS_ALL: + return StringAttr::get(context, "llvm.hivm.mem.bar.vs.all").getValue(); + case MemBarKind::VST_LD: + return StringAttr::get(context, "llvm.hivm.mem.bar.vst.ld").getValue(); + case MemBarKind::VLD_ST: + return StringAttr::get(context, "llvm.hivm.mem.bar.vld.st").getValue(); + case MemBarKind::VST_ST: + return StringAttr::get(context, "llvm.hivm.mem.bar.vst.st").getValue(); + case MemBarKind::SV_ALL: + return StringAttr::get(context, "llvm.hivm.mem.bar.sv.all").getValue(); + case MemBarKind::ST_VLD: + return StringAttr::get(context, "llvm.hivm.mem.bar.st.vld").getValue(); + case MemBarKind::LD_VST: + return StringAttr::get(context, "llvm.hivm.mem.bar.ld.vst").getValue(); + case MemBarKind::ST_VST: + return StringAttr::get(context, "llvm.hivm.mem.bar.st.vst").getValue(); + case MemBarKind::SS_ALL: + return StringAttr::get(context, "llvm.hivm.mem.bar.ss.all").getValue(); + case MemBarKind::ST_LD: + return StringAttr::get(context, "llvm.hivm.mem.bar.st.ld").getValue(); + case MemBarKind::LD_ST: + return StringAttr::get(context, "llvm.hivm.mem.bar.ld.st").getValue(); + case MemBarKind::ST_ST: + return StringAttr::get(context, "llvm.hivm.mem.bar.st.st").getValue(); + } + llvm_unreachable("unexpected membar kind"); +} + +template <> +StringRef buildSyncCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.GET.BUFI.mode").getValue(); +} + +template <> +StringRef buildSyncCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.RLS.BUFI.mode").getValue(); +} + +template +static StringRef buildRuntimeQueryCallee(MLIRContext *context); + +template <> +StringRef buildRuntimeQueryCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.GET.BLOCK.IDX").getValue(); +} + +template <> +StringRef buildRuntimeQueryCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.GET.SUBBLOCKID").getValue(); +} + +template <> +StringRef buildRuntimeQueryCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.GET.BLOCK.NUM").getValue(); +} + +template <> +StringRef buildRuntimeQueryCallee(MLIRContext *context) { + return StringAttr::get(context, "llvm.hivm.GET.SUBBLOCKDIM").getValue(); +} + +static LogicalResult +materializeDecls(ModuleOp module, ArrayRef plannedDecls, + llvm::raw_ostream &diagOS) { + OpBuilder builder(module.getBodyRegion()); + builder.setInsertionPointToStart(&module.getBodyRegion().front()); + for (const PlannedDecl &decl : plannedDecls) { + if (func::FuncOp existing = module.lookupSymbol(decl.name)) { + if (existing.getFunctionType() != decl.type) { + diagOS << "VPTO LLVM emission failed: conflicting declaration for " + << decl.name << "\n"; + return failure(); + } + continue; + } + auto func = + builder.create(module.getLoc(), decl.name, decl.type); + func.setPrivate(); + } + return success(); +} + +template +class LowerUnaryMaskedOpPattern final : public OpConversionPattern { +public: + explicit LowerUnaryMaskedOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(UnaryOp op, typename UnaryOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + StringRef stem = getUnaryMaskedStem(); + FailureOr calleeName = + buildLaneTypedCallee(op.getContext(), op.getResult().getType(), stem, ".x"); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported unary VPTO signature"); + + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, + "failed to convert unary result type"); + + Value input = adaptor.getOperands()[0]; + Value mask = adaptor.getOperands()[1]; + Type expectedMaskType = + this->getTypeConverter()->convertType(op->getOperand(1).getType()); + if (!input || !mask || input.getType() != resultType || + mask.getType() != expectedMaskType) { + return rewriter.notifyMatchFailure( + op, "unexpected converted unary VPTO operand types"); + } + + auto call = rewriter.create(op.getLoc(), *calleeName, + TypeRange{resultType}, + ValueRange{input, mask}); + state.plannedDecls.push_back( + PlannedDecl{calleeName->str(), call.getCalleeType()}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVsqzOpPattern final : public OpConversionPattern { +public: + explicit LowerVsqzOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VsqzOp op, pto::VsqzOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr calleeName = + buildVsqzCallee(op.getContext(), op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vsqz VPTO signature"); + + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + Type maskType = this->getTypeConverter()->convertType(op.getMask().getType()); + if (!resultType || !maskType) + return rewriter.notifyMatchFailure(op, "failed to convert vsqz types"); + + Value input = adaptor.getInput(); + Value mask = adaptor.getMask(); + if (!input || !mask || input.getType() != resultType || + mask.getType() != maskType) { + return rewriter.notifyMatchFailure(op, + "unexpected converted vsqz operand types"); + } + + Value storeHint = + getI32Constant(rewriter, op.getLoc(), determineVsqzStoreHint(op)); + auto funcType = rewriter.getFunctionType( + TypeRange{resultType, maskType, storeHint.getType()}, TypeRange{resultType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType}, + ValueRange{input, mask, storeHint}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVusqzOpPattern final : public OpConversionPattern { +public: + explicit LowerVusqzOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VusqzOp op, pto::VusqzOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr calleeName = + buildVusqzCallee(op.getContext(), op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vusqz VPTO signature"); + + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + Type maskType = this->getTypeConverter()->convertType(op.getMask().getType()); + if (!resultType || !maskType) + return rewriter.notifyMatchFailure(op, "failed to convert vusqz types"); + + Value src = adaptor.getSrc(); + Value mask = adaptor.getMask(); + if (!src || !mask || src.getType() != resultType || mask.getType() != maskType) { + return rewriter.notifyMatchFailure(op, + "unexpected converted vusqz operand types"); + } + + auto funcType = + rewriter.getFunctionType(TypeRange{resultType, maskType}, TypeRange{resultType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType}, ValueRange{src, mask}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVmulaOpPattern final : public OpConversionPattern { +public: + explicit LowerVmulaOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VmulaOp op, pto::VmulaOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr calleeName = + buildVmulaCallee(op.getContext(), op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vmula VPTO signature"); + + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + Type maskType = this->getTypeConverter()->convertType(op.getMask().getType()); + if (!resultType || !maskType) + return rewriter.notifyMatchFailure(op, "failed to convert vmula types"); + + Value acc = adaptor.getAcc(); + Value lhs = adaptor.getLhs(); + Value rhs = adaptor.getRhs(); + Value mask = adaptor.getMask(); + if (!acc || !lhs || !rhs || !mask || acc.getType() != resultType || + lhs.getType() != resultType || rhs.getType() != resultType || + mask.getType() != maskType) { + return rewriter.notifyMatchFailure(op, + "unexpected converted vmula operand types"); + } + + auto funcType = rewriter.getFunctionType( + TypeRange{resultType, resultType, resultType, maskType}, + TypeRange{resultType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType}, + ValueRange{acc, lhs, rhs, mask}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVmullOpPattern final : public OpConversionPattern { +public: + explicit LowerVmullOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VmullOp op, pto::VmullOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr calleeName = + buildVmullCallee(op.getContext(), op.getLow().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vmull VPTO signature"); + + Type inputType = this->getTypeConverter()->convertType(op.getLhs().getType()); + Type maskType = this->getTypeConverter()->convertType(op.getMask().getType()); + SmallVector resultTypes; + if (!inputType || !maskType || + failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), resultTypes))) { + return rewriter.notifyMatchFailure(op, "failed to convert vmull types"); + } + if (resultTypes.size() != 2 || resultTypes[0] != resultTypes[1]) + return rewriter.notifyMatchFailure(op, "unexpected converted vmull results"); + + Value lhs = adaptor.getLhs(); + Value rhs = adaptor.getRhs(); + Value mask = adaptor.getMask(); + if (!lhs || !rhs || !mask || lhs.getType() != inputType || + rhs.getType() != inputType || mask.getType() != maskType) { + return rewriter.notifyMatchFailure(op, + "unexpected converted vmull operand types"); + } + + auto funcType = rewriter.getFunctionType(TypeRange{inputType, inputType, maskType}, + resultTypes); + auto call = rewriter.create(op.getLoc(), *calleeName, resultTypes, + ValueRange{lhs, rhs, mask}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerBinaryMaskedOpPattern final : public OpConversionPattern { +public: + explicit LowerBinaryMaskedOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(BinaryOp op, typename BinaryOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + StringRef stem = getBinaryMaskedStem(); + FailureOr calleeName = + buildLaneTypedCallee(op.getContext(), op.getResult().getType(), stem, ".x"); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported binary VPTO signature"); + + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, + "failed to convert binary result type"); + + Value lhs = adaptor.getOperands()[0]; + Value rhs = adaptor.getOperands()[1]; + Value mask = adaptor.getOperands()[2]; + Type expectedMaskType = + this->getTypeConverter()->convertType(op->getOperand(2).getType()); + if (!lhs || !rhs || !mask || lhs.getType() != resultType || + rhs.getType() != resultType || mask.getType() != expectedMaskType) { + return rewriter.notifyMatchFailure( + op, "unexpected converted binary VPTO operand types"); + } + + auto call = rewriter.create(op.getLoc(), *calleeName, + TypeRange{resultType}, + ValueRange{lhs, rhs, mask}); + state.plannedDecls.push_back( + PlannedDecl{calleeName->str(), call.getCalleeType()}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerCarryBinaryOpPattern final : public OpConversionPattern { +public: + explicit LowerCarryBinaryOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(CarryOp op, typename CarryOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + StringRef stem = getCarryBinaryStem(); + FailureOr calleeName = + buildCarryBinaryCallee(op.getContext(), op.getResult().getType(), stem); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported carry VPTO signature"); + + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + Type carryType = + this->getTypeConverter()->convertType(op->getResult(1).getType()); + if (!resultType || !carryType) + return rewriter.notifyMatchFailure(op, + "failed to convert carry result types"); + + SmallVector callArgs; + callArgs.append(adaptor.getOperands().begin(), adaptor.getOperands().end()); + const size_t expectedArgCount = hasCarryInput() ? 4 : 3; + if (callArgs.size() != expectedArgCount || callArgs[0].getType() != resultType || + callArgs[1].getType() != resultType || callArgs.back().getType() != carryType) + return rewriter.notifyMatchFailure(op, + "unexpected converted carry operand types"); + if constexpr (hasCarryInput()) { + if (callArgs[2].getType() != carryType) + return rewriter.notifyMatchFailure( + op, "unexpected converted carry input operand type"); + } + + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType, carryType}, callArgs); + state.plannedDecls.push_back( + PlannedDecl{calleeName->str(), call.getCalleeType()}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerCopyOpPattern final : public OpConversionPattern { +public: + explicit LowerCopyOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(CopyOp op, typename CopyOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr calleeName = failure(); + if constexpr (std::is_same_v) + calleeName = buildCopyGmToUbCallee(op.getContext(), op.getSource().getType()); + else + calleeName = buildCopyUbToGmCallee(op.getContext()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported copy VPTO signature"); + + auto llvmSourceType = + dyn_cast(adaptor.getOperands()[0].getType()); + auto llvmDestType = + dyn_cast(adaptor.getOperands()[1].getType()); + if (!llvmSourceType || !llvmDestType) + return rewriter.notifyMatchFailure(op, "expected LLVM pointer copy operands"); + + FailureOr config0 = failure(); + FailureOr config1 = failure(); + if constexpr (std::is_same_v) { + config0 = packCopyGmToUbConfig0(op, adaptor.getOperands()); + config1 = packCopyGmToUbConfig1(op, adaptor.getOperands()); + } else { + config0 = packCopyUbToGmConfig0(op, adaptor.getOperands()); + config1 = packCopyUbToGmConfig1(op, adaptor.getOperands()); + } + if (failed(config0) || failed(config1)) + return rewriter.notifyMatchFailure(op, "failed to materialize copy config"); + + SmallVector args{adaptor.getOperands()[1], adaptor.getOperands()[0], + *config0, *config1}; + auto funcType = rewriter.getFunctionType( + TypeRange{llvmDestType, llvmSourceType, rewriter.getI64Type(), + rewriter.getI64Type()}, + TypeRange{}); + auto call = rewriter.create(op.getLoc(), *calleeName, + TypeRange{}, args); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.eraseOp(op); + (void)call; + return success(); + } + +private: + LoweringState &state; +}; + +class LowerCopyUbufToUbufOpPattern final + : public OpConversionPattern { +public: + explicit LowerCopyUbufToUbufOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::CopyUbufToUbufOp op, + pto::CopyUbufToUbufOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto llvmSourceType = + dyn_cast(adaptor.getOperands()[0].getType()); + auto llvmDestType = + dyn_cast(adaptor.getOperands()[1].getType()); + if (!llvmSourceType || !llvmDestType) + return rewriter.notifyMatchFailure(op, "expected LLVM pointer copy operands"); + + FailureOr config = packCopyUbToUbConfig(op, adaptor.getOperands()); + if (failed(config)) + return rewriter.notifyMatchFailure(op, "failed to materialize copy config"); + + StringRef calleeName = buildCopyUbToUbCallee(op.getContext()); + SmallVector args{adaptor.getOperands()[1], adaptor.getOperands()[0], + *config}; + auto funcType = rewriter.getFunctionType( + TypeRange{llvmDestType, llvmSourceType, rewriter.getI64Type()}, + TypeRange{}); + auto call = rewriter.create(op.getLoc(), calleeName, + TypeRange{}, args); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + (void)call; + return success(); + } + +private: + LoweringState &state; +}; + +class LowerCopyCbufToUbufOpPattern final + : public OpConversionPattern { +public: + explicit LowerCopyCbufToUbufOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::CopyCbufToUbufOp op, + pto::CopyCbufToUbufOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value sourceRaw = adaptor.getSource(); + Value destinationRaw = adaptor.getDestination(); + if (!sourceRaw || !destinationRaw) + return rewriter.notifyMatchFailure(op, "expected converted operands"); + if (!isa(sourceRaw.getType()) || + !isa(destinationRaw.getType())) + return rewriter.notifyMatchFailure(op, "expected LLVM pointer src/dst"); + + constexpr unsigned cbufAddressSpace = + static_cast(pto::AddressSpace::MAT); + constexpr unsigned ubufAddressSpace = + static_cast(pto::AddressSpace::VEC); + FailureOr source = + reinterpretPointerToAddrSpace(op, sourceRaw, cbufAddressSpace); + FailureOr destination = + reinterpretPointerToAddrSpace(op, destinationRaw, ubufAddressSpace); + if (failed(source) || failed(destination)) + return rewriter.notifyMatchFailure(op, "failed to map cbuf/ubuf pointer spaces"); + + FailureOr config = packCopyCbufToUbConfig(op, adaptor.getOperands()); + if (failed(config)) + return rewriter.notifyMatchFailure(op, "failed to materialize copy config"); + + StringRef calleeName = buildCopyCbufToUbCallee(op.getContext()); + auto funcType = rewriter.getFunctionType( + TypeRange{destination->getType(), source->getType(), + rewriter.getI64Type()}, + TypeRange{}); + rewriter.create(op.getLoc(), calleeName, TypeRange{}, + ValueRange{*destination, *source, *config}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerCopyUbufToCbufOpPattern final + : public OpConversionPattern { +public: + explicit LowerCopyUbufToCbufOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::CopyUbufToCbufOp op, + pto::CopyUbufToCbufOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value sourceRaw = adaptor.getSource(); + Value destinationRaw = adaptor.getDestination(); + if (!sourceRaw || !destinationRaw) + return rewriter.notifyMatchFailure(op, "expected converted operands"); + if (!isa(sourceRaw.getType()) || + !isa(destinationRaw.getType())) + return rewriter.notifyMatchFailure(op, "expected LLVM pointer src/dst"); + + constexpr unsigned ubufAddressSpace = + static_cast(pto::AddressSpace::VEC); + constexpr unsigned cbufAddressSpace = + static_cast(pto::AddressSpace::MAT); + FailureOr source = + reinterpretPointerToAddrSpace(op, sourceRaw, ubufAddressSpace); + FailureOr destination = + reinterpretPointerToAddrSpace(op, destinationRaw, cbufAddressSpace); + if (failed(source) || failed(destination)) + return rewriter.notifyMatchFailure(op, "failed to map ubuf/cbuf pointer spaces"); + + FailureOr config = packCopyUbToCbufConfig(op, adaptor.getOperands()); + if (failed(config)) + return rewriter.notifyMatchFailure(op, "failed to materialize copy config"); + + StringRef calleeName = buildCopyUbToCbufCallee(op.getContext()); + auto funcType = rewriter.getFunctionType( + TypeRange{destination->getType(), source->getType(), + rewriter.getI64Type()}, + TypeRange{}); + rewriter.create(op.getLoc(), calleeName, TypeRange{}, + ValueRange{*destination, *source, *config}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +static LogicalResult lowerMadRawOp(pto::MadRawOpInterface op, + ValueRange convertedOperands, + ConversionPatternRewriter &rewriter, + LoweringState &state) { + Value lhsRaw = convertedOperands[0]; + Value rhsRaw = convertedOperands[1]; + Value dstRaw = convertedOperands[2]; + Value biasRaw = op.hasBiasOperand() ? convertedOperands[3] : Value(); + Value xt = convertedOperands[op.hasBiasOperand() ? 4 : 3]; + if (!lhsRaw || !rhsRaw || !dstRaw || !xt || + (op.hasBiasOperand() && !biasRaw)) + return rewriter.notifyMatchFailure(op, "expected converted mad raw operands"); + + if (!isa(lhsRaw.getType()) || + !isa(rhsRaw.getType()) || + !isa(dstRaw.getType()) || + (biasRaw && !isa(biasRaw.getType()))) { + return rewriter.notifyMatchFailure( + op, "expected LLVM pointer lhs/rhs/dst/bias operands"); + } + + Type i64Ty = rewriter.getI64Type(); + constexpr unsigned caAddressSpace = + static_cast(pto::AddressSpace::LEFT); + constexpr unsigned cbAddressSpace = + static_cast(pto::AddressSpace::RIGHT); + constexpr unsigned ccAddressSpace = + static_cast(pto::AddressSpace::ACC); + constexpr unsigned btAddressSpace = + static_cast(pto::AddressSpace::BIAS); + FailureOr lhs = + reinterpretPointerToAddrSpace(op, lhsRaw, caAddressSpace); + FailureOr rhs = + reinterpretPointerToAddrSpace(op, rhsRaw, cbAddressSpace); + FailureOr dst = + reinterpretPointerToAddrSpace(op, dstRaw, ccAddressSpace); + FailureOr bias; + if (biasRaw) + bias = reinterpretPointerToAddrSpace(op, biasRaw, btAddressSpace); + if (failed(lhs) || failed(rhs) || failed(dst) || + (biasRaw && failed(bias))) { + return rewriter.notifyMatchFailure(op, "failed to map cube pointer spaces"); + } + + FailureOr calleeName = + op.isMadMxFamily() ? buildMxMadCallee(op.getContext(), op) + : buildOrdinaryMadCallee(op.getContext(), op); + if (failed(calleeName)) + return rewriter.notifyMatchFailure( + op, "unsupported mad element types for raw dispatch"); + + Value callDst = *dst; + if (biasRaw) + callDst = buildMadBiasDestination(op, rewriter, *dst, *bias); + auto funcType = rewriter.getFunctionType( + TypeRange{dst->getType(), lhs->getType(), rhs->getType(), i64Ty}, + TypeRange{}); + auto call = rewriter.create( + op->getLoc(), *calleeName, TypeRange{}, + ValueRange{callDst, *lhs, *rhs, xt}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); +} + +template +class LowerMadRawPattern final : public OpConversionPattern { +public: + explicit LowerMadRawPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(RawOp op, typename RawOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto raw = dyn_cast(op.getOperation()); + if (!raw) + return failure(); + return lowerMadRawOp(raw, adaptor.getOperands(), rewriter, state); + } + +private: + LoweringState &state; +}; + +class LowerCopyGmToCbufOpPattern final + : public OpConversionPattern { +public: + explicit LowerCopyGmToCbufOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult matchAndRewrite( + pto::CopyGmToCbufOp op, + pto::CopyGmToCbufOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value sourceRaw = adaptor.getSource(); + Value destinationRaw = adaptor.getDestination(); + Value nBurst = adaptor.getNBurst(); + Value lenBurst = adaptor.getLenBurst(); + Value srcStride = adaptor.getSrcStride(); + Value dstStride = adaptor.getDstStride(); + if (!sourceRaw || !destinationRaw || !nBurst || !lenBurst || !srcStride || + !dstStride) { + return rewriter.notifyMatchFailure(op, "expected converted operands"); + } + + if (!isa(sourceRaw.getType()) || + !isa(destinationRaw.getType())) { + return rewriter.notifyMatchFailure(op, "expected LLVM pointer src/dst"); + } + + Type i64Ty = rewriter.getI64Type(); + if (nBurst.getType() != i64Ty || lenBurst.getType() != i64Ty || + srcStride.getType() != i64Ty || dstStride.getType() != i64Ty) { + return rewriter.notifyMatchFailure(op, "expected i64 config operands"); + } + + constexpr unsigned gmAddressSpace = + static_cast(pto::AddressSpace::GM); + constexpr unsigned cbufAddressSpace = + static_cast(pto::AddressSpace::MAT); + FailureOr source = reinterpretPointerToAddrSpace(op, sourceRaw, gmAddressSpace); + FailureOr destination = + reinterpretPointerToAddrSpace(op, destinationRaw, cbufAddressSpace); + if (failed(source) || failed(destination)) + return rewriter.notifyMatchFailure(op, "failed to map cbuf/gm pointer spaces"); + + FailureOr calleeName = + buildCopyGmToCbufCallee(op.getContext(), op.getSource().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, + "unsupported copy_gm_to_cbuf element type"); + FailureOr config0 = + packCopyGmToCbufConfig0(op, nBurst, lenBurst); + FailureOr config1 = + packCopyGmToCbufConfig1(op, srcStride, dstStride); + if (failed(config0) || failed(config1)) + return rewriter.notifyMatchFailure(op, + "failed to pack copy_gm_to_cbuf config"); + + auto funcType = rewriter.getFunctionType( + TypeRange{destination->getType(), source->getType(), i64Ty, i64Ty}, + TypeRange{}); + rewriter.create( + op.getLoc(), *calleeName, TypeRange{}, + ValueRange{*destination, *source, *config0, *config1}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerCopyGmToCbufMultiOpPattern final + : public OpConversionPattern { +public: + explicit LowerCopyGmToCbufMultiOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(CopyOp op, typename CopyOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value sourceRaw = adaptor.getSource(); + Value destinationRaw = adaptor.getDestination(); + if (!sourceRaw || !destinationRaw) + return rewriter.notifyMatchFailure(op, "expected converted operands"); + if (!isa(sourceRaw.getType()) || + !isa(destinationRaw.getType())) + return rewriter.notifyMatchFailure(op, "expected LLVM pointer src/dst"); + + constexpr unsigned gmAddressSpace = + static_cast(pto::AddressSpace::GM); + constexpr unsigned cbufAddressSpace = + static_cast(pto::AddressSpace::MAT); + FailureOr source = + reinterpretPointerToAddrSpace(op, sourceRaw, gmAddressSpace); + FailureOr destination = + reinterpretPointerToAddrSpace(op, destinationRaw, cbufAddressSpace); + if (failed(source) || failed(destination)) + return rewriter.notifyMatchFailure(op, "failed to map cbuf/gm pointer spaces"); + + FailureOr config0 = packCopyGmToCbufMultiConfig0( + op, adaptor.getSid(), adaptor.getLoop1SrcStride(), + adaptor.getL2CacheCtrl(), adaptor.getNValue()); + FailureOr config1 = + packCopyGmToCbufMultiConfig1(op, adaptor.getDValue(), + adaptor.getLoop4SrcStride(), + adaptor.getSmallc0En()); + if (failed(config0) || failed(config1)) + return rewriter.notifyMatchFailure(op, "failed to pack multi copy config"); + + FailureOr calleeName = [&] (MLIRContext *ctx, Type sourceType) + -> FailureOr { + if constexpr (std::is_same_v) + return buildCopyGmToCbufMultiNd2NzCallee(ctx, op.getSource().getType()); + return buildCopyGmToCbufMultiDn2NzCallee(ctx, sourceType); + }(op.getContext(), op.getSource().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure( + op, "unsupported copy_gm_to_cbuf_multi element type"); + + Type i64Ty = rewriter.getI64Type(); + auto funcType = rewriter.getFunctionType( + TypeRange{destination->getType(), source->getType(), i64Ty, i64Ty}, + TypeRange{}); + rewriter.create( + op.getLoc(), *calleeName, TypeRange{}, + ValueRange{*destination, *source, *config0, *config1}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerCopyCbufToBtOpPattern final + : public OpConversionPattern { +public: + explicit LowerCopyCbufToBtOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult matchAndRewrite(pto::CopyCbufToBtOp op, + pto::CopyCbufToBtOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value sourceRaw = adaptor.getSource(); + Value destinationRaw = adaptor.getDestination(); + if (!sourceRaw || !destinationRaw) + return rewriter.notifyMatchFailure(op, "expected converted operands"); + if (!isa(sourceRaw.getType()) || + !isa(destinationRaw.getType())) + return rewriter.notifyMatchFailure(op, "expected LLVM pointer src/dst"); + + constexpr unsigned cbufAddressSpace = + static_cast(pto::AddressSpace::MAT); + constexpr unsigned btAddressSpace = + static_cast(pto::AddressSpace::BIAS); + FailureOr source = + reinterpretPointerToAddrSpace(op, sourceRaw, cbufAddressSpace); + FailureOr destinationPtr = + reinterpretPointerToAddrSpace(op, destinationRaw, btAddressSpace); + if (failed(source) || failed(destinationPtr)) + return rewriter.notifyMatchFailure(op, "failed to map cbuf/bt pointer spaces"); + + FailureOr config = packCopyCbufToBtConfig( + op, adaptor.getConvControl(), adaptor.getNBurst(), adaptor.getLenBurst(), + adaptor.getSourceGap(), adaptor.getDstGap()); + if (failed(config)) + return rewriter.notifyMatchFailure(op, "failed to pack copy_cbuf_to_bt config"); + + Type i64Ty = rewriter.getI64Type(); + Value destination = + rewriter.create(op.getLoc(), i64Ty, *destinationPtr); + FailureOr calleeName = buildCopyCbufToBtCallee(op); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported copy_cbuf_to_bt source element type"); + auto funcType = rewriter.getFunctionType( + TypeRange{i64Ty, source->getType(), i64Ty}, TypeRange{}); + rewriter.create(op.getLoc(), *calleeName, TypeRange{}, + ValueRange{destination, *source, *config}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerCopyCbufToFbufOpPattern final + : public OpConversionPattern { +public: + explicit LowerCopyCbufToFbufOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult matchAndRewrite(pto::CopyCbufToFbufOp op, + pto::CopyCbufToFbufOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value sourceRaw = adaptor.getSource(); + Value destinationRaw = adaptor.getDestination(); + if (!sourceRaw || !destinationRaw) + return rewriter.notifyMatchFailure(op, "expected converted operands"); + if (!isa(sourceRaw.getType()) || + !isa(destinationRaw.getType())) + return rewriter.notifyMatchFailure(op, "expected LLVM pointer src/dst"); + + constexpr unsigned cbufAddressSpace = + static_cast(pto::AddressSpace::MAT); + constexpr unsigned fbufAddressSpace = 7; + FailureOr source = + reinterpretPointerToAddrSpace(op, sourceRaw, cbufAddressSpace); + FailureOr destination = + reinterpretPointerToAddrSpace(op, destinationRaw, fbufAddressSpace); + if (failed(source) || failed(destination)) + return rewriter.notifyMatchFailure(op, "failed to map cbuf/fbuf pointer spaces"); + + FailureOr config = packCopyCbufToFbufConfig( + op, adaptor.getNBurst(), adaptor.getLenBurst(), adaptor.getSourceGap(), + adaptor.getDstGap()); + if (failed(config)) + return rewriter.notifyMatchFailure(op, "failed to pack copy_cbuf_to_fbuf config"); + + Type i64Ty = rewriter.getI64Type(); + StringRef calleeName = buildCopyCbufToFbufCallee(op.getContext()); + auto funcType = rewriter.getFunctionType( + TypeRange{destination->getType(), source->getType(), i64Ty}, TypeRange{}); + rewriter.create(op.getLoc(), calleeName, TypeRange{}, + ValueRange{*destination, *source, *config}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerLoadCbufToCaOpPattern final + : public OpConversionPattern { +public: + explicit LowerLoadCbufToCaOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult matchAndRewrite(pto::LoadCbufToCaOp op, + pto::LoadCbufToCaOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value sourceRaw = adaptor.getSource(); + Value destinationRaw = adaptor.getDestination(); + Value mStart = adaptor.getMStart(); + Value kStart = adaptor.getKStart(); + Value mStep = adaptor.getMStep(); + Value kStep = adaptor.getKStep(); + Value srcStride = adaptor.getSrcStride(); + Value dstStride = adaptor.getDstStride(); + if (!sourceRaw || !destinationRaw || !mStart || !kStart || !mStep || + !kStep || !srcStride || !dstStride) + return rewriter.notifyMatchFailure(op, "expected converted operands"); + + if (!isa(sourceRaw.getType()) || + !isa(destinationRaw.getType())) { + return rewriter.notifyMatchFailure(op, "expected LLVM pointer src/dst"); + } + + Type i64Ty = rewriter.getI64Type(); + + constexpr unsigned cbufAddressSpace = + static_cast(pto::AddressSpace::MAT); + constexpr unsigned caAddressSpace = + static_cast(pto::AddressSpace::LEFT); + FailureOr source = reinterpretPointerToAddrSpace(op, sourceRaw, cbufAddressSpace); + FailureOr destination = + reinterpretPointerToAddrSpace(op, destinationRaw, caAddressSpace); + if (failed(source) || failed(destination)) + return rewriter.notifyMatchFailure(op, "failed to map cbuf/ca pointer spaces"); + + FailureOr config0 = + packLoadCbufToCaConfig0(op, mStart, kStart, mStep, kStep); + FailureOr config1 = + packLoadCbufToCaConfig1(op, srcStride, dstStride); + if (failed(config0) || failed(config1)) + return rewriter.notifyMatchFailure(op, "failed to pack load_cbuf_to_ca config"); + Value transpose = + getI64Constant(rewriter, op.getLoc(), op.getTranspose() ? 1 : 0); + + FailureOr calleeName = + buildLoadCbufToCaCallee(op.getContext(), op.getSource().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, + "unsupported load_cbuf_to_ca element type"); + auto funcType = rewriter.getFunctionType( + TypeRange{destination->getType(), source->getType(), i64Ty, i64Ty, + i64Ty}, + TypeRange{}); + rewriter.create(op.getLoc(), *calleeName, TypeRange{}, + ValueRange{*destination, *source, *config0, + *config1, transpose}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerLoadCbufToS4OpPattern final : public OpConversionPattern { +public: + explicit LowerLoadCbufToS4OpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(LoadOp op, typename LoadOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value sourceRaw = adaptor.getSource(); + Value destinationRaw = adaptor.getDestination(); + if (!sourceRaw || !destinationRaw) + return rewriter.notifyMatchFailure(op, "expected converted operands"); + if (!isa(sourceRaw.getType()) || + !isa(destinationRaw.getType())) + return rewriter.notifyMatchFailure(op, "expected LLVM pointer src/dst"); + + constexpr unsigned cbufAddressSpace = + static_cast(pto::AddressSpace::MAT); + constexpr unsigned targetAddressSpace = + std::is_same_v + ? static_cast(pto::AddressSpace::LEFT) + : static_cast(pto::AddressSpace::RIGHT); + FailureOr source = + reinterpretPointerToAddrSpace(op, sourceRaw, cbufAddressSpace); + FailureOr destination = + reinterpretPointerToAddrSpace(op, destinationRaw, targetAddressSpace); + if (failed(source) || failed(destination)) + return rewriter.notifyMatchFailure(op, "failed to map cbuf/cube pointer spaces"); + + FailureOr config0 = packLoadCbufToS4Config0( + op, adaptor.getMStart(), adaptor.getKStart(), adaptor.getMStep(), + adaptor.getKStep()); + FailureOr config1 = + packLoadCbufToS4Config1(op, adaptor.getSrcStride(), + adaptor.getDstStride()); + if (failed(config0) || failed(config1)) + return rewriter.notifyMatchFailure(op, "failed to pack load_cbuf_to_*_s4 config"); + + Value transpose = + castIntegerLikeTo(op, adaptor.getTranspose(), rewriter.getI64Type()); + if (!transpose) + return rewriter.notifyMatchFailure(op, "failed to cast transpose to i64"); + + StringRef calleeName = std::is_same_v + ? buildLoadCbufToCaS4Callee(op.getContext()) + : buildLoadCbufToCbS4Callee(op.getContext()); + Type i64Ty = rewriter.getI64Type(); + auto funcType = rewriter.getFunctionType( + TypeRange{destination->getType(), source->getType(), i64Ty, i64Ty, + i64Ty}, + TypeRange{}); + rewriter.create( + op.getLoc(), calleeName, TypeRange{}, + ValueRange{*destination, *source, *config0, *config1, transpose}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerLoadCbufToCbOpPattern final + : public OpConversionPattern { +public: + explicit LowerLoadCbufToCbOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult matchAndRewrite(pto::LoadCbufToCbOp op, + pto::LoadCbufToCbOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value sourceRaw = adaptor.getSource(); + Value destinationRaw = adaptor.getDestination(); + Value mStart = adaptor.getMStart(); + Value kStart = adaptor.getKStart(); + Value mStep = adaptor.getMStep(); + Value kStep = adaptor.getKStep(); + Value srcStride = adaptor.getSrcStride(); + Value dstStride = adaptor.getDstStride(); + if (!sourceRaw || !destinationRaw || !mStart || !kStart || !mStep || + !kStep || !srcStride || !dstStride) + return rewriter.notifyMatchFailure(op, "expected converted operands"); + + if (!isa(sourceRaw.getType()) || + !isa(destinationRaw.getType())) { + return rewriter.notifyMatchFailure(op, "expected LLVM pointer src/dst"); + } + + Type i64Ty = rewriter.getI64Type(); + + constexpr unsigned cbufAddressSpace = + static_cast(pto::AddressSpace::MAT); + constexpr unsigned cbAddressSpace = + static_cast(pto::AddressSpace::RIGHT); + FailureOr source = reinterpretPointerToAddrSpace(op, sourceRaw, cbufAddressSpace); + FailureOr destination = + reinterpretPointerToAddrSpace(op, destinationRaw, cbAddressSpace); + if (failed(source) || failed(destination)) + return rewriter.notifyMatchFailure(op, "failed to map cbuf/cb pointer spaces"); + + bool transpose = op.getTranspose(); + FailureOr config0 = + packLoadCbufToCbConfig0(op, mStart, kStart, mStep, kStep); + FailureOr config1 = + packLoadCbufToCbConfig1(op, srcStride, dstStride); + if (failed(config0) || failed(config1)) + return rewriter.notifyMatchFailure(op, "failed to pack load_cbuf_to_cb config"); + Value transposeValue = + getI64Constant(rewriter, op.getLoc(), transpose ? 1 : 0); + + FailureOr calleeName = + buildLoadCbufToCbCallee(op.getContext(), op.getSource().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, + "unsupported load_cbuf_to_cb element type"); + auto funcType = rewriter.getFunctionType( + TypeRange{destination->getType(), source->getType(), i64Ty, i64Ty, + i64Ty}, + TypeRange{}); + rewriter.create(op.getLoc(), *calleeName, TypeRange{}, + ValueRange{*destination, *source, *config0, + *config1, transposeValue}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerLoadCbufToCaMxOpPattern final + : public OpConversionPattern { +public: + explicit LowerLoadCbufToCaMxOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::LoadCbufToCaMxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value srcRaw = adaptor.getSource(); + Value dstRaw = adaptor.getDestination(); + if (!srcRaw || !dstRaw || !adaptor.getM() || !adaptor.getK()) + return rewriter.notifyMatchFailure(op, "expected converted operands"); + if (!isa(srcRaw.getType()) || + !isa(dstRaw.getType())) + return rewriter.notifyMatchFailure(op, "expected LLVM pointer src/dst"); + + constexpr unsigned cbufAddressSpace = + static_cast(pto::AddressSpace::MAT); + constexpr unsigned caAddressSpace = + static_cast(pto::AddressSpace::LEFT); + FailureOr src = reinterpretPointerToAddrSpace(op, srcRaw, cbufAddressSpace); + FailureOr dst = reinterpretPointerToAddrSpace(op, dstRaw, caAddressSpace); + if (failed(src) || failed(dst)) + return rewriter.notifyMatchFailure(op, "failed to map cbuf/ca pointer spaces"); + + Type sourceElemType = cast(op.getSource().getType()).getElementType(); + unsigned elemBitWidth = sourceElemType.getIntOrFloatBitWidth(); + if (elemBitWidth == 0 || (elemBitWidth % 8) != 0) + return rewriter.notifyMatchFailure(op, + "unsupported load_cbuf_to_ca_mx element type"); + uint64_t elemBytes = elemBitWidth / 8; + Location loc = op.getLoc(); + auto constant = [&](uint64_t value) -> Value { + return rewriter.create(loc, value, 64); + }; + auto ceilDivConst = [&](Value value, uint64_t divisor) -> Value { + Value bias = constant(divisor - 1); + Value sum = rewriter.create(loc, value, bias); + return rewriter.create(loc, sum, constant(divisor)); + }; + Value zero = constant(0); + Value mStep = ceilDivConst(adaptor.getM(), 16); + Value kBytes = + rewriter.create(loc, adaptor.getK(), constant(elemBytes)); + Value kStep = ceilDivConst(kBytes, 32); + Value stride = ceilDivConst(adaptor.getM(), 16); + FailureOr config0 = + packLoadCbufToCaConfig0(op, zero, zero, mStep, kStep); + FailureOr config1 = + packLoadCbufToCaConfig1(op, stride, stride); + if (failed(config0) || failed(config1)) + return rewriter.notifyMatchFailure(op, + "failed to pack load_cbuf_to_ca_mx config"); + auto i64Ty = rewriter.getI64Type(); + Value dstAddr = rewriter.create(op.getLoc(), i64Ty, *dst); + + StringRef calleeName = buildLoadCbufToCaMxCallee(op.getContext()); + auto funcType = rewriter.getFunctionType( + TypeRange{i64Ty, src->getType(), i64Ty, i64Ty}, + TypeRange{}); + rewriter.create(op.getLoc(), calleeName, TypeRange{}, + ValueRange{dstAddr, *src, *config0, *config1}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerLoadCbufToCbMxOpPattern final + : public OpConversionPattern { +public: + explicit LowerLoadCbufToCbMxOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::LoadCbufToCbMxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value srcRaw = adaptor.getSource(); + Value dstRaw = adaptor.getDestination(); + if (!srcRaw || !dstRaw || !adaptor.getXStartPosition() || + !adaptor.getYStartPosition() || !adaptor.getXStep() || + !adaptor.getYStep() || !adaptor.getSrcStride() || + !adaptor.getDstStride()) + return rewriter.notifyMatchFailure(op, "expected converted operands"); + if (!isa(srcRaw.getType()) || + !isa(dstRaw.getType())) + return rewriter.notifyMatchFailure(op, "expected LLVM pointer src/dst"); + + constexpr unsigned cbufAddressSpace = + static_cast(pto::AddressSpace::MAT); + constexpr unsigned cbAddressSpace = + static_cast(pto::AddressSpace::RIGHT); + FailureOr src = reinterpretPointerToAddrSpace(op, srcRaw, cbufAddressSpace); + FailureOr dst = reinterpretPointerToAddrSpace(op, dstRaw, cbAddressSpace); + if (failed(src) || failed(dst)) + return rewriter.notifyMatchFailure(op, "failed to map cbuf/cb pointer spaces"); + + Type sourceElemType = cast(op.getSource().getType()).getElementType(); + unsigned elemBitWidth = sourceElemType.getIntOrFloatBitWidth(); + if (elemBitWidth == 0 || (elemBitWidth % 8) != 0) + return rewriter.notifyMatchFailure(op, + "unsupported load_cbuf_to_cb_mx element type"); + FailureOr config0 = + packLoadCbufToCaConfig0(op, adaptor.getXStartPosition(), + adaptor.getYStartPosition(), adaptor.getXStep(), + adaptor.getYStep()); + FailureOr config1 = + packLoadCbufToCaConfig1(op, adaptor.getSrcStride(), + adaptor.getDstStride()); + if (failed(config0) || failed(config1)) + return rewriter.notifyMatchFailure(op, + "failed to pack load_cbuf_to_cb_mx config"); + auto i64Ty = rewriter.getI64Type(); + Value dstAddr = rewriter.create(op.getLoc(), i64Ty, *dst); + + StringRef calleeName = buildLoadCbufToCbMxCallee(op.getContext()); + auto funcType = rewriter.getFunctionType( + TypeRange{i64Ty, src->getType(), i64Ty, i64Ty}, + TypeRange{}); + rewriter.create(op.getLoc(), calleeName, TypeRange{}, + ValueRange{dstAddr, *src, *config0, *config1}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerCopyMatrixCcToGmOpPattern final + : public OpConversionPattern { +public: + explicit LowerCopyMatrixCcToGmOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult matchAndRewrite( + pto::CopyMatrixCcToGmOp op, pto::CopyMatrixCcToGmOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value sourceRaw = adaptor.getSource(); + Value destinationRaw = adaptor.getDestination(); + Value xm = adaptor.getXm(); + Value xt = adaptor.getXt(); + if (!sourceRaw || !destinationRaw || !xm || !xt) + return rewriter.notifyMatchFailure(op, "expected converted operands"); + + if (!isa(sourceRaw.getType()) || + !isa(destinationRaw.getType())) { + return rewriter.notifyMatchFailure(op, "expected LLVM pointer src/dst"); + } + + Type i64Ty = rewriter.getI64Type(); + if (xm.getType() != i64Ty || xt.getType() != i64Ty) + return rewriter.notifyMatchFailure(op, "expected i64 xm/xt operands"); + + constexpr unsigned gmAddressSpace = + static_cast(pto::AddressSpace::GM); + constexpr unsigned ccAddressSpace = + static_cast(pto::AddressSpace::ACC); + FailureOr source = reinterpretPointerToAddrSpace(op, sourceRaw, ccAddressSpace); + FailureOr destination = + reinterpretPointerToAddrSpace(op, destinationRaw, gmAddressSpace); + if (failed(source) || failed(destination)) + return rewriter.notifyMatchFailure(op, "failed to map cc/gm pointer spaces"); + + StringRef calleeName = buildCopyMatrixCcToGmCallee(op.getContext()); + auto funcType = rewriter.getFunctionType( + TypeRange{destination->getType(), source->getType(), i64Ty, i64Ty}, + TypeRange{}); + rewriter.create(op.getLoc(), calleeName, TypeRange{}, + ValueRange{*destination, *source, xm, xt}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerCopyMatrixCcToBufOpPattern final + : public OpConversionPattern { +public: + explicit LowerCopyMatrixCcToBufOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(CopyOp op, typename CopyOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value sourceRaw = adaptor.getSource(); + Value destinationRaw = adaptor.getDestination(); + if (!sourceRaw || !destinationRaw) + return rewriter.notifyMatchFailure(op, "expected converted operands"); + if (!isa(sourceRaw.getType()) || + !isa(destinationRaw.getType())) + return rewriter.notifyMatchFailure(op, "expected LLVM pointer src/dst"); + + constexpr unsigned ccAddressSpace = + static_cast(pto::AddressSpace::ACC); + constexpr unsigned targetAddressSpace = + std::is_same_v + ? static_cast(pto::AddressSpace::MAT) + : static_cast(pto::AddressSpace::VEC); + FailureOr source = + reinterpretPointerToAddrSpace(op, sourceRaw, ccAddressSpace); + FailureOr destination = + reinterpretPointerToAddrSpace(op, destinationRaw, targetAddressSpace); + if (failed(source) || failed(destination)) + return rewriter.notifyMatchFailure(op, "failed to map cc->buf pointer spaces"); + + Type i64Ty = rewriter.getI64Type(); + Value config0 = castIntegerLikeTo(op, adaptor.getConfig0(), i64Ty); + Value config1 = castIntegerLikeTo(op, adaptor.getConfig1(), i64Ty); + if (!config0 || !config1) + return rewriter.notifyMatchFailure(op, "failed to cast config operands to i64"); + + FailureOr calleeName = + std::is_same_v + ? FailureOr(buildCopyMatrixCcToCbufCallee(op.getContext())) + : buildCopyMatrixCcToUbCallee(op.getContext(), + op.getDestination().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure( + op, "unsupported copy_matrix_cc_to_{cbuf,ub} element type"); + auto funcType = rewriter.getFunctionType( + TypeRange{destination->getType(), source->getType(), i64Ty, i64Ty}, + TypeRange{}); + rewriter.create(op.getLoc(), *calleeName, TypeRange{}, + ValueRange{*destination, *source, config0, + config1}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerVecScalarMaskedOpPattern final + : public OpConversionPattern { +public: + explicit LowerVecScalarMaskedOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(VecScalarOp op, typename VecScalarOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + StringRef stem = getVecScalarMaskedStem(); + FailureOr calleeName = + buildLaneTypedCallee(op.getContext(), op.getResult().getType(), stem, ".x"); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, + "unsupported vec-scalar VPTO signature"); + + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure( + op, "failed to convert vec-scalar result type"); + + Value input = adaptor.getOperands()[0]; + Value scalar = adaptor.getOperands()[1]; + Value mask = adaptor.getOperands()[2]; + Type expectedMaskType = + this->getTypeConverter()->convertType(op->getOperand(2).getType()); + if (!input || !scalar || !mask || input.getType() != resultType || + mask.getType() != expectedMaskType) { + return rewriter.notifyMatchFailure( + op, "unexpected converted vec-scalar VPTO operand types"); + } + + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType}, + ValueRange{input, scalar, mask}); + state.plannedDecls.push_back( + PlannedDecl{calleeName->str(), call.getCalleeType()}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerReductionUnaryOpPattern final + : public OpConversionPattern { +public: + explicit LowerReductionUnaryOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(ReductionOp op, typename ReductionOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + StringRef stem = getReductionUnaryStem(); + FailureOr calleeName = + buildLaneTypedCallee(op.getContext(), op.getResult().getType(), stem, ".x"); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, + "unsupported reduction VPTO signature"); + + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + Type maskType = this->getTypeConverter()->convertType(op.getMask().getType()); + if (!resultType || !maskType) { + return rewriter.notifyMatchFailure( + op, "failed to convert reduction result type"); + } + + Value input = adaptor.getInput(); + Value mask = adaptor.getMask(); + if (!input || !mask || input.getType() != resultType || + mask.getType() != maskType) { + return rewriter.notifyMatchFailure( + op, "unexpected converted reduction operand types"); + } + + auto call = rewriter.create(op.getLoc(), *calleeName, + TypeRange{resultType}, + ValueRange{input, mask}); + state.plannedDecls.push_back( + PlannedDecl{calleeName->str(), call.getCalleeType()}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerWideningReductionUnaryOpPattern final + : public OpConversionPattern { +public: + explicit LowerWideningReductionUnaryOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(ReductionOp op, typename ReductionOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr calleeName = buildLaneTypedCalleeFromInput( + op.getContext(), op.getInput().getType(), + getReductionUnaryStem(), ".x"); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, + "unsupported widening reduction VPTO signature"); + + Type inputType = + this->getTypeConverter()->convertType(op.getInput().getType()); + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + Type maskType = this->getTypeConverter()->convertType(op.getMask().getType()); + if (!inputType || !resultType || !maskType) + return rewriter.notifyMatchFailure(op, + "failed to convert widening reduction types"); + + Value input = adaptor.getInput(); + Value mask = adaptor.getMask(); + if (!input || !mask || input.getType() != inputType || + mask.getType() != maskType) { + return rewriter.notifyMatchFailure( + op, "unexpected converted widening reduction operand types"); + } + + auto call = rewriter.create(op.getLoc(), *calleeName, + TypeRange{resultType}, + ValueRange{input, mask}); + state.plannedDecls.push_back( + PlannedDecl{calleeName->str(), call.getCalleeType()}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVselOpPattern final : public OpConversionPattern { +public: + explicit LowerVselOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VselOp op, pto::VselOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr calleeName = + buildVselCallee(op.getContext(), op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vsel VPTO signature"); + + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + Type maskType = this->getTypeConverter()->convertType(op.getMask().getType()); + if (!resultType || !maskType) + return rewriter.notifyMatchFailure(op, "failed to convert vsel result type"); + + Value src0 = adaptor.getSrc0(); + Value src1 = adaptor.getSrc1(); + Value mask = adaptor.getMask(); + if (!src0 || !src1 || !mask || src0.getType() != resultType || + src1.getType() != resultType || mask.getType() != maskType) { + return rewriter.notifyMatchFailure(op, + "unexpected converted vsel operand types"); + } + + auto call = rewriter.create(op.getLoc(), *calleeName, + TypeRange{resultType}, + ValueRange{src0, src1, mask}); + state.plannedDecls.push_back( + PlannedDecl{calleeName->str(), call.getCalleeType()}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVdupOpPattern final : public OpConversionPattern { +public: + explicit LowerVdupOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VdupOp op, pto::VdupOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr calleeName = buildVdupCallee(op.getContext(), op); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vdup VPTO signature"); + + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + Type maskType = this->getTypeConverter()->convertType(op.getMask().getType()); + if (!resultType || !maskType) + return rewriter.notifyMatchFailure(op, "failed to convert vdup result type"); + + Value mask = adaptor.getMask(); + if (!mask || mask.getType() != maskType) + return rewriter.notifyMatchFailure(op, + "unexpected converted vdup mask type"); + + SmallVector callArgs; + bool vectorInput = isa(op.getInput().getType()); + if (vectorInput) { + Value input = adaptor.getInput(); + if (!input || input.getType() != resultType) { + return rewriter.notifyMatchFailure( + op, "vector-input vdup requires matching result type"); + } + callArgs.push_back(input); + } else { + Type scalarType = getElementTypeFromVectorLike(op.getResult().getType()); + if (!scalarType || + (op.getInput().getType() != scalarType && + !isCompatibleScalarForSemanticType(scalarType, + op.getInput().getType()))) { + return rewriter.notifyMatchFailure(op, + "unexpected scalar-input vdup type"); + } + FailureOr normalizedScalar = + normalizeVdupScalarOperand(rewriter, op.getLoc(), adaptor.getInput(), + op.getResult().getType()); + if (failed(normalizedScalar)) + return rewriter.notifyMatchFailure(op, + "failed to normalize scalar vdup input"); + Value scalarForCall = normalizeByteScalarOperandForHivmCall( + rewriter, op.getLoc(), *normalizedScalar, scalarType); + callArgs.push_back(scalarForCall); + } + + callArgs.push_back(mask); + callArgs.push_back(getI32Constant(rewriter, op.getLoc(), 1)); + + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType}, callArgs); + state.plannedDecls.push_back( + PlannedDecl{calleeName->str(), call.getCalleeType()}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVbrOpPattern final : public OpConversionPattern { +public: + explicit LowerVbrOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VbrOp op, pto::VbrOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr calleeName = + buildVbrCallee(op.getContext(), + cast(op.getResult().getType()).getElementType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vbr VPTO signature"); + + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, "failed to convert vbr result type"); + + Value scalar = adaptor.getValue(); + Type expectedScalarType = + this->getTypeConverter()->convertType(op.getValue().getType()); + if (!scalar || !expectedScalarType || scalar.getType() != expectedScalarType) + return rewriter.notifyMatchFailure(op, + "unexpected converted vbr operand type"); + + scalar = normalizeByteScalarOperandForHivmCall( + rewriter, op.getLoc(), scalar, + cast(op.getResult().getType()).getElementType()); + + auto funcType = rewriter.getFunctionType(TypeRange{scalar.getType()}, + TypeRange{resultType}); + auto call = rewriter.create(op.getLoc(), *calleeName, + TypeRange{resultType}, + ValueRange{scalar}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVselrOpPattern final : public OpConversionPattern { +public: + explicit LowerVselrOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VselrOp op, pto::VselrOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr calleeName = + buildVselrCallee(op.getContext(), op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vselr VPTO signature"); + + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + auto resultVectorType = dyn_cast(resultType); + if (!resultVectorType) + return rewriter.notifyMatchFailure(op, + "unexpected converted vselr result type"); + + Type intrinsicResultType = resultType; + if (auto floatType = dyn_cast(resultVectorType.getElementType()); + floatType && floatType.isF32()) { + intrinsicResultType = VectorType::get( + resultVectorType.getShape(), rewriter.getI32Type(), + resultVectorType.getScalableDims()); + } + + Type indexType = this->getTypeConverter()->convertType(op.getSrc1().getType()); + if (!indexType) + return rewriter.notifyMatchFailure(op, + "failed to convert vselr index type"); + + Value src0 = adaptor.getSrc0(); + Value src1 = adaptor.getSrc1(); + if (!src0 || !src1 || src1.getType() != indexType) + return rewriter.notifyMatchFailure(op, + "unexpected converted vselr operand types"); + + if (src0.getType() != intrinsicResultType) { + if (src0.getType() != resultType) + return rewriter.notifyMatchFailure(op, + "unexpected converted vselr source type"); + src0 = rewriter.create(op.getLoc(), intrinsicResultType, src0); + } + + auto funcType = rewriter.getFunctionType( + TypeRange{intrinsicResultType, indexType}, TypeRange{intrinsicResultType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{intrinsicResultType}, + ValueRange{src0, src1}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + + Value result = call.getResult(0); + if (intrinsicResultType != resultType) + result = rewriter.create(op.getLoc(), resultType, result); + rewriter.replaceOp(op, result); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerPnotOpPattern final : public OpConversionPattern { +public: + explicit LowerPnotOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::PnotOp op, pto::PnotOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, + "failed to convert pnot result type"); + + Value input = adaptor.getInput(); + Value mask = adaptor.getMask(); + if (!input || !mask || input.getType() != resultType || + mask.getType() != resultType) { + return rewriter.notifyMatchFailure( + op, "unexpected converted pnot operand types"); + } + + StringRef calleeName = getPredicateMaskCallee(op.getContext()); + auto call = rewriter.create(op.getLoc(), calleeName, + TypeRange{resultType}, + ValueRange{input, mask}); + state.plannedDecls.push_back( + PlannedDecl{calleeName.str(), call.getCalleeType()}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerInterleaveOpPattern final + : public OpConversionPattern { +public: + explicit LowerInterleaveOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(InterleaveOp op, typename InterleaveOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + StringRef stem = std::is_same_v ? "vintlv" : "vdintlv"; + FailureOr calleeName = + buildInterleaveCallee(op.getContext(), op.getLow().getType(), stem); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, + "unsupported interleave VPTO signature"); + + Type lowType = this->getTypeConverter()->convertType(op.getLow().getType()); + Type highType = this->getTypeConverter()->convertType(op.getHigh().getType()); + if (!lowType || !highType || lowType != highType) { + return rewriter.notifyMatchFailure( + op, "failed to convert interleave result types"); + } + + Value lhs = adaptor.getLhs(); + Value rhs = adaptor.getRhs(); + if (!lhs || !rhs || lhs.getType() != lowType || rhs.getType() != lowType) { + return rewriter.notifyMatchFailure( + op, "unexpected converted interleave operand types"); + } + + auto funcType = rewriter.getFunctionType(TypeRange{lowType, lowType}, + TypeRange{lowType, highType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{lowType, highType}, ValueRange{lhs, rhs}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerPredicatePackOpPattern final : public OpConversionPattern { +public: + explicit LowerPredicatePackOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(PackOp op, typename PackOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure( + op, "failed to convert predicate-pack result type"); + + auto part = parseHiLoPartImmediate(op.getPart()); + if (!part) + return rewriter.notifyMatchFailure( + op, "unsupported predicate-pack part immediate"); + + Value input = adaptor.getInput(); + if (!input || input.getType() != resultType) + return rewriter.notifyMatchFailure( + op, "unexpected converted predicate-pack operand type"); + + Value partValue = rewriter.create( + op.getLoc(), rewriter.getI32IntegerAttr(*part)); + StringRef calleeName = getPredicatePackCallee(op.getContext()); + auto funcType = rewriter.getFunctionType( + TypeRange{resultType, rewriter.getI32Type()}, TypeRange{resultType}); + auto call = rewriter.create( + op.getLoc(), calleeName, TypeRange{resultType}, ValueRange{input, partValue}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerUnpackOpPattern final : public OpConversionPattern { +public: + explicit LowerUnpackOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(UnpackOp op, typename UnpackOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + StringRef stem = std::is_same_v ? "vsunpack" + : "vzunpack"; + FailureOr calleeName = buildUnpackCallee( + op.getContext(), op.getSrc().getType(), op.getResult().getType(), stem); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, + "unsupported unpack VPTO signature"); + + Type srcType = this->getTypeConverter()->convertType(op.getSrc().getType()); + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + if (!srcType || !resultType) + return rewriter.notifyMatchFailure(op, + "failed to convert unpack types"); + + Value src = adaptor.getSrc(); + if (!src || src.getType() != srcType) { + return rewriter.notifyMatchFailure( + op, "unexpected converted unpack source type"); + } + + Value part = castIntegerLikeTo(op, adaptor.getPart(), rewriter.getI32Type()); + if (!part) + return rewriter.notifyMatchFailure(op, "failed to materialize unpack part"); + + auto funcType = rewriter.getFunctionType(TypeRange{srcType, part.getType()}, + TypeRange{resultType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType}, ValueRange{src, part}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVpackOpPattern final : public OpConversionPattern { +public: + explicit LowerVpackOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VpackOp op, pto::VpackOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr calleeName = + buildVpackCallee(op.getContext(), op.getSrc().getType(), + op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vpack VPTO signature"); + + Type srcType = this->getTypeConverter()->convertType(op.getSrc().getType()); + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + if (!srcType || !resultType) + return rewriter.notifyMatchFailure(op, "failed to convert vpack types"); + + auto partImm = parseHiLoPartImmediate(op.getPart()); + if (!partImm) + return rewriter.notifyMatchFailure(op, "unsupported vpack part immediate"); + + Value src = adaptor.getSrc(); + if (!src || src.getType() != srcType) { + return rewriter.notifyMatchFailure( + op, "unexpected converted vpack source type"); + } + + Value part = getI32Constant(rewriter, op.getLoc(), *partImm); + auto funcType = rewriter.getFunctionType(TypeRange{srcType, part.getType()}, + TypeRange{resultType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType}, ValueRange{src, part}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerPredicateMaskBinaryOpPattern final + : public OpConversionPattern { +public: + explicit LowerPredicateMaskBinaryOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(PredicateMaskOp op, typename PredicateMaskOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure( + op, "failed to convert predicate-mask result type"); + + Value src0 = adaptor.getSrc0(); + Value src1 = adaptor.getSrc1(); + Value mask = adaptor.getMask(); + if (!src0 || !src1 || !mask || src0.getType() != resultType || + src1.getType() != resultType || mask.getType() != resultType) { + return rewriter.notifyMatchFailure( + op, "unexpected converted predicate-mask operand types"); + } + + StringRef calleeName = getPredicateMaskCallee(op.getContext()); + auto call = rewriter.create(op.getLoc(), calleeName, + TypeRange{resultType}, + ValueRange{src0, src1, mask}); + state.plannedDecls.push_back( + PlannedDecl{calleeName.str(), call.getCalleeType()}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerPredicatePairReorderOpPattern final + : public OpConversionPattern { +public: + explicit LowerPredicatePairReorderOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(ReorderOp op, typename ReorderOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector resultTypes; + if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), resultTypes))) + return rewriter.notifyMatchFailure( + op, "failed to convert predicate-pair-reorder result types"); + if (resultTypes.size() != 2 || resultTypes[0] != resultTypes[1]) + return rewriter.notifyMatchFailure( + op, "unexpected predicate-pair-reorder converted result types"); + + Value lhs = adaptor.getLhs(); + Value rhs = adaptor.getRhs(); + if (!lhs || !rhs || lhs.getType() != resultTypes[0] || + rhs.getType() != resultTypes[0]) { + return rewriter.notifyMatchFailure( + op, "unexpected converted predicate-pair-reorder operand types"); + } + + StringRef calleeName = + buildPredicatePairReorderCallee(op.getContext()); + auto call = rewriter.create(op.getLoc(), calleeName, resultTypes, + ValueRange{lhs, rhs}); + state.plannedDecls.push_back( + PlannedDecl{calleeName.str(), call.getCalleeType()}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerCmpOpPattern final : public OpConversionPattern { +public: + explicit LowerCmpOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(CmpOp op, typename CmpOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + constexpr bool isScalarCompare = std::is_same_v; + Type inputType = Type(); + if constexpr (isScalarCompare) + inputType = op.getSrc().getType(); + else + inputType = op.getSrc0().getType(); + FailureOr calleeName = + buildVcmpCallee(op.getContext(), inputType, op.getCmpMode(), + isScalarCompare); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, + "unsupported compare VPTO signature"); + + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + Type maskType = + this->getTypeConverter()->convertType(op.getMask().getType()); + if (!resultType || !maskType) + return rewriter.notifyMatchFailure(op, + "failed to convert compare result type"); + if (resultType != maskType) + return rewriter.notifyMatchFailure(op, + "unexpected compare mask conversion"); + + SmallVector callArgs; + callArgs.append(adaptor.getOperands().begin(), adaptor.getOperands().end()); + if constexpr (isScalarCompare) { + if (callArgs.size() != 3 || !callArgs[0] || !callArgs[1] || !callArgs[2] || + callArgs[2].getType() != maskType) { + return rewriter.notifyMatchFailure( + op, "unexpected converted scalar-compare operand types"); + } + callArgs[1] = normalizeByteScalarOperandForHivmCall( + rewriter, op.getLoc(), callArgs[1], + cast(op.getSrc().getType()).getElementType()); + } else { + if (callArgs.size() != 3 || !callArgs[0] || !callArgs[1] || !callArgs[2] || + callArgs[0].getType() != callArgs[1].getType() || + callArgs[2].getType() != maskType) { + return rewriter.notifyMatchFailure( + op, "unexpected converted compare operand types"); + } + } + + auto call = rewriter.create(op.getLoc(), *calleeName, + TypeRange{resultType}, callArgs); + state.plannedDecls.push_back( + PlannedDecl{calleeName->str(), call.getCalleeType()}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerPltOpPattern final : public OpConversionPattern { +public: + explicit LowerPltOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(PltOp op, typename PltOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value laneCount = castIntegerLikeTo(op, adaptor.getScalar(), rewriter.getI32Type()); + if (!laneCount) + return rewriter.notifyMatchFailure(op, "failed to materialize plt lane count"); + + SmallVector resultTypes; + if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), resultTypes))) + return rewriter.notifyMatchFailure(op, "failed to convert plt result types"); + + StringRef calleeName = buildPltCallee(op.getContext()); + auto funcType = rewriter.getFunctionType(TypeRange{rewriter.getI32Type()}, + resultTypes); + auto call = rewriter.create(op.getLoc(), calleeName, + resultTypes, ValueRange{laneCount}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerPsetOpPattern final : public OpConversionPattern { +public: + explicit LowerPsetOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(PsetOp op, typename PsetOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + auto pattern = parsePredicatePatternImmediate(op.getPattern()); + if (!pattern) + return rewriter.notifyMatchFailure(op, "unsupported pset pattern"); + + SmallVector resultTypes; + if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), resultTypes))) + return rewriter.notifyMatchFailure(op, "failed to convert pset result types"); + + StringRef calleeName = buildPsetCallee(op.getContext()); + Value patternValue = rewriter.create( + op.getLoc(), rewriter.getI32IntegerAttr(*pattern)); + auto funcType = rewriter.getFunctionType(TypeRange{rewriter.getI32Type()}, + resultTypes); + auto call = rewriter.create(op.getLoc(), calleeName, + resultTypes, ValueRange{patternValue}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerPgeOpPattern final : public OpConversionPattern { +public: + explicit LowerPgeOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(PgeOp op, typename PgeOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + auto pattern = parsePredicatePatternImmediate(op.getPattern()); + if (!pattern) + return rewriter.notifyMatchFailure(op, "unsupported pge pattern"); + + SmallVector resultTypes; + if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), resultTypes))) + return rewriter.notifyMatchFailure(op, "failed to convert pge result types"); + + StringRef calleeName = buildPgeCallee(op.getContext()); + Value patternValue = rewriter.create( + op.getLoc(), rewriter.getI32IntegerAttr(*pattern)); + Value zero = rewriter.create(op.getLoc(), + rewriter.getI32IntegerAttr(0)); + auto funcType = rewriter.getFunctionType( + TypeRange{rewriter.getI32Type(), rewriter.getI32Type()}, resultTypes); + auto call = + rewriter.create(op.getLoc(), calleeName, resultTypes, + ValueRange{patternValue, zero}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVldsOpPattern final : public OpConversionPattern { +public: + explicit LowerVldsOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VldsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type elementType = getElementTypeFromVectorLike(op.getResult().getType()); + if (!elementType) + return rewriter.notifyMatchFailure(op, "unsupported vlds element type"); + auto offsetBytes = convertElementOffsetToBytes(op, adaptor.getOffset(), elementType); + auto basePtr = dyn_cast(adaptor.getSource().getType()); + auto dist = + parseLoadDistImmediate(op.getDist().value_or("NORM"), elementType); + if (failed(offsetBytes) || !basePtr || !dist) + return rewriter.notifyMatchFailure(op, "failed to materialize vlds operands"); + + SmallVector resultTypes; + if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), resultTypes))) + return rewriter.notifyMatchFailure(op, "failed to convert vlds result types"); + + FailureOr calleeName = buildVldsCallee(op.getContext(), + op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vlds signature"); + + Value distValue = rewriter.create( + op.getLoc(), rewriter.getI32IntegerAttr(*dist)); + Value zero = rewriter.create(op.getLoc(), + rewriter.getI32IntegerAttr(0)); + SmallVector args{adaptor.getSource(), *offsetBytes, distValue, zero}; + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getSource().getType(), rewriter.getI32Type(), + rewriter.getI32Type(), rewriter.getI32Type()}, + resultTypes); + auto call = rewriter.create(op.getLoc(), *calleeName, + resultTypes, args); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVldsPostOpPattern final + : public OpConversionPattern { +public: + explicit LowerVldsPostOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::VldsPostOp op, pto::VldsPostOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type elementType = getElementTypeFromVectorLike(op.getResult().getType()); + if (!elementType) + return rewriter.notifyMatchFailure(op, "unsupported vlds_post element type"); + + auto offsetBytes = convertElementOffsetToBytes(op, adaptor.getOffset(), elementType); + auto basePtr = dyn_cast(adaptor.getSource().getType()); + auto dist = + parseLoadDistImmediate(op.getDist().value_or("NORM"), elementType); + if (failed(offsetBytes) || !basePtr || !dist) { + return rewriter.notifyMatchFailure(op, + "failed to materialize vlds_post operands"); + } + + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + Type updatedSourceType = + this->getTypeConverter()->convertType(op.getUpdatedSource().getType()); + if (!resultType || !updatedSourceType || updatedSourceType != adaptor.getSource().getType()) { + return rewriter.notifyMatchFailure(op, + "failed to convert vlds_post result types"); + } + + FailureOr calleeName = + buildVldsPostCallee(op.getContext(), op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vlds_post signature"); + + Value distValue = getI32Constant(rewriter, op.getLoc(), *dist); + Value postValue = getI32Constant(rewriter, op.getLoc(), 1); + SmallVector args{adaptor.getSource(), *offsetBytes, distValue, postValue}; + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getSource().getType(), (*offsetBytes).getType(), + distValue.getType(), postValue.getType()}, + TypeRange{resultType, updatedSourceType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType, updatedSourceType}, args); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVldsx2OpPattern final : public OpConversionPattern { +public: + explicit LowerVldsx2OpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::Vldsx2Op op, pto::Vldsx2Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type elementType = getElementTypeFromVectorLike(op.getLow().getType()); + if (!elementType) + return rewriter.notifyMatchFailure(op, "unsupported vldsx2 element type"); + + auto offsetBytes = + convertElementOffsetToBytes(op, adaptor.getOffset(), elementType); + auto basePtr = dyn_cast(adaptor.getSource().getType()); + auto dist = parseLoadX2DistImmediate(op.getDist(), elementType); + if (failed(offsetBytes) || !basePtr || !dist) { + return rewriter.notifyMatchFailure(op, + "failed to materialize vldsx2 operands"); + } + + SmallVector resultTypes; + if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), + resultTypes)) || + resultTypes.size() != 2) { + return rewriter.notifyMatchFailure(op, + "failed to convert vldsx2 result types"); + } + + FailureOr calleeName = + buildVldsx2Callee(op.getContext(), op.getLow().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vldsx2 signature"); + + Value distValue = getI32Constant(rewriter, op.getLoc(), *dist); + Value zeroValue = getI32Constant(rewriter, op.getLoc(), 0); + SmallVector args{adaptor.getSource(), *offsetBytes, distValue, + zeroValue}; + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getSource().getType(), (*offsetBytes).getType(), + distValue.getType(), zeroValue.getType()}, + resultTypes); + auto call = rewriter.create(op.getLoc(), *calleeName, + resultTypes, args); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVsldbOpPattern final : public OpConversionPattern { +public: + explicit LowerVsldbOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VsldbOp op, pto::VsldbOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto basePtr = dyn_cast(adaptor.getSource().getType()); + Value packedStride = + packBlockRepeatStride(op, adaptor.getBlockStride(), adaptor.getRepeatStride()); + if (!basePtr || !packedStride) + return rewriter.notifyMatchFailure(op, "failed to materialize vsldb operands"); + + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, "failed to convert vsldb result type"); + + StringRef calleeName = buildVsldbCallee(op.getContext()); + Value zeroValue = getI32Constant(rewriter, op.getLoc(), 0); + SmallVector args{adaptor.getSource(), packedStride, zeroValue, + adaptor.getMask()}; + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getSource().getType(), packedStride.getType(), + zeroValue.getType(), adaptor.getMask().getType()}, + TypeRange{resultType}); + auto call = rewriter.create(op.getLoc(), calleeName, + TypeRange{resultType}, args); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerInitAlignOpPattern final + : public OpConversionPattern { +public: + explicit LowerInitAlignOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::InitAlignOp op, pto::InitAlignOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, "failed to convert init_align result type"); + + StringRef calleeName = buildInitAlignCallee(op.getContext()); + auto funcType = rewriter.getFunctionType(TypeRange{}, TypeRange{resultType}); + auto call = + rewriter.create(op.getLoc(), calleeName, TypeRange{resultType}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVldasOpPattern final : public OpConversionPattern { +public: + explicit LowerVldasOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VldasOp op, pto::VldasOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto sourceType = dyn_cast(adaptor.getSource().getType()); + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!sourceType || !resultType) + return rewriter.notifyMatchFailure(op, + "expected converted vldas operand/result types"); + + StringRef calleeName = buildVldasCallee(op.getContext()); + auto funcType = + rewriter.getFunctionType(TypeRange{adaptor.getSource().getType()}, + TypeRange{resultType}); + auto call = rewriter.create(op.getLoc(), calleeName, + TypeRange{resultType}, + ValueRange{adaptor.getSource()}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVldusOpPattern final : public OpConversionPattern { +public: + explicit LowerVldusOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VldusOp op, pto::VldusOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto sourceType = dyn_cast(adaptor.getSource().getType()); + SmallVector resultTypes; + if (!sourceType || + failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), resultTypes)) || + resultTypes.size() != 2 || adaptor.getAlign().getType() != resultTypes[1]) { + return rewriter.notifyMatchFailure(op, + "expected converted vldus operand/result types"); + } + + FailureOr calleeName = + buildVldusCallee(op.getContext(), op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vldus signature"); + + SmallVector intrinsicResultTypes(resultTypes.begin(), resultTypes.end()); + // The installed no-post A5 vldus intrinsic returns an extra hidden base ptr. + intrinsicResultTypes.push_back(adaptor.getSource().getType()); + + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getSource().getType(), adaptor.getAlign().getType()}, + intrinsicResultTypes); + auto call = rewriter.create( + op.getLoc(), *calleeName, intrinsicResultTypes, + ValueRange{adaptor.getSource(), adaptor.getAlign()}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults().take_front(resultTypes.size())); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerSprclrOpPattern final : public OpConversionPattern { +public: + explicit LowerSprclrOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::SprclrOp op, pto::SprclrOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + auto spr = parseSprImmediate(op.getSpr()); + if (!spr) + return rewriter.notifyMatchFailure(op, "unsupported sprclr target"); + + StringRef calleeName = buildSprclrCallee(op.getContext()); + Value sprValue = rewriter.create( + op.getLoc(), rewriter.getI16IntegerAttr(*spr)); + auto funcType = rewriter.getFunctionType(TypeRange{sprValue.getType()}, TypeRange{}); + rewriter.create(op.getLoc(), calleeName, TypeRange{}, ValueRange{sprValue}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVstsOpPattern final : public OpConversionPattern { +public: + explicit LowerVstsOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VstsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type elementType = getElementTypeFromVectorLike(op.getValue().getType()); + if (!elementType) + return rewriter.notifyMatchFailure(op, "unsupported vsts element type"); + auto offsetBytes = + convertElementOffsetToBytes(op, adaptor.getOffset(), elementType); + auto basePtr = dyn_cast(adaptor.getDestination().getType()); + auto dist = + parseStoreDistImmediate(op.getDist().value_or(""), elementType); + if (failed(offsetBytes) || !basePtr || !dist) + return rewriter.notifyMatchFailure(op, "failed to materialize vsts operands"); + + FailureOr calleeName = + buildVstsCallee(op.getContext(), op.getValue().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vsts signature"); + + Value distValue = rewriter.create( + op.getLoc(), rewriter.getI32IntegerAttr(*dist)); + Value zero = rewriter.create(op.getLoc(), + rewriter.getI32IntegerAttr(0)); + SmallVector args{adaptor.getValue(), adaptor.getDestination(), + *offsetBytes, distValue, zero, adaptor.getMask()}; + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getValue().getType(), adaptor.getDestination().getType(), + rewriter.getI32Type(), rewriter.getI32Type(), + rewriter.getI32Type(), adaptor.getMask().getType()}, + TypeRange{}); + rewriter.create(op.getLoc(), *calleeName, TypeRange{}, args); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVsstbOpPattern final : public OpConversionPattern { +public: + explicit LowerVsstbOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VsstbOp op, pto::VsstbOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto basePtr = + dyn_cast(adaptor.getDestination().getType()); + Value packedStride = + packBlockRepeatStride(op, adaptor.getBlockStride(), adaptor.getRepeatStride()); + if (!basePtr || !packedStride) + return rewriter.notifyMatchFailure(op, "failed to materialize vsstb operands"); + + StringRef calleeName = buildVsstbCallee(op.getContext()); + Value zeroValue = getI32Constant(rewriter, op.getLoc(), 0); + SmallVector args{adaptor.getValue(), adaptor.getDestination(), + packedStride, zeroValue, adaptor.getMask()}; + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getValue().getType(), adaptor.getDestination().getType(), + packedStride.getType(), zeroValue.getType(), + adaptor.getMask().getType()}, + TypeRange{}); + rewriter.create(op.getLoc(), calleeName, TypeRange{}, args); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVstsPostOpPattern final + : public OpConversionPattern { +public: + explicit LowerVstsPostOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::VstsPostOp op, pto::VstsPostOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type elementType = getElementTypeFromVectorLike(op.getValue().getType()); + if (!elementType) + return rewriter.notifyMatchFailure(op, "unsupported vsts_post element type"); + + auto offsetBytes = convertElementOffsetToBytes(op, adaptor.getOffset(), elementType); + auto basePtr = + dyn_cast(adaptor.getDestination().getType()); + auto dist = + parseStoreDistImmediate(op.getDist().value_or(""), elementType); + if (failed(offsetBytes) || !basePtr || !dist) { + return rewriter.notifyMatchFailure(op, + "failed to materialize vsts_post operands"); + } + + Type updatedDestinationType = + this->getTypeConverter()->convertType(op.getUpdatedDestination().getType()); + if (!updatedDestinationType || updatedDestinationType != adaptor.getDestination().getType()) { + return rewriter.notifyMatchFailure(op, + "failed to convert vsts_post result type"); + } + + FailureOr calleeName = + buildVstsPostCallee(op.getContext(), op.getValue().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vsts_post signature"); + + Value distValue = getI32Constant(rewriter, op.getLoc(), *dist); + Value postValue = getI32Constant(rewriter, op.getLoc(), 1); + SmallVector args{adaptor.getValue(), adaptor.getDestination(), *offsetBytes, + distValue, postValue, adaptor.getMask()}; + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getValue().getType(), adaptor.getDestination().getType(), + (*offsetBytes).getType(), distValue.getType(), postValue.getType(), + adaptor.getMask().getType()}, + TypeRange{updatedDestinationType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{updatedDestinationType}, args); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVstsx2OpPattern final : public OpConversionPattern { +public: + explicit LowerVstsx2OpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::Vstsx2Op op, pto::Vstsx2Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type elementType = getElementTypeFromVectorLike(op.getLow().getType()); + if (!elementType) + return rewriter.notifyMatchFailure(op, "unsupported vstsx2 element type"); + + auto offsetBytes = + convertElementOffsetToBytes(op, adaptor.getOffset(), elementType); + auto basePtr = + dyn_cast(adaptor.getDestination().getType()); + auto dist = parseStoreX2DistImmediate(op.getDist(), elementType); + if (failed(offsetBytes) || !basePtr || !dist) { + return rewriter.notifyMatchFailure(op, + "failed to materialize vstsx2 operands"); + } + + FailureOr calleeName = + buildVstsx2Callee(op.getContext(), op.getLow().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vstsx2 signature"); + + Value distValue = getI32Constant(rewriter, op.getLoc(), *dist); + Value zeroValue = getI32Constant(rewriter, op.getLoc(), 0); + SmallVector args{adaptor.getLow(), adaptor.getHigh(), + adaptor.getDestination(), *offsetBytes, distValue, + zeroValue, adaptor.getMask()}; + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getLow().getType(), adaptor.getHigh().getType(), + adaptor.getDestination().getType(), (*offsetBytes).getType(), + distValue.getType(), zeroValue.getType(), + adaptor.getMask().getType()}, + TypeRange{}); + rewriter.create(op.getLoc(), *calleeName, TypeRange{}, args); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerPstuOpPattern final : public OpConversionPattern { +public: + explicit LowerPstuOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::PstuOp op, pto::PstuOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr calleeName = buildPstuCallee(op.getContext(), op); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported pstu signature"); + + SmallVector resultTypes; + if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), resultTypes))) + return rewriter.notifyMatchFailure(op, "failed to convert pstu result types"); + if (resultTypes.size() != 2) + return rewriter.notifyMatchFailure(op, "unexpected converted pstu result arity"); + + auto baseType = dyn_cast(adaptor.getBase().getType()); + if (!baseType || adaptor.getAlignIn().getType() != resultTypes[0] || + adaptor.getBase().getType() != resultTypes[1]) { + return rewriter.notifyMatchFailure(op, + "unexpected converted pstu operand/result types"); + } + + SmallVector args{adaptor.getValue(), adaptor.getBase(), adaptor.getAlignIn()}; + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getValue().getType(), adaptor.getBase().getType(), + adaptor.getAlignIn().getType()}, + resultTypes); + auto call = rewriter.create(op.getLoc(), *calleeName, resultTypes, + args); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVstusOpPattern final : public OpConversionPattern { +public: + explicit LowerVstusOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VstusOp op, pto::VstusOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type elementType = getElementTypeFromVectorLike(op.getValue().getType()); + if (!elementType) + return rewriter.notifyMatchFailure(op, "unsupported vstus element type"); + + auto offsetBytes = convertElementOffsetToBytes(op, adaptor.getOffset(), elementType); + if (failed(offsetBytes)) + return rewriter.notifyMatchFailure(op, "failed to convert vstus offset"); + + Type resultType = this->getTypeConverter()->convertType(op.getAlignOut().getType()); + auto baseType = dyn_cast(adaptor.getBase().getType()); + if (!resultType || !baseType || adaptor.getAlignIn().getType() != resultType) { + return rewriter.notifyMatchFailure(op, + "unexpected converted vstus operand/result types"); + } + + StringRef calleeName = buildVstusCallee(op.getContext()); + SmallVector args{adaptor.getValue(), adaptor.getBase(), *offsetBytes, + adaptor.getAlignIn()}; + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getValue().getType(), adaptor.getBase().getType(), + (*offsetBytes).getType(), adaptor.getAlignIn().getType()}, + TypeRange{resultType}); + auto call = + rewriter.create(op.getLoc(), calleeName, TypeRange{resultType}, args); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVsturOpPattern final : public OpConversionPattern { +public: + explicit LowerVsturOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VsturOp op, pto::VsturOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto postMode = parsePostModeImmediate(op.getMode()); + if (!postMode) + return rewriter.notifyMatchFailure(op, "unsupported vstur mode immediate"); + + Type resultType = this->getTypeConverter()->convertType(op.getAlignOut().getType()); + auto baseType = dyn_cast(adaptor.getBase().getType()); + if (!resultType || !baseType || adaptor.getAlignIn().getType() != resultType) { + return rewriter.notifyMatchFailure(op, + "unexpected converted vstur operand/result types"); + } + + StringRef calleeName = buildVsturCallee(op.getContext()); + Value modeValue = getI32Constant(rewriter, op.getLoc(), *postMode); + Value zeroValue = getI32Constant(rewriter, op.getLoc(), 0); + SmallVector args{adaptor.getValue(), adaptor.getBase(), adaptor.getAlignIn(), + modeValue, zeroValue}; + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getValue().getType(), adaptor.getBase().getType(), + adaptor.getAlignIn().getType(), modeValue.getType(), + zeroValue.getType()}, + TypeRange{resultType}); + auto call = + rewriter.create(op.getLoc(), calleeName, TypeRange{resultType}, args); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVstarOpPattern final : public OpConversionPattern { +public: + explicit LowerVstarOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VstarOp op, pto::VstarOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto baseType = dyn_cast(adaptor.getDestination().getType()); + Type alignType = this->getTypeConverter()->convertType(op.getValue().getType()); + if (!baseType || !alignType || adaptor.getValue().getType() != alignType) { + return rewriter.notifyMatchFailure(op, + "unexpected converted vstar operand types"); + } + + StringRef calleeName = buildVstarCallee(op.getContext()); + Value zeroValue = getI32Constant(rewriter, op.getLoc(), 0); + SmallVector args{adaptor.getValue(), adaptor.getDestination(), zeroValue}; + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getValue().getType(), adaptor.getDestination().getType(), + zeroValue.getType()}, + TypeRange{}); + rewriter.create(op.getLoc(), calleeName, TypeRange{}, args); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVstasOpPattern final : public OpConversionPattern { +public: + explicit LowerVstasOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VstasOp op, pto::VstasOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto baseType = dyn_cast(adaptor.getDestination().getType()); + Type alignType = this->getTypeConverter()->convertType(op.getValue().getType()); + auto dstType = dyn_cast(op.getDestination().getType()); + if (!baseType || !alignType || adaptor.getValue().getType() != alignType || !dstType) { + return rewriter.notifyMatchFailure(op, + "unexpected converted vstas operand types"); + } + + auto offsetBytes = + convertElementOffsetToBytes(op, adaptor.getOffset(), dstType.getElementType()); + if (failed(offsetBytes)) + return rewriter.notifyMatchFailure(op, "failed to convert vstas offset"); + + StringRef calleeName = buildVstasCallee(op.getContext()); + Value zeroValue = getI32Constant(rewriter, op.getLoc(), 0); + SmallVector args{adaptor.getValue(), adaptor.getDestination(), *offsetBytes, + zeroValue}; + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getValue().getType(), adaptor.getDestination().getType(), + (*offsetBytes).getType(), zeroValue.getType()}, + TypeRange{}); + rewriter.create(op.getLoc(), calleeName, TypeRange{}, args); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVgather2OpPattern final + : public OpConversionPattern { +public: + explicit LowerVgather2OpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::Vgather2Op op, pto::Vgather2Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type elemType = getElementTypeFromVectorLike(op.getResult().getType()); + auto basePtr = dyn_cast(adaptor.getSource().getType()); + if (!elemType || !basePtr) + return rewriter.notifyMatchFailure(op, + "unexpected converted vgather2 operand types"); + + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, "failed to convert vgather2 result type"); + + FailureOr calleeName = + buildVgather2Callee(op.getContext(), op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vgather2 signature"); + + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getSource().getType(), adaptor.getOffsets().getType(), + adaptor.getMask().getType()}, + TypeRange{resultType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType}, + ValueRange{adaptor.getSource(), adaptor.getOffsets(), adaptor.getMask()}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVgather2BcOpPattern final + : public OpConversionPattern { +public: + explicit LowerVgather2BcOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::Vgather2BcOp op, pto::Vgather2BcOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto basePtr = dyn_cast(adaptor.getSource().getType()); + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!basePtr || !resultType) + return rewriter.notifyMatchFailure(op, + "unexpected converted vgather2_bc operand/result types"); + + FailureOr calleeName = + buildVgather2BcCallee(op.getContext(), op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vgather2_bc signature"); + + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getSource().getType(), adaptor.getOffsets().getType(), + adaptor.getMask().getType()}, + TypeRange{resultType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType}, + ValueRange{adaptor.getSource(), adaptor.getOffsets(), adaptor.getMask()}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVgatherbOpPattern final + : public OpConversionPattern { +public: + explicit LowerVgatherbOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::VgatherbOp op, pto::VgatherbOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto basePtr = dyn_cast(adaptor.getSource().getType()); + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!basePtr || !resultType) + return rewriter.notifyMatchFailure(op, + "unexpected converted vgatherb operand/result types"); + + FailureOr calleeName = + buildVgatherbCallee(op.getContext(), op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vgatherb signature"); + + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getSource().getType(), adaptor.getOffsets().getType(), + adaptor.getMask().getType()}, + TypeRange{resultType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType}, + ValueRange{adaptor.getSource(), adaptor.getOffsets(), adaptor.getMask()}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVscatterOpPattern final + : public OpConversionPattern { +public: + explicit LowerVscatterOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::VscatterOp op, pto::VscatterOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type elemType = getElementTypeFromVectorLike(op.getValue().getType()); + auto basePtr = + dyn_cast(adaptor.getDestination().getType()); + if (!elemType || !basePtr) + return rewriter.notifyMatchFailure(op, + "unexpected converted vscatter operand types"); + + FailureOr calleeName = + buildVscatterCallee(op.getContext(), op.getValue().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vscatter signature"); + + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getValue().getType(), adaptor.getDestination().getType(), + adaptor.getOffsets().getType(), adaptor.getMask().getType()}, + TypeRange{}); + rewriter.create( + op.getLoc(), *calleeName, TypeRange{}, + ValueRange{adaptor.getValue(), adaptor.getDestination(), + adaptor.getOffsets(), adaptor.getMask()}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVaxpyOpPattern final : public OpConversionPattern { +public: + explicit LowerVaxpyOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VaxpyOp op, pto::VaxpyOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type elemType = getElementTypeFromVectorLike(op.getResult().getType()); + if (!elemType) + return rewriter.notifyMatchFailure(op, "unsupported vaxpy signature"); + + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, "failed to convert vaxpy result type"); + + FailureOr calleeName = + buildVaxpyCallee(op.getContext(), op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vaxpy callee"); + + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getSrc1().getType(), adaptor.getSrc0().getType(), + adaptor.getAlpha().getType(), adaptor.getMask().getType()}, + TypeRange{resultType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType}, + ValueRange{adaptor.getSrc1(), adaptor.getSrc0(), adaptor.getAlpha(), + adaptor.getMask()}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVciOpPattern final : public OpConversionPattern { +public: + explicit LowerVciOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VciOp op, pto::VciOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto order = parseOrderImmediate(op.getOrder().value_or("ASC")); + if (!order) + return rewriter.notifyMatchFailure(op, "unsupported vci order"); + + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, "failed to convert vci result type"); + + FailureOr calleeName = + buildVciCallee(op.getContext(), op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vci callee"); + + Value indexValue = adaptor.getIndex(); + Type resultElemType = + cast(op.getResult().getType()).getElementType(); + if (auto intType = dyn_cast(resultElemType)) { + if (intType.getWidth() == 8) { + Type loweredIndexType = rewriter.getI16Type(); + if (intType.isUnsigned()) + indexValue = rewriter.create(op.getLoc(), + loweredIndexType, + indexValue); + else + indexValue = rewriter.create(op.getLoc(), + loweredIndexType, + indexValue); + } + } + + Value orderValue = getI32Constant(rewriter, op.getLoc(), *order); + auto funcType = rewriter.getFunctionType( + TypeRange{indexValue.getType(), orderValue.getType()}, + TypeRange{resultType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType}, + ValueRange{indexValue, orderValue}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVexpdifOpPattern final + : public OpConversionPattern { +public: + explicit LowerVexpdifOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::VexpdifOp op, pto::VexpdifOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto part = parsePartImmediate(op.getPart()); + if (!part) + return rewriter.notifyMatchFailure(op, "unsupported vexpdif signature"); + + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, "failed to convert vexpdif result type"); + + FailureOr calleeName = + buildVexpdifCallee(op.getContext(), op.getInput().getType(), + op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vexpdif callee"); + + Value partValue = getI32Constant(rewriter, op.getLoc(), *part); + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getInput().getType(), adaptor.getMax().getType(), + adaptor.getMask().getType(), partValue.getType()}, + TypeRange{resultType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType}, + ValueRange{adaptor.getInput(), adaptor.getMax(), adaptor.getMask(), + partValue}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVbitsortOpPattern final + : public OpConversionPattern { +public: + explicit LowerVbitsortOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::VbitsortOp op, pto::VbitsortOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto dstType = + dyn_cast(adaptor.getDestination().getType()); + auto srcType = dyn_cast(adaptor.getSource().getType()); + auto idxType = + dyn_cast(adaptor.getIndices().getType()); + if (!dstType || !srcType || !idxType) + return rewriter.notifyMatchFailure(op, + "unexpected converted vbitsort operand types"); + + FailureOr config = packVbitsortConfig(op, adaptor.getRepeatTimes()); + if (failed(config)) + return rewriter.notifyMatchFailure(op, "failed to pack vbitsort config"); + + FailureOr calleeName = buildVbitsortCallee(op.getContext(), op); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vbitsort signature"); + + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getDestination().getType(), adaptor.getSource().getType(), + adaptor.getIndices().getType(), (*config).getType()}, + TypeRange{}); + rewriter.create( + op.getLoc(), *calleeName, TypeRange{}, + ValueRange{adaptor.getDestination(), adaptor.getSource(), + adaptor.getIndices(), *config}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVmrgsort4OpPattern final + : public OpConversionPattern { +public: + explicit LowerVmrgsort4OpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::Vmrgsort4Op op, pto::Vmrgsort4Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto dstType = + dyn_cast(adaptor.getDestination().getType()); + auto src0Type = + dyn_cast(adaptor.getSource0().getType()); + auto src1Type = + dyn_cast(adaptor.getSource1().getType()); + auto src2Type = + dyn_cast(adaptor.getSource2().getType()); + auto src3Type = + dyn_cast(adaptor.getSource3().getType()); + if (!dstType || !src0Type || !src1Type || !src2Type || !src3Type) + return rewriter.notifyMatchFailure( + op, "unexpected converted vmrgsort4 operand types"); + + Type elemType = + cast(op.getDestination().getType()).getElementType(); + FailureOr packedSrc = packVmrgsort4SourceAddr( + op, adaptor.getSource0(), adaptor.getSource1(), adaptor.getSource2(), + adaptor.getSource3(), elemType); + if (failed(packedSrc)) + return rewriter.notifyMatchFailure( + op, "failed to pack vmrgsort4 source addresses"); + + FailureOr dst = reinterpretPointerToAddrSpace(op, adaptor.getDestination(), 6); + if (failed(dst)) + return rewriter.notifyMatchFailure(op, "failed to normalize vmrgsort4 destination"); + + FailureOr calleeName = buildVmrgsort4Callee(op.getContext(), op); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vmrgsort4 signature"); + + auto funcType = rewriter.getFunctionType( + TypeRange{(*dst).getType(), (*packedSrc).getType(), + adaptor.getCount().getType(), adaptor.getConfig().getType()}, + TypeRange{}); + rewriter.create( + op.getLoc(), *calleeName, TypeRange{}, + ValueRange{*dst, *packedSrc, adaptor.getCount(), adaptor.getConfig()}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVcvtOpPattern final : public OpConversionPattern { +public: + explicit LowerVcvtOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VcvtOp op, pto::VcvtOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr contract = buildVcvtContract(op); + if (failed(contract)) + return rewriter.notifyMatchFailure(op, "unsupported vcvt type pair"); + + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, "failed to convert vcvt result type"); + + SmallVector callArgs; + SmallVector argTypes; + callArgs.push_back(adaptor.getInput()); + argTypes.push_back(adaptor.getInput().getType()); + callArgs.push_back(adaptor.getMask()); + argTypes.push_back(adaptor.getMask().getType()); + + auto appendRndArg = [&]() -> LogicalResult { + auto roundMode = + op.getRndAttr() ? parseRoundModeImmediate(*op.getRnd()) : std::nullopt; + if (!roundMode) + return rewriter.notifyMatchFailure(op, "vcvt requires valid rnd attr"); + Value roundValue = getI32Constant(rewriter, op.getLoc(), *roundMode); + callArgs.push_back(roundValue); + argTypes.push_back(roundValue.getType()); + return success(); + }; + + auto appendSatArg = [&]() -> LogicalResult { + auto saturation = + op.getSatAttr() ? parseSaturationImmediate(*op.getSat()) : std::nullopt; + if (!saturation) + return rewriter.notifyMatchFailure(op, "vcvt requires valid sat attr"); + Value satValue = getI32Constant(rewriter, op.getLoc(), *saturation); + callArgs.push_back(satValue); + argTypes.push_back(satValue.getType()); + return success(); + }; + + if ((*contract).satBeforeRnd) { + if ((*contract).requiresSat && failed(appendSatArg())) + return failure(); + if ((*contract).requiresRnd && failed(appendRndArg())) + return failure(); + } else { + if ((*contract).requiresRnd && failed(appendRndArg())) + return failure(); + if ((*contract).requiresSat && failed(appendSatArg())) + return failure(); + } + + if ((*contract).requiresPart) { + auto part = + op.getPartAttr() ? parseVcvtPartImmediate(*op.getPart()) : std::nullopt; + if (!part) + return rewriter.notifyMatchFailure(op, "vcvt requires valid part attr"); + Value partValue = getI32Constant(rewriter, op.getLoc(), *part); + callArgs.push_back(partValue); + argTypes.push_back(partValue.getType()); + } + + auto funcType = rewriter.getFunctionType(argTypes, TypeRange{resultType}); + auto call = rewriter.create( + op.getLoc(), StringRef((*contract).intrinsic), TypeRange{resultType}, callArgs); + state.plannedDecls.push_back( + PlannedDecl{std::string((*contract).intrinsic), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerVbitcastOpPattern final + : public OpConversionPattern { +public: + explicit LowerVbitcastOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context) {} + + LogicalResult + matchAndRewrite(pto::VbitcastOp op, pto::VbitcastOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, + "failed to convert vbitcast result type"); + rewriter.replaceOpWithNewOp(op, resultType, + adaptor.getInput()); + return success(); + } +}; + +class LowerPbitcastOpPattern final + : public OpConversionPattern { +public: + explicit LowerPbitcastOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context) {} + + LogicalResult + matchAndRewrite(pto::PbitcastOp op, pto::PbitcastOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, + "failed to convert pbitcast result type"); + if (adaptor.getInput().getType() != resultType) { + return rewriter.notifyMatchFailure( + op, "pbitcast expects identical lowered input/result types"); + } + rewriter.replaceOp(op, adaptor.getInput()); + return success(); + } +}; + +class LowerVtrcOpPattern final : public OpConversionPattern { +public: + explicit LowerVtrcOpPattern(TypeConverter &typeConverter, MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(pto::VtrcOp op, pto::VtrcOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto roundMode = parseRoundModeImmediate(op.getRoundMode()); + if (!roundMode) + return rewriter.notifyMatchFailure(op, "unsupported vtrc signature"); + + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, "failed to convert vtrc result type"); + + FailureOr calleeName = + buildVtrcCallee(op.getContext(), op.getResult().getType()); + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported vtrc callee"); + + Value roundValue = getI32Constant(rewriter, op.getLoc(), *roundMode); + auto funcType = rewriter.getFunctionType( + TypeRange{adaptor.getInput().getType(), roundValue.getType(), + adaptor.getMask().getType()}, + TypeRange{resultType}); + auto call = rewriter.create( + op.getLoc(), *calleeName, TypeRange{resultType}, + ValueRange{adaptor.getInput(), roundValue, adaptor.getMask()}); + state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerPredicateStoreOpPattern final : public OpConversionPattern { +public: + explicit LowerPredicateStoreOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(StoreOp op, typename StoreOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto llvmDestType = + dyn_cast(adaptor.getDestination().getType()); + Type valueType = this->getTypeConverter()->convertType(op.getValue().getType()); + if (!llvmDestType || !valueType) + return rewriter.notifyMatchFailure( + op, "expected converted predicate-store operand types"); + + auto dist = parsePredicateStoreDistImmediate(op.getDist()); + if (!dist) + return rewriter.notifyMatchFailure( + op, "unsupported predicate-store dist immediate"); + + Value offset = castIntegerLikeTo(op, adaptor.getOffset(), rewriter.getI32Type()); + if (!offset) + return rewriter.notifyMatchFailure( + op, "failed to convert predicate-store offset to i32"); + + StringRef calleeName = getPredicateStoreCallee(op.getContext()); + SmallVector args; + args.push_back(adaptor.getValue()); + args.push_back(adaptor.getDestination()); + args.push_back(offset); + args.push_back(rewriter.create( + op.getLoc(), rewriter.getI32IntegerAttr(*dist))); + args.push_back(rewriter.create( + op.getLoc(), rewriter.getI32IntegerAttr(0))); + auto funcType = rewriter.getFunctionType( + TypeRange{valueType, llvmDestType, rewriter.getI32Type(), + rewriter.getI32Type(), rewriter.getI32Type()}, + TypeRange{}); + rewriter.create(op.getLoc(), calleeName, TypeRange{}, args); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerPredicateLoadOpPattern final : public OpConversionPattern { +public: + explicit LowerPredicateLoadOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(LoadOp op, typename LoadOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto llvmSourceType = + dyn_cast(adaptor.getSource().getType()); + Type resultType = + this->getTypeConverter()->convertType(op.getResult().getType()); + if (!llvmSourceType || !resultType) + return rewriter.notifyMatchFailure( + op, "expected converted predicate-load operand/result types"); + + auto dist = parsePredicateLoadDistImmediate(op.getDist()); + if (!dist) + return rewriter.notifyMatchFailure( + op, "unsupported predicate-load dist immediate"); + + Value offset = castIntegerLikeTo(op, adaptor.getOffset(), rewriter.getI32Type()); + if (!offset) + return rewriter.notifyMatchFailure( + op, "failed to convert predicate-load offset to i32"); + + StringRef calleeName = getPredicateLoadCallee(op.getContext()); + SmallVector args; + args.push_back(adaptor.getSource()); + args.push_back(offset); + args.push_back(rewriter.create( + op.getLoc(), rewriter.getI32IntegerAttr(*dist))); + args.push_back(rewriter.create( + op.getLoc(), rewriter.getI32IntegerAttr(0))); + auto funcType = rewriter.getFunctionType( + TypeRange{llvmSourceType, rewriter.getI32Type(), rewriter.getI32Type(), + rewriter.getI32Type()}, + TypeRange{resultType}); + auto call = + rewriter.create(op.getLoc(), calleeName, resultType, args); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerSetLoopConfigOpPattern final : public OpConversionPattern { +public: + explicit LowerSetLoopConfigOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(LoopOp op, typename LoopOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr packed = failure(); + if constexpr (std::is_same_v || + std::is_same_v) { + packed = packLoopSize(op, adaptor.getFirst(), adaptor.getSecond()); + } else { + packed = packLoopPair(op, adaptor.getFirst(), adaptor.getSecond()); + } + if (failed(packed)) + return rewriter.notifyMatchFailure(op, + "failed to pack loop configuration"); + + StringRef calleeName = buildSetLoopCallee(op.getContext()); + auto funcType = + rewriter.getFunctionType(TypeRange{rewriter.getI64Type()}, TypeRange{}); + rewriter.create(op.getLoc(), calleeName, TypeRange{}, + ValueRange{*packed}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerUnaryConfigOpPattern final : public OpConversionPattern { +public: + explicit LowerUnaryConfigOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(ConfigOp op, typename ConfigOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr encoded = + encodeMovPadValue(op.getLoc(), adaptor.getValue(), rewriter); + if (failed(encoded)) + return rewriter.notifyMatchFailure( + op, "expected 8/16/32-bit integer or float mov-pad payload"); + + StringRef calleeName = buildUnaryConfigCallee(op.getContext()); + auto funcType = + rewriter.getFunctionType(TypeRange{rewriter.getI64Type()}, TypeRange{}); + rewriter.create(op.getLoc(), calleeName, TypeRange{}, + ValueRange{*encoded}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerUnaryI64ConfigOpPattern final : public OpConversionPattern { +public: + explicit LowerUnaryI64ConfigOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(ConfigOp op, typename ConfigOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + StringRef calleeName = buildUnaryConfigCallee(op.getContext()); + auto funcType = + rewriter.getFunctionType(TypeRange{adaptor.getValue().getType()}, + TypeRange{}); + rewriter.create(op.getLoc(), calleeName, TypeRange{}, + ValueRange{adaptor.getValue()}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerStoreVfSimtInfoOpPattern final + : public OpConversionPattern { +public: + explicit LowerStoreVfSimtInfoOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::StoreVfSimtInfoOp op, + pto::StoreVfSimtInfoOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value dimZ = adaptor.getDimZ(); + Value dimY = adaptor.getDimY(); + Value dimX = adaptor.getDimX(); + if (!dimZ || !dimY || !dimX) + return rewriter.notifyMatchFailure(op, "missing converted SIMT dims"); + + auto i64Type = rewriter.getI64Type(); + auto castToI64 = [&](Value value) -> Value { + if (value.getType().isInteger(64)) + return value; + return rewriter.create(loc, i64Type, value).getResult(); + }; + + Value dimZI64 = castToI64(dimZ); + Value dimYI64 = castToI64(dimY); + Value dimXI64 = castToI64(dimX); + Value dimYShift = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(16)); + Value dimZShift = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(32)); + Value packedDimY = + rewriter.create(loc, dimYI64, dimYShift).getResult(); + Value packedDimZ = + rewriter.create(loc, dimZI64, dimZShift).getResult(); + Value payload = + rewriter.create(loc, dimXI64, packedDimY).getResult(); + payload = + rewriter.create(loc, payload, packedDimZ).getResult(); + + StringRef calleeName = buildStoreVfSimtInfoCallee(op.getContext()); + auto funcType = rewriter.getFunctionType(TypeRange{i64Type}, TypeRange{}); + rewriter.create(loc, calleeName, TypeRange{}, + ValueRange{payload}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerNullaryConfigOpPattern final : public OpConversionPattern { +public: + explicit LowerNullaryConfigOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(ConfigOp op, typename ConfigOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + StringRef calleeName = buildNullaryConfigCallee(op.getContext()); + auto funcType = rewriter.getFunctionType(TypeRange{}, TypeRange{}); + rewriter.create(op.getLoc(), calleeName, TypeRange{}, + ValueRange{}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerPipeEventSyncOpPattern final : public OpConversionPattern { +public: + explicit LowerPipeEventSyncOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(SyncOp op, typename SyncOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + auto src = parsePipeImmediate(stringifyPIPE(op.getSrcPipe().getPipe())); + auto dst = parsePipeImmediate(stringifyPIPE(op.getDstPipe().getPipe())); + auto event = parseEventImmediate(stringifyEVENT(op.getEventId().getEvent())); + if (!src || !dst || !event) + return rewriter.notifyMatchFailure(op, "unsupported sync immediate"); + + StringRef calleeName = buildSyncCallee(op.getContext()); + Value srcValue = getI64Constant(rewriter, op.getLoc(), *src); + Value dstValue = getI64Constant(rewriter, op.getLoc(), *dst); + Value eventValue = getI64Constant(rewriter, op.getLoc(), *event); + auto funcType = rewriter.getFunctionType( + TypeRange{rewriter.getI64Type(), rewriter.getI64Type(), + rewriter.getI64Type()}, + TypeRange{}); + rewriter.create(op.getLoc(), calleeName, TypeRange{}, + ValueRange{srcValue, dstValue, eventValue}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerPipeEventDynSyncOpPattern final : public OpConversionPattern { +public: + explicit LowerPipeEventDynSyncOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(SyncOp op, typename SyncOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto src = parsePipeImmediate(stringifyPIPE(op.getSrcPipe().getPipe())); + auto dst = parsePipeImmediate(stringifyPIPE(op.getDstPipe().getPipe())); + if (!src || !dst) + return rewriter.notifyMatchFailure(op, "unsupported sync pipe"); + + StringRef calleeName = buildSyncCallee(op.getContext()); + Value srcValue = getI64Constant(rewriter, op.getLoc(), *src); + Value dstValue = getI64Constant(rewriter, op.getLoc(), *dst); + + Value eventIdValue = adaptor.getEventId(); + if (!eventIdValue) + return rewriter.notifyMatchFailure(op, "missing event_id operand"); + + Value eventValue = eventIdValue; + + while (eventValue.getDefiningOp()) { + auto unrealizedCast = dyn_cast(eventValue.getDefiningOp()); + if (!unrealizedCast || unrealizedCast.getInputs().size() != 1) + break; + eventValue = unrealizedCast.getInputs()[0]; + } + + if (eventValue.getType().isIndex()) { + eventValue = rewriter.create(op.getLoc(), + rewriter.getI64Type(), + eventValue); + } else if (auto intType = dyn_cast(eventValue.getType())) { + if (intType.getWidth() < 64) { + eventValue = rewriter.create(op.getLoc(), + rewriter.getI64Type(), + eventValue); + } + } else { + return rewriter.notifyMatchFailure(op, "unexpected event_id type"); + } + + auto funcType = rewriter.getFunctionType( + TypeRange{rewriter.getI64Type(), rewriter.getI64Type(), + rewriter.getI64Type()}, + TypeRange{}); + rewriter.create(op.getLoc(), calleeName, TypeRange{}, + ValueRange{srcValue, dstValue, eventValue}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerInterCoreSyncOpPattern final : public OpConversionPattern { +public: + explicit LowerInterCoreSyncOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(SyncOp op, typename SyncOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto pipe = parsePipeImmediate(stringifyPIPE(op.getPipe().getPipe())); + if (!pipe) + return rewriter.notifyMatchFailure(op, "unsupported inter-core sync pipe"); + + Value pipeValue = getI64Constant(rewriter, op.getLoc(), *pipe); + Value eventValue; + if (IntegerAttr eventIdAttr = op.getEventIdAttr()) { + eventValue = getI64Constant(rewriter, op.getLoc(), eventIdAttr.getInt()); + } else { + Value eventIdDyn = adaptor.getEventIdDyn(); + if (!eventIdDyn) + return rewriter.notifyMatchFailure( + op, "expected static or dynamic event-id operand"); + + eventValue = castIntegerLikeTo(op, eventIdDyn, rewriter.getI64Type()); + if (!eventValue) { + return rewriter.notifyMatchFailure( + op, "failed to cast dynamic event-id to i64"); + } + } + + StringRef calleeName = buildSyncCallee(op.getContext()); + auto funcType = rewriter.getFunctionType( + TypeRange{rewriter.getI64Type(), rewriter.getI64Type()}, TypeRange{}); + rewriter.create(op.getLoc(), calleeName, TypeRange{}, + ValueRange{pipeValue, eventValue}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerBarrierOpPattern final : public OpConversionPattern { +public: + explicit LowerBarrierOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::BarrierOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + auto pipe = parsePipeImmediate(stringifyPIPE(op.getPipe().getPipe())); + if (!pipe) + return rewriter.notifyMatchFailure(op, "unsupported barrier pipe"); + + StringRef calleeName = buildSyncCallee(op.getContext()); + Value pipeValue = getI64Constant(rewriter, op.getLoc(), *pipe); + auto funcType = + rewriter.getFunctionType(TypeRange{rewriter.getI64Type()}, TypeRange{}); + rewriter.create(op.getLoc(), calleeName, TypeRange{}, + ValueRange{pipeValue}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerMemBarOpPattern final : public OpConversionPattern { +public: + explicit LowerMemBarOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::MemBarOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + StringRef calleeName = buildMemBarCallee(op.getKind().getKind(), op.getContext()); + auto funcType = rewriter.getFunctionType(TypeRange{}, TypeRange{}); + rewriter.create(op.getLoc(), calleeName, TypeRange{}, ValueRange{}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerBufSyncOpPattern final : public OpConversionPattern { +public: + explicit LowerBufSyncOpPattern(TypeConverter &typeConverter, + MLIRContext *context, LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(BufSyncOp op, typename BufSyncOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + PIPE pipe = PIPE::PIPE_UNASSIGNED; + if (auto pipeAttr = dyn_cast(op.getOpTypeAttr())) { + pipe = pipeAttr.getPipe(); + } else { + auto opTypeOr = parseSyncOpTypeLikeAttr(op.getOpTypeAttr()); + if (failed(opTypeOr)) + return rewriter.notifyMatchFailure( + op, "buffer sync expects pipe/sync_op_type/pipe_event_type attr"); + pipe = mapSyncOpTypeToPipe(*opTypeOr); + } + if (!isConcreteSyncPipe(pipe)) + return rewriter.notifyMatchFailure(op, + "buffer sync op_type cannot map to concrete pipe"); + + auto pipeImm = parsePipeImmediate(stringifyPIPE(pipe)); + if (!pipeImm) + return rewriter.notifyMatchFailure(op, "unsupported buffer sync pipe"); + + StringRef calleeName = buildSyncCallee(op.getContext()); + Value pipeValue = getI64Constant(rewriter, op.getLoc(), *pipeImm); + Value bufIdValue = + getI64Constant(rewriter, op.getLoc(), op.getBufIdAttr().getInt()); + Value modeValue = + getI64Constant(rewriter, op.getLoc(), op.getModeAttr().getInt()); + auto funcType = rewriter.getFunctionType( + TypeRange{rewriter.getI64Type(), rewriter.getI64Type(), + rewriter.getI64Type()}, + TypeRange{}); + rewriter.create(op.getLoc(), calleeName, TypeRange{}, + ValueRange{pipeValue, bufIdValue, modeValue}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.eraseOp(op); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerRuntimeQueryOpPattern final : public OpConversionPattern { +public: + explicit LowerRuntimeQueryOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(QueryOp op, typename QueryOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, + "failed to convert runtime-query result type"); + + StringRef calleeName = buildRuntimeQueryCallee(op.getContext()); + auto funcType = rewriter.getFunctionType(TypeRange{}, TypeRange{resultType}); + auto call = rewriter.create(op.getLoc(), calleeName, + TypeRange{resultType}, ValueRange{}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class LowerGetVms4SrOpPattern final + : public OpConversionPattern { +public: + explicit LowerGetVms4SrOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), + state(state) {} + + LogicalResult + matchAndRewrite(pto::GetVms4SrOp op, pto::GetVms4SrOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + (void)adaptor; + SmallVector resultTypes; + if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), + resultTypes)) || + resultTypes.size() != 4) + return rewriter.notifyMatchFailure( + op, "failed to convert get_vms4_sr result types"); + + StringRef calleeName = buildRuntimeQueryCallee( + op.getContext()); + auto funcType = + rewriter.getFunctionType(TypeRange{}, TypeRange{rewriter.getI64Type()}); + auto call = rewriter.create( + op.getLoc(), calleeName, TypeRange{rewriter.getI64Type()}, + ValueRange{}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + + SmallVector counts; + counts.reserve(4); + Value raw = call.getResult(0); + for (unsigned i = 0; i < 4; ++i) { + Value shifted = raw; + if (i != 0) + shifted = rewriter.create( + op.getLoc(), raw, getI64Constant(rewriter, op.getLoc(), i * 16)); + counts.push_back(rewriter.create( + op.getLoc(), resultTypes[i], shifted)); + } + rewriter.replaceOp(op, counts); + return success(); + } + +private: + LoweringState &state; +}; + +template +class LowerBinaryI64PureOpPattern final : public OpConversionPattern { +public: + explicit LowerBinaryI64PureOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + LoweringState &state) + : OpConversionPattern(typeConverter, context), state(state) {} + + LogicalResult + matchAndRewrite(BinaryOp op, typename BinaryOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, "failed to convert result type"); + + StringRef calleeName = buildBinaryI64PureCallee(op.getContext()); + auto funcType = + rewriter.getFunctionType(TypeRange{adaptor.getFirst().getType(), + adaptor.getSecond().getType()}, + TypeRange{resultType}); + auto call = rewriter.create( + op.getLoc(), calleeName, TypeRange{resultType}, + ValueRange{adaptor.getFirst(), adaptor.getSecond()}); + state.plannedDecls.push_back(PlannedDecl{calleeName.str(), funcType}); + rewriter.replaceOp(op, call.getResults()); + return success(); + } + +private: + LoweringState &state; +}; + +class ConvertVPTOUnrealizedCastOp final + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (op->getNumOperands() != 1 || op->getNumResults() != 1) + return failure(); + if (!hasVPTOConvertibleType(op->getOperandTypes()) && + !hasVPTOConvertibleType(op->getResultTypes())) + return failure(); + + Type convertedResultType = + getTypeConverter()->convertType(op.getResult(0).getType()); + if (!convertedResultType) + return failure(); + + Value input = adaptor.getOperands().front(); + if (input.getType() != convertedResultType) + return failure(); + + rewriter.replaceOp(op, input); + return success(); + } +}; + +class ConvertPtoAddPtrOp final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(pto::AddPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type convertedResultType = getTypeConverter()->convertType(op.getResult().getType()); + auto llvmPtrType = dyn_cast(convertedResultType); + if (!llvmPtrType) + return rewriter.notifyMatchFailure(op, "expected LLVM pointer result type"); + + Value offset = adaptor.getOffset(); + if (offset.getType().isIndex()) + offset = rewriter.create(op.getLoc(), + rewriter.getI64Type(), offset); + + auto gep = rewriter.create( + op.getLoc(), llvmPtrType, cast(op.getPtr().getType()).getElementType(), + adaptor.getPtr(), ValueRange{offset}); + rewriter.replaceOp(op, gep.getResult()); + return success(); + } +}; + +class ConvertPtoCastPtrOp final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(pto::CastPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type convertedResultType = + getTypeConverter()->convertType(op.getResult().getType()); + if (!convertedResultType) + return rewriter.notifyMatchFailure(op, + "could not convert castptr result type"); + + Value input = adaptor.getInput(); + Type inputType = input.getType(); + if (inputType == convertedResultType) { + rewriter.replaceOp(op, input); + return success(); + } + + if (auto llvmPtrType = dyn_cast(convertedResultType)) { + if (isa(inputType)) { + rewriter.replaceOpWithNewOp(op, llvmPtrType, input); + return success(); + } + auto sourcePtrType = dyn_cast(inputType); + if (!sourcePtrType) + return rewriter.notifyMatchFailure(op, + "expected integer or LLVM pointer input"); + if (sourcePtrType.getAddressSpace() == llvmPtrType.getAddressSpace()) { + rewriter.replaceOpWithNewOp(op, llvmPtrType, input); + return success(); + } + return rewriter.notifyMatchFailure( + op, "cross-address-space ptr casts are unsupported"); + } + + if (auto resultIntType = dyn_cast(convertedResultType)) { + if (isa(inputType)) { + rewriter.replaceOpWithNewOp(op, resultIntType, input); + return success(); + } + } + + return rewriter.notifyMatchFailure(op, "unsupported castptr conversion"); + } +}; + +class ConvertPtoLoadScalarOp final + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(pto::LoadScalarOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto llvmPtrType = dyn_cast(adaptor.getPtr().getType()); + if (!llvmPtrType) + return rewriter.notifyMatchFailure(op, "expected LLVM pointer operand"); + + Type convertedValueType = + getTypeConverter()->convertType(op.getValue().getType()); + if (!convertedValueType) + return rewriter.notifyMatchFailure(op, + "could not convert load_scalar result type"); + + Value offset = adaptor.getOffset(); + if (offset.getType().isIndex()) + offset = rewriter.create(op.getLoc(), + rewriter.getI64Type(), offset); + + Value elemPtr = adaptor.getPtr(); + if (!matchPattern(offset, m_Zero())) { + elemPtr = rewriter.create(op.getLoc(), llvmPtrType, + convertedValueType, adaptor.getPtr(), + ValueRange{offset}); + } + + auto getNaturalAlignment = [&](Type type) -> unsigned { + unsigned alignBytes = 0; + if (auto intType = dyn_cast(type)) + alignBytes = llvm::divideCeil(unsigned(intType.getWidth()), 8u); + else if (type.isF16() || type.isBF16()) + alignBytes = 2; + else if (type.isF32()) + alignBytes = 4; + else if (type.isF64()) + alignBytes = 8; + return alignBytes; + }; + + rewriter.replaceOpWithNewOp( + op, convertedValueType, elemPtr, + getNaturalAlignment(convertedValueType)); + return success(); + } +}; + +class ConvertPtoStoreScalarOp final + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(pto::StoreScalarOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto llvmPtrType = dyn_cast(adaptor.getPtr().getType()); + if (!llvmPtrType) + return rewriter.notifyMatchFailure(op, "expected LLVM pointer operand"); + + Value offset = adaptor.getOffset(); + if (offset.getType().isIndex()) + offset = rewriter.create(op.getLoc(), + rewriter.getI64Type(), offset); + + Value elemPtr = adaptor.getPtr(); + if (!matchPattern(offset, m_Zero())) { + elemPtr = rewriter.create(op.getLoc(), llvmPtrType, + adaptor.getValue().getType(), + adaptor.getPtr(), ValueRange{offset}); + } + + auto getNaturalAlignment = [&](Type type) -> unsigned { + unsigned alignBytes = 0; + if (auto intType = dyn_cast(type)) + alignBytes = llvm::divideCeil(unsigned(intType.getWidth()), 8u); + else if (type.isF16() || type.isBF16()) + alignBytes = 2; + else if (type.isF32()) + alignBytes = 4; + else if (type.isF64()) + alignBytes = 8; + return alignBytes; + }; + + rewriter.create(op.getLoc(), adaptor.getValue(), elemPtr, + getNaturalAlignment(adaptor.getValue().getType())); + rewriter.eraseOp(op); + return success(); + } +}; + +class ConvertPtoLoadOp final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(pto::PTOLoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto llvmPtrType = dyn_cast(adaptor.getPtr().getType()); + if (!llvmPtrType) + return rewriter.notifyMatchFailure(op, "expected LLVM pointer operand"); + + Type convertedValueType = + getTypeConverter()->convertType(op.getValue().getType()); + if (!convertedValueType) + return rewriter.notifyMatchFailure(op, "could not convert load result type"); + + Value offset = adaptor.getOffset(); + if (offset.getType().isIndex()) + offset = rewriter.create(op.getLoc(), + rewriter.getI64Type(), offset); + + Value elemPtr = adaptor.getPtr(); + if (!matchPattern(offset, m_Zero())) { + elemPtr = rewriter.create(op.getLoc(), llvmPtrType, + convertedValueType, adaptor.getPtr(), + ValueRange{offset}); + } + + auto getNaturalAlignment = [&](Type type) -> unsigned { + unsigned alignBytes = 0; + if (auto intType = dyn_cast(type)) + alignBytes = llvm::divideCeil(unsigned(intType.getWidth()), 8u); + else if (type.isF16() || type.isBF16()) + alignBytes = 2; + else if (type.isF32()) + alignBytes = 4; + else if (type.isF64()) + alignBytes = 8; + return alignBytes; + }; + + rewriter.replaceOpWithNewOp( + op, convertedValueType, elemPtr, + getNaturalAlignment(convertedValueType)); + return success(); + } +}; + +class ConvertPtoStoreOp final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(pto::PTOStoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto llvmPtrType = dyn_cast(adaptor.getPtr().getType()); + if (!llvmPtrType) + return rewriter.notifyMatchFailure(op, "expected LLVM pointer operand"); + + Value offset = adaptor.getOffset(); + if (offset.getType().isIndex()) + offset = rewriter.create(op.getLoc(), + rewriter.getI64Type(), offset); + + Value elemPtr = adaptor.getPtr(); + if (!matchPattern(offset, m_Zero())) { + elemPtr = rewriter.create(op.getLoc(), llvmPtrType, + adaptor.getValue().getType(), + adaptor.getPtr(), ValueRange{offset}); + } + + auto getNaturalAlignment = [&](Type type) -> unsigned { + unsigned alignBytes = 0; + if (auto intType = dyn_cast(type)) + alignBytes = llvm::divideCeil(unsigned(intType.getWidth()), 8u); + else if (type.isF16() || type.isBF16()) + alignBytes = 2; + else if (type.isF32()) + alignBytes = 4; + else if (type.isF64()) + alignBytes = 8; + return alignBytes; + }; + + rewriter.replaceOpWithNewOp( + op, adaptor.getValue(), elemPtr, + getNaturalAlignment(adaptor.getValue().getType())); + return success(); + } +}; + +class ConvertVPTOTypedCarrierOp final : public ConversionPattern { +public: + ConvertVPTOTypedCarrierOp(TypeConverter &typeConverter, MLIRContext *context) + : ConversionPattern(typeConverter, MatchAnyOpTypeTag(), 1, context) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (isa(op)) + return failure(); + if (!hasVPTOConvertibleType(op->getOperandTypes()) && + !hasVPTOConvertibleType(op->getResultTypes())) + return failure(); + if (op->getNumRegions() != 0) + return rewriter.notifyMatchFailure( + op, "region ops with VPTO types are handled structurally"); + + FailureOr converted = + convertOpResultTypes(op, operands, *typeConverter, rewriter); + if (failed(converted)) + return failure(); + return success(); + } +}; + +static void populateVPTOOpLoweringPatterns(VPTOTypeConverter &typeConverter, + RewritePatternSet &patterns, + LoweringState &state) { + patterns.add, + LowerUnaryMaskedOpPattern, + LowerUnaryMaskedOpPattern, + LowerUnaryMaskedOpPattern, + LowerUnaryMaskedOpPattern, + LowerUnaryMaskedOpPattern, + LowerUnaryMaskedOpPattern, + LowerVsqzOpPattern, LowerVusqzOpPattern, + LowerVmulaOpPattern, LowerVmullOpPattern, + LowerBinaryMaskedOpPattern, + LowerBinaryMaskedOpPattern, + LowerBinaryMaskedOpPattern, + LowerBinaryMaskedOpPattern, + LowerBinaryMaskedOpPattern, + LowerBinaryMaskedOpPattern, + LowerBinaryMaskedOpPattern, + LowerBinaryMaskedOpPattern, + LowerBinaryMaskedOpPattern, + LowerBinaryMaskedOpPattern, + LowerCarryBinaryOpPattern, + LowerCarryBinaryOpPattern, + LowerCarryBinaryOpPattern, + LowerCarryBinaryOpPattern, + LowerBinaryMaskedOpPattern, + LowerBinaryMaskedOpPattern, + LowerVecScalarMaskedOpPattern, + LowerVecScalarMaskedOpPattern, + LowerVecScalarMaskedOpPattern, + LowerVecScalarMaskedOpPattern, + LowerVecScalarMaskedOpPattern, + LowerVecScalarMaskedOpPattern, + LowerVecScalarMaskedOpPattern, + LowerWideningReductionUnaryOpPattern, + LowerReductionUnaryOpPattern, + LowerReductionUnaryOpPattern, + LowerReductionUnaryOpPattern, + LowerReductionUnaryOpPattern, + LowerReductionUnaryOpPattern, + LowerReductionUnaryOpPattern, + LowerVdupOpPattern, + LowerVbrOpPattern, + LowerPredicatePackOpPattern, + LowerPredicatePackOpPattern, + LowerVselOpPattern, LowerVselrOpPattern, LowerPnotOpPattern, + LowerPredicateMaskBinaryOpPattern, + LowerPredicateMaskBinaryOpPattern, + LowerPredicateMaskBinaryOpPattern, + LowerPredicateMaskBinaryOpPattern, + LowerPredicatePairReorderOpPattern, + LowerPredicatePairReorderOpPattern, + LowerPredicatePairReorderOpPattern, + LowerPredicatePairReorderOpPattern, + LowerPredicatePairReorderOpPattern, + LowerPredicatePairReorderOpPattern, + LowerUnpackOpPattern, + LowerUnpackOpPattern, + LowerVpackOpPattern, + LowerInterleaveOpPattern, + LowerInterleaveOpPattern, + LowerCmpOpPattern, + LowerCmpOpPattern, + LowerPltOpPattern, + LowerPltOpPattern, + LowerPltOpPattern, + LowerPsetOpPattern, + LowerPsetOpPattern, + LowerPsetOpPattern, + LowerPgeOpPattern, + LowerPgeOpPattern, + LowerPgeOpPattern, + LowerRuntimeQueryOpPattern, + LowerGetVms4SrOpPattern, + LowerRuntimeQueryOpPattern, + LowerRuntimeQueryOpPattern, + LowerRuntimeQueryOpPattern, + LowerBinaryI64PureOpPattern, + LowerBinaryI64PureOpPattern, + LowerSetLoopConfigOpPattern, + LowerSetLoopConfigOpPattern, + LowerSetLoopConfigOpPattern, + LowerSetLoopConfigOpPattern, + LowerSetLoopConfigOpPattern, + LowerSetLoopConfigOpPattern, + LowerSetLoopConfigOpPattern, + LowerSetLoopConfigOpPattern, + LowerUnaryI64ConfigOpPattern, + LowerStoreVfSimtInfoOpPattern, + LowerUnaryConfigOpPattern, + LowerUnaryI64ConfigOpPattern, + LowerUnaryI64ConfigOpPattern, + LowerUnaryI64ConfigOpPattern, + LowerUnaryI64ConfigOpPattern, + LowerUnaryI64ConfigOpPattern, + LowerUnaryI64ConfigOpPattern, + LowerUnaryI64ConfigOpPattern, + LowerUnaryI64ConfigOpPattern, + LowerUnaryI64ConfigOpPattern, + LowerNullaryConfigOpPattern, + LowerNullaryConfigOpPattern, + LowerPipeEventSyncOpPattern, + LowerPipeEventSyncOpPattern, + LowerPipeEventDynSyncOpPattern, + LowerPipeEventDynSyncOpPattern, + LowerBarrierOpPattern, LowerMemBarOpPattern, + LowerBufSyncOpPattern, + LowerBufSyncOpPattern, + LowerRuntimeQueryOpPattern, + LowerRuntimeQueryOpPattern, + LowerRuntimeQueryOpPattern, + LowerRuntimeQueryOpPattern, + LowerVldsOpPattern, LowerVldsPostOpPattern, + LowerVldsx2OpPattern, LowerVsldbOpPattern, + LowerVldasOpPattern, LowerInitAlignOpPattern, + LowerVldusOpPattern, LowerSprclrOpPattern, + LowerVstsOpPattern, LowerVsstbOpPattern, + LowerVstsPostOpPattern, LowerVstsx2OpPattern, + LowerVstarOpPattern, LowerVstasOpPattern, + LowerVgather2OpPattern, LowerVgather2BcOpPattern, + LowerVgatherbOpPattern, LowerVscatterOpPattern, + LowerVaxpyOpPattern, + LowerVciOpPattern, LowerVexpdifOpPattern, + LowerVbitsortOpPattern, LowerVmrgsort4OpPattern, + LowerVtrcOpPattern, LowerVcvtOpPattern, + LowerVbitcastOpPattern, LowerPbitcastOpPattern, + LowerPredicateLoadOpPattern, + LowerPredicateLoadOpPattern, + LowerPredicateStoreOpPattern, + LowerPredicateStoreOpPattern, + LowerPstuOpPattern, LowerVstusOpPattern, LowerVsturOpPattern, + LowerInterCoreSyncOpPattern, + LowerInterCoreSyncOpPattern, + LowerCopyGmToCbufOpPattern, LowerLoadCbufToCaOpPattern, + LowerLoadCbufToCbOpPattern, + LowerLoadCbufToS4OpPattern, + LowerLoadCbufToS4OpPattern, + LowerLoadCbufToCaMxOpPattern, + LowerLoadCbufToCbMxOpPattern, LowerCopyMatrixCcToGmOpPattern, + LowerCopyMatrixCcToBufOpPattern, + LowerCopyMatrixCcToBufOpPattern, + LowerCopyCbufToBtOpPattern, LowerCopyCbufToFbufOpPattern, + LowerCopyGmToCbufMultiOpPattern, + LowerCopyGmToCbufMultiOpPattern, + LowerMadRawPattern, + LowerMadRawPattern, + LowerMadRawPattern, + LowerMadRawPattern, + LowerCopyOpPattern, + LowerCopyOpPattern, + LowerCopyUbufToUbufOpPattern, + LowerCopyCbufToUbufOpPattern, + LowerCopyUbufToCbufOpPattern>( + typeConverter, patterns.getContext(), state); +} + +static void configureVPTOOpLoweringTarget(ConversionTarget &target, + VPTOTypeConverter &typeConverter) { + (void)typeConverter; + target.addLegalOp(); + target.addLegalDialect(); + target.addLegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); +} + +static void populateVPTOStructuralTypePatterns( + VPTOTypeConverter &typeConverter, RewritePatternSet &patterns, + ConversionTarget &target) { + scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter, patterns, + target); + populateFunctionOpInterfaceTypeConversionPattern(patterns, + typeConverter); + populateCallOpTypeConversionPattern(patterns, typeConverter); + populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter); + populateReturnOpTypeConversionPattern(patterns, typeConverter); +} + +static void foldVPTOTypeCasts(ModuleOp module, TypeConverter &typeConverter) { + SmallVector castsToFold; + module.walk([&](UnrealizedConversionCastOp castOp) { + if (castOp->getNumOperands() != 1 || castOp->getNumResults() != 1) + return; + if (!hasVPTOConvertibleType(castOp->getOperandTypes()) && + !hasVPTOConvertibleType(castOp->getResultTypes())) + return; + Type convertedResultType = + typeConverter.convertType(castOp.getResult(0).getType()); + if (convertedResultType && + convertedResultType == castOp.getOperand(0).getType()) + castsToFold.push_back(castOp); + }); + for (UnrealizedConversionCastOp castOp : castsToFold) { + castOp.getResult(0).replaceAllUsesWith(castOp.getOperand(0)); + castOp.erase(); + } +} + +static LogicalResult lowerVPTOOps(ModuleOp module, llvm::raw_ostream &diagOS) { + MLIRContext *context = module.getContext(); + VPTOTypeConverter typeConverter(context); + ConversionTarget target(*context); + RewritePatternSet patterns(context); + LoweringState state; + + configureVPTOOpLoweringTarget(target, typeConverter); + populateVPTOOpLoweringPatterns(typeConverter, patterns, state); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) { + diagOS << "VPTO LLVM emission failed: VPTO op lowering failed\n"; + return failure(); + } + if (failed(materializeDecls(module, state.plannedDecls, diagOS))) + return failure(); + return success(); +} + +static LogicalResult lowerVPTOTypes(ModuleOp module, llvm::raw_ostream &diagOS) { + MLIRContext *context = module.getContext(); + VPTOTypeConverter typeConverter(context); + ConversionTarget target(*context); + RewritePatternSet patterns(context); + + target.addLegalOp(); + target.addDynamicallyLegalOp([&](func::FuncOp op) { + return typeConverter.isSignatureLegal(op.getFunctionType()) && + typeConverter.isLegal(&op.getBody()); + }); + target.addDynamicallyLegalOp( + [&](func::CallOp op) { return typeConverter.isLegal(op); }); + target.addDynamicallyLegalOp( + [&](func::ReturnOp op) { return typeConverter.isLegal(op); }); + target.addDynamicallyLegalOp( + [&](Operation *op) { + return isLegalForBranchOpInterfaceTypeConversionPattern(op, + typeConverter); + }); + target.addIllegalOp(); + target.addDynamicallyLegalOp( + [&](UnrealizedConversionCastOp op) { + return !hasVPTOConvertibleType(op->getOperandTypes()) && + !hasVPTOConvertibleType(op->getResultTypes()); + }); + target.markUnknownOpDynamicallyLegal([&](Operation *op) { + return typeConverter.isLegal(op->getOperandTypes()) && + typeConverter.isLegal(op->getResultTypes()); + }); + + populateVPTOStructuralTypePatterns(typeConverter, patterns, target); + patterns.add( + typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) { + diagOS << "VPTO LLVM emission failed: VPTO type lowering failed\n"; + return failure(); + } + foldVPTOTypeCasts(module, typeConverter); + return success(); +} + +static Type normalizeTypeForOfficialLLVMLowering(Type type, Builder &builder) { + type = convertVPTOType(type, builder); + return type; +} + +static void normalizeFuncSignaturesForOfficialLLVMLowering(ModuleOp module) { + Builder builder(module.getContext()); + + for (func::FuncOp funcOp : module.getOps()) { + FunctionType oldType = funcOp.getFunctionType(); + SmallVector newInputs; + SmallVector newResults; + bool changed = false; + + for (Type input : oldType.getInputs()) { + Type normalized = normalizeTypeForOfficialLLVMLowering(input, builder); + changed |= (normalized != input); + newInputs.push_back(normalized); + } + for (Type result : oldType.getResults()) { + Type normalized = normalizeTypeForOfficialLLVMLowering(result, builder); + changed |= (normalized != result); + newResults.push_back(normalized); + } + + if (!changed) + continue; + + auto newType = builder.getFunctionType(newInputs, newResults); + funcOp.setFunctionTypeAttr(TypeAttr::get(newType)); + + if (funcOp.isExternal()) + continue; + Block &entry = funcOp.getBody().front(); + for (auto [arg, newType] : llvm::zip(entry.getArguments(), newInputs)) + if (arg.getType() != newType) + arg.setType(newType); + } +} + +static void forceV300CtrlModeForVPTOFuncs(ModuleOp module) { + OpBuilder builder(module.getContext()); + + for (func::FuncOp funcOp : module.getOps()) { + if (!needsV300CtrlModeForVPTOFunc(funcOp)) + continue; + + Block &entry = funcOp.getBody().front(); + builder.setInsertionPointToStart(&entry); + auto i64Type = builder.getI64Type(); + auto bit60 = builder.create( + funcOp.getLoc(), i64Type, builder.getI64IntegerAttr(60)); + Value ctrl = + builder.create(funcOp.getLoc(), i64Type).getResult(); + Value ctrlV300 = builder + .create(funcOp.getLoc(), i64Type, + ctrl, bit60.getResult()) + .getResult(); + builder.create(funcOp.getLoc(), ctrlV300); + } +} + +static std::optional getKernelKind(ModuleOp module) { + auto kernelKind = module->getAttrOfType( + FunctionKernelKindAttr::name); + if (!kernelKind) + return std::nullopt; + return kernelKind.getKernelKind(); +} + +static VPTOEmissionOptions +makeDeviceEmissionOptions(const VPTOEmissionOptions &baseOptions, + FunctionKernelKind kind) { + VPTOEmissionOptions options = baseOptions; + constexpr llvm::StringLiteral kVecTargetFeatures = + "+ATOMIC,+ArchV130,+AregRedefinable,+ArithmeticBf16,+AtomicForB8 ," + "+F8e4m3,+F8e5m2,+F8e8m0,+FFTSBlk,+Fp4e1m2x2,+Fp4e2m1x2,+LDExtRefine," + "+MOVX8,+SPR7bits,+SyncV,+dav-c310-vec"; + constexpr llvm::StringLiteral kCubeTargetFeatures = + "+ATOMIC,+ArchV130,+AregRedefinable,+ArithmeticBf16,+AtomicForB8 ," + "+F8e4m3,+F8e5m2,+F8e8m0,+FFTSBlk,+Fp4e1m2x2,+Fp4e2m1x2,+LDExtRefine," + "+MOVX8,+SPR7bits,+SyncV,+dav-c310-cube"; + if (kind == FunctionKernelKind::Vector) { + options.march = "dav-c310-vec"; + options.aicoreArch = "dav-c310-vec"; + options.defaultTargetCPU = "dav-c310-vec"; + options.defaultTargetFeatures = kVecTargetFeatures.str(); + } else if (kind == FunctionKernelKind::Cube) { + options.march = "dav-c310-cube"; + options.aicoreArch = "dav-c310-cube"; + options.defaultTargetCPU = "dav-c310-cube"; + options.defaultTargetFeatures = kCubeTargetFeatures.str(); + } + return options; +} + +static FailureOr +getUniqueDeviceModuleByKernelKind(ModuleOp module, FunctionKernelKind kind, + llvm::raw_ostream &diagOS) { + ModuleOp matched; + for (ModuleOp child : module.getOps()) { + auto kernelKind = getKernelKind(child); + if (!kernelKind) + continue; + if (*kernelKind != kind) + continue; + if (matched) { + diagOS << "VPTO LLVM emission failed: duplicate device module with " + << FunctionKernelKindAttr::name << "\n"; + return failure(); + } + matched = child; + } + return matched; +} + +static LogicalResult renameKernelFunctionsForKernelKind(ModuleOp module, + llvm::raw_ostream &diagOS) { + auto kernelKind = getKernelKind(module); + if (!kernelKind) { + diagOS << "VPTO LLVM emission failed: device module missing " + << FunctionKernelKindAttr::name << "\n"; + return failure(); + } + + StringRef suffix; + if (*kernelKind == FunctionKernelKind::Vector) + suffix = kVectorSuffix; + else if (*kernelKind == FunctionKernelKind::Cube) + suffix = kCubeSuffix; + else { + diagOS << "VPTO LLVM emission failed: unsupported " + << FunctionKernelKindAttr::name << "\n"; + return failure(); + } + + for (func::FuncOp funcOp : module.getOps()) { + if (!hasVPTOKernelAttr(funcOp)) + continue; + if (funcOp.getSymName().ends_with(suffix)) + continue; + funcOp.setSymName((funcOp.getSymName() + suffix).str()); + } + return success(); +} + +struct LowerVPTOOpsPass final + : public PassWrapper> { + void runOnOperation() override { + materializeVecScopeCarrierLoops(getOperation()); + if (failed(lowerVPTOOps(getOperation(), llvm::errs()))) + signalPassFailure(); + } +}; + +struct LowerVPTOTypesPass final + : public PassWrapper> { + void runOnOperation() override { + if (failed(lowerVPTOTypes(getOperation(), llvm::errs()))) + signalPassFailure(); + } +}; + +struct NormalizeFuncSignaturesForLLVMLoweringPass final + : public PassWrapper> { + void runOnOperation() override { + normalizeFuncSignaturesForOfficialLLVMLowering(getOperation()); + } +}; + +struct PrepareVPTOLLVMLoweringPass final + : public PassWrapper> { + void runOnOperation() override { + ModuleOp module = getOperation(); + pto::annotatePTOEntryFunctions(module); + forceV300CtrlModeForVPTOFuncs(module); + if (failed(renameKernelFunctionsForKernelKind(module, llvm::errs()))) + signalPassFailure(); + } +}; + +llvm::StringSet +collectSimtEntryFunctionNames(ModuleOp module) { + llvm::StringSet simtEntries; + module.walk([&](func::FuncOp funcOp) { + if (funcOp->hasAttr(pto::kPTOSimtEntryAttrName)) + simtEntries.insert(funcOp.getSymName()); + }); + return simtEntries; +} + +static void applySimtEntryCallingConvention( + llvm::Module &llvmModule, + const llvm::StringSet &simtEntryNames) { + constexpr unsigned kSimtEntryCallingConv = 109; + + for (llvm::Function &function : llvmModule) { + if (simtEntryNames.contains(function.getName())) { + function.setCallingConv(kSimtEntryCallingConv); + function.addFnAttr(llvm::Attribute::NoInline); + } + } + + for (llvm::Function &function : llvmModule) { + for (llvm::BasicBlock &block : function) { + for (llvm::Instruction &inst : block) { + auto *call = llvm::dyn_cast(&inst); + if (!call) + continue; + auto *callee = call->getCalledFunction(); + if (!callee || !simtEntryNames.contains(callee->getName())) + continue; + call->setCallingConv(kSimtEntryCallingConv); + } + } + } +} + +static FailureOr +emitDeviceLLVMModule(ModuleOp deviceModule, StringRef kernelKind, + const VPTOEmissionOptions &options, + const llvm::StringSet &simtEntryNames, + llvm::raw_ostream &diagOS) { + if (!deviceModule) + return EmittedLLVMModule{}; + if (failed(applyQueriedTargetAttrs(deviceModule, options, diagOS))) + return failure(); + + auto llvmContext = std::make_unique(); + registerBuiltinDialectTranslation(*deviceModule.getContext()); + registerLLVMDialectTranslation(*deviceModule.getContext()); + std::unique_ptr llvmModule = + translateModuleToLLVMIR(deviceModule.getOperation(), *llvmContext); + if (!llvmModule) { + diagOS << "VPTO LLVM emission failed: LLVM IR export failed for " + << kernelKind << " module\n"; + return failure(); + } + + applySimtEntryCallingConvention(*llvmModule, simtEntryNames); + if (failed(attachAIVectorScopeMetadata(*llvmModule, diagOS))) + return failure(); + attachHIVMKernelAnnotations(*llvmModule); + llvmModule->setModuleIdentifier(("ptoas.hivm.official." + kernelKind).str()); + llvmModule->setSourceFileName(("ptoas.hivm.official." + kernelKind).str()); + return EmittedLLVMModule{std::move(llvmContext), std::move(llvmModule)}; +} + +template +static LogicalResult runPipeline(ModuleOp module, llvm::raw_ostream &diagOS, + const llvm::StringSet &simtEntryNames, + EmitFn &&emit) { + OwningOpRef clonedOp(module->clone()); + ModuleOp clonedModule = cast(*clonedOp); + + if (failed(validateVPTOAuthoringIR(clonedModule, &diagOS))) { + diagOS << "VPTO LLVM emission failed: authoring-stage VPTO legality " + "validation failed\n"; + return failure(); + } + + PassManager pm(clonedModule.getContext()); + pm.enableVerifier(); + auto &kernelModulePM = pm.nest(); + kernelModulePM.addPass(std::make_unique()); + kernelModulePM.addPass(std::make_unique()); + kernelModulePM.addPass(std::make_unique()); + kernelModulePM.addPass( + std::make_unique()); + kernelModulePM.addPass(createConvertSCFToCFPass()); + kernelModulePM.addPass(createArithToLLVMConversionPass()); + kernelModulePM.addPass(createConvertIndexToLLVMPass()); + kernelModulePM.addPass(createFinalizeMemRefToLLVMConversionPass()); + kernelModulePM.addPass(createConvertFuncToLLVMPass()); + kernelModulePM.addPass(createConvertControlFlowToLLVMPass()); + kernelModulePM.addPass(createReconcileUnrealizedCastsPass()); + if (failed(mlir::applyPassManagerCLOptions(pm))) { + diagOS << "VPTO LLVM emission failed: unable to apply MLIR pass manager " + "command-line options\n"; + return failure(); + } + if (failed(pm.run(clonedModule))) { + diagOS << "VPTO LLVM emission failed: official lowering pipeline failed\n"; + return failure(); + } + return emit(clonedModule); +} + +} // namespace + +LogicalResult lowerVPTOModuleToLLVMModules( + ModuleOp module, const VPTOEmissionOptions &options, + EmittedLLVMModule &cubeModule, EmittedLLVMModule &vectorModule, + llvm::raw_ostream &diagOS) { + llvm::StringSet simtEntryNames = + collectSimtEntryFunctionNames(module); + cubeModule.context.reset(); + cubeModule.module.reset(); + vectorModule.context.reset(); + vectorModule.module.reset(); + return runPipeline(module, diagOS, simtEntryNames, + [&](ModuleOp loweredModule) { + auto vectorDeviceModule = + getUniqueDeviceModuleByKernelKind( + loweredModule, FunctionKernelKind::Vector, diagOS); + if (failed(vectorDeviceModule)) + return failure(); + auto cubeDeviceModule = + getUniqueDeviceModuleByKernelKind( + loweredModule, FunctionKernelKind::Cube, diagOS); + if (failed(cubeDeviceModule)) + return failure(); + + if (*vectorDeviceModule) { + auto vectorOptions = + makeDeviceEmissionOptions(options, FunctionKernelKind::Vector); + auto emitted = + emitDeviceLLVMModule(*vectorDeviceModule, "vector", vectorOptions, + simtEntryNames, diagOS); + if (failed(emitted)) + return failure(); + vectorModule.context = std::move(emitted->context); + vectorModule.module = std::move(emitted->module); + } + if (*cubeDeviceModule) { + auto cubeOptions = + makeDeviceEmissionOptions(options, FunctionKernelKind::Cube); + auto emitted = + emitDeviceLLVMModule(*cubeDeviceModule, "cube", cubeOptions, + simtEntryNames, diagOS); + if (failed(emitted)) + return failure(); + cubeModule.context = std::move(emitted->context); + cubeModule.module = std::move(emitted->module); + } + return success(); + }); +} + +} // namespace mlir::pto diff --git a/lib/PTO/Transforms/VPTOLLVMEmitterHelper.cpp b/lib/PTO/Transforms/VPTOLLVMEmitterHelper.cpp new file mode 100644 index 000000000..aeb1e3a9e --- /dev/null +++ b/lib/PTO/Transforms/VPTOLLVMEmitterHelper.cpp @@ -0,0 +1,605 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- VPTOLLVMEmitterHelper.cpp - VPTO LLVM emission helpers ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// https://discourse.llvm.org/t/matchandrewrite-hiding-virtual-functions/84933/8 +#pragma GCC diagnostic ignored "-Woverloaded-virtual" + +#include "PTO/Transforms/VPTOLLVMEmitterHelper.h" + +#include "PTO/IR/PTO.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Target/LLVMIR/Export.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopeExit.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/MDBuilder.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/Process.h" +#include "llvm/Support/Program.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" + +using namespace mlir; + +namespace mlir::pto { +namespace { + +constexpr StringLiteral kAIVScopeDummyCallee = "aivscope_dummy"; + +struct QueriedTargetAttrs { + std::string targetCPU; + std::string targetFeatures; +}; + +static bool hasPtoMemRefMemorySpace(Type type) { + if (auto memRefType = dyn_cast(type)) + return isa(memRefType.getMemorySpace()); + if (auto functionType = dyn_cast(type)) + return llvm::any_of(functionType.getInputs(), hasPtoMemRefMemorySpace) || + llvm::any_of(functionType.getResults(), hasPtoMemRefMemorySpace); + return false; +} + +static bool hasPtoMemRefMemorySpace(TypeRange types) { + return llvm::any_of(types, [](Type type) { + return hasPtoMemRefMemorySpace(type); + }); +} + +struct ConvertPtoMemRefSpaceCarrierOp final : ConversionPattern { + ConvertPtoMemRefSpaceCarrierOp(TypeConverter &typeConverter, + MLIRContext *context) + : ConversionPattern(typeConverter, MatchAnyOpTypeTag(), 1, context) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (!hasPtoMemRefMemorySpace(op->getOperandTypes()) && + !hasPtoMemRefMemorySpace(op->getResultTypes())) + return failure(); + if (op->getNumRegions() != 0) + return rewriter.notifyMatchFailure( + op, "region ops with PTO memref spaces are handled structurally"); + + FailureOr converted = + convertOpResultTypes(op, operands, *typeConverter, rewriter); + if (failed(converted)) + return failure(); + return success(); + } +}; + +struct ConvertMemRefReinterpretCastSpaceOp final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type convertedResultType = getTypeConverter()->convertType(op.getType()); + auto memRefResultType = dyn_cast_or_null(convertedResultType); + if (!memRefResultType) + return rewriter.notifyMatchFailure(op, "expected memref result type"); + + rewriter.replaceOpWithNewOp( + op, memRefResultType, adaptor.getSource(), adaptor.getOffsets(), + adaptor.getSizes(), adaptor.getStrides(), op.getStaticOffsets(), + op.getStaticSizes(), op.getStaticStrides()); + return success(); + } +}; + +struct ConvertMemRefSubViewSpaceOp final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::SubViewOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type convertedResultType = getTypeConverter()->convertType(op.getType()); + auto memRefResultType = dyn_cast_or_null(convertedResultType); + if (!memRefResultType) + return rewriter.notifyMatchFailure(op, "expected memref result type"); + + rewriter.replaceOpWithNewOp( + op, memRefResultType, adaptor.getSource(), op.getMixedOffsets(), + op.getMixedSizes(), op.getMixedStrides()); + return success(); + } +}; + +struct ConvertMemRefSpaceUnrealizedCastOp final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (op->getNumOperands() != 1 || op->getNumResults() != 1) + return failure(); + if (!hasPtoMemRefMemorySpace(op->getOperandTypes()) && + !hasPtoMemRefMemorySpace(op->getResultTypes())) + return failure(); + + Type convertedResultType = + getTypeConverter()->convertType(op.getResult(0).getType()); + if (!convertedResultType) + return failure(); + + Value input = adaptor.getOperands().front(); + if (input.getType() == convertedResultType) { + rewriter.replaceOp(op, input); + return success(); + } + return failure(); + } +}; + +static void ensureAIVScopeDummyDecl(ModuleOp module) { + SymbolTable symbolTable(module); + if (symbolTable.lookup(kAIVScopeDummyCallee)) + return; + + OpBuilder builder(module.getBodyRegion()); + builder.setInsertionPointToStart(module.getBody()); + auto funcType = builder.getFunctionType(TypeRange{}, TypeRange{}); + auto dummy = builder.create(module.getLoc(), + kAIVScopeDummyCallee, funcType); + dummy.setPrivate(); +} + +static bool satisfiesAIVectorScopeLatchPostcondition(llvm::Loop *loop) { + llvm::BasicBlock *latch = loop->getLoopLatch(); + if (!latch) + return false; + + llvm::SmallVector preds(llvm::predecessors(latch)); + if (preds.size() != 1) + return false; + + auto *predTerm = preds.front()->getTerminator(); + return predTerm && predTerm->getNumSuccessors() == 1 && + predTerm->getSuccessor(0) == latch; +} + +static LogicalResult ensureDummyPredForAIVectorScopeLatch( + llvm::Loop *loop, llvm::raw_ostream &diagOS) { + if (satisfiesAIVectorScopeLatchPostcondition(loop)) + return success(); + + llvm::BasicBlock *latch = loop->getLoopLatch(); + if (!latch) { + diagOS << "VPTO LLVM emission failed: aivscope loop is missing a latch\n"; + return failure(); + } + + llvm::SmallVector preds(llvm::predecessors(latch)); + if (preds.empty()) { + diagOS << "VPTO LLVM emission failed: aivscope latch has no predecessor\n"; + return failure(); + } + + auto *dummy = llvm::SplitBlockPredecessors( + latch, preds, "aivscope.dummy", static_cast(nullptr), + static_cast(nullptr), nullptr, /*PreserveLCSSA=*/false); + if (!dummy) { + diagOS << "VPTO LLVM emission failed: failed to normalize aivscope latch " + "predecessors\n"; + return failure(); + } + + if (!satisfiesAIVectorScopeLatchPostcondition(loop)) { + diagOS << "VPTO LLVM emission failed: normalized aivscope latch still does " + "not satisfy the single-predecessor/single-successor contract\n"; + return failure(); + } + return success(); +} + +static FailureOr extractQuotedLLVMFnAttr(llvm::StringRef ir, + llvm::StringRef key) { + std::string pattern = "\""; + pattern += key.str(); + pattern += "\"=\""; + size_t start = ir.find(pattern); + if (start == llvm::StringRef::npos) + return failure(); + start += pattern.size(); + size_t end = ir.find('"', start); + if (end == llvm::StringRef::npos || end <= start) + return failure(); + return ir.slice(start, end).str(); +} + +static FailureOr +queryDefaultTargetAttrs(const VPTOEmissionOptions &options, + llvm::raw_ostream &diagOS) { + static llvm::StringMap cache; + + if (options.targetTriple.empty() || options.march.empty() || + options.aicoreArch.empty()) { + diagOS << "VPTO LLVM emission failed: missing target query options\n"; + return failure(); + } + + std::string cacheKey = + options.targetTriple + "|" + options.march + "|" + options.aicoreArch; + if (auto it = cache.find(cacheKey); it != cache.end()) + return it->second; + + auto bisheng = llvm::sys::findProgramByName("bisheng"); + if (!bisheng) { + diagOS << "VPTO LLVM emission failed: unable to find 'bisheng' in PATH\n"; + return failure(); + } + const std::string &bishengPath = *bisheng; + + llvm::SmallString<64> inputPath; + llvm::SmallString<64> outputPath; + int inputFD = -1; + int outputFD = -1; + if (auto ec = llvm::sys::fs::createTemporaryFile("ptoas-vpto-target-query", + "c", inputFD, inputPath)) { + diagOS << "VPTO LLVM emission failed: cannot create bisheng query input: " + << ec.message() << "\n"; + return failure(); + } + if (auto ec = llvm::sys::fs::createTemporaryFile("ptoas-vpto-target-query", + "ll", outputFD, outputPath)) { + llvm::sys::fs::remove(inputPath); + llvm::sys::Process::SafelyCloseFileDescriptor(inputFD); + diagOS << "VPTO LLVM emission failed: cannot create bisheng query output: " + << ec.message() << "\n"; + return failure(); + } + + auto cleanup = llvm::make_scope_exit([&]() { + llvm::sys::fs::remove(inputPath); + llvm::sys::fs::remove(outputPath); + }); + + { + llvm::raw_fd_ostream inputOS(inputFD, /*shouldClose=*/false); + inputOS << "void f(void) {}\n"; + } + llvm::sys::Process::SafelyCloseFileDescriptor(inputFD); + llvm::sys::Process::SafelyCloseFileDescriptor(outputFD); + + llvm::SmallString<128> stderrPath; + int stderrFD = -1; + if (auto ec = llvm::sys::fs::createTemporaryFile("ptoas-vpto-target-query", + "stderr", stderrFD, + stderrPath)) { + diagOS << "VPTO LLVM emission failed: cannot create bisheng query stderr: " + << ec.message() << "\n"; + return failure(); + } + auto stderrCleanup = llvm::make_scope_exit([&]() { + llvm::sys::fs::remove(stderrPath); + }); + llvm::sys::Process::SafelyCloseFileDescriptor(stderrFD); + + llvm::SmallVector argStorage = { + bishengPath, + ("--target=" + options.targetTriple), + ("-march=" + options.march), + ("--cce-aicore-arch=" + options.aicoreArch), + "--cce-aicore-only", + "-x", + "c", + inputPath.str().str(), + "-S", + "-emit-llvm", + "-o", + outputPath.str().str(), + }; + llvm::SmallVector args; + args.reserve(argStorage.size()); + for (const std::string &arg : argStorage) + args.push_back(arg); + + std::string execErr; + bool execFailed = false; + int rc = llvm::sys::ExecuteAndWait( + bishengPath, args, std::nullopt, + {std::nullopt, std::nullopt, llvm::StringRef(stderrPath)}, 0, 0, + &execErr, &execFailed); + + auto stderrBuffer = llvm::MemoryBuffer::getFile(stderrPath); + llvm::StringRef stderrText = + stderrBuffer ? stderrBuffer.get()->getBuffer() : llvm::StringRef(); + + if (execFailed || rc != 0) { + diagOS << "VPTO LLVM emission failed: bisheng target query failed\n"; + diagOS << "Command:"; + for (llvm::StringRef arg : args) + diagOS << " " << arg; + diagOS << "\n"; + if (!execErr.empty()) + diagOS << execErr << "\n"; + if (!stderrText.empty()) + diagOS << stderrText << "\n"; + return failure(); + } + + auto outputBuffer = llvm::MemoryBuffer::getFile(outputPath); + if (!outputBuffer) { + diagOS << "VPTO LLVM emission failed: cannot read bisheng query output\n"; + return failure(); + } + + FailureOr targetCPU = + extractQuotedLLVMFnAttr(outputBuffer.get()->getBuffer(), "target-cpu"); + FailureOr targetFeatures = + extractQuotedLLVMFnAttr(outputBuffer.get()->getBuffer(), "target-features"); + if (failed(targetCPU) || failed(targetFeatures)) { + diagOS << "VPTO LLVM emission failed: cannot parse bisheng target attrs\n"; + diagOS << outputBuffer.get()->getBuffer() << "\n"; + return failure(); + } + + QueriedTargetAttrs attrs{*targetCPU, *targetFeatures}; + cache[cacheKey] = attrs; + return attrs; +} + +} // namespace + +void materializeVecScopeCarrierLoops(ModuleOp module) { + MLIRContext *ctx = module.getContext(); + (void)ctx->getOrLoadDialect(); + (void)ctx->getOrLoadDialect(); + ensureAIVScopeDummyDecl(module); + + SmallVector scopes; + module.walk([&](pto::VecScopeOp vecScope) { scopes.push_back(vecScope); }); + + IRRewriter rewriter(module.getContext()); + for (pto::VecScopeOp vecScope : llvm::reverse(scopes)) { + if (!vecScope || vecScope.getBody().empty()) + continue; + + rewriter.setInsertionPoint(vecScope); + auto loc = vecScope.getLoc(); + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + scf::ForOp carrier = rewriter.create(loc, c0, c1, c1); + + Block &vecScopeBody = vecScope.getBody().front(); + Block *carrierBody = carrier.getBody(); + Operation *yield = carrierBody->getTerminator(); + carrierBody->getOperations().splice(Block::iterator(yield), + vecScopeBody.getOperations(), + vecScopeBody.begin(), + vecScopeBody.end()); + rewriter.setInsertionPoint(yield); + rewriter.create(loc, kAIVScopeDummyCallee, TypeRange{}, + ValueRange{}); + rewriter.eraseOp(vecScope); + } + + SmallVector strictScopes; + module.walk([&](pto::StrictVecScopeOp strictVecScope) { + strictScopes.push_back(strictVecScope); + }); + + for (pto::StrictVecScopeOp strictVecScope : llvm::reverse(strictScopes)) { + if (!strictVecScope || strictVecScope.getBody().empty()) + continue; + + rewriter.setInsertionPoint(strictVecScope); + auto loc = strictVecScope.getLoc(); + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + scf::ForOp carrier = rewriter.create(loc, c0, c1, c1); + + Block &strictBody = strictVecScope.getBody().front(); + Block *carrierBody = carrier.getBody(); + Operation *yield = carrierBody->getTerminator(); + + IRMapping mapping; + for (auto [blockArg, capture] : + llvm::zip(strictBody.getArguments(), strictVecScope.getCaptures())) + mapping.map(blockArg, capture); + + rewriter.setInsertionPoint(yield); + for (Operation &nested : strictBody.getOperations()) + rewriter.clone(nested, mapping); + rewriter.create(loc, kAIVScopeDummyCallee, TypeRange{}, + ValueRange{}); + + rewriter.eraseOp(strictVecScope); + } +} + +LogicalResult attachAIVectorScopeMetadata(llvm::Module &llvmModule, + llvm::raw_ostream &diagOS) { + llvm::Function *dummyCallee = llvmModule.getFunction(kAIVScopeDummyCallee); + if (!dummyCallee) + return success(); + + for (llvm::Function &function : llvmModule) { + if (function.isDeclaration()) + continue; + llvm::DominatorTree dt(function); + llvm::LoopInfo loopInfo(dt); + + llvm::SmallVector dummyCalls; + for (llvm::BasicBlock &block : function) { + for (llvm::Instruction &inst : block) { + auto *call = dyn_cast(&inst); + if (call && call->getCalledFunction() == dummyCallee) + dummyCalls.push_back(call); + } + } + + for (llvm::CallInst *dummyCall : dummyCalls) { + llvm::BasicBlock *markedBlock = dummyCall->getParent(); + llvm::Loop *loop = loopInfo.getLoopFor(markedBlock); + if (!loop) { + diagOS << "VPTO LLVM emission failed: aivscope_dummy in function " + << function.getName() << " does not belong to an LLVM loop\n"; + return failure(); + } + + if (markedBlock == loop->getLoopLatch() && + dummyCall != markedBlock->getTerminator()) { + markedBlock->splitBasicBlock(dummyCall->getIterator(), "aivscope.latch"); + dt.recalculate(function); + loopInfo.releaseMemory(); + loopInfo.analyze(dt); + markedBlock = dummyCall->getParent(); + loop = loopInfo.getLoopFor(markedBlock); + if (!loop) { + diagOS << "VPTO LLVM emission failed: split aivscope latch in " + << function.getName() + << " no longer belongs to an LLVM loop\n"; + return failure(); + } + } + + if (failed(ensureDummyPredForAIVectorScopeLatch(loop, diagOS))) + return failure(); + + dt.recalculate(function); + loopInfo.releaseMemory(); + loopInfo.analyze(dt); + loop = loopInfo.getLoopFor(markedBlock); + if (!loop) { + diagOS << "VPTO LLVM emission failed: aivscope_dummy in function " + << function.getName() + << " lost its loop after latch normalization\n"; + return failure(); + } + + llvm::BasicBlock *latch = loop->getLoopLatch(); + auto *branch = dyn_cast_or_null( + latch ? latch->getTerminator() : nullptr); + if (!branch || branch->isConditional()) { + diagOS << "VPTO LLVM emission failed: normalized aivscope loop in " + << function.getName() + << " does not have an unconditional latch backedge\n"; + return failure(); + } + + llvm::LLVMContext &ctx = llvmModule.getContext(); + llvm::Metadata *ops[] = { + nullptr, llvm::MDNode::get(ctx, llvm::MDString::get(ctx, "llvm.loop.aivector_scope"))}; + auto *loopID = llvm::MDNode::getDistinct(ctx, ops); + loopID->replaceOperandWith(0, loopID); + branch->setMetadata(llvm::LLVMContext::MD_loop, loopID); + dummyCall->eraseFromParent(); + } + } + + if (dummyCallee->use_empty()) + dummyCallee->eraseFromParent(); + return success(); +} + +void attachHIVMKernelAnnotations(llvm::Module &llvmModule) { + llvm::NamedMDNode *annotations = + llvmModule.getOrInsertNamedMetadata("hivm.annotations"); + llvm::LLVMContext &ctx = llvmModule.getContext(); + llvm::Type *i32Ty = llvm::Type::getInt32Ty(ctx); + llvm::Constant *one = llvm::ConstantInt::get(i32Ty, 1); + + auto hasInModuleCaller = [](llvm::Function &function) { + for (llvm::User *user : function.users()) { + auto *call = llvm::dyn_cast(user); + if (!call) + continue; + if (call->getCalledFunction() != &function) + continue; + return true; + } + return false; + }; + + auto addAnnotation = [&](llvm::Function &function, llvm::StringRef kind) { + llvm::Metadata *ops[] = { + llvm::ValueAsMetadata::get(&function), + llvm::MDString::get(ctx, kind), + llvm::ConstantAsMetadata::get(one)}; + annotations->addOperand(llvm::MDNode::get(ctx, ops)); + }; + + for (llvm::Function &function : llvmModule) { + if (function.isDeclaration()) + continue; + if (function.getLinkage() != llvm::GlobalValue::ExternalLinkage) + continue; + + llvm::StringRef name = function.getName(); + if (name.contains(".extracted") || name.contains(".vector.thread")) + continue; + if (hasInModuleCaller(function)) + continue; + + addAnnotation(function, "kernel"); + addAnnotation(function, "kernel_with_simd"); + } +} + +LogicalResult +applyQueriedTargetAttrs(ModuleOp module, const VPTOEmissionOptions &options, + llvm::raw_ostream &diagOS) { + FailureOr attrs = queryDefaultTargetAttrs(options, diagOS); + if (failed(attrs)) { + if (options.defaultTargetCPU.empty() || + options.defaultTargetFeatures.empty()) + return failure(); + diagOS << "VPTO LLVM emission: falling back to configured default target " + "attributes\n"; + attrs = QueriedTargetAttrs{options.defaultTargetCPU, + options.defaultTargetFeatures}; + } + + MLIRContext *ctx = module.getContext(); + StringAttr cpuAttr = StringAttr::get(ctx, attrs->targetCPU); + LLVM::TargetFeaturesAttr featureAttr = + LLVM::TargetFeaturesAttr::get(ctx, attrs->targetFeatures); + module.walk([&](LLVM::LLVMFuncOp funcOp) { + funcOp.setTargetCpuAttr(cpuAttr); + funcOp.setTargetFeaturesAttr(featureAttr); + }); + return success(); +} + +} // namespace mlir::pto diff --git a/lib/PTO/Transforms/VPTONormalizeContainer.cpp b/lib/PTO/Transforms/VPTONormalizeContainer.cpp new file mode 100644 index 000000000..d5b22af3d --- /dev/null +++ b/lib/PTO/Transforms/VPTONormalizeContainer.cpp @@ -0,0 +1,86 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "PTO/IR/PTO.h" +#include "PTO/Transforms/Passes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_VPTONORMALIZECONTAINER +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +using namespace mlir; +using namespace mlir::pto; + +namespace { + +static bool isVPTOKernelSubmodule(ModuleOp module) { + return module->hasAttr(FunctionKernelKindAttr::name); +} + +static LogicalResult verifyNormalizedVPTOContainer(ModuleOp module) { + bool hasChildModules = false; + for (Operation &op : module.getBodyRegion().front().getOperations()) { + auto child = dyn_cast(op); + if (!child) { + return op.emitError() + << "expected VPTO container top level to contain only kernel " + "submodules"; + } + hasChildModules = true; + if (!isVPTOKernelSubmodule(child)) { + return child.emitError() + << "expected VPTO kernel submodule to carry 'pto.kernel_kind'"; + } + } + + if (hasChildModules) + return success(); + + return module.emitError() + << "expected VPTO input to be a kernel submodule with " + "'pto.kernel_kind' or a container of kernel submodules"; +} + +struct VPTONormalizeContainerPass + : public mlir::pto::impl::VPTONormalizeContainerBase< + VPTONormalizeContainerPass> { + void runOnOperation() override { + ModuleOp module = getOperation(); + if (isVPTOKernelSubmodule(module)) { + MLIRContext *context = module.getContext(); + SmallVector outerAttrs; + for (NamedAttribute attr : module->getAttrs()) + if (attr.getName() != SymbolTable::getSymbolAttrName() && + attr.getName() != FunctionKernelKindAttr::name) + outerAttrs.push_back(attr); + + auto child = ModuleOp::create(module.getLoc()); + child->setAttrs(module->getAttrDictionary()); + child.getBodyRegion().takeBody(module.getBodyRegion()); + + module->setAttrs(DictionaryAttr::get(context, outerAttrs)); + module.getBodyRegion().push_back(new Block); + module.getBodyRegion().front().push_back(child.getOperation()); + } + + if (failed(verifyNormalizedVPTOContainer(module))) + signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr mlir::pto::createVPTONormalizeContainerPass() { + return std::make_unique(); +} diff --git a/lib/PTO/Transforms/VPTOPtrCastCleanup.cpp b/lib/PTO/Transforms/VPTOPtrCastCleanup.cpp new file mode 100644 index 000000000..dbaa9a780 --- /dev/null +++ b/lib/PTO/Transforms/VPTOPtrCastCleanup.cpp @@ -0,0 +1,81 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "PTO/Transforms/Passes.h" + +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_VPTOPTRCASTCLEANUP +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +using namespace mlir; + +namespace { + +struct CollapsePtrMemRefPtrBridgePattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(UnrealizedConversionCastOp op, + PatternRewriter &rewriter) const override { + if (op->getNumOperands() != 1 || op->getNumResults() != 1) + return failure(); + + auto resultPtrType = dyn_cast(op.getResult(0).getType()); + if (!resultPtrType) + return failure(); + + auto castOp = op.getOperand(0).getDefiningOp(); + if (!castOp || castOp->getNumOperands() != 1) + return failure(); + + auto innerCast = + castOp.getSource().getDefiningOp(); + if (!innerCast || innerCast->getNumOperands() != 1 || + innerCast->getNumResults() != 1) + return failure(); + + Value basePtr = innerCast.getOperand(0); + if (basePtr.getType() != resultPtrType) + return failure(); + + rewriter.replaceOp(op, basePtr); + if (castOp->use_empty()) + rewriter.eraseOp(castOp); + if (innerCast->use_empty()) + rewriter.eraseOp(innerCast); + return success(); + } +}; + +struct VPTOPtrCastCleanupPass + : public pto::impl::VPTOPtrCastCleanupBase { + using pto::impl::VPTOPtrCastCleanupBase< + VPTOPtrCastCleanupPass>::VPTOPtrCastCleanupBase; + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr mlir::pto::createVPTOPtrCastCleanupPass() { + return std::make_unique(); +} diff --git a/lib/PTO/Transforms/VPTOPtrNormalize.cpp b/lib/PTO/Transforms/VPTOPtrNormalize.cpp new file mode 100644 index 000000000..04f4e8f59 --- /dev/null +++ b/lib/PTO/Transforms/VPTOPtrNormalize.cpp @@ -0,0 +1,871 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// https://discourse.llvm.org/t/matchandrewrite-hiding-virtual-functions/84933/8 +#pragma GCC diagnostic ignored "-Woverloaded-virtual" + +#include "PTO/IR/PTO.h" +#include "PTO/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/Twine.h" + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_VPTOPTRNORMALIZE +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +using namespace mlir; + +namespace { + +static pto::AddressSpaceAttr getPointerMemorySpace(Attribute memorySpace, + MLIRContext *ctx) { + if (auto addrSpace = dyn_cast_or_null(memorySpace)) + return addrSpace; + if (auto intAttr = dyn_cast_or_null(memorySpace)) + return pto::AddressSpaceAttr::get( + ctx, static_cast(intAttr.getInt())); + return {}; +} + +static Value buildIndexValue(OpBuilder &builder, Location loc, + OpFoldResult ofr) { + if (auto value = dyn_cast(ofr)) + return value; + auto attr = cast(cast(ofr)); + return builder.create(loc, attr.getInt()); +} + +static bool needsSubviewPtrConversion(memref::SubViewOp op) { + auto resultType = dyn_cast(op.getType()); + if (!resultType) + return false; + return static_cast( + getPointerMemorySpace(resultType.getMemorySpace(), op.getContext())); +} + +static Type convertSubviewResultType(Type type) { + auto memrefType = dyn_cast(type); + if (!memrefType) + return type; + + auto memorySpace = + getPointerMemorySpace(memrefType.getMemorySpace(), type.getContext()); + if (!memorySpace) + return type; + + return pto::PtrType::get(type.getContext(), memrefType.getElementType(), + memorySpace); +} + +static bool hasPtrNormalizeConvertibleType(Type type) { + if (isa(type)) + return true; + auto memrefType = dyn_cast(type); + return memrefType && static_cast(getPointerMemorySpace( + memrefType.getMemorySpace(), type.getContext())); +} + +static bool hasPtrNormalizeConvertibleType(TypeRange types) { + return llvm::any_of( + types, [](Type type) { return hasPtrNormalizeConvertibleType(type); }); +} + +static bool isMemRefType(Type type) { return isa(type); } + +static Value materializeUnrealizedCast(OpBuilder &builder, Type resultType, + ValueRange inputs, Location loc) { + if (inputs.size() != 1) + return {}; + return builder + .create(loc, TypeRange{resultType}, inputs) + .getResult(0); +} + +static LogicalResult computeSubviewElementOffset(memref::SubViewOp op, + PatternRewriter &rewriter, + Value &offset) { + auto sourceType = dyn_cast(op.getSource().getType()); + if (!sourceType) + return failure(); + + SmallVector strides; + int64_t baseOffset = 0; + if (failed(getStridesAndOffset(sourceType, strides, baseOffset))) + return failure(); + // The SSA source already names the base address after bind_tile/pointer_cast + // normalization. A dynamic memref layout offset here is metadata we can + // ignore for ptr normalization and model as zero. + if (baseOffset == ShapedType::kDynamic) + baseOffset = 0; + + Location loc = op.getLoc(); + Value total = rewriter.create(loc, baseOffset); + ArrayRef mixedOffsets = op.getMixedOffsets(); + if (mixedOffsets.size() != strides.size()) + return failure(); + + for (auto [ofr, stride] : llvm::zip(mixedOffsets, strides)) { + if (stride == 0) + continue; + if (stride == ShapedType::kDynamic) + return failure(); + + Value idx = buildIndexValue(rewriter, loc, ofr); + if (!idx.getType().isIndex()) + return failure(); + + if (stride != 1) { + Value strideValue = + rewriter.create(loc, stride); + idx = rewriter.create(loc, idx, strideValue); + } + total = rewriter.create(loc, total, idx); + } + + offset = total; + return success(); +} + +static Value materializeSubviewInputPtr(Value source, PatternRewriter &rewriter, + Location loc) { + if (!source) + return {}; + if (isa(source.getType())) + return source; + + auto memrefType = dyn_cast(source.getType()); + if (!memrefType) + return {}; + + auto memorySpace = + getPointerMemorySpace(memrefType.getMemorySpace(), rewriter.getContext()); + if (!memorySpace) + return {}; + + auto ptrType = pto::PtrType::get(rewriter.getContext(), + memrefType.getElementType(), memorySpace); + return rewriter.create(loc, ptrType, source); +} + +static Value materializeScalarAccessPtr(Value source, PatternRewriter &rewriter, + Location loc) { + if (!source) + return {}; + if (isa(source.getType())) + return source; + + if (auto cast = source.getDefiningOp()) { + if (cast->getNumOperands() != 1 || cast->getNumResults() != 1) + return {}; + Value input = cast.getOperands().front(); + if (isa(input.getType())) + return input; + return materializeScalarAccessPtr(input, rewriter, loc); + } + + if (auto cast = source.getDefiningOp()) + return materializeScalarAccessPtr(cast.getSource(), rewriter, loc); + + if (auto subview = source.getDefiningOp()) { + if (!needsSubviewPtrConversion(subview)) + return {}; + + Value basePtr = + materializeScalarAccessPtr(subview.getSource(), rewriter, loc); + if (!basePtr) + return {}; + + Value offset; + if (failed(computeSubviewElementOffset(subview, rewriter, offset))) + return {}; + + auto ptrType = dyn_cast(convertSubviewResultType(source.getType())); + if (!ptrType) + return {}; + if (basePtr.getType() != ptrType) + basePtr = rewriter.create(loc, ptrType, basePtr); + return rewriter.create(loc, ptrType, basePtr, offset); + } + + if (auto bind = source.getDefiningOp()) + return materializeScalarAccessPtr(bind.getSource(), rewriter, loc); + + if (auto pointerCast = source.getDefiningOp()) { + if (pointerCast.getAddrs().empty()) + return {}; + Value addr = pointerCast.getAddrs().front(); + if (isa(addr.getType())) + return addr; + return materializeScalarAccessPtr(addr, rewriter, loc); + } + + // Restrict normalization to memref views that already sit on top of a ptr-like + // boundary bridge. Materializing fresh memref -> ptr casts here would leave + // illegal pto.castptr(memref) behind in this pass. + return {}; +} + +static Value materializeBoundaryOperandPtr(Value source, + PatternRewriter &rewriter, + Location loc) { + if (!source) + return {}; + if (isa(source.getType())) + return source; + return materializeScalarAccessPtr(source, rewriter, loc); +} + +template +static LogicalResult rewriteBufferLikeBoundaryOp( + OpTy op, typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter, + StringRef sourceRole, StringRef destinationRole) { + Value source = + materializeBoundaryOperandPtr(adaptor.getOperands()[0], rewriter, op.getLoc()); + if (!source) + return rewriter.notifyMatchFailure( + op, (Twine("failed to materialize ") + sourceRole + " ptr").str()); + if (!isa(source.getType())) + return rewriter.notifyMatchFailure( + op, (Twine("expected ptr-form ") + sourceRole).str()); + + Value destination = materializeBoundaryOperandPtr(adaptor.getOperands()[1], + rewriter, op.getLoc()); + if (!destination) + return rewriter.notifyMatchFailure( + op, (Twine("failed to materialize ") + destinationRole + " ptr").str()); + if (!isa(destination.getType())) + return rewriter.notifyMatchFailure( + op, (Twine("expected ptr-form ") + destinationRole).str()); + + SmallVector operands(adaptor.getOperands().begin(), + adaptor.getOperands().end()); + operands[0] = source; + operands[1] = destination; + + OperationState state(op.getLoc(), op->getName().getStringRef()); + state.addOperands(operands); + state.addTypes(op->getResultTypes()); + state.addAttributes(op->getAttrs()); + state.propertiesAttr = op->getPropertiesAsAttribute(); + Operation *newOp = rewriter.create(state); + rewriter.replaceOp(op, newOp->getResults()); + return success(); +} + +struct ConvertTileBufAddrToPtrPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(pto::TileBufAddrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type convertedType = getTypeConverter()->convertType(op.getDst().getType()); + if (!isa(convertedType)) + return failure(); + + rewriter.replaceOpWithNewOp(op, convertedType, + adaptor.getSrc()); + return success(); + } +}; + +struct ConvertPointerCastToCastPtrPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(pto::PointerCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type convertedType = getTypeConverter()->convertType(op.getResult().getType()); + auto ptrType = dyn_cast(convertedType); + if (!ptrType) + return failure(); + + if (adaptor.getAddrs().empty()) + return rewriter.notifyMatchFailure(op, "expected at least one address"); + + rewriter.replaceOpWithNewOp(op, ptrType, + adaptor.getAddrs().front()); + return success(); + } +}; + +struct ConvertCastPtrPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(pto::CastPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type convertedResultType = + getTypeConverter()->convertType(op.getResult().getType()); + if (!convertedResultType) + return failure(); + + Value input = adaptor.getInput(); + Type inputType = input.getType(); + if (isMemRefType(inputType) || isMemRefType(convertedResultType)) + return rewriter.notifyMatchFailure(op, + "memref castptr must be eliminated"); + + if (!isa(inputType) || + !isa(convertedResultType)) + return rewriter.notifyMatchFailure(op, + "expected ptr/int castptr operands"); + + if (inputType == convertedResultType) { + rewriter.replaceOp(op, input); + return success(); + } + + rewriter.replaceOpWithNewOp(op, convertedResultType, input); + return success(); + } +}; + +struct ConvertBindTileToPtrPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(pto::BindTileOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type convertedType = getTypeConverter()->convertType(op.getResult().getType()); + auto ptrType = dyn_cast(convertedType); + if (!ptrType) + return failure(); + + Value ptr = + materializeSubviewInputPtr(adaptor.getSource(), rewriter, op.getLoc()); + if (!ptr) + return rewriter.notifyMatchFailure(op, + "failed to materialize bind_tile input ptr"); + + if (ptr.getType() != ptrType) + ptr = rewriter.create(op.getLoc(), ptrType, ptr); + + rewriter.replaceOp(op, ptr); + return success(); + } +}; + +struct ConvertSubviewToAddPtrPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::SubViewOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!needsSubviewPtrConversion(op)) + return failure(); + + auto ptrType = + dyn_cast(getTypeConverter()->convertType(op.getType())); + if (!ptrType) + return rewriter.notifyMatchFailure(op, "expected ptr result type"); + + Value basePtr = + materializeSubviewInputPtr(adaptor.getSource(), rewriter, op.getLoc()); + if (!basePtr) + return rewriter.notifyMatchFailure(op, + "failed to materialize subview input ptr"); + + Value offset; + if (failed(computeSubviewElementOffset(op, rewriter, offset))) + return rewriter.notifyMatchFailure(op, + "failed to compute subview element offset"); + + rewriter.replaceOpWithNewOp(op, ptrType, basePtr, offset); + return success(); + } +}; + +struct ConvertVldsSubviewOperandPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(pto::VldsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!isa(adaptor.getSource().getType())) + return failure(); + + OperationState state(op.getLoc(), op->getName().getStringRef()); + state.addOperands({adaptor.getSource(), adaptor.getOffset()}); + state.addTypes(op->getResultTypes()); + state.addAttributes(op->getAttrs()); + Operation *newOp = rewriter.create(state); + rewriter.replaceOp(op, newOp->getResults()); + return success(); + } +}; + +struct ConvertVstsSubviewOperandPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(pto::VstsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!isa(adaptor.getDestination().getType())) + return failure(); + + OperationState state(op.getLoc(), op->getName().getStringRef()); + state.addOperands( + {adaptor.getValue(), adaptor.getDestination(), adaptor.getOffset(), + adaptor.getMask()}); + state.addTypes(op->getResultTypes()); + state.addAttributes(op->getAttrs()); + Operation *newOp = rewriter.create(state); + rewriter.replaceOp(op, newOp->getResults()); + return success(); + } +}; + +struct ConvertLoadScalarOperandToPtrPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(pto::LoadScalarOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value ptr = materializeScalarAccessPtr(adaptor.getPtr(), rewriter, op.getLoc()); + if (!ptr) + return rewriter.notifyMatchFailure(op, + "failed to materialize load_scalar ptr"); + if (!isa(ptr.getType())) + return rewriter.notifyMatchFailure(op, "expected ptr-form load_scalar input"); + + rewriter.replaceOpWithNewOp(op, op.getValue().getType(), + ptr, adaptor.getOffset()); + return success(); + } +}; + +struct ConvertStoreScalarOperandToPtrPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(pto::StoreScalarOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value ptr = materializeScalarAccessPtr(adaptor.getPtr(), rewriter, op.getLoc()); + if (!ptr) + return rewriter.notifyMatchFailure(op, + "failed to materialize store_scalar ptr"); + if (!isa(ptr.getType())) + return rewriter.notifyMatchFailure(op, "expected ptr-form store_scalar input"); + + rewriter.replaceOpWithNewOp(op, ptr, + adaptor.getOffset(), + adaptor.getValue()); + return success(); + } +}; + +struct ConvertMteUbUbOperandPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(pto::MteUbUbOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + return rewriteBufferLikeBoundaryOp(op, adaptor, rewriter, "mte_ub_ub source", + "mte_ub_ub destination"); + } +}; + +struct ConvertMteUbL1OperandPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(pto::MteUbL1Op op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + return rewriteBufferLikeBoundaryOp(op, adaptor, rewriter, "mte_ub_l1 source", + "mte_ub_l1 destination"); + } +}; + +struct ConvertCubeLoadOperandPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(pto::MteGmL1Op op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + return rewriteBufferLikeBoundaryOp(op, adaptor, rewriter, "mte_gm_l1 source", + "mte_gm_l1 destination"); + } +}; + +struct ConvertCubeStoreOperandPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(pto::MteL1UbOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + return rewriteBufferLikeBoundaryOp(op, adaptor, rewriter, + "mte_l1_ub source", + "mte_l1_ub destination"); + } +}; + +struct ConvertBiasLoadOperandPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(pto::MteL1BtOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + return rewriteBufferLikeBoundaryOp(op, adaptor, rewriter, "mte_l1_bt source", + "mte_l1_bt destination"); + } +}; + +struct ConvertCubeLoadFracOperandPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(pto::MteGmL1FracOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + return rewriteBufferLikeBoundaryOp(op, adaptor, rewriter, "mte_gm_l1_frac source", + "mte_gm_l1_frac destination"); + } +}; + +struct ConvertLeftLoadOperandPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(pto::MteL1L0aOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + return rewriteBufferLikeBoundaryOp(op, adaptor, rewriter, "mte_l1_l0a source", + "mte_l1_l0a destination"); + } +}; + +struct ConvertRightLoadOperandPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(pto::MteL1L0bOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + return rewriteBufferLikeBoundaryOp(op, adaptor, rewriter, + "mte_l1_l0b source", + "mte_l1_l0b destination"); + } +}; + +struct ConvertLeftLoadMxOperandPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(pto::MteL1L0aMxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + return rewriteBufferLikeBoundaryOp(op, adaptor, rewriter, + "mte_l1_l0a_mx source", + "mte_l1_l0a_mx destination"); + } +}; + +struct ConvertRightLoadMxOperandPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(pto::MteL1L0bMxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + return rewriteBufferLikeBoundaryOp(op, adaptor, rewriter, + "mte_l1_l0b_mx source", + "mte_l1_l0b_mx destination"); + } +}; + +struct ConvertAccStoreOperandPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(pto::MteL0cL1Op op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + return rewriteBufferLikeBoundaryOp(op, adaptor, rewriter, "mte_l0c_l1 source", + "mte_l0c_l1 destination"); + } +}; + +struct ConvertAccStoreGmOperandPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(pto::MteL0cGmOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + return rewriteBufferLikeBoundaryOp(op, adaptor, rewriter, + "mte_l0c_gm source", + "mte_l0c_gm destination"); + } +}; + +struct ConvertAccStoreUbOperandPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(pto::MteL0cUbOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + return rewriteBufferLikeBoundaryOp(op, adaptor, rewriter, + "mte_l0c_ub source", + "mte_l0c_ub destination"); + } +}; + +struct ConvertLoadOperandToPtrPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(pto::PTOLoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value ptr = materializeScalarAccessPtr(adaptor.getPtr(), rewriter, op.getLoc()); + if (!ptr) + return rewriter.notifyMatchFailure(op, "failed to materialize load ptr"); + if (!isa(ptr.getType())) + return rewriter.notifyMatchFailure(op, "expected ptr-form load input"); + + rewriter.replaceOpWithNewOp(op, op.getValue().getType(), + ptr, adaptor.getOffset()); + return success(); + } +}; + +struct ConvertStoreOperandToPtrPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(pto::PTOStoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value ptr = materializeScalarAccessPtr(adaptor.getPtr(), rewriter, op.getLoc()); + if (!ptr) + return rewriter.notifyMatchFailure(op, "failed to materialize store ptr"); + if (!isa(ptr.getType())) + return rewriter.notifyMatchFailure(op, "expected ptr-form store input"); + + rewriter.replaceOpWithNewOp(op, ptr, adaptor.getOffset(), + adaptor.getValue()); + return success(); + } +}; + +struct ConvertPtrNormalizeUnrealizedCastOp final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (op->getNumOperands() != 1 || op->getNumResults() != 1) + return failure(); + if (!hasPtrNormalizeConvertibleType(op->getOperandTypes()) && + !hasPtrNormalizeConvertibleType(op->getResultTypes())) + return failure(); + + Type convertedResultType = + getTypeConverter()->convertType(op.getResult(0).getType()); + if (!convertedResultType) + return failure(); + + Value input = adaptor.getOperands().front(); + if (input.getType() != convertedResultType) + return failure(); + + rewriter.replaceOp(op, input); + return success(); + } +}; + +struct VPTOPtrNormalizePass + : public pto::impl::VPTOPtrNormalizeBase { + using pto::impl::VPTOPtrNormalizeBase< + VPTOPtrNormalizePass>::VPTOPtrNormalizeBase; + + void runOnOperation() override { + ModuleOp module = getOperation(); + MLIRContext *context = module.getContext(); + + TypeConverter typeConverter; + typeConverter.addConversion([](Type type) { return type; }); + typeConverter.addConversion( + [](Type type) { return convertSubviewResultType(type); }); + typeConverter.addTargetMaterialization(materializeUnrealizedCast); + typeConverter.addSourceMaterialization(materializeUnrealizedCast); + typeConverter.addArgumentMaterialization(materializeUnrealizedCast); + + ConversionTarget target(*context); + target.addLegalDialect(); + target.addDynamicallyLegalDialect([](Operation *op) { + return !isa(op); + }); + target.addLegalOp(); + target.addDynamicallyLegalOp([&](func::FuncOp op) { + return typeConverter.isSignatureLegal(op.getFunctionType()) && + typeConverter.isLegal(&op.getBody()); + }); + target.addDynamicallyLegalOp( + [&](func::CallOp op) { return typeConverter.isLegal(op); }); + target.addDynamicallyLegalOp( + [&](func::ReturnOp op) { return typeConverter.isLegal(op); }); + target.addDynamicallyLegalOp( + [&](UnrealizedConversionCastOp op) { + return !hasPtrNormalizeConvertibleType(op->getOperandTypes()) && + !hasPtrNormalizeConvertibleType(op->getResultTypes()); + }); + target.addDynamicallyLegalOp([&](pto::TileBufAddrOp op) { + return op.getDst().getType() == + typeConverter.convertType(op.getDst().getType()); + }); + target.addDynamicallyLegalOp( + [&](pto::PointerCastOp op) { + return op.getResult().getType() == + typeConverter.convertType(op.getResult().getType()); + }); + target.addDynamicallyLegalOp([&](pto::CastPtrOp op) { + return !isMemRefType(op.getInput().getType()) && + !isMemRefType(op.getResult().getType()); + }); + target.addDynamicallyLegalOp([&](pto::BindTileOp op) { + return op.getResult().getType() == + typeConverter.convertType(op.getResult().getType()); + }); + target.addDynamicallyLegalOp( + [](pto::VldsOp op) { return isa(op.getSource().getType()); }); + target.addDynamicallyLegalOp([](pto::VstsOp op) { + return isa(op.getDestination().getType()); + }); + target.addDynamicallyLegalOp( + [](pto::LoadScalarOp op) { return isa(op.getPtr().getType()); }); + target.addDynamicallyLegalOp( + [](pto::StoreScalarOp op) { return isa(op.getPtr().getType()); }); + target.addDynamicallyLegalOp([](pto::MteUbUbOp op) { + return isa(op.getSource().getType()) && + isa(op.getDestination().getType()); + }); + target.addDynamicallyLegalOp([](pto::MteUbL1Op op) { + return isa(op.getSource().getType()) && + isa(op.getDestination().getType()); + }); + target.addDynamicallyLegalOp([](pto::MteGmL1Op op) { + return isa(op.getSource().getType()) && + isa(op.getDestination().getType()); + }); + target.addDynamicallyLegalOp([](pto::MteL1UbOp op) { + return isa(op.getSource().getType()) && + isa(op.getDestination().getType()); + }); + target.addDynamicallyLegalOp([](pto::MteL1BtOp op) { + return isa(op.getSource().getType()) && + isa(op.getDestination().getType()); + }); + target.addDynamicallyLegalOp([](pto::MteGmL1FracOp op) { + return isa(op.getSource().getType()) && + isa(op.getDestination().getType()); + }); + target.addDynamicallyLegalOp([](pto::MteL1L0aOp op) { + return isa(op.getSource().getType()) && + isa(op.getDestination().getType()); + }); + target.addDynamicallyLegalOp([](pto::MteL1L0bOp op) { + return isa(op.getSource().getType()) && + isa(op.getDestination().getType()); + }); + target.addDynamicallyLegalOp([](pto::MteL1L0aMxOp op) { + return isa(op.getSource().getType()) && + isa(op.getDestination().getType()); + }); + target.addDynamicallyLegalOp([](pto::MteL1L0bMxOp op) { + return isa(op.getSource().getType()) && + isa(op.getDestination().getType()); + }); + target.addDynamicallyLegalOp([](pto::MteL0cL1Op op) { + return isa(op.getSource().getType()) && + isa(op.getDestination().getType()); + }); + target.addDynamicallyLegalOp([](pto::MteL0cGmOp op) { + return isa(op.getSource().getType()) && + isa(op.getDestination().getType()); + }); + target.addDynamicallyLegalOp([](pto::MteL0cUbOp op) { + return isa(op.getSource().getType()) && + isa(op.getDestination().getType()); + }); + target.addDynamicallyLegalOp( + [](pto::PTOLoadOp op) { return isa(op.getPtr().getType()); }); + target.addDynamicallyLegalOp( + [](pto::PTOStoreOp op) { return isa(op.getPtr().getType()); }); + target.addDynamicallyLegalOp( + [](memref::SubViewOp op) { return !needsSubviewPtrConversion(op); }); + + RewritePatternSet patterns(context); + scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter, patterns, + target); + populateFunctionOpInterfaceTypeConversionPattern(patterns, + typeConverter); + populateCallOpTypeConversionPattern(patterns, typeConverter); + populateReturnOpTypeConversionPattern(patterns, typeConverter); + patterns.add( + typeConverter, context); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr mlir::pto::createVPTOPtrNormalizePass() { + return std::make_unique(); +} diff --git a/lib/PTO/Transforms/VPTOSplitCVModule.cpp b/lib/PTO/Transforms/VPTOSplitCVModule.cpp new file mode 100644 index 000000000..6cbc622b5 --- /dev/null +++ b/lib/PTO/Transforms/VPTOSplitCVModule.cpp @@ -0,0 +1,270 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "PTO/IR/PTO.h" +#include "PTO/Transforms/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_VPTOSPLITCVMODULE +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +using namespace mlir; +using namespace mlir::pto; + +namespace { + +static bool hasVPTOKernelAttr(Operation *op) { + return op->hasAttr("pto.kernel") || op->hasAttr("pto.aicore"); +} + +static bool hasKernelKind(ModuleOp module) { + return module->hasAttr(FunctionKernelKindAttr::name); +} + +static bool hasKernelKindChildModule(ModuleOp module) { + return llvm::any_of(module.getOps(), + [](ModuleOp child) { return hasKernelKind(child); }); +} + +static bool hasCVSections(ModuleOp module) { + bool found = false; + module.walk([&](func::FuncOp funcOp) { + if (found || !hasVPTOKernelAttr(funcOp)) + return WalkResult::advance(); + WalkResult result = funcOp.walk([&](Operation *op) { + if (isa(op)) { + found = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return result.wasInterrupted() ? WalkResult::interrupt() + : WalkResult::advance(); + }); + return found; +} + +static bool hasSectionKind(ModuleOp module, FunctionKernelKind kind) { + bool found = false; + module.walk([&](func::FuncOp funcOp) { + if (found || !hasVPTOKernelAttr(funcOp)) + return WalkResult::advance(); + WalkResult result = funcOp.walk([&](Operation *op) { + bool matches = kind == FunctionKernelKind::Cube + ? isa(op) + : isa(op); + if (matches) { + found = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return result.wasInterrupted() ? WalkResult::interrupt() + : WalkResult::advance(); + }); + return found; +} + +static bool hasSectionKind(func::FuncOp funcOp, FunctionKernelKind kind) { + bool found = false; + funcOp.walk([&](Operation *op) { + bool matches = kind == FunctionKernelKind::Cube ? isa(op) + : isa(op); + if (matches) { + found = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return found; +} + +static bool hasAnySection(func::FuncOp funcOp) { + bool found = false; + funcOp.walk([&](Operation *op) { + if (isa(op)) { + found = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return found; +} + +static LogicalResult verifyNoNestedSections(ModuleOp module) { + LogicalResult status = success(); + module.walk([&](Operation *op) { + if (failed(status) || !isa(op)) + return WalkResult::advance(); + Operation *parent = op->getParentOp(); + while (parent) { + if (isa(parent)) { + status = op->emitError("nested pto.section.cube/vector is not allowed"); + return WalkResult::interrupt(); + } + parent = parent->getParentOp(); + } + return WalkResult::advance(); + }); + return status; +} + +static LogicalResult verifyKernelFunctionsUseSections(ModuleOp module) { + LogicalResult status = success(); + module.walk([&](func::FuncOp funcOp) { + if (failed(status) || !hasVPTOKernelAttr(funcOp)) + return WalkResult::advance(); + if (!hasAnySection(funcOp)) { + status = funcOp.emitOpError( + "must contain pto.section.cube or pto.section.vector in section " + "input split by vpto-split-cv-module"); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return status; +} + +static LogicalResult verifyUniqueSectionKindsPerFunction(ModuleOp module) { + LogicalResult status = success(); + module.walk([&](func::FuncOp funcOp) { + if (failed(status) || !hasVPTOKernelAttr(funcOp)) + return WalkResult::advance(); + unsigned cubeCount = 0; + unsigned vectorCount = 0; + funcOp.walk([&](Operation *op) { + if (isa(op)) + ++cubeCount; + if (isa(op)) + ++vectorCount; + }); + if (cubeCount > 1) { + status = funcOp.emitOpError("contains more than one pto.section.cube"); + return WalkResult::interrupt(); + } + if (vectorCount > 1) { + status = funcOp.emitOpError("contains more than one pto.section.vector"); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return status; +} + +static void eraseKernelFunctionsWithoutSectionKind(ModuleOp module, + FunctionKernelKind kind) { + SmallVector eraseFuncs; + module.walk([&](func::FuncOp funcOp) { + if (hasVPTOKernelAttr(funcOp) && !hasSectionKind(funcOp, kind)) + eraseFuncs.push_back(funcOp); + }); + + for (func::FuncOp funcOp : eraseFuncs) + funcOp.erase(); +} + +static void replaceSectionWithBody(Operation *sectionOp) { + Region ®ion = sectionOp->getRegion(0); + Block &body = region.front(); + Block *parentBlock = sectionOp->getBlock(); + parentBlock->getOperations().splice(Block::iterator(sectionOp), + body.getOperations()); + sectionOp->erase(); +} + +static void rewriteSectionsForKind(ModuleOp module, FunctionKernelKind kind) { + SmallVector eraseSections; + SmallVector inlineSections; + module.walk([&](Operation *op) { + if (kind == FunctionKernelKind::Cube) { + if (isa(op)) + eraseSections.push_back(op); + else if (isa(op)) + inlineSections.push_back(op); + } else { + if (isa(op)) + eraseSections.push_back(op); + else if (isa(op)) + inlineSections.push_back(op); + } + }); + + for (Operation *op : eraseSections) + op->erase(); + for (Operation *op : inlineSections) + replaceSectionWithBody(op); +} + +static ModuleOp cloneModuleForKind(ModuleOp source, FunctionKernelKind kind, + OpBuilder &builder) { + auto cloned = cast(source->clone()); + cloned->setAttr(FunctionKernelKindAttr::name, + FunctionKernelKindAttr::get(cloned.getContext(), kind)); + eraseKernelFunctionsWithoutSectionKind(cloned, kind); + rewriteSectionsForKind(cloned, kind); + builder.insert(cloned); + return cloned; +} + +static LogicalResult splitCVModule(ModuleOp module) { + if (hasKernelKind(module) || hasKernelKindChildModule(module)) + return success(); + if (!hasCVSections(module)) + return success(); + if (failed(verifyNoNestedSections(module))) + return failure(); + if (failed(verifyKernelFunctionsUseSections(module))) + return failure(); + if (failed(verifyUniqueSectionKindsPerFunction(module))) + return failure(); + + bool needVector = hasSectionKind(module, FunctionKernelKind::Vector); + bool needCube = hasSectionKind(module, FunctionKernelKind::Cube); + if (!needVector && !needCube) + return success(); + + SmallVector outerAttrs; + outerAttrs.reserve(module->getAttrs().size()); + for (NamedAttribute attr : module->getAttrs()) + if (attr.getName() != SymbolTable::getSymbolAttrName()) + outerAttrs.push_back(attr); + + auto outer = ModuleOp::create(module.getLoc()); + outer->setAttrs(DictionaryAttr::get(module.getContext(), outerAttrs)); + OpBuilder builder(outer.getBody(), outer.getBody()->end()); + if (needVector) + cloneModuleForKind(module, FunctionKernelKind::Vector, builder); + if (needCube) + cloneModuleForKind(module, FunctionKernelKind::Cube, builder); + + module.getBodyRegion().takeBody(outer.getBodyRegion()); + module->setAttrs(outer->getAttrs()); + return success(); +} + +struct VPTOSplitCVModulePass + : public mlir::pto::impl::VPTOSplitCVModuleBase { + void runOnOperation() override { + if (failed(splitCVModule(getOperation()))) + signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr mlir::pto::createVPTOSplitCVModulePass() { + return std::make_unique(); +} diff --git a/lib/TileOps/__init__.py b/lib/TileOps/__init__.py new file mode 100644 index 000000000..34437fbee --- /dev/null +++ b/lib/TileOps/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""Shared TileOps template helpers.""" diff --git a/lib/TileOps/div_hp.py b/lib/TileOps/div_hp.py new file mode 100644 index 000000000..66c7c7643 --- /dev/null +++ b/lib/TileOps/div_hp.py @@ -0,0 +1,455 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""Shared IEEE 754 high-precision division algorithms for pto.tdiv and pto.tdivs + +This module provides inline_proc functions that implement IEEE 754 compliant +division with improved accuracy for: +- Precision-sensitive values (1/7, 7/3, etc.) +- Subnormal numbers (denormals) +- Overflow/underflow boundary cases +- NaN propagation + +Reference: pto-isa/include/pto/npu/a5/custom/Div754.hpp +""" + +import tilelang_dsl as pto + + +@pto.inline_proc +def _div_three_candidate_search_f32(lhs, rhs, mask): + """Three-candidate search core algorithm for IEEE 754 division accuracy improvement. + + Corresponds to DivPrecisionImpl in pto-isa/include/pto/npu/a5/custom/Div754.hpp:16-62 + + Algorithm: Computes three candidates (z, z-1, z+1) and selects the one with smallest + residual |lhs - z*rhs|, improving accuracy for values like 1/7 that have infinite + binary representation. + """ + + # IEEE 754 Float32 bit patterns (corresponds to Div754.hpp:18-19) + inf_bound_u32 = pto.ui32(0x7f800000) # Infinity bound: sign=0, exp=255, mant=0 + sign_bit_u32 = pto.ui32(0x80000000) # Sign bit mask: bit31=1, others=0 + zero_f32 = pto.f32(0.0) + one_f32 = pto.f32(1.0) + neg_one_f32 = pto.f32(-1.0) + + z = pto.vdiv(lhs, rhs, mask) + z_init = z + + z_u32 = pto.vbitcast(z, pto.ui32) + z_or_sign = pto.vor(z_u32, pto.vbr(sign_bit_u32), mask) + is_inf_nan = pto.vcmp(z_or_sign, pto.vbr(inf_bound_u32), mask, pto.CmpMode.GE) + + is_zero = pto.vcmp(z, pto.vbr(zero_f32), mask, pto.CmpMode.EQ) + + special_mask = pto.por(is_inf_nan, is_zero, mask) + + y = pto.vmuls(rhs, neg_one_f32, mask) + r = pto.vmula(lhs, z, y, mask) + + z_pre = pto.vadds(z, neg_one_f32, mask) + z_next = pto.vadds(z, one_f32, mask) + + r_pre = pto.vmula(lhs, z_pre, y, mask) + r_next = pto.vmula(lhs, z_next, y, mask) + + r_abs = pto.vabs(r, mask) + r_pre_abs = pto.vabs(r_pre, mask) + r_next_abs = pto.vabs(r_next, mask) + + better_pre = pto.vcmp(r_pre_abs, r_abs, mask, pto.CmpMode.LT) + z_best = pto.vsel(z_pre, z, better_pre) + r_best_abs = pto.vsel(r_pre_abs, r_abs, better_pre) + + better_next = pto.vcmp(r_next_abs, r_best_abs, mask, pto.CmpMode.LT) + z_best = pto.vsel(z_next, z_best, better_next) + + divided = pto.vsel(z_init, z_best, special_mask) + + return divided + + +@pto.inline_proc +def _div_ieee754_f32_impl(src0, src1, mask): + """Complete IEEE 754 float32 high-precision division with subnormal and overflow handling. + + Corresponds to DivIEEE754FloatImpl in pto-isa/include/pto/npu/a5/custom/Div754.hpp:65-288 + + Key improvements over pto-isa: + - Subnormal detection uses LT (line 94) instead of EQ (Div754.hpp:159) + Rationale: Covers entire subnormal range [2^-149, 2^-126), not just max subnormal + """ + + # IEEE 754 Float32 bit masks and constants (corresponds to Div754.hpp:69-81) + F32_INF = pto.ui32(0x7f800000) # +Infinity: sign=0, exp=255, mant=0 + sign_extractor = pto.ui32(0x80000000) # Sign bit mask (bit31) + exponent_extractor = pto.ui32(0x807FFFFF) # Clear exponent bits [30:23] + exponent_normalizer = pto.ui32(0x3F800000) # Bias 127: 1.0f reference + subnormal_threshold = pto.ui32(0x007FFFFF) # Max subnormal: (1-2^-23)*2^-126 + nan_value = pto.ui32(0x7fc00000) # Quiet NaN: exp=255, mant=0x400000 + min_denormal = pto.ui32(0x1) # Smallest positive: 2^-149 + + # Subnormal normalization factors (corresponds to Div754.hpp:86-89) + normalize_scale_enlarge = pto.f32(8388608.0) # 2^23: shifts subnormals to normal range + normalize_scale_reduce = pto.f32(1.1920928955078125e-07) # 2^-23: inverse for result compensation + + src0_abs = pto.vabs(src0, mask) + src1_abs = pto.vabs(src1, mask) + + src0_abs_u32 = pto.vbitcast(src0_abs, pto.ui32) + src1_abs_u32 = pto.vbitcast(src1_abs, pto.ui32) + + mask_inf_src0 = pto.vcmp(src0_abs_u32, pto.vbr(F32_INF), mask, pto.CmpMode.EQ) + mask_inf_src1 = pto.vcmp(src1_abs_u32, pto.vbr(F32_INF), mask, pto.CmpMode.EQ) + mask_invalid = pto.por(mask_inf_src0, mask_inf_src1, mask) + + mask_zero_src0 = pto.vcmp(src0_abs_u32, pto.vbr(pto.ui32(0)), mask, pto.CmpMode.EQ) + mask_invalid = pto.por(mask_invalid, mask_zero_src0, mask) + mask_zero_src1 = pto.vcmp(src1_abs_u32, pto.vbr(pto.ui32(0)), mask, pto.CmpMode.EQ) + mask_invalid = pto.por(mask_invalid, mask_zero_src1, mask) + + mask_valid = pto.pnot(mask_invalid, mask) + + # Detect subnormal numbers (denormals) + # NOTE: Uses EQ/LT comparison matching pto-isa Div754.hpp asymmetry: + # - src0: EQ comparison (Div754.hpp:159) - detects exact max subnormal + # - src1: LT comparison (Div754.hpp:166) - covers entire subnormal range + mask_src0_subnormal = pto.vcmp(src0_abs_u32, pto.vbr(subnormal_threshold), mask, pto.CmpMode.EQ) + mask_src0_normal = pto.pnot(mask_src0_subnormal, mask) + src0_subnormal = pto.vmuls(src0, normalize_scale_enlarge, mask_src0_subnormal) + + mask_src1_subnormal = pto.vcmp(src1_abs_u32, pto.vbr(subnormal_threshold), mask, pto.CmpMode.LT) + mask_src1_normal = pto.pnot(mask_src1_subnormal, mask) + src1_subnormal = pto.vmuls(src1, normalize_scale_enlarge, mask_src1_subnormal) + + src0_all = pto.vsel(src0, src0_subnormal, mask_src0_normal) + src1_all = pto.vsel(src1, src1_subnormal, mask_src1_normal) + + src0_all_u32 = pto.vbitcast(src0_all, pto.ui32) + src1_all_u32 = pto.vbitcast(src1_all, pto.ui32) + + src0_norm_u32 = pto.vand(src0_all_u32, pto.vbr(exponent_extractor), mask_valid) + src1_norm_u32 = pto.vand(src1_all_u32, pto.vbr(exponent_extractor), mask_valid) + + src0_norm_u32 = pto.vadd(src0_norm_u32, pto.vbr(exponent_normalizer), mask_valid) + src1_norm_u32 = pto.vadd(src1_norm_u32, pto.vbr(exponent_normalizer), mask_valid) + + src0_norm_f32 = pto.vbitcast(src0_norm_u32, pto.f32) + src1_norm_f32 = pto.vbitcast(src1_norm_u32, pto.f32) + src0_norm = pto.vsel(src0_norm_f32, src0_all, mask_valid) + src1_norm = pto.vsel(src1_norm_f32, src1_all, mask_valid) + + dst = _div_three_candidate_search_f32(src0_norm, src1_norm, mask_valid) + + mask0 = pto.pand(mask_src0_subnormal, mask_src1_normal, mask) + z1 = pto.vmuls(dst, normalize_scale_reduce, mask0) + dst = pto.vsel(z1, dst, mask0) + + mask0 = pto.pand(mask_src0_normal, mask_src1_subnormal, mask) + z1 = pto.vmuls(dst, normalize_scale_enlarge, mask0) + dst = pto.vsel(z1, dst, mask0) + + dst_u32 = pto.vbitcast(dst, pto.ui32) + dst_sign = pto.vand(dst_u32, pto.vbr(sign_extractor), mask) + + src0_exponent = pto.vand(src0_all_u32, pto.vbr(F32_INF), mask) + src1_exponent = pto.vand(src1_all_u32, pto.vbr(F32_INF), mask) + + src0_exp_shifted = pto.vshrs(src0_exponent, pto.i16(23), mask) + src1_exp_shifted = pto.vshrs(src1_exponent, pto.i16(23), mask) + + src0_exp_i32 = pto.vbitcast(src0_exp_shifted, pto.si32) + src1_exp_i32 = pto.vbitcast(src1_exp_shifted, pto.si32) + + scale = pto.vsub(src0_exp_i32, src1_exp_i32, mask) + scale = pto.vadds(scale, pto.si32(127), mask) + + neg23 = pto.si32(-23) + mask_underflow1 = pto.vcmp(scale, pto.vbr(neg23), mask, pto.CmpMode.EQ) + mask_underflow1 = pto.pand(mask_underflow1, mask_valid, mask) + + z1_u32 = pto.vadd(dst_sign, pto.vbr(min_denormal), mask_underflow1) + z2_u32 = pto.vadd(dst_sign, pto.vbr(pto.ui32(0)), mask_underflow1) + + src0_norm_abs = pto.vabs(src0_norm, mask_valid) + src1_norm_abs = pto.vabs(src1_norm, mask_valid) + mask_norm = pto.vcmp(src0_norm_abs, src1_norm_abs, mask_valid, pto.CmpMode.LE) + + z1_sel = pto.vsel(z2_u32, z1_u32, mask_norm) + dst_u32_temp = pto.vsel(z1_sel, dst_u32, mask_underflow1) + + mask_underflow1_not = pto.pnot(mask_underflow1, mask) + mask_valid_temp = pto.pand(mask_underflow1_not, mask_valid, mask) + + mask_underflow2 = pto.vcmp(scale, pto.vbr(neg23), mask, pto.CmpMode.LT) + mask_underflow2 = pto.pand(mask_underflow2, mask_valid_temp, mask) + + z1_u32 = pto.vadd(dst_sign, pto.vbr(pto.ui32(0)), mask_underflow2) + dst_u32_temp = pto.vsel(z1_u32, dst_u32_temp, mask_underflow2) + + mask_underflow2_not = pto.pnot(mask_underflow2, mask) + mask_valid_temp = pto.pand(mask_underflow2_not, mask_valid_temp, mask) + + max_exp = pto.si32(255) + mask_overflow1 = pto.vcmp(scale, pto.vbr(max_exp), mask, pto.CmpMode.EQ) + mask_overflow1 = pto.pand(mask_overflow1, mask_valid_temp, mask) + + scale_adj = pto.vadds(scale, pto.si32(-1), mask_overflow1) + scale = pto.vsel(scale_adj, scale, mask_overflow1) + + dst_f32_temp = pto.vbitcast(dst_u32_temp, pto.f32) + z1_f32 = pto.vmuls(dst_f32_temp, pto.f32(2.0), mask_overflow1) + dst_f32_temp = pto.vsel(z1_f32, dst_f32_temp, mask_overflow1) + + mask_overflow2 = pto.vcmp(scale, pto.vbr(max_exp), mask, pto.CmpMode.GT) + mask_overflow2 = pto.pand(mask_overflow2, mask_valid_temp, mask) + + z1_u32 = pto.vadd(dst_sign, pto.vbr(F32_INF), mask_overflow2) + dst_u32_temp = pto.vbitcast(dst_f32_temp, pto.ui32) + dst_u32_temp = pto.vsel(z1_u32, dst_u32_temp, mask_overflow2) + + mask_overflow2_not = pto.pnot(mask_overflow2, mask) + mask_valid_final = pto.pand(mask_overflow2_not, mask_valid_temp, mask) + + zero_exp = pto.si32(0) + mask_pos_exp = pto.vcmp(scale, pto.vbr(zero_exp), mask_valid_final, pto.CmpMode.GT) + + scale_u32 = pto.vbitcast(scale, pto.ui32) + exp_shifted = pto.vshls(scale_u32, pto.i16(23), mask_pos_exp) + exp_factor_f32 = pto.vbitcast(exp_shifted, pto.f32) + + dst_f32_temp = pto.vbitcast(dst_u32_temp, pto.f32) + z1_f32 = pto.vmul(dst_f32_temp, exp_factor_f32, mask_pos_exp) + dst_f32_temp = pto.vsel(z1_f32, dst_f32_temp, mask_pos_exp) + + mask_pos_exp_not = pto.pnot(mask_pos_exp, mask_valid_final) + + # Handle negative exponent (underflow scenarios) + # Corresponds to Div754.hpp:275 + # Value 0x00400000 = Float32 with exp=0, mantissa bit22=1 (used for shift calculation) + four_million = pto.ui32(4194304) # Normal float 1.0 in bit representation for exponent manipulation + scale_abs = pto.vabs(scale, mask_pos_exp_not) + + shr_base_vec = pto.vdup(four_million, mask_pos_exp_not) + shr_base_i32 = pto.vbitcast(shr_base_vec, pto.si32) + shr_factor_i32 = pto.vshr(shr_base_i32, scale_abs, mask_pos_exp_not) + shr_factor_f32 = pto.vbitcast(shr_factor_i32, pto.f32) + + z1_f32 = pto.vmul(dst_f32_temp, shr_factor_f32, mask_pos_exp_not) + dst_f32_temp = pto.vsel(z1_f32, dst_f32_temp, mask_pos_exp_not) + + mask_nan_src0 = pto.vcmp(src0_abs, src0_abs, mask, pto.CmpMode.NE) + mask_nan_src1 = pto.vcmp(src1_abs, src1_abs, mask, pto.CmpMode.NE) + mask_nan = pto.por(mask_nan_src0, mask_nan_src1, mask) + + nan_vec = pto.vbr(nan_value) + nan_f32_vec = pto.vbitcast(nan_vec, pto.f32) + dst_final = pto.vsel(nan_f32_vec, dst_f32_temp, mask_nan) + + return dst_final + + +@pto.inline_proc +def _div_ieee754_f16_impl(src0, src1, mask): + """Complete IEEE 754 float16 high-precision division with subnormal handling. + + Follows pto-isa Div754.hpp:291-502 (DivIEEE754HalfImpl). + + Key differences from F32 implementation: + - Uses LT for both src0/src1 subnormal detection (symmetric, not EQ/LT like F32) + - Normalization factor: 2^10 (not 2^23 for F32) + - Exponent bias: 15 (not 127 for F32) + - Exponent shift: 10 bits (not 23 for F32) + - Direct vdiv call (no three-candidate search) + """ + + # IEEE 754 Float16 bit masks and constants (corresponds to Div754.hpp:293-309) + F16_INF = pto.ui16(0x7C00) # +Infinity: sign=0, exp=31, mant=0 + exponent_extractor = pto.ui16(0x83FF) # Clear exponent bits [14:10] + exponent_normalizer = pto.ui16(0x3C00) # 1.0f16 reference (bias=15) + sign_extractor = pto.ui16(0x8000) # Sign bit mask (bit15) + subnormal_threshold = pto.ui16(0x03FF) # Max subnormal: (1-2^-10)*2^-14 + nan_value = pto.ui16(0x7E00) # Quiet NaN: exp=31, mant=0x200 + min_denormal = pto.ui16(0x1) # Smallest positive: 2^-24 + + # Subnormal normalization factors (corresponds to Div754.hpp:306-309) + normalize_scale_enlarge = pto.f16(1024.0) # 2^10: shifts subnormals to normal range + normalize_scale_reduce = pto.f16(0.0009765625) # 2^-10: inverse for result compensation + + src0_abs = pto.vabs(src0, mask) + src1_abs = pto.vabs(src1, mask) + + src0_abs_u16 = pto.vbitcast(src0_abs, pto.ui16) + src1_abs_u16 = pto.vbitcast(src1_abs, pto.ui16) + + # Detect Infinity values + mask_inf_src0 = pto.vcmp(src0_abs_u16, pto.vbr(F16_INF), mask, pto.CmpMode.EQ) + mask_inf_src1 = pto.vcmp(src1_abs_u16, pto.vbr(F16_INF), mask, pto.CmpMode.EQ) + mask_invalid = pto.por(mask_inf_src0, mask_inf_src1, mask) + + # Detect Zero values + mask_zero_src0 = pto.vcmp(src0_abs_u16, pto.vbr(pto.ui16(0)), mask, pto.CmpMode.EQ) + mask_invalid = pto.por(mask_invalid, mask_zero_src0, mask) + mask_zero_src1 = pto.vcmp(src1_abs_u16, pto.vbr(pto.ui16(0)), mask, pto.CmpMode.EQ) + mask_invalid = pto.por(mask_invalid, mask_zero_src1, mask) + + mask_valid = pto.pnot(mask_invalid, mask) + + # Detect subnormal numbers (denormals) + # NOTE: F16 uses LT for BOTH src0 and src1 (symmetric detection) + # Different from F32's asymmetric EQ/LT pattern + mask_src0_subnormal = pto.vcmp(src0_abs_u16, pto.vbr(subnormal_threshold), mask, pto.CmpMode.LT) + mask_src0_normal = pto.pnot(mask_src0_subnormal, mask) + src0_subnormal = pto.vmuls(src0, normalize_scale_enlarge, mask_src0_subnormal) + + mask_src1_subnormal = pto.vcmp(src1_abs_u16, pto.vbr(subnormal_threshold), mask, pto.CmpMode.LT) + mask_src1_normal = pto.pnot(mask_src1_subnormal, mask) + src1_subnormal = pto.vmuls(src1, normalize_scale_enlarge, mask_src1_subnormal) + + # Merge normalized subnormals with normal values + src0_all = pto.vsel(src0, src0_subnormal, mask_src0_normal) + src1_all = pto.vsel(src1, src1_subnormal, mask_src1_normal) + + src0_all_u16 = pto.vbitcast(src0_all, pto.ui16) + src1_all_u16 = pto.vbitcast(src1_all, pto.ui16) + + # Standardize exponent bits (corresponds to Div754.hpp:391-401) + src0_norm_u16 = pto.vand(src0_all_u16, pto.vbr(exponent_extractor), mask_valid) + src1_norm_u16 = pto.vand(src1_all_u16, pto.vbr(exponent_extractor), mask_valid) + + src0_norm_u16 = pto.vadd(src0_norm_u16, pto.vbr(exponent_normalizer), mask_valid) + src1_norm_u16 = pto.vadd(src1_norm_u16, pto.vbr(exponent_normalizer), mask_valid) + + src0_norm_f16 = pto.vbitcast(src0_norm_u16, pto.f16) + src1_norm_f16 = pto.vbitcast(src1_norm_u16, pto.f16) + src0_norm = pto.vsel(src0_norm_f16, src0_all, mask_valid) + src1_norm = pto.vsel(src1_norm_f16, src1_all, mask_valid) + + src0_norm_abs = pto.vabs(src0_norm, mask_valid) + src1_norm_abs = pto.vabs(src1_norm, mask_valid) + mask_norm = pto.vcmp(src0_norm_abs, src1_norm_abs, mask_valid, pto.CmpMode.LE) + + # Execute division directly (no three-candidate search for F16) + # Corresponds to Div754.hpp:406 + dst = pto.vdiv(src0_norm, src1_norm, mask) + + # Subnormal dividend, normal divisor: scale down result + # Corresponds to Div754.hpp:408-412 + mask0 = pto.pand(mask_src0_subnormal, mask_src1_normal, mask) + z1 = pto.vmuls(dst, normalize_scale_reduce, mask0) + dst = pto.vsel(z1, dst, mask0) + + # Normal dividend, subnormal divisor: scale up result + # Corresponds to Div754.hpp:414-419 + mask0 = pto.pand(mask_src0_normal, mask_src1_subnormal, mask) + z1 = pto.vmuls(dst, normalize_scale_enlarge, mask0) + dst = pto.vsel(z1, dst, mask0) + + # Preserve sign for overflow/underflow handling + dst_u16 = pto.vbitcast(dst, pto.ui16) + dst_sign = pto.vand(dst_u16, pto.vbr(sign_extractor), mask) + + # Extract exponent bits (corresponds to Div754.hpp:428-439) + src0_exponent = pto.vand(src0_all_u16, pto.vbr(F16_INF), mask) + src1_exponent = pto.vand(src1_all_u16, pto.vbr(F16_INF), mask) + + src0_exp_shifted = pto.vshrs(src0_exponent, pto.i16(10), mask) + src1_exp_shifted = pto.vshrs(src1_exponent, pto.i16(10), mask) + + src0_exp_i16 = pto.vbitcast(src0_exp_shifted, pto.si16) + src1_exp_i16 = pto.vbitcast(src1_exp_shifted, pto.si16) + + # Scale = src0_exp - src1_exp + bias(15) + scale = pto.vsub(src0_exp_i16, src1_exp_i16, mask) + scale = pto.vadds(scale, pto.si16(15), mask) + + # Underflow handling: scale == -9 (corresponds to Div754.hpp:443-453) + neg9 = pto.si16(-9) + mask_underflow1 = pto.vcmp(scale, pto.vbr(neg9), mask, pto.CmpMode.EQ) + mask_underflow1 = pto.pand(mask_underflow1, mask_valid, mask) + + z1_u16 = pto.vadd(dst_sign, pto.vbr(min_denormal), mask_underflow1) + z2_u16 = pto.vadd(dst_sign, pto.vbr(pto.ui16(0)), mask_underflow1) + + z1_sel = pto.vsel(z2_u16, z1_u16, mask_norm) + dst_u16_temp = pto.vsel(z1_sel, dst_u16, mask_underflow1) + + mask_underflow1_not = pto.pnot(mask_underflow1, mask) + mask_valid_temp = pto.pand(mask_underflow1_not, mask_valid, mask) + + # Underflow handling: scale < -9 (corresponds to Div754.hpp:456-463) + mask_underflow2 = pto.vcmp(scale, pto.vbr(neg9), mask, pto.CmpMode.LT) + mask_underflow2 = pto.pand(mask_underflow2, mask_valid_temp, mask) + + z1_u16 = pto.vadd(dst_sign, pto.vbr(pto.ui16(0)), mask_underflow2) + dst_u16_temp = pto.vsel(z1_u16, dst_u16_temp, mask_underflow2) + + mask_underflow2_not = pto.pnot(mask_underflow2, mask) + mask_valid_temp = pto.pand(mask_underflow2_not, mask_valid_temp, mask) + + # Overflow handling: scale == 31 (corresponds to Div754.hpp:465-472) + max_exp = pto.si16(31) + mask_overflow1 = pto.vcmp(scale, pto.vbr(max_exp), mask, pto.CmpMode.EQ) + mask_overflow1 = pto.pand(mask_overflow1, mask_valid_temp, mask) + + scale_adj = pto.vadds(scale, pto.si16(-1), mask_overflow1) + scale = pto.vsel(scale_adj, scale, mask_overflow1) + + dst_f16_temp = pto.vbitcast(dst_u16_temp, pto.f16) + z1_f16 = pto.vmuls(dst_f16_temp, pto.f16(2.0), mask_overflow1) + dst_f16_temp = pto.vsel(z1_f16, dst_f16_temp, mask_overflow1) + + # Overflow handling: scale > 31 (corresponds to Div754.hpp:474-480) + mask_overflow2 = pto.vcmp(scale, pto.vbr(max_exp), mask, pto.CmpMode.GT) + mask_overflow2 = pto.pand(mask_overflow2, mask_valid_temp, mask) + + z1_u16 = pto.vadd(dst_sign, pto.vbr(F16_INF), mask_overflow2) + dst_u16_temp = pto.vbitcast(dst_f16_temp, pto.ui16) + dst_u16_temp = pto.vsel(z1_u16, dst_u16_temp, mask_overflow2) + + mask_overflow2_not = pto.pnot(mask_overflow2, mask) + mask_valid_final = pto.pand(mask_overflow2_not, mask_valid_temp, mask) + + # Positive exponent handling (corresponds to Div754.hpp:482-486) + zero_exp = pto.si16(0) + mask_pos_exp = pto.vcmp(scale, pto.vbr(zero_exp), mask_valid_final, pto.CmpMode.GT) + + scale_u16 = pto.vbitcast(scale, pto.ui16) + exp_shifted = pto.vshls(scale_u16, pto.i16(10), mask_pos_exp) + exp_factor_f16 = pto.vbitcast(exp_shifted, pto.f16) + + dst_f16_temp = pto.vbitcast(dst_u16_temp, pto.f16) + z1_f16 = pto.vmul(dst_f16_temp, exp_factor_f16, mask_pos_exp) + dst_f16_temp = pto.vsel(z1_f16, dst_f16_temp, mask_pos_exp) + + # Negative exponent handling (corresponds to Div754.hpp:488-493) + mask_pos_exp_not = pto.pnot(mask_pos_exp, mask_valid_final) + + # Value 0x0200 = Float16 with exp=0, mantissa bit9=1 (used for shift calculation) + shr_base = pto.ui16(512) # 0x0200 + scale_abs = pto.vabs(scale, mask_pos_exp_not) + + shr_base_vec = pto.vdup(shr_base, mask_pos_exp_not) + shr_base_i16 = pto.vbitcast(shr_base_vec, pto.si16) + shr_factor_i16 = pto.vshr(shr_base_i16, scale_abs, mask_pos_exp_not) + shr_factor_f16 = pto.vbitcast(shr_factor_i16, pto.f16) + + z1_f16 = pto.vmul(dst_f16_temp, shr_factor_f16, mask_pos_exp_not) + dst_f16_temp = pto.vsel(z1_f16, dst_f16_temp, mask_pos_exp_not) + + # NaN propagation (corresponds to Div754.hpp:495-501) + mask_nan_src0 = pto.vcmp(src0_abs, src0_abs, mask, pto.CmpMode.NE) + mask_nan_src1 = pto.vcmp(src1_abs, src1_abs, mask, pto.CmpMode.NE) + mask_nan = pto.por(mask_nan_src0, mask_nan_src1, mask) + + nan_vec = pto.vbr(nan_value) + nan_f16_vec = pto.vbitcast(nan_vec, pto.f16) + dst_final = pto.vsel(nan_f16_vec, dst_f16_temp, mask_nan) + + return dst_final \ No newline at end of file diff --git a/lib/TileOps/exp_hp.py b/lib/TileOps/exp_hp.py new file mode 100644 index 000000000..f967c1249 --- /dev/null +++ b/lib/TileOps/exp_hp.py @@ -0,0 +1,30 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import tilelang_dsl as pto + +@pto.inline_proc +def _tl_exp_precision(src, mask, dtype): + if pto.constexpr(dtype == pto.f16): + subnormal_threshold = pto.f16("0x03ff") + two_val = pto.f16(2.0) + else: + subnormal_threshold = pto.f32("0x007FFFFF") + two_val = pto.f32(2.0) + + dst = pto.vexp(src, mask) + + subnormal_mask = pto.vcmps(dst, subnormal_threshold, mask, pto.CmpMode.LE) + + reg_two = pto.vbr(two_val) + tmp = pto.vdiv(src, reg_two, subnormal_mask) + tmp = pto.vexp(tmp, subnormal_mask) + tmp = pto.vmul(tmp, tmp, subnormal_mask) + + result = pto.vsel(tmp, dst, subnormal_mask) + return result diff --git a/lib/TileOps/math.py b/lib/TileOps/math.py new file mode 100644 index 000000000..83c4634f6 --- /dev/null +++ b/lib/TileOps/math.py @@ -0,0 +1,584 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import tilelang_dsl as pto + + +@pto.inline_proc +def _tl_soft_vdiv_u8(vec, scalar_vec, mask): + zero = pto.ui8(0) + zero_q = pto.ui8(0xFF) + full_mask_b8 = pto.pset_b8(pto.PAT.ALL) + + zero_mask = pto.vcmps(scalar_vec, zero, mask, pto.CmpMode.EQ) + active_mask = pto.pnot(zero_mask, mask) + active_low = pto.punpack(active_mask, pto.PredicatePart.LOWER) + active_high = pto.punpack(active_mask, pto.PredicatePart.HIGHER) + + vec_low = pto.vzunpack(vec, 0) + vec_high = pto.vzunpack(vec, 1) + scalar_low = pto.vzunpack(scalar_vec, 0) + scalar_high = pto.vzunpack(scalar_vec, 1) + + q_low = _tl_soft_vdiv_u16(vec_low, scalar_low, active_low) + q_high = _tl_soft_vdiv_u16(vec_high, scalar_high, active_high) + packed_low = pto.vpack(q_low, pto.PredicatePart.LOWER) + packed_high = pto.vpack(q_high, pto.PredicatePart.HIGHER) + q = pto.vor(packed_low, packed_high, full_mask_b8) + return pto.vsel(pto.vbr(zero_q), q, zero_mask) + + +@pto.inline_proc +def _tl_soft_vdiv_i8(vec, scalar_vec, mask): + zero = pto.i8(0) + neg_one = pto.i8(-1) + full_mask_b8 = pto.pset_b8(pto.PAT.ALL) + + zero_mask = pto.vcmps(scalar_vec, zero, mask, pto.CmpMode.EQ) + active_mask = pto.pnot(zero_mask, mask) + active_low = pto.punpack(active_mask, pto.PredicatePart.LOWER) + active_high = pto.punpack(active_mask, pto.PredicatePart.HIGHER) + + vec_low = pto.vsunpack(vec, 0) + vec_high = pto.vsunpack(vec, 1) + scalar_low = pto.vsunpack(scalar_vec, 0) + scalar_high = pto.vsunpack(scalar_vec, 1) + + q_low = _tl_soft_vdiv_i16(vec_low, scalar_low, active_low) + q_high = _tl_soft_vdiv_i16(vec_high, scalar_high, active_high) + packed_low = pto.vpack(q_low, pto.PredicatePart.LOWER) + packed_high = pto.vpack(q_high, pto.PredicatePart.HIGHER) + q = pto.vbitcast(pto.vor(packed_low, packed_high, full_mask_b8), pto.i8) + return pto.vsel(pto.vbr(neg_one), q, zero_mask) + + +@pto.inline_proc +def _tl_soft_vdiv_u16(vec, scalar_vec, mask): + zero = pto.ui16(0) + one = pto.ui16(1) + fp32_one = pto.f32(1.0) + full_mask_b16 = pto.pset_b16(pto.PAT.ALL) + full_mask_b32 = pto.pset_b32(pto.PAT.ALL) + + zero_mask = pto.vcmps(scalar_vec, zero, mask, pto.CmpMode.EQ) + active_mask = pto.pnot(zero_mask, mask) + + zero_u16 = pto.vbr(zero) + vy_lower_u16, vy_higher_u16 = pto.vintlv(scalar_vec, zero_u16) + vy_lower_u32 = pto.vcvt(vy_lower_u16, pto.ui32, full_mask_b16, part=pto.VcvtPartMode.EVEN) + vy_higher_u32 = pto.vcvt(vy_higher_u16, pto.ui32, full_mask_b16, part=pto.VcvtPartMode.EVEN) + active_low = pto.vcmps(vy_lower_u32, pto.ui32(0), full_mask_b32, pto.CmpMode.NE) + active_high = pto.vcmps(vy_higher_u32, pto.ui32(0), full_mask_b32, pto.CmpMode.NE) + vy_lower_f32 = pto.vcvt(pto.vbitcast(vy_lower_u32, pto.i32), pto.f32, active_low, rnd=pto.VcvtRoundMode.F) + vy_higher_f32 = pto.vcvt(pto.vbitcast(vy_higher_u32, pto.i32), pto.f32, active_high, rnd=pto.VcvtRoundMode.F) + + vy_rec_lower = pto.vdiv(pto.vbr(fp32_one), vy_lower_f32, active_low) + vy_rec_higher = pto.vdiv(pto.vbr(fp32_one), vy_higher_f32, active_high) + vy_scale_lower = pto.vmul(vy_rec_lower, pto.vbr(pto.f32(65536.0)), active_low) + vy_scale_higher = pto.vmul(vy_rec_higher, pto.vbr(pto.f32(65536.0)), active_high) + + v_lower_i32 = pto.vcvt( + vy_scale_lower, + pto.i32, + active_low, + rnd=pto.VcvtRoundMode.F, + sat=pto.VcvtSatMode.NOSAT, + ) + v_higher_i32 = pto.vcvt( + vy_scale_higher, + pto.i32, + active_high, + rnd=pto.VcvtRoundMode.F, + sat=pto.VcvtSatMode.NOSAT, + ) + v_lower_u32 = pto.vbitcast(v_lower_i32, pto.ui32) + v_higher_u32 = pto.vbitcast(v_higher_i32, pto.ui32) + + vx_lower_u16, vx_higher_u16 = pto.vintlv(vec, zero_u16) + vx_lower_u32 = pto.vcvt(vx_lower_u16, pto.ui32, full_mask_b16, part=pto.VcvtPartMode.EVEN) + vx_higher_u32 = pto.vcvt(vx_higher_u16, pto.ui32, full_mask_b16, part=pto.VcvtPartMode.EVEN) + q_tmp_lower = pto.vmul(v_lower_u32, vx_lower_u32, active_low) + q_tmp_higher = pto.vmul(v_higher_u32, vx_higher_u32, active_high) + _q_lower, q_tmp = pto.vdintlv(pto.vbitcast(q_tmp_lower, pto.ui16), pto.vbitcast(q_tmp_higher, pto.ui16)) + + yq_tmp = pto.vmul(q_tmp, scalar_vec, active_mask) + r_tmp = pto.vsub(vec, yq_tmp, active_mask) + ge_mask = pto.vcmp(r_tmp, scalar_vec, active_mask, pto.CmpMode.GE) + refined_r = pto.vsub(r_tmp, scalar_vec, active_mask) + r_tmp = pto.vsel(refined_r, r_tmp, ge_mask) + q_inc = pto.vadds(q_tmp, one, active_mask) + q_tmp = pto.vsel(q_inc, q_tmp, ge_mask) + ge_mask = pto.vcmp(r_tmp, scalar_vec, active_mask, pto.CmpMode.GE) + refined_r = pto.vsub(r_tmp, scalar_vec, active_mask) + r_tmp = pto.vsel(refined_r, r_tmp, ge_mask) + q_inc = pto.vadds(q_tmp, one, active_mask) + q_tmp = pto.vsel(q_inc, q_tmp, ge_mask) + zero_q = pto.vbr(pto.ui16(0xFFFF)) + return pto.vsel(zero_q, q_tmp, zero_mask) + + +@pto.inline_proc +def _tl_soft_vdiv_i16(vec, scalar_vec, mask): + zero = pto.i16(0) + neg_one = pto.i16(-1) + + zero_mask = pto.vcmps(scalar_vec, zero, mask, pto.CmpMode.EQ) + active_mask = pto.pnot(zero_mask, mask) + + abs_x = pto.vbitcast(pto.vabs(vec, active_mask), pto.ui16) + abs_y = pto.vbitcast(pto.vabs(scalar_vec, active_mask), pto.ui16) + x_xor_y = pto.vxor(vec, scalar_vec, active_mask) + p_pos = pto.vcmps(x_xor_y, zero, active_mask, pto.CmpMode.GE) + + q_abs = _tl_soft_vdiv_u16(abs_x, abs_y, active_mask) + neg_q = pto.vneg(pto.vbitcast(q_abs, pto.i16), active_mask) + q = pto.vsel(pto.vbitcast(q_abs, pto.i16), neg_q, p_pos) + return pto.vsel(pto.vbr(neg_one), q, zero_mask) + + +@pto.inline_proc +def _tl_soft_vmod_u8(vec, scalar_vec, mask): + zero = pto.ui8(0) + zero_r = pto.ui8(0xFF) + full_mask_b8 = pto.pset_b8(pto.PAT.ALL) + + zero_mask = pto.vcmps(scalar_vec, zero, mask, pto.CmpMode.EQ) + active_mask = pto.pnot(zero_mask, mask) + active_low = pto.punpack(active_mask, pto.PredicatePart.LOWER) + active_high = pto.punpack(active_mask, pto.PredicatePart.HIGHER) + + vec_low = pto.vzunpack(vec, 0) + vec_high = pto.vzunpack(vec, 1) + scalar_low = pto.vzunpack(scalar_vec, 0) + scalar_high = pto.vzunpack(scalar_vec, 1) + + r_low = _tl_soft_vmod_u16(vec_low, scalar_low, active_low) + r_high = _tl_soft_vmod_u16(vec_high, scalar_high, active_high) + packed_low = pto.vpack(r_low, pto.PredicatePart.LOWER) + packed_high = pto.vpack(r_high, pto.PredicatePart.HIGHER) + r = pto.vor(packed_low, packed_high, full_mask_b8) + return pto.vsel(pto.vbr(zero_r), r, zero_mask) + + +@pto.inline_proc +def _tl_soft_vmod_i8(vec, scalar_vec, mask): + zero = pto.i8(0) + neg_one = pto.i8(-1) + full_mask_b8 = pto.pset_b8(pto.PAT.ALL) + + zero_mask = pto.vcmps(scalar_vec, zero, mask, pto.CmpMode.EQ) + active_mask = pto.pnot(zero_mask, mask) + active_low = pto.punpack(active_mask, pto.PredicatePart.LOWER) + active_high = pto.punpack(active_mask, pto.PredicatePart.HIGHER) + + vec_low = pto.vsunpack(vec, 0) + vec_high = pto.vsunpack(vec, 1) + scalar_low = pto.vsunpack(scalar_vec, 0) + scalar_high = pto.vsunpack(scalar_vec, 1) + + r_low = _tl_soft_vmod_i16(vec_low, scalar_low, active_low) + r_high = _tl_soft_vmod_i16(vec_high, scalar_high, active_high) + packed_low = pto.vpack(r_low, pto.PredicatePart.LOWER) + packed_high = pto.vpack(r_high, pto.PredicatePart.HIGHER) + r = pto.vbitcast(pto.vor(packed_low, packed_high, full_mask_b8), pto.i8) + return pto.vsel(pto.vbr(neg_one), r, zero_mask) + + +@pto.inline_proc +def _tl_soft_vmod_u16(vec, scalar_vec, mask): + zero = pto.ui16(0) + one = pto.ui16(1) + zero_r = pto.ui16(0xFFFF) + fp32_one = pto.f32(1.0) + full_mask_b16 = pto.pset_b16(pto.PAT.ALL) + full_mask_b32 = pto.pset_b32(pto.PAT.ALL) + + zero_mask = pto.vcmps(scalar_vec, zero, mask, pto.CmpMode.EQ) + active_mask = pto.pnot(zero_mask, mask) + + zero_u16 = pto.vbr(zero) + vy_lower_u16, vy_higher_u16 = pto.vintlv(scalar_vec, zero_u16) + vy_lower_u32 = pto.vcvt(vy_lower_u16, pto.ui32, full_mask_b16, part=pto.VcvtPartMode.EVEN) + vy_higher_u32 = pto.vcvt(vy_higher_u16, pto.ui32, full_mask_b16, part=pto.VcvtPartMode.EVEN) + active_low = pto.vcmps(vy_lower_u32, pto.ui32(0), full_mask_b32, pto.CmpMode.NE) + active_high = pto.vcmps(vy_higher_u32, pto.ui32(0), full_mask_b32, pto.CmpMode.NE) + vy_lower_f32 = pto.vcvt(pto.vbitcast(vy_lower_u32, pto.i32), pto.f32, active_low, rnd=pto.VcvtRoundMode.F) + vy_higher_f32 = pto.vcvt(pto.vbitcast(vy_higher_u32, pto.i32), pto.f32, active_high, rnd=pto.VcvtRoundMode.F) + + vy_rec_lower = pto.vdiv(pto.vbr(fp32_one), vy_lower_f32, active_low) + vy_rec_higher = pto.vdiv(pto.vbr(fp32_one), vy_higher_f32, active_high) + vy_scale_lower = pto.vmul(vy_rec_lower, pto.vbr(pto.f32(65536.0)), active_low) + vy_scale_higher = pto.vmul(vy_rec_higher, pto.vbr(pto.f32(65536.0)), active_high) + + v_lower_i32 = pto.vcvt( + vy_scale_lower, + pto.i32, + active_low, + rnd=pto.VcvtRoundMode.F, + sat=pto.VcvtSatMode.NOSAT, + ) + v_higher_i32 = pto.vcvt( + vy_scale_higher, + pto.i32, + active_high, + rnd=pto.VcvtRoundMode.F, + sat=pto.VcvtSatMode.NOSAT, + ) + v_lower_u32 = pto.vbitcast(v_lower_i32, pto.ui32) + v_higher_u32 = pto.vbitcast(v_higher_i32, pto.ui32) + + vx_lower_u16, vx_higher_u16 = pto.vintlv(vec, zero_u16) + vx_lower_u32 = pto.vcvt(vx_lower_u16, pto.ui32, full_mask_b16, part=pto.VcvtPartMode.EVEN) + vx_higher_u32 = pto.vcvt(vx_higher_u16, pto.ui32, full_mask_b16, part=pto.VcvtPartMode.EVEN) + q_tmp_lower = pto.vmul(v_lower_u32, vx_lower_u32, active_low) + q_tmp_higher = pto.vmul(v_higher_u32, vx_higher_u32, active_high) + _q_lower, q_tmp = pto.vdintlv(pto.vbitcast(q_tmp_lower, pto.ui16), pto.vbitcast(q_tmp_higher, pto.ui16)) + + yq_tmp = pto.vmul(q_tmp, scalar_vec, active_mask) + r_tmp = pto.vsub(vec, yq_tmp, active_mask) + ge_mask = pto.vcmp(r_tmp, scalar_vec, active_mask, pto.CmpMode.GE) + refined_r = pto.vsub(r_tmp, scalar_vec, active_mask) + r_tmp = pto.vsel(refined_r, r_tmp, ge_mask) + q_inc = pto.vadds(q_tmp, one, active_mask) + q_tmp = pto.vsel(q_inc, q_tmp, ge_mask) + ge_mask = pto.vcmp(r_tmp, scalar_vec, active_mask, pto.CmpMode.GE) + refined_r = pto.vsub(r_tmp, scalar_vec, active_mask) + r_tmp = pto.vsel(refined_r, r_tmp, ge_mask) + q_inc = pto.vadds(q_tmp, one, active_mask) + q_tmp = pto.vsel(q_inc, q_tmp, ge_mask) + return pto.vsel(pto.vbr(zero_r), r_tmp, zero_mask) + + +@pto.inline_proc +def _tl_soft_vdiv_u32(vec, scalar_vec, mask): + zero = pto.ui32(0) + one = pto.ui32(1) + zero_q = pto.ui32(0xFFFFFFFF) + fp32_one = pto.f32(1.0) + full_mask = pto.pset_b32(pto.PAT.ALL) + + zero_mask = pto.vcmps(scalar_vec, zero, mask, pto.CmpMode.EQ) + active_mask = pto.pnot(zero_mask, mask) + + zero_u32 = pto.vbr(zero) + zero_f32 = pto.vbr(pto.f32(0.0)) + vy_lower_u32, vy_higher_u32 = pto.vintlv(scalar_vec, zero_u32) + vy_lower_f32 = pto.vcvt(pto.vbitcast(vy_lower_u32, pto.i64), pto.f32, full_mask, rnd=pto.VcvtRoundMode.F, part=pto.VcvtPartMode.EVEN) + vy_higher_f32 = pto.vcvt(pto.vbitcast(vy_higher_u32, pto.i64), pto.f32, full_mask, rnd=pto.VcvtRoundMode.F, part=pto.VcvtPartMode.EVEN) + vy_float, _vy_waste = pto.vdintlv(vy_lower_f32, vy_higher_f32) + + vy_rec = pto.vdiv(pto.vbr(fp32_one), vy_float, full_mask) + vy_scale = pto.vmul(vy_rec, pto.vbr(pto.f32(4294966784.0)), full_mask) + + vy_scale_lower_f32, vy_scale_higher_f32 = pto.vintlv(vy_scale, zero_f32) + v_lower_i64 = pto.vcvt( + vy_scale_lower_f32, + pto.i64, + full_mask, + rnd=pto.VcvtRoundMode.F, + sat=pto.VcvtSatMode.NOSAT, + part=pto.VcvtPartMode.EVEN, + ) + v_higher_i64 = pto.vcvt( + vy_scale_higher_f32, + pto.i64, + full_mask, + rnd=pto.VcvtRoundMode.F, + sat=pto.VcvtSatMode.NOSAT, + part=pto.VcvtPartMode.EVEN, + ) + z, _z_waste = pto.vdintlv(pto.vbitcast(v_lower_i64, pto.ui32), pto.vbitcast(v_higher_i64, pto.ui32)) + + tmp_0 = pto.vmul(z, scalar_vec, full_mask) + tmp_0 = pto.vbitcast(pto.vneg(pto.vbitcast(tmp_0, pto.i32), full_mask), pto.ui32) + _z_lower, z_high = pto.vmull(z, tmp_0, full_mask) + z = pto.vadd(z, z_high, full_mask) + + _q_lower, q_tmp = pto.vmull(vec, z, full_mask) + yq_tmp = pto.vmul(q_tmp, scalar_vec, active_mask) + r_tmp = pto.vsub(vec, yq_tmp, active_mask) + ge_mask = pto.vcmp(r_tmp, scalar_vec, active_mask, pto.CmpMode.GE) + refined_r = pto.vsub(r_tmp, scalar_vec, active_mask) + r_tmp = pto.vsel(refined_r, r_tmp, ge_mask) + q_inc = pto.vadds(q_tmp, one, active_mask) + q_tmp = pto.vsel(q_inc, q_tmp, ge_mask) + ge_mask = pto.vcmp(r_tmp, scalar_vec, active_mask, pto.CmpMode.GE) + refined_r = pto.vsub(r_tmp, scalar_vec, active_mask) + r_tmp = pto.vsel(refined_r, r_tmp, ge_mask) + q_inc = pto.vadds(q_tmp, one, active_mask) + q_tmp = pto.vsel(q_inc, q_tmp, ge_mask) + return pto.vsel(pto.vbr(zero_q), q_tmp, zero_mask) + + +@pto.inline_proc +def _tl_soft_vmod_i16(vec, scalar_vec, mask): + zero = pto.i16(0) + neg_one = pto.i16(-1) + + zero_mask = pto.vcmps(scalar_vec, zero, mask, pto.CmpMode.EQ) + active_mask = pto.pnot(zero_mask, mask) + + abs_x = pto.vbitcast(pto.vabs(vec, active_mask), pto.ui16) + abs_y = pto.vbitcast(pto.vabs(scalar_vec, active_mask), pto.ui16) + x_xor_y = pto.vxor(vec, scalar_vec, active_mask) + p_pos = pto.vcmps(x_xor_y, zero, active_mask, pto.CmpMode.GE) + + q_abs = _tl_soft_vdiv_u16(abs_x, abs_y, active_mask) + neg_q = pto.vneg(pto.vbitcast(q_abs, pto.i16), active_mask) + q = pto.vsel(pto.vbitcast(q_abs, pto.i16), neg_q, p_pos) + + qy = pto.vmul(q, scalar_vec, active_mask) + remainder = pto.vsub(vec, qy, active_mask) + + nonzero_remainder = pto.vcmps(remainder, zero, active_mask, pto.CmpMode.NE) + sign_x = pto.vcmps(vec, zero, active_mask, pto.CmpMode.GE) + sign_y = pto.vcmps(scalar_vec, zero, active_mask, pto.CmpMode.GE) + sign_diff = pto.pxor(sign_x, sign_y, active_mask) + need_floor_fix = pto.pand(sign_diff, nonzero_remainder, active_mask) + amended_remainder = pto.vadd(scalar_vec, remainder, active_mask) + remainder = pto.vsel(amended_remainder, remainder, need_floor_fix) + return pto.vsel(pto.vbr(neg_one), remainder, zero_mask) + + +@pto.inline_proc +def _tl_soft_vdiv_i32(vec, scalar_vec, mask): + zero = pto.i32(0) + neg_one = pto.i32(-1) + fp32_one = pto.f32(1.0) + false_mask = pto.pset_b32(pto.PAT.ALLF) + + zero_mask = pto.vcmps(scalar_vec, zero, mask, pto.CmpMode.EQ) + active_mask = pto.pnot(zero_mask, mask) + + abs_x = pto.vbitcast(pto.vabs(vec, active_mask), pto.ui32) + abs_y = pto.vbitcast(pto.vabs(scalar_vec, active_mask), pto.ui32) + x_xor_y = pto.vxor(vec, scalar_vec, active_mask) + p_pos = pto.vcmps(x_xor_y, zero, active_mask, pto.CmpMode.GE) + + y_float = pto.vcvt(pto.vbitcast(abs_y, pto.i32), pto.f32, active_mask, rnd=pto.VcvtRoundMode.R) + y_rec = pto.vdiv(pto.vbr(fp32_one), y_float, active_mask) + f_z_tmp_bits = pto.vadds(pto.vbitcast(y_rec, pto.ui32), pto.ui32(0x0FFFFFFE), active_mask) + + low_mask, high_mask = pto.pintlv_b32(active_mask, false_mask) + lower_bits, higher_bits = pto.vintlv(f_z_tmp_bits, pto.vbr(pto.ui32(0))) + lower_i64 = pto.vcvt( + pto.vbitcast(lower_bits, pto.f32), + pto.i64, + low_mask, + rnd=pto.VcvtRoundMode.F, + sat=pto.VcvtSatMode.NOSAT, + part=pto.VcvtPartMode.EVEN, + ) + higher_i64 = pto.vcvt( + pto.vbitcast(higher_bits, pto.f32), + pto.i64, + high_mask, + rnd=pto.VcvtRoundMode.F, + sat=pto.VcvtSatMode.NOSAT, + part=pto.VcvtPartMode.EVEN, + ) + z, _z_waste = pto.vdintlv(pto.vbitcast(lower_i64, pto.ui32), pto.vbitcast(higher_i64, pto.ui32)) + active_mask, _waste_mask = pto.pdintlv_b32(low_mask, high_mask) + + fz_negative = pto.vcmps(pto.vbitcast(f_z_tmp_bits, pto.f32), pto.f32(0.0), active_mask, pto.CmpMode.LT) + z = pto.vsel(pto.vbr(pto.ui32(0)), z, fz_negative) + + tmp_0 = pto.vmul(z, abs_y, active_mask) + tmp_0 = pto.vbitcast(pto.vneg(pto.vbitcast(tmp_0, pto.i32), active_mask), pto.ui32) + _z_lower, z_high = pto.vmull(z, tmp_0, active_mask) + z = pto.vadd(z, z_high, active_mask) + + _q_lower, q_tmp = pto.vmull(abs_x, z, active_mask) + yq_tmp = pto.vmul(q_tmp, abs_y, active_mask) + r_tmp = pto.vsub(abs_x, yq_tmp, active_mask) + ge_mask = pto.vcmp(r_tmp, abs_y, active_mask, pto.CmpMode.GE) + refined_r = pto.vsub(r_tmp, abs_y, active_mask) + r_tmp = pto.vsel(refined_r, r_tmp, ge_mask) + q_inc = pto.vadds(q_tmp, pto.ui32(1), active_mask) + q_tmp = pto.vsel(q_inc, q_tmp, ge_mask) + ge_mask = pto.vcmp(r_tmp, abs_y, active_mask, pto.CmpMode.GE) + refined_r = pto.vsub(r_tmp, abs_y, active_mask) + r_tmp = pto.vsel(refined_r, r_tmp, ge_mask) + q_inc = pto.vadds(q_tmp, pto.ui32(1), active_mask) + q_tmp = pto.vsel(q_inc, q_tmp, ge_mask) + + neg_q = pto.vneg(pto.vbitcast(q_tmp, pto.i32), active_mask) + q = pto.vsel(pto.vbitcast(q_tmp, pto.i32), neg_q, p_pos) + return pto.vsel(pto.vbr(neg_one), q, zero_mask) + + +@pto.inline_proc +def _tl_soft_vmod_u32(vec, scalar_vec, mask): + zero = pto.ui32(0) + one = pto.ui32(1) + zero_r = pto.ui32(0xFFFFFFFF) + fp32_one = pto.f32(1.0) + full_mask = pto.pset_b32(pto.PAT.ALL) + + zero_mask = pto.vcmps(scalar_vec, zero, mask, pto.CmpMode.EQ) + active_mask = pto.pnot(zero_mask, mask) + + zero_u32 = pto.vbr(zero) + zero_f32 = pto.vbr(pto.f32(0.0)) + vy_lower_u32, vy_higher_u32 = pto.vintlv(scalar_vec, zero_u32) + vy_lower_f32 = pto.vcvt(pto.vbitcast(vy_lower_u32, pto.i64), pto.f32, full_mask, rnd=pto.VcvtRoundMode.F, part=pto.VcvtPartMode.EVEN) + vy_higher_f32 = pto.vcvt(pto.vbitcast(vy_higher_u32, pto.i64), pto.f32, full_mask, rnd=pto.VcvtRoundMode.F, part=pto.VcvtPartMode.EVEN) + vy_float, _vy_waste = pto.vdintlv(vy_lower_f32, vy_higher_f32) + + vy_rec = pto.vdiv(pto.vbr(fp32_one), vy_float, full_mask) + vy_scale = pto.vmul(vy_rec, pto.vbr(pto.f32(4294966784.0)), full_mask) + + vy_scale_lower_f32, vy_scale_higher_f32 = pto.vintlv(vy_scale, zero_f32) + v_lower_i64 = pto.vcvt( + vy_scale_lower_f32, + pto.i64, + full_mask, + rnd=pto.VcvtRoundMode.F, + sat=pto.VcvtSatMode.NOSAT, + part=pto.VcvtPartMode.EVEN, + ) + v_higher_i64 = pto.vcvt( + vy_scale_higher_f32, + pto.i64, + full_mask, + rnd=pto.VcvtRoundMode.F, + sat=pto.VcvtSatMode.NOSAT, + part=pto.VcvtPartMode.EVEN, + ) + z, _z_waste = pto.vdintlv(pto.vbitcast(v_lower_i64, pto.ui32), pto.vbitcast(v_higher_i64, pto.ui32)) + + tmp_0 = pto.vmul(z, scalar_vec, full_mask) + tmp_0 = pto.vbitcast(pto.vneg(pto.vbitcast(tmp_0, pto.i32), full_mask), pto.ui32) + _z_lower, z_high = pto.vmull(z, tmp_0, full_mask) + z = pto.vadd(z, z_high, full_mask) + + _q_lower, q_tmp = pto.vmull(vec, z, full_mask) + yq_tmp = pto.vmul(q_tmp, scalar_vec, active_mask) + r_tmp = pto.vsub(vec, yq_tmp, active_mask) + ge_mask = pto.vcmp(r_tmp, scalar_vec, active_mask, pto.CmpMode.GE) + refined_r = pto.vsub(r_tmp, scalar_vec, active_mask) + r_tmp = pto.vsel(refined_r, r_tmp, ge_mask) + q_inc = pto.vadds(q_tmp, one, active_mask) + q_tmp = pto.vsel(q_inc, q_tmp, ge_mask) + ge_mask = pto.vcmp(r_tmp, scalar_vec, active_mask, pto.CmpMode.GE) + refined_r = pto.vsub(r_tmp, scalar_vec, active_mask) + r_tmp = pto.vsel(refined_r, r_tmp, ge_mask) + q_inc = pto.vadds(q_tmp, one, active_mask) + q_tmp = pto.vsel(q_inc, q_tmp, ge_mask) + return pto.vsel(pto.vbr(zero_r), r_tmp, zero_mask) + + +@pto.inline_proc +def _tl_soft_vmod_i32(vec, scalar_vec, mask): + zero = pto.i32(0) + neg_one = pto.i32(-1) + fp32_one = pto.f32(1.0) + false_mask = pto.pset_b32(pto.PAT.ALLF) + + zero_mask = pto.vcmps(scalar_vec, zero, mask, pto.CmpMode.EQ) + active_mask = pto.pnot(zero_mask, mask) + + abs_x = pto.vbitcast(pto.vabs(vec, active_mask), pto.ui32) + abs_y = pto.vbitcast(pto.vabs(scalar_vec, active_mask), pto.ui32) + x_xor_y = pto.vxor(vec, scalar_vec, active_mask) + p_pos = pto.vcmps(x_xor_y, zero, active_mask, pto.CmpMode.GE) + + y_float = pto.vcvt(pto.vbitcast(abs_y, pto.i32), pto.f32, active_mask, rnd=pto.VcvtRoundMode.R) + y_rec = pto.vdiv(pto.vbr(fp32_one), y_float, active_mask) + f_z_tmp_bits = pto.vadds(pto.vbitcast(y_rec, pto.ui32), pto.ui32(0x0FFFFFFE), active_mask) + + low_mask, high_mask = pto.pintlv_b32(active_mask, false_mask) + lower_bits, higher_bits = pto.vintlv(f_z_tmp_bits, pto.vbr(pto.ui32(0))) + lower_i64 = pto.vcvt( + pto.vbitcast(lower_bits, pto.f32), + pto.i64, + low_mask, + rnd=pto.VcvtRoundMode.F, + sat=pto.VcvtSatMode.NOSAT, + part=pto.VcvtPartMode.EVEN, + ) + higher_i64 = pto.vcvt( + pto.vbitcast(higher_bits, pto.f32), + pto.i64, + high_mask, + rnd=pto.VcvtRoundMode.F, + sat=pto.VcvtSatMode.NOSAT, + part=pto.VcvtPartMode.EVEN, + ) + z, _z_waste = pto.vdintlv(pto.vbitcast(lower_i64, pto.ui32), pto.vbitcast(higher_i64, pto.ui32)) + active_mask, _waste_mask = pto.pdintlv_b32(low_mask, high_mask) + + fz_negative = pto.vcmps(pto.vbitcast(f_z_tmp_bits, pto.f32), pto.f32(0.0), active_mask, pto.CmpMode.LT) + z = pto.vsel(pto.vbr(pto.ui32(0)), z, fz_negative) + + tmp_0 = pto.vmul(z, abs_y, active_mask) + tmp_0 = pto.vbitcast(pto.vneg(pto.vbitcast(tmp_0, pto.i32), active_mask), pto.ui32) + _z_lower, z_high = pto.vmull(z, tmp_0, active_mask) + z = pto.vadd(z, z_high, active_mask) + + _q_lower, q_tmp = pto.vmull(abs_x, z, active_mask) + yq_tmp = pto.vmul(q_tmp, abs_y, active_mask) + r_tmp = pto.vsub(abs_x, yq_tmp, active_mask) + ge_mask = pto.vcmp(r_tmp, abs_y, active_mask, pto.CmpMode.GE) + refined_r = pto.vsub(r_tmp, abs_y, active_mask) + r_tmp = pto.vsel(refined_r, r_tmp, ge_mask) + q_inc = pto.vadds(q_tmp, pto.ui32(1), active_mask) + q_tmp = pto.vsel(q_inc, q_tmp, ge_mask) + ge_mask = pto.vcmp(r_tmp, abs_y, active_mask, pto.CmpMode.GE) + refined_r = pto.vsub(r_tmp, abs_y, active_mask) + r_tmp = pto.vsel(refined_r, r_tmp, ge_mask) + q_inc = pto.vadds(q_tmp, pto.ui32(1), active_mask) + q_tmp = pto.vsel(q_inc, q_tmp, ge_mask) + + neg_q = pto.vneg(pto.vbitcast(q_tmp, pto.i32), active_mask) + q = pto.vsel(pto.vbitcast(q_tmp, pto.i32), neg_q, p_pos) + + qy = pto.vmul(q, scalar_vec, active_mask) + remainder = pto.vsub(vec, qy, active_mask) + nonzero_remainder = pto.vcmps(pto.vbitcast(r_tmp, pto.i32), zero, active_mask, pto.CmpMode.NE) + sign_x = pto.vcmps(vec, zero, active_mask, pto.CmpMode.GE) + sign_y = pto.vcmps(scalar_vec, zero, active_mask, pto.CmpMode.GE) + sign_diff = pto.pxor(sign_x, sign_y, active_mask) + need_floor_fix = pto.pand(sign_diff, nonzero_remainder, active_mask) + amended_remainder = pto.vadd(scalar_vec, remainder, active_mask) + remainder = pto.vsel(amended_remainder, remainder, need_floor_fix) + return pto.vsel(pto.vbr(neg_one), remainder, zero_mask) + + +@pto.inline_proc +def _tl_soft_vmod(vec, scalar_vec, mask, dtype): + if pto.constexpr(dtype == pto.ui8): + result = _tl_soft_vmod_u8(vec, scalar_vec, mask) + elif pto.constexpr(dtype == pto.i8): + result = _tl_soft_vmod_i8(vec, scalar_vec, mask) + elif pto.constexpr(dtype == pto.ui16): + result = _tl_soft_vmod_u16(vec, scalar_vec, mask) + elif pto.constexpr(dtype == pto.i16): + result = _tl_soft_vmod_i16(vec, scalar_vec, mask) + elif pto.constexpr(dtype == pto.ui32): + result = _tl_soft_vmod_u32(vec, scalar_vec, mask) + else: + result = _tl_soft_vmod_i32(vec, scalar_vec, mask) + return result + + +@pto.inline_proc +def _tl_soft_vdiv(vec, scalar_vec, mask, dtype): + if pto.constexpr(dtype == pto.ui8): + result = _tl_soft_vdiv_u8(vec, scalar_vec, mask) + elif pto.constexpr(dtype == pto.i8): + result = _tl_soft_vdiv_i8(vec, scalar_vec, mask) + elif pto.constexpr(dtype == pto.ui16): + result = _tl_soft_vdiv_u16(vec, scalar_vec, mask) + elif pto.constexpr(dtype == pto.i16): + result = _tl_soft_vdiv_i16(vec, scalar_vec, mask) + elif pto.constexpr(dtype == pto.ui32): + result = _tl_soft_vdiv_u32(vec, scalar_vec, mask) + else: + result = _tl_soft_vdiv_i32(vec, scalar_vec, mask) + return result diff --git a/lib/TileOps/render_template_mlir.py b/lib/TileOps/render_template_mlir.py new file mode 100644 index 000000000..b21118961 --- /dev/null +++ b/lib/TileOps/render_template_mlir.py @@ -0,0 +1,378 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""Materialize a TileLang DSL library template to authoring-form MLIR. + +Examples: + python3 lib/TileOps/render_template_mlir.py lib/TileOps/tload_template.py + python3 lib/TileOps/render_template_mlir.py lib/TileOps/tadd_template.py --tile dst=8x64@ub --tile src0=8x64@ub --tile src1=8x64@ub + python3 lib/TileOps/render_template_mlir.py lib/TileOps/tload_template.py --dtypes f16,f16 -o /tmp/tload.mlir +""" + +from __future__ import annotations + +import argparse +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +REPO_ROOT = Path(__file__).resolve().parents[2] +TILELANG_PYTHON_DIR = REPO_ROOT / "tilelang-dsl" / "python" +if str(TILELANG_PYTHON_DIR) not in sys.path: + sys.path.insert(0, str(TILELANG_PYTHON_DIR)) + +import tilelang_dsl as pto + + +_DTYPE_BY_NAME = { + "i1": pto.i1, + "i8": pto.i8, + "i16": pto.i16, + "i32": pto.i32, + "i64": pto.i64, + "f16": pto.f16, + "bf16": pto.bf16, + "f32": pto.f32, +} +_MEMORY_SPACE_BY_NAME = { + "gm": pto.MemorySpace.GM, + "ub": pto.MemorySpace.UB, +} + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Load a TileLang DSL template file and emit its corresponding MLIR text.", + ) + parser.add_argument("template", help="Path to the template Python file") + parser.add_argument( + "--kernel", + help="Descriptor symbol name inside the module when the file defines multiple @pto.vkernel templates", + ) + parser.add_argument( + "--op", + help="Concrete op to bind when the descriptor matches multiple ops; defaults to the first match op", + ) + parser.add_argument( + "--dtypes", + help="Concrete operand dtypes as a comma-separated list, for example: f32,f32 or f16,f16,f16", + ) + parser.add_argument( + "--tile", + action="append", + default=[], + metavar="PARAM=SHAPE[@SPACE][:VALID]", + help=( + "Tile specialization override, for example: dst=16x32@ub or " + "dst=16x32@ub:8x32. May be repeated." + ), + ) + parser.add_argument( + "--default-tile-shape", + default="16x32", + help="Default shape for every bare Tile parameter when no --tile override is given", + ) + parser.add_argument( + "--default-tile-space", + default="ub", + choices=sorted(_MEMORY_SPACE_BY_NAME), + help="Default memory space for every bare Tile parameter", + ) + parser.add_argument( + "-o", + "--output", + help="Optional output path; defaults to stdout", + ) + return parser.parse_args() + + +def _load_module(template_path: Path) -> ModuleType: + template_parent = template_path.parent.parent + if str(template_parent) not in sys.path: + sys.path.insert(0, str(template_parent)) + module_name = f"_tileops_template_{template_path.stem}" + spec = importlib.util.spec_from_file_location(module_name, template_path) + if spec is None or spec.loader is None: + raise ValueError(f"failed to load Python module from {template_path}") + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + +def _find_descriptors(module: ModuleType) -> dict[str, pto.VKernelDescriptor]: + descriptors: dict[str, pto.VKernelDescriptor] = {} + for name, value in vars(module).items(): + if isinstance(value, pto.VKernelDescriptor): + descriptors[name] = value + return descriptors + + +def _select_descriptor( + descriptors: dict[str, pto.VKernelDescriptor], + kernel_name: str | None, +) -> tuple[str, pto.VKernelDescriptor]: + if not descriptors: + raise ValueError("no @pto.vkernel descriptor found in the template module") + if kernel_name is not None: + descriptor = descriptors.get(kernel_name) + if descriptor is None: + available = ", ".join(sorted(descriptors)) + raise ValueError( + f"kernel {kernel_name!r} was not found in the template module; available descriptors: {available}" + ) + return kernel_name, descriptor + if len(descriptors) == 1: + return next(iter(descriptors.items())) + available = ", ".join(sorted(descriptors)) + raise ValueError( + "the template module defines multiple @pto.vkernel descriptors; " + f"please pass --kernel. Available descriptors: {available}" + ) + + +def _parse_dtype_list(text: str) -> tuple[pto.ScalarType, ...]: + parts = [part.strip() for part in text.split(",") if part.strip()] + if not parts: + raise ValueError("--dtypes must contain at least one dtype") + try: + return tuple(_DTYPE_BY_NAME[part] for part in parts) + except KeyError as exc: + available = ", ".join(sorted(_DTYPE_BY_NAME)) + raise ValueError( + f"unsupported dtype {exc.args[0]!r}; available dtypes: {available}" + ) from exc + + +def _default_concrete_dtype(pattern: object) -> pto.ScalarType: + if isinstance(pattern, pto.ScalarType): + return pattern + if isinstance(pattern, pto.WildcardType): + if pattern.name in {"AnyType", "AnyFloat"}: + return pto.f32 + if pattern.name == "AnyInt": + return pto.i32 + if pattern.name == "AnyMask": + return pto.i1 + raise ValueError(f"unsupported wildcard dtype pattern {pattern!r}") + if isinstance(pattern, pto.TypeVariable): + return pto.f32 + raise ValueError(f"unsupported dtype pattern {pattern!r}") + + +def _default_parameter_dtype( + param_spec: object | None, + pattern: object, +) -> pto.ScalarType: + annotation = getattr(param_spec, "annotation", None) + if isinstance(annotation, pto.ScalarType): + return annotation + if isinstance(annotation, pto.WildcardType) and annotation.name != "AnyType": + return _default_concrete_dtype(annotation) + if isinstance(annotation, pto.MaskType): + return pto.i1 + return _default_concrete_dtype(pattern) + + +def _default_operand_types(descriptor: pto.VKernelDescriptor) -> tuple[pto.ScalarType, ...]: + if not descriptor.dtypes: + raise ValueError("descriptor does not declare any dtype signatures") + prototype = descriptor.dtypes[0] + parameter_specs = getattr(descriptor, "_parameter_specs", ()) + typevar_bindings: dict[str, pto.ScalarType] = {} + concrete: list[pto.ScalarType] = [] + for index, pattern in enumerate(prototype): + param_spec = parameter_specs[index] if index < len(parameter_specs) else None + if isinstance(pattern, pto.TypeVariable): + bound = typevar_bindings.get(pattern.name) + if bound is None: + bound = _default_parameter_dtype(param_spec, pattern) + typevar_bindings[pattern.name] = bound + concrete.append(bound) + continue + concrete.append(_default_parameter_dtype(param_spec, pattern)) + return tuple(concrete) + + +def _bind_descriptor( + descriptor: pto.VKernelDescriptor, + *, + op_name: str | None, + operand_types: tuple[pto.ScalarType, ...] | None, +) -> pto.VKernelDescriptor: + concrete_op = op_name + if concrete_op is None: + if descriptor.selected_op is not None: + concrete_op = descriptor.selected_op + elif len(descriptor.match_ops) == 1: + concrete_op = descriptor.match_ops[0] + else: + available = ", ".join(descriptor.match_ops) + raise ValueError( + f"descriptor matches multiple ops; pass --op. Available ops: {available}" + ) + + concrete_operand_types = operand_types + if concrete_operand_types is None: + if descriptor._selected_dtype_signature is not None: + concrete_operand_types = descriptor._selected_dtype_signature + else: + concrete_operand_types = _default_operand_types(descriptor) + + registry = pto.KernelRegistry((descriptor,)) + return pto.select_kernel( + target=descriptor.target, + op=concrete_op, + operand_types=concrete_operand_types, + registry=registry, + ) + + +def _parse_shape(text: str) -> tuple[int, ...]: + dims = [] + for part in text.split("x"): + part = part.strip() + if not part: + raise ValueError(f"invalid shape {text!r}") + value = int(part) + if value <= 0: + raise ValueError(f"shape dimensions must be positive integers, got {text!r}") + dims.append(value) + if not dims: + raise ValueError(f"invalid shape {text!r}") + return tuple(dims) + + +def _parse_tile_override(spec_text: str) -> tuple[str, pto.TileSpecialization]: + if "=" not in spec_text: + raise ValueError( + f"invalid --tile value {spec_text!r}; expected PARAM=SHAPE[@SPACE][:VALID]" + ) + param_name, payload = spec_text.split("=", 1) + param_name = param_name.strip() + payload = payload.strip() + if not param_name: + raise ValueError(f"invalid --tile value {spec_text!r}; missing parameter name") + + valid_shape = None + if ":" in payload: + payload, valid_text = payload.split(":", 1) + valid_shape = _parse_shape(valid_text.strip()) + + memory_space = pto.MemorySpace.UB + if "@" in payload: + shape_text, memory_space_text = payload.split("@", 1) + memory_space_key = memory_space_text.strip().lower() + try: + memory_space = _MEMORY_SPACE_BY_NAME[memory_space_key] + except KeyError as exc: + available = ", ".join(sorted(_MEMORY_SPACE_BY_NAME)) + raise ValueError( + f"unsupported memory space {memory_space_text!r}; available spaces: {available}" + ) from exc + else: + shape_text = payload + + shape = _parse_shape(shape_text.strip()) + if valid_shape is not None and len(valid_shape) != len(shape): + raise ValueError( + f"valid_shape rank {len(valid_shape)} does not match shape rank {len(shape)} for {param_name!r}" + ) + return ( + param_name, + pto.TileSpecialization( + shape=shape, + memory_space=memory_space, + valid_shape=valid_shape, + ), + ) + + +def _default_tile_specialization( + *, + shape: tuple[int, ...], + memory_space: pto.MemorySpace, +) -> pto.TileSpecialization: + return pto.TileSpecialization(shape=shape, memory_space=memory_space) + + +def _specialize_tiles( + descriptor: pto.VKernelDescriptor, + *, + tile_overrides: dict[str, pto.TileSpecialization], + default_shape: tuple[int, ...], + default_memory_space: pto.MemorySpace, +) -> pto.VKernelDescriptor: + if not descriptor.tile_parameters: + return descriptor + + specializations: dict[str, pto.TileSpecialization] = {} + for param in descriptor.tile_parameters: + specializations[param.name] = tile_overrides.get( + param.name, + _default_tile_specialization( + shape=default_shape, + memory_space=default_memory_space, + ), + ) + return descriptor.specialize(**specializations) + + +def _emit_output(text: str, output_path: str | None) -> None: + if output_path is None: + sys.stdout.write(text) + if not text.endswith("\n"): + sys.stdout.write("\n") + return + path = Path(output_path) + path.write_text(text, encoding="utf-8") + + +def main() -> int: + args = _parse_args() + template_path = Path(args.template).resolve() + if not template_path.is_file(): + print(f"error: template file not found: {template_path}", file=sys.stderr) + return 1 + + try: + module = _load_module(template_path) + _, descriptor = _select_descriptor(_find_descriptors(module), args.kernel) + operand_types = None if args.dtypes is None else _parse_dtype_list(args.dtypes) + bound = _bind_descriptor( + descriptor, + op_name=args.op, + operand_types=operand_types, + ) + tile_overrides = dict(_parse_tile_override(spec_text) for spec_text in args.tile) + specialized = _specialize_tiles( + bound, + tile_overrides=tile_overrides, + default_shape=_parse_shape(args.default_tile_shape), + default_memory_space=_MEMORY_SPACE_BY_NAME[args.default_tile_space], + ) + _emit_output(specialized.mlir_text(), args.output) + return 0 + except Exception as exc: + print(f"error: {exc}", file=sys.stderr) + return 1 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/lib/TileOps/sqrt_hp.py b/lib/TileOps/sqrt_hp.py new file mode 100644 index 000000000..11f0ea97d --- /dev/null +++ b/lib/TileOps/sqrt_hp.py @@ -0,0 +1,83 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import tilelang_dsl as pto + +@pto.inline_proc +def _tl_sqrt_precision_f16(src, mask): + multiply_factor0 = pto.f16("0x6c00") + multiply_factor1 = pto.f16("0x2400") + subnormal_threshold = pto.f16("0x03ff") + + subnormal_mask = pto.vcmps(src, subnormal_threshold, mask, pto.CmpMode.LT) + + tmp = pto.vmuls(src, multiply_factor0, subnormal_mask) + src_adjusted = pto.vsel(tmp, src, subnormal_mask) + + dst = pto.vsqrt(src_adjusted, mask) + + tmp = pto.vmuls(dst, multiply_factor1, subnormal_mask) + result = pto.vsel(tmp, dst, subnormal_mask) + + return result + + +@pto.inline_proc +def _tl_sqrt_precision_f32(src, mask): + multiply_factor0 = pto.f32(16777216.0) + multiply_factor1 = pto.f32(0.000244140625) + subnormal_bound = pto.f32(1.0) + half_factor = pto.f32(0.5) + neg_one = pto.f32(-1.0) + + subnormal_mask = pto.vcmps(src, subnormal_bound, mask, pto.CmpMode.LT) + + tmp = pto.vmuls(src, multiply_factor0, subnormal_mask) + src_adjusted = pto.vsel(tmp, src, subnormal_mask) + + reg_one = pto.vbr(pto.f32(1.0)) + tmp_sqrt = pto.vsqrt(src_adjusted, mask) + dst = pto.vdiv(reg_one, tmp_sqrt, mask) + + reg_neg_one = pto.vmuls(dst, neg_one, mask) + err = pto.vmul(dst, src_adjusted, mask) + reg_one_adj = pto.vmula(reg_one, err, reg_neg_one, mask) + tmp_half = pto.vmuls(dst, half_factor, mask) + dst = pto.vmula(dst, reg_one_adj, tmp_half, mask) + + res = pto.vmul(dst, src_adjusted, mask) + tmp_neg = pto.vmuls(res, neg_one, mask) + err = pto.vmula(src_adjusted, res, tmp_neg, mask) + tmp_half = pto.vmuls(dst, half_factor, mask) + tmp = pto.vmul(err, tmp_half, mask) + tmp = pto.vadd(tmp, res, mask) + + tmp_scaled = pto.vmuls(tmp, multiply_factor1, mask) + result = pto.vsel(tmp_scaled, tmp, subnormal_mask) + + pos_inf = pto.ui32(0x7f800000) + neg_zero = pto.ui32(0x80000000) + + src_as_u32 = pto.vbitcast(src_adjusted, pto.ui32) + is_inf_mask = pto.vcmps(src_as_u32, pos_inf, mask, pto.CmpMode.EQ) + src_with_sign = pto.vor(src_as_u32, pto.vbr(neg_zero), mask) + is_zero_mask = pto.vcmps(src_with_sign, neg_zero, mask, pto.CmpMode.EQ) + special_mask = pto.por(is_zero_mask, is_inf_mask, mask) + + result = pto.vsel(src_adjusted, result, special_mask) + + return result + + +@pto.inline_proc +def _tl_sqrt_precision(src, mask, dtype): + if pto.constexpr(dtype == pto.f16): + result = _tl_sqrt_precision_f16(src, mask) + else: + result = _tl_sqrt_precision_f32(src, mask) + return result \ No newline at end of file diff --git a/lib/TileOps/tabs_template.py b/lib/TileOps/tabs_template.py new file mode 100644 index 000000000..6c6802ae5 --- /dev/null +++ b/lib/TileOps/tabs_template.py @@ -0,0 +1,29 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tabs""" + +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tabs" +) +def template_tabs(src: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + vinput = pto.vlds(src[row, col:]) + result = pto.vabs(vinput, mask) + pto.vsts(result, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/tadd_template.py b/lib/TileOps/tadd_template.py new file mode 100644 index 000000000..8e247fd73 --- /dev/null +++ b/lib/TileOps/tadd_template.py @@ -0,0 +1,32 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tadd""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tadd" +) +def template_tadd(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + summed = pto.vadd(lhs, rhs, mask) + pto.vsts(summed, dst[row, col:], mask) + return diff --git a/lib/TileOps/tadds_template.py b/lib/TileOps/tadds_template.py new file mode 100644 index 000000000..7c3ddb06c --- /dev/null +++ b/lib/TileOps/tadds_template.py @@ -0,0 +1,31 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tadds""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tadds", +) +def template_tadds(src: pto.Tile, scalar: pto.AnyType, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + vec = pto.vlds(src[row, col:]) + result = pto.vadds(vec, scalar, mask) + pto.vsts(result, dst[row, col:], mask) + return diff --git a/lib/TileOps/tand_template.py b/lib/TileOps/tand_template.py new file mode 100644 index 000000000..6c1477197 --- /dev/null +++ b/lib/TileOps/tand_template.py @@ -0,0 +1,32 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tand""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tand" +) +def template_tand(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + result = pto.vand(lhs, rhs, mask) + pto.vsts(result, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/tands_template.py b/lib/TileOps/tands_template.py new file mode 100644 index 000000000..91258e793 --- /dev/null +++ b/lib/TileOps/tands_template.py @@ -0,0 +1,40 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tands + +Note: A5 hardware implements tands as: + TEXPANDS_IMPL(dst, scalar); // broadcast scalar to dst + TAND_IMPL(dst, src, dst); // dst = src & dst + +This template uses vbr + vand to achieve element-wise bitwise AND. +Only supports tile, scalar order. +""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tands", +) +def template_tands(src: pto.Tile, scalar: pto.AnyType, dst: pto.Tile): + dtype = src.element_type + valid_rows, valid_cols = src.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + vec = pto.vlds(src[row, col:]) + scalar_vec = pto.vbr(scalar) + result = pto.vand(vec, scalar_vec, mask) + pto.vsts(result, dst[row, col:], mask) + return diff --git a/lib/TileOps/tcmp_template.py b/lib/TileOps/tcmp_template.py new file mode 100644 index 000000000..6b6b2e88f --- /dev/null +++ b/lib/TileOps/tcmp_template.py @@ -0,0 +1,130 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tcmp + +Note: A5 hardware implements tcmp as packed comparison between two tiles: + dst = packed_mask(src0 cmp src1) + +Dst is i8 type with same shape as src (packed predicate mask bytes). +Uses vcmp + psts to produce packed predicate mask output. + +Implementation per TCmp.hpp: + - 32B types (f32, i32): uses TCmp_32B path with pdintlv_b8 + PK storage + - 16B types (f16, i16): uses TCmp_8B_16B path with PK storage + - 8B types (i8, u8): uses TCmp_8B_16B path with NORM storage + +Supported comparison modes (via cmp_mode attribute): + eq, ne, lt, gt, ge, le +""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tcmp", + dtypes=[ + (pto.f32, pto.f32, pto.i8), + (pto.i32, pto.i32, pto.i8), + (pto.f16, pto.f16, pto.i8), + (pto.i16, pto.i16, pto.i8), + (pto.i8, pto.i8, pto.i8), + (pto.ui8, pto.ui8, pto.i8), + ], + advanced=True, +) +def template_tcmp(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + """src0 cmp src1 -> packed mask in dst (i8, same shape as src) + + TCmp.hpp structure: + - 32B: TCmp_32B with double iteration + pdintlv_b8 + PK + - 16B: TCmp_8B_16B with PK + - 8B: TCmp_8B_16B with NORM + """ + dtype = src0.element_type + valid_rows, valid_cols = src0.valid_shape + cmp_mode = pto.get_op_attr("cmp_mode", "eq") + + lanes = pto.get_lanes(dtype) + dst_ptr = dst.as_ptr() + dst_stride = dst.shape[1] + + if pto.constexpr(dtype == pto.f32 or dtype == pto.i32): + # 32B path: TCmp_32B implementation per TCmp.hpp + # repeatElm = CCE_VL / sizeof(uint32_t) = 64 + # repeatTimes = CeilDivision(validCol, repeatElm) + 1 + # iterations = repeatTimes // 2 + # Each iteration loads 2*64 elements (4 vlds), uses plt_b32 to split into + # two 32-lane comparisons, then pdintlv_b8 to interleave + repeat_times = (valid_cols + lanes - 1) // lanes + 1 + iterations = repeat_times // 2 + + for row in range(0, valid_rows, 1): + remained = valid_cols + for j in range(0, iterations, 1): + # Load 4 vector registers per TCmp.hpp structure + # src0Reg0, src0Reg1 from j*2*repeatElm offset + # src1Reg0, src1Reg1 from (j*2+1)*repeatElm offset + vec_src0_first = pto.vlds(src0[row, j * lanes * 2:]) + vec_src1_first = pto.vlds(src1[row, j * lanes * 2:]) + vec_src0_second = pto.vlds(src0[row, (j * 2 + 1) * lanes:]) + vec_src1_second = pto.vlds(src1[row, (j * 2 + 1) * lanes:]) + + # Use plt_b32 to create 32-lane masks (POST_UPDATE semantics) + mask_first, remained = pto.make_mask(dtype, remained) + cmp_first = pto.vcmp(vec_src0_first, vec_src1_first, mask_first, cmp_mode) + cmp_first_b8 = pto.pbitcast(cmp_first, pto.mask_b8) + + mask_second, remained = pto.make_mask(dtype, remained) + cmp_second = pto.vcmp(vec_src0_second, vec_src1_second, mask_second, cmp_mode) + cmp_second_b8 = pto.pbitcast(cmp_second, pto.mask_b8) + + # pdintlv_b8 interleave two mask_b8 results + packed_low, packed_high = pto.pdintlv_b8(cmp_first_b8, cmp_second_b8) + + # Store to dst: offset = (row * dstStride + j * 4) in uint32 units + # byte_offset = row * dst_stride + j * 16 + # For i8 dst, dstStride = RowStride / 4 = 16 uint32 units + byte_offset = row * dst_stride + j * 16 + pto.psts(packed_low, dst_ptr, byte_offset, pto.PredicateDist.PK) + + elif pto.constexpr(dtype == pto.f16 or dtype == pto.i16): + # 16B path: TCmp_8B_16B with PK + # vcmp returns mask_b16, cast to mask_b8 for psts PK + iters_per_row = (valid_cols + lanes - 1) // lanes + + for row in range(0, valid_rows, 1): + remained = valid_cols + for j in range(0, iters_per_row, 1): + mask, remained = pto.make_mask(dtype, remained) + vec0 = pto.vlds(src0[row, j * lanes:]) + vec1 = pto.vlds(src1[row, j * lanes:]) + cmp = pto.vcmp(vec0, vec1, mask, cmp_mode) + cmp_b8 = pto.pbitcast(cmp, pto.mask_b8) + byte_offset = row * dst_stride + j * 16 + pto.psts(cmp_b8, dst_ptr, byte_offset, pto.PredicateDist.PK) + + else: + # 8B path: TCmp_8B_16B with NORM + # vcmp returns mask_b8 directly, no cast needed + iters_per_row = (valid_cols + lanes - 1) // lanes + + for row in range(0, valid_rows, 1): + remained = valid_cols + for j in range(0, iters_per_row, 1): + mask, remained = pto.make_mask(dtype, remained) + vec0 = pto.vlds(src0[row, j * lanes:]) + vec1 = pto.vlds(src1[row, j * lanes:]) + cmp = pto.vcmp(vec0, vec1, mask, cmp_mode) + byte_offset = row * dst_stride + j * 32 + pto.psts(cmp, dst_ptr, byte_offset, pto.PredicateDist.NORM) + + return \ No newline at end of file diff --git a/lib/TileOps/tcmps_template.py b/lib/TileOps/tcmps_template.py new file mode 100644 index 000000000..fad499c5c --- /dev/null +++ b/lib/TileOps/tcmps_template.py @@ -0,0 +1,137 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tcmps + +Note: A5 hardware implements tcmps as packed comparison with scalar: + dst = packed_mask(src cmp scalar) + +Uses vcmps + psts to produce packed predicate mask output. +Implementation: + - 32B types (f32, i32): 64 elements per repeat, 32 bytes per iteration (NORM mode). + - 16B types (f16, i16): 128 elements per repeat, 16 bytes per iteration (PK mode). + - 8B types (i8, u8): 256 elements per repeat, 32 bytes per iteration (NORM mode). + +Supported comparison modes (via cmp_mode op attribute): + eq, ne, lt, gt, ge, le +""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tcmps", + dtypes=[ + (pto.f32, pto.f32, pto.ui8), (pto.i32, pto.i32, pto.ui8), + (pto.f16, pto.f16, pto.ui8), (pto.i16, pto.i16, pto.ui8), + (pto.i8, pto.i8, pto.ui8), (pto.ui8, pto.ui8, pto.ui8), + ], + advanced=True, +) +def template_tcmps(src: pto.Tile, scalar: pto.AnyType, dst: pto.Tile): + """src cmp scalar -> packed mask in dst (ui8) + + - 32B: 1 repeat per iteration, 32 bytes/store (NORM mode, 1 bit per element) + - 16B: 1 repeat per iteration, 16 bytes/store (PK mode) + - 8B: 1 repeat per iteration, 32 bytes/store (NORM mode) + """ + dtype = src.element_type + valid_rows, valid_cols = src.valid_shape + cmp_mode = pto.get_op_attr("cmp_mode") + + lanes = pto.get_lanes(dtype) + dst_ptr = dst.as_ptr() + + if pto.constexpr(dtype == pto.f32 or dtype == pto.i32): + # 32B path: 2 vcmps + pbitcast + dintlv_b8 -> psts(PK) + # Use 2D slicing for safety, convert linear offset to (row, col) + bytes_per_iter = 16 + elem_size = 4 + total_elm = valid_rows * valid_cols + repeat_elm = lanes + + # Calculate repeat times matching ISA: CeilDivision + 1 + # But limit iterations to avoid complete out-of-bounds access + repeat_times = (total_elm + repeat_elm - 1) // repeat_elm + 1 + + # Safety: limit iterations to avoid elem_offset beyond total_elm + repeat_elm + # ISA allows one extra repeat for odd elements, but we need to protect DSL slicing + iterations_needed = repeat_times // 2 + + for i in range(0, iterations_needed, 1): + # Convert linear element offsets to (row, col) coordinates + elem_offset0 = i * 2 * repeat_elm + elem_offset1 = (i * 2 + 1) * repeat_elm + + row0 = elem_offset0 // valid_cols + col0 = elem_offset0 % valid_cols + row1 = elem_offset1 // valid_cols + col1 = elem_offset1 % valid_cols + + # Remaining elements for each position (clamp to >= 0) + # When remaining <= 0, make_mask returns all-zero mask (safe) + remaining0 = total_elm - elem_offset0 + if remaining0 < 0: + remaining0 = 0 + remaining1 = total_elm - elem_offset1 + if remaining1 < 0: + remaining1 = 0 + + # Predicate for each compare + mask0, _ = pto.make_mask(dtype, remaining0) + mask1, _ = pto.make_mask(dtype, remaining1) + + # Load using 2D slicing (safer than pointer+offset) + # When row/col exceeds valid_shape, mask ensures no invalid data is used + vec0 = pto.vlds(src[row0, col0:]) + vec1 = pto.vlds(src[row1, col1:]) + + cmp0 = pto.vcmps(vec0, scalar, mask0, cmp_mode) + cmp1 = pto.vcmps(vec1, scalar, mask1, cmp_mode) + + # Convert mask_b32 to mask_b8 and interleave + cmp0_b8 = pto.pbitcast(cmp0, pto.mask_b8) + cmp1_b8 = pto.pbitcast(cmp1, pto.mask_b8) + cmp_interleaved, _ = pto.pdintlv_b8(cmp0_b8, cmp1_b8) + + # Linear byte offset for output + byte_offset = i * bytes_per_iter + pto.psts(cmp_interleaved, dst_ptr, byte_offset, pto.PredicateDist.PK) + elif pto.constexpr(dtype == pto.f16 or dtype == pto.i16): + # 16B path: 128 elements per repeat, 16 bytes per iteration (PK mode). + # Each vcmps produces 128 bits; PK mode packs them into 16 bytes, + # achieving 1 bit per element. + bytes_per_iter = 16 + iters_per_row = (valid_cols + lanes - 1) // lanes + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, lanes): + mask, remained = pto.make_mask(dtype, remained) + vec = pto.vlds(src[row, col:]) + cmp = pto.vcmps(vec, scalar, mask, cmp_mode) + byte_offset = (row * iters_per_row + col // lanes) * bytes_per_iter + pto.psts(cmp, dst_ptr, byte_offset, pto.PredicateDist.PK) + else: + # 8B path: 256 elements per repeat, 32 bytes packed per iteration + bytes_per_iter = 32 + iters_per_row = (valid_cols + lanes - 1) // lanes + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, lanes): + mask, remained = pto.make_mask(dtype, remained) + vec = pto.vlds(src[row, col:]) + cmp = pto.vcmps(vec, scalar, mask, cmp_mode) + byte_offset = (row * iters_per_row + col // lanes) * bytes_per_iter + pto.psts(cmp, dst_ptr, byte_offset, pto.PredicateDist.NORM) + + return diff --git a/lib/TileOps/tcolargmax_template.py b/lib/TileOps/tcolargmax_template.py new file mode 100644 index 000000000..e6781be27 --- /dev/null +++ b/lib/TileOps/tcolargmax_template.py @@ -0,0 +1,213 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +def _validate_tcolargmax( + src_shape=(), + src_valid_shape=(), + tmp_shape=(), + tmp_valid_shape=(), + dst_shape=(), + dst_valid_shape=(), + src_config=None, + tmp_config=None, + dst_config=None, + src_dtype=None, + tmp_dtype=None, + dst_dtype=None, +): + if src_config is None or tmp_config is None or dst_config is None: + return False + if src_config.b_layout != pto.BLayout.ROW_MAJOR: + return False + if tmp_config.b_layout != pto.BLayout.ROW_MAJOR: + return False + if dst_config.b_layout != pto.BLayout.ROW_MAJOR: + return False + if src_config.s_layout != pto.SLayout.NONE_BOX: + return False + if tmp_config.s_layout != pto.SLayout.NONE_BOX: + return False + if dst_config.s_layout != pto.SLayout.NONE_BOX: + return False + if dst_valid_shape[0] != 1: + return False + if src_dtype != tmp_dtype: + return False + return True + + +@pto.vkernel( + target="a5", + op="pto.tcolargmax", + dtypes=[ + (pto.ui8, pto.ui8, pto.i32), + (pto.i8, pto.i8, pto.i32), + ], + constraints=[_validate_tcolargmax], + advanced=True, +) +def template_tcolargmax_i8_to_i32(src: pto.Tile, tmp: pto.Tile, dst: pto.Tile): + src_valid_rows, src_valid_cols = src.valid_shape + src_dtype = src.element_type + lanes_i8 = pto.get_lanes(src_dtype) + lanes_i32 = pto.get_lanes(pto.i32) + + if pto.constexpr(src_dtype == pto.ui8): + intermediate_dtype = pto.ui16 + final_cvt_dtype = pto.ui32 + else: + intermediate_dtype = pto.i16 + final_cvt_dtype = pto.i32 + + lanes_intermediate = pto.get_lanes(intermediate_dtype) + + with pto.vecscope(): + all_mask_b8 = pto.make_mask(src_dtype, pto.PAT.ALL) + all_mask_intermediate = pto.make_mask(intermediate_dtype, pto.PAT.ALL) + + for col in range(0, src_valid_cols, lanes_i8): + remained = src_valid_cols - col + mask_i32_0, remained = pto.make_mask(pto.i32, remained) + mask_i32_1, remained = pto.make_mask(pto.i32, remained) + mask_i32_2, remained = pto.make_mask(pto.i32, remained) + mask_i32_3, _ = pto.make_mask(pto.i32, remained) + + index_old_even = pto.vdup(intermediate_dtype(0), all_mask_intermediate) + index_old_odd = pto.vdup(intermediate_dtype(0), all_mask_intermediate) + index_new_even = pto.vdup(intermediate_dtype(0), all_mask_intermediate) + index_new_odd = pto.vdup(intermediate_dtype(0), all_mask_intermediate) + + vreg_old = pto.vlds(src[0, col:]) + vreg_old_even = pto.vcvt(vreg_old, intermediate_dtype, all_mask_b8, part=pto.VcvtPartMode.EVEN) + vreg_old_odd = pto.vcvt(vreg_old, intermediate_dtype, all_mask_b8, part=pto.VcvtPartMode.ODD) + + for row in range(1, src_valid_rows, 1): + index_new_even = pto.vadds(index_new_even, intermediate_dtype(1), all_mask_intermediate) + index_new_odd = pto.vadds(index_new_odd, intermediate_dtype(1), all_mask_intermediate) + vreg_new = pto.vlds(src[row, col:]) + vreg_new_even = pto.vcvt(vreg_new, intermediate_dtype, all_mask_b8, part=pto.VcvtPartMode.EVEN) + vreg_new_odd = pto.vcvt(vreg_new, intermediate_dtype, all_mask_b8, part=pto.VcvtPartMode.ODD) + + select_even = pto.vcmp(vreg_new_even, vreg_old_even, all_mask_intermediate, "gt") + select_odd = pto.vcmp(vreg_new_odd, vreg_old_odd, all_mask_intermediate, "gt") + + index_old_even = pto.vsel(index_new_even, index_old_even, select_even) + index_old_odd = pto.vsel(index_new_odd, index_old_odd, select_odd) + + vreg_old_even = pto.vmax(vreg_old_even, vreg_new_even, all_mask_intermediate) + vreg_old_odd = pto.vmax(vreg_old_odd, vreg_new_odd, all_mask_intermediate) + + index_output_0, index_output_1 = pto.vintlv(index_old_even, index_old_odd) + output_even = pto.vcvt(index_output_0, final_cvt_dtype, all_mask_intermediate, part=pto.VcvtPartMode.EVEN) + output_odd = pto.vcvt(index_output_0, final_cvt_dtype, all_mask_intermediate, part=pto.VcvtPartMode.ODD) + output_0, output_1 = pto.vintlv(output_even, output_odd) + + output_0 = pto.vbitcast(output_0, pto.i32) + output_1 = pto.vbitcast(output_1, pto.i32) + + pto.vsts(output_0, dst[0, col:], mask_i32_0) + pto.vsts(output_1, dst[0, col + lanes_i32:], mask_i32_1) + + output_even = pto.vcvt(index_output_1, final_cvt_dtype, all_mask_intermediate, part=pto.VcvtPartMode.EVEN) + output_odd = pto.vcvt(index_output_1, final_cvt_dtype, all_mask_intermediate, part=pto.VcvtPartMode.ODD) + output_0, output_1 = pto.vintlv(output_even, output_odd) + + output_0 = pto.vbitcast(output_0, pto.i32) + output_1 = pto.vbitcast(output_1, pto.i32) + + pto.vsts(output_0, dst[0, col + 2 * lanes_i32:], mask_i32_2) + pto.vsts(output_1, dst[0, col + 3 * lanes_i32:], mask_i32_3) + + return + + +@pto.vkernel( + target="a5", + op="pto.tcolargmax", + dtypes=[ + (pto.f16, pto.f16, pto.i32), + (pto.ui16, pto.ui16, pto.i32), + ], + constraints=[_validate_tcolargmax], + advanced=True, +) +def template_tcolargmax_f16_to_i32(src: pto.Tile, tmp: pto.Tile, dst: pto.Tile): + src_valid_rows, src_valid_cols = src.valid_shape + lanes_f16 = pto.get_lanes(pto.f16) + lanes_i32 = pto.get_lanes(pto.i32) + + with pto.vecscope(): + all_mask = pto.make_mask(pto.f16, pto.PAT.ALL) + for col in range(0, src_valid_cols, lanes_f16): + remained = src_valid_cols - col + + mask_f16, _ = pto.make_mask(pto.f16, remained) + + mask_i32_0, remained = pto.make_mask(pto.i32, remained) + mask_i32_1, _ = pto.make_mask(pto.i32, remained) + + index_old = pto.vdup(pto.i16(0), mask_f16) + index_new = pto.vdup(pto.i16(0), mask_f16) + max_vals = pto.vlds(src[0, col:]) + + for row in range(1, src_valid_rows, 1): + index_new = pto.vadds(index_new, pto.i16(1), mask_f16) + new_vals = pto.vlds(src[row, col:]) + gt_mask = pto.vcmp(new_vals, max_vals, all_mask, "gt") + index_old = pto.vsel(index_new, index_old, gt_mask) + max_vals = pto.vmax(max_vals, new_vals, mask_f16) + + index_even = pto.vcvt(index_old, pto.i32, all_mask, part=pto.VcvtPartMode.EVEN) + index_odd = pto.vcvt(index_old, pto.i32, all_mask, part=pto.VcvtPartMode.ODD) + index_lo, index_hi = pto.vintlv(index_even, index_odd) + + pto.vsts(index_lo, dst[0, col:], mask_i32_0) + pto.vsts(index_hi, dst[0, col + lanes_i32:], mask_i32_1) + + return + + +@pto.vkernel( + target="a5", + op="pto.tcolargmax", + dtypes=[ + (pto.f32, pto.f32, pto.i32), + (pto.ui32, pto.ui32, pto.i32), + ], + constraints=[_validate_tcolargmax], + advanced=True, +) +def template_tcolargmax_f32_to_i32(src: pto.Tile, tmp: pto.Tile, dst: pto.Tile): + src_valid_rows, src_valid_cols = src.valid_shape + lanes = pto.get_lanes(pto.f32) + + with pto.vecscope(): + remained = src_valid_cols + for col in range(0, src_valid_cols, lanes): + all_mask = pto.make_mask(pto.f32, pto.PAT.ALL) + mask, remained = pto.make_mask(pto.f32, remained) + + index_old = pto.vdup(pto.i32(0), mask) + index_new = pto.vdup(pto.i32(0), mask) + max_vals = pto.vlds(src[0, col:]) + + for row in range(1, src_valid_rows, 1): + index_new = pto.vadds(index_new, pto.i32(1), mask) + new_vals = pto.vlds(src[row, col:]) + gt_mask = pto.vcmp(new_vals, max_vals, all_mask, "gt") + index_old = pto.vsel(index_new, index_old, gt_mask) + max_vals = pto.vmax(max_vals, new_vals, mask) + + pto.vsts(index_old, dst[0, col:], mask) + + return \ No newline at end of file diff --git a/lib/TileOps/tcolargmin_template.py b/lib/TileOps/tcolargmin_template.py new file mode 100644 index 000000000..d5f6808ab --- /dev/null +++ b/lib/TileOps/tcolargmin_template.py @@ -0,0 +1,213 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +def _validate_tcolargmin( + src_shape=(), + src_valid_shape=(), + tmp_shape=(), + tmp_valid_shape=(), + dst_shape=(), + dst_valid_shape=(), + src_config=None, + tmp_config=None, + dst_config=None, + src_dtype=None, + tmp_dtype=None, + dst_dtype=None, +): + if src_config is None or tmp_config is None or dst_config is None: + return False + if src_config.b_layout != pto.BLayout.ROW_MAJOR: + return False + if tmp_config.b_layout != pto.BLayout.ROW_MAJOR: + return False + if dst_config.b_layout != pto.BLayout.ROW_MAJOR: + return False + if src_config.s_layout != pto.SLayout.NONE_BOX: + return False + if tmp_config.s_layout != pto.SLayout.NONE_BOX: + return False + if dst_config.s_layout != pto.SLayout.NONE_BOX: + return False + if dst_valid_shape[0] != 1: + return False + if src_dtype != tmp_dtype: + return False + return True + + +@pto.vkernel( + target="a5", + op="pto.tcolargmin", + dtypes=[ + (pto.ui8, pto.ui8, pto.i32), + (pto.i8, pto.i8, pto.i32), + ], + constraints=[_validate_tcolargmin], + advanced=True, +) +def template_tcolargmin_i8_to_i32(src: pto.Tile, tmp: pto.Tile, dst: pto.Tile): + src_valid_rows, src_valid_cols = src.valid_shape + src_dtype = src.element_type + lanes_i8 = pto.get_lanes(src_dtype) + lanes_i32 = pto.get_lanes(pto.i32) + + if pto.constexpr(src_dtype == pto.ui8): + intermediate_dtype = pto.ui16 + final_cvt_dtype = pto.ui32 + else: + intermediate_dtype = pto.i16 + final_cvt_dtype = pto.i32 + + lanes_intermediate = pto.get_lanes(intermediate_dtype) + + with pto.vecscope(): + all_mask_b8 = pto.make_mask(src_dtype, pto.PAT.ALL) + all_mask_intermediate = pto.make_mask(intermediate_dtype, pto.PAT.ALL) + + for col in range(0, src_valid_cols, lanes_i8): + remained = src_valid_cols - col + mask_i32_0, remained = pto.make_mask(pto.i32, remained) + mask_i32_1, remained = pto.make_mask(pto.i32, remained) + mask_i32_2, remained = pto.make_mask(pto.i32, remained) + mask_i32_3, _ = pto.make_mask(pto.i32, remained) + + index_old_even = pto.vdup(intermediate_dtype(0), all_mask_intermediate) + index_old_odd = pto.vdup(intermediate_dtype(0), all_mask_intermediate) + index_new_even = pto.vdup(intermediate_dtype(0), all_mask_intermediate) + index_new_odd = pto.vdup(intermediate_dtype(0), all_mask_intermediate) + + vreg_old = pto.vlds(src[0, col:]) + vreg_old_even = pto.vcvt(vreg_old, intermediate_dtype, all_mask_b8, part=pto.VcvtPartMode.EVEN) + vreg_old_odd = pto.vcvt(vreg_old, intermediate_dtype, all_mask_b8, part=pto.VcvtPartMode.ODD) + + for row in range(1, src_valid_rows, 1): + index_new_even = pto.vadds(index_new_even, intermediate_dtype(1), all_mask_intermediate) + index_new_odd = pto.vadds(index_new_odd, intermediate_dtype(1), all_mask_intermediate) + vreg_new = pto.vlds(src[row, col:]) + vreg_new_even = pto.vcvt(vreg_new, intermediate_dtype, all_mask_b8, part=pto.VcvtPartMode.EVEN) + vreg_new_odd = pto.vcvt(vreg_new, intermediate_dtype, all_mask_b8, part=pto.VcvtPartMode.ODD) + + select_even = pto.vcmp(vreg_new_even, vreg_old_even, all_mask_intermediate, "lt") + select_odd = pto.vcmp(vreg_new_odd, vreg_old_odd, all_mask_intermediate, "lt") + + index_old_even = pto.vsel(index_new_even, index_old_even, select_even) + index_old_odd = pto.vsel(index_new_odd, index_old_odd, select_odd) + + vreg_old_even = pto.vmin(vreg_old_even, vreg_new_even, all_mask_intermediate) + vreg_old_odd = pto.vmin(vreg_old_odd, vreg_new_odd, all_mask_intermediate) + + index_output_0, index_output_1 = pto.vintlv(index_old_even, index_old_odd) + output_even = pto.vcvt(index_output_0, final_cvt_dtype, all_mask_intermediate, part=pto.VcvtPartMode.EVEN) + output_odd = pto.vcvt(index_output_0, final_cvt_dtype, all_mask_intermediate, part=pto.VcvtPartMode.ODD) + output_0, output_1 = pto.vintlv(output_even, output_odd) + + output_0 = pto.vbitcast(output_0, pto.i32) + output_1 = pto.vbitcast(output_1, pto.i32) + + pto.vsts(output_0, dst[0, col:], mask_i32_0) + pto.vsts(output_1, dst[0, col + lanes_i32:], mask_i32_1) + + output_even = pto.vcvt(index_output_1, final_cvt_dtype, all_mask_intermediate, part=pto.VcvtPartMode.EVEN) + output_odd = pto.vcvt(index_output_1, final_cvt_dtype, all_mask_intermediate, part=pto.VcvtPartMode.ODD) + output_0, output_1 = pto.vintlv(output_even, output_odd) + + output_0 = pto.vbitcast(output_0, pto.i32) + output_1 = pto.vbitcast(output_1, pto.i32) + + pto.vsts(output_0, dst[0, col + 2 * lanes_i32:], mask_i32_2) + pto.vsts(output_1, dst[0, col + 3 * lanes_i32:], mask_i32_3) + + return + + +@pto.vkernel( + target="a5", + op="pto.tcolargmin", + dtypes=[ + (pto.f16, pto.f16, pto.i32), + (pto.ui16, pto.ui16, pto.i32), + ], + constraints=[_validate_tcolargmin], + advanced=True, +) +def template_tcolargmin_f16_to_i32(src: pto.Tile, tmp: pto.Tile, dst: pto.Tile): + src_valid_rows, src_valid_cols = src.valid_shape + lanes_f16 = pto.get_lanes(pto.f16) + lanes_i32 = pto.get_lanes(pto.i32) + + with pto.vecscope(): + all_mask = pto.make_mask(pto.f16, pto.PAT.ALL) + for col in range(0, src_valid_cols, lanes_f16): + remained = src_valid_cols - col + + mask_f16, _ = pto.make_mask(pto.f16, remained) + + mask_i32_0, remained = pto.make_mask(pto.i32, remained) + mask_i32_1, _ = pto.make_mask(pto.i32, remained) + + index_old = pto.vdup(pto.i16(0), mask_f16) + index_new = pto.vdup(pto.i16(0), mask_f16) + min_vals = pto.vlds(src[0, col:]) + + for row in range(1, src_valid_rows, 1): + index_new = pto.vadds(index_new, pto.i16(1), mask_f16) + new_vals = pto.vlds(src[row, col:]) + lt_mask = pto.vcmp(new_vals, min_vals, all_mask, "lt") + index_old = pto.vsel(index_new, index_old, lt_mask) + min_vals = pto.vmin(min_vals, new_vals, mask_f16) + + index_even = pto.vcvt(index_old, pto.i32, all_mask, part=pto.VcvtPartMode.EVEN) + index_odd = pto.vcvt(index_old, pto.i32, all_mask, part=pto.VcvtPartMode.ODD) + index_lo, index_hi = pto.vintlv(index_even, index_odd) + + pto.vsts(index_lo, dst[0, col:], mask_i32_0) + pto.vsts(index_hi, dst[0, col + lanes_i32:], mask_i32_1) + + return + + +@pto.vkernel( + target="a5", + op="pto.tcolargmin", + dtypes=[ + (pto.f32, pto.f32, pto.i32), + (pto.ui32, pto.ui32, pto.i32), + ], + constraints=[_validate_tcolargmin], + advanced=True, +) +def template_tcolargmin_f32_to_i32(src: pto.Tile, tmp: pto.Tile, dst: pto.Tile): + src_valid_rows, src_valid_cols = src.valid_shape + lanes = pto.get_lanes(pto.f32) + + with pto.vecscope(): + remained = src_valid_cols + for col in range(0, src_valid_cols, lanes): + all_mask = pto.make_mask(pto.f32, pto.PAT.ALL) + mask, remained = pto.make_mask(pto.f32, remained) + + index_old = pto.vdup(pto.i32(0), mask) + index_new = pto.vdup(pto.i32(0), mask) + min_vals = pto.vlds(src[0, col:]) + + for row in range(1, src_valid_rows, 1): + index_new = pto.vadds(index_new, pto.i32(1), mask) + new_vals = pto.vlds(src[row, col:]) + lt_mask = pto.vcmp(new_vals, min_vals, all_mask, "lt") + index_old = pto.vsel(index_new, index_old, lt_mask) + min_vals = pto.vmin(min_vals, new_vals, mask) + + pto.vsts(index_old, dst[0, col:], mask) + + return diff --git a/lib/TileOps/tcolexpand_template.py b/lib/TileOps/tcolexpand_template.py new file mode 100644 index 000000000..5d20fcbe1 --- /dev/null +++ b/lib/TileOps/tcolexpand_template.py @@ -0,0 +1,29 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""TileLang DSL template for pto.tcolexpand""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tcolexpand" +) +def template_tcolexpand(src0: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[0, col:]) + pto.vsts(lhs, dst[row, col:], mask) + return diff --git a/lib/TileOps/tcolexpandadd_template.py b/lib/TileOps/tcolexpandadd_template.py new file mode 100644 index 000000000..287be93dc --- /dev/null +++ b/lib/TileOps/tcolexpandadd_template.py @@ -0,0 +1,31 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""TileLang DSL template for pto.tcolexpandadd""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tcolexpandadd" +) +def template_tcolexpandadd(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[0, col:]) + summed = pto.vadd(lhs, rhs, mask) + pto.vsts(summed, dst[row, col:], mask) + return diff --git a/lib/TileOps/tcolexpanddiv_template.py b/lib/TileOps/tcolexpanddiv_template.py new file mode 100644 index 000000000..b08a74044 --- /dev/null +++ b/lib/TileOps/tcolexpanddiv_template.py @@ -0,0 +1,54 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tcolexpanddiv with IEEE 754 high-precision support + +Divide each column of src0 by a per-column scalar from src1[0, col]. +Semantics: dst[row, col] = src0[row, col] / src1[0, col] +""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + +# Import shared high-precision division algorithms +from div_hp import _div_ieee754_f32_impl, _div_ieee754_f16_impl + + +@pto.vkernel( + target="a5", + op="pto.tcolexpanddiv" +) +def template_tcolexpanddiv(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + """Template for pto.tcolexpanddiv with optional high-precision mode.""" + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + precision_mode = pto.get_op_attr("precision_mode", "DEFAULT") + if pto.constexpr(precision_mode == "HIGH_PRECISION"): + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[0, col:]) + if pto.constexpr(dtype == pto.f32): + result = _div_ieee754_f32_impl(lhs, rhs, mask) + else: # dtype == pto.f16 (guaranteed by MLIR validation) + result = _div_ieee754_f16_impl(lhs, rhs, mask) + pto.vsts(result, dst[row, col:], mask) + else: + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[0, col:]) + result = pto.vdiv(lhs, rhs, mask) + pto.vsts(result, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/tcolexpandexpdif_template.py b/lib/TileOps/tcolexpandexpdif_template.py new file mode 100644 index 000000000..0ae28ee0c --- /dev/null +++ b/lib/TileOps/tcolexpandexpdif_template.py @@ -0,0 +1,57 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""TileLang DSL template for pto.tcolexpandexpdif""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tcolexpandexpdif", + dtypes=[ + (pto.f16, pto.f16, pto.f16), + ], +) +def template_tcolexpandexpdif_f16(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[0, col:]) + diff = pto.vsub(lhs, rhs, mask) + result = pto.vexp(diff, mask) + pto.vsts(result, dst[row, col:], mask) + return + + +@pto.vkernel( + target="a5", + op="pto.tcolexpandexpdif", + dtypes=[ + (pto.f32, pto.f32, pto.f32), + ], +) +def template_tcolexpandexpdif_f32(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[0, col:]) + result = pto.vexpdif(lhs, rhs, mask, pto.VcvtPartMode.ODD) + pto.vsts(result, dst[row, col:], mask) + return diff --git a/lib/TileOps/tcolexpandmax_template.py b/lib/TileOps/tcolexpandmax_template.py new file mode 100644 index 000000000..79f3699b7 --- /dev/null +++ b/lib/TileOps/tcolexpandmax_template.py @@ -0,0 +1,31 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""TileLang DSL template for pto.tcolexpandmax""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tcolexpandmax" +) +def template_tcolexpandmax(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[0, col:]) + result = pto.vmax(lhs, rhs, mask) + pto.vsts(result, dst[row, col:], mask) + return diff --git a/lib/TileOps/tcolexpandmin_template.py b/lib/TileOps/tcolexpandmin_template.py new file mode 100644 index 000000000..054b35dfc --- /dev/null +++ b/lib/TileOps/tcolexpandmin_template.py @@ -0,0 +1,31 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""TileLang DSL template for pto.tcolexpandmin""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tcolexpandmin" +) +def template_tcolexpandmin(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[0, col:]) + result = pto.vmin(lhs, rhs, mask) + pto.vsts(result, dst[row, col:], mask) + return diff --git a/lib/TileOps/tcolexpandmul_template.py b/lib/TileOps/tcolexpandmul_template.py new file mode 100644 index 000000000..5dcacfa91 --- /dev/null +++ b/lib/TileOps/tcolexpandmul_template.py @@ -0,0 +1,31 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""TileLang DSL template for pto.tcolexpandmul""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tcolexpandmul" +) +def template_tcolexpandmul(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[0, col:]) + result = pto.vmul(lhs, rhs, mask) + pto.vsts(result, dst[row, col:], mask) + return diff --git a/lib/TileOps/tcolexpandsub_template.py b/lib/TileOps/tcolexpandsub_template.py new file mode 100644 index 000000000..f46bbaf11 --- /dev/null +++ b/lib/TileOps/tcolexpandsub_template.py @@ -0,0 +1,31 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""TileLang DSL template for pto.tcolexpandsub""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tcolexpandsub" +) +def template_tcolexpandsub(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[0, col:]) + result = pto.vsub(lhs, rhs, mask) + pto.vsts(result, dst[row, col:], mask) + return diff --git a/lib/TileOps/tcolmax_template.py b/lib/TileOps/tcolmax_template.py new file mode 100644 index 000000000..f08df7eda --- /dev/null +++ b/lib/TileOps/tcolmax_template.py @@ -0,0 +1,56 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys +from pathlib import Path +import tilelang_dsl as pto + +def _validate_tcolmax( + src_shape=(), + src_valid_shape=(), + dst_shape=(), + dst_valid_shape=(), + src_config=None, + dst_config=None +): + if src_config is None or dst_config is None: + return False + if src_config.b_layout != pto.BLayout.ROW_MAJOR: + return False + if dst_config.b_layout != pto.BLayout.ROW_MAJOR: + return False + if src_config.s_layout != pto.SLayout.NONE_BOX: + return False + if dst_config.s_layout != pto.SLayout.NONE_BOX: + return False + if dst_valid_shape[0] != 1: + return False + return True + +@pto.vkernel( + target="a5", + op="pto.tcolmax", + constraints=[_validate_tcolmax] +) +def template_tcolmax(src: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = src.valid_shape + + lanes = pto.get_lanes(dtype) + remained = valid_cols + + for col_chunk in range(0, valid_cols, lanes): + mask, remained = pto.make_mask(dtype, remained) + + acc = pto.vlds(src[0, col_chunk:]) + for row in range(1, valid_rows, 1): + row_vec = pto.vlds(src[row, col_chunk:]) + acc = pto.vmax(acc, row_vec, mask) + pto.vsts(acc, dst[0, col_chunk:], mask) + + return diff --git a/lib/TileOps/tcolmin_template.py b/lib/TileOps/tcolmin_template.py new file mode 100644 index 000000000..2a36dcdd5 --- /dev/null +++ b/lib/TileOps/tcolmin_template.py @@ -0,0 +1,56 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys +from pathlib import Path +import tilelang_dsl as pto + +def _validate_tcolmin( + src_shape=(), + src_valid_shape=(), + dst_shape=(), + dst_valid_shape=(), + src_config=None, + dst_config=None +): + if src_config is None or dst_config is None: + return False + if src_config.b_layout != pto.BLayout.ROW_MAJOR: + return False + if dst_config.b_layout != pto.BLayout.ROW_MAJOR: + return False + if src_config.s_layout != pto.SLayout.NONE_BOX: + return False + if dst_config.s_layout != pto.SLayout.NONE_BOX: + return False + if dst_valid_shape[0] != 1: + return False + return True + +@pto.vkernel( + target="a5", + op="pto.tcolmin", + constraints=[_validate_tcolmin] +) +def template_tcolmin(src: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = src.valid_shape + + lanes = pto.get_lanes(dtype) + remained = valid_cols + + for col_chunk in range(0, valid_cols, lanes): + mask, remained = pto.make_mask(dtype, remained) + + acc = pto.vlds(src[0, col_chunk:]) + for row in range(1, valid_rows, 1): + row_vec = pto.vlds(src[row, col_chunk:]) + acc = pto.vmin(acc, row_vec, mask) + pto.vsts(acc, dst[0, col_chunk:], mask) + + return diff --git a/lib/TileOps/tcolprod_template.py b/lib/TileOps/tcolprod_template.py new file mode 100644 index 000000000..4ebb99f48 --- /dev/null +++ b/lib/TileOps/tcolprod_template.py @@ -0,0 +1,56 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys +from pathlib import Path +import tilelang_dsl as pto + +def _validate_tcolprod( + src_shape=(), + src_valid_shape=(), + dst_shape=(), + dst_valid_shape=(), + src_config=None, + dst_config=None +): + if src_config is None or dst_config is None: + return False + if src_config.b_layout != pto.BLayout.ROW_MAJOR: + return False + if dst_config.b_layout != pto.BLayout.ROW_MAJOR: + return False + if src_config.s_layout != pto.SLayout.NONE_BOX: + return False + if dst_config.s_layout != pto.SLayout.NONE_BOX: + return False + if dst_valid_shape[0] != 1: + return False + return True + +@pto.vkernel( + target="a5", + op="pto.tcolprod", + constraints=[_validate_tcolprod] +) +def template_tcolprod(src: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = src.valid_shape + + lanes = pto.get_lanes(dtype) + remained = valid_cols + + for col_chunk in range(0, valid_cols, lanes): + mask, remained = pto.make_mask(dtype, remained) + + acc = pto.vlds(src[0, col_chunk:]) + for row in range(1, valid_rows, 1): + row_vec = pto.vlds(src[row, col_chunk:]) + acc = pto.vmul(acc, row_vec, mask) + pto.vsts(acc, dst[0, col_chunk:], mask) + + return diff --git a/lib/TileOps/tcolsum_template.py b/lib/TileOps/tcolsum_template.py new file mode 100644 index 000000000..b187b45a3 --- /dev/null +++ b/lib/TileOps/tcolsum_template.py @@ -0,0 +1,57 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys +from pathlib import Path +import tilelang_dsl as pto + +def _validate_tcolsum( + src_shape=(), + src_valid_shape=(), + dst_shape=(), + dst_valid_shape=(), + src_config=None, + dst_config=None +): + if src_config is None or dst_config is None: + return False + if src_config.b_layout != pto.BLayout.ROW_MAJOR: + return False + if dst_config.b_layout != pto.BLayout.ROW_MAJOR: + return False + if src_config.s_layout != pto.SLayout.NONE_BOX: + return False + if dst_config.s_layout != pto.SLayout.NONE_BOX: + return False + if dst_valid_shape[0] != 1: + return False + return True + +# Todo: This is the basic implementation. Later the binary colsum algorithm should be implemented also. +@pto.vkernel( + target="a5", + op="pto.tcolsum", + constraints=[_validate_tcolsum] +) +def template_tcolsum(src: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = src.valid_shape + + lanes = pto.get_lanes(dtype) + remained = valid_cols + + for col_chunk in range(0, valid_cols, lanes): + mask, remained = pto.make_mask(dtype, remained) + + acc = pto.vlds(src[0, col_chunk:]) + for row in range(1, valid_rows, 1): + row_vec = pto.vlds(src[row, col_chunk:]) + acc = pto.vadd(acc, row_vec, mask) + pto.vsts(acc, dst[0, col_chunk:], mask) + + return diff --git a/lib/TileOps/tcvt_template.py b/lib/TileOps/tcvt_template.py new file mode 100644 index 000000000..45acd23ff --- /dev/null +++ b/lib/TileOps/tcvt_template.py @@ -0,0 +1,1112 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tcvt.""" + +import tilelang_dsl as pto + + +def _supports_basic_rowwise_tcvt( + src_shape=(), + src_valid_shape=(), + dst_shape=(), + dst_valid_shape=(), + src_config=None, + dst_config=None, +): + if tuple(src_shape) != tuple(dst_shape): + return False + if tuple(src_valid_shape) != tuple(dst_valid_shape): + return False + if len(src_shape) != 2 or len(dst_shape) != 2: + return False + if src_config is None or dst_config is None: + return False + if src_config.b_layout != pto.BLayout.ROW_MAJOR: + return False + if dst_config.b_layout != pto.BLayout.ROW_MAJOR: + return False + if src_config.s_layout != pto.SLayout.NONE_BOX: + return False + if dst_config.s_layout != pto.SLayout.NONE_BOX: + return False + return True + + +@pto.vkernel( + target="a5", + op="pto.tcvt", + dtypes=[ + (pto.f32, pto.f16), + ], + constraints=[_supports_basic_rowwise_tcvt], +) +def template_tcvt_f32_to_f16(src: pto.Tile, dst: pto.Tile): + valid_rows, valid_cols = dst.valid_shape + round_mode = pto.get_op_attr("round_mode", "RINT") + rnd = pto.VcvtRoundMode.R + if pto.constexpr(round_mode == "ROUND"): + rnd = pto.VcvtRoundMode.A + elif pto.constexpr(round_mode == "FLOOR"): + rnd = pto.VcvtRoundMode.F + elif pto.constexpr(round_mode == "CEIL"): + rnd = pto.VcvtRoundMode.C + elif pto.constexpr(round_mode == "TRUNC"): + rnd = pto.VcvtRoundMode.Z + elif pto.constexpr(round_mode == "ODD"): + rnd = pto.VcvtRoundMode.O + + full_mask = pto.make_mask(pto.f32, pto.PAT.ALL) + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(pto.f32)): + store_mask, remained = pto.make_mask(pto.f32, remained) + vec = pto.vlds(src[row, col:]) + converted = pto.vcvt( + vec, + pto.f16, + full_mask, + rnd=rnd, + sat=pto.VcvtSatMode.SAT, + part=pto.VcvtPartMode.EVEN, + ) + pto.vsts(converted, dst[row, col:], store_mask, dist=pto.VStoreDist.PK_B32) + return + + +@pto.vkernel( + target="a5", + op="pto.tcvt", + dtypes=[ + (pto.f32, pto.i32), + ], + constraints=[_supports_basic_rowwise_tcvt], +) +def template_tcvt_f32_to_i32(src: pto.Tile, dst: pto.Tile): + dst_dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + round_mode = pto.get_op_attr("round_mode", "RINT") + rnd = pto.VcvtRoundMode.R + if pto.constexpr(round_mode == "ROUND"): + rnd = pto.VcvtRoundMode.A + elif pto.constexpr(round_mode == "FLOOR"): + rnd = pto.VcvtRoundMode.F + elif pto.constexpr(round_mode == "CEIL"): + rnd = pto.VcvtRoundMode.C + elif pto.constexpr(round_mode == "TRUNC"): + rnd = pto.VcvtRoundMode.Z + elif pto.constexpr(round_mode == "ODD"): + rnd = pto.VcvtRoundMode.O + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dst_dtype)): + mask, remained = pto.make_mask(dst_dtype, remained) + vec = pto.vlds(src[row, col:]) + converted = pto.vcvt( + vec, + dst_dtype, + mask, + rnd=rnd, + sat=pto.VcvtSatMode.SAT, + ) + pto.vsts(converted, dst[row, col:], mask) + return + + +@pto.vkernel( + target="a5", + op="pto.tcvt", + dtypes=[ + (pto.i32, pto.f32), + ], + constraints=[_supports_basic_rowwise_tcvt], +) +def template_tcvt_i32_to_f32(src: pto.Tile, dst: pto.Tile): + dst_dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + round_mode = pto.get_op_attr("round_mode", "RINT") + rnd = pto.VcvtRoundMode.R + if pto.constexpr(round_mode == "ROUND"): + rnd = pto.VcvtRoundMode.A + elif pto.constexpr(round_mode == "FLOOR"): + rnd = pto.VcvtRoundMode.F + elif pto.constexpr(round_mode == "CEIL"): + rnd = pto.VcvtRoundMode.C + elif pto.constexpr(round_mode == "TRUNC"): + rnd = pto.VcvtRoundMode.Z + elif pto.constexpr(round_mode == "ODD"): + rnd = pto.VcvtRoundMode.O + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dst_dtype)): + mask, remained = pto.make_mask(dst_dtype, remained) + vec = pto.vlds(src[row, col:]) + converted = pto.vcvt( + vec, + dst_dtype, + mask, + rnd=rnd, + ) + pto.vsts(converted, dst[row, col:], mask) + return + + +@pto.vkernel( + target="a5", + op="pto.tcvt", + dtypes=[ + (pto.f16, pto.f32), + (pto.bf16, pto.f32), + (pto.i16, pto.f32), + (pto.i16, pto.i32), + (pto.i16, pto.ui32), + ], + constraints=[_supports_basic_rowwise_tcvt], +) +def template_tcvt_16_to_32(src: pto.Tile, dst: pto.Tile): + src_dtype = src.element_type + dst_dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + full_mask = pto.make_mask(src_dtype, pto.PAT.ALL) + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dst_dtype)): + store_mask, remained = pto.make_mask(dst_dtype, remained) + vec = pto.vlds(src[row, col:], dist=pto.VLoadDist.UNPK_B16) + converted = pto.vcvt( + vec, + dst_dtype, + full_mask, + part=pto.VcvtPartMode.EVEN, + ) + pto.vsts(converted, dst[row, col:], store_mask) + return + + +@pto.vkernel( + target="a5", + op="pto.tcvt", + dtypes=[ + (pto.i32, pto.i64), + ], + constraints=[_supports_basic_rowwise_tcvt], +) +def template_tcvt_i32_to_i64(src: pto.Tile, dst: pto.Tile): + valid_rows, valid_cols = dst.valid_shape + full_mask = pto.make_mask(pto.i32, pto.PAT.ALL) + for row in range(0, valid_rows, 1): + remained = valid_cols * 2 # i64 requires double the mask + for col in range(0, valid_cols, pto.get_lanes(pto.i64)): + store_mask, remained = pto.make_mask(pto.i64, remained) + vec = pto.vlds(src[row, col:], dist=pto.VLoadDist.UNPK_B32) + converted = pto.vcvt( + vec, + pto.i64, + full_mask, + part=pto.VcvtPartMode.EVEN, + ) + pto.vsts(converted, dst[row, col:], store_mask, dist=pto.VStoreDist.NORM_B32) + return + + +@pto.vkernel( + target="a5", + op="pto.tcvt", + dtypes=[ + (pto.ui8, pto.f16), + ], + constraints=[_supports_basic_rowwise_tcvt], +) +def template_tcvt_ui8_to_f16(src: pto.Tile, dst: pto.Tile): + valid_rows, valid_cols = dst.valid_shape + full_mask = pto.make_mask(pto.ui8, pto.PAT.ALL) + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(pto.f16)): + store_mask, remained = pto.make_mask(pto.f16, remained) + vec = pto.vlds(src[row, col:], dist=pto.VLoadDist.UNPK_B8) + converted = pto.vcvt( + vec, + pto.f16, + full_mask, + part=pto.VcvtPartMode.EVEN, + ) + pto.vsts(converted, dst[row, col:], store_mask) + return + + +@pto.vkernel( + target="a5", + op="pto.tcvt", + dtypes=[ + (pto.ui8, pto.ui16), + ], + constraints=[_supports_basic_rowwise_tcvt], +) +def template_tcvt_ui8_to_ui16(src: pto.Tile, dst: pto.Tile): + valid_rows, valid_cols = dst.valid_shape + full_mask = pto.make_mask(pto.ui8, pto.PAT.ALL) + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(pto.ui16)): + store_mask, remained = pto.make_mask(pto.ui16, remained) + vec = pto.vlds(src[row, col:], dist=pto.VLoadDist.UNPK_B8) + converted = pto.vcvt( + vec, + pto.ui16, + full_mask, + part=pto.VcvtPartMode.EVEN, + ) + pto.vsts(converted, dst[row, col:], store_mask) + return + + +@pto.vkernel( + target="a5", + op="pto.tcvt", + dtypes=[ + (pto.si8, pto.f16), + ], + constraints=[_supports_basic_rowwise_tcvt], +) +def template_tcvt_si8_to_f16(src: pto.Tile, dst: pto.Tile): + valid_rows, valid_cols = dst.valid_shape + full_mask = pto.make_mask(pto.si8, pto.PAT.ALL) + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(pto.f16)): + store_mask, remained = pto.make_mask(pto.f16, remained) + vec = pto.vlds(src[row, col:], dist=pto.VLoadDist.UNPK_B8) + converted = pto.vcvt( + vec, + pto.f16, + full_mask, + part=pto.VcvtPartMode.EVEN, + ) + pto.vsts(converted, dst[row, col:], store_mask) + return + + +@pto.vkernel( + target="a5", + op="pto.tcvt", + dtypes=[ + (pto.si8, pto.si16), + ], + constraints=[_supports_basic_rowwise_tcvt], +) +def template_tcvt_si8_to_si16(src: pto.Tile, dst: pto.Tile): + valid_rows, valid_cols = dst.valid_shape + full_mask = pto.make_mask(pto.si8, pto.PAT.ALL) + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(pto.si16)): + store_mask, remained = pto.make_mask(pto.si16, remained) + vec = pto.vlds(src[row, col:], dist=pto.VLoadDist.UNPK_B8) + converted = pto.vcvt( + vec, + pto.si16, + full_mask, + part=pto.VcvtPartMode.EVEN, + ) + pto.vsts(converted, dst[row, col:], store_mask, dist=pto.VStoreDist.NORM_B16) + return + + +@pto.vkernel( + target="a5", + op="pto.tcvt", + dtypes=[ + (pto.si8, pto.i32), + ], + constraints=[_supports_basic_rowwise_tcvt], + advanced=True, +) +def template_tcvt_si8_to_i32(src: pto.Tile, dst: pto.Tile): + valid_rows, valid_cols = dst.valid_shape + b8_mask = pto.make_mask(pto.ui8, pto.PAT.ALL) + v_zero = pto.vdup(pto.ui8(0), b8_mask) + lanes_i32 = pto.get_lanes(pto.i32) + for row in range(0, valid_rows, 1): + remained = valid_cols + next_remained = 0 + if valid_cols > lanes_i32: + next_remained = valid_cols - lanes_i32 + for col in range(0, valid_cols, pto.get_lanes(pto.i16)): + mask_b16_cur, remained = pto.make_mask(pto.i16, remained) + mask_b16_next, next_remained = pto.make_mask(pto.i16, next_remained) + mask_b32_cur = pto.punpack(mask_b16_cur, pto.PredicatePart.LOWER) + mask_b32_next = pto.punpack(mask_b16_next, pto.PredicatePart.LOWER) + vec_si8_0 = pto.vlds(src[row, col:], dist=pto.VLoadDist.UNPK_B8) + vec_ui8_0 = pto.vbitcast(vec_si8_0, pto.ui8) + vec_ui8_1, vec_ui8_2 = pto.vintlv(vec_ui8_0, v_zero) + vec_si8_1 = pto.vbitcast(vec_ui8_1, pto.si8) + vec_si8_2 = pto.vbitcast(vec_ui8_2, pto.si8) + output_0 = pto.vcvt(vec_si8_1, pto.i32, b8_mask, part=pto.VcvtPartMode.P0) + output_1 = pto.vcvt(vec_si8_2, pto.i32, b8_mask, part=pto.VcvtPartMode.P0) + pto.vsts(output_0, dst[row, col:], mask_b32_cur, dist=pto.VStoreDist.NORM_B32) + pto.vsts(output_1, dst[row, col + lanes_i32:], mask_b32_next, dist=pto.VStoreDist.NORM_B32) + return + + +@pto.vkernel( + target="a5", + op="pto.tcvt", + dtypes=[ + (pto.f32, pto.f32), + ], + constraints=[_supports_basic_rowwise_tcvt], +) +def template_tcvt_f32_to_f32(src: pto.Tile, dst: pto.Tile): + valid_rows, valid_cols = dst.valid_shape + round_mode = pto.get_op_attr("round_mode", "RINT") + rnd = pto.VcvtRoundMode.R + if pto.constexpr(round_mode == "ROUND"): + rnd = pto.VcvtRoundMode.A + elif pto.constexpr(round_mode == "FLOOR"): + rnd = pto.VcvtRoundMode.F + elif pto.constexpr(round_mode == "CEIL"): + rnd = pto.VcvtRoundMode.C + elif pto.constexpr(round_mode == "TRUNC"): + rnd = pto.VcvtRoundMode.Z + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(pto.f32)): + mask, remained = pto.make_mask(pto.f32, remained) + vec = pto.vlds(src[row, col:]) + converted = pto.vtrc(vec, mask, rnd=rnd) + pto.vsts(converted, dst[row, col:], mask) + return + + +@pto.vkernel( + target="a5", + op="pto.tcvt", + dtypes=[ + (pto.f16, pto.i32), + ], + constraints=[_supports_basic_rowwise_tcvt], +) +def template_tcvt_f16_to_i32(src: pto.Tile, dst: pto.Tile): + valid_rows, valid_cols = dst.valid_shape + round_mode = pto.get_op_attr("round_mode", "RINT") + rnd = pto.VcvtRoundMode.R + if pto.constexpr(round_mode == "ROUND"): + rnd = pto.VcvtRoundMode.A + elif pto.constexpr(round_mode == "FLOOR"): + rnd = pto.VcvtRoundMode.F + elif pto.constexpr(round_mode == "CEIL"): + rnd = pto.VcvtRoundMode.C + elif pto.constexpr(round_mode == "TRUNC"): + rnd = pto.VcvtRoundMode.Z + elif pto.constexpr(round_mode == "ODD"): + rnd = pto.VcvtRoundMode.O + + full_mask = pto.make_mask(pto.f16, pto.PAT.ALL) + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(pto.i32)): + store_mask, remained = pto.make_mask(pto.i32, remained) + vec = pto.vlds(src[row, col:], dist=pto.VLoadDist.UNPK_B16) + converted = pto.vcvt( + vec, + pto.i32, + full_mask, + rnd=rnd, + part=pto.VcvtPartMode.EVEN, + ) + pto.vsts(converted, dst[row, col:], store_mask) + return + + +@pto.vkernel( + target="a5", + op="pto.tcvt", + dtypes=[ + (pto.i16, pto.f16), + ], + constraints=[_supports_basic_rowwise_tcvt], +) +def template_tcvt_i16_to_f16(src: pto.Tile, dst: pto.Tile): + valid_rows, valid_cols = dst.valid_shape + round_mode = pto.get_op_attr("round_mode", "RINT") + rnd = pto.VcvtRoundMode.R + if pto.constexpr(round_mode == "ROUND"): + rnd = pto.VcvtRoundMode.A + elif pto.constexpr(round_mode == "FLOOR"): + rnd = pto.VcvtRoundMode.F + elif pto.constexpr(round_mode == "CEIL"): + rnd = pto.VcvtRoundMode.C + elif pto.constexpr(round_mode == "TRUNC"): + rnd = pto.VcvtRoundMode.Z + elif pto.constexpr(round_mode == "ODD"): + rnd = pto.VcvtRoundMode.O + + full_mask = pto.make_mask(pto.i16, pto.PAT.ALL) + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(pto.i16)): + store_mask, remained = pto.make_mask(pto.f16, remained) + vec = pto.vlds(src[row, col:]) + converted = pto.vcvt( + vec, + pto.f16, + full_mask, + rnd=rnd, + ) + pto.vsts(converted, dst[row, col:], store_mask) + return + + +@pto.vkernel( + target="a5", + op="pto.tcvt", + dtypes=[ + (pto.i64, pto.f32), + ], + constraints=[_supports_basic_rowwise_tcvt], +) +def template_tcvt_i64_to_f32(src: pto.Tile, dst: pto.Tile): + valid_rows, valid_cols = dst.valid_shape + round_mode = pto.get_op_attr("round_mode", "RINT") + rnd = pto.VcvtRoundMode.R + if pto.constexpr(round_mode == "ROUND"): + rnd = pto.VcvtRoundMode.A + elif pto.constexpr(round_mode == "FLOOR"): + rnd = pto.VcvtRoundMode.F + elif pto.constexpr(round_mode == "CEIL"): + rnd = pto.VcvtRoundMode.C + elif pto.constexpr(round_mode == "TRUNC"): + rnd = pto.VcvtRoundMode.Z + elif pto.constexpr(round_mode == "ODD"): + rnd = pto.VcvtRoundMode.O + + for row in range(0, valid_rows, 1): + remained = valid_cols * 2 # i64 requires double the mask + full_mask, _ = pto.make_mask(pto.i64, remained) + for col in range(0, valid_cols, pto.get_lanes(pto.i64)): + store_mask, remained = pto.make_mask(pto.f32, remained) + vec = pto.vlds(src[row, col:]) + converted = pto.vcvt( + vec, + pto.f32, + full_mask, + rnd=rnd, + part=pto.VcvtPartMode.EVEN, + ) + pto.vsts(converted, dst[row, col:], store_mask, dist=pto.VStoreDist.PK_B64) + return + + +@pto.vkernel( + target="a5", + op="pto.tcvt", + dtypes=[ + (pto.i16, pto.ui8), + ], + constraints=[_supports_basic_rowwise_tcvt], +) +def template_tcvt_i16_to_ui8(src: pto.Tile, dst: pto.Tile): + valid_rows, valid_cols = dst.valid_shape + full_mask = pto.make_mask(pto.i16, pto.PAT.ALL) + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(pto.i16)): + store_mask, remained = pto.make_mask(pto.i16, remained) + vec = pto.vlds(src[row, col:]) + converted = pto.vcvt( + vec, + pto.ui8, + full_mask, + sat=pto.VcvtSatMode.SAT, + part=pto.VcvtPartMode.EVEN, + ) + pto.vsts(converted, dst[row, col:], store_mask, dist=pto.VStoreDist.PK_B16) + return + + +@pto.vkernel( + target="a5", + op="pto.tcvt", + dtypes=[ + (pto.i32, pto.i16), + ], + constraints=[_supports_basic_rowwise_tcvt], +) +def template_tcvt_i32_to_i16(src: pto.Tile, dst: pto.Tile): + valid_rows, valid_cols = dst.valid_shape + full_mask = pto.make_mask(pto.i32, pto.PAT.ALL) + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(pto.i32)): + store_mask, remained = pto.make_mask(pto.i32, remained) + vec = pto.vlds(src[row, col:]) + converted = pto.vcvt( + vec, + pto.i16, + full_mask, + sat=pto.VcvtSatMode.NOSAT, + part=pto.VcvtPartMode.EVEN, + ) + pto.vsts(converted, dst[row, col:], store_mask, dist=pto.VStoreDist.PK_B32) + return + + +@pto.vkernel( + target="a5", + op="pto.tcvt", + dtypes=[ + (pto.i32, pto.ui16), + ], + constraints=[_supports_basic_rowwise_tcvt], +) +def template_tcvt_i32_to_ui16(src: pto.Tile, dst: pto.Tile): + valid_rows, valid_cols = dst.valid_shape + full_mask = pto.make_mask(pto.i32, pto.PAT.ALL) + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(pto.i32)): + store_mask, remained = pto.make_mask(pto.i32, remained) + vec = pto.vlds(src[row, col:]) + converted = pto.vcvt( + vec, + pto.ui16, + full_mask, + sat=pto.VcvtSatMode.SAT, + part=pto.VcvtPartMode.EVEN, + ) + pto.vsts(converted, dst[row, col:], store_mask, dist=pto.VStoreDist.PK_B32) + return + + +@pto.vkernel( + target="a5", + op="pto.tcvt", + dtypes=[ + (pto.i32, pto.ui8), + ], + constraints=[_supports_basic_rowwise_tcvt], + advanced=True, +) +def template_tcvt_i32_to_ui8(src: pto.Tile, dst: pto.Tile): + valid_rows, valid_cols = dst.valid_shape + full_mask = pto.make_mask(pto.i32, pto.PAT.ALL) + idx_mask_b8 = pto.pset_b8(pto.PAT.ALL) + idx_mask_b16 = pto.pbitcast(idx_mask_b8, pto.mask_b16) + lanes_i32 = pto.get_lanes(pto.i32) + v_idx = pto.vci(pto.i8(0), pto.OrderMode.ASC) + v_idx_i16 = pto.vbitcast(v_idx, pto.i16) + v_idx_i16 = pto.vmuls(v_idx_i16, pto.i16(4), idx_mask_b16) + v_idx_ui8 = pto.vbitcast(v_idx_i16, pto.ui8) + for row in range(0, valid_rows, 1): + mask_len_tail = valid_cols % lanes_i32 + if valid_cols % lanes_i32 == 0: + mask_len_tail = lanes_i32 + for col in range(0, valid_cols, lanes_i32): + mask_len = lanes_i32 + if valid_cols < lanes_i32 or col == valid_cols - lanes_i32: + mask_len = mask_len_tail + store_mask, _ = pto.make_mask(pto.ui8, mask_len) + vec = pto.vlds(src[row, col:]) + converted = pto.vcvt( + vec, + pto.ui8, + full_mask, + sat=pto.VcvtSatMode.SAT, + part=pto.VcvtPartMode.P0, + ) + result = pto.vselr(converted, v_idx_ui8) + pto.mem_bar(pto.BarrierType.VST_VST) + pto.vsts(result, dst[row, col:], store_mask, dist=pto.VStoreDist.NORM_B8) + return + + +@pto.vkernel( + target="a5", + op="pto.tcvt", + dtypes=[ + (pto.ui32, pto.i16), + ], + constraints=[_supports_basic_rowwise_tcvt], +) +def template_tcvt_ui32_to_i16(src: pto.Tile, dst: pto.Tile): + valid_rows, valid_cols = dst.valid_shape + full_mask = pto.make_mask(pto.ui32, pto.PAT.ALL) + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(pto.ui32)): + store_mask, remained = pto.make_mask(pto.ui32, remained) + vec = pto.vlds(src[row, col:]) + converted = pto.vcvt( + vec, + pto.i16, + full_mask, + sat=pto.VcvtSatMode.SAT, + part=pto.VcvtPartMode.EVEN, + ) + pto.vsts(converted, dst[row, col:], store_mask, dist=pto.VStoreDist.PK_B32) + return + + +@pto.vkernel( + target="a5", + op="pto.tcvt", + dtypes=[ + (pto.ui32, pto.ui16), + ], + constraints=[_supports_basic_rowwise_tcvt], +) +def template_tcvt_ui32_to_ui16(src: pto.Tile, dst: pto.Tile): + valid_rows, valid_cols = dst.valid_shape + full_mask = pto.make_mask(pto.ui32, pto.PAT.ALL) + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(pto.ui32)): + store_mask, remained = pto.make_mask(pto.ui32, remained) + vec = pto.vlds(src[row, col:]) + converted = pto.vcvt( + vec, + pto.ui16, + full_mask, + sat=pto.VcvtSatMode.SAT, + part=pto.VcvtPartMode.EVEN, + ) + pto.vsts(converted, dst[row, col:], store_mask, dist=pto.VStoreDist.PK_B32) + return + + +@pto.vkernel( + target="a5", + op="pto.tcvt", + dtypes=[ + (pto.ui32, pto.ui8), + ], + constraints=[_supports_basic_rowwise_tcvt], + advanced=True, +) +def template_tcvt_ui32_to_ui8(src: pto.Tile, dst: pto.Tile): + valid_rows, valid_cols = dst.valid_shape + full_mask = pto.make_mask(pto.ui32, pto.PAT.ALL) + idx_mask_b8 = pto.pset_b8(pto.PAT.ALL) + idx_mask_b16 = pto.pbitcast(idx_mask_b8, pto.mask_b16) + lanes_ui32 = pto.get_lanes(pto.ui32) + v_idx = pto.vci(pto.i8(0), pto.OrderMode.ASC) + v_idx_i16 = pto.vbitcast(v_idx, pto.i16) + v_idx_i16 = pto.vmuls(v_idx_i16, pto.i16(4), idx_mask_b16) + v_idx_ui8 = pto.vbitcast(v_idx_i16, pto.ui8) + for row in range(0, valid_rows, 1): + mask_len_tail = valid_cols % lanes_ui32 + if valid_cols % lanes_ui32 == 0: + mask_len_tail = lanes_ui32 + for col in range(0, valid_cols, lanes_ui32): + mask_len = lanes_ui32 + if valid_cols < lanes_ui32 or col == valid_cols - lanes_ui32: + mask_len = mask_len_tail + store_mask, _ = pto.make_mask(pto.ui8, mask_len) + vec = pto.vlds(src[row, col:]) + converted = pto.vcvt( + vec, + pto.ui8, + full_mask, + sat=pto.VcvtSatMode.SAT, + part=pto.VcvtPartMode.P0, + ) + result = pto.vselr(converted, v_idx_ui8) + pto.mem_bar(pto.BarrierType.VST_VST) + pto.vsts(result, dst[row, col:], store_mask, dist=pto.VStoreDist.NORM_B8) + return + + +@pto.vkernel( + target="a5", + op="pto.tcvt", + dtypes=[ + (pto.i64, pto.i32), + ], + constraints=[_supports_basic_rowwise_tcvt], +) +def template_tcvt_i64_to_i32(src: pto.Tile, dst: pto.Tile): + valid_rows, valid_cols = dst.valid_shape + for row in range(0, valid_rows, 1): + remained = valid_cols * 2 # i64 requires double the mask + full_mask, _ = pto.make_mask(pto.i64, remained) + for col in range(0, valid_cols, pto.get_lanes(pto.i64)): + store_mask, remained = pto.make_mask(pto.i32, remained) + vec = pto.vlds(src[row, col:]) + converted = pto.vcvt( + vec, + pto.i32, + full_mask, + sat=pto.VcvtSatMode.NOSAT, + part=pto.VcvtPartMode.EVEN, + ) + pto.vsts(converted, dst[row, col:], store_mask, dist=pto.VStoreDist.PK_B64) + return + + +@pto.vkernel( + target="a5", + op="pto.tcvt", + dtypes=[ + (pto.f32, pto.bf16), + ], + constraints=[_supports_basic_rowwise_tcvt], +) +def template_tcvt_f32_to_bf16(src: pto.Tile, dst: pto.Tile): + valid_rows, valid_cols = dst.valid_shape + round_mode = pto.get_op_attr("round_mode", "RINT") + rnd = pto.VcvtRoundMode.R + if pto.constexpr(round_mode == "ROUND"): + rnd = pto.VcvtRoundMode.A + elif pto.constexpr(round_mode == "FLOOR"): + rnd = pto.VcvtRoundMode.F + elif pto.constexpr(round_mode == "CEIL"): + rnd = pto.VcvtRoundMode.C + elif pto.constexpr(round_mode == "TRUNC"): + rnd = pto.VcvtRoundMode.Z + elif pto.constexpr(round_mode == "ODD"): + rnd = pto.VcvtRoundMode.O + + full_mask = pto.make_mask(pto.f32, pto.PAT.ALL) + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(pto.f32)): + store_mask, remained = pto.make_mask(pto.f32, remained) + vec = pto.vlds(src[row, col:]) + converted = pto.vcvt( + vec, + pto.bf16, + full_mask, + rnd=rnd, + sat=pto.VcvtSatMode.SAT, + part=pto.VcvtPartMode.EVEN, + ) + pto.vsts(converted, dst[row, col:], store_mask, dist=pto.VStoreDist.PK_B32) + return + + +@pto.vkernel( + target="a5", + op="pto.tcvt", + dtypes=[ + (pto.f32, pto.i64), + ], + constraints=[_supports_basic_rowwise_tcvt], +) +def template_tcvt_f32_to_i64(src: pto.Tile, dst: pto.Tile): + valid_rows, valid_cols = dst.valid_shape + round_mode = pto.get_op_attr("round_mode", "RINT") + rnd = pto.VcvtRoundMode.R + if pto.constexpr(round_mode == "ROUND"): + rnd = pto.VcvtRoundMode.A + elif pto.constexpr(round_mode == "FLOOR"): + rnd = pto.VcvtRoundMode.F + elif pto.constexpr(round_mode == "CEIL"): + rnd = pto.VcvtRoundMode.C + elif pto.constexpr(round_mode == "TRUNC"): + rnd = pto.VcvtRoundMode.Z + elif pto.constexpr(round_mode == "ODD"): + rnd = pto.VcvtRoundMode.O + + full_mask = pto.make_mask(pto.f32, pto.PAT.ALL) + for row in range(0, valid_rows, 1): + remained = valid_cols * 2 # i64 requires double the mask + for col in range(0, valid_cols, pto.get_lanes(pto.i64)): + store_mask, remained = pto.make_mask(pto.i64, remained) + vec = pto.vlds(src[row, col:], dist=pto.VLoadDist.UNPK_B32) + converted = pto.vcvt( + vec, + pto.i64, + full_mask, + rnd=rnd, + sat=pto.VcvtSatMode.SAT, + part=pto.VcvtPartMode.EVEN, + ) + pto.vsts(converted, dst[row, col:], store_mask, dist=pto.VStoreDist.NORM_B32) + return + + +@pto.vkernel( + target="a5", + op="pto.tcvt", + dtypes=[ + (pto.f16, pto.ui8), + ], + constraints=[_supports_basic_rowwise_tcvt], +) +def template_tcvt_f16_to_ui8(src: pto.Tile, dst: pto.Tile): + valid_rows, valid_cols = dst.valid_shape + round_mode = pto.get_op_attr("round_mode", "RINT") + rnd = pto.VcvtRoundMode.R + if pto.constexpr(round_mode == "ROUND"): + rnd = pto.VcvtRoundMode.A + elif pto.constexpr(round_mode == "FLOOR"): + rnd = pto.VcvtRoundMode.F + elif pto.constexpr(round_mode == "CEIL"): + rnd = pto.VcvtRoundMode.C + elif pto.constexpr(round_mode == "TRUNC"): + rnd = pto.VcvtRoundMode.Z + elif pto.constexpr(round_mode == "ODD"): + rnd = pto.VcvtRoundMode.O + + full_mask = pto.make_mask(pto.f16, pto.PAT.ALL) + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(pto.f16)): + store_mask, remained = pto.make_mask(pto.f16, remained) + vec = pto.vlds(src[row, col:]) + converted = pto.vcvt( + vec, + pto.ui8, + full_mask, + rnd=rnd, + sat=pto.VcvtSatMode.NOSAT, + part=pto.VcvtPartMode.EVEN, + ) + pto.vsts(converted, dst[row, col:], store_mask, dist=pto.VStoreDist.PK_B16) + return + + +@pto.vkernel( + target="a5", + op="pto.tcvt", + dtypes=[ + (pto.bf16, pto.i32), + ], + constraints=[_supports_basic_rowwise_tcvt], +) +def template_tcvt_bf16_to_i32(src: pto.Tile, dst: pto.Tile): + valid_rows, valid_cols = dst.valid_shape + round_mode = pto.get_op_attr("round_mode", "RINT") + rnd = pto.VcvtRoundMode.R + if pto.constexpr(round_mode == "ROUND"): + rnd = pto.VcvtRoundMode.A + elif pto.constexpr(round_mode == "FLOOR"): + rnd = pto.VcvtRoundMode.F + elif pto.constexpr(round_mode == "CEIL"): + rnd = pto.VcvtRoundMode.C + elif pto.constexpr(round_mode == "TRUNC"): + rnd = pto.VcvtRoundMode.Z + elif pto.constexpr(round_mode == "ODD"): + rnd = pto.VcvtRoundMode.O + + full_mask = pto.make_mask(pto.bf16, pto.PAT.ALL) + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(pto.i32)): + store_mask, remained = pto.make_mask(pto.i32, remained) + vec = pto.vlds(src[row, col:], dist=pto.VLoadDist.UNPK_B16) + converted = pto.vcvt( + vec, + pto.i32, + full_mask, + rnd=rnd, + sat=pto.VcvtSatMode.SAT, + part=pto.VcvtPartMode.EVEN, + ) + pto.vsts(converted, dst[row, col:], store_mask) + return + + +@pto.vkernel( + target="a5", + op="pto.tcvt", + dtypes=[ + (pto.bf16, pto.f16), + ], + constraints=[_supports_basic_rowwise_tcvt], +) +def template_tcvt_bf16_to_f16(src: pto.Tile, dst: pto.Tile): + valid_rows, valid_cols = dst.valid_shape + round_mode = pto.get_op_attr("round_mode", "RINT") + rnd = pto.VcvtRoundMode.R + if pto.constexpr(round_mode == "ROUND"): + rnd = pto.VcvtRoundMode.A + elif pto.constexpr(round_mode == "FLOOR"): + rnd = pto.VcvtRoundMode.F + elif pto.constexpr(round_mode == "CEIL"): + rnd = pto.VcvtRoundMode.C + elif pto.constexpr(round_mode == "TRUNC"): + rnd = pto.VcvtRoundMode.Z + elif pto.constexpr(round_mode == "ODD"): + rnd = pto.VcvtRoundMode.O + + full_mask = pto.make_mask(pto.bf16, pto.PAT.ALL) + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(pto.bf16)): + store_mask, remained = pto.make_mask(pto.f16, remained) + vec = pto.vlds(src[row, col:]) + converted = pto.vcvt( + vec, + pto.f16, + full_mask, + sat=pto.VcvtSatMode.SAT, + rnd=rnd, + ) + pto.vsts(converted, dst[row, col:], store_mask) + return + + +@pto.vkernel( + target="a5", + op="pto.tcvt", + dtypes=[ + (pto.f32, pto.i16), + ], + constraints=[_supports_basic_rowwise_tcvt], +) +def template_tcvt_f32_to_i16(src: pto.Tile, dst: pto.Tile): + valid_rows, valid_cols = dst.valid_shape + round_mode = pto.get_op_attr("round_mode", "RINT") + rnd = pto.VcvtRoundMode.R + if pto.constexpr(round_mode == "ROUND"): + rnd = pto.VcvtRoundMode.A + elif pto.constexpr(round_mode == "FLOOR"): + rnd = pto.VcvtRoundMode.F + elif pto.constexpr(round_mode == "CEIL"): + rnd = pto.VcvtRoundMode.C + elif pto.constexpr(round_mode == "TRUNC"): + rnd = pto.VcvtRoundMode.Z + elif pto.constexpr(round_mode == "ODD"): + rnd = pto.VcvtRoundMode.O + + full_mask = pto.make_mask(pto.f32, pto.PAT.ALL) + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(pto.f32)): + store_mask, remained = pto.make_mask(pto.f32, remained) + vec_f32 = pto.vlds(src[row, col:]) + # sat=OFF NonSatTorch + vec_i32 = pto.vcvt( + vec_f32, + pto.i32, + full_mask, + rnd=rnd, + sat=pto.VcvtSatMode.NOSAT, + ) + vec_i16 = pto.vcvt( + vec_i32, + pto.i16, + full_mask, + sat=pto.VcvtSatMode.NOSAT, + part=pto.VcvtPartMode.EVEN, + ) + pto.vsts(vec_i16, dst[row, col:], store_mask, dist=pto.VStoreDist.PK_B32) + return + + +@pto.vkernel( + target="a5", + op="pto.tcvt", + dtypes=[ + (pto.f16, pto.i16), + ], + constraints=[_supports_basic_rowwise_tcvt], +) +def template_tcvt_f16_to_i16(src: pto.Tile, dst: pto.Tile): + valid_rows, valid_cols = dst.valid_shape + round_mode = pto.get_op_attr("round_mode", "RINT") + rnd = pto.VcvtRoundMode.R + if pto.constexpr(round_mode == "ROUND"): + rnd = pto.VcvtRoundMode.A + elif pto.constexpr(round_mode == "FLOOR"): + rnd = pto.VcvtRoundMode.F + elif pto.constexpr(round_mode == "CEIL"): + rnd = pto.VcvtRoundMode.C + elif pto.constexpr(round_mode == "TRUNC"): + rnd = pto.VcvtRoundMode.Z + elif pto.constexpr(round_mode == "ODD"): + rnd = pto.VcvtRoundMode.O + + full_mask_b16 = pto.make_mask(pto.f16, pto.PAT.ALL) + full_mask_b32 = pto.make_mask(pto.i32, pto.PAT.ALL) + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(pto.f32)): + store_mask, remained = pto.make_mask(pto.i32, remained) + vec_f16 = pto.vlds(src[row, col:], dist=pto.VLoadDist.UNPK_B16) + # sat=OFF NonSatTorch + vec_i32 = pto.vcvt( + vec_f16, + pto.i32, + full_mask_b16, + rnd=rnd, + part=pto.VcvtPartMode.EVEN, + ) + vec_i16 = pto.vcvt( + vec_i32, + pto.i16, + full_mask_b32, + sat=pto.VcvtSatMode.NOSAT, + part=pto.VcvtPartMode.EVEN, + ) + pto.vsts(vec_i16, dst[row, col:], store_mask, dist=pto.VStoreDist.PK_B32) + return + + +@pto.vkernel( + target="a5", + op="pto.tcvt", + dtypes=[ + (pto.f16, pto.si8), + ], + constraints=[_supports_basic_rowwise_tcvt], +) +def template_tcvt_f16_to_si8(src: pto.Tile, dst: pto.Tile): + valid_rows, valid_cols = dst.valid_shape + round_mode = pto.get_op_attr("round_mode", "RINT") + rnd = pto.VcvtRoundMode.R + if pto.constexpr(round_mode == "ROUND"): + rnd = pto.VcvtRoundMode.A + elif pto.constexpr(round_mode == "FLOOR"): + rnd = pto.VcvtRoundMode.F + elif pto.constexpr(round_mode == "CEIL"): + rnd = pto.VcvtRoundMode.C + elif pto.constexpr(round_mode == "TRUNC"): + rnd = pto.VcvtRoundMode.Z + elif pto.constexpr(round_mode == "ODD"): + rnd = pto.VcvtRoundMode.O + + lanes_f16 = pto.get_lanes(pto.f16) + pg = pto.make_mask(pto.f16, pto.PAT.ALL) + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, lanes_f16): + full_mask, _ = pto.make_mask(pto.f16, lanes_f16) + store_mask, remained = pto.make_mask(pto.f16, remained) + vec_f16 = pto.vlds(src[row, col:]) + # sat=OFF NonSatTorch + vec_i16 = pto.vcvt( + vec_f16, + pto.i16, + full_mask, + rnd=rnd, + sat=pto.VcvtSatMode.NOSAT, + ) + v_mask = pto.vdup(pto.i16(255), pg) + vec_i16_and = pto.vand(vec_i16, v_mask, store_mask) + vec_f16_temp = pto.vcvt( + vec_i16_and, + pto.f16, + full_mask, + rnd=rnd, + ) + vec_si8 = pto.vcvt( + vec_f16_temp, + pto.si8, + full_mask, + rnd=rnd, + sat=pto.VcvtSatMode.NOSAT, + part=pto.VcvtPartMode.EVEN, + ) + pto.vsts(vec_si8, dst[row, col:], store_mask, dist=pto.VStoreDist.PK_B16) + return \ No newline at end of file diff --git a/lib/TileOps/tdiv_template.py b/lib/TileOps/tdiv_template.py new file mode 100644 index 000000000..2b841e13b --- /dev/null +++ b/lib/TileOps/tdiv_template.py @@ -0,0 +1,50 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tdiv with IEEE 754 high-precision support""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + +# Import shared high-precision division algorithms +from div_hp import _div_ieee754_f32_impl, _div_ieee754_f16_impl + + +@pto.vkernel( + target="a5", + op="pto.tdiv" +) +def template_tdiv(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + """Element-wise division with optional high-precision mode""" + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + precision_mode = pto.get_op_attr("precision_mode", "DEFAULT") + if pto.constexpr(precision_mode == "HIGH_PRECISION"): + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + if pto.constexpr(dtype == pto.f32): + divided = _div_ieee754_f32_impl(lhs, rhs, mask) + else: # dtype == pto.f16 (guaranteed by MLIR validation) + divided = _div_ieee754_f16_impl(lhs, rhs, mask) + pto.vsts(divided, dst[row, col:], mask) + else: + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + divided = pto.vdiv(lhs, rhs, mask) + pto.vsts(divided, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/tdivs_template.py b/lib/TileOps/tdivs_template.py new file mode 100644 index 000000000..fa35a9856 --- /dev/null +++ b/lib/TileOps/tdivs_template.py @@ -0,0 +1,92 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tdivs with IEEE 754 high-precision support + +Supports two operand orders: + 1. tdivs(src_tile, scalar, dst) -> src / scalar + 2. tdivs(scalar, src_tile, dst) -> scalar / src + +High-precision mode uses IEEE 754 compliant division algorithms from div_hp module +for improved accuracy with precision-sensitive, subnormal, and overflow boundary cases. +""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + +# Import shared high-precision division algorithms +from div_hp import _div_ieee754_f32_impl, _div_ieee754_f16_impl + + +@pto.vkernel( + target="a5", + op="pto.tdivs", +) +def template_tdivs_tile_scalar(src: pto.Tile, scalar: pto.AnyType, dst: pto.Tile): + """src / scalar with optional high-precision mode""" + dtype = src.element_type + valid_rows, valid_cols = src.valid_shape + + precision_mode = pto.get_op_attr("precision_mode", "DEFAULT") + if pto.constexpr(precision_mode == "HIGH_PRECISION"): + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + vec = pto.vlds(src[row, col:]) + scalar_vec = pto.vbr(scalar) + if pto.constexpr(dtype == pto.f32): + result = _div_ieee754_f32_impl(vec, scalar_vec, mask) + else: # dtype == pto.f16 (guaranteed by MLIR validation) + result = _div_ieee754_f16_impl(vec, scalar_vec, mask) + pto.vsts(result, dst[row, col:], mask) + else: + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + vec = pto.vlds(src[row, col:]) + scalar_vec = pto.vbr(scalar) + result = pto.vdiv(vec, scalar_vec, mask) + pto.vsts(result, dst[row, col:], mask) + return + + +@pto.vkernel( + target="a5", + op="pto.tdivs", +) +def template_tdivs_scalar_tile(scalar: pto.AnyType, src: pto.Tile, dst: pto.Tile): + """scalar / src with optional high-precision mode""" + dtype = src.element_type + valid_rows, valid_cols = src.valid_shape + + precision_mode = pto.get_op_attr("precision_mode", "DEFAULT") + if pto.constexpr(precision_mode == "HIGH_PRECISION"): + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + vec = pto.vlds(src[row, col:]) + scalar_vec = pto.vbr(scalar) + if pto.constexpr(dtype == pto.f32): + result = _div_ieee754_f32_impl(scalar_vec, vec, mask) + else: # dtype == pto.f16 (guaranteed by MLIR validation) + result = _div_ieee754_f16_impl(scalar_vec, vec, mask) + pto.vsts(result, dst[row, col:], mask) + else: + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + vec = pto.vlds(src[row, col:]) + scalar_vec = pto.vbr(scalar) + result = pto.vdiv(scalar_vec, vec, mask) + pto.vsts(result, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/texp_template.py b/lib/TileOps/texp_template.py new file mode 100644 index 000000000..609cb2fad --- /dev/null +++ b/lib/TileOps/texp_template.py @@ -0,0 +1,52 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.texp""" + +import tilelang_dsl as pto +from exp_hp import _tl_exp_precision + +@pto.inline_proc +def template_texp_hp_impl(src: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + vinput = pto.vlds(src[row, col:]) + result = _tl_exp_precision(vinput, mask, dtype) + pto.vsts(result, dst[row, col:], mask) + return + +@pto.inline_proc +def template_texp_impl(src: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + vinput = pto.vlds(src[row, col:]) + result = pto.vexp(vinput, mask) + pto.vsts(result, dst[row, col:], mask) + return + +@pto.vkernel( + target="a5", + op="pto.texp" +) +def template_texp(src: pto.Tile, dst: pto.Tile): + hp_mode = pto.get_op_attr("precision_mode") + if pto.constexpr(hp_mode == "HIGH_PRECISION"): + template_texp_hp_impl(src, dst) + else: + template_texp_impl(src, dst) + return diff --git a/lib/TileOps/texpand_template.py b/lib/TileOps/texpand_template.py new file mode 100644 index 000000000..d2e07360f --- /dev/null +++ b/lib/TileOps/texpand_template.py @@ -0,0 +1,72 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.texpands + +This template implements scalar broadcast expansion for location=VEC tiles. +It fills dst.valid_shape region with the broadcasted scalar value. + +Location constraint: + - This template is designed for tiles with location=VEC (vector buffer) + - In PTO-ISA, texpands has separate implementations for VEC and MAT locations + - For MAT location tiles, a different template/implementation path should be used + - Current tilelang_dsl MemorySpace only distinguishes GM and UB, where UB maps to + both VEC and MAT locations. The constraint checks memory_space=="ub" as a proxy + - Future enhancement: tilelang_dsl should support explicit location distinction + (e.g., MemorySpace.VEC vs MemorySpace.MAT) for more precise constraint matching + +Layout considerations: + - PTO-ISA has both rowmajor and colmajor expands implementations + - However, expands (scalar broadcast) is layout-agnostic: it simply fills + the tile with a scalar value using vector stores + - The vector store (vsts) writes data according to the tile's physical layout, + which is handled by the underlying DMA engine + - Therefore, this single template covers both rowmajor and colmajor cases +""" + +import tilelang_dsl as pto + + +def _texpands_vec_location_constraint(scalar, dst) -> bool: + """Constraint: dst tile must have location=VEC (represented as memory_space=ub). + + PTO-ISA defines texpands for both MAT and VEC locations: + - MAT location: expands matrix tiles (different implementation path, not supported here) + - VEC location: expands vector tiles (this template) + + Current tilelang_dsl limitation: + MemorySpace only has UB and GM. VEC and MAT both map to UB. + We check memory_space=="ub" as a proxy for VEC location. + MAT tiles should use a different op/template path and won't match here. + """ + # Check memory_space is "ub" (VEC/MAT location, not GM) + # In current tilelang_dsl, VEC location tiles have memory_space="ub" + ms = dst.memory_space + if isinstance(ms, str): + return ms == "ub" + return hasattr(ms, "value") and ms.value == "ub" + + +@pto.vkernel( + target="a5", + op="pto.texpands", + constraints=[_texpands_vec_location_constraint], +) +def template_texpands(scalar: pto.AnyType, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + # Use vdup for scalar broadcast + vec = pto.vdup(scalar, mask) + pto.vsts(vec, dst[row, col:], mask) + + return \ No newline at end of file diff --git a/lib/TileOps/tfillpad_expand_template.py b/lib/TileOps/tfillpad_expand_template.py new file mode 100644 index 000000000..6a685f707 --- /dev/null +++ b/lib/TileOps/tfillpad_expand_template.py @@ -0,0 +1,164 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tfillpad_expand + +Expand mode semantics: + - TFILLPAD_EXPAND: src rows may be less than dst rows + - Copy src.valid data to dst + - Fill cols from src.valid_cols to dst.valid_cols with FillPadVal + - Fill rows from src.rows to dst.rows with FillPadVal + +Strategy: + - Phase 1: Copy aligned valid blocks (cols 0 to aligned_col-1) + - Phase 2: Fill cols aligned_col to dst_valid_cols-1 with FillPadVal + - Phase 3: Copy tail valid lanes (cols aligned_col to src_valid_cols-1) + - Phase 4: Fill row expansion + +Address alignment and unaligned handling: + - vlds/vsts require 32-byte aligned base addresses + - Phase 1: col=0 is always aligned (tile base address is aligned), each iteration + accesses col + lanes which maintains alignment + - Phase 2/3/4: handle non-aligned lengths using make_mask() to control active lanes + - make_mask approach: simpler than vldus/vstus for isolated tail operations, no need + for alignment state management (vldas/vldus/vsta sequence) + - vldus/vstus is suitable for continuous unaligned streams; for single tail ops, + mask-controlled vlds/vsts is more direct and efficient +""" + +import tilelang_dsl as pto + +_NEG1_F32 = -1.0 + +# All supported dtype pairs for tfillpad_expand +_DTYPE_SIGNATURES = [ + (pto.f32, pto.f32), + (pto.i16, pto.i16), + (pto.si16, pto.si16), + (pto.ui16, pto.ui16), + (pto.i32, pto.i32), + (pto.si32, pto.si32), + (pto.ui32, pto.ui32), + (pto.i8, pto.i8), + (pto.si8, pto.si8), + (pto.ui8, pto.ui8), +] + + +@pto.vkernel( + target="a5", + op="pto.tfillpad_expand", + dtypes=_DTYPE_SIGNATURES, +) +def template_tfillpad_expand(src: pto.Tile, dst: pto.Tile): + """Unified tfillpad_expand template for all dtypes. + + Main logic is identical across dtypes; only PadValue handling differs: + - f32: ZERO + expansion uses -1.0 (special encoding), otherwise eval() or 0.0 + - integer families: eval() or dtype-specific zero constant + """ + dtype = dst.element_type + src_rows, _ = src.shape + src_valid_rows, src_valid_cols = src.valid_shape + dst_rows, _ = dst.shape + dst_valid_rows, dst_valid_cols = dst.valid_shape + + lanes = pto.get_lanes(dtype) + aligned_col = (src_valid_cols // lanes) * lanes + has_tail = src_valid_cols > aligned_col + has_valid_expansion = (src_valid_cols < dst_valid_cols) or (src_valid_rows < dst_valid_rows) + + # PadValue handling - dtype-specific + if pto.constexpr(dtype == pto.f32): + if pto.constexpr(dst.pad_value == pto.PadValue.ZERO and has_valid_expansion): + fill_scalar = pto.f32(_NEG1_F32) + elif pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.f32(0.0) + elif pto.constexpr(dtype == pto.ui16): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.ui16(0) + elif pto.constexpr(dtype == pto.si16): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.si16(0) + elif pto.constexpr(dtype == pto.i16): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.i16(0) + elif pto.constexpr(dtype == pto.ui32): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.ui32(0) + elif pto.constexpr(dtype == pto.si32): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.si32(0) + elif pto.constexpr(dtype == pto.i32): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.i32(0) + elif pto.constexpr(dtype == pto.ui8): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.ui8(0) + elif pto.constexpr(dtype == pto.si8): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.si8(0) + elif pto.constexpr(dtype == pto.i8): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.i8(0) + + # Phase 1: Copy aligned valid blocks + for row in range(0, src_valid_rows, 1): + remained = aligned_col + for col in range(0, aligned_col, lanes): + mask, remained = pto.make_mask(dtype, remained) + data = pto.vlds(src[row, col:]) + pto.vsts(data, dst[row, col:], mask) + + # Phase 2: Fill col padding + if pto.constexpr(aligned_col < dst_valid_cols): + for row in range(0, dst_valid_rows, 1): + remained = dst_valid_cols - aligned_col + for col in range(aligned_col, dst_valid_cols, lanes): + mask, remained = pto.make_mask(dtype, remained) + vec = pto.vdup(fill_scalar, mask) + pto.vsts(vec, dst[row, col:], mask) + + # Phase 3: Copy tail valid lanes + if pto.constexpr(has_tail): + for row in range(0, src_valid_rows, 1): + remained = src_valid_cols - aligned_col + mask_copy, remained = pto.make_mask(dtype, remained) + data = pto.vlds(src[row, aligned_col:]) + pto.vsts(data, dst[row, aligned_col:], mask_copy) + + # Phase 4: Fill row expansion + if pto.constexpr(src_rows < dst_rows): + for row in range(src_rows, dst_rows, 1): + remained = dst_valid_cols + for col in range(0, dst_valid_cols, lanes): + mask, remained = pto.make_mask(dtype, remained) + vec = pto.vdup(fill_scalar, mask) + pto.vsts(vec, dst[row, col:], mask) + + return \ No newline at end of file diff --git a/lib/TileOps/tfillpad_inplace_template.py b/lib/TileOps/tfillpad_inplace_template.py new file mode 100644 index 000000000..6214476c6 --- /dev/null +++ b/lib/TileOps/tfillpad_inplace_template.py @@ -0,0 +1,165 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tfillpad_inplace + +Semantic (based on TFillPad.hpp reference): + - TFILLPAD_INPLACE: same physical buffer (src == dst), skips copy phase, only fills expansion + - Inplace mode: src and dst share the same physical UB address + +Strategy (inplace mode): + - Skip Phase 1+3: Copy phases (data already in buffer) + - Phase 2: Fill cols from src_valid_cols to dst_valid_cols-1 with FillPadVal + - Phase 4: Fill rows from src_valid_rows to dst_valid_rows-1 with FillPadVal +""" + +import tilelang_dsl as pto + +_NEG1_F32 = -1.0 + +# All supported dtype pairs +_DTYPE_SIGNATURES = [ + (pto.f32, pto.f32), + (pto.i16, pto.i16), + (pto.si16, pto.si16), + (pto.ui16, pto.ui16), + (pto.i32, pto.i32), + (pto.si32, pto.si32), + (pto.ui32, pto.ui32), + (pto.i8, pto.i8), + (pto.si8, pto.si8), + (pto.ui8, pto.ui8), +] + + +@pto.vkernel( + target="a5", + op="pto.tfillpad_inplace", + dtypes=_DTYPE_SIGNATURES, + advanced=True, # Required for as_ptr() +) +def template_tfillpad_inplace(src: pto.Tile, dst: pto.Tile): + """tfillpad_inplace: skip copy phase, only fill expansion regions. + +Uses vstus+vstas for unaligned column fill, matching TFillPad.hpp. +""" + dtype = dst.element_type + _, cols = dst.shape + src_valid_rows, src_valid_cols = src.valid_shape + dst_valid_rows, dst_valid_cols = dst.valid_shape + + lanes = pto.get_lanes(dtype) + byte_width = pto.bytewidth(dtype) + has_valid_expansion = (src_valid_cols < dst_valid_cols) or (src_valid_rows < dst_valid_rows) + + # PadValue handling - same as tfillpad_template.py + # Note: dtype and pad_value are compile-time known, so constexpr is valid for those. + # has_valid_expansion is a runtime value derived from dynamic shapes, so split the condition. + if pto.constexpr(dtype == pto.f32): + if pto.constexpr(dst.pad_value == pto.PadValue.ZERO): + # For ZERO pad_value, use -1.0 encoding only when there's valid expansion + if has_valid_expansion: + fill_scalar = pto.f32(_NEG1_F32) + else: + fill_scalar = pto.f32(0.0) + elif pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.f32(0.0) + elif pto.constexpr(dtype == pto.ui16): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.ui16(0) + elif pto.constexpr(dtype == pto.si16): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.si16(0) + elif pto.constexpr(dtype == pto.i16): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.i16(0) + elif pto.constexpr(dtype == pto.ui32): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.ui32(0) + elif pto.constexpr(dtype == pto.si32): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.si32(0) + elif pto.constexpr(dtype == pto.i32): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.i32(0) + elif pto.constexpr(dtype == pto.ui8): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.ui8(0) + elif pto.constexpr(dtype == pto.si8): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.si8(0) + elif pto.constexpr(dtype == pto.i8): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.i8(0) + + # Phase 2: Fill cols from src_valid_cols to cols-1 (physical buffer end) + # Use vstus+vstas for unaligned starting column, matching TFillPad.hpp + pad_cols = cols - src_valid_cols # TileDataDst::Cols - srcValidCol + + # Create fill vector once (reused across all rows) + fill_vec = pto.vdup(fill_scalar, pto.make_mask(dtype, pto.PAT.ALL)) + + # Get base pointer to UB buffer + base_ptr = dst.as_ptr() + + for row in range(0, src_valid_rows, 1): + # Initialize align register for this row + ureg = pto.init_align() + + # Pointer to dst[row, src_valid_cols]: base_ptr + (row * cols + src_valid_cols) * byte_width + # Matching: dstPtr + i * dstStride + srcValidCol + row_offset = (row * cols + src_valid_cols) * byte_width + row_ptr = pto.addptr(base_ptr, row_offset) + + # Simple loop: iterate pad_cols times with step lanes + # Use vstus + addptr in each branch to simulate POST_UPDATE behavior + # ureg, remaining, row_ptr are all loop-carried, updated in every iteration + remaining = pad_cols + for _ in range(0, pad_cols, lanes): + if remaining >= lanes: + ureg = pto.vstus(ureg, lanes, fill_vec, row_ptr) + row_ptr = pto.addptr(row_ptr, lanes * byte_width) + remaining = remaining - lanes + else: + ureg = pto.vstus(ureg, remaining, fill_vec, row_ptr) + row_ptr = pto.addptr(row_ptr, remaining * byte_width) + remaining = 0 + + # vstas: flush buffered bytes + pto.vstas(ureg, row_ptr, 0) + + # Phase 4: Fill rows from src_valid_rows to dst_valid_rows-1 + # Fill entire physical rows (cols elements), matching: padRows * dstStride + for row in range(src_valid_rows, dst_valid_rows, 1): + remained = cols + for col in range(0, cols, lanes): + mask, remained = pto.make_mask(dtype, remained) + vec = pto.vdup(fill_scalar, mask) + pto.vsts(vec, dst[row, col:], mask) + + return \ No newline at end of file diff --git a/lib/TileOps/tfillpad_template.py b/lib/TileOps/tfillpad_template.py new file mode 100644 index 000000000..77730ad11 --- /dev/null +++ b/lib/TileOps/tfillpad_template.py @@ -0,0 +1,170 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tfillpad + +Semantic (based on C++ TFillPad.hpp reference): + - TFILLPAD: copies src.valid data to dst, then fills dst expansion with FillPadVal + - TFILLPAD_INPLACE: same physical buffer (src == dst), skips copy phase, only fills expansion + +Key logic from C++: + if constexpr (!inplace) { + CopyValidElementsVec(dst, src, ...); // Phase 1+3: copy valid data + } + // Phase 2+4: fill expansion (always executed if dst has larger valid region) + FillExpansion(dst, padCols, padRows, padValue); + +Strategy: + - Phase 1: Copy aligned valid blocks (cols 0 to aligned_col-1) [only if !inplace] + - Phase 2: Fill cols aligned_col to dst_valid_cols-1 with FillPadVal + - Phase 3: Copy tail valid lanes (cols aligned_col to src_valid_cols-1) [only if !inplace] + - Phase 4: Fill row expansion + +Address alignment and unaligned handling: + - vlds/vsts require 32-byte aligned base addresses + - Phase 1: col=0 is always aligned (tile base address is aligned) + - Phase 2/4: handle non-aligned lengths using make_mask() to control active lanes + +Note: There is no separate pto.tfillpad_inplace operation in PTO IR. + In-place mode is expressed via tfillpad with src and dst being the same SSA value. + This template handles both cases by detecting if copy phase is needed. +""" + +import tilelang_dsl as pto + +_NEG1_F32 = -1.0 + +# All supported dtype pairs +_DTYPE_SIGNATURES = [ + (pto.f32, pto.f32), + (pto.i16, pto.i16), + (pto.si16, pto.si16), + (pto.ui16, pto.ui16), + (pto.i32, pto.i32), + (pto.si32, pto.si32), + (pto.ui32, pto.ui32), + (pto.i8, pto.i8), + (pto.si8, pto.si8), + (pto.ui8, pto.ui8), +] + + +@pto.vkernel( + target="a5", + op="pto.tfillpad", + dtypes=_DTYPE_SIGNATURES, +) +def template_tfillpad(src: pto.Tile, dst: pto.Tile): + """tfillpad: copy src.valid to dst and fill expansion regions. + + Based on C++ TFillPad.hpp reference: + - TFILLPAD (non-inplace): CopyValidElementsVec + FillExpansion + + tfillpad requires src.shape == dst.shape (same physical size). + If dst.valid > src.valid, fill the expansion regions. + """ + dtype = dst.element_type + src_rows, _ = src.shape + src_valid_rows, src_valid_cols = src.valid_shape + dst_rows, dst_cols = dst.shape + dst_valid_rows, dst_valid_cols = dst.valid_shape + + lanes = pto.get_lanes(dtype) + aligned_col = (src_valid_cols // lanes) * lanes + has_tail = src_valid_cols > aligned_col + has_valid_expansion = (src_valid_cols < dst_valid_cols) or (src_valid_rows < dst_valid_rows) + + # PadValue handling - dtype-specific (inline to avoid external call) + if pto.constexpr(dtype == pto.f32): + if pto.constexpr(dst.pad_value == pto.PadValue.ZERO and has_valid_expansion): + fill_scalar = pto.f32(_NEG1_F32) + elif pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.f32(0.0) + elif pto.constexpr(dtype == pto.ui16): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.ui16(0) + elif pto.constexpr(dtype == pto.si16): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.si16(0) + elif pto.constexpr(dtype == pto.i16): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.i16(0) + elif pto.constexpr(dtype == pto.ui32): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.ui32(0) + elif pto.constexpr(dtype == pto.si32): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.si32(0) + elif pto.constexpr(dtype == pto.i32): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.i32(0) + elif pto.constexpr(dtype == pto.ui8): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.ui8(0) + elif pto.constexpr(dtype == pto.si8): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.si8(0) + elif pto.constexpr(dtype == pto.i8): + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + fill_scalar = dst.pad_value.eval() + else: + fill_scalar = pto.i8(0) + + # Phase 1: Copy aligned valid blocks + for row in range(0, src_valid_rows, 1): + remained = aligned_col + for col in range(0, aligned_col, lanes): + mask, remained = pto.make_mask(dtype, remained) + data = pto.vlds(src[row, col:]) + pto.vsts(data, dst[row, col:], mask) + + # Phase 2: Fill cols from aligned_col to dst_cols-1 + if pto.constexpr(aligned_col < dst_cols): + for row in range(0, src_valid_rows, 1): + remained = dst_cols - aligned_col + for col in range(aligned_col, dst_cols, lanes): + mask, remained = pto.make_mask(dtype, remained) + vec = pto.vdup(fill_scalar, mask) + pto.vsts(vec, dst[row, col:], mask) + + # Phase 3: Copy tail valid lanes + if pto.constexpr(has_tail): + for row in range(0, src_valid_rows, 1): + remained = src_valid_cols - aligned_col + mask_copy, remained = pto.make_mask(dtype, remained) + data = pto.vlds(src[row, aligned_col:]) + pto.vsts(data, dst[row, aligned_col:], mask_copy) + + # Phase 4: Fill row expansion + if pto.constexpr(src_rows < dst_rows): + for row in range(src_rows, dst_rows, 1): + remained = dst_cols + for col in range(0, dst_cols, lanes): + mask, remained = pto.make_mask(dtype, remained) + vec = pto.vdup(fill_scalar, mask) + pto.vsts(vec, dst[row, col:], mask) + + return \ No newline at end of file diff --git a/lib/TileOps/tfmod_template.py b/lib/TileOps/tfmod_template.py new file mode 100644 index 000000000..32f7cb65f --- /dev/null +++ b/lib/TileOps/tfmod_template.py @@ -0,0 +1,71 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tfmod + +Aligned with pto-isa/include/pto/npu/a5/TFMod.hpp: +- float: vdiv -> vtrc(ROUND_Z) -> vmul -> vsub +- half: vcvt(half->float, PART_EVEN/ODD) -> vdiv -> vtrc(ROUND_Z) -> vmul -> vsub + -> vcvt(float->half, ROUND_Z, RS_ENABLE, PART_EVEN/ODD) -> vor +- other (i16/ui16): vdiv -> vmul -> vsub (no vtrc, integer div is trunc by nature) +""" + +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tfmod", + dtypes=[ + (pto.f32, pto.f32, pto.f32), + (pto.f16, pto.f16, pto.f16), + (pto.i16, pto.i16, pto.i16), + (pto.ui16, pto.ui16, pto.ui16), + ], +) +def template_tfmod(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + + if pto.constexpr(dtype == pto.f32): + quotient = pto.vdiv(lhs, rhs, mask) + quotient = pto.vtrc(quotient, mask, rnd=pto.VcvtRoundMode.Z) + truncated_mul = pto.vmul(quotient, rhs, mask) + result = pto.vsub(lhs, truncated_mul, mask) + elif pto.constexpr(dtype == pto.f16): + lhs_even = pto.vcvt(lhs, pto.f32, mask, part=pto.VcvtPartMode.EVEN) + rhs_even = pto.vcvt(rhs, pto.f32, mask, part=pto.VcvtPartMode.EVEN) + quotient_even = pto.vdiv(lhs_even, rhs_even, mask) + quotient_even = pto.vtrc(quotient_even, mask, rnd=pto.VcvtRoundMode.Z) + truncated_mul_even = pto.vmul(quotient_even, rhs_even, mask) + result_even = pto.vsub(lhs_even, truncated_mul_even, mask) + dst_even = pto.vcvt(result_even, pto.f16, mask, rnd=pto.VcvtRoundMode.Z, sat=pto.VcvtSatMode.SAT, part=pto.VcvtPartMode.EVEN) + + lhs_odd = pto.vcvt(lhs, pto.f32, mask, part=pto.VcvtPartMode.ODD) + rhs_odd = pto.vcvt(rhs, pto.f32, mask, part=pto.VcvtPartMode.ODD) + quotient_odd = pto.vdiv(lhs_odd, rhs_odd, mask) + quotient_odd = pto.vtrc(quotient_odd, mask, rnd=pto.VcvtRoundMode.Z) + truncated_mul_odd = pto.vmul(quotient_odd, rhs_odd, mask) + result_odd = pto.vsub(lhs_odd, truncated_mul_odd, mask) + dst_odd = pto.vcvt(result_odd, pto.f16, mask, rnd=pto.VcvtRoundMode.Z, sat=pto.VcvtSatMode.SAT, part=pto.VcvtPartMode.ODD) + + result = pto.vor(dst_even, dst_odd, mask) + else: + quotient = pto.vdiv(lhs, rhs, mask) + truncated_mul = pto.vmul(quotient, rhs, mask) + result = pto.vsub(lhs, truncated_mul, mask) + + pto.vsts(result, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/tfmods_template.py b/lib/TileOps/tfmods_template.py new file mode 100644 index 000000000..9d9606625 --- /dev/null +++ b/lib/TileOps/tfmods_template.py @@ -0,0 +1,111 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tfmods + +Note: A5 hardware implements tfmods as: + - float: dst = src - trunc(src / scalar) * scalar (f32 precision) + - half: dst = src - trunc(src / scalar) * scalar (computed in f32 precision, then converted back to f16) + - integer: dst = src - (src / scalar) * scalar (integer division already truncates) + +f16 path: convert f16 to f32 (even/odd), compute in f32 with vtrc(ROUND_Z), +convert back to f16 with ROUND_Z, merge with vor. +""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tfmods", + dtypes=[ + (pto.f32, pto.f32, pto.f32), + (pto.f16, pto.f16, pto.f16), + (pto.i32, pto.i32, pto.i32), + (pto.i16, pto.i16, pto.i16), + ], + advanced=True, +) +def template_tfmods(src: pto.Tile, scalar: pto.AnyType, dst: pto.Tile): + """dst = src - trunc(src / scalar) * scalar""" + dtype = src.element_type + valid_rows, valid_cols = src.valid_shape + + if pto.constexpr(dtype == pto.f32): + # f32 path: direct f32 computation + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(pto.f32)): + mask, remained = pto.make_mask(pto.f32, remained) + vec = pto.vlds(src[row, col:]) + scalar_vec = pto.vbr(scalar) + quotient = pto.vdiv(vec, scalar_vec, mask) + truncated = pto.vtrc(quotient, mask, rnd="Z") + product = pto.vmuls(truncated, scalar, mask) + result = pto.vsub(vec, product, mask) + pto.vsts(result, dst[row, col:], mask) + elif pto.constexpr(dtype == pto.f16): + # f16 path: compute in f32 precision, then convert back to f16 + full_mask_b16 = pto.make_mask(pto.f16, pto.PAT.ALL) + full_mask_b32 = pto.make_mask(pto.f32, pto.PAT.ALL) + scalar_vec_f16 = pto.vbr(scalar) + scalar_f32_vec = pto.vcvt(scalar_vec_f16, pto.f32, full_mask_b16, part=pto.VcvtPartMode.EVEN) + + for row in range(0, valid_rows, 1): + remained = valid_cols + remained_f32 = valid_cols + for col in range(0, valid_cols, pto.get_lanes(pto.f16)): + mask_f16, remained = pto.make_mask(pto.f16, remained) + mask_f32, remained_f32 = pto.make_mask(pto.f32, remained_f32) + vec = pto.vlds(src[row, col:]) + + # Convert f16 to f32 (even and odd parts) + vec_even = pto.vcvt(vec, pto.f32, full_mask_b16, part=pto.VcvtPartMode.EVEN) + vec_odd = pto.vcvt(vec, pto.f32, full_mask_b16, part=pto.VcvtPartMode.ODD) + + # Even part: f32 computation + scalar_f32 = pto.f32(scalar) + quotient_even = pto.vdiv(vec_even, scalar_f32_vec, mask_f32) + truncated_even = pto.vtrc(quotient_even, mask_f32, rnd="Z") + product_even = pto.vmuls(truncated_even, scalar_f32, mask_f32) + result_even = pto.vsub(vec_even, product_even, mask_f32) + + # Odd part: f32 computation + quotient_odd = pto.vdiv(vec_odd, scalar_f32_vec, mask_f32) + truncated_odd = pto.vtrc(quotient_odd, mask_f32, rnd="Z") + product_odd = pto.vmuls(truncated_odd, scalar_f32, mask_f32) + result_odd = pto.vsub(vec_odd, product_odd, mask_f32) + + # Convert f32 results back to f16 with ROUND_Z + saturation + result_f16_even = pto.vcvt(result_even, pto.f16, full_mask_b32, + rnd=pto.VcvtRoundMode.Z, + sat=pto.VcvtSatMode.SAT, + part=pto.VcvtPartMode.EVEN) + result_f16_odd = pto.vcvt(result_odd, pto.f16, full_mask_b32, + rnd=pto.VcvtRoundMode.Z, + sat=pto.VcvtSatMode.SAT, + part=pto.VcvtPartMode.ODD) + + # Merge even and odd parts + result_f16 = pto.vor(result_f16_even, result_f16_odd, mask_f16) + pto.vsts(result_f16, dst[row, col:], mask_f16) + else: + # Integer path: vdiv already truncates towards zero, no vtrc needed + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + vec = pto.vlds(src[row, col:]) + scalar_vec = pto.vbr(scalar) + quotient = pto.vdiv(vec, scalar_vec, mask) + product = pto.vmuls(quotient, scalar, mask) + result = pto.vsub(vec, product, mask) + pto.vsts(result, dst[row, col:], mask) + return diff --git a/lib/TileOps/tload_template.py b/lib/TileOps/tload_template.py new file mode 100644 index 000000000..7be0971a3 --- /dev/null +++ b/lib/TileOps/tload_template.py @@ -0,0 +1,302 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""`pto.tload` 的 TileLang DSL 模板""" + +import tilelang_dsl as pto + + +def _constraint_scalar(value): + return value.value if hasattr(value, "value") else value + + +def _known_eq(lhs, rhs) -> bool: + lhs_value = _constraint_scalar(lhs) + rhs_value = _constraint_scalar(rhs) + if lhs_value is None or rhs_value is None: + return True + return lhs_value == rhs_value + + +def _known_le(lhs, rhs) -> bool: + lhs_value = _constraint_scalar(lhs) + rhs_value = _constraint_scalar(rhs) + if lhs_value is None or rhs_value is None: + return True + return lhs_value <= rhs_value + + +def _match_tile_layout(dst, *, row_major: bool, s_layout) -> bool: + b_layout_ok = ( + dst.config.b_layout == pto.BLayout.ROW_MAJOR + if row_major + else dst.config.b_layout != pto.BLayout.ROW_MAJOR + ) + return b_layout_ok and dst.config.s_layout == s_layout + + +def _check_load_bounds(src, dst, *, logical_rows, logical_cols=None, stride_axis=None) -> bool: + if src.rank != 5: + return False + if stride_axis is not None and not _known_eq(src.strides[stride_axis], 1): + return False + if not _known_le(dst.valid_shape[0], logical_rows): + return False + if not _known_le(logical_rows, dst.shape[0]): + return False + if not _known_le(dst.valid_shape[0], dst.shape[0]): + return False + if logical_cols is not None: + if not _known_le(dst.valid_shape[1], logical_cols): + return False + if not _known_le(logical_cols, dst.shape[1]): + return False + if not _known_le(dst.valid_shape[1], dst.shape[1]): + return False + return True + + +def _tload_preconditions_nd2nd(src, dst) -> bool: + logical_rows = src.shape[0] * src.shape[1] * src.shape[2] * src.shape[3] + logical_cols = src.shape[4] + return _match_tile_layout( + dst, row_major=True, s_layout=pto.SLayout.NONE_BOX + ) and _check_load_bounds( + src, dst, logical_rows=logical_rows, logical_cols=logical_cols, stride_axis=4 + ) + + +def _tload_preconditions_dn2dn(src, dst) -> bool: + logical_rows = src.shape[3] + logical_cols = src.shape[0] * src.shape[1] * src.shape[2] * src.shape[4] + return _match_tile_layout( + dst, row_major=False, s_layout=pto.SLayout.NONE_BOX + ) and _check_load_bounds( + src, dst, logical_rows=logical_rows, logical_cols=logical_cols, stride_axis=3 + ) + +def _tload_preconditions_nz2nz(src, dst) -> bool: + logical_rows = src.shape[2] + return _match_tile_layout( + dst, row_major=False, s_layout=pto.SLayout.ROW_MAJOR + ) and _check_load_bounds( + src, dst, logical_rows=logical_rows + ) + + +@pto.vkernel( + target="a5", + op="pto.tload", + advanced=True, + constraints=[_tload_preconditions_nd2nd], +) +def template_tload_nd2nd(src: pto.PartitionTensorView, dst: pto.Tile): + dtype = dst.element_type + elem_bytes = pto.bytewidth(dtype) + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + pto.set_mov_pad_val(dst.pad_value.eval()) + + g0, g1, g2, g3, g4 = src.shape + s0, s1, s2, s3, s4 = src.strides + + valid_rows, valid_cols = dst.valid_shape + ub_rows, ub_cols = dst.shape + + n_burst = g3 + len_burst = g4 * elem_bytes + gm_stride = s3 * elem_bytes + ub_stride = ub_cols * elem_bytes + + dst_stride2 = g3 * ub_cols + dst_stride1 = g2 * dst_stride2 + dst_stride0 = g1 * dst_stride1 + + loop1 = g2 + loop2 = g1 + loop1_src_stride = s2 * elem_bytes + loop1_dst_stride = dst_stride2 * elem_bytes + loop2_src_stride = s1 * elem_bytes + loop2_dst_stride = dst_stride1 * elem_bytes + + gm_ptr = src.as_ptr() + ub_ptr = dst.as_ptr() + + if loop1 != 1 or loop2 != 1: + pto.set_loop2_stride_outtoub( + src_stride=loop2_src_stride, dst_stride=loop2_dst_stride + ) + pto.set_loop1_stride_outtoub( + src_stride=loop1_src_stride, dst_stride=loop1_dst_stride + ) + pto.set_loop_size_outtoub(loop1=loop1, loop2=loop2) + + for i in range(0, g0, 1): + src_i = pto.addptr(gm_ptr, i * s0) + dst_i = pto.addptr(ub_ptr, i * dst_stride0) + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + pto.copy_gm_to_ubuf( + dst=dst_i, + src=src_i, + n_burst=n_burst, + len_burst=len_burst, + gm_stride=gm_stride, + ub_stride=ub_stride, + enable_ub_pad=True, + ) + else: + pto.copy_gm_to_ubuf( + dst=dst_i, + src=src_i, + n_burst=n_burst, + len_burst=len_burst, + gm_stride=gm_stride, + ub_stride=ub_stride, + enable_ub_pad=False, + ) + + if loop1 != 1 or loop2 != 1: + pto.set_loop_size_outtoub(loop1=1, loop2=1) + return + +@pto.vkernel( + target="a5", + op="pto.tload", + advanced=True, + constraints=[_tload_preconditions_dn2dn], +) +def template_tload_dn2dn(src: pto.PartitionTensorView, dst: pto.Tile): + dtype = dst.element_type + elem_bytes = pto.bytewidth(dtype) + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + pto.set_mov_pad_val(dst.pad_value.eval()) + + # rank-5 partition view 元信息。 + g0, g1, g2, g3, g4 = src.shape + s0, s1, s2, s3, s4 = src.strides + + tile_rows, tile_cols = dst.shape + valid_rows, valid_cols = dst.valid_shape + + n_burst = g4 + len_burst = valid_rows * elem_bytes + gm_stride = s4 * elem_bytes + ub_stride = tile_rows * elem_bytes + + # UB 目标 tile 是列高为 `tile_rows` 的紧凑 col-major 布局, + # 从最内层 `g4 × tile_rows` 块递推出三层阶梯 stride。 + dst_stride2 = g4 * tile_rows + dst_stride1 = g2 * dst_stride2 + dst_stride0 = g1 * dst_stride1 + + # loop1 ↔ g2(内层),loop2 ↔ g1(外层),软件 for ↔ g0。 + loop1 = g2 + loop2 = g1 + loop1_src_stride = s2 * elem_bytes + loop1_dst_stride = dst_stride2 * elem_bytes + loop2_src_stride = s1 * elem_bytes + loop2_dst_stride = dst_stride1 * elem_bytes + + gm_ptr = src.as_ptr() + ub_ptr = dst.as_ptr() + + if loop1 != 1 or loop2 != 1: + pto.set_loop2_stride_outtoub( + src_stride=loop2_src_stride, dst_stride=loop2_dst_stride + ) + pto.set_loop1_stride_outtoub( + src_stride=loop1_src_stride, dst_stride=loop1_dst_stride + ) + pto.set_loop_size_outtoub(loop1=loop1, loop2=loop2) + + for i in range(0, g0, 1): + src_i = pto.addptr(gm_ptr, i * s0) + dst_i = pto.addptr(ub_ptr, i * dst_stride0) + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + pto.copy_gm_to_ubuf( + dst=dst_i, + src=src_i, + n_burst=n_burst, + len_burst=len_burst, + gm_stride=gm_stride, + ub_stride=ub_stride, + enable_ub_pad=True, + ) + else: + pto.copy_gm_to_ubuf( + dst=dst_i, + src=src_i, + n_burst=n_burst, + len_burst=len_burst, + gm_stride=gm_stride, + ub_stride=ub_stride, + enable_ub_pad=False, + ) + + if loop1 != 1 or loop2 != 1: + pto.set_loop_size_outtoub(loop1=1, loop2=1) + return + +@pto.vkernel( + target="a5", + op="pto.tload", + advanced=True, + constraints=[_tload_preconditions_nz2nz], +) +def template_tload_nz2nz(src: pto.PartitionTensorView, dst: pto.Tile): + dtype = dst.element_type + elem_bytes = pto.bytewidth(dtype) + + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + pto.set_mov_pad_val(dst.pad_value.eval()) + + # rank-5 partition view 元信息。NZ 静态分块约束(g3/g4 与 dtype 的关系) + # 由更高层 schema/static-check 保证,这里只保留运行时搬运公式。 + g0, g1, g2, g3, g4 = src.shape + s0, s1, s2, s3, s4 = src.strides + + tile_rows, tile_cols = dst.shape + valid_rows, valid_cols = dst.valid_shape + + c0_size_bytes = 32 + n_burst = g1 + len_burst = valid_rows * c0_size_bytes + gm_stride = s1 * elem_bytes + ub_stride = tile_rows * c0_size_bytes + + # 每个 g0 block 在 UB 中包含 `g1` 个 NZ 小块;每块的列宽是 `g4` elems。 + tile_stride = g1 * tile_rows * g4 + + gm_ptr = src.as_ptr() + ub_ptr = dst.as_ptr() + + # NZ2NZ 对应实现始终走 normal mode,不复用 loop1/loop2 寄存器。 + pto.set_loop_size_outtoub(loop1=1, loop2=1) + for i in range(0, g0, 1): + src_i = pto.addptr(gm_ptr, i * s0) + dst_i = pto.addptr(ub_ptr, i * tile_stride) + if pto.constexpr(dst.pad_value != pto.PadValue.NULL): + pto.copy_gm_to_ubuf( + dst=dst_i, + src=src_i, + n_burst=n_burst, + len_burst=len_burst, + gm_stride=gm_stride, + ub_stride=ub_stride, + enable_ub_pad=True, + ) + else: + pto.copy_gm_to_ubuf( + dst=dst_i, + src=src_i, + n_burst=n_burst, + len_burst=len_burst, + gm_stride=gm_stride, + ub_stride=ub_stride, + enable_ub_pad=False, + ) + return diff --git a/lib/TileOps/tlog_template.py b/lib/TileOps/tlog_template.py new file mode 100644 index 000000000..9b3dfabcd --- /dev/null +++ b/lib/TileOps/tlog_template.py @@ -0,0 +1,66 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tlog""" + +import tilelang_dsl as pto + + +@pto.inline_proc +def _tlog_high_precision(src: pto.Tile, dst: pto.Tile, dtype, valid_rows, valid_cols): + if pto.constexpr(dtype == pto.f16): + subnormal_threshold = pto.f16("0x03FF") + mul_factor = pto.f16("0x6400") + compensation = pto.f16(-6.931471805599453094172) + elif pto.constexpr(dtype == pto.f32): + subnormal_threshold = pto.f32("0x007FFFFF") + mul_factor = pto.f32("0x4B000000") + compensation = pto.f32(-15.9423851528787421) + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + vinput = pto.vlds(src[row, col:]) + cmp_mask = pto.vcmps(vinput, subnormal_threshold, mask, pto.CmpMode.LT) + scaled = pto.vmuls(vinput, mul_factor, mask) + selected_input = pto.vsel(scaled, vinput, cmp_mask) + log_result = pto.vln(selected_input, mask) + compensated = pto.vadds(log_result, compensation, mask) + result = pto.vsel(compensated, log_result, cmp_mask) + pto.vsts(result, dst[row, col:], mask) + return None + + +@pto.inline_proc +def _tlog_default(src: pto.Tile, dst: pto.Tile, dtype, valid_rows, valid_cols): + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + vinput = pto.vlds(src[row, col:]) + result = pto.vln(vinput, mask) + pto.vsts(result, dst[row, col:], mask) + return None + + +@pto.vkernel( + target="a5", + op="pto.tlog", + advanced=True +) +def template_tlog(src: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + precision_mode = pto.get_op_attr("precision_mode", "DEFAULT") + + if pto.constexpr(precision_mode == "HIGH_PRECISION"): + _tlog_high_precision(src, dst, dtype, valid_rows, valid_cols) + else: + _tlog_default(src, dst, dtype, valid_rows, valid_cols) + return \ No newline at end of file diff --git a/lib/TileOps/tlrelu_template.py b/lib/TileOps/tlrelu_template.py new file mode 100644 index 000000000..33087ca51 --- /dev/null +++ b/lib/TileOps/tlrelu_template.py @@ -0,0 +1,46 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tlrelu (Leaky ReLU with scalar slope)""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tlrelu", + advanced=True +) +def template_tlrelu(src: pto.Tile, slope: pto.f32, dst: pto.Tile): + """Leaky ReLU: dst = src if src > 0 else src * slope. + + Semantics: + For each element (i, j): + dst[i, j] = src[i, j] > 0 ? src[i, j] : slope * src[i, j] + + Supported data types: f16, f32 + """ + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + lanes = pto.get_lanes(dtype) + if pto.constexpr(dtype == pto.f16): + slope_scalar = pto.f16(slope) + else: + slope_scalar = slope + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, lanes): + mask, remained = pto.make_mask(dtype, remained) + src_vec = pto.vlds(src[row, col:]) + result = pto.vlrelu(src_vec, slope_scalar, mask) + pto.vsts(result, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/tmatmul_template.py b/lib/TileOps/tmatmul_template.py new file mode 100644 index 000000000..96ba7ea9b --- /dev/null +++ b/lib/TileOps/tmatmul_template.py @@ -0,0 +1,27 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tmatmul.""" + +import tilelang_dsl as pto + + +@pto.ckernel( + target="a5", + op="pto.tmatmul", + dtypes=[ + (pto.f16, pto.f16, pto.f32), + (pto.bf16, pto.bf16, pto.f32), + (pto.f32, pto.f32, pto.f32), + ], +) +def template_tmatmul(lhs: pto.Tile, rhs: pto.Tile, acc: pto.Tile): + m, k = lhs.valid_shape + n, _ = rhs.valid_shape + pto.mad(lhs.as_ptr(), rhs.as_ptr(), acc.as_ptr(), m, n, k, disable_gemv=True) + return None diff --git a/lib/TileOps/tmax_template.py b/lib/TileOps/tmax_template.py new file mode 100644 index 000000000..8831d3eef --- /dev/null +++ b/lib/TileOps/tmax_template.py @@ -0,0 +1,32 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tmax""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tmax" +) +def template_tmax(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + max_val = pto.vmax(lhs, rhs, mask) + pto.vsts(max_val, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/tmaxs_template.py b/lib/TileOps/tmaxs_template.py new file mode 100644 index 000000000..5c9e409a3 --- /dev/null +++ b/lib/TileOps/tmaxs_template.py @@ -0,0 +1,31 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tmaxs""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tmaxs", +) +def template_tmaxs(src: pto.Tile, scalar: pto.AnyType, dst: pto.Tile): + dtype = src.element_type + valid_rows, valid_cols = src.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + vec = pto.vlds(src[row, col:]) + result = pto.vmaxs(vec, scalar, mask) + pto.vsts(result, dst[row, col:], mask) + return diff --git a/lib/TileOps/tmin_template.py b/lib/TileOps/tmin_template.py new file mode 100644 index 000000000..61664b14d --- /dev/null +++ b/lib/TileOps/tmin_template.py @@ -0,0 +1,32 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tmin""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tmin" +) +def template_tmin(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + min_val = pto.vmin(lhs, rhs, mask) + pto.vsts(min_val, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/tmins_template.py b/lib/TileOps/tmins_template.py new file mode 100644 index 000000000..bda0df5f9 --- /dev/null +++ b/lib/TileOps/tmins_template.py @@ -0,0 +1,31 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tmins""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tmins", +) +def template_tmins(src: pto.Tile, scalar: pto.AnyType, dst: pto.Tile): + dtype = src.element_type + valid_rows, valid_cols = src.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + vec = pto.vlds(src[row, col:]) + result = pto.vmins(vec, scalar, mask) + pto.vsts(result, dst[row, col:], mask) + return diff --git a/lib/TileOps/tmov_template.py b/lib/TileOps/tmov_template.py new file mode 100644 index 000000000..4b1567f91 --- /dev/null +++ b/lib/TileOps/tmov_template.py @@ -0,0 +1,95 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tmov - tile data movement + +This template implements UB2UB ND2ND tile data movement: + - UB2UB: Both src and dst must be in Unified Buffer (memory_space="ub") + - ND2ND: Both tiles must have N-dimensional layout (s_layout=NONE_BOX) + +For other transfer scenarios (GM2UB, UB2GM, or specialized layouts), +different templates/implementation paths should be used. +""" + +import tilelang_dsl as pto + + +def _tmov_ub2ub_nd2nd_constraint(src: pto.Tile, dst: pto.Tile) -> bool: + """Constraint: Both src and dst must be UB location with ND layout. + + Supported scenario: + - UB2UB: src and dst both have memory_space="ub" + - ND2ND: src and dst both have s_layout=NONE_BOX (N-dimensional format) + + Unsupported scenarios (require different implementation paths): + - GM2UB: src from Global Memory, dst to Unified Buffer + - UB2GM: src from Unified Buffer, dst to Global Memory + - Specialized layouts (e.g., cube formats with non-NONE_BOX s_layout) + """ + # Check memory_space for both tiles (UB2UB constraint) + src_ms = src.memory_space + dst_ms = dst.memory_space + if isinstance(src_ms, str): + src_is_ub = src_ms == "ub" + else: + src_is_ub = hasattr(src_ms, "value") and src_ms.value == "ub" + if isinstance(dst_ms, str): + dst_is_ub = dst_ms == "ub" + else: + dst_is_ub = hasattr(dst_ms, "value") and dst_ms.value == "ub" + + if not (src_is_ub and dst_is_ub): + return False + + # Check s_layout for both tiles (ND2ND constraint) + # ND layout uses NONE_BOX, specialized layouts (cube, etc.) use different values + src_config = src.config + dst_config = dst.config + if src_config is None or dst_config is None: + return False + if src_config.s_layout != pto.SLayout.NONE_BOX: + return False + if dst_config.s_layout != pto.SLayout.NONE_BOX: + return False + + return True + + +@pto.vkernel( + target="a5", + op="pto.tmov", + constraints=[_tmov_ub2ub_nd2nd_constraint], + advanced=True, +) +def template_tmov_basic(src: pto.Tile, dst: pto.Tile): + """Basic tile-to-tile data movement using vlds/vsts. + + Based on TMovVecToVec in TMov.hpp (lines 378-405): + - Iterate over valid_row rows + - Each row processed in chunks of nRepeatElem elements + - Use predicate mask for partial chunks + + Args: + src: Source tile (Vec location) + dst: Destination tile (Vec location) + """ + dtype = dst.element_type + lanes = pto.get_lanes(dtype) + + # Use dst.valid_shape as the copy dimensions + # The dst tile defines how many elements to write + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, lanes): + mask, remained = pto.make_mask(dtype, remained) + data = pto.vlds(src[row, col:]) + pto.vsts(data, dst[row, col:], mask) + + return None \ No newline at end of file diff --git a/lib/TileOps/tmrgsort_template.py b/lib/TileOps/tmrgsort_template.py new file mode 100644 index 000000000..ac0c8e1ca --- /dev/null +++ b/lib/TileOps/tmrgsort_template.py @@ -0,0 +1,273 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tmrgsort""" + +import tilelang_dsl as pto + +STRUCT_SIZE = 8 # bytes per structure (value + index) +STRUCT_SIZE_SHIFT = 3 # log2(8) +BLOCK_NUM = 4 + + +@pto.inline_proc +def tmrgsort_single_list_instr(dst: pto.Tile, src: pto.Tile, + num_structures, repeat_times): + dst_ptr = dst.as_ptr() + src_ptr = src.as_ptr() + + count = pto.i64(num_structures) + count = count | (pto.i64(num_structures) << pto.i64(16)) + count = count | (pto.i64(num_structures) << pto.i64(32)) + count = count | (pto.i64(num_structures) << pto.i64(48)) + + offset = num_structures * STRUCT_SIZE // pto.bytewidth(dst.element_type) + src0 = src_ptr + src1 = pto.addptr(src_ptr, offset) + src2 = pto.addptr(src_ptr, offset * 2) + src3 = pto.addptr(src_ptr, offset * 3) + + config = pto.i64(repeat_times) + config = config | (pto.i64(0b1111) << pto.i64(8)) + config = config | (pto.i64(0b0) << pto.i64(12)) + + pto.vmrgsort4(dst_ptr, src0, src1, src2, src3, pto.i64(count), pto.i64(config)) + return + + +@pto.inline_proc +def tmrgsort_multi_list2_instr(tmp: pto.Tile, src0: pto.Tile, src1: pto.Tile, + src0_structures: int, src1_structures: int): + tmp_ptr = tmp.as_ptr() + src0_ptr = src0.as_ptr() + src1_ptr = src1.as_ptr() + + count = pto.i64(src0_structures) + count = count | (pto.i64(src1_structures) << pto.i64(16)) + + repeat_time = 1 + list_mask = 0b0011 + exhausted_bit = 0 + + exhausted_str = pto.get_op_attr("exhausted", "0") + if pto.constexpr(exhausted_str == "1"): + exhausted_bit = 1 + + config = pto.i64(repeat_time) + config = config | (pto.i64(list_mask) << pto.i64(8)) + config = config | (pto.i64(exhausted_bit) << pto.i64(12)) + + pto.vmrgsort4(tmp_ptr, src0_ptr, src1_ptr, src0_ptr, src0_ptr, + count, config) + + return + + +@pto.inline_proc +def tmrgsort_multi_list3_instr(tmp: pto.Tile, src0: pto.Tile, src1: pto.Tile, src2: pto.Tile, + src0_structures: int, src1_structures: int, src2_structures: int): + tmp_ptr = tmp.as_ptr() + src0_ptr = src0.as_ptr() + src1_ptr = src1.as_ptr() + src2_ptr = src2.as_ptr() + + count = pto.i64(src0_structures) + count = count | (pto.i64(src1_structures) << pto.i64(16)) + count = count | (pto.i64(src2_structures) << pto.i64(32)) + + repeat_time = 1 + list_mask = 0b0111 + exhausted_bit = 0 + + exhausted_str = pto.get_op_attr("exhausted", "0") + if pto.constexpr(exhausted_str == "1"): + exhausted_bit = 1 + + config = pto.i64(repeat_time) + config = config | (pto.i64(list_mask) << pto.i64(8)) + config = config | (pto.i64(exhausted_bit) << pto.i64(12)) + + pto.vmrgsort4(tmp_ptr, src0_ptr, src1_ptr, src2_ptr, src0_ptr, + count, config) + + return + + +@pto.inline_proc +def tmrgsort_multi_list4_instr(tmp: pto.Tile, src0: pto.Tile, src1: pto.Tile, + src2: pto.Tile, src3: pto.Tile, + src0_structures: int, src1_structures: int, + src2_structures: int, src3_structures: int): + dtype = tmp.element_type + + tmp_ptr = tmp.as_ptr() + src0_ptr = src0.as_ptr() + src1_ptr = src1.as_ptr() + src2_ptr = src2.as_ptr() + src3_ptr = src3.as_ptr() + + count = pto.i64(src0_structures) + count = count | (pto.i64(src1_structures) << pto.i64(16)) + count = count | (pto.i64(src2_structures) << pto.i64(32)) + count = count | (pto.i64(src3_structures) << pto.i64(48)) + + repeat_time = 1 + list_mask = 0b1111 + exhausted_bit = 0 + + exhausted_str = pto.get_op_attr("exhausted", "0") + if pto.constexpr(exhausted_str == "1"): + exhausted_bit = 1 + + config = pto.i64(repeat_time) + config = config | (pto.i64(list_mask) << pto.i64(8)) + config = config | (pto.i64(exhausted_bit) << pto.i64(12)) + + pto.vmrgsort4(tmp_ptr, src0_ptr, src1_ptr, src2_ptr, src3_ptr, + count, config) + + return + + +@pto.vkernel( + target="a5", + op="pto.tmrgsort", + advanced=True, +) +def template_tmrgsort_single_list(src: pto.Tile, block_len: pto.AnyInt, dst: pto.Tile): + """Format1 template: single list internal block sorting. + + Standard Format1: single vmrgsort4 for block sorting. + TopK variant is handled by ST kernel via iterative tmrgsort + tmov calls. + """ + src_valid_col = src.valid_shape[1] + + # Block length in structures + num_structures = block_len * pto.bytewidth(src.element_type) >> STRUCT_SIZE_SHIFT + + # Repeat times: how many groups of 4 blocks need merging + repeat_times = src_valid_col // (block_len * BLOCK_NUM) + + # Standard Format1: single merge operation + tmrgsort_single_list_instr(dst, src, num_structures, repeat_times) + + return None + + +@pto.vkernel( + target="a5", + op="pto.tmrgsort", + advanced=True, +) +def template_tmrgsort_multi_list2(src0: pto.Tile, src1: pto.Tile, + tmp: pto.Tile, dst: pto.Tile, ex_vec: pto.AnyInt): + dtype = dst.element_type + bw = pto.bytewidth(dtype) + + src0_valid_col = src0.valid_shape[1] + src1_valid_col = src1.valid_shape[1] + dst_valid_col = dst.valid_shape[1] + + if pto.constexpr(bw == 4): + src0_structures = src0_valid_col // 2 + src1_structures = src1_valid_col // 2 + else: + src0_structures = src0_valid_col // 4 + src1_structures = src1_valid_col // 4 + + dst_elements = dst_valid_col + + tmrgsort_multi_list2_instr(tmp, src0, src1, src0_structures, src1_structures) + + lanes = pto.get_lanes(dtype) + for col in range(0, dst_elements, lanes): + remained = dst_elements - col + mask, remained = pto.make_mask(dtype, remained) + data = pto.vlds(tmp[0, col:]) + pto.vsts(data, dst[0, col:], mask) + + return None + + +@pto.vkernel( + target="a5", + op="pto.tmrgsort", + advanced=True, +) +def template_tmrgsort_multi_list3(src0: pto.Tile, src1: pto.Tile, src2: pto.Tile, + tmp: pto.Tile, dst: pto.Tile, ex_vec: pto.AnyInt): + dtype = dst.element_type + bw = pto.bytewidth(dtype) + + src0_valid_col = src0.valid_shape[1] + src1_valid_col = src1.valid_shape[1] + src2_valid_col = src2.valid_shape[1] + dst_valid_col = dst.valid_shape[1] + + if pto.constexpr(bw == 4): + src0_structures = src0_valid_col // 2 + src1_structures = src1_valid_col // 2 + src2_structures = src2_valid_col // 2 + else: + src0_structures = src0_valid_col // 4 + src1_structures = src1_valid_col // 4 + src2_structures = src2_valid_col // 4 + + dst_elements = dst_valid_col + + tmrgsort_multi_list3_instr(tmp, src0, src1, src2, src0_structures, src1_structures, src2_structures) + + lanes = pto.get_lanes(dtype) + for col in range(0, dst_elements, lanes): + remained = dst_elements - col + mask, remained = pto.make_mask(dtype, remained) + data = pto.vlds(tmp[0, col:]) + pto.vsts(data, dst[0, col:], mask) + + return None + + +@pto.vkernel( + target="a5", + op="pto.tmrgsort", + advanced=True, +) +def template_tmrgsort_multi_list4(src0: pto.Tile, src1: pto.Tile, src2: pto.Tile, src3: pto.Tile, + tmp: pto.Tile, dst: pto.Tile, ex_vec: pto.AnyInt): + dtype = dst.element_type + bw = pto.bytewidth(dtype) + + src0_valid_col = src0.valid_shape[1] + src1_valid_col = src1.valid_shape[1] + src2_valid_col = src2.valid_shape[1] + src3_valid_col = src3.valid_shape[1] + dst_valid_col = dst.valid_shape[1] + + if pto.constexpr(bw == 4): + src0_structures = src0_valid_col // 2 + src1_structures = src1_valid_col // 2 + src2_structures = src2_valid_col // 2 + src3_structures = src3_valid_col // 2 + else: + src0_structures = src0_valid_col // 4 + src1_structures = src1_valid_col // 4 + src2_structures = src2_valid_col // 4 + src3_structures = src3_valid_col // 4 + + dst_elements = dst_valid_col + + tmrgsort_multi_list4_instr(tmp, src0, src1, src2, src3, src0_structures, src1_structures, src2_structures, src3_structures) + + lanes = pto.get_lanes(dtype) + for col in range(0, dst_elements, lanes): + remained = dst_elements - col + mask, remained = pto.make_mask(dtype, remained) + data = pto.vlds(tmp[0, col:]) + pto.vsts(data, dst[0, col:], mask) + + return None \ No newline at end of file diff --git a/lib/TileOps/tmul_template.py b/lib/TileOps/tmul_template.py new file mode 100644 index 000000000..ae7adf44e --- /dev/null +++ b/lib/TileOps/tmul_template.py @@ -0,0 +1,32 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tmul""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tmul" +) +def template_tmul(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + multiplied = pto.vmul(lhs, rhs, mask) + pto.vsts(multiplied, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/tmuls_template.py b/lib/TileOps/tmuls_template.py new file mode 100644 index 000000000..8d02ea826 --- /dev/null +++ b/lib/TileOps/tmuls_template.py @@ -0,0 +1,31 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tmuls""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tmuls", +) +def template_tmuls(src: pto.Tile, scalar: pto.AnyType, dst: pto.Tile): + dtype = src.element_type + valid_rows, valid_cols = src.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + vec = pto.vlds(src[row, col:]) + result = pto.vmuls(vec, scalar, mask) + pto.vsts(result, dst[row, col:], mask) + return diff --git a/lib/TileOps/tneg_template.py b/lib/TileOps/tneg_template.py new file mode 100644 index 000000000..8e10ce4ca --- /dev/null +++ b/lib/TileOps/tneg_template.py @@ -0,0 +1,29 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tneg""" + +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tneg" +) +def template_tneg(src: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + vinput = pto.vlds(src[row, col:]) + result = pto.vneg(vinput, mask) + pto.vsts(result, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/tnot_template.py b/lib/TileOps/tnot_template.py new file mode 100644 index 000000000..f4728e853 --- /dev/null +++ b/lib/TileOps/tnot_template.py @@ -0,0 +1,30 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tnot""" + +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tnot", + dtypes=[(pto.AnyInt, pto.AnyInt)] +) +def template_tnot(src: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + vinput = pto.vlds(src[row, col:]) + result = pto.vnot(vinput, mask) + pto.vsts(result, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/tor_template.py b/lib/TileOps/tor_template.py new file mode 100644 index 000000000..e8be63d5e --- /dev/null +++ b/lib/TileOps/tor_template.py @@ -0,0 +1,32 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tor""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tor" +) +def template_tor(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + result = pto.vor(lhs, rhs, mask) + pto.vsts(result, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/tors_template.py b/lib/TileOps/tors_template.py new file mode 100644 index 000000000..542ee4044 --- /dev/null +++ b/lib/TileOps/tors_template.py @@ -0,0 +1,40 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tors + +Note: A5 hardware implements tors as: + TEXPANDS_IMPL(dst, scalar); // broadcast scalar to dst + TOR_IMPL(dst, src, dst); // dst = src | dst + +This template uses vbr + vor to achieve element-wise bitwise OR. +Only supports tile, scalar order. +""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tors", +) +def template_tors(src: pto.Tile, scalar: pto.AnyType, dst: pto.Tile): + dtype = src.element_type + valid_rows, valid_cols = src.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + vec = pto.vlds(src[row, col:]) + scalar_vec = pto.vbr(scalar) + result = pto.vor(vec, scalar_vec, mask) + pto.vsts(result, dst[row, col:], mask) + return diff --git a/lib/TileOps/tpartadd_template.py b/lib/TileOps/tpartadd_template.py new file mode 100644 index 000000000..8e9716389 --- /dev/null +++ b/lib/TileOps/tpartadd_template.py @@ -0,0 +1,79 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tpartadd""" + +import tilelang_dsl as pto + + +@pto.inline_proc +def tpart_op_instr(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile, valid_rows, valid_cols): + dtype = dst.element_type + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + summed = pto.vadd(lhs, rhs, mask) + pto.vsts(summed, dst[row, col:], mask) + return None + +@pto.inline_proc +def tpart_copy_instr(dst: pto.Tile, src: pto.Tile, valid_rows, valid_cols, start_row): + dtype = dst.element_type + for row in range(start_row, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + val = pto.vlds(src[row, col:]) + pto.vsts(val, dst[row, col:], mask) + return None + +@pto.inline_proc +def tpart_op(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile, + dst_valid_rows, dst_valid_cols, + src1_valid_rows, src1_valid_cols): + + src1_eq_dst = (src1_valid_rows == dst_valid_rows and src1_valid_cols == dst_valid_cols) + src1_row_lt_dst = (src1_valid_rows < dst_valid_rows and src1_valid_cols == dst_valid_cols) + src1_col_lt_dst = (src1_valid_rows <= dst_valid_rows and src1_valid_cols < dst_valid_cols) + + if src1_eq_dst: + tpart_op_instr(dst, src0, src1, dst_valid_rows, dst_valid_cols) + elif src1_col_lt_dst: + tpart_copy_instr(dst, src0, dst_valid_rows, dst_valid_cols, 0) + if src1_valid_cols > 0: + tpart_op_instr(dst, src0, src1, src1_valid_rows, src1_valid_cols) + elif src1_row_lt_dst: + if src1_valid_cols > 0: + tpart_op_instr(dst, src0, src1, src1_valid_rows, src1_valid_cols) + tpart_copy_instr(dst, src0, dst_valid_rows, dst_valid_cols, src1_valid_rows) + + return + +@pto.vkernel( + target="a5", + op="pto.tpartadd" +) +def template_tpartadd(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dst_valid_rows, dst_valid_cols = dst.valid_shape + src0_valid_rows, src0_valid_cols = src0.valid_shape + src1_valid_rows, src1_valid_cols = src1.valid_shape + + src0_eq_dst = (src0_valid_rows == dst_valid_rows and src0_valid_cols == dst_valid_cols) + src1_eq_dst = (src1_valid_rows == dst_valid_rows and src1_valid_cols == dst_valid_cols) + + if src0_eq_dst or src1_eq_dst: + if src0_eq_dst: + tpart_op(dst, src0, src1, dst_valid_rows, dst_valid_cols, src1_valid_rows, src1_valid_cols) + elif src1_eq_dst: + tpart_op(dst, src1, src0, dst_valid_rows, dst_valid_cols, src0_valid_rows, src0_valid_cols) + # TODO: raise an error later + + return \ No newline at end of file diff --git a/lib/TileOps/tpartmax_template.py b/lib/TileOps/tpartmax_template.py new file mode 100644 index 000000000..87060cd70 --- /dev/null +++ b/lib/TileOps/tpartmax_template.py @@ -0,0 +1,54 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tpartmax""" + +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tpartmax", + advanced=True, +) +def template_tpartmax(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + src0_valid_rows, src0_valid_cols = src0.valid_shape + src1_valid_rows, src1_valid_cols = src1.valid_shape + lanes = pto.get_lanes(dtype) + + pad_scalar = pto.PadValue.MIN.eval(dtype) + pad_vec = pto.vbr(pad_scalar) + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, lanes): + mask, remained = pto.make_mask(dtype, remained) + pto.vsts(pad_vec, dst[row, col:], mask) + + pto.mem_bar(pto.BarrierType.VST_VLD) + + for row in range(0, src0_valid_rows, 1): + remained = src0_valid_cols + for col in range(0, src0_valid_cols, lanes): + mask, remained = pto.make_mask(dtype, remained) + vec0 = pto.vlds(src0[row, col:]) + pto.vsts(vec0, dst[row, col:], mask) + + pto.mem_bar(pto.BarrierType.VST_VLD) + + for row in range(0, src1_valid_rows, 1): + remained = src1_valid_cols + for col in range(0, src1_valid_cols, lanes): + mask, remained = pto.make_mask(dtype, remained) + vec_dst = pto.vlds(dst[row, col:]) + vec1 = pto.vlds(src1[row, col:]) + result = pto.vmax(vec_dst, vec1, mask) + pto.vsts(result, dst[row, col:], mask) + + return \ No newline at end of file diff --git a/lib/TileOps/tpartmin_template.py b/lib/TileOps/tpartmin_template.py new file mode 100644 index 000000000..cc21c609c --- /dev/null +++ b/lib/TileOps/tpartmin_template.py @@ -0,0 +1,54 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tpartmin""" + +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tpartmin", + advanced=True, +) +def template_tpartmin(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + src0_valid_rows, src0_valid_cols = src0.valid_shape + src1_valid_rows, src1_valid_cols = src1.valid_shape + lanes = pto.get_lanes(dtype) + + pad_scalar = pto.PadValue.MAX.eval(dtype) + pad_vec = pto.vbr(pad_scalar) + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, lanes): + mask, remained = pto.make_mask(dtype, remained) + pto.vsts(pad_vec, dst[row, col:], mask) + + pto.mem_bar(pto.BarrierType.VST_VLD) + + for row in range(0, src0_valid_rows, 1): + remained = src0_valid_cols + for col in range(0, src0_valid_cols, lanes): + mask, remained = pto.make_mask(dtype, remained) + vec0 = pto.vlds(src0[row, col:]) + pto.vsts(vec0, dst[row, col:], mask) + + pto.mem_bar(pto.BarrierType.VST_VLD) + + for row in range(0, src1_valid_rows, 1): + remained = src1_valid_cols + for col in range(0, src1_valid_cols, lanes): + mask, remained = pto.make_mask(dtype, remained) + vec_dst = pto.vlds(dst[row, col:]) + vec1 = pto.vlds(src1[row, col:]) + result = pto.vmin(vec_dst, vec1, mask) + pto.vsts(result, dst[row, col:], mask) + + return \ No newline at end of file diff --git a/lib/TileOps/tpartmul_template.py b/lib/TileOps/tpartmul_template.py new file mode 100644 index 000000000..fe39597ca --- /dev/null +++ b/lib/TileOps/tpartmul_template.py @@ -0,0 +1,79 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tpartmul""" + +import tilelang_dsl as pto + + +@pto.inline_proc +def tpart_op_instr(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile, valid_rows, valid_cols): + dtype = dst.element_type + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + summed = pto.vmul(lhs, rhs, mask) + pto.vsts(summed, dst[row, col:], mask) + return None + +@pto.inline_proc +def tpart_copy_instr(dst: pto.Tile, src: pto.Tile, valid_rows, valid_cols, start_row): + dtype = dst.element_type + for row in range(start_row, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + val = pto.vlds(src[row, col:]) + pto.vsts(val, dst[row, col:], mask) + return None + +@pto.inline_proc +def tpart_op(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile, + dst_valid_rows, dst_valid_cols, + src1_valid_rows, src1_valid_cols): + + src1_eq_dst = (src1_valid_rows == dst_valid_rows and src1_valid_cols == dst_valid_cols) + src1_row_lt_dst = (src1_valid_rows < dst_valid_rows and src1_valid_cols == dst_valid_cols) + src1_col_lt_dst = (src1_valid_rows <= dst_valid_rows and src1_valid_cols < dst_valid_cols) + + if src1_eq_dst: + tpart_op_instr(dst, src0, src1, dst_valid_rows, dst_valid_cols) + elif src1_col_lt_dst: + tpart_copy_instr(dst, src0, dst_valid_rows, dst_valid_cols, 0) + if src1_valid_cols > 0: + tpart_op_instr(dst, src0, src1, src1_valid_rows, src1_valid_cols) + elif src1_row_lt_dst: + if src1_valid_cols > 0: + tpart_op_instr(dst, src0, src1, src1_valid_rows, src1_valid_cols) + tpart_copy_instr(dst, src0, dst_valid_rows, dst_valid_cols, src1_valid_rows) + + return + +@pto.vkernel( + target="a5", + op="pto.tpartmul" +) +def template_tpartmul(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dst_valid_rows, dst_valid_cols = dst.valid_shape + src0_valid_rows, src0_valid_cols = src0.valid_shape + src1_valid_rows, src1_valid_cols = src1.valid_shape + + src0_eq_dst = (src0_valid_rows == dst_valid_rows and src0_valid_cols == dst_valid_cols) + src1_eq_dst = (src1_valid_rows == dst_valid_rows and src1_valid_cols == dst_valid_cols) + + if src0_eq_dst or src1_eq_dst: + if src0_eq_dst: + tpart_op(dst, src0, src1, dst_valid_rows, dst_valid_cols, src1_valid_rows, src1_valid_cols) + elif src1_eq_dst: + tpart_op(dst, src1, src0, dst_valid_rows, dst_valid_cols, src0_valid_rows, src0_valid_cols) + # TODO: raise an error later + + return \ No newline at end of file diff --git a/lib/TileOps/tprelu_template.py b/lib/TileOps/tprelu_template.py new file mode 100644 index 000000000..45f83be10 --- /dev/null +++ b/lib/TileOps/tprelu_template.py @@ -0,0 +1,44 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tprelu (Parametric ReLU)""" + +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tprelu", + dtypes=[(pto.f16, pto.f16, pto.f16, pto.f16), (pto.f32, pto.f32, pto.f32, pto.f32), + (pto.f16, pto.f16, pto.i8, pto.f16), (pto.f32, pto.f32, pto.i8, pto.f32)], + advanced=True +) +def template_tprelu(src0: pto.Tile, src1: pto.Tile, tmp: pto.Tile, dst: pto.Tile): + """Parametric ReLU: dst = src0 > 0 ? src0 : src0 * src1. + + Semantics: + For each element (i, j): + dst[i, j] = src0[i, j] > 0 ? src0[i, j] : src0[i, j] * src1[i, j] + + Supported data types: f16, f32 + tmp is a workspace buffer with same dtype as src0. + """ + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + lanes = pto.get_lanes(dtype) + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, lanes): + mask, remained = pto.make_mask(dtype, remained) + vec0 = pto.vlds(src0[row, col:]) + vec1 = pto.vlds(src1[row, col:]) + result = pto.vprelu(vec0, vec1, mask) + pto.vsts(result, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/trandom_template.py b/lib/TileOps/trandom_template.py new file mode 100644 index 000000000..af7bdd0c6 --- /dev/null +++ b/lib/TileOps/trandom_template.py @@ -0,0 +1,285 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.trandom - unified template using constexpr""" + +import tilelang_dsl as pto + +TRANDOM_ONCE_REPEAT = 4 +TRANDOM_CONST_0 = 0xD2511F53 +TRANDOM_CONST_1 = 0xCD9E8D57 +TRANDOM_CONST_KEY_ADD_0 = 0x9E3779B9 +TRANDOM_CONST_KEY_ADD_1 = 0xBB67AE85 + + +def _check_row_major(dst) -> bool: + return dst.config.b_layout == pto.BLayout.ROW_MAJOR + + +@pto.vkernel( + target="a5", + op="pto.trandom", + dtypes=[ + (pto.i32, pto.i32, pto.i32, pto.i32, pto.i32, pto.i32, pto.ui32), + ], + constraints=[_check_row_major], + advanced=True, +) +def template_trandom( + key0: pto.i32, + key1: pto.i32, + counter0: pto.i32, + counter1: pto.i32, + counter2: pto.i32, + counter3: pto.i32, + dst: pto.Tile, +): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + lanes = pto.get_lanes(dtype) + n_loop = (valid_cols + TRANDOM_ONCE_REPEAT * lanes - 1) // (TRANDOM_ONCE_REPEAT * lanes) + rounds_str = pto.get_op_attr("rounds", "10") + + pg = pto.pset_b32(pto.PAT.ALL) + + ctr0_init = pto.vbitcast(pto.vbr(counter0), pto.ui32) + ctr1_init = pto.vbitcast(pto.vbr(counter1), pto.ui32) + ctr2_init = pto.vbitcast(pto.vbr(counter2), pto.ui32) + ctr3_init = pto.vbitcast(pto.vbr(counter3), pto.ui32) + key0_v = pto.vbitcast(pto.vbr(key0), pto.ui32) + key1_v = pto.vbitcast(pto.vbr(key1), pto.ui32) + zeros = pto.vbr(pto.ui32(0)) + const0 = pto.vbr(pto.ui32(TRANDOM_CONST_0)) + const1 = pto.vbr(pto.ui32(TRANDOM_CONST_1)) + inc_idx = pto.vbitcast(pto.vci(pto.i32(0)), pto.ui32) + + ctr0, pd = pto.vaddc(ctr0_init, inc_idx, pg) + ctr1, pd = pto.vaddcs(ctr1_init, zeros, pd, pg) + ctr2, pd = pto.vaddcs(ctr2_init, zeros, pd, pg) + ctr3, pd = pto.vaddcs(ctr3_init, zeros, pd, pg) + + for i in range(0, valid_rows, 1): + s_reg = valid_cols + counter_add_val = lanes + for j in range(0, n_loop, 1): + tmp_ctr0 = ctr0 + tmp_ctr1 = ctr1 + tmp_ctr2 = ctr2 + tmp_ctr3 = ctr3 + tmp_key0 = key0_v + tmp_key1 = key1_v + + if pto.constexpr(rounds_str == "10"): + tmpL0, tmpH0 = pto.vmull(tmp_ctr0, const0, pg) + tmpL1, tmpH1 = pto.vmull(tmp_ctr2, const1, pg) + tmpH1 = pto.vxor(tmpH1, tmp_ctr1, pg) + tmp_ctr0 = pto.vxor(tmpH1, tmp_key0, pg) + tmpH0 = pto.vxor(tmpH0, tmp_ctr3, pg) + tmp_ctr2 = pto.vxor(tmpH0, tmp_key1, pg) + tmp_key0 = pto.vadds(tmp_key0, pto.ui32(TRANDOM_CONST_KEY_ADD_0), pg) + tmp_key1 = pto.vadds(tmp_key1, pto.ui32(TRANDOM_CONST_KEY_ADD_1), pg) + tmp_ctr1 = tmpL1 + tmp_ctr3 = tmpL0 + tmpL0, tmpH0 = pto.vmull(tmp_ctr0, const0, pg) + tmpL1, tmpH1 = pto.vmull(tmp_ctr2, const1, pg) + tmpH1 = pto.vxor(tmpH1, tmp_ctr1, pg) + tmp_ctr0 = pto.vxor(tmpH1, tmp_key0, pg) + tmpH0 = pto.vxor(tmpH0, tmp_ctr3, pg) + tmp_ctr2 = pto.vxor(tmpH0, tmp_key1, pg) + tmp_key0 = pto.vadds(tmp_key0, pto.ui32(TRANDOM_CONST_KEY_ADD_0), pg) + tmp_key1 = pto.vadds(tmp_key1, pto.ui32(TRANDOM_CONST_KEY_ADD_1), pg) + tmp_ctr1 = tmpL1 + tmp_ctr3 = tmpL0 + tmpL0, tmpH0 = pto.vmull(tmp_ctr0, const0, pg) + tmpL1, tmpH1 = pto.vmull(tmp_ctr2, const1, pg) + tmpH1 = pto.vxor(tmpH1, tmp_ctr1, pg) + tmp_ctr0 = pto.vxor(tmpH1, tmp_key0, pg) + tmpH0 = pto.vxor(tmpH0, tmp_ctr3, pg) + tmp_ctr2 = pto.vxor(tmpH0, tmp_key1, pg) + tmp_key0 = pto.vadds(tmp_key0, pto.ui32(TRANDOM_CONST_KEY_ADD_0), pg) + tmp_key1 = pto.vadds(tmp_key1, pto.ui32(TRANDOM_CONST_KEY_ADD_1), pg) + tmp_ctr1 = tmpL1 + tmp_ctr3 = tmpL0 + tmpL0, tmpH0 = pto.vmull(tmp_ctr0, const0, pg) + tmpL1, tmpH1 = pto.vmull(tmp_ctr2, const1, pg) + tmpH1 = pto.vxor(tmpH1, tmp_ctr1, pg) + tmp_ctr0 = pto.vxor(tmpH1, tmp_key0, pg) + tmpH0 = pto.vxor(tmpH0, tmp_ctr3, pg) + tmp_ctr2 = pto.vxor(tmpH0, tmp_key1, pg) + tmp_key0 = pto.vadds(tmp_key0, pto.ui32(TRANDOM_CONST_KEY_ADD_0), pg) + tmp_key1 = pto.vadds(tmp_key1, pto.ui32(TRANDOM_CONST_KEY_ADD_1), pg) + tmp_ctr1 = tmpL1 + tmp_ctr3 = tmpL0 + tmpL0, tmpH0 = pto.vmull(tmp_ctr0, const0, pg) + tmpL1, tmpH1 = pto.vmull(tmp_ctr2, const1, pg) + tmpH1 = pto.vxor(tmpH1, tmp_ctr1, pg) + tmp_ctr0 = pto.vxor(tmpH1, tmp_key0, pg) + tmpH0 = pto.vxor(tmpH0, tmp_ctr3, pg) + tmp_ctr2 = pto.vxor(tmpH0, tmp_key1, pg) + tmp_key0 = pto.vadds(tmp_key0, pto.ui32(TRANDOM_CONST_KEY_ADD_0), pg) + tmp_key1 = pto.vadds(tmp_key1, pto.ui32(TRANDOM_CONST_KEY_ADD_1), pg) + tmp_ctr1 = tmpL1 + tmp_ctr3 = tmpL0 + tmpL0, tmpH0 = pto.vmull(tmp_ctr0, const0, pg) + tmpL1, tmpH1 = pto.vmull(tmp_ctr2, const1, pg) + tmpH1 = pto.vxor(tmpH1, tmp_ctr1, pg) + tmp_ctr0 = pto.vxor(tmpH1, tmp_key0, pg) + tmpH0 = pto.vxor(tmpH0, tmp_ctr3, pg) + tmp_ctr2 = pto.vxor(tmpH0, tmp_key1, pg) + tmp_key0 = pto.vadds(tmp_key0, pto.ui32(TRANDOM_CONST_KEY_ADD_0), pg) + tmp_key1 = pto.vadds(tmp_key1, pto.ui32(TRANDOM_CONST_KEY_ADD_1), pg) + tmp_ctr1 = tmpL1 + tmp_ctr3 = tmpL0 + tmpL0, tmpH0 = pto.vmull(tmp_ctr0, const0, pg) + tmpL1, tmpH1 = pto.vmull(tmp_ctr2, const1, pg) + tmpH1 = pto.vxor(tmpH1, tmp_ctr1, pg) + tmp_ctr0 = pto.vxor(tmpH1, tmp_key0, pg) + tmpH0 = pto.vxor(tmpH0, tmp_ctr3, pg) + tmp_ctr2 = pto.vxor(tmpH0, tmp_key1, pg) + tmp_key0 = pto.vadds(tmp_key0, pto.ui32(TRANDOM_CONST_KEY_ADD_0), pg) + tmp_key1 = pto.vadds(tmp_key1, pto.ui32(TRANDOM_CONST_KEY_ADD_1), pg) + tmp_ctr1 = tmpL1 + tmp_ctr3 = tmpL0 + tmpL0, tmpH0 = pto.vmull(tmp_ctr0, const0, pg) + tmpL1, tmpH1 = pto.vmull(tmp_ctr2, const1, pg) + tmpH1 = pto.vxor(tmpH1, tmp_ctr1, pg) + tmp_ctr0 = pto.vxor(tmpH1, tmp_key0, pg) + tmpH0 = pto.vxor(tmpH0, tmp_ctr3, pg) + tmp_ctr2 = pto.vxor(tmpH0, tmp_key1, pg) + tmp_key0 = pto.vadds(tmp_key0, pto.ui32(TRANDOM_CONST_KEY_ADD_0), pg) + tmp_key1 = pto.vadds(tmp_key1, pto.ui32(TRANDOM_CONST_KEY_ADD_1), pg) + tmp_ctr1 = tmpL1 + tmp_ctr3 = tmpL0 + tmpL0, tmpH0 = pto.vmull(tmp_ctr0, const0, pg) + tmpL1, tmpH1 = pto.vmull(tmp_ctr2, const1, pg) + tmpH1 = pto.vxor(tmpH1, tmp_ctr1, pg) + tmp_ctr0 = pto.vxor(tmpH1, tmp_key0, pg) + tmpH0 = pto.vxor(tmpH0, tmp_ctr3, pg) + tmp_ctr2 = pto.vxor(tmpH0, tmp_key1, pg) + tmp_key0 = pto.vadds(tmp_key0, pto.ui32(TRANDOM_CONST_KEY_ADD_0), pg) + tmp_key1 = pto.vadds(tmp_key1, pto.ui32(TRANDOM_CONST_KEY_ADD_1), pg) + tmp_ctr1 = tmpL1 + tmp_ctr3 = tmpL0 + tmpL0, tmpH0 = pto.vmull(tmp_ctr0, const0, pg) + tmpL1, tmpH1 = pto.vmull(tmp_ctr2, const1, pg) + tmpH1 = pto.vxor(tmpH1, tmp_ctr1, pg) + tmp_ctr0 = pto.vxor(tmpH1, tmp_key0, pg) + tmpH0 = pto.vxor(tmpH0, tmp_ctr3, pg) + tmp_ctr2 = pto.vxor(tmpH0, tmp_key1, pg) + tmp_key0 = pto.vadds(tmp_key0, pto.ui32(TRANDOM_CONST_KEY_ADD_0), pg) + tmp_key1 = pto.vadds(tmp_key1, pto.ui32(TRANDOM_CONST_KEY_ADD_1), pg) + tmp_ctr1 = tmpL1 + tmp_ctr3 = tmpL0 + elif pto.constexpr(rounds_str == "7"): + tmpL0, tmpH0 = pto.vmull(tmp_ctr0, const0, pg) + tmpL1, tmpH1 = pto.vmull(tmp_ctr2, const1, pg) + tmpH1 = pto.vxor(tmpH1, tmp_ctr1, pg) + tmp_ctr0 = pto.vxor(tmpH1, tmp_key0, pg) + tmpH0 = pto.vxor(tmpH0, tmp_ctr3, pg) + tmp_ctr2 = pto.vxor(tmpH0, tmp_key1, pg) + tmp_key0 = pto.vadds(tmp_key0, pto.ui32(TRANDOM_CONST_KEY_ADD_0), pg) + tmp_key1 = pto.vadds(tmp_key1, pto.ui32(TRANDOM_CONST_KEY_ADD_1), pg) + tmp_ctr1 = tmpL1 + tmp_ctr3 = tmpL0 + tmpL0, tmpH0 = pto.vmull(tmp_ctr0, const0, pg) + tmpL1, tmpH1 = pto.vmull(tmp_ctr2, const1, pg) + tmpH1 = pto.vxor(tmpH1, tmp_ctr1, pg) + tmp_ctr0 = pto.vxor(tmpH1, tmp_key0, pg) + tmpH0 = pto.vxor(tmpH0, tmp_ctr3, pg) + tmp_ctr2 = pto.vxor(tmpH0, tmp_key1, pg) + tmp_key0 = pto.vadds(tmp_key0, pto.ui32(TRANDOM_CONST_KEY_ADD_0), pg) + tmp_key1 = pto.vadds(tmp_key1, pto.ui32(TRANDOM_CONST_KEY_ADD_1), pg) + tmp_ctr1 = tmpL1 + tmp_ctr3 = tmpL0 + tmpL0, tmpH0 = pto.vmull(tmp_ctr0, const0, pg) + tmpL1, tmpH1 = pto.vmull(tmp_ctr2, const1, pg) + tmpH1 = pto.vxor(tmpH1, tmp_ctr1, pg) + tmp_ctr0 = pto.vxor(tmpH1, tmp_key0, pg) + tmpH0 = pto.vxor(tmpH0, tmp_ctr3, pg) + tmp_ctr2 = pto.vxor(tmpH0, tmp_key1, pg) + tmp_key0 = pto.vadds(tmp_key0, pto.ui32(TRANDOM_CONST_KEY_ADD_0), pg) + tmp_key1 = pto.vadds(tmp_key1, pto.ui32(TRANDOM_CONST_KEY_ADD_1), pg) + tmp_ctr1 = tmpL1 + tmp_ctr3 = tmpL0 + tmpL0, tmpH0 = pto.vmull(tmp_ctr0, const0, pg) + tmpL1, tmpH1 = pto.vmull(tmp_ctr2, const1, pg) + tmpH1 = pto.vxor(tmpH1, tmp_ctr1, pg) + tmp_ctr0 = pto.vxor(tmpH1, tmp_key0, pg) + tmpH0 = pto.vxor(tmpH0, tmp_ctr3, pg) + tmp_ctr2 = pto.vxor(tmpH0, tmp_key1, pg) + tmp_key0 = pto.vadds(tmp_key0, pto.ui32(TRANDOM_CONST_KEY_ADD_0), pg) + tmp_key1 = pto.vadds(tmp_key1, pto.ui32(TRANDOM_CONST_KEY_ADD_1), pg) + tmp_ctr1 = tmpL1 + tmp_ctr3 = tmpL0 + tmpL0, tmpH0 = pto.vmull(tmp_ctr0, const0, pg) + tmpL1, tmpH1 = pto.vmull(tmp_ctr2, const1, pg) + tmpH1 = pto.vxor(tmpH1, tmp_ctr1, pg) + tmp_ctr0 = pto.vxor(tmpH1, tmp_key0, pg) + tmpH0 = pto.vxor(tmpH0, tmp_ctr3, pg) + tmp_ctr2 = pto.vxor(tmpH0, tmp_key1, pg) + tmp_key0 = pto.vadds(tmp_key0, pto.ui32(TRANDOM_CONST_KEY_ADD_0), pg) + tmp_key1 = pto.vadds(tmp_key1, pto.ui32(TRANDOM_CONST_KEY_ADD_1), pg) + tmp_ctr1 = tmpL1 + tmp_ctr3 = tmpL0 + tmpL0, tmpH0 = pto.vmull(tmp_ctr0, const0, pg) + tmpL1, tmpH1 = pto.vmull(tmp_ctr2, const1, pg) + tmpH1 = pto.vxor(tmpH1, tmp_ctr1, pg) + tmp_ctr0 = pto.vxor(tmpH1, tmp_key0, pg) + tmpH0 = pto.vxor(tmpH0, tmp_ctr3, pg) + tmp_ctr2 = pto.vxor(tmpH0, tmp_key1, pg) + tmp_key0 = pto.vadds(tmp_key0, pto.ui32(TRANDOM_CONST_KEY_ADD_0), pg) + tmp_key1 = pto.vadds(tmp_key1, pto.ui32(TRANDOM_CONST_KEY_ADD_1), pg) + tmp_ctr1 = tmpL1 + tmp_ctr3 = tmpL0 + tmpL0, tmpH0 = pto.vmull(tmp_ctr0, const0, pg) + tmpL1, tmpH1 = pto.vmull(tmp_ctr2, const1, pg) + tmpH1 = pto.vxor(tmpH1, tmp_ctr1, pg) + tmp_ctr0 = pto.vxor(tmpH1, tmp_key0, pg) + tmpH0 = pto.vxor(tmpH0, tmp_ctr3, pg) + tmp_ctr2 = pto.vxor(tmpH0, tmp_key1, pg) + tmp_key0 = pto.vadds(tmp_key0, pto.ui32(TRANDOM_CONST_KEY_ADD_0), pg) + tmp_key1 = pto.vadds(tmp_key1, pto.ui32(TRANDOM_CONST_KEY_ADD_1), pg) + tmp_ctr1 = tmpL1 + tmp_ctr3 = tmpL0 + + tmpL0, tmpH0 = pto.vintlv(tmp_ctr0, tmp_ctr2) + tmpL1, tmpH1 = pto.vintlv(tmp_ctr1, tmp_ctr3) + tmp_ctr0, tmp_ctr1 = pto.vintlv(tmpL0, tmpL1) + tmp_ctr2, tmp_ctr3 = pto.vintlv(tmpH0, tmpH1) + + remained = s_reg + mask0, remained = pto.make_mask(dtype, remained) + mask1, remained = pto.make_mask(dtype, remained) + mask2, remained = pto.make_mask(dtype, remained) + mask3, remained = pto.make_mask(dtype, remained) + + pto.vsts(tmp_ctr0, dst[i, TRANDOM_ONCE_REPEAT * j * lanes:], mask0) + pto.vsts(tmp_ctr1, dst[i, (TRANDOM_ONCE_REPEAT * j + 1) * lanes:], mask1) + pto.vsts(tmp_ctr2, dst[i, (TRANDOM_ONCE_REPEAT * j + 2) * lanes:], mask2) + pto.vsts(tmp_ctr3, dst[i, (TRANDOM_ONCE_REPEAT * j + 3) * lanes:], mask3) + + if s_reg >= TRANDOM_ONCE_REPEAT * lanes: + s_reg = s_reg - TRANDOM_ONCE_REPEAT * lanes + else: + s_reg = 0 + + if j != n_loop - 1: + counter_add_val = lanes + else: + counter_add_val = (valid_cols - 1) % lanes + 1 + v_ele_stride = pto.vbr(pto.ui32(counter_add_val)) + ctr0_next, pd_next = pto.vaddc(ctr0, v_ele_stride, pg) + ctr1_next, pd_next2 = pto.vaddcs(ctr1, zeros, pd_next, pg) + ctr2_next, pd_next3 = pto.vaddcs(ctr2, zeros, pd_next2, pg) + ctr3_next, pd_next4 = pto.vaddcs(ctr3, zeros, pd_next3, pg) + ctr0 = ctr0_next + ctr1 = ctr1_next + ctr2 = ctr2_next + ctr3 = ctr3_next + + return diff --git a/lib/TileOps/trecip_template.py b/lib/TileOps/trecip_template.py new file mode 100644 index 000000000..44947f335 --- /dev/null +++ b/lib/TileOps/trecip_template.py @@ -0,0 +1,61 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.trecip with IEEE 754 high-precision support + +Computes reciprocal: dst = 1 / src +High-precision mode uses IEEE 754 compliant division algorithms. +""" + +import tilelang_dsl as pto + +# Import shared high-precision division algorithms +from div_hp import _div_ieee754_f32_impl, _div_ieee754_f16_impl + + +@pto.vkernel( + target="a5", + op="pto.trecip", + dtypes=[(pto.f16, pto.f16), (pto.f32, pto.f32)] +) +def template_trecip(src: pto.Tile, dst: pto.Tile): + """Reciprocal with optional high-precision mode: dst = 1 / src""" + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + precision_mode = pto.get_op_attr("precision_mode", "DEFAULT") + if pto.constexpr(precision_mode == "HIGH_PRECISION"): + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + vinput = pto.vlds(src[row, col:]) + if pto.constexpr(dtype == pto.f16): + one_scalar = pto.f16(1.0) + else: + one_scalar = pto.f32(1.0) + one = pto.vbr(one_scalar) + if pto.constexpr(dtype == pto.f32): + result = _div_ieee754_f32_impl(one, vinput, mask) + else: # dtype == pto.f16 (guaranteed by MLIR validation) + result = _div_ieee754_f16_impl(one, vinput, mask) + pto.vsts(result, dst[row, col:], mask) + else: + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + vinput = pto.vlds(src[row, col:]) + if pto.constexpr(dtype == pto.f16): + one_scalar = pto.f16(1.0) + else: + one_scalar = pto.f32(1.0) + one = pto.vbr(one_scalar) + result = pto.vdiv(one, vinput, mask) + pto.vsts(result, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/trelu_template.py b/lib/TileOps/trelu_template.py new file mode 100644 index 000000000..0cdd0e7eb --- /dev/null +++ b/lib/TileOps/trelu_template.py @@ -0,0 +1,43 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.trelu (Elementwise ReLU)""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.trelu", + dtypes=[(pto.f16, pto.f16), (pto.f32, pto.f32), (pto.i32, pto.i32)], + advanced=True +) +def template_trelu(src: pto.Tile, dst: pto.Tile): + """Elementwise ReLU: dst = max(0, src). + + Semantics: + For each element (i, j): + dst[i, j] = max(0, src[i, j]) + + Supported data types: f16, f32, i32 + """ + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + lanes = pto.get_lanes(dtype) + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, lanes): + mask, remained = pto.make_mask(dtype, remained) + src_vec = pto.vlds(src[row, col:]) + result = pto.vrelu(src_vec, mask) + pto.vsts(result, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/trem_template.py b/lib/TileOps/trem_template.py new file mode 100644 index 000000000..7073816d9 --- /dev/null +++ b/lib/TileOps/trem_template.py @@ -0,0 +1,64 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.trem""" + +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.trem", + dtypes=[ + (pto.f32, pto.f32, pto.f32, pto.f32), + (pto.f16, pto.f16, pto.f16, pto.f16), + (pto.i32, pto.i32, pto.i32, pto.i32), + ], + advanced=True, +) +def template_trem(src0: pto.Tile, src1: pto.Tile, tmp: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + if pto.constexpr(dtype == pto.f32): + quotient = pto.vdiv(lhs, rhs, mask) + quotient = pto.vtrc(quotient, mask, rnd=pto.VcvtRoundMode.F) + floored_mul = pto.vmul(quotient, rhs, mask) + result = pto.vsub(lhs, floored_mul, mask) + sign_diff_mask = pto.vcmps(pto.vmul(rhs, result, mask), 0.0, mask, pto.CmpMode.LT) + corrected = pto.vadd(result, rhs, sign_diff_mask) + result = pto.vsel(corrected, result, sign_diff_mask) + elif pto.constexpr(dtype == pto.f16): + lhs_even = pto.vcvt(lhs, pto.f32, mask, part=pto.VcvtPartMode.EVEN) + rhs_even = pto.vcvt(rhs, pto.f32, mask, part=pto.VcvtPartMode.EVEN) + lhs_odd = pto.vcvt(lhs, pto.f32, mask, part=pto.VcvtPartMode.ODD) + rhs_odd = pto.vcvt(rhs, pto.f32, mask, part=pto.VcvtPartMode.ODD) + q_even = pto.vdiv(lhs_even, rhs_even, mask) + q_odd = pto.vdiv(lhs_odd, rhs_odd, mask) + q_even = pto.vtrc(q_even, mask, rnd=pto.VcvtRoundMode.F) + q_odd = pto.vtrc(q_odd, mask, rnd=pto.VcvtRoundMode.F) + fm_even = pto.vmul(q_even, rhs_even, mask) + fm_odd = pto.vmul(q_odd, rhs_odd, mask) + r_even = pto.vsub(lhs_even, fm_even, mask) + r_odd = pto.vsub(lhs_odd, fm_odd, mask) + dst_even = pto.vcvt(r_even, pto.f16, mask, rnd=pto.VcvtRoundMode.Z, sat=pto.VcvtSatMode.RS_ENABLE, part=pto.VcvtPartMode.EVEN) + dst_odd = pto.vcvt(r_odd, pto.f16, mask, rnd=pto.VcvtRoundMode.Z, sat=pto.VcvtSatMode.RS_ENABLE, part=pto.VcvtPartMode.ODD) + result = pto.vor(dst_even, dst_odd, mask) + sign_diff_mask = pto.vcmps(pto.vmul(rhs, result, mask), 0.0, mask, pto.CmpMode.LT) + corrected = pto.vadd(result, rhs, sign_diff_mask) + result = pto.vsel(corrected, result, sign_diff_mask) + elif pto.constexpr(dtype == pto.i32): + result = pto.vmod(lhs, rhs, mask) + pto.vsts(result, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/trems_template.py b/lib/TileOps/trems_template.py new file mode 100644 index 000000000..29cc7038c --- /dev/null +++ b/lib/TileOps/trems_template.py @@ -0,0 +1,94 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.trems + +Note: A5 hardware implements trems as: + - float: dst = src - trunc(src/scalar) * scalar (f32 precision) + - half: dst = src - trunc(src/scalar) * scalar (computed in f32 precision, then converted back to f16) + - integer: dst = src % scalar (using vmod) - NOT YET SUPPORTED + +f16 path aligns: convert f16 to f32 (even/odd), compute in f32 with vtrc(ROUND_F), +convert back to f16 with ROUND_Z, merge with vor. +""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.trems", + dtypes=[(pto.f32, pto.f32, pto.f32, pto.f32), (pto.f16, pto.f16, pto.f16, pto.f16)], + advanced=True, +) +def template_trems(src: pto.Tile, scalar: pto.AnyType, tmp: pto.Tile, dst: pto.Tile): + dtype = src.element_type + valid_rows, valid_cols = src.valid_shape + + if pto.constexpr(dtype == pto.f32): + # f32 path: direct f32 computation + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(pto.f32)): + mask, remained = pto.make_mask(pto.f32, remained) + vec = pto.vlds(src[row, col:]) + scalar_vec = pto.vbr(scalar) + quotient = pto.vdiv(vec, scalar_vec, mask) + truncated = pto.vtrc(quotient, mask, rnd="Z") + product = pto.vmuls(truncated, scalar, mask) + result = pto.vsub(vec, product, mask) + pto.vsts(result, dst[row, col:], mask) + else: + # f16 path: compute in f32 precision, then convert back to f16 + # The predicate register is shared across f16 and f32 operations. + full_mask_b16 = pto.make_mask(pto.f16, pto.PAT.ALL) + full_mask_b32 = pto.make_mask(pto.f32, pto.PAT.ALL) + # Broadcast f16 scalar, convert to f32 vector for f32 arithmetic + scalar_vec_f16 = pto.vbr(scalar) + scalar_f32_vec = pto.vcvt(scalar_vec_f16, pto.f32, full_mask_b16, part=pto.VcvtPartMode.EVEN) + + for row in range(0, valid_rows, 1): + remained = valid_cols + remained_f32 = valid_cols + for col in range(0, valid_cols, pto.get_lanes(pto.f16)): + mask_f16, remained = pto.make_mask(pto.f16, remained) + mask_f32, remained_f32 = pto.make_mask(pto.f32, remained_f32) + vec = pto.vlds(src[row, col:]) + + # Convert f16 to f32 (even and odd parts) + vec_even = pto.vcvt(vec, pto.f32, full_mask_b16, part=pto.VcvtPartMode.EVEN) + vec_odd = pto.vcvt(vec, pto.f32, full_mask_b16, part=pto.VcvtPartMode.ODD) + + # Even part: f32 computation (use vmul with f32 scalar vector) + quotient_even = pto.vdiv(vec_even, scalar_f32_vec, mask_f32) + truncated_even = pto.vtrc(quotient_even, mask_f32, rnd="F") + product_even = pto.vmul(truncated_even, scalar_f32_vec, mask_f32) + result_even = pto.vsub(vec_even, product_even, mask_f32) + + # Odd part: f32 computation + quotient_odd = pto.vdiv(vec_odd, scalar_f32_vec, mask_f32) + truncated_odd = pto.vtrc(quotient_odd, mask_f32, rnd="F") + product_odd = pto.vmul(truncated_odd, scalar_f32_vec, mask_f32) + result_odd = pto.vsub(vec_odd, product_odd, mask_f32) + + # Convert f32 results back to f16 with ROUND_Z + saturation + result_f16_even = pto.vcvt(result_even, pto.f16, full_mask_b32, + rnd=pto.VcvtRoundMode.Z, + sat=pto.VcvtSatMode.SAT, + part=pto.VcvtPartMode.EVEN) + result_f16_odd = pto.vcvt(result_odd, pto.f16, full_mask_b32, + rnd=pto.VcvtRoundMode.Z, + sat=pto.VcvtSatMode.SAT, + part=pto.VcvtPartMode.ODD) + + # Merge even and odd parts + result_f16 = pto.vor(result_f16_even, result_f16_odd, mask_f16) + pto.vsts(result_f16, dst[row, col:], mask_f16) + return diff --git a/lib/TileOps/trowargmax_template.py b/lib/TileOps/trowargmax_template.py new file mode 100644 index 000000000..c0df597ed --- /dev/null +++ b/lib/TileOps/trowargmax_template.py @@ -0,0 +1,74 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.trowargmax""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + +@pto.vkernel( + target="a5", + op="pto.trowargmax", + advanced=True, +) +def template_trowargmax(src: pto.Tile, tmp: pto.Tile, dst: pto.Tile): + src_dtype = src.element_type + idx_dtype = dst.element_type + lanes = pto.get_lanes(src_dtype) + valid_rows, valid_cols = src.valid_shape + elem_bytes = pto.bytewidth(idx_dtype) + + # Initialize with dtype-specific minimum value (aligned with pto-isa Padding::Min) + init_val = pto.PadValue.MIN.eval(src_dtype) + + # Select one-point store dist based on index dtype size + if pto.constexpr(elem_bytes == 4): + store_dist = pto.VStoreDist.ONE_POINT_B32 + elif pto.constexpr(elem_bytes == 2): + store_dist = pto.VStoreDist.ONE_POINT_B16 + else: + store_dist = pto.VStoreDist.ONE_POINT_B8 + + for row in range(0, valid_rows, 1): + remained = valid_cols + + v_val_acc = pto.vbr(init_val) + init_zero_idx = idx_dtype(0) + v_idx_acc = pto.vbr(init_zero_idx) + + # Masks: src_dtype for data ops and final store (matches pto-isa CreatePredicate) + # idx_dtype for index arithmetic operations + mask_1, _ = pto.make_mask(src_dtype, 1) + mask_1_idx, _ = pto.make_mask(idx_dtype, 1) + + # Process all column chunks + for col in range(0, valid_cols, lanes): + mask, remained = pto.make_mask(src_dtype, remained) + v_src = pto.vlds(src[row, col:]) + v_reduced = pto.vcmax(v_src, mask) + + v_val, v_idx = pto.vdintlv(v_reduced, pto.vbr(src_dtype(0))) + v_idx = pto.vbitcast(v_idx, idx_dtype) + + # Add absolute col offset to the chunk's local index + col_offset = idx_dtype(col) + v_idx = pto.vadds(v_idx, col_offset, mask_1_idx) + + # Compare current chunk max with global max so far + cmp_mask = pto.vcmp(v_val_acc, v_val, mask_1, "lt") + + # Update global max and global argmax + v_val_acc = pto.vsel(v_val, v_val_acc, cmp_mask) + # v_idx_acc is ui32, requires b32 mask; convert cmp_mask from src_dtype's mask to b32 + cmp_mask_b32 = pto.pbitcast(cmp_mask, pto.mask_b32) + v_idx_acc = pto.vsel(v_idx, v_idx_acc, cmp_mask_b32) + + # Store index accumulator to destination tile using one-point mode + pto.vsts(v_idx_acc, dst[row, 0:], mask_1_idx, dist=store_dist) + return \ No newline at end of file diff --git a/lib/TileOps/trowargmin_template.py b/lib/TileOps/trowargmin_template.py new file mode 100644 index 000000000..f23ab7137 --- /dev/null +++ b/lib/TileOps/trowargmin_template.py @@ -0,0 +1,74 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.trowargmin""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + +@pto.vkernel( + target="a5", + op="pto.trowargmin", + advanced=True, +) +def template_trowargmin(src: pto.Tile, tmp: pto.Tile, dst: pto.Tile): + src_dtype = src.element_type + idx_dtype = dst.element_type + lanes = pto.get_lanes(src_dtype) + valid_rows, valid_cols = src.valid_shape + elem_bytes = pto.bytewidth(idx_dtype) + + # Initialize with dtype-specific maximum value (aligned with pto-isa Padding::Max) + init_val = pto.PadValue.MAX.eval(src_dtype) + + # Select one-point store dist based on index dtype size + if pto.constexpr(elem_bytes == 4): + store_dist = pto.VStoreDist.ONE_POINT_B32 + elif pto.constexpr(elem_bytes == 2): + store_dist = pto.VStoreDist.ONE_POINT_B16 + else: + store_dist = pto.VStoreDist.ONE_POINT_B8 + + for row in range(0, valid_rows, 1): + remained = valid_cols + + v_val_acc = pto.vbr(init_val) + init_zero_idx = idx_dtype(0) + v_idx_acc = pto.vbr(init_zero_idx) + + # Masks: src_dtype for data ops and final store (matches pto-isa CreatePredicate) + # idx_dtype for index arithmetic operations + mask_1, _ = pto.make_mask(src_dtype, 1) + mask_1_idx, _ = pto.make_mask(idx_dtype, 1) + + # Process all column chunks + for col in range(0, valid_cols, lanes): + mask, remained = pto.make_mask(src_dtype, remained) + v_src = pto.vlds(src[row, col:]) + v_reduced = pto.vcmin(v_src, mask) + + v_val, v_idx = pto.vdintlv(v_reduced, pto.vbr(src_dtype(0))) + v_idx = pto.vbitcast(v_idx, idx_dtype) + + # Add absolute col offset to the chunk's local index + col_offset = idx_dtype(col) + v_idx = pto.vadds(v_idx, col_offset, mask_1_idx) + + # Compare current chunk min with global min so far + cmp_mask = pto.vcmp(v_val_acc, v_val, mask_1, "gt") + + # Update global min and global argmin + v_val_acc = pto.vsel(v_val, v_val_acc, cmp_mask) + # v_idx_acc is ui32, requires b32 mask; cast cmp_mask from src_dtype's mask to b32 + cmp_mask_b32 = pto.pbitcast(cmp_mask, pto.mask_b32) + v_idx_acc = pto.vsel(v_idx, v_idx_acc, cmp_mask_b32) + + # Store index accumulator to destination tile using one-point mode + pto.vsts(v_idx_acc, dst[row, 0:], mask_1_idx, dist=store_dist) + return diff --git a/lib/TileOps/trowexpand_template.py b/lib/TileOps/trowexpand_template.py new file mode 100644 index 000000000..4dc6c6a79 --- /dev/null +++ b/lib/TileOps/trowexpand_template.py @@ -0,0 +1,48 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.trowexpand""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +def _constraint_trowexpand_row_major(src: pto.Tile, dst: pto.Tile) -> bool: + """Constraint for RowMajor layout trowexpand template.""" + # Both src and dst must be RowMajor layout + src_row_major = src.config.b_layout == pto.BLayout.ROW_MAJOR + dst_row_major = dst.config.b_layout == pto.BLayout.ROW_MAJOR + return src_row_major and dst_row_major + + +@pto.vkernel( + target="a5", + op="pto.trowexpand", + dtypes=[(pto.AnyFloat, pto.AnyFloat), (pto.AnyInt, pto.AnyInt)], + constraints=[_constraint_trowexpand_row_major], +) +def template_trowexpand(src: pto.Tile, dst: pto.Tile): + """Template for pto.trowexpand. + + Broadcast src[row, 0] to entire dst[row, :] for each row. + Semantics: dst[row, col] = src[row, 0] for all col. + """ + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + # Load the first element of each row (src has cols=1, so entire row is the scalar) + # vdup broadcasts the first element to the full vector width + scalar_vec = pto.vlds(src[row, :]) + broadcasted = pto.vdup(scalar_vec, mask) + pto.vsts(broadcasted, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/trowexpandadd_template.py b/lib/TileOps/trowexpandadd_template.py new file mode 100644 index 000000000..96019d501 --- /dev/null +++ b/lib/TileOps/trowexpandadd_template.py @@ -0,0 +1,52 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.trowexpandadd""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +def _constraint_trowexpandadd_row_major(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile) -> bool: + """Constraint for RowMajor layout trowexpandadd template.""" + # All tiles must be RowMajor layout + src0_row_major = src0.config.b_layout == pto.BLayout.ROW_MAJOR + src1_row_major = src1.config.b_layout == pto.BLayout.ROW_MAJOR + dst_row_major = dst.config.b_layout == pto.BLayout.ROW_MAJOR + return src0_row_major and src1_row_major and dst_row_major + + +@pto.vkernel( + target="a5", + op="pto.trowexpandadd", + dtypes=[(pto.AnyFloat, pto.AnyFloat, pto.AnyFloat), (pto.AnyInt, pto.AnyInt, pto.AnyInt)], + constraints=[_constraint_trowexpandadd_row_major], +) +def template_trowexpandadd(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + """Template for pto.trowexpandadd. + + Add a per-row scalar from src1[row, 0] to each row of src0. + Semantics: dst[row, col] = src0[row, col] + src1[row, 0] + """ + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + # Load the scalar vector from src1[row, :] + # For row-major src1, valid_shape[1] is 32/sizeof(dtype) (e.g., 8 for f32) + # vdup broadcasts the first element to the full vector width + scalar_vec = pto.vlds(src1[row, :]) + broadcasted = pto.vdup(scalar_vec, mask) + lhs = pto.vlds(src0[row, col:]) + result = pto.vadd(lhs, broadcasted, mask) + pto.vsts(result, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/trowexpanddiv_template.py b/lib/TileOps/trowexpanddiv_template.py new file mode 100644 index 000000000..627a118aa --- /dev/null +++ b/lib/TileOps/trowexpanddiv_template.py @@ -0,0 +1,107 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.trowexpanddiv with IEEE 754 high-precision support + +Divide each row of src0 by a per-row scalar from src1[row, 0]. +Semantics: dst[row, col] = src0[row, col] / src1[row, 0] +""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + +# Import shared high-precision division algorithms +from div_hp import _div_ieee754_f32_impl, _div_ieee754_f16_impl + + +def _constraint_trowexpanddiv_row_major(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile) -> bool: + """Constraint for RowMajor layout trowexpanddiv template.""" + # All tiles must be RowMajor layout + src0_row_major = src0.config.b_layout == pto.BLayout.ROW_MAJOR + src1_row_major = src1.config.b_layout == pto.BLayout.ROW_MAJOR + dst_row_major = dst.config.b_layout == pto.BLayout.ROW_MAJOR + return src0_row_major and src1_row_major and dst_row_major + + +@pto.vkernel( + target="a5", + op="pto.trowexpanddiv", + dtypes=[(pto.f32, pto.f32, pto.f32)], + constraints=[_constraint_trowexpanddiv_row_major], +) +def template_trowexpanddiv_f32(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + """Template for pto.trowexpanddiv with f32 dtype and optional high-precision mode.""" + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + precision_mode = pto.get_op_attr("precision_mode", "DEFAULT") + if pto.constexpr(precision_mode == "HIGH_PRECISION"): + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + # Load the scalar vector from src1[row, :] + # For row-major src1, valid_shape[1] is 32/sizeof(dtype) (e.g., 8 for f32) + # vdup broadcasts the first element to the full vector width + scalar_vec = pto.vlds(src1[row, :]) + broadcasted = pto.vdup(scalar_vec, mask) + lhs = pto.vlds(src0[row, col:]) + result = _div_ieee754_f32_impl(lhs, broadcasted, mask) + pto.vsts(result, dst[row, col:], mask) + else: + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + # Load the scalar vector from src1[row, :] + scalar_vec = pto.vlds(src1[row, :]) + broadcasted = pto.vdup(scalar_vec, mask) + lhs = pto.vlds(src0[row, col:]) + result = pto.vdiv(lhs, broadcasted, mask) + pto.vsts(result, dst[row, col:], mask) + return + + +@pto.vkernel( + target="a5", + op="pto.trowexpanddiv", + dtypes=[(pto.f16, pto.f16, pto.f16)], + constraints=[_constraint_trowexpanddiv_row_major], +) +def template_trowexpanddiv_f16(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + """Template for pto.trowexpanddiv with f16 dtype and optional high-precision mode.""" + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + precision_mode = pto.get_op_attr("precision_mode", "DEFAULT") + if pto.constexpr(precision_mode == "HIGH_PRECISION"): + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + # Load the scalar vector from src1[row, :] + # For row-major src1, valid_shape[1] is 32/sizeof(dtype) (e.g., 16 for f16) + # vdup broadcasts the first element to the full vector width + scalar_vec = pto.vlds(src1[row, :]) + broadcasted = pto.vdup(scalar_vec, mask) + lhs = pto.vlds(src0[row, col:]) + result = _div_ieee754_f16_impl(lhs, broadcasted, mask) + pto.vsts(result, dst[row, col:], mask) + else: + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + # Load the scalar vector from src1[row, :] + scalar_vec = pto.vlds(src1[row, :]) + broadcasted = pto.vdup(scalar_vec, mask) + lhs = pto.vlds(src0[row, col:]) + result = pto.vdiv(lhs, broadcasted, mask) + pto.vsts(result, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/trowexpandexpdif_template.py b/lib/TileOps/trowexpandexpdif_template.py new file mode 100644 index 000000000..fed08feab --- /dev/null +++ b/lib/TileOps/trowexpandexpdif_template.py @@ -0,0 +1,78 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.trowexpandexpdif""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +def _constraint_trowexpandexpdif_row_major(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile) -> bool: + """Constraint for RowMajor layout trowexpandexpdif template.""" + # All tiles must be RowMajor layout + src0_row_major = src0.config.b_layout == pto.BLayout.ROW_MAJOR + src1_row_major = src1.config.b_layout == pto.BLayout.ROW_MAJOR + dst_row_major = dst.config.b_layout == pto.BLayout.ROW_MAJOR + return src0_row_major and src1_row_major and dst_row_major + + +@pto.vkernel( + target="a5", + op="pto.trowexpandexpdif", + dtypes=[(pto.f32, pto.f32, pto.f32)], + constraints=[_constraint_trowexpandexpdif_row_major], +) +def template_trowexpandexpdif_f32(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + """Template for pto.trowexpandexpdif with f32 dtype. + + Compute exp(src0 - scalar) for each row using per-row scalars from src1[row, 0]. + Semantics: dst[row, col] = exp(src0[row, col] - src1[row, 0]) + Used in numerically stable softmax computation. + """ + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(pto.f32)): + mask, remained = pto.make_mask(pto.f32, remained) + scalar_vec = pto.vlds(src1[row, :]) + broadcasted = pto.vdup(scalar_vec, mask) + lhs = pto.vlds(src0[row, col:]) + result = pto.vexpdif(lhs, broadcasted, mask, pto.VcvtPartMode.EVEN) + pto.vsts(result, dst[row, col:], mask) + return + + +@pto.vkernel( + target="a5", + op="pto.trowexpandexpdif", + dtypes=[(pto.f16, pto.f16, pto.f16)], + constraints=[_constraint_trowexpandexpdif_row_major], +) +def template_trowexpandexpdif_f16(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + """Template for pto.trowexpandexpdif with f16 dtype. + + Compute exp(src0 - scalar) for each row using per-row scalars from src1[row, 0]. + Semantics: dst[row, col] = exp(src0[row, col] - src1[row, 0]) + Used in numerically stable softmax computation. + """ + dtype = pto.f16 + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + scalar_vec = pto.vlds(src1[row, :]) + broadcasted = pto.vdup(scalar_vec, mask) + lhs = pto.vlds(src0[row, col:]) + diff = pto.vsub(lhs, broadcasted, mask) + result = pto.vexp(diff, mask) + pto.vsts(result, dst[row, col:], mask) + return diff --git a/lib/TileOps/trowexpandmax_template.py b/lib/TileOps/trowexpandmax_template.py new file mode 100644 index 000000000..2e5d53124 --- /dev/null +++ b/lib/TileOps/trowexpandmax_template.py @@ -0,0 +1,52 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.trowexpandmax""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +def _constraint_trowexpandmax_row_major(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile) -> bool: + """Constraint for RowMajor layout trowexpandmax template.""" + # All tiles must be RowMajor layout + src0_row_major = src0.config.b_layout == pto.BLayout.ROW_MAJOR + src1_row_major = src1.config.b_layout == pto.BLayout.ROW_MAJOR + dst_row_major = dst.config.b_layout == pto.BLayout.ROW_MAJOR + return src0_row_major and src1_row_major and dst_row_major + + +@pto.vkernel( + target="a5", + op="pto.trowexpandmax", + dtypes=[(pto.AnyFloat, pto.AnyFloat, pto.AnyFloat), (pto.AnyInt, pto.AnyInt, pto.AnyInt)], + constraints=[_constraint_trowexpandmax_row_major], +) +def template_trowexpandmax(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + """Template for pto.trowexpandmax. + + Compute element-wise max of each row of src0 with a per-row scalar from src1[row, 0]. + Semantics: dst[row, col] = max(src0[row, col], src1[row, 0]) + """ + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + # Load the scalar vector from src1[row, :] + # For row-major src1, valid_shape[1] is 32/sizeof(dtype) (e.g., 8 for f32) + # vdup broadcasts the first element to the full vector width + scalar_vec = pto.vlds(src1[row, :]) + broadcasted = pto.vdup(scalar_vec, mask) + lhs = pto.vlds(src0[row, col:]) + result = pto.vmax(lhs, broadcasted, mask) + pto.vsts(result, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/trowexpandmin_template.py b/lib/TileOps/trowexpandmin_template.py new file mode 100644 index 000000000..eae99ff30 --- /dev/null +++ b/lib/TileOps/trowexpandmin_template.py @@ -0,0 +1,52 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.trowexpandmin""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +def _constraint_trowexpandmin_row_major(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile) -> bool: + """Constraint for RowMajor layout trowexpandmin template.""" + # All tiles must be RowMajor layout + src0_row_major = src0.config.b_layout == pto.BLayout.ROW_MAJOR + src1_row_major = src1.config.b_layout == pto.BLayout.ROW_MAJOR + dst_row_major = dst.config.b_layout == pto.BLayout.ROW_MAJOR + return src0_row_major and src1_row_major and dst_row_major + + +@pto.vkernel( + target="a5", + op="pto.trowexpandmin", + dtypes=[(pto.AnyFloat, pto.AnyFloat, pto.AnyFloat), (pto.AnyInt, pto.AnyInt, pto.AnyInt)], + constraints=[_constraint_trowexpandmin_row_major], +) +def template_trowexpandmin(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + """Template for pto.trowexpandmin. + + Compute element-wise min of each row of src0 with a per-row scalar from src1[row, 0]. + Semantics: dst[row, col] = min(src0[row, col], src1[row, 0]) + """ + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + # Load the scalar vector from src1[row, :] + # For row-major src1, valid_shape[1] is 32/sizeof(dtype) (e.g., 8 for f32) + # vdup broadcasts the first element to the full vector width + scalar_vec = pto.vlds(src1[row, :]) + broadcasted = pto.vdup(scalar_vec, mask) + lhs = pto.vlds(src0[row, col:]) + result = pto.vmin(lhs, broadcasted, mask) + pto.vsts(result, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/trowexpandmul_template.py b/lib/TileOps/trowexpandmul_template.py new file mode 100644 index 000000000..593420125 --- /dev/null +++ b/lib/TileOps/trowexpandmul_template.py @@ -0,0 +1,52 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.trowexpandmul""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +def _constraint_trowexpandmul_row_major(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile) -> bool: + """Constraint for RowMajor layout trowexpandmul template.""" + # All tiles must be RowMajor layout + src0_row_major = src0.config.b_layout == pto.BLayout.ROW_MAJOR + src1_row_major = src1.config.b_layout == pto.BLayout.ROW_MAJOR + dst_row_major = dst.config.b_layout == pto.BLayout.ROW_MAJOR + return src0_row_major and src1_row_major and dst_row_major + + +@pto.vkernel( + target="a5", + op="pto.trowexpandmul", + dtypes=[(pto.AnyFloat, pto.AnyFloat, pto.AnyFloat), (pto.AnyInt, pto.AnyInt, pto.AnyInt)], + constraints=[_constraint_trowexpandmul_row_major], +) +def template_trowexpandmul(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + """Template for pto.trowexpandmul. + + Multiply each row of src0 by a per-row scalar from src1[row, 0]. + Semantics: dst[row, col] = src0[row, col] * src1[row, 0] + """ + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + # Load the scalar vector from src1[row, :] + # For row-major src1, valid_shape[1] is 32/sizeof(dtype) (e.g., 8 for f32) + # vdup broadcasts the first element to the full vector width + scalar_vec = pto.vlds(src1[row, :]) + broadcasted = pto.vdup(scalar_vec, mask) + lhs = pto.vlds(src0[row, col:]) + result = pto.vmul(lhs, broadcasted, mask) + pto.vsts(result, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/trowexpandsub_template.py b/lib/TileOps/trowexpandsub_template.py new file mode 100644 index 000000000..ed44e8613 --- /dev/null +++ b/lib/TileOps/trowexpandsub_template.py @@ -0,0 +1,52 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.trowexpandsub""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +def _constraint_trowexpandsub_row_major(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile) -> bool: + """Constraint for RowMajor layout trowexpandsub template.""" + # All tiles must be RowMajor layout + src0_row_major = src0.config.b_layout == pto.BLayout.ROW_MAJOR + src1_row_major = src1.config.b_layout == pto.BLayout.ROW_MAJOR + dst_row_major = dst.config.b_layout == pto.BLayout.ROW_MAJOR + return src0_row_major and src1_row_major and dst_row_major + + +@pto.vkernel( + target="a5", + op="pto.trowexpandsub", + dtypes=[(pto.AnyFloat, pto.AnyFloat, pto.AnyFloat), (pto.AnyInt, pto.AnyInt, pto.AnyInt)], + constraints=[_constraint_trowexpandsub_row_major], +) +def template_trowexpandsub(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + """Template for pto.trowexpandsub. + + Subtract a per-row scalar from src1[row, 0] from each row of src0. + Semantics: dst[row, col] = src0[row, col] - src1[row, 0] + """ + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + # Load the scalar vector from src1[row, :] + # For row-major src1, valid_shape[1] is 32/sizeof(dtype) (e.g., 8 for f32) + # vdup broadcasts the first element to the full vector width + scalar_vec = pto.vlds(src1[row, :]) + broadcasted = pto.vdup(scalar_vec, mask) + lhs = pto.vlds(src0[row, col:]) + result = pto.vsub(lhs, broadcasted, mask) + pto.vsts(result, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/trowmax_template.py b/lib/TileOps/trowmax_template.py new file mode 100644 index 000000000..522ce699e --- /dev/null +++ b/lib/TileOps/trowmax_template.py @@ -0,0 +1,61 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.trowmax""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + +@pto.vkernel( + target="a5", + op="pto.trowmax", + advanced=True, +) +def template_trowmax(src: pto.Tile, tmp: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + lanes = pto.get_lanes(dtype) + valid_rows, valid_cols = src.valid_shape + elem_bytes = pto.bytewidth(dtype) + + # Initialize with dtype-specific minimum value (aligned with pto-isa Padding::Min) + init_val = pto.PadValue.MIN.eval(dtype) + + # Select one-point store dist based on dtype size + if pto.constexpr(elem_bytes == 4): + store_dist = pto.VStoreDist.ONE_POINT_B32 + elif pto.constexpr(elem_bytes == 2): + store_dist = pto.VStoreDist.ONE_POINT_B16 + else: + store_dist = pto.VStoreDist.ONE_POINT_B8 + + for row in range(0, valid_rows, 1): + remained = valid_cols + + mask_1, _ = pto.make_mask(dtype, 1) + + # Initialize the accumulator for ROWMAX + v_acc = pto.vbr(init_val) + + # Process column chunks + for col in range(0, valid_cols, lanes): + mask, remained = pto.make_mask(dtype, remained) + v_src = pto.vlds(src[row, col:]) + + # vcmax reduces src_dtype to acc_dtype + v_reduced = pto.vcmax(v_src, mask) + + # Clear masked lanes to init_val for float types so vmax doesn't see NaN + if pto.constexpr(dtype == pto.f32 or dtype == pto.f16): + v_reduced = pto.vsel(v_reduced, v_acc, mask) + + v_acc = pto.vmax(v_acc, v_reduced, mask_1) + + # Write final reduction to dest buffer once using one-point mode + pto.vsts(v_acc, dst[row, 0:], mask_1, dist=store_dist) + return diff --git a/lib/TileOps/trowmin_template.py b/lib/TileOps/trowmin_template.py new file mode 100644 index 000000000..f74b798c4 --- /dev/null +++ b/lib/TileOps/trowmin_template.py @@ -0,0 +1,63 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.trowmin""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.trowmin", + advanced=True, +) +def template_trowmin(src: pto.Tile, tmp: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + lanes = pto.get_lanes(dtype) + valid_rows, valid_cols = src.valid_shape + elem_bytes = pto.bytewidth(dtype) + + # Initialize with dtype-specific maximum value (aligned with pto-isa Padding::Max) + init_val = pto.PadValue.MAX.eval(dtype) + + # Select one-point store dist based on dtype size + if pto.constexpr(elem_bytes == 4): + store_dist = pto.VStoreDist.ONE_POINT_B32 + elif pto.constexpr(elem_bytes == 2): + store_dist = pto.VStoreDist.ONE_POINT_B16 + else: + store_dist = pto.VStoreDist.ONE_POINT_B8 + + mask_1, _ = pto.make_mask(dtype, 1) + + for row in range(0, valid_rows, 1): + remained = valid_cols + + # Initialize the accumulator for ROWMIN + v_acc = pto.vbr(init_val) + + # Process column chunks + for col in range(0, valid_cols, lanes): + mask, remained = pto.make_mask(dtype, remained) + v_src = pto.vlds(src[row, col:]) + + # vcmin reduces src_dtype to acc_dtype + v_reduced = pto.vcmin(v_src, mask) + + # Clear masked lanes to init_val for float types so vmin doesn't see NaN + if pto.constexpr(dtype == pto.f32 or dtype == pto.f16): + v_reduced = pto.vsel(v_reduced, v_acc, mask) + + # accumulate using the accumulator's mask logic + v_acc = pto.vmin(v_acc, v_reduced, mask_1) + + # Write final reduction to dest buffer once using one-point mode + pto.vsts(v_acc, dst[row, 0:], mask_1, dist=store_dist) + return diff --git a/lib/TileOps/trowprod_template.py b/lib/TileOps/trowprod_template.py new file mode 100644 index 000000000..55208d328 --- /dev/null +++ b/lib/TileOps/trowprod_template.py @@ -0,0 +1,71 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.trowprod""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + +@pto.vkernel( + target="a5", + op="pto.trowprod", + advanced=True, +) +def template_trowprod(src: pto.Tile, tmp: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + lanes = pto.get_lanes(dtype) + valid_rows, valid_cols = src.valid_shape + elem_bytes = pto.bytewidth(dtype) + + # nLoop from C++ constants: TROW_PROD_LOOP_B16=7, TROW_PROD_LOOP_B32=6 + TROW_PROD_LOOP_B16 = 7 + TROW_PROD_LOOP_B32 = 6 + if pto.constexpr(dtype == pto.f16 or dtype == pto.i16): + n_loop = TROW_PROD_LOOP_B16 + else: + n_loop = TROW_PROD_LOOP_B32 + + # Select one-point store dist based on dtype size + if pto.constexpr(elem_bytes == 4): + store_dist = pto.VStoreDist.ONE_POINT_B32 + elif pto.constexpr(elem_bytes == 2): + store_dist = pto.VStoreDist.ONE_POINT_B16 + else: + store_dist = pto.VStoreDist.ONE_POINT_B8 + + mask_1, _ = pto.make_mask(dtype, 1) + + for row in range(0, valid_rows, 1): + remained = valid_cols + + one_val = dtype(1) + v_acc = pto.vbr(one_val) + v_one = pto.vbr(one_val) + + # Multiply across column chunks + for col in range(0, valid_cols, lanes): + mask, remained = pto.make_mask(dtype, remained) + v_src = pto.vlds(src[row, col:]) + + # Element-wise product + v_prod = pto.vmul(v_acc, v_src, mask) + + # Simulate MODE_MERGING with vsel (keep v_acc outside mask) + v_acc = pto.vsel(v_prod, v_acc, mask) + + # Log2 reduction phase across the vector + reduce_mask, _ = pto.make_mask(dtype, lanes) # all lanes active for inner reduction + + for k in range(0, n_loop, 1): + v_intlv1, v_intlv2 = pto.vintlv(v_acc, v_one) + v_acc = pto.vmul(v_intlv1, v_intlv2, reduce_mask) + + # Write final result at lane 0 using one-point mode + pto.vsts(v_acc, dst[row, 0:], mask_1, dist=store_dist) + return diff --git a/lib/TileOps/trowsum_template.py b/lib/TileOps/trowsum_template.py new file mode 100644 index 000000000..cada565ff --- /dev/null +++ b/lib/TileOps/trowsum_template.py @@ -0,0 +1,75 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.trowsum""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.trowsum", +) +def template_trowsum(src: pto.Tile, tmp: pto.Tile, dst: pto.Tile): + src_dtype = src.element_type + dst_dtype = dst.element_type + + # vcadd widens i16 -> i32; floats/i32 unchanged + if pto.constexpr(src_dtype == pto.i16): + acc_dtype = pto.i32 + else: + acc_dtype = src_dtype + + lanes = pto.get_lanes(src_dtype) + valid_rows, valid_cols = src.valid_shape + + # Use type-appropriate zero for accumulator initialization + zero_val = acc_dtype(0) + + # Select one-point store dist based on dst dtype size + elem_bytes = pto.bytewidth(dst_dtype) + if pto.constexpr(elem_bytes == 4): + store_dist = pto.VStoreDist.ONE_POINT_B32 + elif pto.constexpr(elem_bytes == 2): + store_dist = pto.VStoreDist.ONE_POINT_B16 + else: + store_dist = pto.VStoreDist.ONE_POINT_B8 + + dst_mask_1, _ = pto.make_mask(dst_dtype, 1) + + for row in range(0, valid_rows, 1): + remained = valid_cols + + acc_mask_1, _ = pto.make_mask(acc_dtype, 1) + + # Initialize the accumulator with type-appropriate zero + v_acc = pto.vbr(zero_val) + + # Process column chunks + for col in range(0, valid_cols, lanes): + mask, remained = pto.make_mask(src_dtype, remained) + v_src = pto.vlds(src[row, col:]) + + # vcadd widens src_dtype to acc_dtype for integer types + v_reduced = pto.vcadd(v_src, mask) + + # accumulate using the accumulator's mask logic + v_acc = pto.vadd(v_acc, v_reduced, acc_mask_1) + + # Store the accumulated result safely once per row using one-point mode + if pto.constexpr(src_dtype == pto.i16): + # Truncate i32 accumulator back to i16 + # Non-saturation mode (wrap-around), matching pto-isa CTRL[59:60] behavior + acc_mask_for_cvt, _ = pto.make_mask(acc_dtype, 1) + v_acc_casted = pto.vcvt(v_acc, dst_dtype, acc_mask_for_cvt, sat=pto.VcvtSatMode.NOSAT, part=pto.VcvtPartMode.EVEN) + pto.vsts(v_acc_casted, dst[row, 0:], dst_mask_1, dist=store_dist) + else: + pto.vsts(v_acc, dst[row, 0:], dst_mask_1, dist=store_dist) + return diff --git a/lib/TileOps/trsqrt_template.py b/lib/TileOps/trsqrt_template.py new file mode 100644 index 000000000..87368adca --- /dev/null +++ b/lib/TileOps/trsqrt_template.py @@ -0,0 +1,36 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.trsqrt""" + +import tilelang_dsl as pto + +# TODO: Add implementation for HIGH_PRECISION type +@pto.vkernel( + target="a5", + op="pto.trsqrt", + dtypes=[(pto.f16, pto.f16), (pto.f32, pto.f32)] +) +def template_trsqrt(src: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + vinput = pto.vlds(src[row, col:]) + if pto.constexpr(dtype == pto.f16): + one_scalar = pto.f16(1.0) + else: + one_scalar = pto.f32(1.0) + one = pto.vbr(one_scalar) + sqrt_result = pto.vsqrt(vinput, mask) + result = pto.vdiv(one, sqrt_result, mask) + pto.vsts(result, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/tsel_template.py b/lib/TileOps/tsel_template.py new file mode 100644 index 000000000..92285716b --- /dev/null +++ b/lib/TileOps/tsel_template.py @@ -0,0 +1,99 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tsel + +NOTE: This template uses pto.plds for mask loading which directly +loads predicate mask from UB without vcmps comparison. +This approach matches the TSel.hpp implementation in pto-isa. + +Mask tile format: +- Packed predicate bytes in UB (`i8` tile data). +- Each row stores `ceil(valid_cols / 8)` valid bytes; tile row stride may be padded. + +REQUIRES: tilelang_dsl support for plds, astype(mask), pintlv_b16, castptr operations +""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tsel", + dtypes=[ + (pto.i8, pto.f32, pto.f32, pto.f32, pto.f32), + (pto.i8, pto.f16, pto.f16, pto.f16, pto.f16), + (pto.i8, pto.i8, pto.i8, pto.i8, pto.i8), + ], + advanced=True +) +def template_tsel(mask: pto.Tile, src0: pto.Tile, src1: pto.Tile, tmp: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + lanes = pto.get_lanes(dtype) + mask_row_stride = mask.shape[1] + mask_ptr = pto.castptr(mask.as_ptr(), pto.ptr(pto.ui8, pto.MemorySpace.UB)) + + if pto.constexpr(dtype == pto.f32): + full_mask_b16 = pto.pset_b16(pto.MaskPattern.ALL) + pair_width = lanes * 2 + paired_cols = (valid_cols // pair_width) * pair_width + for row in range(0, valid_rows, 1): + for col in range(0, paired_cols, pair_width): + mask_offset = row * mask_row_stride + col // 8 + select_mask_raw = pto.plds(mask_ptr, mask_offset, pto.PredicateDist.US) + select_mask = select_mask_raw.astype(pto.mask_b16) + pred0, _ = pto.make_mask(dtype, pair_width) + pred1, _ = pto.make_mask(dtype, lanes) + select_mask0, select_mask1 = pto.pintlv_b16(select_mask, full_mask_b16) + select_mask0 = select_mask0.astype(pto.mask_b32) + select_mask1 = select_mask1.astype(pto.mask_b32) + lhs0 = pto.vlds(src0[row, col:]) + rhs0 = pto.vlds(src1[row, col:]) + lhs1 = pto.vlds(src0[row, col + lanes:]) + rhs1 = pto.vlds(src1[row, col + lanes:]) + selected0 = pto.vsel(lhs0, rhs0, select_mask0) + selected1 = pto.vsel(lhs1, rhs1, select_mask1) + pto.vsts(selected0, dst[row, col:], pred0) + pto.vsts(selected1, dst[row, col + lanes:], pred1) + tail_cols = valid_cols - paired_cols + if tail_cols > 0: + col = paired_cols + mask_offset = row * mask_row_stride + col // 8 + select_mask_raw = pto.plds(mask_ptr, mask_offset, pto.PredicateDist.US) + select_mask = select_mask_raw.astype(pto.mask_b16) + select_mask0 = pto.punpack(select_mask, pto.PredicatePart.LOWER) + select_mask0 = select_mask0.astype(pto.mask_b32) + pred0, _ = pto.make_mask(dtype, tail_cols) + lhs0 = pto.vlds(src0[row, col:]) + rhs0 = pto.vlds(src1[row, col:]) + selected0 = pto.vsel(lhs0, rhs0, select_mask0) + pto.vsts(selected0, dst[row, col:], pred0) + else: + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, lanes): + pred_mask, remained = pto.make_mask(dtype, remained) + mask_offset = row * mask_row_stride + col // 8 + if pto.constexpr(dtype == pto.f16): + select_mask_raw = pto.plds(mask_ptr, mask_offset, pto.PredicateDist.US) + select_mask = select_mask_raw.astype(pto.mask_b16) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + selected = pto.vsel(lhs, rhs, select_mask) + pto.vsts(selected, dst[row, col:], pred_mask) + else: + select_mask = pto.plds(mask_ptr, mask_offset, pto.PredicateDist.NORM) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + selected = pto.vsel(lhs, rhs, select_mask) + pto.vsts(selected, dst[row, col:], pred_mask) + return \ No newline at end of file diff --git a/lib/TileOps/tsels_template.py b/lib/TileOps/tsels_template.py new file mode 100644 index 000000000..905d0dd06 --- /dev/null +++ b/lib/TileOps/tsels_template.py @@ -0,0 +1,121 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tsels + +NOTE: This template uses pto.plds for mask loading which directly +loads predicate mask from UB without vcmps comparison. + +TSels: Select between source tile and scalar based on mask. +- mask=true: select from src +- mask=false: select scalar value + +Mask tile format: +- Packed predicate bytes in UB. +- Each row stores ceil(valid_cols / 8) valid bytes; tile row stride may be padded. +- mask_dtype determines the storage format (i8/i16/i32), but the actual + predicate bits are packed and accessed as bytes. + +IMPORTANT: mask_row_stride is always mask.shape[1] (element count), +because mask tile stride equals cols in element units regardless of mask_dtype. +Byte offset for plds is col // 8 (one byte covers 8 elements). +""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tsels", + dtypes=[ + (pto.i8, pto.i8, pto.i8, pto.i8, pto.i8), + (pto.i16, pto.i8, pto.i8, pto.i8, pto.i8), + (pto.i32, pto.i8, pto.i8, pto.i8, pto.i8), + (pto.i8, pto.i16, pto.i16, pto.i16, pto.i16), + (pto.i16, pto.i16, pto.i16, pto.i16, pto.i16), + (pto.i32, pto.i16, pto.i16, pto.i16, pto.i16), + (pto.i8, pto.i32, pto.i32, pto.i32, pto.i32), + (pto.i16, pto.i32, pto.i32, pto.i32, pto.i32), + (pto.i32, pto.i32, pto.i32, pto.i32, pto.i32), + (pto.i8, pto.f16, pto.f16, pto.f16, pto.f16), + (pto.i16, pto.f16, pto.f16, pto.f16, pto.f16), + (pto.i32, pto.f16, pto.f16, pto.f16, pto.f16), + (pto.i8, pto.f32, pto.f32, pto.f32, pto.f32), + (pto.i16, pto.f32, pto.f32, pto.f32, pto.f32), + (pto.i32, pto.f32, pto.f32, pto.f32, pto.f32), + ], + advanced=True +) +def template_tsels(mask: pto.Tile, src: pto.Tile, tmp: pto.Tile, scalar: pto.AnyType, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + mask_dtype = mask.element_type + + lanes = pto.get_lanes(dtype) + mask_row_stride = mask.shape[1] * pto.bytewidth(mask_dtype) + mask_ptr = pto.castptr(mask.as_ptr(), pto.ptr(pto.ui8, pto.MemorySpace.UB)) + + scalar_mask, _ = pto.make_mask(dtype, lanes) + vreg_scalar = pto.vdup(scalar, scalar_mask) + + if pto.constexpr(lanes == 64): + full_mask_b16 = pto.pset_b16(pto.MaskPattern.ALL) + pair_width = lanes * 2 + paired_cols = (valid_cols // pair_width) * pair_width + for row in range(0, valid_rows, 1): + for col in range(0, paired_cols, pair_width): + mask_offset = row * mask_row_stride + col // 8 + select_mask_raw = pto.plds(mask_ptr, mask_offset, pto.PredicateDist.US) + select_mask = select_mask_raw.astype(pto.mask_b16) + pred0, _ = pto.make_mask(dtype, pair_width) + pred1, _ = pto.make_mask(dtype, lanes) + select_mask0, select_mask1 = pto.pintlv_b16(select_mask, full_mask_b16) + select_mask0 = select_mask0.astype(pto.mask_b32) + select_mask1 = select_mask1.astype(pto.mask_b32) + src0 = pto.vlds(src[row, col:]) + src1 = pto.vlds(src[row, col + lanes:]) + selected0 = pto.vsel(src0, vreg_scalar, select_mask0) + selected1 = pto.vsel(src1, vreg_scalar, select_mask1) + pto.vsts(selected0, dst[row, col:], pred0) + pto.vsts(selected1, dst[row, col + lanes:], pred1) + tail_cols = valid_cols - paired_cols + if tail_cols > 0: + col = paired_cols + mask_offset = row * mask_row_stride + col // 8 + select_mask_raw = pto.plds(mask_ptr, mask_offset, pto.PredicateDist.US) + select_mask = select_mask_raw.astype(pto.mask_b16) + select_mask0 = pto.punpack(select_mask, pto.PredicatePart.LOWER) + select_mask0 = select_mask0.astype(pto.mask_b32) + pred0, _ = pto.make_mask(dtype, tail_cols) + src0 = pto.vlds(src[row, col:]) + selected0 = pto.vsel(src0, vreg_scalar, select_mask0) + pto.vsts(selected0, dst[row, col:], pred0) + elif pto.constexpr(lanes == 128): + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, lanes): + pred_mask, remained = pto.make_mask(dtype, remained) + mask_offset = row * mask_row_stride + col // 8 + select_mask_raw = pto.plds(mask_ptr, mask_offset, pto.PredicateDist.US) + select_mask = select_mask_raw.astype(pto.mask_b16) + src_vec = pto.vlds(src[row, col:]) + selected = pto.vsel(src_vec, vreg_scalar, select_mask) + pto.vsts(selected, dst[row, col:], pred_mask) + else: + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, lanes): + pred_mask, remained = pto.make_mask(dtype, remained) + mask_offset = row * mask_row_stride + col // 8 + select_mask = pto.plds(mask_ptr, mask_offset, pto.PredicateDist.NORM) + src_vec = pto.vlds(src[row, col:]) + selected = pto.vsel(src_vec, vreg_scalar, select_mask) + pto.vsts(selected, dst[row, col:], pred_mask) + return \ No newline at end of file diff --git a/lib/TileOps/tshl_template.py b/lib/TileOps/tshl_template.py new file mode 100644 index 000000000..d236c8940 --- /dev/null +++ b/lib/TileOps/tshl_template.py @@ -0,0 +1,32 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tshl""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tshl" +) +def template_tshl(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + result = pto.vshl(lhs, rhs, mask) + pto.vsts(result, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/tshls_template.py b/lib/TileOps/tshls_template.py new file mode 100644 index 000000000..def0b0353 --- /dev/null +++ b/lib/TileOps/tshls_template.py @@ -0,0 +1,31 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tshls""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tshls", +) +def template_tshls(src: pto.Tile, scalar: pto.i16, dst: pto.Tile): + dtype = src.element_type + valid_rows, valid_cols = src.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + vec = pto.vlds(src[row, col:]) + result = pto.vshls(vec, scalar, mask) + pto.vsts(result, dst[row, col:], mask) + return diff --git a/lib/TileOps/tshr_template.py b/lib/TileOps/tshr_template.py new file mode 100644 index 000000000..f16ba9abe --- /dev/null +++ b/lib/TileOps/tshr_template.py @@ -0,0 +1,32 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tshr""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tshr" +) +def template_tshr(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + result = pto.vshr(lhs, rhs, mask) + pto.vsts(result, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/tshrs_template.py b/lib/TileOps/tshrs_template.py new file mode 100644 index 000000000..8366a638b --- /dev/null +++ b/lib/TileOps/tshrs_template.py @@ -0,0 +1,31 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tshrs""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tshrs", +) +def template_tshrs(src: pto.Tile, scalar: pto.AnyType, dst: pto.Tile): + dtype = src.element_type + valid_rows, valid_cols = src.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + vec = pto.vlds(src[row, col:]) + result = pto.vshrs(vec, scalar, mask) + pto.vsts(result, dst[row, col:], mask) + return diff --git a/lib/TileOps/tsort32_template.py b/lib/TileOps/tsort32_template.py new file mode 100644 index 000000000..aaf5b9f06 --- /dev/null +++ b/lib/TileOps/tsort32_template.py @@ -0,0 +1,241 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tsort32""" + +import tilelang_dsl as pto + +BLOCK_SIZE = 32 +FLOAT_DST_STRIDE_COEF = 2 +HALF_DST_STRIDE_COEF = 4 +MAX_UB_TMP = 32 * 255 # 8160 bytes +REPEAT_MAX = 255 + + +def _constraint_aligned( + src_shape=(), + src_valid_shape=(), + idx_shape=(), + idx_valid_shape=(), + dst_shape=(), + dst_valid_shape=(), + src_config=None, + idx_config=None, + dst_config=None, +) -> bool: + """Constraint for Format1: valid_cols % 32 == 0 (aligned, no tmp needed).""" + if len(src_valid_shape) != 2: + return False + valid_cols = src_valid_shape[1] + return valid_cols % BLOCK_SIZE == 0 + + +def _constraint_unaligned( + src_shape=(), + src_valid_shape=(), + idx_shape=(), + idx_valid_shape=(), + tmp_shape=(), + tmp_valid_shape=(), + dst_shape=(), + dst_valid_shape=(), + src_config=None, + idx_config=None, + tmp_config=None, + dst_config=None, +) -> bool: + """Constraint for Format2: valid_cols % 32 != 0 (unaligned, tmp needed).""" + if len(src_valid_shape) != 2: + return False + valid_cols = src_valid_shape[1] + return valid_cols % BLOCK_SIZE != 0 + + +@pto.vkernel( + target="a5", + advanced=True, + op="pto.tsort32", + constraints=[_constraint_aligned] +) +def template_tsort32(src: pto.Tile, idx: pto.Tile, dst: pto.Tile): + """ + TSort32 Format1: Bitonic sort for aligned cols (valid_cols % 32 == 0). + + Semantics (matching pto-isa TSort32.hpp Format1): + - Sorts src values into dst, generating indices in idx + - Direct sort without tmp buffer when src.valid_cols % 32 == 0 + - No padding needed + """ + dtype = dst.element_type + valid_rows = dst.valid_shape[0] + valid_cols = src.valid_shape[1] + + dst_ptr = dst.as_ptr() + src_ptr = src.as_ptr() + idx_ptr = idx.as_ptr() + + elem_bytes = pto.bytewidth(dtype) + dst_stride = ((dst.shape[1] * elem_bytes + BLOCK_SIZE - 1) // BLOCK_SIZE * BLOCK_SIZE) // elem_bytes + src_stride = ((src.shape[1] * elem_bytes + BLOCK_SIZE - 1) // BLOCK_SIZE * BLOCK_SIZE) // elem_bytes + idx_stride = ((idx.shape[1] * 4 + BLOCK_SIZE - 1) // BLOCK_SIZE * BLOCK_SIZE) // 4 + if idx.valid_shape[0] == 1: + idx_stride = 0 + + type_coef = HALF_DST_STRIDE_COEF + if pto.constexpr(dtype == pto.f32): + type_coef = FLOAT_DST_STRIDE_COEF + + repeat_num_per_row = (valid_cols + BLOCK_SIZE - 1) // BLOCK_SIZE + + if repeat_num_per_row <= REPEAT_MAX: + for i in range(0, valid_rows, 1): + pto.vbitsort( + pto.addptr(dst_ptr, i * dst_stride), + pto.addptr(src_ptr, i * src_stride), + pto.addptr(idx_ptr, i * idx_stride), + repeat_num_per_row + ) + else: + loop_num = (repeat_num_per_row + REPEAT_MAX - 1) // REPEAT_MAX + tail_repeat_num = repeat_num_per_row % REPEAT_MAX + for i in range(0, valid_rows, 1): + for j in range(0, loop_num, 1): + repeat_num = REPEAT_MAX + if j == loop_num - 1: + repeat_num = tail_repeat_num + + pto.vbitsort( + pto.addptr(dst_ptr, i * dst_stride + j * REPEAT_MAX * BLOCK_SIZE * type_coef), + pto.addptr(src_ptr, i * src_stride + j * REPEAT_MAX * BLOCK_SIZE), + pto.addptr(idx_ptr, i * idx_stride + j * REPEAT_MAX * BLOCK_SIZE), + repeat_num + ) + return + + +@pto.vkernel( + target="a5", + advanced=True, + op="pto.tsort32", + constraints=[_constraint_unaligned] +) +def template_tsort32_with_tmp(src: pto.Tile, idx: pto.Tile, tmp: pto.Tile, dst: pto.Tile): + """ + TSort32 Format2: Bitonic sort with tmp buffer for unaligned cols. + + Semantics (matching pto-isa TSort32.hpp Format2): + - Sorts src values into dst, generating indices in idx + - Uses tmp buffer when src.valid_cols % 32 != 0 (padding needed) + - Pads unaligned tail with NaN to ensure correct sorting + """ + dtype = dst.element_type + valid_rows = dst.valid_shape[0] + valid_cols = src.valid_shape[1] + + dst_ptr = dst.as_ptr() + src_ptr = src.as_ptr() + idx_ptr = idx.as_ptr() + tmp_ptr = tmp.as_ptr() + + elem_bytes = pto.bytewidth(dtype) + dst_stride = ((dst.shape[1] * elem_bytes + BLOCK_SIZE - 1) // BLOCK_SIZE * BLOCK_SIZE) // elem_bytes + src_stride = ((src.shape[1] * elem_bytes + BLOCK_SIZE - 1) // BLOCK_SIZE * BLOCK_SIZE) // elem_bytes + idx_stride = ((idx.shape[1] * 4 + BLOCK_SIZE - 1) // BLOCK_SIZE * BLOCK_SIZE) // 4 + if idx.valid_shape[0] == 1: + idx_stride = 0 + + type_coef = HALF_DST_STRIDE_COEF + if pto.constexpr(dtype == pto.f32): + type_coef = FLOAT_DST_STRIDE_COEF + + repeat_num_per_row = (valid_cols + BLOCK_SIZE - 1) // BLOCK_SIZE + src_tail_per_row = valid_cols % BLOCK_SIZE + src_tail_repeat_num = ((valid_cols + BLOCK_SIZE - 1) // BLOCK_SIZE) % REPEAT_MAX + + if pto.constexpr(dtype == pto.f16): + min_val = pto.f16(0xFC00) + elif pto.constexpr(dtype == pto.bf16): + min_val = pto.bf16(0xFF80) + else: + min_val = pto.f32(0xFF800000) + + src_shape_bytes_per_row = valid_cols * elem_bytes + + if src_shape_bytes_per_row <= MAX_UB_TMP: + len_burst = (src_shape_bytes_per_row + BLOCK_SIZE - 1) // BLOCK_SIZE + + for i in range(0, valid_rows, 1): + pto.copy_ubuf_to_ubuf( + pto.addptr(src_ptr, i * src_stride), + tmp_ptr, + 0, 1, len_burst, 0, 0 + ) + + tmp_last_offset = ((valid_cols + BLOCK_SIZE - 1) // BLOCK_SIZE * BLOCK_SIZE) - BLOCK_SIZE + + vec = pto.vlds(tmp[0, tmp_last_offset:]) + pad_mask, _ = pto.make_mask(dtype, BLOCK_SIZE - src_tail_per_row) + vec = pto.vdup(min_val, pad_mask) + pto.vsts(vec, tmp[0, tmp_last_offset:], pad_mask) + + pto.vbitsort( + pto.addptr(dst_ptr, i * dst_stride), + tmp_ptr, + pto.addptr(idx_ptr, i * idx_stride), + repeat_num_per_row + ) + else: + loop_num = (repeat_num_per_row + REPEAT_MAX - 1) // REPEAT_MAX + + for i in range(0, valid_rows, 1): + for j in range(0, loop_num, 1): + if j < loop_num - 1: + pto.vbitsort( + pto.addptr(dst_ptr, i * dst_stride + j * REPEAT_MAX * BLOCK_SIZE * type_coef), + pto.addptr(src_ptr, i * src_stride + j * REPEAT_MAX * BLOCK_SIZE), + pto.addptr(idx_ptr, i * idx_stride + j * REPEAT_MAX * BLOCK_SIZE), + REPEAT_MAX + ) + else: + if src_tail_repeat_num > 0: + sort_repeat_num = 0 + if src_tail_repeat_num > 1: + sort_repeat_num = src_tail_repeat_num - 1 + + pto.vbitsort( + pto.addptr(dst_ptr, i * dst_stride + j * REPEAT_MAX * BLOCK_SIZE * type_coef), + pto.addptr(src_ptr, i * src_stride + j * REPEAT_MAX * BLOCK_SIZE), + pto.addptr(idx_ptr, i * idx_stride + j * REPEAT_MAX * BLOCK_SIZE), + sort_repeat_num + ) + + tail_src_offset = (j * REPEAT_MAX + (src_tail_repeat_num - 1)) * BLOCK_SIZE + tail_dst_offset = (j * REPEAT_MAX + (src_tail_repeat_num - 1)) * BLOCK_SIZE * type_coef + len_burst = (src_tail_per_row * elem_bytes + BLOCK_SIZE - 1) // BLOCK_SIZE + + pto.copy_ubuf_to_ubuf( + pto.addptr(src_ptr, i * src_stride + tail_src_offset), + tmp_ptr, + 0, 1, len_burst, 0, 0 + ) + + tmp_last_offset = ((src_tail_per_row + BLOCK_SIZE - 1) // BLOCK_SIZE * BLOCK_SIZE) - BLOCK_SIZE + + vec = pto.vlds(tmp[0, tmp_last_offset:]) + pad_mask, _ = pto.make_mask(dtype, BLOCK_SIZE - src_tail_per_row) + vec = pto.vdup(min_val, pad_mask) + pto.vsts(vec, tmp[0, tmp_last_offset:], pad_mask) + + pto.vbitsort( + pto.addptr(dst_ptr, i * dst_stride + tail_dst_offset), + tmp_ptr, + pto.addptr(idx_ptr, i * idx_stride + tail_src_offset), + 1 + ) + + return \ No newline at end of file diff --git a/lib/TileOps/tsqrt_template.py b/lib/TileOps/tsqrt_template.py new file mode 100644 index 000000000..7381c50aa --- /dev/null +++ b/lib/TileOps/tsqrt_template.py @@ -0,0 +1,52 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tsqrt""" + +import tilelang_dsl as pto +from sqrt_hp import _tl_sqrt_precision + +@pto.inline_proc +def template_tsqrt_hp_impl(src: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + vinput = pto.vlds(src[row, col:]) + result = _tl_sqrt_precision(vinput, mask, dtype) + pto.vsts(result, dst[row, col:], mask) + return + +@pto.inline_proc +def template_tsqrt_impl(src: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + vinput = pto.vlds(src[row, col:]) + result = pto.vsqrt(vinput, mask) + pto.vsts(result, dst[row, col:], mask) + return + +@pto.vkernel( + target="a5", + op="pto.tsqrt" +) +def template_tsqrt(src: pto.Tile, dst: pto.Tile): + hp_mode = pto.get_op_attr("precision_mode") + if pto.constexpr(hp_mode == "HIGH_PRECISION"): + template_tsqrt_hp_impl(src, dst) + else: + template_tsqrt_impl(src, dst) + return \ No newline at end of file diff --git a/lib/TileOps/tstore_template.py b/lib/TileOps/tstore_template.py new file mode 100644 index 000000000..37597a299 --- /dev/null +++ b/lib/TileOps/tstore_template.py @@ -0,0 +1,253 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""`pto.tstore` 的 TileLang DSL 模板""" + +import tilelang_dsl as pto + + +def _constraint_scalar(value): + return value.value if hasattr(value, "value") else value + + +def _known_eq(lhs, rhs) -> bool: + lhs_value = _constraint_scalar(lhs) + rhs_value = _constraint_scalar(rhs) + if lhs_value is None or rhs_value is None: + return True + return lhs_value == rhs_value + + +def _known_le(lhs, rhs) -> bool: + lhs_value = _constraint_scalar(lhs) + rhs_value = _constraint_scalar(rhs) + if lhs_value is None or rhs_value is None: + return True + return lhs_value <= rhs_value + + +def _match_store_tile_layout(src, *, row_major: bool, s_layout) -> bool: + b_layout_ok = ( + src.config.b_layout == pto.BLayout.ROW_MAJOR + if row_major + else src.config.b_layout != pto.BLayout.ROW_MAJOR + ) + return b_layout_ok and src.config.s_layout == s_layout + + +def _check_store_bounds(src, dst, *, logical_rows, logical_cols, stride_axis=None) -> bool: + if dst.rank != 5: + return False + if stride_axis is not None and not _known_eq(dst.strides[stride_axis], 1): + return False + if not _known_eq(src.valid_shape[0], logical_rows): + return False + if not _known_eq(src.valid_shape[1], logical_cols): + return False + if not _known_le(src.valid_shape[0], src.shape[0]): + return False + if not _known_le(src.valid_shape[1], src.shape[1]): + return False + return True + + +def _tstore_preconditions_nd(src, dst) -> bool: + logical_rows = dst.shape[0] * dst.shape[1] * dst.shape[2] * dst.shape[3] + logical_cols = dst.shape[4] + return _match_store_tile_layout( + src, row_major=True, s_layout=pto.SLayout.NONE_BOX + ) and _check_store_bounds( + src, dst, logical_rows=logical_rows, logical_cols=logical_cols, stride_axis=4 + ) + +def _tstore_preconditions_dn(src, dst) -> bool: + logical_rows = dst.shape[3] + logical_cols = dst.shape[0] * dst.shape[1] * dst.shape[2] * dst.shape[4] + return _match_store_tile_layout( + src, row_major=False, s_layout=pto.SLayout.NONE_BOX + ) and _check_store_bounds( + src, dst, logical_rows=logical_rows, logical_cols=logical_cols, stride_axis=3 + ) + +def _tstore_preconditions_nz(src, dst) -> bool: + logical_rows = dst.shape[2] * dst.shape[3] + logical_cols = dst.shape[0] * dst.shape[1] * dst.shape[4] + return _match_store_tile_layout( + src, row_major=False, s_layout=pto.SLayout.ROW_MAJOR + ) and _check_store_bounds( + src, dst, logical_rows=logical_rows, logical_cols=logical_cols + ) + +@pto.vkernel( + target="a5", + op="pto.tstore", + advanced=True, + constraints=[_tstore_preconditions_nd], +) +def template_tstore_nd(src: pto.Tile, dst: pto.PartitionTensorView): + dtype = src.element_type + elem_bytes = pto.bytewidth(dtype) + + g0, g1, g2, g3, g4 = dst.shape + s0, s1, s2, s3, s4 = dst.strides + + valid_rows, valid_cols = src.valid_shape + ub_rows, ub_cols = src.shape + + # These preconditions are expressed through the descriptor-level constraint + # callable above, using direct `src.*` / `dst.*` metadata syntax. + + n_burst = g3 + len_burst = valid_cols * elem_bytes + ub_stride = ub_cols * elem_bytes + gm_stride = s3 * elem_bytes + + src_stride2 = g3 * ub_cols + src_stride1 = g2 * src_stride2 + src_stride0 = g1 * src_stride1 + + loop1 = g2 + loop2 = g1 + loop1_src_stride = src_stride2 * elem_bytes + loop1_dst_stride = s2 * elem_bytes + loop2_src_stride = src_stride1 * elem_bytes + loop2_dst_stride = s1 * elem_bytes + + ub_ptr = src.as_ptr() + gm_ptr = dst.as_ptr() + + if loop1 != 1 or loop2 != 1: + pto.set_loop2_stride_ubtoout( + src_stride=loop2_src_stride, dst_stride=loop2_dst_stride + ) + pto.set_loop1_stride_ubtoout( + src_stride=loop1_src_stride, dst_stride=loop1_dst_stride + ) + pto.set_loop_size_ubtoout(loop1=loop1, loop2=loop2) + + for i in range(0, g0, 1): + src_i = pto.addptr(ub_ptr, i * src_stride0) + dst_i = pto.addptr(gm_ptr, i * s0) + pto.copy_ubuf_to_gm( + dst=dst_i, + src=src_i, + n_burst=n_burst, + len_burst=len_burst, + gm_stride=gm_stride, + ub_stride=ub_stride, + ) + + if loop1 != 1 or loop2 != 1: + pto.set_loop_size_ubtoout(loop1=1, loop2=1) + return + +@pto.vkernel( + target="a5", + op="pto.tstore", + advanced=True, + constraints=[_tstore_preconditions_dn], +) +def template_tstore_dn(src: pto.Tile, dst: pto.PartitionTensorView): + dtype = src.element_type + elem_bytes = pto.bytewidth(dtype) + + g0, g1, g2, g3, g4 = dst.shape + s0, s1, s2, s3, s4 = dst.strides + + valid_rows, valid_cols = src.valid_shape + ub_rows, ub_cols = src.shape + + n_burst = g4 + len_burst = valid_rows * elem_bytes + gm_stride = s4 * elem_bytes + ub_stride = ub_rows * elem_bytes + + # UB 源 tile 是列高 `ub_rows` 的紧凑 col-major 布局, + # 与 `TStoreVecDN` 一样由 `g4` / `g2` / `g1` 递推出三级 stride。 + src_stride2 = ub_rows * g4 + src_stride1 = g2 * src_stride2 + src_stride0 = g1 * src_stride1 + + loop1 = g2 + loop2 = g1 + loop1_src_stride = src_stride2 * elem_bytes + loop1_dst_stride = s2 * elem_bytes + loop2_src_stride = src_stride1 * elem_bytes + loop2_dst_stride = s1 * elem_bytes + + ub_ptr = src.as_ptr() + gm_ptr = dst.as_ptr() + + if loop1 != 1 or loop2 != 1: + pto.set_loop2_stride_ubtoout( + src_stride=loop2_src_stride, dst_stride=loop2_dst_stride + ) + pto.set_loop1_stride_ubtoout( + src_stride=loop1_src_stride, dst_stride=loop1_dst_stride + ) + pto.set_loop_size_ubtoout(loop1=loop1, loop2=loop2) + + for i in range(0, g0, 1): + src_i = pto.addptr(ub_ptr, i * src_stride0) + dst_i = pto.addptr(gm_ptr, i * s0) + pto.copy_ubuf_to_gm( + dst=dst_i, + src=src_i, + n_burst=n_burst, + len_burst=len_burst, + gm_stride=gm_stride, + ub_stride=ub_stride, + ) + + if loop1 != 1 or loop2 != 1: + pto.set_loop_size_ubtoout(loop1=1, loop2=1) + return + +@pto.vkernel( + target="a5", + op="pto.tstore", + advanced=True, + constraints=[_tstore_preconditions_nz], +) +def template_tstore_nz(src: pto.Tile, dst: pto.PartitionTensorView): + dtype = src.element_type + elem_bytes = pto.bytewidth(dtype) + + g0, g1, g2, g3, g4 = dst.shape + s0, s1, s2, s3, s4 = dst.strides + + valid_rows, valid_cols = src.valid_shape + ub_rows, ub_cols = src.shape + + # 对应 C++ `C0_SIZE_BYTE`。NZ 每个 burst 始终写一个完整 C0 block。 + c0_size_bytes = 32 + n_burst = g1 + len_burst = valid_rows * c0_size_bytes + gm_stride = s1 * elem_bytes + ub_stride = ub_rows * c0_size_bytes + + # 每个 g0 block 在 UB 中由 `g1` 个 NZ block 串接组成。 + tile_stride = g1 * ub_rows * g4 + + ub_ptr = src.as_ptr() + gm_ptr = dst.as_ptr() + + # NZ path 本身不使用 loop1/loop2,主动切回 normal mode 避免继承旧状态。 + pto.set_loop_size_ubtoout(loop1=1, loop2=1) + for i in range(0, g0, 1): + src_i = pto.addptr(ub_ptr, i * tile_stride) + dst_i = pto.addptr(gm_ptr, i * s0) + pto.copy_ubuf_to_gm( + dst=dst_i, + src=src_i, + n_burst=n_burst, + len_burst=len_burst, + gm_stride=gm_stride, + ub_stride=ub_stride, + ) + return diff --git a/lib/TileOps/tsub_template.py b/lib/TileOps/tsub_template.py new file mode 100644 index 000000000..81d1b13dd --- /dev/null +++ b/lib/TileOps/tsub_template.py @@ -0,0 +1,32 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tsub""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tsub" +) +def template_tsub(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + subtracted = pto.vsub(lhs, rhs, mask) + pto.vsts(subtracted, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/tsubs_template.py b/lib/TileOps/tsubs_template.py new file mode 100644 index 000000000..84dc8bfbd --- /dev/null +++ b/lib/TileOps/tsubs_template.py @@ -0,0 +1,38 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.tsubs + +Note: A5 hardware implements tsubs as vadds with negated scalar: + dst = src - scalar = src + (-scalar) +This template uses vbr + vsub to achieve element-wise subtraction. +TODO: Use vadds(vec, -scalar) when DSL supports unary negation on scalars. +""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.tsubs", +) +def template_tsubs(src: pto.Tile, scalar: pto.AnyType, dst: pto.Tile): + dtype = src.element_type + valid_rows, valid_cols = src.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + vec = pto.vlds(src[row, col:]) + scalar_vec = pto.vbr(scalar) + result = pto.vsub(vec, scalar_vec, mask) + pto.vsts(result, dst[row, col:], mask) + return diff --git a/lib/TileOps/txor_template.py b/lib/TileOps/txor_template.py new file mode 100644 index 000000000..d2ca4f1f7 --- /dev/null +++ b/lib/TileOps/txor_template.py @@ -0,0 +1,32 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.txor""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.txor" +) +def template_txor(src0: pto.Tile, src1: pto.Tile, tmp: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + result = pto.vxor(lhs, rhs, mask) + pto.vsts(result, dst[row, col:], mask) + return \ No newline at end of file diff --git a/lib/TileOps/txors_template.py b/lib/TileOps/txors_template.py new file mode 100644 index 000000000..c06020950 --- /dev/null +++ b/lib/TileOps/txors_template.py @@ -0,0 +1,39 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL template for pto.txors + +Note: A5 hardware implements txors as: + TEXPANDS_IMPL(dst, scalar); // broadcast scalar to dst + TXOR_IMPL(dst, src, dst, tmp); // dst = src ^ dst + +This template uses vbr + vxor to achieve element-wise bitwise XOR. +""" + +import sys +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + target="a5", + op="pto.txors", +) +def template_txors(src: pto.Tile, scalar: pto.AnyType, tmp: pto.Tile, dst: pto.Tile): + dtype = src.element_type + valid_rows, valid_cols = src.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + vec = pto.vlds(src[row, col:]) + scalar_vec = pto.vbr(scalar) + result = pto.vxor(vec, scalar_vec, mask) + pto.vsts(result, dst[row, col:], mask) + return diff --git a/python/pto/dialects/pto.py b/python/pto/dialects/pto.py index a0b5037c7..81b4e37ec 100644 --- a/python/pto/dialects/pto.py +++ b/python/pto/dialects/pto.py @@ -8,6 +8,7 @@ import importlib import importlib.util +import functools from pathlib import Path from mlir import ir as _ods_ir @@ -85,6 +86,8 @@ def get_op_result_or_value(value): ReduceOpAttr = _pto_mod.ReduceOpAttr RoundMode = _pto_mod.RoundMode RoundModeAttr = _pto_mod.RoundModeAttr +PrecisionMode = _pto_mod.PrecisionMode +PrecisionModeAttr = _pto_mod.PrecisionModeAttr SaturationMode = _pto_mod.SaturationMode SaturationModeAttr = _pto_mod.SaturationModeAttr CmpMode = _pto_mod.CmpMode @@ -102,6 +105,59 @@ def get_op_result_or_value(value): QuantType = _pto_mod.QuantType QuantTypeAttr = _pto_mod.QuantTypeAttr + +_ptr_type_get_impl = PtrType.get +_ods_get_default_loc_context = getattr(_pto_ops_gen, "_ods_get_default_loc_context") + + +def _ptr_type_get_compat(cls, element_type, memory_space=None, context=None): + if isinstance(memory_space, _ods_ir.Context): + if context is not None: + raise TypeError("PtrType.get got multiple context arguments") + context = memory_space + memory_space = None + return _ptr_type_get_impl( + element_type, memory_space=memory_space, context=context + ) + + +PtrType.get = classmethod(_ptr_type_get_compat) + + +def _default_precision_mode_attr(loc=None): + ctx = _ods_get_default_loc_context(loc) + return PrecisionModeAttr.get(PrecisionMode.DEFAULT, ctx) + + +def _install_default_precision_mode_compat(): + for op_name in ( + "TDivOp", + "TDivSOp", + "TExpOp", + "TLogOp", + "TRecipOp", + "TRowExpandDivOp", + "TRsqrtOp", + "TSqrtOp", + "TColExpandDivOp", + ): + op_cls = getattr(_pto_ops_gen, op_name, None) + if op_cls is None or getattr(op_cls, "_pto_default_precision_mode_compat", False): + continue + original_init = op_cls.__init__ + + @functools.wraps(original_init) + def compat_init(self, *args, __orig_init=original_init, precision_mode=None, **kwargs): + if precision_mode is None: + precision_mode = _default_precision_mode_attr(kwargs.get("loc")) + __orig_init(self, *args, precision_mode=precision_mode, **kwargs) + + op_cls.__init__ = compat_init + op_cls._pto_default_precision_mode_compat = True + + +_install_default_precision_mode_compat() + __all__ = [ # Dialect utilities "register_dialect", @@ -141,6 +197,8 @@ def get_op_result_or_value(value): "ReduceOpAttr", "RoundMode", "RoundModeAttr", + "PrecisionMode", + "PrecisionModeAttr", "SaturationMode", "SaturationModeAttr", "CmpMode", @@ -846,3 +904,875 @@ def _install_op_aliases(): __all__.extend(_install_op_aliases()) + +# ----------------------------------------------------------------------------- +# Experimental VPTO Python DSL (`@pto.vkernel`) +# ----------------------------------------------------------------------------- +import ast as _ast +import inspect as _inspect +import textwrap as _textwrap +from dataclasses import dataclass as _dataclass + + +class _VKernelType: + def render(self): + raise NotImplementedError + + +@_dataclass(frozen=True) +class _VKernelScalarType(_VKernelType): + name: str + + def render(self): + return self.name + + +@_dataclass(frozen=True) +class _VKernelPtrType(_VKernelType): + elem: _VKernelType + space: str + + def render(self): + return f"!pto.ptr<{self.elem.render()}, {self.space}>" + + +@_dataclass(frozen=True) +class _VKernelVRegType(_VKernelType): + lanes: int + elem: _VKernelType + + def render(self): + return f"!pto.vreg<{self.lanes}x{self.elem.render()}>" + + +@_dataclass(frozen=True) +class _VKernelConstBinding: + value: object + + +@_dataclass(frozen=True) +class _VKernelStructDef(_VKernelType): + name: str + fields: tuple + + def render(self): + raise _VKernelCompileError(f"{self.name} is a template-only surface type; use .jit(...) to specialize it") + + def __call__(self, **kwargs): + return _VKernelStructBinding(self, dict(kwargs)) + + +@_dataclass(frozen=True) +class _VKernelStructBinding: + schema: _VKernelStructDef + values: dict + + +@_dataclass(frozen=True) +class _VKStaticSequence: + values: tuple + + +@_dataclass(frozen=True) +class _VKStructValue: + schema: _VKernelStructDef + fields: dict + + +i1 = _VKernelScalarType("i1") +i8 = _VKernelScalarType("i8") +i16 = _VKernelScalarType("i16") +i32 = _VKernelScalarType("i32") +i64 = _VKernelScalarType("i64") +f16 = _VKernelScalarType("f16") +bf16 = _VKernelScalarType("bf16") +f32 = _VKernelScalarType("f32") +_vk_index = _VKernelScalarType("index") +mask = _VKernelScalarType("!pto.mask") +align = _VKernelScalarType("!pto.align") + + +def ptr(elem_type, space): + return _VKernelPtrType(elem_type, space) + + +def vreg(lanes, elem_type): + return _VKernelVRegType(lanes, elem_type) + + +def const(value): + return _VKernelConstBinding(value) + + +def struct(cls): + annotations = dict(getattr(cls, "__annotations__", {})) + if not annotations: + raise _VKernelCompileError("@pto.struct requires annotated fields") + fields = [] + for name, field_ty in annotations.items(): + if field_ty not in (ptr, const): + raise _VKernelCompileError( + f"unsupported field annotation for {cls.__name__}.{name}: {field_ty!r}" + ) + fields.append((name, field_ty)) + return _VKernelStructDef(cls.__name__, tuple(fields)) + + +@struct +class Tile: + ub_ptr: ptr + shape: const + + +tile = Tile + + +class _VKernelCompileError(Exception): + pass + + +@_dataclass +class _VKValue: + name: str | None = None + type: _VKernelType | None = None + literal: object | None = None + + def render_type(self): + if self.type is None: + raise _VKernelCompileError(f"unresolved type for {self.name}") + return self.type.render() + + +def _project_result(group, index, ty): + return _VKValue(f"{group.name}#{index}", ty) + + +def _load_standard_dialects(): + try: + from mlir.dialects import arith as _mlir_arith # noqa: F401 + from mlir.dialects import func as _mlir_func # noqa: F401 + from mlir.dialects import scf as _mlir_scf # noqa: F401 + except ImportError as exc: + raise RuntimeError("mlir standard dialect python bindings are required for vkernel parsing") from exc + + +class _VKernelContext: + def __init__(self): + self.ssa_counter = 0 + self.arg_counter = 0 + + def new_ssa(self): + name = f"%{self.ssa_counter}" + self.ssa_counter += 1 + return name + + def new_arg(self): + name = f"%arg{self.arg_counter}" + self.arg_counter += 1 + return name + + +def _type_key(ty): + return ty.render() if ty is not None else None + + +def _types_equal(lhs, rhs): + if lhs is None or rhs is None: + return lhs is rhs + return lhs.render() == rhs.render() + + +def _ensure_type(value, expected): + if value.type is None: + value.type = expected + return + if not _types_equal(value.type, expected): + raise _VKernelCompileError( + f"type mismatch for {value.name}: expected {expected.render()}, got {value.type.render()}" + ) + + +def _literal_text(value): + if isinstance(value, bool): + return "true" if value else "false" + return str(value) + + +def _coerce_surface_type(value): + if value is bool: + return i1 + if value is float: + return f32 + return value + + +def _ptr_elem_bytes(ptr_type): + if not isinstance(ptr_type, _VKernelPtrType): + raise _VKernelCompileError("elem_bytes requires a ptr type") + elem_name = ptr_type.elem.render() + table = { + "i8": 1, + "i16": 2, + "i32": 4, + "i64": 8, + "f16": 2, + "bf16": 2, + "f32": 4, + } + if elem_name not in table: + raise _VKernelCompileError(f"unsupported elem_bytes for {elem_name}") + return table[elem_name] + + +def _ptr_vector_lanes(ptr_type): + return 256 // _ptr_elem_bytes(ptr_type) + + +class _VKernelBuilder: + def __init__(self, py_fn, fn_def, target, kernel_name, specialization=None): + self.py_fn = py_fn + self.fn_def = fn_def + self.target = target + self.kernel_name = kernel_name + self.ctx = _VKernelContext() + self.specialization = specialization or {} + + def _emit(self, lines, indent, text): + lines.append(" " * indent + text) + + def _eval_type_expr(self, node): + expr = _ast.Expression(body=node) + globals_dict = dict(self.py_fn.__globals__) + globals_dict.update(globals()) + value = eval(compile(expr, self.py_fn.__code__.co_filename, "eval"), + globals_dict, {}) + value = _coerce_surface_type(value) + if not isinstance(value, _VKernelType): + raise _VKernelCompileError(f"unsupported vkernel type annotation: {value!r}") + return value + + def _new_value(self, ty=None): + return _VKValue(self.ctx.new_ssa(), ty) + + def _new_arg_value(self, ty=None): + return _VKValue(self.ctx.new_arg(), ty) + + def _materialize_value(self, value, lines, indent, expected_type=None): + if expected_type is not None: + _ensure_type(value, expected_type) + if value.name is not None: + return value + if value.literal is None: + raise _VKernelCompileError("value has no SSA name and cannot be materialized") + if value.type is None: + raise _VKernelCompileError("literal requires type context") + value.name = self.ctx.new_ssa() + lit = _literal_text(value.literal) + if isinstance(value.literal, bool): + self._emit(lines, indent, f"{value.name} = arith.constant {lit}") + else: + self._emit(lines, indent, f"{value.name} = arith.constant {lit} : {value.type.render()}") + return value + + def _literal_value(self, node, lines, indent, expected_type): + value = _VKValue(type=expected_type, literal=node.value) + if expected_type is None: + return value + return self._materialize_value(value, lines, indent) + + def _lower_attribute(self, node, env, lines, indent, expected_type=None): + if isinstance(node.value, _ast.Name): + if node.value.id not in env: + raise _VKernelCompileError(f"unknown name '{node.value.id}'") + base = env[node.value.id] + else: + base = self._lower_expr(node.value, env, lines, indent) + if isinstance(base, _VKStructValue): + if node.attr not in base.fields: + raise _VKernelCompileError(f"unsupported struct attribute '{node.attr}'") + field = base.fields[node.attr] + if isinstance(field, _VKValue): + return self._materialize_value(field, lines, indent, expected_type) + return field + if isinstance(base, _VKValue) and isinstance(base.type, _VKernelPtrType): + if node.attr == "elem_bytes": + return _VKValue(type=expected_type, literal=_ptr_elem_bytes(base.type)) + raise _VKernelCompileError(f"unsupported attribute access '{node.attr}'") + + def _lower_subscript(self, node, env, lines, indent, expected_type=None): + base = self._lower_expr(node.value, env, lines, indent) + if not isinstance(base, _VKStaticSequence): + raise _VKernelCompileError("subscript base must be a static sequence") + if not isinstance(node.slice, _ast.Constant) or not isinstance(node.slice.value, int): + raise _VKernelCompileError("only constant integer subscripts are supported") + index = node.slice.value + if index < 0 or index >= len(base.values): + raise _VKernelCompileError("subscript out of range") + value = base.values[index] + if not isinstance(value, _VKValue): + value = _VKValue(type=expected_type, literal=value) + return self._materialize_value(value, lines, indent, expected_type) if expected_type is not None else value + + def _lower_binop(self, node, env, lines, indent, expected_type=None): + lhs = self._lower_expr(node.left, env, lines, indent) + rhs = self._lower_expr(node.right, env, lines, indent) + if lhs.literal is not None and rhs.literal is not None: + if isinstance(node.op, _ast.Mult): + result = lhs.literal * rhs.literal + elif isinstance(node.op, _ast.FloorDiv): + result = lhs.literal // rhs.literal + else: + raise _VKernelCompileError(f"unsupported binary operator: {type(node.op).__name__}") + return _VKValue(type=expected_type, literal=result) + raise _VKernelCompileError("non-constant binary expressions are not supported yet") + + def _lower_expr(self, node, env, lines, indent, expected_type=None): + if isinstance(node, _ast.Name): + if node.id not in env: + raise _VKernelCompileError(f"unknown name '{node.id}'") + value = env[node.id] + if isinstance(value, (_VKStructValue, _VKStaticSequence)): + raise _VKernelCompileError(f"name '{node.id}' is not a scalar/SSA value") + if ( + isinstance(value, _VKValue) + and value.name is None + and value.literal is not None + and expected_type is not None + ): + return self._materialize_value( + _VKValue(type=expected_type, literal=value.literal), + lines, + indent, + ) + return self._materialize_value(value, lines, indent, expected_type) + if isinstance(node, _ast.Constant): + return self._literal_value(node, lines, indent, expected_type) + if isinstance(node, _ast.Attribute): + return self._lower_attribute(node, env, lines, indent, expected_type) + if isinstance(node, _ast.Subscript): + return self._lower_subscript(node, env, lines, indent, expected_type) + if isinstance(node, _ast.BinOp): + return self._lower_binop(node, env, lines, indent, expected_type) + if isinstance(node, _ast.Call): + results = self._lower_call(node, env, lines, indent, expected_types=[expected_type] if expected_type else None) + if len(results) != 1: + raise _VKernelCompileError("expression expected single result") + return results[0] + raise _VKernelCompileError(f"unsupported expression: {type(node).__name__}") + + def _lower_call_name(self, node): + if isinstance(node, _ast.Attribute) and isinstance(node.value, _ast.Name) and node.value.id == "pto": + return node.attr + raise _VKernelCompileError("only pto.* calls are supported") + + def _infer_expr_type(self, node, env): + if isinstance(node, _ast.Name): + if node.id not in env: + raise _VKernelCompileError(f"unknown name '{node.id}'") + value = env[node.id] + return value.type if isinstance(value, _VKValue) else None + if isinstance(node, _ast.Attribute): + try: + value = self._lower_attribute(node, env, [], 0) + except _VKernelCompileError: + return None + return value.type if isinstance(value, _VKValue) else None + if isinstance(node, _ast.Constant): + return None + return None + + def _format_typed_operands(self, values): + return ", ".join(v.name for v in values), ", ".join(v.render_type() for v in values) + + def _lower_call(self, node, env, lines, indent, expected_types=None): + opname = self._lower_call_name(node.func) + + if opname in ("set_loop_size_outtoub", "set_loop_size_ubtoout"): + ops = [self._lower_expr(arg, env, lines, indent, i64) for arg in node.args] + operands, types = self._format_typed_operands(ops) + self._emit(lines, indent, f"pto.{opname} {operands} : {types}") + return [] + + if opname == "castptr": + if len(node.args) != 2: + raise _VKernelCompileError("pto.castptr expects 2 arguments") + result_type = self._eval_type_expr(node.args[1]) + addr = self._lower_expr(node.args[0], env, lines, indent, i64) + result = self._new_value(result_type) + self._emit(lines, indent, f"{result.name} = pto.castptr {addr.name} : {addr.render_type()} -> {result.render_type()}") + return [result] + + if opname == "copy_gm_to_ubuf": + expected = [None, None, i64, i64, i64, i64, i64, i1, i64, i64, i64] + ops = [self._lower_expr(arg, env, lines, indent, expected[i]) for i, arg in enumerate(node.args)] + operands, types = self._format_typed_operands(ops) + self._emit(lines, indent, f"pto.copy_gm_to_ubuf {operands} : {types}") + return [] + + if opname == "copy_ubuf_to_gm": + expected = [None, None, i64, i64, i64, i64, i64, i64] + ops = [self._lower_expr(arg, env, lines, indent, expected[i]) for i, arg in enumerate(node.args)] + operands, types = self._format_typed_operands(ops) + self._emit(lines, indent, f"pto.copy_ubuf_to_gm {operands} : {types}") + return [] + + if opname in ("set_flag", "wait_flag"): + attrs = [] + for arg in node.args: + if not isinstance(arg, _ast.Constant) or not isinstance(arg.value, str): + raise _VKernelCompileError(f"pto.{opname} expects string literals") + attrs.append(arg.value) + self._emit(lines, indent, f'pto.{opname}["{attrs[0]}", "{attrs[1]}", "{attrs[2]}"]') + return [] + + if opname == "barrier": + arg = node.args[0] + if not isinstance(arg, _ast.Constant) or not isinstance(arg.value, str): + raise _VKernelCompileError("pto.barrier expects a string literal") + self._emit(lines, indent, f"pto.barrier #pto.pipe<{arg.value}>") + return [] + + if opname == "plt_b32": + src = self._lower_expr(node.args[0], env, lines, indent, i32) + res0 = self._new_value(mask) + res1 = self._new_value(i32) + self._emit(lines, indent, f"{res0.name}, {res1.name} = pto.plt_b32 {src.name} : i32 -> !pto.mask, i32") + return [res0, res1] + + if opname == "pset_b32": + arg = node.args[0] + if not isinstance(arg, _ast.Constant) or not isinstance(arg.value, str): + raise _VKernelCompileError("pto.pset_b32 expects a string literal") + res = self._new_value(mask) + self._emit(lines, indent, f'{res.name} = pto.pset_b32 "{arg.value}" : !pto.mask') + return [res] + + if opname == "vlds": + ptr_value = self._lower_expr(node.args[0], env, lines, indent) + if not isinstance(ptr_value.type, _VKernelPtrType): + raise _VKernelCompileError("pto.vlds expects a ptr operand") + offset = self._lower_expr(node.args[1], env, lines, indent, _vk_index) + result = self._new_value(vreg(_ptr_vector_lanes(ptr_value.type), ptr_value.type.elem)) + self._emit(lines, indent, + f"{result.name} = pto.vlds {ptr_value.name}[{offset.name}] : {ptr_value.render_type()} -> {result.render_type()}") + return [result] + + if opname == "vabs": + vec_value = self._lower_expr(node.args[0], env, lines, indent) + mask_value = self._lower_expr(node.args[1], env, lines, indent, mask) + result = self._new_value(vec_value.type) + self._emit(lines, indent, + f"{result.name} = pto.vabs {vec_value.name}, {mask_value.name} : {vec_value.render_type()}, {mask_value.render_type()} -> {result.render_type()}") + return [result] + + if opname == "vsts": + vec_value = self._lower_expr(node.args[0], env, lines, indent) + ptr_value = self._lower_expr(node.args[1], env, lines, indent) + offset = self._lower_expr(node.args[2], env, lines, indent, _vk_index) + mask_value = self._lower_expr(node.args[3], env, lines, indent, mask) + self._emit(lines, indent, + f"pto.vsts {vec_value.name}, {ptr_value.name}[{offset.name}], {mask_value.name} : {vec_value.render_type()}, {ptr_value.render_type()}, {mask_value.render_type()}") + return [] + + raise _VKernelCompileError(f"unsupported pto op in vkernel: {opname}") + + def _collect_assigned_names(self, statements): + names = set() + + class Visitor(_ast.NodeVisitor): + def visit_Assign(self, node): + for target in node.targets: + self._collect_target(target) + + def _collect_target(self, target): + if isinstance(target, _ast.Name): + names.add(target.id) + elif isinstance(target, _ast.Tuple): + for elt in target.elts: + self._collect_target(elt) + + visitor = Visitor() + for stmt in statements: + if isinstance(stmt, (_ast.With, _ast.For, _ast.If)): + continue + visitor.visit(stmt) + return names + + def _compile_block(self, statements, env, indent): + lines = [] + current_env = dict(env) + + for stmt in statements: + if isinstance(stmt, _ast.Assign): + if len(stmt.targets) != 1: + raise _VKernelCompileError("multiple assignment targets are not supported") + target = stmt.targets[0] + if isinstance(target, _ast.Name): + value = self._lower_expr(stmt.value, current_env, lines, indent) + current_env[target.id] = value + elif isinstance(target, _ast.Tuple): + results = self._lower_call(stmt.value, current_env, lines, indent) + if len(results) != len(target.elts): + raise _VKernelCompileError("tuple assignment arity mismatch") + for elt, value in zip(target.elts, results): + if not isinstance(elt, _ast.Name): + raise _VKernelCompileError("tuple assignment only supports names") + current_env[elt.id] = value + else: + raise _VKernelCompileError("unsupported assignment target") + continue + + if isinstance(stmt, _ast.AnnAssign): + if stmt.value is None: + raise _VKernelCompileError("annotation-only assignment is not supported") + if not isinstance(stmt.target, _ast.Name): + raise _VKernelCompileError("annotated assignment only supports names") + target_type = self._eval_type_expr(stmt.annotation) + value = self._lower_expr(stmt.value, current_env, lines, indent, target_type) + current_env[stmt.target.id] = value + continue + + if isinstance(stmt, _ast.Expr): + if isinstance(stmt.value, _ast.Call): + self._lower_call(stmt.value, current_env, lines, indent) + else: + self._lower_expr(stmt.value, current_env, lines, indent) + continue + + if isinstance(stmt, _ast.Return): + if stmt.value is not None: + raise _VKernelCompileError("only empty return is supported") + self._emit(lines, indent, "return") + continue + + if isinstance(stmt, _ast.With): + if len(stmt.items) != 1: + raise _VKernelCompileError("only single with item is supported") + item = stmt.items[0] + name = self._lower_call_name(item.context_expr.func) + if name not in ("strict_vecscope", "vecscope"): + raise _VKernelCompileError("unsupported with context") + if name == "strict_vecscope": + body_lines, body_result = self._compile_strict_vecscope(item, stmt.body, current_env, indent) + else: + body_lines, body_result = self._compile_vecscope(stmt.body, current_env, indent) + lines.extend(body_lines) + current_env.update(body_result) + continue + + if isinstance(stmt, _ast.For): + loop_lines, updated_env = self._compile_for(stmt, current_env, indent) + lines.extend(loop_lines) + current_env = updated_env + continue + + if isinstance(stmt, _ast.If): + if_lines, updated_env = self._compile_if(stmt, current_env, indent) + lines.extend(if_lines) + current_env = updated_env + continue + + raise _VKernelCompileError(f"unsupported statement: {type(stmt).__name__}") + + return lines, current_env + + def _compile_vecscope(self, body, outer_env, indent): + body_lines, _ = self._compile_block(body, dict(outer_env), indent + 1) + lines = [] + self._emit(lines, indent, "pto.vecscope {") + lines.extend(body_lines) + self._emit(lines, indent, "}") + return lines, {} + + def _compile_strict_vecscope(self, item, body, outer_env, indent): + if not isinstance(item.optional_vars, _ast.Tuple): + raise _VKernelCompileError("pto.strict_vecscope requires tuple binding in 'as'") + if len(item.context_expr.args) != len(item.optional_vars.elts): + raise _VKernelCompileError("strict_vecscope capture arity must match bound block arguments") + arg_names = [] + inner_env = {} + for elt in item.optional_vars.elts: + if not isinstance(elt, _ast.Name): + raise _VKernelCompileError("pto.strict_vecscope bindings must be names") + arg = self._new_arg_value() + arg_names.append((elt.id, arg)) + inner_env[elt.id] = arg + + for expr, (_, arg) in zip(item.context_expr.args, arg_names): + inferred_type = self._infer_expr_type(expr, outer_env) + if inferred_type is not None: + arg.type = inferred_type + + lines = [] + body_lines, body_env = self._compile_block(body, inner_env, indent + 1) + captures = [] + for name, arg in arg_names: + if arg.type is None and name in body_env and body_env[name].type is not None: + arg.type = body_env[name].type + for expr, (_, arg) in zip(item.context_expr.args, arg_names): + if arg.type is None: + raise _VKernelCompileError("strict_vecscope block argument type could not be inferred") + capture = self._lower_expr(expr, outer_env, lines, indent, expected_type=arg.type) + captures.append(capture) + capture_operands = ", ".join(value.name for value in captures) + block_args = ", ".join(f"{arg.name}: {arg.render_type()}" for _, arg in arg_names) + func_type = ", ".join(arg.render_type() for _, arg in arg_names) + + self._emit(lines, indent, f"pto.strict_vecscope({capture_operands}) {{") + self._emit(lines, indent, f"^bb0({block_args}):") + lines.extend(body_lines) + self._emit(lines, indent, f"}} : ({func_type}) -> ()") + return lines, {} + + def _compile_for(self, stmt, outer_env, indent): + if not isinstance(stmt.target, _ast.Name): + raise _VKernelCompileError("for target must be a single name") + if not isinstance(stmt.iter, _ast.Call) or not isinstance(stmt.iter.func, _ast.Name) or stmt.iter.func.id != "range": + raise _VKernelCompileError("only Python range(...) loops are supported") + if len(stmt.iter.args) != 3: + raise _VKernelCompileError("range expects exactly 3 arguments in vkernel") + + lines = [] + lb = self._lower_expr(stmt.iter.args[0], outer_env, lines, indent, _vk_index) + ub = self._lower_expr(stmt.iter.args[1], outer_env, lines, indent, _vk_index) + step = self._lower_expr(stmt.iter.args[2], outer_env, lines, indent, _vk_index) + + loop_env = dict(outer_env) + iv = self._new_arg_value(_vk_index) + loop_env[stmt.target.id] = iv + candidate_carried = [] + for name in self._collect_assigned_names(stmt.body): + if name in outer_env and name != stmt.target.id: + iter_arg = self._new_arg_value(outer_env[name].type) + loop_env[name] = iter_arg + candidate_carried.append((name, outer_env[name], iter_arg)) + + body_lines, body_env = self._compile_block(stmt.body, loop_env, indent + 1) + carried = [] + for name, before, iter_arg in candidate_carried: + after = body_env.get(name) + if after is not None and after is not iter_arg: + carried.append((name, before, after)) + + result_prefix = "" + yield_line = None + if carried: + results = [after.render_type() for _, _, after in carried] + result_value = self._new_value() + result_prefix = f"{result_value.name}:{len(carried)} = " + iter_arg_map = {name: iter_arg for name, _, iter_arg in candidate_carried} + carried_with_initials = [] + for name, before, after in carried: + before = self._materialize_value(before, lines, indent, after.type) + carried_with_initials.append((name, before, after)) + carried = carried_with_initials + iter_args = ", ".join( + f"{iter_arg_map[name].name} = {before.name}" for name, before, _ in carried + ) + self._emit( + lines, + indent, + f"{result_prefix}scf.for {iv.name} = {lb.name} to {ub.name} step {step.name} iter_args({iter_args}) -> ({', '.join(results)}) {{", + ) + yield_line = f"scf.yield {', '.join(after.name for _, _, after in carried)} : {', '.join(results)}" + else: + self._emit(lines, indent, f"scf.for {iv.name} = {lb.name} to {ub.name} step {step.name} {{") + lines.extend(body_lines) + if yield_line: + self._emit(lines, indent + 1, yield_line) + self._emit(lines, indent, "}") + + updated_env = dict(outer_env) + if carried: + for idx, (name, _, after) in enumerate(carried): + updated_env[name] = _project_result(result_value, idx, after.type) + return lines, updated_env + + def _compile_if(self, stmt, outer_env, indent): + lines = [] + cond = self._lower_expr(stmt.test, outer_env, lines, indent, i1) + then_lines, then_env = self._compile_block(stmt.body, dict(outer_env), indent + 1) + else_lines, else_env = self._compile_block(stmt.orelse, dict(outer_env), indent + 1) + updated = [] + for name, before in outer_env.items(): + then_val = then_env.get(name, before) + else_val = else_env.get(name, before) + if then_val is not before or else_val is not before: + if not _types_equal(then_val.type, else_val.type): + raise _VKernelCompileError(f"if merge type mismatch for '{name}'") + updated.append((name, then_val, else_val)) + + if updated: + result = self._new_value() + types = ", ".join(val.type.render() for _, val, _ in updated) + self._emit(lines, indent, f"{result.name}:{len(updated)} = scf.if {cond.name} -> ({types}) {{") + lines.extend(then_lines) + self._emit(lines, indent + 1, f"scf.yield {', '.join(val.name for _, val, _ in updated)} : {types}") + self._emit(lines, indent, "} else {") + lines.extend(else_lines) + self._emit(lines, indent + 1, f"scf.yield {', '.join(val.name for _, _, val in updated)} : {types}") + self._emit(lines, indent, "}") + updated_env = dict(outer_env) + for idx, (name, then_val, _) in enumerate(updated): + updated_env[name] = _project_result(result, idx, then_val.type) + return lines, updated_env + + self._emit(lines, indent, f"scf.if {cond.name} {{") + lines.extend(then_lines) + self._emit(lines, indent, "} else {") + lines.extend(else_lines) + self._emit(lines, indent, "}") + return lines, dict(outer_env) + + def build_text(self): + lines = [f'module attributes {{pto.target_arch = "{self.target}"}} {{'] + arg_types = [] + env = {} + for arg in self.fn_def.args.args: + arg_ty = _coerce_surface_type(self.py_fn.__annotations__.get(arg.arg)) + if arg_ty is None: + raise _VKernelCompileError(f"missing type annotation for argument '{arg.arg}'") + if not isinstance(arg_ty, _VKernelType): + raise _VKernelCompileError(f"unsupported type annotation for argument '{arg.arg}'") + if isinstance(arg_ty, _VKernelStructDef): + if arg.arg not in self.specialization: + raise _VKernelCompileError( + f"template argument '{arg.arg}: {arg_ty.name}' requires .jit(...) specialization" + ) + binding = self.specialization[arg.arg] + if not isinstance(binding, _VKernelStructBinding) or binding.schema != arg_ty: + raise _VKernelCompileError( + f"specialization for '{arg.arg}' must be a {arg_ty.name}(...) binding" + ) + struct_fields = {} + for field_name, field_kind in arg_ty.fields: + if field_name not in binding.values: + raise _VKernelCompileError( + f"missing field '{field_name}' in specialization for '{arg.arg}'" + ) + field_value = binding.values[field_name] + if field_kind is ptr: + if not isinstance(field_value, _VKernelPtrType): + raise _VKernelCompileError( + f"{arg_ty.name}.{field_name} must be a pto.ptr(...) type object" + ) + arg_val = self._new_arg_value(field_value) + arg_types.append(f"{arg_val.name}: {field_value.render()}") + struct_fields[field_name] = arg_val + continue + if field_kind is const: + if not isinstance(field_value, _VKernelConstBinding): + raise _VKernelCompileError( + f"{arg_ty.name}.{field_name} must use pto.const(...)" + ) + static_value = field_value.value + if not isinstance(static_value, (list, tuple)) or not all( + isinstance(v, int) for v in static_value + ): + raise _VKernelCompileError( + f"{arg_ty.name}.{field_name} must be a list/tuple of ints" + ) + struct_fields[field_name] = _VKStaticSequence( + tuple(_VKValue(literal=v) for v in static_value) + ) + continue + raise _VKernelCompileError( + f"unsupported struct field kind for {arg_ty.name}.{field_name}" + ) + env[arg.arg] = _VKStructValue(arg_ty, struct_fields) + continue + arg_val = self._new_arg_value(arg_ty) + arg_types.append(f"{arg_val.name}: {arg_ty.render()}") + env[arg.arg] = arg_val + self._emit(lines, 1, f"func.func @{self.kernel_name}({', '.join(arg_types)}) {{") + body_lines, _ = self._compile_block(self.fn_def.body, env, 2) + lines.extend(body_lines) + if not any(line.strip() == "return" for line in body_lines): + self._emit(lines, 2, "return") + self._emit(lines, 1, "}") + lines.append("}") + return "\n".join(lines) + "\n" + + +class VKernelHandle: + def __init__(self, py_fn, target="a5", name=None, verify=True, specialization=None): + self._py_fn = py_fn + self._target = target + self._name = name or py_fn.__name__ + self._verify = verify + self._specialization = specialization or {} + self._cached_text = None + + def _load_ast(self): + source = _textwrap.dedent(_inspect.getsource(self._py_fn)) + module = _ast.parse(source) + for node in module.body: + if isinstance(node, _ast.FunctionDef) and node.name == self._py_fn.__name__: + return node + raise _VKernelCompileError(f"failed to locate function AST for {self._py_fn.__name__}") + + def mlir_text(self): + if self._cached_text is None: + builder = _VKernelBuilder( + self._py_fn, + self._load_ast(), + self._target, + self._name, + specialization=self._specialization, + ) + self._cached_text = builder.build_text() + return self._cached_text + + def mlir_module(self): + with _ods_ir.Context() as ctx: + _load_standard_dialects() + register_dialect(ctx, load=True) + return _ods_ir.Module.parse(self.mlir_text(), ctx) + + def verify(self): + mod = self.mlir_module() + mod.operation.verify() + return True + + def dump(self): + print(self.mlir_text(), end="") + + def emit(self, path): + with open(path, "w", encoding="utf-8") as f: + f.write(self.mlir_text()) + + def jit(self, **kwargs): + return VKernelHandle( + self._py_fn, + target=self._target, + name=self._name, + verify=self._verify, + specialization=kwargs, + ) + + def __str__(self): + return self.mlir_text() + + +def vkernel(py_fn=None, *, target="a5", name=None, verify=True): + def wrap(fn): + return VKernelHandle(fn, target=target, name=name, verify=verify) + + if py_fn is None: + return wrap + return wrap(py_fn) + + +__all__.extend([ + "vkernel", + "VKernelHandle", + "struct", + "Tile", + "tile", + "const", + "ptr", + "vreg", + "i1", "i8", "i16", "i32", "i64", + "f16", "bf16", "f32", + "mask", "align", +]) diff --git a/scripts/batch_compile_output_cpp.sh b/scripts/batch_compile_output_cpp.sh new file mode 100755 index 000000000..8de0fb114 --- /dev/null +++ b/scripts/batch_compile_output_cpp.sh @@ -0,0 +1,472 @@ +#!/usr/bin/env bash +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +set -u + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd -- "${SCRIPT_DIR}/.." && pwd)" +ASCEND_HOME_PATH="${ASCEND_HOME_PATH:-${HOME}/cann}" + +DEFAULT_SOURCE_DIR="${PTO_SOURCE_DIR:-${REPO_ROOT}}" +SRC_ROOT="${PTOAS_OUT_DIR:-${DEFAULT_SOURCE_DIR}/build/output}" +BUILD_ROOT="${DEFAULT_SOURCE_DIR}/build/output_asm" +LOG_DIR="${DEFAULT_SOURCE_DIR}/build/output_log" + +COMPILER="${COMPILER:-}" +PTO_ISA_PATH="${PTO_ISA_PATH:-${PTO_ISA_ROOT:-}}" +EXTRA_ARGS=() + +JOBS="${JOBS:-$(nproc)}" +AICORE_ARCH="${AICORE_ARCH:-dav-c310-vec}" +MEM_BASE_DEFINE="${MEM_BASE_DEFINE:-REGISTER_BASE}" +ENABLE_DEFAULT_ARGS=1 + +print_usage() { + cat <<'EOF' +批量编译 output 目录下所有 .cpp 文件为 .S,并汇总结果。 + +用法: + scripts/batch_compile_output_cpp.sh \ + [--compiler <编译器路径>] \ + [--pto-isa-path ] \ + [--compile-arg <单个参数>]... \ + [--jobs <并行数>] \ + [--aicore-arch ] \ + [--mem-base-define <宏名>] \ + [--src-root <源码目录>] \ + [--build-root <产物目录>] \ + [--log-dir <日志目录>] + +参数说明: + --compiler, -c 编译器路径。默认优先使用环境变量 COMPILER, + 其次使用 PATH 中的 bisheng 或 + ${ASCEND_HOME_PATH}/bin/bisheng + --pto-isa-path, -p PTO-ISA 根路径。默认优先使用环境变量 + PTO_ISA_PATH / PTO_ISA_ROOT。脚本会自动检测 include 目录: + 1) /include + 2) /tests/common (存在时自动加入) + 3) + --compile-arg 额外编译参数,可重复传入 + --jobs, -j 并行编译任务数,默认: nproc + --aicore-arch 默认: dav-c220-vec + --mem-base-define 默认: MEMORY_BASE (可改为 REGISTER_BASE) + --no-default-args 不使用脚本内置默认参数(仅使用 --compile-arg) + --src-root 要扫描的 .cpp 根目录,默认: $PTOAS_OUT_DIR + 或 $PTO_SOURCE_DIR/build/output + --build-root .S 产物目录,默认: $PTO_SOURCE_DIR/build/output_asm + --log-dir 编译日志目录,默认: /logs + --help, -h 显示帮助 + +推荐先执行: + source scripts/ptoas_env.sh + +默认编译参数来源: + 由 test/npu_validation/scripts/generate_testcase.py 中 + CMAKE_CCE_COMPILE_OPTIONS + target_compile_options() 提取: + -xcce -fenable-matrix --cce-aicore-enable-tl -fPIC -Xhost-start -Xhost-end + -mllvm -cce-aicore-function-stack-size=0x8000 + -mllvm -cce-aicore-record-overflow=true + -mllvm -cce-aicore-addr-transform + -mllvm -cce-aicore-dcci-insert-for-scalar=false + --cce-aicore-arch= -D -std=c++17 +EOF +} + +die() { + echo "[ERROR] $*" >&2 + exit 1 +} + +while [[ $# -gt 0 ]]; do + case "$1" in + --compiler | -c) + [[ $# -ge 2 ]] || die "--compiler 缺少参数" + COMPILER="$2" + shift 2 + ;; + --pto-isa-path | -p) + [[ $# -ge 2 ]] || die "--pto-isa-path 缺少参数" + PTO_ISA_PATH="$2" + shift 2 + ;; + --compile-arg) + [[ $# -ge 2 ]] || die "--compile-arg 缺少参数" + EXTRA_ARGS+=("$2") + shift 2 + ;; + --jobs | -j) + [[ $# -ge 2 ]] || die "--jobs 缺少参数" + JOBS="$2" + shift 2 + ;; + --aicore-arch) + [[ $# -ge 2 ]] || die "--aicore-arch 缺少参数" + AICORE_ARCH="$2" + shift 2 + ;; + --mem-base-define) + [[ $# -ge 2 ]] || die "--mem-base-define 缺少参数" + MEM_BASE_DEFINE="$2" + shift 2 + ;; + --no-default-args) + ENABLE_DEFAULT_ARGS=0 + shift + ;; + --src-root) + [[ $# -ge 2 ]] || die "--src-root 缺少参数" + SRC_ROOT="$2" + shift 2 + ;; + --build-root) + [[ $# -ge 2 ]] || die "--build-root 缺少参数" + BUILD_ROOT="$2" + shift 2 + ;; + --log-dir) + [[ $# -ge 2 ]] || die "--log-dir 缺少参数" + LOG_DIR="$2" + shift 2 + ;; + --help | -h) + print_usage + exit 0 + ;; + *) + die "未知参数: $1 (使用 --help 查看用法)" + ;; + esac +done + +if [[ -z "${COMPILER}" ]]; then + if command -v bisheng >/dev/null 2>&1; then + COMPILER="$(command -v bisheng)" + elif [[ -n "${ASCEND_HOME_PATH:-}" && -x "${ASCEND_HOME_PATH}/bin/bisheng" ]]; then + COMPILER="${ASCEND_HOME_PATH}/bin/bisheng" + fi +elif [[ "${COMPILER}" != */* ]] && command -v "${COMPILER}" >/dev/null 2>&1; then + COMPILER="$(command -v "${COMPILER}")" +fi + +[[ -n "${COMPILER}" ]] || die "未找到编译器,请先 source scripts/ptoas_env.sh,或通过 --compiler/COMPILER 指定 bisheng 路径" +[[ -n "${PTO_ISA_PATH}" ]] || die "未找到 PTO-ISA 路径,请通过 --pto-isa-path、PTO_ISA_PATH 或 PTO_ISA_ROOT 指定" +[[ -x "${COMPILER}" ]] || die "编译器不可执行: ${COMPILER}" +[[ -d "${SRC_ROOT}" ]] || die "源码目录不存在: ${SRC_ROOT}" +[[ -d "${PTO_ISA_PATH}" ]] || die "PTO-ISA 路径不存在: ${PTO_ISA_PATH}" +[[ "${JOBS}" =~ ^[1-9][0-9]*$ ]] || die "--jobs 必须为正整数" + +if [[ -z "${LOG_DIR}" ]]; then + LOG_DIR="${BUILD_ROOT}/logs" +fi + +mkdir -p "${BUILD_ROOT}" "${LOG_DIR}" || die "创建目录失败" + +INCLUDE_DIRS=() +if [[ -f "${PTO_ISA_PATH}/include/pto/pto-inst.hpp" ]]; then + INCLUDE_DIRS+=("${PTO_ISA_PATH}/include") +fi +if [[ -d "${PTO_ISA_PATH}/tests/common" ]]; then + INCLUDE_DIRS+=("${PTO_ISA_PATH}/tests/common") +fi +if [[ -f "${PTO_ISA_PATH}/pto/pto-inst.hpp" ]]; then + INCLUDE_DIRS+=("${PTO_ISA_PATH}") +fi +[[ ${#INCLUDE_DIRS[@]} -gt 0 ]] || die "未找到 pto/pto-inst.hpp,请检查 --pto-isa-path" + +if [[ -n "${ASCEND_HOME_PATH:-}" && -d "${ASCEND_HOME_PATH}/include" ]]; then + INCLUDE_DIRS+=("${ASCEND_HOME_PATH}/include") +fi +ASCEND_DRIVER_PATH="${ASCEND_DRIVER_PATH:-/usr/local/Ascend/driver}" +if [[ -d "${ASCEND_DRIVER_PATH}/kernel/inc" ]]; then + INCLUDE_DIRS+=("${ASCEND_DRIVER_PATH}/kernel/inc") +fi + +DEFAULT_ARGS=() +if [[ ${ENABLE_DEFAULT_ARGS} -eq 1 ]]; then + DEFAULT_ARGS=( + "-xcce" + "-fenable-matrix" + "--cce-aicore-enable-tl" + "--cce-aicore-only" + "-fPIC" + "-Xhost-start" + "-Xhost-end" + "-mllvm" "-cce-aicore-stack-size=0x8000" + "-mllvm" "-cce-aicore-function-stack-size=0x8000" + "-mllvm" "-cce-aicore-record-overflow=true" + "-mllvm" "-cce-aicore-addr-transform" + "-mllvm" "-cce-aicore-dcci-insert-for-scalar=false" + "--cce-aicore-arch=${AICORE_ARCH}" + "-D${MEM_BASE_DEFINE}" + "-std=c++17" + ) + if [[ "${AICORE_ARCH}" == dav-l310* || "${AICORE_ARCH}" == dav-l311* ]]; then + FILTERED_DEFAULT_ARGS=() + i=0 + while [[ ${i} -lt ${#DEFAULT_ARGS[@]} ]]; do + if [[ "${DEFAULT_ARGS[${i}]}" == "-mllvm" ]] && [[ $((i + 1)) -lt ${#DEFAULT_ARGS[@]} ]] && + [[ "${DEFAULT_ARGS[$((i + 1))]}" == "-cce-aicore-stack-size=0x8000" ]]; then + i=$((i + 2)) + continue + fi + FILTERED_DEFAULT_ARGS+=("${DEFAULT_ARGS[${i}]}") + i=$((i + 1)) + done + DEFAULT_ARGS=("${FILTERED_DEFAULT_ARGS[@]}") + fi +fi + +declare -a CPP_FILES=() +while IFS= read -r -d '' file; do + CPP_FILES+=("${file}") +done < <(find "${SRC_ROOT}" -type f -name "*.cpp" -print0 | sort -z) + +TOTAL_COUNT=${#CPP_FILES[@]} +[[ ${TOTAL_COUNT} -gt 0 ]] || die "未在 ${SRC_ROOT} 下找到 .cpp 文件" + +STATUS_FILE="$(mktemp "${BUILD_ROOT}/compile_status.XXXXXX")" || die "创建状态文件失败" +trap 'rm -f "${STATUS_FILE}"' EXIT + +record_compile_status() { + local status="$1" + local rel_path="$2" + printf '%s\t%s\n' "${status}" "${rel_path}" >>"${STATUS_FILE}" +} + +cleanup_work_dir() { + local work_dir="$1" + [[ -n "${work_dir}" ]] && rm -rf -- "${work_dir}" +} + +get_log_failure_reason() { + local log_path="$1" + local excerpt + + excerpt="$(grep -E -i 'error:|fatal:|undefined reference|undefined symbol|undeclared identifier|exception|traceback|failed' "${log_path}" | tail -n 5 || true)" + if [[ -z "${excerpt}" ]]; then + excerpt="$(tail -n 10 "${log_path}" 2>/dev/null || true)" + fi + printf '%s' "${excerpt}" +} + +find_generated_output() { + local work_dir="$1" + local src_stem="$2" + local candidate + + for candidate in \ + "${work_dir}/${src_stem}.o" \ + "${work_dir}/${src_stem}.S" \ + "${work_dir}/${src_stem}.s"; do + if [[ -f "${candidate}" ]]; then + printf '%s\n' "${candidate}" + return 0 + fi + done + + find "${work_dir}" -maxdepth 1 -type f \( -name "*.o" -o -name "*.S" -o -name "*.s" \) | head -n 1 +} + +write_rebuild_cmd() { + local cmd_path="$1" + local asm_path="$2" + local src_stem="$3" + shift 3 + local -a cmd=("$@") + local cmd_text="" + local arg + + for arg in "${cmd[@]}"; do + printf -v cmd_text '%s %q' "${cmd_text}" "${arg}" + done + cmd_text="${cmd_text# }" + + { + echo "#!/usr/bin/env bash" + echo + echo "set -euo pipefail" + echo + printf 'ASM_PATH=%q\n' "${asm_path}" + printf 'SRC_STEM=%q\n' "${src_stem}" + printf 'WORK_ROOT=%q\n' "${BUILD_ROOT}" + echo + echo 'WORK_DIR="$(mktemp -d "${WORK_ROOT}/tmp_rebuild.XXXXXX")"' + echo 'trap '\''rm -rf -- "${WORK_DIR}"'\'' EXIT' + echo + echo 'cd "${WORK_DIR}"' + echo "${cmd_text}" + echo + echo 'GENERATED_FILE=""' + echo 'for candidate in "${WORK_DIR}/${SRC_STEM}.o" "${WORK_DIR}/${SRC_STEM}.S" "${WORK_DIR}/${SRC_STEM}.s"; do' + echo ' if [[ -f "${candidate}" ]]; then' + echo ' GENERATED_FILE="${candidate}"' + echo ' break' + echo ' fi' + echo 'done' + echo + echo 'if [[ -z "${GENERATED_FILE}" ]]; then' + echo ' GENERATED_FILE="$(find "${WORK_DIR}" -maxdepth 1 -type f \( -name "*.o" -o -name "*.S" -o -name "*.s" \) | head -n 1)"' + echo 'fi' + echo + echo 'if [[ -z "${GENERATED_FILE}" || ! -f "${GENERATED_FILE}" ]]; then' + echo ' echo "[ERROR] 编译成功但未找到输出文件,期望类型: .o/.S/.s" >&2' + echo ' exit 1' + echo 'fi' + echo + echo 'mkdir -p "$(dirname -- "${ASM_PATH}")"' + echo 'mv -f -- "${GENERATED_FILE}" "${ASM_PATH}"' + printf 'echo "已更新: %s"\n' "${asm_path}" + } >"${cmd_path}" || return 1 + + chmod +x "${cmd_path}" +} + +compile_one() { + local src="$1" + local rel_path asm_path log_path cmd_path src_base src_stem work_dir generated_file + local -a cmd=() + + rel_path="${src#"${SRC_ROOT}/"}" + asm_path="${BUILD_ROOT}/${rel_path%.cpp}.S" + log_path="${LOG_DIR}/${rel_path%.cpp}.log" + cmd_path="$(dirname -- "${log_path}")/cmd.sh" + src_base="$(basename -- "${src}")" + src_stem="${src_base%.cpp}" + + mkdir -p "$(dirname -- "${asm_path}")" "$(dirname -- "${log_path}")" || { + record_compile_status "FAIL" "${rel_path}" + return 0 + } + + cmd=("${COMPILER}") + if [[ ${#DEFAULT_ARGS[@]} -gt 0 ]]; then + cmd+=("${DEFAULT_ARGS[@]}") + fi + if [[ ${#EXTRA_ARGS[@]} -gt 0 ]]; then + cmd+=("${EXTRA_ARGS[@]}") + fi + local inc + for inc in "${INCLUDE_DIRS[@]}"; do + cmd+=("-I${inc}") + done + cmd+=("-c" "${src}") + + if ! write_rebuild_cmd "${cmd_path}" "${asm_path}" "${src_stem}" "${cmd[@]}"; then + record_compile_status "FAIL" "${rel_path}" + return 0 + fi + + echo "[BUILD] ${rel_path}" + work_dir="$(mktemp -d "${BUILD_ROOT}/tmp_compile.XXXXXX")" || { + record_compile_status "FAIL" "${rel_path}" + return 0 + } + + if ! (cd "${work_dir}" && "${cmd[@]}") >"${log_path}" 2>&1; then + cleanup_work_dir "${work_dir}" + record_compile_status "FAIL" "${rel_path}" + return 0 + fi + + generated_file="$(find_generated_output "${work_dir}" "${src_stem}")" + + if [[ -z "${generated_file}" || ! -f "${generated_file}" ]]; then + { + echo + echo "[ERROR] 编译成功但未找到输出文件,期望类型: .o/.S/.s" + echo "[ERROR] 临时目录: ${work_dir}" + } >>"${log_path}" + cleanup_work_dir "${work_dir}" + record_compile_status "FAIL" "${rel_path}" + return 0 + fi + + if mv -f -- "${generated_file}" "${asm_path}"; then + cleanup_work_dir "${work_dir}" + record_compile_status "OK" "${rel_path}" + else + { + echo + echo "[ERROR] 输出重命名失败: ${generated_file} -> ${asm_path}" + } >>"${log_path}" + cleanup_work_dir "${work_dir}" + record_compile_status "FAIL" "${rel_path}" + fi +} + +START_TIME="$(date +%s)" + +echo "[INFO] 编译器: ${COMPILER}" +echo "[INFO] 源目录: ${SRC_ROOT}" +echo "[INFO] 产物目录(.S): ${BUILD_ROOT}" +echo "[INFO] 日志目录: ${LOG_DIR}" +echo "[INFO] PTO-ISA: ${PTO_ISA_PATH}" +echo "[INFO] 并行度: ${JOBS}" +echo "[INFO] include: ${INCLUDE_DIRS[*]}" +if [[ ${ENABLE_DEFAULT_ARGS} -eq 1 ]]; then + echo "[INFO] 默认参数(来自 generate_testcase.py): ${DEFAULT_ARGS[*]}" +else + echo "[INFO] 默认参数: 已禁用 (--no-default-args)" +fi +if [[ ${#EXTRA_ARGS[@]} -gt 0 ]]; then + echo "[INFO] 额外参数: ${EXTRA_ARGS[*]}" +fi +echo "[INFO] 文件总数: ${TOTAL_COUNT}" +echo + +running_jobs=0 +for src in "${CPP_FILES[@]}"; do + compile_one "${src}" & + running_jobs=$((running_jobs + 1)) + if [[ ${running_jobs} -ge ${JOBS} ]]; then + wait -n + running_jobs=$((running_jobs - 1)) + fi +done + +wait + +SUCCESS_COUNT="$(awk -F'\t' '$1=="OK"{c++} END{print c+0}' "${STATUS_FILE}")" +FAIL_COUNT="$(awk -F'\t' '$1=="FAIL"{c++} END{print c+0}' "${STATUS_FILE}")" + +declare -a FAILED_FILES=() +while IFS= read -r failed; do + [[ -n "${failed}" ]] && FAILED_FILES+=("${failed}") +done < <(awk -F'\t' '$1=="FAIL"{print $2}' "${STATUS_FILE}") + +END_TIME="$(date +%s)" +ELAPSED="$((END_TIME - START_TIME))" + +echo +echo "========== 编译汇总 ==========" +echo "总文件数 : ${TOTAL_COUNT}" +echo "成功数 : ${SUCCESS_COUNT}" +echo "失败数 : ${FAIL_COUNT}" +echo "耗时(秒) : ${ELAPSED}" + +if [[ ${FAIL_COUNT} -gt 0 ]]; then + failure_reason="" + echo + echo "失败文件列表:" + for f in "${FAILED_FILES[@]}"; do + echo " - ${f} (log: ${LOG_DIR}/${f%.cpp}.log)" + failure_reason="$(get_log_failure_reason "${LOG_DIR}/${f%.cpp}.log")" + if [[ -n "${failure_reason}" ]]; then + while IFS= read -r line; do + [[ -n "${line}" ]] || continue + echo " reason: ${line}" + done <<<"${failure_reason}" + fi + done + exit 1 +fi + +echo "[INFO] 全部编译成功" +exit 0 diff --git a/scripts/compile_pto_to_vpto_llvm.sh b/scripts/compile_pto_to_vpto_llvm.sh new file mode 100755 index 000000000..4b6c066c9 --- /dev/null +++ b/scripts/compile_pto_to_vpto_llvm.sh @@ -0,0 +1,124 @@ +#!/usr/bin/env bash +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT_DIR="$(cd "${SCRIPT_DIR}/.." && pwd)" + +PTO_FILE="${1:-}" +OUT_DIR_ARG="${2:-}" + +PTOAS_BIN="${PTOAS_BIN:-${ROOT_DIR}/build/tools/ptoas/ptoas}" +PTOAS_FLAGS="${PTOAS_FLAGS:---pto-arch a5}" +VPTO_FLAGS="${VPTO_FLAGS:---pto-backend=vpto --vpto-emit-hivm-llvm}" +AICORE_ARCH="${AICORE_ARCH:-dav-c310-vec}" +ASCEND_HOME_PATH="${ASCEND_HOME_PATH:-${HOME}/cann}" +BISHENG_BIN="" +BISHENG_FLAGS="${BISHENG_FLAGS:-}" +LLVM_IR="" +DEVICE_OBJ="" + +log() { + echo "[$(date +'%F %T')] $*" +} + +die() { + echo "ERROR: $*" >&2 + exit 1 +} + +on_error() { + local exit_code="$1" + if [[ -n "${LLVM_IR}" && -f "${LLVM_IR}" ]]; then + echo "Retained LLVM IR: ${LLVM_IR}" >&2 + fi + if [[ -n "${DEVICE_OBJ}" ]]; then + echo "Expected device object: ${DEVICE_OBJ}" >&2 + fi + exit "${exit_code}" +} + +trap 'on_error $?' ERR + +usage() { + cat < [output_dir] + +Environment overrides: + PTOAS_BIN path to ptoas + PTOAS_FLAGS default: --pto-arch a5 + VPTO_FLAGS default: --pto-backend=vpto --vpto-emit-hivm-llvm + ASCEND_HOME_PATH default: \$HOME/cann + BISHENG_BIN + BISHENG_FLAGS extra flags passed to bisheng when compiling .ll to .o + AICORE_ARCH default: dav-c310-vec + +Example: + $(basename "$0") test/samples/PyPTOIRParser/paged_attention_example_kernel_online_update.pto +EOF +} + +[[ -n "${PTO_FILE}" ]] || { + usage + exit 1 +} + +[[ "${PTO_FILE}" == *.pto ]] || die "input must be a .pto file: ${PTO_FILE}" +[[ -f "${PTO_FILE}" ]] || die "missing input file: ${PTO_FILE}" + +set +u +source "${ROOT_DIR}/scripts/ptoas_env.sh" +set -u + +if [[ -n "${ASCEND_HOME_PATH}" && -f "${ASCEND_HOME_PATH}/set_env.sh" ]]; then + set +u + source "${ASCEND_HOME_PATH}/set_env.sh" >/dev/null 2>&1 + set -u +fi + +BISHENG_BIN="${BISHENG_BIN:-${ASCEND_HOME_PATH}/bin/bisheng}" + +[[ -x "${PTOAS_BIN}" ]] || die "PTOAS_BIN is not executable: ${PTOAS_BIN}" +command -v "${BISHENG_BIN}" >/dev/null 2>&1 || die "bisheng not found: ${BISHENG_BIN}" + +pto_abs="$(cd "$(dirname "${PTO_FILE}")" && pwd)/$(basename "${PTO_FILE}")" +pto_base="$(basename "${PTO_FILE}" .pto)" + +if [[ -n "${OUT_DIR_ARG}" ]]; then + OUT_DIR="${OUT_DIR_ARG}" +else + OUT_DIR="${ROOT_DIR}/build/vpto_quick/${pto_base}" +fi + +mkdir -p "${OUT_DIR}" +OUT_DIR="$(cd "${OUT_DIR}" && pwd)" + +LLVM_IR="${OUT_DIR}/${pto_base}.ll" +DEVICE_OBJ="${OUT_DIR}/${pto_base}.o" + +log "step 1/2: lower PTO to VPTO LLVM IR" +"${PTOAS_BIN}" ${PTOAS_FLAGS} ${VPTO_FLAGS} \ + "${pto_abs}" \ + -o "${LLVM_IR}" + +log "step 2/2: compile LLVM IR to device object" +"${BISHENG_BIN}" \ + --target=hiipu64-hisilicon-cce \ + -march="${AICORE_ARCH}" \ + --cce-aicore-arch="${AICORE_ARCH}" \ + --cce-aicore-only \ + ${BISHENG_FLAGS} \ + -c -x ir "${LLVM_IR}" \ + -o "${DEVICE_OBJ}" + +log "done" +echo "LLVM IR: ${LLVM_IR}" +echo "Device object: ${DEVICE_OBJ}" diff --git a/scripts/ptoas_env.sh b/scripts/ptoas_env.sh new file mode 100644 index 000000000..15ed7466a --- /dev/null +++ b/scripts/ptoas_env.sh @@ -0,0 +1,91 @@ +#!/usr/bin/env bash +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# PTOAS runtime environment bootstrap. +# Usage: +# source scripts/ptoas_env.sh +# +# Optional overrides before sourcing: +# export WORKSPACE_DIR=/path/to/workspace +# export LLVM_BUILD_DIR=/path/to/llvm-project/build-shared +# export PTO_SOURCE_DIR=/path/to/PTOAS +# export PTO_INSTALL_DIR=/path/to/PTOAS/install +# export PTO_PYTHON_BIN=/path/to/python3 + +if [[ "${BASH_SOURCE[0]}" == "${0}" ]]; then + echo "This script must be sourced: source scripts/ptoas_env.sh" + exit 1 +fi + +_PTOAS_ENV_SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd)" +_PTOAS_REPO_DIR="$(cd -- "${_PTOAS_ENV_SCRIPT_DIR}/.." && pwd)" + +# Default layout: +# / +# ├── PTOAS/ +# └── llvm-project/ +export PTO_SOURCE_DIR="${PTO_SOURCE_DIR:-${_PTOAS_REPO_DIR}}" +export WORKSPACE_DIR="${WORKSPACE_DIR:-$(cd -- "${PTO_SOURCE_DIR}/.." && pwd)}" +export LLVM_SOURCE_DIR="${LLVM_SOURCE_DIR:-${WORKSPACE_DIR}/llvm-project}" +export LLVM_BUILD_DIR="${LLVM_BUILD_DIR:-${LLVM_SOURCE_DIR}/build-shared}" +export PTO_INSTALL_DIR="${PTO_INSTALL_DIR:-${PTO_SOURCE_DIR}/install}" +export PTO_ISA_PATH="${PTO_ISA_PATH:-${WORKSPACE_DIR}/pto-isa}" +export ASCEND_HOME_PATH="${ASCEND_HOME_PATH:-${HOME}/cann}" + +export MLIR_PYTHON_ROOT="${MLIR_PYTHON_ROOT:-${LLVM_BUILD_DIR}/tools/mlir/python_packages/mlir_core}" +export PTO_PYTHON_ROOT="${PTO_PYTHON_ROOT:-${PTO_INSTALL_DIR}}" +export PTO_PYTHON_BUILD_ROOT="${PTO_PYTHON_BUILD_ROOT:-${PTO_SOURCE_DIR}/build/python}" +export PYBIND11_CMAKE_DIR=$(python3 -m pybind11 --cmakedir) +export PTOAS_FLAGS="${PTOAS_FLAGS:-}" +export PTOAS_OUT_DIR=$PTO_SOURCE_DIR/build/output + +_ptoas_prepend_path() { + local var_name="$1" + local value="$2" + local current="${!var_name:-}" + if [[ -z "${value}" ]]; then + return 0 + fi + if [[ ! -e "${value}" ]]; then + return 0 + fi + if [[ ":${current}:" == *":${value}:"* ]]; then + return 0 + fi + if [[ -z "${current}" ]]; then + printf -v "${var_name}" '%s' "${value}" + else + printf -v "${var_name}" '%s:%s' "${value}" "${current}" + fi + export "${var_name}" +} + +_ptoas_prepend_path PYTHONPATH "${MLIR_PYTHON_ROOT}" +_ptoas_prepend_path PYTHONPATH "${PTO_PYTHON_ROOT}" +_ptoas_prepend_path PYTHONPATH "${PTO_PYTHON_BUILD_ROOT}" + +_ptoas_prepend_path LD_LIBRARY_PATH "${LLVM_BUILD_DIR}/lib" +_ptoas_prepend_path LD_LIBRARY_PATH "${PTO_INSTALL_DIR}/lib" +_ptoas_prepend_path LD_LIBRARY_PATH "${PTO_SOURCE_DIR}/build/lib" + +_ptoas_prepend_path PATH "${PTO_SOURCE_DIR}/build/tools/ptoas" + +if [[ -n "${PTO_PYTHON_BIN:-}" && -x "${PTO_PYTHON_BIN}" ]]; then + alias ptoas-python="${PTO_PYTHON_BIN}" +fi + +echo "[ptoas_env] PTO_SOURCE_DIR=${PTO_SOURCE_DIR}" +echo "[ptoas_env] LLVM_BUILD_DIR=${LLVM_BUILD_DIR}" +echo "[ptoas_env] PTO_INSTALL_DIR=${PTO_INSTALL_DIR}" +echo "[ptoas_env] PTO_ISA_PATH=${PTO_ISA_PATH}" +echo "[ptoas_env] ASCEND_HOME_PATH=${ASCEND_HOME_PATH}" +echo "[ptoas_env] PATH/PYTHONPATH/LD_LIBRARY_PATH updated" + +unset _PTOAS_ENV_SCRIPT_DIR +unset _PTOAS_REPO_DIR diff --git a/test/dsl/abs.py b/test/dsl/abs.py new file mode 100644 index 000000000..1917a5725 --- /dev/null +++ b/test/dsl/abs.py @@ -0,0 +1,42 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import mlir.dialects.pto as pto + + +@pto.vkernel(target="a5", name="abs_kernel_2d") +def abs_kernel_2d(inp: pto.ptr(pto.f32, "gm"), out: pto.ptr(pto.f32, "gm")): + ub_in = pto.castptr(0, pto.ptr(pto.f32, "ub")) + ub_out = pto.castptr(4096, pto.ptr(pto.f32, "ub")) + + pto.set_loop_size_outtoub(1, 1) + pto.copy_gm_to_ubuf(inp, ub_in, 0, 32, 128, 0, 0, False, 0, 128, 128) + + pto.set_flag("PIPE_MTE2", "PIPE_V", "EVENT_ID0") + pto.wait_flag("PIPE_MTE2", "PIPE_V", "EVENT_ID0") + + with pto.vecscope(): + remaining: pto.i32 = 1024 + for offset in range(0, 1024, 64): + mask, remaining = pto.plt_b32(remaining) + vec_in = pto.vlds(ub_in, offset) + vec_out = pto.vabs(vec_in, mask) + pto.vsts(vec_out, ub_out, offset, mask) + + pto.set_flag("PIPE_V", "PIPE_MTE3", "EVENT_ID0") + pto.wait_flag("PIPE_V", "PIPE_MTE3", "EVENT_ID0") + + pto.set_loop_size_ubtoout(1, 1) + pto.copy_ubuf_to_gm(ub_out, out, 0, 32, 128, 0, 128, 128) + pto.barrier("PIPE_ALL") + + return + + +if __name__ == "__main__": + print(abs_kernel_2d.mlir_text(), end="") diff --git a/test/dsl/expand_tile_op_tilelang_tadds.pto b/test/dsl/expand_tile_op_tilelang_tadds.pto new file mode 100644 index 000000000..55c206856 --- /dev/null +++ b/test/dsl/expand_tile_op_tilelang_tadds.pto @@ -0,0 +1,37 @@ +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test ExpandTileOp expansion for pto.tadds in the VPTO pipeline. +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto %s -o - | FileCheck %s +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand --emit-vpto %s -o - | FileCheck %s + +// CHECK: func.func @TADDS() +// CHECK: pto.vecscope +// CHECK: pto.addptr +// CHECK: pto.vlds +// CHECK: pto.vadds +// CHECK: pto.vsts +// CHECK-NOT: memref.cast +// CHECK-NOT: builtin.unrealized_conversion_cast + +module { + func.func @TADDS() { + %a = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + %scalar = arith.constant 1.0 : f32 + + pto.tadds ins(%a, %scalar : !pto.tile_buf, + f32) + outs(%tile_buf : !pto.tile_buf) + return + } +} diff --git a/test/dsl/strict_vecscope.py b/test/dsl/strict_vecscope.py new file mode 100644 index 000000000..7badac2ae --- /dev/null +++ b/test/dsl/strict_vecscope.py @@ -0,0 +1,50 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import mlir.dialects.pto as pto + + +@pto.vkernel(target="a5", name="abs_strict_vecscope_kernel_2d") +def abs_strict_vecscope_kernel_2d( + inp: pto.ptr(pto.f32, "gm"), out: pto.ptr(pto.f32, "gm") +): + ub_in = pto.castptr(0, pto.ptr(pto.f32, "ub")) + ub_out = pto.castptr(4096, pto.ptr(pto.f32, "ub")) + + pto.set_loop_size_outtoub(1, 1) + pto.copy_gm_to_ubuf(inp, ub_in, 0, 32, 128, 0, 0, False, 0, 128, 128) + + pto.set_flag("PIPE_MTE2", "PIPE_V", "EVENT_ID0") + pto.wait_flag("PIPE_MTE2", "PIPE_V", "EVENT_ID0") + + with pto.strict_vecscope(ub_in, ub_out, 0, 1024, 64, 1024) as ( + src, + dst, + lb, + ub, + step, + remaining, + ): + for offset in range(lb, ub, step): + mask, remaining = pto.plt_b32(remaining) + vec_in = pto.vlds(src, offset) + vec_out = pto.vabs(vec_in, mask) + pto.vsts(vec_out, dst, offset, mask) + + pto.set_flag("PIPE_V", "PIPE_MTE3", "EVENT_ID0") + pto.wait_flag("PIPE_V", "PIPE_MTE3", "EVENT_ID0") + + pto.set_loop_size_ubtoout(1, 1) + pto.copy_ubuf_to_gm(ub_out, out, 0, 32, 128, 0, 128, 128) + pto.barrier("PIPE_ALL") + + return + + +if __name__ == "__main__": + print(abs_strict_vecscope_kernel_2d.mlir_text(), end="") diff --git a/test/dsl/template_abs.py b/test/dsl/template_abs.py new file mode 100644 index 000000000..dc674adee --- /dev/null +++ b/test/dsl/template_abs.py @@ -0,0 +1,56 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import mlir.dialects.pto as pto + + +@pto.vkernel(target="a5", name="template_abs_kernel") +def template_abs_kernel(src: pto.Tile, dst: pto.Tile): + total = src.shape[0] * src.shape[1] + step = 256 // src.ub_ptr.elem_bytes + + with pto.strict_vecscope(src.ub_ptr, dst.ub_ptr, 0, total, step, total) as ( + vin, + vout, + lb, + ub, + vec_step, + remaining, + ): + for offset in range(lb, ub, vec_step): + mask, remaining = pto.plt_b32(remaining) + vec_in = pto.vlds(vin, offset) + vec_out = pto.vabs(vec_in, mask) + pto.vsts(vec_out, vout, offset, mask) + + +template_abs_kernel_f32 = template_abs_kernel.jit( + src=pto.Tile( + ub_ptr=pto.ptr(pto.f32, "ub"), + shape=pto.const([32, 32]), + ), + dst=pto.Tile( + ub_ptr=pto.ptr(pto.f32, "ub"), + shape=pto.const([32, 32]), + ), +) + +template_abs_kernel_f16 = template_abs_kernel.jit( + src=pto.Tile( + ub_ptr=pto.ptr(pto.f16, "ub"), + shape=pto.const([32, 32]), + ), + dst=pto.Tile( + ub_ptr=pto.ptr(pto.f16, "ub"), + shape=pto.const([32, 32]), + ), +) + + +if __name__ == "__main__": + print(template_abs_kernel_f32.mlir_text(), end="") diff --git a/test/lit.cfg.py b/test/lit.cfg.py new file mode 100644 index 000000000..ab7a8e01c --- /dev/null +++ b/test/lit.cfg.py @@ -0,0 +1,93 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import os +import lit.formats + +config.name = "PTOAS" +config.test_format = lit.formats.ShTest(execute_external=True) + +# Keep discovery focused on lit-style tests. +config.suffixes = [".mlir", ".pto"] +config.excludes = [ + "CMakeLists.txt", + "README.md", + "lit.cfg.py", + "resources", +] + +config.test_source_root = os.path.dirname(__file__) + + +def _resolve_build_root(): + env_build_dir = os.environ.get("PTOAS_BUILD_DIR") + if env_build_dir: + return os.path.abspath(env_build_dir) + + repo_root = os.path.abspath(os.path.join(config.test_source_root, "..")) + return os.path.join(repo_root, "build") + + +build_root = _resolve_build_root() +config.test_exec_root = os.path.join(build_root, "test") +os.makedirs(config.test_exec_root, exist_ok=True) + + +def _resolve_llvm_bin_dir(): + env_build_dir = os.environ.get("LLVM_BUILD_DIR") + candidates = [] + if env_build_dir: + candidates.append(os.path.join(os.path.abspath(env_build_dir), "bin")) + + repo_root = os.path.abspath(os.path.join(config.test_source_root, "..")) + candidates.append( + os.path.abspath( + os.path.join(repo_root, "..", "llvm-project", "build-shared", "bin") + ) + ) + + for candidate in candidates: + if os.path.isdir(candidate): + return candidate + return "" + + +def _resolve_ptoas_bin(): + env_bin = os.environ.get("PTOAS_BIN") + if env_bin: + return env_bin + + candidate = os.path.join(build_root, "tools", "ptoas", "ptoas") + if os.path.isfile(candidate) and os.access(candidate, os.X_OK): + return candidate + + return "ptoas" + + +def _prepend_path(path_var, entry): + if not entry: + return path_var + if not path_var: + return entry + return entry + os.pathsep + path_var + + +ptoas_bin = _resolve_ptoas_bin() +ptoas_dir = os.path.dirname(ptoas_bin) if os.path.isabs(ptoas_bin) else "" +llvm_bin_dir = _resolve_llvm_bin_dir() + +path_env = config.environment.get("PATH", os.environ.get("PATH", "")) +if llvm_bin_dir: + path_env = _prepend_path(path_env, llvm_bin_dir) +if ptoas_dir: + path_env = _prepend_path(path_env, ptoas_dir) +config.environment["PATH"] = path_env + +# Keep RUN lines using bare `ptoas` stable regardless of shell cwd. +if os.path.isabs(ptoas_bin): + config.substitutions.append(("ptoas", ptoas_bin)) diff --git a/test/lit/pto/expand_tile_op_trandom_tilelang.pto b/test/lit/pto/expand_tile_op_trandom_tilelang.pto new file mode 100644 index 000000000..bdf0bddd2 --- /dev/null +++ b/test/lit/pto/expand_tile_op_trandom_tilelang.pto @@ -0,0 +1,37 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --pto-arch=a5 --emit-pto-ir %s -o - | FileCheck %s +// CHECK-LABEL: func.func @TRandom_test + +// Test template for pto.trandom operator +// +// NOTE: pto.trandom uses Philox random number generation. +// It requires special operations (vci, vmull, vbr) that may not be +// fully supported in current TileLang DSL without extension. +// +// This test file demonstrates expected IR structure. + +module { + func.func @TRandom_test() { + %dst = pto.alloc_tile + : !pto.tile_buf + %seed = pto.alloc_tile + : !pto.tile_buf + %counter = pto.alloc_tile + : !pto.tile_buf + + // Note: pto.trandom may not be defined in current PTO IR + // This shows expected usage pattern when available + + return + } +} diff --git a/test/lit/pto/ptr_int_cast.pto b/test/lit/pto/ptr_int_cast.pto index bfbff8da0..1bf2b3038 100644 --- a/test/lit/pto/ptr_int_cast.pto +++ b/test/lit/pto/ptr_int_cast.pto @@ -24,12 +24,12 @@ module { } // IR-LABEL: func.func @ptr_int_cast_kernel -// IR: pto.ptrtoint {{.*}} : memref -> i64 +// IR: pto.ptrtoint {{.*}} : memref> -> i64 // IR: arith.index_cast {{.*}} : index to i64 // IR: arith.constant 8 : i64 // IR: arith.muli {{.*}} : i64 // IR: arith.addi {{.*}} : i64 -// IR: pto.inttoptr {{.*}} : i64 -> memref +// IR: pto.inttoptr {{.*}} : i64 -> memref> // IR-NOT: !pto.ptr // CPP-LABEL: AICORE void ptr_int_cast_kernel @@ -40,8 +40,8 @@ module { // CPP: [[DST]][[[IDX]]] = [[VAL]]; // IR-LABEL: func.func @ptrtoint_addptr_multi_consumer -// IR: pto.load_scalar {{.*}} : memref -> ui64 -// IR: pto.ptrtoint {{.*}} : memref -> i64 +// IR: pto.load_scalar {{.*}} : memref> -> ui64 +// IR: pto.ptrtoint {{.*}} : memref> -> i64 // IR: arith.index_cast {{.*}} : index to i64 // IR: arith.constant 8 : i64 // IR: arith.muli {{.*}} : i64 diff --git a/test/lit/pto/subview_col_major_row_plus_one_stride_offset.pto b/test/lit/pto/subview_col_major_row_plus_one_stride_offset.pto index 5b810f6a6..b83bb4bdf 100644 --- a/test/lit/pto/subview_col_major_row_plus_one_stride_offset.pto +++ b/test/lit/pto/subview_col_major_row_plus_one_stride_offset.pto @@ -1,4 +1,4 @@ -// RUN: ptoas --mlir-print-ir-after=pto-view-to-memref %s 2>&1 | FileCheck %s +// RUN: ptoas --emit-pto-ir %s 2>&1 | FileCheck %s module { func.func @subview_col_major_row_plus_one_stride_offset( @@ -21,6 +21,6 @@ module { } // ColMajor + RowPlusOne: major stride should be 17 (not 16). -// CHECK-DAG: memref.alloc() : memref<16x16xf32, strided<[1, 17]>, #pto.address_space> -// CHECK-DAG: memref.subview {{.*}} to memref<8x8xf32, strided<[1, 17], offset: ?>, #pto.address_space> -// CHECK-DAG: Tile +// CHECK: pto.pointer_cast(%c0_i64) {{.*}} : memref<16x16xf32, strided<[1, 17]>, #pto.address_space> +// CHECK: compact=#pto.compact_mode +// CHECK: memref.subview {{.*}} to memref<8x8xf32, strided<[1, 17], offset: ?>, #pto.address_space> diff --git a/test/lit/pto/subview_dynamic_offset_static_valid_regression.pto b/test/lit/pto/subview_dynamic_offset_static_valid_regression.pto index aa77d25c1..9826fc294 100644 --- a/test/lit/pto/subview_dynamic_offset_static_valid_regression.pto +++ b/test/lit/pto/subview_dynamic_offset_static_valid_regression.pto @@ -1,4 +1,4 @@ -// RUN: ptoas --mlir-print-ir-after=pto-plan-memory %s 2>&1 | FileCheck %s +// RUN: ptoas --emit-pto-ir %s -o - | FileCheck %s module { func.func @subview_dynamic_offset_static_valid_regression( @@ -24,4 +24,5 @@ module { } // CHECK: func.func @subview_dynamic_offset_static_valid_regression -// CHECK: Tile +// CHECK: memref.subview %{{.*}}[0, %{{.*}}] [1, 64] [1, 1] +// CHECK: pto.bind_tile %{{.*}}, %c1, %c64 diff --git a/test/lit/pto/subview_row_plus_one_stride_offset.pto b/test/lit/pto/subview_row_plus_one_stride_offset.pto index a1f5bc294..c2587b3a6 100644 --- a/test/lit/pto/subview_row_plus_one_stride_offset.pto +++ b/test/lit/pto/subview_row_plus_one_stride_offset.pto @@ -1,4 +1,4 @@ -// RUN: ptoas %s 2>&1 --mlir-print-ir-after=pto-view-to-memref | FileCheck %s +// RUN: ptoas %s 2>&1 --emit-pto-ir | FileCheck %s module { func.func @subview_row_plus_one_stride_offset( @@ -21,6 +21,6 @@ module { } // RowPlusOne: major stride should be 17 (not 16). -// CHECK-DAG: memref.alloc() : memref<16x16xf32, strided<[17, 1]>, #pto.address_space> -// CHECK-DAG: memref.subview {{.*}} to memref<8x8xf32, strided<[17, 1], offset: ?>, #pto.address_space> -// CHECK-DAG: Tile +// CHECK: pto.pointer_cast(%c0_i64) {{.*}} : memref<16x16xf32, strided<[17, 1]>, #pto.address_space> +// CHECK: compact=#pto.compact_mode +// CHECK: memref.subview {{.*}} to memref<8x8xf32, strided<[17, 1], offset: ?>, #pto.address_space> diff --git a/test/lit/pto/tadds_validrow_mismatch.pto b/test/lit/pto/tadds_validrow_mismatch.pto new file mode 100644 index 000000000..304291461 --- /dev/null +++ b/test/lit/pto/tadds_validrow_mismatch.pto @@ -0,0 +1,60 @@ +// RUN: not ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s +// Test: TADDS ValidRow consistency - more test cases + +module { + // Case 1: src ValidRow=1, dst ValidRow=4 - should fail + func.func @tadds_row_1_to_4_fail() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + %scalar = arith.constant 1.0 : f32 + + pto.tadds ins(%src, %scalar : !pto.tile_buf, f32) + outs(%dst : !pto.tile_buf) + return + } + // CHECK: expects src and dst to have the same valid_shape[0] + + // Case 2: src ValidRow=3, dst ValidRow=2 - should fail + func.func @tadds_row_3_to_2_fail() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + %scalar = arith.constant 1.0 : f32 + + pto.tadds ins(%src, %scalar : !pto.tile_buf, f32) + outs(%dst : !pto.tile_buf) + return + } + + // Case 3: ValidCol mismatch - should fail + func.func @tadds_col_mismatch_fail() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + %scalar = arith.constant 1.0 : f32 + + pto.tadds ins(%src, %scalar : !pto.tile_buf, f32) + outs(%dst : !pto.tile_buf) + return + } + + // Valid case 1: ValidRow=4, ValidCol=64 + func.func @tadds_row4_valid() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + %scalar = arith.constant 1.0 : f32 + + pto.tadds ins(%src, %scalar : !pto.tile_buf, f32) + outs(%dst : !pto.tile_buf) + return + } + + // Valid case 2: f16 type + func.func @tadds_f16_valid() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + %scalar = arith.constant 1.0 : f16 + + pto.tadds ins(%src, %scalar : !pto.tile_buf, f16) + outs(%dst : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/lit/pto/tilebuf_invalid_fractal_size.pto b/test/lit/pto/tilebuf_invalid_fractal_size.pto new file mode 100644 index 000000000..dc666e802 --- /dev/null +++ b/test/lit/pto/tilebuf_invalid_fractal_size.pto @@ -0,0 +1,8 @@ +// RUN: not ptoas --pto-arch=a3 %s 2>&1 | FileCheck %s + +// CHECK: error: unsupported s_fractal_size: 16, must be one of {32, 512, 1024} + +func.func @test_fractal_invalid() attributes {pto.kernel_kind = #pto.kernel_kind} { + %tb = pto.alloc_tile : !pto.tile_buf + return +} \ No newline at end of file diff --git a/test/lit/pto/tilebuf_invalid_valid_exceeds_shape.pto b/test/lit/pto/tilebuf_invalid_valid_exceeds_shape.pto new file mode 100644 index 000000000..5aeb585ca --- /dev/null +++ b/test/lit/pto/tilebuf_invalid_valid_exceeds_shape.pto @@ -0,0 +1,8 @@ +// RUN: not ptoas --pto-arch=a3 %s 2>&1 | FileCheck %s + +// CHECK: error: tile_buf valid_row (20) exceeds row (16) + +func.func @test_valid_exceeds_rows() attributes {pto.kernel_kind = #pto.kernel_kind} { + %tb = pto.alloc_tile : !pto.tile_buf + return +} diff --git a/test/lit/pto/tilebuf_invalid_zero_shape.pto b/test/lit/pto/tilebuf_invalid_zero_shape.pto new file mode 100644 index 000000000..feac782e3 --- /dev/null +++ b/test/lit/pto/tilebuf_invalid_zero_shape.pto @@ -0,0 +1,8 @@ +// RUN: not ptoas --pto-arch=a3 %s 2>&1 | FileCheck %s + +// CHECK: error: tile_buf rows/cols must be positive + +func.func @test_zero_rows() attributes {pto.kernel_kind = #pto.kernel_kind} { + %tb = pto.alloc_tile : !pto.tile_buf + return +} diff --git a/test/lit/pto/tilebuf_valid_positive.pto b/test/lit/pto/tilebuf_valid_positive.pto new file mode 100644 index 000000000..c94cae28a --- /dev/null +++ b/test/lit/pto/tilebuf_valid_positive.pto @@ -0,0 +1,39 @@ +// RUN: ptoas --pto-arch=a3 --emit-pto-ir %s -o - 2>&1 | FileCheck %s + +// CHECK: module attributes +// CHECK-LABEL: func.func @test_vec_unboxed_rowmajor +// CHECK-LABEL: func.func @test_vec_unboxed_colmajor +// CHECK-LABEL: func.func @test_left_boxed_fractal512 +// CHECK-LABEL: func.func @test_acc_fractal1024 +// CHECK-LABEL: func.func @test_fractal32 +// CHECK-LABEL: func.func @test_rows_one_exempt + +func.func @test_vec_unboxed_rowmajor() attributes {pto.kernel_kind = #pto.kernel_kind} { + %tb1 = pto.alloc_tile : !pto.tile_buf + return +} + +func.func @test_vec_unboxed_colmajor() attributes {pto.kernel_kind = #pto.kernel_kind} { + %tb2 = pto.alloc_tile : !pto.tile_buf + return +} + +func.func @test_left_boxed_fractal512() attributes {pto.kernel_kind = #pto.kernel_kind} { + %tb3 = pto.alloc_tile : !pto.tile_buf + return +} + +func.func @test_acc_fractal1024() attributes {pto.kernel_kind = #pto.kernel_kind} { + %tb4 = pto.alloc_tile : !pto.tile_buf + return +} + +func.func @test_fractal32() attributes {pto.kernel_kind = #pto.kernel_kind} { + %tb5 = pto.alloc_tile : !pto.tile_buf + return +} + +func.func @test_rows_one_exempt() attributes {pto.kernel_kind = #pto.kernel_kind} { + %tb6 = pto.alloc_tile : !pto.tile_buf + return +} diff --git a/test/lit/pto/tload_vec_layout_mismatch.pto b/test/lit/pto/tload_vec_layout_mismatch.pto new file mode 100644 index 000000000..259f921d3 --- /dev/null +++ b/test/lit/pto/tload_vec_layout_mismatch.pto @@ -0,0 +1,52 @@ +// RUN: not ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s +// Test: TLOAD.VEC layout validation for A5 + +module { + // Case 1: DN layout (valid for tload) + func.func @tload_dn_valid(%ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + + %tv = pto.make_tensor_view %ptr, shape = [%c1, %c1, %c1, %c1, %c64], strides = [%c64, %c64, %c64, %c64, %c1] : !pto.tensor_view<1x1x1x1x64xf32> + %pv = pto.partition_view %tv, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c1, %c64] : !pto.tensor_view<1x1x1x1x64xf32> -> !pto.partition_tensor_view<1x1x1x1x64xf32> + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%pv : !pto.partition_tensor_view<1x1x1x1x64xf32>) + outs(%dst : !pto.tile_buf) + return + } + + // Case 2: NZ layout (valid for tload) + func.func @tload_nz_valid(%ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + + %tv = pto.make_tensor_view %ptr, shape = [%c1, %c1, %c1, %c16, %c8], strides = [%c128, %c128, %c128, %c8, %c1] : !pto.tensor_view<1x1x1x16x8xf32> + %pv = pto.partition_view %tv, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c8] : !pto.tensor_view<1x1x1x16x8xf32> -> !pto.partition_tensor_view<1x1x1x16x8xf32> + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%pv : !pto.partition_tensor_view<1x1x1x16x8xf32>) + outs(%dst : !pto.tile_buf) + return + } + + // Case 3: Invalid layout combination (row_major + row_major) - should fail + func.func @tload_invalid_layout(%ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + + %tv = pto.make_tensor_view %ptr, shape = [%c1, %c1, %c1, %c1, %c64], strides = [%c64, %c64, %c64, %c64, %c1] : !pto.tensor_view<1x1x1x1x64xf32> + %pv = pto.partition_view %tv, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c1, %c64] : !pto.tensor_view<1x1x1x1x64xf32> -> !pto.partition_tensor_view<1x1x1x1x64xf32> + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%pv : !pto.partition_tensor_view<1x1x1x1x64xf32>) + outs(%dst : !pto.tile_buf) + return + } + // CHECK: expects A5 tload vec dst layout to be ND, DN, or NZ +} \ No newline at end of file diff --git a/test/lit/pto/tshl_dtype_mismatch.pto b/test/lit/pto/tshl_dtype_mismatch.pto new file mode 100644 index 000000000..1c2e4e85b --- /dev/null +++ b/test/lit/pto/tshl_dtype_mismatch.pto @@ -0,0 +1,60 @@ +// RUN: not ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s +// Test: TSHL dst data type consistency - more test cases + +module { + // Case 1: src0/src1=i16, dst=i32 - should fail + func.func @tshl_i16_to_i32_fail() { + %src0 = pto.alloc_tile : !pto.tile_buf + %src1 = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tshl ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } + // CHECK: expects src0, src1, and dst to have the same element type + + // Case 2: src0/src1=i32, dst=f32 - should fail + func.func @tshl_i32_to_f32_fail() { + %src0 = pto.alloc_tile : !pto.tile_buf + %src1 = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tshl ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } + + // Case 3: src0=i8, src1=i16 (mismatch) - should fail + func.func @tshl_src_mismatch_fail() { + %src0 = pto.alloc_tile : !pto.tile_buf + %src1 = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tshl ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } + + // Valid case 1: all i8 + func.func @tshl_i8_valid() { + %src0 = pto.alloc_tile : !pto.tile_buf + %src1 = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tshl ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } + + // Valid case 2: all i16 + func.func @tshl_i16_valid() { + %src0 = pto.alloc_tile : !pto.tile_buf + %src1 = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tshl ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/lit/pto/tshr_dtype_mismatch.pto b/test/lit/pto/tshr_dtype_mismatch.pto new file mode 100644 index 000000000..3215010b9 --- /dev/null +++ b/test/lit/pto/tshr_dtype_mismatch.pto @@ -0,0 +1,36 @@ +// RUN: not ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s +// Test: TSHR dtype mismatch - dst element type must match src0/src1 + +// Case 1: src0/src1=i32, dst=i16 - should fail +func.func @tshr_dtype_mismatch_case1() { + %src0 = pto.alloc_tile : !pto.tile_buf + %src1 = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tshr ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return +} +// CHECK: error: 'pto.tshr' op expects src0, src1, and dst to have the same element type + +// Case 2: src0/src1=i8, dst=i32 - should fail +func.func @tshr_dtype_mismatch_case2() { + %src0 = pto.alloc_tile : !pto.tile_buf + %src1 = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tshr ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return +} + +// Valid case: all i32 - should pass +func.func @tshr_dtype_valid() { + %src0 = pto.alloc_tile : !pto.tile_buf + %src1 = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tshr ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return +} \ No newline at end of file diff --git a/test/lit/vpto/acc_store.pto b/test/lit/vpto/acc_store.pto new file mode 100644 index 000000000..8f5338dcb --- /dev/null +++ b/test/lit/vpto/acc_store.pto @@ -0,0 +1,39 @@ +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --mlir-print-ir-before=vpto-expand-wrapper-ops %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=ROUNDTRIP +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --mlir-print-ir-after=vpto-expand-wrapper-ops %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=EXPAND + +module attributes {"pto.target_arch" = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @acc_store_probe(%src: !pto.ptr, + %dst: !pto.ptr) { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + %c1_f32 = arith.constant 1.000000e+00 : f32 + %c025_f32 = arith.constant 2.500000e-01 : f32 + + pto.mte_l0c_l1 %src, %dst, %c16_i64, %c16_i64, %c16_i64, %c32_i64, + unit_flag(check_only), + pre_quant(%c1_f32, mode = qf322f16_pre_scalar), + pre_relu(%c025_f32, mode = scalar_relu), + nz2nd, + loop3(%c1_i64, %c0_i64, %c32_i64), + sat + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, f32, f32, + i64, i64, i64 + + return + } +} + +// ROUNDTRIP-LABEL: func.func @acc_store_probe( +// ROUNDTRIP: pto.mte_l0c_l1 %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, unit_flag(check_only), pre_quant(%{{.*}}, mode = qf322f16_pre_scalar), pre_relu(%{{.*}}, mode = scalar_relu), nz2nd, loop3(%{{.*}}, %{{.*}}, %{{.*}}), sat + +// EXPAND-LABEL: func.func @acc_store_probe( +// EXPAND: pto.set_quant_pre +// EXPAND: pto.set_relu_alpha +// EXPAND: pto.get_ctrl +// EXPAND: pto.set_ctrl +// EXPAND: pto.set_loop3_para +// EXPAND: pto.set_channel_para +// EXPAND: pto.copy_matrix_cc_to_cbuf +// EXPAND: pto.set_ctrl diff --git a/test/lit/vpto/acc_store_gm.pto b/test/lit/vpto/acc_store_gm.pto new file mode 100644 index 000000000..16ad44b47 --- /dev/null +++ b/test/lit/vpto/acc_store_gm.pto @@ -0,0 +1,42 @@ +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --mlir-print-ir-before=vpto-expand-wrapper-ops %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=ROUNDTRIP +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --mlir-print-ir-after=vpto-expand-wrapper-ops %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=EXPAND + +module attributes {"pto.target_arch" = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @acc_store_gm_probe(%src: !pto.ptr, + %dst: !pto.ptr, + %qfb: !pto.ptr, + %rfb: !pto.ptr) { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c3_i64 = arith.constant 3 : i64 + %c5_i64 = arith.constant 5 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + + pto.mte_l0c_gm %src, %dst, %c16_i64, %c16_i64, %c16_i64, %c32_i64, + %c3_i64, %c5_i64, + unit_flag(check_and_clear), + pre_quant(%qfb, mode = qf322f16_pre_vec), + pre_relu(%rfb, mode = vector_relu), + nz2dn(%c1_i64), + atomic(type = f32, op = max) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, + !pto.ptr, !pto.ptr, i64 + + return + } +} + +// ROUNDTRIP-LABEL: func.func @acc_store_gm_probe( +// ROUNDTRIP: pto.mte_l0c_gm %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, unit_flag(check_and_clear), pre_quant(%{{.*}}, mode = qf322f16_pre_vec), pre_relu(%{{.*}}, mode = vector_relu), nz2dn(%{{.*}}), atomic(type = f32, op = max) + +// EXPAND-LABEL: func.func @acc_store_gm_probe( +// EXPAND: %[[QFB:.*]] = pto.castptr %{{.*}} : !pto.ptr -> i64 +// EXPAND: %[[RFB:.*]] = pto.castptr %{{.*}} : !pto.ptr -> i64 +// EXPAND: pto.set_fpc +// EXPAND: pto.get_ctrl +// EXPAND: pto.set_ctrl +// EXPAND: pto.set_loop3_para +// EXPAND: pto.set_channel_para +// EXPAND: pto.copy_matrix_cc_to_gm +// EXPAND: pto.set_ctrl diff --git a/test/lit/vpto/acc_store_gm_sat.pto b/test/lit/vpto/acc_store_gm_sat.pto new file mode 100644 index 000000000..dae7d153c --- /dev/null +++ b/test/lit/vpto/acc_store_gm_sat.pto @@ -0,0 +1,38 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --mlir-print-ir-before=vpto-expand-wrapper-ops %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=ROUNDTRIP +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --mlir-print-ir-after=vpto-expand-wrapper-ops %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=EXPAND + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @acc_store_gm_sat_probe(%src: !pto.ptr, + %dst: !pto.ptr) { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c16_i64 = arith.constant 16 : i64 + + pto.mte_l0c_gm %src, %dst, %c16_i64, %c16_i64, %c16_i64, %c16_i64, %c0_i64, %c0_i64, + nz2nd, + loop3(%c1_i64, %c0_i64, %c16_i64), + sat + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i64, i64, i64 + + return + } +} + +// ROUNDTRIP-LABEL: func.func @acc_store_gm_sat_probe( +// ROUNDTRIP: pto.mte_l0c_gm %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, nz2nd, loop3(%{{.*}}, %{{.*}}, %{{.*}}), sat + +// EXPAND-LABEL: func.func @acc_store_gm_sat_probe( +// EXPAND: pto.get_ctrl +// EXPAND: pto.set_ctrl +// EXPAND: pto.set_loop3_para +// EXPAND: pto.set_channel_para +// EXPAND: pto.copy_matrix_cc_to_gm +// EXPAND: pto.set_ctrl diff --git a/test/lit/vpto/acc_store_gm_unit_flag.pto b/test/lit/vpto/acc_store_gm_unit_flag.pto new file mode 100644 index 000000000..ddbab8b9d --- /dev/null +++ b/test/lit/vpto/acc_store_gm_unit_flag.pto @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --mlir-print-ir-before=vpto-expand-wrapper-ops %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=ROUNDTRIP +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --mlir-print-ir-after=vpto-expand-wrapper-ops %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=EXPAND + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @acc_store_gm_unit_flag_probe(%src: !pto.ptr, + %dst: !pto.ptr) { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c16_i64 = arith.constant 16 : i64 + + pto.mte_l0c_gm %src, %dst, %c16_i64, %c16_i64, %c16_i64, %c16_i64, %c0_i64, %c0_i64, + unit_flag(check_and_clear), + nz2nd, + loop3(%c1_i64, %c0_i64, %c16_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i64, i64, i64 + + return + } +} + +// ROUNDTRIP-LABEL: func.func @acc_store_gm_unit_flag_probe( +// ROUNDTRIP: pto.mte_l0c_gm %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, unit_flag(check_and_clear), nz2nd, loop3(%{{.*}}, %{{.*}}, %{{.*}}) + +// EXPAND-LABEL: func.func @acc_store_gm_unit_flag_probe( +// EXPAND: %[[XT:.*]] = arith.constant 8808977924112 : i64 +// EXPAND: %[[XM:.*]] = arith.constant 68720525568 : i64 +// EXPAND: pto.copy_matrix_cc_to_gm %{{.*}}, %{{.*}}, %[[XM]], %[[XT]] diff --git a/test/lit/vpto/acc_store_gm_verify_invalid_duplicate_sat.pto b/test/lit/vpto/acc_store_gm_verify_invalid_duplicate_sat.pto new file mode 100644 index 000000000..b9eb3d7f0 --- /dev/null +++ b/test/lit/vpto/acc_store_gm_verify_invalid_duplicate_sat.pto @@ -0,0 +1,27 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --pto-arch=a5 %s -o - 2>&1 | FileCheck %s + +module attributes {"pto.target_arch" = "a5"} { + func.func @acc_store_gm_verify_invalid_duplicate_sat(%src: !pto.ptr, + %dst: !pto.ptr) { + %c0_i64 = arith.constant 0 : i64 + %c16_i64 = arith.constant 16 : i64 + + pto.mte_l0c_gm %src, %dst, %c16_i64, %c16_i64, %c16_i64, %c16_i64, %c0_i64, %c0_i64, + nz2nd, + sat, + nosat + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + + return + } +} + +// CHECK: duplicate sat/nosat clause diff --git a/test/lit/vpto/acc_store_pre_relu_clip.pto b/test/lit/vpto/acc_store_pre_relu_clip.pto new file mode 100644 index 000000000..047dd4a45 --- /dev/null +++ b/test/lit/vpto/acc_store_pre_relu_clip.pto @@ -0,0 +1,36 @@ +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --mlir-print-ir-before=vpto-expand-wrapper-ops %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=ROUNDTRIP +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --mlir-print-ir-after=vpto-expand-wrapper-ops %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=EXPAND + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @acc_store_pre_relu_clip_probe(%src: !pto.ptr, + %dst: !pto.ptr) { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + %c1_f32 = arith.constant 1.000000e+00 : f32 + %c025_f32 = arith.constant 2.500000e-01 : f32 + %c8_f16 = arith.constant 8.000000e+00 : f16 + + pto.mte_l0c_ub %src, %dst, %c16_i64, %c16_i64, %c16_i64, %c32_i64, dst_mode(%c0_i64), + pre_quant(%c1_f32, mode = qf322f16_pre_scalar), + pre_relu(%c025_f32, mode = scalar_relu, clip = %c8_f16), + nz2nd, + loop3(%c1_i64, %c0_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, + f32, f32, f16, i64, i64, i64 + + return + } +} + +// ROUNDTRIP-LABEL: func.func @acc_store_pre_relu_clip_probe( +// ROUNDTRIP: pto.mte_l0c_ub %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, dst_mode(%{{.*}}), pre_quant(%{{.*}}, mode = qf322f16_pre_scalar), pre_relu(%{{.*}}, mode = scalar_relu, clip = %{{.*}}), nz2nd, loop3(%{{.*}}, %{{.*}}, %{{.*}}) + +// EXPAND-LABEL: func.func @acc_store_pre_relu_clip_probe( +// EXPAND: pto.set_quant_pre +// EXPAND: pto.set_relu_alpha +// EXPAND: pto.set_fix_clip_relu +// EXPAND: pto.set_loop3_para +// EXPAND: pto.set_channel_para +// EXPAND: pto.copy_matrix_cc_to_ub diff --git a/test/lit/vpto/acc_store_sat.pto b/test/lit/vpto/acc_store_sat.pto new file mode 100644 index 000000000..34b8a7368 --- /dev/null +++ b/test/lit/vpto/acc_store_sat.pto @@ -0,0 +1,62 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --mlir-print-ir-before=vpto-expand-wrapper-ops %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=ROUNDTRIP +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --mlir-print-ir-after=vpto-expand-wrapper-ops %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=EXPAND + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @acc_store_sat_probe(%src: !pto.ptr, + %dst: !pto.ptr) { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c16_i64 = arith.constant 16 : i64 + + pto.mte_l0c_l1 %src, %dst, %c16_i64, %c16_i64, %c16_i64, %c16_i64, + nz2nd, + loop3(%c1_i64, %c0_i64, %c16_i64), + sat + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i64 + pto.mte_l0c_l1 %src, %dst, %c16_i64, %c16_i64, %c16_i64, %c16_i64, + nz2nd, + loop3(%c1_i64, %c0_i64, %c16_i64), + sat(preserve_nan) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i64 + pto.mte_l0c_l1 %src, %dst, %c16_i64, %c16_i64, %c16_i64, %c16_i64, + nz2nd, + loop3(%c1_i64, %c0_i64, %c16_i64), + nosat + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i64 + + return + } +} + +// ROUNDTRIP-LABEL: func.func @acc_store_sat_probe( +// ROUNDTRIP: pto.mte_l0c_l1 %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, nz2nd, loop3(%{{.*}}, %{{.*}}, %{{.*}}), sat +// ROUNDTRIP: pto.mte_l0c_l1 %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, nz2nd, loop3(%{{.*}}, %{{.*}}, %{{.*}}), sat(preserve_nan) +// ROUNDTRIP: pto.mte_l0c_l1 %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, nz2nd, loop3(%{{.*}}, %{{.*}}, %{{.*}}), nosat + +// EXPAND-LABEL: func.func @acc_store_sat_probe( +// EXPAND: pto.get_ctrl +// EXPAND: pto.set_ctrl +// EXPAND: pto.set_loop3_para +// EXPAND: pto.set_channel_para +// EXPAND: pto.copy_matrix_cc_to_cbuf +// EXPAND: pto.set_ctrl +// EXPAND: pto.get_ctrl +// EXPAND: pto.set_ctrl +// EXPAND: pto.set_loop3_para +// EXPAND: pto.set_channel_para +// EXPAND: pto.copy_matrix_cc_to_cbuf +// EXPAND: pto.set_ctrl +// EXPAND: pto.get_ctrl +// EXPAND: pto.set_ctrl +// EXPAND: pto.set_loop3_para +// EXPAND: pto.set_channel_para +// EXPAND: pto.copy_matrix_cc_to_cbuf +// EXPAND: pto.set_ctrl diff --git a/test/lit/vpto/acc_store_ub.pto b/test/lit/vpto/acc_store_ub.pto new file mode 100644 index 000000000..3f5b6603e --- /dev/null +++ b/test/lit/vpto/acc_store_ub.pto @@ -0,0 +1,33 @@ +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --mlir-print-ir-before=vpto-expand-wrapper-ops %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=ROUNDTRIP +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --mlir-print-ir-after=vpto-expand-wrapper-ops %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=EXPAND + +module attributes {"pto.target_arch" = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @acc_store_ub_probe(%src: !pto.ptr, + %dst: !pto.ptr) { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c2_i64 = arith.constant 2 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + + pto.mte_l0c_ub %src, %dst, %c16_i64, %c16_i64, %c16_i64, %c32_i64, dst_mode(%c0_i64), + pre_relu(mode = normal_relu), + nz2nz(%c1_i64), + sat + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, + i64 + + return + } +} + +// ROUNDTRIP-LABEL: func.func @acc_store_ub_probe( +// ROUNDTRIP: pto.mte_l0c_ub %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, dst_mode(%{{.*}}), pre_relu(mode = normal_relu), nz2nz(%{{.*}}), sat + +// EXPAND-LABEL: func.func @acc_store_ub_probe( +// EXPAND: pto.get_ctrl +// EXPAND: pto.set_ctrl +// EXPAND: pto.set_loop3_para +// EXPAND: pto.set_channel_para +// EXPAND: pto.copy_matrix_cc_to_ub +// EXPAND: pto.set_ctrl diff --git a/test/lit/vpto/acc_store_ub_clip.pto b/test/lit/vpto/acc_store_ub_clip.pto new file mode 100644 index 000000000..643024a5f --- /dev/null +++ b/test/lit/vpto/acc_store_ub_clip.pto @@ -0,0 +1,32 @@ +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --mlir-print-ir-before=vpto-expand-wrapper-ops %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=ROUNDTRIP +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --mlir-print-ir-after=vpto-expand-wrapper-ops %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=EXPAND + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @acc_store_ub_clip_probe(%src: !pto.ptr, + %dst: !pto.ptr) { + %c0_i64 = arith.constant 0 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + %c1_f32 = arith.constant 1.000000e+00 : f32 + %c8_f16 = arith.constant 8.000000e+00 : f16 + + pto.mte_l0c_ub %src, %dst, %c16_i64, %c16_i64, %c16_i64, %c32_i64, dst_mode(%c0_i64), + pre_quant(%c1_f32, mode = qf322f16_pre_scalar), + pre_relu(mode = no_relu, clip = %c8_f16), + nz2nd + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, + f32, f16 + + return + } +} + +// ROUNDTRIP-LABEL: func.func @acc_store_ub_clip_probe( +// ROUNDTRIP: pto.mte_l0c_ub %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, dst_mode(%{{.*}}), pre_quant(%{{.*}}, mode = qf322f16_pre_scalar), pre_relu(mode = no_relu, clip = %{{.*}}), nz2nd + +// EXPAND-LABEL: func.func @acc_store_ub_clip_probe( +// EXPAND: pto.set_quant_pre +// EXPAND: pto.set_fix_clip_relu +// EXPAND: pto.set_loop3_para +// EXPAND: pto.set_channel_para +// EXPAND: pto.copy_matrix_cc_to_ub diff --git a/test/lit/vpto/acc_store_ub_clip_ui8.pto b/test/lit/vpto/acc_store_ub_clip_ui8.pto new file mode 100644 index 000000000..00768c6bd --- /dev/null +++ b/test/lit/vpto/acc_store_ub_clip_ui8.pto @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --mlir-print-ir-before=vpto-expand-wrapper-ops %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=ROUNDTRIP +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --mlir-print-ir-after=vpto-expand-wrapper-ops %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=EXPAND + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @acc_store_ub_clip_ui8_probe(%src: !pto.ptr, + %dst: !pto.ptr) { + %c0_i64 = arith.constant 0 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + %c1_f32 = arith.constant 1.000000e+00 : f32 + %c255_i16 = arith.constant 255 : i16 + + pto.mte_l0c_ub %src, %dst, %c16_i64, %c16_i64, %c16_i64, %c32_i64, dst_mode(%c0_i64), + pre_quant(%c1_f32, mode = qf322b8_pre_scalar), + pre_relu(mode = no_relu, clip = %c255_i16), + nz2nd + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, + f32, i16 + + return + } +} + +// ROUNDTRIP-LABEL: func.func @acc_store_ub_clip_ui8_probe( +// ROUNDTRIP: pto.mte_l0c_ub +// ROUNDTRIP-SAME: pre_quant(%{{.*}}, mode = qf322b8_pre_scalar) +// ROUNDTRIP-SAME: pre_relu(mode = no_relu, clip = %{{.*}}) + +// EXPAND-LABEL: func.func @acc_store_ub_clip_ui8_probe( +// EXPAND: pto.set_quant_pre +// EXPAND: pto.set_fix_clip_relu +// EXPAND: pto.copy_matrix_cc_to_ub diff --git a/test/lit/vpto/acc_store_ub_sat.pto b/test/lit/vpto/acc_store_ub_sat.pto new file mode 100644 index 000000000..eceefa0f6 --- /dev/null +++ b/test/lit/vpto/acc_store_ub_sat.pto @@ -0,0 +1,38 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --mlir-print-ir-before=vpto-expand-wrapper-ops %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=ROUNDTRIP +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --mlir-print-ir-after=vpto-expand-wrapper-ops %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=EXPAND + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @acc_store_ub_sat_probe(%src: !pto.ptr, + %dst: !pto.ptr) { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c16_i64 = arith.constant 16 : i64 + + pto.mte_l0c_ub %src, %dst, %c16_i64, %c16_i64, %c16_i64, %c16_i64, dst_mode(%c0_i64), + nz2nd, + loop3(%c1_i64, %c0_i64, %c16_i64), + sat + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i64, i64 + + return + } +} + +// ROUNDTRIP-LABEL: func.func @acc_store_ub_sat_probe( +// ROUNDTRIP: pto.mte_l0c_ub %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, dst_mode(%{{.*}}), nz2nd, loop3(%{{.*}}, %{{.*}}, %{{.*}}), sat + +// EXPAND-LABEL: func.func @acc_store_ub_sat_probe( +// EXPAND: pto.get_ctrl +// EXPAND: pto.set_ctrl +// EXPAND: pto.set_loop3_para +// EXPAND: pto.set_channel_para +// EXPAND: pto.copy_matrix_cc_to_ub +// EXPAND: pto.set_ctrl diff --git a/test/lit/vpto/acc_store_ub_unit_flag.pto b/test/lit/vpto/acc_store_ub_unit_flag.pto new file mode 100644 index 000000000..0dddc8d43 --- /dev/null +++ b/test/lit/vpto/acc_store_ub_unit_flag.pto @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --mlir-print-ir-before=vpto-expand-wrapper-ops %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=ROUNDTRIP +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --mlir-print-ir-after=vpto-expand-wrapper-ops %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=EXPAND + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @acc_store_ub_unit_flag_probe(%src: !pto.ptr, + %dst: !pto.ptr) { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c16_i64 = arith.constant 16 : i64 + + pto.mte_l0c_ub %src, %dst, %c16_i64, %c16_i64, %c16_i64, %c16_i64, dst_mode(%c0_i64), + unit_flag(check_only), + nz2nd, + loop3(%c1_i64, %c0_i64, %c16_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i64, i64 + + return + } +} + +// ROUNDTRIP-LABEL: func.func @acc_store_ub_unit_flag_probe( +// ROUNDTRIP: pto.mte_l0c_ub %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, dst_mode(%{{.*}}), unit_flag(check_only), nz2nd, loop3(%{{.*}}, %{{.*}}, %{{.*}}) + +// EXPAND-LABEL: func.func @acc_store_ub_unit_flag_probe( +// EXPAND: %[[XT:.*]] = arith.constant 8804682956816 : i64 +// EXPAND: %[[XM:.*]] = arith.constant 68720525568 : i64 +// EXPAND: pto.copy_matrix_cc_to_ub %{{.*}}, %{{.*}}, %[[XM]], %[[XT]] diff --git a/test/lit/vpto/acc_store_ub_verify_invalid_atomic.pto b/test/lit/vpto/acc_store_ub_verify_invalid_atomic.pto new file mode 100644 index 000000000..e9958f6cb --- /dev/null +++ b/test/lit/vpto/acc_store_ub_verify_invalid_atomic.pto @@ -0,0 +1,19 @@ +// RUN: not ptoas --pto-arch=a5 %s -o - 2>&1 | FileCheck %s + +module attributes {"pto.target_arch" = "a5"} { + func.func @acc_store_ub_verify_invalid_atomic(%src: !pto.ptr, + %dst: !pto.ptr) { + %c0_i64 = arith.constant 0 : i64 + %c2_i64 = arith.constant 2 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + + pto.mte_l0c_ub %src, %dst, %c16_i64, %c16_i64, %c16_i64, %c32_i64, dst_mode(%c0_i64), + atomic(type = f32, op = add) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + return + } +} + +// CHECK: atomic is only supported for mte_l0c_gm diff --git a/test/lit/vpto/acc_store_ub_verify_invalid_dual_dst_mode.pto b/test/lit/vpto/acc_store_ub_verify_invalid_dual_dst_mode.pto new file mode 100644 index 000000000..d37b0906a --- /dev/null +++ b/test/lit/vpto/acc_store_ub_verify_invalid_dual_dst_mode.pto @@ -0,0 +1,26 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --pto-arch=a5 %s -o - 2>&1 | FileCheck %s + +module attributes {"pto.target_arch" = "a5"} { + func.func @acc_store_ub_verify_invalid_dual_dst_mode(%src: !pto.ptr, + %dst: !pto.ptr) { + %c0_i64 = arith.constant 0 : i64 + %c3_i64 = arith.constant 3 : i64 + %c16_i64 = arith.constant 16 : i64 + + pto.mte_l0c_ub %src, %dst, %c16_i64, %c16_i64, %c16_i64, %c16_i64, dst_mode(bad_mode), + sat + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + + return + } +} + +// CHECK: expected dst_mode(%sub_blockid), dst_mode(split_m), or dst_mode(split_n) diff --git a/test/lit/vpto/acc_store_ub_verify_invalid_dual_dst_shape.pto b/test/lit/vpto/acc_store_ub_verify_invalid_dual_dst_shape.pto new file mode 100644 index 000000000..d02412ca0 --- /dev/null +++ b/test/lit/vpto/acc_store_ub_verify_invalid_dual_dst_shape.pto @@ -0,0 +1,27 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --pto-arch=a5 %s -o - 2>&1 | FileCheck %s + +module attributes {"pto.target_arch" = "a5"} { + func.func @acc_store_ub_verify_invalid_dual_dst_shape(%src: !pto.ptr, + %dst: !pto.ptr) { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c15_i64 = arith.constant 15 : i64 + %c16_i64 = arith.constant 16 : i64 + + pto.mte_l0c_ub %src, %dst, %c15_i64, %c16_i64, %c16_i64, %c16_i64, dst_mode(split_m), + sat + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + return + } +} + +// CHECK: split-M dual destination requires m to be even diff --git a/test/lit/vpto/acc_store_ub_verify_invalid_dual_dst_transform.pto b/test/lit/vpto/acc_store_ub_verify_invalid_dual_dst_transform.pto new file mode 100644 index 000000000..d6ceed9f3 --- /dev/null +++ b/test/lit/vpto/acc_store_ub_verify_invalid_dual_dst_transform.pto @@ -0,0 +1,26 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --pto-arch=a5 %s -o - 2>&1 | FileCheck %s + +module attributes {"pto.target_arch" = "a5"} { + func.func @acc_store_ub_verify_invalid_dual_dst_transform(%src: !pto.ptr, + %dst: !pto.ptr) { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c16_i64 = arith.constant 16 : i64 + + pto.mte_l0c_ub %src, %dst, %c16_i64, %c16_i64, %c16_i64, %c16_i64, dst_mode(split_m), + pre_relu(mode = normal_relu) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + return + } +} + +// CHECK: dual destination mode cannot be combined with pre_quant, pre_relu, clip, nz2dn, nz2nz, or loop3 diff --git a/test/lit/vpto/acc_store_ub_verify_invalid_duplicate_sat.pto b/test/lit/vpto/acc_store_ub_verify_invalid_duplicate_sat.pto new file mode 100644 index 000000000..e58b61f81 --- /dev/null +++ b/test/lit/vpto/acc_store_ub_verify_invalid_duplicate_sat.pto @@ -0,0 +1,27 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --pto-arch=a5 %s -o - 2>&1 | FileCheck %s + +module attributes {"pto.target_arch" = "a5"} { + func.func @acc_store_ub_verify_invalid_duplicate_sat(%src: !pto.ptr, + %dst: !pto.ptr) { + %c0_i64 = arith.constant 0 : i64 + %c16_i64 = arith.constant 16 : i64 + + pto.mte_l0c_ub %src, %dst, %c16_i64, %c16_i64, %c16_i64, %c16_i64, dst_mode(%c0_i64), + nz2nd, + sat, + nosat + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + return + } +} + +// CHECK: duplicate sat/nosat clause diff --git a/test/lit/vpto/acc_store_ub_verify_invalid_unit_flag_nz2dn.pto b/test/lit/vpto/acc_store_ub_verify_invalid_unit_flag_nz2dn.pto new file mode 100644 index 000000000..012d35b4a --- /dev/null +++ b/test/lit/vpto/acc_store_ub_verify_invalid_unit_flag_nz2dn.pto @@ -0,0 +1,28 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --pto-arch=a5 %s -o - 2>&1 | FileCheck %s + +module attributes {"pto.target_arch" = "a5"} { + func.func @acc_store_ub_verify_invalid_unit_flag_nz2dn(%src: !pto.ptr, + %dst: !pto.ptr) { + %c0_i64 = arith.constant 0 : i64 + %c2_i64 = arith.constant 2 : i64 + %c16_i64 = arith.constant 16 : i64 + + pto.mte_l0c_ub %src, %dst, %c16_i64, %c16_i64, %c16_i64, %c16_i64, dst_mode(%c0_i64), + unit_flag(check_only), + nz2dn(%c2_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, + i64 + + return + } +} + +// CHECK: unit_flag must be off when nz2dn loop0_src_stride is not 1 diff --git a/test/lit/vpto/acc_store_unit_flag.pto b/test/lit/vpto/acc_store_unit_flag.pto new file mode 100644 index 000000000..2446cbfcc --- /dev/null +++ b/test/lit/vpto/acc_store_unit_flag.pto @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --mlir-print-ir-before=vpto-expand-wrapper-ops %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=ROUNDTRIP +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --mlir-print-ir-after=vpto-expand-wrapper-ops %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=EXPAND + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @acc_store_unit_flag_probe(%src: !pto.ptr, + %dst: !pto.ptr) { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c16_i64 = arith.constant 16 : i64 + + pto.mte_l0c_l1 %src, %dst, %c16_i64, %c16_i64, %c16_i64, %c16_i64, + unit_flag(check_only), + nz2nd, + loop3(%c1_i64, %c0_i64, %c16_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i64 + + return + } +} + +// ROUNDTRIP-LABEL: func.func @acc_store_unit_flag_probe( +// ROUNDTRIP: pto.mte_l0c_l1 %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, unit_flag(check_only), nz2nd, loop3(%{{.*}}, %{{.*}}, %{{.*}}) + +// EXPAND-LABEL: func.func @acc_store_unit_flag_probe( +// EXPAND: %[[XT:.*]] = arith.constant 8804682956816 : i64 +// EXPAND: %[[XM:.*]] = arith.constant 68720525568 : i64 +// EXPAND: pto.copy_matrix_cc_to_cbuf %{{.*}}, %{{.*}}, %[[XM]], %[[XT]] diff --git a/test/lit/vpto/acc_store_verify_invalid_atomic.pto b/test/lit/vpto/acc_store_verify_invalid_atomic.pto new file mode 100644 index 000000000..29ff65817 --- /dev/null +++ b/test/lit/vpto/acc_store_verify_invalid_atomic.pto @@ -0,0 +1,17 @@ +// RUN: not ptoas --pto-arch=a5 %s -o - 2>&1 | FileCheck %s + +module attributes {"pto.target_arch" = "a5"} { + func.func @acc_store_verify_invalid_atomic(%src: !pto.ptr, + %dst: !pto.ptr) { + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + + pto.mte_l0c_l1 %src, %dst, %c16_i64, %c16_i64, %c16_i64, %c32_i64, + atomic(type = f32, op = add) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + return + } +} + +// CHECK: atomic is only supported for mte_l0c_gm diff --git a/test/lit/vpto/acc_store_verify_invalid_clip_payload_ui8.pto b/test/lit/vpto/acc_store_verify_invalid_clip_payload_ui8.pto new file mode 100644 index 000000000..2fe8df4a9 --- /dev/null +++ b/test/lit/vpto/acc_store_verify_invalid_clip_payload_ui8.pto @@ -0,0 +1,28 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --pto-arch=a5 %s -o - 2>&1 | FileCheck %s + +module attributes {"pto.target_arch" = "a5"} { + func.func @acc_store_verify_invalid_clip_payload_ui8(%src: !pto.ptr, + %dst: !pto.ptr) { + %c0_i64 = arith.constant 0 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + %c7_i8 = arith.constant 7 : i8 + + pto.mte_l0c_ub %src, %dst, %c16_i64, %c16_i64, %c16_i64, %c32_i64, dst_mode(%c0_i64), + pre_relu(mode = no_relu, clip = %c7_i8) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, + i8 + + return + } +} + +// CHECK: clip for ui8 destination requires ui16/signless i16 payload diff --git a/test/lit/vpto/acc_store_verify_invalid_clip_type.pto b/test/lit/vpto/acc_store_verify_invalid_clip_type.pto new file mode 100644 index 000000000..71ac0083b --- /dev/null +++ b/test/lit/vpto/acc_store_verify_invalid_clip_type.pto @@ -0,0 +1,19 @@ +// RUN: not ptoas --pto-arch=a5 %s -o - 2>&1 | FileCheck %s + +module attributes {"pto.target_arch" = "a5"} { + func.func @acc_store_verify_invalid_clip_type(%src: !pto.ptr, + %dst: !pto.ptr) { + %c0_i64 = arith.constant 0 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + + pto.mte_l0c_ub %src, %dst, %c16_i64, %c16_i64, %c16_i64, %c32_i64, dst_mode(%c0_i64), + pre_relu(mode = no_relu, clip = %c16_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, + i64 + + return + } +} + +// CHECK: clip requires destination element type to be f16, ui8, or signed 4/8/16-bit integer diff --git a/test/lit/vpto/acc_store_verify_invalid_duplicate_sat.pto b/test/lit/vpto/acc_store_verify_invalid_duplicate_sat.pto new file mode 100644 index 000000000..6af2876df --- /dev/null +++ b/test/lit/vpto/acc_store_verify_invalid_duplicate_sat.pto @@ -0,0 +1,26 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --pto-arch=a5 %s -o - 2>&1 | FileCheck %s + +module attributes {"pto.target_arch" = "a5"} { + func.func @acc_store_verify_invalid_duplicate_sat(%src: !pto.ptr, + %dst: !pto.ptr) { + %c16_i64 = arith.constant 16 : i64 + + pto.mte_l0c_l1 %src, %dst, %c16_i64, %c16_i64, %c16_i64, %c16_i64, + nz2nd, + sat, + nosat + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + return + } +} + +// CHECK: duplicate sat/nosat clause diff --git a/test/lit/vpto/acc_store_verify_invalid_nz2dn.pto b/test/lit/vpto/acc_store_verify_invalid_nz2dn.pto new file mode 100644 index 000000000..2a77f92f8 --- /dev/null +++ b/test/lit/vpto/acc_store_verify_invalid_nz2dn.pto @@ -0,0 +1,18 @@ +// RUN: not ptoas --pto-arch=a5 %s -o - 2>&1 | FileCheck %s + +module attributes {"pto.target_arch" = "a5"} { + func.func @acc_store_verify_invalid_nz2dn(%src: !pto.ptr, + %dst: !pto.ptr) { + %c0_i64 = arith.constant 0 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + + pto.mte_l0c_gm %src, %dst, %c16_i64, %c16_i64, %c16_i64, %c32_i64, %c0_i64, %c0_i64, + nz2dn + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + + return + } +} + +// CHECK: nz2dn requires loop0_src_stride diff --git a/test/lit/vpto/acc_store_verify_invalid_nz2nz.pto b/test/lit/vpto/acc_store_verify_invalid_nz2nz.pto new file mode 100644 index 000000000..950b280ff --- /dev/null +++ b/test/lit/vpto/acc_store_verify_invalid_nz2nz.pto @@ -0,0 +1,17 @@ +// RUN: not ptoas --pto-arch=a5 %s -o - 2>&1 | FileCheck %s + +module attributes {"pto.target_arch" = "a5"} { + func.func @acc_store_verify_invalid_nz2nz(%src: !pto.ptr, + %dst: !pto.ptr) { + %c0_i64 = arith.constant 0 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + + pto.mte_l0c_ub %src, %dst, %c16_i64, %c16_i64, %c16_i64, %c32_i64, dst_mode(%c0_i64), nz2nz + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + return + } +} + +// CHECK: nz2nz requires destination element type to be f32 diff --git a/test/lit/vpto/acc_store_verify_invalid_pre_quant_vec_payload_type.pto b/test/lit/vpto/acc_store_verify_invalid_pre_quant_vec_payload_type.pto new file mode 100644 index 000000000..6431bea65 --- /dev/null +++ b/test/lit/vpto/acc_store_verify_invalid_pre_quant_vec_payload_type.pto @@ -0,0 +1,21 @@ +// RUN: not ptoas --pto-arch=a5 %s -o - 2>&1 | FileCheck %s + +module attributes {"pto.target_arch" = "a5"} { + func.func @acc_store_verify_invalid_pre_quant_vec_payload_type( + %src: !pto.ptr, %dst: !pto.ptr) { + %c0_i64 = arith.constant 0 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + %payload = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.mte_l0c_ub %src, %dst, %c16_i64, %c16_i64, %c16_i64, %c32_i64, + dst_mode(%c0_i64), + pre_quant(%payload, mode = qf322f16_pre_vec) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, + !pto.ptr + + return + } +} + +// CHECK: vector pre_quant mode requires scaling pointer element type to be f16, bf16, or f32 diff --git a/test/lit/vpto/acc_store_verify_invalid_pwl.pto b/test/lit/vpto/acc_store_verify_invalid_pwl.pto new file mode 100644 index 000000000..24fdeced5 --- /dev/null +++ b/test/lit/vpto/acc_store_verify_invalid_pwl.pto @@ -0,0 +1,17 @@ +// RUN: not ptoas --pto-arch=a5 %s -o - 2>&1 | FileCheck %s + +module attributes {"pto.target_arch" = "a5"} { + func.func @acc_store_verify_invalid_pwl(%src: !pto.ptr, + %dst: !pto.ptr) { + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + + pto.mte_l0c_l1 %src, %dst, %c16_i64, %c16_i64, %c16_i64, %c32_i64, + pre_relu(mode = pwl) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + return + } +} + +// CHECK: pwl is not supported for target_profile mte_l0c_l1 diff --git a/test/lit/vpto/acc_store_verify_invalid_req8_vec_f32_to_f16.pto b/test/lit/vpto/acc_store_verify_invalid_req8_vec_f32_to_f16.pto new file mode 100644 index 000000000..9a6036450 --- /dev/null +++ b/test/lit/vpto/acc_store_verify_invalid_req8_vec_f32_to_f16.pto @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --pto-arch=a5 %s -o - 2>&1 | FileCheck %s + +module attributes {"pto.target_arch" = "a5"} { + func.func @acc_store_verify_invalid_req8_vec_f32_to_f16( + %src: !pto.ptr, %dst: !pto.ptr) { + %c0_i64 = arith.constant 0 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + %fp = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.mte_l0c_ub %src, %dst, %c16_i64, %c16_i64, %c16_i64, %c32_i64, dst_mode(%c0_i64), + pre_quant(%fp, mode = req8_vec), + nz2nd + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, + !pto.ptr + + return + } +} + +// CHECK: pre_quant mode req8_vec is incompatible with source element type 'f32' and destination element type 'f16' diff --git a/test/lit/vpto/acc_store_verify_invalid_scalar_relu_payload.pto b/test/lit/vpto/acc_store_verify_invalid_scalar_relu_payload.pto new file mode 100644 index 000000000..5d48979a9 --- /dev/null +++ b/test/lit/vpto/acc_store_verify_invalid_scalar_relu_payload.pto @@ -0,0 +1,18 @@ +// RUN: not ptoas --pto-arch=a5 %s -o - 2>&1 | FileCheck %s + +module attributes {"pto.target_arch" = "a5"} { + func.func @acc_store_verify_invalid_scalar_relu_payload(%src: !pto.ptr, + %dst: !pto.ptr) { + %c0_i64 = arith.constant 0 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + + pto.mte_l0c_ub %src, %dst, %c16_i64, %c16_i64, %c16_i64, %c32_i64, dst_mode(%c0_i64), + pre_relu(mode = scalar_relu) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + return + } +} + +// CHECK: scalar_relu requires payload diff --git a/test/lit/vpto/acc_store_verify_invalid_unit_flag_nz2dn.pto b/test/lit/vpto/acc_store_verify_invalid_unit_flag_nz2dn.pto new file mode 100644 index 000000000..1e19abb92 --- /dev/null +++ b/test/lit/vpto/acc_store_verify_invalid_unit_flag_nz2dn.pto @@ -0,0 +1,21 @@ +// RUN: not ptoas --pto-arch=a5 %s -o - 2>&1 | FileCheck %s + +module attributes {"pto.target_arch" = "a5"} { + func.func @acc_store_verify_invalid_unit_flag_nz2dn(%src: !pto.ptr, + %dst: !pto.ptr) { + %c0_i64 = arith.constant 0 : i64 + %c2_i64 = arith.constant 2 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + + pto.mte_l0c_gm %src, %dst, %c16_i64, %c16_i64, %c16_i64, %c32_i64, %c0_i64, %c0_i64, + unit_flag(check_only), + nz2dn(%c2_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, + i64 + + return + } +} + +// CHECK: unit_flag must be off when nz2dn loop0_src_stride is not 1 diff --git a/test/lit/vpto/acc_store_verify_invalid_unit_flag_nz2dn_l1.pto b/test/lit/vpto/acc_store_verify_invalid_unit_flag_nz2dn_l1.pto new file mode 100644 index 000000000..68c26d052 --- /dev/null +++ b/test/lit/vpto/acc_store_verify_invalid_unit_flag_nz2dn_l1.pto @@ -0,0 +1,26 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --pto-arch=a5 %s -o - 2>&1 | FileCheck %s + +module attributes {"pto.target_arch" = "a5"} { + func.func @acc_store_verify_invalid_unit_flag_nz2dn_l1(%src: !pto.ptr, + %dst: !pto.ptr) { + %c2_i64 = arith.constant 2 : i64 + %c16_i64 = arith.constant 16 : i64 + + pto.mte_l0c_l1 %src, %dst, %c16_i64, %c16_i64, %c16_i64, %c16_i64, + unit_flag(check_only), + nz2dn(%c2_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + return + } +} + +// CHECK: unit_flag must be off when nz2dn loop0_src_stride is not 1 diff --git a/test/lit/vpto/acc_store_verify_invalid_vector_relu_payload_type.pto b/test/lit/vpto/acc_store_verify_invalid_vector_relu_payload_type.pto new file mode 100644 index 000000000..51da2c2ac --- /dev/null +++ b/test/lit/vpto/acc_store_verify_invalid_vector_relu_payload_type.pto @@ -0,0 +1,21 @@ +// RUN: not ptoas --pto-arch=a5 %s -o - 2>&1 | FileCheck %s + +module attributes {"pto.target_arch" = "a5"} { + func.func @acc_store_verify_invalid_vector_relu_payload_type( + %src: !pto.ptr, %dst: !pto.ptr) { + %c0_i64 = arith.constant 0 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + %payload = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.mte_l0c_ub %src, %dst, %c16_i64, %c16_i64, %c16_i64, %c32_i64, + dst_mode(%c0_i64), + pre_relu(%payload, mode = vector_relu) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, + !pto.ptr + + return + } +} + +// CHECK: vector_relu requires scaling pointer element type to be f16, bf16, or f32 diff --git a/test/lit/vpto/auto_vecscope_infer_boundary.pto b/test/lit/vpto/auto_vecscope_infer_boundary.pto new file mode 100644 index 000000000..e75e21916 --- /dev/null +++ b/test/lit/vpto/auto_vecscope_infer_boundary.pto @@ -0,0 +1,57 @@ +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto %s -o - | FileCheck %s + +module attributes {pto.target_arch = "a5"} { + module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @auto_vecscope_infer_boundary() { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c0_i64 = arith.constant 0 : i64 + %c16_i32 = arith.constant 16 : i32 + %cst = arith.constant 1.000000e+00 : f32 + %ub = pto.castptr %c0_i64 : i64 -> !pto.ptr + + %mask0 = pto.pset_b32 "PAT_ALL" : !pto.mask + %vec0 = pto.vlds %ub[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %sum0 = pto.vadds %vec0, %cst, %mask0 : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %sum0, %ub[%c0], %mask0 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + + pto.barrier #pto.pipe + + %mask1, %tail = pto.plt_b32 %c16_i32 : i32 -> !pto.mask, i32 + %vec1 = pto.vlds %ub[%c64] : !pto.ptr -> !pto.vreg<64xf32> + %sum1 = pto.vadds %vec1, %cst, %mask1 : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %sum1, %ub[%c64], %mask1 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + return + } + + func.func @auto_vecscope_keeps_membar_inside() { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %cst = arith.constant 1.000000e+00 : f32 + %ub = pto.castptr %c0_i64 : i64 -> !pto.ptr + + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %vec = pto.vlds %ub[%c0] : !pto.ptr -> !pto.vreg<64xf32> + pto.mem_bar "VST_VLD" + %sum = pto.vadds %vec, %cst, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %sum, %ub[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + return + } + } +} + +// CHECK-LABEL: func.func @auto_vecscope_infer_boundary +// CHECK: pto.vecscope +// CHECK: pto.vsts +// CHECK: pto.barrier +// CHECK-NEXT: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vadds +// CHECK: pto.vsts + +// CHECK-LABEL: func.func @auto_vecscope_keeps_membar_inside +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK-NEXT: pto.mem_bar "VST_VLD" +// CHECK-NEXT: pto.vadds +// CHECK: pto.vsts diff --git a/test/lit/vpto/auto_vecscope_infer_escape_error.pto b/test/lit/vpto/auto_vecscope_infer_escape_error.pto new file mode 100644 index 000000000..37fdc835c --- /dev/null +++ b/test/lit/vpto/auto_vecscope_infer_escape_error.pto @@ -0,0 +1,17 @@ +// RUN: ! ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto %s -o - 2>&1 | FileCheck %s + +module attributes {pto.target_arch = "a5"} { + func.func @auto_vecscope_infer_escape_error() { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %ub = pto.castptr %c0_i64 : i64 -> !pto.ptr + + %vec = pto.vlds %ub[%c0] : !pto.ptr -> !pto.vreg<64xf32> + pto.barrier #pto.pipe + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + pto.vsts %vec, %ub[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + return + } +} + +// CHECK: error: 'pto.vlds' op cannot infer resultless pto.vecscope because VPTO vector-scope data cannot have external users; escaping value type is '!pto.vreg<64xf32>' diff --git a/test/lit/vpto/auto_vecscope_infer_nested_control_flow.pto b/test/lit/vpto/auto_vecscope_infer_nested_control_flow.pto new file mode 100644 index 000000000..f118f44b0 --- /dev/null +++ b/test/lit/vpto/auto_vecscope_infer_nested_control_flow.pto @@ -0,0 +1,85 @@ +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto %s -o - | FileCheck %s + +module attributes {pto.target_arch = "a5"} { + module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @auto_vecscope_atomic_control_flow(%cond: i1) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c16_i32 = arith.constant 16 : i32 + %cst = arith.constant 1.000000e+00 : f32 + %ub = pto.castptr %c0_i64 : i64 -> !pto.ptr + + %mask0, %tail0 = pto.plt_b32 %c16_i32 : i32 -> !pto.mask, i32 + scf.for %i = %c0 to %c2 step %c1 { + %vec0 = pto.vlds %ub[%i] : !pto.ptr -> !pto.vreg<64xf32> + %sum0 = pto.vadds %vec0, %cst, %mask0 : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %sum0, %ub[%i], %mask0 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + scf.if %cond { + %vec1 = pto.vlds %ub[%c64] : !pto.ptr -> !pto.vreg<64xf32> + %sum1 = pto.vadds %vec1, %cst, %mask0 : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %sum1, %ub[%c64], %mask0 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + %vec2 = pto.vlds %ub[%c128] : !pto.ptr -> !pto.vreg<64xf32> + pto.vsts %vec2, %ub[%c128], %mask0 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + return + } + + func.func @auto_vecscope_recursive_control_flow_fallback(%cond: i1) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c0_i64 = arith.constant 0 : i64 + %c16_i32 = arith.constant 16 : i32 + %cst = arith.constant 1.000000e+00 : f32 + %ub = pto.castptr %c0_i64 : i64 -> !pto.ptr + + scf.if %cond { + %mask0 = pto.pset_b32 "PAT_ALL" : !pto.mask + %vec0 = pto.vlds %ub[%c0] : !pto.ptr -> !pto.vreg<64xf32> + pto.vsts %vec0, %ub[%c0], %mask0 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + + pto.barrier #pto.pipe + + %mask1, %tail = pto.plt_b32 %c16_i32 : i32 -> !pto.mask, i32 + %vec1 = pto.vlds %ub[%c64] : !pto.ptr -> !pto.vreg<64xf32> + %sum1 = pto.vadds %vec1, %cst, %mask1 : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %sum1, %ub[%c64], %mask1 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + return + } + } +} + +// CHECK-LABEL: func.func @auto_vecscope_atomic_control_flow +// CHECK: pto.vecscope { +// CHECK: pto.plt_b32 +// CHECK: scf.for +// CHECK: pto.vlds +// CHECK: pto.vadds +// CHECK: pto.vsts +// CHECK: scf.if +// CHECK: pto.vlds +// CHECK: pto.vadds +// CHECK: pto.vsts +// CHECK: pto.vlds +// CHECK: pto.vsts +// CHECK-NEXT: } + +// CHECK-LABEL: func.func @auto_vecscope_recursive_control_flow_fallback +// CHECK: scf.if +// CHECK-NEXT: pto.vecscope { +// CHECK: pto.pset_b32 +// CHECK: pto.vlds +// CHECK: pto.vsts +// CHECK-NEXT: } +// CHECK-NEXT: pto.barrier +// CHECK-NEXT: pto.vecscope { +// CHECK: pto.plt_b32 +// CHECK: pto.vlds +// CHECK: pto.vadds +// CHECK: pto.vsts +// CHECK-NEXT: } diff --git a/test/lit/vpto/auto_vecscope_infer_safe_scalar.pto b/test/lit/vpto/auto_vecscope_infer_safe_scalar.pto new file mode 100644 index 000000000..5341a502e --- /dev/null +++ b/test/lit/vpto/auto_vecscope_infer_safe_scalar.pto @@ -0,0 +1,28 @@ +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto %s -o - | FileCheck %s + +module attributes {pto.target_arch = "a5"} { + module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @auto_vecscope_infer_safe_scalar(%base_remaining: i32) { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i32 = arith.constant 1 : i32 + %cst = arith.constant 1.000000e+00 : f32 + %ub = pto.castptr %c0_i64 : i64 -> !pto.ptr + + %remaining = arith.addi %base_remaining, %c1_i32 : i32 + %mask, %tail = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %sum = pto.vadds %vec, %cst, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %sum, %ub[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + return + } + } +} + +// CHECK-LABEL: func.func @auto_vecscope_infer_safe_scalar +// CHECK: pto.vecscope +// CHECK: arith.addi +// CHECK-NEXT: pto.plt_b32 +// CHECK: pto.vlds +// CHECK: pto.vadds +// CHECK: pto.vsts diff --git a/test/lit/vpto/auto_vecscope_infer_simple.pto b/test/lit/vpto/auto_vecscope_infer_simple.pto new file mode 100644 index 000000000..7373149e9 --- /dev/null +++ b/test/lit/vpto/auto_vecscope_infer_simple.pto @@ -0,0 +1,28 @@ +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto %s -o - | FileCheck %s + +module attributes {pto.target_arch = "a5"} { + module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @auto_vecscope_infer_simple() { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c16_i32 = arith.constant 16 : i32 + %cst = arith.constant 1.000000e+00 : f32 + %ub = pto.castptr %c0_i64 : i64 -> !pto.ptr + + %store_mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %active_mask, %remaining = pto.plt_b32 %c16_i32 : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %sum = pto.vadds %vec, %cst, %active_mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %sum, %ub[%c0], %store_mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + return + } + } +} + +// CHECK-LABEL: func.func @auto_vecscope_infer_simple +// CHECK: pto.vecscope +// CHECK: pto.pset_b32 +// CHECK: pto.plt_b32 +// CHECK: pto.vlds +// CHECK: pto.vadds +// CHECK: pto.vsts diff --git a/test/lit/vpto/auto_vecscope_preserve_existing.pto b/test/lit/vpto/auto_vecscope_preserve_existing.pto new file mode 100644 index 000000000..deb74f062 --- /dev/null +++ b/test/lit/vpto/auto_vecscope_preserve_existing.pto @@ -0,0 +1,64 @@ +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto %s -o - | FileCheck %s + +module attributes {pto.target_arch = "a5"} { + module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @auto_vecscope_preserve_existing() { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c192 = arith.constant 192 : index + %c0_i64 = arith.constant 0 : i64 + %c16_i32 = arith.constant 16 : i32 + %c17_i32 = arith.constant 17 : i32 + %c18_i32 = arith.constant 18 : i32 + %cst = arith.constant 1.000000e+00 : f32 + %ub = pto.castptr %c0_i64 : i64 -> !pto.ptr + + %mask0, %tail0 = pto.plt_b32 %c17_i32 : i32 -> !pto.mask, i32 + %vec0 = pto.vlds %ub[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %sum0 = pto.vadds %vec0, %cst, %mask0 : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %sum0, %ub[%c0], %mask0 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + + pto.vecscope { + %mask1, %tail1 = pto.plt_b32 %c18_i32 : i32 -> !pto.mask, i32 + %vec1 = pto.vlds %ub[%c64] : !pto.ptr -> !pto.vreg<64xf32> + %sum1 = pto.vadds %vec1, %cst, %mask1 : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %sum1, %ub[%c64], %mask1 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + + pto.strict_vecscope(%ub, %c128, %cst) { + ^bb0(%scope_ub: !pto.ptr, %scope_offset: index, %scope_scalar: f32): + %mask2 = pto.pset_b32 "PAT_ALL" : !pto.mask + %vec2 = pto.vlds %scope_ub[%scope_offset] : !pto.ptr -> !pto.vreg<64xf32> + %sum2 = pto.vadds %vec2, %scope_scalar, %mask2 : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %sum2, %scope_ub[%scope_offset], %mask2 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } : (!pto.ptr, index, f32) -> () + + %mask3, %tail = pto.plt_b32 %c16_i32 : i32 -> !pto.mask, i32 + %vec3 = pto.vlds %ub[%c192] : !pto.ptr -> !pto.vreg<64xf32> + %sum3 = pto.vadds %vec3, %cst, %mask3 : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %sum3, %ub[%c192], %mask3 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + return + } + } +} + +// CHECK-LABEL: func.func @auto_vecscope_preserve_existing +// CHECK: pto.vecscope { +// CHECK: pto.vsts +// CHECK-NEXT: } +// CHECK-NEXT: pto.vecscope { +// CHECK-NOT: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vsts +// CHECK-NEXT: } +// CHECK-NEXT: pto.strict_vecscope +// CHECK-NOT: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vsts +// CHECK: } : (!pto.ptr, index, f32) -> () +// CHECK-NEXT: pto.vecscope { +// CHECK: pto.plt_b32 +// CHECK: pto.vlds +// CHECK: pto.vadds +// CHECK: pto.vsts diff --git a/test/lit/vpto/bias_load_vpto_llvm.pto b/test/lit/vpto/bias_load_vpto_llvm.pto new file mode 100644 index 000000000..6c069dd60 --- /dev/null +++ b/test/lit/vpto/bias_load_vpto_llvm.pto @@ -0,0 +1,18 @@ +// RUN: ( mkdir -p %T && ptoas --pto-arch=a5 --pto-backend=vpto %s -o %t --mlir-print-ir-after=convert-func-to-llvm 2>&1 || true ) | FileCheck %s + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @bias_load_f16_to_f32() attributes {pto.kernel} { + %c0 = arith.constant 0 : i64 + %c2 = arith.constant 2 : i64 + %c4 = arith.constant 4 : i64 + %src = pto.castptr %c0 : i64 -> !pto.ptr + %dst = pto.castptr %c0 : i64 -> !pto.ptr + + pto.mte_l1_bt %src, %dst, %c4 nburst(%c2, %c0, %c0) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + return + } +} + +// CHECK-LABEL: llvm.func @bias_load_f16_to_f32_mix_aic +// CHECK: llvm.call @llvm.hivm.MOV.L1.TO.BT.f16 diff --git a/test/lit/vpto/copy_cbuf_to_fbuf_vpto_llvm.pto b/test/lit/vpto/copy_cbuf_to_fbuf_vpto_llvm.pto new file mode 100644 index 000000000..a006e2520 --- /dev/null +++ b/test/lit/vpto/copy_cbuf_to_fbuf_vpto_llvm.pto @@ -0,0 +1,18 @@ +// RUN: ( mkdir -p %T && ptoas --pto-arch=a5 --pto-backend=vpto %s -o %t --mlir-print-ir-after=convert-func-to-llvm 2>&1 || true ) | FileCheck %s + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @copy_cbuf_to_fbuf_basic() attributes {pto.kernel} { + %c0 = arith.constant 0 : i64 + %c2 = arith.constant 2 : i64 + %c4 = arith.constant 4 : i64 + %src = pto.castptr %c0 : i64 -> !pto.ptr + %dst = pto.castptr %c0 : i64 -> !pto.ptr + + pto.copy_cbuf_to_fbuf %src, %dst, %c2, %c4, %c0, %c0 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + return + } +} + +// CHECK-LABEL: llvm.func @copy_cbuf_to_fbuf_basic_mix_aic +// CHECK: llvm.call @llvm.hivm.MOV.L1.TO.FB.v220 diff --git a/test/lit/vpto/copy_matrix_cc_to_ub_vpto_llvm.pto b/test/lit/vpto/copy_matrix_cc_to_ub_vpto_llvm.pto new file mode 100644 index 000000000..0598145e5 --- /dev/null +++ b/test/lit/vpto/copy_matrix_cc_to_ub_vpto_llvm.pto @@ -0,0 +1,26 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ( mkdir -p %T && ptoas --pto-arch=a5 --pto-backend=vpto %s -o %t --mlir-print-ir-after=convert-func-to-llvm 2>&1 || true ) | FileCheck %s + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @copy_matrix_cc_to_ub_probe() attributes {pto.kernel} { + %c0_i64 = arith.constant 0 : i64 + %c16_i64 = arith.constant 16 : i64 + %c68720525568_i64 = arith.constant 68720525568 : i64 + + %src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %dst = pto.castptr %c0_i64 : i64 -> !pto.ptr + pto.copy_matrix_cc_to_ub %src, %dst, %c68720525568_i64, %c16_i64 + : !pto.ptr, !pto.ptr, i64, i64 + return + } +} + +// CHECK-LABEL: llvm.func @copy_matrix_cc_to_ub_probe_mix_aic +// CHECK: llvm.call @llvm.hivm.FIX.L0C.TO.UB.f32.EXT diff --git a/test/lit/vpto/ctrl_ops.pto b/test/lit/vpto/ctrl_ops.pto new file mode 100644 index 000000000..91f5f1233 --- /dev/null +++ b/test/lit/vpto/ctrl_ops.pto @@ -0,0 +1,18 @@ +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto %s -o - | FileCheck %s + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @kernel() { + %ctrl = pto.get_ctrl : i64 + %bit0 = arith.constant 0 : i64 + %bit45 = arith.constant 45 : i64 + %ctrl0 = pto.sbitset0 %ctrl, %bit0 : i64, i64 -> i64 + %ctrl1 = pto.sbitset1 %ctrl0, %bit45 : i64, i64 -> i64 + pto.set_ctrl %ctrl1 : i64 + return + } +} + +// CHECK: pto.get_ctrl +// CHECK: pto.sbitset0 +// CHECK: pto.sbitset1 +// CHECK: pto.set_ctrl diff --git a/test/lit/vpto/cube/cube_bridge_load_buffer_like.pto b/test/lit/vpto/cube/cube_bridge_load_buffer_like.pto new file mode 100644 index 000000000..8537da6ae --- /dev/null +++ b/test/lit/vpto/cube/cube_bridge_load_buffer_like.pto @@ -0,0 +1,56 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Guard: buffer-like cube bridge loads must round-trip from memref-like +// authoring operands back to !pto.ptr forms before bridge expansion, and then +// lower to the expected VPTO cube load bridge ops. +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --mlir-print-ir-before=vpto-expand-wrapper-ops %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=ROUNDTRIP +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --mlir-print-ir-after=vpto-expand-wrapper-ops %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=EXPAND + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @cube_bridge_load_buffer_like( + %src_mat: memref>, + %dst_left: memref>, + %dst_right: memref>, + %dst_left_mx: memref>, + %dst_right_mx: memref>, + %dst_bias: memref>) attributes {pto.kernel} { + %c0_i64 = arith.constant 0 : i64 + %c2_i64 = arith.constant 2 : i64 + %c4_i64 = arith.constant 4 : i64 + %c16_i64 = arith.constant 16 : i64 + + pto.mte_l1_l0a %src_mat, %dst_left, %c16_i64, %c16_i64 + : memref>, memref>, i64, i64 + pto.mte_l1_l0b %src_mat, %dst_right, %c16_i64, %c16_i64 {transpose = true} + : memref>, memref>, i64, i64 + pto.mte_l1_l0a_mx %src_mat, %dst_left_mx, %c16_i64, %c16_i64 + : memref>, memref>, i64, i64 + pto.mte_l1_l0b_mx %src_mat, %dst_right_mx, %c16_i64, %c16_i64 + : memref>, memref>, i64, i64 + pto.mte_l1_bt %src_mat, %dst_bias, %c4_i64 nburst(%c2_i64, %c0_i64, %c0_i64) + : memref>, memref>, i64, i64, i64, i64 + + return + } +} + +// ROUNDTRIP-LABEL: func.func @cube_bridge_load_buffer_like( +// ROUNDTRIP-SAME: %{{.*}}: !pto.ptr, %{{.*}}: !pto.ptr, %{{.*}}: !pto.ptr, %{{.*}}: !pto.ptr, %{{.*}}: !pto.ptr, %{{.*}}: !pto.ptr) attributes {pto.kernel} { +// ROUNDTRIP: pto.mte_l1_l0a %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} +// ROUNDTRIP: pto.mte_l1_l0b %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {transpose = true} +// ROUNDTRIP: pto.mte_l1_l0a_mx %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} +// ROUNDTRIP: pto.mte_l1_l0b_mx %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} +// ROUNDTRIP: pto.mte_l1_bt %{{.*}}, %{{.*}}, %{{.*}} nburst(%{{.*}}, %{{.*}}, %{{.*}}) + +// EXPAND-LABEL: func.func @cube_bridge_load_buffer_like( +// EXPAND: pto.load_cbuf_to_ca +// EXPAND: pto.load_cbuf_to_cb +// EXPAND: pto.load_cbuf_to_ca_mx +// EXPAND: pto.load_cbuf_to_cb_mx +// EXPAND: pto.copy_cbuf_to_bt diff --git a/test/lit/vpto/cube/cube_bridge_load_verify_invalid.pto b/test/lit/vpto/cube/cube_bridge_load_verify_invalid.pto new file mode 100644 index 000000000..f41854843 --- /dev/null +++ b/test/lit/vpto/cube/cube_bridge_load_verify_invalid.pto @@ -0,0 +1,18 @@ +// Guard: cube bridge loads must reject non-MAT sources at verification time +// before any VPTO bridge expansion runs. +// RUN: not ptoas --pto-arch=a5 --pto-backend=vpto %s -o - 2>&1 | FileCheck %s + +module attributes {"pto.target_arch" = "a5"} { + func.func @cube_bridge_load_verify_invalid( + %src_gm: memref>, + %dst_left: memref>) { + %c16_i64 = arith.constant 16 : i64 + + pto.mte_l1_l0a %src_gm, %dst_left, %c16_i64, %c16_i64 + : memref>, memref>, i64, i64 + + return + } +} + +// CHECK: error: 'pto.mte_l1_l0a' op requires MAT source diff --git a/test/lit/vpto/cube/expand_tile_op_tilelang_tmatmul.pto b/test/lit/vpto/cube/expand_tile_op_tilelang_tmatmul.pto new file mode 100644 index 000000000..6c5fe2e23 --- /dev/null +++ b/test/lit/vpto/cube/expand_tile_op_tilelang_tmatmul.pto @@ -0,0 +1,38 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Guard: TileLang cube tmatmul expansion on the dav-c310-cube VPTO path must +// inline the template body and lower the tile op to cube MAD instructions. +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-tile-op-expand --emit-vpto %s -o - 2>/dev/null | FileCheck %s + +// CHECK-LABEL: func.func @TMATMUL +// CHECK-NOT: pto.tmatmul ins +// CHECK: pto.mad + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @TMATMUL() attributes {pto.kernel} { + %lhs = pto.alloc_tile + : !pto.tile_buf + %rhs = pto.alloc_tile + : !pto.tile_buf + %acc = pto.alloc_tile + : !pto.tile_buf + + pto.tmatmul ins(%lhs, %rhs : !pto.tile_buf, + !pto.tile_buf) + outs(%acc : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/cube/mad_mx_semantic_to_raw.pto b/test/lit/vpto/cube/mad_mx_semantic_to_raw.pto new file mode 100644 index 000000000..9c3314542 --- /dev/null +++ b/test/lit/vpto/cube/mad_mx_semantic_to_raw.pto @@ -0,0 +1,56 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --mlir-print-ir-before=vpto-expand-wrapper-ops %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=SEMANTIC +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --mlir-print-ir-after=vpto-expand-wrapper-ops %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=EXPAND + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @mad_mx_semantic_to_raw( + %lhs: !pto.ptr, + %rhs: !pto.ptr, + %dst: !pto.ptr, + %bias: !pto.ptr) attributes {pto.kernel} { + %c16 = arith.constant 16 : i64 + pto.mad_mx %lhs, %rhs, %dst, %c16, %c16, %c16 unit_flag(check_only) disable_gemv sat n_dir + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + pto.mad_mx_acc %lhs, %rhs, %dst, %c16, %c16, %c16 unit_flag(check_and_set) nosat + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + pto.mad_mx_bias %lhs, %rhs, %dst, %bias, %c16, %c16, %c16 + : !pto.ptr, !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + return + } +} + +// SEMANTIC-LABEL: func.func @mad_mx_semantic_to_raw( +// SEMANTIC: pto.mad_mx {{.*}} unit_flag(check_only) disable_gemv sat n_dir +// SEMANTIC: pto.mad_mx_acc {{.*}} unit_flag(check_and_set) nosat +// SEMANTIC: pto.mad_mx_bias +// SEMANTIC-NOT: tf32_mode + +// EXPAND-LABEL: func.func @mad_mx_semantic_to_raw( +// EXPAND-DAG: %[[BIT45:.*]] = arith.constant 45 : i64 +// EXPAND-DAG: %[[BIT46:.*]] = arith.constant 46 : i64 +// EXPAND-DAG: %[[BIT47:.*]] = arith.constant 47 : i64 +// EXPAND-DAG: %[[BIT48:.*]] = arith.constant 48 : i64 +// EXPAND-DAG: %[[BIT51:.*]] = arith.constant 51 : i64 +// EXPAND: %[[CTRL0:.*]] = pto.get_ctrl : i64 +// EXPAND: pto.sbitset0 {{.*}}, %[[BIT45]] +// EXPAND: pto.sbitset0 {{.*}}, %[[BIT46]] +// EXPAND: pto.sbitset0 {{.*}}, %[[BIT47]] +// EXPAND: pto.sbitset0 {{.*}}, %[[BIT48]] +// EXPAND: pto.sbitset1 {{.*}}, %[[BIT51]] +// EXPAND: pto.set_ctrl +// EXPAND: pto.mad_mx_raw +// EXPAND: pto.set_ctrl %[[CTRL0]] : i64 +// EXPAND: %[[CTRL1:.*]] = pto.get_ctrl : i64 +// EXPAND: pto.sbitset1 {{.*}}, %[[BIT48]] +// EXPAND: pto.mad_mx_raw +// EXPAND: pto.set_ctrl %[[CTRL1]] : i64 +// EXPAND: %[[CTRL2:.*]] = pto.get_ctrl : i64 +// EXPAND: pto.mad_mx_bias_raw +// EXPAND: pto.set_ctrl %[[CTRL2]] : i64 diff --git a/test/lit/vpto/cube/mad_mx_semantic_vpto_llvm.pto b/test/lit/vpto/cube/mad_mx_semantic_vpto_llvm.pto new file mode 100644 index 000000000..f005c9b65 --- /dev/null +++ b/test/lit/vpto/cube/mad_mx_semantic_vpto_llvm.pto @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ( mkdir -p %T && ptoas --pto-arch=a5 --pto-backend=vpto %s -o %t --mlir-print-ir-after=convert-func-to-llvm 2>&1 || true ) | FileCheck %s + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @mad_mx_semantic_vpto_llvm( + %lhs: !pto.ptr, + %rhs: !pto.ptr, + %dst: !pto.ptr, + %bias: !pto.ptr) attributes {pto.kernel} { + %c16 = arith.constant 16 : i64 + pto.mad_mx %lhs, %rhs, %dst, %c16, %c16, %c16 unit_flag(check_only) disable_gemv sat n_dir + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + pto.mad_mx_acc %lhs, %rhs, %dst, %c16, %c16, %c16 unit_flag(check_and_set) nosat + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + pto.mad_mx_bias %lhs, %rhs, %dst, %bias, %c16, %c16, %c16 + : !pto.ptr, !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + return + } +} + +// CHECK-LABEL: llvm.func @mad_mx_semantic_vpto_llvm_mix_aic +// CHECK-COUNT-3: llvm.call @llvm.hivm.MMAD.MX.e4m3e4m3 diff --git a/test/lit/vpto/cube/mad_mx_tf32_verify_invalid.pto b/test/lit/vpto/cube/mad_mx_tf32_verify_invalid.pto new file mode 100644 index 000000000..b6cef25e7 --- /dev/null +++ b/test/lit/vpto/cube/mad_mx_tf32_verify_invalid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto %s -o /dev/null 2>&1 | FileCheck %s + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @mad_mx_tf32_invalid( + %lhs: !pto.ptr, + %rhs: !pto.ptr, + %dst: !pto.ptr) attributes {pto.kernel} { + %c16 = arith.constant 16 : i64 + pto.mad_mx %lhs, %rhs, %dst, %c16, %c16, %c16 tf32_mode(round_even) + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + return + } +} + +// CHECK: expected ':' diff --git a/test/lit/vpto/cube/mad_sat_verify_invalid.pto b/test/lit/vpto/cube/mad_sat_verify_invalid.pto new file mode 100644 index 000000000..f21ccd39d --- /dev/null +++ b/test/lit/vpto/cube/mad_sat_verify_invalid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto %s -o /dev/null 2>&1 | FileCheck %s + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @mad_integer_sat_invalid( + %lhs: !pto.ptr, + %rhs: !pto.ptr, + %dst: !pto.ptr) attributes {pto.kernel} { + %c16 = arith.constant 16 : i64 + pto.mad %lhs, %rhs, %dst, %c16, %c16, %c16 sat + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + return + } +} + +// CHECK: requires sat/nosat only for floating lhs/rhs/dst element types diff --git a/test/lit/vpto/cube/mad_semantic_to_raw.pto b/test/lit/vpto/cube/mad_semantic_to_raw.pto new file mode 100644 index 000000000..d7a02062b --- /dev/null +++ b/test/lit/vpto/cube/mad_semantic_to_raw.pto @@ -0,0 +1,65 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --mlir-print-ir-before=vpto-expand-wrapper-ops %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=SEMANTIC +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --mlir-print-ir-after=vpto-expand-wrapper-ops %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=EXPAND + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @mad_semantic_to_raw( + %lhs: !pto.ptr, + %rhs: !pto.ptr, + %dst: !pto.ptr, + %lhs_hif8: !pto.ptr, + %rhs_hif8: !pto.ptr, + %bias: !pto.ptr) attributes {pto.kernel} { + %c16 = arith.constant 16 : i64 + pto.mad %lhs, %rhs, %dst, %c16, %c16, %c16 unit_flag(check_only) disable_gemv sat tf32_mode(round_even) n_dir + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + pto.mad_acc %lhs, %rhs, %dst, %c16, %c16, %c16 unit_flag(check_and_set) nosat tf32_mode(round_away) + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + pto.mad_bias %lhs, %rhs, %dst, %bias, %c16, %c16, %c16 + : !pto.ptr, !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + pto.mad %lhs_hif8, %rhs_hif8, %dst, %c16, %c16, %c16 + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + return + } +} + +// SEMANTIC-LABEL: func.func @mad_semantic_to_raw( +// SEMANTIC: pto.mad {{.*}} unit_flag(check_only) disable_gemv sat tf32_mode(round_even) n_dir +// SEMANTIC: pto.mad_acc {{.*}} unit_flag(check_and_set) nosat tf32_mode(round_away) +// SEMANTIC: pto.mad_bias +// SEMANTIC: pto.mad {{.*}}!pto.ptr, !pto.ptr + +// EXPAND-LABEL: func.func @mad_semantic_to_raw( +// EXPAND-DAG: %[[BIT45:.*]] = arith.constant 45 : i64 +// EXPAND-DAG: %[[BIT46:.*]] = arith.constant 46 : i64 +// EXPAND-DAG: %[[BIT47:.*]] = arith.constant 47 : i64 +// EXPAND-DAG: %[[BIT48:.*]] = arith.constant 48 : i64 +// EXPAND-DAG: %[[BIT51:.*]] = arith.constant 51 : i64 +// EXPAND: %[[CTRL0:.*]] = pto.get_ctrl : i64 +// EXPAND: pto.sbitset0 {{.*}}, %[[BIT45]] +// EXPAND: pto.sbitset1 {{.*}}, %[[BIT46]] +// EXPAND: pto.sbitset0 {{.*}}, %[[BIT47]] +// EXPAND: pto.sbitset0 {{.*}}, %[[BIT48]] +// EXPAND: pto.sbitset1 {{.*}}, %[[BIT51]] +// EXPAND: pto.set_ctrl +// EXPAND: pto.mad_raw +// EXPAND: pto.set_ctrl %[[CTRL0]] : i64 +// EXPAND: %[[CTRL1:.*]] = pto.get_ctrl : i64 +// EXPAND: pto.sbitset1 {{.*}}, %[[BIT47]] +// EXPAND: pto.sbitset1 {{.*}}, %[[BIT48]] +// EXPAND: pto.mad_raw +// EXPAND: pto.set_ctrl %[[CTRL1]] : i64 +// EXPAND: %[[CTRL2:.*]] = pto.get_ctrl : i64 +// EXPAND: pto.mad_bias_raw +// EXPAND: pto.set_ctrl %[[CTRL2]] : i64 +// EXPAND: %[[CTRL3:.*]] = pto.get_ctrl : i64 +// EXPAND: pto.sbitset1 {{.*}}, %[[BIT45]] +// EXPAND: pto.mad_raw {{.*}}!pto.ptr, !pto.ptr +// EXPAND: pto.set_ctrl %[[CTRL3]] : i64 diff --git a/test/lit/vpto/cube/mad_semantic_verify_invalid.pto b/test/lit/vpto/cube/mad_semantic_verify_invalid.pto new file mode 100644 index 000000000..ed3ea5fa7 --- /dev/null +++ b/test/lit/vpto/cube/mad_semantic_verify_invalid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto %s -o /dev/null 2>&1 | FileCheck %s + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @mad_mixed_hif8_invalid( + %lhs: !pto.ptr, + %rhs: !pto.ptr, + %dst: !pto.ptr) attributes {pto.kernel} { + %c16 = arith.constant 16 : i64 + pto.mad %lhs, %rhs, %dst, %c16, %c16, %c16 + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + return + } +} + +// CHECK: requires lhs/rhs to both use hif8 or both use non-hif8 element types diff --git a/test/lit/vpto/cube/mad_semantic_vpto_llvm.pto b/test/lit/vpto/cube/mad_semantic_vpto_llvm.pto new file mode 100644 index 000000000..af5a6ce81 --- /dev/null +++ b/test/lit/vpto/cube/mad_semantic_vpto_llvm.pto @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ( mkdir -p %T && ptoas --pto-arch=a5 --pto-backend=vpto %s -o %t --mlir-print-ir-after=convert-func-to-llvm 2>&1 || true ) | FileCheck %s + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @mad_semantic_vpto_llvm( + %lhs: !pto.ptr, + %rhs: !pto.ptr, + %dst: !pto.ptr, + %bias: !pto.ptr) attributes {pto.kernel} { + %c16 = arith.constant 16 : i64 + pto.mad %lhs, %rhs, %dst, %c16, %c16, %c16 unit_flag(check_only) disable_gemv sat tf32_mode(round_even) n_dir + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + pto.mad_acc %lhs, %rhs, %dst, %c16, %c16, %c16 unit_flag(check_and_set) nosat tf32_mode(round_away) + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + pto.mad_bias %lhs, %rhs, %dst, %bias, %c16, %c16, %c16 + : !pto.ptr, !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + return + } +} + +// CHECK-LABEL: llvm.func @mad_semantic_vpto_llvm_mix_aic +// CHECK-COUNT-3: llvm.call @llvm.hivm.MAD.f322f32.c310 diff --git a/test/lit/vpto/cube/mad_tf32_verify_invalid.pto b/test/lit/vpto/cube/mad_tf32_verify_invalid.pto new file mode 100644 index 000000000..fbf47a095 --- /dev/null +++ b/test/lit/vpto/cube/mad_tf32_verify_invalid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto %s -o /dev/null 2>&1 | FileCheck %s + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @mad_tf32_non_f32_invalid( + %lhs: !pto.ptr, + %rhs: !pto.ptr, + %dst: !pto.ptr) attributes {pto.kernel} { + %c16 = arith.constant 16 : i64 + pto.mad %lhs, %rhs, %dst, %c16, %c16, %c16 tf32_mode(round_even) + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + return + } +} + +// CHECK: requires tf32_mode only for f32 lhs/rhs/dst element types diff --git a/test/lit/vpto/cube_load_frac.pto b/test/lit/vpto/cube_load_frac.pto new file mode 100644 index 000000000..7699f742e --- /dev/null +++ b/test/lit/vpto/cube_load_frac.pto @@ -0,0 +1,55 @@ +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --mlir-print-ir-before=vpto-expand-wrapper-ops %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=ROUNDTRIP +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto %s -o - 2>/dev/null | FileCheck %s --check-prefix=EXPAND + +module attributes {"pto.target_arch" = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @cube_load_frac_roundtrip(%arg0: memref>, + %arg1: memref>, + %dst0: memref>, + %dst1: memref>) { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c2_i64 = arith.constant 2 : i64 + %c3_i64 = arith.constant 3 : i64 + %c4_i64 = arith.constant 4 : i64 + %c5_i64 = arith.constant 5 : i64 + %c16_i64 = arith.constant 16 : i64 + %c20_i64 = arith.constant 20 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %cfalse = arith.constant 0 : i1 + %ctrue = arith.constant 1 : i1 + + pto.mte_gm_l1_frac %arg0, %dst0, nd2nz, + shape(%c16_i64, %c5_i64), + src_layout(%c32_i64, %c256_i64), + dst_group(%c2_i64, %c1_i64, %c4_i64, %c20_i64), + ctrl(%c0_i64, %cfalse) + : memref>, memref>, nd2nz, + shape i64, i64, src_layout(i64, i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + + pto.mte_gm_l1_frac %arg1, %dst1, dn2nz, + shape(%c16_i64, %c4_i64), + src_layout(%c64_i64), + dst_group(%c1_i64, %c1_i64, %c4_i64, %c20_i64), + ctrl(%c0_i64, %ctrue) + : memref>, memref>, dn2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + + return + } +} + +// ROUNDTRIP-LABEL: func.func @cube_load_frac_roundtrip( +// ROUNDTRIP: pto.mte_gm_l1_frac %{{.*}}, %{{.*}}, nd2nz, shape(%{{.*}}, %{{.*}}), src_layout(%{{.*}}, %{{.*}}), dst_group(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}), ctrl(%{{.*}}, %{{.*}}) +// ROUNDTRIP: pto.mte_gm_l1_frac %{{.*}}, %{{.*}}, dn2nz, shape(%{{.*}}, %{{.*}}), src_layout(%{{.*}}), dst_group(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}), ctrl(%{{.*}}, %{{.*}}) + +// EXPAND-LABEL: func.func @cube_load_frac_roundtrip( +// EXPAND: pto.set_mte2_nz_para +// EXPAND: pto.copy_gm_to_cbuf_multi_nd2nz +// EXPAND: pto.set_mte2_nz_para +// EXPAND: pto.copy_gm_to_cbuf_multi_dn2nz diff --git a/test/lit/vpto/cube_store_dma_copy.pto b/test/lit/vpto/cube_store_dma_copy.pto new file mode 100644 index 000000000..ac6fbd48c --- /dev/null +++ b/test/lit/vpto/cube_store_dma_copy.pto @@ -0,0 +1,34 @@ +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --mlir-print-ir-before=vpto-expand-wrapper-ops %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=ROUNDTRIP +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --mlir-print-ir-after=vpto-expand-wrapper-ops %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=EXPAND + +module attributes {"pto.target_arch" = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @cube_store_dma_copy_probe(%src_mat: !pto.ptr, + %dst_ub: !pto.ptr, + %src_ub: !pto.ptr, + %dst_mat: !pto.ptr) { + %c1_i64 = arith.constant 1 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + + pto.mte_l1_ub %src_mat, %dst_ub, %c16_i64 + nburst(%c1_i64, %c32_i64, %c64_i64) + loop(%c1_i64, %c64_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, loop i64, i64, i64 + + pto.mte_ub_l1 %src_ub, %dst_mat, %c16_i64 + nburst(%c1_i64, %c32_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + return + } +} + +// ROUNDTRIP-LABEL: func.func @cube_store_dma_copy_probe( +// ROUNDTRIP: pto.mte_l1_ub {{.*}} nburst({{.*}}) loop({{.*}}) +// ROUNDTRIP: pto.mte_ub_l1 {{.*}} nburst({{.*}}) + +// EXPAND-LABEL: func.func @cube_store_dma_copy_probe( +// EXPAND: scf.for +// EXPAND: pto.copy_cbuf_to_ubuf +// EXPAND: pto.copy_ubuf_to_cbuf diff --git a/test/lit/vpto/expand_tile_op_tilelang.pto b/test/lit/vpto/expand_tile_op_tilelang.pto new file mode 100644 index 000000000..cef4e22bb --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang.pto @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tadd via the default TileLang Python DSL template +// lib/TileOps/tadd_template.py. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tadd should be lowered to vector-style VPTO IR. +// CHECK: func.func @TADD +// CHECK-NOT: pto.tadd ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vadd +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TADD() { + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + + pto.tadd ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%tile_buf : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_tilelang_tabs.pto b/test/lit/vpto/expand_tile_op_tilelang_tabs.pto new file mode 100644 index 000000000..40d19d663 --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_tabs.pto @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tabs should be lowered to vector-style VPTO IR. +// CHECK: func.func @TABS +// CHECK-NOT: pto.tabs ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vabs +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TABS() { + %a = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + + pto.tabs ins(%a: !pto.tile_buf) + outs(%tile_buf: !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_tilelang_tadds.pto b/test/lit/vpto/expand_tile_op_tilelang_tadds.pto new file mode 100644 index 000000000..85e4d0359 --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_tadds.pto @@ -0,0 +1,38 @@ +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that pto.tadds can be lowered to vector-style VPTO IR. +// +// IMPORTANT: Do NOT use --enable-tile-op-expand for ops with scalar operands. +// TADDS has a scalar operand (f32), so it uses the PTOToVPTO lowering path. +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto %s -o - 2>/dev/null | FileCheck %s + +// After PTOToVPTO lowering, pto.tadds should use vadds (vector add scalar). +// CHECK: func.func @TADDS +// CHECK-NOT: pto.tadds ins +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vadds +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TADDS() { + %a = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + %scalar = arith.constant 1.0 : f32 + + pto.tadds ins(%a, %scalar : !pto.tile_buf, + f32) + outs(%tile_buf : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/lit/vpto/expand_tile_op_tilelang_tand.pto b/test/lit/vpto/expand_tile_op_tilelang_tand.pto new file mode 100644 index 000000000..d59c14600 --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_tand.pto @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms of conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tand via the default TileLang Python DSL template +// lib/TileOps/tand_template.py. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tand should be lowered to vector-style VPTO IR. +// CHECK: func.func @TAND +// CHECK-NOT: pto.tand ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vand +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TAND() { + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + + pto.tand ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%tile_buf : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_tilelang_tands.pto b/test/lit/vpto/expand_tile_op_tilelang_tands.pto new file mode 100644 index 000000000..f67018eac --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_tands.pto @@ -0,0 +1,34 @@ +// Test that pto.tands can be lowered to vector-style VPTO IR. +// +// IMPORTANT: Do NOT use --enable-tile-op-expand for ops with scalar operands. +// TANDS has a scalar operand (i32), so it uses the PTOToVPTO lowering path. +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto %s -o - 2>/dev/null | FileCheck %s + +// After PTOToVPTO lowering, pto.tands should use vbr + vand. +// CHECK: func.func @TANDS +// CHECK-NOT: pto.tands ins +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vbr +// CHECK: pto.vand +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TANDS() { + %a = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + %scalar = arith.constant 0xFF : i32 + + pto.tands ins(%a, %scalar : !pto.tile_buf, + i32) + outs(%tile_buf : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_tilelang_tcmp.pto b/test/lit/vpto/expand_tile_op_tilelang_tcmp.pto new file mode 100644 index 000000000..4acfda692 --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_tcmp.pto @@ -0,0 +1,49 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tcmp via the default TileLang Python DSL template +// lib/TileOps/tcmp_template.py. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// Tile op should be lowered to vector-style VPTO IR. + +// CHECK-LABEL: func.func @TCMP +// CHECK-NOT: pto.tcmp ins +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vcmp +// CHECK: pto.pbitcast +// CHECK: pto.pdintlv_b8 +// CHECK: pto.psts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TCMP() { + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + + pto.tcmp ins(%a, %b {cmpMode = #pto} : !pto.tile_buf, + !pto.tile_buf) + outs(%tile_buf : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_tilelang_tcmps.pto b/test/lit/vpto/expand_tile_op_tilelang_tcmps.pto new file mode 100644 index 000000000..5a0c42aa2 --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_tcmps.pto @@ -0,0 +1,39 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that pto.tcmps can be lowered to vector-style VPTO IR via TileLang template. +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After TileLang template expansion, pto.tcmps should use vcmps + pdintlv_b8 + psts. +// CHECK: func.func @TCMPS +// CHECK-NOT: pto.tcmps ins +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vcmps +// CHECK: pto.pdintlv_b8 +// CHECK: pto.psts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TCMPS() { + %a = pto.alloc_tile + : !pto.tile_buf + %mask_buf = pto.alloc_tile + : !pto.tile_buf + %scalar = arith.constant 0.0 : f32 + + pto.tcmps ins(%a, %scalar {cmpMode = #pto} : !pto.tile_buf, + f32) + outs(%mask_buf : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_tilelang_tcvt.pto b/test/lit/vpto/expand_tile_op_tilelang_tcvt.pto new file mode 100644 index 000000000..1a88c6882 --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_tcvt.pto @@ -0,0 +1,348 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test the first regular tcvt paths through ExpandTileOp. +// Current scope is intentionally narrow: +// - f32 -> i32: regular vcvt with rnd + sat +// - f32 -> f16: cast32to16_2D_NoPostUpdate-style vcvt(part=EVEN) + PK_B32 store +// - i32 -> f32: regular vcvt with rnd only +// - f16 -> f32: cast16to32-style UNPK_B16 load + vcvt(part=EVEN) +// - round_mode must reach the template so different rmode values materialize +// different vcvt attrs +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// CHECK-LABEL: func.func @TCVT_f32_to_f16 +// CHECK-NOT: pto.tcvt ins +// CHECK: pto.vecscope +// CHECK: pto.vcvt {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: pto.vsts {{.*}} {dist = "PK_B32"} : !pto.vreg<128xf16>, {{.*}}, !pto.mask + +// CHECK-LABEL: func.func @TCVT_f32_to_i32 +// CHECK-NOT: pto.tcvt ins +// CHECK: pto.vecscope +// CHECK: pto.vcvt {{.*}} {rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xi32> + +// CHECK-LABEL: func.func @TCVT_f16_to_f32 +// CHECK-NOT: pto.tcvt ins +// CHECK: pto.vecscope +// CHECK: pto.vlds {{.*}} {dist = "UNPK_B16"} : {{.*}} -> !pto.vreg<128xf16> +// CHECK: pto.vcvt {{.*}} {part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> + +// CHECK-LABEL: func.func @TCVT_i32_to_f32 +// CHECK-NOT: pto.tcvt ins +// CHECK: pto.vecscope +// CHECK: pto.vcvt {{.*}} {rnd = "R"} : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xf32> + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + // 1. f32 -> f16 + func.func @TCVT_f32_to_f16() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tcvt ins(%src: !pto.tile_buf) + outs(%dst: !pto.tile_buf) + return + } + + // 2. f32 -> i32 + func.func @TCVT_f32_to_i32() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tcvt ins(%src: !pto.tile_buf) + outs(%dst: !pto.tile_buf) + return + } + + // 3. f16 -> f32 + func.func @TCVT_f16_to_f32() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tcvt ins(%src: !pto.tile_buf) + outs(%dst: !pto.tile_buf) + return + } + + // 4. i32 -> f32 + func.func @TCVT_i32_to_f32() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tcvt ins(%src: !pto.tile_buf) + outs(%dst: !pto.tile_buf) + return + } + + // 5. bf16 -> f32 + func.func @TCVT_bf16_to_f32() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tcvt ins(%src: !pto.tile_buf) + outs(%dst: !pto.tile_buf) + return + } + + // 6. i16 -> f32 + func.func @TCVT_i16_to_f32() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tcvt ins(%src: !pto.tile_buf) + outs(%dst: !pto.tile_buf) + return + } + + // 7. i16 -> i32 + func.func @TCVT_i16_to_i32() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tcvt ins(%src: !pto.tile_buf) + outs(%dst: !pto.tile_buf) + return + } + + // 8. i16 -> ui32 + func.func @TCVT_i16_to_ui32() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tcvt ins(%src: !pto.tile_buf) + outs(%dst: !pto.tile_buf) + return + } + + // 9. i32 -> i64 + func.func @TCVT_i32_to_i64() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tcvt ins(%src: !pto.tile_buf) + outs(%dst: !pto.tile_buf) + return + } + + // 10. ui8 -> f16 + func.func @TCVT_ui8_to_f16() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tcvt ins(%src: !pto.tile_buf) + outs(%dst: !pto.tile_buf) + return + } + + // 11. ui8 -> ui16 + func.func @TCVT_ui8_to_ui16() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tcvt ins(%src: !pto.tile_buf) + outs(%dst: !pto.tile_buf) + return + } + + // 12. si8 -> f16 + func.func @TCVT_si8_to_f16() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tcvt ins(%src: !pto.tile_buf) + outs(%dst: !pto.tile_buf) + return + } + + // 13. si8 -> si16 + func.func @TCVT_si8_to_si16() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tcvt ins(%src: !pto.tile_buf) + outs(%dst: !pto.tile_buf) + return + } + + // 14. si8 -> i32 + func.func @TCVT_si8_to_i32() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tcvt ins(%src: !pto.tile_buf) + outs(%dst: !pto.tile_buf) + return + } + + // 15. f32 -> f32 (vtrc) + func.func @TCVT_f32_to_f32() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tcvt ins(%src: !pto.tile_buf) + outs(%dst: !pto.tile_buf) + return + } + + // 16. f16 -> i32 + func.func @TCVT_f16_to_i32() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tcvt ins(%src: !pto.tile_buf) + outs(%dst: !pto.tile_buf) + return + } + + // 17. i16 -> f16 + func.func @TCVT_i16_to_f16() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tcvt ins(%src: !pto.tile_buf) + outs(%dst: !pto.tile_buf) + return + } + + // 18. i64 -> f32 + func.func @TCVT_i64_to_f32() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tcvt ins(%src: !pto.tile_buf) + outs(%dst: !pto.tile_buf) + return + } + + // 19. i16 -> ui8 + func.func @TCVT_i16_to_ui8() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tcvt ins(%src: !pto.tile_buf) + outs(%dst: !pto.tile_buf) + return + } + + // 20. i32 -> i16 + func.func @TCVT_i32_to_i16() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tcvt ins(%src: !pto.tile_buf) + outs(%dst: !pto.tile_buf) + return + } + + // 21. i32 -> ui16 + func.func @TCVT_i32_to_ui16() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tcvt ins(%src: !pto.tile_buf) + outs(%dst: !pto.tile_buf) + return + } + + // 22. i32 -> ui8 + func.func @TCVT_i32_to_ui8() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tcvt ins(%src: !pto.tile_buf) + outs(%dst: !pto.tile_buf) + return + } + + // 23. ui32 -> i16 + func.func @TCVT_ui32_to_i16() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tcvt ins(%src: !pto.tile_buf) + outs(%dst: !pto.tile_buf) + return + } + + // 24. ui32 -> ui16 + func.func @TCVT_ui32_to_ui16() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tcvt ins(%src: !pto.tile_buf) + outs(%dst: !pto.tile_buf) + return + } + + // 25. ui32 -> ui8 + func.func @TCVT_ui32_to_ui8() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tcvt ins(%src: !pto.tile_buf) + outs(%dst: !pto.tile_buf) + return + } + + // 26. i64 -> i32 + func.func @TCVT_i64_to_i32() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tcvt ins(%src: !pto.tile_buf) + outs(%dst: !pto.tile_buf) + return + } + + // 27. f32 -> bf16 + func.func @TCVT_f32_to_bf16() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tcvt ins(%src: !pto.tile_buf) + outs(%dst: !pto.tile_buf) + return + } + + // 28. f32 -> i64 + func.func @TCVT_f32_to_i64() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tcvt ins(%src: !pto.tile_buf) + outs(%dst: !pto.tile_buf) + return + } + + // 29. f16 -> ui8 + func.func @TCVT_f16_to_ui8() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tcvt ins(%src: !pto.tile_buf) + outs(%dst: !pto.tile_buf) + return + } + + // 30. bf16 -> i32 + func.func @TCVT_bf16_to_i32() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tcvt ins(%src: !pto.tile_buf) + outs(%dst: !pto.tile_buf) + return + } + + // 31. bf16 -> f16 + func.func @TCVT_bf16_to_f16() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tcvt ins(%src: !pto.tile_buf) + outs(%dst: !pto.tile_buf) + return + } + + // 32. f32 -> i16 + func.func @TCVT_f32_to_i16() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tcvt ins(%src: !pto.tile_buf) + outs(%dst: !pto.tile_buf) + return + } + + // 33. f16 -> i16 + func.func @TCVT_f16_to_i16() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tcvt ins(%src: !pto.tile_buf) + outs(%dst: !pto.tile_buf) + return + } + + // 34. f16 -> si8 + func.func @TCVT_f16_to_si8() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tcvt ins(%src: !pto.tile_buf) + outs(%dst: !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_tilelang_tdiv.pto b/test/lit/vpto/expand_tile_op_tilelang_tdiv.pto new file mode 100644 index 000000000..2811e8316 --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_tdiv.pto @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms of conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tdiv via the default TileLang Python DSL template +// lib/TileOps/tdiv_template.py. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tdiv should be lowered to vector-style VPTO IR. +// CHECK: func.func @TDIV +// CHECK-NOT: pto.tdiv ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vdiv +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TDIV() { + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + + pto.tdiv ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%tile_buf : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_tilelang_tdivs.pto b/test/lit/vpto/expand_tile_op_tilelang_tdivs.pto new file mode 100644 index 000000000..35f1fe323 --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_tdivs.pto @@ -0,0 +1,64 @@ +// Test that pto.tdivs can be lowered to vector-style VPTO IR. +// +// IMPORTANT: Do NOT use --enable-tile-op-expand for ops with scalar operands. +// TDIVS has a scalar operand (f32), so it uses the PTOToVPTO lowering path. +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto %s -o - 2>/dev/null | FileCheck %s --check-prefix=CHECK-TILE-SCALAR --check-prefix=CHECK-SCALAR-TILE + +// tile / scalar form: +// CHECK-TILE-SCALAR-LABEL: func.func @TDIVS_TILE_SCALAR +// CHECK-TILE-SCALAR-NOT: pto.tdivs ins +// CHECK-TILE-SCALAR: pto.vecscope +// CHECK-TILE-SCALAR: pto.castptr +// CHECK-TILE-SCALAR: %[[MASK:.+]], %[[SCALAR_OUT:.+]] = pto.plt_b32 +// CHECK-TILE-SCALAR: %[[LD:.+]] = pto.vlds +// CHECK-TILE-SCALAR: %[[BR:.+]] = pto.vbr +// CHECK-TILE-SCALAR: %[[DIV:.+]] = pto.vdiv %[[LD]], %[[BR]], %[[MASK]] +// CHECK-TILE-SCALAR: pto.vsts %[[DIV]] + +// scalar / tile form: +// CHECK-SCALAR-TILE-LABEL: func.func @TDIVS_SCALAR_TILE +// CHECK-SCALAR-TILE-NOT: pto.tdivs ins +// CHECK-SCALAR-TILE: pto.vecscope +// CHECK-SCALAR-TILE: pto.castptr +// CHECK-SCALAR-TILE: %[[MASK:.+]], %[[SCALAR_OUT:.+]] = pto.plt_b32 +// CHECK-SCALAR-TILE: %[[LD:.+]] = pto.vlds +// CHECK-SCALAR-TILE: %[[BR:.+]] = pto.vbr +// CHECK-SCALAR-TILE: %[[DIV:.+]] = pto.vdiv %[[BR]], %[[LD]], %[[MASK]] +// CHECK-SCALAR-TILE: pto.vsts %[[DIV]] + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TDIVS_TILE_SCALAR() { + %a = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + %scalar = arith.constant 2.0 : f32 + + pto.tdivs ins(%a, %scalar : !pto.tile_buf, + f32) + outs(%dst : !pto.tile_buf) + return + } + + func.func @TDIVS_SCALAR_TILE() { + %a = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + %scalar = arith.constant 2.0 : f32 + + pto.tdivs ins(%scalar, %a : f32, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_tilelang_texp.pto b/test/lit/vpto/expand_tile_op_tilelang_texp.pto new file mode 100644 index 000000000..6243cd09c --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_texp.pto @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.texp should be lowered to vector-style VPTO IR. +// CHECK: func.func @TEXP +// CHECK-NOT: pto.texp ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vexp +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TEXP() { + %a = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + + pto.texp ins(%a: !pto.tile_buf) + outs(%tile_buf: !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_tilelang_texpand.pto b/test/lit/vpto/expand_tile_op_tilelang_texpand.pto new file mode 100644 index 000000000..0915e1f2d --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_texpand.pto @@ -0,0 +1,38 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.texpands via the default TileLang Python DSL template +// lib/TileOps/texpands_template.py. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.texpands should be lowered to vector-style VPTO IR. +// CHECK: func.func @TEXPANDS +// CHECK-NOT: pto.texpands ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vdup +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TEXPANDS() { + %scalar = arith.constant 1.0 : f32 + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.texpands ins(%scalar : f32) + outs(%dst : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_tilelang_tfillpad.pto b/test/lit/vpto/expand_tile_op_tilelang_tfillpad.pto new file mode 100644 index 000000000..7dab0b04c --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_tfillpad.pto @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tfillpad via the default TileLang Python DSL template +// lib/TileOps/tfillpad_template.py. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tfillpad should be lowered to vector-style VPTO IR. +// CHECK: func.func @TFILLPAD +// CHECK-NOT: pto.tfillpad ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK-DAG: pto.vdup +// CHECK-DAG: pto.vlds +// CHECK-DAG: pto.vsts +// Note: vstus is not supported in TileLang DSL v1, so padding uses vsts instead + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TFILLPAD() { + // Source Tile: valid region 8x48, total capacity 16x64 + %src = pto.alloc_tile + : !pto.tile_buf + // Destination Tile: same size as source, valid region also 8x48 + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tfillpad ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_tilelang_tfillpad_expand.pto b/test/lit/vpto/expand_tile_op_tilelang_tfillpad_expand.pto new file mode 100644 index 000000000..e97523156 --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_tfillpad_expand.pto @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tfillpad_expand via the default TileLang Python DSL template +// lib/TileOps/tfillpad_expand_template.py. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tfillpad_expand should be lowered to vector-style VPTO IR. +// CHECK: func.func @TFILLPAD_EXPAND +// CHECK-NOT: pto.tfillpad_expand ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK-DAG: pto.vdup +// CHECK-DAG: pto.vlds +// CHECK-DAG: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TFILLPAD_EXPAND() { + // 源 Tile: 较小尺寸 8x32 + %src = pto.alloc_tile + : !pto.tile_buf + // 目标 Tile: 较大尺寸 16x64 + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tfillpad_expand ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_tilelang_tfillpad_inplace.pto b/test/lit/vpto/expand_tile_op_tilelang_tfillpad_inplace.pto new file mode 100644 index 000000000..d3e7b00e1 --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_tfillpad_inplace.pto @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tfillpad (inplace mode) via the default TileLang Python DSL template +// lib/TileOps/tfillpad_template.py. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tfillpad (inplace) should be lowered to vector-style VPTO IR. +// CHECK: func.func @TFILLPAD_INPLACE +// CHECK-NOT: pto.tfillpad ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK-DAG: pto.vdup +// CHECK-DAG: pto.vlds +// CHECK-DAG: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TFILLPAD_INPLACE() { + // 原地操作:src 和 dst 是同一个 Tile + // 有效区域 8x48,总容量 16x64 + %tile = pto.alloc_tile + : !pto.tile_buf + + // src 和 dst 相同,表示原地填充 padding + pto.tfillpad ins(%tile : !pto.tile_buf) + outs(%tile : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_tilelang_tfmod.pto b/test/lit/vpto/expand_tile_op_tilelang_tfmod.pto new file mode 100644 index 000000000..4cfd8b451 --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_tfmod.pto @@ -0,0 +1,50 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tfmod via the default TileLang Python DSL template +// lib/TileOps/tfmod_template.py. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// Tile op should be lowered to vector-style VPTO IR. + +// CHECK-LABEL: func.func @TFMOD +// CHECK-NOT: pto.tfmod ins +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vdiv +// CHECK: pto.vtrc +// CHECK: pto.vmul +// CHECK: pto.vsub +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TFMOD() { + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + + pto.tfmod ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%tile_buf : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_tilelang_tfmods.pto b/test/lit/vpto/expand_tile_op_tilelang_tfmods.pto new file mode 100644 index 000000000..99394a6e2 --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_tfmods.pto @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that pto.tfmods can be lowered to vector-style VPTO IR. +// +// IMPORTANT: Do NOT use --enable-tile-op-expand for ops with scalar operands. +// TFMODS has a scalar operand (f32), so it uses the PTOToVPTO lowering path. +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto %s -o - 2>/dev/null | FileCheck %s + +// After PTOToVPTO lowering, pto.tfmods should use vdiv + vtrc + vmuls + vsub. +// CHECK: func.func @TFMODS +// CHECK-NOT: pto.tfmods ins +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vdiv +// CHECK: pto.vtrc +// CHECK: pto.vmuls +// CHECK: pto.vsub +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TFMODS() { + %a = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + %scalar = arith.constant 3.0 : f32 + + pto.tfmods ins(%a, %scalar : !pto.tile_buf, + f32) + outs(%tile_buf : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_tilelang_tlog.pto b/test/lit/vpto/expand_tile_op_tilelang_tlog.pto new file mode 100644 index 000000000..d2656a409 --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_tlog.pto @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tlog should be lowered to vector-style VPTO IR. +// CHECK: func.func @TLOG +// CHECK-NOT: pto.tlog ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vln +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TLOG() { + %a = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + + pto.tlog ins(%a: !pto.tile_buf) + outs(%tile_buf: !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_tilelang_tmax.pto b/test/lit/vpto/expand_tile_op_tilelang_tmax.pto new file mode 100644 index 000000000..4ec289f0e --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_tmax.pto @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms of conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tmax via the default TileLang Python DSL template +// lib/TileOps/tmax_template.py. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tmax should be lowered to vector-style VPTO IR. +// CHECK: func.func @TMAX +// CHECK-NOT: pto.tmax ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vmax +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TMAX() { + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + + pto.tmax ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%tile_buf : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_tilelang_tmaxs.pto b/test/lit/vpto/expand_tile_op_tilelang_tmaxs.pto new file mode 100644 index 000000000..6f908ca8a --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_tmaxs.pto @@ -0,0 +1,33 @@ +// Test that pto.tmaxs can be lowered to vector-style VPTO IR. +// +// IMPORTANT: Do NOT use --enable-tile-op-expand for ops with scalar operands. +// TMAXS has a scalar operand (f32), so it uses the PTOToVPTO lowering path. +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto %s -o - 2>/dev/null | FileCheck %s + +// After PTOToVPTO lowering, pto.tmaxs should use vmaxs (vector max scalar). +// CHECK: func.func @TMAXS +// CHECK-NOT: pto.tmaxs ins +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vmaxs +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TMAXS() { + %a = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + %scalar = arith.constant 0.0 : f32 + + pto.tmaxs ins(%a, %scalar : !pto.tile_buf, + f32) + outs(%tile_buf : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_tilelang_tmin.pto b/test/lit/vpto/expand_tile_op_tilelang_tmin.pto new file mode 100644 index 000000000..621edc55f --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_tmin.pto @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms of conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tmin via the default TileLang Python DSL template +// lib/TileOps/tmin_template.py. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tmin should be lowered to vector-style VPTO IR. +// CHECK: func.func @TMIN +// CHECK-NOT: pto.tmin ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vmin +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TMIN() { + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + + pto.tmin ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%tile_buf : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_tilelang_tmins.pto b/test/lit/vpto/expand_tile_op_tilelang_tmins.pto new file mode 100644 index 000000000..f32b0588a --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_tmins.pto @@ -0,0 +1,33 @@ +// Test that pto.tmins can be lowered to vector-style VPTO IR. +// +// IMPORTANT: Do NOT use --enable-tile-op-expand for ops with scalar operands. +// TMINS has a scalar operand (f32), so it uses the PTOToVPTO lowering path. +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto %s -o - 2>/dev/null | FileCheck %s + +// After PTOToVPTO lowering, pto.tmins should use vmins (vector min scalar). +// CHECK: func.func @TMINS +// CHECK-NOT: pto.tmins ins +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vmins +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TMINS() { + %a = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + %scalar = arith.constant 0.0 : f32 + + pto.tmins ins(%a, %scalar : !pto.tile_buf, + f32) + outs(%tile_buf : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_tilelang_tmrgsort.pto b/test/lit/vpto/expand_tile_op_tilelang_tmrgsort.pto new file mode 100644 index 000000000..1d50ea4c7 --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_tmrgsort.pto @@ -0,0 +1,109 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tmrgsort format1 (single input list) via the TileLang Python DSL template +// lib/TileOps/tmrgsort_template.py. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tmrgsort (format1: single input list) should be lowered to vmrgsort4. + +// CHECK-LABEL: func.func @TMRGSORT_1LIST_F32 +// CHECK-NOT: pto.tmrgsort ins +// CHECK: pto.vmrgsort4 + +// CHECK-LABEL: func.func @TMRGSORT_1LIST_F16 +// CHECK-NOT: pto.tmrgsort ins +// CHECK: pto.vmrgsort4 + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + // Format 1: single input list (internal block sorting) with f32 + func.func @TMRGSORT_1LIST_F32() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + %blockLen = arith.constant 64 : i32 + + pto.tmrgsort ins(%src, %blockLen : !pto.tile_buf, i32) + outs(%dst : !pto.tile_buf) + return + } + + // Format 1: single input list (internal block sorting) with f16 + func.func @TMRGSORT_1LIST_F16() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + %blockLen = arith.constant 64 : i32 + + pto.tmrgsort ins(%src, %blockLen : !pto.tile_buf, i32) + outs(%dst : !pto.tile_buf) + return + } + + // Format2: 2-list merge sort (f32_2list_b64_basic) + func.func @TMRGSORT_f32_2list_b64_basic(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, + %dst_ptr: !pto.ptr, %ex_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %c0_i32 = arith.constant 0 : i32 + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c512, %c512, %c512, %c512, %c1] + : !pto.tensor_view<1x1x1x1x256xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c256] + : !pto.tensor_view<1x1x1x1x256xf32> -> !pto.partition_tensor_view<1x1x1x1x256xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c256] + : !pto.tensor_view<1x1x1x1x256xf32> -> !pto.partition_tensor_view<1x1x1x1x256xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c256] + : !pto.tensor_view<1x1x1x1x256xf32> -> !pto.partition_tensor_view<1x1x1x1x256xf32> + + %src0_tile = pto.alloc_tile : !pto.tile_buf + %src1_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + %tmp_tile = pto.alloc_tile : !pto.tile_buf + %ex_vec = arith.constant dense<0> : vector<4xi16> + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x1x256xf32>) + outs(%src0_tile : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x256xf32>) + outs(%src1_tile : !pto.tile_buf) + + pto.tmrgsort ins(%src0_tile, %src1_tile, %tmp_tile {exhausted = false} : + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile, %ex_vec : + !pto.tile_buf, + vector<4xi16>) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x256xf32>) + return + } +} \ No newline at end of file diff --git a/test/lit/vpto/expand_tile_op_tilelang_tmul.pto b/test/lit/vpto/expand_tile_op_tilelang_tmul.pto new file mode 100644 index 000000000..de37d6039 --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_tmul.pto @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms of conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tmul via the default TileLang Python DSL template +// lib/TileOps/tmul_template.py. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tmul should be lowered to vector-style VPTO IR. +// CHECK: func.func @TMUL +// CHECK-NOT: pto.tmul ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vmul +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TMUL() { + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + + pto.tmul ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%tile_buf : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_tilelang_tmuls.pto b/test/lit/vpto/expand_tile_op_tilelang_tmuls.pto new file mode 100644 index 000000000..d99102737 --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_tmuls.pto @@ -0,0 +1,33 @@ +// Test that pto.tmuls can be lowered to vector-style VPTO IR. +// +// IMPORTANT: Do NOT use --enable-tile-op-expand for ops with scalar operands. +// TMULS has a scalar operand (f32), so it uses the PTOToVPTO lowering path. +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto %s -o - 2>/dev/null | FileCheck %s + +// After PTOToVPTO lowering, pto.tmuls should use vmuls (vector multiply scalar). +// CHECK: func.func @TMULS +// CHECK-NOT: pto.tmuls ins +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vmuls +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TMULS() { + %a = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + %scalar = arith.constant 2.0 : f32 + + pto.tmuls ins(%a, %scalar : !pto.tile_buf, + f32) + outs(%tile_buf : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_tilelang_tneg.pto b/test/lit/vpto/expand_tile_op_tilelang_tneg.pto new file mode 100644 index 000000000..e932bb6af --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_tneg.pto @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tneg should be lowered to vector-style VPTO IR. +// CHECK: func.func @TNEG +// CHECK-NOT: pto.tneg ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vneg +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TNEG() { + %a = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + + pto.tneg ins(%a: !pto.tile_buf) + outs(%tile_buf: !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_tilelang_tnot.pto b/test/lit/vpto/expand_tile_op_tilelang_tnot.pto new file mode 100644 index 000000000..6e15b0e39 --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_tnot.pto @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tnot should be lowered to vector-style VPTO IR. +// CHECK: func.func @TNOT +// CHECK-NOT: pto.tnot ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vnot +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TNOT() { + %a = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + + pto.tnot ins(%a: !pto.tile_buf) + outs(%tile_buf: !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_tilelang_tor.pto b/test/lit/vpto/expand_tile_op_tilelang_tor.pto new file mode 100644 index 000000000..d12f3e739 --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_tor.pto @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms of conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tor via the default TileLang Python DSL template +// lib/TileOps/tor_template.py. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tor should be lowered to vector-style VPTO IR. +// CHECK: func.func @TOR +// CHECK-NOT: pto.tor ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vor +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TOR() { + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + + pto.tor ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%tile_buf : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_tilelang_tors.pto b/test/lit/vpto/expand_tile_op_tilelang_tors.pto new file mode 100644 index 000000000..3653fb0d1 --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_tors.pto @@ -0,0 +1,34 @@ +// Test that pto.tors can be lowered to vector-style VPTO IR. +// +// IMPORTANT: Do NOT use --enable-tile-op-expand for ops with scalar operands. +// TORS has a scalar operand (i32), so it uses the PTOToVPTO lowering path. +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto %s -o - 2>/dev/null | FileCheck %s + +// After PTOToVPTO lowering, pto.tors should use vbr + vor. +// CHECK: func.func @TORS +// CHECK-NOT: pto.tors ins +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vbr +// CHECK: pto.vor +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TORS() { + %a = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + %scalar = arith.constant 0xFF : i32 + + pto.tors ins(%a, %scalar : !pto.tile_buf, + i32) + outs(%tile_buf : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_tilelang_tpartadd.pto b/test/lit/vpto/expand_tile_op_tilelang_tpartadd.pto new file mode 100644 index 000000000..8f6883c34 --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_tpartadd.pto @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tpartadd/tpartmul/tpartmax/tpartmin via TileLang Python DSL templates. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// Tile ops should be lowered to vector-style VPTO IR. + +// TPartAdd checks +// CHECK-LABEL: func.func @TPARTADD +// CHECK-NOT: pto.tpartadd ins +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vadd +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TPARTADD() { + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tpartadd ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/lit/vpto/expand_tile_op_tilelang_tpartmax.pto b/test/lit/vpto/expand_tile_op_tilelang_tpartmax.pto new file mode 100644 index 000000000..522d5f3b2 --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_tpartmax.pto @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tpartadd/tpartmul/tpartmax/tpartmin via TileLang Python DSL templates. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// Tile ops should be lowered to vector-style VPTO IR. + +// TPartMax checks +// CHECK-LABEL: func.func @TPARTMAX +// CHECK-NOT: pto.tpartmax ins +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vmax +// CHECK: pto.vsts + + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TPARTMAX() { + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tpartmax ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/lit/vpto/expand_tile_op_tilelang_tpartmin.pto b/test/lit/vpto/expand_tile_op_tilelang_tpartmin.pto new file mode 100644 index 000000000..7980d7cfe --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_tpartmin.pto @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tpartadd/tpartmul/tpartmax/tpartmin via TileLang Python DSL templates. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// Tile ops should be lowered to vector-style VPTO IR. + +// TPartMin checks +// CHECK-LABEL: func.func @TPARTMIN +// CHECK-NOT: pto.tpartmin ins +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vmin +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TPARTMIN() { + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tpartmin ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/lit/vpto/expand_tile_op_tilelang_tpartmul.pto b/test/lit/vpto/expand_tile_op_tilelang_tpartmul.pto new file mode 100644 index 000000000..53c140bae --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_tpartmul.pto @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tpartadd/tpartmul/tpartmax/tpartmin via TileLang Python DSL templates. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// Tile ops should be lowered to vector-style VPTO IR. + +// TPartMul checks +// CHECK-LABEL: func.func @TPARTMUL +// CHECK-NOT: pto.tpartmul ins +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vmul +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TPARTMUL() { + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tpartmul ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/lit/vpto/expand_tile_op_tilelang_trecip.pto b/test/lit/vpto/expand_tile_op_tilelang_trecip.pto new file mode 100644 index 000000000..8e082462a --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_trecip.pto @@ -0,0 +1,41 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.trecip should be lowered to vector-style VPTO IR. +// CHECK: func.func @TRECIP +// CHECK-NOT: pto.trecip ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vbr +// CHECK: pto.vdiv +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TRECIP() { + %a = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + + pto.trecip ins(%a: !pto.tile_buf) + outs(%tile_buf: !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_tilelang_trem.pto b/test/lit/vpto/expand_tile_op_tilelang_trem.pto new file mode 100644 index 000000000..2aeca7f55 --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_trem.pto @@ -0,0 +1,58 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms of conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.trem via the default TileLang Python DSL template +// lib/TileOps/trem_template.py. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// Tile op should be lowered to vector-style VPTO IR. + +// CHECK-LABEL: func.func @TREM +// CHECK-NOT: pto.trem ins +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vdiv +// CHECK: pto.vtrc +// CHECK: pto.vmul +// CHECK: pto.vsub +// CHECK: pto.vcmps +// CHECK: pto.vadd +// CHECK: pto.vsel +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TREM() { + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + + pto.trem ins(%a, %b, %tmp : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%tile_buf : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_tilelang_trems.pto b/test/lit/vpto/expand_tile_op_tilelang_trems.pto new file mode 100644 index 000000000..2e54f9585 --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_trems.pto @@ -0,0 +1,49 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that pto.trems can be lowered to vector-style VPTO IR. +// +// IMPORTANT: Do NOT use --enable-tile-op-expand for ops with scalar operands. +// TREMS has a scalar operand (f32), so it uses the PTOToVPTO lowering path. +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto %s -o - 2>/dev/null | FileCheck %s + +// After PTOToVPTO lowering, pto.trems should use vbr + vdiv + vmuls + vsub. +// CHECK: func.func @TREMS +// CHECK-NOT: pto.trems ins +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vbr +// CHECK: pto.vdiv +// CHECK: pto.vmuls +// CHECK: pto.vsub +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TREMS() { + %a = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + %tmp_buf = pto.alloc_tile + : !pto.tile_buf + %scalar = arith.constant 3.0 : f32 + + pto.trems ins(%a, %scalar, %tmp_buf : !pto.tile_buf, + f32, + !pto.tile_buf) + outs(%tile_buf : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_tilelang_trowargmax.pto b/test/lit/vpto/expand_tile_op_tilelang_trowargmax.pto new file mode 100644 index 000000000..8d0de3bd5 --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_trowargmax.pto @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s +// After the full tile-op-expand path on the VPTO backend, the original +// CHECK: func.func @TROWARGMAX +// CHECK-NOT: pto.trowargmax ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vbr +// CHECK: pto.vlds +// CHECK: pto.vcmax +// CHECK: pto.vdintlv +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TROWARGMAX() { + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_tilelang_trowargmin.pto b/test/lit/vpto/expand_tile_op_tilelang_trowargmin.pto new file mode 100644 index 000000000..88556082d --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_trowargmin.pto @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s +// After the full tile-op-expand path on the VPTO backend, the original +// CHECK: func.func @TROWARGMIN +// CHECK-NOT: pto.trowargmin ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vbr +// CHECK: pto.vlds +// CHECK: pto.vcmin +// CHECK: pto.vdintlv +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TROWARGMIN() { + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_tilelang_trowmax.pto b/test/lit/vpto/expand_tile_op_tilelang_trowmax.pto new file mode 100644 index 000000000..8be500392 --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_trowmax.pto @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s +// After the full tile-op-expand path on the VPTO backend, the original +// CHECK: func.func @TROWMAX +// CHECK-NOT: pto.trowmax ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vbr +// CHECK: pto.vlds +// CHECK: pto.vcmax +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TROWMAX() { + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_tilelang_trowmin.pto b/test/lit/vpto/expand_tile_op_tilelang_trowmin.pto new file mode 100644 index 000000000..f9bcb767a --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_trowmin.pto @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s +// After the full tile-op-expand path on the VPTO backend, the original +// CHECK: func.func @TROWMIN +// CHECK-NOT: pto.trowmin ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vbr +// CHECK: pto.vlds +// CHECK: pto.vcmin +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TROWMIN() { + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_tilelang_trowprod.pto b/test/lit/vpto/expand_tile_op_tilelang_trowprod.pto new file mode 100644 index 000000000..9fcb89112 --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_trowprod.pto @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s +// After the full tile-op-expand path on the VPTO backend, the original +// CHECK: func.func @TROWPROD +// CHECK-NOT: pto.trowprod ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vbr +// CHECK: pto.vlds +// CHECK: pto.vmul +// CHECK: pto.vintlv +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TROWPROD() { + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.trowprod ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_tilelang_trowsum.pto b/test/lit/vpto/expand_tile_op_tilelang_trowsum.pto new file mode 100644 index 000000000..b5b15f674 --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_trowsum.pto @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s +// After the full tile-op-expand path on the VPTO backend, the original +// CHECK: func.func @TROWSUM +// CHECK-NOT: pto.trowsum ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vbr +// CHECK: pto.vlds +// CHECK: pto.vcadd +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TROWSUM() { + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.trowsum ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_tilelang_trsqrt.pto b/test/lit/vpto/expand_tile_op_tilelang_trsqrt.pto new file mode 100644 index 000000000..6564102bd --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_trsqrt.pto @@ -0,0 +1,41 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.trsqrt should be lowered to vector-style VPTO IR. +// CHECK: func.func @TRSQRT +// CHECK-NOT: pto.trsqrt ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vsqrt +// CHECK: pto.vdiv +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TRSQRT() { + %a = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + + pto.trsqrt ins(%a: !pto.tile_buf) + outs(%tile_buf: !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_tilelang_tshl.pto b/test/lit/vpto/expand_tile_op_tilelang_tshl.pto new file mode 100644 index 000000000..a05fa95ec --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_tshl.pto @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms of conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tshl via the default TileLang Python DSL template +// lib/TileOps/tshl_template.py. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tshl should be lowered to vector-style VPTO IR. +// CHECK: func.func @TSHL +// CHECK-NOT: pto.tshl ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vshl +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TSHL() { + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + + pto.tshl ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%tile_buf : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_tilelang_tshls.pto b/test/lit/vpto/expand_tile_op_tilelang_tshls.pto new file mode 100644 index 000000000..34e290ea7 --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_tshls.pto @@ -0,0 +1,33 @@ +// Test that pto.tshls can be lowered to vector-style VPTO IR. +// +// IMPORTANT: Do NOT use --enable-tile-op-expand for ops with scalar operands. +// TSHLS has a scalar operand (i16), so it uses the PTOToVPTO lowering path. +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto %s -o - 2>/dev/null | FileCheck %s + +// After PTOToVPTO lowering, pto.tshls should use vshls (vector shift left scalar). +// CHECK: func.func @TSHLS +// CHECK-NOT: pto.tshls ins +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vshls +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TSHLS() { + %a = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + %scalar = arith.constant 2 : i16 + + pto.tshls ins(%a, %scalar : !pto.tile_buf, + i16) + outs(%tile_buf : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_tilelang_tshr.pto b/test/lit/vpto/expand_tile_op_tilelang_tshr.pto new file mode 100644 index 000000000..6ce58add5 --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_tshr.pto @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms of conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tshr via the default TileLang Python DSL template +// lib/TileOps/tshr_template.py. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tshr should be lowered to vector-style VPTO IR. +// CHECK: func.func @TSHR +// CHECK-NOT: pto.tshr ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vshr +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TSHR() { + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + + pto.tshr ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%tile_buf : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_tilelang_tshrs.pto b/test/lit/vpto/expand_tile_op_tilelang_tshrs.pto new file mode 100644 index 000000000..512f543a8 --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_tshrs.pto @@ -0,0 +1,33 @@ +// Test that pto.tshrs can be lowered to vector-style VPTO IR. +// +// IMPORTANT: Do NOT use --enable-tile-op-expand for ops with scalar operands. +// TSHRS has a scalar operand (i16), so it uses the PTOToVPTO lowering path. +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto %s -o - 2>/dev/null | FileCheck %s + +// After PTOToVPTO lowering, pto.tshrs should use vshrs (vector shift right scalar). +// CHECK: func.func @TSHRS +// CHECK-NOT: pto.tshrs ins +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vshrs +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TSHRS() { + %a = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + %scalar = arith.constant 2 : i16 + + pto.tshrs ins(%a, %scalar : !pto.tile_buf, + i16) + outs(%tile_buf : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_tilelang_tsort32.pto b/test/lit/vpto/expand_tile_op_tilelang_tsort32.pto new file mode 100644 index 000000000..906b72fb4 --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_tsort32.pto @@ -0,0 +1,53 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tsort32 via the TileLang Python DSL template +// lib/TileOps/tsort32_template.py. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tsort32 should be lowered to vector-style VPTO IR. +// CHECK-LABEL: func.func @TSORT32_no_tmp +// CHECK-NOT: pto.tsort32 ins +// CHECK: pto.vbitsort + +// CHECK-LABEL: func.func @TSORT32_with_tmp +// CHECK-NOT: pto.tsort32 ins +// CHECK: pto.vbitsort + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + // Case 1: aligned cols (valid_cols % 32 == 0), no tmp needed + func.func @TSORT32_no_tmp() { + %src = pto.alloc_tile : !pto.tile_buf + %idx = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tsort32 ins(%src, %idx : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } + + // Case 2: unaligned cols (valid_cols % 32 != 0), tmp needed + func.func @TSORT32_with_tmp() { + %src = pto.alloc_tile : !pto.tile_buf + %idx = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tsort32 ins(%src, %idx, %tmp : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/lit/vpto/expand_tile_op_tilelang_tsqrt.pto b/test/lit/vpto/expand_tile_op_tilelang_tsqrt.pto new file mode 100644 index 000000000..667223752 --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_tsqrt.pto @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tsqrt should be lowered to vector-style VPTO IR. +// CHECK: func.func @TSQRT +// CHECK-NOT: pto.tsqrt ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vsqrt +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TSQRT() { + %a = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + + pto.tsqrt ins(%a: !pto.tile_buf) + outs(%tile_buf: !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_tilelang_tsub.pto b/test/lit/vpto/expand_tile_op_tilelang_tsub.pto new file mode 100644 index 000000000..7f53712fe --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_tsub.pto @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms of conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tsub via the default TileLang Python DSL template +// lib/TileOps/tsub_template.py. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tsub should be lowered to vector-style VPTO IR. +// CHECK: func.func @TSUB +// CHECK-NOT: pto.tsub ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vsub +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TSUB() { + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + + pto.tsub ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%tile_buf : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_tilelang_tsubs.pto b/test/lit/vpto/expand_tile_op_tilelang_tsubs.pto new file mode 100644 index 000000000..b7f8d079b --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_tsubs.pto @@ -0,0 +1,34 @@ +// Test that pto.tsubs can be lowered to vector-style VPTO IR. +// +// IMPORTANT: Do NOT use --enable-tile-op-expand for ops with scalar operands. +// TSUBS has a scalar operand (f32), so it uses the PTOToVPTO lowering path. +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto %s -o - 2>/dev/null | FileCheck %s + +// After PTOToVPTO lowering, pto.tsubs should use vsubs (vector subtract scalar). +// CHECK: func.func @TSUBS +// CHECK-NOT: pto.tsubs ins +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vbr +// CHECK: pto.vsub +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TSUBS() { + %a = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + %scalar = arith.constant 1.0 : f32 + + pto.tsubs ins(%a, %scalar : !pto.tile_buf, + f32) + outs(%tile_buf : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_tilelang_txor.pto b/test/lit/vpto/expand_tile_op_tilelang_txor.pto new file mode 100644 index 000000000..5441bf62c --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_txor.pto @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms of conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.txor via the default TileLang Python DSL template +// lib/TileOps/txor_template.py. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.txor should be lowered to vector-style VPTO IR. +// CHECK: func.func @TXOR +// CHECK-NOT: pto.txor ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vxor +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TXOR() { + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + + pto.txor ins(%a, %b, %tmp : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%tile_buf : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_tilelang_txors.pto b/test/lit/vpto/expand_tile_op_tilelang_txors.pto new file mode 100644 index 000000000..2686dbeb2 --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tilelang_txors.pto @@ -0,0 +1,39 @@ +// Test that pto.txors can be lowered to vector-style VPTO IR. +// +// IMPORTANT: Do NOT use --enable-tile-op-expand for ops with scalar operands. +// TXORS has a scalar operand (i32), so it uses the PTOToVPTO lowering path. +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto %s -o - 2>/dev/null | FileCheck %s + +// After PTOToVPTO lowering, pto.txors should use vbr + vxor. +// CHECK: func.func @TXORS +// CHECK-NOT: pto.txors ins +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vbr +// CHECK: pto.vxor +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TXORS() { + %a = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + %tmp_buf = pto.alloc_tile + : !pto.tile_buf + %scalar = arith.constant 0xFF : i32 + + pto.txors ins(%a, %scalar, %tmp_buf : !pto.tile_buf, + i32, + !pto.tile_buf) + outs(%tile_buf : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_tlrelu_tilelang.pto b/test/lit/vpto/expand_tile_op_tlrelu_tilelang.pto new file mode 100644 index 000000000..c5b1b4dc9 --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tlrelu_tilelang.pto @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tlrelu via the default TileLang Python DSL template +// lib/TileOps/tlrelu_template.py. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tlrelu should be lowered to vector-style VPTO IR. +// CHECK: func.func @TLRelu_test +// CHECK-NOT: pto.tlrelu ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vlrelu +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TLRelu_test() { + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + %slope = arith.constant 0.1 : f32 + + pto.tlrelu ins(%src, %slope : !pto.tile_buf, + f32) + outs(%dst : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_tprelu_tilelang.pto b/test/lit/vpto/expand_tile_op_tprelu_tilelang.pto new file mode 100644 index 000000000..91d3810f0 --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tprelu_tilelang.pto @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tprelu via the default TileLang Python DSL template +// lib/TileOps/tprelu_template.py. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tprelu should be lowered to vector-style VPTO IR. +// CHECK: func.func @TPRelu_test +// CHECK-NOT: pto.tprelu ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vprelu +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TPRelu_test() { + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + + pto.tprelu ins(%src0, %src1, %tmp : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_trelu_tilelang.pto b/test/lit/vpto/expand_tile_op_trelu_tilelang.pto new file mode 100644 index 000000000..c11c8b359 --- /dev/null +++ b/test/lit/vpto/expand_tile_op_trelu_tilelang.pto @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.trelu via the default TileLang Python DSL template +// lib/TileOps/trelu_template.py. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.trelu should be lowered to vector-style VPTO IR. +// CHECK: func.func @TRelu_test +// CHECK-NOT: pto.trelu ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vrelu +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TRelu_test() { + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.trelu ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_tsel_tilelang.pto b/test/lit/vpto/expand_tile_op_tsel_tilelang.pto new file mode 100644 index 000000000..42e045bd2 --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tsel_tilelang.pto @@ -0,0 +1,58 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tsel via the default TileLang Python DSL template +// lib/TileOps/tsel_template.py. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tsel should be lowered to vector-style VPTO IR. +// CHECK: func.func @TSEL_test +// CHECK-NOT: pto.tsel ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.plds +// CHECK: pto.vlds +// CHECK: pto.vsel +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TSEL_test() { + %mask = pto.alloc_tile + : !pto.tile_buf + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + + pto.tsel ins(%mask, %src0, %src1, %tmp : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_tsels_tilelang.pto b/test/lit/vpto/expand_tile_op_tsels_tilelang.pto new file mode 100644 index 000000000..71efabde1 --- /dev/null +++ b/test/lit/vpto/expand_tile_op_tsels_tilelang.pto @@ -0,0 +1,57 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tsels via the default TileLang Python DSL template +// lib/TileOps/tsels_template.py. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tsels should be lowered to vector-style VPTO IR. +// CHECK: func.func @TSELS_test +// CHECK-NOT: pto.tsels ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK-DAG: pto.vdup +// CHECK-DAG: pto.plds +// CHECK-DAG: pto.vlds +// CHECK-DAG: pto.vsel +// CHECK-DAG: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TSELS_test() { + %mask = pto.alloc_tile + : !pto.tile_buf + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + %scalar = arith.constant 42 : i8 + + pto.tsels ins(%mask, %src, %tmp, %scalar : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf, + i8) + outs(%dst : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/fold_tile_buf_intrinsics.pto b/test/lit/vpto/fold_tile_buf_intrinsics.pto new file mode 100644 index 000000000..0d909ab6f --- /dev/null +++ b/test/lit/vpto/fold_tile_buf_intrinsics.pto @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Unit test for FoldTileBufIntrinsics on the VPTO tile-op expansion path. +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --mlir-print-ir-after=pto-fold-tile-buf-intrinsics %s -o /dev/null 2>&1 | FileCheck %s +// +// After FoldTileBufIntrinsics: +// - tile_buf_addr / tile_valid_rows / tile_valid_cols should be gone +// - the expanded body should already use concrete memrefs/constants +// CHECK-LABEL: func.func @TADD +// CHECK-NOT: pto.tile_buf_addr +// CHECK-NOT: pto.tile_valid_rows +// CHECK-NOT: pto.tile_valid_cols +// CHECK-NOT: pto.bind_tile +// CHECK: pto.pointer_cast( +// CHECK: arith.constant 16 : index +// CHECK: arith.constant 64 : index +// CHECK: pto.vlds +// CHECK: pto.vadd +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TADD() { + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tadd ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/fp_load.pto b/test/lit/vpto/fp_load.pto new file mode 100644 index 000000000..4218c10af --- /dev/null +++ b/test/lit/vpto/fp_load.pto @@ -0,0 +1,22 @@ +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --mlir-print-ir-before=vpto-expand-wrapper-ops %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=ROUNDTRIP +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --mlir-print-ir-after=vpto-expand-wrapper-ops %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=EXPAND + +module attributes {"pto.target_arch" = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @fp_load_probe(%src: !pto.ptr, + %dst: !pto.ptr) { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c2_i64 = arith.constant 2 : i64 + + pto.mte_l1_fb %src, %dst, %c2_i64 + nburst(%c1_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + return + } +} + +// ROUNDTRIP-LABEL: func.func @fp_load_probe( +// ROUNDTRIP: pto.mte_l1_fb %{{.*}}, %{{.*}}, %{{.*}} nburst(%{{.*}}, %{{.*}}, %{{.*}}) + +// EXPAND-LABEL: func.func @fp_load_probe( +// EXPAND: pto.copy_cbuf_to_fbuf diff --git a/test/lit/vpto/get_vms4_sr_vpto.pto b/test/lit/vpto/get_vms4_sr_vpto.pto new file mode 100644 index 000000000..b8384f5e7 --- /dev/null +++ b/test/lit/vpto/get_vms4_sr_vpto.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto %s -o - | FileCheck %s + +module attributes {pto.target_arch = "a5"} { + module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @get_vms4_sr_kernel(%arg0: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %list0, %list1, %list2, %list3 = pto.get_vms4_sr : i16, i16, i16, i16 + pto.store_scalar %list1, %arg0[%c0] : !pto.ptr, i16 + return + } + } +} + +// CHECK: %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} = pto.get_vms4_sr : i16, i16, i16, i16 +// CHECK: pto.store_scalar diff --git a/test/lit/vpto/inline_libcall_filter_tilelang_scope.pto b/test/lit/vpto/inline_libcall_filter_tilelang_scope.pto new file mode 100644 index 000000000..cf863e940 --- /dev/null +++ b/test/lit/vpto/inline_libcall_filter_tilelang_scope.pto @@ -0,0 +1,38 @@ +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --mlir-print-ir-after=pto-inline-libcall %s -o /dev/null 2>&1 | FileCheck %s + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @kernel(%arg0: i32) { + %v0 = func.call @__tl_inline_add1_i32(%arg0) : (i32) -> i32 + %v1 = func.call @__tilelang_template_passthrough_i32(%v0) : (i32) -> i32 + %v2 = func.call @__regular_passthrough_i32(%v1) : (i32) -> i32 + return + } + + func.func private @__tl_inline_add1_i32(%x: i32) -> i32 attributes { pto.tilelang.inline_proc } { + %c1 = arith.constant 1 : i32 + %y = arith.addi %x, %c1 : i32 + return %y : i32 + } + + // This is intentionally NOT an inline_proc helper. The inline pass should + // still inline it because template helper inlining is enabled in VPTO + // mainline regardless of --enable-tile-op-expand. + func.func private @__tilelang_template_passthrough_i32(%x: i32) -> i32 attributes { pto.tilelang.instance } { + return %x : i32 + } + + // Regular private helper without TileLang/OP-Lib attrs must not be inlined + // by pto-inline-libcall. + func.func private @__regular_passthrough_i32(%x: i32) -> i32 { + return %x : i32 + } +} + +// CHECK-LABEL: func.func @kernel +// CHECK: arith.constant 1 : i32 +// CHECK-NOT: func.call @__tl_inline_add1_i32 +// CHECK-NOT: call @__tilelang_template_passthrough_i32 +// CHECK: call @__regular_passthrough_i32 +// CHECK-NOT: func.func private @__tilelang_template_passthrough_i32 +// CHECK: func.func private @__regular_passthrough_i32 +// CHECK-NOT: func.func private @__tl_inline_add1_i32 diff --git a/test/lit/vpto/inline_libcall_result_rewrite.pto b/test/lit/vpto/inline_libcall_result_rewrite.pto new file mode 100644 index 000000000..351c2de24 --- /dev/null +++ b/test/lit/vpto/inline_libcall_result_rewrite.pto @@ -0,0 +1,34 @@ +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --mlir-print-ir-after=pto-inline-libcall %s -o /dev/null 2>&1 | FileCheck %s + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @kernel(%x: i32) { + %a, %b = func.call @__tl_inline_pair_i32(%x) : (i32) -> (i32, i32) + %sum = arith.addi %a, %b : i32 + func.call @__tl_inline_sink_i32(%sum) : (i32) -> () + return + } + + func.func private @__tl_inline_pair_i32(%arg0: i32) -> (i32, i32) attributes { pto.tilelang.inline_proc } { + %c1 = arith.constant 1 : i32 + %c2 = arith.constant 2 : i32 + %v1 = arith.addi %arg0, %c1 : i32 + %v2 = arith.addi %arg0, %c2 : i32 + return %v1, %v2 : i32, i32 + } + + func.func private @__tl_inline_sink_i32(%arg0: i32) attributes { pto.tilelang.inline_proc } { + %c0 = arith.constant 0 : i32 + %_ = arith.addi %arg0, %c0 : i32 + return + } +} + +// CHECK-LABEL: func.func @kernel( +// CHECK: arith.constant 1 : i32 +// CHECK: arith.constant 2 : i32 +// CHECK: arith.addi %{{[^,]+}}, %{{[^,]+}} : i32 +// CHECK: arith.addi %{{[^,]+}}, %{{[^,]+}} : i32 +// CHECK: arith.addi %{{[^,]+}}, %{{[^,]+}} : i32 +// CHECK-NOT: func.call @__tl_inline_ +// CHECK-NOT: func.func private @__tl_inline_pair_i32 +// CHECK-NOT: func.func private @__tl_inline_sink_i32 diff --git a/test/lit/vpto/intra_block_sync_vpto_llvm.pto b/test/lit/vpto/intra_block_sync_vpto_llvm.pto new file mode 100644 index 000000000..a2b7d001d --- /dev/null +++ b/test/lit/vpto/intra_block_sync_vpto_llvm.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ( mkdir -p %T && ptoas --pto-arch=a5 --pto-backend=vpto %s -o %t --mlir-print-ir-after=convert-func-to-llvm 2>&1 || true ) | FileCheck %s + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @intra_block_sync_static_dynamic(%event_id: index) attributes {pto.kernel} { + pto.sync.set , 0 + pto.sync.wait , 0 + pto.sync.set , %event_id + pto.sync.wait , %event_id + return + } +} + +// CHECK-LABEL: llvm.func @intra_block_sync_static_dynamic_mix_aiv +// CHECK: llvm.call @llvm.hivm.SET.INTRA.BLOCK.mode +// CHECK: llvm.call @llvm.hivm.WAIT.INTRA.BLOCK.mode diff --git a/test/lit/vpto/issue220_vrelu_i32_vpto_llvm.pto b/test/lit/vpto/issue220_vrelu_i32_vpto_llvm.pto new file mode 100644 index 000000000..5daed7872 --- /dev/null +++ b/test/lit/vpto/issue220_vrelu_i32_vpto_llvm.pto @@ -0,0 +1,39 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Regression test for issue #220: direct VPTO `pto.vrelu` should accept both +// signless and signed i32 vectors and lower them to the same HIVM callee. +// +// RUN: ( mkdir -p %T && ptoas --pto-arch=a5 --pto-backend=vpto %s -o %t --mlir-print-ir-after=convert-func-to-llvm 2>&1 || true ) | FileCheck %s + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vrelu_signless_i32_store(%value: !pto.vreg<64xi32>, %dst: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %out = pto.vrelu %value, %mask : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> + pto.vsts %out, %dst[%c0], %mask : !pto.vreg<64xi32>, !pto.ptr, !pto.mask + } + return + } + + func.func @vrelu_signed_i32_store(%value: !pto.vreg<64xsi32>, %dst: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %out = pto.vrelu %value, %mask : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<64xsi32> + pto.vsts %out, %dst[%c0], %mask : !pto.vreg<64xsi32>, !pto.ptr, !pto.mask + } + return + } +} + +// CHECK-LABEL: llvm.func @vrelu_signless_i32_store_mix_aiv +// CHECK: llvm.call @llvm.hivm.vrelu.v64s32.x +// CHECK-LABEL: llvm.func @vrelu_signed_i32_store_mix_aiv +// CHECK: llvm.call @llvm.hivm.vrelu.v64s32.x diff --git a/test/lit/vpto/issue_173_vpto_llvm.pto b/test/lit/vpto/issue_173_vpto_llvm.pto new file mode 100644 index 000000000..b62322639 --- /dev/null +++ b/test/lit/vpto/issue_173_vpto_llvm.pto @@ -0,0 +1,36 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ( mkdir -p %T && ptoas --pto-arch=a5 --pto-backend=vpto %s -o %t --mlir-print-ir-after=convert-func-to-llvm 2>&1 || true ) | FileCheck %s + +// Regression test for issue #173: signed and signless i16 vector stores should +// share the same LLVM/HIVM declaration after VPTO type conversion. +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @store_signed_i16(%value: !pto.vreg<128xsi16>, %dst: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + pto.vsts %value, %dst[%c0], %mask : !pto.vreg<128xsi16>, !pto.ptr, !pto.mask + } + return + } + + func.func @store_signless_i16(%value: !pto.vreg<128xi16>, %dst: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + pto.vsts %value, %dst[%c0], %mask : !pto.vreg<128xi16>, !pto.ptr, !pto.mask + } + return + } +} + +// CHECK-LABEL: llvm.func @store_signed_i16_mix_aiv +// CHECK: llvm.call @llvm.hivm.vstsx1.v128s16 +// CHECK-LABEL: llvm.func @store_signless_i16_mix_aiv +// CHECK: llvm.call @llvm.hivm.vstsx1.v128s16 diff --git a/test/lit/vpto/issue_247_load_scalar_ptr_normalize.pto b/test/lit/vpto/issue_247_load_scalar_ptr_normalize.pto new file mode 100644 index 000000000..f2b7be1c9 --- /dev/null +++ b/test/lit/vpto/issue_247_load_scalar_ptr_normalize.pto @@ -0,0 +1,26 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --mlir-print-ir-after=vpto-ptr-normalize %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=NORMALIZE +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto %s -o /dev/null + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @load_scalar_ui32(%ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %v = pto.load_scalar %ptr[%c0] : !pto.ptr -> ui32 + pto.store_scalar %v, %ptr[%c0] : !pto.ptr, ui32 + return + } +} + +// NORMALIZE: // -----// IR Dump After +// NORMALIZE-SAME: (vpto-ptr-normalize) +// NORMALIZE-LABEL: func.func @load_scalar_ui32(%arg0: !pto.ptr) { +// NORMALIZE: %[[C0:.*]] = arith.constant 0 : index +// NORMALIZE: %[[VAL:.*]] = pto.load_scalar %arg0[%[[C0]]] : !pto.ptr -> ui32 +// NORMALIZE: pto.store_scalar %[[VAL]], %arg0[%[[C0]]] : !pto.ptr, ui32 diff --git a/test/lit/vpto/legacy_aicore_kernel_attr.pto b/test/lit/vpto/legacy_aicore_kernel_attr.pto new file mode 100644 index 000000000..e699f5c1c --- /dev/null +++ b/test/lit/vpto/legacy_aicore_kernel_attr.pto @@ -0,0 +1,35 @@ +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto %s -o - | FileCheck %s --check-prefix=SPLIT +// RUN: ( ptoas --pto-arch=a5 --pto-backend=vpto %s -o %t --mlir-print-ir-after=convert-func-to-llvm 2>&1 || true ) | FileCheck %s --check-prefix=LLVM + +module attributes {pto.target_arch = "a5"} { + func.func @legacy_attr_kernel(%arg0: !pto.ptr) + attributes {pto.aicore} { + %c0 = arith.constant 0 : i64 + %c1 = arith.constant 1 : i64 + %ub = pto.castptr %c0 : i64 -> !pto.ptr + %mat = pto.castptr %c0 : i64 -> !pto.ptr + + pto.section.vector { + pto.mte_gm_ub %arg0, %ub, %c0, %c1 + nburst(%c1, %c1, %c1) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + } + + pto.section.cube { + pto.copy_ubuf_to_cbuf %ub, %mat, %c0, %c1, %c1, %c0, %c0 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + } + + return + } +} + +// SPLIT: module attributes {pto.kernel_kind = #pto.kernel_kind +// SPLIT: func.func @legacy_attr_kernel +// SPLIT-SAME: attributes {pto.aicore} +// SPLIT: module attributes {pto.kernel_kind = #pto.kernel_kind +// SPLIT: func.func @legacy_attr_kernel +// SPLIT-SAME: attributes {pto.aicore} + +// LLVM-DAG: llvm.func @legacy_attr_kernel_mix_aiv +// LLVM-DAG: llvm.func @legacy_attr_kernel_mix_aic diff --git a/test/lit/vpto/load_cbuf_to_ca_vpto_llvm.pto b/test/lit/vpto/load_cbuf_to_ca_vpto_llvm.pto new file mode 100644 index 000000000..5ac243ff1 --- /dev/null +++ b/test/lit/vpto/load_cbuf_to_ca_vpto_llvm.pto @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ( mkdir -p %T && ptoas --pto-arch=a5 --pto-backend=vpto %s -o %t --mlir-print-ir-after=convert-func-to-llvm 2>&1 || true ) | FileCheck %s + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @load_cbuf_to_ca_2dv2_control() attributes {pto.kernel} { + %c0 = arith.constant 0 : i64 + %c2 = arith.constant 2 : i64 + %c3 = arith.constant 3 : i64 + %c4 = arith.constant 4 : i64 + %c5 = arith.constant 5 : i64 + %c6 = arith.constant 6 : i64 + %src = pto.castptr %c0 : i64 -> !pto.ptr + %dst = pto.castptr %c0 : i64 -> !pto.ptr + + pto.load_cbuf_to_ca %src, %dst, %c0, %c2, %c3, %c4, %c5, %c6 {transpose = true} + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + return + } +} + +// CHECK-LABEL: llvm.func @load_cbuf_to_ca_2dv2_control_mix_aic +// CHECK: llvm.call @llvm.hivm.LOAD.L1.TO.L0A.2Dv2.f16 diff --git a/test/lit/vpto/load_cbuf_to_cb_vpto_llvm.pto b/test/lit/vpto/load_cbuf_to_cb_vpto_llvm.pto new file mode 100644 index 000000000..434e05e7e --- /dev/null +++ b/test/lit/vpto/load_cbuf_to_cb_vpto_llvm.pto @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ( mkdir -p %T && ptoas --pto-arch=a5 --pto-backend=vpto %s -o %t --mlir-print-ir-after=convert-func-to-llvm 2>&1 || true ) | FileCheck %s + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @load_cbuf_to_cb_2dv2_control() attributes {pto.kernel} { + %c0 = arith.constant 0 : i64 + %c2 = arith.constant 2 : i64 + %c3 = arith.constant 3 : i64 + %c4 = arith.constant 4 : i64 + %c5 = arith.constant 5 : i64 + %c6 = arith.constant 6 : i64 + %src = pto.castptr %c0 : i64 -> !pto.ptr + %dst = pto.castptr %c0 : i64 -> !pto.ptr + + pto.load_cbuf_to_cb %src, %dst, %c0, %c2, %c3, %c4, %c5, %c6 {transpose = true} + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + return + } +} + +// CHECK-LABEL: llvm.func @load_cbuf_to_cb_2dv2_control_mix_aic +// CHECK: llvm.call @llvm.hivm.LOAD.L1.TO.L0B.2Dv2.f16 diff --git a/test/lit/vpto/membar_barrier_types_vpto_llvm.pto b/test/lit/vpto/membar_barrier_types_vpto_llvm.pto new file mode 100644 index 000000000..14c2d1ab3 --- /dev/null +++ b/test/lit/vpto/membar_barrier_types_vpto_llvm.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ( mkdir -p %T && ptoas --pto-arch=a5 --pto-backend=vpto %s -o %t --mlir-print-ir-after=convert-func-to-llvm 2>&1 || true ) | FileCheck %s + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @membar_all_documented_kinds() attributes {pto.kernel} { + pto.vecscope { + pto.mem_bar "VV_ALL" + pto.mem_bar "VST_VLD" + pto.mem_bar "VLD_VST" + pto.mem_bar "VST_VST" + pto.mem_bar "VS_ALL" + pto.mem_bar "VST_LD" + pto.mem_bar "VLD_ST" + pto.mem_bar "VST_ST" + pto.mem_bar "SV_ALL" + pto.mem_bar "ST_VLD" + pto.mem_bar "LD_VST" + pto.mem_bar "ST_VST" + } + return + } +} + +// CHECK-LABEL: llvm.func @membar_all_documented_kinds_mix_aiv +// CHECK: llvm.call @llvm.hivm.mem.bar.vv.all +// CHECK: llvm.call @llvm.hivm.mem.bar.st.vst diff --git a/test/lit/vpto/mlir_print_ir_debug.pto b/test/lit/vpto/mlir_print_ir_debug.pto new file mode 100644 index 000000000..117ab1f10 --- /dev/null +++ b/test/lit/vpto/mlir_print_ir_debug.pto @@ -0,0 +1,30 @@ +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --mlir-print-ir-after=vpto-normalize-container %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=MAIN +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --mlir-print-ir-after=vpto-ptr-normalize %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=VPTO + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @kernel(%arg0: i32) attributes { pto.tilelang.instance } { + %0 = func.call @__tl_inline_add1_i32(%arg0) : (i32) -> i32 + func.call @__tl_inline_sink_i32(%0) : (i32) -> () + return + } + + func.func private @__tl_inline_add1_i32(%x: i32) -> i32 attributes { pto.tilelang.inline_proc } { + %c1 = arith.constant 1 : i32 + %v = arith.addi %x, %c1 : i32 + return %v : i32 + } + + func.func private @__tl_inline_sink_i32(%x: i32) attributes { pto.tilelang.inline_proc } { + %c2 = arith.constant 2 : i32 + %t = arith.addi %x, %c2 : i32 + return + } +} + +// MAIN: // -----// IR Dump After +// MAIN-SAME: (vpto-normalize-container) +// MAIN: func.func @kernel + +// VPTO: // -----// IR Dump After +// VPTO-SAME: (vpto-ptr-normalize) +// VPTO: func.func @kernel diff --git a/test/lit/vpto/pipe_event_dyn_sync_llvm.pto b/test/lit/vpto/pipe_event_dyn_sync_llvm.pto new file mode 100644 index 000000000..5b49934c0 --- /dev/null +++ b/test/lit/vpto/pipe_event_dyn_sync_llvm.pto @@ -0,0 +1,179 @@ +// RUN: ( mkdir -p %T && ptoas --pto-arch=a5 --pto-backend=vpto %s -o %t --mlir-print-ir-after=convert-func-to-llvm 2>&1 || true ) | FileCheck %s + +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test LowerPipeEventDynSyncOpPattern for SetFlagDynOp and WaitFlagDynOp lowering to LLVM IR + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + + // ===----------------------------------------------------------------------=== + // Section 1: Basic index type event ID + // ===----------------------------------------------------------------------=== + + func.func @pipe_event_dyn_sync_basic(%event_id: index) attributes {pto.kernel} { + pto.set_flag_dyn [#pto.pipe, #pto.pipe, %event_id] + pto.wait_flag_dyn [#pto.pipe, #pto.pipe, %event_id] + return + } + + // ===----------------------------------------------------------------------=== + // Section 2: Pipe value boundary tests + // ===----------------------------------------------------------------------=== + + // Test minimum pipe values: PIPE_S=0, PIPE_V=1, PIPE_M=2 + func.func @pipe_event_dyn_sync_min_pipes(%eid: index) attributes {pto.kernel} { + pto.set_flag_dyn [#pto.pipe, #pto.pipe, %eid] + pto.wait_flag_dyn [#pto.pipe, #pto.pipe, %eid] + return + } + + // Test MTE pipes: PIPE_MTE1=3, PIPE_MTE2=4, PIPE_MTE3=5, PIPE_MTE4=7, PIPE_MTE5=8 + func.func @pipe_event_dyn_sync_mte_pipes(%eid: index) attributes {pto.kernel} { + pto.set_flag_dyn [#pto.pipe, #pto.pipe, %eid] + pto.wait_flag_dyn [#pto.pipe, #pto.pipe, %eid] + pto.set_flag_dyn [#pto.pipe, #pto.pipe, %eid] + return + } + + // Test high pipe values: PIPE_FIX=10, PIPE_V2=9, PIPE_ALL=6 + func.func @pipe_event_dyn_sync_high_pipes(%eid: index) attributes {pto.kernel} { + pto.set_flag_dyn [#pto.pipe, #pto.pipe, %eid] + pto.wait_flag_dyn [#pto.pipe, #pto.pipe, %eid] + return + } + + // Test virtual pipes: VIRTUAL_PIPE_MTE2_L1A=11, VIRTUAL_PIPE_MTE2_L1B=12 + func.func @pipe_event_dyn_sync_virtual_pipes(%eid: index) attributes {pto.kernel} { + pto.set_flag_dyn [#pto.pipe, #pto.pipe, %eid] + pto.wait_flag_dyn [#pto.pipe, #pto.pipe, %eid] + return + } + + // ===----------------------------------------------------------------------=== + // Section 3: Computed event IDs + // ===----------------------------------------------------------------------=== + + func.func @pipe_event_dyn_sync_computed(%base: index, %offset: index) attributes {pto.kernel} { + %eid = arith.addi %base, %offset : index + pto.set_flag_dyn [#pto.pipe, #pto.pipe, %eid] + pto.wait_flag_dyn [#pto.pipe, #pto.pipe, %eid] + return + } + + // ===----------------------------------------------------------------------=== + // Section 4: Constant event IDs + // ===----------------------------------------------------------------------=== + + func.func @pipe_event_dyn_sync_const() attributes {pto.kernel} { + %c7 = arith.constant 7 : index + pto.set_flag_dyn [#pto.pipe, #pto.pipe, %c7] + pto.wait_flag_dyn [#pto.pipe, #pto.pipe, %c7] + return + } + + // ===----------------------------------------------------------------------=== + // Section 5: Standalone operations (not paired) + // ===----------------------------------------------------------------------=== + + func.func @pipe_event_dyn_sync_only_set(%eid: index) attributes {pto.kernel} { + pto.set_flag_dyn [#pto.pipe, #pto.pipe, %eid] + return + } + + func.func @pipe_event_dyn_sync_only_wait(%eid: index) attributes {pto.kernel} { + pto.wait_flag_dyn [#pto.pipe, #pto.pipe, %eid] + return + } +} + +// ===----------------------------------------------------------------------=== +// CHECK patterns for Section 1: Basic +// ===----------------------------------------------------------------------=== + +// CHECK-LABEL: llvm.func @pipe_event_dyn_sync_basic_mix_aiv +// CHECK: llvm.mlir.constant(4 : i64) : i64 +// CHECK: llvm.mlir.constant(5 : i64) : i64 +// CHECK: llvm.call @llvm.hivm.SET.FLAG.REG +// CHECK: llvm.mlir.constant(4 : i64) : i64 +// CHECK: llvm.mlir.constant(5 : i64) : i64 +// CHECK: llvm.call @llvm.hivm.WAIT.FLAG.REG + +// ===----------------------------------------------------------------------=== +// CHECK patterns for Section 2: Pipe boundaries +// ===----------------------------------------------------------------------=== + +// CHECK-LABEL: llvm.func @pipe_event_dyn_sync_min_pipes_mix_aiv +// CHECK: llvm.mlir.constant(0 : i64) : i64 +// CHECK: llvm.mlir.constant(1 : i64) : i64 +// CHECK: llvm.call @llvm.hivm.SET.FLAG.REG +// CHECK: llvm.mlir.constant(0 : i64) : i64 +// CHECK: llvm.mlir.constant(2 : i64) : i64 +// CHECK: llvm.call @llvm.hivm.WAIT.FLAG.REG + +// CHECK-LABEL: llvm.func @pipe_event_dyn_sync_mte_pipes_mix_aiv +// CHECK: llvm.mlir.constant(3 : i64) : i64 +// CHECK: llvm.mlir.constant(4 : i64) : i64 +// CHECK: llvm.call @llvm.hivm.SET.FLAG.REG +// CHECK: llvm.mlir.constant(5 : i64) : i64 +// CHECK: llvm.mlir.constant(7 : i64) : i64 +// CHECK: llvm.call @llvm.hivm.WAIT.FLAG.REG +// CHECK: llvm.mlir.constant(8 : i64) : i64 +// CHECK: llvm.mlir.constant(1 : i64) : i64 +// CHECK: llvm.call @llvm.hivm.SET.FLAG.REG + +// CHECK-LABEL: llvm.func @pipe_event_dyn_sync_high_pipes_mix_aiv +// CHECK: llvm.mlir.constant(10 : i64) : i64 +// CHECK: llvm.mlir.constant(9 : i64) : i64 +// CHECK: llvm.call @llvm.hivm.SET.FLAG.REG +// CHECK: llvm.mlir.constant(6 : i64) : i64 +// CHECK: llvm.mlir.constant(10 : i64) : i64 +// CHECK: llvm.call @llvm.hivm.WAIT.FLAG.REG + +// CHECK-LABEL: llvm.func @pipe_event_dyn_sync_virtual_pipes_mix_aiv +// CHECK: llvm.mlir.constant(11 : i64) : i64 +// CHECK: llvm.mlir.constant(12 : i64) : i64 +// CHECK: llvm.call @llvm.hivm.SET.FLAG.REG +// CHECK: llvm.mlir.constant(11 : i64) : i64 +// CHECK: llvm.mlir.constant(12 : i64) : i64 +// CHECK: llvm.call @llvm.hivm.WAIT.FLAG.REG + +// ===----------------------------------------------------------------------=== +// CHECK patterns for Section 3: Computed event IDs +// ===----------------------------------------------------------------------=== + +// CHECK-LABEL: llvm.func @pipe_event_dyn_sync_computed_mix_aiv +// CHECK: llvm.add +// CHECK: llvm.mlir.constant(1 : i64) : i64 +// CHECK: llvm.mlir.constant(2 : i64) : i64 +// CHECK: llvm.call @llvm.hivm.SET.FLAG.REG +// CHECK: llvm.call @llvm.hivm.WAIT.FLAG.REG + +// ===----------------------------------------------------------------------=== +// CHECK patterns for Section 4: Constant event IDs +// ===----------------------------------------------------------------------=== + +// CHECK-LABEL: llvm.func @pipe_event_dyn_sync_const_mix_aiv +// CHECK: llvm.mlir.constant(5 : i64) : i64 +// CHECK: llvm.mlir.constant(7 : i64) : i64 +// CHECK: llvm.mlir.constant(7 : i64) : i64 +// CHECK: llvm.call @llvm.hivm.SET.FLAG.REG +// CHECK: llvm.mlir.constant(5 : i64) : i64 +// CHECK: llvm.mlir.constant(7 : i64) : i64 +// CHECK: llvm.mlir.constant(7 : i64) : i64 +// CHECK: llvm.call @llvm.hivm.WAIT.FLAG.REG + +// ===----------------------------------------------------------------------=== +// CHECK patterns for Section 5: Standalone operations +// ===----------------------------------------------------------------------=== + +// CHECK-LABEL: llvm.func @pipe_event_dyn_sync_only_set_mix_aiv +// CHECK: llvm.call @llvm.hivm.SET.FLAG.REG + +// CHECK-LABEL: llvm.func @pipe_event_dyn_sync_only_wait_mix_aiv +// CHECK: llvm.call @llvm.hivm.WAIT.FLAG.REG \ No newline at end of file diff --git a/test/lit/vpto/ptr_addrspace_aliases.pto b/test/lit/vpto/ptr_addrspace_aliases.pto new file mode 100644 index 000000000..91da5f313 --- /dev/null +++ b/test/lit/vpto/ptr_addrspace_aliases.pto @@ -0,0 +1,38 @@ +// RUN: ptoas --pto-arch=a5 --mlir-print-ir-after-all %s -o /dev/null 2>&1 | FileCheck %s + +module { + func.func @ptr_addrspace_aliases( + %l1: !pto.ptr, + %l0a: !pto.ptr, + %l0b: !pto.ptr, + %l0c: !pto.ptr, + %bt: !pto.ptr, + %fb: !pto.ptr) attributes {pto.kernel} { + return + } + + func.func @ptr_addrspace_legacy_aliases( + %mat: !pto.ptr, + %left: !pto.ptr, + %right: !pto.ptr, + %acc: !pto.ptr, + %bias: !pto.ptr, + %scaling: !pto.ptr) attributes {pto.kernel} { + return + } +} + +// CHECK-LABEL: func.func @ptr_addrspace_aliases( +// CHECK-SAME: !pto.ptr +// CHECK-SAME: !pto.ptr +// CHECK-SAME: !pto.ptr +// CHECK-SAME: !pto.ptr +// CHECK-SAME: !pto.ptr +// CHECK-SAME: !pto.ptr +// CHECK-LABEL: func.func @ptr_addrspace_legacy_aliases( +// CHECK-SAME: !pto.ptr +// CHECK-SAME: !pto.ptr +// CHECK-SAME: !pto.ptr +// CHECK-SAME: !pto.ptr +// CHECK-SAME: !pto.ptr +// CHECK-SAME: !pto.ptr diff --git a/test/lit/vpto/section_sugar_duplicate_invalid.pto b/test/lit/vpto/section_sugar_duplicate_invalid.pto new file mode 100644 index 000000000..16cf5bf0e --- /dev/null +++ b/test/lit/vpto/section_sugar_duplicate_invalid.pto @@ -0,0 +1,14 @@ +// RUN: ! ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto %s -o - 2>&1 | FileCheck %s + +module attributes {pto.target_arch = "a5"} { + func.func @duplicate_section_kernel() attributes {pto.kernel} { + pto.section.vector { + } + pto.section.vector { + } + return + } +} + +// CHECK: contains more than one pto.section.vector + diff --git a/test/lit/vpto/section_sugar_mixed.pto b/test/lit/vpto/section_sugar_mixed.pto new file mode 100644 index 000000000..f855ef619 --- /dev/null +++ b/test/lit/vpto/section_sugar_mixed.pto @@ -0,0 +1,38 @@ +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto %s -o - | FileCheck %s + +module attributes {pto.target_arch = "a5"} { + func.func @section_sugar_kernel(%arg0: !pto.ptr) + attributes {pto.kernel} { + %c0 = arith.constant 0 : i64 + %c1 = arith.constant 1 : i64 + %ub = pto.castptr %c0 : i64 -> !pto.ptr + %mat = pto.castptr %c0 : i64 -> !pto.ptr + + pto.section.vector { + pto.mte_gm_ub %arg0, %ub, %c0, %c1 + nburst(%c1, %c1, %c1) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + } + + pto.section.cube { + pto.copy_ubuf_to_cbuf %ub, %mat, %c0, %c1, %c1, %c0, %c0 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + } + + return + } +} + +// CHECK: module attributes {pto.target_arch = "a5"} { +// CHECK: module attributes { +// CHECK-SAME: pto.kernel_kind = #pto.kernel_kind +// CHECK: func.func @section_sugar_kernel +// CHECK: pto.copy_gm_to_ubuf +// CHECK-NOT: pto.section +// CHECK-NOT: pto.copy_ubuf_to_cbuf +// CHECK: module attributes { +// CHECK-SAME: pto.kernel_kind = #pto.kernel_kind +// CHECK: func.func @section_sugar_kernel +// CHECK: pto.copy_ubuf_to_cbuf +// CHECK-NOT: pto.section +// CHECK-NOT: pto.copy_gm_to_ubuf diff --git a/test/lit/vpto/section_sugar_multi_func.pto b/test/lit/vpto/section_sugar_multi_func.pto new file mode 100644 index 000000000..a0f87ff71 --- /dev/null +++ b/test/lit/vpto/section_sugar_multi_func.pto @@ -0,0 +1,25 @@ +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto %s -o - | FileCheck %s + +module attributes {pto.target_arch = "a5"} { + func.func @vec_only() attributes {pto.kernel} { + pto.section.vector { + } + return + } + + func.func @cube_only() attributes {pto.kernel} { + pto.section.cube { + } + return + } +} + +// CHECK: module attributes {pto.target_arch = "a5"} { +// CHECK: module attributes { +// CHECK-SAME: pto.kernel_kind = #pto.kernel_kind +// CHECK: func.func @vec_only +// CHECK-NOT: func.func @cube_only +// CHECK: module attributes { +// CHECK-SAME: pto.kernel_kind = #pto.kernel_kind +// CHECK-NOT: func.func @vec_only +// CHECK: func.func @cube_only diff --git a/test/lit/vpto/tcolargmax.pto b/test/lit/vpto/tcolargmax.pto new file mode 100644 index 000000000..e1003375d --- /dev/null +++ b/test/lit/vpto/tcolargmax.pto @@ -0,0 +1,49 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FIT FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tcolargmax via the TileLang Python DSL template +// lib/TileOps/tcolargmax_template.py. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tcolargmax should be lowered to vector-style VPTO IR. +// CHECK: func.func @TCOLARGMAX +// CHECK-NOT: pto.tcolargmax ins +// CHECK: pto.castptr +// CHECK: pto.vecscope +// CHECK: pto.vdup +// CHECK: pto.vlds +// CHECK: pto.vcmp +// CHECK: pto.vsel +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TCOLARGMAX() { + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tcolargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/tcolargmin.pto b/test/lit/vpto/tcolargmin.pto new file mode 100644 index 000000000..10835a04b --- /dev/null +++ b/test/lit/vpto/tcolargmin.pto @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FIT FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tcolargmin via the TileLang Python DSL template +// lib/TileOps/tcolargmin_template.py. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tcolargmin should be lowered to vector-style VPTO IR. +// CHECK: func.func @TCOLARGMIN +// CHECK-NOT: pto.tcolargmin ins +// CHECK: pto.castptr +// CHECK: pto.vecscope +// CHECK: pto.vdup +// CHECK: pto.vlds +// CHECK: pto.vcmp +// CHECK: pto.vsel +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TCOLARGMIN() { + %src = pto.alloc_tile + : !pto.tile_buf + + %tmp = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tcolargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + + outs(%dst : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/tcolexpand_tilelang.pto b/test/lit/vpto/tcolexpand_tilelang.pto new file mode 100644 index 000000000..a35feeb1d --- /dev/null +++ b/test/lit/vpto/tcolexpand_tilelang.pto @@ -0,0 +1,33 @@ +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tcolexpand via the default TileLang Python DSL template +// lib/TileOps/tcolexpand_template.py. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tcolexpand should be lowered to vector-style VPTO IR. +// CHECK: func.func @TCOLEXPAND +// CHECK-NOT: pto.tcolexpand ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TCOLEXPAND() { + %src0 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tcolexpand ins(%src0 : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/tcolexpandadd_tilelang.pto b/test/lit/vpto/tcolexpandadd_tilelang.pto new file mode 100644 index 000000000..a5ed9a21c --- /dev/null +++ b/test/lit/vpto/tcolexpandadd_tilelang.pto @@ -0,0 +1,39 @@ +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tcolexpandadd via the default TileLang Python DSL template +// lib/TileOps/tcolexpandadd_template.py. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tcolexpandadd should be lowered to vector-style VPTO IR. +// CHECK: func.func @TCOLEXPANDADD +// CHECK-NOT: pto.tcolexpandadd ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vadd +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TCOLEXPANDADD() { + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tcolexpandadd ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/tcolexpanddiv_tilelang.pto b/test/lit/vpto/tcolexpanddiv_tilelang.pto new file mode 100644 index 000000000..3e2b32905 --- /dev/null +++ b/test/lit/vpto/tcolexpanddiv_tilelang.pto @@ -0,0 +1,39 @@ +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tcolexpanddiv via the default TileLang Python DSL template +// lib/TileOps/tcolexpanddiv_template.py. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tcolexpanddiv should be lowered to vector-style VPTO IR. +// CHECK: func.func @TCOLEXPANDDIV +// CHECK-NOT: pto.tcolexpanddiv ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vdiv +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TCOLEXPANDDIV() { + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tcolexpanddiv ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/tcolexpandexpdif_tilelang.pto b/test/lit/vpto/tcolexpandexpdif_tilelang.pto new file mode 100644 index 000000000..6e2062252 --- /dev/null +++ b/test/lit/vpto/tcolexpandexpdif_tilelang.pto @@ -0,0 +1,39 @@ +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tcolexpandexpdif via the default TileLang Python DSL template +// lib/TileOps/tcolexpandexpdif_template.py. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tcolexpandexpdif should be lowered to vector-style VPTO IR. +// CHECK: func.func @TCOLEXPANDEXPDIF +// CHECK-NOT: pto.tcolexpandexpdif ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vexpdif +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TCOLEXPANDEXPDIF() { + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tcolexpandexpdif ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/tcolexpandmax_tilelang.pto b/test/lit/vpto/tcolexpandmax_tilelang.pto new file mode 100644 index 000000000..9d0fb20ec --- /dev/null +++ b/test/lit/vpto/tcolexpandmax_tilelang.pto @@ -0,0 +1,39 @@ +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tcolexpandmax via the default TileLang Python DSL template +// lib/TileOps/tcolexpandmax_template.py. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tcolexpandmax should be lowered to vector-style VPTO IR. +// CHECK: func.func @TCOLEXPANDMAX +// CHECK-NOT: pto.tcolexpandmax ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vmax +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TCOLEXPANDMAX() { + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tcolexpandmax ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/tcolexpandmin_tilelang.pto b/test/lit/vpto/tcolexpandmin_tilelang.pto new file mode 100644 index 000000000..269307986 --- /dev/null +++ b/test/lit/vpto/tcolexpandmin_tilelang.pto @@ -0,0 +1,39 @@ +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tcolexpandmin via the default TileLang Python DSL template +// lib/TileOps/tcolexpandmin_template.py. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tcolexpandmin should be lowered to vector-style VPTO IR. +// CHECK: func.func @TCOLEXPANDMIN +// CHECK-NOT: pto.tcolexpandmin ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vmin +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TCOLEXPANDMIN() { + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tcolexpandmin ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/tcolexpandmul_tilelang.pto b/test/lit/vpto/tcolexpandmul_tilelang.pto new file mode 100644 index 000000000..3e3d148a2 --- /dev/null +++ b/test/lit/vpto/tcolexpandmul_tilelang.pto @@ -0,0 +1,39 @@ +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tcolexpandmul via the default TileLang Python DSL template +// lib/TileOps/tcolexpandmul_template.py. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tcolexpandmul should be lowered to vector-style VPTO IR. +// CHECK: func.func @TCOLEXPANDMUL +// CHECK-NOT: pto.tcolexpandmul ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vmul +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TCOLEXPANDMUL() { + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tcolexpandmul ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/tcolexpandsub_tilelang.pto b/test/lit/vpto/tcolexpandsub_tilelang.pto new file mode 100644 index 000000000..49bfd7314 --- /dev/null +++ b/test/lit/vpto/tcolexpandsub_tilelang.pto @@ -0,0 +1,39 @@ +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tcolexpandsub via the default TileLang Python DSL template +// lib/TileOps/tcolexpandsub_template.py. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tcolexpandsub should be lowered to vector-style VPTO IR. +// CHECK: func.func @TCOLEXPANDSUB +// CHECK-NOT: pto.tcolexpandsub ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vsub +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TCOLEXPANDSUB() { + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tcolexpandsub ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/tcolmax.pto b/test/lit/vpto/tcolmax.pto new file mode 100644 index 000000000..3d6af696e --- /dev/null +++ b/test/lit/vpto/tcolmax.pto @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FIT FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tcolmax via the TileLang Python DSL template +// lib/TileOps/tcolmax_template.py. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tcolmax should be lowered to vector-style VPTO IR. +// CHECK: func.func @TCOLMAX +// CHECK-NOT: pto.tcolmax ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vmax +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TCOLMAX() { + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tcolmax ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/tcolmin.pto b/test/lit/vpto/tcolmin.pto new file mode 100644 index 000000000..83e6a621a --- /dev/null +++ b/test/lit/vpto/tcolmin.pto @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FIT FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tcolmin via the TileLang Python DSL template +// lib/TileOps/tcolmin_template.py. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tcolmin should be lowered to vector-style VPTO IR. +// CHECK: func.func @TCOLMIN +// CHECK-NOT: pto.tcolmin ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vmin +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TCOLMIN() { + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tcolmin ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/tcolprod.pto b/test/lit/vpto/tcolprod.pto new file mode 100644 index 000000000..af31f0595 --- /dev/null +++ b/test/lit/vpto/tcolprod.pto @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FIT FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tcolprod via the TileLang Python DSL template +// lib/TileOps/tcolprod_template.py. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tcolprod should be lowered to vector-style VPTO IR. +// CHECK: func.func @TCOLPROD +// CHECK-NOT: pto.tcolprod ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vmul +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TCOLPROD() { + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tcolprod ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/tcolsum.pto b/test/lit/vpto/tcolsum.pto new file mode 100644 index 000000000..ee470c1dc --- /dev/null +++ b/test/lit/vpto/tcolsum.pto @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FIT FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.tcolsum via the TileLang Python DSL template +// lib/TileOps/tcolsum_template.py. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.tcolsum should be lowered to vector-style VPTO IR. +// CHECK: func.func @TCOLSUM +// CHECK-NOT: pto.tcolsum ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vadd +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TCOLSUM() { + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tcolsum ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/tilelang_cross_file_inline_proc_backend_inline.pto b/test/lit/vpto/tilelang_cross_file_inline_proc_backend_inline.pto new file mode 100644 index 000000000..0843d8697 --- /dev/null +++ b/test/lit/vpto/tilelang_cross_file_inline_proc_backend_inline.pto @@ -0,0 +1,25 @@ +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto -o - %s | FileCheck %s + +// Models MLIR produced by a template that imported a shared @pto.inline_proc +// helper from another Python file. At this stage imported and local helpers use +// the same pto.tilelang.inline_proc contract; the backend must erase the helper +// boundary. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @kernel(%arg0: !pto.tile_buf) attributes { pto.tilelang.instance } { + func.call @__tl_inline_shared_touch_0() : () -> () + return + } + + func.func private @__tl_inline_shared_touch_0() attributes { pto.tilelang.inline_proc } { + pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + } + return + } +} + +// CHECK-LABEL: module +// CHECK-NOT: func.call @__tl_inline_shared_touch +// CHECK-NOT: func.func private @__tl_inline_shared_touch +// CHECK-NOT: pto.tilelang.inline_proc diff --git a/test/lit/vpto/tilelang_inline_proc_backend_inline.pto b/test/lit/vpto/tilelang_inline_proc_backend_inline.pto new file mode 100644 index 000000000..d9b38d77e --- /dev/null +++ b/test/lit/vpto/tilelang_inline_proc_backend_inline.pto @@ -0,0 +1,28 @@ +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --mlir-print-ir-after=pto-inline-libcall %s -o /dev/null 2>&1 | FileCheck %s + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @kernel(%arg0: i32) attributes { pto.tilelang.instance } { + %0 = func.call @__tl_inline_add1_i32(%arg0) : (i32) -> i32 + func.call @__tl_inline_sink_i32(%0) : (i32) -> () + return + } + + func.func private @__tl_inline_add1_i32(%x: i32) -> i32 attributes { pto.tilelang.inline_proc } { + %c1 = arith.constant 1 : i32 + %v = arith.addi %x, %c1 : i32 + return %v : i32 + } + + func.func private @__tl_inline_sink_i32(%x: i32) attributes { pto.tilelang.inline_proc } { + %c2 = arith.constant 2 : i32 + %t = arith.addi %x, %c2 : i32 + return + } +} + +// CHECK-LABEL: func.func @kernel +// CHECK: arith.constant 1 : i32 +// CHECK: arith.addi %arg0, %{{[^,]+}} : i32 +// CHECK-NOT: func.call @__tl_inline_ +// CHECK-NOT: func.func private @__tl_inline_add1_i32 +// CHECK-NOT: func.func private @__tl_inline_sink_i32 diff --git a/test/lit/vpto/tilelang_soft_vmod_backend_inline.pto b/test/lit/vpto/tilelang_soft_vmod_backend_inline.pto new file mode 100644 index 000000000..65c84c2ec --- /dev/null +++ b/test/lit/vpto/tilelang_soft_vmod_backend_inline.pto @@ -0,0 +1,126 @@ +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto %s --emit-vpto -o - | FileCheck %s + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @kernel(%arg0: !pto.tile_buf, %arg1: !pto.tile_buf) attributes { pto.tilelang.instance } { + %c0 = arith.constant 0 : index + %tmp_0 = pto.tile_buf_addr %arg0 : !pto.tile_buf -> memref<8x16xi16, #pto.address_space> + %tmp_1 = pto.tile_buf_addr %arg1 : !pto.tile_buf -> memref<8x16xi16, #pto.address_space> + pto.vecscope { + %mask_0 = pto.pset_b16 "PAT_ALL" : !pto.mask + %vec_1 = pto.vlds %tmp_1[%c0] : memref<8x16xi16, #pto.address_space> -> !pto.vreg<128xi16> + %result_81 = func.call @__tl_inline__tl_soft_vmod_2(%vec_1, %vec_1, %mask_0) : (!pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask) -> !pto.vreg<128xi16> + pto.vsts %result_81, %tmp_0[%c0], %mask_0 : !pto.vreg<128xi16>, memref<8x16xi16, #pto.address_space>, !pto.mask + } + return + } + + func.func private @__tl_inline__tl_soft_vdiv_u16_0(%arg0: !pto.vreg<128xui16>, %arg1: !pto.vreg<128xui16>, %arg2: !pto.mask) -> !pto.vreg<128xui16> attributes { pto.tilelang.inline_proc } { + %c0_i32 = arith.constant 0 : i32 + %c0_ui32 = builtin.unrealized_conversion_cast %c0_i32 : i32 to ui32 + %c65536_0_f32 = arith.constant 65536.0 : f32 + %c65535_i16 = arith.constant 65535 : i16 + %c65535_ui16 = builtin.unrealized_conversion_cast %c65535_i16 : i16 to ui16 + %tmp_0 = arith.constant 0 : i16 + %zero_10 = builtin.unrealized_conversion_cast %tmp_0 : i16 to ui16 + %tmp_1 = arith.constant 1 : i16 + %one_11 = builtin.unrealized_conversion_cast %tmp_1 : i16 to ui16 + %fp32_one_12 = arith.constant 1.0 : f32 + %full_mask_b16_13 = pto.pset_b16 "PAT_ALL" : !pto.mask + %full_mask_b32_14 = pto.pset_b32 "PAT_ALL" : !pto.mask + %zero_mask_15 = pto.vcmps %arg1, %zero_10, %arg2, "eq" : !pto.vreg<128xui16>, ui16, !pto.mask -> !pto.mask + %active_mask_16 = pto.pnot %zero_mask_15, %arg2 : !pto.mask, !pto.mask -> !pto.mask + %zero_u16_17 = pto.vbr %zero_10 : ui16 -> !pto.vreg<128xui16> + %vy_lower_u16_18, %vy_higher_u16_19 = pto.vintlv %arg1, %zero_u16_17 : !pto.vreg<128xui16>, !pto.vreg<128xui16> -> !pto.vreg<128xui16>, !pto.vreg<128xui16> + %vy_lower_u32_20 = pto.vcvt %vy_lower_u16_18, %active_mask_16 {part = "EVEN"} : !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<64xui32> + %vy_higher_u32_21 = pto.vcvt %vy_higher_u16_19, %active_mask_16 {part = "EVEN"} : !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<64xui32> + %active_low_22 = pto.vcmps %vy_lower_u32_20, %c0_ui32, %full_mask_b32_14, "ne" : !pto.vreg<64xui32>, ui32, !pto.mask -> !pto.mask + %active_high_23 = pto.vcmps %vy_higher_u32_21, %c0_ui32, %full_mask_b32_14, "ne" : !pto.vreg<64xui32>, ui32, !pto.mask -> !pto.mask + %tmp_2 = pto.vbitcast %vy_lower_u32_20 : !pto.vreg<64xui32> -> !pto.vreg<64xi32> + %vy_lower_f32_24 = pto.vcvt %tmp_2, %active_low_22 {rnd = "F"} : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xf32> + %tmp_3 = pto.vbitcast %vy_higher_u32_21 : !pto.vreg<64xui32> -> !pto.vreg<64xi32> + %vy_higher_f32_25 = pto.vcvt %tmp_3, %active_high_23 {rnd = "F"} : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xf32> + %tmp_4 = pto.vbr %fp32_one_12 : f32 -> !pto.vreg<64xf32> + %vy_rec_lower_26 = pto.vdiv %tmp_4, %vy_lower_f32_24, %active_low_22 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %tmp_5 = pto.vbr %fp32_one_12 : f32 -> !pto.vreg<64xf32> + %vy_rec_higher_27 = pto.vdiv %tmp_5, %vy_higher_f32_25, %active_high_23 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %tmp_6 = pto.vbr %c65536_0_f32 : f32 -> !pto.vreg<64xf32> + %vy_scale_lower_28 = pto.vmul %vy_rec_lower_26, %tmp_6, %active_low_22 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %tmp_7 = pto.vbr %c65536_0_f32 : f32 -> !pto.vreg<64xf32> + %vy_scale_higher_29 = pto.vmul %vy_rec_higher_27, %tmp_7, %active_high_23 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %v_lower_i32_30 = pto.vcvt %vy_scale_lower_28, %active_low_22 {rnd = "F", sat = "NOSAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xi32> + %v_higher_i32_31 = pto.vcvt %vy_scale_higher_29, %active_high_23 {rnd = "F", sat = "NOSAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xi32> + %v_lower_u32_32 = pto.vbitcast %v_lower_i32_30 : !pto.vreg<64xi32> -> !pto.vreg<64xui32> + %v_higher_u32_33 = pto.vbitcast %v_higher_i32_31 : !pto.vreg<64xi32> -> !pto.vreg<64xui32> + %vx_lower_u16_34, %vx_higher_u16_35 = pto.vintlv %arg0, %zero_u16_17 : !pto.vreg<128xui16>, !pto.vreg<128xui16> -> !pto.vreg<128xui16>, !pto.vreg<128xui16> + %vx_lower_u32_36 = pto.vcvt %vx_lower_u16_34, %active_mask_16 {part = "EVEN"} : !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<64xui32> + %vx_higher_u32_37 = pto.vcvt %vx_higher_u16_35, %active_mask_16 {part = "EVEN"} : !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<64xui32> + %q_tmp_lower_38 = pto.vmul %v_lower_u32_32, %vx_lower_u32_36, %active_low_22 : !pto.vreg<64xui32>, !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<64xui32> + %q_tmp_higher_39 = pto.vmul %v_higher_u32_33, %vx_higher_u32_37, %active_high_23 : !pto.vreg<64xui32>, !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<64xui32> + %tmp_8 = pto.vbitcast %q_tmp_lower_38 : !pto.vreg<64xui32> -> !pto.vreg<128xui16> + %tmp_9 = pto.vbitcast %q_tmp_higher_39 : !pto.vreg<64xui32> -> !pto.vreg<128xui16> + %_q_lower_40, %q_tmp_41 = pto.vdintlv %tmp_8, %tmp_9 : !pto.vreg<128xui16>, !pto.vreg<128xui16> -> !pto.vreg<128xui16>, !pto.vreg<128xui16> + %yq_tmp_42 = pto.vmul %q_tmp_41, %arg1, %active_mask_16 : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + %r_tmp_43 = pto.vsub %arg0, %yq_tmp_42, %active_mask_16 : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + %ge_mask_44 = pto.vcmp %r_tmp_43, %arg1, %active_mask_16, "ge" : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.mask + %refined_r_45 = pto.vsub %r_tmp_43, %arg1, %active_mask_16 : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + %r_tmp_46 = pto.vsel %refined_r_45, %r_tmp_43, %ge_mask_44 : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + %q_inc_47 = pto.vadds %q_tmp_41, %one_11, %active_mask_16 : !pto.vreg<128xui16>, ui16, !pto.mask -> !pto.vreg<128xui16> + %q_tmp_48 = pto.vsel %q_inc_47, %q_tmp_41, %ge_mask_44 : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + %ge_mask_49 = pto.vcmp %r_tmp_46, %arg1, %active_mask_16, "ge" : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.mask + %refined_r_50 = pto.vsub %r_tmp_46, %arg1, %active_mask_16 : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + %r_tmp_51 = pto.vsel %refined_r_50, %r_tmp_46, %ge_mask_49 : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + %q_inc_52 = pto.vadds %q_tmp_48, %one_11, %active_mask_16 : !pto.vreg<128xui16>, ui16, !pto.mask -> !pto.vreg<128xui16> + %q_tmp_53 = pto.vsel %q_inc_52, %q_tmp_48, %ge_mask_49 : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + %zero_q_54 = pto.vbr %c65535_ui16 : ui16 -> !pto.vreg<128xui16> + %tmp_10 = pto.vsel %zero_q_54, %q_tmp_53, %zero_mask_15 : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + return %tmp_10 : !pto.vreg<128xui16> + } + + func.func private @__tl_inline__tl_soft_vmod_i16_1(%arg0: !pto.vreg<128xi16>, %arg1: !pto.vreg<128xi16>, %arg2: !pto.mask) -> !pto.vreg<128xi16> attributes { pto.tilelang.inline_proc } { + %zero_60 = arith.constant 0 : i16 + %neg_one_61 = arith.constant -1 : i16 + %zero_mask_62 = pto.vcmps %arg1, %zero_60, %arg2, "eq" : !pto.vreg<128xi16>, i16, !pto.mask -> !pto.mask + %active_mask_63 = pto.pnot %zero_mask_62, %arg2 : !pto.mask, !pto.mask -> !pto.mask + %tmp_11 = pto.vabs %arg0, %active_mask_63 : !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + %abs_x_64 = pto.vbitcast %tmp_11 : !pto.vreg<128xi16> -> !pto.vreg<128xui16> + %tmp_12 = pto.vabs %arg1, %active_mask_63 : !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + %abs_y_65 = pto.vbitcast %tmp_12 : !pto.vreg<128xi16> -> !pto.vreg<128xui16> + %x_xor_y_66 = pto.vxor %arg0, %arg1, %active_mask_63 : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + %p_pos_67 = pto.vcmps %x_xor_y_66, %zero_60, %active_mask_63, "ge" : !pto.vreg<128xi16>, i16, !pto.mask -> !pto.mask + %q_abs_68 = func.call @__tl_inline__tl_soft_vdiv_u16_0(%abs_x_64, %abs_y_65, %active_mask_63) : (!pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask) -> !pto.vreg<128xui16> + %tmp_13 = pto.vbitcast %q_abs_68 : !pto.vreg<128xui16> -> !pto.vreg<128xi16> + %neg_q_69 = pto.vneg %tmp_13, %active_mask_63 : !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + %tmp_14 = pto.vbitcast %q_abs_68 : !pto.vreg<128xui16> -> !pto.vreg<128xi16> + %q_70 = pto.vsel %tmp_14, %neg_q_69, %p_pos_67 : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + %qy_71 = pto.vmul %q_70, %arg1, %active_mask_63 : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + %remainder_72 = pto.vsub %arg0, %qy_71, %active_mask_63 : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + %nonzero_remainder_73 = pto.vcmps %remainder_72, %zero_60, %active_mask_63, "ne" : !pto.vreg<128xi16>, i16, !pto.mask -> !pto.mask + %sign_x_74 = pto.vcmps %arg0, %zero_60, %active_mask_63, "ge" : !pto.vreg<128xi16>, i16, !pto.mask -> !pto.mask + %sign_y_75 = pto.vcmps %arg1, %zero_60, %active_mask_63, "ge" : !pto.vreg<128xi16>, i16, !pto.mask -> !pto.mask + %sign_diff_76 = pto.pxor %sign_x_74, %sign_y_75, %active_mask_63 : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + %need_floor_fix_77 = pto.pand %sign_diff_76, %nonzero_remainder_73, %active_mask_63 : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + %amended_remainder_78 = pto.vadd %arg1, %remainder_72, %active_mask_63 : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + %remainder_79 = pto.vsel %amended_remainder_78, %remainder_72, %need_floor_fix_77 : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + %tmp_15 = pto.vbr %neg_one_61 : i16 -> !pto.vreg<128xi16> + %tmp_16 = pto.vsel %tmp_15, %remainder_79, %zero_mask_62 : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + return %tmp_16 : !pto.vreg<128xi16> + } + + func.func private @__tl_inline__tl_soft_vmod_2(%arg0: !pto.vreg<128xi16>, %arg1: !pto.vreg<128xi16>, %arg2: !pto.mask) -> !pto.vreg<128xi16> attributes { pto.tilelang.inline_proc } { + %result_80 = func.call @__tl_inline__tl_soft_vmod_i16_1(%arg0, %arg1, %arg2) : (!pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask) -> !pto.vreg<128xi16> + return %result_80 : !pto.vreg<128xi16> + } +} + +// CHECK-LABEL: func.func @kernel( +// CHECK: pto.vecscope { +// CHECK: pto.vlds +// CHECK: pto.vdiv +// CHECK: pto.vmul +// CHECK: pto.vsub +// CHECK: pto.pxor +// CHECK: pto.pand +// CHECK: pto.vsel +// CHECK: pto.vsts +// CHECK-NOT: func.call @__tl_inline__tl_soft_ +// CHECK-NOT: func.func private @__tl_inline__tl_soft_ diff --git a/test/lit/vpto/trowexpand_tile_op_expand.pto b/test/lit/vpto/trowexpand_tile_op_expand.pto new file mode 100644 index 000000000..54db4ff81 --- /dev/null +++ b/test/lit/vpto/trowexpand_tile_op_expand.pto @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.trowexpand via the TileLang Python DSL template +// lib/TileOps/trowexpand.py. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.trowexpand should be lowered to vector-style VPTO IR. +// CHECK: func.func @TROWEXPAND +// CHECK-NOT: pto.trowexpand ins +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vdup +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TROWEXPAND() { + // Source tile: 32x8 (32-byte aligned, cols=32/sizeof(f32)=8) + // Only column 0 contains valid data (v_col=8 for alignment, actual valid=1) + %src = pto.alloc_tile + : !pto.tile_buf + // Destination tile: 32x32 (broadcast each scalar across the row) + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.trowexpand ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/lit/vpto/trowexpandadd_tile_op_expand.pto b/test/lit/vpto/trowexpandadd_tile_op_expand.pto new file mode 100644 index 000000000..208a3f55b --- /dev/null +++ b/test/lit/vpto/trowexpandadd_tile_op_expand.pto @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.trowexpandadd via the TileLang Python DSL template +// lib/TileOps/trowexpandadd.py. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.trowexpandadd should be lowered to vector-style VPTO IR. +// CHECK: func.func @TROWEXPANDADD +// CHECK-NOT: pto.trowexpandadd ins +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vdup +// CHECK: pto.vadd +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TROWEXPANDADD() { + // src0: 32x32 matrix (row-major) + %src0 = pto.alloc_tile + : !pto.tile_buf + // src1: 32x8 (one scalar vector per row, width=8 for f32 which is 32/sizeof(f32)) + // For row-major src1, valid_shape[1] must be 32/sizeof(dtype) + %src1 = pto.alloc_tile + : !pto.tile_buf + // dst: 32x32 result (row-major) + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.trowexpandadd ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/lit/vpto/trowexpanddiv_tile_op_expand.pto b/test/lit/vpto/trowexpanddiv_tile_op_expand.pto new file mode 100644 index 000000000..115dd538a --- /dev/null +++ b/test/lit/vpto/trowexpanddiv_tile_op_expand.pto @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.trowexpanddiv via the TileLang Python DSL template +// lib/TileOps/trowexpanddiv.py. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.trowexpanddiv should be lowered to vector-style VPTO IR. +// CHECK: func.func @TROWEXPANDDIV +// CHECK-NOT: pto.trowexpanddiv ins +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vdup +// CHECK: pto.vdiv +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TROWEXPANDDIV() { + // src0: 32x32 matrix (row-major) + %src0 = pto.alloc_tile + : !pto.tile_buf + // src1: 32x8 (one scalar vector per row, width=8 for f32 which is 32/sizeof(f32)) + // For row-major src1, valid_shape[1] must be 32/sizeof(dtype) + %src1 = pto.alloc_tile + : !pto.tile_buf + // dst: 32x32 result (row-major) + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.trowexpanddiv ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/lit/vpto/trowexpandexpdif_tile_op_expand.pto b/test/lit/vpto/trowexpandexpdif_tile_op_expand.pto new file mode 100644 index 000000000..b45529c99 --- /dev/null +++ b/test/lit/vpto/trowexpandexpdif_tile_op_expand.pto @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.trowexpandexpdif via the TileLang Python DSL template +// lib/TileOps/trowexpandexpdif.py. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.trowexpandexpdif should be lowered to vector-style VPTO IR. +// CHECK: func.func @TROWEXPANDEXPDIF +// CHECK-NOT: pto.trowexpandexpdif ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.vlds +// CHECK: pto.vdup +// CHECK: pto.vexpdif +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TROWEXPANDEXPDIF() { + // src0: 32x32 matrix (row-major) + %src0 = pto.alloc_tile + : !pto.tile_buf + // src1: 32x8 (one scalar vector per row, width=8 for f32 which is 32/sizeof(f32)) + // For row-major src1, valid_shape[1] must be 32/sizeof(dtype) + %src1 = pto.alloc_tile + : !pto.tile_buf + // dst: 32x32 result (row-major) + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.trowexpandexpdif ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/trowexpandmax_tile_op_expand.pto b/test/lit/vpto/trowexpandmax_tile_op_expand.pto new file mode 100644 index 000000000..599d32ff1 --- /dev/null +++ b/test/lit/vpto/trowexpandmax_tile_op_expand.pto @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.trowexpandmax via the TileLang Python DSL template +// lib/TileOps/trowexpandmax.py. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.trowexpandmax should be lowered to vector-style VPTO IR. +// CHECK: func.func @TROWEXPANDMAX +// CHECK-NOT: pto.trowexpandmax ins +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vdup +// CHECK: pto.vmax +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TROWEXPANDMAX() { + // src0: 32x32 matrix (row-major) + %src0 = pto.alloc_tile + : !pto.tile_buf + // src1: 32x8 (one scalar vector per row, width=8 for f32 which is 32/sizeof(f32)) + // For row-major src1, valid_shape[1] must be 32/sizeof(dtype) + %src1 = pto.alloc_tile + : !pto.tile_buf + // dst: 32x32 result (row-major) + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.trowexpandmax ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/lit/vpto/trowexpandmin_tile_op_expand.pto b/test/lit/vpto/trowexpandmin_tile_op_expand.pto new file mode 100644 index 000000000..e1b00b44a --- /dev/null +++ b/test/lit/vpto/trowexpandmin_tile_op_expand.pto @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.trowexpandmin via the TileLang Python DSL template +// lib/TileOps/trowexpandmin.py. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.trowexpandmin should be lowered to vector-style VPTO IR. +// CHECK: func.func @TROWEXPANDMIN +// CHECK-NOT: pto.trowexpandmin ins +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vdup +// CHECK: pto.vmin +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TROWEXPANDMIN() { + // src0: 32x32 matrix (row-major) + %src0 = pto.alloc_tile + : !pto.tile_buf + // src1: 32x8 (one scalar vector per row, width=8 for f32 which is 32/sizeof(f32)) + // For row-major src1, valid_shape[1] must be 32/sizeof(dtype) + %src1 = pto.alloc_tile + : !pto.tile_buf + // dst: 32x32 result (row-major) + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.trowexpandmin ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/lit/vpto/trowexpandmul_tile_op_expand.pto b/test/lit/vpto/trowexpandmul_tile_op_expand.pto new file mode 100644 index 000000000..7e50bf2f8 --- /dev/null +++ b/test/lit/vpto/trowexpandmul_tile_op_expand.pto @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.trowexpandmul via the TileLang Python DSL template +// lib/TileOps/trowexpandmul.py. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.trowexpandmul should be lowered to vector-style VPTO IR. +// CHECK: func.func @TROWEXPANDMUL +// CHECK-NOT: pto.trowexpandmul ins +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vdup +// CHECK: pto.vmul +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TROWEXPANDMUL() { + // src0: 32x32 matrix (row-major) + %src0 = pto.alloc_tile + : !pto.tile_buf + // src1: 32x8 (one scalar vector per row, width=8 for f32 which is 32/sizeof(f32)) + // For row-major src1, valid_shape[1] must be 32/sizeof(dtype) + %src1 = pto.alloc_tile + : !pto.tile_buf + // dst: 32x32 result (row-major) + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.trowexpandmul ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/lit/vpto/trowexpandsub_tile_op_expand.pto b/test/lit/vpto/trowexpandsub_tile_op_expand.pto new file mode 100644 index 000000000..a668612c3 --- /dev/null +++ b/test/lit/vpto/trowexpandsub_tile_op_expand.pto @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that ExpandTileOp + InlineLibCall + FoldTileBufIntrinsics pipeline +// expands pto.trowexpandsub via the TileLang Python DSL template +// lib/TileOps/trowexpandsub.py. +// +// Pipeline: PTOMaterializeTileHandles -> ExpandTileOp -> InlineLibCall -> FoldTileBufIntrinsics +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --enable-tile-op-expand %s -o - 2>/dev/null | FileCheck %s + +// After the full tile-op-expand path on the VPTO backend, the original +// pto.trowexpandsub should be lowered to vector-style VPTO IR. +// CHECK: func.func @TROWEXPANDSUB +// CHECK-NOT: pto.trowexpandsub ins +// CHECK: pto.vecscope +// CHECK: pto.vlds +// CHECK: pto.vdup +// CHECK: pto.vsub +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TROWEXPANDSUB() { + // src0: 32x32 matrix (row-major) + %src0 = pto.alloc_tile + : !pto.tile_buf + // src1: 32x8 (one scalar vector per row, width=8 for f32 which is 32/sizeof(f32)) + // For row-major src1, valid_shape[1] must be 32/sizeof(dtype) + %src1 = pto.alloc_tile + : !pto.tile_buf + // dst: 32x32 result (row-major) + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.trowexpandsub ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} \ No newline at end of file diff --git a/test/lit/vpto/vbitcast_vpto_llvm.pto b/test/lit/vpto/vbitcast_vpto_llvm.pto new file mode 100644 index 000000000..cb921563a --- /dev/null +++ b/test/lit/vpto/vbitcast_vpto_llvm.pto @@ -0,0 +1,39 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ( mkdir -p %T && ptoas --pto-arch=a5 --pto-backend=vpto %s -o %t --mlir-print-ir-after=convert-func-to-llvm 2>&1 || true ) | FileCheck %s + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vbitcast_f32_to_i32_store(%value: f32, %dst: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %vec = pto.vdup %value, %mask : f32, !pto.mask -> !pto.vreg<64xf32> + %cast = pto.vbitcast %vec : !pto.vreg<64xf32> -> !pto.vreg<64xi32> + pto.vsts %cast, %dst[%c0], %mask : !pto.vreg<64xi32>, !pto.ptr, !pto.mask + } + return + } + + func.func @vbitcast_f32_to_i16x128_store(%value: f32, %dst: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + pto.vecscope { + %mask_f32 = pto.pset_b32 "PAT_ALL" : !pto.mask + %mask_i16 = pto.pset_b16 "PAT_ALL" : !pto.mask + %vec = pto.vdup %value, %mask_f32 : f32, !pto.mask -> !pto.vreg<64xf32> + %cast = pto.vbitcast %vec : !pto.vreg<64xf32> -> !pto.vreg<128xi16> + pto.vsts %cast, %dst[%c0], %mask_i16 : !pto.vreg<128xi16>, !pto.ptr, !pto.mask + } + return + } +} + +// CHECK-LABEL: llvm.func @vbitcast_f32_to_i32_store_mix_aiv +// CHECK: llvm.call @llvm.hivm.vstsx1.v64s32 +// CHECK-LABEL: llvm.func @vbitcast_f32_to_i16x128_store_mix_aiv +// CHECK: llvm.call @llvm.hivm.vstsx1.v128s16 diff --git a/test/lit/vpto/vcvt_part_modes_verify_invalid.pto b/test/lit/vpto/vcvt_part_modes_verify_invalid.pto new file mode 100644 index 000000000..c8063af0e --- /dev/null +++ b/test/lit/vpto/vcvt_part_modes_verify_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --pto-arch=a5 --pto-backend=vpto %s 2>&1 | FileCheck %s + +// CHECK: error: 'pto.vcvt' op part must be P0, P1, P2, or P3 for 8/32 vcvt forms + +module attributes {pto.target_arch = "a5"} { + func.func @vcvt_u32_to_u8_rejects_even(%seed: ui32) { + pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %src = pto.vbr %seed : ui32 -> !pto.vreg<64xui32> + %bad = pto.vcvt %src, %mask {sat = "SAT", part = "EVEN"} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<256xui8> + } + return + } +} diff --git a/test/lit/vpto/vcvt_part_modes_verify_invalid_even_odd.pto b/test/lit/vpto/vcvt_part_modes_verify_invalid_even_odd.pto new file mode 100644 index 000000000..c0d126da6 --- /dev/null +++ b/test/lit/vpto/vcvt_part_modes_verify_invalid_even_odd.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --pto-arch=a5 --pto-backend=vpto %s 2>&1 | FileCheck %s + +// CHECK: error: 'pto.vcvt' op part must be EVEN or ODD for 8/16 and 16/32 vcvt forms + +module attributes {pto.target_arch = "a5"} { + func.func @vcvt_u16_to_u8_rejects_p0(%seed: ui16) { + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + %src = pto.vbr %seed : ui16 -> !pto.vreg<128xui16> + %bad = pto.vcvt %src, %mask {sat = "SAT", part = "P0"} : !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<256xui8> + } + return + } +} diff --git a/test/lit/vpto/vcvt_part_modes_vpto_llvm.pto b/test/lit/vpto/vcvt_part_modes_vpto_llvm.pto new file mode 100644 index 000000000..749898d75 --- /dev/null +++ b/test/lit/vpto/vcvt_part_modes_vpto_llvm.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ( mkdir -p %T && ptoas --pto-arch=a5 --pto-backend=vpto %s -o %t --mlir-print-ir-after=convert-func-to-llvm 2>&1 || true ) | FileCheck %s + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vcvt_u32_to_u8_packed_parts(%seed: ui32, %dst: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + pto.vecscope { + %mask = pto.pset_b8 "PAT_ALL" : !pto.mask + %src_mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %src = pto.vbr %seed : ui32 -> !pto.vreg<64xui32> + %p0 = pto.vcvt %src, %src_mask {sat = "SAT", part = "P0"} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<256xui8> + %p1 = pto.vcvt %src, %src_mask {sat = "SAT", part = "P1"} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<256xui8> + %p2 = pto.vcvt %src, %src_mask {sat = "SAT", part = "P2"} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<256xui8> + %p3 = pto.vcvt %src, %src_mask {sat = "SAT", part = "P3"} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<256xui8> + %m01 = pto.vor %p0, %p1, %mask : !pto.vreg<256xui8>, !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<256xui8> + %m23 = pto.vor %p2, %p3, %mask : !pto.vreg<256xui8>, !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<256xui8> + %out = pto.vor %m01, %m23, %mask : !pto.vreg<256xui8>, !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<256xui8> + pto.vsts %out, %dst[%c0], %mask : !pto.vreg<256xui8>, !pto.ptr, !pto.mask + } + return + } +} + +// CHECK-LABEL: llvm.func @vcvt_u32_to_u8_packed_parts_mix_aiv +// CHECK: llvm.call @llvm.hivm.vcvtii.u322u8.x +// CHECK: llvm.call @llvm.hivm.vstsx1.v256u8 diff --git a/test/lit/vpto/vpto_kernel_entry_force_v300_ctrl.pto b/test/lit/vpto/vpto_kernel_entry_force_v300_ctrl.pto new file mode 100644 index 000000000..b5c7255d0 --- /dev/null +++ b/test/lit/vpto/vpto_kernel_entry_force_v300_ctrl.pto @@ -0,0 +1,17 @@ +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto %s -o - | FileCheck %s + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @kernel(%value: f16, %dst: !pto.ptr) { + %c0 = arith.constant 0 : index + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + %vec = pto.vdup %value, %mask : f16, !pto.mask -> !pto.vreg<128xf16> + %out = pto.vcvt %vec, %mask {rnd = "R", sat = "NOSAT", part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<256xui8> + pto.vsts %out, %dst[%c0], %mask {dist = "PK_B16"} : !pto.vreg<256xui8>, !pto.ptr, !pto.mask + } + return + } +} + +// CHECK: pto.vcvt +// CHECK-SAME: {part = "EVEN", rnd = "R", sat = "NOSAT"} diff --git a/test/lit/vpto/vpto_kernel_entry_skip_v300_ctrl_for_nonsat_insensitive_vcvt.pto b/test/lit/vpto/vpto_kernel_entry_skip_v300_ctrl_for_nonsat_insensitive_vcvt.pto new file mode 100644 index 000000000..07ecae617 --- /dev/null +++ b/test/lit/vpto/vpto_kernel_entry_skip_v300_ctrl_for_nonsat_insensitive_vcvt.pto @@ -0,0 +1,17 @@ +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto %s -o - | FileCheck %s + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @kernel(%value: f16, %dst: !pto.ptr) { + %c0 = arith.constant 0 : index + pto.vecscope { + %load_mask = pto.pset_b16 "PAT_ALL" : !pto.mask + %store_mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %vec = pto.vdup %value, %load_mask : f16, !pto.mask -> !pto.vreg<128xf16> + %out = pto.vcvt %vec, %load_mask {part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %dst[%c0], %store_mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + return + } +} + +// CHECK-NOT: pto.set_ctrl diff --git a/test/lit/vpto/vpto_mainline_inline_proc_cleanup.pto b/test/lit/vpto/vpto_mainline_inline_proc_cleanup.pto new file mode 100644 index 000000000..9197079d4 --- /dev/null +++ b/test/lit/vpto/vpto_mainline_inline_proc_cleanup.pto @@ -0,0 +1,28 @@ +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto %s -o - | FileCheck %s + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @kernel(%arg0: i32) attributes { pto.tilelang.instance } { + %0 = func.call @__tl_inline_add1_i32(%arg0) : (i32) -> i32 + func.call @__tl_inline_sink_i32(%0) : (i32) -> () + return + } + + func.func private @__tl_inline_add1_i32(%x: i32) -> i32 attributes { pto.tilelang.inline_proc } { + %c1 = arith.constant 1 : i32 + %v = arith.addi %x, %c1 : i32 + return %v : i32 + } + + func.func private @__tl_inline_sink_i32(%x: i32) attributes { pto.tilelang.inline_proc } { + %c2 = arith.constant 2 : i32 + %t = arith.addi %x, %c2 : i32 + return + } +} + +// CHECK-LABEL: func.func @kernel +// CHECK: return +// CHECK-NOT: func.call @__tl_inline_ +// CHECK-NOT: call @__tl_inline_ +// CHECK-NOT: pto.tilelang.inline_proc +// CHECK-NOT: func.func private @__tl_inline_ diff --git a/test/lit/vpto/vrelu_verify_invalid.pto b/test/lit/vpto/vrelu_verify_invalid.pto new file mode 100644 index 000000000..f56901d62 --- /dev/null +++ b/test/lit/vpto/vrelu_verify_invalid.pto @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ! ptoas --pto-arch=a5 --pto-backend=vpto %s 2>&1 | FileCheck %s + +// Negative tests for pto.vrelu verifier. +// +// CHECK: error: 'pto.vrelu' op requires si32/i32/f16/f32 vector element type + +func.func @vrelu_bf16_invalid(%src: !pto.ptr, %dst: !pto.ptr) + { + %c0 = arith.constant 0 : index + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + %vec = pto.vlds %src[%c0] : !pto.ptr -> !pto.vreg<128xbf16> + %out = pto.vrelu %vec, %mask : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<128xbf16> + pto.vsts %out, %dst[%c0], %mask : !pto.vreg<128xbf16>, !pto.ptr, !pto.mask + } + return +} diff --git a/test/tilelang_st/npu/a5/src/st/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/CMakeLists.txt new file mode 100644 index 000000000..a3e167a84 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/CMakeLists.txt @@ -0,0 +1,86 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +cmake_minimum_required(VERSION 3.16) +project(tilelang_st) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_POSITION_INDEPENDENT_CODE ON) + +# CMake 3.27+ may ask the linker to emit dependency files via +# `--dependency-file`. bisheng/cce-ld does not support that flag, so disable +# linker-generated link dependencies for this standalone ST build. +if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.27) + set(CMAKE_LINK_DEPENDS_USE_LINKER FALSE) +endif() + +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) +set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) + +# -------------------------------------------------------------------------- +# PTOAS binary — passed by run_st.py via -DPTOAS_BIN=... +# -------------------------------------------------------------------------- +if(NOT DEFINED PTOAS_BIN) + message(FATAL_ERROR "PTOAS_BIN is not set. Pass -DPTOAS_BIN=/path/to/ptoas to cmake.") +endif() + +# -------------------------------------------------------------------------- +# ASCEND environment +# -------------------------------------------------------------------------- +if(NOT DEFINED ENV{ASCEND_HOME_PATH}) + message(FATAL_ERROR "Cannot find ASCEND_HOME_PATH, please run set_env.sh.") +else() + set(ASCEND_HOME_PATH $ENV{ASCEND_HOME_PATH}) +endif() + +set(PTO_ISA_ROOT "${CMAKE_CURRENT_LIST_DIR}/../../../../../../../pto-isa" CACHE PATH "Path to pto-isa repo") +set(PTO_TILELANG_ST_COMMON_DIR + "${CMAKE_CURRENT_LIST_DIR}/common") +set(ASCEND_DRIVER_PATH /usr/local/Ascend/driver) + +set(CMAKE_COMPILER bisheng) +set(CMAKE_C_COMPILER ${CMAKE_COMPILER}) +set(CMAKE_CXX_COMPILER ${CMAKE_COMPILER}) + +add_compile_options( + -D_FORTIFY_SOURCE=2 + -O2 -std=c++17 + -Wno-macro-redefined -Wno-ignored-attributes -Wno-unknown-attributes + -fstack-protector-strong + -fPIC +) +add_link_options( + -s + -Wl,-z,relro + -Wl,-z,now +) + +set(CMAKE_CCE_COMPILE_OPTIONS + -xcce + -fPIC + -Xhost-start -Xhost-end + "SHELL:-mllvm -cce-aicore-stack-size=0x8000" + "SHELL:-mllvm -cce-aicore-function-stack-size=0x8000" + "SHELL:-mllvm -cce-aicore-record-overflow=true" + "SHELL:-mllvm -cce-aicore-addr-transform" + "SHELL:-mllvm -cce-aicore-dcci-insert-for-scalar=false" +) + +set(CMAKE_CPP_COMPILE_OPTIONS + -xc++ + "SHELL:-include stdint.h" + "SHELL:-include stddef.h" +) + +include_directories( + ${ASCEND_HOME_PATH}/include + ${ASCEND_DRIVER_PATH}/kernel/inc +) + +add_subdirectory(testcase) diff --git a/test/tilelang_st/npu/a5/src/st/common/test_common.h b/test/tilelang_st/npu/a5/src/st/common/test_common.h new file mode 100644 index 000000000..661cdbcec --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/common/test_common.h @@ -0,0 +1,62 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace PtoTestCommon { + +inline bool ReadFile(const std::string &filePath, size_t &fileSize, void *buffer, + size_t bufferSize) { + struct stat sBuf; + if (stat(filePath.c_str(), &sBuf) == -1) { + return false; + } + if (!S_ISREG(sBuf.st_mode)) { + return false; + } + + std::ifstream file(filePath, std::ios::binary); + if (!file.is_open()) { + return false; + } + + std::filebuf *buf = file.rdbuf(); + size_t size = buf->pubseekoff(0, std::ios::end, std::ios::in); + if (size == 0 || size > bufferSize) { + return false; + } + buf->pubseekpos(0, std::ios::in); + buf->sgetn(static_cast(buffer), size); + fileSize = size; + return true; +} + +inline bool WriteFile(const std::string &filePath, const void *buffer, size_t size) { + if (buffer == nullptr) { + return false; + } + + int fd = open(filePath.c_str(), O_RDWR | O_CREAT | O_TRUNC, S_IRUSR | S_IWRITE); + if (fd < 0) { + return false; + } + + ssize_t writeSize = write(fd, buffer, size); + (void)close(fd); + return writeSize == static_cast(size); +} + +} // namespace PtoTestCommon diff --git a/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt new file mode 100644 index 000000000..22fbc3f73 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/CMakeLists.txt @@ -0,0 +1,209 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# -------------------------------------------------------------------------- +# pto_tilelang_st(NAME) +# +# CMake macro for TileLang ST test cases. Unlike pto-isa's pto_vec_st() +# which compiles a hand-written kernel.cpp with -xcce, this macro: +# 1. Runs ptoas to compile .pto → kernel.fatobj.o +# 2. Links the fatobj object with launch.cpp → shared library +# 3. Builds host executable from main.cpp (no GTest — comparison via compare.py) +# -------------------------------------------------------------------------- +set(PTO_TILELANG_ST_TESTCASE_DIR ${CMAKE_CURRENT_LIST_DIR}) + +function(pto_tilelang_st NAME) + set(options DISABLE_INSERT_SYNC) + set(oneValueArgs PTO_LEVEL AICORE_ARCH) + cmake_parse_arguments(PTO_TILELANG_ST "${options}" "${oneValueArgs}" "" ${ARGN}) + + set(PTOAS_ENABLE_INSERT_SYNC ON) + if(PTO_TILELANG_ST_DISABLE_INSERT_SYNC) + set(PTOAS_ENABLE_INSERT_SYNC OFF) + endif() + + set(PTOAS_PTO_LEVEL "") + if(DEFINED PTO_TILELANG_ST_PTO_LEVEL) + set(PTOAS_PTO_LEVEL "${PTO_TILELANG_ST_PTO_LEVEL}") + endif() + + set(AICORE_ARCH "dav-c310-vec") + if(DEFINED PTO_TILELANG_ST_AICORE_ARCH) + set(AICORE_ARCH "${PTO_TILELANG_ST_AICORE_ARCH}") + endif() + + # Step 1: ptoas .pto → kernel fatobj object + set(PTO_SRC ${CMAKE_CURRENT_SOURCE_DIR}/${NAME}.pto) + set(KERNEL_FATOBJ ${CMAKE_CURRENT_BINARY_DIR}/${NAME}_kernel.o) + set(PTOAS_CAPTURE_SCRIPT + ${PTO_TILELANG_ST_TESTCASE_DIR}/run_ptoas_to_file.cmake) + add_custom_command( + OUTPUT ${KERNEL_FATOBJ} + COMMAND ${CMAKE_COMMAND} + -DPTOAS_BIN=${PTOAS_BIN} + -DPTO_SRC=${PTO_SRC} + -DKERNEL_FATOBJ=${KERNEL_FATOBJ} + -DPTOAS_ENABLE_INSERT_SYNC=${PTOAS_ENABLE_INSERT_SYNC} + -DPTOAS_PTO_LEVEL=${PTOAS_PTO_LEVEL} + -P ${PTOAS_CAPTURE_SCRIPT} + DEPENDS ${PTO_SRC} ${PTOAS_CAPTURE_SCRIPT} + COMMENT "ptoas: ${NAME}.pto -> ${NAME}_kernel.o" + VERBATIM + ) + + # Step 2: link the fatobj object with launch.cpp. + add_library(${NAME}_kernel SHARED launch.cpp ${KERNEL_FATOBJ}) + set_source_files_properties(${KERNEL_FATOBJ} + PROPERTIES EXTERNAL_OBJECT TRUE GENERATED TRUE) + target_compile_options(${NAME}_kernel PRIVATE + ${CMAKE_CCE_COMPILE_OPTIONS} --cce-aicore-arch=${AICORE_ARCH} -std=c++17) + target_include_directories(${NAME}_kernel PRIVATE + ${ASCEND_HOME_PATH}/pkg_inc/ + ${ASCEND_HOME_PATH}/pkg_inc/profiling/ + ${ASCEND_HOME_PATH}/pkg_inc/runtime/runtime + ) + target_link_options(${NAME}_kernel PRIVATE --cce-fatobj-link) + + # Step 3: main.cpp → host executable + add_executable(${NAME} main.cpp) + target_compile_options(${NAME} PRIVATE ${CMAKE_CPP_COMPILE_OPTIONS}) + target_include_directories(${NAME} PRIVATE + ${PTO_TILELANG_ST_COMMON_DIR} + ) + + target_link_directories(${NAME} PUBLIC + ${ASCEND_HOME_PATH}/lib64 + ${ASCEND_HOME_PATH}/tools/simulator/${SOC_VERSION}/lib + ) + + target_link_libraries(${NAME} PRIVATE + ${NAME}_kernel + $:runtime_camodel>> + $:runtime>> + stdc++ ascendcl m tiling_api platform c_sec dl nnopbase pthread + ) +endfunction() + +function(pto_tilelang_vec_st NAME) + pto_tilelang_st( + ${NAME} + AICORE_ARCH dav-c310-vec + ${ARGN} + ) +endfunction() + +function(pto_tilelang_cube_st NAME) + pto_tilelang_st( + ${NAME} + AICORE_ARCH dav-c310-cube + ${ARGN} + ) +endfunction() + +# -------------------------------------------------------------------------- +# Test case registry — add new ops here. +# -------------------------------------------------------------------------- +set(ALL_TESTCASES + tadd + tsub + tmul + tdiv + tmax + tmin + tmov + tmrgsort + tshl + tshr + tand + tor + txor + tcmp + tfmod + trem + tcvt + tload + tlrelu + trelu + tsel + tsels + tcolmax + tcolmin + tcolsum + tcolprod + tcolargmax + tcolargmin + tcolexpand + tcolexpandadd + tcolexpandsub + tcolexpandmul + tcolexpanddiv + tcolexpandmax + tcolexpandmin + tcolexpandexpdif + softmax + tabs + texp + tlog + tneg + tnot + tpartmax + tpartmin + tpartadd + tpartmul + tprelu + trandom + trecip + trowargmax + trowargmin + trowsum + trowmax + trowmin + trowprod + trsqrt + tsort32 + tsqrt + trowexpand + trowexpandadd + trowexpanddiv + trowexpandexpdif + trowexpandmax + trowexpandmin + trowexpandmul + trowexpandsub + texpands + tfillpad + tfillpad_inplace + tfillpad_expand + tadds + tands + tdivs + tmaxs + tmins + tmuls + tors + tshls + tshrs + tsubs + txors + trems + tfmods + tcmps + tmatmul +) + +if((TEST_CASE IN_LIST ALL_TESTCASES) OR (TEST_CASE STREQUAL "all")) + message(STATUS "run: ${TEST_CASE}") +else() + message(FATAL_ERROR "not found TEST_CASE: ${TEST_CASE}, supported: ${ALL_TESTCASES}") +endif() + +foreach(TESTCASE ${ALL_TESTCASES}) + if((DEFINED TEST_CASE AND TEST_CASE STREQUAL TESTCASE) OR (NOT DEFINED TEST_CASE) OR (TEST_CASE STREQUAL "all")) + add_subdirectory(${TESTCASE}) + endif() +endforeach() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/run_ptoas_to_file.cmake b/test/tilelang_st/npu/a5/src/st/testcase/run_ptoas_to_file.cmake new file mode 100644 index 000000000..b8a3a0070 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/run_ptoas_to_file.cmake @@ -0,0 +1,69 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +if(NOT DEFINED PTOAS_BIN OR NOT DEFINED PTO_SRC OR NOT DEFINED KERNEL_FATOBJ) + message(FATAL_ERROR "PTOAS_BIN, PTO_SRC, and KERNEL_FATOBJ must be provided") +endif() + +get_filename_component(KERNEL_FATOBJ_DIR "${KERNEL_FATOBJ}" DIRECTORY) +file(MAKE_DIRECTORY "${KERNEL_FATOBJ_DIR}") + +if(NOT DEFINED PTOAS_ENABLE_INSERT_SYNC) + set(PTOAS_ENABLE_INSERT_SYNC ON) +endif() + +set(PTOAS_COMMAND + "${PTOAS_BIN}" + --pto-arch=a5 +) + +if(DEFINED PTOAS_PTO_LEVEL AND NOT PTOAS_PTO_LEVEL STREQUAL "") + list(APPEND PTOAS_COMMAND "--pto-level=${PTOAS_PTO_LEVEL}") +endif() + +list(APPEND PTOAS_COMMAND --pto-backend=vpto) + +if(PTOAS_ENABLE_INSERT_SYNC) + list(APPEND PTOAS_COMMAND --enable-insert-sync) +endif() + +list(APPEND PTOAS_COMMAND + --enable-tile-op-expand + "${PTO_SRC}" + -o + "${KERNEL_FATOBJ}" +) + +execute_process( + COMMAND ${PTOAS_COMMAND} + ERROR_VARIABLE PTOAS_STDERR + RESULT_VARIABLE PTOAS_RESULT +) + +if(NOT PTOAS_RESULT EQUAL 0) + string(STRIP "${PTOAS_STDERR}" PTOAS_STDERR) + if(PTOAS_STDERR) + message(FATAL_ERROR "ptoas failed while generating ${KERNEL_FATOBJ}:\n${PTOAS_STDERR}") + endif() + message(FATAL_ERROR "ptoas failed while generating ${KERNEL_FATOBJ}") +endif() + +if(NOT EXISTS "${KERNEL_FATOBJ}") + message(FATAL_ERROR "ptoas completed without producing ${KERNEL_FATOBJ}") +endif() + +file(SIZE "${KERNEL_FATOBJ}" KERNEL_FATOBJ_SIZE) +if(KERNEL_FATOBJ_SIZE EQUAL 0) + file(REMOVE "${KERNEL_FATOBJ}") + string(STRIP "${PTOAS_STDERR}" PTOAS_STDERR) + if(PTOAS_STDERR) + message(FATAL_ERROR + "ptoas produced empty fatobj for ${PTO_SRC}:\n${PTOAS_STDERR}") + endif() + message(FATAL_ERROR "ptoas produced empty fatobj for ${PTO_SRC}") +endif() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/softmax/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/softmax/CMakeLists.txt new file mode 100644 index 000000000..3c5224444 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/softmax/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(softmax DISABLE_INSERT_SYNC PTO_LEVEL level3) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/softmax/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/softmax/cases.py new file mode 100644 index 000000000..8b865c96a --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/softmax/cases.py @@ -0,0 +1,25 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import numpy as np + + +CASES = [ + { + "name": "f32_rows24_seq73", + "dtype": np.float32, + "shape": (24, 128), + "valid_shape": (24, 73), + "eps": 1e-4, + "rows": 24, + "cols": 128, + "seq": 73, + "seed": 19, + }, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/softmax/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/softmax/compare.py new file mode 100644 index 000000000..6a5c89eb8 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/softmax/compare.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import os +import sys + +import numpy as np +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def load_array(path, dtype, shape): + if not os.path.exists(path): + raise FileNotFoundError(path) + return np.fromfile(path, dtype=dtype).reshape(shape) + + +def compare_case(case): + case_dir = case["name"] + rows = int(case["rows"]) + cols = int(case["cols"]) + seq = int(case["seq"]) + dtype = case["dtype"] + eps = case["eps"] + + try: + golden_v4 = load_array(os.path.join(case_dir, "golden_v4.bin"), dtype, (rows,)) + output_v4 = load_array(os.path.join(case_dir, "v4.bin"), dtype, (rows,)) + golden_v5 = load_array(os.path.join(case_dir, "golden_v5.bin"), dtype, (rows,)) + output_v5 = load_array(os.path.join(case_dir, "v5.bin"), dtype, (rows,)) + golden_v6 = load_array(os.path.join(case_dir, "golden_v6.bin"), dtype, (rows,)) + output_v6 = load_array(os.path.join(case_dir, "v6.bin"), dtype, (rows,)) + golden_v7 = load_array( + os.path.join(case_dir, "golden_v7.bin"), dtype, (rows, cols) + ) + output_v7 = load_array(os.path.join(case_dir, "v7.bin"), dtype, (rows, cols)) + except FileNotFoundError as exc: + print(style_fail(f"[ERROR] {case['name']}: missing file {exc}")) + return False + + ok = True + ok = result_cmp(golden_v4, output_v4, eps) and ok + ok = result_cmp(golden_v5, output_v5, eps) and ok + ok = result_cmp(golden_v6, output_v6, eps) and ok + ok = result_cmp(golden_v7[:, :seq], output_v7[:, :seq], eps) and ok + return ok + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + matched_case = case_filter is None + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + matched_case = True + ok = compare_case(case) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not matched_case: + print(style_fail(f"[ERROR] unknown case filter: {case_filter}")) + sys.exit(2) + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/softmax/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/softmax/gen_data.py new file mode 100644 index 000000000..05bcef759 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/softmax/gen_data.py @@ -0,0 +1,64 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import numpy as np + +from cases import CASES +from st_common import save_case_data, validate_cases + + +validate_cases(CASES) + +for case in CASES: + rows = int(case["rows"]) + cols = int(case["cols"]) + seq = int(case["seq"]) + seed = int(case["seed"]) + + rng = np.random.default_rng(seed) + oldmax = rng.uniform(-3.0, 1.5, size=(rows,)).astype(np.float32) + oldsum = rng.uniform(0.5, 4.0, size=(rows,)).astype(np.float32) + qk = rng.normal(loc=0.0, scale=1.5, size=(rows, cols)).astype(np.float32) + + qk_active = qk[:, :seq] + qk_rowmax = np.max(qk_active, axis=1) + newmax = np.maximum(qk_rowmax, oldmax) + tmp_active = np.exp(qk_active - newmax[:, None], dtype=np.float32) + cursum = np.sum(tmp_active, axis=1, dtype=np.float32) + raw_expmax = np.exp(oldmax - newmax, dtype=np.float32) + newsum = raw_expmax * oldsum + cursum + expmax = (raw_expmax * oldsum) / newsum + out = np.zeros((rows, cols), dtype=np.float32) + out[:, :seq] = tmp_active / newsum[:, None] + + zeros_state = np.zeros((rows,), dtype=np.float32) + zeros_out = np.zeros((rows, cols), dtype=np.float32) + + save_case_data( + case["name"], + { + "v1": oldmax, + "v2": oldsum, + "v3": qk.reshape(-1), + "v4": zeros_state, + "v5": zeros_state, + "v6": zeros_state, + "v7": zeros_out.reshape(-1), + "v8": np.array([seq], dtype=np.int32), + "v9": np.array([rows], dtype=np.int32), + "golden_v4": newmax, + "golden_v5": newsum, + "golden_v6": expmax, + "golden_v7": out.reshape(-1), + }, + ) + print( + f"[INFO] gen_data: {case['name']} rows={rows} cols={cols} " + f"seq={seq} dtype={case['dtype'].__name__}" + ) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/softmax/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/softmax/launch.cpp new file mode 100644 index 000000000..dd702e189 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/softmax/launch.cpp @@ -0,0 +1,27 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +extern "C" __global__ AICORE void online_softmax_update_kernel_2d(__gm__ float *v1, __gm__ float *v2, __gm__ float *v3, __gm__ float *v4, __gm__ float *v5, __gm__ float *v6, __gm__ float *v7, int32_t v8, int32_t v9); + +void LaunchSOFTMAX_f32_rows24_seq73(float *v1, float *v2, float *v3, + float *v4, float *v5, float *v6, + float *v7, int32_t v8, int32_t v9, + void *stream) { + const int32_t blockRows = 8; + const int32_t blocks = (v9 + blockRows - 1) / blockRows; + online_softmax_update_kernel_2d<<>>( + (__gm__ float *)v1, (__gm__ float *)v2, (__gm__ float *)v3, + (__gm__ float *)v4, (__gm__ float *)v5, (__gm__ float *)v6, + (__gm__ float *)v7, v8, v9); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/softmax/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/softmax/main.cpp new file mode 100644 index 000000000..4018ecf57 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/softmax/main.cpp @@ -0,0 +1,197 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +namespace pto { +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +} // namespace pto +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchSOFTMAX_f32_rows24_seq73(float *v1, float *v2, float *v3, + float *v4, float *v5, float *v6, + float *v7, int32_t v8, int32_t v9, + void *stream); + +using LaunchFn = void (*)(float *, float *, float *, float *, float *, float *, + float *, int32_t, int32_t, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; + size_t cols; +}; + +static const TestCase kCases[] = { + {"f32_rows24_seq73", LaunchSOFTMAX_f32_rows24_seq73, 24, 128}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, aclrtStream stream) { + const size_t scalarBytes = sizeof(int32_t); + const size_t stateElems = tc.rows; + const size_t outElems = tc.rows * tc.cols; + const size_t stateBytes = stateElems * sizeof(float); + const size_t outBytes = outElems * sizeof(float); + std::string caseDir = std::string("./") + tc.name; + + float *v1Host = nullptr, *v2Host = nullptr, *v3Host = nullptr; + float *v4Host = nullptr, *v5Host = nullptr, *v6Host = nullptr, *v7Host = nullptr; + float *v1Device = nullptr, *v2Device = nullptr, *v3Device = nullptr; + float *v4Device = nullptr, *v5Device = nullptr, *v6Device = nullptr, *v7Device = nullptr; + int32_t seqHost = 0; + int32_t rowsHost = 0; + size_t fileSize = 0; + int rc = 0; + + std::printf("[INFO] === case: %s (rows=%zu, cols=%zu) ===\n", + tc.name, tc.rows, tc.cols); + + if (!ReadFile(caseDir + "/v8.bin", fileSize, &seqHost, scalarBytes) || + !ReadFile(caseDir + "/v9.bin", fileSize, &rowsHost, scalarBytes)) { + std::fprintf(stderr, "[ERROR] failed to read scalar inputs for %s\n", tc.name); + return 1; + } + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), stateBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), stateBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), outBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&v4Host), stateBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&v5Host), stateBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&v6Host), stateBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&v7Host), outBytes)); + + ACL_CHECK(aclrtMalloc((void **)&v1Device, stateBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, stateBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v4Device, stateBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v5Device, stateBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v6Device, stateBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v7Device, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + if (!ReadFile(caseDir + "/v1.bin", fileSize, v1Host, stateBytes) || + !ReadFile(caseDir + "/v2.bin", fileSize, v2Host, stateBytes) || + !ReadFile(caseDir + "/v3.bin", fileSize, v3Host, outBytes) || + !ReadFile(caseDir + "/v4.bin", fileSize, v4Host, stateBytes) || + !ReadFile(caseDir + "/v5.bin", fileSize, v5Host, stateBytes) || + !ReadFile(caseDir + "/v6.bin", fileSize, v6Host, stateBytes) || + !ReadFile(caseDir + "/v7.bin", fileSize, v7Host, outBytes)) { + std::fprintf(stderr, "[ERROR] failed to read tensor inputs for %s\n", tc.name); + rc = 1; + goto cleanup; + } + + ACL_CHECK(aclrtMemcpy(v1Device, stateBytes, v1Host, stateBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, stateBytes, v2Host, stateBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, outBytes, v3Host, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v4Device, stateBytes, v4Host, stateBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v5Device, stateBytes, v5Host, stateBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v6Device, stateBytes, v6Host, stateBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v7Device, outBytes, v7Host, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + + tc.launch(v1Device, v2Device, v3Device, v4Device, v5Device, v6Device, + v7Device, seqHost, rowsHost, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v4Host, stateBytes, v4Device, stateBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(v5Host, stateBytes, v5Device, stateBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(v6Host, stateBytes, v6Device, stateBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(v7Host, outBytes, v7Device, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + + if (!WriteFile(caseDir + "/v4.bin", v4Host, stateBytes) || + !WriteFile(caseDir + "/v5.bin", v5Host, stateBytes) || + !WriteFile(caseDir + "/v6.bin", v6Host, stateBytes) || + !WriteFile(caseDir + "/v7.bin", v7Host, outBytes)) { + std::fprintf(stderr, "[ERROR] failed to write outputs for %s\n", tc.name); + rc = 1; + } + +cleanup: + if (v1Device != nullptr) aclrtFree(v1Device); + if (v2Device != nullptr) aclrtFree(v2Device); + if (v3Device != nullptr) aclrtFree(v3Device); + if (v4Device != nullptr) aclrtFree(v4Device); + if (v5Device != nullptr) aclrtFree(v5Device); + if (v6Device != nullptr) aclrtFree(v6Device); + if (v7Device != nullptr) aclrtFree(v7Device); + if (v1Host != nullptr) aclrtFreeHost(v1Host); + if (v2Host != nullptr) aclrtFreeHost(v2Host); + if (v3Host != nullptr) aclrtFreeHost(v3Host); + if (v4Host != nullptr) aclrtFreeHost(v4Host); + if (v5Host != nullptr) aclrtFreeHost(v5Host); + if (v6Host != nullptr) aclrtFreeHost(v6Host); + if (v7Host != nullptr) aclrtFreeHost(v7Host); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + bool matchedCase = (caseFilter == nullptr); + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) + continue; + matchedCase = true; + if (RunCase(kCases[i], stream) != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (!matchedCase) { + std::fprintf(stderr, "[ERROR] unknown case filter: %s\n", caseFilter); + rc = 1; + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/softmax/softmax.pto b/test/tilelang_st/npu/a5/src/st/testcase/softmax/softmax.pto new file mode 100644 index 000000000..c6acdbe40 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/softmax/softmax.pto @@ -0,0 +1,238 @@ +// TileLang ST kernel for online softmax update with mixed pto.tload/pto.tstore +// and raw VPTO vecscope compute. +// This testcase keeps manual sync in the source, so ST compilation disables +// --enable-insert-sync and enables --pto-level=level3 for alloc_tile addr=. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @online_softmax_update_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr, + %arg2: !pto.ptr, + %arg3: !pto.ptr, + %arg4: !pto.ptr, + %arg5: !pto.ptr, + %arg6: !pto.ptr, + %arg7: i32, + %arg8: i32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c8_i64 = arith.constant 8 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %c8448_i64 = arith.constant 8448 : i64 + %c16640_i64 = arith.constant 16640 : i64 + %c16768_i64 = arith.constant 16768 : i64 + %c16896_i64 = arith.constant 16896 : i64 + + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %c64_i32 = arith.constant 64 : i32 + %c0_i32 = arith.constant 0 : i32 + + %block = pto.get_block_idx + %block_idx = arith.index_cast %block : i64 to index + %row_base = arith.muli %block_idx, %c8 : index + %block_rows_i32 = arith.index_cast %c8 : index to i32 + %row_base_i32 = arith.index_cast %row_base : index to i32 + %remaining_rows = arith.subi %arg8, %row_base_i32 : i32 + %has_rows = arith.cmpi sgt, %remaining_rows, %c0_i32 : i32 + %too_many_rows = arith.cmpi sgt, %remaining_rows, %c8_i32 : i32 + %row_count_i32 = arith.select %too_many_rows, %c8_i32, %remaining_rows : i32 + %row_count = arith.index_cast %row_count_i32 : i32 to index + %seq = arith.index_cast %arg7 : i32 to index + %rows = arith.index_cast %arg8 : i32 to index + %rows_x_128 = arith.muli %rows, %c128 : index + + scf.if %has_rows { + %oldmax_view = pto.make_tensor_view %arg0, + shape = [%c1, %c1, %c1, %rows, %c1], + strides = [%rows, %rows, %rows, %c1, %rows] + : !pto.tensor_view + %oldsum_view = pto.make_tensor_view %arg1, + shape = [%c1, %c1, %c1, %rows, %c1], + strides = [%rows, %rows, %rows, %c1, %rows] + : !pto.tensor_view + %qk_view = pto.make_tensor_view %arg2, + shape = [%c1, %c1, %c1, %rows, %c128], + strides = [%rows_x_128, %rows_x_128, %rows_x_128, %c128, %c1] + : !pto.tensor_view + %newmax_view = pto.make_tensor_view %arg3, + shape = [%c1, %c1, %c1, %rows, %c1], + strides = [%rows, %rows, %rows, %c1, %rows] + : !pto.tensor_view + %newsum_view = pto.make_tensor_view %arg4, + shape = [%c1, %c1, %c1, %rows, %c1], + strides = [%rows, %rows, %rows, %c1, %rows] + : !pto.tensor_view + %expmax_view = pto.make_tensor_view %arg5, + shape = [%c1, %c1, %c1, %rows, %c1], + strides = [%rows, %rows, %rows, %c1, %rows] + : !pto.tensor_view + %out_view = pto.make_tensor_view %arg6, + shape = [%c1, %c1, %c1, %rows, %c128], + strides = [%rows_x_128, %rows_x_128, %rows_x_128, %c128, %c1] + : !pto.tensor_view + + %oldmax_part = pto.partition_view %oldmax_view, + offsets = [%c0, %c0, %c0, %row_base, %c0], + sizes = [%c1, %c1, %c1, %row_count, %c1] + : !pto.tensor_view -> !pto.partition_tensor_view + %oldsum_part = pto.partition_view %oldsum_view, + offsets = [%c0, %c0, %c0, %row_base, %c0], + sizes = [%c1, %c1, %c1, %row_count, %c1] + : !pto.tensor_view -> !pto.partition_tensor_view + %qk_part = pto.partition_view %qk_view, + offsets = [%c0, %c0, %c0, %row_base, %c0], + sizes = [%c1, %c1, %c1, %row_count, %seq] + : !pto.tensor_view -> !pto.partition_tensor_view + %newmax_part = pto.partition_view %newmax_view, + offsets = [%c0, %c0, %c0, %row_base, %c0], + sizes = [%c1, %c1, %c1, %row_count, %c1] + : !pto.tensor_view -> !pto.partition_tensor_view + %newsum_part = pto.partition_view %newsum_view, + offsets = [%c0, %c0, %c0, %row_base, %c0], + sizes = [%c1, %c1, %c1, %row_count, %c1] + : !pto.tensor_view -> !pto.partition_tensor_view + %expmax_part = pto.partition_view %expmax_view, + offsets = [%c0, %c0, %c0, %row_base, %c0], + sizes = [%c1, %c1, %c1, %row_count, %c1] + : !pto.tensor_view -> !pto.partition_tensor_view + %out_part = pto.partition_view %out_view, + offsets = [%c0, %c0, %c0, %row_base, %c0], + sizes = [%c1, %c1, %c1, %row_count, %seq] + : !pto.tensor_view -> !pto.partition_tensor_view + + // Tile domain: alloc_tile creates UB tile handles; tload/tstore operate + // on tile_buf values before/after the vector scope compute region. + %oldmax_tile = pto.alloc_tile addr = %c0_i64 valid_row = %row_count + : !pto.tile_buf + %oldsum_tile = pto.alloc_tile addr = %c128_i64 valid_row = %row_count + : !pto.tile_buf + %qk_tile = pto.alloc_tile addr = %c256_i64 valid_row = %row_count valid_col = %seq + : !pto.tile_buf + %out_tile = pto.alloc_tile addr = %c8448_i64 valid_row = %row_count valid_col = %seq + : !pto.tile_buf + %newmax_tile = pto.alloc_tile addr = %c16640_i64 valid_row = %row_count + : !pto.tile_buf + %newsum_tile = pto.alloc_tile addr = %c16768_i64 valid_row = %row_count + : !pto.tile_buf + %expmax_tile = pto.alloc_tile addr = %c16896_i64 valid_row = %row_count + : !pto.tile_buf + + pto.tload ins(%oldmax_part : !pto.partition_tensor_view) + outs(%oldmax_tile : !pto.tile_buf) + pto.tload ins(%oldsum_part : !pto.partition_tensor_view) + outs(%oldsum_tile : !pto.tile_buf) + pto.tload ins(%qk_part : !pto.partition_tensor_view) + outs(%qk_tile : !pto.tile_buf) + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + // Boundary into vecscope instructions: tile_buf_addr materializes UB + // pointers from tile handles so vecscope can use vlds/vsts. + %ub_oldmax = pto.tile_buf_addr %oldmax_tile + : !pto.tile_buf + -> !pto.ptr + %ub_oldsum = pto.tile_buf_addr %oldsum_tile + : !pto.tile_buf + -> !pto.ptr + %ub_qk = pto.tile_buf_addr %qk_tile + : !pto.tile_buf + -> !pto.ptr + %ub_out = pto.tile_buf_addr %out_tile + : !pto.tile_buf + -> !pto.ptr + %ub_newmax = pto.tile_buf_addr %newmax_tile + : !pto.tile_buf + -> !pto.ptr + %ub_newsum = pto.tile_buf_addr %newsum_tile + : !pto.tile_buf + -> !pto.ptr + %ub_expmax = pto.tile_buf_addr %expmax_tile + : !pto.tile_buf + -> !pto.ptr + %active = pto.pset_b32 "PAT_ALL" : !pto.mask + %one_mask, %one_remaining = pto.plt_b32 %c1_i32 : i32 -> !pto.mask, i32 + scf.for %row = %c0 to %row_count step %c1 { + %row_qk = arith.muli %row, %c128 : index + %oldmax_bc = pto.vlds %ub_oldmax[%row] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> + %oldsum_bc = pto.vlds %ub_oldsum[%row] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> + + %final_max, %final_sum = scf.for %chunk = %c0 to %c128 step %c64 + iter_args(%running_max = %oldmax_bc, %running_sum = %oldsum_bc) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %chunk_i32 = arith.index_cast %chunk : index to i32 + %remaining_cols = arith.subi %arg7, %chunk_i32 : i32 + %has_chunk = arith.cmpi sgt, %remaining_cols, %c0_i32 : i32 + %next_max, %next_sum = scf.if %has_chunk -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %chunk_mask, %chunk_rest = pto.plt_b32 %remaining_cols : i32 -> !pto.mask, i32 + %chunk_base = arith.addi %row_qk, %chunk : index + %vec = pto.vlds %ub_qk[%chunk_base] : !pto.ptr -> !pto.vreg<64xf32> + %chunk_max = pto.vcmax %vec, %chunk_mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %chunk_max_bc = pto.vdup %chunk_max, %active {position = "LOWEST"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %merged_max = pto.vmax %running_max, %chunk_max_bc, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %scaled_running = pto.vexpdif %running_max, %merged_max, %active, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %running_sum_scaled = pto.vmul %scaled_running, %running_sum, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %chunk_exp = pto.vexpdif %vec, %merged_max, %chunk_mask, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %chunk_sum = pto.vcadd %chunk_exp, %chunk_mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %chunk_sum_bc = pto.vdup %chunk_sum, %active {position = "LOWEST"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %merged_sum = pto.vadd %running_sum_scaled, %chunk_sum_bc, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + scf.yield %merged_max, %merged_sum : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } else { + scf.yield %running_max, %running_sum : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + scf.yield %next_max, %next_sum : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + %raw_expmax = pto.vexpdif %oldmax_bc, %final_max, %active, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %scaled_oldsum = pto.vmul %raw_expmax, %oldsum_bc, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %expmax = pto.vdiv %scaled_oldsum, %final_sum, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %final_max, %ub_newmax[%row], %one_mask {dist = "1PT_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + pto.vsts %final_sum, %ub_newsum[%row], %one_mask {dist = "1PT_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + pto.vsts %expmax, %ub_expmax[%row], %one_mask {dist = "1PT_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + + scf.for %chunk = %c0 to %c128 step %c64 { + %chunk_i32 = arith.index_cast %chunk : index to i32 + %remaining_cols = arith.subi %arg7, %chunk_i32 : i32 + %has_chunk = arith.cmpi sgt, %remaining_cols, %c0_i32 : i32 + scf.if %has_chunk { + %chunk_mask, %chunk_rest = pto.plt_b32 %remaining_cols : i32 -> !pto.mask, i32 + %chunk_base = arith.addi %row_qk, %chunk : index + %vec = pto.vlds %ub_qk[%chunk_base] : !pto.ptr -> !pto.vreg<64xf32> + %exp = pto.vexpdif %vec, %final_max, %chunk_mask, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %out = pto.vdiv %exp, %final_sum, %chunk_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%chunk_base], %chunk_mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + } + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + // Back in the tile domain: tstore writes the tile_buf results to GM + // partitions after the VPTO vecscope finishes. + pto.tstore ins(%newmax_tile : !pto.tile_buf) + outs(%newmax_part : !pto.partition_tensor_view) + pto.tstore ins(%newsum_tile : !pto.tile_buf) + outs(%newsum_part : !pto.partition_tensor_view) + pto.tstore ins(%expmax_tile : !pto.tile_buf) + outs(%expmax_part : !pto.partition_tensor_view) + pto.tstore ins(%out_tile : !pto.tile_buf) + outs(%out_part : !pto.partition_tensor_view) + } + pto.barrier #pto.pipe + return + } +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/st_common.py b/test/tilelang_st/npu/a5/src/st/testcase/st_common.py new file mode 100644 index 000000000..d0401b202 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/st_common.py @@ -0,0 +1,143 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Shared utilities for TileLang ST test cases. + +Provides: + - Data helpers: setup_case_rng(), save_case_data() + - Compare: result_cmp() + - Styling: supports_color(), style_pass(), style_fail() +""" + +import os +import sys +import numpy as np + + +# --------------------------------------------------------------------------- +# Case helpers +# --------------------------------------------------------------------------- + +REQUIRED_CASE_KEYS = {"name", "dtype", "shape", "valid_shape", "eps"} + + +def _to_shape_tuple(shape): + if not isinstance(shape, (tuple, list)): + raise ValueError(f"shape must be tuple/list, got {type(shape).__name__}: {shape!r}") + if not shape: + raise ValueError("shape must not be empty") + dims = tuple(int(dim) for dim in shape) + if any(dim <= 0 for dim in dims): + raise ValueError(f"shape dimensions must be > 0, got {dims}") + return dims + + +def _validate_shape_pair(shape, valid_shape, label): + shape = _to_shape_tuple(shape) + valid_shape = _to_shape_tuple(valid_shape) + if len(shape) != len(valid_shape): + raise ValueError(f"{label}: shape rank mismatch: {shape} vs {valid_shape}") + if any(valid_dim > dim for dim, valid_dim in zip(shape, valid_shape)): + raise ValueError(f"{label}: valid shape {valid_shape} exceeds shape {shape}") + return shape, valid_shape + + +def validate_cases(cases): + """Check that every case has all required keys.""" + for i, case in enumerate(cases): + missing = REQUIRED_CASE_KEYS - case.keys() + if missing: + raise ValueError(f"cases[{i}] ({case.get('name', '?')}) missing keys: {missing}") + _validate_shape_pair(case["shape"], case["valid_shape"], "shape") + has_dst_shape = "dst_shape" in case + has_dst_valid_shape = "dst_valid_shape" in case + if has_dst_shape != has_dst_valid_shape: + raise ValueError( + f"cases[{i}] ({case.get('name', '?')}) must define both dst_shape and dst_valid_shape" + ) + if has_dst_shape: + _validate_shape_pair(case["dst_shape"], case["dst_valid_shape"], "dst") + + +# --------------------------------------------------------------------------- +# Data generation helpers +# --------------------------------------------------------------------------- + +def setup_case_rng(case): + """Set a per-case deterministic random seed. + + Using hash(name) ensures that adding/reordering cases does not change + the random data of existing cases. + """ + np.random.seed(hash(case["name"]) & 0xFFFFFFFF) + + +def save_case_data(case_name, data_dict): + """Create case directory and write {name}.bin for each entry in data_dict. + + Args: + case_name: subdirectory name (e.g. "f32_16x64"). + data_dict: mapping from file stem to numpy array, + e.g. {"input1": arr1, "input2": arr2, "golden": golden}. + """ + os.makedirs(case_name, exist_ok=True) + for name, arr in data_dict.items(): + arr.tofile(os.path.join(case_name, f"{name}.bin")) + + +# --------------------------------------------------------------------------- +# Terminal styling +# --------------------------------------------------------------------------- + +ANSI_RESET = "\033[0m" +ANSI_BOLD_GREEN = "\033[1;32m" +ANSI_BOLD_RED = "\033[1;31m" + + +def supports_color(): + return sys.stdout.isatty() and os.environ.get("TERM") not in (None, "", "dumb") + + +def style_pass(text): + if not supports_color(): + return text + return f"{ANSI_BOLD_GREEN}{text}{ANSI_RESET}" + + +def style_fail(text): + if not supports_color(): + return text + return f"{ANSI_BOLD_RED}{text}{ANSI_RESET}" + + +# --------------------------------------------------------------------------- +# Comparison +# --------------------------------------------------------------------------- + +def result_cmp(golden, output, eps): + """Compare already prepared golden/output arrays. + + The caller is responsible for loading, reshaping and slicing data. + """ + g = np.asarray(golden).astype(np.float64, copy=False) + o = np.asarray(output).astype(np.float64, copy=False) + + if g.shape != o.shape: + print(style_fail(f"[ERROR] Shape mismatch: golden {g.shape} vs output {o.shape}")) + return False + if not np.allclose(g, o, atol=eps, rtol=eps, equal_nan=True): + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + print(style_fail(f"[ERROR] Mismatch: max diff={float(abs_diff.flat[idx])} " + f"at flat idx={idx} " + f"(golden={g.flat[idx]}, output={o.flat[idx]})")) + return False + return True diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tabs/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tabs/CMakeLists.txt new file mode 100644 index 000000000..b776efb52 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tabs/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tabs) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tabs/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tabs/cases.py new file mode 100644 index 000000000..d63eb85f8 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tabs/cases.py @@ -0,0 +1,55 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tabs ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_16x64", + "dtype": np.float32, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-6, + }, + { + "name": "f32_32x32", + "dtype": np.float32, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-6, + }, + { + "name": "f16_16x64", + "dtype": np.float16, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-3, + }, + { + "name": "f16_32x32", + "dtype": np.float16, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-3, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tabs/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tabs/compare.py new file mode 100644 index 000000000..428604929 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tabs/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tabs/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tabs/gen_data.py new file mode 100644 index 000000000..22bf5d95d --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tabs/gen_data.py @@ -0,0 +1,32 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + input = np.random.randn(*shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + golden[:vr, :vc] = np.abs(input[:vr, :vc]).astype(dtype, copy=False) + + save_case_data(case["name"], {"input": input, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tabs/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tabs/launch.cpp new file mode 100644 index 000000000..dd39abd15 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tabs/launch.cpp @@ -0,0 +1,41 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: f32 16x64 +extern "C" __global__ AICORE void TABS_f32_16x64(__gm__ float *a, __gm__ float *b); + +void LaunchTABS_f32_16x64(void *a, void *b, void *stream) { + TABS_f32_16x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b); +} + +// Case 1: f32 32x32 +extern "C" __global__ AICORE void TABS_f32_32x32(__gm__ float *a, __gm__ float *b); + +void LaunchTABS_f32_32x32(void *a, void *b, void *stream) { + TABS_f32_32x32<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b); +} + +// Case 2: f16 16x64 +extern "C" __global__ AICORE void TABS_f16_16x64(__gm__ uint16_t *a, __gm__ uint16_t *b); + +void LaunchTABS_f16_16x64(void *a, void *b, void *stream) { + TABS_f16_16x64<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b); +} + +// Case 3: f16 32x32 +extern "C" __global__ AICORE void TABS_f16_32x32(__gm__ uint16_t *a, __gm__ uint16_t *b); + +void LaunchTABS_f16_32x32(void *a, void *b, void *stream) { + TABS_f16_32x32<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tabs/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tabs/main.cpp new file mode 100644 index 000000000..681510ddf --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tabs/main.cpp @@ -0,0 +1,137 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tabs ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTABS_f32_16x64(void *a, void *b, void *stream); +void LaunchTABS_f32_32x32(void *a, void *b, void *stream); +void LaunchTABS_f16_16x64(void *a, void *b, void *stream); +void LaunchTABS_f16_32x32(void *a, void *b, void *stream); + +using LaunchFn = void (*)(void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_16x64", LaunchTABS_f32_16x64, 16, 64, 16, 64, sizeof(float)}, + {"f32_32x32", LaunchTABS_f32_32x32, 32, 32, 32, 32, sizeof(float)}, + {"f16_16x64", LaunchTABS_f16_16x64, 16, 64, 16, 64, sizeof(uint16_t)}, + {"f16_32x32", LaunchTABS_f16_32x32, 32, 32, 32, 32, sizeof(uint16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSize = fileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&srcHost), fileSize); + aclrtMallocHost((void **)(&dstHost), fileSize); + + aclrtMalloc((void **)&srcDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input.bin").c_str(), srcFileSize, srcHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, fileSize, srcHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tabs [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tabs/tabs.pto b/test/tilelang_st/npu/a5/src/st/testcase/tabs/tabs.pto new file mode 100644 index 000000000..bf702ecdf --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tabs/tabs.pto @@ -0,0 +1,180 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tabs: tload(a) + tabs(a)->b + tstore(b). +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // Case 0: f32 16x64 (1024 elements) + func.func @TABS_f32_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%a : !pto.tile_buf) + + pto.tabs ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + return + } + + // Case 1: f32 32x32 (1024 elements) + func.func @TABS_f32_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + outs(%a : !pto.tile_buf) + + pto.tabs ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + return + } + + // Case 2: f16 16x64 (1024 elements) + func.func @TABS_f16_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + outs(%a : !pto.tile_buf) + + pto.tabs ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + return + } + + // Case 3: f16 32x32 (1024 elements) + func.func @TABS_f16_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf16> -> !pto.partition_tensor_view<1x1x1x32x32xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf16> -> !pto.partition_tensor_view<1x1x1x32x32xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xf16>) + outs(%a : !pto.tile_buf) + + pto.tabs ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x32x32xf16>) + return + } +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tadd/CMakeLists.txt new file mode 100644 index 000000000..84928bcdb --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tadd/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tadd) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tadd/cases.py new file mode 100644 index 000000000..5958f05d2 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tadd/cases.py @@ -0,0 +1,40 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tadd ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_16x64", + "dtype": np.float32, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-6, + }, + { + "name": "f32_32x32", + "dtype": np.float32, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-6, + }, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tadd/compare.py new file mode 100644 index 000000000..6a4d5d1aa --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tadd/compare.py @@ -0,0 +1,48 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tadd/gen_data.py new file mode 100644 index 000000000..986dba17d --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tadd/gen_data.py @@ -0,0 +1,32 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + input2 = np.random.randint(1, 10, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + golden[:vr, :vc] = (input1[:vr, :vc] + input2[:vr, :vc]).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tadd/launch.cpp new file mode 100644 index 000000000..f1074c838 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tadd/launch.cpp @@ -0,0 +1,27 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: f32 16x64 +extern "C" __global__ AICORE void TADD_f32_16x64(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTADD_f32_16x64(float *a, float *b, float *c, void *stream) { + TADD_f32_16x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 1: f32 32x32 +extern "C" __global__ AICORE void TADD_f32_32x32(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTADD_f32_32x32(float *a, float *b, float *c, void *stream) { + TADD_f32_32x32<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tadd/main.cpp new file mode 100644 index 000000000..1a010623f --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tadd/main.cpp @@ -0,0 +1,145 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tadd ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTADD_f32_16x64(float *a, float *b, float *c, void *stream); +void LaunchTADD_f32_32x32(float *a, float *b, float *c, void *stream); + +using LaunchFn = void (*)(float *, float *, float *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_16x64", LaunchTADD_f32_16x64, 16, 64, 16, 64, sizeof(float)}, + {"f32_32x32", LaunchTADD_f32_32x32, 32, 32, 32, 32, sizeof(float)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t src0FileSize = fileSize; + size_t src1FileSize = fileSize; + + float *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + float *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), fileSize); + aclrtMallocHost((void **)(&src1Host), fileSize); + aclrtMallocHost((void **)(&dstHost), fileSize); + + aclrtMalloc((void **)&src0Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, fileSize, src0Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, fileSize, src1Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tadd [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadd/tadd.pto b/test/tilelang_st/npu/a5/src/st/testcase/tadd/tadd.pto new file mode 100644 index 000000000..94d01af64 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tadd/tadd.pto @@ -0,0 +1,123 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tadd: tload(a) + tload(b) + tadd(a,b)->c + tstore(c). +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // Case 0: f32 16x64 (1024 elements) + func.func @TADD_f32_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tadd ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + return + } + + // Case 1: f32 32x32 (1024 elements) + func.func @TADD_f32_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + outs(%b : !pto.tile_buf) + + pto.tadd ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + return + } +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadds/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tadds/CMakeLists.txt new file mode 100644 index 000000000..d4535a569 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tadds/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tadds) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadds/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tadds/cases.py new file mode 100644 index 000000000..5b24462bf --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tadds/cases.py @@ -0,0 +1,69 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tadds ST test cases. + +Shapes and dtypes match testcase/tadds (C++ GTest suite): + case1: float, 32x64, valid 32x64 + case2: float16, 63x64, valid 63x64 + case3: int32, 31x128, valid 31x128 + case4: int16, 15x192, valid 15x192 + case5: float, 7x448, valid 7x448 + case6: float, 256x16, valid 256x16 + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_32x64", + "dtype": np.float32, + "shape": (32, 64), + "valid_shape": (32, 64), + "eps": 1e-6, + }, + { + "name": "f16_63x64", + "dtype": np.float16, + "shape": (63, 64), + "valid_shape": (63, 64), + "eps": 1e-3, + }, + { + "name": "i32_31x128", + "dtype": np.int32, + "shape": (31, 128), + "valid_shape": (31, 128), + "eps": 0, + }, + { + "name": "i16_15x192", + "dtype": np.int16, + "shape": (15, 192), + "valid_shape": (15, 192), + "eps": 0, + }, + { + "name": "f32_7x448", + "dtype": np.float32, + "shape": (7, 448), + "valid_shape": (7, 448), + "eps": 1e-6, + }, + { + "name": "f32_256x16", + "dtype": np.float32, + "shape": (256, 16), + "valid_shape": (256, 16), + "eps": 1e-6, + }, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadds/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tadds/compare.py new file mode 100644 index 000000000..50186777e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tadds/compare.py @@ -0,0 +1,46 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadds/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tadds/gen_data.py new file mode 100644 index 000000000..c4f47c5f4 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tadds/gen_data.py @@ -0,0 +1,35 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +# Scalar value added to every element (matches the scalar passed in launch.cpp) +SCALAR = 3.0 + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + scalar_val = dtype(SCALAR) + golden[:vr, :vc] = (input1[:vr, :vc] + scalar_val).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__} scalar={SCALAR}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadds/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tadds/launch.cpp new file mode 100644 index 000000000..49f0c98ec --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tadds/launch.cpp @@ -0,0 +1,58 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Scalar value added to every element (must match gen_data.py SCALAR) +static constexpr float TADDS_SCALAR_F32 = 3.0f; + +// Case 0: f32 32x64 +extern "C" __global__ AICORE void TADDS_f32_32x64(__gm__ float *src, __gm__ float *dst, float scalar); + +void LaunchTADDS_f32_32x64(float *src, float *dst, void *stream) { + TADDS_f32_32x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, TADDS_SCALAR_F32); +} + +// Case 1: f16 63x64 +extern "C" __global__ AICORE void TADDS_f16_63x64(__gm__ unsigned short *src, __gm__ unsigned short *dst, unsigned short scalar); + +void LaunchTADDS_f16_63x64(unsigned short *src, unsigned short *dst, void *stream) { + TADDS_f16_63x64<<<1, nullptr, stream>>>((__gm__ unsigned short *)src, (__gm__ unsigned short *)dst, (unsigned short)0x4200); +} + +// Case 2: i32 31x128 +extern "C" __global__ AICORE void TADDS_i32_31x128(__gm__ int32_t *src, __gm__ int32_t *dst, int32_t scalar); + +void LaunchTADDS_i32_31x128(int32_t *src, int32_t *dst, void *stream) { + TADDS_i32_31x128<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst, (int32_t)3); +} + +// Case 3: i16 15x192 +extern "C" __global__ AICORE void TADDS_i16_15x192(__gm__ int16_t *src, __gm__ int16_t *dst, int16_t scalar); + +void LaunchTADDS_i16_15x192(int16_t *src, int16_t *dst, void *stream) { + TADDS_i16_15x192<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst, (int16_t)3); +} + +// Case 4: f32 7x448 +extern "C" __global__ AICORE void TADDS_f32_7x448(__gm__ float *src, __gm__ float *dst, float scalar); + +void LaunchTADDS_f32_7x448(float *src, float *dst, void *stream) { + TADDS_f32_7x448<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, TADDS_SCALAR_F32); +} + +// Case 5: f32 256x16 +extern "C" __global__ AICORE void TADDS_f32_256x16(__gm__ float *src, __gm__ float *dst, float scalar); + +void LaunchTADDS_f32_256x16(float *src, float *dst, void *stream) { + TADDS_f32_256x16<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, TADDS_SCALAR_F32); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadds/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tadds/main.cpp new file mode 100644 index 000000000..4c6f409dc --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tadds/main.cpp @@ -0,0 +1,139 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tadds ST — case-table driven. +// tadds: dst = src + scalar (single input + scalar, unlike tadd which has two inputs). +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTADDS_f32_32x64(float *src, float *dst, void *stream); +void LaunchTADDS_f16_63x64(uint16_t *src, uint16_t *dst, void *stream); +void LaunchTADDS_i32_31x128(int32_t *src, int32_t *dst, void *stream); +void LaunchTADDS_i16_15x192(int16_t *src, int16_t *dst, void *stream); +void LaunchTADDS_f32_7x448(float *src, float *dst, void *stream); +void LaunchTADDS_f32_256x16(float *src, float *dst, void *stream); + +struct TestCase { + const char *name; + void (*launch)(void *, void *, void *); // src, dst, stream + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_32x64", (void (*)(void*,void*,void*))LaunchTADDS_f32_32x64, 32, 64, 32, 64, sizeof(float)}, + {"f16_63x64", (void (*)(void*,void*,void*))LaunchTADDS_f16_63x64, 63, 64, 63, 64, sizeof(uint16_t)}, + {"i32_31x128", (void (*)(void*,void*,void*))LaunchTADDS_i32_31x128, 31, 128, 31, 128, sizeof(int32_t)}, + {"i16_15x192", (void (*)(void*,void*,void*))LaunchTADDS_i16_15x192, 15, 192, 15, 192, sizeof(int16_t)}, + {"f32_7x448", (void (*)(void*,void*,void*))LaunchTADDS_f32_7x448, 7, 448, 7, 448, sizeof(float)}, + {"f32_256x16", (void (*)(void*,void*,void*))LaunchTADDS_f32_256x16, 256, 16, 256, 16, sizeof(float)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSize = fileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&srcHost, fileSize); + aclrtMallocHost(&dstHost, fileSize); + + aclrtMalloc(&srcDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), srcFileSize, srcHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, fileSize, srcHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tadds [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tadds/tadds.pto b/test/tilelang_st/npu/a5/src/st/testcase/tadds/tadds.pto new file mode 100644 index 000000000..2057fc8df --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tadds/tadds.pto @@ -0,0 +1,256 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tadds: tload(src) + tadds(src, scalar)->dst + tstore(dst). +// Multiple cases with different shapes/dtypes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + + // Case 0: f32 32x64 (2048 elements) + func.func @TADDS_f32_32x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) + outs(%src : !pto.tile_buf) + pto.tadds ins(%src, %scalar : !pto.tile_buf, f32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) + return + } + + // Case 1: f16 63x64 (4032 elements) + func.func @TADDS_f16_63x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f16) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c4032 = arith.constant 4032 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xf16> -> !pto.partition_tensor_view<1x1x1x63x64xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xf16> -> !pto.partition_tensor_view<1x1x1x63x64xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x64xf16>) + outs(%src : !pto.tile_buf) + pto.tadds ins(%src, %scalar : !pto.tile_buf, f16) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x64xf16>) + return + } + + // Case 2: i32 31x128 (3968 elements) + func.func @TADDS_i32_31x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c31 = arith.constant 31 : index + %c128 = arith.constant 128 : index + %c3968 = arith.constant 3968 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c128] + : !pto.tensor_view<1x1x1x31x128xi32> -> !pto.partition_tensor_view<1x1x1x31x128xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c128] + : !pto.tensor_view<1x1x1x31x128xi32> -> !pto.partition_tensor_view<1x1x1x31x128xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x31x128xi32>) + outs(%src : !pto.tile_buf) + pto.tadds ins(%src, %scalar : !pto.tile_buf, i32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x31x128xi32>) + return + } + + // Case 3: i16 15x192 (2880 elements) + func.func @TADDS_i16_15x192(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i16) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c192 = arith.constant 192 : index + %c2880 = arith.constant 2880 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xi16> -> !pto.partition_tensor_view<1x1x1x15x192xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xi16> -> !pto.partition_tensor_view<1x1x1x15x192xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x192xi16>) + outs(%src : !pto.tile_buf) + pto.tadds ins(%src, %scalar : !pto.tile_buf, i16) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x15x192xi16>) + return + } + + // Case 4: f32 7x448 (3136 elements) + func.func @TADDS_f32_7x448(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c7 = arith.constant 7 : index + %c448 = arith.constant 448 : index + %c3136 = arith.constant 3136 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c7, %c448], + strides = [%c3136, %c3136, %c3136, %c448, %c1] + : !pto.tensor_view<1x1x1x7x448xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c7, %c448], + strides = [%c3136, %c3136, %c3136, %c448, %c1] + : !pto.tensor_view<1x1x1x7x448xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c448] + : !pto.tensor_view<1x1x1x7x448xf32> -> !pto.partition_tensor_view<1x1x1x7x448xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c448] + : !pto.tensor_view<1x1x1x7x448xf32> -> !pto.partition_tensor_view<1x1x1x7x448xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x7x448xf32>) + outs(%src : !pto.tile_buf) + pto.tadds ins(%src, %scalar : !pto.tile_buf, f32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x7x448xf32>) + return + } + + // Case 5: f32 256x16 (4096 elements) + func.func @TADDS_f32_256x16(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c16 = arith.constant 16 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c256, %c16], + strides = [%c4096, %c4096, %c4096, %c16, %c1] + : !pto.tensor_view<1x1x1x256x16xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c256, %c16], + strides = [%c4096, %c4096, %c4096, %c16, %c1] + : !pto.tensor_view<1x1x1x256x16xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c16] + : !pto.tensor_view<1x1x1x256x16xf32> -> !pto.partition_tensor_view<1x1x1x256x16xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c16] + : !pto.tensor_view<1x1x1x256x16xf32> -> !pto.partition_tensor_view<1x1x1x256x16xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x256x16xf32>) + outs(%src : !pto.tile_buf) + pto.tadds ins(%src, %scalar : !pto.tile_buf, f32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x256x16xf32>) + return + } + +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tand/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tand/CMakeLists.txt new file mode 100644 index 000000000..230a97296 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tand/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tand) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tand/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tand/cases.py new file mode 100644 index 000000000..8c40489b1 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tand/cases.py @@ -0,0 +1,40 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tand ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.int32). + - shape: (rows, cols) — allocated tile dimensions. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "i32_16x64", + "dtype": np.int32, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 0, + }, + { + "name": "i32_32x32", + "dtype": np.int32, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 0, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tand/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tand/compare.py new file mode 100644 index 000000000..4eae3bc07 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tand/compare.py @@ -0,0 +1,48 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tand/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tand/gen_data.py new file mode 100644 index 000000000..64829832e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tand/gen_data.py @@ -0,0 +1,32 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + input1 = np.random.randint(0, 100, size=shape).astype(dtype) + input2 = np.random.randint(0, 100, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + golden[:vr, :vc] = (input1[:vr, :vc] & input2[:vr, :vc]).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tand/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tand/launch.cpp new file mode 100644 index 000000000..ed3149c6e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tand/launch.cpp @@ -0,0 +1,27 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: i32 16x64 +extern "C" __global__ AICORE void TAND_i32_16x64(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); + +void LaunchTAND_i32_16x64(int32_t *a, int32_t *b, int32_t *c, void *stream) { + TAND_i32_16x64<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); +} + +// Case 1: i32 32x32 +extern "C" __global__ AICORE void TAND_i32_32x32(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); + +void LaunchTAND_i32_32x32(int32_t *a, int32_t *b, int32_t *c, void *stream) { + TAND_i32_32x32<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tand/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tand/main.cpp new file mode 100644 index 000000000..21b90e9b3 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tand/main.cpp @@ -0,0 +1,145 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tand ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTAND_i32_16x64(int32_t *a, int32_t *b, int32_t *c, void *stream); +void LaunchTAND_i32_32x32(int32_t *a, int32_t *b, int32_t *c, void *stream); + +using LaunchFn = void (*)(int32_t *, int32_t *, int32_t *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"i32_16x64", LaunchTAND_i32_16x64, 16, 64, 16, 64, sizeof(int32_t)}, + {"i32_32x32", LaunchTAND_i32_32x32, 32, 32, 32, 32, sizeof(int32_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t src0FileSize = fileSize; + size_t src1FileSize = fileSize; + + int32_t *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + int32_t *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), fileSize); + aclrtMallocHost((void **)(&src1Host), fileSize); + aclrtMallocHost((void **)(&dstHost), fileSize); + + aclrtMalloc((void **)&src0Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, fileSize, src0Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, fileSize, src1Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tand [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tand/tand.pto b/test/tilelang_st/npu/a5/src/st/testcase/tand/tand.pto new file mode 100644 index 000000000..c1380f2b5 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tand/tand.pto @@ -0,0 +1,123 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tand: tload(a) + tload(b) + tand(a,b)->c + tstore(c). +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // Case 0: i32 16x64 (1024 elements) + func.func @TAND_i32_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi32> -> !pto.partition_tensor_view<1x1x1x16x64xi32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi32> -> !pto.partition_tensor_view<1x1x1x16x64xi32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi32> -> !pto.partition_tensor_view<1x1x1x16x64xi32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xi32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x64xi32>) + outs(%b : !pto.tile_buf) + + pto.tand ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x16x64xi32>) + return + } + + // Case 1: i32 32x32 (1024 elements) + func.func @TAND_i32_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xi32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xi32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xi32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xi32> -> !pto.partition_tensor_view<1x1x1x32x32xi32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xi32> -> !pto.partition_tensor_view<1x1x1x32x32xi32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xi32> -> !pto.partition_tensor_view<1x1x1x32x32xi32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xi32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x32x32xi32>) + outs(%b : !pto.tile_buf) + + pto.tand ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x32x32xi32>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tands/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tands/CMakeLists.txt new file mode 100644 index 000000000..0ff088f8d --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tands/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tands) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tands/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tands/cases.py new file mode 100644 index 000000000..18cc99178 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tands/cases.py @@ -0,0 +1,42 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np + +CASES = [ + { + "name": "i32_32x64", + "dtype": np.int32, + "shape": (32, 64), + "valid_shape": (32, 64), + "eps": 0, + }, + { + "name": "i16_63x64", + "dtype": np.int16, + "shape": (63, 64), + "valid_shape": (63, 64), + "eps": 0, + }, + { + "name": "i32_31x128", + "dtype": np.int32, + "shape": (31, 128), + "valid_shape": (31, 128), + "eps": 0, + }, + { + "name": "i16_15x192", + "dtype": np.int16, + "shape": (15, 192), + "valid_shape": (15, 192), + "eps": 0, + }, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tands/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tands/compare.py new file mode 100644 index 000000000..50186777e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tands/compare.py @@ -0,0 +1,46 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tands/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tands/gen_data.py new file mode 100644 index 000000000..9f187cb03 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tands/gen_data.py @@ -0,0 +1,35 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +# Scalar value for bitwise AND (must match launch.cpp) +SCALAR = 3 + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + scalar_val = dtype(SCALAR) + golden[:vr, :vc] = (input1[:vr, :vc] & scalar_val).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__} scalar={SCALAR}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tands/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tands/launch.cpp new file mode 100644 index 000000000..8226ac79e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tands/launch.cpp @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Scalar value for bitwise AND (must match gen_data.py SCALAR) +static constexpr int32_t TANDS_SCALAR_I32 = 3; +static constexpr int16_t TANDS_SCALAR_I16 = 3; + +// Case 0: i32 32x64 +extern "C" __global__ AICORE void TANDS_i32_32x64(__gm__ int32_t *src, __gm__ int32_t *dst, int32_t scalar); + +void LaunchTANDS_i32_32x64(int32_t *src, int32_t *dst, void *stream) { + TANDS_i32_32x64<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst, TANDS_SCALAR_I32); +} + +// Case 1: i16 63x64 +extern "C" __global__ AICORE void TANDS_i16_63x64(__gm__ int16_t *src, __gm__ int16_t *dst, int16_t scalar); + +void LaunchTANDS_i16_63x64(int16_t *src, int16_t *dst, void *stream) { + TANDS_i16_63x64<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst, TANDS_SCALAR_I16); +} + +// Case 2: i32 31x128 +extern "C" __global__ AICORE void TANDS_i32_31x128(__gm__ int32_t *src, __gm__ int32_t *dst, int32_t scalar); + +void LaunchTANDS_i32_31x128(int32_t *src, int32_t *dst, void *stream) { + TANDS_i32_31x128<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst, TANDS_SCALAR_I32); +} + +// Case 3: i16 15x192 +extern "C" __global__ AICORE void TANDS_i16_15x192(__gm__ int16_t *src, __gm__ int16_t *dst, int16_t scalar); + +void LaunchTANDS_i16_15x192(int16_t *src, int16_t *dst, void *stream) { + TANDS_i16_15x192<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst, TANDS_SCALAR_I16); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tands/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tands/main.cpp new file mode 100644 index 000000000..e0f93f2e7 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tands/main.cpp @@ -0,0 +1,135 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tands ST — case-table driven. +// tands: dst = src & scalar (single input + scalar, bitwise AND). +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTANDS_i32_32x64(int32_t *src, int32_t *dst, void *stream); +void LaunchTANDS_i16_63x64(int16_t *src, int16_t *dst, void *stream); +void LaunchTANDS_i32_31x128(int32_t *src, int32_t *dst, void *stream); +void LaunchTANDS_i16_15x192(int16_t *src, int16_t *dst, void *stream); + +struct TestCase { + const char *name; + void (*launch)(void *, void *, void *); // src, dst, stream + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"i32_32x64", (void (*)(void*,void*,void*))LaunchTANDS_i32_32x64, 32, 64, 32, 64, sizeof(int32_t)}, + {"i16_63x64", (void (*)(void*,void*,void*))LaunchTANDS_i16_63x64, 63, 64, 63, 64, sizeof(int16_t)}, + {"i32_31x128", (void (*)(void*,void*,void*))LaunchTANDS_i32_31x128, 31, 128, 31, 128, sizeof(int32_t)}, + {"i16_15x192", (void (*)(void*,void*,void*))LaunchTANDS_i16_15x192, 15, 192, 15, 192, sizeof(int16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSize = fileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&srcHost, fileSize); + aclrtMallocHost(&dstHost, fileSize); + + aclrtMalloc(&srcDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), srcFileSize, srcHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, fileSize, srcHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tands [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tands/tands.pto b/test/tilelang_st/npu/a5/src/st/testcase/tands/tands.pto new file mode 100644 index 000000000..24272369f --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tands/tands.pto @@ -0,0 +1,176 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tands: tload(src) + tands(src, scalar)->dst + tstore(dst). +// Multiple cases with different shapes/dtypes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + + // Case 0: i32 32x64 (2048 elements) + func.func @TANDS_i32_32x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xi32> -> !pto.partition_tensor_view<1x1x1x32x64xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xi32> -> !pto.partition_tensor_view<1x1x1x32x64xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x64xi32>) + outs(%src : !pto.tile_buf) + pto.tands ins(%src, %scalar : !pto.tile_buf, i32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x64xi32>) + return + } + + // Case 1: i16 63x64 (4032 elements) + func.func @TANDS_i16_63x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i16) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c4032 = arith.constant 4032 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xi16> -> !pto.partition_tensor_view<1x1x1x63x64xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xi16> -> !pto.partition_tensor_view<1x1x1x63x64xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x64xi16>) + outs(%src : !pto.tile_buf) + pto.tands ins(%src, %scalar : !pto.tile_buf, i16) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x64xi16>) + return + } + + // Case 2: i32 31x128 (3968 elements) + func.func @TANDS_i32_31x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c31 = arith.constant 31 : index + %c128 = arith.constant 128 : index + %c3968 = arith.constant 3968 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c128] + : !pto.tensor_view<1x1x1x31x128xi32> -> !pto.partition_tensor_view<1x1x1x31x128xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c128] + : !pto.tensor_view<1x1x1x31x128xi32> -> !pto.partition_tensor_view<1x1x1x31x128xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x31x128xi32>) + outs(%src : !pto.tile_buf) + pto.tands ins(%src, %scalar : !pto.tile_buf, i32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x31x128xi32>) + return + } + + // Case 3: i16 15x192 (2880 elements) + func.func @TANDS_i16_15x192(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i16) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c192 = arith.constant 192 : index + %c2880 = arith.constant 2880 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xi16> -> !pto.partition_tensor_view<1x1x1x15x192xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xi16> -> !pto.partition_tensor_view<1x1x1x15x192xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x192xi16>) + outs(%src : !pto.tile_buf) + pto.tands ins(%src, %scalar : !pto.tile_buf, i16) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x15x192xi16>) + return + } + +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcmp/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tcmp/CMakeLists.txt new file mode 100644 index 000000000..a863ea151 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcmp/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tcmp) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcmp/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tcmp/cases.py new file mode 100644 index 000000000..079a7ae94 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcmp/cases.py @@ -0,0 +1,140 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tcmp ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions (same for src and dst). + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - dst_dtype: output mask dtype (i8 - packed mask, same shape as input). + - cmp_mode: comparison mode: "eq", "ne", "lt", "gt", "ge", "le". + - eps: tolerance (exact match for masks, eps=0). + +Aligned with testcase/tcmp test cases. + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + # Case 1: f16 32x32 EQ (half_32x32_32x32) + { + "name": "f16_32x32_eq", + "dtype": np.float16, + "shape": (32, 32), + "valid_shape": (32, 32), + "dst_dtype": np.int8, + "cmp_mode": "eq", + "eps": 0, + }, + # Case 2: f32 8x64 GT (float_8x64_8x64) + { + "name": "f32_8x64_gt", + "dtype": np.float32, + "shape": (8, 64), + "valid_shape": (8, 64), + "dst_dtype": np.int8, + "cmp_mode": "gt", + "eps": 0, + }, + # Case 3: i32 4x64 NE (int32_4x64_4x64) + { + "name": "i32_4x64_ne", + "dtype": np.int32, + "shape": (4, 64), + "valid_shape": (4, 64), + "dst_dtype": np.int8, + "cmp_mode": "ne", + "eps": 0, + }, + # Case 4: i32 128x128 LT with valid 64x64 (int32_128x128_64x64) + { + "name": "i32_128x128_lt", + "dtype": np.int32, + "shape": (128, 128), + "valid_shape": (64, 64), + "dst_dtype": np.int8, + "cmp_mode": "lt", + "eps": 0, + }, + # Case 5: i32 64x64 EQ with valid 32x32 (int32_64x64_32x32) + { + "name": "i32_64x64_eq", + "dtype": np.int32, + "shape": (64, 64), + "valid_shape": (32, 32), + "dst_dtype": np.int8, + "cmp_mode": "eq", + "eps": 0, + }, + # Case 6: i32 16x32 EQ (int32_16x32_16x32) + { + "name": "i32_16x32_eq", + "dtype": np.int32, + "shape": (16, 32), + "valid_shape": (16, 32), + "dst_dtype": np.int8, + "cmp_mode": "eq", + "eps": 0, + }, + # Case 7: f32 128x128 LE with valid 64x64 (float_128x128_64x64) + { + "name": "f32_128x128_le", + "dtype": np.float32, + "shape": (128, 128), + "valid_shape": (64, 64), + "dst_dtype": np.int8, + "cmp_mode": "le", + "eps": 0, + }, + # Case 8: i32 77x96 EQ with valid 32x32 (int32_77x96_32x32) + { + "name": "i32_77x96_eq", + "dtype": np.int32, + "shape": (77, 96), + "valid_shape": (32, 32), + "dst_dtype": np.int8, + "cmp_mode": "eq", + "eps": 0, + }, + # Case 9: i32 32x32 EQ (int32_32x32_32x32) + { + "name": "i32_32x32_eq", + "dtype": np.int32, + "shape": (32, 32), + "valid_shape": (32, 32), + "dst_dtype": np.int8, + "cmp_mode": "eq", + "eps": 0, + }, + # Case 10: i16 32x32 EQ with valid 16x32 (int16_32x32_16x32) + { + "name": "i16_32x32_eq", + "dtype": np.int16, + "shape": (32, 32), + "valid_shape": (16, 32), + "dst_dtype": np.int8, + "cmp_mode": "eq", + "eps": 0, + }, + # Case 11: i16 77x96 LE with valid 32x32 (int16_77x96_32x32) + { + "name": "i16_77x96_le", + "dtype": np.int16, + "shape": (77, 96), + "valid_shape": (32, 32), + "dst_dtype": np.int8, + "cmp_mode": "le", + "eps": 0, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcmp/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tcmp/compare.py new file mode 100644 index 000000000..fd65cbe72 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcmp/compare.py @@ -0,0 +1,54 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + dst_dtype = case["dst_dtype"] + valid_shape = case["valid_shape"] + vr, vc = valid_shape + + # Only compare the packed mask region: rows x (cols//8) + packed_cols = vc // 8 + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=dst_dtype).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=dst_dtype).reshape(shape) + + # Compare packed mask output in valid region + ok = result_cmp(golden[:vr, :packed_cols], output[:vr, :packed_cols], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcmp/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tcmp/gen_data.py new file mode 100644 index 000000000..b415c1e00 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcmp/gen_data.py @@ -0,0 +1,70 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + dst_dtype = case["dst_dtype"] + cmp_mode = case["cmp_mode"] + + # Generate random input data + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + input2 = np.random.randint(1, 10, size=shape).astype(dtype) + + # Compute comparison mask (boolean) + vr, vc = valid_shape + mask_bits = np.zeros(shape, dtype=np.bool_) + input1_valid = input1[:vr, :vc] + input2_valid = input2[:vr, :vc] + + if cmp_mode == "eq": + mask_bits[:vr, :vc] = (input1_valid == input2_valid) + elif cmp_mode == "ne": + mask_bits[:vr, :vc] = (input1_valid != input2_valid) + elif cmp_mode == "lt": + mask_bits[:vr, :vc] = (input1_valid < input2_valid) + elif cmp_mode == "gt": + mask_bits[:vr, :vc] = (input1_valid > input2_valid) + elif cmp_mode == "ge": + mask_bits[:vr, :vc] = (input1_valid >= input2_valid) + elif cmp_mode == "le": + mask_bits[:vr, :vc] = (input1_valid <= input2_valid) + + # dst shape is same as src shape, but only first cols//8 columns store packed mask bytes + # remaining columns are padding (zeros) + # Use uint8 first to avoid overflow, then cast to int8 + golden = np.zeros(shape, dtype=np.uint8) + + # Pack mask bits: each byte stores 8 comparison results (1 bit each) + packed_cols = vc // 8 # number of byte columns that store actual packed data + + for row in range(vr): + for col_byte in range(packed_cols): + byte_val = 0 + for bit in range(8): + src_col = col_byte * 8 + bit + if src_col < vc and mask_bits[row, src_col]: + byte_val |= (1 << bit) + golden[row, col_byte] = byte_val + + # Cast to int8 for final output + golden = golden.astype(dst_dtype) + + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} cmp_mode={cmp_mode}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcmp/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcmp/launch.cpp new file mode 100644 index 000000000..b3114e2f1 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcmp/launch.cpp @@ -0,0 +1,90 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 1: f16 32x32 eq (half_32x32_32x32) +extern "C" __global__ AICORE void TCMP_f16_32x32_eq(__gm__ uint16_t *a, __gm__ uint16_t *b, __gm__ int8_t *c); + +void LaunchTCMP_f16_32x32_eq(uint16_t *a, uint16_t *b, int8_t *c, void *stream) { + TCMP_f16_32x32_eq<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b, (__gm__ int8_t *)c); +} + +// Case 2: f32 8x64 gt (float_8x64_8x64) +extern "C" __global__ AICORE void TCMP_f32_8x64_gt(__gm__ float *a, __gm__ float *b, __gm__ int8_t *c); + +void LaunchTCMP_f32_8x64_gt(float *a, float *b, int8_t *c, void *stream) { + TCMP_f32_8x64_gt<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ int8_t *)c); +} + +// Case 3: i32 4x64 ne (int32_4x64_4x64) +extern "C" __global__ AICORE void TCMP_i32_4x64_ne(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int8_t *c); + +void LaunchTCMP_i32_4x64_ne(int32_t *a, int32_t *b, int8_t *c, void *stream) { + TCMP_i32_4x64_ne<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int8_t *)c); +} + +// Case 4: i32 128x128 lt with valid 64x64 (int32_128x128_64x64) +extern "C" __global__ AICORE void TCMP_i32_128x128_lt(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int8_t *c); + +void LaunchTCMP_i32_128x128_lt(int32_t *a, int32_t *b, int8_t *c, void *stream) { + TCMP_i32_128x128_lt<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int8_t *)c); +} + +// Case 5: i32 64x64 eq with valid 32x32 (int32_64x64_32x32) +extern "C" __global__ AICORE void TCMP_i32_64x64_eq(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int8_t *c); + +void LaunchTCMP_i32_64x64_eq(int32_t *a, int32_t *b, int8_t *c, void *stream) { + TCMP_i32_64x64_eq<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int8_t *)c); +} + +// Case 6: i32 16x32 eq (int32_16x32_16x32) +extern "C" __global__ AICORE void TCMP_i32_16x32_eq(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int8_t *c); + +void LaunchTCMP_i32_16x32_eq(int32_t *a, int32_t *b, int8_t *c, void *stream) { + TCMP_i32_16x32_eq<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int8_t *)c); +} + +// Case 7: f32 128x128 le with valid 64x64 (float_128x128_64x64) +extern "C" __global__ AICORE void TCMP_f32_128x128_le(__gm__ float *a, __gm__ float *b, __gm__ int8_t *c); + +void LaunchTCMP_f32_128x128_le(float *a, float *b, int8_t *c, void *stream) { + TCMP_f32_128x128_le<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ int8_t *)c); +} + +// Case 8: i32 77x96 eq with valid 32x32 (int32_77x96_32x32) +extern "C" __global__ AICORE void TCMP_i32_77x96_eq(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int8_t *c); + +void LaunchTCMP_i32_77x96_eq(int32_t *a, int32_t *b, int8_t *c, void *stream) { + TCMP_i32_77x96_eq<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int8_t *)c); +} + +// Case 9: i32 32x32 eq (int32_32x32_32x32) +extern "C" __global__ AICORE void TCMP_i32_32x32_eq(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int8_t *c); + +void LaunchTCMP_i32_32x32_eq(int32_t *a, int32_t *b, int8_t *c, void *stream) { + TCMP_i32_32x32_eq<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int8_t *)c); +} + +// Case 10: i16 32x32 eq with valid 16x32 (int16_32x32_16x32) +extern "C" __global__ AICORE void TCMP_i16_32x32_eq(__gm__ int16_t *a, __gm__ int16_t *b, __gm__ int8_t *c); + +void LaunchTCMP_i16_32x32_eq(int16_t *a, int16_t *b, int8_t *c, void *stream) { + TCMP_i16_32x32_eq<<<1, nullptr, stream>>>((__gm__ int16_t *)a, (__gm__ int16_t *)b, (__gm__ int8_t *)c); +} + +// Case 11: i16 77x96 le with valid 32x32 (int16_77x96_32x32) +extern "C" __global__ AICORE void TCMP_i16_77x96_le(__gm__ int16_t *a, __gm__ int16_t *b, __gm__ int8_t *c); + +void LaunchTCMP_i16_77x96_le(int16_t *a, int16_t *b, int8_t *c, void *stream) { + TCMP_i16_77x96_le<<<1, nullptr, stream>>>((__gm__ int16_t *)a, (__gm__ int16_t *)b, (__gm__ int8_t *)c); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcmp/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcmp/main.cpp new file mode 100644 index 000000000..a1199de5c --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcmp/main.cpp @@ -0,0 +1,172 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tcmp ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. +// Aligned with testcase/tcmp test cases. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTCMP_f16_32x32_eq(uint16_t *a, uint16_t *b, int8_t *c, void *stream); +void LaunchTCMP_f32_8x64_gt(float *a, float *b, int8_t *c, void *stream); +void LaunchTCMP_i32_4x64_ne(int32_t *a, int32_t *b, int8_t *c, void *stream); +void LaunchTCMP_i32_128x128_lt(int32_t *a, int32_t *b, int8_t *c, void *stream); +void LaunchTCMP_i32_64x64_eq(int32_t *a, int32_t *b, int8_t *c, void *stream); +void LaunchTCMP_i32_16x32_eq(int32_t *a, int32_t *b, int8_t *c, void *stream); +void LaunchTCMP_f32_128x128_le(float *a, float *b, int8_t *c, void *stream); +void LaunchTCMP_i32_77x96_eq(int32_t *a, int32_t *b, int8_t *c, void *stream); +void LaunchTCMP_i32_32x32_eq(int32_t *a, int32_t *b, int8_t *c, void *stream); +void LaunchTCMP_i16_32x32_eq(int16_t *a, int16_t *b, int8_t *c, void *stream); +void LaunchTCMP_i16_77x96_le(int16_t *a, int16_t *b, int8_t *c, void *stream); + +struct TestCase { + const char *name; + void (*launch)(void *, void *, void *, void *); + size_t rows; + size_t cols; + size_t srcElemSize; + size_t dstElemSize; +}; + +static const TestCase kCases[] = { + // Case 1: f16 32x32 eq (half_32x32_32x32) + {"f16_32x32_eq", (void (*)(void*, void*, void*, void*))LaunchTCMP_f16_32x32_eq, 32, 32, sizeof(uint16_t), sizeof(int8_t)}, + // Case 2: f32 8x64 gt (float_8x64_8x64) + {"f32_8x64_gt", (void (*)(void*, void*, void*, void*))LaunchTCMP_f32_8x64_gt, 8, 64, sizeof(float), sizeof(int8_t)}, + // Case 3: i32 4x64 ne (int32_4x64_4x64) + {"i32_4x64_ne", (void (*)(void*, void*, void*, void*))LaunchTCMP_i32_4x64_ne, 4, 64, sizeof(int32_t), sizeof(int8_t)}, + // Case 4: i32 128x128 lt with valid 64x64 (int32_128x128_64x64) + {"i32_128x128_lt", (void (*)(void*, void*, void*, void*))LaunchTCMP_i32_128x128_lt, 128, 128, sizeof(int32_t), sizeof(int8_t)}, + // Case 5: i32 64x64 eq with valid 32x32 (int32_64x64_32x32) + {"i32_64x64_eq", (void (*)(void*, void*, void*, void*))LaunchTCMP_i32_64x64_eq, 64, 64, sizeof(int32_t), sizeof(int8_t)}, + // Case 6: i32 16x32 eq (int32_16x32_16x32) + {"i32_16x32_eq", (void (*)(void*, void*, void*, void*))LaunchTCMP_i32_16x32_eq, 16, 32, sizeof(int32_t), sizeof(int8_t)}, + // Case 7: f32 128x128 le with valid 64x64 (float_128x128_64x64) + {"f32_128x128_le", (void (*)(void*, void*, void*, void*))LaunchTCMP_f32_128x128_le, 128, 128, sizeof(float), sizeof(int8_t)}, + // Case 8: i32 77x96 eq with valid 32x32 (int32_77x96_32x32) + {"i32_77x96_eq", (void (*)(void*, void*, void*, void*))LaunchTCMP_i32_77x96_eq, 77, 96, sizeof(int32_t), sizeof(int8_t)}, + // Case 9: i32 32x32 eq (int32_32x32_32x32) + {"i32_32x32_eq", (void (*)(void*, void*, void*, void*))LaunchTCMP_i32_32x32_eq, 32, 32, sizeof(int32_t), sizeof(int8_t)}, + // Case 10: i16 32x32 eq with valid 16x32 (int16_32x32_16x32) + {"i16_32x32_eq", (void (*)(void*, void*, void*, void*))LaunchTCMP_i16_32x32_eq, 32, 32, sizeof(int16_t), sizeof(int8_t)}, + // Case 11: i16 77x96 le with valid 32x32 (int16_77x96_32x32) + {"i16_77x96_le", (void (*)(void*, void*, void*, void*))LaunchTCMP_i16_77x96_le, 77, 96, sizeof(int16_t), sizeof(int8_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t srcFileSize = tc.rows * tc.cols * tc.srcElemSize; + const size_t dstFileSize = tc.rows * tc.cols * tc.dstElemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols); + + std::string caseDir = std::string("./") + tc.name; + + void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), srcFileSize); + aclrtMallocHost((void **)(&src1Host), srcFileSize); + aclrtMallocHost((void **)(&dstHost), dstFileSize); + + aclrtMalloc((void **)&src0Device, srcFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, srcFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + size_t src0FileSize = srcFileSize; + size_t src1FileSize = srcFileSize; + size_t dstFileSizeActual = dstFileSize; + + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, srcFileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1Host, srcFileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, srcFileSize, src0Host, srcFileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, srcFileSize, src1Host, srcFileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSizeActual)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcmp/tcmp.pto b/test/tilelang_st/npu/a5/src/st/testcase/tcmp/tcmp.pto new file mode 100644 index 000000000..99f31fb03 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcmp/tcmp.pto @@ -0,0 +1,717 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tcmp: tload(a) + tload(b) + tcmp(a,b)->c(mask) + tstore(c). +// Output mask is packed: 1 bit per element, stored as i8 array (same shape as input). +// Aligned with testcase/tcmp test cases. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // Case 1: f16 32x32 EQ (half_32x32_32x32) + func.func @TCMP_f16_32x32_eq(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xi8> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf16> -> !pto.partition_tensor_view<1x1x1x32x32xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf16> -> !pto.partition_tensor_view<1x1x1x32x32xf16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xi8> -> !pto.partition_tensor_view<1x1x1x32x32xi8> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xf16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x32x32xf16>) + outs(%b : !pto.tile_buf) + + pto.tcmp ins(%a, %b {cmpMode = #pto} : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x32x32xi8>) + return + } + + // Case 2: f32 8x64 GT (float_8x64_8x64) + func.func @TCMP_f32_8x64_gt(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c512 = arith.constant 512 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c8, %c64], + strides = [%c512, %c512, %c512, %c64, %c1] + : !pto.tensor_view<1x1x1x8x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c8, %c64], + strides = [%c512, %c512, %c512, %c64, %c1] + : !pto.tensor_view<1x1x1x8x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c8, %c64], + strides = [%c512, %c512, %c512, %c64, %c1] + : !pto.tensor_view<1x1x1x8x64xi8> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c64] + : !pto.tensor_view<1x1x1x8x64xf32> -> !pto.partition_tensor_view<1x1x1x8x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c64] + : !pto.tensor_view<1x1x1x8x64xf32> -> !pto.partition_tensor_view<1x1x1x8x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c64] + : !pto.tensor_view<1x1x1x8x64xi8> -> !pto.partition_tensor_view<1x1x1x8x64xi8> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x8x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x8x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tcmp ins(%a, %b {cmpMode = #pto} : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x8x64xi8>) + return + } + + // Case 3: i32 4x64 NE (int32_4x64_4x64) + func.func @TCMP_i32_4x64_ne(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c64 = arith.constant 64 : index + %c256 = arith.constant 256 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c4, %c64], + strides = [%c256, %c256, %c256, %c64, %c1] + : !pto.tensor_view<1x1x1x4x64xi32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c4, %c64], + strides = [%c256, %c256, %c256, %c64, %c1] + : !pto.tensor_view<1x1x1x4x64xi32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c4, %c64], + strides = [%c256, %c256, %c256, %c64, %c1] + : !pto.tensor_view<1x1x1x4x64xi8> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c64] + : !pto.tensor_view<1x1x1x4x64xi32> -> !pto.partition_tensor_view<1x1x1x4x64xi32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c64] + : !pto.tensor_view<1x1x1x4x64xi32> -> !pto.partition_tensor_view<1x1x1x4x64xi32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c64] + : !pto.tensor_view<1x1x1x4x64xi8> -> !pto.partition_tensor_view<1x1x1x4x64xi8> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x4x64xi32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x4x64xi32>) + outs(%b : !pto.tile_buf) + + pto.tcmp ins(%a, %b {cmpMode = #pto} : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x4x64xi8>) + return + } + + // Case 4: i32 128x128 LT with valid 64x64 (int32_128x128_64x64) + func.func @TCMP_i32_128x128_lt(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c64 = arith.constant 64 : index + %c16384 = arith.constant 16384 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c128, %c128], + strides = [%c16384, %c16384, %c16384, %c128, %c1] + : !pto.tensor_view<1x1x1x128x128xi32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c128, %c128], + strides = [%c16384, %c16384, %c16384, %c128, %c1] + : !pto.tensor_view<1x1x1x128x128xi32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c128, %c128], + strides = [%c16384, %c16384, %c16384, %c128, %c1] + : !pto.tensor_view<1x1x1x128x128xi8> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c128] + : !pto.tensor_view<1x1x1x128x128xi32> -> !pto.partition_tensor_view<1x1x1x128x128xi32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c128] + : !pto.tensor_view<1x1x1x128x128xi32> -> !pto.partition_tensor_view<1x1x1x128x128xi32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x128x128xi8> -> !pto.partition_tensor_view<1x1x1x64x64xi8> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x128x128xi32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x128x128xi32>) + outs(%b : !pto.tile_buf) + + pto.tcmp ins(%a, %b {cmpMode = #pto} : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xi8>) + return + } + + // Case 5: i32 64x64 EQ with valid 32x32 (int32_64x64_32x32) + func.func @TCMP_i32_64x64_eq(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi8> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> + %c32 = arith.constant 32 : index + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x64x64xi8> -> !pto.partition_tensor_view<1x1x1x32x32xi8> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) + outs(%b : !pto.tile_buf) + + pto.tcmp ins(%a, %b {cmpMode = #pto} : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x32x32xi8>) + return + } + + // Case 6: i32 16x32 EQ (int32_16x32_16x32) + func.func @TCMP_i32_16x32_eq(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c512 = arith.constant 512 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi8> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xi8> -> !pto.partition_tensor_view<1x1x1x16x32xi8> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) + outs(%b : !pto.tile_buf) + + pto.tcmp ins(%a, %b {cmpMode = #pto} : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x16x32xi8>) + return + } + + // Case 7: f32 128x128 LE with valid 64x64 (float_128x128_64x64) + func.func @TCMP_f32_128x128_le(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c16384 = arith.constant 16384 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c128, %c128], + strides = [%c16384, %c16384, %c16384, %c128, %c1] + : !pto.tensor_view<1x1x1x128x128xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c128, %c128], + strides = [%c16384, %c16384, %c16384, %c128, %c1] + : !pto.tensor_view<1x1x1x128x128xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c128, %c128], + strides = [%c16384, %c16384, %c16384, %c128, %c1] + : !pto.tensor_view<1x1x1x128x128xi8> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c128] + : !pto.tensor_view<1x1x1x128x128xf32> -> !pto.partition_tensor_view<1x1x1x128x128xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c128] + : !pto.tensor_view<1x1x1x128x128xf32> -> !pto.partition_tensor_view<1x1x1x128x128xf32> + %c64 = arith.constant 64 : index + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x128x128xi8> -> !pto.partition_tensor_view<1x1x1x64x64xi8> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x128x128xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x128x128xf32>) + outs(%b : !pto.tile_buf) + + pto.tcmp ins(%a, %b {cmpMode = #pto} : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xi8>) + return + } + + // Case 8: i32 77x96 EQ with valid 32x32 (int32_77x96_32x32) + func.func @TCMP_i32_77x96_eq(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c77 = arith.constant 77 : index + %c96 = arith.constant 96 : index + %c7392 = arith.constant 7392 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c77, %c96], + strides = [%c7392, %c7392, %c7392, %c96, %c1] + : !pto.tensor_view<1x1x1x77x96xi32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c77, %c96], + strides = [%c7392, %c7392, %c7392, %c96, %c1] + : !pto.tensor_view<1x1x1x77x96xi32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c77, %c96], + strides = [%c7392, %c7392, %c7392, %c96, %c1] + : !pto.tensor_view<1x1x1x77x96xi8> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c77, %c96] + : !pto.tensor_view<1x1x1x77x96xi32> -> !pto.partition_tensor_view<1x1x1x77x96xi32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c77, %c96] + : !pto.tensor_view<1x1x1x77x96xi32> -> !pto.partition_tensor_view<1x1x1x77x96xi32> + %c32 = arith.constant 32 : index + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x77x96xi8> -> !pto.partition_tensor_view<1x1x1x32x32xi8> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x77x96xi32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x77x96xi32>) + outs(%b : !pto.tile_buf) + + pto.tcmp ins(%a, %b {cmpMode = #pto} : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x32x32xi8>) + return + } + + // Case 9: i32 32x32 EQ (int32_32x32_32x32) + func.func @TCMP_i32_32x32_eq(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xi32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xi32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xi8> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xi32> -> !pto.partition_tensor_view<1x1x1x32x32xi32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xi32> -> !pto.partition_tensor_view<1x1x1x32x32xi32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xi8> -> !pto.partition_tensor_view<1x1x1x32x32xi8> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xi32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x32x32xi32>) + outs(%b : !pto.tile_buf) + + pto.tcmp ins(%a, %b {cmpMode = #pto} : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x32x32xi8>) + return + } + + // Case 10: i16 32x32 EQ with valid 16x32 (int16_32x32_16x32) + func.func @TCMP_i16_32x32_eq(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xi16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xi16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xi8> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xi16> -> !pto.partition_tensor_view<1x1x1x32x32xi16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xi16> -> !pto.partition_tensor_view<1x1x1x32x32xi16> + %c16 = arith.constant 16 : index + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x32x32xi8> -> !pto.partition_tensor_view<1x1x1x16x32xi8> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xi16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x32x32xi16>) + outs(%b : !pto.tile_buf) + + pto.tcmp ins(%a, %b {cmpMode = #pto} : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x16x32xi8>) + return + } + + // Case 11: i16 77x96 LE with valid 32x32 (int16_77x96_32x32) + func.func @TCMP_i16_77x96_le(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c77 = arith.constant 77 : index + %c96 = arith.constant 96 : index + %c7392 = arith.constant 7392 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c77, %c96], + strides = [%c7392, %c7392, %c7392, %c96, %c1] + : !pto.tensor_view<1x1x1x77x96xi16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c77, %c96], + strides = [%c7392, %c7392, %c7392, %c96, %c1] + : !pto.tensor_view<1x1x1x77x96xi16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c77, %c96], + strides = [%c7392, %c7392, %c7392, %c96, %c1] + : !pto.tensor_view<1x1x1x77x96xi8> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c77, %c96] + : !pto.tensor_view<1x1x1x77x96xi16> -> !pto.partition_tensor_view<1x1x1x77x96xi16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c77, %c96] + : !pto.tensor_view<1x1x1x77x96xi16> -> !pto.partition_tensor_view<1x1x1x77x96xi16> + %c32 = arith.constant 32 : index + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x77x96xi8> -> !pto.partition_tensor_view<1x1x1x32x32xi8> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x77x96xi16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x77x96xi16>) + outs(%b : !pto.tile_buf) + + pto.tcmp ins(%a, %b {cmpMode = #pto} : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x32x32xi8>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcmps/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tcmps/CMakeLists.txt new file mode 100644 index 000000000..5b766cc09 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcmps/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tcmps) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcmps/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tcmps/cases.py new file mode 100644 index 000000000..329945869 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcmps/cases.py @@ -0,0 +1,131 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tcmps ST test cases. + +tcmps: packed mask of (src < scalar), dst stores packed predicate mask. +Supports 32-bit source types: f32, i32. Output dtype is uint8. + +Cases reference testcase/tcmps with various shapes and valid regions. +""" + +import numpy as np + +CASES = [ + # float32 cases matching testcase/tcmps + { + "name": "f32_1x64", + "dtype": np.float32, + "out_dtype": np.uint8, + "shape": (1, 64), + "valid_shape": (1, 64), + "eps": 0, + }, + { + "name": "f32_4x64", + "dtype": np.float32, + "out_dtype": np.uint8, + "shape": (4, 64), + "valid_shape": (4, 64), + "eps": 0, + }, + { + "name": "f32_8x64", + "dtype": np.float32, + "out_dtype": np.uint8, + "shape": (8, 64), + "valid_shape": (8, 64), + "eps": 0, + }, + { + "name": "f32_32x64", + "dtype": np.float32, + "out_dtype": np.uint8, + "shape": (32, 64), + "valid_shape": (32, 64), + "eps": 0, + }, + { + "name": "f32_128x128", + "dtype": np.float32, + "out_dtype": np.uint8, + "shape": (128, 128), + "valid_shape": (128, 128), + "eps": 0, + }, + # int32 cases matching testcase/tcmps + { + "name": "i32_16x32", + "dtype": np.int32, + "out_dtype": np.uint8, + "shape": (16, 32), + "valid_shape": (16, 32), + "eps": 0, + }, + { + "name": "i32_32x32", + "dtype": np.int32, + "out_dtype": np.uint8, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 0, + }, + { + "name": "i32_32x64_valid32x64", + "dtype": np.int32, + "out_dtype": np.uint8, + "shape": (64, 64), + "valid_shape": (32, 64), + "eps": 0, + }, + # Non-aligned cases + { + "name": "f32_7x448", + "dtype": np.float32, + "out_dtype": np.uint8, + "shape": (7, 448), + "valid_shape": (7, 448), + "eps": 0, + }, + { + "name": "f32_256x16", + "dtype": np.float32, + "out_dtype": np.uint8, + "shape": (256, 16), + "valid_shape": (256, 16), + "eps": 0, + }, + { + "name": "i32_31x128", + "dtype": np.int32, + "out_dtype": np.uint8, + "shape": (31, 128), + "valid_shape": (31, 128), + "eps": 0, + }, + # 16B cases (f16, i16) + { + "name": "f16_32x128", + "dtype": np.float16, + "out_dtype": np.uint8, + "shape": (32, 128), + "valid_shape": (32, 128), + "eps": 0, + }, + { + "name": "i16_32x128", + "dtype": np.int16, + "out_dtype": np.uint8, + "shape": (32, 128), + "valid_shape": (32, 128), + "eps": 0, + }, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcmps/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tcmps/compare.py new file mode 100644 index 000000000..3d012d09c --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcmps/compare.py @@ -0,0 +1,89 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + dtype = case["dtype"] + out_dtype = case["out_dtype"] + elem_size = np.dtype(dtype).itemsize + lanes = 256 // elem_size + + # Calculate expected output size (same as gen_data.py) + total_elm = vr * vc + if elem_size == 4: # 32B + bytes_per_iter = 16 + repeat_times = (total_elm + lanes - 1) // lanes + 1 + total_iters = repeat_times // 2 + expected_bytes = total_iters * bytes_per_iter + elif elem_size == 2: # 16B + bytes_per_iter = 16 + iters_per_row = (vc + lanes - 1) // lanes + expected_bytes = vr * iters_per_row * bytes_per_iter + else: # 8B + bytes_per_iter = 32 + iters_per_row = (vc + lanes - 1) // lanes + expected_bytes = vr * iters_per_row * bytes_per_iter + + # Read golden (already correct size from gen_data.py) + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=np.uint8) + + # Read output and truncate/zero-pad to expected size + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=np.uint8) + if len(output) > expected_bytes: + output = output[:expected_bytes] + elif len(output) < expected_bytes: + output = np.pad(output, (0, expected_bytes - len(output)), mode='constant') + + # Compare byte-by-byte + ok = np.array_equal(golden, output) + if not ok: + # Find first mismatch for debugging + diff_mask = golden != output + diff_indices = np.where(diff_mask)[0] + if len(diff_indices) > 0: + diff_idx = diff_indices[0] + max_diff = int(np.max(np.abs(golden.astype(int) - output.astype(int)))) + print(style_fail(f"[ERROR] Mismatch: max diff={max_diff} at byte idx={diff_idx} " + f"(golden=0x{golden[diff_idx]:02x}, output=0x{output[diff_idx]:02x})")) + else: + print(style_fail(f"[ERROR] Mismatch: shapes differ golden={golden.shape} output={output.shape}")) + + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcmps/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tcmps/gen_data.py new file mode 100644 index 000000000..95a8264f3 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcmps/gen_data.py @@ -0,0 +1,109 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +# Scalar value for comparison (matches the scalar passed in launch.cpp) +SCALAR = 5.0 + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + out_dtype = case["out_dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + # Generate random input matching testcase/tcmps pattern + if np.issubdtype(dtype, np.floating): + input1 = np.random.randint(-5, 5, size=shape).astype(dtype) + else: + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + + vr, vc = valid_shape + if np.issubdtype(dtype, np.floating): + scalar_val = dtype(SCALAR) + else: + scalar_val = dtype(int(SCALAR)) + + # Compute element-wise comparison result (0 or 1 per element) + # Using "lt" mode to match the template + cmp_result = (input1[:vr, :vc] < scalar_val).astype(np.uint8, copy=False) + + # tcmps output uses psts: + # - 32B: 64 elements -> 32 bytes (NORM mode, sequential, bit_pos = col_in_iter * 4) + # - 16B: 128 elements -> 16 bytes (PK mode, bit_pos = col_in_iter) + # - 8B: 256 elements -> 32 bytes (NORM mode, sequential, bit_pos = col_in_iter) + elem_size = np.dtype(dtype).itemsize + lanes = 256 // elem_size + if elem_size == 4: # 32B: 2 vcmps + dintlv_b8 -> PK mode (16 bytes per iteration) + bytes_per_iter = 16 + bit_multiplier = 1 + # For 32B, each iteration processes 2 repeats (128 elements) + # Element linear index maps to bit position after dintlv_b8 + elif elem_size == 2: # 16B: PK mode (16 bytes per iteration) + bytes_per_iter = 16 + bit_multiplier = 1 + else: # 8B: NORM mode (32 bytes per iteration) + bytes_per_iter = 32 + bit_multiplier = 1 + + # Calculate iterations (total) + total_elm = vr * vc + if elem_size == 4: # 32B: special handling for linear offset + repeat_times = (total_elm + lanes - 1) // lanes + 1 + total_iters = repeat_times // 2 + else: + iters_per_row = (vc + lanes - 1) // lanes + + total_elm = vr * vc + if elem_size == 4: # 32B: special handling for linear offset + repeat_times = (total_elm + lanes - 1) // lanes + 1 + total_iters = repeat_times // 2 + total_output_bytes = total_iters * bytes_per_iter + else: + iters_per_row = (vc + lanes - 1) // lanes + total_iters = vr * iters_per_row + total_output_bytes = total_iters * bytes_per_iter + + # Output buffer size matches actual output + golden = np.zeros(total_output_bytes, dtype=np.uint8) + + for row in range(vr): + for col in range(vc): + if cmp_result[row, col]: + if elem_size == 4: # 32B: PK mode after dintlv_b8 with linear offset + # Linear element index + linear_idx = row * vc + col + # Each iteration processes 128 elements (2 repeats of 64) + iter_idx = linear_idx // (2 * lanes) + # Position within the 128-element block + pos_in_block = linear_idx % (2 * lanes) + # PK mode: bit position = pos_in_block + bit_pos = pos_in_block + # Byte offset (linear) + byte_idx = iter_idx * bytes_per_iter + (bit_pos // 8) + bit_idx = bit_pos % 8 + else: # 16B and 8B + col_in_iter = col % lanes + bit_pos = col_in_iter * bit_multiplier + byte_idx = (row * iters_per_row + col // lanes) * bytes_per_iter + (bit_pos // 8) + bit_idx = bit_pos % 8 + + if byte_idx < total_output_bytes: + golden[byte_idx] |= (1 << bit_idx) + + save_case_data(case["name"], {"input1": input1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__} out_dtype={out_dtype.__name__} scalar={SCALAR}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcmps/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcmps/launch.cpp new file mode 100644 index 000000000..ee145de17 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcmps/launch.cpp @@ -0,0 +1,112 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Scalar value for comparison (must match gen_data.py SCALAR) +static constexpr float TCMP_SCALAR_F32 = 5.0f; +static constexpr int32_t TCMP_SCALAR_I32 = 5; + +// Case 0: f32 1x64 +extern "C" __global__ AICORE void TCMP_f32_1x64(__gm__ float *src, __gm__ uint8_t *dst, float scalar); + +void LaunchTCMP_f32_1x64(float *src, uint8_t *dst, void *stream) { + TCMP_f32_1x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint8_t *)dst, TCMP_SCALAR_F32); +} + +// Case 1: f32 4x64 +extern "C" __global__ AICORE void TCMP_f32_4x64(__gm__ float *src, __gm__ uint8_t *dst, float scalar); + +void LaunchTCMP_f32_4x64(float *src, uint8_t *dst, void *stream) { + TCMP_f32_4x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint8_t *)dst, TCMP_SCALAR_F32); +} + +// Case 2: f32 8x64 +extern "C" __global__ AICORE void TCMP_f32_8x64(__gm__ float *src, __gm__ uint8_t *dst, float scalar); + +void LaunchTCMP_f32_8x64(float *src, uint8_t *dst, void *stream) { + TCMP_f32_8x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint8_t *)dst, TCMP_SCALAR_F32); +} + +// Case 3: f32 32x64 +extern "C" __global__ AICORE void TCMP_f32_32x64(__gm__ float *src, __gm__ uint8_t *dst, float scalar); + +void LaunchTCMP_f32_32x64(float *src, uint8_t *dst, void *stream) { + TCMP_f32_32x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint8_t *)dst, TCMP_SCALAR_F32); +} + +// Case 4: f32 128x128 +extern "C" __global__ AICORE void TCMP_f32_128x128(__gm__ float *src, __gm__ uint8_t *dst, float scalar); + +void LaunchTCMP_f32_128x128(float *src, uint8_t *dst, void *stream) { + TCMP_f32_128x128<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint8_t *)dst, TCMP_SCALAR_F32); +} + +// Case 5: i32 16x32 +extern "C" __global__ AICORE void TCMP_i32_16x32(__gm__ int32_t *src, __gm__ uint8_t *dst, int32_t scalar); + +void LaunchTCMP_i32_16x32(int32_t *src, uint8_t *dst, void *stream) { + TCMP_i32_16x32<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ uint8_t *)dst, TCMP_SCALAR_I32); +} + +// Case 6: i32 32x32 +extern "C" __global__ AICORE void TCMP_i32_32x32(__gm__ int32_t *src, __gm__ uint8_t *dst, int32_t scalar); + +void LaunchTCMP_i32_32x32(int32_t *src, uint8_t *dst, void *stream) { + TCMP_i32_32x32<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ uint8_t *)dst, TCMP_SCALAR_I32); +} + +// Case 7: i32 64x64 tile with valid 32x64 +extern "C" __global__ AICORE void TCMP_i32_32x64_valid32x64(__gm__ int32_t *src, __gm__ uint8_t *dst, int32_t scalar); + +void LaunchTCMP_i32_32x64_valid32x64(int32_t *src, uint8_t *dst, void *stream) { + TCMP_i32_32x64_valid32x64<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ uint8_t *)dst, TCMP_SCALAR_I32); +} + +// Case 8: f32 7x448 +extern "C" __global__ AICORE void TCMP_f32_7x448(__gm__ float *src, __gm__ uint8_t *dst, float scalar); + +void LaunchTCMP_f32_7x448(float *src, uint8_t *dst, void *stream) { + TCMP_f32_7x448<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint8_t *)dst, TCMP_SCALAR_F32); +} + +// Case 9: f32 256x16 +extern "C" __global__ AICORE void TCMP_f32_256x16(__gm__ float *src, __gm__ uint8_t *dst, float scalar); + +void LaunchTCMP_f32_256x16(float *src, uint8_t *dst, void *stream) { + TCMP_f32_256x16<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint8_t *)dst, TCMP_SCALAR_F32); +} + +// Case 10: i32 31x128 +extern "C" __global__ AICORE void TCMP_i32_31x128(__gm__ int32_t *src, __gm__ uint8_t *dst, int32_t scalar); + +void LaunchTCMP_i32_31x128(int32_t *src, uint8_t *dst, void *stream) { + TCMP_i32_31x128<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ uint8_t *)dst, TCMP_SCALAR_I32); +} + +// Case 11: f16 32x128 +static constexpr uint16_t TCMP_SCALAR_F16 = 0x4500; // 5.0 in half precision + +extern "C" __global__ AICORE void TCMP_f16_32x128(__gm__ uint16_t *src, __gm__ uint8_t *dst, uint16_t scalar); + +void LaunchTCMP_f16_32x128(uint16_t *src, uint8_t *dst, void *stream) { + TCMP_f16_32x128<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint8_t *)dst, TCMP_SCALAR_F16); +} + +// Case 12: i16 32x128 +static constexpr int16_t TCMP_SCALAR_I16 = 5; + +extern "C" __global__ AICORE void TCMP_i16_32x128(__gm__ int16_t *src, __gm__ uint8_t *dst, int16_t scalar); + +void LaunchTCMP_i16_32x128(int16_t *src, uint8_t *dst, void *stream) { + TCMP_i16_32x128<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ uint8_t *)dst, TCMP_SCALAR_I16); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcmps/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcmps/main.cpp new file mode 100644 index 000000000..481be2990 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcmps/main.cpp @@ -0,0 +1,156 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tcmps ST — case-table driven. +// tcmps: dst = packed mask of (src < scalar). +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTCMP_f32_1x64(float *src, uint8_t *dst, void *stream); +void LaunchTCMP_f32_4x64(float *src, uint8_t *dst, void *stream); +void LaunchTCMP_f32_8x64(float *src, uint8_t *dst, void *stream); +void LaunchTCMP_f32_32x64(float *src, uint8_t *dst, void *stream); +void LaunchTCMP_f32_128x128(float *src, uint8_t *dst, void *stream); +void LaunchTCMP_i32_16x32(int32_t *src, uint8_t *dst, void *stream); +void LaunchTCMP_i32_32x32(int32_t *src, uint8_t *dst, void *stream); +void LaunchTCMP_i32_32x64_valid32x64(int32_t *src, uint8_t *dst, void *stream); +void LaunchTCMP_f32_7x448(float *src, uint8_t *dst, void *stream); +void LaunchTCMP_f32_256x16(float *src, uint8_t *dst, void *stream); +void LaunchTCMP_i32_31x128(int32_t *src, uint8_t *dst, void *stream); +void LaunchTCMP_f16_32x128(uint16_t *src, uint8_t *dst, void *stream); +void LaunchTCMP_i16_32x128(int16_t *src, uint8_t *dst, void *stream); + +struct TestCase { + const char *name; + void (*launch)(void *, void *, void *); // src, dst, stream + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t srcElemSize; // bytes per source element + size_t dstElemSize; // bytes per destination element +}; + +static const TestCase kCases[] = { + {"f32_1x64", (void (*)(void*,void*,void*))LaunchTCMP_f32_1x64, 1, 64, 1, 64, sizeof(float), sizeof(uint8_t)}, + {"f32_4x64", (void (*)(void*,void*,void*))LaunchTCMP_f32_4x64, 4, 64, 4, 64, sizeof(float), sizeof(uint8_t)}, + {"f32_8x64", (void (*)(void*,void*,void*))LaunchTCMP_f32_8x64, 8, 64, 8, 64, sizeof(float), sizeof(uint8_t)}, + {"f32_32x64", (void (*)(void*,void*,void*))LaunchTCMP_f32_32x64, 32, 64, 32, 64, sizeof(float), sizeof(uint8_t)}, + {"f32_128x128", (void (*)(void*,void*,void*))LaunchTCMP_f32_128x128, 128, 128, 128, 128, sizeof(float), sizeof(uint8_t)}, + {"i32_16x32", (void (*)(void*,void*,void*))LaunchTCMP_i32_16x32, 16, 32, 16, 32, sizeof(int32_t), sizeof(uint8_t)}, + {"i32_32x32", (void (*)(void*,void*,void*))LaunchTCMP_i32_32x32, 32, 32, 32, 32, sizeof(int32_t), sizeof(uint8_t)}, + {"i32_32x64_valid32x64", (void (*)(void*,void*,void*))LaunchTCMP_i32_32x64_valid32x64, 64, 64, 32, 64, sizeof(int32_t), sizeof(uint8_t)}, + {"f32_7x448", (void (*)(void*,void*,void*))LaunchTCMP_f32_7x448, 7, 448, 7, 448, sizeof(float), sizeof(uint8_t)}, + {"f32_256x16", (void (*)(void*,void*,void*))LaunchTCMP_f32_256x16, 256, 16, 256, 16, sizeof(float), sizeof(uint8_t)}, + {"i32_31x128", (void (*)(void*,void*,void*))LaunchTCMP_i32_31x128, 31, 128, 31, 128, sizeof(int32_t), sizeof(uint8_t)}, + {"f16_32x128", (void (*)(void*,void*,void*))LaunchTCMP_f16_32x128, 32, 128, 32, 128, sizeof(uint16_t), sizeof(uint8_t)}, + {"i16_32x128", (void (*)(void*,void*,void*))LaunchTCMP_i16_32x128, 32, 128, 32, 128, sizeof(int16_t), sizeof(uint8_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t srcElemCount = tc.rows * tc.cols; + const size_t dstElemCount = tc.rows * tc.cols; + const size_t srcFileSize = srcElemCount * tc.srcElemSize; + const size_t dstFileSize = dstElemCount * tc.dstElemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t inputFileSize = srcFileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&srcHost, srcFileSize); + aclrtMallocHost(&dstHost, dstFileSize); + + aclrtMalloc(&srcDevice, srcFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), inputFileSize, srcHost, srcFileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, srcFileSize, srcHost, srcFileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tcmps [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcmps/tcmps.pto b/test/tilelang_st/npu/a5/src/st/testcase/tcmps/tcmps.pto new file mode 100644 index 000000000..6a5db5f85 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcmps/tcmps.pto @@ -0,0 +1,534 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tcmps: tload(src) + tcmps(src, scalar {cmpMode=lt})->dst + tstore(dst). +// Packed mask of (src < scalar), output stored as packed predicate mask (uint8). +// Supports 32-bit source types: f32, i32. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + + // Case 0: f32 1x64 (64 elements) + func.func @TCMP_f32_1x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c64i = arith.constant 64 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c64], + strides = [%c64, %c64, %c64, %c64, %c1] + : !pto.tensor_view<1x1x1x1x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c64], + strides = [%c64, %c64, %c64, %c64, %c1] + : !pto.tensor_view<1x1x1x1x64xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c64] + : !pto.tensor_view<1x1x1x1x64xf32> -> !pto.partition_tensor_view<1x1x1x1x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c64] + : !pto.tensor_view<1x1x1x1x64xui8> -> !pto.partition_tensor_view<1x1x1x1x64xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x64xf32>) + outs(%src : !pto.tile_buf) + pto.tcmps ins(%src, %scalar {cmpMode = #pto} : !pto.tile_buf, f32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x64xui8>) + return + } + + // Case 1: f32 4x64 (256 elements) + func.func @TCMP_f32_4x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c64 = arith.constant 64 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c64], + strides = [%c256, %c256, %c256, %c64, %c1] + : !pto.tensor_view<1x1x1x4x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c64], + strides = [%c256, %c256, %c256, %c64, %c1] + : !pto.tensor_view<1x1x1x4x64xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c64] + : !pto.tensor_view<1x1x1x4x64xf32> -> !pto.partition_tensor_view<1x1x1x4x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c64] + : !pto.tensor_view<1x1x1x4x64xui8> -> !pto.partition_tensor_view<1x1x1x4x64xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x64xf32>) + outs(%src : !pto.tile_buf) + pto.tcmps ins(%src, %scalar {cmpMode = #pto} : !pto.tile_buf, f32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x64xui8>) + return + } + + // Case 2: f32 8x64 (512 elements) + func.func @TCMP_f32_8x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c512 = arith.constant 512 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c8, %c64], + strides = [%c512, %c512, %c512, %c64, %c1] + : !pto.tensor_view<1x1x1x8x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c8, %c64], + strides = [%c512, %c512, %c512, %c64, %c1] + : !pto.tensor_view<1x1x1x8x64xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c64] + : !pto.tensor_view<1x1x1x8x64xf32> -> !pto.partition_tensor_view<1x1x1x8x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c64] + : !pto.tensor_view<1x1x1x8x64xui8> -> !pto.partition_tensor_view<1x1x1x8x64xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x8x64xf32>) + outs(%src : !pto.tile_buf) + pto.tcmps ins(%src, %scalar {cmpMode = #pto} : !pto.tile_buf, f32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x8x64xui8>) + return + } + + // Case 3: f32 32x64 (2048 elements) + func.func @TCMP_f32_32x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xui8> -> !pto.partition_tensor_view<1x1x1x32x64xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) + outs(%src : !pto.tile_buf) + pto.tcmps ins(%src, %scalar {cmpMode = #pto} : !pto.tile_buf, f32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x64xui8>) + return + } + + // Case 4: f32 128x128 (16384 elements) + func.func @TCMP_f32_128x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c16384 = arith.constant 16384 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c128, %c128], + strides = [%c16384, %c16384, %c16384, %c128, %c1] + : !pto.tensor_view<1x1x1x128x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c128, %c128], + strides = [%c16384, %c16384, %c16384, %c128, %c1] + : !pto.tensor_view<1x1x1x128x128xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c128] + : !pto.tensor_view<1x1x1x128x128xf32> -> !pto.partition_tensor_view<1x1x1x128x128xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c128] + : !pto.tensor_view<1x1x1x128x128xui8> -> !pto.partition_tensor_view<1x1x1x128x128xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x128x128xf32>) + outs(%src : !pto.tile_buf) + pto.tcmps ins(%src, %scalar {cmpMode = #pto} : !pto.tile_buf, f32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x128x128xui8>) + return + } + + // Case 5: i32 16x32 (512 elements) + func.func @TCMP_i32_16x32(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c512 = arith.constant 512 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xui8> -> !pto.partition_tensor_view<1x1x1x16x32xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) + outs(%src : !pto.tile_buf) + pto.tcmps ins(%src, %scalar {cmpMode = #pto} : !pto.tile_buf, i32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x32xui8>) + return + } + + // Case 6: i32 32x32 (1024 elements) + func.func @TCMP_i32_32x32(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xi32> -> !pto.partition_tensor_view<1x1x1x32x32xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xui8> -> !pto.partition_tensor_view<1x1x1x32x32xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x32xi32>) + outs(%src : !pto.tile_buf) + pto.tcmps ins(%src, %scalar {cmpMode = #pto} : !pto.tile_buf, i32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x32xui8>) + return + } + + // Case 7: i32 64x64 tile with valid 32x64 (4096 elements, 2048 valid) + func.func @TCMP_i32_32x64_valid32x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x32x64xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x64x64xui8> -> !pto.partition_tensor_view<1x1x1x32x64xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x64xi32>) + outs(%src : !pto.tile_buf) + pto.tcmps ins(%src, %scalar {cmpMode = #pto} : !pto.tile_buf, i32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x64xui8>) + return + } + + // Case 8: f32 7x448 (3136 elements) + func.func @TCMP_f32_7x448(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c7 = arith.constant 7 : index + %c448 = arith.constant 448 : index + %c3136 = arith.constant 3136 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c7, %c448], + strides = [%c3136, %c3136, %c3136, %c448, %c1] + : !pto.tensor_view<1x1x1x7x448xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c7, %c448], + strides = [%c3136, %c3136, %c3136, %c448, %c1] + : !pto.tensor_view<1x1x1x7x448xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c448] + : !pto.tensor_view<1x1x1x7x448xf32> -> !pto.partition_tensor_view<1x1x1x7x448xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c448] + : !pto.tensor_view<1x1x1x7x448xui8> -> !pto.partition_tensor_view<1x1x1x7x448xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x7x448xf32>) + outs(%src : !pto.tile_buf) + pto.tcmps ins(%src, %scalar {cmpMode = #pto} : !pto.tile_buf, f32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x7x448xui8>) + return + } + + // Case 9: f32 256x16 (4096 elements) + func.func @TCMP_f32_256x16(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c16 = arith.constant 16 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c256, %c16], + strides = [%c4096, %c4096, %c4096, %c16, %c1] + : !pto.tensor_view<1x1x1x256x16xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c256, %c16], + strides = [%c4096, %c4096, %c4096, %c16, %c1] + : !pto.tensor_view<1x1x1x256x16xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c16] + : !pto.tensor_view<1x1x1x256x16xf32> -> !pto.partition_tensor_view<1x1x1x256x16xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c16] + : !pto.tensor_view<1x1x1x256x16xui8> -> !pto.partition_tensor_view<1x1x1x256x16xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x256x16xf32>) + outs(%src : !pto.tile_buf) + pto.tcmps ins(%src, %scalar {cmpMode = #pto} : !pto.tile_buf, f32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x256x16xui8>) + return + } + + // Case 10: i32 31x128 (3968 elements) + func.func @TCMP_i32_31x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c31 = arith.constant 31 : index + %c128 = arith.constant 128 : index + %c3968 = arith.constant 3968 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c128] + : !pto.tensor_view<1x1x1x31x128xi32> -> !pto.partition_tensor_view<1x1x1x31x128xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c128] + : !pto.tensor_view<1x1x1x31x128xui8> -> !pto.partition_tensor_view<1x1x1x31x128xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x31x128xi32>) + outs(%src : !pto.tile_buf) + pto.tcmps ins(%src, %scalar {cmpMode = #pto} : !pto.tile_buf, i32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x31x128xui8>) + return + } + + // Case 11: f16 32x128 (4096 elements) + func.func @TCMP_f16_32x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f16) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c32, %c128], + strides = [%c4096, %c4096, %c4096, %c128, %c1] + : !pto.tensor_view<1x1x1x32x128xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c128], + strides = [%c4096, %c4096, %c4096, %c128, %c1] + : !pto.tensor_view<1x1x1x32x128xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c128] + : !pto.tensor_view<1x1x1x32x128xf16> -> !pto.partition_tensor_view<1x1x1x32x128xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c128] + : !pto.tensor_view<1x1x1x32x128xui8> -> !pto.partition_tensor_view<1x1x1x32x128xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x128xf16>) + outs(%src : !pto.tile_buf) + pto.tcmps ins(%src, %scalar {cmpMode = #pto} : !pto.tile_buf, f16) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x128xui8>) + return + } + + // Case 12: i16 32x128 (4096 elements) + func.func @TCMP_i16_32x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i16) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c32, %c128], + strides = [%c4096, %c4096, %c4096, %c128, %c1] + : !pto.tensor_view<1x1x1x32x128xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c128], + strides = [%c4096, %c4096, %c4096, %c128, %c1] + : !pto.tensor_view<1x1x1x32x128xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c128] + : !pto.tensor_view<1x1x1x32x128xi16> -> !pto.partition_tensor_view<1x1x1x32x128xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c128] + : !pto.tensor_view<1x1x1x32x128xui8> -> !pto.partition_tensor_view<1x1x1x32x128xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x128xi16>) + outs(%src : !pto.tile_buf) + pto.tcmps ins(%src, %scalar {cmpMode = #pto} : !pto.tile_buf, i16) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x128xui8>) + return + } + +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolargmax/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tcolargmax/CMakeLists.txt new file mode 100644 index 000000000..142f38980 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolargmax/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tcolargmax) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolargmax/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolargmax/cases.py new file mode 100644 index 000000000..11addaa2f --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolargmax/cases.py @@ -0,0 +1,210 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tcolargmax ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype for input data (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions for input (src and tmp). + - valid_shape: (valid_rows, valid_cols) — effective computation region for input (src and tmp). + - dst_shape: (1, cols) — allocated tile dimensions for output (indices). + - dst_valid_shape: (1, valid_cols) — effective computation region for output (indices). + - dst_dtype: numpy dtype for output indices (np.int32). + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_1x256", + "dtype": np.float32, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "dst_dtype": np.int32, + "eps": 0, + }, + { + "name": "f32_16x128", + "dtype": np.float32, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "dst_dtype": np.int32, + "eps": 0, + }, + { + "name": "f32_16x256", + "dtype": np.float32, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "dst_dtype": np.int32, + "eps": 0, + }, + { + "name": "f16_1x256", + "dtype": np.float16, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "dst_dtype": np.int32, + "eps": 0, + }, + { + "name": "f16_16x128", + "dtype": np.float16, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "dst_dtype": np.int32, + "eps": 0, + }, + { + "name": "f16_16x256", + "dtype": np.float16, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "dst_dtype": np.int32, + "eps": 0, + }, + { + "name": "ui32_1x256", + "dtype": np.uint32, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "dst_dtype": np.int32, + "eps": 0, + }, + { + "name": "ui32_16x128", + "dtype": np.uint32, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "dst_dtype": np.int32, + "eps": 0, + }, + { + "name": "ui32_16x256", + "dtype": np.uint32, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "dst_dtype": np.int32, + "eps": 0, + }, + { + "name": "ui16_1x256", + "dtype": np.uint16, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "dst_dtype": np.int32, + "eps": 0, + }, + { + "name": "ui16_16x128", + "dtype": np.uint16, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "dst_dtype": np.int32, + "eps": 0, + }, + { + "name": "ui16_16x256", + "dtype": np.uint16, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "dst_dtype": np.int32, + "eps": 0, + }, + { + "name": "ui8_1x256", + "dtype": np.uint8, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "dst_dtype": np.int32, + "eps": 0, + }, + { + "name": "ui8_16x128", + "dtype": np.uint8, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "dst_dtype": np.int32, + "eps": 0, + }, + { + "name": "ui8_16x256", + "dtype": np.uint8, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "dst_dtype": np.int32, + "eps": 0, + }, + { + "name": "i8_1x256", + "dtype": np.int8, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "dst_dtype": np.int32, + "eps": 0, + }, + { + "name": "i8_16x128", + "dtype": np.int8, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "dst_dtype": np.int32, + "eps": 0, + }, + { + "name": "i8_16x256", + "dtype": np.int8, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "dst_dtype": np.int32, + "eps": 0, + }, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolargmax/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolargmax/compare.py new file mode 100644 index 000000000..5be940218 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolargmax/compare.py @@ -0,0 +1,51 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + dst_dtype = case["dst_dtype"] + vr, vc = dst_valid_shape + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=dst_dtype).reshape(dst_shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=dst_dtype).reshape(dst_shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolargmax/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolargmax/gen_data.py new file mode 100644 index 000000000..8041f9aac --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolargmax/gen_data.py @@ -0,0 +1,36 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + dst_dtype = case["dst_dtype"] + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + + vr, vc = valid_shape + golden = np.zeros(dst_shape, dtype=dst_dtype) + golden_result = np.argmax(input1[:vr, :vc], axis=0, keepdims=True).astype(dst_dtype) + golden[:1, :vc] = golden_result + + save_case_data(case["name"], {"input": input1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dst_shape={dst_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolargmax/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolargmax/launch.cpp new file mode 100644 index 000000000..8fddc4a1e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolargmax/launch.cpp @@ -0,0 +1,139 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: f32 1x256 (input: 1x256, tmp: 1x256, output: 1x256 indices) +extern "C" __global__ AICORE void TCOLARGMAX_f32_1x256(__gm__ int32_t *dst, __gm__ float *tmp, __gm__ float *src); + +void LaunchTCOLARGMAX_f32_1x256(int32_t *dst, float *tmp, float *src, void *stream) { + TCOLARGMAX_f32_1x256<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ float *)tmp, (__gm__ float *)src); +} + +// Case 1: f32 16x128 (input: 16x128, tmp: 16x128, output: 1x128 indices) +extern "C" __global__ AICORE void TCOLARGMAX_f32_16x128(__gm__ int32_t *dst, __gm__ float *tmp, __gm__ float *src); + +void LaunchTCOLARGMAX_f32_16x128(int32_t *dst, float *tmp, float *src, void *stream) { + TCOLARGMAX_f32_16x128<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ float *)tmp, (__gm__ float *)src); +} + +// Case 2: f32 16x256 (input: 16x256, tmp: 16x256, output: 1x256 indices) +extern "C" __global__ AICORE void TCOLARGMAX_f32_16x256(__gm__ int32_t *dst, __gm__ float *tmp, __gm__ float *src); + +void LaunchTCOLARGMAX_f32_16x256(int32_t *dst, float *tmp, float *src, void *stream) { + TCOLARGMAX_f32_16x256<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ float *)tmp, (__gm__ float *)src); +} + +// Case 3: f16 1x256 (input: 1x256, tmp: 1x256, output: 1x256 indices) +extern "C" __global__ AICORE void TCOLARGMAX_f16_1x256(__gm__ int32_t *dst, __gm__ half *tmp, __gm__ half *src); + +void LaunchTCOLARGMAX_f16_1x256(void *dst, void *tmp, void *src, void *stream) { + TCOLARGMAX_f16_1x256<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ half *)tmp, (__gm__ half *)src); +} + +// Case 4: f16 16x128 (input: 16x128, tmp: 16x128, output: 1x128 indices) +extern "C" __global__ AICORE void TCOLARGMAX_f16_16x128(__gm__ int32_t *dst, __gm__ half *tmp, __gm__ half *src); + +void LaunchTCOLARGMAX_f16_16x128(void *dst, void *tmp, void *src, void *stream) { + TCOLARGMAX_f16_16x128<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ half *)tmp, (__gm__ half *)src); +} + +// Case 5: f16 16x256 (input: 16x256, tmp: 16x256, output: 1x256 indices) +extern "C" __global__ AICORE void TCOLARGMAX_f16_16x256(__gm__ int32_t *dst, __gm__ half *tmp, __gm__ half *src); + +void LaunchTCOLARGMAX_f16_16x256(void *dst, void *tmp, void *src, void *stream) { + TCOLARGMAX_f16_16x256<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ half *)tmp, (__gm__ half *)src); +} + +// Case 6: ui32 1x256 (input: 1x256, tmp: 1x256, output: 1x256 indices) +extern "C" __global__ AICORE void TCOLARGMAX_ui32_1x256(__gm__ int32_t *dst, __gm__ uint32_t *tmp, __gm__ uint32_t *src); + +void LaunchTCOLARGMAX_ui32_1x256(void *dst, void *tmp, void *src, void *stream) { + TCOLARGMAX_ui32_1x256<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ uint32_t *)tmp, (__gm__ uint32_t *)src); +} + +// Case 7: ui32 16x128 (input: 16x128, tmp: 16x128, output: 1x128 indices) +extern "C" __global__ AICORE void TCOLARGMAX_ui32_16x128(__gm__ int32_t *dst, __gm__ uint32_t *tmp, __gm__ uint32_t *src); + +void LaunchTCOLARGMAX_ui32_16x128(void *dst, void *tmp, void *src, void *stream) { + TCOLARGMAX_ui32_16x128<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ uint32_t *)tmp, (__gm__ uint32_t *)src); +} + +// Case 8: ui32 16x256 (input: 16x256, tmp: 16x256, output: 1x256 indices) +extern "C" __global__ AICORE void TCOLARGMAX_ui32_16x256(__gm__ int32_t *dst, __gm__ uint32_t *tmp, __gm__ uint32_t *src); + +void LaunchTCOLARGMAX_ui32_16x256(void *dst, void *tmp, void *src, void *stream) { + TCOLARGMAX_ui32_16x256<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ uint32_t *)tmp, (__gm__ uint32_t *)src); +} + +// Case 9: ui16 1x256 (input: 1x256, tmp: 1x256, output: 1x256 indices) +extern "C" __global__ AICORE void TCOLARGMAX_ui16_1x256(__gm__ int32_t *dst, __gm__ uint16_t *tmp, __gm__ uint16_t *src); + +void LaunchTCOLARGMAX_ui16_1x256(void *dst, void *tmp, void *src, void *stream) { + TCOLARGMAX_ui16_1x256<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ uint16_t *)tmp, (__gm__ uint16_t *)src); +} + +// Case 10: ui16 16x128 (input: 16x128, tmp: 16x128, output: 1x128 indices) +extern "C" __global__ AICORE void TCOLARGMAX_ui16_16x128(__gm__ int32_t *dst, __gm__ uint16_t *tmp, __gm__ uint16_t *src); + +void LaunchTCOLARGMAX_ui16_16x128(void *dst, void *tmp, void *src, void *stream) { + TCOLARGMAX_ui16_16x128<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ uint16_t *)tmp, (__gm__ uint16_t *)src); +} + +// Case 11: ui16 16x256 (input: 16x256, tmp: 16x256, output: 1x256 indices) +extern "C" __global__ AICORE void TCOLARGMAX_ui16_16x256(__gm__ int32_t *dst, __gm__ uint16_t *tmp, __gm__ uint16_t *src); + +void LaunchTCOLARGMAX_ui16_16x256(void *dst, void *tmp, void *src, void *stream) { + TCOLARGMAX_ui16_16x256<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ uint16_t *)tmp, (__gm__ uint16_t *)src); +} + +// Case 12: ui8 1x256 (input: 1x256, tmp: 1x256, output: 1x256 indices) +extern "C" __global__ AICORE void TCOLARGMAX_ui8_1x256(__gm__ int32_t *dst, __gm__ uint8_t *tmp, __gm__ uint8_t *src); + +void LaunchTCOLARGMAX_ui8_1x256(void *dst, void *tmp, void *src, void *stream) { + TCOLARGMAX_ui8_1x256<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ uint8_t *)tmp, (__gm__ uint8_t *)src); +} + +// Case 13: ui8 16x128 (input: 16x128, tmp: 16x128, output: 1x128 indices) +extern "C" __global__ AICORE void TCOLARGMAX_ui8_16x128(__gm__ int32_t *dst, __gm__ uint8_t *tmp, __gm__ uint8_t *src); + +void LaunchTCOLARGMAX_ui8_16x128(void *dst, void *tmp, void *src, void *stream) { + TCOLARGMAX_ui8_16x128<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ uint8_t *)tmp, (__gm__ uint8_t *)src); +} + +// Case 14: ui8 16x256 (input: 16x256, tmp: 16x256, output: 1x256 indices) +extern "C" __global__ AICORE void TCOLARGMAX_ui8_16x256(__gm__ int32_t *dst, __gm__ uint8_t *tmp, __gm__ uint8_t *src); + +void LaunchTCOLARGMAX_ui8_16x256(void *dst, void *tmp, void *src, void *stream) { + TCOLARGMAX_ui8_16x256<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ uint8_t *)tmp, (__gm__ uint8_t *)src); +} + +// Case 15: i8 1x256 (input: 1x256, tmp: 1x256, output: 1x256 indices) +extern "C" __global__ AICORE void TCOLARGMAX_i8_1x256(__gm__ int32_t *dst, __gm__ int8_t *tmp, __gm__ int8_t *src); + +void LaunchTCOLARGMAX_i8_1x256(void *dst, void *tmp, void *src, void *stream) { + TCOLARGMAX_i8_1x256<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ int8_t *)tmp, (__gm__ int8_t *)src); +} + +// Case 16: i8 16x128 (input: 16x128, tmp: 16x128, output: 1x128 indices) +extern "C" __global__ AICORE void TCOLARGMAX_i8_16x128(__gm__ int32_t *dst, __gm__ int8_t *tmp, __gm__ int8_t *src); + +void LaunchTCOLARGMAX_i8_16x128(void *dst, void *tmp, void *src, void *stream) { + TCOLARGMAX_i8_16x128<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ int8_t *)tmp, (__gm__ int8_t *)src); +} + +// Case 17: i8 16x256 (input: 16x256, tmp: 16x256, output: 1x256 indices) +extern "C" __global__ AICORE void TCOLARGMAX_i8_16x256(__gm__ int32_t *dst, __gm__ int8_t *tmp, __gm__ int8_t *src); + +void LaunchTCOLARGMAX_i8_16x256(void *dst, void *tmp, void *src, void *stream) { + TCOLARGMAX_i8_16x256<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ int8_t *)tmp, (__gm__ int8_t *)src); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolargmax/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolargmax/main.cpp new file mode 100644 index 000000000..18712044c --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolargmax/main.cpp @@ -0,0 +1,195 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tcolargmax ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTCOLARGMAX_f32_1x256(int32_t *dst, float *tmp, float *src, void *stream); +void LaunchTCOLARGMAX_f32_16x128(int32_t *dst, float *tmp, float *src, void *stream); +void LaunchTCOLARGMAX_f32_16x256(int32_t *dst, float *tmp, float *src, void *stream); +void LaunchTCOLARGMAX_f16_1x256(void *dst, void *tmp, void *src, void *stream); +void LaunchTCOLARGMAX_f16_16x128(void *dst, void *tmp, void *src, void *stream); +void LaunchTCOLARGMAX_f16_16x256(void *dst, void *tmp, void *src, void *stream); +void LaunchTCOLARGMAX_ui32_1x256(void *dst, void *tmp, void *src, void *stream); +void LaunchTCOLARGMAX_ui32_16x128(void *dst, void *tmp, void *src, void *stream); +void LaunchTCOLARGMAX_ui32_16x256(void *dst, void *tmp, void *src, void *stream); +void LaunchTCOLARGMAX_ui16_1x256(void *dst, void *tmp, void *src, void *stream); +void LaunchTCOLARGMAX_ui16_16x128(void *dst, void *tmp, void *src, void *stream); +void LaunchTCOLARGMAX_ui16_16x256(void *dst, void *tmp, void *src, void *stream); +void LaunchTCOLARGMAX_ui8_1x256(void *dst, void *tmp, void *src, void *stream); +void LaunchTCOLARGMAX_ui8_16x128(void *dst, void *tmp, void *src, void *stream); +void LaunchTCOLARGMAX_ui8_16x256(void *dst, void *tmp, void *src, void *stream); +void LaunchTCOLARGMAX_i8_1x256(void *dst, void *tmp, void *src, void *stream); +void LaunchTCOLARGMAX_i8_16x128(void *dst, void *tmp, void *src, void *stream); +void LaunchTCOLARGMAX_i8_16x256(void *dst, void *tmp, void *src, void *stream); + +using LaunchFnFloat = void (*)(int32_t *, float *, float *, void *); +using LaunchFnVoid = void (*)(void *, void *, void *, void *); + +struct TestCase { + const char *name; + void *launch; + size_t srcRows; + size_t srcCols; + size_t srcValidRows; + size_t srcValidCols; + size_t tmpRows; + size_t tmpCols; + size_t tmpValidRows; + size_t tmpValidCols; + size_t dstRows; + size_t dstCols; + size_t dstValidCols; + size_t srcElemSize; + size_t dstElemSize; + bool isFp16; + bool isUi32; + bool isUi16; + bool isUi8; + bool isI8; +}; + +static const TestCase kCases[] = { + {"f32_1x256", (void*)LaunchTCOLARGMAX_f32_1x256, 1, 256, 1, 255, 1, 256, 1, 255, 1, 256, 255, sizeof(float), sizeof(int32_t), false, false, false, false, false}, + {"f32_16x128", (void*)LaunchTCOLARGMAX_f32_16x128, 16, 128, 16, 127, 16, 128, 16, 127, 1, 128, 127, sizeof(float), sizeof(int32_t), false, false, false, false, false}, + {"f32_16x256", (void*)LaunchTCOLARGMAX_f32_16x256, 16, 256, 15, 255, 16, 256, 15, 255, 1, 256, 255, sizeof(float), sizeof(int32_t), false, false, false, false, false}, + {"f16_1x256", (void*)LaunchTCOLARGMAX_f16_1x256, 1, 256, 1, 255, 1, 256, 1, 255, 1, 256, 255, 2, sizeof(int32_t), true, false, false, false, false}, + {"f16_16x128", (void*)LaunchTCOLARGMAX_f16_16x128, 16, 128, 16, 127, 16, 128, 16, 127, 1, 128, 127, 2, sizeof(int32_t), true, false, false, false, false}, + {"f16_16x256", (void*)LaunchTCOLARGMAX_f16_16x256, 16, 256, 15, 255, 16, 256, 15, 255, 1, 256, 255, 2, sizeof(int32_t), true, false, false, false, false}, + {"ui32_1x256", (void*)LaunchTCOLARGMAX_ui32_1x256, 1, 256, 1, 255, 1, 256, 1, 255, 1, 256, 255, sizeof(uint32_t), sizeof(int32_t), false, true, false, false, false}, + {"ui32_16x128", (void*)LaunchTCOLARGMAX_ui32_16x128, 16, 128, 16, 127, 16, 128, 16, 127, 1, 128, 127, sizeof(uint32_t), sizeof(int32_t), false, true, false, false, false}, + {"ui32_16x256", (void*)LaunchTCOLARGMAX_ui32_16x256, 16, 256, 15, 255, 16, 256, 15, 255, 1, 256, 255, sizeof(uint32_t), sizeof(int32_t), false, true, false, false, false}, + {"ui16_1x256", (void*)LaunchTCOLARGMAX_ui16_1x256, 1, 256, 1, 255, 1, 256, 1, 255, 1, 256, 255, sizeof(uint16_t), sizeof(int32_t), false, false, true, false, false}, + {"ui16_16x128", (void*)LaunchTCOLARGMAX_ui16_16x128, 16, 128, 16, 127, 16, 128, 16, 127, 1, 128, 127, sizeof(uint16_t), sizeof(int32_t), false, false, true, false, false}, + {"ui16_16x256", (void*)LaunchTCOLARGMAX_ui16_16x256, 16, 256, 15, 255, 16, 256, 15, 255, 1, 256, 255, sizeof(uint16_t), sizeof(int32_t), false, false, true, false, false}, + {"ui8_1x256", (void*)LaunchTCOLARGMAX_ui8_1x256, 1, 256, 1, 255, 1, 256, 1, 255, 1, 256, 255, sizeof(uint8_t), sizeof(int32_t), false, false, false, true, false}, + {"ui8_16x128", (void*)LaunchTCOLARGMAX_ui8_16x128, 16, 128, 16, 127, 16, 128, 16, 127, 1, 128, 127, sizeof(uint8_t), sizeof(int32_t), false, false, false, true, false}, + {"ui8_16x256", (void*)LaunchTCOLARGMAX_ui8_16x256, 16, 256, 15, 255, 16, 256, 15, 255, 1, 256, 255, sizeof(uint8_t), sizeof(int32_t), false, false, false, true, false}, + {"i8_1x256", (void*)LaunchTCOLARGMAX_i8_1x256, 1, 256, 1, 255, 1, 256, 1, 255, 1, 256, 255, sizeof(int8_t), sizeof(int32_t), false, false, false, false, true}, + {"i8_16x128", (void*)LaunchTCOLARGMAX_i8_16x128, 16, 128, 16, 127, 16, 128, 16, 127, 1, 128, 127, sizeof(int8_t), sizeof(int32_t), false, false, false, false, true}, + {"i8_16x256", (void*)LaunchTCOLARGMAX_i8_16x256, 16, 256, 15, 255, 16, 256, 15, 255, 1, 256, 255, sizeof(int8_t), sizeof(int32_t), false, false, false, false, true}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t srcElemCount = tc.srcRows * tc.srcCols; + const size_t srcFileSize = srcElemCount * tc.srcElemSize; + const size_t tmpElemCount = tc.tmpRows * tc.tmpCols; + const size_t tmpFileSize = tmpElemCount * tc.srcElemSize; + const size_t dstElemCount = tc.dstRows * tc.dstCols; + const size_t dstFileSize = dstElemCount * tc.dstElemSize; + + std::printf("[INFO] === case: %s (src=%zux%zu, tmp=%zux%zu, dst=%zux%zu, fp16=%d) ===\n", + tc.name, tc.srcRows, tc.srcCols, tc.tmpRows, tc.tmpCols, tc.dstRows, tc.dstCols, tc.isFp16); + + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSizeVar = srcFileSize; + size_t tmpFileSizeVar = tmpFileSize; + size_t dstFileSizeVar = dstFileSize; + + void *srcHost = nullptr, *tmpHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *tmpDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&srcHost, srcFileSize); + aclrtMallocHost(&tmpHost, tmpFileSize); + aclrtMallocHost(&dstHost, dstFileSize); + + aclrtMalloc(&srcDevice, srcFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&tmpDevice, tmpFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input.bin").c_str(), srcFileSizeVar, srcHost, srcFileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, srcFileSize, srcHost, srcFileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + if (tc.isFp16 || tc.isUi32 || tc.isUi16 || tc.isUi8 || tc.isI8) { + LaunchFnVoid launch = (LaunchFnVoid)tc.launch; + launch(dstDevice, tmpDevice, srcDevice, stream); + } else { + LaunchFnFloat launch = (LaunchFnFloat)tc.launch; + launch((int32_t*)dstDevice, (float*)tmpDevice, (float*)srcDevice, stream); + } + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSizeVar)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (tmpDevice != nullptr) + aclrtFree(tmpDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (tmpHost != nullptr) + aclrtFreeHost(tmpHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolargmax/tcolargmax.pto b/test/tilelang_st/npu/a5/src/st/testcase/tcolargmax/tcolargmax.pto new file mode 100644 index 000000000..027aefe09 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolargmax/tcolargmax.pto @@ -0,0 +1,926 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tcolargmax: tload(src) + tcolargmax(src, tmp, dst) + tstore(dst). +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { +// Case 0: f32 1x256 (input: 1x256, tmp: 1x256, output: 1x256 indices) + func.func @TCOLARGMAX_f32_1x256(%dst_ptr: !pto.ptr, %tmp_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf32> + %tmp_view = pto.make_tensor_view %tmp_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xf32> -> !pto.partition_tensor_view<1x1x1x1x255xf32> + %tmp_part = pto.partition_view %tmp_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xf32> -> !pto.partition_tensor_view<1x1x1x1x255xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x255xi32> + + %src = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xf32>) + outs(%src : !pto.tile_buf) + + pto.tcolargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi32>) + return + } + + // Case 1: f32 16x128 (input: 16x128, tmp: 16x128, output: 1x128 indices) + func.func @TCOLARGMAX_f32_16x128(%dst_ptr: !pto.ptr, %tmp_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf32> + %tmp_view = pto.make_tensor_view %tmp_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x127xf32> + %tmp_part = pto.partition_view %tmp_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x127xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xi32> -> !pto.partition_tensor_view<1x1x1x1x127xi32> + + %src = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xf32>) + outs(%src : !pto.tile_buf) + + pto.tcolargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xi32>) + return + } + + // Case 2: f32 16x256 (input: 16x256, tmp: 16x256, output: 1x256 indices) + func.func @TCOLARGMAX_f32_16x256(%dst_ptr: !pto.ptr, %tmp_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xf32> + %tmp_view = pto.make_tensor_view %tmp_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xf32> -> !pto.partition_tensor_view<1x1x1x15x255xf32> + %tmp_part = pto.partition_view %tmp_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xf32> -> !pto.partition_tensor_view<1x1x1x15x255xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x255xi32> + + %src = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xf32>) + outs(%src : !pto.tile_buf) + + pto.tcolargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi32>) + return + } + + // Case 3: f16 1x256 (input: 1x256, tmp: 1x256, output: 1x256 indices) + func.func @TCOLARGMAX_f16_1x256(%dst_ptr: !pto.ptr, %tmp_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + %tmp_view = pto.make_tensor_view %tmp_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x255xf16> + %tmp_part = pto.partition_view %tmp_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x255xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x255xi32> + + %src = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xf16>) + outs(%src : !pto.tile_buf) + + pto.tcolargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi32>) + return + } + + // Case 4: f16 16x128 (input: 16x128, tmp: 16x128, output: 1x128 indices) + func.func @TCOLARGMAX_f16_16x128(%dst_ptr: !pto.ptr, %tmp_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf16> + %tmp_view = pto.make_tensor_view %tmp_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xf16> -> !pto.partition_tensor_view<1x1x1x16x127xf16> + %tmp_part = pto.partition_view %tmp_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xf16> -> !pto.partition_tensor_view<1x1x1x16x127xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xi32> -> !pto.partition_tensor_view<1x1x1x1x127xi32> + + %src = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xf16>) + outs(%src : !pto.tile_buf) + + pto.tcolargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xi32>) + return + } + + // Case 5: f16 16x256 (input: 16x256, tmp: 16x256, output: 1x256 indices) + func.func @TCOLARGMAX_f16_16x256(%dst_ptr: !pto.ptr, %tmp_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xf16> + %tmp_view = pto.make_tensor_view %tmp_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xf16> -> !pto.partition_tensor_view<1x1x1x15x255xf16> + %tmp_part = pto.partition_view %tmp_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xf16> -> !pto.partition_tensor_view<1x1x1x15x255xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x255xi32> + + %src = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xf16>) + outs(%src : !pto.tile_buf) + + pto.tcolargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi32>) + return + } + + // Case 6: ui32 1x256 (input: 1x256, tmp: 1x256, output: 1x256 indices) + func.func @TCOLARGMAX_ui32_1x256(%dst_ptr: !pto.ptr, %tmp_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui32> + %tmp_view = pto.make_tensor_view %tmp_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui32> -> !pto.partition_tensor_view<1x1x1x1x255xui32> + %tmp_part = pto.partition_view %tmp_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui32> -> !pto.partition_tensor_view<1x1x1x1x255xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x255xi32> + + %src = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xui32>) + outs(%src : !pto.tile_buf) + + pto.tcolargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi32>) + return + } + + // Case 7: ui32 16x128 (input: 16x128, tmp: 16x128, output: 1x128 indices) + func.func @TCOLARGMAX_ui32_16x128(%dst_ptr: !pto.ptr, %tmp_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xui32> + %tmp_view = pto.make_tensor_view %tmp_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xui32> -> !pto.partition_tensor_view<1x1x1x16x127xui32> + %tmp_part = pto.partition_view %tmp_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xui32> -> !pto.partition_tensor_view<1x1x1x16x127xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xi32> -> !pto.partition_tensor_view<1x1x1x1x127xi32> + + %src = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xui32>) + outs(%src : !pto.tile_buf) + + pto.tcolargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xi32>) + return + } + + // Case 8: ui32 16x256 (input: 16x256, tmp: 16x256, output: 1x256 indices) + func.func @TCOLARGMAX_ui32_16x256(%dst_ptr: !pto.ptr, %tmp_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xui32> + %tmp_view = pto.make_tensor_view %tmp_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xui32> -> !pto.partition_tensor_view<1x1x1x15x255xui32> + %tmp_part = pto.partition_view %tmp_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xui32> -> !pto.partition_tensor_view<1x1x1x15x255xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x255xi32> + + %src = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xui32>) + outs(%src : !pto.tile_buf) + + pto.tcolargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi32>) + return + } + + // Case 9: ui16 1x256 (input: 1x256, tmp: 1x256, output: 1x256 indices) + func.func @TCOLARGMAX_ui16_1x256(%dst_ptr: !pto.ptr, %tmp_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui16> + %tmp_view = pto.make_tensor_view %tmp_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui16> -> !pto.partition_tensor_view<1x1x1x1x255xui16> + %tmp_part = pto.partition_view %tmp_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui16> -> !pto.partition_tensor_view<1x1x1x1x255xui16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x255xi32> + + %src = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xui16>) + outs(%src : !pto.tile_buf) + + pto.tcolargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi32>) + return + } + + // Case 10: ui16 16x128 (input: 16x128, tmp: 16x128, output: 1x128 indices) + func.func @TCOLARGMAX_ui16_16x128(%dst_ptr: !pto.ptr, %tmp_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xui16> + %tmp_view = pto.make_tensor_view %tmp_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xui16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xui16> -> !pto.partition_tensor_view<1x1x1x16x127xui16> + %tmp_part = pto.partition_view %tmp_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xui16> -> !pto.partition_tensor_view<1x1x1x16x127xui16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xi32> -> !pto.partition_tensor_view<1x1x1x1x127xi32> + + %src = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xui16>) + outs(%src : !pto.tile_buf) + + pto.tcolargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xi32>) + return + } + + // Case 11: ui16 16x256 (input: 16x256, tmp: 16x256, output: 1x256 indices) + func.func @TCOLARGMAX_ui16_16x256(%dst_ptr: !pto.ptr, %tmp_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xui16> + %tmp_view = pto.make_tensor_view %tmp_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xui16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xui16> -> !pto.partition_tensor_view<1x1x1x15x255xui16> + %tmp_part = pto.partition_view %tmp_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xui16> -> !pto.partition_tensor_view<1x1x1x15x255xui16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x255xi32> + + %src = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xui16>) + outs(%src : !pto.tile_buf) + + pto.tcolargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi32>) + return + } + + // Case 12: ui8 1x256 (input: 1x256, tmp: 1x256, output: 1x256 indices) + func.func @TCOLARGMAX_ui8_1x256(%dst_ptr: !pto.ptr, %tmp_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui8> + %tmp_view = pto.make_tensor_view %tmp_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui8> -> !pto.partition_tensor_view<1x1x1x1x255xui8> + %tmp_part = pto.partition_view %tmp_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui8> -> !pto.partition_tensor_view<1x1x1x1x255xui8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x255xi32> + + %src = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xui8>) + outs(%src : !pto.tile_buf) + + pto.tcolargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi32>) + return + } + + // Case 13: ui8 16x128 (input: 16x128, tmp: 16x128, output: 1x128 indices) + func.func @TCOLARGMAX_ui8_16x128(%dst_ptr: !pto.ptr, %tmp_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xui8> + %tmp_view = pto.make_tensor_view %tmp_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xui8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xui8> -> !pto.partition_tensor_view<1x1x1x16x127xui8> + %tmp_part = pto.partition_view %tmp_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xui8> -> !pto.partition_tensor_view<1x1x1x16x127xui8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xi32> -> !pto.partition_tensor_view<1x1x1x1x127xi32> + + %src = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xui8>) + outs(%src : !pto.tile_buf) + + pto.tcolargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xi32>) + return + } + + // Case 14: ui8 16x256 (input: 16x256, tmp: 16x256, output: 1x256 indices) + func.func @TCOLARGMAX_ui8_16x256(%dst_ptr: !pto.ptr, %tmp_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xui8> + %tmp_view = pto.make_tensor_view %tmp_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xui8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xui8> -> !pto.partition_tensor_view<1x1x1x15x255xui8> + %tmp_part = pto.partition_view %tmp_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xui8> -> !pto.partition_tensor_view<1x1x1x15x255xui8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x255xi32> + + %src = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xui8>) + outs(%src : !pto.tile_buf) + + pto.tcolargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi32>) + return + } + + // Case 15: i8 1x256 (input: 1x256, tmp: 1x256, output: 1x256 indices) + func.func @TCOLARGMAX_i8_1x256(%dst_ptr: !pto.ptr, %tmp_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi8> + %tmp_view = pto.make_tensor_view %tmp_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi8> -> !pto.partition_tensor_view<1x1x1x1x255xi8> + %tmp_part = pto.partition_view %tmp_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi8> -> !pto.partition_tensor_view<1x1x1x1x255xi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x255xi32> + + %src = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xi8>) + outs(%src : !pto.tile_buf) + + pto.tcolargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi32>) + return + } + + // Case 16: i8 16x128 (input: 16x128, tmp: 16x128, output: 1x128 indices) + func.func @TCOLARGMAX_i8_16x128(%dst_ptr: !pto.ptr, %tmp_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xi8> + %tmp_view = pto.make_tensor_view %tmp_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xi8> -> !pto.partition_tensor_view<1x1x1x16x127xi8> + %tmp_part = pto.partition_view %tmp_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xi8> -> !pto.partition_tensor_view<1x1x1x16x127xi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xi32> -> !pto.partition_tensor_view<1x1x1x1x127xi32> + + %src = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xi8>) + outs(%src : !pto.tile_buf) + + pto.tcolargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xi32>) + return + } + + // Case 17: i8 16x256 (input: 16x256, tmp: 16x256, output: 1x256 indices) + func.func @TCOLARGMAX_i8_16x256(%dst_ptr: !pto.ptr, %tmp_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xi8> + %tmp_view = pto.make_tensor_view %tmp_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xi8> -> !pto.partition_tensor_view<1x1x1x15x255xi8> + %tmp_part = pto.partition_view %tmp_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xi8> -> !pto.partition_tensor_view<1x1x1x15x255xi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x255xi32> + + %src = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xi8>) + outs(%src : !pto.tile_buf) + + pto.tcolargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi32>) + return + } +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolargmin/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tcolargmin/CMakeLists.txt new file mode 100644 index 000000000..260ee1ded --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolargmin/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tcolargmin) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolargmin/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolargmin/cases.py new file mode 100644 index 000000000..e5847c346 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolargmin/cases.py @@ -0,0 +1,210 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tcolargmin ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype for input data (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions for input (src and tmp). + - valid_shape: (valid_rows, valid_cols) — effective computation region for input (src and tmp). + - dst_shape: (1, cols) — allocated tile dimensions for output (indices). + - dst_valid_shape: (1, valid_cols) — effective computation region for output (indices). + - dst_dtype: numpy dtype for output indices (np.int32). + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_1x256", + "dtype": np.float32, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "dst_dtype": np.int32, + "eps": 0, + }, + { + "name": "f32_16x128", + "dtype": np.float32, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "dst_dtype": np.int32, + "eps": 0, + }, + { + "name": "f32_16x256", + "dtype": np.float32, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "dst_dtype": np.int32, + "eps": 0, + }, + { + "name": "f16_1x256", + "dtype": np.float16, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "dst_dtype": np.int32, + "eps": 0, + }, + { + "name": "f16_16x128", + "dtype": np.float16, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "dst_dtype": np.int32, + "eps": 0, + }, + { + "name": "f16_16x256", + "dtype": np.float16, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "dst_dtype": np.int32, + "eps": 0, + }, + { + "name": "ui32_1x256", + "dtype": np.uint32, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "dst_dtype": np.int32, + "eps": 0, + }, + { + "name": "ui32_16x128", + "dtype": np.uint32, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "dst_dtype": np.int32, + "eps": 0, + }, + { + "name": "ui32_16x256", + "dtype": np.uint32, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "dst_dtype": np.int32, + "eps": 0, + }, + { + "name": "ui16_1x256", + "dtype": np.uint16, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "dst_dtype": np.int32, + "eps": 0, + }, + { + "name": "ui16_16x128", + "dtype": np.uint16, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "dst_dtype": np.int32, + "eps": 0, + }, + { + "name": "ui16_16x256", + "dtype": np.uint16, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "dst_dtype": np.int32, + "eps": 0, + }, + { + "name": "ui8_1x256", + "dtype": np.uint8, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "dst_dtype": np.int32, + "eps": 0, + }, + { + "name": "ui8_16x128", + "dtype": np.uint8, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "dst_dtype": np.int32, + "eps": 0, + }, + { + "name": "ui8_16x256", + "dtype": np.uint8, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "dst_dtype": np.int32, + "eps": 0, + }, + { + "name": "i8_1x256", + "dtype": np.int8, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "dst_dtype": np.int32, + "eps": 0, + }, + { + "name": "i8_16x128", + "dtype": np.int8, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "dst_dtype": np.int32, + "eps": 0, + }, + { + "name": "i8_16x256", + "dtype": np.int8, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "dst_dtype": np.int32, + "eps": 0, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolargmin/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolargmin/compare.py new file mode 100644 index 000000000..5be940218 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolargmin/compare.py @@ -0,0 +1,51 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + dst_dtype = case["dst_dtype"] + vr, vc = dst_valid_shape + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=dst_dtype).reshape(dst_shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=dst_dtype).reshape(dst_shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolargmin/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolargmin/gen_data.py new file mode 100644 index 000000000..2b38a6372 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolargmin/gen_data.py @@ -0,0 +1,36 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + dst_dtype = case["dst_dtype"] + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + + vr, vc = valid_shape + golden = np.zeros(dst_shape, dtype=dst_dtype) + golden_result = np.argmin(input1[:vr, :vc], axis=0, keepdims=True).astype(dst_dtype) + golden[:1, :vc] = golden_result + + save_case_data(case["name"], {"input": input1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dst_shape={dst_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolargmin/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolargmin/launch.cpp new file mode 100644 index 000000000..b204c98cf --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolargmin/launch.cpp @@ -0,0 +1,139 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: f32 1x256 (input: 1x256, tmp: 1x256, output: 1x256 indices) +extern "C" __global__ AICORE void TCOLARGMIN_f32_1x256(__gm__ int32_t *dst, __gm__ float *tmp, __gm__ float *src); + +void LaunchTCOLARGMIN_f32_1x256(int32_t *dst, float *tmp, float *src, void *stream) { + TCOLARGMIN_f32_1x256<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ float *)tmp, (__gm__ float *)src); +} + +// Case 1: f32 16x128 (input: 16x128, tmp: 16x128, output: 1x128 indices) +extern "C" __global__ AICORE void TCOLARGMIN_f32_16x128(__gm__ int32_t *dst, __gm__ float *tmp, __gm__ float *src); + +void LaunchTCOLARGMIN_f32_16x128(int32_t *dst, float *tmp, float *src, void *stream) { + TCOLARGMIN_f32_16x128<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ float *)tmp, (__gm__ float *)src); +} + +// Case 2: f32 16x256 (input: 16x256, tmp: 16x256, output: 1x256 indices) +extern "C" __global__ AICORE void TCOLARGMIN_f32_16x256(__gm__ int32_t *dst, __gm__ float *tmp, __gm__ float *src); + +void LaunchTCOLARGMIN_f32_16x256(int32_t *dst, float *tmp, float *src, void *stream) { + TCOLARGMIN_f32_16x256<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ float *)tmp, (__gm__ float *)src); +} + +// Case 3: f16 1x256 (input: 1x256, tmp: 1x256, output: 1x256 indices) +extern "C" __global__ AICORE void TCOLARGMIN_f16_1x256(__gm__ int32_t *dst, __gm__ half *tmp, __gm__ half *src); + +void LaunchTCOLARGMIN_f16_1x256(void *dst, void *tmp, void *src, void *stream) { + TCOLARGMIN_f16_1x256<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ half *)tmp, (__gm__ half *)src); +} + +// Case 4: f16 16x128 (input: 16x128, tmp: 16x128, output: 1x128 indices) +extern "C" __global__ AICORE void TCOLARGMIN_f16_16x128(__gm__ int32_t *dst, __gm__ half *tmp, __gm__ half *src); + +void LaunchTCOLARGMIN_f16_16x128(void *dst, void *tmp, void *src, void *stream) { + TCOLARGMIN_f16_16x128<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ half *)tmp, (__gm__ half *)src); +} + +// Case 5: f16 16x256 (input: 16x256, tmp: 16x256, output: 1x256 indices) +extern "C" __global__ AICORE void TCOLARGMIN_f16_16x256(__gm__ int32_t *dst, __gm__ half *tmp, __gm__ half *src); + +void LaunchTCOLARGMIN_f16_16x256(void *dst, void *tmp, void *src, void *stream) { + TCOLARGMIN_f16_16x256<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ half *)tmp, (__gm__ half *)src); +} + +// Case 6: ui32 1x256 (input: 1x256, tmp: 1x256, output: 1x256 indices) +extern "C" __global__ AICORE void TCOLARGMIN_ui32_1x256(__gm__ int32_t *dst, __gm__ uint32_t *tmp, __gm__ uint32_t *src); + +void LaunchTCOLARGMIN_ui32_1x256(void *dst, void *tmp, void *src, void *stream) { + TCOLARGMIN_ui32_1x256<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ uint32_t *)tmp, (__gm__ uint32_t *)src); +} + +// Case 7: ui32 16x128 (input: 16x128, tmp: 16x128, output: 1x128 indices) +extern "C" __global__ AICORE void TCOLARGMIN_ui32_16x128(__gm__ int32_t *dst, __gm__ uint32_t *tmp, __gm__ uint32_t *src); + +void LaunchTCOLARGMIN_ui32_16x128(void *dst, void *tmp, void *src, void *stream) { + TCOLARGMIN_ui32_16x128<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ uint32_t *)tmp, (__gm__ uint32_t *)src); +} + +// Case 8: ui32 16x256 (input: 16x256, tmp: 16x256, output: 1x256 indices) +extern "C" __global__ AICORE void TCOLARGMIN_ui32_16x256(__gm__ int32_t *dst, __gm__ uint32_t *tmp, __gm__ uint32_t *src); + +void LaunchTCOLARGMIN_ui32_16x256(void *dst, void *tmp, void *src, void *stream) { + TCOLARGMIN_ui32_16x256<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ uint32_t *)tmp, (__gm__ uint32_t *)src); +} + +// Case 9: ui16 1x256 (input: 1x256, tmp: 1x256, output: 1x256 indices) +extern "C" __global__ AICORE void TCOLARGMIN_ui16_1x256(__gm__ int32_t *dst, __gm__ uint16_t *tmp, __gm__ uint16_t *src); + +void LaunchTCOLARGMIN_ui16_1x256(void *dst, void *tmp, void *src, void *stream) { + TCOLARGMIN_ui16_1x256<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ uint16_t *)tmp, (__gm__ uint16_t *)src); +} + +// Case 10: ui16 16x128 (input: 16x128, tmp: 16x128, output: 1x128 indices) +extern "C" __global__ AICORE void TCOLARGMIN_ui16_16x128(__gm__ int32_t *dst, __gm__ uint16_t *tmp, __gm__ uint16_t *src); + +void LaunchTCOLARGMIN_ui16_16x128(void *dst, void *tmp, void *src, void *stream) { + TCOLARGMIN_ui16_16x128<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ uint16_t *)tmp, (__gm__ uint16_t *)src); +} + +// Case 11: ui16 16x256 (input: 16x256, tmp: 16x256, output: 1x256 indices) +extern "C" __global__ AICORE void TCOLARGMIN_ui16_16x256(__gm__ int32_t *dst, __gm__ uint16_t *tmp, __gm__ uint16_t *src); + +void LaunchTCOLARGMIN_ui16_16x256(void *dst, void *tmp, void *src, void *stream) { + TCOLARGMIN_ui16_16x256<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ uint16_t *)tmp, (__gm__ uint16_t *)src); +} + +// Case 12: ui8 1x256 (input: 1x256, tmp: 1x256, output: 1x256 indices) +extern "C" __global__ AICORE void TCOLARGMIN_ui8_1x256(__gm__ int32_t *dst, __gm__ uint8_t *tmp, __gm__ uint8_t *src); + +void LaunchTCOLARGMIN_ui8_1x256(void *dst, void *tmp, void *src, void *stream) { + TCOLARGMIN_ui8_1x256<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ uint8_t *)tmp, (__gm__ uint8_t *)src); +} + +// Case 13: ui8 16x128 (input: 16x128, tmp: 16x128, output: 1x128 indices) +extern "C" __global__ AICORE void TCOLARGMIN_ui8_16x128(__gm__ int32_t *dst, __gm__ uint8_t *tmp, __gm__ uint8_t *src); + +void LaunchTCOLARGMIN_ui8_16x128(void *dst, void *tmp, void *src, void *stream) { + TCOLARGMIN_ui8_16x128<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ uint8_t *)tmp, (__gm__ uint8_t *)src); +} + +// Case 14: ui8 16x256 (input: 16x256, tmp: 16x256, output: 1x256 indices) +extern "C" __global__ AICORE void TCOLARGMIN_ui8_16x256(__gm__ int32_t *dst, __gm__ uint8_t *tmp, __gm__ uint8_t *src); + +void LaunchTCOLARGMIN_ui8_16x256(void *dst, void *tmp, void *src, void *stream) { + TCOLARGMIN_ui8_16x256<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ uint8_t *)tmp, (__gm__ uint8_t *)src); +} + +// Case 15: i8 1x256 (input: 1x256, tmp: 1x256, output: 1x256 indices) +extern "C" __global__ AICORE void TCOLARGMIN_i8_1x256(__gm__ int32_t *dst, __gm__ int8_t *tmp, __gm__ int8_t *src); + +void LaunchTCOLARGMIN_i8_1x256(void *dst, void *tmp, void *src, void *stream) { + TCOLARGMIN_i8_1x256<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ int8_t *)tmp, (__gm__ int8_t *)src); +} + +// Case 16: i8 16x128 (input: 16x128, tmp: 16x128, output: 1x128 indices) +extern "C" __global__ AICORE void TCOLARGMIN_i8_16x128(__gm__ int32_t *dst, __gm__ int8_t *tmp, __gm__ int8_t *src); + +void LaunchTCOLARGMIN_i8_16x128(void *dst, void *tmp, void *src, void *stream) { + TCOLARGMIN_i8_16x128<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ int8_t *)tmp, (__gm__ int8_t *)src); +} + +// Case 17: i8 16x256 (input: 16x256, tmp: 16x256, output: 1x256 indices) +extern "C" __global__ AICORE void TCOLARGMIN_i8_16x256(__gm__ int32_t *dst, __gm__ int8_t *tmp, __gm__ int8_t *src); + +void LaunchTCOLARGMIN_i8_16x256(void *dst, void *tmp, void *src, void *stream) { + TCOLARGMIN_i8_16x256<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ int8_t *)tmp, (__gm__ int8_t *)src); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolargmin/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolargmin/main.cpp new file mode 100644 index 000000000..2347e7717 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolargmin/main.cpp @@ -0,0 +1,195 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tcolargmin ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTCOLARGMIN_f32_1x256(int32_t *dst, float *tmp, float *src, void *stream); +void LaunchTCOLARGMIN_f32_16x128(int32_t *dst, float *tmp, float *src, void *stream); +void LaunchTCOLARGMIN_f32_16x256(int32_t *dst, float *tmp, float *src, void *stream); +void LaunchTCOLARGMIN_f16_1x256(void *dst, void *tmp, void *src, void *stream); +void LaunchTCOLARGMIN_f16_16x128(void *dst, void *tmp, void *src, void *stream); +void LaunchTCOLARGMIN_f16_16x256(void *dst, void *tmp, void *src, void *stream); +void LaunchTCOLARGMIN_ui32_1x256(void *dst, void *tmp, void *src, void *stream); +void LaunchTCOLARGMIN_ui32_16x128(void *dst, void *tmp, void *src, void *stream); +void LaunchTCOLARGMIN_ui32_16x256(void *dst, void *tmp, void *src, void *stream); +void LaunchTCOLARGMIN_ui16_1x256(void *dst, void *tmp, void *src, void *stream); +void LaunchTCOLARGMIN_ui16_16x128(void *dst, void *tmp, void *src, void *stream); +void LaunchTCOLARGMIN_ui16_16x256(void *dst, void *tmp, void *src, void *stream); +void LaunchTCOLARGMIN_ui8_1x256(void *dst, void *tmp, void *src, void *stream); +void LaunchTCOLARGMIN_ui8_16x128(void *dst, void *tmp, void *src, void *stream); +void LaunchTCOLARGMIN_ui8_16x256(void *dst, void *tmp, void *src, void *stream); +void LaunchTCOLARGMIN_i8_1x256(void *dst, void *tmp, void *src, void *stream); +void LaunchTCOLARGMIN_i8_16x128(void *dst, void *tmp, void *src, void *stream); +void LaunchTCOLARGMIN_i8_16x256(void *dst, void *tmp, void *src, void *stream); + +using LaunchFnFloat = void (*)(int32_t *, float *, float *, void *); +using LaunchFnVoid = void (*)(void *, void *, void *, void *); + +struct TestCase { + const char *name; + void *launch; + size_t srcRows; + size_t srcCols; + size_t srcValidRows; + size_t srcValidCols; + size_t tmpRows; + size_t tmpCols; + size_t tmpValidRows; + size_t tmpValidCols; + size_t dstRows; + size_t dstCols; + size_t dstValidCols; + size_t srcElemSize; + size_t dstElemSize; + bool isFp16; + bool isUi32; + bool isUi16; + bool isUi8; + bool isI8; +}; + +static const TestCase kCases[] = { + {"f32_1x256", (void*)LaunchTCOLARGMIN_f32_1x256, 1, 256, 1, 255, 1, 256, 1, 255, 1, 256, 255, sizeof(float), sizeof(int32_t), false, false, false, false, false}, + {"f32_16x128", (void*)LaunchTCOLARGMIN_f32_16x128, 16, 128, 16, 127, 16, 128, 16, 127, 1, 128, 127, sizeof(float), sizeof(int32_t), false, false, false, false, false}, + {"f32_16x256", (void*)LaunchTCOLARGMIN_f32_16x256, 16, 256, 15, 255, 16, 256, 15, 255, 1, 256, 255, sizeof(float), sizeof(int32_t), false, false, false, false, false}, + {"f16_1x256", (void*)LaunchTCOLARGMIN_f16_1x256, 1, 256, 1, 255, 1, 256, 1, 255, 1, 256, 255, 2, sizeof(int32_t), true, false, false, false, false}, + {"f16_16x128", (void*)LaunchTCOLARGMIN_f16_16x128, 16, 128, 16, 127, 16, 128, 16, 127, 1, 128, 127, 2, sizeof(int32_t), true, false, false, false, false}, + {"f16_16x256", (void*)LaunchTCOLARGMIN_f16_16x256, 16, 256, 15, 255, 16, 256, 15, 255, 1, 256, 255, 2, sizeof(int32_t), true, false, false, false, false}, + {"ui32_1x256", (void*)LaunchTCOLARGMIN_ui32_1x256, 1, 256, 1, 255, 1, 256, 1, 255, 1, 256, 255, sizeof(uint32_t), sizeof(int32_t), false, true, false, false, false}, + {"ui32_16x128", (void*)LaunchTCOLARGMIN_ui32_16x128, 16, 128, 16, 127, 16, 128, 16, 127, 1, 128, 127, sizeof(uint32_t), sizeof(int32_t), false, true, false, false, false}, + {"ui32_16x256", (void*)LaunchTCOLARGMIN_ui32_16x256, 16, 256, 15, 255, 16, 256, 15, 255, 1, 256, 255, sizeof(uint32_t), sizeof(int32_t), false, true, false, false, false}, + {"ui16_1x256", (void*)LaunchTCOLARGMIN_ui16_1x256, 1, 256, 1, 255, 1, 256, 1, 255, 1, 256, 255, sizeof(uint16_t), sizeof(int32_t), false, false, true, false, false}, + {"ui16_16x128", (void*)LaunchTCOLARGMIN_ui16_16x128, 16, 128, 16, 127, 16, 128, 16, 127, 1, 128, 127, sizeof(uint16_t), sizeof(int32_t), false, false, true, false, false}, + {"ui16_16x256", (void*)LaunchTCOLARGMIN_ui16_16x256, 16, 256, 15, 255, 16, 256, 15, 255, 1, 256, 255, sizeof(uint16_t), sizeof(int32_t), false, false, true, false, false}, + {"ui8_1x256", (void*)LaunchTCOLARGMIN_ui8_1x256, 1, 256, 1, 255, 1, 256, 1, 255, 1, 256, 255, sizeof(uint8_t), sizeof(int32_t), false, false, false, true, false}, + {"ui8_16x128", (void*)LaunchTCOLARGMIN_ui8_16x128, 16, 128, 16, 127, 16, 128, 16, 127, 1, 128, 127, sizeof(uint8_t), sizeof(int32_t), false, false, false, true, false}, + {"ui8_16x256", (void*)LaunchTCOLARGMIN_ui8_16x256, 16, 256, 15, 255, 16, 256, 15, 255, 1, 256, 255, sizeof(uint8_t), sizeof(int32_t), false, false, false, true, false}, + {"i8_1x256", (void*)LaunchTCOLARGMIN_i8_1x256, 1, 256, 1, 255, 1, 256, 1, 255, 1, 256, 255, sizeof(int8_t), sizeof(int32_t), false, false, false, false, true}, + {"i8_16x128", (void*)LaunchTCOLARGMIN_i8_16x128, 16, 128, 16, 127, 16, 128, 16, 127, 1, 128, 127, sizeof(int8_t), sizeof(int32_t), false, false, false, false, true}, + {"i8_16x256", (void*)LaunchTCOLARGMIN_i8_16x256, 16, 256, 15, 255, 16, 256, 15, 255, 1, 256, 255, sizeof(int8_t), sizeof(int32_t), false, false, false, false, true}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t srcElemCount = tc.srcRows * tc.srcCols; + const size_t srcFileSize = srcElemCount * tc.srcElemSize; + const size_t tmpElemCount = tc.tmpRows * tc.tmpCols; + const size_t tmpFileSize = tmpElemCount * tc.srcElemSize; + const size_t dstElemCount = tc.dstRows * tc.dstCols; + const size_t dstFileSize = dstElemCount * tc.dstElemSize; + + std::printf("[INFO] === case: %s (src=%zux%zu, tmp=%zux%zu, dst=%zux%zu, fp16=%d) ===\n", + tc.name, tc.srcRows, tc.srcCols, tc.tmpRows, tc.tmpCols, tc.dstRows, tc.dstCols, tc.isFp16); + + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSizeVar = srcFileSize; + size_t tmpFileSizeVar = tmpFileSize; + size_t dstFileSizeVar = dstFileSize; + + void *srcHost = nullptr, *tmpHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *tmpDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&srcHost, srcFileSize); + aclrtMallocHost(&tmpHost, tmpFileSize); + aclrtMallocHost(&dstHost, dstFileSize); + + aclrtMalloc(&srcDevice, srcFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&tmpDevice, tmpFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input.bin").c_str(), srcFileSizeVar, srcHost, srcFileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, srcFileSize, srcHost, srcFileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + if (tc.isFp16 || tc.isUi32 || tc.isUi16 || tc.isUi8 || tc.isI8) { + LaunchFnVoid launch = (LaunchFnVoid)tc.launch; + launch(dstDevice, tmpDevice, srcDevice, stream); + } else { + LaunchFnFloat launch = (LaunchFnFloat)tc.launch; + launch((int32_t*)dstDevice, (float*)tmpDevice, (float*)srcDevice, stream); + } + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSizeVar)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (tmpDevice != nullptr) + aclrtFree(tmpDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (tmpHost != nullptr) + aclrtFreeHost(tmpHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolargmin/tcolargmin.pto b/test/tilelang_st/npu/a5/src/st/testcase/tcolargmin/tcolargmin.pto new file mode 100644 index 000000000..8e76e0af1 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolargmin/tcolargmin.pto @@ -0,0 +1,926 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tcolargmin: tload(src) + tcolargmin(src, tmp, dst) + tstore(dst). +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { +// Case 0: f32 1x256 (input: 1x256, tmp: 1x256, output: 1x256 indices) + func.func @TCOLARGMIN_f32_1x256(%dst_ptr: !pto.ptr, %tmp_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf32> + %tmp_view = pto.make_tensor_view %tmp_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xf32> -> !pto.partition_tensor_view<1x1x1x1x255xf32> + %tmp_part = pto.partition_view %tmp_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xf32> -> !pto.partition_tensor_view<1x1x1x1x255xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x255xi32> + + %src = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xf32>) + outs(%src : !pto.tile_buf) + + pto.tcolargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi32>) + return + } + + // Case 1: f32 16x128 (input: 16x128, tmp: 16x128, output: 1x128 indices) + func.func @TCOLARGMIN_f32_16x128(%dst_ptr: !pto.ptr, %tmp_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf32> + %tmp_view = pto.make_tensor_view %tmp_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x127xf32> + %tmp_part = pto.partition_view %tmp_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x127xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xi32> -> !pto.partition_tensor_view<1x1x1x1x127xi32> + + %src = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xf32>) + outs(%src : !pto.tile_buf) + + pto.tcolargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xi32>) + return + } + + // Case 2: f32 16x256 (input: 16x256, tmp: 16x256, output: 1x256 indices) + func.func @TCOLARGMIN_f32_16x256(%dst_ptr: !pto.ptr, %tmp_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xf32> + %tmp_view = pto.make_tensor_view %tmp_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xf32> -> !pto.partition_tensor_view<1x1x1x15x255xf32> + %tmp_part = pto.partition_view %tmp_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xf32> -> !pto.partition_tensor_view<1x1x1x15x255xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x255xi32> + + %src = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xf32>) + outs(%src : !pto.tile_buf) + + pto.tcolargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi32>) + return + } + + // Case 3: f16 1x256 (input: 1x256, tmp: 1x256, output: 1x256 indices) + func.func @TCOLARGMIN_f16_1x256(%dst_ptr: !pto.ptr, %tmp_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + %tmp_view = pto.make_tensor_view %tmp_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x255xf16> + %tmp_part = pto.partition_view %tmp_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x255xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x255xi32> + + %src = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xf16>) + outs(%src : !pto.tile_buf) + + pto.tcolargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi32>) + return + } + + // Case 4: f16 16x128 (input: 16x128, tmp: 16x128, output: 1x128 indices) + func.func @TCOLARGMIN_f16_16x128(%dst_ptr: !pto.ptr, %tmp_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf16> + %tmp_view = pto.make_tensor_view %tmp_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xf16> -> !pto.partition_tensor_view<1x1x1x16x127xf16> + %tmp_part = pto.partition_view %tmp_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xf16> -> !pto.partition_tensor_view<1x1x1x16x127xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xi32> -> !pto.partition_tensor_view<1x1x1x1x127xi32> + + %src = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xf16>) + outs(%src : !pto.tile_buf) + + pto.tcolargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xi32>) + return + } + + // Case 5: f16 16x256 (input: 16x256, tmp: 16x256, output: 1x256 indices) + func.func @TCOLARGMIN_f16_16x256(%dst_ptr: !pto.ptr, %tmp_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xf16> + %tmp_view = pto.make_tensor_view %tmp_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xf16> -> !pto.partition_tensor_view<1x1x1x15x255xf16> + %tmp_part = pto.partition_view %tmp_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xf16> -> !pto.partition_tensor_view<1x1x1x15x255xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x255xi32> + + %src = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xf16>) + outs(%src : !pto.tile_buf) + + pto.tcolargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi32>) + return + } + + // Case 6: ui32 1x256 (input: 1x256, tmp: 1x256, output: 1x256 indices) + func.func @TCOLARGMIN_ui32_1x256(%dst_ptr: !pto.ptr, %tmp_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui32> + %tmp_view = pto.make_tensor_view %tmp_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui32> -> !pto.partition_tensor_view<1x1x1x1x255xui32> + %tmp_part = pto.partition_view %tmp_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui32> -> !pto.partition_tensor_view<1x1x1x1x255xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x255xi32> + + %src = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xui32>) + outs(%src : !pto.tile_buf) + + pto.tcolargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi32>) + return + } + + // Case 7: ui32 16x128 (input: 16x128, tmp: 16x128, output: 1x128 indices) + func.func @TCOLARGMIN_ui32_16x128(%dst_ptr: !pto.ptr, %tmp_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xui32> + %tmp_view = pto.make_tensor_view %tmp_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xui32> -> !pto.partition_tensor_view<1x1x1x16x127xui32> + %tmp_part = pto.partition_view %tmp_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xui32> -> !pto.partition_tensor_view<1x1x1x16x127xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xi32> -> !pto.partition_tensor_view<1x1x1x1x127xi32> + + %src = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xui32>) + outs(%src : !pto.tile_buf) + + pto.tcolargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xi32>) + return + } + + // Case 8: ui32 16x256 (input: 16x256, tmp: 16x256, output: 1x256 indices) + func.func @TCOLARGMIN_ui32_16x256(%dst_ptr: !pto.ptr, %tmp_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xui32> + %tmp_view = pto.make_tensor_view %tmp_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xui32> -> !pto.partition_tensor_view<1x1x1x15x255xui32> + %tmp_part = pto.partition_view %tmp_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xui32> -> !pto.partition_tensor_view<1x1x1x15x255xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x255xi32> + + %src = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xui32>) + outs(%src : !pto.tile_buf) + + pto.tcolargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi32>) + return + } + + // Case 9: ui16 1x256 (input: 1x256, tmp: 1x256, output: 1x256 indices) + func.func @TCOLARGMIN_ui16_1x256(%dst_ptr: !pto.ptr, %tmp_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui16> + %tmp_view = pto.make_tensor_view %tmp_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui16> -> !pto.partition_tensor_view<1x1x1x1x255xui16> + %tmp_part = pto.partition_view %tmp_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui16> -> !pto.partition_tensor_view<1x1x1x1x255xui16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x255xi32> + + %src = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xui16>) + outs(%src : !pto.tile_buf) + + pto.tcolargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi32>) + return + } + + // Case 10: ui16 16x128 (input: 16x128, tmp: 16x128, output: 1x128 indices) + func.func @TCOLARGMIN_ui16_16x128(%dst_ptr: !pto.ptr, %tmp_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xui16> + %tmp_view = pto.make_tensor_view %tmp_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xui16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xui16> -> !pto.partition_tensor_view<1x1x1x16x127xui16> + %tmp_part = pto.partition_view %tmp_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xui16> -> !pto.partition_tensor_view<1x1x1x16x127xui16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xi32> -> !pto.partition_tensor_view<1x1x1x1x127xi32> + + %src = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xui16>) + outs(%src : !pto.tile_buf) + + pto.tcolargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xi32>) + return + } + + // Case 11: ui16 16x256 (input: 16x256, tmp: 16x256, output: 1x256 indices) + func.func @TCOLARGMIN_ui16_16x256(%dst_ptr: !pto.ptr, %tmp_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xui16> + %tmp_view = pto.make_tensor_view %tmp_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xui16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xui16> -> !pto.partition_tensor_view<1x1x1x15x255xui16> + %tmp_part = pto.partition_view %tmp_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xui16> -> !pto.partition_tensor_view<1x1x1x15x255xui16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x255xi32> + + %src = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xui16>) + outs(%src : !pto.tile_buf) + + pto.tcolargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi32>) + return + } + + // Case 12: ui8 1x256 (input: 1x256, tmp: 1x256, output: 1x256 indices) + func.func @TCOLARGMIN_ui8_1x256(%dst_ptr: !pto.ptr, %tmp_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui8> + %tmp_view = pto.make_tensor_view %tmp_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui8> -> !pto.partition_tensor_view<1x1x1x1x255xui8> + %tmp_part = pto.partition_view %tmp_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui8> -> !pto.partition_tensor_view<1x1x1x1x255xui8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x255xi32> + + %src = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xui8>) + outs(%src : !pto.tile_buf) + + pto.tcolargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi32>) + return + } + + // Case 13: ui8 16x128 (input: 16x128, tmp: 16x128, output: 1x128 indices) + func.func @TCOLARGMIN_ui8_16x128(%dst_ptr: !pto.ptr, %tmp_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xui8> + %tmp_view = pto.make_tensor_view %tmp_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xui8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xui8> -> !pto.partition_tensor_view<1x1x1x16x127xui8> + %tmp_part = pto.partition_view %tmp_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xui8> -> !pto.partition_tensor_view<1x1x1x16x127xui8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xi32> -> !pto.partition_tensor_view<1x1x1x1x127xi32> + + %src = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xui8>) + outs(%src : !pto.tile_buf) + + pto.tcolargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xi32>) + return + } + + // Case 14: ui8 16x256 (input: 16x256, tmp: 16x256, output: 1x256 indices) + func.func @TCOLARGMIN_ui8_16x256(%dst_ptr: !pto.ptr, %tmp_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xui8> + %tmp_view = pto.make_tensor_view %tmp_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xui8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xui8> -> !pto.partition_tensor_view<1x1x1x15x255xui8> + %tmp_part = pto.partition_view %tmp_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xui8> -> !pto.partition_tensor_view<1x1x1x15x255xui8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x255xi32> + + %src = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xui8>) + outs(%src : !pto.tile_buf) + + pto.tcolargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi32>) + return + } + + // Case 15: i8 1x256 (input: 1x256, tmp: 1x256, output: 1x256 indices) + func.func @TCOLARGMIN_i8_1x256(%dst_ptr: !pto.ptr, %tmp_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi8> + %tmp_view = pto.make_tensor_view %tmp_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi8> -> !pto.partition_tensor_view<1x1x1x1x255xi8> + %tmp_part = pto.partition_view %tmp_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi8> -> !pto.partition_tensor_view<1x1x1x1x255xi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x255xi32> + + %src = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xi8>) + outs(%src : !pto.tile_buf) + + pto.tcolargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi32>) + return + } + + // Case 16: i8 16x128 (input: 16x128, tmp: 16x128, output: 1x128 indices) + func.func @TCOLARGMIN_i8_16x128(%dst_ptr: !pto.ptr, %tmp_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xi8> + %tmp_view = pto.make_tensor_view %tmp_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xi8> -> !pto.partition_tensor_view<1x1x1x16x127xi8> + %tmp_part = pto.partition_view %tmp_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xi8> -> !pto.partition_tensor_view<1x1x1x16x127xi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xi32> -> !pto.partition_tensor_view<1x1x1x1x127xi32> + + %src = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xi8>) + outs(%src : !pto.tile_buf) + + pto.tcolargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xi32>) + return + } + + // Case 17: i8 16x256 (input: 16x256, tmp: 16x256, output: 1x256 indices) + func.func @TCOLARGMIN_i8_16x256(%dst_ptr: !pto.ptr, %tmp_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xi8> + %tmp_view = pto.make_tensor_view %tmp_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xi8> -> !pto.partition_tensor_view<1x1x1x15x255xi8> + %tmp_part = pto.partition_view %tmp_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xi8> -> !pto.partition_tensor_view<1x1x1x15x255xi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x255xi32> + + %src = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xi8>) + outs(%src : !pto.tile_buf) + + pto.tcolargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi32>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpand/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpand/CMakeLists.txt new file mode 100644 index 000000000..f0c992931 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpand/CMakeLists.txt @@ -0,0 +1,16 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file in the compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tcolexpand) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpand/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpand/cases.py new file mode 100644 index 000000000..2ea12b560 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpand/cases.py @@ -0,0 +1,75 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tcolexpand ST test cases. +Matches PTO-ISA testcase definitions in /home/zhoushaofan/code/pto-isa/tests/npu/a5/src/st/testcase/tcolexpand/ + +TCOLEXPAND: expand src first row to dst all rows by broadcasting. + - src_shape: (src_row, cols) - input tile (only first row is used for broadcast) + - dst_shape: (dst_row, cols) - expanded output + - shape: (dst_row, cols) - alias of dst_shape, for compare.py compatibility + - valid_shape: (valid_row, valid_col) - effective computation region + +Case naming: {dtype}_{src_row}_{dst_row}_{cols}_{valid_col} +""" + +import numpy as np + +CASES = [ + { + "name": "half_1_16_512_512", + "dtype": np.float16, + "src_shape": (1, 512), + "shape": (16, 512), + "valid_shape": (16, 512), + "eps": 1e-3, + }, + { + "name": "int8_2_32_256_255", + "dtype": np.int8, + "src_shape": (2, 256), + "shape": (32, 256), + "valid_shape": (32, 255), + "eps": 0, + }, + { + "name": "float_1_8_128_63", + "dtype": np.float32, + "src_shape": (1, 128), + "shape": (8, 128), + "valid_shape": (8, 63), + "eps": 1e-6, + }, + { + "name": "half_1_33_512_512", + "dtype": np.float16, + "src_shape": (1, 512), + "shape": (33, 512), + "valid_shape": (33, 512), + "eps": 1e-3, + }, + { + "name": "int8_2_17_256_44", + "dtype": np.int8, + "src_shape": (2, 256), + "shape": (17, 256), + "valid_shape": (17, 44), + "eps": 0, + }, + { + "name": "float_1_54_64_63", + "dtype": np.float32, + "src_shape": (1, 64), + "shape": (54, 64), + "valid_shape": (54, 63), + "eps": 1e-6, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpand/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpand/compare.py new file mode 100644 index 000000000..b8ddb1131 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpand/compare.py @@ -0,0 +1,56 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You can not use this file in the compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpand/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpand/gen_data.py new file mode 100644 index 000000000..7f727226d --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpand/gen_data.py @@ -0,0 +1,34 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + src_shape = case["src_shape"] + dst_shape = case["shape"] + valid_shape = case["valid_shape"] + + src = np.random.randint(1, 10, size=src_shape).astype(dtype) + + golden = np.zeros(dst_shape, dtype=dtype) + valid_row, valid_col = valid_shape + for i in range(valid_row): + golden[i, :valid_col] = src[0, :valid_col] + + save_case_data(case["name"], {"input0": src, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} src={src_shape} dst={dst_shape} valid={valid_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpand/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpand/launch.cpp new file mode 100644 index 000000000..c22f39010 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpand/launch.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 1: half_1_16_512_512 +extern "C" __global__ AICORE void TCOLEXPAND_half_1_16_512_512(__gm__ uint16_t *src, __gm__ uint16_t *dst); + +void LaunchTCOLEXPAND_half_1_16_512_512(uint16_t *src, uint16_t *dst, void *stream) { + TCOLEXPAND_half_1_16_512_512<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint16_t *)dst); +} + +// Case 2: int8_2_32_256_255 +extern "C" __global__ AICORE void TCOLEXPAND_int8_2_32_256_255(__gm__ int8_t *src, __gm__ int8_t *dst); + +void LaunchTCOLEXPAND_int8_2_32_256_255(int8_t *src, int8_t *dst, void *stream) { + TCOLEXPAND_int8_2_32_256_255<<<1, nullptr, stream>>>((__gm__ int8_t *)src, (__gm__ int8_t *)dst); +} + +// Case 3: float_1_8_128_63 +extern "C" __global__ AICORE void TCOLEXPAND_float_1_8_128_63(__gm__ float *src, __gm__ float *dst); + +void LaunchTCOLEXPAND_float_1_8_128_63(float *src, float *dst, void *stream) { + TCOLEXPAND_float_1_8_128_63<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +// Case 4: half_1_33_512_512 +extern "C" __global__ AICORE void TCOLEXPAND_half_1_33_512_512(__gm__ uint16_t *src, __gm__ uint16_t *dst); + +void LaunchTCOLEXPAND_half_1_33_512_512(uint16_t *src, uint16_t *dst, void *stream) { + TCOLEXPAND_half_1_33_512_512<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint16_t *)dst); +} + +// Case 5: int8_2_17_256_44 +extern "C" __global__ AICORE void TCOLEXPAND_int8_2_17_256_44(__gm__ int8_t *src, __gm__ int8_t *dst); + +void LaunchTCOLEXPAND_int8_2_17_256_44(int8_t *src, int8_t *dst, void *stream) { + TCOLEXPAND_int8_2_17_256_44<<<1, nullptr, stream>>>((__gm__ int8_t *)src, (__gm__ int8_t *)dst); +} + +// Case 6: float_1_54_64_63 +extern "C" __global__ AICORE void TCOLEXPAND_float_1_54_64_63(__gm__ float *src, __gm__ float *dst); + +void LaunchTCOLEXPAND_float_1_54_64_63(float *src, float *dst, void *stream) { + TCOLEXPAND_float_1_54_64_63<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpand/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpand/main.cpp new file mode 100644 index 000000000..717b68094 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpand/main.cpp @@ -0,0 +1,143 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tcolexpand ST +// Test cases match PTO-ISA: /home/zhoushaofan/code/pto-isa/tests/npu/a5/src/st/testcase/tcolexpand/ +// TCOLEXPAND: expand src first row to dst all rows by broadcasting + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTCOLEXPAND_half_1_16_512_512(uint16_t *src, uint16_t *dst, void *stream); +void LaunchTCOLEXPAND_int8_2_32_256_255(int8_t *src, int8_t *dst, void *stream); +void LaunchTCOLEXPAND_float_1_8_128_63(float *src, float *dst, void *stream); +void LaunchTCOLEXPAND_half_1_33_512_512(uint16_t *src, uint16_t *dst, void *stream); +void LaunchTCOLEXPAND_int8_2_17_256_44(int8_t *src, int8_t *dst, void *stream); +void LaunchTCOLEXPAND_float_1_54_64_63(float *src, float *dst, void *stream); + +using LaunchFn = void (*)(void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t srcRows; + size_t srcCols; + size_t dstRows; + size_t dstCols; + size_t validRows; + size_t validCols; + size_t elemSize; +}; + +static const TestCase kCases[] = { + {"half_1_16_512_512", (LaunchFn)LaunchTCOLEXPAND_half_1_16_512_512, 1, 512, 16, 512, 16, 512, sizeof(uint16_t)}, + {"int8_2_32_256_255", (LaunchFn)LaunchTCOLEXPAND_int8_2_32_256_255, 2, 256, 32, 256, 32, 255, sizeof(int8_t)}, + {"float_1_8_128_63", (LaunchFn)LaunchTCOLEXPAND_float_1_8_128_63, 1, 128, 8, 128, 8, 63, sizeof(float)}, + {"half_1_33_512_512", (LaunchFn)LaunchTCOLEXPAND_half_1_33_512_512, 1, 512, 33, 512, 33, 512, sizeof(uint16_t)}, + {"int8_2_17_256_44", (LaunchFn)LaunchTCOLEXPAND_int8_2_17_256_44, 2, 256, 17, 256, 17, 44, sizeof(int8_t)}, + {"float_1_54_64_63", (LaunchFn)LaunchTCOLEXPAND_float_1_54_64_63, 1, 64, 54, 64, 54, 63, sizeof(float)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t srcElemCount = tc.srcRows * tc.srcCols; + const size_t dstElemCount = tc.dstRows * tc.dstCols; + const size_t srcFileSize = srcElemCount * tc.elemSize; + const size_t dstFileSize = dstElemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (src=%zux%zu -> dst=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.srcRows, tc.srcCols, tc.dstRows, tc.dstCols, tc.validRows, tc.validCols); + + std::string caseDir = std::string("./") + tc.name; + size_t actualSrcFileSize = srcFileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&srcHost), srcFileSize); + aclrtMallocHost((void **)(&dstHost), dstFileSize); + + aclrtMalloc((void **)&srcDevice, srcFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input0.bin").c_str(), actualSrcFileSize, srcHost, srcFileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input0.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, srcFileSize, srcHost, srcFileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpand/tcolexpand.pto b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpand/tcolexpand.pto new file mode 100644 index 000000000..87ec41ceb --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpand/tcolexpand.pto @@ -0,0 +1,291 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tcolexpand: expand src (src_row x valid_col) to dst (dst_row x valid_col). +// Matches PTO-ISA testcase parameters. +// Key: tile_buf cols = full tensor width, v_col = valid portion + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // Case 1: half_1_16_512_512 (fp16, valid_col=512, cols=512) + func.func @TCOLEXPAND_half_1_16_512_512(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c512 = arith.constant 512 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c512], + strides = [%c512, %c512, %c512, %c512, %c1] + : !pto.tensor_view<1x1x1x1x512xf16> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c512], + strides = [%c8192, %c8192, %c8192, %c512, %c1] + : !pto.tensor_view<1x1x1x16x512xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c512] + : !pto.tensor_view<1x1x1x1x512xf16> -> !pto.partition_tensor_view<1x1x1x1x512xf16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c512] + : !pto.tensor_view<1x1x1x16x512xf16> -> !pto.partition_tensor_view<1x1x1x16x512xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x512xf16>) + outs(%src : !pto.tile_buf) + + pto.tcolexpand ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x512xf16>) + return + } + + // Case 2: int8_2_32_256_255 (int8, cols=256, valid_col=255) + func.func @TCOLEXPAND_int8_2_32_256_255(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c32 = arith.constant 32 : index + %c255 = arith.constant 255 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c256], + strides = [%c512, %c512, %c512, %c256, %c1] + : !pto.tensor_view<1x1x1x2x256xi8> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c256], + strides = [%c8192, %c8192, %c8192, %c256, %c1] + : !pto.tensor_view<1x1x1x32x256xi8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c255] + : !pto.tensor_view<1x1x1x2x256xi8> -> !pto.partition_tensor_view<1x1x1x2x255xi8> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c255] + : !pto.tensor_view<1x1x1x32x256xi8> -> !pto.partition_tensor_view<1x1x1x32x255xi8> + + %src = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x255xi8>) + outs(%src : !pto.tile_buf) + + pto.tcolexpand ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x255xi8>) + return + } + + // Case 3: float_1_8_128_63 (float32, cols=128, valid_col=63) + func.func @TCOLEXPAND_float_1_8_128_63(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c63 = arith.constant 63 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c8, %c128], + strides = [%c1024, %c1024, %c1024, %c128, %c1] + : !pto.tensor_view<1x1x1x8x128xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c63] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x63xf32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c63] + : !pto.tensor_view<1x1x1x8x128xf32> -> !pto.partition_tensor_view<1x1x1x8x63xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x63xf32>) + outs(%src : !pto.tile_buf) + + pto.tcolexpand ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x8x63xf32>) + return + } + + // Case 4: half_1_33_512_512 (fp16, cols=512, valid_col=512) + func.func @TCOLEXPAND_half_1_33_512_512(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c33 = arith.constant 33 : index + %c512 = arith.constant 512 : index + %c16896 = arith.constant 16896 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c512], + strides = [%c512, %c512, %c512, %c512, %c1] + : !pto.tensor_view<1x1x1x1x512xf16> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c33, %c512], + strides = [%c16896, %c16896, %c16896, %c512, %c1] + : !pto.tensor_view<1x1x1x33x512xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c512] + : !pto.tensor_view<1x1x1x1x512xf16> -> !pto.partition_tensor_view<1x1x1x1x512xf16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c33, %c512] + : !pto.tensor_view<1x1x1x33x512xf16> -> !pto.partition_tensor_view<1x1x1x33x512xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x512xf16>) + outs(%src : !pto.tile_buf) + + pto.tcolexpand ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x33x512xf16>) + return + } + + // Case 5: int8_2_17_256_44 (int8, cols=256, valid_col=44) + func.func @TCOLEXPAND_int8_2_17_256_44(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c17 = arith.constant 17 : index + %c44 = arith.constant 44 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %c4352 = arith.constant 4352 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c256], + strides = [%c512, %c512, %c512, %c256, %c1] + : !pto.tensor_view<1x1x1x2x256xi8> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c17, %c256], + strides = [%c4352, %c4352, %c4352, %c256, %c1] + : !pto.tensor_view<1x1x1x17x256xi8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c44] + : !pto.tensor_view<1x1x1x2x256xi8> -> !pto.partition_tensor_view<1x1x1x2x44xi8> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c17, %c44] + : !pto.tensor_view<1x1x1x17x256xi8> -> !pto.partition_tensor_view<1x1x1x17x44xi8> + + %src = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x44xi8>) + outs(%src : !pto.tile_buf) + + pto.tcolexpand ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x17x44xi8>) + return + } + + // Case 6: float_1_54_64_63 (float32, cols=64, valid_col=63) + func.func @TCOLEXPAND_float_1_54_64_63(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c54 = arith.constant 54 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c3456 = arith.constant 3456 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c64], + strides = [%c64, %c64, %c64, %c64, %c1] + : !pto.tensor_view<1x1x1x1x64xf32> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c54, %c64], + strides = [%c3456, %c3456, %c3456, %c64, %c1] + : !pto.tensor_view<1x1x1x54x64xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c63] + : !pto.tensor_view<1x1x1x1x64xf32> -> !pto.partition_tensor_view<1x1x1x1x63xf32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c54, %c63] + : !pto.tensor_view<1x1x1x54x64xf32> -> !pto.partition_tensor_view<1x1x1x54x63xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x63xf32>) + outs(%src : !pto.tile_buf) + + pto.tcolexpand ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x54x63xf32>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandadd/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandadd/CMakeLists.txt new file mode 100644 index 000000000..7151caba5 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandadd/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tcolexpandadd) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandadd/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandadd/cases.py new file mode 100644 index 000000000..fa496319d --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandadd/cases.py @@ -0,0 +1,77 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tcolexpandadd ST test cases. + +TCOLEXPANDADD: expand src1 then add with src0. + - src0_shape: (dst_row, dst_col) - already expanded (src0_shape = shape) + - src1_shape: (src1_row, src1_col) - to be expanded (usually src1_row=1) + - shape: (dst_row, dst_col) - output shape +""" + +import numpy as np + +CASES = [ + { + "name": "fp32_16_128_1_128", + "dtype": np.float32, + "src0_shape": (16, 128), + "src1_shape": (1, 128), + "shape": (16, 128), + "valid_shape": (16, 128), + "eps": 1e-3, + }, + { + "name": "fp32_32_32_1_32", + "dtype": np.float32, + "src0_shape": (32, 32), + "src1_shape": (1, 32), + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-3, + }, + { + "name": "fp16_4_256_1_256", + "dtype": np.float16, + "src0_shape": (4, 256), + "src1_shape": (1, 256), + "shape": (4, 256), + "valid_shape": (4, 256), + "eps": 1e-3, + }, + { + "name": "fp16_10_64_1_64", + "dtype": np.float16, + "src0_shape": (10, 64), + "src1_shape": (1, 64), + "shape": (10, 64), + "valid_shape": (10, 64), + "eps": 1e-3, + }, + { + "name": "int32_16_32_1_32", + "dtype": np.int32, + "src0_shape": (16, 32), + "src1_shape": (1, 32), + "shape": (16, 32), + "valid_shape": (16, 32), + "eps": 0, + }, + { + "name": "int16_16_64_1_64", + "dtype": np.int16, + "src0_shape": (16, 64), + "src1_shape": (1, 64), + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 0, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandadd/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandadd/compare.py new file mode 100644 index 000000000..b8ddb1131 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandadd/compare.py @@ -0,0 +1,56 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You can not use this file in the compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandadd/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandadd/gen_data.py new file mode 100644 index 000000000..d8556e3a3 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandadd/gen_data.py @@ -0,0 +1,38 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + src0_shape = case["src0_shape"] + dst_shape = case["shape"] + src1_shape = case["src1_shape"] + valid_shape = case["valid_shape"] + + src0 = np.random.randint(1, 10, size=src0_shape).astype(dtype) + src1 = np.random.randint(1, 10, size=src1_shape).astype(dtype) + + golden = np.zeros(dst_shape, dtype=dtype) + valid_row, valid_col = valid_shape + src1_row, src1_col = src1_shape + reps = dst_shape[0] // src1_row + expanded_src1 = np.tile(src1, (reps, 1))[:, :valid_col] + golden[:valid_row, :valid_col] = (src0[:valid_row, :valid_col] + expanded_src1[:valid_row, :valid_col]).astype(dtype, copy=False) + + save_case_data(case["name"], {"input0": src0, "input1": src1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} src0={src0_shape} src1={src1_shape} dst={dst_shape} valid={valid_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandadd/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandadd/launch.cpp new file mode 100644 index 000000000..fe60cd7c2 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandadd/launch.cpp @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 1: fp32_16_128_1_128 +extern "C" __global__ AICORE void TCOLEXPANDADD_fp32_16_128_1_128(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTCOLEXPANDADD_fp32_16_128_1_128(float *src0, float *src1, float *dst, void *stream) { + TCOLEXPANDADD_fp32_16_128_1_128<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// Case 2: fp32_32_32_1_32 +extern "C" __global__ AICORE void TCOLEXPANDADD_fp32_32_32_1_32(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTCOLEXPANDADD_fp32_32_32_1_32(float *src0, float *src1, float *dst, void *stream) { + TCOLEXPANDADD_fp32_32_32_1_32<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// Case 3: fp16_4_256_1_256 +extern "C" __global__ AICORE void TCOLEXPANDADD_fp16_4_256_1_256(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTCOLEXPANDADD_fp16_4_256_1_256(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream) { + TCOLEXPANDADD_fp16_4_256_1_256<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} + +// Case 4: fp16_10_64_1_64 +extern "C" __global__ AICORE void TCOLEXPANDADD_fp16_10_64_1_64(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTCOLEXPANDADD_fp16_10_64_1_64(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream) { + TCOLEXPANDADD_fp16_10_64_1_64<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} +// Case: int32_16_32_1_32 +extern "C" __global__ AICORE void TCOLEXPANDADD_int32_16_32_1_32(__gm__ int32_t *src0, __gm__ int32_t *src1, __gm__ int32_t *dst); + +void LaunchTCOLEXPANDADD_int32_16_32_1_32(int32_t *src0, int32_t *src1, int32_t *dst, void *stream) { + TCOLEXPANDADD_int32_16_32_1_32<<<1, nullptr, stream>>>((__gm__ int32_t *)src0, (__gm__ int32_t *)src1, (__gm__ int32_t *)dst); +} + +// Case: int16_16_64_1_64 +extern "C" __global__ AICORE void TCOLEXPANDADD_int16_16_64_1_64(__gm__ int16_t *src0, __gm__ int16_t *src1, __gm__ int16_t *dst); + +void LaunchTCOLEXPANDADD_int16_16_64_1_64(int16_t *src0, int16_t *src1, int16_t *dst, void *stream) { + TCOLEXPANDADD_int16_16_64_1_64<<<1, nullptr, stream>>>((__gm__ int16_t *)src0, (__gm__ int16_t *)src1, (__gm__ int16_t *)dst); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandadd/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandadd/main.cpp new file mode 100644 index 000000000..93a9581fd --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandadd/main.cpp @@ -0,0 +1,158 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tcolexpandadd ST +// TCOLEXPANDADD: src0 + expand(src1) -> dst + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTCOLEXPANDADD_fp32_16_128_1_128(float *src0, float *src1, float *dst, void *stream); +void LaunchTCOLEXPANDADD_fp32_32_32_1_32(float *src0, float *src1, float *dst, void *stream); +void LaunchTCOLEXPANDADD_fp16_4_256_1_256(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream); +void LaunchTCOLEXPANDADD_fp16_10_64_1_64(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream); +void LaunchTCOLEXPANDADD_int32_16_32_1_32(int32_t *src0, int32_t *src1, int32_t *dst, void *stream); +void LaunchTCOLEXPANDADD_int16_16_64_1_64(int16_t *src0, int16_t *src1, int16_t *dst, void *stream); + +using LaunchFn = void (*)(void *, void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t src0Rows; + size_t src0Cols; + size_t src1Rows; + size_t src1Cols; + size_t dstRows; + size_t dstCols; + size_t validRows; + size_t validCols; + size_t elemSize; +}; + +static const TestCase kCases[] = { + {"fp32_16_128_1_128", (LaunchFn)LaunchTCOLEXPANDADD_fp32_16_128_1_128, 16, 128, 1, 128, 16, 128, 16, 128, sizeof(float)}, + {"fp32_32_32_1_32", (LaunchFn)LaunchTCOLEXPANDADD_fp32_32_32_1_32, 32, 32, 1, 32, 32, 32, 32, 32, sizeof(float)}, + {"fp16_4_256_1_256", (LaunchFn)LaunchTCOLEXPANDADD_fp16_4_256_1_256, 4, 256, 1, 256, 4, 256, 4, 256, sizeof(uint16_t)}, + {"fp16_10_64_1_64", (LaunchFn)LaunchTCOLEXPANDADD_fp16_10_64_1_64, 10, 64, 1, 64, 10, 64, 10, 64, sizeof(uint16_t)}, + {"int32_16_32_1_32", (LaunchFn)LaunchTCOLEXPANDADD_int32_16_32_1_32, 16, 32, 1, 32, 16, 32, 16, 32, sizeof(int32_t)}, + {"int16_16_64_1_64", (LaunchFn)LaunchTCOLEXPANDADD_int16_16_64_1_64, 16, 64, 1, 64, 16, 64, 16, 64, sizeof(int16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t src0ElemCount = tc.src0Rows * tc.src0Cols; + const size_t src1ElemCount = tc.src1Rows * tc.src1Cols; + const size_t dstElemCount = tc.dstRows * tc.dstCols; + const size_t src0FileSize = src0ElemCount * tc.elemSize; + const size_t src1FileSize = src1ElemCount * tc.elemSize; + const size_t dstFileSize = dstElemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (src0=%zux%zu, src1=%zux%zu -> dst=%zux%zu) ===\n", + tc.name, tc.src0Rows, tc.src0Cols, tc.src1Rows, tc.src1Cols, tc.dstRows, tc.dstCols); + + std::string caseDir = std::string("./") + tc.name; + size_t actualSrc0FileSize = src0FileSize; + size_t actualSrc1FileSize = src1FileSize; + + void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), src0FileSize); + aclrtMallocHost((void **)(&src1Host), src1FileSize); + aclrtMallocHost((void **)(&dstHost), dstFileSize); + + aclrtMalloc((void **)&src0Device, src0FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, src1FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input0.bin").c_str(), actualSrc0FileSize, src0Host, src0FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input0.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input1.bin").c_str(), actualSrc1FileSize, src1Host, src1FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, src0FileSize, src0Host, src0FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, src1FileSize, src1Host, src1FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandadd/tcolexpandadd.pto b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandadd/tcolexpandadd.pto new file mode 100644 index 000000000..5d941560b --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandadd/tcolexpandadd.pto @@ -0,0 +1,386 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tcolexpandadd: expand src1 and add with src0. +// TCOLEXPANDADD: src0 + expand(src1) -> dst +// - src0: (dst_row, dst_col) +// - src1: (src1_row, src1_col), usually src1_row=1 +// - dst: (dst_row, dst_col) + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // Case 1: fp32_16_128_1_128 (float32, src0=(16,128), src1=(1,128), dst=(16,128)) + func.func @TCOLEXPANDADD_fp32_16_128_1_128(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c2048 = arith.constant 2048 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf32> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x128xf32> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x128xf32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x128xf32> + + %src0_tile = pto.alloc_tile + : !pto.tile_buf + + %src1_tile = pto.alloc_tile + : !pto.tile_buf + + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x128xf32>) + outs(%src0_tile : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x128xf32>) + outs(%src1_tile : !pto.tile_buf) + + pto.tcolexpandadd ins(%src0_tile, %src1_tile : !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x128xf32>) + return + } + + // Case 2: fp32_32_32_1_32 (float32, src0=(32,32), src1=(1,32), dst=(32,32)) + func.func @TCOLEXPANDADD_fp32_32_32_1_32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c32], + strides = [%c32, %c32, %c32, %c32, %c1] + : !pto.tensor_view<1x1x1x1x32xf32> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c32] + : !pto.tensor_view<1x1x1x1x32xf32> -> !pto.partition_tensor_view<1x1x1x1x32xf32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + + %src0_tile = pto.alloc_tile + : !pto.tile_buf + + %src1_tile = pto.alloc_tile + : !pto.tile_buf + + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + outs(%src0_tile : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x32xf32>) + outs(%src1_tile : !pto.tile_buf) + + pto.tcolexpandadd ins(%src0_tile, %src1_tile : !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + return + } + + // Case 3: fp16_4_256_1_256 (float16, src0=(4,256), src1=(1,256), dst=(4,256)) + func.func @TCOLEXPANDADD_fp16_4_256_1_256(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xf16> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c256] + : !pto.tensor_view<1x1x1x4x256xf16> -> !pto.partition_tensor_view<1x1x1x4x256xf16> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c256] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x256xf16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c256] + : !pto.tensor_view<1x1x1x4x256xf16> -> !pto.partition_tensor_view<1x1x1x4x256xf16> + + %src0_tile = pto.alloc_tile + : !pto.tile_buf + + %src1_tile = pto.alloc_tile + : !pto.tile_buf + + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x4x256xf16>) + outs(%src0_tile : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x256xf16>) + outs(%src1_tile : !pto.tile_buf) + + pto.tcolexpandadd ins(%src0_tile, %src1_tile : !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x256xf16>) + return + } + + // Case 4: fp16_10_64_1_64 (float16, src0=(10,64), src1=(1,64), dst=(10,64)) + func.func @TCOLEXPANDADD_fp16_10_64_1_64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + %c64 = arith.constant 64 : index + %c640 = arith.constant 640 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c10, %c64], + strides = [%c640, %c640, %c640, %c64, %c1] + : !pto.tensor_view<1x1x1x10x64xf16> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c64], + strides = [%c64, %c64, %c64, %c64, %c1] + : !pto.tensor_view<1x1x1x1x64xf16> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c10, %c64], + strides = [%c640, %c640, %c640, %c64, %c1] + : !pto.tensor_view<1x1x1x10x64xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c10, %c64] + : !pto.tensor_view<1x1x1x10x64xf16> -> !pto.partition_tensor_view<1x1x1x10x64xf16> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c64] + : !pto.tensor_view<1x1x1x1x64xf16> -> !pto.partition_tensor_view<1x1x1x1x64xf16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c10, %c64] + : !pto.tensor_view<1x1x1x10x64xf16> -> !pto.partition_tensor_view<1x1x1x10x64xf16> + + %src0_tile = pto.alloc_tile + : !pto.tile_buf + + %src1_tile = pto.alloc_tile + : !pto.tile_buf + + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x10x64xf16>) + outs(%src0_tile : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x64xf16>) + outs(%src1_tile : !pto.tile_buf) + + pto.tcolexpandadd ins(%src0_tile, %src1_tile : !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x10x64xf16>) + return + } + + // Case 5: int32_16_32_1_32 (int32, src0=(16,32), src1=(1,32), dst=(16,32)) + func.func @TCOLEXPANDADD_int32_16_32_1_32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c512 = arith.constant 512 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi32> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c32], + strides = [%c32, %c32, %c32, %c32, %c1] + : !pto.tensor_view<1x1x1x1x32xi32> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c32] + : !pto.tensor_view<1x1x1x1x32xi32> -> !pto.partition_tensor_view<1x1x1x1x32xi32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + + %src0_tile = pto.alloc_tile + : !pto.tile_buf + + %src1_tile = pto.alloc_tile + : !pto.tile_buf + + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) + outs(%src0_tile : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x32xi32>) + outs(%src1_tile : !pto.tile_buf) + + pto.tcolexpandadd ins(%src0_tile, %src1_tile : !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) + return + } + + // Case 6: int16_16_64_1_64 (int16, src0=(16,64), src1=(1,64), dst=(16,64)) + func.func @TCOLEXPANDADD_int16_16_64_1_64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi16> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c64], + strides = [%c64, %c64, %c64, %c64, %c1] + : !pto.tensor_view<1x1x1x1x64xi16> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi16> -> !pto.partition_tensor_view<1x1x1x16x64xi16> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c64] + : !pto.tensor_view<1x1x1x1x64xi16> -> !pto.partition_tensor_view<1x1x1x1x64xi16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi16> -> !pto.partition_tensor_view<1x1x1x16x64xi16> + + %src0_tile = pto.alloc_tile + : !pto.tile_buf + + %src1_tile = pto.alloc_tile + : !pto.tile_buf + + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x64xi16>) + outs(%src0_tile : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x64xi16>) + outs(%src1_tile : !pto.tile_buf) + + pto.tcolexpandadd ins(%src0_tile, %src1_tile : !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xi16>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpanddiv/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpanddiv/CMakeLists.txt new file mode 100644 index 000000000..4fe30b496 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpanddiv/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tcolexpanddiv) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpanddiv/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpanddiv/cases.py new file mode 100644 index 000000000..c4710cb93 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpanddiv/cases.py @@ -0,0 +1,115 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file in the compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tcolexpanddiv ST test cases. +Matches PTO-ISA testcase definitions in /home/zhoushaofan/code/pto-isa/tests/npu/a5/src/st/testcase/tcolexpanddiv/ + +TCOLEXPANDDIV: column-wise broadcast divide - dst[i,j] = src0[i,j] / src1[0,j] + - src0_shape: (src0_row, cols) - dividend input tile + - src1_shape: (1, cols) - divisor input tile (single row, broadcast) + - dst_shape: (dst_row, cols) - output tile + - valid_shape: (valid_row, valid_col) - effective computation region + +Case naming: {dtype}_{src0_row}_{src0_col}_{src1_row}_{src1_col} +""" + +import numpy as np + +CASES = [ + { + "name": "fp32_32_64_1_64", + "dtype": np.float32, + "src0_shape": (32, 64), + "src1_shape": (1, 64), + "shape": (32, 64), + "valid_shape": (32, 64), + "eps": 1e-6, + }, + { + "name": "fp32_8_32_1_32", + "dtype": np.float32, + "src0_shape": (8, 32), + "src1_shape": (1, 32), + "shape": (8, 32), + "valid_shape": (8, 32), + "eps": 1e-6, + }, + { + "name": "fp16_16_64_1_64", + "dtype": np.float16, + "src0_shape": (16, 64), + "src1_shape": (1, 64), + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-3, + }, + { + "name": "fp16_4_128_1_128", + "dtype": np.float16, + "src0_shape": (4, 128), + "src1_shape": (1, 128), + "shape": (4, 128), + "valid_shape": (4, 128), + "eps": 1e-3, + }, + { + "name": "int32_16_32_1_32", + "dtype": np.int32, + "src0_shape": (16, 32), + "src1_shape": (1, 32), + "shape": (16, 32), + "valid_shape": (16, 32), + "eps": 0, + }, + { + "name": "int16_16_64_1_64", + "dtype": np.int16, + "src0_shape": (16, 64), + "src1_shape": (1, 64), + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 0, + }, + { + "name": "fp32_40_32_1_32", + "dtype": np.float32, + "src0_shape": (40, 32), + "src1_shape": (1, 32), + "shape": (40, 32), + "valid_shape": (40, 32), + "eps": 1e-6, + }, + { + "name": "fp16_16_128_1_128", + "dtype": np.float16, + "src0_shape": (16, 128), + "src1_shape": (1, 128), + "shape": (16, 128), + "valid_shape": (16, 128), + "eps": 1e-3, + }, + { + "name": "fp32_20_64_1_64", + "dtype": np.float32, + "src0_shape": (20, 64), + "src1_shape": (1, 64), + "shape": (20, 64), + "valid_shape": (20, 64), + "eps": 1e-6, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpanddiv/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpanddiv/compare.py new file mode 100644 index 000000000..b8ddb1131 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpanddiv/compare.py @@ -0,0 +1,56 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You can not use this file in the compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpanddiv/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpanddiv/gen_data.py new file mode 100644 index 000000000..60cad96b9 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpanddiv/gen_data.py @@ -0,0 +1,38 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + src0_shape = case["src0_shape"] + dst_shape = case["shape"] + src1_shape = case["src1_shape"] + valid_shape = case["valid_shape"] + + src0 = np.random.uniform(1.0, 10.0, size=src0_shape).astype(dtype) + src1 = np.random.uniform(1.0, 10.0, size=src1_shape).astype(dtype) + + valid_row, valid_col = valid_shape + reps = dst_shape[0] // src1_shape[0] + + golden = np.zeros(dst_shape, dtype=dtype) + expanded_src1 = np.tile(src1, (reps, 1))[:, :valid_col] + golden[:valid_row, :valid_col] = src0[:valid_row, :valid_col] / expanded_src1 + + save_case_data(case["name"], {"input0": src0, "input1": src1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} src0={src0_shape} src1={src1_shape} dst={dst_shape} valid={valid_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpanddiv/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpanddiv/launch.cpp new file mode 100644 index 000000000..c50222329 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpanddiv/launch.cpp @@ -0,0 +1,75 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 1: fp32_32_64_1_64 +extern "C" __global__ AICORE void TCOLEXPANDDIV_fp32_32_64_1_64(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTCOLEXPANDDIV_fp32_32_64_1_64(float *src0, float *src1, float *dst, void *stream) { + TCOLEXPANDDIV_fp32_32_64_1_64<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// Case 2: fp32_8_32_1_32 +extern "C" __global__ AICORE void TCOLEXPANDDIV_fp32_8_32_1_32(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTCOLEXPANDDIV_fp32_8_32_1_32(float *src0, float *src1, float *dst, void *stream) { + TCOLEXPANDDIV_fp32_8_32_1_32<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// Case 3: fp16_16_64_1_64 +extern "C" __global__ AICORE void TCOLEXPANDDIV_fp16_16_64_1_64(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTCOLEXPANDDIV_fp16_16_64_1_64(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream) { + TCOLEXPANDDIV_fp16_16_64_1_64<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} + +// Case 4: fp16_4_128_1_128 +extern "C" __global__ AICORE void TCOLEXPANDDIV_fp16_4_128_1_128(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTCOLEXPANDDIV_fp16_4_128_1_128(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream) { + TCOLEXPANDDIV_fp16_4_128_1_128<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} +// Case: int32_16_32_1_32 +extern "C" __global__ AICORE void TCOLEXPANDDIV_int32_16_32_1_32(__gm__ int32_t *src0, __gm__ int32_t *src1, __gm__ int32_t *dst); + +void LaunchTCOLEXPANDDIV_int32_16_32_1_32(int32_t *src0, int32_t *src1, int32_t *dst, void *stream) { + TCOLEXPANDDIV_int32_16_32_1_32<<<1, nullptr, stream>>>((__gm__ int32_t *)src0, (__gm__ int32_t *)src1, (__gm__ int32_t *)dst); +} + +// Case: int16_16_64_1_64 +extern "C" __global__ AICORE void TCOLEXPANDDIV_int16_16_64_1_64(__gm__ int16_t *src0, __gm__ int16_t *src1, __gm__ int16_t *dst); + +void LaunchTCOLEXPANDDIV_int16_16_64_1_64(int16_t *src0, int16_t *src1, int16_t *dst, void *stream) { + TCOLEXPANDDIV_int16_16_64_1_64<<<1, nullptr, stream>>>((__gm__ int16_t *)src0, (__gm__ int16_t *)src1, (__gm__ int16_t *)dst); +} + +// Case: fp32_40_32_1_32 +extern "C" __global__ AICORE void TCOLEXPANDDIV_fp32_40_32_1_32(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTCOLEXPANDDIV_fp32_40_32_1_32(float *src0, float *src1, float *dst, void *stream) { + TCOLEXPANDDIV_fp32_40_32_1_32<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// Case: fp16_16_128_1_128 +extern "C" __global__ AICORE void TCOLEXPANDDIV_fp16_16_128_1_128(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTCOLEXPANDDIV_fp16_16_128_1_128(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream) { + TCOLEXPANDDIV_fp16_16_128_1_128<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} + +// Case: fp32_20_64_1_64 +extern "C" __global__ AICORE void TCOLEXPANDDIV_fp32_20_64_1_64(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTCOLEXPANDDIV_fp32_20_64_1_64(float *src0, float *src1, float *dst, void *stream) { + TCOLEXPANDDIV_fp32_20_64_1_64<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpanddiv/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpanddiv/main.cpp new file mode 100644 index 000000000..48b46c797 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpanddiv/main.cpp @@ -0,0 +1,167 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tcolexpanddiv ST +// Test cases match PTO-ISA: /home/zhoushaofan/code/pto-isa/tests/npu/a5/src/st/testcase/tcolexpanddiv/ +// TCOLEXPANDDIV: column-wise broadcast divide - dst[i,j] = src0[i,j] / src1[0,j] + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTCOLEXPANDDIV_fp32_32_64_1_64(float *src0, float *src1, float *dst, void *stream); +void LaunchTCOLEXPANDDIV_fp32_8_32_1_32(float *src0, float *src1, float *dst, void *stream); +void LaunchTCOLEXPANDDIV_fp16_16_64_1_64(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream); +void LaunchTCOLEXPANDDIV_fp16_4_128_1_128(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream); +void LaunchTCOLEXPANDDIV_int32_16_32_1_32(int32_t *src0, int32_t *src1, int32_t *dst, void *stream); +void LaunchTCOLEXPANDDIV_int16_16_64_1_64(int16_t *src0, int16_t *src1, int16_t *dst, void *stream); +void LaunchTCOLEXPANDDIV_fp32_40_32_1_32(float *src0, float *src1, float *dst, void *stream); +void LaunchTCOLEXPANDDIV_fp16_16_128_1_128(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream); +void LaunchTCOLEXPANDDIV_fp32_20_64_1_64(float *src0, float *src1, float *dst, void *stream); + +using LaunchFn = void (*)(void *, void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t src0Rows; + size_t src0Cols; + size_t src1Rows; + size_t src1Cols; + size_t dstRows; + size_t dstCols; + size_t validRows; + size_t validCols; + size_t elemSize; +}; + +static const TestCase kCases[] = { + {"fp32_32_64_1_64", (LaunchFn)LaunchTCOLEXPANDDIV_fp32_32_64_1_64, 32, 64, 1, 64, 32, 64, 32, 64, sizeof(float)}, + {"fp32_8_32_1_32", (LaunchFn)LaunchTCOLEXPANDDIV_fp32_8_32_1_32, 8, 32, 1, 32, 8, 32, 8, 32, sizeof(float)}, + {"fp16_16_64_1_64", (LaunchFn)LaunchTCOLEXPANDDIV_fp16_16_64_1_64, 16, 64, 1, 64, 16, 64, 16, 64, sizeof(uint16_t)}, + {"fp16_4_128_1_128", (LaunchFn)LaunchTCOLEXPANDDIV_fp16_4_128_1_128, 4, 128, 1, 128, 4, 128, 4, 128, sizeof(uint16_t)}, + {"int32_16_32_1_32", (LaunchFn)LaunchTCOLEXPANDDIV_int32_16_32_1_32, 16, 32, 1, 32, 16, 32, 16, 32, sizeof(int32_t)}, + {"int16_16_64_1_64", (LaunchFn)LaunchTCOLEXPANDDIV_int16_16_64_1_64, 16, 64, 1, 64, 16, 64, 16, 64, sizeof(int16_t)}, + {"fp32_40_32_1_32", (LaunchFn)LaunchTCOLEXPANDDIV_fp32_40_32_1_32, 40, 32, 1, 32, 40, 32, 40, 32, sizeof(float)}, + {"fp16_16_128_1_128", (LaunchFn)LaunchTCOLEXPANDDIV_fp16_16_128_1_128, 16, 128, 1, 128, 16, 128, 16, 128, sizeof(uint16_t)}, + {"fp32_20_64_1_64", (LaunchFn)LaunchTCOLEXPANDDIV_fp32_20_64_1_64, 20, 64, 1, 64, 20, 64, 20, 64, sizeof(float)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t src0ElemCount = tc.src0Rows * tc.src0Cols; + const size_t src1ElemCount = tc.src1Rows * tc.src1Cols; + const size_t dstElemCount = tc.dstRows * tc.dstCols; + const size_t src0FileSize = src0ElemCount * tc.elemSize; + const size_t src1FileSize = src1ElemCount * tc.elemSize; + const size_t dstFileSize = dstElemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (src0=%zux%zu, src1=%zux%zu -> dst=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.src0Rows, tc.src0Cols, tc.src1Rows, tc.src1Cols, + tc.dstRows, tc.dstCols, tc.validRows, tc.validCols); + + std::string caseDir = std::string("./") + tc.name; + size_t actualSrc0FileSize = src0FileSize; + size_t actualSrc1FileSize = src1FileSize; + + void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), src0FileSize); + aclrtMallocHost((void **)(&src1Host), src1FileSize); + aclrtMallocHost((void **)(&dstHost), dstFileSize); + + aclrtMalloc((void **)&src0Device, src0FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, src1FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input0.bin").c_str(), actualSrc0FileSize, src0Host, src0FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input0.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0 && !ReadFile((caseDir + "/input1.bin").c_str(), actualSrc1FileSize, src1Host, src1FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, src0FileSize, src0Host, src0FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, src1FileSize, src1Host, src1FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpanddiv/tcolexpanddiv.pto b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpanddiv/tcolexpanddiv.pto new file mode 100644 index 000000000..23427cc79 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpanddiv/tcolexpanddiv.pto @@ -0,0 +1,571 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tcolexpanddiv: column-wise broadcast divide. +// Matches PTO-ISA testcase parameters. +// Key: tile_buf cols = full tensor width, v_col = valid portion + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // Case 1: fp32_32_64_1_64 (float32, src0=(32,64), src1=(1,64), dst=(32,64)) + func.func @TCOLEXPANDDIV_fp32_32_64_1_64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c2048 = arith.constant 2048 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf32> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c64], + strides = [%c64, %c64, %c64, %c64, %c1] + : !pto.tensor_view<1x1x1x1x64xf32> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c64] + : !pto.tensor_view<1x1x1x1x64xf32> -> !pto.partition_tensor_view<1x1x1x1x64xf32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x64xf32>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpanddiv ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) + return + } + + // Case 2: fp32_8_32_1_32 (float32, src0=(8,32), src1=(1,32), dst=(8,32)) + func.func @TCOLEXPANDDIV_fp32_8_32_1_32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c32 = arith.constant 32 : index + %c256 = arith.constant 256 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c8, %c32], + strides = [%c256, %c256, %c256, %c32, %c1] + : !pto.tensor_view<1x1x1x8x32xf32> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c32], + strides = [%c32, %c32, %c32, %c32, %c1] + : !pto.tensor_view<1x1x1x1x32xf32> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c8, %c32], + strides = [%c256, %c256, %c256, %c32, %c1] + : !pto.tensor_view<1x1x1x8x32xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c32] + : !pto.tensor_view<1x1x1x8x32xf32> -> !pto.partition_tensor_view<1x1x1x8x32xf32> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c32] + : !pto.tensor_view<1x1x1x1x32xf32> -> !pto.partition_tensor_view<1x1x1x1x32xf32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c32] + : !pto.tensor_view<1x1x1x8x32xf32> -> !pto.partition_tensor_view<1x1x1x8x32xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x8x32xf32>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x32xf32>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpanddiv ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x8x32xf32>) + return + } + + // Case 3: fp16_16_64_1_64 (float16, src0=(16,64), src1=(1,64), dst=(16,64)) + func.func @TCOLEXPANDDIV_fp16_16_64_1_64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c64], + strides = [%c64, %c64, %c64, %c64, %c1] + : !pto.tensor_view<1x1x1x1x64xf16> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c64] + : !pto.tensor_view<1x1x1x1x64xf16> -> !pto.partition_tensor_view<1x1x1x1x64xf16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x64xf16>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpanddiv ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + return + } + + // Case 4: fp16_4_128_1_128 (float16, src0=(4,128), src1=(1,128), dst=(4,128)) + func.func @TCOLEXPANDDIV_fp16_4_128_1_128(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xf16> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf16> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c128] + : !pto.tensor_view<1x1x1x4x128xf16> -> !pto.partition_tensor_view<1x1x1x4x128xf16> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf16> -> !pto.partition_tensor_view<1x1x1x1x128xf16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c128] + : !pto.tensor_view<1x1x1x4x128xf16> -> !pto.partition_tensor_view<1x1x1x4x128xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x4x128xf16>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x128xf16>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpanddiv ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + +pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x128xf16>) + return + } + + // Case 5: int32_16_32_1_32 (int32, src0=(16,32), src1=(1,32), dst=(16,32)) + func.func @TCOLEXPANDDIV_int32_16_32_1_32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c512 = arith.constant 512 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi32> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c32], + strides = [%c32, %c32, %c32, %c32, %c1] + : !pto.tensor_view<1x1x1x1x32xi32> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c32] + : !pto.tensor_view<1x1x1x1x32xi32> -> !pto.partition_tensor_view<1x1x1x1x32xi32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x32xi32>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpanddiv ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) + return + } + + // Case 6: int16_16_64_1_64 (int16, src0=(16,64), src1=(1,64), dst=(16,64)) + func.func @TCOLEXPANDDIV_int16_16_64_1_64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi16> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c64], + strides = [%c64, %c64, %c64, %c64, %c1] + : !pto.tensor_view<1x1x1x1x64xi16> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi16> -> !pto.partition_tensor_view<1x1x1x16x64xi16> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c64] + : !pto.tensor_view<1x1x1x1x64xi16> -> !pto.partition_tensor_view<1x1x1x1x64xi16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi16> -> !pto.partition_tensor_view<1x1x1x16x64xi16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x64xi16>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x64xi16>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpanddiv ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xi16>) + return + } + + // Case 7: fp32_40_32_1_32 (float32, src0=(40,32), src1=(1,32), dst=(40,32)) + func.func @TCOLEXPANDDIV_fp32_40_32_1_32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c40 = arith.constant 40 : index + %c32 = arith.constant 32 : index + %c1280 = arith.constant 1280 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c40, %c32], + strides = [%c1280, %c1280, %c1280, %c32, %c1] + : !pto.tensor_view<1x1x1x40x32xf32> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c32], + strides = [%c32, %c32, %c32, %c32, %c1] + : !pto.tensor_view<1x1x1x1x32xf32> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c40, %c32], + strides = [%c1280, %c1280, %c1280, %c32, %c1] + : !pto.tensor_view<1x1x1x40x32xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c40, %c32] + : !pto.tensor_view<1x1x1x40x32xf32> -> !pto.partition_tensor_view<1x1x1x40x32xf32> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c32] + : !pto.tensor_view<1x1x1x1x32xf32> -> !pto.partition_tensor_view<1x1x1x1x32xf32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c40, %c32] + : !pto.tensor_view<1x1x1x40x32xf32> -> !pto.partition_tensor_view<1x1x1x40x32xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x40x32xf32>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x32xf32>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpanddiv ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x40x32xf32>) + return + } + + // Case 8: fp16_16_128_1_128 (float16, src0=(16,128), src1=(1,128), dst=(16,128)) + func.func @TCOLEXPANDDIV_fp16_16_128_1_128(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c2048 = arith.constant 2048 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf16> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf16> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf16> -> !pto.partition_tensor_view<1x1x1x16x128xf16> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf16> -> !pto.partition_tensor_view<1x1x1x1x128xf16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf16> -> !pto.partition_tensor_view<1x1x1x16x128xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x128xf16>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x128xf16>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpanddiv ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x128xf16>) + return + } + + // Case 9: fp32_20_64_1_64 (float32, src0=(20,64), src1=(1,64), dst=(20,64)) + func.func @TCOLEXPANDDIV_fp32_20_64_1_64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c20 = arith.constant 20 : index + %c64 = arith.constant 64 : index + %c1280 = arith.constant 1280 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c20, %c64], + strides = [%c1280, %c1280, %c1280, %c64, %c1] + : !pto.tensor_view<1x1x1x20x64xf32> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c64], + strides = [%c64, %c64, %c64, %c64, %c1] + : !pto.tensor_view<1x1x1x1x64xf32> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c20, %c64], + strides = [%c1280, %c1280, %c1280, %c64, %c1] + : !pto.tensor_view<1x1x1x20x64xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c20, %c64] + : !pto.tensor_view<1x1x1x20x64xf32> -> !pto.partition_tensor_view<1x1x1x20x64xf32> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c64] + : !pto.tensor_view<1x1x1x1x64xf32> -> !pto.partition_tensor_view<1x1x1x1x64xf32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c20, %c64] + : !pto.tensor_view<1x1x1x20x64xf32> -> !pto.partition_tensor_view<1x1x1x20x64xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x20x64xf32>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x64xf32>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpanddiv ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x20x64xf32>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandexpdif/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandexpdif/CMakeLists.txt new file mode 100644 index 000000000..acf3e74d4 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandexpdif/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tcolexpandexpdif) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandexpdif/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandexpdif/cases.py new file mode 100644 index 000000000..e4ea5f005 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandexpdif/cases.py @@ -0,0 +1,67 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tcolexpandexpdif ST test cases. +Matches PTO-ISA testcase definitions in /home/zhoushaofan/code/pto-isa/tests/npu/a5/src/st/testcase/tcolexpandexpdif/ + +TCOLEXPANDEXPDIF: compute exp(src0) - exp(expanded_src1) where src1 is expanded by tiling. + - src0_shape: (src0_row, cols) - first input tile + - src1_shape: (src1_row, cols) - second input tile (tiled to match src0 rows) + - dst_shape: (dst_row, dst_col) - output tile + - shape: (dst_row, dst_col) - alias of dst_shape, for compare.py compatibility + - valid_shape: (valid_row, valid_col) - effective computation region + +Golden: np.exp(src0) - np.exp(np.tile(src1, (reps, 1))[:, :dst_col]) + where reps = dst_row // src1_row + +Case naming: {dtype}_{src0_row}_{src0_col}_{src1_row}_{src1_col} +""" + +import numpy as np + +CASES = [ + { + "name": "fp32_32_16_1_16", + "dtype": np.float32, + "src0_shape": (32, 16), + "src1_shape": (1, 16), + "shape": (32, 16), + "valid_shape": (32, 16), + "eps": 1e-5, + }, + { + "name": "fp32_16_32_1_32", + "dtype": np.float32, + "src0_shape": (16, 32), + "src1_shape": (1, 32), + "shape": (16, 32), + "valid_shape": (16, 32), + "eps": 1e-5, + }, + { + "name": "fp16_32_32_1_32", + "dtype": np.float16, + "src0_shape": (32, 32), + "src1_shape": (1, 32), + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-2, + }, + { + "name": "fp16_16_128_1_128", + "dtype": np.float16, + "src0_shape": (16, 128), + "src1_shape": (1, 128), + "shape": (16, 128), + "valid_shape": (16, 128), + "eps": 1e-2, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandexpdif/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandexpdif/compare.py new file mode 100644 index 000000000..b8ddb1131 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandexpdif/compare.py @@ -0,0 +1,56 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You can not use this file in the compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandexpdif/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandexpdif/gen_data.py new file mode 100644 index 000000000..4606b3571 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandexpdif/gen_data.py @@ -0,0 +1,39 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + src0_shape = case["src0_shape"] + dst_shape = case["shape"] + src1_shape = case["src1_shape"] + valid_shape = case["valid_shape"] + + src0 = np.random.uniform(-255, 255, size=src0_shape).astype(dtype) + src1 = np.random.uniform(1, 255, size=src1_shape).astype(dtype) + + dst_row, dst_col = dst_shape + src1_row = src1_shape[0] + reps = (dst_row + src1_row - 1) // src1_row + + expanded_src1 = np.tile(src1, (reps, 1))[:dst_row, :dst_col] + golden = np.exp((src0.astype(np.float64) - expanded_src1.astype(np.float64))) + golden = golden.astype(dtype) + + save_case_data(case["name"], {"input0": src0, "input1": src1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} src0={src0_shape} src1={src1_shape} dst={dst_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandexpdif/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandexpdif/launch.cpp new file mode 100644 index 000000000..b6f8ad19d --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandexpdif/launch.cpp @@ -0,0 +1,41 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 1: fp32_32_16_1_16 +extern "C" __global__ AICORE void TCOLEXPANDEXPDIF_fp32_32_16_1_16(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTCOLEXPANDEXPDIF_fp32_32_16_1_16(float *src0, float *src1, float *dst, void *stream) { + TCOLEXPANDEXPDIF_fp32_32_16_1_16<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// Case 2: fp32_16_32_1_32 +extern "C" __global__ AICORE void TCOLEXPANDEXPDIF_fp32_16_32_1_32(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTCOLEXPANDEXPDIF_fp32_16_32_1_32(float *src0, float *src1, float *dst, void *stream) { + TCOLEXPANDEXPDIF_fp32_16_32_1_32<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// Case 3: fp16_32_32_1_32 +extern "C" __global__ AICORE void TCOLEXPANDEXPDIF_fp16_32_32_1_32(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTCOLEXPANDEXPDIF_fp16_32_32_1_32(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream) { + TCOLEXPANDEXPDIF_fp16_32_32_1_32<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} + +// Case 4: fp16_16_128_1_128 +extern "C" __global__ AICORE void TCOLEXPANDEXPDIF_fp16_16_128_1_128(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTCOLEXPANDEXPDIF_fp16_16_128_1_128(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream) { + TCOLEXPANDEXPDIF_fp16_16_128_1_128<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandexpdif/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandexpdif/main.cpp new file mode 100644 index 000000000..afa716050 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandexpdif/main.cpp @@ -0,0 +1,157 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tcolexpandexpdif ST +// Test cases match PTO-ISA: /home/zhoushaofan/code/pto-isa/tests/npu/a5/src/st/testcase/tcolexpandexpdif/ +// TCOLEXPANDEXPDIF: compute exp(src0) - exp(tiled_src1) + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTCOLEXPANDEXPDIF_fp32_32_16_1_16(float *src0, float *src1, float *dst, void *stream); +void LaunchTCOLEXPANDEXPDIF_fp32_16_32_1_32(float *src0, float *src1, float *dst, void *stream); +void LaunchTCOLEXPANDEXPDIF_fp16_32_32_1_32(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream); +void LaunchTCOLEXPANDEXPDIF_fp16_16_128_1_128(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream); + +using LaunchFn = void (*)(void *, void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t src0Rows; + size_t src0Cols; + size_t src1Rows; + size_t src1Cols; + size_t dstRows; + size_t dstCols; + size_t validRows; + size_t validCols; + size_t elemSize; +}; + +static const TestCase kCases[] = { + {"fp32_32_16_1_16", (LaunchFn)LaunchTCOLEXPANDEXPDIF_fp32_32_16_1_16, 32, 16, 1, 16, 32, 16, 32, 16, sizeof(float)}, + {"fp32_16_32_1_32", (LaunchFn)LaunchTCOLEXPANDEXPDIF_fp32_16_32_1_32, 16, 32, 1, 32, 16, 32, 16, 32, sizeof(float)}, + {"fp16_32_32_1_32", (LaunchFn)LaunchTCOLEXPANDEXPDIF_fp16_32_32_1_32, 32, 32, 1, 32, 32, 32, 32, 32, sizeof(uint16_t)}, + {"fp16_16_128_1_128", (LaunchFn)LaunchTCOLEXPANDEXPDIF_fp16_16_128_1_128, 16, 128, 1, 128, 16, 128, 16, 128, sizeof(uint16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t src0ElemCount = tc.src0Rows * tc.src0Cols; + const size_t src1ElemCount = tc.src1Rows * tc.src1Cols; + const size_t dstElemCount = tc.dstRows * tc.dstCols; + const size_t src0FileSize = src0ElemCount * tc.elemSize; + const size_t src1FileSize = src1ElemCount * tc.elemSize; + const size_t dstFileSize = dstElemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (src0=%zux%zu, src1=%zux%zu -> dst=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.src0Rows, tc.src0Cols, tc.src1Rows, tc.src1Cols, tc.dstRows, tc.dstCols, tc.validRows, tc.validCols); + + std::string caseDir = std::string("./") + tc.name; + + void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), src0FileSize); + aclrtMallocHost((void **)(&src1Host), src1FileSize); + aclrtMallocHost((void **)(&dstHost), dstFileSize); + + aclrtMalloc((void **)&src0Device, src0FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, src1FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + size_t actualSrc0FileSize = src0FileSize; + size_t actualSrc1FileSize = src1FileSize; + + if (!ReadFile((caseDir + "/input0.bin").c_str(), actualSrc0FileSize, src0Host, src0FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input0.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0 && !ReadFile((caseDir + "/input1.bin").c_str(), actualSrc1FileSize, src1Host, src1FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, src0FileSize, src0Host, src0FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, src1FileSize, src1Host, src1FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandexpdif/tcolexpandexpdif.pto b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandexpdif/tcolexpandexpdif.pto new file mode 100644 index 000000000..f12431f56 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandexpdif/tcolexpandexpdif.pto @@ -0,0 +1,260 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tcolexpandexpdif: compute exp(src0) - exp(tiled_src1). +// Matches PTO-ISA testcase parameters. +// Key: tile_buf cols = full tensor width, v_col = valid portion + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // Case 1: fp32_32_16_1_16 (float32, src0=(32,16), src1=(1,16), dst=(32,16)) + func.func @TCOLEXPANDEXPDIF_fp32_32_16_1_16(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c512 = arith.constant 512 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c32, %c16], + strides = [%c512, %c512, %c512, %c16, %c1] + : !pto.tensor_view<1x1x1x32x16xf32> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c16], + strides = [%c16, %c16, %c16, %c16, %c1] + : !pto.tensor_view<1x1x1x1x16xf32> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c16], + strides = [%c512, %c512, %c512, %c16, %c1] + : !pto.tensor_view<1x1x1x32x16xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c16] + : !pto.tensor_view<1x1x1x32x16xf32> -> !pto.partition_tensor_view<1x1x1x32x16xf32> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c16] + : !pto.tensor_view<1x1x1x1x16xf32> -> !pto.partition_tensor_view<1x1x1x1x16xf32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c16] + : !pto.tensor_view<1x1x1x32x16xf32> -> !pto.partition_tensor_view<1x1x1x32x16xf32> + + %src0_tile = pto.alloc_tile + : !pto.tile_buf + + %src1_tile = pto.alloc_tile + : !pto.tile_buf + + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x32x16xf32>) + outs(%src0_tile : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x16xf32>) + outs(%src1_tile : !pto.tile_buf) + + pto.tcolexpandexpdif ins(%src0_tile, %src1_tile : !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x16xf32>) + return + } + + // Case 2: fp32_16_32_1_32 (float32, src0=(16,32), src1=(1,32), dst=(16,32)) + func.func @TCOLEXPANDEXPDIF_fp32_16_32_1_32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c512 = arith.constant 512 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xf32> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c32], + strides = [%c32, %c32, %c32, %c32, %c1] + : !pto.tensor_view<1x1x1x1x32xf32> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xf32> -> !pto.partition_tensor_view<1x1x1x16x32xf32> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c32] + : !pto.tensor_view<1x1x1x1x32xf32> -> !pto.partition_tensor_view<1x1x1x1x32xf32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xf32> -> !pto.partition_tensor_view<1x1x1x16x32xf32> + + %src0_tile = pto.alloc_tile + : !pto.tile_buf + + %src1_tile = pto.alloc_tile + : !pto.tile_buf + + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x32xf32>) + outs(%src0_tile : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x32xf32>) + outs(%src1_tile : !pto.tile_buf) + + pto.tcolexpandexpdif ins(%src0_tile, %src1_tile : !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x32xf32>) + return + } + + // Case 3: fp16_32_32_1_32 (float16, src0=(32,32), src1=(1,32), dst=(32,32)) + func.func @TCOLEXPANDEXPDIF_fp16_32_32_1_32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf16> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c32], + strides = [%c32, %c32, %c32, %c32, %c1] + : !pto.tensor_view<1x1x1x1x32xf16> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf16> -> !pto.partition_tensor_view<1x1x1x32x32xf16> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c32] + : !pto.tensor_view<1x1x1x1x32xf16> -> !pto.partition_tensor_view<1x1x1x1x32xf16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf16> -> !pto.partition_tensor_view<1x1x1x32x32xf16> + + %src0_tile = pto.alloc_tile + : !pto.tile_buf + + %src1_tile = pto.alloc_tile + : !pto.tile_buf + + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x32x32xf16>) + outs(%src0_tile : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x32xf16>) + outs(%src1_tile : !pto.tile_buf) + + pto.tcolexpandexpdif ins(%src0_tile, %src1_tile : !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x32xf16>) + return + } + + // Case 4: fp16_16_128_1_128 (float16, src0=(16,128), src1=(1,128), dst=(16,128)) + func.func @TCOLEXPANDEXPDIF_fp16_16_128_1_128(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c2048 = arith.constant 2048 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf16> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf16> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf16> -> !pto.partition_tensor_view<1x1x1x16x128xf16> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf16> -> !pto.partition_tensor_view<1x1x1x1x128xf16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf16> -> !pto.partition_tensor_view<1x1x1x16x128xf16> + + %src0_tile = pto.alloc_tile + : !pto.tile_buf + + %src1_tile = pto.alloc_tile + : !pto.tile_buf + + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x128xf16>) + outs(%src0_tile : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x128xf16>) + outs(%src1_tile : !pto.tile_buf) + + pto.tcolexpandexpdif ins(%src0_tile, %src1_tile : !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x128xf16>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmax/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmax/CMakeLists.txt new file mode 100644 index 000000000..c132ea923 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmax/CMakeLists.txt @@ -0,0 +1,16 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file in the compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tcolexpandmax) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmax/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmax/cases.py new file mode 100644 index 000000000..78c7d0d6b --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmax/cases.py @@ -0,0 +1,91 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tcolexpandmax ST test cases. +Matches PTO-ISA testcase definitions in /home/zhoushaofan/code/pto-isa/tests/npu/a5/src/st/testcase/tcolexpandmax/ + +TCOLEXPANDMAX: compute elementwise maximum of src0 and tiled src1. + - src0_shape: (src0_row, cols) - first input tile + - src1_shape: (1, cols) - second input tile (single row, broadcasted) + - dst_shape: (dst_row, cols) - output tile + - shape: (dst_row, cols) - alias of dst_shape, for compare.py compatibility + - valid_shape: (valid_row, valid_col) - effective computation region + - reps: number of times to tile src1 (equals src0_row) + +Golden: np.maximum(src0, np.tile(src1, (reps, 1))[:, :dst_col]) + +Case naming: {dtype}_{src0_row}_{src0_col}_{src1_row}_{dst_col} +""" + +import numpy as np + +CASES = [ + { + "name": "fp32_16_128_1_128", + "dtype": np.float32, + "src0_shape": (16, 128), + "src1_shape": (1, 128), + "shape": (16, 128), + "valid_shape": (16, 128), + "reps": 16, + "eps": 1e-6, + }, + { + "name": "fp32_32_32_1_32", + "dtype": np.float32, + "src0_shape": (32, 32), + "src1_shape": (1, 32), + "shape": (32, 32), + "valid_shape": (32, 32), + "reps": 32, + "eps": 1e-6, + }, + { + "name": "fp16_4_256_1_256", + "dtype": np.float16, + "src0_shape": (4, 256), + "src1_shape": (1, 256), + "shape": (4, 256), + "valid_shape": (4, 256), + "reps": 4, + "eps": 1e-3, + }, + { + "name": "fp16_10_64_1_64", + "dtype": np.float16, + "src0_shape": (10, 64), + "src1_shape": (1, 64), + "shape": (10, 64), + "valid_shape": (10, 64), + "reps": 10, + "eps": 1e-3, + }, + { + "name": "int32_16_32_1_32", + "dtype": np.int32, + "src0_shape": (16, 32), + "src1_shape": (1, 32), + "shape": (16, 32), + "valid_shape": (16, 32), + "reps": 16, + "eps": 0, + }, + { + "name": "int16_16_64_1_64", + "dtype": np.int16, + "src0_shape": (16, 64), + "src1_shape": (1, 64), + "shape": (16, 64), + "valid_shape": (16, 64), + "reps": 16, + "eps": 0, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmax/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmax/compare.py new file mode 100644 index 000000000..b8ddb1131 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmax/compare.py @@ -0,0 +1,56 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You can not use this file in the compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmax/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmax/gen_data.py new file mode 100644 index 000000000..010fafb80 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmax/gen_data.py @@ -0,0 +1,41 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file in the compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + src0_shape = case["src0_shape"] + dst_shape = case["shape"] + src1_shape = case["src1_shape"] + valid_shape = case["valid_shape"] + reps = case["reps"] + + src0 = np.random.randint(1, 10, size=src0_shape).astype(dtype) + src1 = np.random.randint(1, 10, size=src1_shape).astype(dtype) + + golden = np.maximum(src0, np.tile(src1, (reps, 1))[:, :dst_shape[1]]) + + save_case_data(case["name"], {"input0": src0, "input1": src1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} src0={src0_shape} src1={src1_shape} dst={dst_shape} valid={valid_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmax/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmax/launch.cpp new file mode 100644 index 000000000..a9797c562 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmax/launch.cpp @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 1: fp32_16_128_1_128 +extern "C" __global__ AICORE void TCOLEXPANDMAX_fp32_16_128_1_128(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTCOLEXPANDMAX_fp32_16_128_1_128(float *src0, float *src1, float *dst, void *stream) { + TCOLEXPANDMAX_fp32_16_128_1_128<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// Case 2: fp32_32_32_1_32 +extern "C" __global__ AICORE void TCOLEXPANDMAX_fp32_32_32_1_32(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTCOLEXPANDMAX_fp32_32_32_1_32(float *src0, float *src1, float *dst, void *stream) { + TCOLEXPANDMAX_fp32_32_32_1_32<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// Case 3: fp16_4_256_1_256 +extern "C" __global__ AICORE void TCOLEXPANDMAX_fp16_4_256_1_256(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTCOLEXPANDMAX_fp16_4_256_1_256(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream) { + TCOLEXPANDMAX_fp16_4_256_1_256<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} + +// Case 4: fp16_10_64_1_64 +extern "C" __global__ AICORE void TCOLEXPANDMAX_fp16_10_64_1_64(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTCOLEXPANDMAX_fp16_10_64_1_64(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream) { + TCOLEXPANDMAX_fp16_10_64_1_64<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} +// Case: int32_16_32_1_32 +extern "C" __global__ AICORE void TCOLEXPANDMAX_int32_16_32_1_32(__gm__ int32_t *src0, __gm__ int32_t *src1, __gm__ int32_t *dst); + +void LaunchTCOLEXPANDMAX_int32_16_32_1_32(int32_t *src0, int32_t *src1, int32_t *dst, void *stream) { + TCOLEXPANDMAX_int32_16_32_1_32<<<1, nullptr, stream>>>((__gm__ int32_t *)src0, (__gm__ int32_t *)src1, (__gm__ int32_t *)dst); +} + +// Case: int16_16_64_1_64 +extern "C" __global__ AICORE void TCOLEXPANDMAX_int16_16_64_1_64(__gm__ int16_t *src0, __gm__ int16_t *src1, __gm__ int16_t *dst); + +void LaunchTCOLEXPANDMAX_int16_16_64_1_64(int16_t *src0, int16_t *src1, int16_t *dst, void *stream) { + TCOLEXPANDMAX_int16_16_64_1_64<<<1, nullptr, stream>>>((__gm__ int16_t *)src0, (__gm__ int16_t *)src1, (__gm__ int16_t *)dst); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmax/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmax/main.cpp new file mode 100644 index 000000000..3972121e9 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmax/main.cpp @@ -0,0 +1,167 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file in the compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tcolexpandmax ST +// Test cases match PTO-ISA: /home/zhoushaofan/code/pto-isa/tests/npu/a5/src/st/testcase/tcolexpandmax/ +// TCOLEXPANDMAX: elementwise maximum of src0 and tiled src1 + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTCOLEXPANDMAX_fp32_16_128_1_128(float *src0, float *src1, float *dst, void *stream); +void LaunchTCOLEXPANDMAX_fp32_32_32_1_32(float *src0, float *src1, float *dst, void *stream); +void LaunchTCOLEXPANDMAX_fp16_4_256_1_256(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream); +void LaunchTCOLEXPANDMAX_fp16_10_64_1_64(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream); +void LaunchTCOLEXPANDMAX_int32_16_32_1_32(int32_t *src0, int32_t *src1, int32_t *dst, void *stream); +void LaunchTCOLEXPANDMAX_int16_16_64_1_64(int16_t *src0, int16_t *src1, int16_t *dst, void *stream); + +using LaunchFn = void (*)(void *, void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t src0Rows; + size_t src0Cols; + size_t src1Rows; + size_t src1Cols; + size_t dstRows; + size_t dstCols; + size_t validRows; + size_t validCols; + size_t elemSize; +}; + +static const TestCase kCases[] = { + {"fp32_16_128_1_128", (LaunchFn)LaunchTCOLEXPANDMAX_fp32_16_128_1_128, 16, 128, 1, 128, 16, 128, 16, 128, sizeof(float)}, + {"fp32_32_32_1_32", (LaunchFn)LaunchTCOLEXPANDMAX_fp32_32_32_1_32, 32, 32, 1, 32, 32, 32, 32, 32, sizeof(float)}, + {"fp16_4_256_1_256", (LaunchFn)LaunchTCOLEXPANDMAX_fp16_4_256_1_256, 4, 256, 1, 256, 4, 256, 4, 256, sizeof(uint16_t)}, + {"fp16_10_64_1_64", (LaunchFn)LaunchTCOLEXPANDMAX_fp16_10_64_1_64, 10, 64, 1, 64, 10, 64, 10, 64, sizeof(uint16_t)}, + {"int32_16_32_1_32", (LaunchFn)LaunchTCOLEXPANDMAX_int32_16_32_1_32, 16, 32, 1, 32, 16, 32, 16, 32, sizeof(int32_t)}, + {"int16_16_64_1_64", (LaunchFn)LaunchTCOLEXPANDMAX_int16_16_64_1_64, 16, 64, 1, 64, 16, 64, 16, 64, sizeof(int16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t src0ElemCount = tc.src0Rows * tc.src0Cols; + const size_t src1ElemCount = tc.src1Rows * tc.src1Cols; + const size_t dstElemCount = tc.dstRows * tc.dstCols; + const size_t src0FileSize = src0ElemCount * tc.elemSize; + const size_t src1FileSize = src1ElemCount * tc.elemSize; + const size_t dstFileSize = dstElemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (src0=%zux%zu, src1=%zux%zu -> dst=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.src0Rows, tc.src0Cols, tc.src1Rows, tc.src1Cols, tc.dstRows, tc.dstCols, tc.validRows, tc.validCols); + + std::string caseDir = std::string("./") + tc.name; + size_t actualSrc0FileSize = src0FileSize; + size_t actualSrc1FileSize = src1FileSize; + + void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), src0FileSize); + aclrtMallocHost((void **)(&src1Host), src1FileSize); + aclrtMallocHost((void **)(&dstHost), dstFileSize); + + aclrtMalloc((void **)&src0Device, src0FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, src1FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input0.bin").c_str(), actualSrc0FileSize, src0Host, src0FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input0.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0 && !ReadFile((caseDir + "/input1.bin").c_str(), actualSrc1FileSize, src1Host, src1FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, src0FileSize, src0Host, src0FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, src1FileSize, src1Host, src1FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmax/tcolexpandmax.pto b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmax/tcolexpandmax.pto new file mode 100644 index 000000000..27813530c --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmax/tcolexpandmax.pto @@ -0,0 +1,384 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file in the compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tcolexpandmax: elementwise maximum of src0 and tiled src1. +// Matches PTO-ISA testcase parameters. +// Key: tile_buf cols = full tensor width, v_col = valid portion + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // Case 1: fp32_16_128_1_128 (float32, src0=(16,128), src1=(1,128), dst=(16,128)) + func.func @TCOLEXPANDMAX_fp32_16_128_1_128(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c2048 = arith.constant 2048 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf32> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x128xf32> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x128xf32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x128xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x128xf32>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x128xf32>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpandmax ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x128xf32>) + return + } + + // Case 2: fp32_32_32_1_32 (float32, src0=(32,32), src1=(1,32), dst=(32,32)) + func.func @TCOLEXPANDMAX_fp32_32_32_1_32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c32], + strides = [%c32, %c32, %c32, %c32, %c1] + : !pto.tensor_view<1x1x1x1x32xf32> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c32] + : !pto.tensor_view<1x1x1x1x32xf32> -> !pto.partition_tensor_view<1x1x1x1x32xf32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x32xf32>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpandmax ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + return + } + + // Case 3: fp16_4_256_1_256 (float16, src0=(4,256), src1=(1,256), dst=(4,256)) + func.func @TCOLEXPANDMAX_fp16_4_256_1_256(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xf16> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c256] + : !pto.tensor_view<1x1x1x4x256xf16> -> !pto.partition_tensor_view<1x1x1x4x256xf16> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c256] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x256xf16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c256] + : !pto.tensor_view<1x1x1x4x256xf16> -> !pto.partition_tensor_view<1x1x1x4x256xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x4x256xf16>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x256xf16>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpandmax ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x256xf16>) + return + } + + // Case 4: fp16_10_64_1_64 (float16, src0=(10,64), src1=(1,64), dst=(10,64)) + func.func @TCOLEXPANDMAX_fp16_10_64_1_64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + %c64 = arith.constant 64 : index + %c640 = arith.constant 640 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c10, %c64], + strides = [%c640, %c640, %c640, %c64, %c1] + : !pto.tensor_view<1x1x1x10x64xf16> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c64], + strides = [%c64, %c64, %c64, %c64, %c1] + : !pto.tensor_view<1x1x1x1x64xf16> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c10, %c64], + strides = [%c640, %c640, %c640, %c64, %c1] + : !pto.tensor_view<1x1x1x10x64xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c10, %c64] + : !pto.tensor_view<1x1x1x10x64xf16> -> !pto.partition_tensor_view<1x1x1x10x64xf16> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c64] + : !pto.tensor_view<1x1x1x1x64xf16> -> !pto.partition_tensor_view<1x1x1x1x64xf16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c10, %c64] + : !pto.tensor_view<1x1x1x10x64xf16> -> !pto.partition_tensor_view<1x1x1x10x64xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x10x64xf16>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x64xf16>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpandmax ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + +pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x10x64xf16>) + return + } + + // Case 5: int32_16_32_1_32 (int32, src0=(16,32), src1=(1,32), dst=(16,32)) + func.func @TCOLEXPANDMAX_int32_16_32_1_32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c512 = arith.constant 512 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi32> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c32], + strides = [%c32, %c32, %c32, %c32, %c1] + : !pto.tensor_view<1x1x1x1x32xi32> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c32] + : !pto.tensor_view<1x1x1x1x32xi32> -> !pto.partition_tensor_view<1x1x1x1x32xi32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x32xi32>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpandmax ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) + return + } + + // Case 6: int16_16_64_1_64 (int16, src0=(16,64), src1=(1,64), dst=(16,64)) + func.func @TCOLEXPANDMAX_int16_16_64_1_64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi16> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c64], + strides = [%c64, %c64, %c64, %c64, %c1] + : !pto.tensor_view<1x1x1x1x64xi16> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi16> -> !pto.partition_tensor_view<1x1x1x16x64xi16> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c64] + : !pto.tensor_view<1x1x1x1x64xi16> -> !pto.partition_tensor_view<1x1x1x1x64xi16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi16> -> !pto.partition_tensor_view<1x1x1x16x64xi16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x64xi16>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x64xi16>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpandmax ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) +outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xi16>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmin/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmin/CMakeLists.txt new file mode 100644 index 000000000..f541b2396 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmin/CMakeLists.txt @@ -0,0 +1,16 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file in the compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tcolexpandmin) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmin/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmin/cases.py new file mode 100644 index 000000000..f77ae2eb2 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmin/cases.py @@ -0,0 +1,97 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file in the compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tcolexpandmin ST test cases. + +TCOLEXPANDMIN: compute elementwise minimum of src0 and tiled src1. + - src0_shape: (src0_row, cols) - first input tile + - src1_shape: (1, cols) - second input tile (single row, broadcasted) + - dst_shape: (dst_row, cols) - output tile + - shape: (dst_row, cols) - alias of dst_shape, for compare.py compatibility + - valid_shape: (valid_row, valid_col) - effective computation region + - reps: number of times to tile src1 (equals src0_row) + +Golden: np.minimum(src0, np.tile(src1, (reps, 1))[:, :dst_col]) + +Case naming: {dtype}_{src0_row}_{src0_col}_{src1_row}_{dst_col} +""" + +import numpy as np + +CASES = [ + { + "name": "fp32_16_128_1_128", + "dtype": np.float32, + "src0_shape": (16, 128), + "src1_shape": (1, 128), + "shape": (16, 128), + "valid_shape": (16, 128), + "reps": 16, + "eps": 1e-6, + }, + { + "name": "fp32_32_32_1_32", + "dtype": np.float32, + "src0_shape": (32, 32), + "src1_shape": (1, 32), + "shape": (32, 32), + "valid_shape": (32, 32), + "reps": 32, + "eps": 1e-6, + }, + { + "name": "fp16_4_256_1_256", + "dtype": np.float16, + "src0_shape": (4, 256), + "src1_shape": (1, 256), + "shape": (4, 256), + "valid_shape": (4, 256), + "reps": 4, + "eps": 1e-3, + }, + { + "name": "fp16_10_64_1_64", + "dtype": np.float16, + "src0_shape": (10, 64), + "src1_shape": (1, 64), + "shape": (10, 64), + "valid_shape": (10, 64), + "reps": 10, + "eps": 1e-3, + }, + { + "name": "int32_16_32_1_32", + "dtype": np.int32, + "src0_shape": (16, 32), + "src1_shape": (1, 32), + "shape": (16, 32), + "valid_shape": (16, 32), + "reps": 16, + "eps": 0, + }, + { + "name": "int16_16_64_1_64", + "dtype": np.int16, + "src0_shape": (16, 64), + "src1_shape": (1, 64), + "shape": (16, 64), + "valid_shape": (16, 64), + "reps": 16, + "eps": 0, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmin/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmin/compare.py new file mode 100644 index 000000000..b8ddb1131 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmin/compare.py @@ -0,0 +1,56 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You can not use this file in the compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmin/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmin/gen_data.py new file mode 100644 index 000000000..a0b5e82be --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmin/gen_data.py @@ -0,0 +1,42 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file in the compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + src0_shape = case["src0_shape"] + dst_shape = case["shape"] + src1_shape = case["src1_shape"] + valid_shape = case["valid_shape"] + reps = case["reps"] + + src0 = np.random.randint(1, 10, size=src0_shape).astype(dtype) + src1 = np.random.randint(1, 10, size=src1_shape).astype(dtype) + + golden = np.minimum(src0, np.tile(src1, (reps, 1))[:, :dst_shape[1]]) + golden = golden.astype(dtype) + + save_case_data(case["name"], {"input0": src0, "input1": src1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} src0={src0_shape} src1={src1_shape} dst={dst_shape} valid={valid_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmin/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmin/launch.cpp new file mode 100644 index 000000000..c8cdf39e3 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmin/launch.cpp @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 1: fp32_16_128_1_128 +extern "C" __global__ AICORE void TCOLEXPANDMIN_fp32_16_128_1_128(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTCOLEXPANDMIN_fp32_16_128_1_128(float *src0, float *src1, float *dst, void *stream) { + TCOLEXPANDMIN_fp32_16_128_1_128<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// Case 2: fp32_32_32_1_32 +extern "C" __global__ AICORE void TCOLEXPANDMIN_fp32_32_32_1_32(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTCOLEXPANDMIN_fp32_32_32_1_32(float *src0, float *src1, float *dst, void *stream) { + TCOLEXPANDMIN_fp32_32_32_1_32<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// Case 3: fp16_4_256_1_256 +extern "C" __global__ AICORE void TCOLEXPANDMIN_fp16_4_256_1_256(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTCOLEXPANDMIN_fp16_4_256_1_256(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream) { + TCOLEXPANDMIN_fp16_4_256_1_256<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} + +// Case 4: fp16_10_64_1_64 +extern "C" __global__ AICORE void TCOLEXPANDMIN_fp16_10_64_1_64(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTCOLEXPANDMIN_fp16_10_64_1_64(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream) { + TCOLEXPANDMIN_fp16_10_64_1_64<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} +// Case: int32_16_32_1_32 +extern "C" __global__ AICORE void TCOLEXPANDMIN_int32_16_32_1_32(__gm__ int32_t *src0, __gm__ int32_t *src1, __gm__ int32_t *dst); + +void LaunchTCOLEXPANDMIN_int32_16_32_1_32(int32_t *src0, int32_t *src1, int32_t *dst, void *stream) { + TCOLEXPANDMIN_int32_16_32_1_32<<<1, nullptr, stream>>>((__gm__ int32_t *)src0, (__gm__ int32_t *)src1, (__gm__ int32_t *)dst); +} + +// Case: int16_16_64_1_64 +extern "C" __global__ AICORE void TCOLEXPANDMIN_int16_16_64_1_64(__gm__ int16_t *src0, __gm__ int16_t *src1, __gm__ int16_t *dst); + +void LaunchTCOLEXPANDMIN_int16_16_64_1_64(int16_t *src0, int16_t *src1, int16_t *dst, void *stream) { + TCOLEXPANDMIN_int16_16_64_1_64<<<1, nullptr, stream>>>((__gm__ int16_t *)src0, (__gm__ int16_t *)src1, (__gm__ int16_t *)dst); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmin/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmin/main.cpp new file mode 100644 index 000000000..64219b470 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmin/main.cpp @@ -0,0 +1,167 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file in the compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tcolexpandmin ST +// Test cases match PTO-ISA +// TCOLEXPANDMIN: compute elementwise minimum of src0 and tiled src1 + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTCOLEXPANDMIN_fp32_16_128_1_128(float *src0, float *src1, float *dst, void *stream); +void LaunchTCOLEXPANDMIN_fp32_32_32_1_32(float *src0, float *src1, float *dst, void *stream); +void LaunchTCOLEXPANDMIN_fp16_4_256_1_256(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream); +void LaunchTCOLEXPANDMIN_fp16_10_64_1_64(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream); +void LaunchTCOLEXPANDMIN_int32_16_32_1_32(int32_t *src0, int32_t *src1, int32_t *dst, void *stream); +void LaunchTCOLEXPANDMIN_int16_16_64_1_64(int16_t *src0, int16_t *src1, int16_t *dst, void *stream); + +using LaunchFn = void (*)(void *, void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t src0Rows; + size_t src0Cols; + size_t src1Rows; + size_t src1Cols; + size_t dstRows; + size_t dstCols; + size_t validRows; + size_t validCols; + size_t elemSize; +}; + +static const TestCase kCases[] = { + {"fp32_16_128_1_128", (LaunchFn)LaunchTCOLEXPANDMIN_fp32_16_128_1_128, 16, 128, 1, 128, 16, 128, 16, 128, sizeof(float)}, + {"fp32_32_32_1_32", (LaunchFn)LaunchTCOLEXPANDMIN_fp32_32_32_1_32, 32, 32, 1, 32, 32, 32, 32, 32, sizeof(float)}, + {"fp16_4_256_1_256", (LaunchFn)LaunchTCOLEXPANDMIN_fp16_4_256_1_256, 4, 256, 1, 256, 4, 256, 4, 256, sizeof(uint16_t)}, + {"fp16_10_64_1_64", (LaunchFn)LaunchTCOLEXPANDMIN_fp16_10_64_1_64, 10, 64, 1, 64, 10, 64, 10, 64, sizeof(uint16_t)}, + {"int32_16_32_1_32", (LaunchFn)LaunchTCOLEXPANDMIN_int32_16_32_1_32, 16, 32, 1, 32, 16, 32, 16, 32, sizeof(int32_t)}, + {"int16_16_64_1_64", (LaunchFn)LaunchTCOLEXPANDMIN_int16_16_64_1_64, 16, 64, 1, 64, 16, 64, 16, 64, sizeof(int16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t src0ElemCount = tc.src0Rows * tc.src0Cols; + const size_t src1ElemCount = tc.src1Rows * tc.src1Cols; + const size_t dstElemCount = tc.dstRows * tc.dstCols; + const size_t src0FileSize = src0ElemCount * tc.elemSize; + const size_t src1FileSize = src1ElemCount * tc.elemSize; + const size_t dstFileSize = dstElemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (src0=%zux%zu, src1=%zux%zu -> dst=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.src0Rows, tc.src0Cols, tc.src1Rows, tc.src1Cols, + tc.dstRows, tc.dstCols, tc.validRows, tc.validCols); + + std::string caseDir = std::string("./") + tc.name; + size_t actualSrc0FileSize = src0FileSize; + size_t actualSrc1FileSize = src1FileSize; + + void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), src0FileSize); + aclrtMallocHost((void **)(&src1Host), src1FileSize); + aclrtMallocHost((void **)(&dstHost), dstFileSize); + + aclrtMalloc((void **)&src0Device, src0FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, src1FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input0.bin").c_str(), actualSrc0FileSize, src0Host, src0FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input0.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input1.bin").c_str(), actualSrc1FileSize, src1Host, src1FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, src0FileSize, src0Host, src0FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, src1FileSize, src1Host, src1FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmin/tcolexpandmin.pto b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmin/tcolexpandmin.pto new file mode 100644 index 000000000..c79f9aa89 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmin/tcolexpandmin.pto @@ -0,0 +1,383 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file in the compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tcolexpandmin: elementwise minimum of src0 and tiled src1. +// Matches PTO-ISA testcase parameters. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // Case 1: fp32_16_128_1_128 (float32, src0=(16,128), src1=(1,128), dst=(16,128)) + func.func @TCOLEXPANDMIN_fp32_16_128_1_128(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c2048 = arith.constant 2048 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf32> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x128xf32> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x128xf32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x128xf32> + + %src0_tile = pto.alloc_tile + : !pto.tile_buf + + %src1_tile = pto.alloc_tile + : !pto.tile_buf + + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x128xf32>) + outs(%src0_tile : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x128xf32>) + outs(%src1_tile : !pto.tile_buf) + + pto.tcolexpandmin ins(%src0_tile, %src1_tile : !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x128xf32>) + return + } + + // Case 2: fp32_32_32_1_32 (float32, src0=(32,32), src1=(1,32), dst=(32,32)) + func.func @TCOLEXPANDMIN_fp32_32_32_1_32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c32], + strides = [%c32, %c32, %c32, %c32, %c1] + : !pto.tensor_view<1x1x1x1x32xf32> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c32] + : !pto.tensor_view<1x1x1x1x32xf32> -> !pto.partition_tensor_view<1x1x1x1x32xf32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + + %src0_tile = pto.alloc_tile + : !pto.tile_buf + + %src1_tile = pto.alloc_tile + : !pto.tile_buf + + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + outs(%src0_tile : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x32xf32>) + outs(%src1_tile : !pto.tile_buf) + + pto.tcolexpandmin ins(%src0_tile, %src1_tile : !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + return + } + + // Case 3: fp16_4_256_1_256 (float16, src0=(4,256), src1=(1,256), dst=(4,256)) + func.func @TCOLEXPANDMIN_fp16_4_256_1_256(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xf16> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c256] + : !pto.tensor_view<1x1x1x4x256xf16> -> !pto.partition_tensor_view<1x1x1x4x256xf16> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c256] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x256xf16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c256] + : !pto.tensor_view<1x1x1x4x256xf16> -> !pto.partition_tensor_view<1x1x1x4x256xf16> + + %src0_tile = pto.alloc_tile + : !pto.tile_buf + + %src1_tile = pto.alloc_tile + : !pto.tile_buf + + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x4x256xf16>) + outs(%src0_tile : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x256xf16>) + outs(%src1_tile : !pto.tile_buf) + + pto.tcolexpandmin ins(%src0_tile, %src1_tile : !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x256xf16>) + return + } + + // Case 4: fp16_10_64_1_64 (float16, src0=(10,64), src1=(1,64), dst=(10,64)) + func.func @TCOLEXPANDMIN_fp16_10_64_1_64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + %c64 = arith.constant 64 : index + %c640 = arith.constant 640 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c10, %c64], + strides = [%c640, %c640, %c640, %c64, %c1] + : !pto.tensor_view<1x1x1x10x64xf16> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c64], + strides = [%c64, %c64, %c64, %c64, %c1] + : !pto.tensor_view<1x1x1x1x64xf16> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c10, %c64], + strides = [%c640, %c640, %c640, %c64, %c1] + : !pto.tensor_view<1x1x1x10x64xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c10, %c64] + : !pto.tensor_view<1x1x1x10x64xf16> -> !pto.partition_tensor_view<1x1x1x10x64xf16> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c64] + : !pto.tensor_view<1x1x1x1x64xf16> -> !pto.partition_tensor_view<1x1x1x1x64xf16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c10, %c64] + : !pto.tensor_view<1x1x1x10x64xf16> -> !pto.partition_tensor_view<1x1x1x10x64xf16> + + %src0_tile = pto.alloc_tile + : !pto.tile_buf + + %src1_tile = pto.alloc_tile + : !pto.tile_buf + + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x10x64xf16>) + outs(%src0_tile : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x64xf16>) + outs(%src1_tile : !pto.tile_buf) + + pto.tcolexpandmin ins(%src0_tile, %src1_tile : !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + +pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x10x64xf16>) + return + } + + // Case 5: int32_16_32_1_32 (int32, src0=(16,32), src1=(1,32), dst=(16,32)) + func.func @TCOLEXPANDMIN_int32_16_32_1_32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c512 = arith.constant 512 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi32> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c32], + strides = [%c32, %c32, %c32, %c32, %c1] + : !pto.tensor_view<1x1x1x1x32xi32> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c32] + : !pto.tensor_view<1x1x1x1x32xi32> -> !pto.partition_tensor_view<1x1x1x1x32xi32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + + %src0_tile = pto.alloc_tile + : !pto.tile_buf + + %src1_tile = pto.alloc_tile + : !pto.tile_buf + + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) + outs(%src0_tile : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x32xi32>) + outs(%src1_tile : !pto.tile_buf) + + pto.tcolexpandmin ins(%src0_tile, %src1_tile : !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) + return + } + + // Case 6: int16_16_64_1_64 (int16, src0=(16,64), src1=(1,64), dst=(16,64)) + func.func @TCOLEXPANDMIN_int16_16_64_1_64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi16> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c64], + strides = [%c64, %c64, %c64, %c64, %c1] + : !pto.tensor_view<1x1x1x1x64xi16> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi16> -> !pto.partition_tensor_view<1x1x1x16x64xi16> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c64] + : !pto.tensor_view<1x1x1x1x64xi16> -> !pto.partition_tensor_view<1x1x1x1x64xi16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi16> -> !pto.partition_tensor_view<1x1x1x16x64xi16> + + %src0_tile = pto.alloc_tile + : !pto.tile_buf + + %src1_tile = pto.alloc_tile + : !pto.tile_buf + + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x64xi16>) + outs(%src0_tile : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x64xi16>) + outs(%src1_tile : !pto.tile_buf) + + pto.tcolexpandmin ins(%src0_tile, %src1_tile : !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xi16>) + return + } +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmul/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmul/CMakeLists.txt new file mode 100644 index 000000000..c957813e6 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmul/CMakeLists.txt @@ -0,0 +1,16 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You can not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tcolexpandmul) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmul/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmul/cases.py new file mode 100644 index 000000000..7bc14eb3a --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmul/cases.py @@ -0,0 +1,84 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You can not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tcolexpandmul ST test cases. + +TCOLEXPANDMUL: expand src1 then multiply with src0. + - src0_shape: (dst_row, dst_col) - already expanded + - src1_shape: (src1_row, src1_col) - to be expanded (usually src1_row=1) + - dst_shape: (dst_row, dst_col) +""" + +import numpy as np + +CASES = [ + { + "name": "fp32_16_128_1_128", + "dtype": np.float32, + "src0_shape": (16, 128), + "src1_shape": (1, 128), + "shape": (16, 128), + "valid_shape": (16, 128), + "eps": 1e-3, + }, + { + "name": "fp32_32_32_1_32", + "dtype": np.float32, + "src0_shape": (32, 32), + "src1_shape": (1, 32), + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-3, + }, + { + "name": "fp16_4_256_1_256", + "dtype": np.float16, + "src0_shape": (4, 256), + "src1_shape": (1, 256), + "shape": (4, 256), + "valid_shape": (4, 256), + "eps": 1e-3, + }, + { + "name": "fp16_10_64_1_64", + "dtype": np.float16, + "src0_shape": (10, 64), + "src1_shape": (1, 64), + "shape": (10, 64), + "valid_shape": (10, 64), + "eps": 1e-3, + }, + { + "name": "int32_16_32_1_32", + "dtype": np.int32, + "src0_shape": (16, 32), + "src1_shape": (1, 32), + "shape": (16, 32), + "valid_shape": (16, 32), + "eps": 0, + }, + { + "name": "int16_16_64_1_64", + "dtype": np.int16, + "src0_shape": (16, 64), + "src1_shape": (1, 64), + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 0, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmul/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmul/compare.py new file mode 100644 index 000000000..b8ddb1131 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmul/compare.py @@ -0,0 +1,56 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You can not use this file in the compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmul/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmul/gen_data.py new file mode 100644 index 000000000..d8ee0880f --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmul/gen_data.py @@ -0,0 +1,42 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You can not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + src0_shape = case["src0_shape"] + dst_shape = case["shape"] + src1_shape = case["src1_shape"] + valid_shape = case["valid_shape"] + + src0 = np.random.randint(1, 10, size=src0_shape).astype(dtype) + src1 = np.random.randint(1, 10, size=src1_shape).astype(dtype) + + dst_row, dst_col = dst_shape + reps = dst_row + golden = src0 * np.tile(src1, (reps, 1))[:, :dst_col] + + save_case_data(case["name"], {"input0": src0, "input1": src1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} src0={src0_shape} src1={src1_shape} dst={dst_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmul/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmul/launch.cpp new file mode 100644 index 000000000..1fb3521d9 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmul/launch.cpp @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 1: fp32_16_128_1_128 +extern "C" __global__ AICORE void TCOLEXPANDMUL_fp32_16_128_1_128(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTCOLEXPANDMUL_fp32_16_128_1_128(float *src0, float *src1, float *dst, void *stream) { + TCOLEXPANDMUL_fp32_16_128_1_128<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// Case 2: fp32_32_32_1_32 +extern "C" __global__ AICORE void TCOLEXPANDMUL_fp32_32_32_1_32(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTCOLEXPANDMUL_fp32_32_32_1_32(float *src0, float *src1, float *dst, void *stream) { + TCOLEXPANDMUL_fp32_32_32_1_32<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// Case 3: fp16_4_256_1_256 +extern "C" __global__ AICORE void TCOLEXPANDMUL_fp16_4_256_1_256(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTCOLEXPANDMUL_fp16_4_256_1_256(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream) { + TCOLEXPANDMUL_fp16_4_256_1_256<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} + +// Case 4: fp16_10_64_1_64 +extern "C" __global__ AICORE void TCOLEXPANDMUL_fp16_10_64_1_64(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTCOLEXPANDMUL_fp16_10_64_1_64(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream) { + TCOLEXPANDMUL_fp16_10_64_1_64<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} +// Case: int32_16_32_1_32 +extern "C" __global__ AICORE void TCOLEXPANDMUL_int32_16_32_1_32(__gm__ int32_t *src0, __gm__ int32_t *src1, __gm__ int32_t *dst); + +void LaunchTCOLEXPANDMUL_int32_16_32_1_32(int32_t *src0, int32_t *src1, int32_t *dst, void *stream) { + TCOLEXPANDMUL_int32_16_32_1_32<<<1, nullptr, stream>>>((__gm__ int32_t *)src0, (__gm__ int32_t *)src1, (__gm__ int32_t *)dst); +} + +// Case: int16_16_64_1_64 +extern "C" __global__ AICORE void TCOLEXPANDMUL_int16_16_64_1_64(__gm__ int16_t *src0, __gm__ int16_t *src1, __gm__ int16_t *dst); + +void LaunchTCOLEXPANDMUL_int16_16_64_1_64(int16_t *src0, int16_t *src1, int16_t *dst, void *stream) { + TCOLEXPANDMUL_int16_16_64_1_64<<<1, nullptr, stream>>>((__gm__ int16_t *)src0, (__gm__ int16_t *)src1, (__gm__ int16_t *)dst); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmul/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmul/main.cpp new file mode 100644 index 000000000..6186d0068 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmul/main.cpp @@ -0,0 +1,165 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tcolexpandmul ST +// TCOLEXPANDMUL: expand src1 then multiply with src0 + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTCOLEXPANDMUL_fp32_16_128_1_128(float *src0, float *src1, float *dst, void *stream); +void LaunchTCOLEXPANDMUL_fp32_32_32_1_32(float *src0, float *src1, float *dst, void *stream); +void LaunchTCOLEXPANDMUL_fp16_4_256_1_256(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream); +void LaunchTCOLEXPANDMUL_fp16_10_64_1_64(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream); +void LaunchTCOLEXPANDMUL_int32_16_32_1_32(int32_t *src0, int32_t *src1, int32_t *dst, void *stream); +void LaunchTCOLEXPANDMUL_int16_16_64_1_64(int16_t *src0, int16_t *src1, int16_t *dst, void *stream); + +using LaunchFn = void (*)(void *, void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t src0Rows; + size_t src0Cols; + size_t src1Rows; + size_t src1Cols; + size_t dstRows; + size_t dstCols; + size_t validRows; + size_t validCols; + size_t elemSize; +}; + +static const TestCase kCases[] = { + {"fp32_16_128_1_128", (LaunchFn)LaunchTCOLEXPANDMUL_fp32_16_128_1_128, 16, 128, 1, 128, 16, 128, 16, 128, sizeof(float)}, + {"fp32_32_32_1_32", (LaunchFn)LaunchTCOLEXPANDMUL_fp32_32_32_1_32, 32, 32, 1, 32, 32, 32, 32, 32, sizeof(float)}, + {"fp16_4_256_1_256", (LaunchFn)LaunchTCOLEXPANDMUL_fp16_4_256_1_256, 4, 256, 1, 256, 4, 256, 4, 256, sizeof(uint16_t)}, + {"fp16_10_64_1_64", (LaunchFn)LaunchTCOLEXPANDMUL_fp16_10_64_1_64, 10, 64, 1, 64, 10, 64, 10, 64, sizeof(uint16_t)}, + {"int32_16_32_1_32", (LaunchFn)LaunchTCOLEXPANDMUL_int32_16_32_1_32, 16, 32, 1, 32, 16, 32, 16, 32, sizeof(int32_t)}, + {"int16_16_64_1_64", (LaunchFn)LaunchTCOLEXPANDMUL_int16_16_64_1_64, 16, 64, 1, 64, 16, 64, 16, 64, sizeof(int16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t src0ElemCount = tc.src0Rows * tc.src0Cols; + const size_t src1ElemCount = tc.src1Rows * tc.src1Cols; + const size_t dstElemCount = tc.dstRows * tc.dstCols; + const size_t src0FileSize = src0ElemCount * tc.elemSize; + const size_t src1FileSize = src1ElemCount * tc.elemSize; + const size_t dstFileSize = dstElemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (src0=%zux%zu, src1=%zux%zu -> dst=%zux%zu) ===\n", + tc.name, tc.src0Rows, tc.src0Cols, tc.src1Rows, tc.src1Cols, tc.dstRows, tc.dstCols); + + std::string caseDir = std::string("./") + tc.name; + size_t actualSrc0FileSize = src0FileSize; + size_t actualSrc1FileSize = src1FileSize; + + void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), src0FileSize); + aclrtMallocHost((void **)(&src1Host), src1FileSize); + aclrtMallocHost((void **)(&dstHost), dstFileSize); + + aclrtMalloc((void **)&src0Device, src0FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, src1FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input0.bin").c_str(), actualSrc0FileSize, src0Host, src0FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input0.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input1.bin").c_str(), actualSrc1FileSize, src1Host, src1FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, src0FileSize, src0Host, src0FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, src1FileSize, src1Host, src1FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmul/tcolexpandmul.pto b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmul/tcolexpandmul.pto new file mode 100644 index 000000000..9d371d043 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandmul/tcolexpandmul.pto @@ -0,0 +1,382 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tcolexpandmul: expand src1 then multiply with src0. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // Case 1: fp32_16_128_1_128 (float32, src0=(16,128), src1=(1,128), dst=(16,128)) + func.func @TCOLEXPANDMUL_fp32_16_128_1_128(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c2048 = arith.constant 2048 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf32> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x128xf32> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x128xf32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x128xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x128xf32>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x128xf32>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpandmul ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x128xf32>) + return + } + + // Case 2: fp32_32_32_1_32 (float32, src0=(32,32), src1=(1,32), dst=(32,32)) + func.func @TCOLEXPANDMUL_fp32_32_32_1_32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c32], + strides = [%c32, %c32, %c32, %c32, %c1] + : !pto.tensor_view<1x1x1x1x32xf32> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c32] + : !pto.tensor_view<1x1x1x1x32xf32> -> !pto.partition_tensor_view<1x1x1x1x32xf32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x32xf32>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpandmul ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + return + } + + // Case 3: fp16_4_256_1_256 (float16, src0=(4,256), src1=(1,256), dst=(4,256)) + func.func @TCOLEXPANDMUL_fp16_4_256_1_256(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xf16> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c256] + : !pto.tensor_view<1x1x1x4x256xf16> -> !pto.partition_tensor_view<1x1x1x4x256xf16> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c256] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x256xf16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c256] + : !pto.tensor_view<1x1x1x4x256xf16> -> !pto.partition_tensor_view<1x1x1x4x256xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x4x256xf16>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x256xf16>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpandmul ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x256xf16>) + return + } + + // Case 4: fp16_10_64_1_64 (float16, src0=(10,64), src1=(1,64), dst=(10,64)) + func.func @TCOLEXPANDMUL_fp16_10_64_1_64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + %c64 = arith.constant 64 : index + %c640 = arith.constant 640 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c10, %c64], + strides = [%c640, %c640, %c640, %c64, %c1] + : !pto.tensor_view<1x1x1x10x64xf16> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c64], + strides = [%c64, %c64, %c64, %c64, %c1] + : !pto.tensor_view<1x1x1x1x64xf16> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c10, %c64], + strides = [%c640, %c640, %c640, %c64, %c1] + : !pto.tensor_view<1x1x1x10x64xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c10, %c64] + : !pto.tensor_view<1x1x1x10x64xf16> -> !pto.partition_tensor_view<1x1x1x10x64xf16> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c64] + : !pto.tensor_view<1x1x1x1x64xf16> -> !pto.partition_tensor_view<1x1x1x1x64xf16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c10, %c64] + : !pto.tensor_view<1x1x1x10x64xf16> -> !pto.partition_tensor_view<1x1x1x10x64xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x10x64xf16>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x64xf16>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpandmul ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + +pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x10x64xf16>) + return + } + + // Case 5: int32_16_32_1_32 (int32, src0=(16,32), src1=(1,32), dst=(16,32)) + func.func @TCOLEXPANDMUL_int32_16_32_1_32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c512 = arith.constant 512 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi32> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c32], + strides = [%c32, %c32, %c32, %c32, %c1] + : !pto.tensor_view<1x1x1x1x32xi32> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c32] + : !pto.tensor_view<1x1x1x1x32xi32> -> !pto.partition_tensor_view<1x1x1x1x32xi32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x32xi32>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpandmul ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) + return + } + + // Case 6: int16_16_64_1_64 (int16, src0=(16,64), src1=(1,64), dst=(16,64)) + func.func @TCOLEXPANDMUL_int16_16_64_1_64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi16> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c64], + strides = [%c64, %c64, %c64, %c64, %c1] + : !pto.tensor_view<1x1x1x1x64xi16> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi16> -> !pto.partition_tensor_view<1x1x1x16x64xi16> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c64] + : !pto.tensor_view<1x1x1x1x64xi16> -> !pto.partition_tensor_view<1x1x1x1x64xi16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi16> -> !pto.partition_tensor_view<1x1x1x16x64xi16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x64xi16>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x64xi16>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpandmul ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xi16>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandsub/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandsub/CMakeLists.txt new file mode 100644 index 000000000..0eacb2968 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandsub/CMakeLists.txt @@ -0,0 +1,16 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file in the compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tcolexpandsub) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandsub/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandsub/cases.py new file mode 100644 index 000000000..47d04eed7 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandsub/cases.py @@ -0,0 +1,84 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tcolexpandsub ST test cases. +Matches PTO-ISA testcase definitions in /home/zhoushaofan/code/pto-isa/tests/npu/a5/src/st/testcase/tcolexpandsub/ + +TCOLEXPANDSUB: subtract src0 by expanded src1 (broadcast src1 first row). + - src0_shape: (src0_row, cols) - first input tile + - src1_shape: (src1_row, cols) - second input tile (only first row used for broadcast) + - dst_shape: (dst_row, cols) - result output + - shape: (dst_row, cols) - alias of dst_shape, for compare.py compatibility + - valid_shape: (valid_row, valid_col) - effective computation region + +Golden: src0 - np.tile(src1, (reps, 1))[:, :dst_col] # expand then subtract + +Case naming: {dtype}_{src0_row}_{cols}_{src1_row}_{dst_col} +""" + +import numpy as np + +CASES = [ + { + "name": "fp32_6_128_1_128", + "dtype": np.float32, + "src0_shape": (6, 128), + "src1_shape": (1, 128), + "shape": (6, 128), + "valid_shape": (6, 128), + "eps": 1e-6, + }, + { + "name": "fp32_18_32_1_32", + "dtype": np.float32, + "src0_shape": (18, 32), + "src1_shape": (1, 32), + "shape": (18, 32), + "valid_shape": (18, 32), + "eps": 1e-6, + }, + { + "name": "fp16_10_256_1_256", + "dtype": np.float16, + "src0_shape": (10, 256), + "src1_shape": (1, 256), + "shape": (10, 256), + "valid_shape": (10, 256), + "eps": 1e-3, + }, + { + "name": "fp16_12_64_1_64", + "dtype": np.float16, + "src0_shape": (12, 64), + "src1_shape": (1, 64), + "shape": (12, 64), + "valid_shape": (12, 64), + "eps": 1e-3, + }, + { + "name": "int32_16_32_1_32", + "dtype": np.int32, + "src0_shape": (16, 32), + "src1_shape": (1, 32), + "shape": (16, 32), + "valid_shape": (16, 32), + "eps": 0, + }, + { + "name": "int16_16_64_1_64", + "dtype": np.int16, + "src0_shape": (16, 64), + "src1_shape": (1, 64), + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 0, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandsub/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandsub/compare.py new file mode 100644 index 000000000..b8ddb1131 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandsub/compare.py @@ -0,0 +1,56 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You can not use this file in the compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandsub/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandsub/gen_data.py new file mode 100644 index 000000000..a2fc88aab --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandsub/gen_data.py @@ -0,0 +1,35 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + src0_shape = case["src0_shape"] + dst_shape = case["shape"] + src1_shape = case["src1_shape"] + valid_shape = case["valid_shape"] + + src0 = np.random.randint(1, 10, size=src0_shape).astype(dtype) + src1 = np.random.randint(1, 10, size=src1_shape).astype(dtype) + + valid_row, valid_col = valid_shape + reps = valid_row + golden = src0 - np.tile(src1, (reps, 1))[:, :valid_col] + + save_case_data(case["name"], {"input0": src0, "input1": src1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} src0={src0_shape} src1={src1_shape} dst={dst_shape} valid={valid_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandsub/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandsub/launch.cpp new file mode 100644 index 000000000..6d9ff773f --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandsub/launch.cpp @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 1: fp32_6_128_1_128 +extern "C" __global__ AICORE void TCOLEXPANDSUB_fp32_6_128_1_128(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTCOLEXPANDSUB_fp32_6_128_1_128(float *src0, float *src1, float *dst, void *stream) { + TCOLEXPANDSUB_fp32_6_128_1_128<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// Case 2: fp32_18_32_1_32 +extern "C" __global__ AICORE void TCOLEXPANDSUB_fp32_18_32_1_32(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTCOLEXPANDSUB_fp32_18_32_1_32(float *src0, float *src1, float *dst, void *stream) { + TCOLEXPANDSUB_fp32_18_32_1_32<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// Case 3: fp16_10_256_1_256 +extern "C" __global__ AICORE void TCOLEXPANDSUB_fp16_10_256_1_256(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTCOLEXPANDSUB_fp16_10_256_1_256(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream) { + TCOLEXPANDSUB_fp16_10_256_1_256<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} + +// Case 4: fp16_12_64_1_64 +extern "C" __global__ AICORE void TCOLEXPANDSUB_fp16_12_64_1_64(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTCOLEXPANDSUB_fp16_12_64_1_64(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream) { + TCOLEXPANDSUB_fp16_12_64_1_64<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} +// Case: int32_16_32_1_32 +extern "C" __global__ AICORE void TCOLEXPANDSUB_int32_16_32_1_32(__gm__ int32_t *src0, __gm__ int32_t *src1, __gm__ int32_t *dst); + +void LaunchTCOLEXPANDSUB_int32_16_32_1_32(int32_t *src0, int32_t *src1, int32_t *dst, void *stream) { + TCOLEXPANDSUB_int32_16_32_1_32<<<1, nullptr, stream>>>((__gm__ int32_t *)src0, (__gm__ int32_t *)src1, (__gm__ int32_t *)dst); +} + +// Case: int16_16_64_1_64 +extern "C" __global__ AICORE void TCOLEXPANDSUB_int16_16_64_1_64(__gm__ int16_t *src0, __gm__ int16_t *src1, __gm__ int16_t *dst); + +void LaunchTCOLEXPANDSUB_int16_16_64_1_64(int16_t *src0, int16_t *src1, int16_t *dst, void *stream) { + TCOLEXPANDSUB_int16_16_64_1_64<<<1, nullptr, stream>>>((__gm__ int16_t *)src0, (__gm__ int16_t *)src1, (__gm__ int16_t *)dst); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandsub/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandsub/main.cpp new file mode 100644 index 000000000..20b950896 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandsub/main.cpp @@ -0,0 +1,167 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file in the compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tcolexpandsub ST +// Test cases match PTO-ISA: /home/zhoushaofan/code/pto-isa/tests/npu/a5/src/st/testcase/tcolexpandsub/ +// TCOLEXPANDSUB: subtract src0 by expanded src1 (broadcast src1 first row) + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTCOLEXPANDSUB_fp32_6_128_1_128(float *src0, float *src1, float *dst, void *stream); +void LaunchTCOLEXPANDSUB_fp32_18_32_1_32(float *src0, float *src1, float *dst, void *stream); +void LaunchTCOLEXPANDSUB_fp16_10_256_1_256(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream); +void LaunchTCOLEXPANDSUB_fp16_12_64_1_64(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream); +void LaunchTCOLEXPANDSUB_int32_16_32_1_32(int32_t *src0, int32_t *src1, int32_t *dst, void *stream); +void LaunchTCOLEXPANDSUB_int16_16_64_1_64(int16_t *src0, int16_t *src1, int16_t *dst, void *stream); + +using LaunchFn = void (*)(void *, void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t src0Rows; + size_t src0Cols; + size_t src1Rows; + size_t src1Cols; + size_t dstRows; + size_t dstCols; + size_t validRows; + size_t validCols; + size_t elemSize; +}; + +static const TestCase kCases[] = { + {"fp32_6_128_1_128", (LaunchFn)LaunchTCOLEXPANDSUB_fp32_6_128_1_128, 6, 128, 1, 128, 6, 128, 6, 128, sizeof(float)}, + {"fp32_18_32_1_32", (LaunchFn)LaunchTCOLEXPANDSUB_fp32_18_32_1_32, 18, 32, 1, 32,18, 32,18, 32, sizeof(float)}, + {"fp16_10_256_1_256", (LaunchFn)LaunchTCOLEXPANDSUB_fp16_10_256_1_256, 10, 256, 1, 256,10, 256,10, 256, sizeof(uint16_t)}, + {"fp16_12_64_1_64", (LaunchFn)LaunchTCOLEXPANDSUB_fp16_12_64_1_64, 12, 64, 1, 64,12, 64,12, 64, sizeof(uint16_t)}, + {"int32_16_32_1_32", (LaunchFn)LaunchTCOLEXPANDSUB_int32_16_32_1_32, 16, 32, 1, 32, 16, 32, 16, 32, sizeof(int32_t)}, + {"int16_16_64_1_64", (LaunchFn)LaunchTCOLEXPANDSUB_int16_16_64_1_64, 16, 64, 1, 64, 16, 64, 16, 64, sizeof(int16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t src0ElemCount = tc.src0Rows * tc.src0Cols; + const size_t src1ElemCount = tc.src1Rows * tc.src1Cols; + const size_t dstElemCount = tc.dstRows * tc.dstCols; + const size_t src0FileSize = src0ElemCount * tc.elemSize; + const size_t src1FileSize = src1ElemCount * tc.elemSize; + const size_t dstFileSize = dstElemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (src0=%zux%zu, src1=%zux%zu -> dst=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.src0Rows, tc.src0Cols, tc.src1Rows, tc.src1Cols, tc.dstRows, tc.dstCols, tc.validRows, tc.validCols); + + std::string caseDir = std::string("./") + tc.name; + size_t actualSrc0FileSize = src0FileSize; + size_t actualSrc1FileSize = src1FileSize; + + void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), src0FileSize); + aclrtMallocHost((void **)(&src1Host), src1FileSize); + aclrtMallocHost((void **)(&dstHost), dstFileSize); + + aclrtMalloc((void **)&src0Device, src0FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, src1FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input0.bin").c_str(), actualSrc0FileSize, src0Host, src0FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input0.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0 && !ReadFile((caseDir + "/input1.bin").c_str(), actualSrc1FileSize, src1Host, src1FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, src0FileSize, src0Host, src0FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, src1FileSize, src1Host, src1FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandsub/tcolexpandsub.pto b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandsub/tcolexpandsub.pto new file mode 100644 index 000000000..c0858009e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolexpandsub/tcolexpandsub.pto @@ -0,0 +1,385 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file in the compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tcolexpandsub: subtract src0 by expanded src1. +// Matches PTO-ISA testcase parameters. +// Golden: src0 - np.tile(src1, (reps, 1))[:, :dst_col] + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // Case 1: fp32_6_128_1_128 (float32, src0=(6,128), src1=(1,128), dst=(6,128)) + func.func @TCOLEXPANDSUB_fp32_6_128_1_128(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c6 = arith.constant 6 : index + %c128 = arith.constant 128 : index + %c768 = arith.constant 768 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c6, %c128], + strides = [%c768, %c768, %c768, %c128, %c1] + : !pto.tensor_view<1x1x1x6x128xf32> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c6, %c128], + strides = [%c768, %c768, %c768, %c128, %c1] + : !pto.tensor_view<1x1x1x6x128xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c6, %c128] + : !pto.tensor_view<1x1x1x6x128xf32> -> !pto.partition_tensor_view<1x1x1x6x128xf32> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x128xf32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c6, %c128] + : !pto.tensor_view<1x1x1x6x128xf32> -> !pto.partition_tensor_view<1x1x1x6x128xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x6x128xf32>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x128xf32>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpandsub ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x6x128xf32>) + return + } + + // Case 2: fp32_18_32_1_32 (float32, src0=(18,32), src1=(1,32), dst=(18,32)) + func.func @TCOLEXPANDSUB_fp32_18_32_1_32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c18 = arith.constant 18 : index + %c32 = arith.constant 32 : index + %c576 = arith.constant 576 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c18, %c32], + strides = [%c576, %c576, %c576, %c32, %c1] + : !pto.tensor_view<1x1x1x18x32xf32> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c32], + strides = [%c32, %c32, %c32, %c32, %c1] + : !pto.tensor_view<1x1x1x1x32xf32> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c18, %c32], + strides = [%c576, %c576, %c576, %c32, %c1] + : !pto.tensor_view<1x1x1x18x32xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c18, %c32] + : !pto.tensor_view<1x1x1x18x32xf32> -> !pto.partition_tensor_view<1x1x1x18x32xf32> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c32] + : !pto.tensor_view<1x1x1x1x32xf32> -> !pto.partition_tensor_view<1x1x1x1x32xf32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c18, %c32] + : !pto.tensor_view<1x1x1x18x32xf32> -> !pto.partition_tensor_view<1x1x1x18x32xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x18x32xf32>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x32xf32>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpandsub ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x18x32xf32>) + return + } + + // Case 3: fp16_10_256_1_256 (float16, src0=(10,256), src1=(1,256), dst=(10,256)) + func.func @TCOLEXPANDSUB_fp16_10_256_1_256(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + %c256 = arith.constant 256 : index + %c2560 = arith.constant 2560 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c10, %c256], + strides = [%c2560, %c2560, %c2560, %c256, %c1] + : !pto.tensor_view<1x1x1x10x256xf16> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c10, %c256], + strides = [%c2560, %c2560, %c2560, %c256, %c1] + : !pto.tensor_view<1x1x1x10x256xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c10, %c256] + : !pto.tensor_view<1x1x1x10x256xf16> -> !pto.partition_tensor_view<1x1x1x10x256xf16> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c256] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x256xf16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c10, %c256] + : !pto.tensor_view<1x1x1x10x256xf16> -> !pto.partition_tensor_view<1x1x1x10x256xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x10x256xf16>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x256xf16>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpandsub ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x10x256xf16>) + return + } + + // Case 4: fp16_12_64_1_64 (float16, src0=(12,64), src1=(1,64), dst=(12,64)) + func.func @TCOLEXPANDSUB_fp16_12_64_1_64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c12 = arith.constant 12 : index + %c64 = arith.constant 64 : index + %c768 = arith.constant 768 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c12, %c64], + strides = [%c768, %c768, %c768, %c64, %c1] + : !pto.tensor_view<1x1x1x12x64xf16> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c64], + strides = [%c64, %c64, %c64, %c64, %c1] + : !pto.tensor_view<1x1x1x1x64xf16> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c12, %c64], + strides = [%c768, %c768, %c768, %c64, %c1] + : !pto.tensor_view<1x1x1x12x64xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c12, %c64] + : !pto.tensor_view<1x1x1x12x64xf16> -> !pto.partition_tensor_view<1x1x1x12x64xf16> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c64] + : !pto.tensor_view<1x1x1x1x64xf16> -> !pto.partition_tensor_view<1x1x1x1x64xf16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c12, %c64] + : !pto.tensor_view<1x1x1x12x64xf16> -> !pto.partition_tensor_view<1x1x1x12x64xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x12x64xf16>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x64xf16>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpandsub ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x12x64xf16>) + return + } + + // Case 5: int32_16_32_1_32 (int32, src0=(16,32), src1=(1,32), dst=(16,32)) + func.func @TCOLEXPANDSUB_int32_16_32_1_32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c512 = arith.constant 512 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi32> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c32], + strides = [%c32, %c32, %c32, %c32, %c1] + : !pto.tensor_view<1x1x1x1x32xi32> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c32] + : !pto.tensor_view<1x1x1x1x32xi32> -> !pto.partition_tensor_view<1x1x1x1x32xi32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x32xi32>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpandsub ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) + return + } + + // Case 6: int16_16_64_1_64 (int16, src0=(16,64), src1=(1,64), dst=(16,64)) + func.func @TCOLEXPANDSUB_int16_16_64_1_64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi16> + + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c64], + strides = [%c64, %c64, %c64, %c64, %c1] + : !pto.tensor_view<1x1x1x1x64xi16> + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi16> -> !pto.partition_tensor_view<1x1x1x16x64xi16> + + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c64] + : !pto.tensor_view<1x1x1x1x64xi16> -> !pto.partition_tensor_view<1x1x1x1x64xi16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi16> -> !pto.partition_tensor_view<1x1x1x16x64xi16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + + %src1 = pto.alloc_tile + : !pto.tile_buf + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x64xi16>) + outs(%src0 : !pto.tile_buf) + + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x64xi16>) + outs(%src1 : !pto.tile_buf) + + pto.tcolexpandsub ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xi16>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolmax/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tcolmax/CMakeLists.txt new file mode 100644 index 000000000..5afae033c --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolmax/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tcolmax) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolmax/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolmax/cases.py new file mode 100644 index 000000000..f6fa36d9e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolmax/cases.py @@ -0,0 +1,245 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tcolmax ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions for input. + - valid_shape: (valid_rows, valid_cols) — effective computation region for input. + - dst_shape: (1, cols) — allocated tile dimensions for output. + - dst_valid_shape: (1, valid_cols) — effective computation region for output. + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_1x256", + "dtype": np.float32, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 1e-6, + }, + { + "name": "f32_16x128", + "dtype": np.float32, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 1e-6, + }, + { + "name": "f32_16x256", + "dtype": np.float32, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 1e-6, + }, + { + "name": "f16_1x256", + "dtype": np.float16, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 1e-3, + }, + { + "name": "f16_16x128", + "dtype": np.float16, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 1e-3, + }, + { + "name": "f16_16x256", + "dtype": np.float16, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 1e-3, + }, + { + "name": "i8_1x256", + "dtype": np.int8, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "i8_16x128", + "dtype": np.int8, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 0, + }, + { + "name": "i8_16x256", + "dtype": np.int8, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "i16_1x256", + "dtype": np.int16, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "i16_16x128", + "dtype": np.int16, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 0, + }, + { + "name": "i16_16x256", + "dtype": np.int16, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "i32_1x256", + "dtype": np.int32, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "i32_16x128", + "dtype": np.int32, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 0, + }, + { + "name": "i32_16x256", + "dtype": np.int32, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "ui8_1x256", + "dtype": np.uint8, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "ui8_16x128", + "dtype": np.uint8, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 0, + }, + { + "name": "ui8_16x256", + "dtype": np.uint8, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "ui16_1x256", + "dtype": np.uint16, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "ui16_16x128", + "dtype": np.uint16, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 0, + }, + { + "name": "ui16_16x256", + "dtype": np.uint16, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "ui32_1x256", + "dtype": np.uint32, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "ui32_16x128", + "dtype": np.uint32, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 0, + }, + { + "name": "ui32_16x256", + "dtype": np.uint32, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolmax/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolmax/compare.py new file mode 100644 index 000000000..06a17bbda --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolmax/compare.py @@ -0,0 +1,50 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + vr, vc = dst_valid_shape + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(dst_shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(dst_shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolmax/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolmax/gen_data.py new file mode 100644 index 000000000..4dcd83b95 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolmax/gen_data.py @@ -0,0 +1,35 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + + vr, vc = valid_shape + golden = np.zeros(dst_shape, dtype=dtype) + golden_result = np.max(input1[:vr, :vc], axis=0, keepdims=True).astype(dtype) + golden[:1, :vc] = golden_result + + save_case_data(case["name"], {"input": input1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dst_shape={dst_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolmax/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolmax/launch.cpp new file mode 100644 index 000000000..cb5325a5b --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolmax/launch.cpp @@ -0,0 +1,181 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: f32 1x256 (input: 1x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMAX_f32_1x256(__gm__ float *dst, __gm__ float *src); + +void LaunchTCOLMAX_f32_1x256(float *dst, float *src, void *stream) { + TCOLMAX_f32_1x256<<<1, nullptr, stream>>>((__gm__ float *)dst, (__gm__ float *)src); +} + +// Case 1: f32 16x128 (input: 16x128, output: 1x128) +extern "C" __global__ AICORE void TCOLMAX_f32_16x128(__gm__ float *dst, __gm__ float *src); + +void LaunchTCOLMAX_f32_16x128(float *dst, float *src, void *stream) { + TCOLMAX_f32_16x128<<<1, nullptr, stream>>>((__gm__ float *)dst, (__gm__ float *)src); +} + +// Case 2: f32 16x256 (input: 16x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMAX_f32_16x256(__gm__ float *dst, __gm__ float *src); + +void LaunchTCOLMAX_f32_16x256(float *dst, float *src, void *stream) { + TCOLMAX_f32_16x256<<<1, nullptr, stream>>>((__gm__ float *)dst, (__gm__ float *)src); +} + +// Case 3: f16 1x256 (input: 1x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMAX_f16_1x256(__gm__ half *dst, __gm__ half *src); + +void LaunchTCOLMAX_f16_1x256(void *dst, void *src, void *stream) { + TCOLMAX_f16_1x256<<<1, nullptr, stream>>>((__gm__ half *)dst, (__gm__ half *)src); +} + +// Case 4: f16 16x128 (input: 16x128, output: 1x128) +extern "C" __global__ AICORE void TCOLMAX_f16_16x128(__gm__ half *dst, __gm__ half *src); + +void LaunchTCOLMAX_f16_16x128(void *dst, void *src, void *stream) { + TCOLMAX_f16_16x128<<<1, nullptr, stream>>>((__gm__ half *)dst, (__gm__ half *)src); +} + +// Case 5: f16 16x256 (input: 16x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMAX_f16_16x256(__gm__ half *dst, __gm__ half *src); + +void LaunchTCOLMAX_f16_16x256(void *dst, void *src, void *stream) { + TCOLMAX_f16_16x256<<<1, nullptr, stream>>>((__gm__ half *)dst, (__gm__ half *)src); +} + +// Case 6: i8 1x256 (input: 1x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMAX_i8_1x256(__gm__ int8_t *dst, __gm__ int8_t *src); + +void LaunchTCOLMAX_i8_1x256(void *dst, void *src, void *stream) { + TCOLMAX_i8_1x256<<<1, nullptr, stream>>>((__gm__ int8_t *)dst, (__gm__ int8_t *)src); +} + +// Case 7: i8 16x128 (input: 16x128, output: 1x128) +extern "C" __global__ AICORE void TCOLMAX_i8_16x128(__gm__ int8_t *dst, __gm__ int8_t *src); + +void LaunchTCOLMAX_i8_16x128(void *dst, void *src, void *stream) { + TCOLMAX_i8_16x128<<<1, nullptr, stream>>>((__gm__ int8_t *)dst, (__gm__ int8_t *)src); +} + +// Case 8: i8 16x256 (input: 16x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMAX_i8_16x256(__gm__ int8_t *dst, __gm__ int8_t *src); + +void LaunchTCOLMAX_i8_16x256(void *dst, void *src, void *stream) { + TCOLMAX_i8_16x256<<<1, nullptr, stream>>>((__gm__ int8_t *)dst, (__gm__ int8_t *)src); +} + +// Case 9: i16 1x256 (input: 1x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMAX_i16_1x256(__gm__ int16_t *dst, __gm__ int16_t *src); + +void LaunchTCOLMAX_i16_1x256(void *dst, void *src, void *stream) { + TCOLMAX_i16_1x256<<<1, nullptr, stream>>>((__gm__ int16_t *)dst, (__gm__ int16_t *)src); +} + +// Case 10: i16 16x128 (input: 16x128, output: 1x128) +extern "C" __global__ AICORE void TCOLMAX_i16_16x128(__gm__ int16_t *dst, __gm__ int16_t *src); + +void LaunchTCOLMAX_i16_16x128(void *dst, void *src, void *stream) { + TCOLMAX_i16_16x128<<<1, nullptr, stream>>>((__gm__ int16_t *)dst, (__gm__ int16_t *)src); +} + +// Case 11: i16 16x256 (input: 16x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMAX_i16_16x256(__gm__ int16_t *dst, __gm__ int16_t *src); + +void LaunchTCOLMAX_i16_16x256(void *dst, void *src, void *stream) { + TCOLMAX_i16_16x256<<<1, nullptr, stream>>>((__gm__ int16_t *)dst, (__gm__ int16_t *)src); +} + +// Case 12: i32 1x256 (input: 1x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMAX_i32_1x256(__gm__ int32_t *dst, __gm__ int32_t *src); + +void LaunchTCOLMAX_i32_1x256(void *dst, void *src, void *stream) { + TCOLMAX_i32_1x256<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ int32_t *)src); +} + +// Case 13: i32 16x128 (input: 16x128, output: 1x128) +extern "C" __global__ AICORE void TCOLMAX_i32_16x128(__gm__ int32_t *dst, __gm__ int32_t *src); + +void LaunchTCOLMAX_i32_16x128(void *dst, void *src, void *stream) { + TCOLMAX_i32_16x128<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ int32_t *)src); +} + +// Case 14: i32 16x256 (input: 16x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMAX_i32_16x256(__gm__ int32_t *dst, __gm__ int32_t *src); + +void LaunchTCOLMAX_i32_16x256(void *dst, void *src, void *stream) { + TCOLMAX_i32_16x256<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ int32_t *)src); +} + +// Case 15: ui8 1x256 (input: 1x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMAX_ui8_1x256(__gm__ uint8_t *dst, __gm__ uint8_t *src); + +void LaunchTCOLMAX_ui8_1x256(void *dst, void *src, void *stream) { + TCOLMAX_ui8_1x256<<<1, nullptr, stream>>>((__gm__ uint8_t *)dst, (__gm__ uint8_t *)src); +} + +// Case 16: ui8 16x128 (input: 16x128, output: 1x128) +extern "C" __global__ AICORE void TCOLMAX_ui8_16x128(__gm__ uint8_t *dst, __gm__ uint8_t *src); + +void LaunchTCOLMAX_ui8_16x128(void *dst, void *src, void *stream) { + TCOLMAX_ui8_16x128<<<1, nullptr, stream>>>((__gm__ uint8_t *)dst, (__gm__ uint8_t *)src); +} + +// Case 17: ui8 16x256 (input: 16x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMAX_ui8_16x256(__gm__ uint8_t *dst, __gm__ uint8_t *src); + +void LaunchTCOLMAX_ui8_16x256(void *dst, void *src, void *stream) { + TCOLMAX_ui8_16x256<<<1, nullptr, stream>>>((__gm__ uint8_t *)dst, (__gm__ uint8_t *)src); +} + +// Case 18: ui16 1x256 (input: 1x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMAX_ui16_1x256(__gm__ uint16_t *dst, __gm__ uint16_t *src); + +void LaunchTCOLMAX_ui16_1x256(void *dst, void *src, void *stream) { + TCOLMAX_ui16_1x256<<<1, nullptr, stream>>>((__gm__ uint16_t *)dst, (__gm__ uint16_t *)src); +} + +// Case 19: ui16 16x128 (input: 16x128, output: 1x128) +extern "C" __global__ AICORE void TCOLMAX_ui16_16x128(__gm__ uint16_t *dst, __gm__ uint16_t *src); + +void LaunchTCOLMAX_ui16_16x128(void *dst, void *src, void *stream) { + TCOLMAX_ui16_16x128<<<1, nullptr, stream>>>((__gm__ uint16_t *)dst, (__gm__ uint16_t *)src); +} + +// Case 20: ui16 16x256 (input: 16x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMAX_ui16_16x256(__gm__ uint16_t *dst, __gm__ uint16_t *src); + +void LaunchTCOLMAX_ui16_16x256(void *dst, void *src, void *stream) { + TCOLMAX_ui16_16x256<<<1, nullptr, stream>>>((__gm__ uint16_t *)dst, (__gm__ uint16_t *)src); +} + +// Case 21: ui32 1x256 (input: 1x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMAX_ui32_1x256(__gm__ uint32_t *dst, __gm__ uint32_t *src); + +void LaunchTCOLMAX_ui32_1x256(void *dst, void *src, void *stream) { + TCOLMAX_ui32_1x256<<<1, nullptr, stream>>>((__gm__ uint32_t *)dst, (__gm__ uint32_t *)src); +} + +// Case 22: ui32 16x128 (input: 16x128, output: 1x128) +extern "C" __global__ AICORE void TCOLMAX_ui32_16x128(__gm__ uint32_t *dst, __gm__ uint32_t *src); + +void LaunchTCOLMAX_ui32_16x128(void *dst, void *src, void *stream) { + TCOLMAX_ui32_16x128<<<1, nullptr, stream>>>((__gm__ uint32_t *)dst, (__gm__ uint32_t *)src); +} + +// Case 23: ui32 16x256 (input: 16x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMAX_ui32_16x256(__gm__ uint32_t *dst, __gm__ uint32_t *src); + +void LaunchTCOLMAX_ui32_16x256(void *dst, void *src, void *stream) { + TCOLMAX_ui32_16x256<<<1, nullptr, stream>>>((__gm__ uint32_t *)dst, (__gm__ uint32_t *)src); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolmax/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolmax/main.cpp new file mode 100644 index 000000000..aaa0b9505 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolmax/main.cpp @@ -0,0 +1,189 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tcolmax ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTCOLMAX_f32_1x256(float *dst, float *src, void *stream); +void LaunchTCOLMAX_f32_16x128(float *dst, float *src, void *stream); +void LaunchTCOLMAX_f32_16x256(float *dst, float *src, void *stream); +void LaunchTCOLMAX_f16_1x256(void *dst, void *src, void *stream); +void LaunchTCOLMAX_f16_16x128(void *dst, void *src, void *stream); +void LaunchTCOLMAX_f16_16x256(void *dst, void *src, void *stream); +void LaunchTCOLMAX_i8_1x256(void *dst, void *src, void *stream); +void LaunchTCOLMAX_i8_16x128(void *dst, void *src, void *stream); +void LaunchTCOLMAX_i8_16x256(void *dst, void *src, void *stream); +void LaunchTCOLMAX_i16_1x256(void *dst, void *src, void *stream); +void LaunchTCOLMAX_i16_16x128(void *dst, void *src, void *stream); +void LaunchTCOLMAX_i16_16x256(void *dst, void *src, void *stream); +void LaunchTCOLMAX_i32_1x256(void *dst, void *src, void *stream); +void LaunchTCOLMAX_i32_16x128(void *dst, void *src, void *stream); +void LaunchTCOLMAX_i32_16x256(void *dst, void *src, void *stream); +void LaunchTCOLMAX_ui8_1x256(void *dst, void *src, void *stream); +void LaunchTCOLMAX_ui8_16x128(void *dst, void *src, void *stream); +void LaunchTCOLMAX_ui8_16x256(void *dst, void *src, void *stream); +void LaunchTCOLMAX_ui16_1x256(void *dst, void *src, void *stream); +void LaunchTCOLMAX_ui16_16x128(void *dst, void *src, void *stream); +void LaunchTCOLMAX_ui16_16x256(void *dst, void *src, void *stream); +void LaunchTCOLMAX_ui32_1x256(void *dst, void *src, void *stream); +void LaunchTCOLMAX_ui32_16x128(void *dst, void *src, void *stream); +void LaunchTCOLMAX_ui32_16x256(void *dst, void *src, void *stream); + +using LaunchFnFloat = void (*)(float *, float *, void *); +using LaunchFnVoid = void (*)(void *, void *, void *); + +struct TestCase { + const char *name; + void *launch; + size_t srcRows; + size_t srcCols; + size_t srcValidRows; + size_t srcValidCols; + size_t dstRows; + size_t dstCols; + size_t dstValidCols; + size_t elemSize; + bool isFp16; +}; + +static const TestCase kCases[] = { + {"f32_1x256", (void*)LaunchTCOLMAX_f32_1x256, 1, 256, 1, 255, 1, 256, 255, sizeof(float), false}, + {"f32_16x128", (void*)LaunchTCOLMAX_f32_16x128, 16, 128, 16, 127, 1, 128, 127, sizeof(float), false}, + {"f32_16x256", (void*)LaunchTCOLMAX_f32_16x256, 16, 256, 15, 255, 1, 256, 255, sizeof(float), false}, + {"f16_1x256", (void*)LaunchTCOLMAX_f16_1x256, 1, 256, 1, 255, 1, 256, 255, 2, true}, + {"f16_16x128", (void*)LaunchTCOLMAX_f16_16x128, 16, 128, 16, 127, 1, 128, 127, 2, true}, + {"f16_16x256", (void*)LaunchTCOLMAX_f16_16x256, 16, 256, 15, 255, 1, 256, 255, 2, true}, + {"i8_1x256", (void*)LaunchTCOLMAX_i8_1x256, 1, 256, 1, 255, 1, 256, 255, 1, true}, + {"i8_16x128", (void*)LaunchTCOLMAX_i8_16x128, 16, 128, 16, 127, 1, 128, 127, 1, true}, + {"i8_16x256", (void*)LaunchTCOLMAX_i8_16x256, 16, 256, 15, 255, 1, 256, 255, 1, true}, + {"i16_1x256", (void*)LaunchTCOLMAX_i16_1x256, 1, 256, 1, 255, 1, 256, 255, 2, true}, + {"i16_16x128", (void*)LaunchTCOLMAX_i16_16x128, 16, 128, 16, 127, 1, 128, 127, 2, true}, + {"i16_16x256", (void*)LaunchTCOLMAX_i16_16x256, 16, 256, 15, 255, 1, 256, 255, 2, true}, + {"i32_1x256", (void*)LaunchTCOLMAX_i32_1x256, 1, 256, 1, 255, 1, 256, 255, 4, true}, + {"i32_16x128", (void*)LaunchTCOLMAX_i32_16x128, 16, 128, 16, 127, 1, 128, 127, 4, true}, + {"i32_16x256", (void*)LaunchTCOLMAX_i32_16x256, 16, 256, 15, 255, 1, 256, 255, 4, true}, + {"ui8_1x256", (void*)LaunchTCOLMAX_ui8_1x256, 1, 256, 1, 255, 1, 256, 255, 1, true}, + {"ui8_16x128", (void*)LaunchTCOLMAX_ui8_16x128, 16, 128, 16, 127, 1, 128, 127, 1, true}, + {"ui8_16x256", (void*)LaunchTCOLMAX_ui8_16x256, 16, 256, 15, 255, 1, 256, 255, 1, true}, + {"ui16_1x256", (void*)LaunchTCOLMAX_ui16_1x256, 1, 256, 1, 255, 1, 256, 255, 2, true}, + {"ui16_16x128", (void*)LaunchTCOLMAX_ui16_16x128, 16, 128, 16, 127, 1, 128, 127, 2, true}, + {"ui16_16x256", (void*)LaunchTCOLMAX_ui16_16x256, 16, 256, 15, 255, 1, 256, 255, 2, true}, + {"ui32_1x256", (void*)LaunchTCOLMAX_ui32_1x256, 1, 256, 1, 255, 1, 256, 255, 4, true}, + {"ui32_16x128", (void*)LaunchTCOLMAX_ui32_16x128, 16, 128, 16, 127, 1, 128, 127, 4, true}, + {"ui32_16x256", (void*)LaunchTCOLMAX_ui32_16x256, 16, 256, 15, 255, 1, 256, 255, 4, true}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t srcElemCount = tc.srcRows * tc.srcCols; + const size_t srcFileSize = srcElemCount * tc.elemSize; + const size_t dstElemCount = tc.dstRows * tc.dstCols; + const size_t dstFileSize = dstElemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (src=%zux%zu, dst=%zux%zu, fp16=%d) ===\n", + tc.name, tc.srcRows, tc.srcCols, tc.dstRows, tc.dstCols, tc.isFp16); + + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSizeVar = srcFileSize; + size_t dstFileSizeVar = dstFileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&srcHost, srcFileSize); + aclrtMallocHost(&dstHost, dstFileSize); + + aclrtMalloc(&srcDevice, srcFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input.bin").c_str(), srcFileSizeVar, srcHost, srcFileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, srcFileSize, srcHost, srcFileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + if (tc.isFp16) { + LaunchFnVoid launch = (LaunchFnVoid)tc.launch; + launch(dstDevice, srcDevice, stream); + } else { + LaunchFnFloat launch = (LaunchFnFloat)tc.launch; + launch((float*)dstDevice, (float*)srcDevice, stream); + } + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSizeVar)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolmax/tcolmax.pto b/test/tilelang_st/npu/a5/src/st/testcase/tcolmax/tcolmax.pto new file mode 100644 index 000000000..864fde643 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolmax/tcolmax.pto @@ -0,0 +1,1038 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tcolmax: tload(src) + tcolmax(dst, src) + tstore(dst). +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // Case 0: f32 1x256 (input: 1x256, output: 1x256) + func.func @TCOLMAX_f32_1x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xf32> -> !pto.partition_tensor_view<1x1x1x1x255xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xf32> -> !pto.partition_tensor_view<1x1x1x1x255xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xf32>) + outs(%src : !pto.tile_buf) + + pto.tcolmax ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xf32>) + return + } + +// Case 1: f32 16x128 (input: 16x128, output: 1x128) + func.func @TCOLMAX_f32_16x128(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x127xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x127xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xf32>) + outs(%src : !pto.tile_buf) + + pto.tcolmax ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xf32>) + return + } + + // Case 2: f32 16x256 (input: 16x256, output: 1x256) + func.func @TCOLMAX_f32_16x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xf32> -> !pto.partition_tensor_view<1x1x1x15x255xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xf32> -> !pto.partition_tensor_view<1x1x1x1x255xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xf32>) + outs(%src : !pto.tile_buf) + + pto.tcolmax ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xf32>) + return + } + + // Case 3: f16 1x256 (input: 1x256, output: 1x256) + func.func @TCOLMAX_f16_1x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x255xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x255xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xf16>) + outs(%src : !pto.tile_buf) + + pto.tcolmax ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xf16>) + return + } + + // Case 4: f16 16x128 (input: 16x128, output: 1x128) + func.func @TCOLMAX_f16_16x128(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xf16> -> !pto.partition_tensor_view<1x1x1x16x127xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xf16> -> !pto.partition_tensor_view<1x1x1x1x127xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xf16>) + outs(%src : !pto.tile_buf) + + pto.tcolmax ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xf16>) + return + } + + // Case 5: f16 16x256 (input: 16x256, output: 1x256) + func.func @TCOLMAX_f16_16x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xf16> -> !pto.partition_tensor_view<1x1x1x15x255xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x255xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xf16>) + outs(%src : !pto.tile_buf) + + pto.tcolmax ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xf16>) + return + } + + // Case 6: i8 1x256 (input: 1x256, output: 1x256) + func.func @TCOLMAX_i8_1x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi8> -> !pto.partition_tensor_view<1x1x1x1x255xi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi8> -> !pto.partition_tensor_view<1x1x1x1x255xi8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xi8>) + outs(%src : !pto.tile_buf) + + pto.tcolmax ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi8>) + return + } + + // Case 7: i8 16x128 (input: 16x128, output: 1x128) + func.func @TCOLMAX_i8_16x128(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xi8> -> !pto.partition_tensor_view<1x1x1x16x127xi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xi8> -> !pto.partition_tensor_view<1x1x1x1x127xi8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xi8>) + outs(%src : !pto.tile_buf) + + pto.tcolmax ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xi8>) + return + } + + // Case 8: i8 16x256 (input: 16x256, output: 1x256) + func.func @TCOLMAX_i8_16x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xi8> -> !pto.partition_tensor_view<1x1x1x15x255xi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi8> -> !pto.partition_tensor_view<1x1x1x1x255xi8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xi8>) + outs(%src : !pto.tile_buf) + + pto.tcolmax ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi8>) + return + } + + // Case 9: i16 1x256 (input: 1x256, output: 1x256) + func.func @TCOLMAX_i16_1x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi16> -> !pto.partition_tensor_view<1x1x1x1x255xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi16> -> !pto.partition_tensor_view<1x1x1x1x255xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xi16>) + outs(%src : !pto.tile_buf) + + pto.tcolmax ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi16>) + return + } + + // Case 10: i16 16x128 (input: 16x128, output: 1x128) + func.func @TCOLMAX_i16_16x128(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xi16> -> !pto.partition_tensor_view<1x1x1x16x127xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xi16> -> !pto.partition_tensor_view<1x1x1x1x127xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xi16>) + outs(%src : !pto.tile_buf) + + pto.tcolmax ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xi16>) + return + } + + // Case 11: i16 16x256 (input: 16x256, output: 1x256) + func.func @TCOLMAX_i16_16x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xi16> -> !pto.partition_tensor_view<1x1x1x15x255xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi16> -> !pto.partition_tensor_view<1x1x1x1x255xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xi16>) + outs(%src : !pto.tile_buf) + + pto.tcolmax ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi16>) + return + } + + // Case 12: i32 1x256 (input: 1x256, output: 1x256) + func.func @TCOLMAX_i32_1x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x255xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x255xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xi32>) + outs(%src : !pto.tile_buf) + + pto.tcolmax ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi32>) + return + } + + // Case 13: i32 16x128 (input: 16x128, output: 1x128) + func.func @TCOLMAX_i32_16x128(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xi32> -> !pto.partition_tensor_view<1x1x1x16x127xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xi32> -> !pto.partition_tensor_view<1x1x1x1x127xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xi32>) + outs(%src : !pto.tile_buf) + + pto.tcolmax ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xi32>) + return + } + + // Case 14: i32 16x256 (input: 16x256, output: 1x256) + func.func @TCOLMAX_i32_16x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xi32> -> !pto.partition_tensor_view<1x1x1x15x255xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x255xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xi32>) + outs(%src : !pto.tile_buf) + + pto.tcolmax ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi32>) + return + } + + // Case 15: ui8 1x256 (input: 1x256, output: 1x256) + func.func @TCOLMAX_ui8_1x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui8> -> !pto.partition_tensor_view<1x1x1x1x255xui8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui8> -> !pto.partition_tensor_view<1x1x1x1x255xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xui8>) + outs(%src : !pto.tile_buf) + + pto.tcolmax ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xui8>) + return + } + + // Case 16: ui8 16x128 (input: 16x128, output: 1x128) + func.func @TCOLMAX_ui8_16x128(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xui8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xui8> -> !pto.partition_tensor_view<1x1x1x16x127xui8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xui8> -> !pto.partition_tensor_view<1x1x1x1x127xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xui8>) + outs(%src : !pto.tile_buf) + + pto.tcolmax ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xui8>) + return + } + + // Case 17: ui8 16x256 (input: 16x256, output: 1x256) + func.func @TCOLMAX_ui8_16x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xui8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xui8> -> !pto.partition_tensor_view<1x1x1x15x255xui8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui8> -> !pto.partition_tensor_view<1x1x1x1x255xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xui8>) + outs(%src : !pto.tile_buf) + + pto.tcolmax ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xui8>) + return + } + + // Case 18: ui16 1x256 (input: 1x256, output: 1x256) + func.func @TCOLMAX_ui16_1x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui16> -> !pto.partition_tensor_view<1x1x1x1x255xui16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui16> -> !pto.partition_tensor_view<1x1x1x1x255xui16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xui16>) + outs(%src : !pto.tile_buf) + + pto.tcolmax ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xui16>) + return + } + + // Case 19: ui16 16x128 (input: 16x128, output: 1x128) + func.func @TCOLMAX_ui16_16x128(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xui16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xui16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xui16> -> !pto.partition_tensor_view<1x1x1x16x127xui16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xui16> -> !pto.partition_tensor_view<1x1x1x1x127xui16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xui16>) + outs(%src : !pto.tile_buf) + + pto.tcolmax ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xui16>) + return + } + + // Case 20: ui16 16x256 (input: 16x256, output: 1x256) + func.func @TCOLMAX_ui16_16x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xui16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xui16> -> !pto.partition_tensor_view<1x1x1x15x255xui16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui16> -> !pto.partition_tensor_view<1x1x1x1x255xui16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xui16>) + outs(%src : !pto.tile_buf) + + pto.tcolmax ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xui16>) + return + } + + // Case 21: ui32 1x256 (input: 1x256, output: 1x256) + func.func @TCOLMAX_ui32_1x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui32> -> !pto.partition_tensor_view<1x1x1x1x255xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui32> -> !pto.partition_tensor_view<1x1x1x1x255xui32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xui32>) + outs(%src : !pto.tile_buf) + + pto.tcolmax ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xui32>) + return + } + + // Case 22: ui32 16x128 (input: 16x128, output: 1x128) + func.func @TCOLMAX_ui32_16x128(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xui32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xui32> -> !pto.partition_tensor_view<1x1x1x16x127xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xui32> -> !pto.partition_tensor_view<1x1x1x1x127xui32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xui32>) + outs(%src : !pto.tile_buf) + + pto.tcolmax ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xui32>) + return + } + + // Case 23: ui32 16x256 (input: 16x256, output: 1x256) + func.func @TCOLMAX_ui32_16x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xui32> -> !pto.partition_tensor_view<1x1x1x15x255xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui32> -> !pto.partition_tensor_view<1x1x1x1x255xui32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xui32>) + outs(%src : !pto.tile_buf) + + pto.tcolmax ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xui32>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolmin/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tcolmin/CMakeLists.txt new file mode 100644 index 000000000..6de952f0b --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolmin/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tcolmin) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolmin/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolmin/cases.py new file mode 100644 index 000000000..ba16bbd3f --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolmin/cases.py @@ -0,0 +1,245 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tcolmin ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions for input. + - valid_shape: (valid_rows, valid_cols) — effective computation region for input. + - dst_shape: (1, cols) — allocated tile dimensions for output. + - dst_valid_shape: (1, valid_cols) — effective computation region for output. + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_1x256", + "dtype": np.float32, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 1e-6, + }, + { + "name": "f32_16x128", + "dtype": np.float32, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 1e-6, + }, + { + "name": "f32_16x256", + "dtype": np.float32, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 1e-6, + }, + { + "name": "f16_1x256", + "dtype": np.float16, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 1e-3, + }, + { + "name": "f16_16x128", + "dtype": np.float16, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 1e-3, + }, + { + "name": "f16_16x256", + "dtype": np.float16, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 1e-3, + }, + { + "name": "i8_1x256", + "dtype": np.int8, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "i8_16x128", + "dtype": np.int8, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 0, + }, + { + "name": "i8_16x256", + "dtype": np.int8, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "i16_1x256", + "dtype": np.int16, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "i16_16x128", + "dtype": np.int16, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 0, + }, + { + "name": "i16_16x256", + "dtype": np.int16, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "i32_1x256", + "dtype": np.int32, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "i32_16x128", + "dtype": np.int32, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 0, + }, + { + "name": "i32_16x256", + "dtype": np.int32, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "ui8_1x256", + "dtype": np.uint8, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "ui8_16x128", + "dtype": np.uint8, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 0, + }, + { + "name": "ui8_16x256", + "dtype": np.uint8, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "ui16_1x256", + "dtype": np.uint16, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "ui16_16x128", + "dtype": np.uint16, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 0, + }, + { + "name": "ui16_16x256", + "dtype": np.uint16, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "ui32_1x256", + "dtype": np.uint32, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "ui32_16x128", + "dtype": np.uint32, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 0, + }, + { + "name": "ui32_16x256", + "dtype": np.uint32, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolmin/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolmin/compare.py new file mode 100644 index 000000000..06a17bbda --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolmin/compare.py @@ -0,0 +1,50 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + vr, vc = dst_valid_shape + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(dst_shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(dst_shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolmin/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolmin/gen_data.py new file mode 100644 index 000000000..152c58370 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolmin/gen_data.py @@ -0,0 +1,35 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + + vr, vc = valid_shape + golden = np.zeros(dst_shape, dtype=dtype) + golden_result = np.min(input1[:vr, :vc], axis=0, keepdims=True).astype(dtype) + golden[:1, :vc] = golden_result + + save_case_data(case["name"], {"input": input1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dst_shape={dst_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolmin/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolmin/launch.cpp new file mode 100644 index 000000000..7e43609cc --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolmin/launch.cpp @@ -0,0 +1,181 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: f32 1x256 (input: 1x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMIN_f32_1x256(__gm__ float *dst, __gm__ float *src); + +void LaunchTCOLMIN_f32_1x256(float *dst, float *src, void *stream) { + TCOLMIN_f32_1x256<<<1, nullptr, stream>>>((__gm__ float *)dst, (__gm__ float *)src); +} + +// Case 1: f32 16x128 (input: 16x128, output: 1x128) +extern "C" __global__ AICORE void TCOLMIN_f32_16x128(__gm__ float *dst, __gm__ float *src); + +void LaunchTCOLMIN_f32_16x128(float *dst, float *src, void *stream) { + TCOLMIN_f32_16x128<<<1, nullptr, stream>>>((__gm__ float *)dst, (__gm__ float *)src); +} + +// Case 2: f32 16x256 (input: 16x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMIN_f32_16x256(__gm__ float *dst, __gm__ float *src); + +void LaunchTCOLMIN_f32_16x256(float *dst, float *src, void *stream) { + TCOLMIN_f32_16x256<<<1, nullptr, stream>>>((__gm__ float *)dst, (__gm__ float *)src); +} + +// Case 3: f16 1x256 (input: 1x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMIN_f16_1x256(__gm__ half *dst, __gm__ half *src); + +void LaunchTCOLMIN_f16_1x256(void *dst, void *src, void *stream) { + TCOLMIN_f16_1x256<<<1, nullptr, stream>>>((__gm__ half *)dst, (__gm__ half *)src); +} + +// Case 4: f16 16x128 (input: 16x128, output: 1x128) +extern "C" __global__ AICORE void TCOLMIN_f16_16x128(__gm__ half *dst, __gm__ half *src); + +void LaunchTCOLMIN_f16_16x128(void *dst, void *src, void *stream) { + TCOLMIN_f16_16x128<<<1, nullptr, stream>>>((__gm__ half *)dst, (__gm__ half *)src); +} + +// Case 5: f16 16x256 (input: 16x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMIN_f16_16x256(__gm__ half *dst, __gm__ half *src); + +void LaunchTCOLMIN_f16_16x256(void *dst, void *src, void *stream) { + TCOLMIN_f16_16x256<<<1, nullptr, stream>>>((__gm__ half *)dst, (__gm__ half *)src); +} + +// Case 6: i8 1x256 (input: 1x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMIN_i8_1x256(__gm__ int8_t *dst, __gm__ int8_t *src); + +void LaunchTCOLMIN_i8_1x256(void *dst, void *src, void *stream) { + TCOLMIN_i8_1x256<<<1, nullptr, stream>>>((__gm__ int8_t *)dst, (__gm__ int8_t *)src); +} + +// Case 7: i8 16x128 (input: 16x128, output: 1x128) +extern "C" __global__ AICORE void TCOLMIN_i8_16x128(__gm__ int8_t *dst, __gm__ int8_t *src); + +void LaunchTCOLMIN_i8_16x128(void *dst, void *src, void *stream) { + TCOLMIN_i8_16x128<<<1, nullptr, stream>>>((__gm__ int8_t *)dst, (__gm__ int8_t *)src); +} + +// Case 8: i8 16x256 (input: 16x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMIN_i8_16x256(__gm__ int8_t *dst, __gm__ int8_t *src); + +void LaunchTCOLMIN_i8_16x256(void *dst, void *src, void *stream) { + TCOLMIN_i8_16x256<<<1, nullptr, stream>>>((__gm__ int8_t *)dst, (__gm__ int8_t *)src); +} + +// Case 9: i16 1x256 (input: 1x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMIN_i16_1x256(__gm__ int16_t *dst, __gm__ int16_t *src); + +void LaunchTCOLMIN_i16_1x256(void *dst, void *src, void *stream) { + TCOLMIN_i16_1x256<<<1, nullptr, stream>>>((__gm__ int16_t *)dst, (__gm__ int16_t *)src); +} + +// Case 10: i16 16x128 (input: 16x128, output: 1x128) +extern "C" __global__ AICORE void TCOLMIN_i16_16x128(__gm__ int16_t *dst, __gm__ int16_t *src); + +void LaunchTCOLMIN_i16_16x128(void *dst, void *src, void *stream) { + TCOLMIN_i16_16x128<<<1, nullptr, stream>>>((__gm__ int16_t *)dst, (__gm__ int16_t *)src); +} + +// Case 11: i16 16x256 (input: 16x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMIN_i16_16x256(__gm__ int16_t *dst, __gm__ int16_t *src); + +void LaunchTCOLMIN_i16_16x256(void *dst, void *src, void *stream) { + TCOLMIN_i16_16x256<<<1, nullptr, stream>>>((__gm__ int16_t *)dst, (__gm__ int16_t *)src); +} + +// Case 12: i32 1x256 (input: 1x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMIN_i32_1x256(__gm__ int32_t *dst, __gm__ int32_t *src); + +void LaunchTCOLMIN_i32_1x256(void *dst, void *src, void *stream) { + TCOLMIN_i32_1x256<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ int32_t *)src); +} + +// Case 13: i32 16x128 (input: 16x128, output: 1x128) +extern "C" __global__ AICORE void TCOLMIN_i32_16x128(__gm__ int32_t *dst, __gm__ int32_t *src); + +void LaunchTCOLMIN_i32_16x128(void *dst, void *src, void *stream) { + TCOLMIN_i32_16x128<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ int32_t *)src); +} + +// Case 14: i32 16x256 (input: 16x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMIN_i32_16x256(__gm__ int32_t *dst, __gm__ int32_t *src); + +void LaunchTCOLMIN_i32_16x256(void *dst, void *src, void *stream) { + TCOLMIN_i32_16x256<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ int32_t *)src); +} + +// Case 15: ui8 1x256 (input: 1x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMIN_ui8_1x256(__gm__ uint8_t *dst, __gm__ uint8_t *src); + +void LaunchTCOLMIN_ui8_1x256(void *dst, void *src, void *stream) { + TCOLMIN_ui8_1x256<<<1, nullptr, stream>>>((__gm__ uint8_t *)dst, (__gm__ uint8_t *)src); +} + +// Case 16: ui8 16x128 (input: 16x128, output: 1x128) +extern "C" __global__ AICORE void TCOLMIN_ui8_16x128(__gm__ uint8_t *dst, __gm__ uint8_t *src); + +void LaunchTCOLMIN_ui8_16x128(void *dst, void *src, void *stream) { + TCOLMIN_ui8_16x128<<<1, nullptr, stream>>>((__gm__ uint8_t *)dst, (__gm__ uint8_t *)src); +} + +// Case 17: ui8 16x256 (input: 16x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMIN_ui8_16x256(__gm__ uint8_t *dst, __gm__ uint8_t *src); + +void LaunchTCOLMIN_ui8_16x256(void *dst, void *src, void *stream) { + TCOLMIN_ui8_16x256<<<1, nullptr, stream>>>((__gm__ uint8_t *)dst, (__gm__ uint8_t *)src); +} + +// Case 18: ui16 1x256 (input: 1x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMIN_ui16_1x256(__gm__ uint16_t *dst, __gm__ uint16_t *src); + +void LaunchTCOLMIN_ui16_1x256(void *dst, void *src, void *stream) { + TCOLMIN_ui16_1x256<<<1, nullptr, stream>>>((__gm__ uint16_t *)dst, (__gm__ uint16_t *)src); +} + +// Case 19: ui16 16x128 (input: 16x128, output: 1x128) +extern "C" __global__ AICORE void TCOLMIN_ui16_16x128(__gm__ uint16_t *dst, __gm__ uint16_t *src); + +void LaunchTCOLMIN_ui16_16x128(void *dst, void *src, void *stream) { + TCOLMIN_ui16_16x128<<<1, nullptr, stream>>>((__gm__ uint16_t *)dst, (__gm__ uint16_t *)src); +} + +// Case 20: ui16 16x256 (input: 16x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMIN_ui16_16x256(__gm__ uint16_t *dst, __gm__ uint16_t *src); + +void LaunchTCOLMIN_ui16_16x256(void *dst, void *src, void *stream) { + TCOLMIN_ui16_16x256<<<1, nullptr, stream>>>((__gm__ uint16_t *)dst, (__gm__ uint16_t *)src); +} + +// Case 21: ui32 1x256 (input: 1x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMIN_ui32_1x256(__gm__ uint32_t *dst, __gm__ uint32_t *src); + +void LaunchTCOLMIN_ui32_1x256(void *dst, void *src, void *stream) { + TCOLMIN_ui32_1x256<<<1, nullptr, stream>>>((__gm__ uint32_t *)dst, (__gm__ uint32_t *)src); +} + +// Case 22: ui32 16x128 (input: 16x128, output: 1x128) +extern "C" __global__ AICORE void TCOLMIN_ui32_16x128(__gm__ uint32_t *dst, __gm__ uint32_t *src); + +void LaunchTCOLMIN_ui32_16x128(void *dst, void *src, void *stream) { + TCOLMIN_ui32_16x128<<<1, nullptr, stream>>>((__gm__ uint32_t *)dst, (__gm__ uint32_t *)src); +} + +// Case 23: ui32 16x256 (input: 16x256, output: 1x256) +extern "C" __global__ AICORE void TCOLMIN_ui32_16x256(__gm__ uint32_t *dst, __gm__ uint32_t *src); + +void LaunchTCOLMIN_ui32_16x256(void *dst, void *src, void *stream) { + TCOLMIN_ui32_16x256<<<1, nullptr, stream>>>((__gm__ uint32_t *)dst, (__gm__ uint32_t *)src); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolmin/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolmin/main.cpp new file mode 100644 index 000000000..c24bc5316 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolmin/main.cpp @@ -0,0 +1,189 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tcolmin ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTCOLMIN_f32_1x256(float *dst, float *src, void *stream); +void LaunchTCOLMIN_f32_16x128(float *dst, float *src, void *stream); +void LaunchTCOLMIN_f32_16x256(float *dst, float *src, void *stream); +void LaunchTCOLMIN_f16_1x256(void *dst, void *src, void *stream); +void LaunchTCOLMIN_f16_16x128(void *dst, void *src, void *stream); +void LaunchTCOLMIN_f16_16x256(void *dst, void *src, void *stream); +void LaunchTCOLMIN_i8_1x256(void *dst, void *src, void *stream); +void LaunchTCOLMIN_i8_16x128(void *dst, void *src, void *stream); +void LaunchTCOLMIN_i8_16x256(void *dst, void *src, void *stream); +void LaunchTCOLMIN_i16_1x256(void *dst, void *src, void *stream); +void LaunchTCOLMIN_i16_16x128(void *dst, void *src, void *stream); +void LaunchTCOLMIN_i16_16x256(void *dst, void *src, void *stream); +void LaunchTCOLMIN_i32_1x256(void *dst, void *src, void *stream); +void LaunchTCOLMIN_i32_16x128(void *dst, void *src, void *stream); +void LaunchTCOLMIN_i32_16x256(void *dst, void *src, void *stream); +void LaunchTCOLMIN_ui8_1x256(void *dst, void *src, void *stream); +void LaunchTCOLMIN_ui8_16x128(void *dst, void *src, void *stream); +void LaunchTCOLMIN_ui8_16x256(void *dst, void *src, void *stream); +void LaunchTCOLMIN_ui16_1x256(void *dst, void *src, void *stream); +void LaunchTCOLMIN_ui16_16x128(void *dst, void *src, void *stream); +void LaunchTCOLMIN_ui16_16x256(void *dst, void *src, void *stream); +void LaunchTCOLMIN_ui32_1x256(void *dst, void *src, void *stream); +void LaunchTCOLMIN_ui32_16x128(void *dst, void *src, void *stream); +void LaunchTCOLMIN_ui32_16x256(void *dst, void *src, void *stream); + +using LaunchFnFloat = void (*)(float *, float *, void *); +using LaunchFnVoid = void (*)(void *, void *, void *); + +struct TestCase { + const char *name; + void *launch; + size_t srcRows; + size_t srcCols; + size_t srcValidRows; + size_t srcValidCols; + size_t dstRows; + size_t dstCols; + size_t dstValidCols; + size_t elemSize; + bool isFp16; +}; + +static const TestCase kCases[] = { + {"f32_1x256", (void*)LaunchTCOLMIN_f32_1x256, 1, 256, 1, 255, 1, 256, 255, sizeof(float), false}, + {"f32_16x128", (void*)LaunchTCOLMIN_f32_16x128, 16, 128, 16, 127, 1, 128, 127, sizeof(float), false}, + {"f32_16x256", (void*)LaunchTCOLMIN_f32_16x256, 16, 256, 15, 255, 1, 256, 255, sizeof(float), false}, + {"f16_1x256", (void*)LaunchTCOLMIN_f16_1x256, 1, 256, 1, 255, 1, 256, 255, 2, true}, + {"f16_16x128", (void*)LaunchTCOLMIN_f16_16x128, 16, 128, 16, 127, 1, 128, 127, 2, true}, + {"f16_16x256", (void*)LaunchTCOLMIN_f16_16x256, 16, 256, 15, 255, 1, 256, 255, 2, true}, + {"i8_1x256", (void*)LaunchTCOLMIN_i8_1x256, 1, 256, 1, 255, 1, 256, 255, 1, true}, + {"i8_16x128", (void*)LaunchTCOLMIN_i8_16x128, 16, 128, 16, 127, 1, 128, 127, 1, true}, + {"i8_16x256", (void*)LaunchTCOLMIN_i8_16x256, 16, 256, 15, 255, 1, 256, 255, 1, true}, + {"i16_1x256", (void*)LaunchTCOLMIN_i16_1x256, 1, 256, 1, 255, 1, 256, 255, 2, true}, + {"i16_16x128", (void*)LaunchTCOLMIN_i16_16x128, 16, 128, 16, 127, 1, 128, 127, 2, true}, + {"i16_16x256", (void*)LaunchTCOLMIN_i16_16x256, 16, 256, 15, 255, 1, 256, 255, 2, true}, + {"i32_1x256", (void*)LaunchTCOLMIN_i32_1x256, 1, 256, 1, 255, 1, 256, 255, 4, true}, + {"i32_16x128", (void*)LaunchTCOLMIN_i32_16x128, 16, 128, 16, 127, 1, 128, 127, 4, true}, + {"i32_16x256", (void*)LaunchTCOLMIN_i32_16x256, 16, 256, 15, 255, 1, 256, 255, 4, true}, + {"ui8_1x256", (void*)LaunchTCOLMIN_ui8_1x256, 1, 256, 1, 255, 1, 256, 255, 1, true}, + {"ui8_16x128", (void*)LaunchTCOLMIN_ui8_16x128, 16, 128, 16, 127, 1, 128, 127, 1, true}, + {"ui8_16x256", (void*)LaunchTCOLMIN_ui8_16x256, 16, 256, 15, 255, 1, 256, 255, 1, true}, + {"ui16_1x256", (void*)LaunchTCOLMIN_ui16_1x256, 1, 256, 1, 255, 1, 256, 255, 2, true}, + {"ui16_16x128", (void*)LaunchTCOLMIN_ui16_16x128, 16, 128, 16, 127, 1, 128, 127, 2, true}, + {"ui16_16x256", (void*)LaunchTCOLMIN_ui16_16x256, 16, 256, 15, 255, 1, 256, 255, 2, true}, + {"ui32_1x256", (void*)LaunchTCOLMIN_ui32_1x256, 1, 256, 1, 255, 1, 256, 255, 4, true}, + {"ui32_16x128", (void*)LaunchTCOLMIN_ui32_16x128, 16, 128, 16, 127, 1, 128, 127, 4, true}, + {"ui32_16x256", (void*)LaunchTCOLMIN_ui32_16x256, 16, 256, 15, 255, 1, 256, 255, 4, true}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t srcElemCount = tc.srcRows * tc.srcCols; + const size_t srcFileSize = srcElemCount * tc.elemSize; + const size_t dstElemCount = tc.dstRows * tc.dstCols; + const size_t dstFileSize = dstElemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (src=%zux%zu, dst=%zux%zu, fp16=%d) ===\n", + tc.name, tc.srcRows, tc.srcCols, tc.dstRows, tc.dstCols, tc.isFp16); + + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSizeVar = srcFileSize; + size_t dstFileSizeVar = dstFileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&srcHost, srcFileSize); + aclrtMallocHost(&dstHost, dstFileSize); + + aclrtMalloc(&srcDevice, srcFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input.bin").c_str(), srcFileSizeVar, srcHost, srcFileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, srcFileSize, srcHost, srcFileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + if (tc.isFp16) { + LaunchFnVoid launch = (LaunchFnVoid)tc.launch; + launch(dstDevice, srcDevice, stream); + } else { + LaunchFnFloat launch = (LaunchFnFloat)tc.launch; + launch((float*)dstDevice, (float*)srcDevice, stream); + } + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSizeVar)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolmin/tcolmin.pto b/test/tilelang_st/npu/a5/src/st/testcase/tcolmin/tcolmin.pto new file mode 100644 index 000000000..14065059f --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolmin/tcolmin.pto @@ -0,0 +1,1038 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tcolmin: tload(src) + tcolmin(dst, src) + tstore(dst). +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // Case 0: f32 1x256 (input: 1x256, output: 1x256) + func.func @TCOLMIN_f32_1x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xf32> -> !pto.partition_tensor_view<1x1x1x1x255xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xf32> -> !pto.partition_tensor_view<1x1x1x1x255xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xf32>) + outs(%src : !pto.tile_buf) + + pto.tcolmin ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xf32>) + return + } + + // Case 1: f32 16x128 (input: 16x128, output: 1x128) + func.func @TCOLMIN_f32_16x128(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x127xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x127xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xf32>) + outs(%src : !pto.tile_buf) + + pto.tcolmin ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xf32>) + return + } + + // Case 2: f32 16x256 (input: 16x256, output: 1x256) + func.func @TCOLMIN_f32_16x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xf32> -> !pto.partition_tensor_view<1x1x1x15x255xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xf32> -> !pto.partition_tensor_view<1x1x1x1x255xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xf32>) + outs(%src : !pto.tile_buf) + + pto.tcolmin ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xf32>) + return + } + + // Case 3: f16 1x256 (input: 1x256, output: 1x256) + func.func @TCOLMIN_f16_1x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x255xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x255xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xf16>) + outs(%src : !pto.tile_buf) + + pto.tcolmin ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xf16>) + return + } + + // Case 4: f16 16x128 (input: 16x128, output: 1x128) + func.func @TCOLMIN_f16_16x128(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xf16> -> !pto.partition_tensor_view<1x1x1x16x127xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xf16> -> !pto.partition_tensor_view<1x1x1x1x127xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xf16>) + outs(%src : !pto.tile_buf) + + pto.tcolmin ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xf16>) + return + } + + // Case 5: f16 16x256 (input: 16x256, output: 1x256) + func.func @TCOLMIN_f16_16x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xf16> -> !pto.partition_tensor_view<1x1x1x15x255xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x255xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xf16>) + outs(%src : !pto.tile_buf) + + pto.tcolmin ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xf16>) + return + } + + // Case 6: i8 1x256 (input: 1x256, output: 1x256) + func.func @TCOLMIN_i8_1x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi8> -> !pto.partition_tensor_view<1x1x1x1x255xi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi8> -> !pto.partition_tensor_view<1x1x1x1x255xi8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xi8>) + outs(%src : !pto.tile_buf) + + pto.tcolmin ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi8>) + return + } + + // Case 7: i8 16x128 (input: 16x128, output: 1x128) + func.func @TCOLMIN_i8_16x128(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xi8> -> !pto.partition_tensor_view<1x1x1x16x127xi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xi8> -> !pto.partition_tensor_view<1x1x1x1x127xi8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xi8>) + outs(%src : !pto.tile_buf) + + pto.tcolmin ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xi8>) + return + } + + // Case 8: i8 16x256 (input: 16x256, output: 1x256) + func.func @TCOLMIN_i8_16x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xi8> -> !pto.partition_tensor_view<1x1x1x15x255xi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi8> -> !pto.partition_tensor_view<1x1x1x1x255xi8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xi8>) + outs(%src : !pto.tile_buf) + + pto.tcolmin ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi8>) + return + } + + // Case 9: i16 1x256 (input: 1x256, output: 1x256) + func.func @TCOLMIN_i16_1x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi16> -> !pto.partition_tensor_view<1x1x1x1x255xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi16> -> !pto.partition_tensor_view<1x1x1x1x255xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xi16>) + outs(%src : !pto.tile_buf) + + pto.tcolmin ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi16>) + return + } + + // Case 10: i16 16x128 (input: 16x128, output: 1x128) + func.func @TCOLMIN_i16_16x128(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xi16> -> !pto.partition_tensor_view<1x1x1x16x127xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xi16> -> !pto.partition_tensor_view<1x1x1x1x127xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xi16>) + outs(%src : !pto.tile_buf) + + pto.tcolmin ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xi16>) + return + } + + // Case 11: i16 16x256 (input: 16x256, output: 1x256) + func.func @TCOLMIN_i16_16x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xi16> -> !pto.partition_tensor_view<1x1x1x15x255xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi16> -> !pto.partition_tensor_view<1x1x1x1x255xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xi16>) + outs(%src : !pto.tile_buf) + + pto.tcolmin ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi16>) + return + } + + // Case 12: i32 1x256 (input: 1x256, output: 1x256) + func.func @TCOLMIN_i32_1x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x255xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x255xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xi32>) + outs(%src : !pto.tile_buf) + + pto.tcolmin ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi32>) + return + } + + // Case 13: i32 16x128 (input: 16x128, output: 1x128) + func.func @TCOLMIN_i32_16x128(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xi32> -> !pto.partition_tensor_view<1x1x1x16x127xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xi32> -> !pto.partition_tensor_view<1x1x1x1x127xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xi32>) + outs(%src : !pto.tile_buf) + + pto.tcolmin ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xi32>) + return + } + + // Case 14: i32 16x256 (input: 16x256, output: 1x256) + func.func @TCOLMIN_i32_16x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xi32> -> !pto.partition_tensor_view<1x1x1x15x255xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x255xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xi32>) + outs(%src : !pto.tile_buf) + + pto.tcolmin ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi32>) + return + } + + // Case 15: ui8 1x256 (input: 1x256, output: 1x256) + func.func @TCOLMIN_ui8_1x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui8> -> !pto.partition_tensor_view<1x1x1x1x255xui8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui8> -> !pto.partition_tensor_view<1x1x1x1x255xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xui8>) + outs(%src : !pto.tile_buf) + + pto.tcolmin ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xui8>) + return + } + + // Case 16: ui8 16x128 (input: 16x128, output: 1x128) + func.func @TCOLMIN_ui8_16x128(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xui8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xui8> -> !pto.partition_tensor_view<1x1x1x16x127xui8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xui8> -> !pto.partition_tensor_view<1x1x1x1x127xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xui8>) + outs(%src : !pto.tile_buf) + + pto.tcolmin ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xui8>) + return + } + + // Case 17: ui8 16x256 (input: 16x256, output: 1x256) + func.func @TCOLMIN_ui8_16x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xui8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xui8> -> !pto.partition_tensor_view<1x1x1x15x255xui8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui8> -> !pto.partition_tensor_view<1x1x1x1x255xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xui8>) + outs(%src : !pto.tile_buf) + + pto.tcolmin ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xui8>) + return + } + + // Case 18: ui16 1x256 (input: 1x256, output: 1x256) + func.func @TCOLMIN_ui16_1x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui16> -> !pto.partition_tensor_view<1x1x1x1x255xui16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui16> -> !pto.partition_tensor_view<1x1x1x1x255xui16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xui16>) + outs(%src : !pto.tile_buf) + + pto.tcolmin ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xui16>) + return + } + + // Case 19: ui16 16x128 (input: 16x128, output: 1x128) + func.func @TCOLMIN_ui16_16x128(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xui16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xui16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xui16> -> !pto.partition_tensor_view<1x1x1x16x127xui16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xui16> -> !pto.partition_tensor_view<1x1x1x1x127xui16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xui16>) + outs(%src : !pto.tile_buf) + + pto.tcolmin ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xui16>) + return + } + + // Case 20: ui16 16x256 (input: 16x256, output: 1x256) + func.func @TCOLMIN_ui16_16x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xui16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xui16> -> !pto.partition_tensor_view<1x1x1x15x255xui16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui16> -> !pto.partition_tensor_view<1x1x1x1x255xui16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xui16>) + outs(%src : !pto.tile_buf) + + pto.tcolmin ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xui16>) + return + } + + // Case 21: ui32 1x256 (input: 1x256, output: 1x256) + func.func @TCOLMIN_ui32_1x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui32> -> !pto.partition_tensor_view<1x1x1x1x255xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui32> -> !pto.partition_tensor_view<1x1x1x1x255xui32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xui32>) + outs(%src : !pto.tile_buf) + + pto.tcolmin ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xui32>) + return + } + + // Case 22: ui32 16x128 (input: 16x128, output: 1x128) + func.func @TCOLMIN_ui32_16x128(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xui32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xui32> -> !pto.partition_tensor_view<1x1x1x16x127xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xui32> -> !pto.partition_tensor_view<1x1x1x1x127xui32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xui32>) + outs(%src : !pto.tile_buf) + + pto.tcolmin ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xui32>) + return + } + + // Case 23: ui32 16x256 (input: 16x256, output: 1x256) + func.func @TCOLMIN_ui32_16x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xui32> -> !pto.partition_tensor_view<1x1x1x15x255xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui32> -> !pto.partition_tensor_view<1x1x1x1x255xui32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xui32>) + outs(%src : !pto.tile_buf) + + pto.tcolmin ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xui32>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolprod/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tcolprod/CMakeLists.txt new file mode 100644 index 000000000..02b874532 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolprod/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tcolprod) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolprod/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolprod/cases.py new file mode 100644 index 000000000..e95d300ec --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolprod/cases.py @@ -0,0 +1,164 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tcolprod ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions for input. + - valid_shape: (valid_rows, valid_cols) — effective computation region for input. + - dst_shape: (1, cols) — allocated tile dimensions for output. + - dst_valid_shape: (1, valid_cols) — effective computation region for output. + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_1x256", + "dtype": np.float32, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 1e-6, + }, + { + "name": "f32_16x128", + "dtype": np.float32, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 1e-6, + }, + { + "name": "f32_16x256", + "dtype": np.float32, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 1e-6, + }, + { + "name": "i16_1x256", + "dtype": np.int16, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "i16_16x128", + "dtype": np.int16, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 0, + }, + { + "name": "i16_16x256", + "dtype": np.int16, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "ui16_1x256", + "dtype": np.uint16, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "ui16_16x128", + "dtype": np.uint16, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 0, + }, + { + "name": "ui16_16x256", + "dtype": np.uint16, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "i32_1x256", + "dtype": np.int32, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "i32_16x128", + "dtype": np.int32, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 0, + }, + { + "name": "i32_16x256", + "dtype": np.int32, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "ui32_1x256", + "dtype": np.uint32, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "ui32_16x128", + "dtype": np.uint32, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 0, + }, + { + "name": "ui32_16x256", + "dtype": np.uint32, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolprod/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolprod/compare.py new file mode 100644 index 000000000..06a17bbda --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolprod/compare.py @@ -0,0 +1,50 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + vr, vc = dst_valid_shape + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(dst_shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(dst_shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolprod/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolprod/gen_data.py new file mode 100644 index 000000000..ff3b740eb --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolprod/gen_data.py @@ -0,0 +1,35 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + + vr, vc = valid_shape + golden = np.zeros(dst_shape, dtype=dtype) + golden_result = np.prod(input1[:vr, :vc], axis=0, keepdims=True).astype(dtype) + golden[:1, :vc] = golden_result + + save_case_data(case["name"], {"input": input1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dst_shape={dst_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolprod/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolprod/launch.cpp new file mode 100644 index 000000000..158d2a06e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolprod/launch.cpp @@ -0,0 +1,118 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: f32 1x256 (input: 1x256, output: 1x256) +extern "C" __global__ AICORE void TCOLPROD_f32_1x256(__gm__ float *dst, __gm__ float *src); + +void LaunchTCOLPROD_f32_1x256(float *dst, float *src, void *stream) { + TCOLPROD_f32_1x256<<<1, nullptr, stream>>>((__gm__ float *)dst, (__gm__ float *)src); +} + +// Case 1: f32 16x128 (input: 16x128, output: 1x128) +extern "C" __global__ AICORE void TCOLPROD_f32_16x128(__gm__ float *dst, __gm__ float *src); + +void LaunchTCOLPROD_f32_16x128(float *dst, float *src, void *stream) { + TCOLPROD_f32_16x128<<<1, nullptr, stream>>>((__gm__ float *)dst, (__gm__ float *)src); +} + +// Case 2: f32 16x256 (input: 16x256, output: 1x256) +extern "C" __global__ AICORE void TCOLPROD_f32_16x256(__gm__ float *dst, __gm__ float *src); + +void LaunchTCOLPROD_f32_16x256(float *dst, float *src, void *stream) { + TCOLPROD_f32_16x256<<<1, nullptr, stream>>>((__gm__ float *)dst, (__gm__ float *)src); +} + +// Case 3: i16 1x256 (input: 1x256, output: 1x256) +extern "C" __global__ AICORE void TCOLPROD_i16_1x256(__gm__ int16_t *dst, __gm__ int16_t *src); + +void LaunchTCOLPROD_i16_1x256(void *dst, void *src, void *stream) { + TCOLPROD_i16_1x256<<<1, nullptr, stream>>>((__gm__ int16_t *)dst, (__gm__ int16_t *)src); +} + +// Case 4: i16 16x128 (input: 16x128, output: 1x128) +extern "C" __global__ AICORE void TCOLPROD_i16_16x128(__gm__ int16_t *dst, __gm__ int16_t *src); + +void LaunchTCOLPROD_i16_16x128(void *dst, void *src, void *stream) { + TCOLPROD_i16_16x128<<<1, nullptr, stream>>>((__gm__ int16_t *)dst, (__gm__ int16_t *)src); +} + +// Case 5: i16 16x256 (input: 16x256, output: 1x256) +extern "C" __global__ AICORE void TCOLPROD_i16_16x256(__gm__ int16_t *dst, __gm__ int16_t *src); + +void LaunchTCOLPROD_i16_16x256(void *dst, void *src, void *stream) { + TCOLPROD_i16_16x256<<<1, nullptr, stream>>>((__gm__ int16_t *)dst, (__gm__ int16_t *)src); +} + +// Case 6: ui16 1x256 (input: 1x256, output: 1x256) +extern "C" __global__ AICORE void TCOLPROD_ui16_1x256(__gm__ uint16_t *dst, __gm__ uint16_t *src); + +void LaunchTCOLPROD_ui16_1x256(void *dst, void *src, void *stream) { + TCOLPROD_ui16_1x256<<<1, nullptr, stream>>>((__gm__ uint16_t *)dst, (__gm__ uint16_t *)src); +} + +// Case 7: ui16 16x128 (input: 16x128, output: 1x128) +extern "C" __global__ AICORE void TCOLPROD_ui16_16x128(__gm__ uint16_t *dst, __gm__ uint16_t *src); + +void LaunchTCOLPROD_ui16_16x128(void *dst, void *src, void *stream) { + TCOLPROD_ui16_16x128<<<1, nullptr, stream>>>((__gm__ uint16_t *)dst, (__gm__ uint16_t *)src); +} + +// Case 8: ui16 16x256 (input: 16x256, output: 1x256) +extern "C" __global__ AICORE void TCOLPROD_ui16_16x256(__gm__ uint16_t *dst, __gm__ uint16_t *src); + +void LaunchTCOLPROD_ui16_16x256(void *dst, void *src, void *stream) { + TCOLPROD_ui16_16x256<<<1, nullptr, stream>>>((__gm__ uint16_t *)dst, (__gm__ uint16_t *)src); +} + +// Case 9: i32 1x256 (input: 1x256, output: 1x256) +extern "C" __global__ AICORE void TCOLPROD_i32_1x256(__gm__ int32_t *dst, __gm__ int32_t *src); + +void LaunchTCOLPROD_i32_1x256(void *dst, void *src, void *stream) { + TCOLPROD_i32_1x256<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ int32_t *)src); +} + +// Case 10: i32 16x128 (input: 16x128, output: 1x128) +extern "C" __global__ AICORE void TCOLPROD_i32_16x128(__gm__ int32_t *dst, __gm__ int32_t *src); + +void LaunchTCOLPROD_i32_16x128(void *dst, void *src, void *stream) { + TCOLPROD_i32_16x128<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ int32_t *)src); +} + +// Case 11: i32 16x256 (input: 16x256, output: 1x256) +extern "C" __global__ AICORE void TCOLPROD_i32_16x256(__gm__ int32_t *dst, __gm__ int32_t *src); + +void LaunchTCOLPROD_i32_16x256(void *dst, void *src, void *stream) { + TCOLPROD_i32_16x256<<<1, nullptr, stream>>>((__gm__ int32_t *)dst, (__gm__ int32_t *)src); +} + +// Case 12: ui32 1x256 (input: 1x256, output: 1x256) +extern "C" __global__ AICORE void TCOLPROD_ui32_1x256(__gm__ uint32_t *dst, __gm__ uint32_t *src); + +void LaunchTCOLPROD_ui32_1x256(void *dst, void *src, void *stream) { + TCOLPROD_ui32_1x256<<<1, nullptr, stream>>>((__gm__ uint32_t *)dst, (__gm__ uint32_t *)src); +} + +// Case 13: ui32 16x128 (input: 16x128, output: 1x128) +extern "C" __global__ AICORE void TCOLPROD_ui32_16x128(__gm__ uint32_t *dst, __gm__ uint32_t *src); + +void LaunchTCOLPROD_ui32_16x128(void *dst, void *src, void *stream) { + TCOLPROD_ui32_16x128<<<1, nullptr, stream>>>((__gm__ uint32_t *)dst, (__gm__ uint32_t *)src); +} + +// Case 14: ui32 16x256 (input: 16x256, output: 1x256) +extern "C" __global__ AICORE void TCOLPROD_ui32_16x256(__gm__ uint32_t *dst, __gm__ uint32_t *src); + +void LaunchTCOLPROD_ui32_16x256(void *dst, void *src, void *stream) { + TCOLPROD_ui32_16x256<<<1, nullptr, stream>>>((__gm__ uint32_t *)dst, (__gm__ uint32_t *)src); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolprod/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolprod/main.cpp new file mode 100644 index 000000000..6850592a1 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolprod/main.cpp @@ -0,0 +1,170 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. +// Host driver for TileLang tcolprod ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTCOLPROD_f32_1x256(float *dst, float *src, void *stream); +void LaunchTCOLPROD_f32_16x128(float *dst, float *src, void *stream); +void LaunchTCOLPROD_f32_16x256(float *dst, float *src, void *stream); +void LaunchTCOLPROD_i16_1x256(void *dst, void *src, void *stream); +void LaunchTCOLPROD_i16_16x128(void *dst, void *src, void *stream); +void LaunchTCOLPROD_i16_16x256(void *dst, void *src, void *stream); +void LaunchTCOLPROD_ui16_1x256(void *dst, void *src, void *stream); +void LaunchTCOLPROD_ui16_16x128(void *dst, void *src, void *stream); +void LaunchTCOLPROD_ui16_16x256(void *dst, void *src, void *stream); +void LaunchTCOLPROD_i32_1x256(void *dst, void *src, void *stream); +void LaunchTCOLPROD_i32_16x128(void *dst, void *src, void *stream); +void LaunchTCOLPROD_i32_16x256(void *dst, void *src, void *stream); +void LaunchTCOLPROD_ui32_1x256(void *dst, void *src, void *stream); +void LaunchTCOLPROD_ui32_16x128(void *dst, void *src, void *stream); +void LaunchTCOLPROD_ui32_16x256(void *dst, void *src, void *stream); + +using LaunchFnFloat = void (*)(float *, float *, void *); +using LaunchFnVoid = void (*)(void *, void *, void *); + +struct TestCase { + const char *name; + void *launch; + size_t srcRows; + size_t srcCols; + size_t srcValidRows; + size_t srcValidCols; + size_t dstRows; + size_t dstCols; + size_t dstValidCols; + size_t elemSize; + bool isFp16; +}; + +static const TestCase kCases[] = { + {"f32_1x256", (void*)LaunchTCOLPROD_f32_1x256, 1, 256, 1, 255, 1, 256, 255, sizeof(float), false}, + {"f32_16x128", (void*)LaunchTCOLPROD_f32_16x128, 16, 128, 16, 127, 1, 128, 127, sizeof(float), false}, + {"f32_16x256", (void*)LaunchTCOLPROD_f32_16x256, 16, 256, 15, 255, 1, 256, 255, sizeof(float), false}, + {"i16_1x256", (void*)LaunchTCOLPROD_i16_1x256, 1, 256, 1, 255, 1, 256, 255, 2, true}, + {"i16_16x128", (void*)LaunchTCOLPROD_i16_16x128, 16, 128, 16, 127, 1, 128, 127, 2, true}, + {"i16_16x256", (void*)LaunchTCOLPROD_i16_16x256, 16, 256, 15, 255, 1, 256, 255, 2, true}, + {"ui16_1x256", (void*)LaunchTCOLPROD_ui16_1x256, 1, 256, 1, 255, 1, 256, 255, 2, true}, + {"ui16_16x128", (void*)LaunchTCOLPROD_ui16_16x128, 16, 128, 16, 127, 1, 128, 127, 2, true}, + {"ui16_16x256", (void*)LaunchTCOLPROD_ui16_16x256, 16, 256, 15, 255, 1, 256, 255, 2, true}, + {"i32_1x256", (void*)LaunchTCOLPROD_i32_1x256, 1, 256, 1, 255, 1, 256, 255, 4, true}, + {"i32_16x128", (void*)LaunchTCOLPROD_i32_16x128, 16, 128, 16, 127, 1, 128, 127, 4, true}, + {"i32_16x256", (void*)LaunchTCOLPROD_i32_16x256, 16, 256, 15, 255, 1, 256, 255, 4, true}, + {"ui32_1x256", (void*)LaunchTCOLPROD_ui32_1x256, 1, 256, 1, 255, 1, 256, 255, 4, true}, + {"ui32_16x128", (void*)LaunchTCOLPROD_ui32_16x128, 16, 128, 16, 127, 1, 128, 127, 4, true}, + {"ui32_16x256", (void*)LaunchTCOLPROD_ui32_16x256, 16, 256, 15, 255, 1, 256, 255, 4, true}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t srcElemCount = tc.srcRows * tc.srcCols; + const size_t srcFileSize = srcElemCount * tc.elemSize; + const size_t dstElemCount = tc.dstRows * tc.dstCols; + const size_t dstFileSize = dstElemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (src=%zux%zu, dst=%zux%zu, fp16=%d) ===\n", + tc.name, tc.srcRows, tc.srcCols, tc.dstRows, tc.dstCols, tc.isFp16); + + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSizeVar = srcFileSize; + size_t dstFileSizeVar = dstFileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&srcHost, srcFileSize); + aclrtMallocHost(&dstHost, dstFileSize); + + aclrtMalloc(&srcDevice, srcFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input.bin").c_str(), srcFileSizeVar, srcHost, srcFileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, srcFileSize, srcHost, srcFileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + if (tc.isFp16) { + LaunchFnVoid launch = (LaunchFnVoid)tc.launch; + launch(dstDevice, srcDevice, stream); + } else { + LaunchFnFloat launch = (LaunchFnFloat)tc.launch; + launch((float*)dstDevice, (float*)srcDevice, stream); + } + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSizeVar)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolprod/tcolprod.pto b/test/tilelang_st/npu/a5/src/st/testcase/tcolprod/tcolprod.pto new file mode 100644 index 000000000..9da594562 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolprod/tcolprod.pto @@ -0,0 +1,654 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tcolprod: tload(src) + tcolprod(dst, src) + tstore(dst). +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // Case 0: f32 1x256 (input: 1x256, output: 1x256) + func.func @TCOLPROD_f32_1x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xf32> -> !pto.partition_tensor_view<1x1x1x1x255xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xf32> -> !pto.partition_tensor_view<1x1x1x1x255xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xf32>) + outs(%src : !pto.tile_buf) + + pto.tcolprod ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xf32>) + return + } + + // Case 1: f32 16x128 (input: 16x128, output: 1x128) + func.func @TCOLPROD_f32_16x128(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x127xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x127xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xf32>) + outs(%src : !pto.tile_buf) + + pto.tcolprod ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xf32>) + return + } + + // Case 2: f32 16x256 (input: 16x256, output: 1x256) + func.func @TCOLPROD_f32_16x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xf32> -> !pto.partition_tensor_view<1x1x1x15x255xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xf32> -> !pto.partition_tensor_view<1x1x1x1x255xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xf32>) + outs(%src : !pto.tile_buf) + + pto.tcolprod ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xf32>) + return + } + + // Case 3: i16 1x256 (input: 1x256, output: 1x256) + func.func @TCOLPROD_i16_1x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi16> -> !pto.partition_tensor_view<1x1x1x1x255xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi16> -> !pto.partition_tensor_view<1x1x1x1x255xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xi16>) + outs(%src : !pto.tile_buf) + + pto.tcolprod ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi16>) + return + } + + // Case 4: i16 16x128 (input: 16x128, output: 1x128) + func.func @TCOLPROD_i16_16x128(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xi16> -> !pto.partition_tensor_view<1x1x1x16x127xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xi16> -> !pto.partition_tensor_view<1x1x1x1x127xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xi16>) + outs(%src : !pto.tile_buf) + + pto.tcolprod ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xi16>) + return + } + + // Case 5: i16 16x256 (input: 16x256, output: 1x256) + func.func @TCOLPROD_i16_16x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xi16> -> !pto.partition_tensor_view<1x1x1x15x255xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi16> -> !pto.partition_tensor_view<1x1x1x1x255xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xi16>) + outs(%src : !pto.tile_buf) + + pto.tcolprod ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi16>) + return + } + + // Case 6: ui16 1x256 (input: 1x256, output: 1x256) + func.func @TCOLPROD_ui16_1x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui16> -> !pto.partition_tensor_view<1x1x1x1x255xui16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui16> -> !pto.partition_tensor_view<1x1x1x1x255xui16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xui16>) + outs(%src : !pto.tile_buf) + + pto.tcolprod ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xui16>) + return + } + + // Case 7: ui16 16x128 (input: 16x128, output: 1x128) + func.func @TCOLPROD_ui16_16x128(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xui16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xui16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xui16> -> !pto.partition_tensor_view<1x1x1x16x127xui16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xui16> -> !pto.partition_tensor_view<1x1x1x1x127xui16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xui16>) + outs(%src : !pto.tile_buf) + + pto.tcolprod ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xui16>) + return + } + + // Case 8: ui16 16x256 (input: 16x256, output: 1x256) + func.func @TCOLPROD_ui16_16x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xui16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xui16> -> !pto.partition_tensor_view<1x1x1x15x255xui16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui16> -> !pto.partition_tensor_view<1x1x1x1x255xui16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xui16>) + outs(%src : !pto.tile_buf) + + pto.tcolprod ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xui16>) + return + } + + // Case 9: i32 1x256 (input: 1x256, output: 1x256) + func.func @TCOLPROD_i32_1x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x255xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x255xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xi32>) + outs(%src : !pto.tile_buf) + + pto.tcolprod ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi32>) + return + } + + // Case 10: i32 16x128 (input: 16x128, output: 1x128) + func.func @TCOLPROD_i32_16x128(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xi32> -> !pto.partition_tensor_view<1x1x1x16x127xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xi32> -> !pto.partition_tensor_view<1x1x1x1x127xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xi32>) + outs(%src : !pto.tile_buf) + + pto.tcolprod ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xi32>) + return + } + + // Case 11: i32 16x256 (input: 16x256, output: 1x256) + func.func @TCOLPROD_i32_16x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xi32> -> !pto.partition_tensor_view<1x1x1x15x255xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x255xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xi32>) + outs(%src : !pto.tile_buf) + + pto.tcolprod ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi32>) + return + } + + // Case 12: ui32 1x256 (input: 1x256, output: 1x256) + func.func @TCOLPROD_ui32_1x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui32> -> !pto.partition_tensor_view<1x1x1x1x255xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui32> -> !pto.partition_tensor_view<1x1x1x1x255xui32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xui32>) + outs(%src : !pto.tile_buf) + + pto.tcolprod ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xui32>) + return + } + + // Case 13: ui32 16x128 (input: 16x128, output: 1x128) + func.func @TCOLPROD_ui32_16x128(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xui32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xui32> -> !pto.partition_tensor_view<1x1x1x16x127xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xui32> -> !pto.partition_tensor_view<1x1x1x1x127xui32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xui32>) + outs(%src : !pto.tile_buf) + + pto.tcolprod ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xui32>) + return + } + + // Case 14: ui32 16x256 (input: 16x256, output: 1x256) + func.func @TCOLPROD_ui32_16x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xui32> -> !pto.partition_tensor_view<1x1x1x15x255xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xui32> -> !pto.partition_tensor_view<1x1x1x1x255xui32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xui32>) + outs(%src : !pto.tile_buf) + + pto.tcolprod ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xui32>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolsum/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tcolsum/CMakeLists.txt new file mode 100644 index 000000000..e59d778af --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolsum/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tcolsum) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolsum/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolsum/cases.py new file mode 100644 index 000000000..dfbcbddf3 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolsum/cases.py @@ -0,0 +1,173 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tcolsum ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions for input. + - valid_shape: (valid_rows, valid_cols) — effective computation region for input. + - dst_shape: (1, cols) — allocated tile dimensions for output. + - dst_valid_shape: (1, valid_cols) — effective computation region for output. + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_1x256", + "dtype": np.float32, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 1e-6, + }, + { + "name": "f32_16x128", + "dtype": np.float32, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 1e-6, + }, + { + "name": "f32_16x256", + "dtype": np.float32, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 1e-6, + }, + { + "name": "f32_64x128_1", + "dtype": np.float32, + "shape": (64, 128), + "valid_shape": (63, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 1e-6, + }, + { + "name": "f32_64x128_2", + "dtype": np.float32, + "shape": (64, 128), + "valid_shape": (64, 128), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 128), + "eps": 1e-6, + }, + { + "name": "f32_1x512", + "dtype": np.float32, + "shape": (1, 512), + "valid_shape": (1, 511), + "dst_shape": (1, 512), + "dst_valid_shape": (1, 511), + "eps": 1e-6, + }, + { + "name": "f16_1x256", + "dtype": np.float16, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 1e-3, + }, + { + "name": "f16_16x128", + "dtype": np.float16, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 1e-3, + }, + { + "name": "f16_16x256", + "dtype": np.float16, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 1e-3, + }, + { + "name": "f16_64x128_1", + "dtype": np.float16, + "shape": (64, 128), + "valid_shape": (63, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 1e-3, + }, + { + "name": "f16_64x128_2", + "dtype": np.float16, + "shape": (64, 128), + "valid_shape": (64, 128), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 128), + "eps": 1e-3, + }, + { + "name": "i8_1x256", + "dtype": np.int8, + "shape": (1, 256), + "valid_shape": (1, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "i8_16x128", + "dtype": np.int8, + "shape": (16, 128), + "valid_shape": (16, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 0, + }, + { + "name": "i8_16x256", + "dtype": np.int8, + "shape": (16, 256), + "valid_shape": (15, 255), + "dst_shape": (1, 256), + "dst_valid_shape": (1, 255), + "eps": 0, + }, + { + "name": "i8_64x128_1", + "dtype": np.int8, + "shape": (64, 128), + "valid_shape": (63, 127), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 127), + "eps": 0, + }, + { + "name": "i8_64x128_2", + "dtype": np.int8, + "shape": (64, 128), + "valid_shape": (64, 128), + "dst_shape": (1, 128), + "dst_valid_shape": (1, 128), + "eps": 0, + }, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolsum/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolsum/compare.py new file mode 100644 index 000000000..06a17bbda --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolsum/compare.py @@ -0,0 +1,50 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + vr, vc = dst_valid_shape + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(dst_shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(dst_shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolsum/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tcolsum/gen_data.py new file mode 100644 index 000000000..7c0e19eef --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolsum/gen_data.py @@ -0,0 +1,35 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + + vr, vc = valid_shape + golden = np.zeros(dst_shape, dtype=dtype) + golden_result = np.sum(input1[:vr, :vc], axis=0, keepdims=True).astype(dtype) + golden[:1, :vc] = golden_result + + save_case_data(case["name"], {"input": input1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dst_shape={dst_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolsum/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolsum/launch.cpp new file mode 100644 index 000000000..6c8af717f --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolsum/launch.cpp @@ -0,0 +1,125 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: f32 1x256 (input: 1x256, output: 1x256) +extern "C" __global__ AICORE void TCOLSUM_f32_1x256(__gm__ float *dst, __gm__ float *src); + +void LaunchTCOLSUM_f32_1x256(float *dst, float *src, void *stream) { + TCOLSUM_f32_1x256<<<1, nullptr, stream>>>((__gm__ float *)dst, (__gm__ float *)src); +} + +// Case 1: f32 16x128 (input: 16x128, output: 1x128) +extern "C" __global__ AICORE void TCOLSUM_f32_16x128(__gm__ float *dst, __gm__ float *src); + +void LaunchTCOLSUM_f32_16x128(float *dst, float *src, void *stream) { + TCOLSUM_f32_16x128<<<1, nullptr, stream>>>((__gm__ float *)dst, (__gm__ float *)src); +} + +// Case 2: f32 16x256 (input: 16x256, output: 1x256) +extern "C" __global__ AICORE void TCOLSUM_f32_16x256(__gm__ float *dst, __gm__ float *src); + +void LaunchTCOLSUM_f32_16x256(float *dst, float *src, void *stream) { + TCOLSUM_f32_16x256<<<1, nullptr, stream>>>((__gm__ float *)dst, (__gm__ float *)src); +} + +// Case 3: f32 64x128_1 (input: 64x128, output: 1x128) +extern "C" __global__ AICORE void TCOLSUM_f32_64x128_1(__gm__ float *dst, __gm__ float *src); + +void LaunchTCOLSUM_f32_64x128_1(float *dst, float *src, void *stream) { + TCOLSUM_f32_64x128_1<<<1, nullptr, stream>>>((__gm__ float *)dst, (__gm__ float *)src); +} + +// Case 4: f32 64x128_2 (input: 64x128, output: 1x128) +extern "C" __global__ AICORE void TCOLSUM_f32_64x128_2(__gm__ float *dst, __gm__ float *src); + +void LaunchTCOLSUM_f32_64x128_2(float *dst, float *src, void *stream) { + TCOLSUM_f32_64x128_2<<<1, nullptr, stream>>>((__gm__ float *)dst, (__gm__ float *)src); +} + +// Case 5: f32 1x512 (input: 1x512, output: 1x512) +extern "C" __global__ AICORE void TCOLSUM_f32_1x512(__gm__ float *dst, __gm__ float *src); + +void LaunchTCOLSUM_f32_1x512(float *dst, float *src, void *stream) { + TCOLSUM_f32_1x512<<<1, nullptr, stream>>>((__gm__ float *)dst, (__gm__ float *)src); +} + +// Case 6: f16 1x256 (input: 1x256, output: 1x256) +extern "C" __global__ AICORE void TCOLSUM_f16_1x256(__gm__ half *dst, __gm__ half *src); + +void LaunchTCOLSUM_f16_1x256(void *dst, void *src, void *stream) { + TCOLSUM_f16_1x256<<<1, nullptr, stream>>>((__gm__ half *)dst, (__gm__ half *)src); +} + +// Case 7: f16 16x128 (input: 16x128, output: 1x128) +extern "C" __global__ AICORE void TCOLSUM_f16_16x128(__gm__ half *dst, __gm__ half *src); + +void LaunchTCOLSUM_f16_16x128(void *dst, void *src, void *stream) { + TCOLSUM_f16_16x128<<<1, nullptr, stream>>>((__gm__ half *)dst, (__gm__ half *)src); +} + +// Case 8: f16 16x256 (input: 16x256, output: 1x256) +extern "C" __global__ AICORE void TCOLSUM_f16_16x256(__gm__ half *dst, __gm__ half *src); + +void LaunchTCOLSUM_f16_16x256(void *dst, void *src, void *stream) { + TCOLSUM_f16_16x256<<<1, nullptr, stream>>>((__gm__ half *)dst, (__gm__ half *)src); +} + +// Case 9: f16 64x128_1 (input: 64x128, output: 1x128) +extern "C" __global__ AICORE void TCOLSUM_f16_64x128_1(__gm__ half *dst, __gm__ half *src); + +void LaunchTCOLSUM_f16_64x128_1(void *dst, void *src, void *stream) { + TCOLSUM_f16_64x128_1<<<1, nullptr, stream>>>((__gm__ half *)dst, (__gm__ half *)src); +} + +// Case 10: f16 64x128_2 (input: 64x128, output: 1x128) +extern "C" __global__ AICORE void TCOLSUM_f16_64x128_2(__gm__ half *dst, __gm__ half *src); + +void LaunchTCOLSUM_f16_64x128_2(void *dst, void *src, void *stream) { + TCOLSUM_f16_64x128_2<<<1, nullptr, stream>>>((__gm__ half *)dst, (__gm__ half *)src); +} + +// Case 11: i8 1x256 (input: 1x256, output: 1x256) +extern "C" __global__ AICORE void TCOLSUM_i8_1x256(__gm__ int8_t *dst, __gm__ int8_t *src); + +void LaunchTCOLSUM_i8_1x256(void *dst, void *src, void *stream) { + TCOLSUM_i8_1x256<<<1, nullptr, stream>>>((__gm__ int8_t *)dst, (__gm__ int8_t *)src); +} + +// Case 12: i8 16x128 (input: 16x128, output: 1x128) +extern "C" __global__ AICORE void TCOLSUM_i8_16x128(__gm__ int8_t *dst, __gm__ int8_t *src); + +void LaunchTCOLSUM_i8_16x128(void *dst, void *src, void *stream) { + TCOLSUM_i8_16x128<<<1, nullptr, stream>>>((__gm__ int8_t *)dst, (__gm__ int8_t *)src); +} + +// Case 13: i8 16x256 (input: 16x256, output: 1x256) +extern "C" __global__ AICORE void TCOLSUM_i8_16x256(__gm__ int8_t *dst, __gm__ int8_t *src); + +void LaunchTCOLSUM_i8_16x256(void *dst, void *src, void *stream) { + TCOLSUM_i8_16x256<<<1, nullptr, stream>>>((__gm__ int8_t *)dst, (__gm__ int8_t *)src); +} + +// Case 14: i8 64x128_1 (input: 64x128, output: 1x128) +extern "C" __global__ AICORE void TCOLSUM_i8_64x128_1(__gm__ int8_t *dst, __gm__ int8_t *src); + +void LaunchTCOLSUM_i8_64x128_1(void *dst, void *src, void *stream) { + TCOLSUM_i8_64x128_1<<<1, nullptr, stream>>>((__gm__ int8_t *)dst, (__gm__ int8_t *)src); +} + +// Case 15: i8 64x128_2 (input: 64x128, output: 1x128) +extern "C" __global__ AICORE void TCOLSUM_i8_64x128_2(__gm__ int8_t *dst, __gm__ int8_t *src); + +void LaunchTCOLSUM_i8_64x128_2(void *dst, void *src, void *stream) { + TCOLSUM_i8_64x128_2<<<1, nullptr, stream>>>((__gm__ int8_t *)dst, (__gm__ int8_t *)src); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolsum/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcolsum/main.cpp new file mode 100644 index 000000000..eaff88d45 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolsum/main.cpp @@ -0,0 +1,173 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tcolsum ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTCOLSUM_f32_1x256(float *dst, float *src, void *stream); +void LaunchTCOLSUM_f32_16x128(float *dst, float *src, void *stream); +void LaunchTCOLSUM_f32_16x256(float *dst, float *src, void *stream); +void LaunchTCOLSUM_f32_64x128_1(float *dst, float *src, void *stream); +void LaunchTCOLSUM_f32_64x128_2(float *dst, float *src, void *stream); +void LaunchTCOLSUM_f32_1x512(float *dst, float *src, void *stream); +void LaunchTCOLSUM_f16_1x256(void *dst, void *src, void *stream); +void LaunchTCOLSUM_f16_16x128(void *dst, void *src, void *stream); +void LaunchTCOLSUM_f16_16x256(void *dst, void *src, void *stream); +void LaunchTCOLSUM_f16_64x128_1(void *dst, void *src, void *stream); +void LaunchTCOLSUM_f16_64x128_2(void *dst, void *src, void *stream); +void LaunchTCOLSUM_i8_1x256(void *dst, void *src, void *stream); +void LaunchTCOLSUM_i8_16x128(void *dst, void *src, void *stream); +void LaunchTCOLSUM_i8_16x256(void *dst, void *src, void *stream); +void LaunchTCOLSUM_i8_64x128_1(void *dst, void *src, void *stream); +void LaunchTCOLSUM_i8_64x128_2(void *dst, void *src, void *stream); + +using LaunchFnFloat = void (*)(float *, float *, void *); +using LaunchFnVoid = void (*)(void *, void *, void *); + +struct TestCase { + const char *name; + void *launch; + size_t srcRows; + size_t srcCols; + size_t srcValidRows; + size_t srcValidCols; + size_t dstRows; + size_t dstCols; + size_t dstValidCols; + size_t elemSize; + bool isFp16; +}; + +static const TestCase kCases[] = { + {"f32_1x256", (void*)LaunchTCOLSUM_f32_1x256, 1, 256, 1, 255, 1, 256, 255, sizeof(float), false}, + {"f32_16x128", (void*)LaunchTCOLSUM_f32_16x128, 16, 128, 16, 127, 1, 128, 127, sizeof(float), false}, + {"f32_16x256", (void*)LaunchTCOLSUM_f32_16x256, 16, 256, 15, 255, 1, 256, 255, sizeof(float), false}, + {"f32_64x128_1", (void*)LaunchTCOLSUM_f32_64x128_1, 64, 128, 63, 127, 1, 128, 127, sizeof(float), false}, + {"f32_64x128_2", (void*)LaunchTCOLSUM_f32_64x128_2, 64, 128, 64, 128, 1, 128, 128, sizeof(float), false}, + {"f32_1x512", (void*)LaunchTCOLSUM_f32_1x512, 1, 512, 1, 511, 1, 512, 511, sizeof(float), false}, + {"f16_1x256", (void*)LaunchTCOLSUM_f16_1x256, 1, 256, 1, 255, 1, 256, 255, 2, true}, + {"f16_16x128", (void*)LaunchTCOLSUM_f16_16x128, 16, 128, 16, 127, 1, 128, 127, 2, true}, + {"f16_16x256", (void*)LaunchTCOLSUM_f16_16x256, 16, 256, 15, 255, 1, 256, 255, 2, true}, + {"f16_64x128_1", (void*)LaunchTCOLSUM_f16_64x128_1, 64, 128, 63, 127, 1, 128, 127, 2, true}, + {"f16_64x128_2", (void*)LaunchTCOLSUM_f16_64x128_2, 64, 128, 64, 128, 1, 128, 128, 2, true}, + {"i8_1x256", (void*)LaunchTCOLSUM_i8_1x256, 1, 256, 1, 255, 1, 256, 255, 1, true}, + {"i8_16x128", (void*)LaunchTCOLSUM_i8_16x128, 16, 128, 16, 127, 1, 128, 127, 1, true}, + {"i8_16x256", (void*)LaunchTCOLSUM_i8_16x256, 16, 256, 15, 255, 1, 256, 255, 1, true}, + {"i8_64x128_1", (void*)LaunchTCOLSUM_i8_64x128_1, 64, 128, 63, 127, 1, 128, 127, 1, true}, + {"i8_64x128_2", (void*)LaunchTCOLSUM_i8_64x128_2, 64, 128, 64, 128, 1, 128, 128, 1, true}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t srcElemCount = tc.srcRows * tc.srcCols; + const size_t srcFileSize = srcElemCount * tc.elemSize; + const size_t dstElemCount = tc.dstRows * tc.dstCols; + const size_t dstFileSize = dstElemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (src=%zux%zu, dst=%zux%zu, fp16=%d) ===\n", + tc.name, tc.srcRows, tc.srcCols, tc.dstRows, tc.dstCols, tc.isFp16); + + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSizeVar = srcFileSize; + size_t dstFileSizeVar = dstFileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&srcHost, srcFileSize); + aclrtMallocHost(&dstHost, dstFileSize); + + aclrtMalloc(&srcDevice, srcFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input.bin").c_str(), srcFileSizeVar, srcHost, srcFileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, srcFileSize, srcHost, srcFileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + if (tc.isFp16) { + LaunchFnVoid launch = (LaunchFnVoid)tc.launch; + launch(dstDevice, srcDevice, stream); + } else { + LaunchFnFloat launch = (LaunchFnFloat)tc.launch; + launch((float*)dstDevice, (float*)srcDevice, stream); + } + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSizeVar)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcolsum/tcolsum.pto b/test/tilelang_st/npu/a5/src/st/testcase/tcolsum/tcolsum.pto new file mode 100644 index 000000000..b616c758a --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcolsum/tcolsum.pto @@ -0,0 +1,697 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tcolsum: tload(src) + tcolsum(dst, src) + tstore(dst). +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // Case 0: f32 1x256 (input: 1x256, output: 1x256) + func.func @TCOLSUM_f32_1x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xf32> -> !pto.partition_tensor_view<1x1x1x1x255xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xf32> -> !pto.partition_tensor_view<1x1x1x1x255xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xf32>) + outs(%src : !pto.tile_buf) + + pto.tcolsum ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xf32>) + return + } + + // Case 1: f32 16x128 (input: 16x128, output: 1x128) + func.func @TCOLSUM_f32_16x128(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x127xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x127xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xf32>) + outs(%src : !pto.tile_buf) + + pto.tcolsum ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xf32>) + return + } + + // Case 2: f32 16x256 (input: 16x256, output: 1x256) + func.func @TCOLSUM_f32_16x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xf32> -> !pto.partition_tensor_view<1x1x1x15x255xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xf32> -> !pto.partition_tensor_view<1x1x1x1x255xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xf32>) + outs(%src : !pto.tile_buf) + + pto.tcolsum ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xf32>) + return + } + + // Case 3: f32 64x128_1 (input: 64x128, output: 1x128) + func.func @TCOLSUM_f32_64x128_1(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c64, %c128], + strides = [%c8192, %c8192, %c8192, %c128, %c1] + : !pto.tensor_view<1x1x1x64x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c127] + : !pto.tensor_view<1x1x1x64x128xf32> -> !pto.partition_tensor_view<1x1x1x63x127xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x127xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x127xf32>) + outs(%src : !pto.tile_buf) + + pto.tcolsum ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xf32>) + return + } + + // Case 4: f32 64x128_2 (input: 64x128, output: 1x128) + func.func @TCOLSUM_f32_64x128_2(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c64, %c128], + strides = [%c8192, %c8192, %c8192, %c128, %c1] + : !pto.tensor_view<1x1x1x64x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c128] + : !pto.tensor_view<1x1x1x64x128xf32> -> !pto.partition_tensor_view<1x1x1x64x128xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x128xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x64x128xf32>) + outs(%src : !pto.tile_buf) + + pto.tcolsum ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x128xf32>) + return + } + + // Case 5: f32 1x512 (input: 1x512, output: 1x512) + func.func @TCOLSUM_f32_1x512(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c512 = arith.constant 512 : index + %c511 = arith.constant 511 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c512], + strides = [%c512, %c512, %c512, %c512, %c1] + : !pto.tensor_view<1x1x1x1x512xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c512], + strides = [%c512, %c512, %c512, %c512, %c1] + : !pto.tensor_view<1x1x1x1x512xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c511] + : !pto.tensor_view<1x1x1x1x512xf32> -> !pto.partition_tensor_view<1x1x1x1x511xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c511] + : !pto.tensor_view<1x1x1x1x512xf32> -> !pto.partition_tensor_view<1x1x1x1x511xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x511xf32>) + outs(%src : !pto.tile_buf) + + pto.tcolsum ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x511xf32>) + return + } + + // Case 6: f16 1x256 (input: 1x256, output: 1x256) + func.func @TCOLSUM_f16_1x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x255xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x255xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xf16>) + outs(%src : !pto.tile_buf) + + pto.tcolsum ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xf16>) + return + } + + // Case 7: f16 16x128 (input: 16x128, output: 1x128) + func.func @TCOLSUM_f16_16x128(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xf16> -> !pto.partition_tensor_view<1x1x1x16x127xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xf16> -> !pto.partition_tensor_view<1x1x1x1x127xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xf16>) + outs(%src : !pto.tile_buf) + + pto.tcolsum ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xf16>) + return + } + + // Case 8: f16 16x256 (input: 16x256, output: 1x256) + func.func @TCOLSUM_f16_16x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xf16> -> !pto.partition_tensor_view<1x1x1x15x255xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x255xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xf16>) + outs(%src : !pto.tile_buf) + + pto.tcolsum ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xf16>) + return + } + + // Case 9: f16 64x128_1 (input: 64x128, output: 1x128) + func.func @TCOLSUM_f16_64x128_1(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c64, %c128], + strides = [%c8192, %c8192, %c8192, %c128, %c1] + : !pto.tensor_view<1x1x1x64x128xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c127] + : !pto.tensor_view<1x1x1x64x128xf16> -> !pto.partition_tensor_view<1x1x1x63x127xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xf16> -> !pto.partition_tensor_view<1x1x1x1x127xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x127xf16>) + outs(%src : !pto.tile_buf) + + pto.tcolsum ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xf16>) + return + } + + // Case 10: f16 64x128_2 (input: 64x128, output: 1x128) + func.func @TCOLSUM_f16_64x128_2(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c64, %c128], + strides = [%c8192, %c8192, %c8192, %c128, %c1] + : !pto.tensor_view<1x1x1x64x128xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c128] + : !pto.tensor_view<1x1x1x64x128xf16> -> !pto.partition_tensor_view<1x1x1x64x128xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf16> -> !pto.partition_tensor_view<1x1x1x1x128xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x64x128xf16>) + outs(%src : !pto.tile_buf) + + pto.tcolsum ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x128xf16>) + return + } + + // Case 11: i8 1x256 (input: 1x256, output: 1x256) + func.func @TCOLSUM_i8_1x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi8> -> !pto.partition_tensor_view<1x1x1x1x255xi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi8> -> !pto.partition_tensor_view<1x1x1x1x255xi8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x255xi8>) + outs(%src : !pto.tile_buf) + + pto.tcolsum ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi8>) + return + } + + // Case 12: i8 16x128 (input: 16x128, output: 1x128) + func.func @TCOLSUM_i8_16x128(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xi8> -> !pto.partition_tensor_view<1x1x1x16x127xi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xi8> -> !pto.partition_tensor_view<1x1x1x1x127xi8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x127xi8>) + outs(%src : !pto.tile_buf) + + pto.tcolsum ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xi8>) + return + } + + // Case 13: i8 16x256 (input: 16x256, output: 1x256) + func.func @TCOLSUM_i8_16x256(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c255 = arith.constant 255 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c255] + : !pto.tensor_view<1x1x1x16x256xi8> -> !pto.partition_tensor_view<1x1x1x15x255xi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c255] + : !pto.tensor_view<1x1x1x1x256xi8> -> !pto.partition_tensor_view<1x1x1x1x255xi8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x255xi8>) + outs(%src : !pto.tile_buf) + + pto.tcolsum ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x255xi8>) + return + } + + // Case 14: i8 64x128_1 (input: 64x128, output: 1x128) + func.func @TCOLSUM_i8_64x128_1(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c127 = arith.constant 127 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c64, %c128], + strides = [%c8192, %c8192, %c8192, %c128, %c1] + : !pto.tensor_view<1x1x1x64x128xi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c127] + : !pto.tensor_view<1x1x1x64x128xi8> -> !pto.partition_tensor_view<1x1x1x63x127xi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c127] + : !pto.tensor_view<1x1x1x1x128xi8> -> !pto.partition_tensor_view<1x1x1x1x127xi8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x127xi8>) + outs(%src : !pto.tile_buf) + + pto.tcolsum ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x127xi8>) + return + } + + // Case 15: i8 64x128_2 (input: 64x128, output: 1x128) + func.func @TCOLSUM_i8_64x128_2(%dst_ptr: !pto.ptr, %src_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c64, %c128], + strides = [%c8192, %c8192, %c8192, %c128, %c1] + : !pto.tensor_view<1x1x1x64x128xi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c128] + : !pto.tensor_view<1x1x1x64x128xi8> -> !pto.partition_tensor_view<1x1x1x64x128xi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xi8> -> !pto.partition_tensor_view<1x1x1x1x128xi8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x64x128xi8>) + outs(%src : !pto.tile_buf) + + pto.tcolsum ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x128xi8>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcvt/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/CMakeLists.txt new file mode 100644 index 000000000..b117e9a27 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tcvt) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcvt/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/cases.py new file mode 100644 index 000000000..9ad28dac9 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/cases.py @@ -0,0 +1,166 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tcvt ST test cases. + +`dtype` is kept for shared validation compatibility. +Actual data generation and comparison use `src_dtype` / `dst_dtype`. +""" + +import numpy as np +from ml_dtypes import bfloat16 + +# 7 shapes (aligning with C++ INSTANTIATE_TCVT) +SHAPES = [ + (1, 128, 1, 128), + (2, 64, 2, 64), + (4, 32, 4, 32), + (2, 128, 2, 128), + (4, 128, 4, 65), # Partial tiles + (4, 256, 4, 200), # Partial tiles + (1, 256, 1, 129), # Partial tiles +] + +_DTYPE_NAME = { + np.float32: "f32", + np.float16: "f16", + bfloat16: "bf16", + np.int8: "si8", + np.uint8: "ui8", + np.int16: "i16", + "si16": "si16", + np.uint16: "ui16", + np.int32: "i32", + np.uint32: "ui32", + np.int64: "i64", + np.uint64: "ui64", +} + + +def _make_cases(src_dtype, dst_dtype): + """Generate cases of 7 test shapes for src_dtype -> dst_dtype""" + src_name = _DTYPE_NAME.get(src_dtype, src_dtype) + dst_name = _DTYPE_NAME.get(dst_dtype, dst_dtype) + + # eps: f32=1e-6; f16/bf16=1e-3; others=0 + eps_map = {np.float32: 1e-6, np.float16: 1e-3, bfloat16: 1e-3} + eps = eps_map.get(dst_dtype, 0.0) + + cases = [] + for rows, cols, v_rows, v_cols in SHAPES: + shape_name = f"{rows}x{cols}" if v_cols == cols else f"{v_rows}x{v_cols}" + cases.append({ + "name": f"{src_name}_to_{dst_name}_{shape_name}", + "dtype": dst_dtype, + "src_dtype": src_dtype, + "dst_dtype": dst_dtype, + "shape": (rows, cols), + "valid_shape": (v_rows, v_cols), + "eps": eps, + }) + return cases + + +CASES = [ + { + "name": "f32_to_i32_rint_16x64", + "dtype": np.int32, + "src_dtype": np.float32, + "dst_dtype": np.int32, + "shape": (16, 64), + "valid_shape": (16, 64), + "round_mode": "RINT", + "eps": 0.0, + }, + { + "name": "f32_to_i32_round_16x64", + "dtype": np.int32, + "src_dtype": np.float32, + "dst_dtype": np.int32, + "shape": (16, 64), + "valid_shape": (16, 64), + "round_mode": "ROUND", + "eps": 0.0, + }, + { + "name": "i32_to_f32_rint_16x64", + "dtype": np.float32, + "src_dtype": np.int32, + "dst_dtype": np.float32, + "shape": (16, 64), + "valid_shape": (16, 64), + "round_mode": "RINT", + "eps": 1e-6, + }, + { + "name": "f32_to_f16_rint_16x64", + "dtype": np.float16, + "src_dtype": np.float32, + "dst_dtype": np.float16, + "shape": (16, 64), + "valid_shape": (16, 64), + "round_mode": "RINT", + "eps": 1e-3, + }, + { + "name": "f16_to_f32_rint_16x64", + "dtype": np.float32, + "src_dtype": np.float16, + "dst_dtype": np.float32, + "shape": (16, 64), + "valid_shape": (16, 64), + "round_mode": "RINT", + "eps": 1e-6, + }, + # f32 → f16, bf16, i16, i32, i64, f32 + *_make_cases(np.float32, np.float16), + *_make_cases(np.float32, bfloat16), + *_make_cases(np.float32, np.int16), + *_make_cases(np.float32, np.int32), + *_make_cases(np.float32, np.int64), + *_make_cases(np.float32, np.float32), + # f16 → f32, i32, i16, si8, ui8 + *_make_cases(np.float16, np.float32), + *_make_cases(np.float16, np.int32), + *_make_cases(np.float16, np.int16), + *_make_cases(np.float16, np.int8), + *_make_cases(np.float16, np.uint8), + # bf16 → f32, f16, i32 + *_make_cases(bfloat16, np.float32), + *_make_cases(bfloat16, np.float16), + *_make_cases(bfloat16, np.int32), + # ui8 → f16, ui16 + *_make_cases(np.uint8, np.float16), + *_make_cases(np.uint8, np.uint16), + # si8 → f16, si16, i32 + *_make_cases(np.int8, np.float16), + *_make_cases(np.int8, "si16"), + *_make_cases(np.int8, np.int32), + # i16 → ui8, f16, f32, ui32, i32 + *_make_cases(np.int16, np.uint8), + *_make_cases(np.int16, np.float16), + *_make_cases(np.int16, np.float32), + *_make_cases(np.int16, np.uint32), + *_make_cases(np.int16, np.int32), + # i32 → f32, i16, i64, ui8, ui16 + *_make_cases(np.int32, np.float32), + *_make_cases(np.int32, np.int16), + *_make_cases(np.int32, np.int64), + *_make_cases(np.int32, np.uint8), + *_make_cases(np.int32, np.uint16), + # ui32 → i16, ui16, ui8 + *_make_cases(np.uint32, np.int16), + *_make_cases(np.uint32, np.uint16), + *_make_cases(np.uint32, np.uint8), + # i64 → f32, i32 + *_make_cases(np.int64, np.float32), + *_make_cases(np.int64, np.int32), +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcvt/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/compare.py new file mode 100644 index 000000000..f3468d12b --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/compare.py @@ -0,0 +1,57 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +_STR_DTYPE_MAP = {"si16": np.int16} + +def normalize_dtype(dtype): + return _STR_DTYPE_MAP.get(dtype, dtype) + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + dst_dtype = case["dst_dtype"] + dst_dtype = normalize_dtype(dst_dtype) + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=dst_dtype).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=dst_dtype).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcvt/gen_cpp.py b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/gen_cpp.py new file mode 100644 index 000000000..0cc4d4d68 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/gen_cpp.py @@ -0,0 +1,234 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""Script to generate launch.cpp & main.cpp""" + +import numpy as np +import cases +from cases import bfloat16 + +_DTYPE_TO_CPP = { + np.float32: "float", + np.float16: "uint16_t", + bfloat16: "uint16_t", + np.int8: "int8_t", + np.uint8: "uint8_t", + np.int16: "int16_t", + "si16": "int16_t", + np.uint16: "uint16_t", + np.int32: "int32_t", + np.uint32: "uint32_t", + np.int64: "int64_t", + np.uint64: "uint64_t", +} + +def gen_launch(): + lines = [ + "// Copyright (c) 2026 Huawei Technologies Co., Ltd.", + "// This program is free software, you can redistribute it and/or modify it under the terms and conditions of", + '// CANN Open Software License Agreement Version 2.0 (the "License").', + "// Please refer to the License for details. You may not use this file except in compliance with the License.", + '// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,', + "// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.", + "// See LICENSE in the root of the software repository for the full text of the License.", + "", + "#include ", + "", + "#ifndef AICORE", + "#define AICORE [aicore]", + "#endif", + "", + ] + + extern_decls = [] + launch_funcs = [] + + for c in cases.CASES: + name = c["name"] + src_cpp = _DTYPE_TO_CPP.get(c["src_dtype"], "float") + dst_cpp = _DTYPE_TO_CPP.get(c["dst_dtype"], "float") + + extern_decls.append(f'extern "C" __global__ AICORE void TCVT_{name}(__gm__ {src_cpp} *src, __gm__ {dst_cpp} *dst);') + launch_funcs.append(f"void LaunchTCVT_{name}(void *src, void *dst, void *stream) {{") + launch_funcs.append(f" TCVT_{name}<<<1, nullptr, stream>>>((__gm__ {src_cpp} *)src, (__gm__ {dst_cpp} *)dst);") + launch_funcs.append("}") + launch_funcs.append("") + + lines.extend(extern_decls) + lines.append("") + lines.extend(launch_funcs) + + return "\n".join(lines) + +def gen_main(): + lines = [ + "// Copyright (c) 2026 Huawei Technologies Co., Ltd.", + "// This program is free software, you can redistribute it and/or modify it under the terms and conditions of", + '// CANN Open Software License Agreement Version 2.0 (the "License").', + "// Please refer to the License for details. You may not use this file except in compliance with the License.", + '// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,', + "// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.", + "// See LICENSE in the root of the software repository for the full text of the License.", + "", + '#include "acl/acl.h"', + '#include "test_common.h"', + "#include ", + "#include ", + "#include ", + "#include ", + "#include ", + "", + "using namespace PtoTestCommon;", + "", + ] + + decls = [] + for c in cases.CASES: + decls.append(f"void LaunchTCVT_{c['name']}(void *src, void *dst, void *stream);") + + lines.extend(decls) + lines.extend([ + "", + "using LaunchFn = void (*)(void *, void *, void *);", + "", + "struct TestCase {", + " const char *name;", + " LaunchFn launch;", + " size_t srcRows;", + " size_t srcCols;", + " size_t dstRows;", + " size_t dstCols;", + " size_t srcElemSize;", + " size_t dstElemSize;", + "};", + "", + "static const TestCase kCases[] = {", + ]) + + case_entries = [] + for c in cases.CASES: + name = c["name"] + src_cpp = _DTYPE_TO_CPP.get(c["src_dtype"], "float") + dst_cpp = _DTYPE_TO_CPP.get(c["dst_dtype"], "float") + rows, cols = c["shape"] + case_entries.append(f' {{"{name}", LaunchTCVT_{name}, {rows}, {cols}, {rows}, {cols}, sizeof({src_cpp}), sizeof({dst_cpp})}},') + + lines.extend(case_entries) + lines.extend([ + "};", + "static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]);", + "", + ]) + + # RunCase 和 main 函数保持不变 + lines.extend([ + "static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) {", + " (void)deviceId;", + " int rc = 0;", + " const size_t srcElemCount = tc.srcRows * tc.srcCols;", + " const size_t dstElemCount = tc.dstRows * tc.dstCols;", + " size_t srcFileSize = srcElemCount * tc.srcElemSize;", + " size_t dstFileSize = dstElemCount * tc.dstElemSize;", + "", + ' std::printf("[INFO] === case: %s (src=%zux%zu, dst=%zux%zu) ===\\n",', + " tc.name, tc.srcRows, tc.srcCols, tc.dstRows, tc.dstCols);", + "", + ' std::string caseDir = std::string("./") + tc.name;', + "", + " void *srcHost = nullptr;", + " void *dstHost = nullptr;", + " void *srcDevice = nullptr;", + " void *dstDevice = nullptr;", + "", + " aclrtMallocHost(&srcHost, srcFileSize);", + " aclrtMallocHost(&dstHost, dstFileSize);", + "", + " aclrtMalloc(&srcDevice, srcFileSize, ACL_MEM_MALLOC_HUGE_FIRST);", + " aclrtMalloc(&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST);", + "", + ' if (!ReadFile((caseDir + "/input.bin").c_str(), srcFileSize, srcHost, srcFileSize)) {', + ' std::fprintf(stderr, "[ERROR] failed to read %s/input.bin\\n", caseDir.c_str());', + " rc = 1;", + " }", + "", + " if (rc == 0) {", + " aclrtMemcpy(srcDevice, srcFileSize, srcHost, srcFileSize, ACL_MEMCPY_HOST_TO_DEVICE);", + " tc.launch(srcDevice, dstDevice, stream);", + " aclrtSynchronizeStream(stream);", + " aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST);", + " }", + "", + ' if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) {', + ' std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\\n", caseDir.c_str());', + " rc = 1;", + " }", + "", + " if (srcDevice != nullptr)", + " aclrtFree(srcDevice);", + " if (dstDevice != nullptr)", + " aclrtFree(dstDevice);", + " if (srcHost != nullptr)", + " aclrtFreeHost(srcHost);", + " if (dstHost != nullptr)", + " aclrtFreeHost(dstHost);", + "", + " if (rc == 0)", + ' std::printf("[INFO] case %s done\\n", tc.name);', + " return rc;", + "}", + "", + "int main(int argc, char *argv[]) {", + " const char *caseFilter = (argc > 1) ? argv[1] : nullptr;", + "", + " int rc = 0;", + " int deviceId = 0;", + " aclrtStream stream = nullptr;", + "", + " aclInit(nullptr);", + ' if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) {', + " deviceId = std::atoi(envDevice);", + " }", + " aclrtSetDevice(deviceId);", + " aclrtCreateStream(&stream);", + "", + " for (size_t i = 0; i < kNumCases; ++i) {", + " if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) {", + " continue;", + " }", + " int ret = RunCase(kCases[i], deviceId, stream);", + " if (ret != 0) {", + ' std::fprintf(stderr, "[ERROR] case %s failed\\n", kCases[i].name);', + " rc = 1;", + " break;", + " }", + " }", + "", + " if (stream != nullptr)", + " aclrtDestroyStream(stream);", + " aclrtResetDevice(deviceId);", + " aclFinalize();", + "", + " return rc;", + "}", + "" + ]) + + return "\n".join(lines) + +if __name__ == "__main__": + from pathlib import Path + HERE = Path(__file__).parent + + with open(HERE / "launch.cpp", "w") as f: + f.write(gen_launch()) + print(f"Generated {(HERE / 'launch.cpp').as_posix()!r}") + + with open(HERE / "main.cpp", "w") as f: + f.write(gen_main()) + print(f"Generated {(HERE / 'main.cpp').as_posix()!r}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcvt/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/gen_data.py new file mode 100644 index 000000000..8772ae584 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/gen_data.py @@ -0,0 +1,180 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +import ml_dtypes + +from cases import CASES +from compare import normalize_dtype +from st_common import save_case_data, setup_case_rng, validate_cases + + +def is_sub_float(dtype): + return np.issubdtype(dtype, np.floating) or dtype == ml_dtypes.bfloat16 + + +def is_sub_int(dtype): + return np.issubdtype(dtype, np.integer) + + +def _make_input_inner(src_dtype, shape): + total = int(np.prod(shape)) + float_types = (np.float32, np.float16, ml_dtypes.bfloat16) + int8_like_types = (np.int8, ) + + # Generate input data + if src_dtype in float_types: + return (np.random.random([total]) * 200 - 100) + elif src_dtype in int8_like_types: + return np.random.randint(-128, 128, [total]) + elif src_dtype == np.uint8: + return np.random.randint(0, 256, [total]) + elif src_dtype == np.int16: + return np.random.randint(-1000, 1000, [total]) + elif src_dtype == np.uint16: + return np.random.randint(0, 10000, [total]) + elif src_dtype in (np.int32, np.int64): + return np.random.randint(-10000, 10000, [total]) + elif src_dtype == np.uint32: + return np.random.randint(0, 10000, [total]) + else: + return np.random.randint(-10000, 10000, [total]) + + +def make_input(src_dtype, shape): + return _make_input_inner(src_dtype, shape).astype(normalize_dtype(src_dtype)).reshape(shape) + + +def round_half_away_from_zero(values): + return np.copysign(np.floor(np.abs(values) + 0.5), values) + + +def default_saturation_off(src_dtype, dst_dtype): + """Mirror the current A5 default saturation policy for supported pairs.""" + return ( + (src_dtype is np.float16 and dst_dtype is np.uint8) + or (src_dtype is np.float16 and dst_dtype is np.int8) + or (src_dtype is np.float32 and dst_dtype is np.int16) + or (src_dtype is np.float16 and dst_dtype is np.int16) + or (src_dtype is np.int64 and dst_dtype is np.int32) + or (src_dtype is np.int32 and dst_dtype is np.int16) + ) + + +def apply_round_mode(values, round_mode): + rounding_funcs = { + "RINT": np.rint, + "ROUND": round_half_away_from_zero, + "FLOOR": np.floor, + "CEIL": np.ceil, + "TRUNC": np.trunc, + } + return rounding_funcs.get(round_mode, np.rint)(values) + + +def convert(values: np.ndarray, src_dtype, dst_dtype, round_mode=None): + is_float_src = is_sub_float(src_dtype) + is_int_dst = is_sub_int(dst_dtype) + is_f32_to_f32 = src_dtype == np.float32 and dst_dtype == np.float32 + needs_rounding = is_float_src and (is_int_dst or is_f32_to_f32) + + if needs_rounding: + values = apply_round_mode(values, round_mode or "RINT") + + if is_int_dst: + # Determine if this conversion has default saturation OFF (truncation) or ON (clamping) + if default_saturation_off(src_dtype, dst_dtype): + # OFF (truncation): bit extraction - wrap around using modulo + return truncate_to_int(values, dst_dtype) + else: + # Saturation ON: clamp to range (widen to int64/float64 to preserve sign) + return clamp_to_range_int(values, dst_dtype) + elif is_sub_float(dst_dtype): + return clamp_to_range_float(values, dst_dtype) + else: + return values.astype(dst_dtype) + + +def truncate_to_int(values: np.ndarray, dst_dtype): + golden_list = [] + for val in values.flat: + int_val = 0 if np.isnan(val) or np.isinf(val) else int(np.int64(val)) + + if dst_dtype == np.int8: + byte_val = int_val & 0xFF + truncated_val = byte_val if byte_val < 128 else byte_val - 256 + elif dst_dtype == np.uint8: + truncated_val = int_val & 0xFF + elif dst_dtype == np.int16: + word_val = int_val & 0xFFFF + truncated_val = word_val if word_val < 32768 else word_val - 65536 + elif dst_dtype == np.int32: + dword_val = int_val & 0xFFFFFFFF + truncated_val = dword_val if dword_val < 2147483648 else dword_val - 4294967296 + else: + truncated_val = int_val + golden_list.append(truncated_val) + return np.array(golden_list, dtype=dst_dtype).reshape(values.shape) + + +def clamp_to_range_int(values: np.ndarray, dst_dtype): + info = ml_dtypes.iinfo(dst_dtype) + is_int_type = is_sub_int(values.dtype) + temp_dtype = np.int64 if is_int_type else np.float64 + widened = values.astype(temp_dtype, copy=False) + return np.clip(widened, info.min, info.max).astype(dst_dtype) + + +def clamp_to_range_float(values: np.ndarray, dst_dtype): + info = ml_dtypes.finfo(dst_dtype) + return np.clip(values, info.min, info.max).astype(dst_dtype) + + +def apply_valid_shape(values: np.ndarray, valid_shape, dst_dtype): + vr, vc = valid_shape + masked = np.zeros_like(values, dtype=dst_dtype) + masked[:vr, :vc] = values[:vr, :vc] + return masked + +def generate_golden(case): + src_dtype = case["src_dtype"] + dst_dtype = case["dst_dtype"] + src_dtype_norm = normalize_dtype(src_dtype) + dst_dtype_norm = normalize_dtype(dst_dtype) + shape = case["shape"] + round_mode = case.get("round_mode") + + input_arr = make_input(src_dtype, shape) + converted = convert(input_arr, src_dtype_norm, dst_dtype_norm, round_mode) + golden = apply_valid_shape(converted, case["valid_shape"], dst_dtype_norm) + + return input_arr, golden + + +if __name__ == "__main__": + np.random.seed(19) + + validate_cases(CASES) + + for case in CASES: + setup_case_rng(case) + input_arr, golden = generate_golden(case) + + save_case_data(case["name"], {"input": input_arr, "golden": golden}) + src_dtype = case["src_dtype"] + dst_dtype = case["dst_dtype"] + src_name = src_dtype.__name__ if isinstance(src_dtype, type) else src_dtype + dst_name = dst_dtype.__name__ if isinstance(dst_dtype, type) else dst_dtype + print( + f"[INFO] gen_data: {case['name']} shape={case['shape']} " + f"src_dtype={src_name} dst_dtype={dst_name} " + f"round_mode={case.get('round_mode')}" + ) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcvt/gen_tcvt_pto.py b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/gen_tcvt_pto.py new file mode 100644 index 000000000..988097bf0 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/gen_tcvt_pto.py @@ -0,0 +1,114 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""Script to generate tcvt.pto""" + +import cases + +def gen_rmode_attr(rmode): + return f"#pto" + +def gen_kernel(case, idx=0): + src_dtype = cases._DTYPE_NAME.get(case["src_dtype"], case["src_dtype"]) + dst_dtype = cases._DTYPE_NAME.get(case["dst_dtype"], case["dst_dtype"]) + rows, cols = case["shape"] + v_rows, v_cols = case["valid_shape"] + + shape_suffix = f"{rows}x{cols}" if v_cols == cols else f"{v_rows}x{v_cols}" + kernel_name = f"TCVT_{src_dtype}_to_{dst_dtype}_{shape_suffix}" + + rmode = "RINT" + rmode_command = "default RINT" + if "round_mode" in case: + rmode = case['round_mode'] + kernel_name = f"TCVT_{src_dtype}_to_{dst_dtype}_{rmode.lower()}_{shape_suffix}" + if rmode != "RINT": + rmode_command = f"explicit {rmode}" + + stride = rows * cols + + tile_valid = "" if v_rows == rows and v_cols == cols else f", valid={v_rows}x{v_cols}" + tile_src = f"!pto.tile_buf" + tile_dst = f"!pto.tile_buf" + + const_vals = sorted(set([0, 1, rows, cols, v_rows, v_cols, stride])) + longest_const = len(str(const_vals[-1])) + const_defs = [f" %c{i:<{longest_const}} = arith.constant {i:<{longest_const}} : index" for i in const_vals] + + lines = [ + f" // Case {idx}: {src_dtype} -> {dst_dtype}, {rmode_command}", + f" func.func @{kernel_name}(%src_ptr: !pto.ptr<{src_dtype}>, %dst_ptr: !pto.ptr<{dst_dtype}>) attributes {{ pto.entry }} {{", + ] + lines.extend(const_defs) + lines.extend([ + "", + f" %src_view = pto.make_tensor_view %src_ptr,", + f" shape = [%c1, %c1, %c1, %c{rows}, %c{cols}],", + f" strides = [%c{stride}, %c{stride}, %c{stride}, %c{cols}, %c1]", + f" : !pto.tensor_view<1x1x1x{rows}x{cols}x{src_dtype}>", + f" %dst_view = pto.make_tensor_view %dst_ptr,", + f" shape = [%c1, %c1, %c1, %c{rows}, %c{cols}],", + f" strides = [%c{stride}, %c{stride}, %c{stride}, %c{cols}, %c1]", + f" : !pto.tensor_view<1x1x1x{rows}x{cols}x{dst_dtype}>", + "", + f" %src_part = pto.partition_view %src_view,", + f" offsets = [%c0, %c0, %c0, %c0, %c0],", + f" sizes = [%c1, %c1, %c1, %c{v_rows}, %c{v_cols}]", + f" : !pto.tensor_view<1x1x1x{rows}x{cols}x{src_dtype}> -> !pto.partition_tensor_view<1x1x1x{v_rows}x{v_cols}x{src_dtype}>", + f" %dst_part = pto.partition_view %dst_view,", + f" offsets = [%c0, %c0, %c0, %c0, %c0],", + f" sizes = [%c1, %c1, %c1, %c{v_rows}, %c{v_cols}]", + f" : !pto.tensor_view<1x1x1x{rows}x{cols}x{dst_dtype}> -> !pto.partition_tensor_view<1x1x1x{v_rows}x{v_cols}x{dst_dtype}>", + "", + f" %src = pto.alloc_tile", + f" : {tile_src}", + f" %dst = pto.alloc_tile", + f" : {tile_dst}", + "", + f" pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x{v_rows}x{v_cols}x{src_dtype}>)", + f" outs(%src : {tile_src})", + "", + f" pto.tcvt ins(%src {{rmode = {gen_rmode_attr(rmode)}}} : {tile_src})" if rmode != "RINT" else f" pto.tcvt ins(%src : {tile_src})", + f" outs(%dst : {tile_dst})", + "", + f" pto.tstore ins(%dst : {tile_dst})", + f" outs(%dst_part : !pto.partition_tensor_view<1x1x1x{v_rows}x{v_cols}x{dst_dtype}>)", + f" return", + f" }}", + "" + ]) + return "\n".join(lines) + +header = """// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tcvt. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. +// Generated by gen_tcvt_pto.py from cases.py. + +module { +""" + +footer = "\n}\n" + +if __name__ == "__main__": + from pathlib import Path + HERE = Path(__file__).parent + + with open(HERE / "tcvt.pto", "w") as f: + f.write(header) + f.write("\n".join(gen_kernel(case, idx) for idx, case in enumerate(cases.CASES))) + f.write(footer) + print(f"Generated {(HERE / 'tcvt.pto').as_posix()!r}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcvt/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/launch.cpp new file mode 100644 index 000000000..7cf802406 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/launch.cpp @@ -0,0 +1,1229 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +extern "C" __global__ AICORE void TCVT_f32_to_i32_rint_16x64(__gm__ float *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCVT_f32_to_i32_round_16x64(__gm__ float *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCVT_i32_to_f32_rint_16x64(__gm__ int32_t *src, __gm__ float *dst); +extern "C" __global__ AICORE void TCVT_f32_to_f16_rint_16x64(__gm__ float *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_f16_to_f32_rint_16x64(__gm__ uint16_t *src, __gm__ float *dst); +extern "C" __global__ AICORE void TCVT_f32_to_f16_1x128(__gm__ float *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_f32_to_f16_2x64(__gm__ float *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_f32_to_f16_4x32(__gm__ float *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_f32_to_f16_2x128(__gm__ float *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_f32_to_f16_4x65(__gm__ float *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_f32_to_f16_4x200(__gm__ float *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_f32_to_f16_1x129(__gm__ float *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_f32_to_bf16_1x128(__gm__ float *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_f32_to_bf16_2x64(__gm__ float *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_f32_to_bf16_4x32(__gm__ float *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_f32_to_bf16_2x128(__gm__ float *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_f32_to_bf16_4x65(__gm__ float *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_f32_to_bf16_4x200(__gm__ float *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_f32_to_bf16_1x129(__gm__ float *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_f32_to_i16_1x128(__gm__ float *src, __gm__ int16_t *dst); +extern "C" __global__ AICORE void TCVT_f32_to_i16_2x64(__gm__ float *src, __gm__ int16_t *dst); +extern "C" __global__ AICORE void TCVT_f32_to_i16_4x32(__gm__ float *src, __gm__ int16_t *dst); +extern "C" __global__ AICORE void TCVT_f32_to_i16_2x128(__gm__ float *src, __gm__ int16_t *dst); +extern "C" __global__ AICORE void TCVT_f32_to_i16_4x65(__gm__ float *src, __gm__ int16_t *dst); +extern "C" __global__ AICORE void TCVT_f32_to_i16_4x200(__gm__ float *src, __gm__ int16_t *dst); +extern "C" __global__ AICORE void TCVT_f32_to_i16_1x129(__gm__ float *src, __gm__ int16_t *dst); +extern "C" __global__ AICORE void TCVT_f32_to_i32_1x128(__gm__ float *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCVT_f32_to_i32_2x64(__gm__ float *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCVT_f32_to_i32_4x32(__gm__ float *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCVT_f32_to_i32_2x128(__gm__ float *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCVT_f32_to_i32_4x65(__gm__ float *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCVT_f32_to_i32_4x200(__gm__ float *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCVT_f32_to_i32_1x129(__gm__ float *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCVT_f32_to_i64_1x128(__gm__ float *src, __gm__ int64_t *dst); +extern "C" __global__ AICORE void TCVT_f32_to_i64_2x64(__gm__ float *src, __gm__ int64_t *dst); +extern "C" __global__ AICORE void TCVT_f32_to_i64_4x32(__gm__ float *src, __gm__ int64_t *dst); +extern "C" __global__ AICORE void TCVT_f32_to_i64_2x128(__gm__ float *src, __gm__ int64_t *dst); +extern "C" __global__ AICORE void TCVT_f32_to_i64_4x65(__gm__ float *src, __gm__ int64_t *dst); +extern "C" __global__ AICORE void TCVT_f32_to_i64_4x200(__gm__ float *src, __gm__ int64_t *dst); +extern "C" __global__ AICORE void TCVT_f32_to_i64_1x129(__gm__ float *src, __gm__ int64_t *dst); +extern "C" __global__ AICORE void TCVT_f32_to_f32_1x128(__gm__ float *src, __gm__ float *dst); +extern "C" __global__ AICORE void TCVT_f32_to_f32_2x64(__gm__ float *src, __gm__ float *dst); +extern "C" __global__ AICORE void TCVT_f32_to_f32_4x32(__gm__ float *src, __gm__ float *dst); +extern "C" __global__ AICORE void TCVT_f32_to_f32_2x128(__gm__ float *src, __gm__ float *dst); +extern "C" __global__ AICORE void TCVT_f32_to_f32_4x65(__gm__ float *src, __gm__ float *dst); +extern "C" __global__ AICORE void TCVT_f32_to_f32_4x200(__gm__ float *src, __gm__ float *dst); +extern "C" __global__ AICORE void TCVT_f32_to_f32_1x129(__gm__ float *src, __gm__ float *dst); +extern "C" __global__ AICORE void TCVT_f16_to_f32_1x128(__gm__ uint16_t *src, __gm__ float *dst); +extern "C" __global__ AICORE void TCVT_f16_to_f32_2x64(__gm__ uint16_t *src, __gm__ float *dst); +extern "C" __global__ AICORE void TCVT_f16_to_f32_4x32(__gm__ uint16_t *src, __gm__ float *dst); +extern "C" __global__ AICORE void TCVT_f16_to_f32_2x128(__gm__ uint16_t *src, __gm__ float *dst); +extern "C" __global__ AICORE void TCVT_f16_to_f32_4x65(__gm__ uint16_t *src, __gm__ float *dst); +extern "C" __global__ AICORE void TCVT_f16_to_f32_4x200(__gm__ uint16_t *src, __gm__ float *dst); +extern "C" __global__ AICORE void TCVT_f16_to_f32_1x129(__gm__ uint16_t *src, __gm__ float *dst); +extern "C" __global__ AICORE void TCVT_f16_to_i32_1x128(__gm__ uint16_t *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCVT_f16_to_i32_2x64(__gm__ uint16_t *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCVT_f16_to_i32_4x32(__gm__ uint16_t *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCVT_f16_to_i32_2x128(__gm__ uint16_t *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCVT_f16_to_i32_4x65(__gm__ uint16_t *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCVT_f16_to_i32_4x200(__gm__ uint16_t *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCVT_f16_to_i32_1x129(__gm__ uint16_t *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCVT_f16_to_i16_1x128(__gm__ uint16_t *src, __gm__ int16_t *dst); +extern "C" __global__ AICORE void TCVT_f16_to_i16_2x64(__gm__ uint16_t *src, __gm__ int16_t *dst); +extern "C" __global__ AICORE void TCVT_f16_to_i16_4x32(__gm__ uint16_t *src, __gm__ int16_t *dst); +extern "C" __global__ AICORE void TCVT_f16_to_i16_2x128(__gm__ uint16_t *src, __gm__ int16_t *dst); +extern "C" __global__ AICORE void TCVT_f16_to_i16_4x65(__gm__ uint16_t *src, __gm__ int16_t *dst); +extern "C" __global__ AICORE void TCVT_f16_to_i16_4x200(__gm__ uint16_t *src, __gm__ int16_t *dst); +extern "C" __global__ AICORE void TCVT_f16_to_i16_1x129(__gm__ uint16_t *src, __gm__ int16_t *dst); +extern "C" __global__ AICORE void TCVT_f16_to_si8_1x128(__gm__ uint16_t *src, __gm__ int8_t *dst); +extern "C" __global__ AICORE void TCVT_f16_to_si8_2x64(__gm__ uint16_t *src, __gm__ int8_t *dst); +extern "C" __global__ AICORE void TCVT_f16_to_si8_4x32(__gm__ uint16_t *src, __gm__ int8_t *dst); +extern "C" __global__ AICORE void TCVT_f16_to_si8_2x128(__gm__ uint16_t *src, __gm__ int8_t *dst); +extern "C" __global__ AICORE void TCVT_f16_to_si8_4x65(__gm__ uint16_t *src, __gm__ int8_t *dst); +extern "C" __global__ AICORE void TCVT_f16_to_si8_4x200(__gm__ uint16_t *src, __gm__ int8_t *dst); +extern "C" __global__ AICORE void TCVT_f16_to_si8_1x129(__gm__ uint16_t *src, __gm__ int8_t *dst); +extern "C" __global__ AICORE void TCVT_f16_to_ui8_1x128(__gm__ uint16_t *src, __gm__ uint8_t *dst); +extern "C" __global__ AICORE void TCVT_f16_to_ui8_2x64(__gm__ uint16_t *src, __gm__ uint8_t *dst); +extern "C" __global__ AICORE void TCVT_f16_to_ui8_4x32(__gm__ uint16_t *src, __gm__ uint8_t *dst); +extern "C" __global__ AICORE void TCVT_f16_to_ui8_2x128(__gm__ uint16_t *src, __gm__ uint8_t *dst); +extern "C" __global__ AICORE void TCVT_f16_to_ui8_4x65(__gm__ uint16_t *src, __gm__ uint8_t *dst); +extern "C" __global__ AICORE void TCVT_f16_to_ui8_4x200(__gm__ uint16_t *src, __gm__ uint8_t *dst); +extern "C" __global__ AICORE void TCVT_f16_to_ui8_1x129(__gm__ uint16_t *src, __gm__ uint8_t *dst); +extern "C" __global__ AICORE void TCVT_bf16_to_f32_1x128(__gm__ uint16_t *src, __gm__ float *dst); +extern "C" __global__ AICORE void TCVT_bf16_to_f32_2x64(__gm__ uint16_t *src, __gm__ float *dst); +extern "C" __global__ AICORE void TCVT_bf16_to_f32_4x32(__gm__ uint16_t *src, __gm__ float *dst); +extern "C" __global__ AICORE void TCVT_bf16_to_f32_2x128(__gm__ uint16_t *src, __gm__ float *dst); +extern "C" __global__ AICORE void TCVT_bf16_to_f32_4x65(__gm__ uint16_t *src, __gm__ float *dst); +extern "C" __global__ AICORE void TCVT_bf16_to_f32_4x200(__gm__ uint16_t *src, __gm__ float *dst); +extern "C" __global__ AICORE void TCVT_bf16_to_f32_1x129(__gm__ uint16_t *src, __gm__ float *dst); +extern "C" __global__ AICORE void TCVT_bf16_to_f16_1x128(__gm__ uint16_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_bf16_to_f16_2x64(__gm__ uint16_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_bf16_to_f16_4x32(__gm__ uint16_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_bf16_to_f16_2x128(__gm__ uint16_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_bf16_to_f16_4x65(__gm__ uint16_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_bf16_to_f16_4x200(__gm__ uint16_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_bf16_to_f16_1x129(__gm__ uint16_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_bf16_to_i32_1x128(__gm__ uint16_t *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCVT_bf16_to_i32_2x64(__gm__ uint16_t *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCVT_bf16_to_i32_4x32(__gm__ uint16_t *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCVT_bf16_to_i32_2x128(__gm__ uint16_t *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCVT_bf16_to_i32_4x65(__gm__ uint16_t *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCVT_bf16_to_i32_4x200(__gm__ uint16_t *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCVT_bf16_to_i32_1x129(__gm__ uint16_t *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCVT_ui8_to_f16_1x128(__gm__ uint8_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_ui8_to_f16_2x64(__gm__ uint8_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_ui8_to_f16_4x32(__gm__ uint8_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_ui8_to_f16_2x128(__gm__ uint8_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_ui8_to_f16_4x65(__gm__ uint8_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_ui8_to_f16_4x200(__gm__ uint8_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_ui8_to_f16_1x129(__gm__ uint8_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_ui8_to_ui16_1x128(__gm__ uint8_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_ui8_to_ui16_2x64(__gm__ uint8_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_ui8_to_ui16_4x32(__gm__ uint8_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_ui8_to_ui16_2x128(__gm__ uint8_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_ui8_to_ui16_4x65(__gm__ uint8_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_ui8_to_ui16_4x200(__gm__ uint8_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_ui8_to_ui16_1x129(__gm__ uint8_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_si8_to_f16_1x128(__gm__ int8_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_si8_to_f16_2x64(__gm__ int8_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_si8_to_f16_4x32(__gm__ int8_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_si8_to_f16_2x128(__gm__ int8_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_si8_to_f16_4x65(__gm__ int8_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_si8_to_f16_4x200(__gm__ int8_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_si8_to_f16_1x129(__gm__ int8_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_si8_to_si16_1x128(__gm__ int8_t *src, __gm__ int16_t *dst); +extern "C" __global__ AICORE void TCVT_si8_to_si16_2x64(__gm__ int8_t *src, __gm__ int16_t *dst); +extern "C" __global__ AICORE void TCVT_si8_to_si16_4x32(__gm__ int8_t *src, __gm__ int16_t *dst); +extern "C" __global__ AICORE void TCVT_si8_to_si16_2x128(__gm__ int8_t *src, __gm__ int16_t *dst); +extern "C" __global__ AICORE void TCVT_si8_to_si16_4x65(__gm__ int8_t *src, __gm__ int16_t *dst); +extern "C" __global__ AICORE void TCVT_si8_to_si16_4x200(__gm__ int8_t *src, __gm__ int16_t *dst); +extern "C" __global__ AICORE void TCVT_si8_to_si16_1x129(__gm__ int8_t *src, __gm__ int16_t *dst); +extern "C" __global__ AICORE void TCVT_si8_to_i32_1x128(__gm__ int8_t *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCVT_si8_to_i32_2x64(__gm__ int8_t *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCVT_si8_to_i32_4x32(__gm__ int8_t *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCVT_si8_to_i32_2x128(__gm__ int8_t *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCVT_si8_to_i32_4x65(__gm__ int8_t *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCVT_si8_to_i32_4x200(__gm__ int8_t *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCVT_si8_to_i32_1x129(__gm__ int8_t *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCVT_i16_to_ui8_1x128(__gm__ int16_t *src, __gm__ uint8_t *dst); +extern "C" __global__ AICORE void TCVT_i16_to_ui8_2x64(__gm__ int16_t *src, __gm__ uint8_t *dst); +extern "C" __global__ AICORE void TCVT_i16_to_ui8_4x32(__gm__ int16_t *src, __gm__ uint8_t *dst); +extern "C" __global__ AICORE void TCVT_i16_to_ui8_2x128(__gm__ int16_t *src, __gm__ uint8_t *dst); +extern "C" __global__ AICORE void TCVT_i16_to_ui8_4x65(__gm__ int16_t *src, __gm__ uint8_t *dst); +extern "C" __global__ AICORE void TCVT_i16_to_ui8_4x200(__gm__ int16_t *src, __gm__ uint8_t *dst); +extern "C" __global__ AICORE void TCVT_i16_to_ui8_1x129(__gm__ int16_t *src, __gm__ uint8_t *dst); +extern "C" __global__ AICORE void TCVT_i16_to_f16_1x128(__gm__ int16_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_i16_to_f16_2x64(__gm__ int16_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_i16_to_f16_4x32(__gm__ int16_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_i16_to_f16_2x128(__gm__ int16_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_i16_to_f16_4x65(__gm__ int16_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_i16_to_f16_4x200(__gm__ int16_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_i16_to_f16_1x129(__gm__ int16_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_i16_to_f32_1x128(__gm__ int16_t *src, __gm__ float *dst); +extern "C" __global__ AICORE void TCVT_i16_to_f32_2x64(__gm__ int16_t *src, __gm__ float *dst); +extern "C" __global__ AICORE void TCVT_i16_to_f32_4x32(__gm__ int16_t *src, __gm__ float *dst); +extern "C" __global__ AICORE void TCVT_i16_to_f32_2x128(__gm__ int16_t *src, __gm__ float *dst); +extern "C" __global__ AICORE void TCVT_i16_to_f32_4x65(__gm__ int16_t *src, __gm__ float *dst); +extern "C" __global__ AICORE void TCVT_i16_to_f32_4x200(__gm__ int16_t *src, __gm__ float *dst); +extern "C" __global__ AICORE void TCVT_i16_to_f32_1x129(__gm__ int16_t *src, __gm__ float *dst); +extern "C" __global__ AICORE void TCVT_i16_to_ui32_1x128(__gm__ int16_t *src, __gm__ uint32_t *dst); +extern "C" __global__ AICORE void TCVT_i16_to_ui32_2x64(__gm__ int16_t *src, __gm__ uint32_t *dst); +extern "C" __global__ AICORE void TCVT_i16_to_ui32_4x32(__gm__ int16_t *src, __gm__ uint32_t *dst); +extern "C" __global__ AICORE void TCVT_i16_to_ui32_2x128(__gm__ int16_t *src, __gm__ uint32_t *dst); +extern "C" __global__ AICORE void TCVT_i16_to_ui32_4x65(__gm__ int16_t *src, __gm__ uint32_t *dst); +extern "C" __global__ AICORE void TCVT_i16_to_ui32_4x200(__gm__ int16_t *src, __gm__ uint32_t *dst); +extern "C" __global__ AICORE void TCVT_i16_to_ui32_1x129(__gm__ int16_t *src, __gm__ uint32_t *dst); +extern "C" __global__ AICORE void TCVT_i16_to_i32_1x128(__gm__ int16_t *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCVT_i16_to_i32_2x64(__gm__ int16_t *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCVT_i16_to_i32_4x32(__gm__ int16_t *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCVT_i16_to_i32_2x128(__gm__ int16_t *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCVT_i16_to_i32_4x65(__gm__ int16_t *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCVT_i16_to_i32_4x200(__gm__ int16_t *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCVT_i16_to_i32_1x129(__gm__ int16_t *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCVT_i32_to_f32_1x128(__gm__ int32_t *src, __gm__ float *dst); +extern "C" __global__ AICORE void TCVT_i32_to_f32_2x64(__gm__ int32_t *src, __gm__ float *dst); +extern "C" __global__ AICORE void TCVT_i32_to_f32_4x32(__gm__ int32_t *src, __gm__ float *dst); +extern "C" __global__ AICORE void TCVT_i32_to_f32_2x128(__gm__ int32_t *src, __gm__ float *dst); +extern "C" __global__ AICORE void TCVT_i32_to_f32_4x65(__gm__ int32_t *src, __gm__ float *dst); +extern "C" __global__ AICORE void TCVT_i32_to_f32_4x200(__gm__ int32_t *src, __gm__ float *dst); +extern "C" __global__ AICORE void TCVT_i32_to_f32_1x129(__gm__ int32_t *src, __gm__ float *dst); +extern "C" __global__ AICORE void TCVT_i32_to_i16_1x128(__gm__ int32_t *src, __gm__ int16_t *dst); +extern "C" __global__ AICORE void TCVT_i32_to_i16_2x64(__gm__ int32_t *src, __gm__ int16_t *dst); +extern "C" __global__ AICORE void TCVT_i32_to_i16_4x32(__gm__ int32_t *src, __gm__ int16_t *dst); +extern "C" __global__ AICORE void TCVT_i32_to_i16_2x128(__gm__ int32_t *src, __gm__ int16_t *dst); +extern "C" __global__ AICORE void TCVT_i32_to_i16_4x65(__gm__ int32_t *src, __gm__ int16_t *dst); +extern "C" __global__ AICORE void TCVT_i32_to_i16_4x200(__gm__ int32_t *src, __gm__ int16_t *dst); +extern "C" __global__ AICORE void TCVT_i32_to_i16_1x129(__gm__ int32_t *src, __gm__ int16_t *dst); +extern "C" __global__ AICORE void TCVT_i32_to_i64_1x128(__gm__ int32_t *src, __gm__ int64_t *dst); +extern "C" __global__ AICORE void TCVT_i32_to_i64_2x64(__gm__ int32_t *src, __gm__ int64_t *dst); +extern "C" __global__ AICORE void TCVT_i32_to_i64_4x32(__gm__ int32_t *src, __gm__ int64_t *dst); +extern "C" __global__ AICORE void TCVT_i32_to_i64_2x128(__gm__ int32_t *src, __gm__ int64_t *dst); +extern "C" __global__ AICORE void TCVT_i32_to_i64_4x65(__gm__ int32_t *src, __gm__ int64_t *dst); +extern "C" __global__ AICORE void TCVT_i32_to_i64_4x200(__gm__ int32_t *src, __gm__ int64_t *dst); +extern "C" __global__ AICORE void TCVT_i32_to_i64_1x129(__gm__ int32_t *src, __gm__ int64_t *dst); +extern "C" __global__ AICORE void TCVT_i32_to_ui8_1x128(__gm__ int32_t *src, __gm__ uint8_t *dst); +extern "C" __global__ AICORE void TCVT_i32_to_ui8_2x64(__gm__ int32_t *src, __gm__ uint8_t *dst); +extern "C" __global__ AICORE void TCVT_i32_to_ui8_4x32(__gm__ int32_t *src, __gm__ uint8_t *dst); +extern "C" __global__ AICORE void TCVT_i32_to_ui8_2x128(__gm__ int32_t *src, __gm__ uint8_t *dst); +extern "C" __global__ AICORE void TCVT_i32_to_ui8_4x65(__gm__ int32_t *src, __gm__ uint8_t *dst); +extern "C" __global__ AICORE void TCVT_i32_to_ui8_4x200(__gm__ int32_t *src, __gm__ uint8_t *dst); +extern "C" __global__ AICORE void TCVT_i32_to_ui8_1x129(__gm__ int32_t *src, __gm__ uint8_t *dst); +extern "C" __global__ AICORE void TCVT_i32_to_ui16_1x128(__gm__ int32_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_i32_to_ui16_2x64(__gm__ int32_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_i32_to_ui16_4x32(__gm__ int32_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_i32_to_ui16_2x128(__gm__ int32_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_i32_to_ui16_4x65(__gm__ int32_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_i32_to_ui16_4x200(__gm__ int32_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_i32_to_ui16_1x129(__gm__ int32_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_ui32_to_i16_1x128(__gm__ uint32_t *src, __gm__ int16_t *dst); +extern "C" __global__ AICORE void TCVT_ui32_to_i16_2x64(__gm__ uint32_t *src, __gm__ int16_t *dst); +extern "C" __global__ AICORE void TCVT_ui32_to_i16_4x32(__gm__ uint32_t *src, __gm__ int16_t *dst); +extern "C" __global__ AICORE void TCVT_ui32_to_i16_2x128(__gm__ uint32_t *src, __gm__ int16_t *dst); +extern "C" __global__ AICORE void TCVT_ui32_to_i16_4x65(__gm__ uint32_t *src, __gm__ int16_t *dst); +extern "C" __global__ AICORE void TCVT_ui32_to_i16_4x200(__gm__ uint32_t *src, __gm__ int16_t *dst); +extern "C" __global__ AICORE void TCVT_ui32_to_i16_1x129(__gm__ uint32_t *src, __gm__ int16_t *dst); +extern "C" __global__ AICORE void TCVT_ui32_to_ui16_1x128(__gm__ uint32_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_ui32_to_ui16_2x64(__gm__ uint32_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_ui32_to_ui16_4x32(__gm__ uint32_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_ui32_to_ui16_2x128(__gm__ uint32_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_ui32_to_ui16_4x65(__gm__ uint32_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_ui32_to_ui16_4x200(__gm__ uint32_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_ui32_to_ui16_1x129(__gm__ uint32_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TCVT_ui32_to_ui8_1x128(__gm__ uint32_t *src, __gm__ uint8_t *dst); +extern "C" __global__ AICORE void TCVT_ui32_to_ui8_2x64(__gm__ uint32_t *src, __gm__ uint8_t *dst); +extern "C" __global__ AICORE void TCVT_ui32_to_ui8_4x32(__gm__ uint32_t *src, __gm__ uint8_t *dst); +extern "C" __global__ AICORE void TCVT_ui32_to_ui8_2x128(__gm__ uint32_t *src, __gm__ uint8_t *dst); +extern "C" __global__ AICORE void TCVT_ui32_to_ui8_4x65(__gm__ uint32_t *src, __gm__ uint8_t *dst); +extern "C" __global__ AICORE void TCVT_ui32_to_ui8_4x200(__gm__ uint32_t *src, __gm__ uint8_t *dst); +extern "C" __global__ AICORE void TCVT_ui32_to_ui8_1x129(__gm__ uint32_t *src, __gm__ uint8_t *dst); +extern "C" __global__ AICORE void TCVT_i64_to_f32_1x128(__gm__ int64_t *src, __gm__ float *dst); +extern "C" __global__ AICORE void TCVT_i64_to_f32_2x64(__gm__ int64_t *src, __gm__ float *dst); +extern "C" __global__ AICORE void TCVT_i64_to_f32_4x32(__gm__ int64_t *src, __gm__ float *dst); +extern "C" __global__ AICORE void TCVT_i64_to_f32_2x128(__gm__ int64_t *src, __gm__ float *dst); +extern "C" __global__ AICORE void TCVT_i64_to_f32_4x65(__gm__ int64_t *src, __gm__ float *dst); +extern "C" __global__ AICORE void TCVT_i64_to_f32_4x200(__gm__ int64_t *src, __gm__ float *dst); +extern "C" __global__ AICORE void TCVT_i64_to_f32_1x129(__gm__ int64_t *src, __gm__ float *dst); +extern "C" __global__ AICORE void TCVT_i64_to_i32_1x128(__gm__ int64_t *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCVT_i64_to_i32_2x64(__gm__ int64_t *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCVT_i64_to_i32_4x32(__gm__ int64_t *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCVT_i64_to_i32_2x128(__gm__ int64_t *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCVT_i64_to_i32_4x65(__gm__ int64_t *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCVT_i64_to_i32_4x200(__gm__ int64_t *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TCVT_i64_to_i32_1x129(__gm__ int64_t *src, __gm__ int32_t *dst); + +void LaunchTCVT_f32_to_i32_rint_16x64(void *src, void *dst, void *stream) { + TCVT_f32_to_i32_rint_16x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ int32_t *)dst); +} + +void LaunchTCVT_f32_to_i32_round_16x64(void *src, void *dst, void *stream) { + TCVT_f32_to_i32_round_16x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ int32_t *)dst); +} + +void LaunchTCVT_i32_to_f32_rint_16x64(void *src, void *dst, void *stream) { + TCVT_i32_to_f32_rint_16x64<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ float *)dst); +} + +void LaunchTCVT_f32_to_f16_rint_16x64(void *src, void *dst, void *stream) { + TCVT_f32_to_f16_rint_16x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_f16_to_f32_rint_16x64(void *src, void *dst, void *stream) { + TCVT_f16_to_f32_rint_16x64<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ float *)dst); +} + +void LaunchTCVT_f32_to_f16_1x128(void *src, void *dst, void *stream) { + TCVT_f32_to_f16_1x128<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_f32_to_f16_2x64(void *src, void *dst, void *stream) { + TCVT_f32_to_f16_2x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_f32_to_f16_4x32(void *src, void *dst, void *stream) { + TCVT_f32_to_f16_4x32<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_f32_to_f16_2x128(void *src, void *dst, void *stream) { + TCVT_f32_to_f16_2x128<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_f32_to_f16_4x65(void *src, void *dst, void *stream) { + TCVT_f32_to_f16_4x65<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_f32_to_f16_4x200(void *src, void *dst, void *stream) { + TCVT_f32_to_f16_4x200<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_f32_to_f16_1x129(void *src, void *dst, void *stream) { + TCVT_f32_to_f16_1x129<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_f32_to_bf16_1x128(void *src, void *dst, void *stream) { + TCVT_f32_to_bf16_1x128<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_f32_to_bf16_2x64(void *src, void *dst, void *stream) { + TCVT_f32_to_bf16_2x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_f32_to_bf16_4x32(void *src, void *dst, void *stream) { + TCVT_f32_to_bf16_4x32<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_f32_to_bf16_2x128(void *src, void *dst, void *stream) { + TCVT_f32_to_bf16_2x128<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_f32_to_bf16_4x65(void *src, void *dst, void *stream) { + TCVT_f32_to_bf16_4x65<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_f32_to_bf16_4x200(void *src, void *dst, void *stream) { + TCVT_f32_to_bf16_4x200<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_f32_to_bf16_1x129(void *src, void *dst, void *stream) { + TCVT_f32_to_bf16_1x129<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_f32_to_i16_1x128(void *src, void *dst, void *stream) { + TCVT_f32_to_i16_1x128<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ int16_t *)dst); +} + +void LaunchTCVT_f32_to_i16_2x64(void *src, void *dst, void *stream) { + TCVT_f32_to_i16_2x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ int16_t *)dst); +} + +void LaunchTCVT_f32_to_i16_4x32(void *src, void *dst, void *stream) { + TCVT_f32_to_i16_4x32<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ int16_t *)dst); +} + +void LaunchTCVT_f32_to_i16_2x128(void *src, void *dst, void *stream) { + TCVT_f32_to_i16_2x128<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ int16_t *)dst); +} + +void LaunchTCVT_f32_to_i16_4x65(void *src, void *dst, void *stream) { + TCVT_f32_to_i16_4x65<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ int16_t *)dst); +} + +void LaunchTCVT_f32_to_i16_4x200(void *src, void *dst, void *stream) { + TCVT_f32_to_i16_4x200<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ int16_t *)dst); +} + +void LaunchTCVT_f32_to_i16_1x129(void *src, void *dst, void *stream) { + TCVT_f32_to_i16_1x129<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ int16_t *)dst); +} + +void LaunchTCVT_f32_to_i32_1x128(void *src, void *dst, void *stream) { + TCVT_f32_to_i32_1x128<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ int32_t *)dst); +} + +void LaunchTCVT_f32_to_i32_2x64(void *src, void *dst, void *stream) { + TCVT_f32_to_i32_2x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ int32_t *)dst); +} + +void LaunchTCVT_f32_to_i32_4x32(void *src, void *dst, void *stream) { + TCVT_f32_to_i32_4x32<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ int32_t *)dst); +} + +void LaunchTCVT_f32_to_i32_2x128(void *src, void *dst, void *stream) { + TCVT_f32_to_i32_2x128<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ int32_t *)dst); +} + +void LaunchTCVT_f32_to_i32_4x65(void *src, void *dst, void *stream) { + TCVT_f32_to_i32_4x65<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ int32_t *)dst); +} + +void LaunchTCVT_f32_to_i32_4x200(void *src, void *dst, void *stream) { + TCVT_f32_to_i32_4x200<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ int32_t *)dst); +} + +void LaunchTCVT_f32_to_i32_1x129(void *src, void *dst, void *stream) { + TCVT_f32_to_i32_1x129<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ int32_t *)dst); +} + +void LaunchTCVT_f32_to_i64_1x128(void *src, void *dst, void *stream) { + TCVT_f32_to_i64_1x128<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ int64_t *)dst); +} + +void LaunchTCVT_f32_to_i64_2x64(void *src, void *dst, void *stream) { + TCVT_f32_to_i64_2x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ int64_t *)dst); +} + +void LaunchTCVT_f32_to_i64_4x32(void *src, void *dst, void *stream) { + TCVT_f32_to_i64_4x32<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ int64_t *)dst); +} + +void LaunchTCVT_f32_to_i64_2x128(void *src, void *dst, void *stream) { + TCVT_f32_to_i64_2x128<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ int64_t *)dst); +} + +void LaunchTCVT_f32_to_i64_4x65(void *src, void *dst, void *stream) { + TCVT_f32_to_i64_4x65<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ int64_t *)dst); +} + +void LaunchTCVT_f32_to_i64_4x200(void *src, void *dst, void *stream) { + TCVT_f32_to_i64_4x200<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ int64_t *)dst); +} + +void LaunchTCVT_f32_to_i64_1x129(void *src, void *dst, void *stream) { + TCVT_f32_to_i64_1x129<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ int64_t *)dst); +} + +void LaunchTCVT_f32_to_f32_1x128(void *src, void *dst, void *stream) { + TCVT_f32_to_f32_1x128<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +void LaunchTCVT_f32_to_f32_2x64(void *src, void *dst, void *stream) { + TCVT_f32_to_f32_2x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +void LaunchTCVT_f32_to_f32_4x32(void *src, void *dst, void *stream) { + TCVT_f32_to_f32_4x32<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +void LaunchTCVT_f32_to_f32_2x128(void *src, void *dst, void *stream) { + TCVT_f32_to_f32_2x128<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +void LaunchTCVT_f32_to_f32_4x65(void *src, void *dst, void *stream) { + TCVT_f32_to_f32_4x65<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +void LaunchTCVT_f32_to_f32_4x200(void *src, void *dst, void *stream) { + TCVT_f32_to_f32_4x200<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +void LaunchTCVT_f32_to_f32_1x129(void *src, void *dst, void *stream) { + TCVT_f32_to_f32_1x129<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +void LaunchTCVT_f16_to_f32_1x128(void *src, void *dst, void *stream) { + TCVT_f16_to_f32_1x128<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ float *)dst); +} + +void LaunchTCVT_f16_to_f32_2x64(void *src, void *dst, void *stream) { + TCVT_f16_to_f32_2x64<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ float *)dst); +} + +void LaunchTCVT_f16_to_f32_4x32(void *src, void *dst, void *stream) { + TCVT_f16_to_f32_4x32<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ float *)dst); +} + +void LaunchTCVT_f16_to_f32_2x128(void *src, void *dst, void *stream) { + TCVT_f16_to_f32_2x128<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ float *)dst); +} + +void LaunchTCVT_f16_to_f32_4x65(void *src, void *dst, void *stream) { + TCVT_f16_to_f32_4x65<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ float *)dst); +} + +void LaunchTCVT_f16_to_f32_4x200(void *src, void *dst, void *stream) { + TCVT_f16_to_f32_4x200<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ float *)dst); +} + +void LaunchTCVT_f16_to_f32_1x129(void *src, void *dst, void *stream) { + TCVT_f16_to_f32_1x129<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ float *)dst); +} + +void LaunchTCVT_f16_to_i32_1x128(void *src, void *dst, void *stream) { + TCVT_f16_to_i32_1x128<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ int32_t *)dst); +} + +void LaunchTCVT_f16_to_i32_2x64(void *src, void *dst, void *stream) { + TCVT_f16_to_i32_2x64<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ int32_t *)dst); +} + +void LaunchTCVT_f16_to_i32_4x32(void *src, void *dst, void *stream) { + TCVT_f16_to_i32_4x32<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ int32_t *)dst); +} + +void LaunchTCVT_f16_to_i32_2x128(void *src, void *dst, void *stream) { + TCVT_f16_to_i32_2x128<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ int32_t *)dst); +} + +void LaunchTCVT_f16_to_i32_4x65(void *src, void *dst, void *stream) { + TCVT_f16_to_i32_4x65<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ int32_t *)dst); +} + +void LaunchTCVT_f16_to_i32_4x200(void *src, void *dst, void *stream) { + TCVT_f16_to_i32_4x200<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ int32_t *)dst); +} + +void LaunchTCVT_f16_to_i32_1x129(void *src, void *dst, void *stream) { + TCVT_f16_to_i32_1x129<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ int32_t *)dst); +} + +void LaunchTCVT_f16_to_i16_1x128(void *src, void *dst, void *stream) { + TCVT_f16_to_i16_1x128<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ int16_t *)dst); +} + +void LaunchTCVT_f16_to_i16_2x64(void *src, void *dst, void *stream) { + TCVT_f16_to_i16_2x64<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ int16_t *)dst); +} + +void LaunchTCVT_f16_to_i16_4x32(void *src, void *dst, void *stream) { + TCVT_f16_to_i16_4x32<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ int16_t *)dst); +} + +void LaunchTCVT_f16_to_i16_2x128(void *src, void *dst, void *stream) { + TCVT_f16_to_i16_2x128<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ int16_t *)dst); +} + +void LaunchTCVT_f16_to_i16_4x65(void *src, void *dst, void *stream) { + TCVT_f16_to_i16_4x65<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ int16_t *)dst); +} + +void LaunchTCVT_f16_to_i16_4x200(void *src, void *dst, void *stream) { + TCVT_f16_to_i16_4x200<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ int16_t *)dst); +} + +void LaunchTCVT_f16_to_i16_1x129(void *src, void *dst, void *stream) { + TCVT_f16_to_i16_1x129<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ int16_t *)dst); +} + +void LaunchTCVT_f16_to_si8_1x128(void *src, void *dst, void *stream) { + TCVT_f16_to_si8_1x128<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ int8_t *)dst); +} + +void LaunchTCVT_f16_to_si8_2x64(void *src, void *dst, void *stream) { + TCVT_f16_to_si8_2x64<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ int8_t *)dst); +} + +void LaunchTCVT_f16_to_si8_4x32(void *src, void *dst, void *stream) { + TCVT_f16_to_si8_4x32<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ int8_t *)dst); +} + +void LaunchTCVT_f16_to_si8_2x128(void *src, void *dst, void *stream) { + TCVT_f16_to_si8_2x128<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ int8_t *)dst); +} + +void LaunchTCVT_f16_to_si8_4x65(void *src, void *dst, void *stream) { + TCVT_f16_to_si8_4x65<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ int8_t *)dst); +} + +void LaunchTCVT_f16_to_si8_4x200(void *src, void *dst, void *stream) { + TCVT_f16_to_si8_4x200<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ int8_t *)dst); +} + +void LaunchTCVT_f16_to_si8_1x129(void *src, void *dst, void *stream) { + TCVT_f16_to_si8_1x129<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ int8_t *)dst); +} + +void LaunchTCVT_f16_to_ui8_1x128(void *src, void *dst, void *stream) { + TCVT_f16_to_ui8_1x128<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint8_t *)dst); +} + +void LaunchTCVT_f16_to_ui8_2x64(void *src, void *dst, void *stream) { + TCVT_f16_to_ui8_2x64<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint8_t *)dst); +} + +void LaunchTCVT_f16_to_ui8_4x32(void *src, void *dst, void *stream) { + TCVT_f16_to_ui8_4x32<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint8_t *)dst); +} + +void LaunchTCVT_f16_to_ui8_2x128(void *src, void *dst, void *stream) { + TCVT_f16_to_ui8_2x128<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint8_t *)dst); +} + +void LaunchTCVT_f16_to_ui8_4x65(void *src, void *dst, void *stream) { + TCVT_f16_to_ui8_4x65<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint8_t *)dst); +} + +void LaunchTCVT_f16_to_ui8_4x200(void *src, void *dst, void *stream) { + TCVT_f16_to_ui8_4x200<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint8_t *)dst); +} + +void LaunchTCVT_f16_to_ui8_1x129(void *src, void *dst, void *stream) { + TCVT_f16_to_ui8_1x129<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint8_t *)dst); +} + +void LaunchTCVT_bf16_to_f32_1x128(void *src, void *dst, void *stream) { + TCVT_bf16_to_f32_1x128<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ float *)dst); +} + +void LaunchTCVT_bf16_to_f32_2x64(void *src, void *dst, void *stream) { + TCVT_bf16_to_f32_2x64<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ float *)dst); +} + +void LaunchTCVT_bf16_to_f32_4x32(void *src, void *dst, void *stream) { + TCVT_bf16_to_f32_4x32<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ float *)dst); +} + +void LaunchTCVT_bf16_to_f32_2x128(void *src, void *dst, void *stream) { + TCVT_bf16_to_f32_2x128<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ float *)dst); +} + +void LaunchTCVT_bf16_to_f32_4x65(void *src, void *dst, void *stream) { + TCVT_bf16_to_f32_4x65<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ float *)dst); +} + +void LaunchTCVT_bf16_to_f32_4x200(void *src, void *dst, void *stream) { + TCVT_bf16_to_f32_4x200<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ float *)dst); +} + +void LaunchTCVT_bf16_to_f32_1x129(void *src, void *dst, void *stream) { + TCVT_bf16_to_f32_1x129<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ float *)dst); +} + +void LaunchTCVT_bf16_to_f16_1x128(void *src, void *dst, void *stream) { + TCVT_bf16_to_f16_1x128<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_bf16_to_f16_2x64(void *src, void *dst, void *stream) { + TCVT_bf16_to_f16_2x64<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_bf16_to_f16_4x32(void *src, void *dst, void *stream) { + TCVT_bf16_to_f16_4x32<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_bf16_to_f16_2x128(void *src, void *dst, void *stream) { + TCVT_bf16_to_f16_2x128<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_bf16_to_f16_4x65(void *src, void *dst, void *stream) { + TCVT_bf16_to_f16_4x65<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_bf16_to_f16_4x200(void *src, void *dst, void *stream) { + TCVT_bf16_to_f16_4x200<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_bf16_to_f16_1x129(void *src, void *dst, void *stream) { + TCVT_bf16_to_f16_1x129<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_bf16_to_i32_1x128(void *src, void *dst, void *stream) { + TCVT_bf16_to_i32_1x128<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ int32_t *)dst); +} + +void LaunchTCVT_bf16_to_i32_2x64(void *src, void *dst, void *stream) { + TCVT_bf16_to_i32_2x64<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ int32_t *)dst); +} + +void LaunchTCVT_bf16_to_i32_4x32(void *src, void *dst, void *stream) { + TCVT_bf16_to_i32_4x32<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ int32_t *)dst); +} + +void LaunchTCVT_bf16_to_i32_2x128(void *src, void *dst, void *stream) { + TCVT_bf16_to_i32_2x128<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ int32_t *)dst); +} + +void LaunchTCVT_bf16_to_i32_4x65(void *src, void *dst, void *stream) { + TCVT_bf16_to_i32_4x65<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ int32_t *)dst); +} + +void LaunchTCVT_bf16_to_i32_4x200(void *src, void *dst, void *stream) { + TCVT_bf16_to_i32_4x200<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ int32_t *)dst); +} + +void LaunchTCVT_bf16_to_i32_1x129(void *src, void *dst, void *stream) { + TCVT_bf16_to_i32_1x129<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ int32_t *)dst); +} + +void LaunchTCVT_ui8_to_f16_1x128(void *src, void *dst, void *stream) { + TCVT_ui8_to_f16_1x128<<<1, nullptr, stream>>>((__gm__ uint8_t *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_ui8_to_f16_2x64(void *src, void *dst, void *stream) { + TCVT_ui8_to_f16_2x64<<<1, nullptr, stream>>>((__gm__ uint8_t *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_ui8_to_f16_4x32(void *src, void *dst, void *stream) { + TCVT_ui8_to_f16_4x32<<<1, nullptr, stream>>>((__gm__ uint8_t *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_ui8_to_f16_2x128(void *src, void *dst, void *stream) { + TCVT_ui8_to_f16_2x128<<<1, nullptr, stream>>>((__gm__ uint8_t *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_ui8_to_f16_4x65(void *src, void *dst, void *stream) { + TCVT_ui8_to_f16_4x65<<<1, nullptr, stream>>>((__gm__ uint8_t *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_ui8_to_f16_4x200(void *src, void *dst, void *stream) { + TCVT_ui8_to_f16_4x200<<<1, nullptr, stream>>>((__gm__ uint8_t *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_ui8_to_f16_1x129(void *src, void *dst, void *stream) { + TCVT_ui8_to_f16_1x129<<<1, nullptr, stream>>>((__gm__ uint8_t *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_ui8_to_ui16_1x128(void *src, void *dst, void *stream) { + TCVT_ui8_to_ui16_1x128<<<1, nullptr, stream>>>((__gm__ uint8_t *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_ui8_to_ui16_2x64(void *src, void *dst, void *stream) { + TCVT_ui8_to_ui16_2x64<<<1, nullptr, stream>>>((__gm__ uint8_t *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_ui8_to_ui16_4x32(void *src, void *dst, void *stream) { + TCVT_ui8_to_ui16_4x32<<<1, nullptr, stream>>>((__gm__ uint8_t *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_ui8_to_ui16_2x128(void *src, void *dst, void *stream) { + TCVT_ui8_to_ui16_2x128<<<1, nullptr, stream>>>((__gm__ uint8_t *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_ui8_to_ui16_4x65(void *src, void *dst, void *stream) { + TCVT_ui8_to_ui16_4x65<<<1, nullptr, stream>>>((__gm__ uint8_t *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_ui8_to_ui16_4x200(void *src, void *dst, void *stream) { + TCVT_ui8_to_ui16_4x200<<<1, nullptr, stream>>>((__gm__ uint8_t *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_ui8_to_ui16_1x129(void *src, void *dst, void *stream) { + TCVT_ui8_to_ui16_1x129<<<1, nullptr, stream>>>((__gm__ uint8_t *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_si8_to_f16_1x128(void *src, void *dst, void *stream) { + TCVT_si8_to_f16_1x128<<<1, nullptr, stream>>>((__gm__ int8_t *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_si8_to_f16_2x64(void *src, void *dst, void *stream) { + TCVT_si8_to_f16_2x64<<<1, nullptr, stream>>>((__gm__ int8_t *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_si8_to_f16_4x32(void *src, void *dst, void *stream) { + TCVT_si8_to_f16_4x32<<<1, nullptr, stream>>>((__gm__ int8_t *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_si8_to_f16_2x128(void *src, void *dst, void *stream) { + TCVT_si8_to_f16_2x128<<<1, nullptr, stream>>>((__gm__ int8_t *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_si8_to_f16_4x65(void *src, void *dst, void *stream) { + TCVT_si8_to_f16_4x65<<<1, nullptr, stream>>>((__gm__ int8_t *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_si8_to_f16_4x200(void *src, void *dst, void *stream) { + TCVT_si8_to_f16_4x200<<<1, nullptr, stream>>>((__gm__ int8_t *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_si8_to_f16_1x129(void *src, void *dst, void *stream) { + TCVT_si8_to_f16_1x129<<<1, nullptr, stream>>>((__gm__ int8_t *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_si8_to_si16_1x128(void *src, void *dst, void *stream) { + TCVT_si8_to_si16_1x128<<<1, nullptr, stream>>>((__gm__ int8_t *)src, (__gm__ int16_t *)dst); +} + +void LaunchTCVT_si8_to_si16_2x64(void *src, void *dst, void *stream) { + TCVT_si8_to_si16_2x64<<<1, nullptr, stream>>>((__gm__ int8_t *)src, (__gm__ int16_t *)dst); +} + +void LaunchTCVT_si8_to_si16_4x32(void *src, void *dst, void *stream) { + TCVT_si8_to_si16_4x32<<<1, nullptr, stream>>>((__gm__ int8_t *)src, (__gm__ int16_t *)dst); +} + +void LaunchTCVT_si8_to_si16_2x128(void *src, void *dst, void *stream) { + TCVT_si8_to_si16_2x128<<<1, nullptr, stream>>>((__gm__ int8_t *)src, (__gm__ int16_t *)dst); +} + +void LaunchTCVT_si8_to_si16_4x65(void *src, void *dst, void *stream) { + TCVT_si8_to_si16_4x65<<<1, nullptr, stream>>>((__gm__ int8_t *)src, (__gm__ int16_t *)dst); +} + +void LaunchTCVT_si8_to_si16_4x200(void *src, void *dst, void *stream) { + TCVT_si8_to_si16_4x200<<<1, nullptr, stream>>>((__gm__ int8_t *)src, (__gm__ int16_t *)dst); +} + +void LaunchTCVT_si8_to_si16_1x129(void *src, void *dst, void *stream) { + TCVT_si8_to_si16_1x129<<<1, nullptr, stream>>>((__gm__ int8_t *)src, (__gm__ int16_t *)dst); +} + +void LaunchTCVT_si8_to_i32_1x128(void *src, void *dst, void *stream) { + TCVT_si8_to_i32_1x128<<<1, nullptr, stream>>>((__gm__ int8_t *)src, (__gm__ int32_t *)dst); +} + +void LaunchTCVT_si8_to_i32_2x64(void *src, void *dst, void *stream) { + TCVT_si8_to_i32_2x64<<<1, nullptr, stream>>>((__gm__ int8_t *)src, (__gm__ int32_t *)dst); +} + +void LaunchTCVT_si8_to_i32_4x32(void *src, void *dst, void *stream) { + TCVT_si8_to_i32_4x32<<<1, nullptr, stream>>>((__gm__ int8_t *)src, (__gm__ int32_t *)dst); +} + +void LaunchTCVT_si8_to_i32_2x128(void *src, void *dst, void *stream) { + TCVT_si8_to_i32_2x128<<<1, nullptr, stream>>>((__gm__ int8_t *)src, (__gm__ int32_t *)dst); +} + +void LaunchTCVT_si8_to_i32_4x65(void *src, void *dst, void *stream) { + TCVT_si8_to_i32_4x65<<<1, nullptr, stream>>>((__gm__ int8_t *)src, (__gm__ int32_t *)dst); +} + +void LaunchTCVT_si8_to_i32_4x200(void *src, void *dst, void *stream) { + TCVT_si8_to_i32_4x200<<<1, nullptr, stream>>>((__gm__ int8_t *)src, (__gm__ int32_t *)dst); +} + +void LaunchTCVT_si8_to_i32_1x129(void *src, void *dst, void *stream) { + TCVT_si8_to_i32_1x129<<<1, nullptr, stream>>>((__gm__ int8_t *)src, (__gm__ int32_t *)dst); +} + +void LaunchTCVT_i16_to_ui8_1x128(void *src, void *dst, void *stream) { + TCVT_i16_to_ui8_1x128<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ uint8_t *)dst); +} + +void LaunchTCVT_i16_to_ui8_2x64(void *src, void *dst, void *stream) { + TCVT_i16_to_ui8_2x64<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ uint8_t *)dst); +} + +void LaunchTCVT_i16_to_ui8_4x32(void *src, void *dst, void *stream) { + TCVT_i16_to_ui8_4x32<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ uint8_t *)dst); +} + +void LaunchTCVT_i16_to_ui8_2x128(void *src, void *dst, void *stream) { + TCVT_i16_to_ui8_2x128<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ uint8_t *)dst); +} + +void LaunchTCVT_i16_to_ui8_4x65(void *src, void *dst, void *stream) { + TCVT_i16_to_ui8_4x65<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ uint8_t *)dst); +} + +void LaunchTCVT_i16_to_ui8_4x200(void *src, void *dst, void *stream) { + TCVT_i16_to_ui8_4x200<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ uint8_t *)dst); +} + +void LaunchTCVT_i16_to_ui8_1x129(void *src, void *dst, void *stream) { + TCVT_i16_to_ui8_1x129<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ uint8_t *)dst); +} + +void LaunchTCVT_i16_to_f16_1x128(void *src, void *dst, void *stream) { + TCVT_i16_to_f16_1x128<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_i16_to_f16_2x64(void *src, void *dst, void *stream) { + TCVT_i16_to_f16_2x64<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_i16_to_f16_4x32(void *src, void *dst, void *stream) { + TCVT_i16_to_f16_4x32<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_i16_to_f16_2x128(void *src, void *dst, void *stream) { + TCVT_i16_to_f16_2x128<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_i16_to_f16_4x65(void *src, void *dst, void *stream) { + TCVT_i16_to_f16_4x65<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_i16_to_f16_4x200(void *src, void *dst, void *stream) { + TCVT_i16_to_f16_4x200<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_i16_to_f16_1x129(void *src, void *dst, void *stream) { + TCVT_i16_to_f16_1x129<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_i16_to_f32_1x128(void *src, void *dst, void *stream) { + TCVT_i16_to_f32_1x128<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ float *)dst); +} + +void LaunchTCVT_i16_to_f32_2x64(void *src, void *dst, void *stream) { + TCVT_i16_to_f32_2x64<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ float *)dst); +} + +void LaunchTCVT_i16_to_f32_4x32(void *src, void *dst, void *stream) { + TCVT_i16_to_f32_4x32<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ float *)dst); +} + +void LaunchTCVT_i16_to_f32_2x128(void *src, void *dst, void *stream) { + TCVT_i16_to_f32_2x128<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ float *)dst); +} + +void LaunchTCVT_i16_to_f32_4x65(void *src, void *dst, void *stream) { + TCVT_i16_to_f32_4x65<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ float *)dst); +} + +void LaunchTCVT_i16_to_f32_4x200(void *src, void *dst, void *stream) { + TCVT_i16_to_f32_4x200<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ float *)dst); +} + +void LaunchTCVT_i16_to_f32_1x129(void *src, void *dst, void *stream) { + TCVT_i16_to_f32_1x129<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ float *)dst); +} + +void LaunchTCVT_i16_to_ui32_1x128(void *src, void *dst, void *stream) { + TCVT_i16_to_ui32_1x128<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ uint32_t *)dst); +} + +void LaunchTCVT_i16_to_ui32_2x64(void *src, void *dst, void *stream) { + TCVT_i16_to_ui32_2x64<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ uint32_t *)dst); +} + +void LaunchTCVT_i16_to_ui32_4x32(void *src, void *dst, void *stream) { + TCVT_i16_to_ui32_4x32<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ uint32_t *)dst); +} + +void LaunchTCVT_i16_to_ui32_2x128(void *src, void *dst, void *stream) { + TCVT_i16_to_ui32_2x128<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ uint32_t *)dst); +} + +void LaunchTCVT_i16_to_ui32_4x65(void *src, void *dst, void *stream) { + TCVT_i16_to_ui32_4x65<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ uint32_t *)dst); +} + +void LaunchTCVT_i16_to_ui32_4x200(void *src, void *dst, void *stream) { + TCVT_i16_to_ui32_4x200<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ uint32_t *)dst); +} + +void LaunchTCVT_i16_to_ui32_1x129(void *src, void *dst, void *stream) { + TCVT_i16_to_ui32_1x129<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ uint32_t *)dst); +} + +void LaunchTCVT_i16_to_i32_1x128(void *src, void *dst, void *stream) { + TCVT_i16_to_i32_1x128<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int32_t *)dst); +} + +void LaunchTCVT_i16_to_i32_2x64(void *src, void *dst, void *stream) { + TCVT_i16_to_i32_2x64<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int32_t *)dst); +} + +void LaunchTCVT_i16_to_i32_4x32(void *src, void *dst, void *stream) { + TCVT_i16_to_i32_4x32<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int32_t *)dst); +} + +void LaunchTCVT_i16_to_i32_2x128(void *src, void *dst, void *stream) { + TCVT_i16_to_i32_2x128<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int32_t *)dst); +} + +void LaunchTCVT_i16_to_i32_4x65(void *src, void *dst, void *stream) { + TCVT_i16_to_i32_4x65<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int32_t *)dst); +} + +void LaunchTCVT_i16_to_i32_4x200(void *src, void *dst, void *stream) { + TCVT_i16_to_i32_4x200<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int32_t *)dst); +} + +void LaunchTCVT_i16_to_i32_1x129(void *src, void *dst, void *stream) { + TCVT_i16_to_i32_1x129<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int32_t *)dst); +} + +void LaunchTCVT_i32_to_f32_1x128(void *src, void *dst, void *stream) { + TCVT_i32_to_f32_1x128<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ float *)dst); +} + +void LaunchTCVT_i32_to_f32_2x64(void *src, void *dst, void *stream) { + TCVT_i32_to_f32_2x64<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ float *)dst); +} + +void LaunchTCVT_i32_to_f32_4x32(void *src, void *dst, void *stream) { + TCVT_i32_to_f32_4x32<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ float *)dst); +} + +void LaunchTCVT_i32_to_f32_2x128(void *src, void *dst, void *stream) { + TCVT_i32_to_f32_2x128<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ float *)dst); +} + +void LaunchTCVT_i32_to_f32_4x65(void *src, void *dst, void *stream) { + TCVT_i32_to_f32_4x65<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ float *)dst); +} + +void LaunchTCVT_i32_to_f32_4x200(void *src, void *dst, void *stream) { + TCVT_i32_to_f32_4x200<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ float *)dst); +} + +void LaunchTCVT_i32_to_f32_1x129(void *src, void *dst, void *stream) { + TCVT_i32_to_f32_1x129<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ float *)dst); +} + +void LaunchTCVT_i32_to_i16_1x128(void *src, void *dst, void *stream) { + TCVT_i32_to_i16_1x128<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int16_t *)dst); +} + +void LaunchTCVT_i32_to_i16_2x64(void *src, void *dst, void *stream) { + TCVT_i32_to_i16_2x64<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int16_t *)dst); +} + +void LaunchTCVT_i32_to_i16_4x32(void *src, void *dst, void *stream) { + TCVT_i32_to_i16_4x32<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int16_t *)dst); +} + +void LaunchTCVT_i32_to_i16_2x128(void *src, void *dst, void *stream) { + TCVT_i32_to_i16_2x128<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int16_t *)dst); +} + +void LaunchTCVT_i32_to_i16_4x65(void *src, void *dst, void *stream) { + TCVT_i32_to_i16_4x65<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int16_t *)dst); +} + +void LaunchTCVT_i32_to_i16_4x200(void *src, void *dst, void *stream) { + TCVT_i32_to_i16_4x200<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int16_t *)dst); +} + +void LaunchTCVT_i32_to_i16_1x129(void *src, void *dst, void *stream) { + TCVT_i32_to_i16_1x129<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int16_t *)dst); +} + +void LaunchTCVT_i32_to_i64_1x128(void *src, void *dst, void *stream) { + TCVT_i32_to_i64_1x128<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int64_t *)dst); +} + +void LaunchTCVT_i32_to_i64_2x64(void *src, void *dst, void *stream) { + TCVT_i32_to_i64_2x64<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int64_t *)dst); +} + +void LaunchTCVT_i32_to_i64_4x32(void *src, void *dst, void *stream) { + TCVT_i32_to_i64_4x32<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int64_t *)dst); +} + +void LaunchTCVT_i32_to_i64_2x128(void *src, void *dst, void *stream) { + TCVT_i32_to_i64_2x128<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int64_t *)dst); +} + +void LaunchTCVT_i32_to_i64_4x65(void *src, void *dst, void *stream) { + TCVT_i32_to_i64_4x65<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int64_t *)dst); +} + +void LaunchTCVT_i32_to_i64_4x200(void *src, void *dst, void *stream) { + TCVT_i32_to_i64_4x200<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int64_t *)dst); +} + +void LaunchTCVT_i32_to_i64_1x129(void *src, void *dst, void *stream) { + TCVT_i32_to_i64_1x129<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int64_t *)dst); +} + +void LaunchTCVT_i32_to_ui8_1x128(void *src, void *dst, void *stream) { + TCVT_i32_to_ui8_1x128<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ uint8_t *)dst); +} + +void LaunchTCVT_i32_to_ui8_2x64(void *src, void *dst, void *stream) { + TCVT_i32_to_ui8_2x64<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ uint8_t *)dst); +} + +void LaunchTCVT_i32_to_ui8_4x32(void *src, void *dst, void *stream) { + TCVT_i32_to_ui8_4x32<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ uint8_t *)dst); +} + +void LaunchTCVT_i32_to_ui8_2x128(void *src, void *dst, void *stream) { + TCVT_i32_to_ui8_2x128<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ uint8_t *)dst); +} + +void LaunchTCVT_i32_to_ui8_4x65(void *src, void *dst, void *stream) { + TCVT_i32_to_ui8_4x65<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ uint8_t *)dst); +} + +void LaunchTCVT_i32_to_ui8_4x200(void *src, void *dst, void *stream) { + TCVT_i32_to_ui8_4x200<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ uint8_t *)dst); +} + +void LaunchTCVT_i32_to_ui8_1x129(void *src, void *dst, void *stream) { + TCVT_i32_to_ui8_1x129<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ uint8_t *)dst); +} + +void LaunchTCVT_i32_to_ui16_1x128(void *src, void *dst, void *stream) { + TCVT_i32_to_ui16_1x128<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_i32_to_ui16_2x64(void *src, void *dst, void *stream) { + TCVT_i32_to_ui16_2x64<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_i32_to_ui16_4x32(void *src, void *dst, void *stream) { + TCVT_i32_to_ui16_4x32<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_i32_to_ui16_2x128(void *src, void *dst, void *stream) { + TCVT_i32_to_ui16_2x128<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_i32_to_ui16_4x65(void *src, void *dst, void *stream) { + TCVT_i32_to_ui16_4x65<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_i32_to_ui16_4x200(void *src, void *dst, void *stream) { + TCVT_i32_to_ui16_4x200<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_i32_to_ui16_1x129(void *src, void *dst, void *stream) { + TCVT_i32_to_ui16_1x129<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_ui32_to_i16_1x128(void *src, void *dst, void *stream) { + TCVT_ui32_to_i16_1x128<<<1, nullptr, stream>>>((__gm__ uint32_t *)src, (__gm__ int16_t *)dst); +} + +void LaunchTCVT_ui32_to_i16_2x64(void *src, void *dst, void *stream) { + TCVT_ui32_to_i16_2x64<<<1, nullptr, stream>>>((__gm__ uint32_t *)src, (__gm__ int16_t *)dst); +} + +void LaunchTCVT_ui32_to_i16_4x32(void *src, void *dst, void *stream) { + TCVT_ui32_to_i16_4x32<<<1, nullptr, stream>>>((__gm__ uint32_t *)src, (__gm__ int16_t *)dst); +} + +void LaunchTCVT_ui32_to_i16_2x128(void *src, void *dst, void *stream) { + TCVT_ui32_to_i16_2x128<<<1, nullptr, stream>>>((__gm__ uint32_t *)src, (__gm__ int16_t *)dst); +} + +void LaunchTCVT_ui32_to_i16_4x65(void *src, void *dst, void *stream) { + TCVT_ui32_to_i16_4x65<<<1, nullptr, stream>>>((__gm__ uint32_t *)src, (__gm__ int16_t *)dst); +} + +void LaunchTCVT_ui32_to_i16_4x200(void *src, void *dst, void *stream) { + TCVT_ui32_to_i16_4x200<<<1, nullptr, stream>>>((__gm__ uint32_t *)src, (__gm__ int16_t *)dst); +} + +void LaunchTCVT_ui32_to_i16_1x129(void *src, void *dst, void *stream) { + TCVT_ui32_to_i16_1x129<<<1, nullptr, stream>>>((__gm__ uint32_t *)src, (__gm__ int16_t *)dst); +} + +void LaunchTCVT_ui32_to_ui16_1x128(void *src, void *dst, void *stream) { + TCVT_ui32_to_ui16_1x128<<<1, nullptr, stream>>>((__gm__ uint32_t *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_ui32_to_ui16_2x64(void *src, void *dst, void *stream) { + TCVT_ui32_to_ui16_2x64<<<1, nullptr, stream>>>((__gm__ uint32_t *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_ui32_to_ui16_4x32(void *src, void *dst, void *stream) { + TCVT_ui32_to_ui16_4x32<<<1, nullptr, stream>>>((__gm__ uint32_t *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_ui32_to_ui16_2x128(void *src, void *dst, void *stream) { + TCVT_ui32_to_ui16_2x128<<<1, nullptr, stream>>>((__gm__ uint32_t *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_ui32_to_ui16_4x65(void *src, void *dst, void *stream) { + TCVT_ui32_to_ui16_4x65<<<1, nullptr, stream>>>((__gm__ uint32_t *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_ui32_to_ui16_4x200(void *src, void *dst, void *stream) { + TCVT_ui32_to_ui16_4x200<<<1, nullptr, stream>>>((__gm__ uint32_t *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_ui32_to_ui16_1x129(void *src, void *dst, void *stream) { + TCVT_ui32_to_ui16_1x129<<<1, nullptr, stream>>>((__gm__ uint32_t *)src, (__gm__ uint16_t *)dst); +} + +void LaunchTCVT_ui32_to_ui8_1x128(void *src, void *dst, void *stream) { + TCVT_ui32_to_ui8_1x128<<<1, nullptr, stream>>>((__gm__ uint32_t *)src, (__gm__ uint8_t *)dst); +} + +void LaunchTCVT_ui32_to_ui8_2x64(void *src, void *dst, void *stream) { + TCVT_ui32_to_ui8_2x64<<<1, nullptr, stream>>>((__gm__ uint32_t *)src, (__gm__ uint8_t *)dst); +} + +void LaunchTCVT_ui32_to_ui8_4x32(void *src, void *dst, void *stream) { + TCVT_ui32_to_ui8_4x32<<<1, nullptr, stream>>>((__gm__ uint32_t *)src, (__gm__ uint8_t *)dst); +} + +void LaunchTCVT_ui32_to_ui8_2x128(void *src, void *dst, void *stream) { + TCVT_ui32_to_ui8_2x128<<<1, nullptr, stream>>>((__gm__ uint32_t *)src, (__gm__ uint8_t *)dst); +} + +void LaunchTCVT_ui32_to_ui8_4x65(void *src, void *dst, void *stream) { + TCVT_ui32_to_ui8_4x65<<<1, nullptr, stream>>>((__gm__ uint32_t *)src, (__gm__ uint8_t *)dst); +} + +void LaunchTCVT_ui32_to_ui8_4x200(void *src, void *dst, void *stream) { + TCVT_ui32_to_ui8_4x200<<<1, nullptr, stream>>>((__gm__ uint32_t *)src, (__gm__ uint8_t *)dst); +} + +void LaunchTCVT_ui32_to_ui8_1x129(void *src, void *dst, void *stream) { + TCVT_ui32_to_ui8_1x129<<<1, nullptr, stream>>>((__gm__ uint32_t *)src, (__gm__ uint8_t *)dst); +} + +void LaunchTCVT_i64_to_f32_1x128(void *src, void *dst, void *stream) { + TCVT_i64_to_f32_1x128<<<1, nullptr, stream>>>((__gm__ int64_t *)src, (__gm__ float *)dst); +} + +void LaunchTCVT_i64_to_f32_2x64(void *src, void *dst, void *stream) { + TCVT_i64_to_f32_2x64<<<1, nullptr, stream>>>((__gm__ int64_t *)src, (__gm__ float *)dst); +} + +void LaunchTCVT_i64_to_f32_4x32(void *src, void *dst, void *stream) { + TCVT_i64_to_f32_4x32<<<1, nullptr, stream>>>((__gm__ int64_t *)src, (__gm__ float *)dst); +} + +void LaunchTCVT_i64_to_f32_2x128(void *src, void *dst, void *stream) { + TCVT_i64_to_f32_2x128<<<1, nullptr, stream>>>((__gm__ int64_t *)src, (__gm__ float *)dst); +} + +void LaunchTCVT_i64_to_f32_4x65(void *src, void *dst, void *stream) { + TCVT_i64_to_f32_4x65<<<1, nullptr, stream>>>((__gm__ int64_t *)src, (__gm__ float *)dst); +} + +void LaunchTCVT_i64_to_f32_4x200(void *src, void *dst, void *stream) { + TCVT_i64_to_f32_4x200<<<1, nullptr, stream>>>((__gm__ int64_t *)src, (__gm__ float *)dst); +} + +void LaunchTCVT_i64_to_f32_1x129(void *src, void *dst, void *stream) { + TCVT_i64_to_f32_1x129<<<1, nullptr, stream>>>((__gm__ int64_t *)src, (__gm__ float *)dst); +} + +void LaunchTCVT_i64_to_i32_1x128(void *src, void *dst, void *stream) { + TCVT_i64_to_i32_1x128<<<1, nullptr, stream>>>((__gm__ int64_t *)src, (__gm__ int32_t *)dst); +} + +void LaunchTCVT_i64_to_i32_2x64(void *src, void *dst, void *stream) { + TCVT_i64_to_i32_2x64<<<1, nullptr, stream>>>((__gm__ int64_t *)src, (__gm__ int32_t *)dst); +} + +void LaunchTCVT_i64_to_i32_4x32(void *src, void *dst, void *stream) { + TCVT_i64_to_i32_4x32<<<1, nullptr, stream>>>((__gm__ int64_t *)src, (__gm__ int32_t *)dst); +} + +void LaunchTCVT_i64_to_i32_2x128(void *src, void *dst, void *stream) { + TCVT_i64_to_i32_2x128<<<1, nullptr, stream>>>((__gm__ int64_t *)src, (__gm__ int32_t *)dst); +} + +void LaunchTCVT_i64_to_i32_4x65(void *src, void *dst, void *stream) { + TCVT_i64_to_i32_4x65<<<1, nullptr, stream>>>((__gm__ int64_t *)src, (__gm__ int32_t *)dst); +} + +void LaunchTCVT_i64_to_i32_4x200(void *src, void *dst, void *stream) { + TCVT_i64_to_i32_4x200<<<1, nullptr, stream>>>((__gm__ int64_t *)src, (__gm__ int32_t *)dst); +} + +void LaunchTCVT_i64_to_i32_1x129(void *src, void *dst, void *stream) { + TCVT_i64_to_i32_1x129<<<1, nullptr, stream>>>((__gm__ int64_t *)src, (__gm__ int32_t *)dst); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcvt/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/main.cpp new file mode 100644 index 000000000..db000cd2d --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/main.cpp @@ -0,0 +1,610 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +void LaunchTCVT_f32_to_i32_rint_16x64(void *src, void *dst, void *stream); +void LaunchTCVT_f32_to_i32_round_16x64(void *src, void *dst, void *stream); +void LaunchTCVT_i32_to_f32_rint_16x64(void *src, void *dst, void *stream); +void LaunchTCVT_f32_to_f16_rint_16x64(void *src, void *dst, void *stream); +void LaunchTCVT_f16_to_f32_rint_16x64(void *src, void *dst, void *stream); +void LaunchTCVT_f32_to_f16_1x128(void *src, void *dst, void *stream); +void LaunchTCVT_f32_to_f16_2x64(void *src, void *dst, void *stream); +void LaunchTCVT_f32_to_f16_4x32(void *src, void *dst, void *stream); +void LaunchTCVT_f32_to_f16_2x128(void *src, void *dst, void *stream); +void LaunchTCVT_f32_to_f16_4x65(void *src, void *dst, void *stream); +void LaunchTCVT_f32_to_f16_4x200(void *src, void *dst, void *stream); +void LaunchTCVT_f32_to_f16_1x129(void *src, void *dst, void *stream); +void LaunchTCVT_f32_to_bf16_1x128(void *src, void *dst, void *stream); +void LaunchTCVT_f32_to_bf16_2x64(void *src, void *dst, void *stream); +void LaunchTCVT_f32_to_bf16_4x32(void *src, void *dst, void *stream); +void LaunchTCVT_f32_to_bf16_2x128(void *src, void *dst, void *stream); +void LaunchTCVT_f32_to_bf16_4x65(void *src, void *dst, void *stream); +void LaunchTCVT_f32_to_bf16_4x200(void *src, void *dst, void *stream); +void LaunchTCVT_f32_to_bf16_1x129(void *src, void *dst, void *stream); +void LaunchTCVT_f32_to_i16_1x128(void *src, void *dst, void *stream); +void LaunchTCVT_f32_to_i16_2x64(void *src, void *dst, void *stream); +void LaunchTCVT_f32_to_i16_4x32(void *src, void *dst, void *stream); +void LaunchTCVT_f32_to_i16_2x128(void *src, void *dst, void *stream); +void LaunchTCVT_f32_to_i16_4x65(void *src, void *dst, void *stream); +void LaunchTCVT_f32_to_i16_4x200(void *src, void *dst, void *stream); +void LaunchTCVT_f32_to_i16_1x129(void *src, void *dst, void *stream); +void LaunchTCVT_f32_to_i32_1x128(void *src, void *dst, void *stream); +void LaunchTCVT_f32_to_i32_2x64(void *src, void *dst, void *stream); +void LaunchTCVT_f32_to_i32_4x32(void *src, void *dst, void *stream); +void LaunchTCVT_f32_to_i32_2x128(void *src, void *dst, void *stream); +void LaunchTCVT_f32_to_i32_4x65(void *src, void *dst, void *stream); +void LaunchTCVT_f32_to_i32_4x200(void *src, void *dst, void *stream); +void LaunchTCVT_f32_to_i32_1x129(void *src, void *dst, void *stream); +void LaunchTCVT_f32_to_i64_1x128(void *src, void *dst, void *stream); +void LaunchTCVT_f32_to_i64_2x64(void *src, void *dst, void *stream); +void LaunchTCVT_f32_to_i64_4x32(void *src, void *dst, void *stream); +void LaunchTCVT_f32_to_i64_2x128(void *src, void *dst, void *stream); +void LaunchTCVT_f32_to_i64_4x65(void *src, void *dst, void *stream); +void LaunchTCVT_f32_to_i64_4x200(void *src, void *dst, void *stream); +void LaunchTCVT_f32_to_i64_1x129(void *src, void *dst, void *stream); +void LaunchTCVT_f32_to_f32_1x128(void *src, void *dst, void *stream); +void LaunchTCVT_f32_to_f32_2x64(void *src, void *dst, void *stream); +void LaunchTCVT_f32_to_f32_4x32(void *src, void *dst, void *stream); +void LaunchTCVT_f32_to_f32_2x128(void *src, void *dst, void *stream); +void LaunchTCVT_f32_to_f32_4x65(void *src, void *dst, void *stream); +void LaunchTCVT_f32_to_f32_4x200(void *src, void *dst, void *stream); +void LaunchTCVT_f32_to_f32_1x129(void *src, void *dst, void *stream); +void LaunchTCVT_f16_to_f32_1x128(void *src, void *dst, void *stream); +void LaunchTCVT_f16_to_f32_2x64(void *src, void *dst, void *stream); +void LaunchTCVT_f16_to_f32_4x32(void *src, void *dst, void *stream); +void LaunchTCVT_f16_to_f32_2x128(void *src, void *dst, void *stream); +void LaunchTCVT_f16_to_f32_4x65(void *src, void *dst, void *stream); +void LaunchTCVT_f16_to_f32_4x200(void *src, void *dst, void *stream); +void LaunchTCVT_f16_to_f32_1x129(void *src, void *dst, void *stream); +void LaunchTCVT_f16_to_i32_1x128(void *src, void *dst, void *stream); +void LaunchTCVT_f16_to_i32_2x64(void *src, void *dst, void *stream); +void LaunchTCVT_f16_to_i32_4x32(void *src, void *dst, void *stream); +void LaunchTCVT_f16_to_i32_2x128(void *src, void *dst, void *stream); +void LaunchTCVT_f16_to_i32_4x65(void *src, void *dst, void *stream); +void LaunchTCVT_f16_to_i32_4x200(void *src, void *dst, void *stream); +void LaunchTCVT_f16_to_i32_1x129(void *src, void *dst, void *stream); +void LaunchTCVT_f16_to_i16_1x128(void *src, void *dst, void *stream); +void LaunchTCVT_f16_to_i16_2x64(void *src, void *dst, void *stream); +void LaunchTCVT_f16_to_i16_4x32(void *src, void *dst, void *stream); +void LaunchTCVT_f16_to_i16_2x128(void *src, void *dst, void *stream); +void LaunchTCVT_f16_to_i16_4x65(void *src, void *dst, void *stream); +void LaunchTCVT_f16_to_i16_4x200(void *src, void *dst, void *stream); +void LaunchTCVT_f16_to_i16_1x129(void *src, void *dst, void *stream); +void LaunchTCVT_f16_to_si8_1x128(void *src, void *dst, void *stream); +void LaunchTCVT_f16_to_si8_2x64(void *src, void *dst, void *stream); +void LaunchTCVT_f16_to_si8_4x32(void *src, void *dst, void *stream); +void LaunchTCVT_f16_to_si8_2x128(void *src, void *dst, void *stream); +void LaunchTCVT_f16_to_si8_4x65(void *src, void *dst, void *stream); +void LaunchTCVT_f16_to_si8_4x200(void *src, void *dst, void *stream); +void LaunchTCVT_f16_to_si8_1x129(void *src, void *dst, void *stream); +void LaunchTCVT_f16_to_ui8_1x128(void *src, void *dst, void *stream); +void LaunchTCVT_f16_to_ui8_2x64(void *src, void *dst, void *stream); +void LaunchTCVT_f16_to_ui8_4x32(void *src, void *dst, void *stream); +void LaunchTCVT_f16_to_ui8_2x128(void *src, void *dst, void *stream); +void LaunchTCVT_f16_to_ui8_4x65(void *src, void *dst, void *stream); +void LaunchTCVT_f16_to_ui8_4x200(void *src, void *dst, void *stream); +void LaunchTCVT_f16_to_ui8_1x129(void *src, void *dst, void *stream); +void LaunchTCVT_bf16_to_f32_1x128(void *src, void *dst, void *stream); +void LaunchTCVT_bf16_to_f32_2x64(void *src, void *dst, void *stream); +void LaunchTCVT_bf16_to_f32_4x32(void *src, void *dst, void *stream); +void LaunchTCVT_bf16_to_f32_2x128(void *src, void *dst, void *stream); +void LaunchTCVT_bf16_to_f32_4x65(void *src, void *dst, void *stream); +void LaunchTCVT_bf16_to_f32_4x200(void *src, void *dst, void *stream); +void LaunchTCVT_bf16_to_f32_1x129(void *src, void *dst, void *stream); +void LaunchTCVT_bf16_to_f16_1x128(void *src, void *dst, void *stream); +void LaunchTCVT_bf16_to_f16_2x64(void *src, void *dst, void *stream); +void LaunchTCVT_bf16_to_f16_4x32(void *src, void *dst, void *stream); +void LaunchTCVT_bf16_to_f16_2x128(void *src, void *dst, void *stream); +void LaunchTCVT_bf16_to_f16_4x65(void *src, void *dst, void *stream); +void LaunchTCVT_bf16_to_f16_4x200(void *src, void *dst, void *stream); +void LaunchTCVT_bf16_to_f16_1x129(void *src, void *dst, void *stream); +void LaunchTCVT_bf16_to_i32_1x128(void *src, void *dst, void *stream); +void LaunchTCVT_bf16_to_i32_2x64(void *src, void *dst, void *stream); +void LaunchTCVT_bf16_to_i32_4x32(void *src, void *dst, void *stream); +void LaunchTCVT_bf16_to_i32_2x128(void *src, void *dst, void *stream); +void LaunchTCVT_bf16_to_i32_4x65(void *src, void *dst, void *stream); +void LaunchTCVT_bf16_to_i32_4x200(void *src, void *dst, void *stream); +void LaunchTCVT_bf16_to_i32_1x129(void *src, void *dst, void *stream); +void LaunchTCVT_ui8_to_f16_1x128(void *src, void *dst, void *stream); +void LaunchTCVT_ui8_to_f16_2x64(void *src, void *dst, void *stream); +void LaunchTCVT_ui8_to_f16_4x32(void *src, void *dst, void *stream); +void LaunchTCVT_ui8_to_f16_2x128(void *src, void *dst, void *stream); +void LaunchTCVT_ui8_to_f16_4x65(void *src, void *dst, void *stream); +void LaunchTCVT_ui8_to_f16_4x200(void *src, void *dst, void *stream); +void LaunchTCVT_ui8_to_f16_1x129(void *src, void *dst, void *stream); +void LaunchTCVT_ui8_to_ui16_1x128(void *src, void *dst, void *stream); +void LaunchTCVT_ui8_to_ui16_2x64(void *src, void *dst, void *stream); +void LaunchTCVT_ui8_to_ui16_4x32(void *src, void *dst, void *stream); +void LaunchTCVT_ui8_to_ui16_2x128(void *src, void *dst, void *stream); +void LaunchTCVT_ui8_to_ui16_4x65(void *src, void *dst, void *stream); +void LaunchTCVT_ui8_to_ui16_4x200(void *src, void *dst, void *stream); +void LaunchTCVT_ui8_to_ui16_1x129(void *src, void *dst, void *stream); +void LaunchTCVT_si8_to_f16_1x128(void *src, void *dst, void *stream); +void LaunchTCVT_si8_to_f16_2x64(void *src, void *dst, void *stream); +void LaunchTCVT_si8_to_f16_4x32(void *src, void *dst, void *stream); +void LaunchTCVT_si8_to_f16_2x128(void *src, void *dst, void *stream); +void LaunchTCVT_si8_to_f16_4x65(void *src, void *dst, void *stream); +void LaunchTCVT_si8_to_f16_4x200(void *src, void *dst, void *stream); +void LaunchTCVT_si8_to_f16_1x129(void *src, void *dst, void *stream); +void LaunchTCVT_si8_to_si16_1x128(void *src, void *dst, void *stream); +void LaunchTCVT_si8_to_si16_2x64(void *src, void *dst, void *stream); +void LaunchTCVT_si8_to_si16_4x32(void *src, void *dst, void *stream); +void LaunchTCVT_si8_to_si16_2x128(void *src, void *dst, void *stream); +void LaunchTCVT_si8_to_si16_4x65(void *src, void *dst, void *stream); +void LaunchTCVT_si8_to_si16_4x200(void *src, void *dst, void *stream); +void LaunchTCVT_si8_to_si16_1x129(void *src, void *dst, void *stream); +void LaunchTCVT_si8_to_i32_1x128(void *src, void *dst, void *stream); +void LaunchTCVT_si8_to_i32_2x64(void *src, void *dst, void *stream); +void LaunchTCVT_si8_to_i32_4x32(void *src, void *dst, void *stream); +void LaunchTCVT_si8_to_i32_2x128(void *src, void *dst, void *stream); +void LaunchTCVT_si8_to_i32_4x65(void *src, void *dst, void *stream); +void LaunchTCVT_si8_to_i32_4x200(void *src, void *dst, void *stream); +void LaunchTCVT_si8_to_i32_1x129(void *src, void *dst, void *stream); +void LaunchTCVT_i16_to_ui8_1x128(void *src, void *dst, void *stream); +void LaunchTCVT_i16_to_ui8_2x64(void *src, void *dst, void *stream); +void LaunchTCVT_i16_to_ui8_4x32(void *src, void *dst, void *stream); +void LaunchTCVT_i16_to_ui8_2x128(void *src, void *dst, void *stream); +void LaunchTCVT_i16_to_ui8_4x65(void *src, void *dst, void *stream); +void LaunchTCVT_i16_to_ui8_4x200(void *src, void *dst, void *stream); +void LaunchTCVT_i16_to_ui8_1x129(void *src, void *dst, void *stream); +void LaunchTCVT_i16_to_f16_1x128(void *src, void *dst, void *stream); +void LaunchTCVT_i16_to_f16_2x64(void *src, void *dst, void *stream); +void LaunchTCVT_i16_to_f16_4x32(void *src, void *dst, void *stream); +void LaunchTCVT_i16_to_f16_2x128(void *src, void *dst, void *stream); +void LaunchTCVT_i16_to_f16_4x65(void *src, void *dst, void *stream); +void LaunchTCVT_i16_to_f16_4x200(void *src, void *dst, void *stream); +void LaunchTCVT_i16_to_f16_1x129(void *src, void *dst, void *stream); +void LaunchTCVT_i16_to_f32_1x128(void *src, void *dst, void *stream); +void LaunchTCVT_i16_to_f32_2x64(void *src, void *dst, void *stream); +void LaunchTCVT_i16_to_f32_4x32(void *src, void *dst, void *stream); +void LaunchTCVT_i16_to_f32_2x128(void *src, void *dst, void *stream); +void LaunchTCVT_i16_to_f32_4x65(void *src, void *dst, void *stream); +void LaunchTCVT_i16_to_f32_4x200(void *src, void *dst, void *stream); +void LaunchTCVT_i16_to_f32_1x129(void *src, void *dst, void *stream); +void LaunchTCVT_i16_to_ui32_1x128(void *src, void *dst, void *stream); +void LaunchTCVT_i16_to_ui32_2x64(void *src, void *dst, void *stream); +void LaunchTCVT_i16_to_ui32_4x32(void *src, void *dst, void *stream); +void LaunchTCVT_i16_to_ui32_2x128(void *src, void *dst, void *stream); +void LaunchTCVT_i16_to_ui32_4x65(void *src, void *dst, void *stream); +void LaunchTCVT_i16_to_ui32_4x200(void *src, void *dst, void *stream); +void LaunchTCVT_i16_to_ui32_1x129(void *src, void *dst, void *stream); +void LaunchTCVT_i16_to_i32_1x128(void *src, void *dst, void *stream); +void LaunchTCVT_i16_to_i32_2x64(void *src, void *dst, void *stream); +void LaunchTCVT_i16_to_i32_4x32(void *src, void *dst, void *stream); +void LaunchTCVT_i16_to_i32_2x128(void *src, void *dst, void *stream); +void LaunchTCVT_i16_to_i32_4x65(void *src, void *dst, void *stream); +void LaunchTCVT_i16_to_i32_4x200(void *src, void *dst, void *stream); +void LaunchTCVT_i16_to_i32_1x129(void *src, void *dst, void *stream); +void LaunchTCVT_i32_to_f32_1x128(void *src, void *dst, void *stream); +void LaunchTCVT_i32_to_f32_2x64(void *src, void *dst, void *stream); +void LaunchTCVT_i32_to_f32_4x32(void *src, void *dst, void *stream); +void LaunchTCVT_i32_to_f32_2x128(void *src, void *dst, void *stream); +void LaunchTCVT_i32_to_f32_4x65(void *src, void *dst, void *stream); +void LaunchTCVT_i32_to_f32_4x200(void *src, void *dst, void *stream); +void LaunchTCVT_i32_to_f32_1x129(void *src, void *dst, void *stream); +void LaunchTCVT_i32_to_i16_1x128(void *src, void *dst, void *stream); +void LaunchTCVT_i32_to_i16_2x64(void *src, void *dst, void *stream); +void LaunchTCVT_i32_to_i16_4x32(void *src, void *dst, void *stream); +void LaunchTCVT_i32_to_i16_2x128(void *src, void *dst, void *stream); +void LaunchTCVT_i32_to_i16_4x65(void *src, void *dst, void *stream); +void LaunchTCVT_i32_to_i16_4x200(void *src, void *dst, void *stream); +void LaunchTCVT_i32_to_i16_1x129(void *src, void *dst, void *stream); +void LaunchTCVT_i32_to_i64_1x128(void *src, void *dst, void *stream); +void LaunchTCVT_i32_to_i64_2x64(void *src, void *dst, void *stream); +void LaunchTCVT_i32_to_i64_4x32(void *src, void *dst, void *stream); +void LaunchTCVT_i32_to_i64_2x128(void *src, void *dst, void *stream); +void LaunchTCVT_i32_to_i64_4x65(void *src, void *dst, void *stream); +void LaunchTCVT_i32_to_i64_4x200(void *src, void *dst, void *stream); +void LaunchTCVT_i32_to_i64_1x129(void *src, void *dst, void *stream); +void LaunchTCVT_i32_to_ui8_1x128(void *src, void *dst, void *stream); +void LaunchTCVT_i32_to_ui8_2x64(void *src, void *dst, void *stream); +void LaunchTCVT_i32_to_ui8_4x32(void *src, void *dst, void *stream); +void LaunchTCVT_i32_to_ui8_2x128(void *src, void *dst, void *stream); +void LaunchTCVT_i32_to_ui8_4x65(void *src, void *dst, void *stream); +void LaunchTCVT_i32_to_ui8_4x200(void *src, void *dst, void *stream); +void LaunchTCVT_i32_to_ui8_1x129(void *src, void *dst, void *stream); +void LaunchTCVT_i32_to_ui16_1x128(void *src, void *dst, void *stream); +void LaunchTCVT_i32_to_ui16_2x64(void *src, void *dst, void *stream); +void LaunchTCVT_i32_to_ui16_4x32(void *src, void *dst, void *stream); +void LaunchTCVT_i32_to_ui16_2x128(void *src, void *dst, void *stream); +void LaunchTCVT_i32_to_ui16_4x65(void *src, void *dst, void *stream); +void LaunchTCVT_i32_to_ui16_4x200(void *src, void *dst, void *stream); +void LaunchTCVT_i32_to_ui16_1x129(void *src, void *dst, void *stream); +void LaunchTCVT_ui32_to_i16_1x128(void *src, void *dst, void *stream); +void LaunchTCVT_ui32_to_i16_2x64(void *src, void *dst, void *stream); +void LaunchTCVT_ui32_to_i16_4x32(void *src, void *dst, void *stream); +void LaunchTCVT_ui32_to_i16_2x128(void *src, void *dst, void *stream); +void LaunchTCVT_ui32_to_i16_4x65(void *src, void *dst, void *stream); +void LaunchTCVT_ui32_to_i16_4x200(void *src, void *dst, void *stream); +void LaunchTCVT_ui32_to_i16_1x129(void *src, void *dst, void *stream); +void LaunchTCVT_ui32_to_ui16_1x128(void *src, void *dst, void *stream); +void LaunchTCVT_ui32_to_ui16_2x64(void *src, void *dst, void *stream); +void LaunchTCVT_ui32_to_ui16_4x32(void *src, void *dst, void *stream); +void LaunchTCVT_ui32_to_ui16_2x128(void *src, void *dst, void *stream); +void LaunchTCVT_ui32_to_ui16_4x65(void *src, void *dst, void *stream); +void LaunchTCVT_ui32_to_ui16_4x200(void *src, void *dst, void *stream); +void LaunchTCVT_ui32_to_ui16_1x129(void *src, void *dst, void *stream); +void LaunchTCVT_ui32_to_ui8_1x128(void *src, void *dst, void *stream); +void LaunchTCVT_ui32_to_ui8_2x64(void *src, void *dst, void *stream); +void LaunchTCVT_ui32_to_ui8_4x32(void *src, void *dst, void *stream); +void LaunchTCVT_ui32_to_ui8_2x128(void *src, void *dst, void *stream); +void LaunchTCVT_ui32_to_ui8_4x65(void *src, void *dst, void *stream); +void LaunchTCVT_ui32_to_ui8_4x200(void *src, void *dst, void *stream); +void LaunchTCVT_ui32_to_ui8_1x129(void *src, void *dst, void *stream); +void LaunchTCVT_i64_to_f32_1x128(void *src, void *dst, void *stream); +void LaunchTCVT_i64_to_f32_2x64(void *src, void *dst, void *stream); +void LaunchTCVT_i64_to_f32_4x32(void *src, void *dst, void *stream); +void LaunchTCVT_i64_to_f32_2x128(void *src, void *dst, void *stream); +void LaunchTCVT_i64_to_f32_4x65(void *src, void *dst, void *stream); +void LaunchTCVT_i64_to_f32_4x200(void *src, void *dst, void *stream); +void LaunchTCVT_i64_to_f32_1x129(void *src, void *dst, void *stream); +void LaunchTCVT_i64_to_i32_1x128(void *src, void *dst, void *stream); +void LaunchTCVT_i64_to_i32_2x64(void *src, void *dst, void *stream); +void LaunchTCVT_i64_to_i32_4x32(void *src, void *dst, void *stream); +void LaunchTCVT_i64_to_i32_2x128(void *src, void *dst, void *stream); +void LaunchTCVT_i64_to_i32_4x65(void *src, void *dst, void *stream); +void LaunchTCVT_i64_to_i32_4x200(void *src, void *dst, void *stream); +void LaunchTCVT_i64_to_i32_1x129(void *src, void *dst, void *stream); + +using LaunchFn = void (*)(void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t srcRows; + size_t srcCols; + size_t dstRows; + size_t dstCols; + size_t srcElemSize; + size_t dstElemSize; +}; + +static const TestCase kCases[] = { + {"f32_to_i32_rint_16x64", LaunchTCVT_f32_to_i32_rint_16x64, 16, 64, 16, 64, sizeof(float), sizeof(int32_t)}, + {"f32_to_i32_round_16x64", LaunchTCVT_f32_to_i32_round_16x64, 16, 64, 16, 64, sizeof(float), sizeof(int32_t)}, + {"i32_to_f32_rint_16x64", LaunchTCVT_i32_to_f32_rint_16x64, 16, 64, 16, 64, sizeof(int32_t), sizeof(float)}, + {"f32_to_f16_rint_16x64", LaunchTCVT_f32_to_f16_rint_16x64, 16, 64, 16, 64, sizeof(float), sizeof(uint16_t)}, + {"f16_to_f32_rint_16x64", LaunchTCVT_f16_to_f32_rint_16x64, 16, 64, 16, 64, sizeof(uint16_t), sizeof(float)}, + {"f32_to_f16_1x128", LaunchTCVT_f32_to_f16_1x128, 1, 128, 1, 128, sizeof(float), sizeof(uint16_t)}, + {"f32_to_f16_2x64", LaunchTCVT_f32_to_f16_2x64, 2, 64, 2, 64, sizeof(float), sizeof(uint16_t)}, + {"f32_to_f16_4x32", LaunchTCVT_f32_to_f16_4x32, 4, 32, 4, 32, sizeof(float), sizeof(uint16_t)}, + {"f32_to_f16_2x128", LaunchTCVT_f32_to_f16_2x128, 2, 128, 2, 128, sizeof(float), sizeof(uint16_t)}, + {"f32_to_f16_4x65", LaunchTCVT_f32_to_f16_4x65, 4, 128, 4, 128, sizeof(float), sizeof(uint16_t)}, + {"f32_to_f16_4x200", LaunchTCVT_f32_to_f16_4x200, 4, 256, 4, 256, sizeof(float), sizeof(uint16_t)}, + {"f32_to_f16_1x129", LaunchTCVT_f32_to_f16_1x129, 1, 256, 1, 256, sizeof(float), sizeof(uint16_t)}, + {"f32_to_bf16_1x128", LaunchTCVT_f32_to_bf16_1x128, 1, 128, 1, 128, sizeof(float), sizeof(uint16_t)}, + {"f32_to_bf16_2x64", LaunchTCVT_f32_to_bf16_2x64, 2, 64, 2, 64, sizeof(float), sizeof(uint16_t)}, + {"f32_to_bf16_4x32", LaunchTCVT_f32_to_bf16_4x32, 4, 32, 4, 32, sizeof(float), sizeof(uint16_t)}, + {"f32_to_bf16_2x128", LaunchTCVT_f32_to_bf16_2x128, 2, 128, 2, 128, sizeof(float), sizeof(uint16_t)}, + {"f32_to_bf16_4x65", LaunchTCVT_f32_to_bf16_4x65, 4, 128, 4, 128, sizeof(float), sizeof(uint16_t)}, + {"f32_to_bf16_4x200", LaunchTCVT_f32_to_bf16_4x200, 4, 256, 4, 256, sizeof(float), sizeof(uint16_t)}, + {"f32_to_bf16_1x129", LaunchTCVT_f32_to_bf16_1x129, 1, 256, 1, 256, sizeof(float), sizeof(uint16_t)}, + {"f32_to_i16_1x128", LaunchTCVT_f32_to_i16_1x128, 1, 128, 1, 128, sizeof(float), sizeof(int16_t)}, + {"f32_to_i16_2x64", LaunchTCVT_f32_to_i16_2x64, 2, 64, 2, 64, sizeof(float), sizeof(int16_t)}, + {"f32_to_i16_4x32", LaunchTCVT_f32_to_i16_4x32, 4, 32, 4, 32, sizeof(float), sizeof(int16_t)}, + {"f32_to_i16_2x128", LaunchTCVT_f32_to_i16_2x128, 2, 128, 2, 128, sizeof(float), sizeof(int16_t)}, + {"f32_to_i16_4x65", LaunchTCVT_f32_to_i16_4x65, 4, 128, 4, 128, sizeof(float), sizeof(int16_t)}, + {"f32_to_i16_4x200", LaunchTCVT_f32_to_i16_4x200, 4, 256, 4, 256, sizeof(float), sizeof(int16_t)}, + {"f32_to_i16_1x129", LaunchTCVT_f32_to_i16_1x129, 1, 256, 1, 256, sizeof(float), sizeof(int16_t)}, + {"f32_to_i32_1x128", LaunchTCVT_f32_to_i32_1x128, 1, 128, 1, 128, sizeof(float), sizeof(int32_t)}, + {"f32_to_i32_2x64", LaunchTCVT_f32_to_i32_2x64, 2, 64, 2, 64, sizeof(float), sizeof(int32_t)}, + {"f32_to_i32_4x32", LaunchTCVT_f32_to_i32_4x32, 4, 32, 4, 32, sizeof(float), sizeof(int32_t)}, + {"f32_to_i32_2x128", LaunchTCVT_f32_to_i32_2x128, 2, 128, 2, 128, sizeof(float), sizeof(int32_t)}, + {"f32_to_i32_4x65", LaunchTCVT_f32_to_i32_4x65, 4, 128, 4, 128, sizeof(float), sizeof(int32_t)}, + {"f32_to_i32_4x200", LaunchTCVT_f32_to_i32_4x200, 4, 256, 4, 256, sizeof(float), sizeof(int32_t)}, + {"f32_to_i32_1x129", LaunchTCVT_f32_to_i32_1x129, 1, 256, 1, 256, sizeof(float), sizeof(int32_t)}, + {"f32_to_i64_1x128", LaunchTCVT_f32_to_i64_1x128, 1, 128, 1, 128, sizeof(float), sizeof(int64_t)}, + {"f32_to_i64_2x64", LaunchTCVT_f32_to_i64_2x64, 2, 64, 2, 64, sizeof(float), sizeof(int64_t)}, + {"f32_to_i64_4x32", LaunchTCVT_f32_to_i64_4x32, 4, 32, 4, 32, sizeof(float), sizeof(int64_t)}, + {"f32_to_i64_2x128", LaunchTCVT_f32_to_i64_2x128, 2, 128, 2, 128, sizeof(float), sizeof(int64_t)}, + {"f32_to_i64_4x65", LaunchTCVT_f32_to_i64_4x65, 4, 128, 4, 128, sizeof(float), sizeof(int64_t)}, + {"f32_to_i64_4x200", LaunchTCVT_f32_to_i64_4x200, 4, 256, 4, 256, sizeof(float), sizeof(int64_t)}, + {"f32_to_i64_1x129", LaunchTCVT_f32_to_i64_1x129, 1, 256, 1, 256, sizeof(float), sizeof(int64_t)}, + {"f32_to_f32_1x128", LaunchTCVT_f32_to_f32_1x128, 1, 128, 1, 128, sizeof(float), sizeof(float)}, + {"f32_to_f32_2x64", LaunchTCVT_f32_to_f32_2x64, 2, 64, 2, 64, sizeof(float), sizeof(float)}, + {"f32_to_f32_4x32", LaunchTCVT_f32_to_f32_4x32, 4, 32, 4, 32, sizeof(float), sizeof(float)}, + {"f32_to_f32_2x128", LaunchTCVT_f32_to_f32_2x128, 2, 128, 2, 128, sizeof(float), sizeof(float)}, + {"f32_to_f32_4x65", LaunchTCVT_f32_to_f32_4x65, 4, 128, 4, 128, sizeof(float), sizeof(float)}, + {"f32_to_f32_4x200", LaunchTCVT_f32_to_f32_4x200, 4, 256, 4, 256, sizeof(float), sizeof(float)}, + {"f32_to_f32_1x129", LaunchTCVT_f32_to_f32_1x129, 1, 256, 1, 256, sizeof(float), sizeof(float)}, + {"f16_to_f32_1x128", LaunchTCVT_f16_to_f32_1x128, 1, 128, 1, 128, sizeof(uint16_t), sizeof(float)}, + {"f16_to_f32_2x64", LaunchTCVT_f16_to_f32_2x64, 2, 64, 2, 64, sizeof(uint16_t), sizeof(float)}, + {"f16_to_f32_4x32", LaunchTCVT_f16_to_f32_4x32, 4, 32, 4, 32, sizeof(uint16_t), sizeof(float)}, + {"f16_to_f32_2x128", LaunchTCVT_f16_to_f32_2x128, 2, 128, 2, 128, sizeof(uint16_t), sizeof(float)}, + {"f16_to_f32_4x65", LaunchTCVT_f16_to_f32_4x65, 4, 128, 4, 128, sizeof(uint16_t), sizeof(float)}, + {"f16_to_f32_4x200", LaunchTCVT_f16_to_f32_4x200, 4, 256, 4, 256, sizeof(uint16_t), sizeof(float)}, + {"f16_to_f32_1x129", LaunchTCVT_f16_to_f32_1x129, 1, 256, 1, 256, sizeof(uint16_t), sizeof(float)}, + {"f16_to_i32_1x128", LaunchTCVT_f16_to_i32_1x128, 1, 128, 1, 128, sizeof(uint16_t), sizeof(int32_t)}, + {"f16_to_i32_2x64", LaunchTCVT_f16_to_i32_2x64, 2, 64, 2, 64, sizeof(uint16_t), sizeof(int32_t)}, + {"f16_to_i32_4x32", LaunchTCVT_f16_to_i32_4x32, 4, 32, 4, 32, sizeof(uint16_t), sizeof(int32_t)}, + {"f16_to_i32_2x128", LaunchTCVT_f16_to_i32_2x128, 2, 128, 2, 128, sizeof(uint16_t), sizeof(int32_t)}, + {"f16_to_i32_4x65", LaunchTCVT_f16_to_i32_4x65, 4, 128, 4, 128, sizeof(uint16_t), sizeof(int32_t)}, + {"f16_to_i32_4x200", LaunchTCVT_f16_to_i32_4x200, 4, 256, 4, 256, sizeof(uint16_t), sizeof(int32_t)}, + {"f16_to_i32_1x129", LaunchTCVT_f16_to_i32_1x129, 1, 256, 1, 256, sizeof(uint16_t), sizeof(int32_t)}, + {"f16_to_i16_1x128", LaunchTCVT_f16_to_i16_1x128, 1, 128, 1, 128, sizeof(uint16_t), sizeof(int16_t)}, + {"f16_to_i16_2x64", LaunchTCVT_f16_to_i16_2x64, 2, 64, 2, 64, sizeof(uint16_t), sizeof(int16_t)}, + {"f16_to_i16_4x32", LaunchTCVT_f16_to_i16_4x32, 4, 32, 4, 32, sizeof(uint16_t), sizeof(int16_t)}, + {"f16_to_i16_2x128", LaunchTCVT_f16_to_i16_2x128, 2, 128, 2, 128, sizeof(uint16_t), sizeof(int16_t)}, + {"f16_to_i16_4x65", LaunchTCVT_f16_to_i16_4x65, 4, 128, 4, 128, sizeof(uint16_t), sizeof(int16_t)}, + {"f16_to_i16_4x200", LaunchTCVT_f16_to_i16_4x200, 4, 256, 4, 256, sizeof(uint16_t), sizeof(int16_t)}, + {"f16_to_i16_1x129", LaunchTCVT_f16_to_i16_1x129, 1, 256, 1, 256, sizeof(uint16_t), sizeof(int16_t)}, + {"f16_to_si8_1x128", LaunchTCVT_f16_to_si8_1x128, 1, 128, 1, 128, sizeof(uint16_t), sizeof(int8_t)}, + {"f16_to_si8_2x64", LaunchTCVT_f16_to_si8_2x64, 2, 64, 2, 64, sizeof(uint16_t), sizeof(int8_t)}, + {"f16_to_si8_4x32", LaunchTCVT_f16_to_si8_4x32, 4, 32, 4, 32, sizeof(uint16_t), sizeof(int8_t)}, + {"f16_to_si8_2x128", LaunchTCVT_f16_to_si8_2x128, 2, 128, 2, 128, sizeof(uint16_t), sizeof(int8_t)}, + {"f16_to_si8_4x65", LaunchTCVT_f16_to_si8_4x65, 4, 128, 4, 128, sizeof(uint16_t), sizeof(int8_t)}, + {"f16_to_si8_4x200", LaunchTCVT_f16_to_si8_4x200, 4, 256, 4, 256, sizeof(uint16_t), sizeof(int8_t)}, + {"f16_to_si8_1x129", LaunchTCVT_f16_to_si8_1x129, 1, 256, 1, 256, sizeof(uint16_t), sizeof(int8_t)}, + {"f16_to_ui8_1x128", LaunchTCVT_f16_to_ui8_1x128, 1, 128, 1, 128, sizeof(uint16_t), sizeof(uint8_t)}, + {"f16_to_ui8_2x64", LaunchTCVT_f16_to_ui8_2x64, 2, 64, 2, 64, sizeof(uint16_t), sizeof(uint8_t)}, + {"f16_to_ui8_4x32", LaunchTCVT_f16_to_ui8_4x32, 4, 32, 4, 32, sizeof(uint16_t), sizeof(uint8_t)}, + {"f16_to_ui8_2x128", LaunchTCVT_f16_to_ui8_2x128, 2, 128, 2, 128, sizeof(uint16_t), sizeof(uint8_t)}, + {"f16_to_ui8_4x65", LaunchTCVT_f16_to_ui8_4x65, 4, 128, 4, 128, sizeof(uint16_t), sizeof(uint8_t)}, + {"f16_to_ui8_4x200", LaunchTCVT_f16_to_ui8_4x200, 4, 256, 4, 256, sizeof(uint16_t), sizeof(uint8_t)}, + {"f16_to_ui8_1x129", LaunchTCVT_f16_to_ui8_1x129, 1, 256, 1, 256, sizeof(uint16_t), sizeof(uint8_t)}, + {"bf16_to_f32_1x128", LaunchTCVT_bf16_to_f32_1x128, 1, 128, 1, 128, sizeof(uint16_t), sizeof(float)}, + {"bf16_to_f32_2x64", LaunchTCVT_bf16_to_f32_2x64, 2, 64, 2, 64, sizeof(uint16_t), sizeof(float)}, + {"bf16_to_f32_4x32", LaunchTCVT_bf16_to_f32_4x32, 4, 32, 4, 32, sizeof(uint16_t), sizeof(float)}, + {"bf16_to_f32_2x128", LaunchTCVT_bf16_to_f32_2x128, 2, 128, 2, 128, sizeof(uint16_t), sizeof(float)}, + {"bf16_to_f32_4x65", LaunchTCVT_bf16_to_f32_4x65, 4, 128, 4, 128, sizeof(uint16_t), sizeof(float)}, + {"bf16_to_f32_4x200", LaunchTCVT_bf16_to_f32_4x200, 4, 256, 4, 256, sizeof(uint16_t), sizeof(float)}, + {"bf16_to_f32_1x129", LaunchTCVT_bf16_to_f32_1x129, 1, 256, 1, 256, sizeof(uint16_t), sizeof(float)}, + {"bf16_to_f16_1x128", LaunchTCVT_bf16_to_f16_1x128, 1, 128, 1, 128, sizeof(uint16_t), sizeof(uint16_t)}, + {"bf16_to_f16_2x64", LaunchTCVT_bf16_to_f16_2x64, 2, 64, 2, 64, sizeof(uint16_t), sizeof(uint16_t)}, + {"bf16_to_f16_4x32", LaunchTCVT_bf16_to_f16_4x32, 4, 32, 4, 32, sizeof(uint16_t), sizeof(uint16_t)}, + {"bf16_to_f16_2x128", LaunchTCVT_bf16_to_f16_2x128, 2, 128, 2, 128, sizeof(uint16_t), sizeof(uint16_t)}, + {"bf16_to_f16_4x65", LaunchTCVT_bf16_to_f16_4x65, 4, 128, 4, 128, sizeof(uint16_t), sizeof(uint16_t)}, + {"bf16_to_f16_4x200", LaunchTCVT_bf16_to_f16_4x200, 4, 256, 4, 256, sizeof(uint16_t), sizeof(uint16_t)}, + {"bf16_to_f16_1x129", LaunchTCVT_bf16_to_f16_1x129, 1, 256, 1, 256, sizeof(uint16_t), sizeof(uint16_t)}, + {"bf16_to_i32_1x128", LaunchTCVT_bf16_to_i32_1x128, 1, 128, 1, 128, sizeof(uint16_t), sizeof(int32_t)}, + {"bf16_to_i32_2x64", LaunchTCVT_bf16_to_i32_2x64, 2, 64, 2, 64, sizeof(uint16_t), sizeof(int32_t)}, + {"bf16_to_i32_4x32", LaunchTCVT_bf16_to_i32_4x32, 4, 32, 4, 32, sizeof(uint16_t), sizeof(int32_t)}, + {"bf16_to_i32_2x128", LaunchTCVT_bf16_to_i32_2x128, 2, 128, 2, 128, sizeof(uint16_t), sizeof(int32_t)}, + {"bf16_to_i32_4x65", LaunchTCVT_bf16_to_i32_4x65, 4, 128, 4, 128, sizeof(uint16_t), sizeof(int32_t)}, + {"bf16_to_i32_4x200", LaunchTCVT_bf16_to_i32_4x200, 4, 256, 4, 256, sizeof(uint16_t), sizeof(int32_t)}, + {"bf16_to_i32_1x129", LaunchTCVT_bf16_to_i32_1x129, 1, 256, 1, 256, sizeof(uint16_t), sizeof(int32_t)}, + {"ui8_to_f16_1x128", LaunchTCVT_ui8_to_f16_1x128, 1, 128, 1, 128, sizeof(uint8_t), sizeof(uint16_t)}, + {"ui8_to_f16_2x64", LaunchTCVT_ui8_to_f16_2x64, 2, 64, 2, 64, sizeof(uint8_t), sizeof(uint16_t)}, + {"ui8_to_f16_4x32", LaunchTCVT_ui8_to_f16_4x32, 4, 32, 4, 32, sizeof(uint8_t), sizeof(uint16_t)}, + {"ui8_to_f16_2x128", LaunchTCVT_ui8_to_f16_2x128, 2, 128, 2, 128, sizeof(uint8_t), sizeof(uint16_t)}, + {"ui8_to_f16_4x65", LaunchTCVT_ui8_to_f16_4x65, 4, 128, 4, 128, sizeof(uint8_t), sizeof(uint16_t)}, + {"ui8_to_f16_4x200", LaunchTCVT_ui8_to_f16_4x200, 4, 256, 4, 256, sizeof(uint8_t), sizeof(uint16_t)}, + {"ui8_to_f16_1x129", LaunchTCVT_ui8_to_f16_1x129, 1, 256, 1, 256, sizeof(uint8_t), sizeof(uint16_t)}, + {"ui8_to_ui16_1x128", LaunchTCVT_ui8_to_ui16_1x128, 1, 128, 1, 128, sizeof(uint8_t), sizeof(uint16_t)}, + {"ui8_to_ui16_2x64", LaunchTCVT_ui8_to_ui16_2x64, 2, 64, 2, 64, sizeof(uint8_t), sizeof(uint16_t)}, + {"ui8_to_ui16_4x32", LaunchTCVT_ui8_to_ui16_4x32, 4, 32, 4, 32, sizeof(uint8_t), sizeof(uint16_t)}, + {"ui8_to_ui16_2x128", LaunchTCVT_ui8_to_ui16_2x128, 2, 128, 2, 128, sizeof(uint8_t), sizeof(uint16_t)}, + {"ui8_to_ui16_4x65", LaunchTCVT_ui8_to_ui16_4x65, 4, 128, 4, 128, sizeof(uint8_t), sizeof(uint16_t)}, + {"ui8_to_ui16_4x200", LaunchTCVT_ui8_to_ui16_4x200, 4, 256, 4, 256, sizeof(uint8_t), sizeof(uint16_t)}, + {"ui8_to_ui16_1x129", LaunchTCVT_ui8_to_ui16_1x129, 1, 256, 1, 256, sizeof(uint8_t), sizeof(uint16_t)}, + {"si8_to_f16_1x128", LaunchTCVT_si8_to_f16_1x128, 1, 128, 1, 128, sizeof(int8_t), sizeof(uint16_t)}, + {"si8_to_f16_2x64", LaunchTCVT_si8_to_f16_2x64, 2, 64, 2, 64, sizeof(int8_t), sizeof(uint16_t)}, + {"si8_to_f16_4x32", LaunchTCVT_si8_to_f16_4x32, 4, 32, 4, 32, sizeof(int8_t), sizeof(uint16_t)}, + {"si8_to_f16_2x128", LaunchTCVT_si8_to_f16_2x128, 2, 128, 2, 128, sizeof(int8_t), sizeof(uint16_t)}, + {"si8_to_f16_4x65", LaunchTCVT_si8_to_f16_4x65, 4, 128, 4, 128, sizeof(int8_t), sizeof(uint16_t)}, + {"si8_to_f16_4x200", LaunchTCVT_si8_to_f16_4x200, 4, 256, 4, 256, sizeof(int8_t), sizeof(uint16_t)}, + {"si8_to_f16_1x129", LaunchTCVT_si8_to_f16_1x129, 1, 256, 1, 256, sizeof(int8_t), sizeof(uint16_t)}, + {"si8_to_si16_1x128", LaunchTCVT_si8_to_si16_1x128, 1, 128, 1, 128, sizeof(int8_t), sizeof(int16_t)}, + {"si8_to_si16_2x64", LaunchTCVT_si8_to_si16_2x64, 2, 64, 2, 64, sizeof(int8_t), sizeof(int16_t)}, + {"si8_to_si16_4x32", LaunchTCVT_si8_to_si16_4x32, 4, 32, 4, 32, sizeof(int8_t), sizeof(int16_t)}, + {"si8_to_si16_2x128", LaunchTCVT_si8_to_si16_2x128, 2, 128, 2, 128, sizeof(int8_t), sizeof(int16_t)}, + {"si8_to_si16_4x65", LaunchTCVT_si8_to_si16_4x65, 4, 128, 4, 128, sizeof(int8_t), sizeof(int16_t)}, + {"si8_to_si16_4x200", LaunchTCVT_si8_to_si16_4x200, 4, 256, 4, 256, sizeof(int8_t), sizeof(int16_t)}, + {"si8_to_si16_1x129", LaunchTCVT_si8_to_si16_1x129, 1, 256, 1, 256, sizeof(int8_t), sizeof(int16_t)}, + {"si8_to_i32_1x128", LaunchTCVT_si8_to_i32_1x128, 1, 128, 1, 128, sizeof(int8_t), sizeof(int32_t)}, + {"si8_to_i32_2x64", LaunchTCVT_si8_to_i32_2x64, 2, 64, 2, 64, sizeof(int8_t), sizeof(int32_t)}, + {"si8_to_i32_4x32", LaunchTCVT_si8_to_i32_4x32, 4, 32, 4, 32, sizeof(int8_t), sizeof(int32_t)}, + {"si8_to_i32_2x128", LaunchTCVT_si8_to_i32_2x128, 2, 128, 2, 128, sizeof(int8_t), sizeof(int32_t)}, + {"si8_to_i32_4x65", LaunchTCVT_si8_to_i32_4x65, 4, 128, 4, 128, sizeof(int8_t), sizeof(int32_t)}, + {"si8_to_i32_4x200", LaunchTCVT_si8_to_i32_4x200, 4, 256, 4, 256, sizeof(int8_t), sizeof(int32_t)}, + {"si8_to_i32_1x129", LaunchTCVT_si8_to_i32_1x129, 1, 256, 1, 256, sizeof(int8_t), sizeof(int32_t)}, + {"i16_to_ui8_1x128", LaunchTCVT_i16_to_ui8_1x128, 1, 128, 1, 128, sizeof(int16_t), sizeof(uint8_t)}, + {"i16_to_ui8_2x64", LaunchTCVT_i16_to_ui8_2x64, 2, 64, 2, 64, sizeof(int16_t), sizeof(uint8_t)}, + {"i16_to_ui8_4x32", LaunchTCVT_i16_to_ui8_4x32, 4, 32, 4, 32, sizeof(int16_t), sizeof(uint8_t)}, + {"i16_to_ui8_2x128", LaunchTCVT_i16_to_ui8_2x128, 2, 128, 2, 128, sizeof(int16_t), sizeof(uint8_t)}, + {"i16_to_ui8_4x65", LaunchTCVT_i16_to_ui8_4x65, 4, 128, 4, 128, sizeof(int16_t), sizeof(uint8_t)}, + {"i16_to_ui8_4x200", LaunchTCVT_i16_to_ui8_4x200, 4, 256, 4, 256, sizeof(int16_t), sizeof(uint8_t)}, + {"i16_to_ui8_1x129", LaunchTCVT_i16_to_ui8_1x129, 1, 256, 1, 256, sizeof(int16_t), sizeof(uint8_t)}, + {"i16_to_f16_1x128", LaunchTCVT_i16_to_f16_1x128, 1, 128, 1, 128, sizeof(int16_t), sizeof(uint16_t)}, + {"i16_to_f16_2x64", LaunchTCVT_i16_to_f16_2x64, 2, 64, 2, 64, sizeof(int16_t), sizeof(uint16_t)}, + {"i16_to_f16_4x32", LaunchTCVT_i16_to_f16_4x32, 4, 32, 4, 32, sizeof(int16_t), sizeof(uint16_t)}, + {"i16_to_f16_2x128", LaunchTCVT_i16_to_f16_2x128, 2, 128, 2, 128, sizeof(int16_t), sizeof(uint16_t)}, + {"i16_to_f16_4x65", LaunchTCVT_i16_to_f16_4x65, 4, 128, 4, 128, sizeof(int16_t), sizeof(uint16_t)}, + {"i16_to_f16_4x200", LaunchTCVT_i16_to_f16_4x200, 4, 256, 4, 256, sizeof(int16_t), sizeof(uint16_t)}, + {"i16_to_f16_1x129", LaunchTCVT_i16_to_f16_1x129, 1, 256, 1, 256, sizeof(int16_t), sizeof(uint16_t)}, + {"i16_to_f32_1x128", LaunchTCVT_i16_to_f32_1x128, 1, 128, 1, 128, sizeof(int16_t), sizeof(float)}, + {"i16_to_f32_2x64", LaunchTCVT_i16_to_f32_2x64, 2, 64, 2, 64, sizeof(int16_t), sizeof(float)}, + {"i16_to_f32_4x32", LaunchTCVT_i16_to_f32_4x32, 4, 32, 4, 32, sizeof(int16_t), sizeof(float)}, + {"i16_to_f32_2x128", LaunchTCVT_i16_to_f32_2x128, 2, 128, 2, 128, sizeof(int16_t), sizeof(float)}, + {"i16_to_f32_4x65", LaunchTCVT_i16_to_f32_4x65, 4, 128, 4, 128, sizeof(int16_t), sizeof(float)}, + {"i16_to_f32_4x200", LaunchTCVT_i16_to_f32_4x200, 4, 256, 4, 256, sizeof(int16_t), sizeof(float)}, + {"i16_to_f32_1x129", LaunchTCVT_i16_to_f32_1x129, 1, 256, 1, 256, sizeof(int16_t), sizeof(float)}, + {"i16_to_ui32_1x128", LaunchTCVT_i16_to_ui32_1x128, 1, 128, 1, 128, sizeof(int16_t), sizeof(uint32_t)}, + {"i16_to_ui32_2x64", LaunchTCVT_i16_to_ui32_2x64, 2, 64, 2, 64, sizeof(int16_t), sizeof(uint32_t)}, + {"i16_to_ui32_4x32", LaunchTCVT_i16_to_ui32_4x32, 4, 32, 4, 32, sizeof(int16_t), sizeof(uint32_t)}, + {"i16_to_ui32_2x128", LaunchTCVT_i16_to_ui32_2x128, 2, 128, 2, 128, sizeof(int16_t), sizeof(uint32_t)}, + {"i16_to_ui32_4x65", LaunchTCVT_i16_to_ui32_4x65, 4, 128, 4, 128, sizeof(int16_t), sizeof(uint32_t)}, + {"i16_to_ui32_4x200", LaunchTCVT_i16_to_ui32_4x200, 4, 256, 4, 256, sizeof(int16_t), sizeof(uint32_t)}, + {"i16_to_ui32_1x129", LaunchTCVT_i16_to_ui32_1x129, 1, 256, 1, 256, sizeof(int16_t), sizeof(uint32_t)}, + {"i16_to_i32_1x128", LaunchTCVT_i16_to_i32_1x128, 1, 128, 1, 128, sizeof(int16_t), sizeof(int32_t)}, + {"i16_to_i32_2x64", LaunchTCVT_i16_to_i32_2x64, 2, 64, 2, 64, sizeof(int16_t), sizeof(int32_t)}, + {"i16_to_i32_4x32", LaunchTCVT_i16_to_i32_4x32, 4, 32, 4, 32, sizeof(int16_t), sizeof(int32_t)}, + {"i16_to_i32_2x128", LaunchTCVT_i16_to_i32_2x128, 2, 128, 2, 128, sizeof(int16_t), sizeof(int32_t)}, + {"i16_to_i32_4x65", LaunchTCVT_i16_to_i32_4x65, 4, 128, 4, 128, sizeof(int16_t), sizeof(int32_t)}, + {"i16_to_i32_4x200", LaunchTCVT_i16_to_i32_4x200, 4, 256, 4, 256, sizeof(int16_t), sizeof(int32_t)}, + {"i16_to_i32_1x129", LaunchTCVT_i16_to_i32_1x129, 1, 256, 1, 256, sizeof(int16_t), sizeof(int32_t)}, + {"i32_to_f32_1x128", LaunchTCVT_i32_to_f32_1x128, 1, 128, 1, 128, sizeof(int32_t), sizeof(float)}, + {"i32_to_f32_2x64", LaunchTCVT_i32_to_f32_2x64, 2, 64, 2, 64, sizeof(int32_t), sizeof(float)}, + {"i32_to_f32_4x32", LaunchTCVT_i32_to_f32_4x32, 4, 32, 4, 32, sizeof(int32_t), sizeof(float)}, + {"i32_to_f32_2x128", LaunchTCVT_i32_to_f32_2x128, 2, 128, 2, 128, sizeof(int32_t), sizeof(float)}, + {"i32_to_f32_4x65", LaunchTCVT_i32_to_f32_4x65, 4, 128, 4, 128, sizeof(int32_t), sizeof(float)}, + {"i32_to_f32_4x200", LaunchTCVT_i32_to_f32_4x200, 4, 256, 4, 256, sizeof(int32_t), sizeof(float)}, + {"i32_to_f32_1x129", LaunchTCVT_i32_to_f32_1x129, 1, 256, 1, 256, sizeof(int32_t), sizeof(float)}, + {"i32_to_i16_1x128", LaunchTCVT_i32_to_i16_1x128, 1, 128, 1, 128, sizeof(int32_t), sizeof(int16_t)}, + {"i32_to_i16_2x64", LaunchTCVT_i32_to_i16_2x64, 2, 64, 2, 64, sizeof(int32_t), sizeof(int16_t)}, + {"i32_to_i16_4x32", LaunchTCVT_i32_to_i16_4x32, 4, 32, 4, 32, sizeof(int32_t), sizeof(int16_t)}, + {"i32_to_i16_2x128", LaunchTCVT_i32_to_i16_2x128, 2, 128, 2, 128, sizeof(int32_t), sizeof(int16_t)}, + {"i32_to_i16_4x65", LaunchTCVT_i32_to_i16_4x65, 4, 128, 4, 128, sizeof(int32_t), sizeof(int16_t)}, + {"i32_to_i16_4x200", LaunchTCVT_i32_to_i16_4x200, 4, 256, 4, 256, sizeof(int32_t), sizeof(int16_t)}, + {"i32_to_i16_1x129", LaunchTCVT_i32_to_i16_1x129, 1, 256, 1, 256, sizeof(int32_t), sizeof(int16_t)}, + {"i32_to_i64_1x128", LaunchTCVT_i32_to_i64_1x128, 1, 128, 1, 128, sizeof(int32_t), sizeof(int64_t)}, + {"i32_to_i64_2x64", LaunchTCVT_i32_to_i64_2x64, 2, 64, 2, 64, sizeof(int32_t), sizeof(int64_t)}, + {"i32_to_i64_4x32", LaunchTCVT_i32_to_i64_4x32, 4, 32, 4, 32, sizeof(int32_t), sizeof(int64_t)}, + {"i32_to_i64_2x128", LaunchTCVT_i32_to_i64_2x128, 2, 128, 2, 128, sizeof(int32_t), sizeof(int64_t)}, + {"i32_to_i64_4x65", LaunchTCVT_i32_to_i64_4x65, 4, 128, 4, 128, sizeof(int32_t), sizeof(int64_t)}, + {"i32_to_i64_4x200", LaunchTCVT_i32_to_i64_4x200, 4, 256, 4, 256, sizeof(int32_t), sizeof(int64_t)}, + {"i32_to_i64_1x129", LaunchTCVT_i32_to_i64_1x129, 1, 256, 1, 256, sizeof(int32_t), sizeof(int64_t)}, + {"i32_to_ui8_1x128", LaunchTCVT_i32_to_ui8_1x128, 1, 128, 1, 128, sizeof(int32_t), sizeof(uint8_t)}, + {"i32_to_ui8_2x64", LaunchTCVT_i32_to_ui8_2x64, 2, 64, 2, 64, sizeof(int32_t), sizeof(uint8_t)}, + {"i32_to_ui8_4x32", LaunchTCVT_i32_to_ui8_4x32, 4, 32, 4, 32, sizeof(int32_t), sizeof(uint8_t)}, + {"i32_to_ui8_2x128", LaunchTCVT_i32_to_ui8_2x128, 2, 128, 2, 128, sizeof(int32_t), sizeof(uint8_t)}, + {"i32_to_ui8_4x65", LaunchTCVT_i32_to_ui8_4x65, 4, 128, 4, 128, sizeof(int32_t), sizeof(uint8_t)}, + {"i32_to_ui8_4x200", LaunchTCVT_i32_to_ui8_4x200, 4, 256, 4, 256, sizeof(int32_t), sizeof(uint8_t)}, + {"i32_to_ui8_1x129", LaunchTCVT_i32_to_ui8_1x129, 1, 256, 1, 256, sizeof(int32_t), sizeof(uint8_t)}, + {"i32_to_ui16_1x128", LaunchTCVT_i32_to_ui16_1x128, 1, 128, 1, 128, sizeof(int32_t), sizeof(uint16_t)}, + {"i32_to_ui16_2x64", LaunchTCVT_i32_to_ui16_2x64, 2, 64, 2, 64, sizeof(int32_t), sizeof(uint16_t)}, + {"i32_to_ui16_4x32", LaunchTCVT_i32_to_ui16_4x32, 4, 32, 4, 32, sizeof(int32_t), sizeof(uint16_t)}, + {"i32_to_ui16_2x128", LaunchTCVT_i32_to_ui16_2x128, 2, 128, 2, 128, sizeof(int32_t), sizeof(uint16_t)}, + {"i32_to_ui16_4x65", LaunchTCVT_i32_to_ui16_4x65, 4, 128, 4, 128, sizeof(int32_t), sizeof(uint16_t)}, + {"i32_to_ui16_4x200", LaunchTCVT_i32_to_ui16_4x200, 4, 256, 4, 256, sizeof(int32_t), sizeof(uint16_t)}, + {"i32_to_ui16_1x129", LaunchTCVT_i32_to_ui16_1x129, 1, 256, 1, 256, sizeof(int32_t), sizeof(uint16_t)}, + {"ui32_to_i16_1x128", LaunchTCVT_ui32_to_i16_1x128, 1, 128, 1, 128, sizeof(uint32_t), sizeof(int16_t)}, + {"ui32_to_i16_2x64", LaunchTCVT_ui32_to_i16_2x64, 2, 64, 2, 64, sizeof(uint32_t), sizeof(int16_t)}, + {"ui32_to_i16_4x32", LaunchTCVT_ui32_to_i16_4x32, 4, 32, 4, 32, sizeof(uint32_t), sizeof(int16_t)}, + {"ui32_to_i16_2x128", LaunchTCVT_ui32_to_i16_2x128, 2, 128, 2, 128, sizeof(uint32_t), sizeof(int16_t)}, + {"ui32_to_i16_4x65", LaunchTCVT_ui32_to_i16_4x65, 4, 128, 4, 128, sizeof(uint32_t), sizeof(int16_t)}, + {"ui32_to_i16_4x200", LaunchTCVT_ui32_to_i16_4x200, 4, 256, 4, 256, sizeof(uint32_t), sizeof(int16_t)}, + {"ui32_to_i16_1x129", LaunchTCVT_ui32_to_i16_1x129, 1, 256, 1, 256, sizeof(uint32_t), sizeof(int16_t)}, + {"ui32_to_ui16_1x128", LaunchTCVT_ui32_to_ui16_1x128, 1, 128, 1, 128, sizeof(uint32_t), sizeof(uint16_t)}, + {"ui32_to_ui16_2x64", LaunchTCVT_ui32_to_ui16_2x64, 2, 64, 2, 64, sizeof(uint32_t), sizeof(uint16_t)}, + {"ui32_to_ui16_4x32", LaunchTCVT_ui32_to_ui16_4x32, 4, 32, 4, 32, sizeof(uint32_t), sizeof(uint16_t)}, + {"ui32_to_ui16_2x128", LaunchTCVT_ui32_to_ui16_2x128, 2, 128, 2, 128, sizeof(uint32_t), sizeof(uint16_t)}, + {"ui32_to_ui16_4x65", LaunchTCVT_ui32_to_ui16_4x65, 4, 128, 4, 128, sizeof(uint32_t), sizeof(uint16_t)}, + {"ui32_to_ui16_4x200", LaunchTCVT_ui32_to_ui16_4x200, 4, 256, 4, 256, sizeof(uint32_t), sizeof(uint16_t)}, + {"ui32_to_ui16_1x129", LaunchTCVT_ui32_to_ui16_1x129, 1, 256, 1, 256, sizeof(uint32_t), sizeof(uint16_t)}, + {"ui32_to_ui8_1x128", LaunchTCVT_ui32_to_ui8_1x128, 1, 128, 1, 128, sizeof(uint32_t), sizeof(uint8_t)}, + {"ui32_to_ui8_2x64", LaunchTCVT_ui32_to_ui8_2x64, 2, 64, 2, 64, sizeof(uint32_t), sizeof(uint8_t)}, + {"ui32_to_ui8_4x32", LaunchTCVT_ui32_to_ui8_4x32, 4, 32, 4, 32, sizeof(uint32_t), sizeof(uint8_t)}, + {"ui32_to_ui8_2x128", LaunchTCVT_ui32_to_ui8_2x128, 2, 128, 2, 128, sizeof(uint32_t), sizeof(uint8_t)}, + {"ui32_to_ui8_4x65", LaunchTCVT_ui32_to_ui8_4x65, 4, 128, 4, 128, sizeof(uint32_t), sizeof(uint8_t)}, + {"ui32_to_ui8_4x200", LaunchTCVT_ui32_to_ui8_4x200, 4, 256, 4, 256, sizeof(uint32_t), sizeof(uint8_t)}, + {"ui32_to_ui8_1x129", LaunchTCVT_ui32_to_ui8_1x129, 1, 256, 1, 256, sizeof(uint32_t), sizeof(uint8_t)}, + {"i64_to_f32_1x128", LaunchTCVT_i64_to_f32_1x128, 1, 128, 1, 128, sizeof(int64_t), sizeof(float)}, + {"i64_to_f32_2x64", LaunchTCVT_i64_to_f32_2x64, 2, 64, 2, 64, sizeof(int64_t), sizeof(float)}, + {"i64_to_f32_4x32", LaunchTCVT_i64_to_f32_4x32, 4, 32, 4, 32, sizeof(int64_t), sizeof(float)}, + {"i64_to_f32_2x128", LaunchTCVT_i64_to_f32_2x128, 2, 128, 2, 128, sizeof(int64_t), sizeof(float)}, + {"i64_to_f32_4x65", LaunchTCVT_i64_to_f32_4x65, 4, 128, 4, 128, sizeof(int64_t), sizeof(float)}, + {"i64_to_f32_4x200", LaunchTCVT_i64_to_f32_4x200, 4, 256, 4, 256, sizeof(int64_t), sizeof(float)}, + {"i64_to_f32_1x129", LaunchTCVT_i64_to_f32_1x129, 1, 256, 1, 256, sizeof(int64_t), sizeof(float)}, + {"i64_to_i32_1x128", LaunchTCVT_i64_to_i32_1x128, 1, 128, 1, 128, sizeof(int64_t), sizeof(int32_t)}, + {"i64_to_i32_2x64", LaunchTCVT_i64_to_i32_2x64, 2, 64, 2, 64, sizeof(int64_t), sizeof(int32_t)}, + {"i64_to_i32_4x32", LaunchTCVT_i64_to_i32_4x32, 4, 32, 4, 32, sizeof(int64_t), sizeof(int32_t)}, + {"i64_to_i32_2x128", LaunchTCVT_i64_to_i32_2x128, 2, 128, 2, 128, sizeof(int64_t), sizeof(int32_t)}, + {"i64_to_i32_4x65", LaunchTCVT_i64_to_i32_4x65, 4, 128, 4, 128, sizeof(int64_t), sizeof(int32_t)}, + {"i64_to_i32_4x200", LaunchTCVT_i64_to_i32_4x200, 4, 256, 4, 256, sizeof(int64_t), sizeof(int32_t)}, + {"i64_to_i32_1x129", LaunchTCVT_i64_to_i32_1x129, 1, 256, 1, 256, sizeof(int64_t), sizeof(int32_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + (void)deviceId; + int rc = 0; + const size_t srcElemCount = tc.srcRows * tc.srcCols; + const size_t dstElemCount = tc.dstRows * tc.dstCols; + size_t srcFileSize = srcElemCount * tc.srcElemSize; + size_t dstFileSize = dstElemCount * tc.dstElemSize; + + std::printf("[INFO] === case: %s (src=%zux%zu, dst=%zux%zu) ===\n", + tc.name, tc.srcRows, tc.srcCols, tc.dstRows, tc.dstCols); + + std::string caseDir = std::string("./") + tc.name; + + void *srcHost = nullptr; + void *dstHost = nullptr; + void *srcDevice = nullptr; + void *dstDevice = nullptr; + + aclrtMallocHost(&srcHost, srcFileSize); + aclrtMallocHost(&dstHost, dstFileSize); + + aclrtMalloc(&srcDevice, srcFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input.bin").c_str(), srcFileSize, srcHost, srcFileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, srcFileSize, srcHost, srcFileSize, ACL_MEMCPY_HOST_TO_DEVICE); + tc.launch(srcDevice, dstDevice, stream); + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tcvt/tcvt.pto b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/tcvt.pto new file mode 100644 index 000000000..a807cf284 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tcvt/tcvt.pto @@ -0,0 +1,10187 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tcvt. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. +// Generated by gen_tcvt_pto.py from cases.py. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // Case 0: f32 -> i32, default RINT + func.func @TCVT_f32_to_i32_rint_16x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi32> -> !pto.partition_tensor_view<1x1x1x16x64xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xi32>) + return + } + + // Case 1: f32 -> i32, explicit ROUND + func.func @TCVT_f32_to_i32_round_16x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi32> -> !pto.partition_tensor_view<1x1x1x16x64xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src {rmode = #pto} : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xi32>) + return + } + + // Case 2: i32 -> f32, default RINT + func.func @TCVT_i32_to_f32_rint_16x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi32> -> !pto.partition_tensor_view<1x1x1x16x64xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x64xi32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + return + } + + // Case 3: f32 -> f16, default RINT + func.func @TCVT_f32_to_f16_rint_16x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + return + } + + // Case 4: f16 -> f32, default RINT + func.func @TCVT_f16_to_f32_rint_16x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + return + } + + // Case 5: f32 -> f16, default RINT + func.func @TCVT_f32_to_f16_1x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x128xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf16> -> !pto.partition_tensor_view<1x1x1x1x128xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x128xf32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x128xf16>) + return + } + + // Case 6: f32 -> f16, default RINT + func.func @TCVT_f32_to_f16_2x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xf32> -> !pto.partition_tensor_view<1x1x1x2x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xf16> -> !pto.partition_tensor_view<1x1x1x2x64xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x64xf32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x64xf16>) + return + } + + // Case 7: f32 -> f16, default RINT + func.func @TCVT_f32_to_f16_4x32(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xf32> -> !pto.partition_tensor_view<1x1x1x4x32xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xf16> -> !pto.partition_tensor_view<1x1x1x4x32xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x32xf32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x32xf16>) + return + } + + // Case 8: f32 -> f16, default RINT + func.func @TCVT_f32_to_f16_2x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xf32> -> !pto.partition_tensor_view<1x1x1x2x128xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xf16> -> !pto.partition_tensor_view<1x1x1x2x128xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x128xf32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x128xf16>) + return + } + + // Case 9: f32 -> f16, default RINT + func.func @TCVT_f32_to_f16_4x65(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c65 = arith.constant 65 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xf32> -> !pto.partition_tensor_view<1x1x1x4x65xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xf16> -> !pto.partition_tensor_view<1x1x1x4x65xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x65xf32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x65xf16>) + return + } + + // Case 10: f32 -> f16, default RINT + func.func @TCVT_f32_to_f16_4x200(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c200 = arith.constant 200 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xf32> -> !pto.partition_tensor_view<1x1x1x4x200xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xf16> -> !pto.partition_tensor_view<1x1x1x4x200xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x200xf32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x200xf16>) + return + } + + // Case 11: f32 -> f16, default RINT + func.func @TCVT_f32_to_f16_1x129(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c129 = arith.constant 129 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xf32> -> !pto.partition_tensor_view<1x1x1x1x129xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x129xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x129xf32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x129xf16>) + return + } + + // Case 12: f32 -> bf16, default RINT + func.func @TCVT_f32_to_bf16_1x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xbf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x128xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xbf16> -> !pto.partition_tensor_view<1x1x1x1x128xbf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x128xf32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x128xbf16>) + return + } + + // Case 13: f32 -> bf16, default RINT + func.func @TCVT_f32_to_bf16_2x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xbf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xf32> -> !pto.partition_tensor_view<1x1x1x2x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xbf16> -> !pto.partition_tensor_view<1x1x1x2x64xbf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x64xf32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x64xbf16>) + return + } + + // Case 14: f32 -> bf16, default RINT + func.func @TCVT_f32_to_bf16_4x32(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xbf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xf32> -> !pto.partition_tensor_view<1x1x1x4x32xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xbf16> -> !pto.partition_tensor_view<1x1x1x4x32xbf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x32xf32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x32xbf16>) + return + } + + // Case 15: f32 -> bf16, default RINT + func.func @TCVT_f32_to_bf16_2x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xbf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xf32> -> !pto.partition_tensor_view<1x1x1x2x128xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xbf16> -> !pto.partition_tensor_view<1x1x1x2x128xbf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x128xf32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x128xbf16>) + return + } + + // Case 16: f32 -> bf16, default RINT + func.func @TCVT_f32_to_bf16_4x65(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c65 = arith.constant 65 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xbf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xf32> -> !pto.partition_tensor_view<1x1x1x4x65xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xbf16> -> !pto.partition_tensor_view<1x1x1x4x65xbf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x65xf32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x65xbf16>) + return + } + + // Case 17: f32 -> bf16, default RINT + func.func @TCVT_f32_to_bf16_4x200(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c200 = arith.constant 200 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xbf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xf32> -> !pto.partition_tensor_view<1x1x1x4x200xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xbf16> -> !pto.partition_tensor_view<1x1x1x4x200xbf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x200xf32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x200xbf16>) + return + } + + // Case 18: f32 -> bf16, default RINT + func.func @TCVT_f32_to_bf16_1x129(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c129 = arith.constant 129 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xbf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xf32> -> !pto.partition_tensor_view<1x1x1x1x129xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xbf16> -> !pto.partition_tensor_view<1x1x1x1x129xbf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x129xf32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x129xbf16>) + return + } + + // Case 19: f32 -> i16, default RINT + func.func @TCVT_f32_to_i16_1x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x128xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xi16> -> !pto.partition_tensor_view<1x1x1x1x128xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x128xf32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x128xi16>) + return + } + + // Case 20: f32 -> i16, default RINT + func.func @TCVT_f32_to_i16_2x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xf32> -> !pto.partition_tensor_view<1x1x1x2x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xi16> -> !pto.partition_tensor_view<1x1x1x2x64xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x64xf32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x64xi16>) + return + } + + // Case 21: f32 -> i16, default RINT + func.func @TCVT_f32_to_i16_4x32(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xf32> -> !pto.partition_tensor_view<1x1x1x4x32xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xi16> -> !pto.partition_tensor_view<1x1x1x4x32xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x32xf32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x32xi16>) + return + } + + // Case 22: f32 -> i16, default RINT + func.func @TCVT_f32_to_i16_2x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xf32> -> !pto.partition_tensor_view<1x1x1x2x128xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xi16> -> !pto.partition_tensor_view<1x1x1x2x128xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x128xf32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x128xi16>) + return + } + + // Case 23: f32 -> i16, default RINT + func.func @TCVT_f32_to_i16_4x65(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c65 = arith.constant 65 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xf32> -> !pto.partition_tensor_view<1x1x1x4x65xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xi16> -> !pto.partition_tensor_view<1x1x1x4x65xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x65xf32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x65xi16>) + return + } + + // Case 24: f32 -> i16, default RINT + func.func @TCVT_f32_to_i16_4x200(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c200 = arith.constant 200 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xf32> -> !pto.partition_tensor_view<1x1x1x4x200xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xi16> -> !pto.partition_tensor_view<1x1x1x4x200xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x200xf32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x200xi16>) + return + } + + // Case 25: f32 -> i16, default RINT + func.func @TCVT_f32_to_i16_1x129(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c129 = arith.constant 129 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xf32> -> !pto.partition_tensor_view<1x1x1x1x129xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xi16> -> !pto.partition_tensor_view<1x1x1x1x129xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x129xf32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x129xi16>) + return + } + + // Case 26: f32 -> i32, default RINT + func.func @TCVT_f32_to_i32_1x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x128xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xi32> -> !pto.partition_tensor_view<1x1x1x1x128xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x128xf32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x128xi32>) + return + } + + // Case 27: f32 -> i32, default RINT + func.func @TCVT_f32_to_i32_2x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xf32> -> !pto.partition_tensor_view<1x1x1x2x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xi32> -> !pto.partition_tensor_view<1x1x1x2x64xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x64xf32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x64xi32>) + return + } + + // Case 28: f32 -> i32, default RINT + func.func @TCVT_f32_to_i32_4x32(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xf32> -> !pto.partition_tensor_view<1x1x1x4x32xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xi32> -> !pto.partition_tensor_view<1x1x1x4x32xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x32xf32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x32xi32>) + return + } + + // Case 29: f32 -> i32, default RINT + func.func @TCVT_f32_to_i32_2x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xf32> -> !pto.partition_tensor_view<1x1x1x2x128xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xi32> -> !pto.partition_tensor_view<1x1x1x2x128xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x128xf32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x128xi32>) + return + } + + // Case 30: f32 -> i32, default RINT + func.func @TCVT_f32_to_i32_4x65(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c65 = arith.constant 65 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xf32> -> !pto.partition_tensor_view<1x1x1x4x65xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xi32> -> !pto.partition_tensor_view<1x1x1x4x65xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x65xf32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x65xi32>) + return + } + + // Case 31: f32 -> i32, default RINT + func.func @TCVT_f32_to_i32_4x200(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c200 = arith.constant 200 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xf32> -> !pto.partition_tensor_view<1x1x1x4x200xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xi32> -> !pto.partition_tensor_view<1x1x1x4x200xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x200xf32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x200xi32>) + return + } + + // Case 32: f32 -> i32, default RINT + func.func @TCVT_f32_to_i32_1x129(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c129 = arith.constant 129 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xf32> -> !pto.partition_tensor_view<1x1x1x1x129xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x129xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x129xf32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x129xi32>) + return + } + + // Case 33: f32 -> i64, default RINT + func.func @TCVT_f32_to_i64_1x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi64> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x128xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xi64> -> !pto.partition_tensor_view<1x1x1x1x128xi64> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x128xf32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x128xi64>) + return + } + + // Case 34: f32 -> i64, default RINT + func.func @TCVT_f32_to_i64_2x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xi64> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xf32> -> !pto.partition_tensor_view<1x1x1x2x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xi64> -> !pto.partition_tensor_view<1x1x1x2x64xi64> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x64xf32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x64xi64>) + return + } + + // Case 35: f32 -> i64, default RINT + func.func @TCVT_f32_to_i64_4x32(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xi64> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xf32> -> !pto.partition_tensor_view<1x1x1x4x32xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xi64> -> !pto.partition_tensor_view<1x1x1x4x32xi64> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x32xf32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x32xi64>) + return + } + + // Case 36: f32 -> i64, default RINT + func.func @TCVT_f32_to_i64_2x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xi64> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xf32> -> !pto.partition_tensor_view<1x1x1x2x128xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xi64> -> !pto.partition_tensor_view<1x1x1x2x128xi64> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x128xf32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x128xi64>) + return + } + + // Case 37: f32 -> i64, default RINT + func.func @TCVT_f32_to_i64_4x65(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c65 = arith.constant 65 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xi64> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xf32> -> !pto.partition_tensor_view<1x1x1x4x65xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xi64> -> !pto.partition_tensor_view<1x1x1x4x65xi64> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x65xf32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x65xi64>) + return + } + + // Case 38: f32 -> i64, default RINT + func.func @TCVT_f32_to_i64_4x200(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c200 = arith.constant 200 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xi64> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xf32> -> !pto.partition_tensor_view<1x1x1x4x200xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xi64> -> !pto.partition_tensor_view<1x1x1x4x200xi64> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x200xf32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x200xi64>) + return + } + + // Case 39: f32 -> i64, default RINT + func.func @TCVT_f32_to_i64_1x129(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c129 = arith.constant 129 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi64> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xf32> -> !pto.partition_tensor_view<1x1x1x1x129xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xi64> -> !pto.partition_tensor_view<1x1x1x1x129xi64> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x129xf32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x129xi64>) + return + } + + // Case 40: f32 -> f32, default RINT + func.func @TCVT_f32_to_f32_1x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x128xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x128xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x128xf32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x128xf32>) + return + } + + // Case 41: f32 -> f32, default RINT + func.func @TCVT_f32_to_f32_2x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xf32> -> !pto.partition_tensor_view<1x1x1x2x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xf32> -> !pto.partition_tensor_view<1x1x1x2x64xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x64xf32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x64xf32>) + return + } + + // Case 42: f32 -> f32, default RINT + func.func @TCVT_f32_to_f32_4x32(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xf32> -> !pto.partition_tensor_view<1x1x1x4x32xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xf32> -> !pto.partition_tensor_view<1x1x1x4x32xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x32xf32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x32xf32>) + return + } + + // Case 43: f32 -> f32, default RINT + func.func @TCVT_f32_to_f32_2x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xf32> -> !pto.partition_tensor_view<1x1x1x2x128xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xf32> -> !pto.partition_tensor_view<1x1x1x2x128xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x128xf32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x128xf32>) + return + } + + // Case 44: f32 -> f32, default RINT + func.func @TCVT_f32_to_f32_4x65(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c65 = arith.constant 65 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xf32> -> !pto.partition_tensor_view<1x1x1x4x65xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xf32> -> !pto.partition_tensor_view<1x1x1x4x65xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x65xf32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x65xf32>) + return + } + + // Case 45: f32 -> f32, default RINT + func.func @TCVT_f32_to_f32_4x200(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c200 = arith.constant 200 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xf32> -> !pto.partition_tensor_view<1x1x1x4x200xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xf32> -> !pto.partition_tensor_view<1x1x1x4x200xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x200xf32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x200xf32>) + return + } + + // Case 46: f32 -> f32, default RINT + func.func @TCVT_f32_to_f32_1x129(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c129 = arith.constant 129 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xf32> -> !pto.partition_tensor_view<1x1x1x1x129xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xf32> -> !pto.partition_tensor_view<1x1x1x1x129xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x129xf32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x129xf32>) + return + } + + // Case 47: f16 -> f32, default RINT + func.func @TCVT_f16_to_f32_1x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf16> -> !pto.partition_tensor_view<1x1x1x1x128xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x128xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x128xf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x128xf32>) + return + } + + // Case 48: f16 -> f32, default RINT + func.func @TCVT_f16_to_f32_2x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xf16> -> !pto.partition_tensor_view<1x1x1x2x64xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xf32> -> !pto.partition_tensor_view<1x1x1x2x64xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x64xf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x64xf32>) + return + } + + // Case 49: f16 -> f32, default RINT + func.func @TCVT_f16_to_f32_4x32(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xf16> -> !pto.partition_tensor_view<1x1x1x4x32xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xf32> -> !pto.partition_tensor_view<1x1x1x4x32xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x32xf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x32xf32>) + return + } + + // Case 50: f16 -> f32, default RINT + func.func @TCVT_f16_to_f32_2x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xf16> -> !pto.partition_tensor_view<1x1x1x2x128xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xf32> -> !pto.partition_tensor_view<1x1x1x2x128xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x128xf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x128xf32>) + return + } + + // Case 51: f16 -> f32, default RINT + func.func @TCVT_f16_to_f32_4x65(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c65 = arith.constant 65 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xf16> -> !pto.partition_tensor_view<1x1x1x4x65xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xf32> -> !pto.partition_tensor_view<1x1x1x4x65xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x65xf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x65xf32>) + return + } + + // Case 52: f16 -> f32, default RINT + func.func @TCVT_f16_to_f32_4x200(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c200 = arith.constant 200 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xf16> -> !pto.partition_tensor_view<1x1x1x4x200xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xf32> -> !pto.partition_tensor_view<1x1x1x4x200xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x200xf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x200xf32>) + return + } + + // Case 53: f16 -> f32, default RINT + func.func @TCVT_f16_to_f32_1x129(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c129 = arith.constant 129 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x129xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xf32> -> !pto.partition_tensor_view<1x1x1x1x129xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x129xf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x129xf32>) + return + } + + // Case 54: f16 -> i32, default RINT + func.func @TCVT_f16_to_i32_1x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf16> -> !pto.partition_tensor_view<1x1x1x1x128xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xi32> -> !pto.partition_tensor_view<1x1x1x1x128xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x128xf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x128xi32>) + return + } + + // Case 55: f16 -> i32, default RINT + func.func @TCVT_f16_to_i32_2x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xf16> -> !pto.partition_tensor_view<1x1x1x2x64xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xi32> -> !pto.partition_tensor_view<1x1x1x2x64xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x64xf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x64xi32>) + return + } + + // Case 56: f16 -> i32, default RINT + func.func @TCVT_f16_to_i32_4x32(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xf16> -> !pto.partition_tensor_view<1x1x1x4x32xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xi32> -> !pto.partition_tensor_view<1x1x1x4x32xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x32xf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x32xi32>) + return + } + + // Case 57: f16 -> i32, default RINT + func.func @TCVT_f16_to_i32_2x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xf16> -> !pto.partition_tensor_view<1x1x1x2x128xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xi32> -> !pto.partition_tensor_view<1x1x1x2x128xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x128xf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x128xi32>) + return + } + + // Case 58: f16 -> i32, default RINT + func.func @TCVT_f16_to_i32_4x65(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c65 = arith.constant 65 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xf16> -> !pto.partition_tensor_view<1x1x1x4x65xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xi32> -> !pto.partition_tensor_view<1x1x1x4x65xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x65xf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x65xi32>) + return + } + + // Case 59: f16 -> i32, default RINT + func.func @TCVT_f16_to_i32_4x200(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c200 = arith.constant 200 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xf16> -> !pto.partition_tensor_view<1x1x1x4x200xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xi32> -> !pto.partition_tensor_view<1x1x1x4x200xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x200xf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x200xi32>) + return + } + + // Case 60: f16 -> i32, default RINT + func.func @TCVT_f16_to_i32_1x129(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c129 = arith.constant 129 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x129xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x129xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x129xf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x129xi32>) + return + } + + // Case 61: f16 -> i16, default RINT + func.func @TCVT_f16_to_i16_1x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf16> -> !pto.partition_tensor_view<1x1x1x1x128xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xi16> -> !pto.partition_tensor_view<1x1x1x1x128xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x128xf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x128xi16>) + return + } + + // Case 62: f16 -> i16, default RINT + func.func @TCVT_f16_to_i16_2x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xf16> -> !pto.partition_tensor_view<1x1x1x2x64xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xi16> -> !pto.partition_tensor_view<1x1x1x2x64xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x64xf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x64xi16>) + return + } + + // Case 63: f16 -> i16, default RINT + func.func @TCVT_f16_to_i16_4x32(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xf16> -> !pto.partition_tensor_view<1x1x1x4x32xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xi16> -> !pto.partition_tensor_view<1x1x1x4x32xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x32xf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x32xi16>) + return + } + + // Case 64: f16 -> i16, default RINT + func.func @TCVT_f16_to_i16_2x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xf16> -> !pto.partition_tensor_view<1x1x1x2x128xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xi16> -> !pto.partition_tensor_view<1x1x1x2x128xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x128xf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x128xi16>) + return + } + + // Case 65: f16 -> i16, default RINT + func.func @TCVT_f16_to_i16_4x65(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c65 = arith.constant 65 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xf16> -> !pto.partition_tensor_view<1x1x1x4x65xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xi16> -> !pto.partition_tensor_view<1x1x1x4x65xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x65xf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x65xi16>) + return + } + + // Case 66: f16 -> i16, default RINT + func.func @TCVT_f16_to_i16_4x200(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c200 = arith.constant 200 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xf16> -> !pto.partition_tensor_view<1x1x1x4x200xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xi16> -> !pto.partition_tensor_view<1x1x1x4x200xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x200xf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x200xi16>) + return + } + + // Case 67: f16 -> i16, default RINT + func.func @TCVT_f16_to_i16_1x129(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c129 = arith.constant 129 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x129xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xi16> -> !pto.partition_tensor_view<1x1x1x1x129xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x129xf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x129xi16>) + return + } + + // Case 68: f16 -> si8, default RINT + func.func @TCVT_f16_to_si8_1x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xsi8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf16> -> !pto.partition_tensor_view<1x1x1x1x128xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xsi8> -> !pto.partition_tensor_view<1x1x1x1x128xsi8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x128xf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x128xsi8>) + return + } + + // Case 69: f16 -> si8, default RINT + func.func @TCVT_f16_to_si8_2x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xsi8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xf16> -> !pto.partition_tensor_view<1x1x1x2x64xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xsi8> -> !pto.partition_tensor_view<1x1x1x2x64xsi8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x64xf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x64xsi8>) + return + } + + // Case 70: f16 -> si8, default RINT + func.func @TCVT_f16_to_si8_4x32(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xsi8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xf16> -> !pto.partition_tensor_view<1x1x1x4x32xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xsi8> -> !pto.partition_tensor_view<1x1x1x4x32xsi8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x32xf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x32xsi8>) + return + } + + // Case 71: f16 -> si8, default RINT + func.func @TCVT_f16_to_si8_2x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xsi8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xf16> -> !pto.partition_tensor_view<1x1x1x2x128xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xsi8> -> !pto.partition_tensor_view<1x1x1x2x128xsi8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x128xf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x128xsi8>) + return + } + + // Case 72: f16 -> si8, default RINT + func.func @TCVT_f16_to_si8_4x65(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c65 = arith.constant 65 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xsi8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xf16> -> !pto.partition_tensor_view<1x1x1x4x65xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xsi8> -> !pto.partition_tensor_view<1x1x1x4x65xsi8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x65xf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x65xsi8>) + return + } + + // Case 73: f16 -> si8, default RINT + func.func @TCVT_f16_to_si8_4x200(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c200 = arith.constant 200 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xsi8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xf16> -> !pto.partition_tensor_view<1x1x1x4x200xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xsi8> -> !pto.partition_tensor_view<1x1x1x4x200xsi8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x200xf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x200xsi8>) + return + } + + // Case 74: f16 -> si8, default RINT + func.func @TCVT_f16_to_si8_1x129(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c129 = arith.constant 129 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xsi8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x129xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xsi8> -> !pto.partition_tensor_view<1x1x1x1x129xsi8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x129xf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x129xsi8>) + return + } + + // Case 75: f16 -> ui8, default RINT + func.func @TCVT_f16_to_ui8_1x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf16> -> !pto.partition_tensor_view<1x1x1x1x128xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xui8> -> !pto.partition_tensor_view<1x1x1x1x128xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x128xf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x128xui8>) + return + } + + // Case 76: f16 -> ui8, default RINT + func.func @TCVT_f16_to_ui8_2x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xf16> -> !pto.partition_tensor_view<1x1x1x2x64xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xui8> -> !pto.partition_tensor_view<1x1x1x2x64xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x64xf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x64xui8>) + return + } + + // Case 77: f16 -> ui8, default RINT + func.func @TCVT_f16_to_ui8_4x32(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xf16> -> !pto.partition_tensor_view<1x1x1x4x32xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xui8> -> !pto.partition_tensor_view<1x1x1x4x32xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x32xf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x32xui8>) + return + } + + // Case 78: f16 -> ui8, default RINT + func.func @TCVT_f16_to_ui8_2x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xf16> -> !pto.partition_tensor_view<1x1x1x2x128xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xui8> -> !pto.partition_tensor_view<1x1x1x2x128xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x128xf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x128xui8>) + return + } + + // Case 79: f16 -> ui8, default RINT + func.func @TCVT_f16_to_ui8_4x65(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c65 = arith.constant 65 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xf16> -> !pto.partition_tensor_view<1x1x1x4x65xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xui8> -> !pto.partition_tensor_view<1x1x1x4x65xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x65xf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x65xui8>) + return + } + + // Case 80: f16 -> ui8, default RINT + func.func @TCVT_f16_to_ui8_4x200(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c200 = arith.constant 200 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xf16> -> !pto.partition_tensor_view<1x1x1x4x200xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xui8> -> !pto.partition_tensor_view<1x1x1x4x200xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x200xf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x200xui8>) + return + } + + // Case 81: f16 -> ui8, default RINT + func.func @TCVT_f16_to_ui8_1x129(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c129 = arith.constant 129 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x129xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xui8> -> !pto.partition_tensor_view<1x1x1x1x129xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x129xf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x129xui8>) + return + } + + // Case 82: bf16 -> f32, default RINT + func.func @TCVT_bf16_to_f32_1x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xbf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xbf16> -> !pto.partition_tensor_view<1x1x1x1x128xbf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x128xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x128xbf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x128xf32>) + return + } + + // Case 83: bf16 -> f32, default RINT + func.func @TCVT_bf16_to_f32_2x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xbf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xbf16> -> !pto.partition_tensor_view<1x1x1x2x64xbf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xf32> -> !pto.partition_tensor_view<1x1x1x2x64xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x64xbf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x64xf32>) + return + } + + // Case 84: bf16 -> f32, default RINT + func.func @TCVT_bf16_to_f32_4x32(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xbf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xbf16> -> !pto.partition_tensor_view<1x1x1x4x32xbf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xf32> -> !pto.partition_tensor_view<1x1x1x4x32xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x32xbf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x32xf32>) + return + } + + // Case 85: bf16 -> f32, default RINT + func.func @TCVT_bf16_to_f32_2x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xbf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xbf16> -> !pto.partition_tensor_view<1x1x1x2x128xbf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xf32> -> !pto.partition_tensor_view<1x1x1x2x128xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x128xbf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x128xf32>) + return + } + + // Case 86: bf16 -> f32, default RINT + func.func @TCVT_bf16_to_f32_4x65(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c65 = arith.constant 65 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xbf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xbf16> -> !pto.partition_tensor_view<1x1x1x4x65xbf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xf32> -> !pto.partition_tensor_view<1x1x1x4x65xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x65xbf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x65xf32>) + return + } + + // Case 87: bf16 -> f32, default RINT + func.func @TCVT_bf16_to_f32_4x200(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c200 = arith.constant 200 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xbf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xbf16> -> !pto.partition_tensor_view<1x1x1x4x200xbf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xf32> -> !pto.partition_tensor_view<1x1x1x4x200xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x200xbf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x200xf32>) + return + } + + // Case 88: bf16 -> f32, default RINT + func.func @TCVT_bf16_to_f32_1x129(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c129 = arith.constant 129 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xbf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xbf16> -> !pto.partition_tensor_view<1x1x1x1x129xbf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xf32> -> !pto.partition_tensor_view<1x1x1x1x129xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x129xbf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x129xf32>) + return + } + + // Case 89: bf16 -> f16, default RINT + func.func @TCVT_bf16_to_f16_1x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xbf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xbf16> -> !pto.partition_tensor_view<1x1x1x1x128xbf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf16> -> !pto.partition_tensor_view<1x1x1x1x128xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x128xbf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x128xf16>) + return + } + + // Case 90: bf16 -> f16, default RINT + func.func @TCVT_bf16_to_f16_2x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xbf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xbf16> -> !pto.partition_tensor_view<1x1x1x2x64xbf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xf16> -> !pto.partition_tensor_view<1x1x1x2x64xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x64xbf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x64xf16>) + return + } + + // Case 91: bf16 -> f16, default RINT + func.func @TCVT_bf16_to_f16_4x32(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xbf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xbf16> -> !pto.partition_tensor_view<1x1x1x4x32xbf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xf16> -> !pto.partition_tensor_view<1x1x1x4x32xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x32xbf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x32xf16>) + return + } + + // Case 92: bf16 -> f16, default RINT + func.func @TCVT_bf16_to_f16_2x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xbf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xbf16> -> !pto.partition_tensor_view<1x1x1x2x128xbf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xf16> -> !pto.partition_tensor_view<1x1x1x2x128xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x128xbf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x128xf16>) + return + } + + // Case 93: bf16 -> f16, default RINT + func.func @TCVT_bf16_to_f16_4x65(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c65 = arith.constant 65 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xbf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xbf16> -> !pto.partition_tensor_view<1x1x1x4x65xbf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xf16> -> !pto.partition_tensor_view<1x1x1x4x65xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x65xbf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x65xf16>) + return + } + + // Case 94: bf16 -> f16, default RINT + func.func @TCVT_bf16_to_f16_4x200(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c200 = arith.constant 200 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xbf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xbf16> -> !pto.partition_tensor_view<1x1x1x4x200xbf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xf16> -> !pto.partition_tensor_view<1x1x1x4x200xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x200xbf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x200xf16>) + return + } + + // Case 95: bf16 -> f16, default RINT + func.func @TCVT_bf16_to_f16_1x129(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c129 = arith.constant 129 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xbf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xbf16> -> !pto.partition_tensor_view<1x1x1x1x129xbf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x129xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x129xbf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x129xf16>) + return + } + + // Case 96: bf16 -> i32, default RINT + func.func @TCVT_bf16_to_i32_1x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xbf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xbf16> -> !pto.partition_tensor_view<1x1x1x1x128xbf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xi32> -> !pto.partition_tensor_view<1x1x1x1x128xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x128xbf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x128xi32>) + return + } + + // Case 97: bf16 -> i32, default RINT + func.func @TCVT_bf16_to_i32_2x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xbf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xbf16> -> !pto.partition_tensor_view<1x1x1x2x64xbf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xi32> -> !pto.partition_tensor_view<1x1x1x2x64xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x64xbf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x64xi32>) + return + } + + // Case 98: bf16 -> i32, default RINT + func.func @TCVT_bf16_to_i32_4x32(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xbf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xbf16> -> !pto.partition_tensor_view<1x1x1x4x32xbf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xi32> -> !pto.partition_tensor_view<1x1x1x4x32xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x32xbf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x32xi32>) + return + } + + // Case 99: bf16 -> i32, default RINT + func.func @TCVT_bf16_to_i32_2x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xbf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xbf16> -> !pto.partition_tensor_view<1x1x1x2x128xbf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xi32> -> !pto.partition_tensor_view<1x1x1x2x128xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x128xbf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x128xi32>) + return + } + + // Case 100: bf16 -> i32, default RINT + func.func @TCVT_bf16_to_i32_4x65(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c65 = arith.constant 65 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xbf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xbf16> -> !pto.partition_tensor_view<1x1x1x4x65xbf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xi32> -> !pto.partition_tensor_view<1x1x1x4x65xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x65xbf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x65xi32>) + return + } + + // Case 101: bf16 -> i32, default RINT + func.func @TCVT_bf16_to_i32_4x200(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c200 = arith.constant 200 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xbf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xbf16> -> !pto.partition_tensor_view<1x1x1x4x200xbf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xi32> -> !pto.partition_tensor_view<1x1x1x4x200xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x200xbf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x200xi32>) + return + } + + // Case 102: bf16 -> i32, default RINT + func.func @TCVT_bf16_to_i32_1x129(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c129 = arith.constant 129 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xbf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xbf16> -> !pto.partition_tensor_view<1x1x1x1x129xbf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x129xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x129xbf16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x129xi32>) + return + } + + // Case 103: ui8 -> f16, default RINT + func.func @TCVT_ui8_to_f16_1x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xui8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xui8> -> !pto.partition_tensor_view<1x1x1x1x128xui8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf16> -> !pto.partition_tensor_view<1x1x1x1x128xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x128xui8>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x128xf16>) + return + } + + // Case 104: ui8 -> f16, default RINT + func.func @TCVT_ui8_to_f16_2x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xui8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xui8> -> !pto.partition_tensor_view<1x1x1x2x64xui8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xf16> -> !pto.partition_tensor_view<1x1x1x2x64xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x64xui8>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x64xf16>) + return + } + + // Case 105: ui8 -> f16, default RINT + func.func @TCVT_ui8_to_f16_4x32(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xui8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xui8> -> !pto.partition_tensor_view<1x1x1x4x32xui8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xf16> -> !pto.partition_tensor_view<1x1x1x4x32xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x32xui8>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x32xf16>) + return + } + + // Case 106: ui8 -> f16, default RINT + func.func @TCVT_ui8_to_f16_2x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xui8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xui8> -> !pto.partition_tensor_view<1x1x1x2x128xui8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xf16> -> !pto.partition_tensor_view<1x1x1x2x128xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x128xui8>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x128xf16>) + return + } + + // Case 107: ui8 -> f16, default RINT + func.func @TCVT_ui8_to_f16_4x65(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c65 = arith.constant 65 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xui8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xui8> -> !pto.partition_tensor_view<1x1x1x4x65xui8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xf16> -> !pto.partition_tensor_view<1x1x1x4x65xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x65xui8>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x65xf16>) + return + } + + // Case 108: ui8 -> f16, default RINT + func.func @TCVT_ui8_to_f16_4x200(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c200 = arith.constant 200 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xui8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xui8> -> !pto.partition_tensor_view<1x1x1x4x200xui8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xf16> -> !pto.partition_tensor_view<1x1x1x4x200xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x200xui8>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x200xf16>) + return + } + + // Case 109: ui8 -> f16, default RINT + func.func @TCVT_ui8_to_f16_1x129(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c129 = arith.constant 129 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xui8> -> !pto.partition_tensor_view<1x1x1x1x129xui8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x129xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x129xui8>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x129xf16>) + return + } + + // Case 110: ui8 -> ui16, default RINT + func.func @TCVT_ui8_to_ui16_1x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xui8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xui16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xui8> -> !pto.partition_tensor_view<1x1x1x1x128xui8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xui16> -> !pto.partition_tensor_view<1x1x1x1x128xui16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x128xui8>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x128xui16>) + return + } + + // Case 111: ui8 -> ui16, default RINT + func.func @TCVT_ui8_to_ui16_2x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xui8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xui16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xui8> -> !pto.partition_tensor_view<1x1x1x2x64xui8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xui16> -> !pto.partition_tensor_view<1x1x1x2x64xui16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x64xui8>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x64xui16>) + return + } + + // Case 112: ui8 -> ui16, default RINT + func.func @TCVT_ui8_to_ui16_4x32(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xui8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xui16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xui8> -> !pto.partition_tensor_view<1x1x1x4x32xui8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xui16> -> !pto.partition_tensor_view<1x1x1x4x32xui16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x32xui8>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x32xui16>) + return + } + + // Case 113: ui8 -> ui16, default RINT + func.func @TCVT_ui8_to_ui16_2x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xui8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xui16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xui8> -> !pto.partition_tensor_view<1x1x1x2x128xui8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xui16> -> !pto.partition_tensor_view<1x1x1x2x128xui16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x128xui8>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x128xui16>) + return + } + + // Case 114: ui8 -> ui16, default RINT + func.func @TCVT_ui8_to_ui16_4x65(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c65 = arith.constant 65 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xui8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xui16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xui8> -> !pto.partition_tensor_view<1x1x1x4x65xui8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xui16> -> !pto.partition_tensor_view<1x1x1x4x65xui16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x65xui8>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x65xui16>) + return + } + + // Case 115: ui8 -> ui16, default RINT + func.func @TCVT_ui8_to_ui16_4x200(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c200 = arith.constant 200 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xui8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xui16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xui8> -> !pto.partition_tensor_view<1x1x1x4x200xui8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xui16> -> !pto.partition_tensor_view<1x1x1x4x200xui16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x200xui8>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x200xui16>) + return + } + + // Case 116: ui8 -> ui16, default RINT + func.func @TCVT_ui8_to_ui16_1x129(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c129 = arith.constant 129 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xui8> -> !pto.partition_tensor_view<1x1x1x1x129xui8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xui16> -> !pto.partition_tensor_view<1x1x1x1x129xui16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x129xui8>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x129xui16>) + return + } + + // Case 117: si8 -> f16, default RINT + func.func @TCVT_si8_to_f16_1x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xsi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xsi8> -> !pto.partition_tensor_view<1x1x1x1x128xsi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf16> -> !pto.partition_tensor_view<1x1x1x1x128xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x128xsi8>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x128xf16>) + return + } + + // Case 118: si8 -> f16, default RINT + func.func @TCVT_si8_to_f16_2x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xsi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xsi8> -> !pto.partition_tensor_view<1x1x1x2x64xsi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xf16> -> !pto.partition_tensor_view<1x1x1x2x64xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x64xsi8>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x64xf16>) + return + } + + // Case 119: si8 -> f16, default RINT + func.func @TCVT_si8_to_f16_4x32(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xsi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xsi8> -> !pto.partition_tensor_view<1x1x1x4x32xsi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xf16> -> !pto.partition_tensor_view<1x1x1x4x32xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x32xsi8>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x32xf16>) + return + } + + // Case 120: si8 -> f16, default RINT + func.func @TCVT_si8_to_f16_2x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xsi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xsi8> -> !pto.partition_tensor_view<1x1x1x2x128xsi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xf16> -> !pto.partition_tensor_view<1x1x1x2x128xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x128xsi8>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x128xf16>) + return + } + + // Case 121: si8 -> f16, default RINT + func.func @TCVT_si8_to_f16_4x65(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c65 = arith.constant 65 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xsi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xsi8> -> !pto.partition_tensor_view<1x1x1x4x65xsi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xf16> -> !pto.partition_tensor_view<1x1x1x4x65xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x65xsi8>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x65xf16>) + return + } + + // Case 122: si8 -> f16, default RINT + func.func @TCVT_si8_to_f16_4x200(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c200 = arith.constant 200 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xsi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xsi8> -> !pto.partition_tensor_view<1x1x1x4x200xsi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xf16> -> !pto.partition_tensor_view<1x1x1x4x200xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x200xsi8>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x200xf16>) + return + } + + // Case 123: si8 -> f16, default RINT + func.func @TCVT_si8_to_f16_1x129(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c129 = arith.constant 129 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xsi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xsi8> -> !pto.partition_tensor_view<1x1x1x1x129xsi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x129xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x129xsi8>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x129xf16>) + return + } + + // Case 124: si8 -> si16, default RINT + func.func @TCVT_si8_to_si16_1x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xsi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xsi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xsi8> -> !pto.partition_tensor_view<1x1x1x1x128xsi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xsi16> -> !pto.partition_tensor_view<1x1x1x1x128xsi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x128xsi8>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x128xsi16>) + return + } + + // Case 125: si8 -> si16, default RINT + func.func @TCVT_si8_to_si16_2x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xsi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xsi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xsi8> -> !pto.partition_tensor_view<1x1x1x2x64xsi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xsi16> -> !pto.partition_tensor_view<1x1x1x2x64xsi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x64xsi8>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x64xsi16>) + return + } + + // Case 126: si8 -> si16, default RINT + func.func @TCVT_si8_to_si16_4x32(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xsi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xsi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xsi8> -> !pto.partition_tensor_view<1x1x1x4x32xsi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xsi16> -> !pto.partition_tensor_view<1x1x1x4x32xsi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x32xsi8>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x32xsi16>) + return + } + + // Case 127: si8 -> si16, default RINT + func.func @TCVT_si8_to_si16_2x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xsi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xsi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xsi8> -> !pto.partition_tensor_view<1x1x1x2x128xsi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xsi16> -> !pto.partition_tensor_view<1x1x1x2x128xsi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x128xsi8>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x128xsi16>) + return + } + + // Case 128: si8 -> si16, default RINT + func.func @TCVT_si8_to_si16_4x65(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c65 = arith.constant 65 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xsi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xsi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xsi8> -> !pto.partition_tensor_view<1x1x1x4x65xsi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xsi16> -> !pto.partition_tensor_view<1x1x1x4x65xsi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x65xsi8>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x65xsi16>) + return + } + + // Case 129: si8 -> si16, default RINT + func.func @TCVT_si8_to_si16_4x200(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c200 = arith.constant 200 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xsi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xsi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xsi8> -> !pto.partition_tensor_view<1x1x1x4x200xsi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xsi16> -> !pto.partition_tensor_view<1x1x1x4x200xsi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x200xsi8>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x200xsi16>) + return + } + + // Case 130: si8 -> si16, default RINT + func.func @TCVT_si8_to_si16_1x129(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c129 = arith.constant 129 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xsi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xsi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xsi8> -> !pto.partition_tensor_view<1x1x1x1x129xsi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xsi16> -> !pto.partition_tensor_view<1x1x1x1x129xsi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x129xsi8>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x129xsi16>) + return + } + + // Case 131: si8 -> i32, default RINT + func.func @TCVT_si8_to_i32_1x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xsi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xsi8> -> !pto.partition_tensor_view<1x1x1x1x128xsi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xi32> -> !pto.partition_tensor_view<1x1x1x1x128xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x128xsi8>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x128xi32>) + return + } + + // Case 132: si8 -> i32, default RINT + func.func @TCVT_si8_to_i32_2x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xsi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xsi8> -> !pto.partition_tensor_view<1x1x1x2x64xsi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xi32> -> !pto.partition_tensor_view<1x1x1x2x64xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x64xsi8>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x64xi32>) + return + } + + // Case 133: si8 -> i32, default RINT + func.func @TCVT_si8_to_i32_4x32(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xsi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xsi8> -> !pto.partition_tensor_view<1x1x1x4x32xsi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xi32> -> !pto.partition_tensor_view<1x1x1x4x32xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x32xsi8>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x32xi32>) + return + } + + // Case 134: si8 -> i32, default RINT + func.func @TCVT_si8_to_i32_2x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xsi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xsi8> -> !pto.partition_tensor_view<1x1x1x2x128xsi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xi32> -> !pto.partition_tensor_view<1x1x1x2x128xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x128xsi8>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x128xi32>) + return + } + + // Case 135: si8 -> i32, default RINT + func.func @TCVT_si8_to_i32_4x65(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c65 = arith.constant 65 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xsi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xsi8> -> !pto.partition_tensor_view<1x1x1x4x65xsi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xi32> -> !pto.partition_tensor_view<1x1x1x4x65xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x65xsi8>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x65xi32>) + return + } + + // Case 136: si8 -> i32, default RINT + func.func @TCVT_si8_to_i32_4x200(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c200 = arith.constant 200 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xsi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xsi8> -> !pto.partition_tensor_view<1x1x1x4x200xsi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xi32> -> !pto.partition_tensor_view<1x1x1x4x200xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x200xsi8>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x200xi32>) + return + } + + // Case 137: si8 -> i32, default RINT + func.func @TCVT_si8_to_i32_1x129(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c129 = arith.constant 129 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xsi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xsi8> -> !pto.partition_tensor_view<1x1x1x1x129xsi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x129xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x129xsi8>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x129xi32>) + return + } + + // Case 138: i16 -> ui8, default RINT + func.func @TCVT_i16_to_ui8_1x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xi16> -> !pto.partition_tensor_view<1x1x1x1x128xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xui8> -> !pto.partition_tensor_view<1x1x1x1x128xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x128xi16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x128xui8>) + return + } + + // Case 139: i16 -> ui8, default RINT + func.func @TCVT_i16_to_ui8_2x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xi16> -> !pto.partition_tensor_view<1x1x1x2x64xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xui8> -> !pto.partition_tensor_view<1x1x1x2x64xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x64xi16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x64xui8>) + return + } + + // Case 140: i16 -> ui8, default RINT + func.func @TCVT_i16_to_ui8_4x32(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xi16> -> !pto.partition_tensor_view<1x1x1x4x32xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xui8> -> !pto.partition_tensor_view<1x1x1x4x32xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x32xi16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x32xui8>) + return + } + + // Case 141: i16 -> ui8, default RINT + func.func @TCVT_i16_to_ui8_2x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xi16> -> !pto.partition_tensor_view<1x1x1x2x128xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xui8> -> !pto.partition_tensor_view<1x1x1x2x128xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x128xi16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x128xui8>) + return + } + + // Case 142: i16 -> ui8, default RINT + func.func @TCVT_i16_to_ui8_4x65(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c65 = arith.constant 65 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xi16> -> !pto.partition_tensor_view<1x1x1x4x65xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xui8> -> !pto.partition_tensor_view<1x1x1x4x65xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x65xi16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x65xui8>) + return + } + + // Case 143: i16 -> ui8, default RINT + func.func @TCVT_i16_to_ui8_4x200(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c200 = arith.constant 200 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xi16> -> !pto.partition_tensor_view<1x1x1x4x200xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xui8> -> !pto.partition_tensor_view<1x1x1x4x200xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x200xi16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x200xui8>) + return + } + + // Case 144: i16 -> ui8, default RINT + func.func @TCVT_i16_to_ui8_1x129(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c129 = arith.constant 129 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xi16> -> !pto.partition_tensor_view<1x1x1x1x129xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xui8> -> !pto.partition_tensor_view<1x1x1x1x129xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x129xi16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x129xui8>) + return + } + + // Case 145: i16 -> f16, default RINT + func.func @TCVT_i16_to_f16_1x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xi16> -> !pto.partition_tensor_view<1x1x1x1x128xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf16> -> !pto.partition_tensor_view<1x1x1x1x128xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x128xi16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x128xf16>) + return + } + + // Case 146: i16 -> f16, default RINT + func.func @TCVT_i16_to_f16_2x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xi16> -> !pto.partition_tensor_view<1x1x1x2x64xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xf16> -> !pto.partition_tensor_view<1x1x1x2x64xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x64xi16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x64xf16>) + return + } + + // Case 147: i16 -> f16, default RINT + func.func @TCVT_i16_to_f16_4x32(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xi16> -> !pto.partition_tensor_view<1x1x1x4x32xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xf16> -> !pto.partition_tensor_view<1x1x1x4x32xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x32xi16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x32xf16>) + return + } + + // Case 148: i16 -> f16, default RINT + func.func @TCVT_i16_to_f16_2x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xi16> -> !pto.partition_tensor_view<1x1x1x2x128xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xf16> -> !pto.partition_tensor_view<1x1x1x2x128xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x128xi16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x128xf16>) + return + } + + // Case 149: i16 -> f16, default RINT + func.func @TCVT_i16_to_f16_4x65(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c65 = arith.constant 65 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xi16> -> !pto.partition_tensor_view<1x1x1x4x65xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xf16> -> !pto.partition_tensor_view<1x1x1x4x65xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x65xi16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x65xf16>) + return + } + + // Case 150: i16 -> f16, default RINT + func.func @TCVT_i16_to_f16_4x200(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c200 = arith.constant 200 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xi16> -> !pto.partition_tensor_view<1x1x1x4x200xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xf16> -> !pto.partition_tensor_view<1x1x1x4x200xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x200xi16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x200xf16>) + return + } + + // Case 151: i16 -> f16, default RINT + func.func @TCVT_i16_to_f16_1x129(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c129 = arith.constant 129 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xi16> -> !pto.partition_tensor_view<1x1x1x1x129xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x129xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x129xi16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x129xf16>) + return + } + + // Case 152: i16 -> f32, default RINT + func.func @TCVT_i16_to_f32_1x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xi16> -> !pto.partition_tensor_view<1x1x1x1x128xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x128xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x128xi16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x128xf32>) + return + } + + // Case 153: i16 -> f32, default RINT + func.func @TCVT_i16_to_f32_2x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xi16> -> !pto.partition_tensor_view<1x1x1x2x64xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xf32> -> !pto.partition_tensor_view<1x1x1x2x64xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x64xi16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x64xf32>) + return + } + + // Case 154: i16 -> f32, default RINT + func.func @TCVT_i16_to_f32_4x32(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xi16> -> !pto.partition_tensor_view<1x1x1x4x32xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xf32> -> !pto.partition_tensor_view<1x1x1x4x32xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x32xi16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x32xf32>) + return + } + + // Case 155: i16 -> f32, default RINT + func.func @TCVT_i16_to_f32_2x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xi16> -> !pto.partition_tensor_view<1x1x1x2x128xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xf32> -> !pto.partition_tensor_view<1x1x1x2x128xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x128xi16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x128xf32>) + return + } + + // Case 156: i16 -> f32, default RINT + func.func @TCVT_i16_to_f32_4x65(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c65 = arith.constant 65 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xi16> -> !pto.partition_tensor_view<1x1x1x4x65xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xf32> -> !pto.partition_tensor_view<1x1x1x4x65xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x65xi16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x65xf32>) + return + } + + // Case 157: i16 -> f32, default RINT + func.func @TCVT_i16_to_f32_4x200(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c200 = arith.constant 200 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xi16> -> !pto.partition_tensor_view<1x1x1x4x200xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xf32> -> !pto.partition_tensor_view<1x1x1x4x200xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x200xi16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x200xf32>) + return + } + + // Case 158: i16 -> f32, default RINT + func.func @TCVT_i16_to_f32_1x129(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c129 = arith.constant 129 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xi16> -> !pto.partition_tensor_view<1x1x1x1x129xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xf32> -> !pto.partition_tensor_view<1x1x1x1x129xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x129xi16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x129xf32>) + return + } + + // Case 159: i16 -> ui32, default RINT + func.func @TCVT_i16_to_ui32_1x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xui32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xi16> -> !pto.partition_tensor_view<1x1x1x1x128xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xui32> -> !pto.partition_tensor_view<1x1x1x1x128xui32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x128xi16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x128xui32>) + return + } + + // Case 160: i16 -> ui32, default RINT + func.func @TCVT_i16_to_ui32_2x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xui32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xi16> -> !pto.partition_tensor_view<1x1x1x2x64xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xui32> -> !pto.partition_tensor_view<1x1x1x2x64xui32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x64xi16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x64xui32>) + return + } + + // Case 161: i16 -> ui32, default RINT + func.func @TCVT_i16_to_ui32_4x32(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xui32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xi16> -> !pto.partition_tensor_view<1x1x1x4x32xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xui32> -> !pto.partition_tensor_view<1x1x1x4x32xui32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x32xi16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x32xui32>) + return + } + + // Case 162: i16 -> ui32, default RINT + func.func @TCVT_i16_to_ui32_2x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xui32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xi16> -> !pto.partition_tensor_view<1x1x1x2x128xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xui32> -> !pto.partition_tensor_view<1x1x1x2x128xui32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x128xi16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x128xui32>) + return + } + + // Case 163: i16 -> ui32, default RINT + func.func @TCVT_i16_to_ui32_4x65(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c65 = arith.constant 65 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xui32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xi16> -> !pto.partition_tensor_view<1x1x1x4x65xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xui32> -> !pto.partition_tensor_view<1x1x1x4x65xui32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x65xi16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x65xui32>) + return + } + + // Case 164: i16 -> ui32, default RINT + func.func @TCVT_i16_to_ui32_4x200(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c200 = arith.constant 200 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xui32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xi16> -> !pto.partition_tensor_view<1x1x1x4x200xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xui32> -> !pto.partition_tensor_view<1x1x1x4x200xui32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x200xi16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x200xui32>) + return + } + + // Case 165: i16 -> ui32, default RINT + func.func @TCVT_i16_to_ui32_1x129(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c129 = arith.constant 129 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xi16> -> !pto.partition_tensor_view<1x1x1x1x129xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xui32> -> !pto.partition_tensor_view<1x1x1x1x129xui32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x129xi16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x129xui32>) + return + } + + // Case 166: i16 -> i32, default RINT + func.func @TCVT_i16_to_i32_1x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xi16> -> !pto.partition_tensor_view<1x1x1x1x128xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xi32> -> !pto.partition_tensor_view<1x1x1x1x128xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x128xi16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x128xi32>) + return + } + + // Case 167: i16 -> i32, default RINT + func.func @TCVT_i16_to_i32_2x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xi16> -> !pto.partition_tensor_view<1x1x1x2x64xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xi32> -> !pto.partition_tensor_view<1x1x1x2x64xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x64xi16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x64xi32>) + return + } + + // Case 168: i16 -> i32, default RINT + func.func @TCVT_i16_to_i32_4x32(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xi16> -> !pto.partition_tensor_view<1x1x1x4x32xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xi32> -> !pto.partition_tensor_view<1x1x1x4x32xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x32xi16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x32xi32>) + return + } + + // Case 169: i16 -> i32, default RINT + func.func @TCVT_i16_to_i32_2x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xi16> -> !pto.partition_tensor_view<1x1x1x2x128xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xi32> -> !pto.partition_tensor_view<1x1x1x2x128xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x128xi16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x128xi32>) + return + } + + // Case 170: i16 -> i32, default RINT + func.func @TCVT_i16_to_i32_4x65(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c65 = arith.constant 65 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xi16> -> !pto.partition_tensor_view<1x1x1x4x65xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xi32> -> !pto.partition_tensor_view<1x1x1x4x65xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x65xi16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x65xi32>) + return + } + + // Case 171: i16 -> i32, default RINT + func.func @TCVT_i16_to_i32_4x200(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c200 = arith.constant 200 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xi16> -> !pto.partition_tensor_view<1x1x1x4x200xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xi32> -> !pto.partition_tensor_view<1x1x1x4x200xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x200xi16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x200xi32>) + return + } + + // Case 172: i16 -> i32, default RINT + func.func @TCVT_i16_to_i32_1x129(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c129 = arith.constant 129 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xi16> -> !pto.partition_tensor_view<1x1x1x1x129xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x129xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x129xi16>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x129xi32>) + return + } + + // Case 173: i32 -> f32, default RINT + func.func @TCVT_i32_to_f32_1x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xi32> -> !pto.partition_tensor_view<1x1x1x1x128xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x128xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x128xi32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x128xf32>) + return + } + + // Case 174: i32 -> f32, default RINT + func.func @TCVT_i32_to_f32_2x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xi32> -> !pto.partition_tensor_view<1x1x1x2x64xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xf32> -> !pto.partition_tensor_view<1x1x1x2x64xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x64xi32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x64xf32>) + return + } + + // Case 175: i32 -> f32, default RINT + func.func @TCVT_i32_to_f32_4x32(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xi32> -> !pto.partition_tensor_view<1x1x1x4x32xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xf32> -> !pto.partition_tensor_view<1x1x1x4x32xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x32xi32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x32xf32>) + return + } + + // Case 176: i32 -> f32, default RINT + func.func @TCVT_i32_to_f32_2x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xi32> -> !pto.partition_tensor_view<1x1x1x2x128xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xf32> -> !pto.partition_tensor_view<1x1x1x2x128xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x128xi32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x128xf32>) + return + } + + // Case 177: i32 -> f32, default RINT + func.func @TCVT_i32_to_f32_4x65(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c65 = arith.constant 65 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xi32> -> !pto.partition_tensor_view<1x1x1x4x65xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xf32> -> !pto.partition_tensor_view<1x1x1x4x65xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x65xi32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x65xf32>) + return + } + + // Case 178: i32 -> f32, default RINT + func.func @TCVT_i32_to_f32_4x200(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c200 = arith.constant 200 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xi32> -> !pto.partition_tensor_view<1x1x1x4x200xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xf32> -> !pto.partition_tensor_view<1x1x1x4x200xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x200xi32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x200xf32>) + return + } + + // Case 179: i32 -> f32, default RINT + func.func @TCVT_i32_to_f32_1x129(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c129 = arith.constant 129 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x129xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xf32> -> !pto.partition_tensor_view<1x1x1x1x129xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x129xi32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x129xf32>) + return + } + + // Case 180: i32 -> i16, default RINT + func.func @TCVT_i32_to_i16_1x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xi32> -> !pto.partition_tensor_view<1x1x1x1x128xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xi16> -> !pto.partition_tensor_view<1x1x1x1x128xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x128xi32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x128xi16>) + return + } + + // Case 181: i32 -> i16, default RINT + func.func @TCVT_i32_to_i16_2x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xi32> -> !pto.partition_tensor_view<1x1x1x2x64xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xi16> -> !pto.partition_tensor_view<1x1x1x2x64xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x64xi32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x64xi16>) + return + } + + // Case 182: i32 -> i16, default RINT + func.func @TCVT_i32_to_i16_4x32(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xi32> -> !pto.partition_tensor_view<1x1x1x4x32xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xi16> -> !pto.partition_tensor_view<1x1x1x4x32xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x32xi32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x32xi16>) + return + } + + // Case 183: i32 -> i16, default RINT + func.func @TCVT_i32_to_i16_2x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xi32> -> !pto.partition_tensor_view<1x1x1x2x128xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xi16> -> !pto.partition_tensor_view<1x1x1x2x128xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x128xi32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x128xi16>) + return + } + + // Case 184: i32 -> i16, default RINT + func.func @TCVT_i32_to_i16_4x65(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c65 = arith.constant 65 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xi32> -> !pto.partition_tensor_view<1x1x1x4x65xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xi16> -> !pto.partition_tensor_view<1x1x1x4x65xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x65xi32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x65xi16>) + return + } + + // Case 185: i32 -> i16, default RINT + func.func @TCVT_i32_to_i16_4x200(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c200 = arith.constant 200 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xi32> -> !pto.partition_tensor_view<1x1x1x4x200xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xi16> -> !pto.partition_tensor_view<1x1x1x4x200xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x200xi32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x200xi16>) + return + } + + // Case 186: i32 -> i16, default RINT + func.func @TCVT_i32_to_i16_1x129(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c129 = arith.constant 129 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x129xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xi16> -> !pto.partition_tensor_view<1x1x1x1x129xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x129xi32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x129xi16>) + return + } + + // Case 187: i32 -> i64, default RINT + func.func @TCVT_i32_to_i64_1x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi64> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xi32> -> !pto.partition_tensor_view<1x1x1x1x128xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xi64> -> !pto.partition_tensor_view<1x1x1x1x128xi64> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x128xi32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x128xi64>) + return + } + + // Case 188: i32 -> i64, default RINT + func.func @TCVT_i32_to_i64_2x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xi64> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xi32> -> !pto.partition_tensor_view<1x1x1x2x64xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xi64> -> !pto.partition_tensor_view<1x1x1x2x64xi64> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x64xi32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x64xi64>) + return + } + + // Case 189: i32 -> i64, default RINT + func.func @TCVT_i32_to_i64_4x32(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xi64> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xi32> -> !pto.partition_tensor_view<1x1x1x4x32xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xi64> -> !pto.partition_tensor_view<1x1x1x4x32xi64> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x32xi32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x32xi64>) + return + } + + // Case 190: i32 -> i64, default RINT + func.func @TCVT_i32_to_i64_2x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xi64> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xi32> -> !pto.partition_tensor_view<1x1x1x2x128xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xi64> -> !pto.partition_tensor_view<1x1x1x2x128xi64> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x128xi32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x128xi64>) + return + } + + // Case 191: i32 -> i64, default RINT + func.func @TCVT_i32_to_i64_4x65(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c65 = arith.constant 65 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xi64> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xi32> -> !pto.partition_tensor_view<1x1x1x4x65xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xi64> -> !pto.partition_tensor_view<1x1x1x4x65xi64> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x65xi32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x65xi64>) + return + } + + // Case 192: i32 -> i64, default RINT + func.func @TCVT_i32_to_i64_4x200(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c200 = arith.constant 200 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xi64> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xi32> -> !pto.partition_tensor_view<1x1x1x4x200xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xi64> -> !pto.partition_tensor_view<1x1x1x4x200xi64> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x200xi32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x200xi64>) + return + } + + // Case 193: i32 -> i64, default RINT + func.func @TCVT_i32_to_i64_1x129(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c129 = arith.constant 129 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi64> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x129xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xi64> -> !pto.partition_tensor_view<1x1x1x1x129xi64> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x129xi32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x129xi64>) + return + } + + // Case 194: i32 -> ui8, default RINT + func.func @TCVT_i32_to_ui8_1x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xi32> -> !pto.partition_tensor_view<1x1x1x1x128xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xui8> -> !pto.partition_tensor_view<1x1x1x1x128xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x128xi32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x128xui8>) + return + } + + // Case 195: i32 -> ui8, default RINT + func.func @TCVT_i32_to_ui8_2x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xi32> -> !pto.partition_tensor_view<1x1x1x2x64xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xui8> -> !pto.partition_tensor_view<1x1x1x2x64xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x64xi32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x64xui8>) + return + } + + // Case 196: i32 -> ui8, default RINT + func.func @TCVT_i32_to_ui8_4x32(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xi32> -> !pto.partition_tensor_view<1x1x1x4x32xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xui8> -> !pto.partition_tensor_view<1x1x1x4x32xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x32xi32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x32xui8>) + return + } + + // Case 197: i32 -> ui8, default RINT + func.func @TCVT_i32_to_ui8_2x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xi32> -> !pto.partition_tensor_view<1x1x1x2x128xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xui8> -> !pto.partition_tensor_view<1x1x1x2x128xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x128xi32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x128xui8>) + return + } + + // Case 198: i32 -> ui8, default RINT + func.func @TCVT_i32_to_ui8_4x65(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c65 = arith.constant 65 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xi32> -> !pto.partition_tensor_view<1x1x1x4x65xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xui8> -> !pto.partition_tensor_view<1x1x1x4x65xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x65xi32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x65xui8>) + return + } + + // Case 199: i32 -> ui8, default RINT + func.func @TCVT_i32_to_ui8_4x200(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c200 = arith.constant 200 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xi32> -> !pto.partition_tensor_view<1x1x1x4x200xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xui8> -> !pto.partition_tensor_view<1x1x1x4x200xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x200xi32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x200xui8>) + return + } + + // Case 200: i32 -> ui8, default RINT + func.func @TCVT_i32_to_ui8_1x129(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c129 = arith.constant 129 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x129xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xui8> -> !pto.partition_tensor_view<1x1x1x1x129xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x129xi32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x129xui8>) + return + } + + // Case 201: i32 -> ui16, default RINT + func.func @TCVT_i32_to_ui16_1x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xui16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xi32> -> !pto.partition_tensor_view<1x1x1x1x128xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xui16> -> !pto.partition_tensor_view<1x1x1x1x128xui16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x128xi32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x128xui16>) + return + } + + // Case 202: i32 -> ui16, default RINT + func.func @TCVT_i32_to_ui16_2x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xui16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xi32> -> !pto.partition_tensor_view<1x1x1x2x64xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xui16> -> !pto.partition_tensor_view<1x1x1x2x64xui16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x64xi32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x64xui16>) + return + } + + // Case 203: i32 -> ui16, default RINT + func.func @TCVT_i32_to_ui16_4x32(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xui16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xi32> -> !pto.partition_tensor_view<1x1x1x4x32xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xui16> -> !pto.partition_tensor_view<1x1x1x4x32xui16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x32xi32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x32xui16>) + return + } + + // Case 204: i32 -> ui16, default RINT + func.func @TCVT_i32_to_ui16_2x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xui16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xi32> -> !pto.partition_tensor_view<1x1x1x2x128xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xui16> -> !pto.partition_tensor_view<1x1x1x2x128xui16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x128xi32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x128xui16>) + return + } + + // Case 205: i32 -> ui16, default RINT + func.func @TCVT_i32_to_ui16_4x65(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c65 = arith.constant 65 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xui16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xi32> -> !pto.partition_tensor_view<1x1x1x4x65xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xui16> -> !pto.partition_tensor_view<1x1x1x4x65xui16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x65xi32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x65xui16>) + return + } + + // Case 206: i32 -> ui16, default RINT + func.func @TCVT_i32_to_ui16_4x200(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c200 = arith.constant 200 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xui16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xi32> -> !pto.partition_tensor_view<1x1x1x4x200xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xui16> -> !pto.partition_tensor_view<1x1x1x4x200xui16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x200xi32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x200xui16>) + return + } + + // Case 207: i32 -> ui16, default RINT + func.func @TCVT_i32_to_ui16_1x129(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c129 = arith.constant 129 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x129xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xui16> -> !pto.partition_tensor_view<1x1x1x1x129xui16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x129xi32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x129xui16>) + return + } + + // Case 208: ui32 -> i16, default RINT + func.func @TCVT_ui32_to_i16_1x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xui32> -> !pto.partition_tensor_view<1x1x1x1x128xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xi16> -> !pto.partition_tensor_view<1x1x1x1x128xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x128xui32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x128xi16>) + return + } + + // Case 209: ui32 -> i16, default RINT + func.func @TCVT_ui32_to_i16_2x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xui32> -> !pto.partition_tensor_view<1x1x1x2x64xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xi16> -> !pto.partition_tensor_view<1x1x1x2x64xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x64xui32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x64xi16>) + return + } + + // Case 210: ui32 -> i16, default RINT + func.func @TCVT_ui32_to_i16_4x32(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xui32> -> !pto.partition_tensor_view<1x1x1x4x32xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xi16> -> !pto.partition_tensor_view<1x1x1x4x32xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x32xui32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x32xi16>) + return + } + + // Case 211: ui32 -> i16, default RINT + func.func @TCVT_ui32_to_i16_2x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xui32> -> !pto.partition_tensor_view<1x1x1x2x128xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xi16> -> !pto.partition_tensor_view<1x1x1x2x128xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x128xui32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x128xi16>) + return + } + + // Case 212: ui32 -> i16, default RINT + func.func @TCVT_ui32_to_i16_4x65(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c65 = arith.constant 65 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xui32> -> !pto.partition_tensor_view<1x1x1x4x65xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xi16> -> !pto.partition_tensor_view<1x1x1x4x65xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x65xui32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x65xi16>) + return + } + + // Case 213: ui32 -> i16, default RINT + func.func @TCVT_ui32_to_i16_4x200(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c200 = arith.constant 200 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xui32> -> !pto.partition_tensor_view<1x1x1x4x200xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xi16> -> !pto.partition_tensor_view<1x1x1x4x200xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x200xui32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x200xi16>) + return + } + + // Case 214: ui32 -> i16, default RINT + func.func @TCVT_ui32_to_i16_1x129(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c129 = arith.constant 129 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xui32> -> !pto.partition_tensor_view<1x1x1x1x129xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xi16> -> !pto.partition_tensor_view<1x1x1x1x129xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x129xui32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x129xi16>) + return + } + + // Case 215: ui32 -> ui16, default RINT + func.func @TCVT_ui32_to_ui16_1x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xui16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xui32> -> !pto.partition_tensor_view<1x1x1x1x128xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xui16> -> !pto.partition_tensor_view<1x1x1x1x128xui16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x128xui32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x128xui16>) + return + } + + // Case 216: ui32 -> ui16, default RINT + func.func @TCVT_ui32_to_ui16_2x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xui16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xui32> -> !pto.partition_tensor_view<1x1x1x2x64xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xui16> -> !pto.partition_tensor_view<1x1x1x2x64xui16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x64xui32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x64xui16>) + return + } + + // Case 217: ui32 -> ui16, default RINT + func.func @TCVT_ui32_to_ui16_4x32(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xui16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xui32> -> !pto.partition_tensor_view<1x1x1x4x32xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xui16> -> !pto.partition_tensor_view<1x1x1x4x32xui16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x32xui32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x32xui16>) + return + } + + // Case 218: ui32 -> ui16, default RINT + func.func @TCVT_ui32_to_ui16_2x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xui16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xui32> -> !pto.partition_tensor_view<1x1x1x2x128xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xui16> -> !pto.partition_tensor_view<1x1x1x2x128xui16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x128xui32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x128xui16>) + return + } + + // Case 219: ui32 -> ui16, default RINT + func.func @TCVT_ui32_to_ui16_4x65(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c65 = arith.constant 65 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xui16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xui32> -> !pto.partition_tensor_view<1x1x1x4x65xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xui16> -> !pto.partition_tensor_view<1x1x1x4x65xui16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x65xui32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x65xui16>) + return + } + + // Case 220: ui32 -> ui16, default RINT + func.func @TCVT_ui32_to_ui16_4x200(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c200 = arith.constant 200 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xui16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xui32> -> !pto.partition_tensor_view<1x1x1x4x200xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xui16> -> !pto.partition_tensor_view<1x1x1x4x200xui16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x200xui32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x200xui16>) + return + } + + // Case 221: ui32 -> ui16, default RINT + func.func @TCVT_ui32_to_ui16_1x129(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c129 = arith.constant 129 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xui32> -> !pto.partition_tensor_view<1x1x1x1x129xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xui16> -> !pto.partition_tensor_view<1x1x1x1x129xui16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x129xui32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x129xui16>) + return + } + + // Case 222: ui32 -> ui8, default RINT + func.func @TCVT_ui32_to_ui8_1x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xui32> -> !pto.partition_tensor_view<1x1x1x1x128xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xui8> -> !pto.partition_tensor_view<1x1x1x1x128xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x128xui32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x128xui8>) + return + } + + // Case 223: ui32 -> ui8, default RINT + func.func @TCVT_ui32_to_ui8_2x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xui32> -> !pto.partition_tensor_view<1x1x1x2x64xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xui8> -> !pto.partition_tensor_view<1x1x1x2x64xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x64xui32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x64xui8>) + return + } + + // Case 224: ui32 -> ui8, default RINT + func.func @TCVT_ui32_to_ui8_4x32(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xui32> -> !pto.partition_tensor_view<1x1x1x4x32xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xui8> -> !pto.partition_tensor_view<1x1x1x4x32xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x32xui32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x32xui8>) + return + } + + // Case 225: ui32 -> ui8, default RINT + func.func @TCVT_ui32_to_ui8_2x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xui32> -> !pto.partition_tensor_view<1x1x1x2x128xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xui8> -> !pto.partition_tensor_view<1x1x1x2x128xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x128xui32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x128xui8>) + return + } + + // Case 226: ui32 -> ui8, default RINT + func.func @TCVT_ui32_to_ui8_4x65(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c65 = arith.constant 65 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xui32> -> !pto.partition_tensor_view<1x1x1x4x65xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xui8> -> !pto.partition_tensor_view<1x1x1x4x65xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x65xui32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x65xui8>) + return + } + + // Case 227: ui32 -> ui8, default RINT + func.func @TCVT_ui32_to_ui8_4x200(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c200 = arith.constant 200 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xui32> -> !pto.partition_tensor_view<1x1x1x4x200xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xui8> -> !pto.partition_tensor_view<1x1x1x4x200xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x200xui32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x200xui8>) + return + } + + // Case 228: ui32 -> ui8, default RINT + func.func @TCVT_ui32_to_ui8_1x129(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c129 = arith.constant 129 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xui32> -> !pto.partition_tensor_view<1x1x1x1x129xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xui8> -> !pto.partition_tensor_view<1x1x1x1x129xui8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x129xui32>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x129xui8>) + return + } + + // Case 229: i64 -> f32, default RINT + func.func @TCVT_i64_to_f32_1x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi64> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xi64> -> !pto.partition_tensor_view<1x1x1x1x128xi64> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x128xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x128xi64>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x128xf32>) + return + } + + // Case 230: i64 -> f32, default RINT + func.func @TCVT_i64_to_f32_2x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xi64> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xi64> -> !pto.partition_tensor_view<1x1x1x2x64xi64> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xf32> -> !pto.partition_tensor_view<1x1x1x2x64xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x64xi64>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x64xf32>) + return + } + + // Case 231: i64 -> f32, default RINT + func.func @TCVT_i64_to_f32_4x32(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xi64> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xi64> -> !pto.partition_tensor_view<1x1x1x4x32xi64> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xf32> -> !pto.partition_tensor_view<1x1x1x4x32xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x32xi64>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x32xf32>) + return + } + + // Case 232: i64 -> f32, default RINT + func.func @TCVT_i64_to_f32_2x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xi64> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xi64> -> !pto.partition_tensor_view<1x1x1x2x128xi64> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xf32> -> !pto.partition_tensor_view<1x1x1x2x128xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x128xi64>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x128xf32>) + return + } + + // Case 233: i64 -> f32, default RINT + func.func @TCVT_i64_to_f32_4x65(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c65 = arith.constant 65 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xi64> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xi64> -> !pto.partition_tensor_view<1x1x1x4x65xi64> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xf32> -> !pto.partition_tensor_view<1x1x1x4x65xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x65xi64>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x65xf32>) + return + } + + // Case 234: i64 -> f32, default RINT + func.func @TCVT_i64_to_f32_4x200(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c200 = arith.constant 200 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xi64> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xi64> -> !pto.partition_tensor_view<1x1x1x4x200xi64> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xf32> -> !pto.partition_tensor_view<1x1x1x4x200xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x200xi64>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x200xf32>) + return + } + + // Case 235: i64 -> f32, default RINT + func.func @TCVT_i64_to_f32_1x129(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c129 = arith.constant 129 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi64> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xi64> -> !pto.partition_tensor_view<1x1x1x1x129xi64> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xf32> -> !pto.partition_tensor_view<1x1x1x1x129xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x129xi64>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x129xf32>) + return + } + + // Case 236: i64 -> i32, default RINT + func.func @TCVT_i64_to_i32_1x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi64> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xi64> -> !pto.partition_tensor_view<1x1x1x1x128xi64> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xi32> -> !pto.partition_tensor_view<1x1x1x1x128xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x128xi64>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x128xi32>) + return + } + + // Case 237: i64 -> i32, default RINT + func.func @TCVT_i64_to_i32_2x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xi64> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xi64> -> !pto.partition_tensor_view<1x1x1x2x64xi64> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xi32> -> !pto.partition_tensor_view<1x1x1x2x64xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x64xi64>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x64xi32>) + return + } + + // Case 238: i64 -> i32, default RINT + func.func @TCVT_i64_to_i32_4x32(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xi64> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c32], + strides = [%c128, %c128, %c128, %c32, %c1] + : !pto.tensor_view<1x1x1x4x32xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xi64> -> !pto.partition_tensor_view<1x1x1x4x32xi64> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c32] + : !pto.tensor_view<1x1x1x4x32xi32> -> !pto.partition_tensor_view<1x1x1x4x32xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x32xi64>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x32xi32>) + return + } + + // Case 239: i64 -> i32, default RINT + func.func @TCVT_i64_to_i32_2x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xi64> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xi64> -> !pto.partition_tensor_view<1x1x1x2x128xi64> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xi32> -> !pto.partition_tensor_view<1x1x1x2x128xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x128xi64>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x128xi32>) + return + } + + // Case 240: i64 -> i32, default RINT + func.func @TCVT_i64_to_i32_4x65(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c65 = arith.constant 65 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xi64> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c128], + strides = [%c512, %c512, %c512, %c128, %c1] + : !pto.tensor_view<1x1x1x4x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xi64> -> !pto.partition_tensor_view<1x1x1x4x65xi64> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c65] + : !pto.tensor_view<1x1x1x4x128xi32> -> !pto.partition_tensor_view<1x1x1x4x65xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x65xi64>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x65xi32>) + return + } + + // Case 241: i64 -> i32, default RINT + func.func @TCVT_i64_to_i32_4x200(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c200 = arith.constant 200 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xi64> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xi64> -> !pto.partition_tensor_view<1x1x1x4x200xi64> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c200] + : !pto.tensor_view<1x1x1x4x256xi32> -> !pto.partition_tensor_view<1x1x1x4x200xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x200xi64>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x200xi32>) + return + } + + // Case 242: i64 -> i32, default RINT + func.func @TCVT_i64_to_i32_1x129(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c129 = arith.constant 129 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi64> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xi64> -> !pto.partition_tensor_view<1x1x1x1x129xi64> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c129] + : !pto.tensor_view<1x1x1x1x256xi32> -> !pto.partition_tensor_view<1x1x1x1x129xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x129xi64>) + outs(%src : !pto.tile_buf) + + pto.tcvt ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x129xi32>) + return + } + +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tdiv/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tdiv/CMakeLists.txt new file mode 100644 index 000000000..506774c21 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tdiv/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tdiv) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tdiv/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tdiv/cases.py new file mode 100644 index 000000000..34e22f633 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tdiv/cases.py @@ -0,0 +1,206 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tdiv ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - eps: tolerance for numpy.allclose (atol and rtol). + - precision_mode: optional, "DEFAULT" or "HIGH_PRECISION". + - test_pattern: optional, "normal", "boundary", "subnormal", "overflow", "nan_inf" + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + # ============================================================ + # Normal cases - basic functionality (DEFAULT precision mode) + # ============================================================ + { + "name": "f32_16x64", + "dtype": np.float32, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-6, + "test_pattern": "normal", + }, + { + "name": "f32_32x32", + "dtype": np.float32, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-6, + "test_pattern": "normal", + }, + { + "name": "f32_64x64", + "dtype": np.float32, + "shape": (64, 64), + "valid_shape": (64, 64), + "eps": 1e-6, + "test_pattern": "normal", + }, + { + "name": "f16_16x256", + "dtype": np.float16, + "shape": (16, 256), + "valid_shape": (16, 256), + "eps": 1e-3, + "test_pattern": "normal", + }, + + # ============================================================ + # HIGH_PRECISION mode - comprehensive boundary tests + # ============================================================ + # Precision-sensitive ratios (1/3, 1/7, 7/3) - tests three-candidate search + { + "name": "f32_16x64_hp_precision", + "dtype": np.float32, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-6, + "precision_mode": "HIGH_PRECISION", + "test_pattern": "precision_sensitive", + "ulp_tolerance": 1, # Allow ±1 ULP for high-precision algorithm + }, + { + "name": "f16_16x64_hp_precision", + "dtype": np.float16, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-3, + "precision_mode": "HIGH_PRECISION", + "test_pattern": "precision_sensitive", + "ulp_tolerance": 1, + }, + + # Subnormal numbers - tests denormal normalization and compensation + { + "name": "f32_16x64_hp_subnormal", + "dtype": np.float32, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-6, + "precision_mode": "HIGH_PRECISION", + "test_pattern": "subnormal", + "ulp_tolerance": 2, # Subnormal handling may have ±2 ULP variance + }, + { + "name": "f16_16x64_hp_subnormal", + "dtype": np.float16, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-3, + "precision_mode": "HIGH_PRECISION", + "test_pattern": "subnormal", + "ulp_tolerance": 2, + }, + +# Overflow/Underflow boundaries - tests exponent handling + { + "name": "f32_16x64_hp_overflow", + "dtype": np.float32, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-6, + "precision_mode": "HIGH_PRECISION", + "test_pattern": "overflow", + }, + { + "name": "f16_16x64_hp_overflow", + "dtype": np.float16, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-3, + "precision_mode": "HIGH_PRECISION", + "test_pattern": "overflow", + }, + + # Different shapes - test tile size variations + { + "name": "f32_32x32_hp", + "dtype": np.float32, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-5, + "precision_mode": "HIGH_PRECISION", + "test_pattern": "precision_sensitive", + "ulp_tolerance": 2, + }, + { + "name": "f32_64x64_hp", + "dtype": np.float32, + "shape": (64, 64), + "valid_shape": (64, 64), + "eps": 1e-5, + "precision_mode": "HIGH_PRECISION", + "test_pattern": "precision_sensitive", + "ulp_tolerance": 2, + }, + { + "name": "f16_16x256_hp", + "dtype": np.float16, + "shape": (16, 256), + "valid_shape": (16, 256), + "eps": 1e-3, + "precision_mode": "HIGH_PRECISION", + "test_pattern": "precision_sensitive", + "ulp_tolerance": 2, + }, + + # Partial valid shape - test masked computation + { + "name": "f32_16x64_hp_partial", + "dtype": np.float32, + "shape": (16, 64), + "valid_shape": (16, 31), + "eps": 1e-5, + "precision_mode": "HIGH_PRECISION", + "test_pattern": "precision_sensitive", + "ulp_tolerance": 2, + }, + { + "name": "f16_16x64_hp_partial", + "dtype": np.float16, + "shape": (16, 64), + "valid_shape": (16, 63), + "eps": 1e-3, + "precision_mode": "HIGH_PRECISION", + "test_pattern": "precision_sensitive", + "ulp_tolerance": 2, + }, + + # Small shape HP tests - aligned with pto-isa (case_float_hp_2x16, case_half_hp_2x32) + { + "name": "f32_2x16_hp", + "dtype": np.float32, + "shape": (2, 16), + "valid_shape": (2, 16), + "eps": 1e-6, + "precision_mode": "HIGH_PRECISION", + "test_pattern": "precision_sensitive", + "ulp_tolerance": 1, + }, + { + "name": "f16_2x32_hp", + "dtype": np.float16, + "shape": (2, 32), + "valid_shape": (2, 32), + "eps": 1e-3, + "precision_mode": "HIGH_PRECISION", + "test_pattern": "precision_sensitive", + "ulp_tolerance": 1, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tdiv/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tdiv/compare.py new file mode 100644 index 000000000..06d7fcc66 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tdiv/compare.py @@ -0,0 +1,295 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np +from pathlib import Path + +# Add current directory to path for standalone execution +script_dir = Path(__file__).parent +if script_dir not in sys.path: + sys.path.insert(0, str(script_dir)) + +# Add st_common directory +st_common_dir = script_dir.parent +if st_common_dir not in sys.path: + sys.path.insert(0, str(st_common_dir)) + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def compute_ulp_difference(golden, output, dtype): + """Compute ULP (Unit in the Last Place) difference between two arrays. + + ULP difference measures how many representable floating-point values + are between golden and output. + + Note: Only computes ULP for normal values (not NaN/Inf/zero). + + Args: + golden: numpy array of golden values + output: numpy array of output values + dtype: numpy dtype (float32 or float16) + + Returns: + Maximum ULP difference across all normal elements, or None if no normal values + """ + if dtype == np.float32: + int_dtype = np.uint32 + elif dtype == np.float16: + int_dtype = np.uint16 + else: + return None # ULP not applicable for integer types + + # Filter out NaN, Inf, and zero values (ULP not meaningful for these) + golden_normal = np.isfinite(golden) & (golden != 0) + output_normal = np.isfinite(output) & (output != 0) + normal_mask = golden_normal & output_normal + + if not np.any(normal_mask): + return None # No normal values to compare + + golden_filtered = golden[normal_mask] + output_filtered = output[normal_mask] + + # Convert to integer representation for ULP calculation + golden_int = golden_filtered.view(int_dtype) + output_int = output_filtered.view(int_dtype) + + # Handle sign difference: ULP counts across zero + # For same sign: simple difference + # For different sign: add both magnitudes (crosses zero boundary) + sign_bit = np.dtype(int_dtype).itemsize * 8 - 1 + golden_sign = golden_int >> sign_bit + output_sign = output_int >> sign_bit + + same_sign = (golden_sign == output_sign) + + # For same sign: subtract representations + ulp_diff_same = np.abs(golden_int.astype(np.int64) - output_int.astype(np.int64)) + + # For different sign: distance through zero (less common, treat as large difference) + # Use maximum possible ULP for different signs + ulp_diff_cross = np.iinfo(int_dtype).max + + ulp_diff = np.where(same_sign, ulp_diff_same, ulp_diff_cross) + + return np.max(ulp_diff) + + +def check_nan_inf_consistency(golden, output, relaxed=False): + """Check that NaN and Inf positions and values are consistent. + + IEEE 754 rules: + - NaN must appear at similar positions (hardware may differ in NaN type) + - Inf must have same sign at same positions + - Both must agree on which positions are NaN vs Inf vs normal + + Args: + golden: numpy array of golden values + output: numpy array of output values + relaxed: if True, allow NaN count differences (hardware may have different NaN handling) + + Returns: + (ok, error_msg) tuple + """ + # Check NaN positions + golden_nan = np.isnan(golden) + output_nan = np.isnan(output) + + # For relaxed mode, check NaN counts are similar (allow some variance) + if relaxed: + golden_nan_count = np.sum(golden_nan) + output_nan_count = np.sum(output_nan) + # Allow 20% variance in NaN count + if golden_nan_count > 0: + variance = abs(golden_nan_count - output_nan_count) / float(golden_nan_count) + if variance > 0.2: + return False, "NaN count variance > 20% (golden={}, output={})".format(golden_nan_count, output_nan_count) + # Continue with other checks even if NaN positions differ + else: + if not np.array_equal(golden_nan, output_nan): + nan_mismatch = np.where(golden_nan != output_nan) + return False, "NaN position mismatch at {} positions".format(len(nan_mismatch[0])) + + # Check Inf positions + golden_inf = np.isinf(golden) + output_inf = np.isinf(output) + + if not np.array_equal(golden_inf, output_inf): + inf_mismatch = np.where(golden_inf != output_inf) + return False, f"Inf position mismatch at {len(inf_mismatch[0])} positions" + + # Check Inf signs + if np.any(golden_inf): + golden_signs = np.sign(golden[golden_inf]) + output_signs = np.sign(output[golden_inf]) + if not np.array_equal(golden_signs, output_signs): + return False, "Inf sign mismatch" + + return True, None + + +def compare_high_precision_result(golden, output, dtype, ulp_tolerance=1, eps=1e-6, relaxed_nan=False): + """Compare results for HIGH_PRECISION mode. + + High-precision algorithm uses three-candidate search which may select + a different but more accurate rounding than numpy standard division. + + Comparison strategy: + 1. Check NaN/Inf consistency (may allow relaxed NaN checking) + 2. For normal/subnormal values: allow ±ulp_tolerance ULP difference + + Args: + golden: numpy array of reference values (numpy division) + output: numpy array of NPU output values + dtype: numpy dtype + ulp_tolerance: maximum allowed ULP difference (default 1) + eps: fallback tolerance for non-float types + relaxed_nan: if True, allow NaN count variance (default False) + + Returns: + (ok, error_msg) tuple + """ + # 1. Check NaN/Inf consistency + ok, error_msg = check_nan_inf_consistency(golden, output, relaxed=relaxed_nan) + if not ok: + return False, error_msg + + # 2. Filter out NaN/Inf for numerical comparison + golden_nan = np.isnan(golden) + golden_inf = np.isinf(golden) + normal_mask = ~(golden_nan | golden_inf) + + if not np.any(normal_mask): + return True, None # All NaN/Inf, already checked + + golden_normal = golden[normal_mask] + output_normal = output[normal_mask] + + # 3. Use ULP tolerance for float types + if dtype in (np.float32, np.float16): + max_ulp = compute_ulp_difference(golden_normal, output_normal, dtype) + if max_ulp is not None and max_ulp <= ulp_tolerance: + return True, f"ULP tolerance passed (max_ulp={max_ulp})" + + # Fallback to eps-based comparison if ULP check fails + ok = result_cmp(golden_normal, output_normal, eps) + if not ok: + return False, f"Both ULP ({max_ulp}) and eps ({eps}) check failed" + return True, f"Passed with eps tolerance (max_ulp={max_ulp} > {ulp_tolerance})" + + # 4. For integer types, use exact comparison + else: + ok = np.array_equal(golden_normal, output_normal) + if not ok: + mismatch = np.where(golden_normal != output_normal) + return False, f"Mismatch at {len(mismatch[0])} positions" + return True, None + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + test_pattern = case.get("test_pattern", "normal") + precision_mode = case.get("precision_mode", "DEFAULT") + check_inf_nan = case.get("check_inf_nan", False) + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + eps = case["eps"] + dtype_name = case["dtype"].__name__ + + # Extract valid region + golden_valid = golden[:vr, :vc] + output_valid = output[:vr, :vc] + + # Integer types: exact comparison + if dtype_name in ("uint32", "int32", "uint16", "int16", "uint8", "int8"): + ok = np.array_equal(golden_valid, output_valid) + if not ok: + mismatch = np.where(golden_valid != output_valid) + print(style_fail(f"[ERROR] {case['name']}: mismatches at {len(mismatch[0])} positions")) + if len(mismatch[0]) > 0 and len(mismatch[0]) <= 10: + for i in range(len(mismatch[0])): + r, c = mismatch[0][i], mismatch[1][i] + print(f" [{r},{c}] golden={golden_valid[r,c]} output={output_valid[r,c]}") + all_passed = False + continue + + # Float types with special handling + else: + # HIGH_PRECISION mode: use ULP tolerance + if precision_mode == "HIGH_PRECISION": + ulp_tolerance = case.get("ulp_tolerance", 1) + # Use relaxed NaN checking for nan_inf and boundary tests + relaxed_nan = test_pattern in ("nan_inf", "boundary") + ok, msg = compare_high_precision_result( + golden_valid, output_valid, case["dtype"], + ulp_tolerance=ulp_tolerance, eps=eps, relaxed_nan=relaxed_nan + ) + if not ok: + print(style_fail("[ERROR] {}: {} (test={})".format(case['name'], msg, test_pattern))) + all_passed = False + continue + elif msg: + print(style_pass("[INFO] {}: {} (test={})".format(case['name'], msg, test_pattern))) + + # check_inf_nan flag or boundary test: check NaN/Inf separately + elif check_inf_nan or test_pattern == "boundary": + # Use relaxed NaN checking for nan_inf and boundary tests + relaxed = test_pattern in ("nan_inf", "boundary") + ok, msg = check_nan_inf_consistency(golden_valid, output_valid, relaxed=relaxed) + if not ok: + print(style_fail("[ERROR] {}: {} (test={})".format(case['name'], msg, test_pattern))) + all_passed = False + continue + + # Compare non-special values + golden_nan = np.isnan(golden_valid) + golden_inf = np.isinf(golden_valid) + normal_mask = ~(golden_nan | golden_inf) + + if np.any(normal_mask): + ok = result_cmp(golden_valid[normal_mask], output_valid[normal_mask], eps) + if not ok: + print(style_fail("[ERROR] {}: numerical mismatch (test={})".format(case['name'], test_pattern))) + all_passed = False + continue + + # Normal test: standard comparison + else: + ok = result_cmp(golden_valid, output_valid, eps) + if not ok: + print(style_fail("[ERROR] {}: comparison failed (test={})".format(case['name'], test_pattern))) + all_passed = False + continue + + print(style_pass("[INFO] {}: passed (dtype={}, precision={}, test={})".format(case['name'], dtype_name, precision_mode, test_pattern))) + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tdiv/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tdiv/gen_data.py new file mode 100644 index 000000000..79e5141d5 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tdiv/gen_data.py @@ -0,0 +1,327 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import sys +import os +from pathlib import Path + +# Add current directory to path for standalone execution +script_dir = Path(__file__).parent +if script_dir not in sys.path: + sys.path.insert(0, str(script_dir)) + +# Add st_common directory +st_common_dir = script_dir.parent +if st_common_dir not in sys.path: + sys.path.insert(0, str(st_common_dir)) + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + + +def generate_precision_sensitive_data(shape, dtype): + """Generate precision-sensitive ratios to test three-candidate search algorithm. + + Focuses on values that cannot be exactly represented in floating point: + - 1/3, 1/7, 7/3 - infinite binary representation + - Values near integer boundaries where z/z±1 compete + """ + rows, cols = shape + input1 = np.zeros(shape, dtype=dtype) + input2 = np.ones(shape, dtype=dtype) + + ratios = [(1, 3), (1, 7), (7, 3), (1, 11), (5, 3), (10, 3)] + + section_size = rows // len(ratios) + for i, (a, b) in enumerate(ratios): + start_row = i * section_size + end_row = min((i + 1) * section_size, rows) + input1[start_row:end_row, :] = dtype(a) + input2[start_row:end_row, :] = dtype(b) + + # Add variations: negative versions, different signs + remaining_rows = rows - len(ratios) * section_size + if remaining_rows > 0: + input1[-remaining_rows:, :] = np.random.choice([-1, 1], size=(remaining_rows, cols)).astype(dtype) + input2[-remaining_rows:, :] = dtype(3) + + return input1, input2 + + +def generate_subnormal_test_data(shape, dtype): + """Generate subnormal (denormal) numbers to test normalization handling. + + NOTE: High-precision division algorithm (Div754) has asymmetric subnormal detection: + - src0 (dividend): EQ comparison - only detects MAX_SUBNORMAL (0x007FFFFF for f32) + - src1 (divisor): LT comparison - detects entire subnormal range + + Test design constraints: + - Section 1: src0 = MAX_SUBNORMAL, src1 = normal (tests src0 EQ detection) + - Section 2: src0 = MAX_SUBNORMAL, src1 = larger subnormal (tests both subnormal) + - Section 3: src0 = normal, src1 = MAX_SUBNORMAL (tests src1 subnormal with normal src0) + - Section 4: normal reference + + Avoid "normal / small_subnormal" which would overflow to Inf. + """ + rows, cols = shape + input1 = np.zeros(shape, dtype=dtype) + input2 = np.ones(shape, dtype=dtype) + + if dtype == np.float32: + tiny = np.finfo(np.float32).tiny + subnormal_max = np.frombuffer(np.array([0x007FFFFF], dtype=np.uint32), dtype=np.float32)[0] + subnormal_min = np.float32(1e-45) + normal_min = tiny * np.float32(2.0) + else: # float16 + tiny = np.finfo(np.float16).tiny + subnormal_max = np.frombuffer(np.array([0x03FF], dtype=np.uint16), dtype=np.float16)[0] + subnormal_min = np.float16(1e-8) + normal_min = tiny * np.float16(2.0) + + quarter = rows // 4 + + # Section 1: src0 = MAX_SUBNORMAL, src1 = normal + # ratio ≈ 1e-38 / 10 ≈ 1e-39 (不 overflow) + input1[:quarter, :] = subnormal_max + input2[:quarter, :] = np.random.uniform(normal_min, 100.0, size=(quarter, cols)).astype(dtype) + + # Section 2: src0 = MAX_SUBNORMAL, src1 = smaller subnormal (ratio ≈ 1-10) + # 确保 src1 在 subnormal 范围内: subnormal_min ~ subnormal_max + input1[quarter:2*quarter, :] = subnormal_max + input2[quarter:2*quarter, :] = np.random.uniform(subnormal_max * 0.1, subnormal_max, + size=(quarter, cols)).astype(dtype) + + # Section 3: src0 = MAX_SUBNORMAL, src1 = very small subnormal (ratio ≈ 10-500) + input1[2*quarter:3*quarter, :] = subnormal_max + input2[2*quarter:3*quarter, :] = np.random.uniform(subnormal_min, subnormal_max * 0.1, + size=(quarter, cols)).astype(dtype) + + # Section 4: normal reference + input1[3*quarter:, :] = np.random.uniform(0.1, 100.0, size=(rows-3*quarter, cols)).astype(dtype) + input2[3*quarter:, :] = np.random.uniform(0.1, 100.0, size=(rows-3*quarter, cols)).astype(dtype) + + return input1, input2 + + +def generate_overflow_test_data(shape, dtype): + """Generate overflow/underflow boundary values to test exponent handling. + + Tests: + - Large/small ratios that overflow to Inf + - Tiny ratios that underflow to 0 or min denormal + - Values at max/min exponent boundaries + """ + rows, cols = shape + input1 = np.zeros(shape, dtype=dtype) + input2 = np.ones(shape, dtype=dtype) + + if dtype == np.float32: + large_val = np.float32(1e30) + tiny_val = np.float32(1e-30) + overflow_trigger = np.float32(1e38) + underflow_trigger = np.float32(1e-45) + max_normal = np.float32(3.4e38) + else: # float16 + large_val = np.float16(60000) # Near f16 max (65504) + tiny_val = np.float16(0.0001) + overflow_trigger = np.float16(65000) + underflow_trigger = np.float16(1e-7) + max_normal = np.float16(65504) + + # Section 1: Overflow scenarios + quarter = rows // 4 + input1[:quarter, :cols//2] = overflow_trigger + input2[:quarter, :cols//2] = tiny_val # overflow_trigger / tiny_val -> Inf + + input1[:quarter, cols//2:] = large_val + input2[:quarter, cols//2:] = np.random.uniform(1e-35 if dtype==np.float32 else 1e-7, + tiny_val, + size=(quarter, cols//2)).astype(dtype) + + # Section 2: Underflow scenarios + input1[quarter:2*quarter, :cols//2] = underflow_trigger + input2[quarter:2*quarter, :cols//2] = large_val # underflow_trigger / large_val -> 0 + + input1[quarter:2*quarter, cols//2:] = tiny_val + input2[quarter:2*quarter, cols//2:] = np.random.uniform(large_val, max_normal, + size=(quarter, cols//2)).astype(dtype) + + # Section 3: Near boundary (may or may not overflow) + input1[2*quarter:3*quarter, :] = np.random.uniform(large_val/10, max_normal, + size=(quarter, cols)).astype(dtype) + input2[2*quarter:3*quarter, :] = np.random.uniform(tiny_val/10, tiny_val, + size=(quarter, cols)).astype(dtype) + + # Section 4: Normal values (control group) + input1[3*quarter:, :] = np.random.uniform(0.1, 100.0, + size=(rows-3*quarter, cols)).astype(dtype) + input2[3*quarter:, :] = np.random.uniform(0.1, 100.0, + size=(rows-3*quarter, cols)).astype(dtype) + + return input1, input2 + + +def generate_nan_inf_test_data(shape, dtype): + """Generate NaN and Inf inputs to test special value propagation. + + Tests IEEE 754 rules: + - 0/0 -> NaN + - Inf/Inf -> NaN + - x/0 -> Inf (or NaN if x=0) + - Inf/x -> Inf + - x/Inf -> 0 + - NaN propagates + """ + rows, cols = shape + input1 = np.zeros(shape, dtype=dtype) + input2 = np.ones(shape, dtype=dtype) + + # Create special values + if dtype == np.float32: + pos_inf = np.float32(np.inf) + neg_inf = np.float32(-np.inf) + nan_val = np.float32(np.nan) + zero_val = np.float32(0.0) + pos_one = np.float32(1.0) + neg_one = np.float32(-1.0) + else: # float16 + pos_inf = np.float16(np.inf) + neg_inf = np.float16(-np.inf) + nan_val = np.float16(np.nan) + zero_val = np.float16(0.0) + pos_one = np.float16(1.0) + neg_one = np.float16(-1.0) + + # Section 1: 0/0 -> NaN, x/0 -> Inf + eighth = rows // 8 + input1[0:eighth, :] = zero_val + input2[0:eighth, :] = zero_val # 0/0 -> NaN + + input1[eighth:2*eighth, :] = pos_one + input2[eighth:2*eighth, :] = zero_val # 1/0 -> Inf + + input1[2*eighth:3*eighth, :] = neg_one + input2[2*eighth:3*eighth, :] = zero_val # -1/0 -> -Inf + + # Section 2: Inf/Inf -> NaN, Inf/x -> Inf, x/Inf -> 0 + input1[3*eighth:4*eighth, :] = pos_inf + input2[3*eighth:4*eighth, :] = pos_inf # Inf/Inf -> NaN + + input1[4*eighth:5*eighth, :] = pos_inf + input2[4*eighth:5*eighth, :] = pos_one # Inf/1 -> Inf + + input1[5*eighth:6*eighth, :] = pos_one + input2[5*eighth:6*eighth, :] = pos_inf # 1/Inf -> 0 + + # Section 3: NaN propagation + input1[6*eighth:7*eighth, :] = nan_val + input2[6*eighth:7*eighth, :] = np.random.uniform(0.1, 10.0, + size=(eighth, cols)).astype(dtype) # NaN/x -> NaN + + input1[7*eighth:rows, :] = np.random.uniform(0.1, 10.0, + size=(rows-7*eighth, cols)).astype(dtype) + input2[7*eighth:rows, :cols//2] = nan_val # x/NaN -> NaN (half of remaining) + input2[7*eighth:rows, cols//2:] = np.random.uniform(0.1, 10.0, + size=(rows-7*eighth, cols//2)).astype(dtype) + + return input1, input2 + + +def generate_boundary_test_data(shape, dtype): + """Generate mixed boundary test data to stress IEEE 754 compliance. + + Combines subnormal and overflow scenarios (no NaN/Inf to avoid hardware limitations). + """ + rows, cols = shape + input1 = np.zeros(shape, dtype=dtype) + input2 = np.ones(shape, dtype=dtype) + + # Adapt thresholds based on dtype + if dtype == np.float32: + subnormal_val = np.float32(1.175e-38) + large_val = np.float32(1e30) + tiny_val = np.float32(1e-10) + elif dtype == np.float16: + subnormal_val = np.float16(6e-5) + large_val = np.float16(60000) + tiny_val = np.float16(0.001) + else: + subnormal_val = np.float32(1e-38) + large_val = np.float32(1e30) + tiny_val = np.float32(1e-10) + + # Section 1: Subnormal numbers (first half) + half = rows // 2 + if dtype == np.float32: + input1[:half, :] = np.random.uniform(1e-40, subnormal_val, + size=(half, cols)).astype(dtype) + else: + input1[:half, :] = np.random.uniform(1e-8, subnormal_val, + size=(half, cols)).astype(dtype) + input2[:half, :] = np.random.uniform(1.0, 10.0, + size=(half, cols)).astype(dtype) + + # Section 2: Overflow boundary (second half) + input1[half:, :cols//2] = large_val + input2[half:, :cols//2] = tiny_val + + input1[half:, cols//2:] = np.random.uniform(large_val/10, large_val, + size=(half, cols//2)).astype(dtype) + input2[half:, cols//2:] = np.random.uniform(tiny_val/10, tiny_val, + size=(half, cols//2)).astype(dtype) + + return input1, input2 + + +def generate_normal_data(shape, dtype): + """Generate simple random values for normal testing.""" + if dtype in (np.int32, np.int16, np.int8, np.uint8, np.uint16, np.uint32): + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + input2 = np.random.randint(1, 10, size=shape).astype(dtype) + else: + input1 = np.random.uniform(0.1, 100.0, size=shape).astype(dtype) + input2 = np.random.uniform(0.1, 100.0, size=shape).astype(dtype) + return input1, input2 + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + test_pattern = case.get("test_pattern", "normal") + + # Generate test data based on pattern + # NOTE: nan_inf test removed due to hardware vdiv NaN-from-division limitations + data_generators = { + "normal": generate_normal_data, + "precision_sensitive": generate_precision_sensitive_data, + "subnormal": generate_subnormal_test_data, + "overflow": generate_overflow_test_data, + "boundary": generate_boundary_test_data, + } + + generator = data_generators.get(test_pattern, generate_normal_data) + input1, input2 = generator(shape, dtype) + + # Compute golden reference using numpy (IEEE 754 compliant) + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + + # Suppress overflow/divide warnings for boundary tests (expected behavior) + with np.errstate(over='ignore', divide='ignore', invalid='ignore'): + golden[:vr, :vc] = (input1[:vr, :vc] / input2[:vr, :vc]).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) + precision_mode = case.get("precision_mode", "DEFAULT") + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__} test={test_pattern} precision={precision_mode}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tdiv/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tdiv/launch.cpp new file mode 100644 index 000000000..d4bbdb39a --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tdiv/launch.cpp @@ -0,0 +1,133 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + + +// Case: f32_16x64 +extern "C" __global__ AICORE void TDIV_f32_16x64(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTDIV_f32_16x64(float *a, float *b, float *c, void *stream) { + TDIV_f32_16x64<<<1, nullptr, stream>>>(a, b, c); +} + +// Case: f32_32x32 +extern "C" __global__ AICORE void TDIV_f32_32x32(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTDIV_f32_32x32(float *a, float *b, float *c, void *stream) { + TDIV_f32_32x32<<<1, nullptr, stream>>>(a, b, c); +} + +// Case: f32_64x64 +extern "C" __global__ AICORE void TDIV_f32_64x64(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTDIV_f32_64x64(float *a, float *b, float *c, void *stream) { + TDIV_f32_64x64<<<1, nullptr, stream>>>(a, b, c); +} + +// Case: f16_16x256 +extern "C" __global__ AICORE void TDIV_f16_16x256(__gm__ void *a, __gm__ void *b, __gm__ void *c); + +void LaunchTDIV_f16_16x256(void *a, void *b, void *c, void *stream) { + TDIV_f16_16x256<<<1, nullptr, stream>>>(a, b, c); +} + +// Case: f32_16x64_hp_precision +extern "C" __global__ AICORE void TDIV_f32_16x64_hp_precision(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTDIV_f32_16x64_hp_precision(float *a, float *b, float *c, void *stream) { + TDIV_f32_16x64_hp_precision<<<1, nullptr, stream>>>(a, b, c); +} + +// Case: f16_16x64_hp_precision +extern "C" __global__ AICORE void TDIV_f16_16x64_hp_precision(__gm__ void *a, __gm__ void *b, __gm__ void *c); + +void LaunchTDIV_f16_16x64_hp_precision(void *a, void *b, void *c, void *stream) { + TDIV_f16_16x64_hp_precision<<<1, nullptr, stream>>>(a, b, c); +} + +// Case: f32_16x64_hp_subnormal +extern "C" __global__ AICORE void TDIV_f32_16x64_hp_subnormal(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTDIV_f32_16x64_hp_subnormal(float *a, float *b, float *c, void *stream) { + TDIV_f32_16x64_hp_subnormal<<<1, nullptr, stream>>>(a, b, c); +} + +// Case: f16_16x64_hp_subnormal +extern "C" __global__ AICORE void TDIV_f16_16x64_hp_subnormal(__gm__ void *a, __gm__ void *b, __gm__ void *c); + +void LaunchTDIV_f16_16x64_hp_subnormal(void *a, void *b, void *c, void *stream) { + TDIV_f16_16x64_hp_subnormal<<<1, nullptr, stream>>>(a, b, c); +} + +// Case: f32_16x64_hp_overflow +extern "C" __global__ AICORE void TDIV_f32_16x64_hp_overflow(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTDIV_f32_16x64_hp_overflow(float *a, float *b, float *c, void *stream) { + TDIV_f32_16x64_hp_overflow<<<1, nullptr, stream>>>(a, b, c); +} + +// Case: f16_16x64_hp_overflow +extern "C" __global__ AICORE void TDIV_f16_16x64_hp_overflow(__gm__ void *a, __gm__ void *b, __gm__ void *c); + +void LaunchTDIV_f16_16x64_hp_overflow(void *a, void *b, void *c, void *stream) { + TDIV_f16_16x64_hp_overflow<<<1, nullptr, stream>>>(a, b, c); +} + +// Case: f32_32x32_hp +extern "C" __global__ AICORE void TDIV_f32_32x32_hp(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTDIV_f32_32x32_hp(float *a, float *b, float *c, void *stream) { + TDIV_f32_32x32_hp<<<1, nullptr, stream>>>(a, b, c); +} + +// Case: f32_64x64_hp +extern "C" __global__ AICORE void TDIV_f32_64x64_hp(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTDIV_f32_64x64_hp(float *a, float *b, float *c, void *stream) { + TDIV_f32_64x64_hp<<<1, nullptr, stream>>>(a, b, c); +} + +// Case: f16_16x256_hp +extern "C" __global__ AICORE void TDIV_f16_16x256_hp(__gm__ void *a, __gm__ void *b, __gm__ void *c); + +void LaunchTDIV_f16_16x256_hp(void *a, void *b, void *c, void *stream) { + TDIV_f16_16x256_hp<<<1, nullptr, stream>>>(a, b, c); +} + +// Case: f32_16x64_hp_partial +extern "C" __global__ AICORE void TDIV_f32_16x64_hp_partial(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTDIV_f32_16x64_hp_partial(float *a, float *b, float *c, void *stream) { + TDIV_f32_16x64_hp_partial<<<1, nullptr, stream>>>(a, b, c); +} + +// Case: f16_16x64_hp_partial +extern "C" __global__ AICORE void TDIV_f16_16x64_hp_partial(__gm__ void *a, __gm__ void *b, __gm__ void *c); + +void LaunchTDIV_f16_16x64_hp_partial(void *a, void *b, void *c, void *stream) { + TDIV_f16_16x64_hp_partial<<<1, nullptr, stream>>>(a, b, c); +} + +// Case: f32_2x16_hp +extern "C" __global__ AICORE void TDIV_f32_2x16_hp(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTDIV_f32_2x16_hp(float *a, float *b, float *c, void *stream) { + TDIV_f32_2x16_hp<<<1, nullptr, stream>>>(a, b, c); +} + +// Case: f16_2x32_hp +extern "C" __global__ AICORE void TDIV_f16_2x32_hp(__gm__ void *a, __gm__ void *b, __gm__ void *c); + +void LaunchTDIV_f16_2x32_hp(void *a, void *b, void *c, void *stream) { + TDIV_f16_2x32_hp<<<1, nullptr, stream>>>(a, b, c); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tdiv/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tdiv/main.cpp new file mode 100644 index 000000000..c4f1a55d4 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tdiv/main.cpp @@ -0,0 +1,176 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tdiv ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.cpp. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTDIV_f32_16x64(float *a, float *b, float *c, void *stream); +void LaunchTDIV_f32_32x32(float *a, float *b, float *c, void *stream); +void LaunchTDIV_f32_64x64(float *a, float *b, float *c, void *stream); +void LaunchTDIV_f16_16x256(void *a, void *b, void *c, void *stream); +void LaunchTDIV_f32_16x64_hp_precision(float *a, float *b, float *c, void *stream); +void LaunchTDIV_f16_16x64_hp_precision(void *a, void *b, void *c, void *stream); +void LaunchTDIV_f32_16x64_hp_subnormal(float *a, float *b, float *c, void *stream); +void LaunchTDIV_f16_16x64_hp_subnormal(void *a, void *b, void *c, void *stream); +void LaunchTDIV_f32_16x64_hp_overflow(float *a, float *b, float *c, void *stream); +void LaunchTDIV_f16_16x64_hp_overflow(void *a, void *b, void *c, void *stream); +void LaunchTDIV_f32_32x32_hp(float *a, float *b, float *c, void *stream); +void LaunchTDIV_f32_64x64_hp(float *a, float *b, float *c, void *stream); +void LaunchTDIV_f16_16x256_hp(void *a, void *b, void *c, void *stream); +void LaunchTDIV_f32_16x64_hp_partial(float *a, float *b, float *c, void *stream); +void LaunchTDIV_f16_16x64_hp_partial(void *a, void *b, void *c, void *stream); +void LaunchTDIV_f32_2x16_hp(float *a, float *b, float *c, void *stream); +void LaunchTDIV_f16_2x32_hp(void *a, void *b, void *c, void *stream); + +// Generic launch function type for void* pointers +using LaunchFn = void (*)(void *a, void *b, void *c, void *stream); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_16x64", (LaunchFn)LaunchTDIV_f32_16x64, 16, 64, 16, 64, 4}, + {"f32_32x32", (LaunchFn)LaunchTDIV_f32_32x32, 32, 32, 32, 32, 4}, + {"f32_64x64", (LaunchFn)LaunchTDIV_f32_64x64, 64, 64, 64, 64, 4}, + {"f16_16x256", (LaunchFn)LaunchTDIV_f16_16x256, 16, 256, 16, 256, 2}, + {"f32_16x64_hp_precision", (LaunchFn)LaunchTDIV_f32_16x64_hp_precision, 16, 64, 16, 64, 4}, + {"f16_16x64_hp_precision", (LaunchFn)LaunchTDIV_f16_16x64_hp_precision, 16, 64, 16, 64, 2}, + {"f32_16x64_hp_subnormal", (LaunchFn)LaunchTDIV_f32_16x64_hp_subnormal, 16, 64, 16, 64, 4}, + {"f16_16x64_hp_subnormal", (LaunchFn)LaunchTDIV_f16_16x64_hp_subnormal, 16, 64, 16, 64, 2}, + {"f32_16x64_hp_overflow", (LaunchFn)LaunchTDIV_f32_16x64_hp_overflow, 16, 64, 16, 64, 4}, + {"f16_16x64_hp_overflow", (LaunchFn)LaunchTDIV_f16_16x64_hp_overflow, 16, 64, 16, 64, 2}, + {"f32_32x32_hp", (LaunchFn)LaunchTDIV_f32_32x32_hp, 32, 32, 32, 32, 4}, + {"f32_64x64_hp", (LaunchFn)LaunchTDIV_f32_64x64_hp, 64, 64, 64, 64, 4}, + {"f16_16x256_hp", (LaunchFn)LaunchTDIV_f16_16x256_hp, 16, 256, 16, 256, 2}, + {"f32_16x64_hp_partial", (LaunchFn)LaunchTDIV_f32_16x64_hp_partial, 16, 64, 16, 31, 4}, + {"f16_16x64_hp_partial", (LaunchFn)LaunchTDIV_f16_16x64_hp_partial, 16, 64, 16, 63, 2}, + {"f32_2x16_hp", (LaunchFn)LaunchTDIV_f32_2x16_hp, 2, 16, 2, 16, 4}, + {"f16_2x32_hp", (LaunchFn)LaunchTDIV_f16_2x32_hp, 2, 32, 2, 32, 2}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t src0FileSize = fileSize; + size_t src1FileSize = fileSize; + + float *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + float *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), fileSize); + aclrtMallocHost((void **)(&src1Host), fileSize); + aclrtMallocHost((void **)(&dstHost), fileSize); + + aclrtMalloc((void **)&src0Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, fileSize, src0Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, fileSize, src1Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tdiv [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tdiv/tdiv.pto b/test/tilelang_st/npu/a5/src/st/testcase/tdiv/tdiv.pto new file mode 100644 index 000000000..3c4ad7ae5 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tdiv/tdiv.pto @@ -0,0 +1,971 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tdiv: tload(a) + tload(b) + tdiv(a,b)->c + tstore(c). +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // Case 0: f32 16x64 (1024 elements) + func.func @TDIV_f32_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tdiv ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + return + } + + // Case 1: f32 32x32 (1024 elements) + func.func @TDIV_f32_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + outs(%b : !pto.tile_buf) + + pto.tdiv ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + return + } + + // Case: f32_64x64 + func.func @TDIV_f32_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tdiv ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + return + } + + // Case: f16_16x256 + func.func @TDIV_f16_16x256(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xf16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c256] + : !pto.tensor_view<1x1x1x16x256xf16> -> !pto.partition_tensor_view<1x1x1x16x256xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c256] + : !pto.tensor_view<1x1x1x16x256xf16> -> !pto.partition_tensor_view<1x1x1x16x256xf16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c256] + : !pto.tensor_view<1x1x1x16x256xf16> -> !pto.partition_tensor_view<1x1x1x16x256xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x256xf16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x256xf16>) + outs(%b : !pto.tile_buf) + + pto.tdiv ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x16x256xf16>) + return + } + + // Case: f32_16x64_hp_precision + func.func @TDIV_f32_16x64_hp_precision(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tdiv ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + {precision_mode = #pto} + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + return + } + + // Case: f16_16x64_hp_precision + func.func @TDIV_f16_16x64_hp_precision(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + outs(%b : !pto.tile_buf) + + pto.tdiv ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + {precision_mode = #pto} + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + return + } + + // Case: f32_16x64_hp_subnormal + func.func @TDIV_f32_16x64_hp_subnormal(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tdiv ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + {precision_mode = #pto} + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + return + } + + // Case: f16_16x64_hp_subnormal + func.func @TDIV_f16_16x64_hp_subnormal(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + outs(%b : !pto.tile_buf) + + pto.tdiv ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + {precision_mode = #pto} + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + return + } + + // Case: f32_16x64_hp_overflow + func.func @TDIV_f32_16x64_hp_overflow(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tdiv ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + {precision_mode = #pto} + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + return + } + + // Case: f16_16x64_hp_overflow + func.func @TDIV_f16_16x64_hp_overflow(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + outs(%b : !pto.tile_buf) + + pto.tdiv ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + {precision_mode = #pto} + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + return + } + + // Case: f32_32x32_hp + func.func @TDIV_f32_32x32_hp(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + outs(%b : !pto.tile_buf) + + pto.tdiv ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + {precision_mode = #pto} + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + return + } + + // Case: f32_64x64_hp + func.func @TDIV_f32_64x64_hp(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tdiv ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + {precision_mode = #pto} + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + return + } + + // Case: f16_16x256_hp + func.func @TDIV_f16_16x256_hp(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xf16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c256] + : !pto.tensor_view<1x1x1x16x256xf16> -> !pto.partition_tensor_view<1x1x1x16x256xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c256] + : !pto.tensor_view<1x1x1x16x256xf16> -> !pto.partition_tensor_view<1x1x1x16x256xf16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c256] + : !pto.tensor_view<1x1x1x16x256xf16> -> !pto.partition_tensor_view<1x1x1x16x256xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x256xf16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x256xf16>) + outs(%b : !pto.tile_buf) + + pto.tdiv ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + {precision_mode = #pto} + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x16x256xf16>) + return + } + + // Case: f32_16x64_hp_partial + func.func @TDIV_f32_16x64_hp_partial(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tdiv ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + {precision_mode = #pto} + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + return + } + + // Case: f16_16x64_hp_partial + func.func @TDIV_f16_16x64_hp_partial(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + outs(%b : !pto.tile_buf) + + pto.tdiv ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + {precision_mode = #pto} + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + return + } + + // Case: f32_2x16_hp + func.func @TDIV_f32_2x16_hp(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c2, %c16], + strides = [%c32, %c32, %c32, %c16, %c1] + : !pto.tensor_view<1x1x1x2x16xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c2, %c16], + strides = [%c32, %c32, %c32, %c16, %c1] + : !pto.tensor_view<1x1x1x2x16xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c2, %c16], + strides = [%c32, %c32, %c32, %c16, %c1] + : !pto.tensor_view<1x1x1x2x16xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c16] + : !pto.tensor_view<1x1x1x2x16xf32> -> !pto.partition_tensor_view<1x1x1x2x16xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c16] + : !pto.tensor_view<1x1x1x2x16xf32> -> !pto.partition_tensor_view<1x1x1x2x16xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c16] + : !pto.tensor_view<1x1x1x2x16xf32> -> !pto.partition_tensor_view<1x1x1x2x16xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x2x16xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x2x16xf32>) + outs(%b : !pto.tile_buf) + + pto.tdiv ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + {precision_mode = #pto} + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x2x16xf32>) + return + } + + // Case: f16_2x32_hp + func.func @TDIV_f16_2x32_hp(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c2, %c32], + strides = [%c64, %c64, %c64, %c32, %c1] + : !pto.tensor_view<1x1x1x2x32xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c2, %c32], + strides = [%c64, %c64, %c64, %c32, %c1] + : !pto.tensor_view<1x1x1x2x32xf16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c2, %c32], + strides = [%c64, %c64, %c64, %c32, %c1] + : !pto.tensor_view<1x1x1x2x32xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c32] + : !pto.tensor_view<1x1x1x2x32xf16> -> !pto.partition_tensor_view<1x1x1x2x32xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c32] + : !pto.tensor_view<1x1x1x2x32xf16> -> !pto.partition_tensor_view<1x1x1x2x32xf16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c32] + : !pto.tensor_view<1x1x1x2x32xf16> -> !pto.partition_tensor_view<1x1x1x2x32xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x2x32xf16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x2x32xf16>) + outs(%b : !pto.tile_buf) + + pto.tdiv ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + {precision_mode = #pto} + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x2x32xf16>) + return + } +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tdivs/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tdivs/CMakeLists.txt new file mode 100644 index 000000000..cfd816f61 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tdivs/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tdivs) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tdivs/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tdivs/cases.py new file mode 100644 index 000000000..c264d7805 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tdivs/cases.py @@ -0,0 +1,245 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tdivs ST test cases. + +vdiv only supports f16/f32 in TileLang DSL v1. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - eps: tolerance for numpy.allclose (atol and rtol). + - direction: "src_scalar" (src / scalar) or "scalar_src" (scalar / src) + - precision_mode: optional, "DEFAULT" or "HIGH_PRECISION". + - test_pattern: optional, "normal", "precision_sensitive", "subnormal", "overflow" + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + # ============================================================ + # Normal cases - basic functionality (DEFAULT precision mode) + # ============================================================ + # src / scalar direction + { + "name": "f32_32x64", + "dtype": np.float32, + "shape": (32, 64), + "valid_shape": (32, 64), + "eps": 1e-6, + "direction": "src_scalar", + }, + { + "name": "f16_63x64", + "dtype": np.float16, + "shape": (63, 64), + "valid_shape": (63, 64), + "eps": 1e-3, + "direction": "src_scalar", + }, + { + "name": "f32_7x448", + "dtype": np.float32, + "shape": (7, 448), + "valid_shape": (7, 448), + "eps": 1e-6, + "direction": "src_scalar", + }, + { + "name": "f32_256x16", + "dtype": np.float32, + "shape": (256, 16), + "valid_shape": (256, 16), + "eps": 1e-6, + "direction": "src_scalar", + }, + # scalar / src direction + { + "name": "f32_32x64_scalar_src", + "dtype": np.float32, + "shape": (32, 64), + "valid_shape": (32, 64), + "eps": 1e-6, + "direction": "scalar_src", + }, + { + "name": "f16_63x64_scalar_src", + "dtype": np.float16, + "shape": (63, 64), + "valid_shape": (63, 64), + "eps": 1e-3, + "direction": "scalar_src", + }, + { + "name": "f32_7x448_scalar_src", + "dtype": np.float32, + "shape": (7, 448), + "valid_shape": (7, 448), + "eps": 1e-6, + "direction": "scalar_src", + }, + { + "name": "f32_256x16_scalar_src", + "dtype": np.float32, + "shape": (256, 16), + "valid_shape": (256, 16), + "eps": 1e-6, + "direction": "scalar_src", + }, + + # ============================================================ + # HIGH_PRECISION mode - src / scalar direction + # ============================================================ + # Precision-sensitive ratios + { + "name": "f32_32x64_hp", + "dtype": np.float32, + "shape": (32, 64), + "valid_shape": (32, 64), + "eps": 1e-6, + "precision_mode": "HIGH_PRECISION", + "direction": "src_scalar", + "test_pattern": "precision_sensitive", + "ulp_tolerance": 1, + }, + { + "name": "f16_63x64_hp", + "dtype": np.float16, + "shape": (63, 64), + "valid_shape": (63, 64), + "eps": 1e-3, + "precision_mode": "HIGH_PRECISION", + "direction": "src_scalar", + "test_pattern": "precision_sensitive", + "ulp_tolerance": 1, + }, + + # Subnormal numbers + { + "name": "f32_16x64_hp_subnormal", + "dtype": np.float32, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-6, + "precision_mode": "HIGH_PRECISION", + "direction": "src_scalar", + "test_pattern": "subnormal", + "ulp_tolerance": 2, + }, + { + "name": "f16_16x64_hp_subnormal", + "dtype": np.float16, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-3, + "precision_mode": "HIGH_PRECISION", + "direction": "src_scalar", + "test_pattern": "subnormal", + "ulp_tolerance": 2, + }, + + # Overflow/Underflow boundaries + { + "name": "f32_16x64_hp_overflow", + "dtype": np.float32, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-6, + "precision_mode": "HIGH_PRECISION", + "direction": "src_scalar", + "test_pattern": "overflow", + }, + { + "name": "f16_16x64_hp_overflow", + "dtype": np.float16, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-3, + "precision_mode": "HIGH_PRECISION", + "direction": "src_scalar", + "test_pattern": "overflow", + }, + + # ============================================================ + # HIGH_PRECISION mode - scalar / src direction + # ============================================================ + { + "name": "f32_32x64_hp_scalar_src", + "dtype": np.float32, + "shape": (32, 64), + "valid_shape": (32, 64), + "eps": 1e-6, + "precision_mode": "HIGH_PRECISION", + "direction": "scalar_src", + "test_pattern": "precision_sensitive", + "ulp_tolerance": 1, + }, + { + "name": "f16_63x64_hp_scalar_src", + "dtype": np.float16, + "shape": (63, 64), + "valid_shape": (63, 64), + "eps": 1e-3, + "precision_mode": "HIGH_PRECISION", + "direction": "scalar_src", + "test_pattern": "precision_sensitive", + "ulp_tolerance": 1, + }, + + # Subnormal - scalar / src (scalar is normal, src contains subnormals) + { + "name": "f32_16x64_hp_subnormal_scalar_src", + "dtype": np.float32, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-6, + "precision_mode": "HIGH_PRECISION", + "direction": "scalar_src", + "test_pattern": "subnormal", + "ulp_tolerance": 2, + }, + { + "name": "f16_16x64_hp_subnormal_scalar_src", + "dtype": np.float16, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-3, + "precision_mode": "HIGH_PRECISION", + "direction": "scalar_src", + "test_pattern": "subnormal", + "ulp_tolerance": 2, + }, + + # Overflow - scalar / src (division by small src values) + { + "name": "f32_16x64_hp_overflow_scalar_src", + "dtype": np.float32, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-6, + "precision_mode": "HIGH_PRECISION", + "direction": "scalar_src", + "test_pattern": "overflow", + }, + { + "name": "f16_16x64_hp_overflow_scalar_src", + "dtype": np.float16, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-3, + "precision_mode": "HIGH_PRECISION", + "direction": "scalar_src", + "test_pattern": "overflow", + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tdivs/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tdivs/compare.py new file mode 100644 index 000000000..50186777e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tdivs/compare.py @@ -0,0 +1,46 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tdivs/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tdivs/gen_data.py new file mode 100644 index 000000000..630491906 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tdivs/gen_data.py @@ -0,0 +1,247 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import sys +import os +from pathlib import Path + +# Add current directory to path for standalone execution +script_dir = Path(__file__).parent +if script_dir not in sys.path: + sys.path.insert(0, str(script_dir)) + +# Add st_common directory +st_common_dir = script_dir.parent +if st_common_dir not in sys.path: + sys.path.insert(0, str(st_common_dir)) + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +# Default scalar value for division (matches the scalar passed in launch.cpp) +DEFAULT_SCALAR = 3.0 + + +def generate_precision_sensitive_scalar(shape, dtype, direction): + """Generate precision-sensitive test data for scalar division. + + Uses scalar values that create precision-sensitive ratios when divided + with tile data (e.g., 1/3, 1/7 patterns). + """ + rows, cols = shape + + # For src / scalar: tile contains precision-sensitive values + # For scalar / src: scalar is precision-sensitive, src contains small integers + if direction == "src_scalar": + # Tile contains values like 1, 7, 5, 10 etc divided by scalar 3 + # Results: 1/3, 7/3, 5/3, 10/3 - precision-sensitive + input1 = np.zeros(shape, dtype=dtype) + values = [1, 7, 5, 10, 1, 3, 2, 11] + section_size = rows // len(values) + for i, v in enumerate(values): + start_row = i * section_size + end_row = min((i + 1) * section_size, rows) + input1[start_row:end_row, :] = dtype(v) + scalar = dtype(DEFAULT_SCALAR) + else: # scalar_src + # Scalar is 1, tile contains 3, 7, etc -> 1/3, 1/7 precision-sensitive + input1 = np.full(shape, dtype(3), dtype=dtype) # Avoid zeros + # Fill with divisor values that create precision-sensitive ratios + values = [3, 7, 11, 3, 5, 7, 11, 3] + section_size = rows // len(values) + for i, v in enumerate(values): + start_row = i * section_size + end_row = min((i + 1) * section_size, rows) + input1[start_row:end_row, :] = dtype(v) + scalar = dtype(1.0) + + return input1, scalar + + +def generate_subnormal_test_data(shape, dtype, direction): + """Generate subnormal (denormal) numbers for scalar division tests. + + For src / scalar: + - src contains subnormal values, scalar is normal + - Tests subnormal dividend handling + + For scalar / src: + - scalar is normal, src contains subnormal values + - Tests subnormal divisor handling (can produce large results) + """ + rows, cols = shape + + if dtype == np.float32: + subnormal_max = np.frombuffer(np.array([0x007FFFFF], dtype=np.uint32), dtype=np.float32)[0] + subnormal_min = np.float32(1e-45) + normal_min = np.float32(1e-38) * np.float32(2.0) # smallest normal + else: # float16 + subnormal_max = np.frombuffer(np.array([0x03FF], dtype=np.uint16), dtype=np.float16)[0] + subnormal_min = np.float16(1e-8) + normal_min = np.float16(6e-5) * np.float16(2.0) + + if direction == "src_scalar": + # src contains subnormal values, scalar is normal (e.g., 10) + input1 = np.zeros(shape, dtype=dtype) + quarter = rows // 4 + + # Section 1: MAX_SUBNORMAL / normal -> tiny normal result + input1[:quarter, :] = subnormal_max + + # Section 2: Mid-range subnormal / normal + input1[quarter:2*quarter, :] = np.random.uniform( + subnormal_min, subnormal_max, size=(quarter, cols)).astype(dtype) + + # Section 3: Smallest subnormal / normal + input1[2*quarter:3*quarter, :] = subnormal_min + + # Section 4: Normal reference + input1[3*quarter:, :] = np.random.uniform(0.1, 100.0, size=(rows-3*quarter, cols)).astype(dtype) + + scalar = dtype(10.0) + else: # scalar_src + # scalar is normal (e.g., 1e-20 for f32), src contains subnormal + # This tests: normal / subnormal -> large result (potential overflow) + input1 = np.zeros(shape, dtype=dtype) + quarter = rows // 4 + + # Section 1: normal / MAX_SUBNORMAL -> large but not overflow + input1[:quarter, :] = subnormal_max + + # Section 2: normal / mid subnormal -> larger + input1[quarter:2*quarter, :] = np.random.uniform( + subnormal_max * 0.1, subnormal_max, size=(quarter, cols)).astype(dtype) + + # Section 3: normal / tiny subnormal -> very large (near overflow) + input1[2*quarter:3*quarter, :] = np.random.uniform( + subnormal_min, subnormal_max * 0.1, size=(quarter, cols)).astype(dtype) + + # Section 4: Normal reference + input1[3*quarter:, :] = np.random.uniform(0.1, 100.0, size=(rows-3*quarter, cols)).astype(dtype) + + # Use a small normal scalar that won't overflow when divided by smallest subnormal + if dtype == np.float32: + scalar = np.float32(1e-20) # Safe: 1e-20 / 1e-45 = 1e25, within f32 range + else: + scalar = np.float16(1e-5) # Safe: 1e-5 / 1e-8 = 1000, within f16 range + + return input1, scalar + + +def generate_overflow_test_data(shape, dtype, direction): + """Generate overflow/underflow boundary values for scalar division tests. + + For src / scalar: + - Large src / tiny scalar -> overflow + - Tiny src / large scalar -> underflow + + For scalar / src: + - Large scalar / tiny src -> overflow + - Tiny scalar / large src -> underflow + """ + rows, cols = shape + + if dtype == np.float32: + large_val = np.float32(1e30) + tiny_val = np.float32(1e-30) + overflow_trigger = np.float32(1e38) + underflow_trigger = np.float32(1e-45) + else: # float16 + large_val = np.float16(60000) + tiny_val = np.float16(0.0001) + overflow_trigger = np.float16(65000) + underflow_trigger = np.float16(1e-7) + + if direction == "src_scalar": + input1 = np.zeros(shape, dtype=dtype) + quarter = rows // 4 + + # Section 1: Overflow - large / tiny + input1[:quarter, :] = overflow_trigger + + # Section 2: Near overflow boundary + input1[quarter:2*quarter, :] = np.random.uniform(large_val, overflow_trigger, + size=(quarter, cols)).astype(dtype) + + # Section 3: Underflow - tiny / large + input1[2*quarter:3*quarter, :] = underflow_trigger + + # Section 4: Normal reference + input1[3*quarter:, :] = np.random.uniform(0.1, 100.0, size=(rows-3*quarter, cols)).astype(dtype) + + scalar = dtype(tiny_val) # Tiny scalar triggers overflow + + else: # scalar_src + input1 = np.zeros(shape, dtype=dtype) + quarter = rows // 4 + + # Section 1: Overflow - scalar / tiny src + input1[:quarter, :] = tiny_val # Tiny divisor + + # Section 2: Near overflow boundary + input1[quarter:2*quarter, :] = np.random.uniform( + tiny_val/10, tiny_val, size=(quarter, cols)).astype(dtype) + + # Section 3: Underflow - scalar / large src + input1[2*quarter:3*quarter, :] = large_val + + # Section 4: Normal reference + input1[3*quarter:, :] = np.random.uniform(0.1, 100.0, size=(rows-3*quarter, cols)).astype(dtype) + + # Large scalar triggers overflow when divided by tiny src + scalar = dtype(overflow_trigger) + + return input1, scalar + + +def generate_normal_data(shape, dtype, direction): + """Generate simple random values for normal testing.""" + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + scalar = dtype(DEFAULT_SCALAR) + return input1, scalar + + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + direction = case.get("direction", "src_scalar") + test_pattern = case.get("test_pattern", "normal") + + # Generate test data based on pattern and direction + data_generators = { + "normal": generate_normal_data, + "precision_sensitive": generate_precision_sensitive_scalar, + "subnormal": generate_subnormal_test_data, + "overflow": generate_overflow_test_data, + } + + generator = data_generators.get(test_pattern, generate_normal_data) + input1, scalar_val = generator(shape, dtype, direction) + + # Compute golden reference using numpy (IEEE 754 compliant) + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + + # Suppress overflow/divide warnings for boundary tests (expected behavior) + with np.errstate(over='ignore', divide='ignore', invalid='ignore'): + if direction == "src_scalar": + golden[:vr, :vc] = (input1[:vr, :vc] / scalar_val).astype(dtype, copy=False) + else: # scalar_src + golden[:vr, :vc] = (scalar_val / input1[:vr, :vc]).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "golden": golden}) + precision_mode = case.get("precision_mode", "DEFAULT") + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__} direction={direction} test={test_pattern} precision={precision_mode} scalar={scalar_val}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tdivs/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tdivs/launch.cpp new file mode 100644 index 000000000..3b6cae07c --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tdivs/launch.cpp @@ -0,0 +1,151 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +static constexpr float TDIVS_SCALAR_F32 = 3.0f; + +// Helper to convert IEEE 754 hex bits to float (runtime initialization) +inline float bits_to_float(uint32_t bits) { + float result; + memcpy(&result, &bits, sizeof(float)); + return result; +} + +// ========== src / scalar direction ========== + +// Case 0: f32 32x64 +extern "C" __global__ AICORE void TDIVS_f32_32x64(__gm__ float *src, __gm__ float *dst, float scalar); +void LaunchTDIVS_f32_32x64(float *src, float *dst, void *stream) { + TDIVS_f32_32x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, TDIVS_SCALAR_F32); +} + +// Case 1: f16 63x64 +extern "C" __global__ AICORE void TDIVS_f16_63x64(__gm__ unsigned short *src, __gm__ unsigned short *dst, unsigned short scalar); +void LaunchTDIVS_f16_63x64(unsigned short *src, unsigned short *dst, void *stream) { + TDIVS_f16_63x64<<<1, nullptr, stream>>>((__gm__ unsigned short *)src, (__gm__ unsigned short *)dst, (unsigned short)0x4200); +} + +// Case 2: f32 7x448 +extern "C" __global__ AICORE void TDIVS_f32_7x448(__gm__ float *src, __gm__ float *dst, float scalar); +void LaunchTDIVS_f32_7x448(float *src, float *dst, void *stream) { + TDIVS_f32_7x448<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, TDIVS_SCALAR_F32); +} + +// Case 3: f32 256x16 +extern "C" __global__ AICORE void TDIVS_f32_256x16(__gm__ float *src, __gm__ float *dst, float scalar); +void LaunchTDIVS_f32_256x16(float *src, float *dst, void *stream) { + TDIVS_f32_256x16<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, TDIVS_SCALAR_F32); +} + +// ========== scalar / src direction ========== + +// Case 4: f32 32x64 scalar/src +extern "C" __global__ AICORE void TDIVS_f32_32x64_scalar_src(__gm__ float *src, __gm__ float *dst, float scalar); +void LaunchTDIVS_f32_32x64_scalar_src(float *src, float *dst, void *stream) { + TDIVS_f32_32x64_scalar_src<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, TDIVS_SCALAR_F32); +} + +// Case 5: f16 63x64 scalar/src +extern "C" __global__ AICORE void TDIVS_f16_63x64_scalar_src(__gm__ unsigned short *src, __gm__ unsigned short *dst, unsigned short scalar); +void LaunchTDIVS_f16_63x64_scalar_src(unsigned short *src, unsigned short *dst, void *stream) { + TDIVS_f16_63x64_scalar_src<<<1, nullptr, stream>>>((__gm__ unsigned short *)src, (__gm__ unsigned short *)dst, (unsigned short)0x4200); +} + +// Case 6: f32 7x448 scalar/src +extern "C" __global__ AICORE void TDIVS_f32_7x448_scalar_src(__gm__ float *src, __gm__ float *dst, float scalar); +void LaunchTDIVS_f32_7x448_scalar_src(float *src, float *dst, void *stream) { + TDIVS_f32_7x448_scalar_src<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, TDIVS_SCALAR_F32); +} + +// Case 7: f32 256x16 scalar/src +extern "C" __global__ AICORE void TDIVS_f32_256x16_scalar_src(__gm__ float *src, __gm__ float *dst, float scalar); +void LaunchTDIVS_f32_256x16_scalar_src(float *src, float *dst, void *stream) { + TDIVS_f32_256x16_scalar_src<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, TDIVS_SCALAR_F32); +} + +// ========== HIGH_PRECISION mode - src / scalar direction ========== + +// Case 8: f32 32x64 HP (precision_sensitive) - scalar=3.0f +extern "C" __global__ AICORE void TDIVS_f32_32x64_hp(__gm__ float *src, __gm__ float *dst, float scalar); +void LaunchTDIVS_f32_32x64_hp(float *src, float *dst, void *stream) { + TDIVS_f32_32x64_hp<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, 3.0f); +} + +// Case 9: f16 63x64 HP (precision_sensitive) - scalar=3.0 in f16 (0x4200) +extern "C" __global__ AICORE void TDIVS_f16_63x64_hp(__gm__ unsigned short *src, __gm__ unsigned short *dst, unsigned short scalar); +void LaunchTDIVS_f16_63x64_hp(unsigned short *src, unsigned short *dst, void *stream) { + TDIVS_f16_63x64_hp<<<1, nullptr, stream>>>((__gm__ unsigned short *)src, (__gm__ unsigned short *)dst, (unsigned short)0x4200); +} + +// Case 10: f32 16x64 HP subnormal - scalar=10.0f +extern "C" __global__ AICORE void TDIVS_f32_16x64_hp_subnormal(__gm__ float *src, __gm__ float *dst, float scalar); +void LaunchTDIVS_f32_16x64_hp_subnormal(float *src, float *dst, void *stream) { + TDIVS_f32_16x64_hp_subnormal<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, 10.0f); +} + +// Case 11: f16 16x64 HP subnormal - scalar=10.0 in f16 (0x4900) +extern "C" __global__ AICORE void TDIVS_f16_16x64_hp_subnormal(__gm__ unsigned short *src, __gm__ unsigned short *dst, unsigned short scalar); +void LaunchTDIVS_f16_16x64_hp_subnormal(unsigned short *src, unsigned short *dst, void *stream) { + TDIVS_f16_16x64_hp_subnormal<<<1, nullptr, stream>>>((__gm__ unsigned short *)src, (__gm__ unsigned short *)dst, (unsigned short)0x4900); +} + +// Case 12: f32 16x64 HP overflow - scalar=np.float32(1e-30) -> hex 0x0DA24260 +extern "C" __global__ AICORE void TDIVS_f32_16x64_hp_overflow(__gm__ float *src, __gm__ float *dst, float scalar); +void LaunchTDIVS_f32_16x64_hp_overflow(float *src, float *dst, void *stream) { + TDIVS_f32_16x64_hp_overflow<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, bits_to_float(0x0DA24260U)); +} + +// Case 13: f16 16x64 HP overflow - scalar=np.float16(0.0001) -> hex 0x068E +extern "C" __global__ AICORE void TDIVS_f16_16x64_hp_overflow(__gm__ unsigned short *src, __gm__ unsigned short *dst, unsigned short scalar); +void LaunchTDIVS_f16_16x64_hp_overflow(unsigned short *src, unsigned short *dst, void *stream) { + TDIVS_f16_16x64_hp_overflow<<<1, nullptr, stream>>>((__gm__ unsigned short *)src, (__gm__ unsigned short *)dst, (unsigned short)0x068E); +} + +// ========== HIGH_PRECISION mode - scalar / src direction ========== + +// Case 14: f32 32x64 HP scalar/src (precision_sensitive) - scalar=1.0f +extern "C" __global__ AICORE void TDIVS_f32_32x64_hp_scalar_src(__gm__ float *src, __gm__ float *dst, float scalar); +void LaunchTDIVS_f32_32x64_hp_scalar_src(float *src, float *dst, void *stream) { + TDIVS_f32_32x64_hp_scalar_src<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, 1.0f); +} + +// Case 15: f16 63x64 HP scalar/src (precision_sensitive) - scalar=1.0 in f16 (0x3C00) +extern "C" __global__ AICORE void TDIVS_f16_63x64_hp_scalar_src(__gm__ unsigned short *src, __gm__ unsigned short *dst, unsigned short scalar); +void LaunchTDIVS_f16_63x64_hp_scalar_src(unsigned short *src, unsigned short *dst, void *stream) { + TDIVS_f16_63x64_hp_scalar_src<<<1, nullptr, stream>>>((__gm__ unsigned short *)src, (__gm__ unsigned short *)dst, (unsigned short)0x3C00); +} + +// Case 16: f32 16x64 HP subnormal scalar/src - scalar=np.float32(1e-20) -> hex 0x1E3CE508 +extern "C" __global__ AICORE void TDIVS_f32_16x64_hp_subnormal_scalar_src(__gm__ float *src, __gm__ float *dst, float scalar); +void LaunchTDIVS_f32_16x64_hp_subnormal_scalar_src(float *src, float *dst, void *stream) { + TDIVS_f32_16x64_hp_subnormal_scalar_src<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, bits_to_float(0x1E3CE508U)); +} + +// Case 17: f16 16x64 HP subnormal scalar/src - scalar=np.float16(1e-5) -> hex 0x00A8 +extern "C" __global__ AICORE void TDIVS_f16_16x64_hp_subnormal_scalar_src(__gm__ unsigned short *src, __gm__ unsigned short *dst, unsigned short scalar); +void LaunchTDIVS_f16_16x64_hp_subnormal_scalar_src(unsigned short *src, unsigned short *dst, void *stream) { + TDIVS_f16_16x64_hp_subnormal_scalar_src<<<1, nullptr, stream>>>((__gm__ unsigned short *)src, (__gm__ unsigned short *)dst, (unsigned short)0x00A8); +} + +// Case 18: f32 16x64 HP overflow scalar/src - scalar=np.float32(1e38) -> hex 0x7E967699 +extern "C" __global__ AICORE void TDIVS_f32_16x64_hp_overflow_scalar_src(__gm__ float *src, __gm__ float *dst, float scalar); +void LaunchTDIVS_f32_16x64_hp_overflow_scalar_src(float *src, float *dst, void *stream) { + TDIVS_f32_16x64_hp_overflow_scalar_src<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, bits_to_float(0x7E967699U)); +} + +// Case 19: f16 16x64 HP overflow scalar/src - scalar=np.float16(65000) -> hex 0x7BEF +extern "C" __global__ AICORE void TDIVS_f16_16x64_hp_overflow_scalar_src(__gm__ unsigned short *src, __gm__ unsigned short *dst, unsigned short scalar); +void LaunchTDIVS_f16_16x64_hp_overflow_scalar_src(unsigned short *src, unsigned short *dst, void *stream) { + TDIVS_f16_16x64_hp_overflow_scalar_src<<<1, nullptr, stream>>>((__gm__ unsigned short *)src, (__gm__ unsigned short *)dst, (unsigned short)0x7BEF); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tdivs/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tdivs/main.cpp new file mode 100644 index 000000000..413cdc0f0 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tdivs/main.cpp @@ -0,0 +1,170 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tdivs ST — case-table driven. +// tdivs: dst = src / scalar (single input + scalar). +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTDIVS_f32_32x64(float *src, float *dst, void *stream); +void LaunchTDIVS_f16_63x64(uint16_t *src, uint16_t *dst, void *stream); +void LaunchTDIVS_f32_7x448(float *src, float *dst, void *stream); +void LaunchTDIVS_f32_256x16(float *src, float *dst, void *stream); +void LaunchTDIVS_f32_32x64_scalar_src(float *src, float *dst, void *stream); +void LaunchTDIVS_f16_63x64_scalar_src(uint16_t *src, uint16_t *dst, void *stream); +void LaunchTDIVS_f32_7x448_scalar_src(float *src, float *dst, void *stream); +void LaunchTDIVS_f32_256x16_scalar_src(float *src, float *dst, void *stream); +// HIGH_PRECISION mode kernels +void LaunchTDIVS_f32_32x64_hp(float *src, float *dst, void *stream); +void LaunchTDIVS_f16_63x64_hp(uint16_t *src, uint16_t *dst, void *stream); +void LaunchTDIVS_f32_16x64_hp_subnormal(float *src, float *dst, void *stream); +void LaunchTDIVS_f16_16x64_hp_subnormal(uint16_t *src, uint16_t *dst, void *stream); +void LaunchTDIVS_f32_16x64_hp_overflow(float *src, float *dst, void *stream); +void LaunchTDIVS_f16_16x64_hp_overflow(uint16_t *src, uint16_t *dst, void *stream); +void LaunchTDIVS_f32_32x64_hp_scalar_src(float *src, float *dst, void *stream); +void LaunchTDIVS_f16_63x64_hp_scalar_src(uint16_t *src, uint16_t *dst, void *stream); +void LaunchTDIVS_f32_16x64_hp_subnormal_scalar_src(float *src, float *dst, void *stream); +void LaunchTDIVS_f16_16x64_hp_subnormal_scalar_src(uint16_t *src, uint16_t *dst, void *stream); +void LaunchTDIVS_f32_16x64_hp_overflow_scalar_src(float *src, float *dst, void *stream); +void LaunchTDIVS_f16_16x64_hp_overflow_scalar_src(uint16_t *src, uint16_t *dst, void *stream); + +struct TestCase { + const char *name; + void (*launch)(void *, void *, void *); // src, dst, stream + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_32x64", (void (*)(void*,void*,void*))LaunchTDIVS_f32_32x64, 32, 64, 32, 64, sizeof(float)}, + {"f16_63x64", (void (*)(void*,void*,void*))LaunchTDIVS_f16_63x64, 63, 64, 63, 64, sizeof(uint16_t)}, + {"f32_7x448", (void (*)(void*,void*,void*))LaunchTDIVS_f32_7x448, 7, 448, 7, 448, sizeof(float)}, + {"f32_256x16", (void (*)(void*,void*,void*))LaunchTDIVS_f32_256x16, 256, 16, 256, 16, sizeof(float)}, + {"f32_32x64_scalar_src", (void (*)(void*,void*,void*))LaunchTDIVS_f32_32x64_scalar_src, 32, 64, 32, 64, sizeof(float)}, + {"f16_63x64_scalar_src", (void (*)(void*,void*,void*))LaunchTDIVS_f16_63x64_scalar_src, 63, 64, 63, 64, sizeof(uint16_t)}, + {"f32_7x448_scalar_src", (void (*)(void*,void*,void*))LaunchTDIVS_f32_7x448_scalar_src, 7, 448, 7, 448, sizeof(float)}, + {"f32_256x16_scalar_src", (void (*)(void*,void*,void*))LaunchTDIVS_f32_256x16_scalar_src, 256, 16, 256, 16, sizeof(float)}, + // HIGH_PRECISION mode - src / scalar direction + {"f32_32x64_hp", (void (*)(void*,void*,void*))LaunchTDIVS_f32_32x64_hp, 32, 64, 32, 64, sizeof(float)}, + {"f16_63x64_hp", (void (*)(void*,void*,void*))LaunchTDIVS_f16_63x64_hp, 63, 64, 63, 64, sizeof(uint16_t)}, + {"f32_16x64_hp_subnormal", (void (*)(void*,void*,void*))LaunchTDIVS_f32_16x64_hp_subnormal, 16, 64, 16, 64, sizeof(float)}, + {"f16_16x64_hp_subnormal", (void (*)(void*,void*,void*))LaunchTDIVS_f16_16x64_hp_subnormal, 16, 64, 16, 64, sizeof(uint16_t)}, + {"f32_16x64_hp_overflow", (void (*)(void*,void*,void*))LaunchTDIVS_f32_16x64_hp_overflow, 16, 64, 16, 64, sizeof(float)}, + {"f16_16x64_hp_overflow", (void (*)(void*,void*,void*))LaunchTDIVS_f16_16x64_hp_overflow, 16, 64, 16, 64, sizeof(uint16_t)}, + // HIGH_PRECISION mode - scalar / src direction + {"f32_32x64_hp_scalar_src", (void (*)(void*,void*,void*))LaunchTDIVS_f32_32x64_hp_scalar_src, 32, 64, 32, 64, sizeof(float)}, + {"f16_63x64_hp_scalar_src", (void (*)(void*,void*,void*))LaunchTDIVS_f16_63x64_hp_scalar_src, 63, 64, 63, 64, sizeof(uint16_t)}, + {"f32_16x64_hp_subnormal_scalar_src", (void (*)(void*,void*,void*))LaunchTDIVS_f32_16x64_hp_subnormal_scalar_src, 16, 64, 16, 64, sizeof(float)}, + {"f16_16x64_hp_subnormal_scalar_src", (void (*)(void*,void*,void*))LaunchTDIVS_f16_16x64_hp_subnormal_scalar_src, 16, 64, 16, 64, sizeof(uint16_t)}, + {"f32_16x64_hp_overflow_scalar_src", (void (*)(void*,void*,void*))LaunchTDIVS_f32_16x64_hp_overflow_scalar_src, 16, 64, 16, 64, sizeof(float)}, + {"f16_16x64_hp_overflow_scalar_src", (void (*)(void*,void*,void*))LaunchTDIVS_f16_16x64_hp_overflow_scalar_src, 16, 64, 16, 64, sizeof(uint16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSize = fileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&srcHost, fileSize); + aclrtMallocHost(&dstHost, fileSize); + + aclrtMalloc(&srcDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), srcFileSize, srcHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, fileSize, srcHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tdivs [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tdivs/tdivs.pto b/test/tilelang_st/npu/a5/src/st/testcase/tdivs/tdivs.pto new file mode 100644 index 000000000..0e70396b5 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tdivs/tdivs.pto @@ -0,0 +1,399 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tdivs: tload(src) + tdivs(src, scalar)->dst + tstore(dst). +// vdiv only supports f16/f32 in TileLang DSL v1. +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + + // Case 0: f32 32x64 + func.func @TDIVS_f32_32x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c2048 = arith.constant 2048 : index + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c32, %c64], strides = [%c2048, %c2048, %c2048, %c64, %c1] : !pto.tensor_view<1x1x1x32x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c32, %c64], strides = [%c2048, %c2048, %c2048, %c64, %c1] : !pto.tensor_view<1x1x1x32x64xf32> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c32, %c64] : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c32, %c64] : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) outs(%src : !pto.tile_buf) + pto.tdivs ins(%src, %scalar : !pto.tile_buf, f32) outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) + return + } + + // Case 1: f16 63x64 + func.func @TDIVS_f16_63x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f16) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c4032 = arith.constant 4032 : index + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c63, %c64], strides = [%c4032, %c4032, %c4032, %c64, %c1] : !pto.tensor_view<1x1x1x63x64xf16> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c63, %c64], strides = [%c4032, %c4032, %c4032, %c64, %c1] : !pto.tensor_view<1x1x1x63x64xf16> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c63, %c64] : !pto.tensor_view<1x1x1x63x64xf16> -> !pto.partition_tensor_view<1x1x1x63x64xf16> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c63, %c64] : !pto.tensor_view<1x1x1x63x64xf16> -> !pto.partition_tensor_view<1x1x1x63x64xf16> + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x64xf16>) outs(%src : !pto.tile_buf) + pto.tdivs ins(%src, %scalar : !pto.tile_buf, f16) outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x64xf16>) + return + } + + // Case 2: f32 7x448 + func.func @TDIVS_f32_7x448(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c7 = arith.constant 7 : index + %c448 = arith.constant 448 : index + %c3136 = arith.constant 3136 : index + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c7, %c448], strides = [%c3136, %c3136, %c3136, %c448, %c1] : !pto.tensor_view<1x1x1x7x448xf32> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c7, %c448], strides = [%c3136, %c3136, %c3136, %c448, %c1] : !pto.tensor_view<1x1x1x7x448xf32> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c7, %c448] : !pto.tensor_view<1x1x1x7x448xf32> -> !pto.partition_tensor_view<1x1x1x7x448xf32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c7, %c448] : !pto.tensor_view<1x1x1x7x448xf32> -> !pto.partition_tensor_view<1x1x1x7x448xf32> + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x7x448xf32>) outs(%src : !pto.tile_buf) + pto.tdivs ins(%src, %scalar : !pto.tile_buf, f32) outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x7x448xf32>) + return + } + + // Case 3: f32 256x16 + func.func @TDIVS_f32_256x16(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c16 = arith.constant 16 : index + %c4096 = arith.constant 4096 : index + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c256, %c16], strides = [%c4096, %c4096, %c4096, %c16, %c1] : !pto.tensor_view<1x1x1x256x16xf32> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c256, %c16], strides = [%c4096, %c4096, %c4096, %c16, %c1] : !pto.tensor_view<1x1x1x256x16xf32> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c256, %c16] : !pto.tensor_view<1x1x1x256x16xf32> -> !pto.partition_tensor_view<1x1x1x256x16xf32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c256, %c16] : !pto.tensor_view<1x1x1x256x16xf32> -> !pto.partition_tensor_view<1x1x1x256x16xf32> + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x256x16xf32>) outs(%src : !pto.tile_buf) + pto.tdivs ins(%src, %scalar : !pto.tile_buf, f32) outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x256x16xf32>) + return + } + + // ========== scalar / src direction ========== + + // Case 4: f32 32x64 scalar/src + func.func @TDIVS_f32_32x64_scalar_src(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c2048 = arith.constant 2048 : index + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c32, %c64], strides = [%c2048, %c2048, %c2048, %c64, %c1] : !pto.tensor_view<1x1x1x32x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c32, %c64], strides = [%c2048, %c2048, %c2048, %c64, %c1] : !pto.tensor_view<1x1x1x32x64xf32> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c32, %c64] : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c32, %c64] : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) outs(%src : !pto.tile_buf) + pto.tdivs ins(%scalar, %src : f32, !pto.tile_buf) outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) + return + } + + // Case 5: f16 63x64 scalar/src + func.func @TDIVS_f16_63x64_scalar_src(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f16) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c4032 = arith.constant 4032 : index + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c63, %c64], strides = [%c4032, %c4032, %c4032, %c64, %c1] : !pto.tensor_view<1x1x1x63x64xf16> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c63, %c64], strides = [%c4032, %c4032, %c4032, %c64, %c1] : !pto.tensor_view<1x1x1x63x64xf16> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c63, %c64] : !pto.tensor_view<1x1x1x63x64xf16> -> !pto.partition_tensor_view<1x1x1x63x64xf16> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c63, %c64] : !pto.tensor_view<1x1x1x63x64xf16> -> !pto.partition_tensor_view<1x1x1x63x64xf16> + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x64xf16>) outs(%src : !pto.tile_buf) + pto.tdivs ins(%scalar, %src : f16, !pto.tile_buf) outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x64xf16>) + return + } + + // Case 6: f32 7x448 scalar/src + func.func @TDIVS_f32_7x448_scalar_src(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c7 = arith.constant 7 : index + %c448 = arith.constant 448 : index + %c3136 = arith.constant 3136 : index + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c7, %c448], strides = [%c3136, %c3136, %c3136, %c448, %c1] : !pto.tensor_view<1x1x1x7x448xf32> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c7, %c448], strides = [%c3136, %c3136, %c3136, %c448, %c1] : !pto.tensor_view<1x1x1x7x448xf32> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c7, %c448] : !pto.tensor_view<1x1x1x7x448xf32> -> !pto.partition_tensor_view<1x1x1x7x448xf32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c7, %c448] : !pto.tensor_view<1x1x1x7x448xf32> -> !pto.partition_tensor_view<1x1x1x7x448xf32> + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x7x448xf32>) outs(%src : !pto.tile_buf) + pto.tdivs ins(%scalar, %src : f32, !pto.tile_buf) outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x7x448xf32>) + return + } + + // Case 7: f32 256x16 scalar/src + func.func @TDIVS_f32_256x16_scalar_src(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c16 = arith.constant 16 : index + %c4096 = arith.constant 4096 : index + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c256, %c16], strides = [%c4096, %c4096, %c4096, %c16, %c1] : !pto.tensor_view<1x1x1x256x16xf32> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c256, %c16], strides = [%c4096, %c4096, %c4096, %c16, %c1] : !pto.tensor_view<1x1x1x256x16xf32> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c256, %c16] : !pto.tensor_view<1x1x1x256x16xf32> -> !pto.partition_tensor_view<1x1x1x256x16xf32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c256, %c16] : !pto.tensor_view<1x1x1x256x16xf32> -> !pto.partition_tensor_view<1x1x1x256x16xf32> + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x256x16xf32>) outs(%src : !pto.tile_buf) + pto.tdivs ins(%scalar, %src : f32, !pto.tile_buf) outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x256x16xf32>) + return + } + + // ========== HIGH_PRECISION mode - src / scalar direction ========== + + // Case 8: f32 32x64 HP (precision_sensitive) + func.func @TDIVS_f32_32x64_hp(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c2048 = arith.constant 2048 : index + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c32, %c64], strides = [%c2048, %c2048, %c2048, %c64, %c1] : !pto.tensor_view<1x1x1x32x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c32, %c64], strides = [%c2048, %c2048, %c2048, %c64, %c1] : !pto.tensor_view<1x1x1x32x64xf32> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c32, %c64] : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c32, %c64] : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) outs(%src : !pto.tile_buf) + pto.tdivs ins(%src, %scalar : !pto.tile_buf, f32) outs(%dst : !pto.tile_buf) {precision_mode = #pto} + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) + return + } + + // Case 9: f16 63x64 HP (precision_sensitive) + func.func @TDIVS_f16_63x64_hp(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f16) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c4032 = arith.constant 4032 : index + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c63, %c64], strides = [%c4032, %c4032, %c4032, %c64, %c1] : !pto.tensor_view<1x1x1x63x64xf16> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c63, %c64], strides = [%c4032, %c4032, %c4032, %c64, %c1] : !pto.tensor_view<1x1x1x63x64xf16> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c63, %c64] : !pto.tensor_view<1x1x1x63x64xf16> -> !pto.partition_tensor_view<1x1x1x63x64xf16> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c63, %c64] : !pto.tensor_view<1x1x1x63x64xf16> -> !pto.partition_tensor_view<1x1x1x63x64xf16> + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x64xf16>) outs(%src : !pto.tile_buf) + pto.tdivs ins(%src, %scalar : !pto.tile_buf, f16) outs(%dst : !pto.tile_buf) {precision_mode = #pto} + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x64xf16>) + return + } + + // Case 10: f32 16x64 HP subnormal + func.func @TDIVS_f32_16x64_hp_subnormal(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c16, %c64], strides = [%c1024, %c1024, %c1024, %c64, %c1] : !pto.tensor_view<1x1x1x16x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c16, %c64], strides = [%c1024, %c1024, %c1024, %c64, %c1] : !pto.tensor_view<1x1x1x16x64xf32> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c64] : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c64] : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) outs(%src : !pto.tile_buf) + pto.tdivs ins(%src, %scalar : !pto.tile_buf, f32) outs(%dst : !pto.tile_buf) {precision_mode = #pto} + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + return + } + + // Case 11: f16 16x64 HP subnormal + func.func @TDIVS_f16_16x64_hp_subnormal(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f16) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c16, %c64], strides = [%c1024, %c1024, %c1024, %c64, %c1] : !pto.tensor_view<1x1x1x16x64xf16> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c16, %c64], strides = [%c1024, %c1024, %c1024, %c64, %c1] : !pto.tensor_view<1x1x1x16x64xf16> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c64] : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c64] : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) outs(%src : !pto.tile_buf) + pto.tdivs ins(%src, %scalar : !pto.tile_buf, f16) outs(%dst : !pto.tile_buf) {precision_mode = #pto} + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + return + } + + // Case 12: f32 16x64 HP overflow + func.func @TDIVS_f32_16x64_hp_overflow(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c16, %c64], strides = [%c1024, %c1024, %c1024, %c64, %c1] : !pto.tensor_view<1x1x1x16x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c16, %c64], strides = [%c1024, %c1024, %c1024, %c64, %c1] : !pto.tensor_view<1x1x1x16x64xf32> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c64] : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c64] : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) outs(%src : !pto.tile_buf) + pto.tdivs ins(%src, %scalar : !pto.tile_buf, f32) outs(%dst : !pto.tile_buf) {precision_mode = #pto} + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + return + } + + // Case 13: f16 16x64 HP overflow + func.func @TDIVS_f16_16x64_hp_overflow(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f16) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c16, %c64], strides = [%c1024, %c1024, %c1024, %c64, %c1] : !pto.tensor_view<1x1x1x16x64xf16> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c16, %c64], strides = [%c1024, %c1024, %c1024, %c64, %c1] : !pto.tensor_view<1x1x1x16x64xf16> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c64] : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c64] : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) outs(%src : !pto.tile_buf) + pto.tdivs ins(%src, %scalar : !pto.tile_buf, f16) outs(%dst : !pto.tile_buf) {precision_mode = #pto} + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + return + } + + // ========== HIGH_PRECISION mode - scalar / src direction ========== + + // Case 14: f32 32x64 HP scalar/src (precision_sensitive) + func.func @TDIVS_f32_32x64_hp_scalar_src(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c2048 = arith.constant 2048 : index + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c32, %c64], strides = [%c2048, %c2048, %c2048, %c64, %c1] : !pto.tensor_view<1x1x1x32x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c32, %c64], strides = [%c2048, %c2048, %c2048, %c64, %c1] : !pto.tensor_view<1x1x1x32x64xf32> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c32, %c64] : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c32, %c64] : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) outs(%src : !pto.tile_buf) + pto.tdivs ins(%scalar, %src : f32, !pto.tile_buf) outs(%dst : !pto.tile_buf) {precision_mode = #pto} + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) + return + } + + // Case 15: f16 63x64 HP scalar/src (precision_sensitive) + func.func @TDIVS_f16_63x64_hp_scalar_src(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f16) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c4032 = arith.constant 4032 : index + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c63, %c64], strides = [%c4032, %c4032, %c4032, %c64, %c1] : !pto.tensor_view<1x1x1x63x64xf16> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c63, %c64], strides = [%c4032, %c4032, %c4032, %c64, %c1] : !pto.tensor_view<1x1x1x63x64xf16> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c63, %c64] : !pto.tensor_view<1x1x1x63x64xf16> -> !pto.partition_tensor_view<1x1x1x63x64xf16> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c63, %c64] : !pto.tensor_view<1x1x1x63x64xf16> -> !pto.partition_tensor_view<1x1x1x63x64xf16> + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x64xf16>) outs(%src : !pto.tile_buf) + pto.tdivs ins(%scalar, %src : f16, !pto.tile_buf) outs(%dst : !pto.tile_buf) {precision_mode = #pto} + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x64xf16>) + return + } + + // Case 16: f32 16x64 HP subnormal scalar/src + func.func @TDIVS_f32_16x64_hp_subnormal_scalar_src(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c16, %c64], strides = [%c1024, %c1024, %c1024, %c64, %c1] : !pto.tensor_view<1x1x1x16x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c16, %c64], strides = [%c1024, %c1024, %c1024, %c64, %c1] : !pto.tensor_view<1x1x1x16x64xf32> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c64] : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c64] : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) outs(%src : !pto.tile_buf) + pto.tdivs ins(%scalar, %src : f32, !pto.tile_buf) outs(%dst : !pto.tile_buf) {precision_mode = #pto} + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + return + } + + // Case 17: f16 16x64 HP subnormal scalar/src + func.func @TDIVS_f16_16x64_hp_subnormal_scalar_src(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f16) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c16, %c64], strides = [%c1024, %c1024, %c1024, %c64, %c1] : !pto.tensor_view<1x1x1x16x64xf16> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c16, %c64], strides = [%c1024, %c1024, %c1024, %c64, %c1] : !pto.tensor_view<1x1x1x16x64xf16> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c64] : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c64] : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) outs(%src : !pto.tile_buf) + pto.tdivs ins(%scalar, %src : f16, !pto.tile_buf) outs(%dst : !pto.tile_buf) {precision_mode = #pto} + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + return + } + + // Case 18: f32 16x64 HP overflow scalar/src + func.func @TDIVS_f32_16x64_hp_overflow_scalar_src(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c16, %c64], strides = [%c1024, %c1024, %c1024, %c64, %c1] : !pto.tensor_view<1x1x1x16x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c16, %c64], strides = [%c1024, %c1024, %c1024, %c64, %c1] : !pto.tensor_view<1x1x1x16x64xf32> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c64] : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c64] : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) outs(%src : !pto.tile_buf) + pto.tdivs ins(%scalar, %src : f32, !pto.tile_buf) outs(%dst : !pto.tile_buf) {precision_mode = #pto} + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + return + } + + // Case 19: f16 16x64 HP overflow scalar/src + func.func @TDIVS_f16_16x64_hp_overflow_scalar_src(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f16) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c16, %c64], strides = [%c1024, %c1024, %c1024, %c64, %c1] : !pto.tensor_view<1x1x1x16x64xf16> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c16, %c64], strides = [%c1024, %c1024, %c1024, %c64, %c1] : !pto.tensor_view<1x1x1x16x64xf16> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c64] : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c64] : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) outs(%src : !pto.tile_buf) + pto.tdivs ins(%scalar, %src : f16, !pto.tile_buf) outs(%dst : !pto.tile_buf) {precision_mode = #pto} + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + return + } + +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/texp/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/texp/CMakeLists.txt new file mode 100644 index 000000000..6ce5def10 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/texp/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(texp) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/texp/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/texp/cases.py new file mode 100644 index 000000000..d303bf358 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/texp/cases.py @@ -0,0 +1,77 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for texp ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - eps: tolerance for numpy.allclose (atol and rtol). + - high_precision: bool — when True, restricts input range to test ExpPrecisionImpl. + Uses subnormal threshold (0x007FFFFF for f32, 0x03FF for f16). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_16x64", + "dtype": np.float32, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-6, + "high_precision": False, + }, + { + "name": "f32_32x32", + "dtype": np.float32, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-6, + "high_precision": False, + }, + { + "name": "f16_16x64", + "dtype": np.float16, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-3, + "high_precision": False, + }, + { + "name": "f16_32x32", + "dtype": np.float16, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-3, + "high_precision": False, + }, + { + "name": "f32_64x64_hp1", + "dtype": np.float32, + "shape": (64, 64), + "valid_shape": (64, 64), + "eps": 1e-7, + "high_precision": True, + }, + { + "name": "f16_64x64_hp2", + "dtype": np.float16, + "shape": (64, 64), + "valid_shape": (64, 64), + "eps": 1e-7, + "high_precision": True, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/texp/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/texp/compare.py new file mode 100644 index 000000000..428604929 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/texp/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/texp/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/texp/gen_data.py new file mode 100644 index 000000000..13103c495 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/texp/gen_data.py @@ -0,0 +1,42 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +import struct +import math +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + high_precision = case["high_precision"] + + if high_precision: + hex_threshold = '007FFFFF' + bound_val = struct.unpack('!f', bytes.fromhex(hex_threshold))[0] + max_val = math.log(bound_val) + min_val = max_val * 2 + input = np.random.uniform(min_val, max_val, size=shape).astype(dtype) + else: + input = np.random.randn(*shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + golden[:vr, :vc] = np.exp(input[:vr, :vc]).astype(dtype, copy=False) + + save_case_data(case["name"], {"input": input, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__} high_precision={high_precision}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/texp/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/texp/launch.cpp new file mode 100644 index 000000000..59d955313 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/texp/launch.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: f32 16x64 +extern "C" __global__ AICORE void TEXP_f32_16x64(__gm__ float *a, __gm__ float *b); + +void LaunchTEXP_f32_16x64(void *a, void *b, void *stream) { + TEXP_f32_16x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b); +} + +// Case 1: f32 32x32 +extern "C" __global__ AICORE void TEXP_f32_32x32(__gm__ float *a, __gm__ float *b); + +void LaunchTEXP_f32_32x32(void *a, void *b, void *stream) { + TEXP_f32_32x32<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b); +} + +// Case 2: f16 16x64 +extern "C" __global__ AICORE void TEXP_f16_16x64(__gm__ uint16_t *a, __gm__ uint16_t *b); + +void LaunchTEXP_f16_16x64(void *a, void *b, void *stream) { + TEXP_f16_16x64<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b); +} + +// Case 3: f16 32x32 +extern "C" __global__ AICORE void TEXP_f16_32x32(__gm__ uint16_t *a, __gm__ uint16_t *b); + +void LaunchTEXP_f16_32x32(void *a, void *b, void *stream) { + TEXP_f16_32x32<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b); +} + +// Case 4: f32 64x64 hp1 +extern "C" __global__ AICORE void TEXP_f32_64x64_hp1(__gm__ float *a, __gm__ float *b); + +void LaunchTEXP_f32_64x64_hp1(void *a, void *b, void *stream) { + TEXP_f32_64x64_hp1<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b); +} + +// Case 5: f16 64x64 hp2 +extern "C" __global__ AICORE void TEXP_f16_64x64_hp2(__gm__ uint16_t *a, __gm__ uint16_t *b); + +void LaunchTEXP_f16_64x64_hp2(void *a, void *b, void *stream) { + TEXP_f16_64x64_hp2<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/texp/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/texp/main.cpp new file mode 100644 index 000000000..7483e2551 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/texp/main.cpp @@ -0,0 +1,141 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang texp ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTEXP_f32_16x64(void *a, void *b, void *stream); +void LaunchTEXP_f32_32x32(void *a, void *b, void *stream); +void LaunchTEXP_f16_16x64(void *a, void *b, void *stream); +void LaunchTEXP_f16_32x32(void *a, void *b, void *stream); +void LaunchTEXP_f32_64x64_hp1(void *a, void *b, void *stream); +void LaunchTEXP_f16_64x64_hp2(void *a, void *b, void *stream); + +using LaunchFn = void (*)(void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_16x64", LaunchTEXP_f32_16x64, 16, 64, 16, 64, sizeof(float)}, + {"f32_32x32", LaunchTEXP_f32_32x32, 32, 32, 32, 32, sizeof(float)}, + {"f16_16x64", LaunchTEXP_f16_16x64, 16, 64, 16, 64, sizeof(uint16_t)}, + {"f16_32x32", LaunchTEXP_f16_32x32, 32, 32, 32, 32, sizeof(uint16_t)}, + {"f32_64x64_hp1", LaunchTEXP_f32_64x64_hp1, 64, 64, 64, 64, sizeof(float)}, + {"f16_64x64_hp2", LaunchTEXP_f16_64x64_hp2, 64, 64, 64, 64, sizeof(uint16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSize = fileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&srcHost), fileSize); + aclrtMallocHost((void **)(&dstHost), fileSize); + + aclrtMalloc((void **)&srcDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input.bin").c_str(), srcFileSize, srcHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, fileSize, srcHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./texp [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/texp/texp.pto b/test/tilelang_st/npu/a5/src/st/testcase/texp/texp.pto new file mode 100644 index 000000000..982a6dec5 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/texp/texp.pto @@ -0,0 +1,264 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.texp: tload(a) + texp(a)->b + tstore(b). +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // Case 0: f32 16x64 (1024 elements) + func.func @TEXP_f32_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%a : !pto.tile_buf) + + pto.texp ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + return + } + + // Case 1: f32 32x32 (1024 elements) + func.func @TEXP_f32_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + outs(%a : !pto.tile_buf) + + pto.texp ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + return + } + + // Case 2: f16 16x64 (1024 elements) + func.func @TEXP_f16_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + outs(%a : !pto.tile_buf) + + pto.texp ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + return + } + + // Case 3: f16 32x32 (1024 elements) + func.func @TEXP_f16_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf16> -> !pto.partition_tensor_view<1x1x1x32x32xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf16> -> !pto.partition_tensor_view<1x1x1x32x32xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xf16>) + outs(%a : !pto.tile_buf) + + pto.texp ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x32x32xf16>) + return + } + + // Case 4: f32 64x64 hp1 (4096 elements) + func.func @TEXP_f32_64x64_hp1(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%a : !pto.tile_buf) + + pto.texp ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + {precision_mode = #pto} + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + return + } + + // Case 5: f16 64x64 hp2 (4096 elements) + func.func @TEXP_f16_64x64_hp2(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf16> -> !pto.partition_tensor_view<1x1x1x64x64xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf16> -> !pto.partition_tensor_view<1x1x1x64x64xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf16>) + outs(%a : !pto.tile_buf) + + pto.texp ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + {precision_mode = #pto} + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf16>) + return + } +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/texpands/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/texpands/CMakeLists.txt new file mode 100644 index 000000000..3b48410cc --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/texpands/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(texpands) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/texpands/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/texpands/cases.py new file mode 100644 index 000000000..3beb7daef --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/texpands/cases.py @@ -0,0 +1,124 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for texpands ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - scalar: the scalar value to broadcast to the tile. + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + # ========== float32 cases ========== + # Full valid shape cases + { + "name": "f32_16x64_scalar5", + "dtype": np.float32, + "shape": (16, 64), + "valid_shape": (16, 64), + "scalar": 5.0, + "eps": 1e-6, + }, + { + "name": "f32_32x32_scalar3", + "dtype": np.float32, + "shape": (32, 32), + "valid_shape": (32, 32), + "scalar": 3.0, + "eps": 1e-6, + }, + { + "name": "f32_64x64_scalar2", + "dtype": np.float32, + "shape": (64, 64), + "valid_shape": (64, 64), + "scalar": 2.0, + "eps": 1e-6, + }, + # Partial valid shape cases + { + "name": "f32_16x64_partial", + "dtype": np.float32, + "shape": (16, 64), + "valid_shape": (12, 48), + "scalar": 7.0, + "eps": 1e-6, + }, + { + "name": "f32_64x64_valid_60x60", + "dtype": np.float32, + "shape": (64, 64), + "valid_shape": (60, 60), + "scalar": 42.0, + "eps": 1e-6, + }, + + # ========== int32 cases ========== + { + "name": "i32_64x64_scalar100", + "dtype": np.int32, + "shape": (64, 64), + "valid_shape": (64, 64), + "scalar": 100, + "eps": 0, # exact match for integers + }, + { + "name": "i32_64x64_valid_60x60", + "dtype": np.int32, + "shape": (64, 64), + "valid_shape": (60, 60), + "scalar": 99, + "eps": 0, + }, + + # ========== half (fp16) cases ========== + { + "name": "f16_64x64_scalar1_5", + "dtype": np.float16, + "shape": (64, 64), + "valid_shape": (64, 64), + "scalar": 1.5, + "eps": 1e-3, # fp16 has lower precision + }, + { + "name": "f16_2x4096_valid_1x3600", + "dtype": np.float16, + "shape": (2, 4096), + "valid_shape": (1, 3600), + "scalar": 2.5, + "eps": 1e-3, + }, + + # ========== int16 cases ========== + { + "name": "i16_64x64_scalar50", + "dtype": np.int16, + "shape": (64, 64), + "valid_shape": (64, 64), + "scalar": 50, + "eps": 0, + }, + { + "name": "i16_20x512_valid_16x200", + "dtype": np.int16, + "shape": (20, 512), + "valid_shape": (16, 200), + "scalar": 25, + "eps": 0, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/texpands/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/texpands/compare.py new file mode 100644 index 000000000..db0cdf826 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/texpands/compare.py @@ -0,0 +1,78 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Compare output against golden for texpands test cases.""" + +import os +import sys +import numpy as np + +from cases import CASES + + +def main(): + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + eps = case["eps"] + + vr, vc = valid_shape + + # Load golden and output + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=dtype).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=dtype).reshape(shape) + + # For integer types, eps=0 means exact match + # For float types, use np.allclose with eps + if eps == 0: + # Integer comparison - exact match + if not np.array_equal(golden[:vr, :vc], output[:vr, :vc]): + diff = golden[:vr, :vc] - output[:vr, :vc] + idx = int(np.argmax(np.abs(diff))) + print(f"[ERROR] {case['name']}: Mismatch at idx={idx} (golden={golden.flat[idx]}, output={output.flat[idx]})") + all_passed = False + else: + print(f"[INFO] {case['name']}: compare passed") + else: + # Float comparison - use allclose + # Convert to float64 for comparison (fp16 precision issues) + g = golden[:vr, :vc].astype(np.float64, copy=False) + o = output[:vr, :vc].astype(np.float64, copy=False) + + if g.shape != o.shape: + print(f"[ERROR] {case['name']}: Shape mismatch: golden {g.shape} vs output {o.shape}") + all_passed = False + continue + + if not np.allclose(g, o, atol=eps, rtol=eps, equal_nan=True): + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + print(f"[ERROR] {case['name']}: Mismatch: max diff={float(abs_diff.flat[idx])} " + f"at idx={idx} (golden={g.flat[idx]}, output={o.flat[idx]})") + all_passed = False + else: + print(f"[INFO] {case['name']}: compare passed") + + if not all_passed: + sys.exit(2) + print("[INFO] all cases passed") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/texpands/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/texpands/gen_data.py new file mode 100644 index 000000000..b2dd3cac2 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/texpands/gen_data.py @@ -0,0 +1,49 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Generate golden data for texpands test cases.""" + +import os +import numpy as np + +from cases import CASES + + +def setup_case_rng(case): + """Set a per-case deterministic random seed.""" + np.random.seed(hash(case["name"]) & 0xFFFFFFFF) + + +def save_case_data(case_name, data_dict): + """Create case directory and write {name}.bin for each entry.""" + os.makedirs(case_name, exist_ok=True) + for name, arr in data_dict.items(): + arr.tofile(os.path.join(case_name, f"{name}.bin")) + + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + scalar = case["scalar"] + + # Convert scalar to the correct dtype + scalar_val = dtype(scalar) + + # Generate golden: fill valid_shape region with scalar value + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + golden[:vr, :vc] = scalar_val + + save_case_data(case["name"], {"golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} scalar={scalar} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/texpands/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/texpands/launch.cpp new file mode 100644 index 000000000..35ffa4f83 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/texpands/launch.cpp @@ -0,0 +1,80 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// ========== float32 kernels ========== + +extern "C" __global__ AICORE void TEXPANDS_f32_16x64_scalar5(__gm__ float *dst); +extern "C" __global__ AICORE void TEXPANDS_f32_32x32_scalar3(__gm__ float *dst); +extern "C" __global__ AICORE void TEXPANDS_f32_64x64_scalar2(__gm__ float *dst); +extern "C" __global__ AICORE void TEXPANDS_f32_16x64_partial(__gm__ float *dst); +extern "C" __global__ AICORE void TEXPANDS_f32_64x64_valid_60x60(__gm__ float *dst); + +void LaunchTEXPANDS_f32_16x64_scalar5(float *dst, void *stream) { + TEXPANDS_f32_16x64_scalar5<<<1, nullptr, stream>>>((__gm__ float *)dst); +} + +void LaunchTEXPANDS_f32_32x32_scalar3(float *dst, void *stream) { + TEXPANDS_f32_32x32_scalar3<<<1, nullptr, stream>>>((__gm__ float *)dst); +} + +void LaunchTEXPANDS_f32_64x64_scalar2(float *dst, void *stream) { + TEXPANDS_f32_64x64_scalar2<<<1, nullptr, stream>>>((__gm__ float *)dst); +} + +void LaunchTEXPANDS_f32_16x64_partial(float *dst, void *stream) { + TEXPANDS_f32_16x64_partial<<<1, nullptr, stream>>>((__gm__ float *)dst); +} + +void LaunchTEXPANDS_f32_64x64_valid_60x60(float *dst, void *stream) { + TEXPANDS_f32_64x64_valid_60x60<<<1, nullptr, stream>>>((__gm__ float *)dst); +} + +// ========== int32 kernels ========== + +extern "C" __global__ AICORE void TEXPANDS_i32_64x64_scalar100(__gm__ int32_t *dst); +extern "C" __global__ AICORE void TEXPANDS_i32_64x64_valid_60x60(__gm__ int32_t *dst); + +void LaunchTEXPANDS_i32_64x64_scalar100(int32_t *dst, void *stream) { + TEXPANDS_i32_64x64_scalar100<<<1, nullptr, stream>>>((__gm__ int32_t *)dst); +} + +void LaunchTEXPANDS_i32_64x64_valid_60x60(int32_t *dst, void *stream) { + TEXPANDS_i32_64x64_valid_60x60<<<1, nullptr, stream>>>((__gm__ int32_t *)dst); +} + +// ========== half (fp16) kernels ========== + +extern "C" __global__ AICORE void TEXPANDS_f16_64x64_scalar1_5(__gm__ uint16_t *dst); +extern "C" __global__ AICORE void TEXPANDS_f16_2x4096_valid_1x3600(__gm__ uint16_t *dst); + +void LaunchTEXPANDS_f16_64x64_scalar1_5(uint16_t *dst, void *stream) { + TEXPANDS_f16_64x64_scalar1_5<<<1, nullptr, stream>>>((__gm__ uint16_t *)dst); +} + +void LaunchTEXPANDS_f16_2x4096_valid_1x3600(uint16_t *dst, void *stream) { + TEXPANDS_f16_2x4096_valid_1x3600<<<1, nullptr, stream>>>((__gm__ uint16_t *)dst); +} + +// ========== int16 kernels ========== + +extern "C" __global__ AICORE void TEXPANDS_i16_64x64_scalar50(__gm__ int16_t *dst); +extern "C" __global__ AICORE void TEXPANDS_i16_20x512_valid_16x200(__gm__ int16_t *dst); + +void LaunchTEXPANDS_i16_64x64_scalar50(int16_t *dst, void *stream) { + TEXPANDS_i16_64x64_scalar50<<<1, nullptr, stream>>>((__gm__ int16_t *)dst); +} + +void LaunchTEXPANDS_i16_20x512_valid_16x200(int16_t *dst, void *stream) { + TEXPANDS_i16_20x512_valid_16x200<<<1, nullptr, stream>>>((__gm__ int16_t *)dst); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/texpands/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/texpands/main.cpp new file mode 100644 index 000000000..20179e763 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/texpands/main.cpp @@ -0,0 +1,171 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang texpands ST — case-table driven. +// Each case launches a different kernel variant, writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTEXPANDS_f32_16x64_scalar5(float *dst, void *stream); +void LaunchTEXPANDS_f32_32x32_scalar3(float *dst, void *stream); +void LaunchTEXPANDS_f32_64x64_scalar2(float *dst, void *stream); +void LaunchTEXPANDS_f32_16x64_partial(float *dst, void *stream); +void LaunchTEXPANDS_f32_64x64_valid_60x60(float *dst, void *stream); +void LaunchTEXPANDS_i32_64x64_scalar100(int32_t *dst, void *stream); +void LaunchTEXPANDS_i32_64x64_valid_60x60(int32_t *dst, void *stream); +void LaunchTEXPANDS_f16_64x64_scalar1_5(uint16_t *dst, void *stream); +void LaunchTEXPANDS_f16_2x4096_valid_1x3600(uint16_t *dst, void *stream); +void LaunchTEXPANDS_i16_64x64_scalar50(int16_t *dst, void *stream); +void LaunchTEXPANDS_i16_20x512_valid_16x200(int16_t *dst, void *stream); + +enum class DataType { F32, I32, F16, I16 }; + +struct TestCase { + const char *name; + DataType dtype; + void (*launch)(void *, void *); // Generic launch function pointer + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +// Helper to wrap type-specific launch functions +template +void wrapLaunch(void *dst, void *stream, void (*fn)(T *, void *)) { + fn((T *)dst, stream); +} + +static const TestCase kCases[] = { + // ========== float32 cases ========== + {"f32_16x64_scalar5", DataType::F32, + [](void *dst, void *stream) { wrapLaunch(dst, stream, LaunchTEXPANDS_f32_16x64_scalar5); }, + 16, 64, 16, 64, sizeof(float)}, + {"f32_32x32_scalar3", DataType::F32, + [](void *dst, void *stream) { wrapLaunch(dst, stream, LaunchTEXPANDS_f32_32x32_scalar3); }, + 32, 32, 32, 32, sizeof(float)}, + {"f32_64x64_scalar2", DataType::F32, + [](void *dst, void *stream) { wrapLaunch(dst, stream, LaunchTEXPANDS_f32_64x64_scalar2); }, + 64, 64, 64, 64, sizeof(float)}, + {"f32_16x64_partial", DataType::F32, + [](void *dst, void *stream) { wrapLaunch(dst, stream, LaunchTEXPANDS_f32_16x64_partial); }, + 16, 64, 12, 48, sizeof(float)}, + {"f32_64x64_valid_60x60", DataType::F32, + [](void *dst, void *stream) { wrapLaunch(dst, stream, LaunchTEXPANDS_f32_64x64_valid_60x60); }, + 64, 64, 60, 60, sizeof(float)}, + + // ========== int32 cases ========== + {"i32_64x64_scalar100", DataType::I32, + [](void *dst, void *stream) { wrapLaunch(dst, stream, LaunchTEXPANDS_i32_64x64_scalar100); }, + 64, 64, 64, 64, sizeof(int32_t)}, + {"i32_64x64_valid_60x60", DataType::I32, + [](void *dst, void *stream) { wrapLaunch(dst, stream, LaunchTEXPANDS_i32_64x64_valid_60x60); }, + 64, 64, 60, 60, sizeof(int32_t)}, + + // ========== half (fp16) cases ========== + {"f16_64x64_scalar1_5", DataType::F16, + [](void *dst, void *stream) { wrapLaunch(dst, stream, LaunchTEXPANDS_f16_64x64_scalar1_5); }, + 64, 64, 64, 64, sizeof(uint16_t)}, + {"f16_2x4096_valid_1x3600", DataType::F16, + [](void *dst, void *stream) { wrapLaunch(dst, stream, LaunchTEXPANDS_f16_2x4096_valid_1x3600); }, + 2, 4096, 1, 3600, sizeof(uint16_t)}, + + // ========== int16 cases ========== + {"i16_64x64_scalar50", DataType::I16, + [](void *dst, void *stream) { wrapLaunch(dst, stream, LaunchTEXPANDS_i16_64x64_scalar50); }, + 64, 64, 64, 64, sizeof(int16_t)}, + {"i16_20x512_valid_16x200", DataType::I16, + [](void *dst, void *stream) { wrapLaunch(dst, stream, LaunchTEXPANDS_i16_20x512_valid_16x200); }, + 20, 512, 16, 200, sizeof(int16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + + void *dstHost = nullptr; + void *dstDevice = nullptr; + + aclrtMallocHost(&dstHost, fileSize); + aclrtMalloc(&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + // Launch kernel (scalar is hardcoded in .pto) + tc.launch(dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./texpands [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/texpands/texpands.pto b/test/tilelang_st/npu/a5/src/st/testcase/texpands/texpands.pto new file mode 100644 index 000000000..9595a3faf --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/texpands/texpands.pto @@ -0,0 +1,354 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.texpands: broadcast a scalar to a tile. +// Multiple cases with different shapes, data types, and scalar values. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // ========== float32 cases ========== + + // Case: f32 16x64, scalar=5.0 (full valid shape) + func.func @TEXPANDS_f32_16x64_scalar5(%dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %scalar = arith.constant 5.0 : f32 + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.texpands ins(%scalar : f32) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + return + } + + // Case: f32 32x32, scalar=3.0 (full valid shape) + func.func @TEXPANDS_f32_32x32_scalar3(%dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + %scalar = arith.constant 3.0 : f32 + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.texpands ins(%scalar : f32) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + return + } + + // Case: f32 64x64, scalar=2.0 (full valid shape) + func.func @TEXPANDS_f32_64x64_scalar2(%dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + %scalar = arith.constant 2.0 : f32 + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.texpands ins(%scalar : f32) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + return + } + + // Case: f32 16x64, scalar=7.0 (partial valid shape: 12x48) + func.func @TEXPANDS_f32_16x64_partial(%dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c12 = arith.constant 12 : index + %c16 = arith.constant 16 : index + %c48 = arith.constant 48 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %scalar = arith.constant 7.0 : f32 + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c12, %c48] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x12x48xf32> + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.texpands ins(%scalar : f32) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x12x48xf32>) + return + } + + // Case: f32 64x64, valid 60x60, scalar=42.0 + func.func @TEXPANDS_f32_64x64_valid_60x60(%dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c60 = arith.constant 60 : index + %c64 = arith.constant 64 : index + %c3600 = arith.constant 3600 : index + %c4096 = arith.constant 4096 : index + %scalar = arith.constant 42.0 : f32 + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c60, %c60] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x60x60xf32> + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.texpands ins(%scalar : f32) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x60x60xf32>) + return + } + + // ========== int32 cases ========== + + // Case: i32 64x64, scalar=100 (full valid shape) + func.func @TEXPANDS_i32_64x64_scalar100(%dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + %scalar = arith.constant 100 : i32 + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.texpands ins(%scalar : i32) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) + return + } + + // Case: i32 64x64, valid 60x60, scalar=99 + func.func @TEXPANDS_i32_64x64_valid_60x60(%dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c60 = arith.constant 60 : index + %c64 = arith.constant 64 : index + %c3600 = arith.constant 3600 : index + %c4096 = arith.constant 4096 : index + %scalar = arith.constant 99 : i32 + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi32> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c60, %c60] + : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x60x60xi32> + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.texpands ins(%scalar : i32) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x60x60xi32>) + return + } + + // ========== half (fp16) cases ========== + + // Case: f16 64x64, scalar=1.5 (full valid shape) + func.func @TEXPANDS_f16_64x64_scalar1_5(%dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + %scalar = arith.constant 1.5 : f16 + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf16> -> !pto.partition_tensor_view<1x1x1x64x64xf16> + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.texpands ins(%scalar : f16) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x64x64xf16>) + return + } + + // Case: f16 2x4096, valid 1x3600, scalar=2.5 (wide column shape) + func.func @TEXPANDS_f16_2x4096_valid_1x3600(%dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3600 = arith.constant 3600 : index + %c4096 = arith.constant 4096 : index + %c8192 = arith.constant 8192 : index // 2*4096 + %scalar = arith.constant 2.5 : f16 + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c4096], + strides = [%c8192, %c8192, %c8192, %c4096, %c1] + : !pto.tensor_view<1x1x1x2x4096xf16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c3600] + : !pto.tensor_view<1x1x1x2x4096xf16> -> !pto.partition_tensor_view<1x1x1x1x3600xf16> + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.texpands ins(%scalar : f16) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x3600xf16>) + return + } + + // ========== int16 cases ========== + + // Case: i16 64x64, scalar=50 (full valid shape) + func.func @TEXPANDS_i16_64x64_scalar50(%dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + %scalar = arith.constant 50 : i16 + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.texpands ins(%scalar : i16) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) + return + } + + // Case: i16 20x512, valid 16x200, scalar=25 + func.func @TEXPANDS_i16_20x512_valid_16x200(%dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c20 = arith.constant 20 : index + %c200 = arith.constant 200 : index + %c512 = arith.constant 512 : index + %c3200 = arith.constant 3200 : index // 16*200 + %c10240 = arith.constant 10240 : index // 20*512 + %scalar = arith.constant 25 : i16 + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c20, %c512], + strides = [%c10240, %c10240, %c10240, %c512, %c1] + : !pto.tensor_view<1x1x1x20x512xi16> + + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c200] + : !pto.tensor_view<1x1x1x20x512xi16> -> !pto.partition_tensor_view<1x1x1x16x200xi16> + + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.texpands ins(%scalar : i16) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x200xi16>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfillpad/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad/CMakeLists.txt new file mode 100644 index 000000000..0bffcc7fd --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tfillpad) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfillpad/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad/cases.py new file mode 100644 index 000000000..4507ef6be --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad/cases.py @@ -0,0 +1,193 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tfillpad ST test cases. + +Matches C++ reference test cases exactly (Cases 1-13). + +PadValue semantics: + - Max: +inf for float, MAX for integers + - Min: -inf for float, MIN for integers + - Null: no fill (keep original value) + - Custom(-1.0f): -1.0f for float, -1 for integers + +Each case defines: + - name: case identifier (must match main.cpp kCases[] and launch.cpp) + - dtype: numpy dtype + - shape: (rows, cols) — dst tile physical dimensions + - valid_shape: (valid_rows, valid_cols) — dst valid region (output size) + - src_shape: (rows, cols) — src tile physical dimensions (optional, default=dst) + - src_valid_shape: (valid_rows, valid_cols) — src valid region (optional, default=dst_valid) + - load_padval: PadValue for TLOAD (fill invalid columns in src tile) + - fill_padval: PadValue for TFILLPAD (fill expansion region in dst) + - eps: tolerance for numpy.allclose +""" + +import numpy as np + +# PadValue enum values matching C++ definition +PADVAL_MAX = "Max" # +inf for float, MAX for integers +PADVAL_MIN = "Min" # -inf for float, MIN for integers +PADVAL_NULL = "Null" # no fill (keep original value, treated as 0 in golden) +PADVAL_ZERO = "Zero" # zero fill +PADVAL_NEG1 = "Neg1" # -1.0f for float, -1 for integers (Custom) + +CASES = [ + # ========== Case 1: float, 128x127 -> 128x128, PadMax ========== + # C++: runTFILLPAD + + { + "name": "f32_128x128_pad_128x127", + "dtype": np.float32, + "shape": (128, 128), # dst tile physical + "valid_shape": (128, 128), # dst valid (output size) + "src_shape": (128, 127), # src tile physical (127 cols, < dst 128) + "src_valid_shape": (128, 127), # src valid = full src + "load_padval": PADVAL_MAX, # TLOAD: fill col 127 with +inf + "fill_padval": PADVAL_MAX, # TFILLPAD: no expansion needed + "eps": 1e-6, + }, + + # ========== Case 2: float, 128x127 -> 128x160, PadMax ========== + # C++: runTFILLPAD + + { + "name": "f32_128x160_pad_128x127", + "dtype": np.float32, + "shape": (128, 160), # dst tile physical + "valid_shape": (128, 160), # dst valid (output size) + "src_shape": (128, 127), # src tile physical + "src_valid_shape": (128, 127), # src valid + "load_padval": PADVAL_MAX, # TLOAD: fill col 127 with +inf + "fill_padval": PADVAL_MAX, # TFILLPAD: fill cols 128-159 with +inf + "eps": 1e-6, + }, + + # ========== Case 3: float, 128x127 -> 128x160, LoadPad=Min, FillPad=Max ========== + # C++: runTFILLPAD + + { + "name": "f32_128x160_pad_128x127_v2", + "dtype": np.float32, + "shape": (128, 160), # dst tile physical + "valid_shape": (128, 160), # dst valid (output size) + "src_shape": (128, 127), # src tile physical + "src_valid_shape": (128, 127), # src valid + "load_padval": PADVAL_MIN, # TLOAD: fill col 127 with -inf + "fill_padval": PADVAL_MAX, # TFILLPAD: fill cols 128-159 with +inf + "eps": 1e-6, + }, + + # ========== Case 4: float, 260x7 -> 260x16, PadMin/Max ========== + # C++: runTFILLPAD + + { + "name": "f32_260x16_pad_260x7", + "dtype": np.float32, + "shape": (260, 16), # dst tile physical + "valid_shape": (260, 16), # dst valid (output size) + "src_shape": (260, 7), # src tile physical + "src_valid_shape": (260, 7), # src valid + "load_padval": PADVAL_MIN, # TLOAD: fill cols 8-15 with -inf (32B aligned tile) + "fill_padval": PADVAL_MAX, # TFILLPAD: no expansion needed + "eps": 1e-6, + }, + + # ========== Case 6: uint16, 260x7 -> 260x32, PadMin/Max ========== + # C++: runTFILLPAD + + { + "name": "u16_260x32_pad_260x7", + "dtype": np.uint16, + "shape": (260, 32), # dst tile physical + "valid_shape": (260, 32), # dst valid (output size) + "src_shape": (260, 7), # src tile physical + "src_valid_shape": (260, 7), # src valid + "load_padval": PADVAL_MIN, # TLOAD: fill cols 8-31 with MIN (uint16 0) + "fill_padval": PADVAL_MAX, # TFILLPAD: fill cols 8-31 with MAX (uint16 65535) + "eps": 0, + }, + + # ========== Case 7: int8, 260x7 -> 260x64, PadMin/Max ========== + # C++: runTFILLPAD + + { + "name": "s8_260x64_pad_260x7", + "dtype": np.int8, + "shape": (260, 64), # dst tile physical + "valid_shape": (260, 64), # dst valid (output size) + "src_shape": (260, 7), # src tile physical + "src_valid_shape": (260, 7), # src valid + "load_padval": PADVAL_MIN, # TLOAD: fill cols 8-63 with MIN (int8 -128) + "fill_padval": PADVAL_MAX, # TFILLPAD: no expansion needed + "eps": 0, + }, + + # ========== Case 10: int16, 260x7 -> 260x32, PadMin/Min ========== + # C++: runTFILLPAD + + { + "name": "s16_260x32_pad_260x7", + "dtype": np.int16, + "shape": (260, 32), # dst tile physical + "valid_shape": (260, 32), # dst valid (output size) + "src_shape": (260, 7), # src tile physical + "src_valid_shape": (260, 7), # src valid + "load_padval": PADVAL_MIN, # TLOAD: fill cols 8-31 with MIN (int16 -32768) + "fill_padval": PADVAL_MIN, # TFILLPAD: no expansion needed + "eps": 0, + }, + + # ========== Case 11: int32, 260x7 -> 260x32, PadMin/Min ========== + # C++: runTFILLPAD + + { + "name": "s32_260x32_pad_260x7", + "dtype": np.int32, + "shape": (260, 32), # dst tile physical + "valid_shape": (260, 32), # dst valid (output size) + "src_shape": (260, 7), # src tile physical + "src_valid_shape": (260, 7), # src valid + "load_padval": PADVAL_MIN, # TLOAD: fill cols 8-31 with MIN (int32 -2147483648) + "fill_padval": PADVAL_MIN, # TFILLPAD: no expansion needed + "eps": 0, + }, + + # ========== Case 12: float, 128x64 -> 128x128, LoadPad=Null, FillPad=Neg1 ========== + # C++: runTFILLPAD + + { + "name": "f32_128x128_pad_128x64_neg1", + "dtype": np.float32, + "shape": (128, 128), # dst tile physical + "valid_shape": (128, 128), # dst valid = full dst (output size) + "src_shape": (128, 64), # src tile physical (64 cols) + "src_valid_shape": (128, 64), # src valid = full src + "load_padval": PADVAL_NULL, # TLOAD: no fill (src cols 64 aligned to 32B) + "fill_padval": PADVAL_NEG1, # TFILLPAD: fill cols 64-127 with -1.0f + "eps": 1e-6, + }, + + # ========== Case 13: float, 128x127 -> 128x160, LoadPad=Neg1, FillPad=Neg1 ========== + # C++: runTFILLPAD + + { + "name": "f32_128x160_pad_128x127_neg1", + "dtype": np.float32, + "shape": (128, 160), # dst tile physical + "valid_shape": (128, 160), # dst valid = full dst (output size) - CHANGED! + "src_shape": (128, 127), # src tile physical (127 cols) + "src_valid_shape": (128, 127), # src valid = full src + "load_padval": PADVAL_NEG1, # TLOAD: fill col 127 with -1.0f (127 not 32B aligned) + "fill_padval": PADVAL_NEG1, # TFILLPAD: fill cols 128-159 with -1.0f + "eps": 1e-6, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfillpad/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad/compare.py new file mode 100644 index 000000000..1a023b000 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad/compare.py @@ -0,0 +1,81 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Compare output against golden for tfillpad test cases. + +For tfillpad: + - Input: full tile shape (rows x cols) + - Output: only valid region (valid_rows x valid_cols) + - Golden: valid region only +""" + +import os +import sys +import numpy as np + +from cases import CASES + + +def main(): + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + dtype = case["dtype"] + valid_shape = case["valid_shape"] + eps = case["eps"] + + # Load golden and output (both stored with valid_shape) + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=dtype).reshape(valid_shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=dtype).reshape(valid_shape) + + # For integer types, eps=0 means exact match + # For float types, use np.allclose with eps + if eps == 0: + # Integer comparison - exact match + if not np.array_equal(golden, output): + diff = golden - output + idx = int(np.argmax(np.abs(diff))) + print(f"[ERROR] {case['name']}: Mismatch at idx={idx} (golden={golden.flat[idx]}, output={output.flat[idx]})") + all_passed = False + else: + print(f"[INFO] {case['name']}: compare passed") + else: + # Float comparison - use allclose + # Convert to float64 for comparison (fp16 precision issues) + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + + if g.shape != o.shape: + print(f"[ERROR] {case['name']}: Shape mismatch: golden {g.shape} vs output {o.shape}") + all_passed = False + continue + + if not np.allclose(g, o, atol=eps, rtol=eps, equal_nan=True): + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + print(f"[ERROR] {case['name']}: Mismatch: max diff={float(abs_diff.flat[idx])} " + f"at idx={idx} (golden={g.flat[idx]}, output={o.flat[idx]})") + all_passed = False + else: + print(f"[INFO] {case['name']}: compare passed") + + if not all_passed: + sys.exit(2) + print("[INFO] all cases passed") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfillpad/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad/gen_data.py new file mode 100644 index 000000000..80b1a15a1 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad/gen_data.py @@ -0,0 +1,117 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Generate golden data for tfillpad test cases. + +TFILLPAD semantics: + 1. Copy src.valid_shape data to dst + 2. Fill cols from src.valid_cols to dst.cols with FillPadVal + 3. Fill rows from src.rows to dst.rows with FillPadVal + +Note: LoadPadVal is used by TLOAD only, TFILLPAD uses FillPadVal for expansion. +""" + +import os +import numpy as np +import struct + +from cases import CASES, PADVAL_MAX, PADVAL_MIN, PADVAL_NULL, PADVAL_ZERO, PADVAL_NEG1 + + +# FLT_MAX and -FLT_MAX (matching DSL PadValue.MAX/MIN) +def _float32_from_bits(bits: int) -> float: + return struct.unpack(">f", bits.to_bytes(4, byteorder="big", signed=False))[0] + +_FLT_MAX = _float32_from_bits(0x7F7FFFFF) # ~3.4028235e+38 +_FLT_MIN = _float32_from_bits(0xFF7FFFFF) # ~-3.4028235e+38 + + +def get_pad_value(dtype, padval_name): + """Get the actual pad value for a dtype based on PadValue enum. + + Matches DSL PadValue.materialize_scalar behavior: + - MAX: FLT_MAX for float (not inf), max for integers + - MIN: -FLT_MAX for float (not -inf), min for integers + - NEG1: -1.0 for float, -1 for integers + - NULL/ZERO: 0 + """ + if padval_name == PADVAL_MAX: + if np.issubdtype(dtype, np.floating): + return np.float32(_FLT_MAX) + else: + return np.iinfo(dtype).max + elif padval_name == PADVAL_MIN: + if np.issubdtype(dtype, np.floating): + return np.float32(_FLT_MIN) + else: + return np.iinfo(dtype).min + elif padval_name == PADVAL_NEG1: + if np.issubdtype(dtype, np.floating): + return np.float32(-1.0) + else: + return dtype(-1) + else: # PADVAL_NULL or PADVAL_ZERO + return dtype(0) + + +def setup_case_rng(case): + """Set a per-case deterministic random seed.""" + np.random.seed(hash(case["name"]) & 0xFFFFFFFF) + + +def save_case_data(case_name, data_dict): + """Create case directory and write {name}.bin for each entry.""" + os.makedirs(case_name, exist_ok=True) + for name, arr in data_dict.items(): + arr.tofile(os.path.join(case_name, f"{name}.bin")) + + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + dst_shape = case["shape"] + dst_valid = case["valid_shape"] + src_shape = case.get("src_shape", dst_shape) + src_valid = case.get("src_valid_shape", dst_valid) + fill_padval = case.get("fill_padval", PADVAL_ZERO) + + # Input: generated with src_shape (matching C++ input size) + src_vr, src_vc = src_valid + input_data = np.zeros(src_shape, dtype=dtype) + input_data[:src_vr, :src_vc] = np.random.randint(1, 10, size=(src_vr, src_vc)).astype(dtype) + + # Golden: generated with dst_valid (output size) + dst_vr, dst_vc = dst_valid + golden = np.zeros(dst_valid, dtype=dtype) + + # Step 1: Copy src valid data to dst + copy_vr = min(src_vr, dst_vr) + copy_vc = min(src_vc, dst_vc) + golden[:copy_vr, :copy_vc] = input_data[:copy_vr, :copy_vc] + + # Step 2: TFILLPAD fills cols from src_valid_cols to dst_cols with FillPadVal + # (NOT LoadPadVal! TFILLPAD uses FillPadVal for expansion) + if dst_vc > src_vc: + fill_val = get_pad_value(dtype, fill_padval) + golden[:dst_vr, src_vc:dst_vc] = fill_val + + # Step 3: TFILLPAD fills rows from src_rows to dst_rows with FillPadVal + if dst_shape[0] > src_shape[0]: + fill_val = get_pad_value(dtype, fill_padval) + expand_rows_start = src_shape[0] + expand_rows_end = dst_vr + if expand_rows_end > expand_rows_start: + golden[expand_rows_start:expand_rows_end, :dst_vc] = fill_val + + save_case_data(case["name"], {"input": input_data, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} input={src_shape} golden={dst_valid} " + f"fill_pad={fill_padval} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfillpad/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad/launch.cpp new file mode 100644 index 000000000..9f1583dc0 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad/launch.cpp @@ -0,0 +1,93 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// ========== Case 1: float, 128x128, valid=128x127 ========== + +extern "C" __global__ AICORE void TFILLPAD_f32_128x128_pad_128x127(__gm__ float *src, __gm__ float *dst); + +void LaunchTFILLPAD_f32_128x128_pad_128x127(float *src, float *dst, void *stream) { + TFILLPAD_f32_128x128_pad_128x127<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +// ========== Case 2: float, 128x160, valid=128x127 ========== + +extern "C" __global__ AICORE void TFILLPAD_f32_128x160_pad_128x127(__gm__ float *src, __gm__ float *dst); + +void LaunchTFILLPAD_f32_128x160_pad_128x127(float *src, float *dst, void *stream) { + TFILLPAD_f32_128x160_pad_128x127<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +// ========== Case 3: float, 128x160, valid=128x127 (different PadVal) ========== + +extern "C" __global__ AICORE void TFILLPAD_f32_128x160_pad_128x127_v2(__gm__ float *src, __gm__ float *dst); + +void LaunchTFILLPAD_f32_128x160_pad_128x127_v2(float *src, float *dst, void *stream) { + TFILLPAD_f32_128x160_pad_128x127_v2<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +// ========== Case 4: float, 260x16, valid=260x7 ========== + +extern "C" __global__ AICORE void TFILLPAD_f32_260x16_pad_260x7(__gm__ float *src, __gm__ float *dst); + +void LaunchTFILLPAD_f32_260x16_pad_260x7(float *src, float *dst, void *stream) { + TFILLPAD_f32_260x16_pad_260x7<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +// ========== Case 6: uint16, 260x32, valid=260x7 ========== + +extern "C" __global__ AICORE void TFILLPAD_u16_260x32_pad_260x7(__gm__ uint16_t *src, __gm__ uint16_t *dst); + +void LaunchTFILLPAD_u16_260x32_pad_260x7(uint16_t *src, uint16_t *dst, void *stream) { + TFILLPAD_u16_260x32_pad_260x7<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint16_t *)dst); +} + +// ========== Case 7: int8, 260x64, valid=260x7 ========== + +extern "C" __global__ AICORE void TFILLPAD_s8_260x64_pad_260x7(__gm__ int8_t *src, __gm__ int8_t *dst); + +void LaunchTFILLPAD_s8_260x64_pad_260x7(int8_t *src, int8_t *dst, void *stream) { + TFILLPAD_s8_260x64_pad_260x7<<<1, nullptr, stream>>>((__gm__ int8_t *)src, (__gm__ int8_t *)dst); +} + +// ========== Case 10: int16, 260x32, valid=260x7 ========== + +extern "C" __global__ AICORE void TFILLPAD_s16_260x32_pad_260x7(__gm__ int16_t *src, __gm__ int16_t *dst); + +void LaunchTFILLPAD_s16_260x32_pad_260x7(int16_t *src, int16_t *dst, void *stream) { + TFILLPAD_s16_260x32_pad_260x7<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst); +} + +// ========== Case 11: int32, 260x32, valid=260x7 ========== + +extern "C" __global__ AICORE void TFILLPAD_s32_260x32_pad_260x7(__gm__ int32_t *src, __gm__ int32_t *dst); + +void LaunchTFILLPAD_s32_260x32_pad_260x7(int32_t *src, int32_t *dst, void *stream) { + TFILLPAD_s32_260x32_pad_260x7<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst); +} + +// ========== Case 12: float, src=128x64, dst=128x128, PadCustomNeg1 ========== + +extern "C" __global__ AICORE void TFILLPAD_f32_128x128_pad_128x64_neg1(__gm__ float *src, __gm__ float *dst); + +void LaunchTFILLPAD_f32_128x128_pad_128x64_neg1(float *src, float *dst, void *stream) { + TFILLPAD_f32_128x128_pad_128x64_neg1<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +// ========== Case 13: float, src=128x127, dst=128x160, PadCustomNeg1 ========== + +extern "C" __global__ AICORE void TFILLPAD_f32_128x160_pad_128x127_neg1(__gm__ float *src, __gm__ float *dst); + +void LaunchTFILLPAD_f32_128x160_pad_128x127_neg1(float *src, float *dst, void *stream) { + TFILLPAD_f32_128x160_pad_128x127_neg1<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfillpad/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad/main.cpp new file mode 100644 index 000000000..a3b0036d0 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad/main.cpp @@ -0,0 +1,207 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tfillpad ST (non-inplace mode). +// Matches C++ reference test cases: Cases 1, 2, 3, 4, 6, 7, 10, 11, 12, 13 +// Output size: dst valid region (dst tile physical shape for full output) + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTFILLPAD_f32_128x128_pad_128x127(float *src, float *dst, void *stream); +void LaunchTFILLPAD_f32_128x160_pad_128x127(float *src, float *dst, void *stream); +void LaunchTFILLPAD_f32_128x160_pad_128x127_v2(float *src, float *dst, void *stream); +void LaunchTFILLPAD_f32_260x16_pad_260x7(float *src, float *dst, void *stream); +void LaunchTFILLPAD_u16_260x32_pad_260x7(uint16_t *src, uint16_t *dst, void *stream); +void LaunchTFILLPAD_s8_260x64_pad_260x7(int8_t *src, int8_t *dst, void *stream); +void LaunchTFILLPAD_s16_260x32_pad_260x7(int16_t *src, int16_t *dst, void *stream); +void LaunchTFILLPAD_s32_260x32_pad_260x7(int32_t *src, int32_t *dst, void *stream); +void LaunchTFILLPAD_f32_128x128_pad_128x64_neg1(float *src, float *dst, void *stream); +void LaunchTFILLPAD_f32_128x160_pad_128x127_neg1(float *src, float *dst, void *stream); + +enum class DataType { F32, U16, S8, S16, S32 }; + +struct TestCase { + const char *name; + DataType dtype; + void (*launch)(void *, void *, void *); + size_t rows; // dst tile rows (physical) + size_t cols; // dst tile cols (physical) + size_t validRows; // dst valid rows (output rows) + size_t validCols; // dst valid cols (output cols) - CHANGED: now = dst physical cols for full output + size_t srcRows; // src tensor rows (0 means same as rows) + size_t srcCols; // src tensor cols (0 means same as cols) + size_t elemSize; +}; + +template +void wrapLaunch(void *src, void *dst, void *stream, void (*fn)(T *, T *, void *)) { + fn((T *)src, (T *)dst, stream); +} + +static const TestCase kCases[] = { + // Case 1: float, src=128x127, dst=128x128, LoadPad=Max, FillPad=Max + // Output: 128x128 (full dst tile) + {"f32_128x128_pad_128x127", DataType::F32, + [](void *src, void *dst, void *stream) { wrapLaunch(src, dst, stream, LaunchTFILLPAD_f32_128x128_pad_128x127); }, + 128, 128, 128, 128, 128, 127, sizeof(float)}, // CHANGED: validCols=128, srcCols=127 + + // Case 2: float, src=128x127, dst=128x160, LoadPad=Max, FillPad=Max + // Output: 128x160 (full dst tile) + {"f32_128x160_pad_128x127", DataType::F32, + [](void *src, void *dst, void *stream) { wrapLaunch(src, dst, stream, LaunchTFILLPAD_f32_128x160_pad_128x127); }, + 128, 160, 128, 160, 128, 127, sizeof(float)}, // CHANGED: validCols=160, srcCols=127 + + // Case 3: float, src=128x127, dst=128x160, LoadPad=Min, FillPad=Max + // Output: 128x160 (full dst tile) + {"f32_128x160_pad_128x127_v2", DataType::F32, + [](void *src, void *dst, void *stream) { wrapLaunch(src, dst, stream, LaunchTFILLPAD_f32_128x160_pad_128x127_v2); }, + 128, 160, 128, 160, 128, 127, sizeof(float)}, // CHANGED: validCols=160, srcCols=127 + + // Case 4: float, src=260x7, dst=260x16, LoadPad=Min, FillPad=Max + // Output: 260x16 (full dst tile) + {"f32_260x16_pad_260x7", DataType::F32, + [](void *src, void *dst, void *stream) { wrapLaunch(src, dst, stream, LaunchTFILLPAD_f32_260x16_pad_260x7); }, + 260, 16, 260, 16, 260, 7, sizeof(float)}, // CHANGED: validCols=16, srcCols=7 + + // Case 6: uint16, src=260x7, dst=260x32, LoadPad=Min, FillPad=Max + // Output: 260x32 (full dst tile) + {"u16_260x32_pad_260x7", DataType::U16, + [](void *src, void *dst, void *stream) { wrapLaunch(src, dst, stream, LaunchTFILLPAD_u16_260x32_pad_260x7); }, + 260, 32, 260, 32, 260, 7, sizeof(uint16_t)}, + + // Case 7: int8, src=260x7, dst=260x64, LoadPad=Min, FillPad=Max + // Output: 260x64 (full dst tile) + {"s8_260x64_pad_260x7", DataType::S8, + [](void *src, void *dst, void *stream) { wrapLaunch(src, dst, stream, LaunchTFILLPAD_s8_260x64_pad_260x7); }, + 260, 64, 260, 64, 260, 7, sizeof(int8_t)}, // CHANGED: validCols=64, srcCols=7 + + // Case 10: int16, src=260x7, dst=260x32, LoadPad=Min, FillPad=Min + // Output: 260x32 (full dst tile) + {"s16_260x32_pad_260x7", DataType::S16, + [](void *src, void *dst, void *stream) { wrapLaunch(src, dst, stream, LaunchTFILLPAD_s16_260x32_pad_260x7); }, + 260, 32, 260, 32, 260, 7, sizeof(int16_t)}, // CHANGED: validCols=32, srcCols=7 + + // Case 11: int32, src=260x7, dst=260x32, LoadPad=Min, FillPad=Min + // Output: 260x32 (full dst tile) + {"s32_260x32_pad_260x7", DataType::S32, + [](void *src, void *dst, void *stream) { wrapLaunch(src, dst, stream, LaunchTFILLPAD_s32_260x32_pad_260x7); }, + 260, 32, 260, 32, 260, 7, sizeof(int32_t)}, // CHANGED: validCols=32, srcCols=7 + + // Case 12: float, src=128x64, dst=128x128, LoadPad=Null, FillPad=Custom(-1.0f) + // Output: 128x128 (full dst tile) + {"f32_128x128_pad_128x64_neg1", DataType::F32, + [](void *src, void *dst, void *stream) { wrapLaunch(src, dst, stream, LaunchTFILLPAD_f32_128x128_pad_128x64_neg1); }, + 128, 128, 128, 128, 128, 64, sizeof(float)}, // correct: validCols=128, srcCols=64 + + // Case 13: float, src=128x127, dst=128x160, LoadPad=Custom(-1.0f), FillPad=Custom(-1.0f) + // Output: 128x160 (full dst tile) - CHANGED from 127 to 160 + {"f32_128x160_pad_128x127_neg1", DataType::F32, + [](void *src, void *dst, void *stream) { wrapLaunch(src, dst, stream, LaunchTFILLPAD_f32_128x160_pad_128x127_neg1); }, + 128, 160, 128, 160, 128, 127, sizeof(float)}, // CHANGED: validCols=160 +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + size_t srcRows = (tc.srcRows > 0) ? tc.srcRows : tc.rows; + size_t srcCols = (tc.srcCols > 0) ? tc.srcCols : tc.cols; + size_t inputElemCount = srcRows * srcCols; + size_t outputElemCount = tc.validRows * tc.validCols; + size_t inputFileSize = inputElemCount * tc.elemSize; + size_t outputFileSize = outputElemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (src=%zux%zu, dst=%zux%zu, output=%zux%zu) ===\n", + tc.name, srcRows, srcCols, tc.rows, tc.cols, tc.validRows, tc.validCols); + + std::string caseDir = std::string("./") + tc.name; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&srcHost, inputFileSize); + aclrtMallocHost(&dstHost, outputFileSize); + + aclrtMalloc(&srcDevice, inputFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, outputFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input.bin").c_str(), inputFileSize, srcHost, inputFileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, inputFileSize, srcHost, inputFileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, outputFileSize, dstDevice, outputFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, outputFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfillpad/tfillpad.pto b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad/tfillpad.pto new file mode 100644 index 000000000..8b4a3c670 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad/tfillpad.pto @@ -0,0 +1,515 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tfillpad (non-inplace mode). +// Matches C++ reference test cases: Cases 1, 2, 3, 4, 6, 7, 10, 11, 12, 13 +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// +// PadValue encoding: 0=Null, 1=Zero, 2=Max, 3=Min +// Cases 12/13 use Custom(-1.0f) which cannot be encoded in PTO IR, +// template uses shape-based detection for these cases. +// +// C++ template params: shape3=src_rows, shape4=src_cols, kTRows_=dst_rows, kTCols_=dst_cols + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // ========== Case 1: float, src=128x127, dst=128x128, LoadPad=Max, FillPad=Max ========== + + func.func @TFILLPAD_f32_128x128_pad_128x127(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c127 = arith.constant 127 : index + %c128 = arith.constant 128 : index + %c16256 = arith.constant 16256 : index // 128*127 (src size) + %c16384 = arith.constant 16384 : index // 128*128 (dst size) + + // Src tensor_view: 128x127 + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c128, %c127], + strides = [%c16256, %c16256, %c16256, %c127, %c1] + : !pto.tensor_view<1x1x1x128x127xf32> + // Dst tensor_view: 128x128 + %dst_out_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c128, %c128], + strides = [%c16384, %c16384, %c16384, %c128, %c1] + : !pto.tensor_view<1x1x1x128x128xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c127] + : !pto.tensor_view<1x1x1x128x127xf32> -> !pto.partition_tensor_view<1x1x1x128x127xf32> + %dst_out_part = pto.partition_view %dst_out_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c128] + : !pto.tensor_view<1x1x1x128x128xf32> -> !pto.partition_tensor_view<1x1x1x128x128xf32> + + // Src tile: LoadPadVal=Max (pad=2), src physical=128x128, v_col=127 + %src = pto.alloc_tile + : !pto.tile_buf + // Dst tile: FillPadVal=Max (pad=2), dst physical=128x128, v_col=128 (full output) + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x128x127xf32>) + outs(%src : !pto.tile_buf) + + pto.tfillpad ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_out_part : !pto.partition_tensor_view<1x1x1x128x128xf32>) + return + } + + // ========== Case 2: float, src=128x127, dst=128x160, LoadPad=Max, FillPad=Max ========== + + func.func @TFILLPAD_f32_128x160_pad_128x127(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c127 = arith.constant 127 : index + %c128 = arith.constant 128 : index + %c160 = arith.constant 160 : index + %c16256 = arith.constant 16256 : index // 128*127 (src size) + %c20480 = arith.constant 20480 : index // 128*160 (dst size) + + // Src tensor_view: 128x127 + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c128, %c127], + strides = [%c16256, %c16256, %c16256, %c127, %c1] + : !pto.tensor_view<1x1x1x128x127xf32> + // Dst tensor_view: 128x160 + %dst_out_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c128, %c160], + strides = [%c20480, %c20480, %c20480, %c160, %c1] + : !pto.tensor_view<1x1x1x128x160xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c127] + : !pto.tensor_view<1x1x1x128x127xf32> -> !pto.partition_tensor_view<1x1x1x128x127xf32> + %dst_out_part = pto.partition_view %dst_out_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c160] + : !pto.tensor_view<1x1x1x128x160xf32> -> !pto.partition_tensor_view<1x1x1x128x160xf32> + + // Src tile: LoadPadVal=Max (pad=2), src physical=128x160, v_col=127 + %src = pto.alloc_tile + : !pto.tile_buf + // Dst tile: FillPadVal=Max (pad=2), dst physical=128x160, v_col=160 (full output) + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x128x127xf32>) + outs(%src : !pto.tile_buf) + + pto.tfillpad ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_out_part : !pto.partition_tensor_view<1x1x1x128x160xf32>) + return + } + + // ========== Case 3: float, src=128x127, dst=128x160, LoadPad=Min, FillPad=Max ========== + + func.func @TFILLPAD_f32_128x160_pad_128x127_v2(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c127 = arith.constant 127 : index + %c128 = arith.constant 128 : index + %c160 = arith.constant 160 : index + %c16256 = arith.constant 16256 : index // 128*127 (src size) + %c20480 = arith.constant 20480 : index // 128*160 (dst size) + + // Src tensor_view: 128x127 + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c128, %c127], + strides = [%c16256, %c16256, %c16256, %c127, %c1] + : !pto.tensor_view<1x1x1x128x127xf32> + // Dst tensor_view: 128x160 + %dst_out_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c128, %c160], + strides = [%c20480, %c20480, %c20480, %c160, %c1] + : !pto.tensor_view<1x1x1x128x160xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c127] + : !pto.tensor_view<1x1x1x128x127xf32> -> !pto.partition_tensor_view<1x1x1x128x127xf32> + %dst_out_part = pto.partition_view %dst_out_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c160] + : !pto.tensor_view<1x1x1x128x160xf32> -> !pto.partition_tensor_view<1x1x1x128x160xf32> + + // Src tile: LoadPadVal=Min (pad=3), src physical=128x160, v_col=127 + %src = pto.alloc_tile + : !pto.tile_buf + // Dst tile: FillPadVal=Max (pad=2), dst physical=128x160, v_col=160 (full output) + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x128x127xf32>) + outs(%src : !pto.tile_buf) + + pto.tfillpad ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_out_part : !pto.partition_tensor_view<1x1x1x128x160xf32>) + return + } + + // ========== Case 4: float, src=260x7, dst=260x16, LoadPad=Min, FillPad=Max ========== + + func.func @TFILLPAD_f32_260x16_pad_260x7(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c7 = arith.constant 7 : index + %c16 = arith.constant 16 : index + %c260 = arith.constant 260 : index + %c1820 = arith.constant 1820 : index // 260*7 (src size) + %c4160 = arith.constant 4160 : index // 260*16 (dst size) + + // Src tensor_view: 260x7 + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c260, %c7], + strides = [%c1820, %c1820, %c1820, %c7, %c1] + : !pto.tensor_view<1x1x1x260x7xf32> + // Dst tensor_view: 260x16 + %dst_out_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c260, %c16], + strides = [%c4160, %c4160, %c4160, %c16, %c1] + : !pto.tensor_view<1x1x1x260x16xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260, %c7] + : !pto.tensor_view<1x1x1x260x7xf32> -> !pto.partition_tensor_view<1x1x1x260x7xf32> + %dst_out_part = pto.partition_view %dst_out_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260, %c16] + : !pto.tensor_view<1x1x1x260x16xf32> -> !pto.partition_tensor_view<1x1x1x260x16xf32> + + // Src tile: LoadPadVal=Min (pad=3), src physical=260x16, v_col=7 + %src = pto.alloc_tile + : !pto.tile_buf + // Dst tile: FillPadVal=Max (pad=2), dst physical=260x16, v_col=16 (full output) + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x260x7xf32>) + outs(%src : !pto.tile_buf) + + pto.tfillpad ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_out_part : !pto.partition_tensor_view<1x1x1x260x16xf32>) + return + } + + // ========== Case 6: uint16, src=260x7, dst=260x32, LoadPad=Min, FillPad=Max ========== + + func.func @TFILLPAD_u16_260x32_pad_260x7(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c7 = arith.constant 7 : index + %c32 = arith.constant 32 : index + %c260 = arith.constant 260 : index + %c1820 = arith.constant 1820 : index // 260*7 (src size) + %c8320 = arith.constant 8320 : index // 260*32 (dst size) + + // Src tensor_view: 260x7 + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c260, %c7], + strides = [%c1820, %c1820, %c1820, %c7, %c1] + : !pto.tensor_view<1x1x1x260x7xui16> + // Dst tensor_view: 260x32 + %dst_out_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c260, %c32], + strides = [%c8320, %c8320, %c8320, %c32, %c1] + : !pto.tensor_view<1x1x1x260x32xui16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260, %c7] + : !pto.tensor_view<1x1x1x260x7xui16> -> !pto.partition_tensor_view<1x1x1x260x7xui16> + %dst_out_part = pto.partition_view %dst_out_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260, %c32] + : !pto.tensor_view<1x1x1x260x32xui16> -> !pto.partition_tensor_view<1x1x1x260x32xui16> + + // Src tile: LoadPadVal=Min (pad=3), src physical=260x32, v_col=7 + %src = pto.alloc_tile + : !pto.tile_buf + // Dst tile: FillPadVal=Max (pad=2), dst physical=260x32, v_col=32 (full output) + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x260x7xui16>) + outs(%src : !pto.tile_buf) + + pto.tfillpad ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_out_part : !pto.partition_tensor_view<1x1x1x260x32xui16>) + return + } + + // ========== Case 7: int8, src=260x7, dst=260x64, LoadPad=Min, FillPad=Max ========== + + func.func @TFILLPAD_s8_260x64_pad_260x7(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c7 = arith.constant 7 : index + %c64 = arith.constant 64 : index + %c260 = arith.constant 260 : index + %c1820 = arith.constant 1820 : index // 260*7 (src size) + %c16640 = arith.constant 16640 : index // 260*64 (dst size) + + // Src tensor_view: 260x7 + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c260, %c7], + strides = [%c1820, %c1820, %c1820, %c7, %c1] + : !pto.tensor_view<1x1x1x260x7xi8> + // Dst tensor_view: 260x64 + %dst_out_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c260, %c64], + strides = [%c16640, %c16640, %c16640, %c64, %c1] + : !pto.tensor_view<1x1x1x260x64xi8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260, %c7] + : !pto.tensor_view<1x1x1x260x7xi8> -> !pto.partition_tensor_view<1x1x1x260x7xi8> + %dst_out_part = pto.partition_view %dst_out_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260, %c64] + : !pto.tensor_view<1x1x1x260x64xi8> -> !pto.partition_tensor_view<1x1x1x260x64xi8> + + // Src tile: LoadPadVal=Min (pad=3), src physical=260x64, v_col=7 + %src = pto.alloc_tile + : !pto.tile_buf + // Dst tile: FillPadVal=Max (pad=2), dst physical=260x64, v_col=64 (full output) + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x260x7xi8>) + outs(%src : !pto.tile_buf) + + pto.tfillpad ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_out_part : !pto.partition_tensor_view<1x1x1x260x64xi8>) + return + } + + // ========== Case 10: int16, src=260x7, dst=260x32, LoadPad=Min, FillPad=Min ========== + + func.func @TFILLPAD_s16_260x32_pad_260x7(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c7 = arith.constant 7 : index + %c32 = arith.constant 32 : index + %c260 = arith.constant 260 : index + %c1820 = arith.constant 1820 : index // 260*7 (src size) + %c8320 = arith.constant 8320 : index // 260*32 (dst size) + + // Src tensor_view: 260x7 + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c260, %c7], + strides = [%c1820, %c1820, %c1820, %c7, %c1] + : !pto.tensor_view<1x1x1x260x7xi16> + // Dst tensor_view: 260x32 + %dst_out_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c260, %c32], + strides = [%c8320, %c8320, %c8320, %c32, %c1] + : !pto.tensor_view<1x1x1x260x32xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260, %c7] + : !pto.tensor_view<1x1x1x260x7xi16> -> !pto.partition_tensor_view<1x1x1x260x7xi16> + %dst_out_part = pto.partition_view %dst_out_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260, %c32] + : !pto.tensor_view<1x1x1x260x32xi16> -> !pto.partition_tensor_view<1x1x1x260x32xi16> + + // Src tile: LoadPadVal=Min (pad=3), src physical=260x32, v_col=7 + %src = pto.alloc_tile + : !pto.tile_buf + // Dst tile: FillPadVal=Min (pad=3), dst physical=260x32, v_col=32 (full output) + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x260x7xi16>) + outs(%src : !pto.tile_buf) + + pto.tfillpad ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_out_part : !pto.partition_tensor_view<1x1x1x260x32xi16>) + return + } + + // ========== Case 11: int32, src=260x7, dst=260x32, LoadPad=Min, FillPad=Min ========== + + func.func @TFILLPAD_s32_260x32_pad_260x7(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c7 = arith.constant 7 : index + %c32 = arith.constant 32 : index + %c260 = arith.constant 260 : index + %c1820 = arith.constant 1820 : index // 260*7 (src size) + %c8320 = arith.constant 8320 : index // 260*32 (dst size) + + // Src tensor_view: 260x7 + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c260, %c7], + strides = [%c1820, %c1820, %c1820, %c7, %c1] + : !pto.tensor_view<1x1x1x260x7xi32> + // Dst tensor_view: 260x32 + %dst_out_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c260, %c32], + strides = [%c8320, %c8320, %c8320, %c32, %c1] + : !pto.tensor_view<1x1x1x260x32xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260, %c7] + : !pto.tensor_view<1x1x1x260x7xi32> -> !pto.partition_tensor_view<1x1x1x260x7xi32> + %dst_out_part = pto.partition_view %dst_out_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260, %c32] + : !pto.tensor_view<1x1x1x260x32xi32> -> !pto.partition_tensor_view<1x1x1x260x32xi32> + + // Src tile: LoadPadVal=Min (pad=3), src physical=260x32, v_col=7 + %src = pto.alloc_tile + : !pto.tile_buf + // Dst tile: FillPadVal=Min (pad=3), dst physical=260x32, v_col=32 (full output) + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x260x7xi32>) + outs(%src : !pto.tile_buf) + + pto.tfillpad ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_out_part : !pto.partition_tensor_view<1x1x1x260x32xi32>) + return + } + + // ========== Case 12: float, src=128x64, dst=128x128, LoadPad=Null, FillPad=Custom(-1.0f) ========== + // PTO IR cannot encode Custom PadValue, template uses shape-based detection: + // src.valid_cols < dst.valid_cols => fill expansion region with -1.0f + + func.func @TFILLPAD_f32_128x128_pad_128x64_neg1(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c8192 = arith.constant 8192 : index // 128*64 (src size) + %c16384 = arith.constant 16384 : index // 128*128 (dst size) + + // Src tensor_view: 128x64 + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c128, %c64], + strides = [%c8192, %c8192, %c8192, %c64, %c1] + : !pto.tensor_view<1x1x1x128x64xf32> + + // Dst output tensor_view: 128x128 (full dst valid region) + %dst_out_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c128, %c128], + strides = [%c16384, %c16384, %c16384, %c128, %c1] + : !pto.tensor_view<1x1x1x128x128xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c64] + : !pto.tensor_view<1x1x1x128x64xf32> -> !pto.partition_tensor_view<1x1x1x128x64xf32> + %dst_out_part = pto.partition_view %dst_out_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c128] + : !pto.tensor_view<1x1x1x128x128xf32> -> !pto.partition_tensor_view<1x1x1x128x128xf32> + + // Src tile: LoadPadVal=Null (pad=0), src physical=128x128, v_col=64 + %src = pto.alloc_tile + : !pto.tile_buf + // Dst tile: FillPadVal=Custom(-1.0f) - detected by template via src.v_col < dst.v_col + // Use pad=1 (Zero) as placeholder (PTO IR cannot encode Custom), template detects expansion and uses -1.0f + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x128x64xf32>) + outs(%src : !pto.tile_buf) + + pto.tfillpad ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_out_part : !pto.partition_tensor_view<1x1x1x128x128xf32>) + return + } + + // ========== Case 13: float, src=128x127, dst=128x160, LoadPad=Custom(-1.0f), FillPad=Custom(-1.0f) ========== + // PTO IR cannot encode Custom PadValue, template uses shape-based detection + + func.func @TFILLPAD_f32_128x160_pad_128x127_neg1(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c127 = arith.constant 127 : index + %c128 = arith.constant 128 : index + %c160 = arith.constant 160 : index + %c16256 = arith.constant 16256 : index // 128*127 (src size) + %c20480 = arith.constant 20480 : index // 128*160 (dst size) + + // Src tensor_view: 128x127 + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c128, %c127], + strides = [%c16256, %c16256, %c16256, %c127, %c1] + : !pto.tensor_view<1x1x1x128x127xf32> + + // Dst output tensor_view: 128x160 (full dst output) + %dst_out_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c128, %c160], + strides = [%c20480, %c20480, %c20480, %c160, %c1] + : !pto.tensor_view<1x1x1x128x160xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c127] + : !pto.tensor_view<1x1x1x128x127xf32> -> !pto.partition_tensor_view<1x1x1x128x127xf32> + %dst_out_part = pto.partition_view %dst_out_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c160] + : !pto.tensor_view<1x1x1x128x160xf32> -> !pto.partition_tensor_view<1x1x1x128x160xf32> + + // Src tile: LoadPadVal=Custom(-1.0f), src physical=128x160, v_col=127 + // Use pad=0, template will detect and fill src padding region with -1.0f + %src = pto.alloc_tile + : !pto.tile_buf + // Dst tile: FillPadVal=Custom(-1.0f), dst physical=128x160, v_col=160 + // Use pad=1 (Zero) as placeholder (PTO IR cannot encode Custom), template detects expansion and uses -1.0f + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x128x127xf32>) + outs(%src : !pto.tile_buf) + + pto.tfillpad ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_out_part : !pto.partition_tensor_view<1x1x1x128x160xf32>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_expand/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_expand/CMakeLists.txt new file mode 100644 index 000000000..eaf9f0308 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_expand/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tfillpad_expand) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_expand/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_expand/cases.py new file mode 100644 index 000000000..351d7716a --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_expand/cases.py @@ -0,0 +1,74 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tfillpad_expand ST test cases. + +Matches C++ reference test cases: Cases 8, 9 + +C++ expand mode parameters: + - shape3: src physical rows + - shape4: src physical cols + - kTRows_: dst physical rows + - kTCols_: dst physical cols + - expand=true: TFILLPAD_EXPAND copies src valid data, fills expansion with FillPadVal + +Case 8: runTFILLPAD +Case 9: runTFILLPAD + +Each case defines: + - name: case identifier + - dtype: numpy dtype + - shape: (rows, cols) — src tile physical dimensions (input size) + - valid_shape: (valid_rows, valid_cols) — src valid region + - dst_shape: (rows, cols) — dst tile physical dimensions + - dst_valid_shape: (valid_rows, valid_cols) — dst valid region (output size) + - load_padval: PadValue for TLOAD (fill invalid columns in src tile) + - fill_padval: PadValue for TFILLPAD_EXPAND (fill expansion region in dst) + - eps: tolerance for numpy.allclose +""" + +import numpy as np + +# PadValue enum values matching C++ definition +PADVAL_MAX = "Max" # FLT_MAX for float, MAX for integers +PADVAL_MIN = "Min" # -FLT_MAX for float, MIN for integers +PADVAL_NULL = "Null" # no fill +PADVAL_ZERO = "Zero" # zero fill +PADVAL_NEG1 = "Neg1" # -1.0f for float, -1 for integers (Custom) + +CASES = [ + # ========== Case 1: uint16, src=259x7, dst=260x32, expand, LoadPad=Min, FillPad=Max ========== + + { + "name": "u16_260x32_src_259x7", + "dtype": np.uint16, + "shape": (259, 7), # src physical (C++ shape3=259, shape4=7) + "valid_shape": (259, 7), # src valid region (actual data) + "dst_shape": (260, 32), # dst physical + "dst_valid_shape": (260, 32), # dst valid (output size) + "load_padval": PADVAL_MIN, # TLOAD: fill cols 7-31 with MIN (uint16 MIN=0) + "fill_padval": PADVAL_MAX, # TFILLPAD_EXPAND: fill expansion region with MAX (uint16 MAX=65535) + "eps": 0, + }, + + # ========== Case 2: int8, src=259x7, dst=260x64, expand, LoadPad=Min, FillPad=Max ========== + + { + "name": "s8_260x64_src_259x7", + "dtype": np.int8, + "shape": (259, 7), # src physical (C++ shape3=259, shape4=7) + "valid_shape": (259, 7), # src valid region (actual data) + "dst_shape": (260, 64), # dst physical + "dst_valid_shape": (260, 64), # dst valid (output size) + "load_padval": PADVAL_MIN, # TLOAD: fill cols 7-63 with MIN (int8 MIN=-128) + "fill_padval": PADVAL_MAX, # TFILLPAD_EXPAND: fill expansion region with MAX (127) + "eps": 0, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_expand/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_expand/compare.py new file mode 100644 index 000000000..fdd4a1d13 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_expand/compare.py @@ -0,0 +1,75 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Compare output against golden for tfillpad_expand test cases.""" + +import os +import sys +import numpy as np + +from cases import CASES + + +def main(): + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + dtype = case["dtype"] + dst_shape = case["dst_shape"] + eps = case["eps"] + + # Load golden and output (both stored with dst_shape) + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=dtype).reshape(dst_shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=dtype).reshape(dst_shape) + + # For integer types, eps=0 means exact match + # For float types, use np.allclose with eps + if eps == 0: + # Integer comparison - exact match + if not np.array_equal(golden, output): + diff = golden - output + idx = int(np.argmax(np.abs(diff))) + print(f"[ERROR] {case['name']}: Mismatch at idx={idx} (golden={golden.flat[idx]}, output={output.flat[idx]})") + all_passed = False + else: + print(f"[INFO] {case['name']}: compare passed") + else: + # Float comparison - use allclose + # Convert to float64 for comparison (fp16 precision issues) + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + + if g.shape != o.shape: + print(f"[ERROR] {case['name']}: Shape mismatch: golden {g.shape} vs output {o.shape}") + all_passed = False + continue + + if not np.allclose(g, o, atol=eps, rtol=eps, equal_nan=True): + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + print(f"[ERROR] {case['name']}: Mismatch: max diff={float(abs_diff.flat[idx])} " + f"at idx={idx} (golden={g.flat[idx]}, output={o.flat[idx]})") + all_passed = False + else: + print(f"[INFO] {case['name']}: compare passed") + + if not all_passed: + sys.exit(2) + print("[INFO] all cases passed") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_expand/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_expand/gen_data.py new file mode 100644 index 000000000..c7b55c8a5 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_expand/gen_data.py @@ -0,0 +1,114 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Generate golden data for tfillpad_expand test cases. + +TFILLPAD_EXPAND semantics: + 1. Copy src.valid_shape data to dst + 2. Fill cols from src.valid_cols to dst.valid_cols with FillPadVal + 3. Fill rows from src.rows to dst.rows with FillPadVal + +Note: LoadPadVal is used by TLOAD only, TFILLPAD_EXPAND uses FillPadVal for expansion. +""" + +import os +import numpy as np +import struct + +from cases import CASES, PADVAL_MAX, PADVAL_MIN, PADVAL_NEG1, PADVAL_ZERO + + +# FLT_MAX and -FLT_MAX (matching DSL PadValue.MAX/MIN) +def _float32_from_bits(bits: int) -> float: + return struct.unpack(">f", bits.to_bytes(4, byteorder="big", signed=False))[0] + +_FLT_MAX = _float32_from_bits(0x7F7FFFFF) # ~3.4028235e+38 +_FLT_MIN = _float32_from_bits(0xFF7FFFFF) # ~-3.4028235e+38 + + +def get_pad_value(dtype, padval_name): + """Get the actual pad value for a dtype based on PadValue enum. + + Matches DSL PadValue.materialize_scalar behavior: + - MAX: FLT_MAX for float (not inf), max for integers + - MIN: -FLT_MAX for float (not -inf), min for integers + - NEG1: -1.0 for float, -1 for integers + - NULL/ZERO: 0 + """ + if padval_name == PADVAL_MAX: + if np.issubdtype(dtype, np.floating): + return np.float32(_FLT_MAX) + else: + return np.iinfo(dtype).max + elif padval_name == PADVAL_MIN: + if np.issubdtype(dtype, np.floating): + return np.float32(_FLT_MIN) + else: + return np.iinfo(dtype).min + elif padval_name == PADVAL_NEG1: + if np.issubdtype(dtype, np.floating): + return np.float32(-1.0) + else: + return dtype(-1) + else: # PADVAL_NULL or PADVAL_ZERO + return dtype(0) + + +def setup_case_rng(case): + """Set a per-case deterministic random seed.""" + np.random.seed(hash(case["name"]) & 0xFFFFFFFF) + + +def save_case_data(case_name, data_dict): + """Create case directory and write {name}.bin for each entry.""" + os.makedirs(case_name, exist_ok=True) + for name, arr in data_dict.items(): + arr.tofile(os.path.join(case_name, f"{name}.bin")) + + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + src_shape = case["shape"] # src physical (input size, matching tensor_view) + src_valid = case["valid_shape"] # src valid region (actual data in input) + dst_shape = case["dst_shape"] # dst physical + dst_valid = case["dst_valid_shape"] # dst valid (output size) + fill_padval = case.get("fill_padval", PADVAL_ZERO) + + src_vr, src_vc = src_valid + dst_vr, dst_vc = dst_valid + + # Generate input: random values in src valid region, zeros elsewhere + # Input size = src_shape (matching tensor_view and C++ input) + input_data = np.zeros(src_shape, dtype=dtype) + input_data[:src_vr, :src_vc] = np.random.randint(1, 10, size=(src_vr, src_vc)).astype(dtype) + + # Generate golden: dst valid region (output size) + golden = np.zeros(dst_valid, dtype=dtype) + + # Step 1: Copy src valid data to dst + copy_vr = min(src_vr, dst_vr) + copy_vc = min(src_vc, dst_vc) + golden[:copy_vr, :copy_vc] = input_data[:copy_vr, :copy_vc] + + # Step 2: Fill column expansion region (cols from src_vc to dst_vc) + if dst_vc > src_vc: + fill_val = get_pad_value(dtype, fill_padval) + golden[:dst_vr, src_vc:dst_vc] = fill_val + + # Step 3: Fill row expansion region (rows from src_vr to dst_vr) + if dst_vr > src_vr: + fill_val = get_pad_value(dtype, fill_padval) + golden[src_vr:dst_vr, :dst_vc] = fill_val + + save_case_data(case["name"], {"input": input_data, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} src={src_shape} valid={src_valid} -> dst={dst_shape} " + f"fill_pad={fill_padval} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_expand/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_expand/launch.cpp new file mode 100644 index 000000000..c2f6a6da0 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_expand/launch.cpp @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// ========== uint16 kernel (C++ case 8) ========== + +extern "C" __global__ AICORE void TFILLPAD_EXPAND_u16_260x32_src_259x7(__gm__ uint16_t *src, __gm__ uint16_t *dst); + +void LaunchTFILLPAD_EXPAND_u16_260x32_src_259x7(uint16_t *src, uint16_t *dst, void *stream) { + TFILLPAD_EXPAND_u16_260x32_src_259x7<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint16_t *)dst); +} + +// ========== int8 kernel (C++ case 9) ========== + +extern "C" __global__ AICORE void TFILLPAD_EXPAND_s8_260x64_src_259x7(__gm__ int8_t *src, __gm__ int8_t *dst); + +void LaunchTFILLPAD_EXPAND_s8_260x64_src_259x7(int8_t *src, int8_t *dst, void *stream) { + TFILLPAD_EXPAND_s8_260x64_src_259x7<<<1, nullptr, stream>>>((__gm__ int8_t *)src, (__gm__ int8_t *)dst); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_expand/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_expand/main.cpp new file mode 100644 index 000000000..72a657248 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_expand/main.cpp @@ -0,0 +1,151 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tfillpad_expand ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTFILLPAD_EXPAND_u16_260x32_src_259x7(uint16_t *src, uint16_t *dst, void *stream); +void LaunchTFILLPAD_EXPAND_s8_260x64_src_259x7(int8_t *src, int8_t *dst, void *stream); + +enum class DataType { U16, S8 }; + +struct TestCase { + const char *name; + DataType dtype; + void (*launch)(void *, void *, void *); // Generic launch function pointer + size_t srcRows; + size_t srcCols; + size_t srcValidRows; + size_t srcValidCols; + size_t dstRows; + size_t dstCols; + size_t dstValidRows; + size_t dstValidCols; + size_t elemSize; +}; + +// Helper to wrap type-specific launch functions +template +void wrapLaunch(void *src, void *dst, void *stream, void (*fn)(T *, T *, void *)) { + fn((T *)src, (T *)dst, stream); +} + +static const TestCase kCases[] = { + // ========== uint16 case (C++ case 8) ========== + {"u16_260x32_src_259x7", DataType::U16, + [](void *src, void *dst, void *stream) { wrapLaunch(src, dst, stream, LaunchTFILLPAD_EXPAND_u16_260x32_src_259x7); }, + 260, 32, 259, 7, 260, 32, 260, 32, sizeof(uint16_t)}, + + // ========== int8 case (C++ case 9) ========== + {"s8_260x64_src_259x7", DataType::S8, + [](void *src, void *dst, void *stream) { wrapLaunch(src, dst, stream, LaunchTFILLPAD_EXPAND_s8_260x64_src_259x7); }, + 260, 64, 259, 7, 260, 64, 260, 64, sizeof(int8_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + size_t srcElemCount = tc.srcRows * tc.srcCols; + size_t dstElemCount = tc.dstRows * tc.dstCols; + size_t srcFileSize = srcElemCount * tc.elemSize; + size_t dstFileSize = dstElemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (src=%zux%zu valid=%zux%zu -> dst=%zux%zu) ===\n", + tc.name, tc.srcRows, tc.srcCols, tc.srcValidRows, tc.srcValidCols, tc.dstRows, tc.dstCols); + + std::string caseDir = std::string("./") + tc.name; + size_t inputFileSize = srcFileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&srcHost, srcFileSize); + aclrtMallocHost(&dstHost, dstFileSize); + + aclrtMalloc(&srcDevice, srcFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input.bin").c_str(), inputFileSize, srcHost, srcFileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, srcFileSize, srcHost, srcFileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_expand/tfillpad_expand.pto b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_expand/tfillpad_expand.pto new file mode 100644 index 000000000..6e68ba00b --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_expand/tfillpad_expand.pto @@ -0,0 +1,121 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tfillpad_expand: copy src to dst and fill padding. +// Matches C++ test cases: case 8, 9 +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// +// PadValue encoding: 0=Null, 1=Zero, 2=Max, 3=Min +// Case 8: uint16, LoadPad=Min(pad=3), FillPad=Max(pad=2) +// Case 9: int8, LoadPad=Min(pad=3), FillPad=Max(pad=2) +// +// C++ template params: shape3=src_rows, shape4=src_cols, kTRows_=dst_rows, kTCols_=dst_cols + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // ========== Case 8: uint16, src=259x7, dst=260x32, LoadPad=Min, FillPad=Max ========== + + func.func @TFILLPAD_EXPAND_u16_260x32_src_259x7(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c7 = arith.constant 7 : index + %c32 = arith.constant 32 : index + %c259 = arith.constant 259 : index + %c260 = arith.constant 260 : index + %c1813 = arith.constant 1813 : index // 259*7 (src size) + %c8320 = arith.constant 8320 : index // 260*32 (dst size) + + // Src tensor_view: 259x7 + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c259, %c7], + strides = [%c1813, %c1813, %c1813, %c7, %c1] + : !pto.tensor_view<1x1x1x259x7xui16> + + // Dst tensor_view: 260x32 + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c260, %c32], + strides = [%c8320, %c8320, %c8320, %c32, %c1] + : !pto.tensor_view<1x1x1x260x32xui16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c259, %c7] + : !pto.tensor_view<1x1x1x259x7xui16> -> !pto.partition_tensor_view<1x1x1x259x7xui16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260, %c32] + : !pto.tensor_view<1x1x1x260x32xui16> -> !pto.partition_tensor_view<1x1x1x260x32xui16> + + // Src tile: LoadPadVal=Min (pad=3), src physical=260x32, v_row=259, v_col=7 + %src = pto.alloc_tile + : !pto.tile_buf + // Dst tile: FillPadVal=Max (pad=2), dst physical=260x32, v_row=260, v_col=32 (full output) + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x259x7xui16>) + outs(%src : !pto.tile_buf) + + pto.tfillpad_expand ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x260x32xui16>) + return + } + + // ========== Case 9: int8, src=259x7, dst=260x64, LoadPad=Min, FillPad=Max ========== + + func.func @TFILLPAD_EXPAND_s8_260x64_src_259x7(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c7 = arith.constant 7 : index + %c64 = arith.constant 64 : index + %c259 = arith.constant 259 : index + %c260 = arith.constant 260 : index + %c1813 = arith.constant 1813 : index // 259*7 (src size) + %c16640 = arith.constant 16640 : index // 260*64 (dst size) + + // Src tensor_view: 259x7 + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c259, %c7], + strides = [%c1813, %c1813, %c1813, %c7, %c1] + : !pto.tensor_view<1x1x1x259x7xi8> + + // Dst tensor_view: 260x64 + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c260, %c64], + strides = [%c16640, %c16640, %c16640, %c64, %c1] + : !pto.tensor_view<1x1x1x260x64xi8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c259, %c7] + : !pto.tensor_view<1x1x1x259x7xi8> -> !pto.partition_tensor_view<1x1x1x259x7xi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260, %c64] + : !pto.tensor_view<1x1x1x260x64xi8> -> !pto.partition_tensor_view<1x1x1x260x64xi8> + + // Src tile: LoadPadVal=Min (pad=3), src physical=260x64, v_row=259, v_col=7 + %src = pto.alloc_tile + : !pto.tile_buf + // Dst tile: FillPadVal=Max (pad=2), dst physical=260x64, v_row=260, v_col=64 (full output) + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x259x7xi8>) + outs(%src : !pto.tile_buf) + + pto.tfillpad_expand ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x260x64xi8>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_inplace/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_inplace/CMakeLists.txt new file mode 100644 index 000000000..9d0f6b924 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_inplace/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tfillpad_inplace PTO_LEVEL level3) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_inplace/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_inplace/cases.py new file mode 100644 index 000000000..c1b1dae17 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_inplace/cases.py @@ -0,0 +1,38 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tfillpad_inplace ST test cases. + +Matches C++ reference test case: Case 5 + +Each case defines: + - name: case identifier + - dtype: numpy dtype + - shape: (rows, cols) — tile dimensions (physical buffer size) + - valid_shape: (valid_rows, valid_cols) — valid region (smaller than shape) + - eps: tolerance for numpy.allclose +""" + +import numpy as np + +CASES = [ + # ========== Case: float, src_valid == dst_valid (no expansion) ========== + + { + "name": "f32_260x16_noexpand", + "dtype": np.float32, + "src_shape": (260, 16), # src physical + "src_valid": (260, 16), # src valid = dst valid (no expansion) + "dst_shape": (260, 16), # dst physical + "dst_valid": (260, 16), # dst valid = full output + "fill_padval": "Max", # FillPadVal (not used since no expansion) + "eps": 1e-6, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_inplace/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_inplace/compare.py new file mode 100644 index 000000000..a58a46a13 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_inplace/compare.py @@ -0,0 +1,80 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Compare output against golden for tfillpad_inplace test cases. + +For tfillpad_inplace: + - Input: full tile shape (rows x cols) + - Output: full tile shape (rows x cols) after inplace fill + - Golden: full tile shape +""" + +import os +import sys +import numpy as np + +from cases import CASES + + +def main(): + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + dtype = case["dtype"] + dst_shape = case["dst_shape"] + eps = case["eps"] + + # Load golden and output (both stored with dst_shape) + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=dtype).reshape(dst_shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=dtype).reshape(dst_shape) + + # For integer types, eps=0 means exact match + # For float types, use np.allclose with eps + if eps == 0: + # Integer comparison - exact match + if not np.array_equal(golden, output): + diff = golden - output + idx = int(np.argmax(np.abs(diff))) + print(f"[ERROR] {case['name']}: Mismatch at idx={idx} (golden={golden.flat[idx]}, output={output.flat[idx]})") + all_passed = False + else: + print(f"[INFO] {case['name']}: compare passed") + else: + # Float comparison - use allclose + # Convert to float64 for comparison (fp16 precision issues) + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + + if g.shape != o.shape: + print(f"[ERROR] {case['name']}: Shape mismatch: golden {g.shape} vs output {o.shape}") + all_passed = False + continue + + if not np.allclose(g, o, atol=eps, rtol=eps, equal_nan=True): + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + print(f"[ERROR] {case['name']}: Mismatch: max diff={float(abs_diff.flat[idx])} " + f"at idx={idx} (golden={g.flat[idx]}, output={o.flat[idx]})") + all_passed = False + else: + print(f"[INFO] {case['name']}: compare passed") + + if not all_passed: + sys.exit(2) + print("[INFO] all cases passed") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_inplace/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_inplace/gen_data.py new file mode 100644 index 000000000..2345a38a4 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_inplace/gen_data.py @@ -0,0 +1,99 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Generate golden data for tfillpad_inplace test cases. + +For tfillpad_inplace: + - Only one tile, valid_shape smaller than tile shape + - Input: full tile shape (rows x cols), random values in valid region, zeros in padding + - Golden: full tile shape with valid region copied and padding filled with MAX (PadValue.Max) +""" + +import os +import numpy as np +import struct + +from cases import CASES + +# FLT_MAX for float (matching DSL PadValue.MAX) +def _float32_from_bits(bits: int) -> float: + return struct.unpack(">f", bits.to_bytes(4, byteorder="big", signed=False))[0] + +_FLT_MAX = _float32_from_bits(0x7F7FFFFF) # ~3.4028235e+38 + + +def get_pad_value(dtype, padval_name): + """Get the actual pad value for a dtype based on PadValue enum.""" + if padval_name == "Max": + if np.issubdtype(dtype, np.floating): + return np.float32(_FLT_MAX) + else: + return np.iinfo(dtype).max + elif padval_name == "Min": + if np.issubdtype(dtype, np.floating): + return np.float32(-_FLT_MAX) + else: + return np.iinfo(dtype).min + elif padval_name == "Zero": + return dtype(0) + else: + return dtype(0) + + +def setup_case_rng(case): + """Set a per-case deterministic random seed.""" + np.random.seed(hash(case["name"]) & 0xFFFFFFFF) + + +def save_case_data(case_name, data_dict): + """Create case directory and write {name}.bin for each entry.""" + os.makedirs(case_name, exist_ok=True) + for name, arr in data_dict.items(): + arr.tofile(os.path.join(case_name, f"{name}.bin")) + + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + src_shape = case["src_shape"] + src_valid = case["src_valid"] + dst_shape = case["dst_shape"] + dst_valid = case["dst_valid"] + fill_padval = case.get("fill_padval", "Max") + + src_vr, src_vc = src_valid + dst_r, dst_c = dst_shape + dst_vr, dst_vc = dst_valid + + # Input: src valid region data (random values) + input_data = np.random.uniform(1.0, 10.0, size=(src_vr, src_vc)).astype(dtype) + + # Golden: dst full region + # Copy src.valid region to dst[:src_vr, :src_vc] + # Fill cols src_vc to dst_vc with FillPadVal + # Fill rows src_vr to dst_vr with FillPadVal (row expansion, if any) + golden = np.zeros(dst_shape, dtype=dtype) + golden[:src_vr, :src_vc] = input_data + + # Fill column padding (cols src_vc to dst_vc) + if dst_vc > src_vc: + fill_val = get_pad_value(dtype, fill_padval) + golden[:dst_vr, src_vc:dst_vc] = fill_val + + # Fill row padding (rows src_vr to dst_vr) + if dst_vr > src_vr: + fill_val = get_pad_value(dtype, fill_padval) + golden[src_vr:dst_vr, :dst_vc] = fill_val + + save_case_data(case["name"], {"input": input_data, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} " + f"src_valid={src_valid} dst_shape={dst_shape} " + f"fill_pad={fill_padval} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_inplace/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_inplace/launch.cpp new file mode 100644 index 000000000..39f42fc3a --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_inplace/launch.cpp @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// ========== Case: float, 260x16, no expansion (inplace single buffer) ========== + +extern "C" __global__ AICORE void TFILLPAD_INPLACE_f32_260x16_noexpand(__gm__ float *buf); + +void LaunchTFILLPAD_INPLACE_f32_260x16_noexpand(float *buf, float *dummy, void *stream) { + // Inplace kernel: single buffer, src == dst physically + // dummy parameter ignored, only buf is used + TFILLPAD_INPLACE_f32_260x16_noexpand<<<1, nullptr, stream>>>((__gm__ float *)buf); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_inplace/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_inplace/main.cpp new file mode 100644 index 000000000..34f94bf40 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_inplace/main.cpp @@ -0,0 +1,129 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tfillpad_inplace ST. +// Matches C++ reference test case: Case 5 + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrapper (defined in launch.cpp) +// Inplace kernel takes single buffer pointer +void LaunchTFILLPAD_INPLACE_f32_260x16_noexpand(float *buf, float *dummy, void *stream); + +enum class DataType { F32 }; + +struct TestCase { + const char *name; + DataType dtype; + size_t rows; + size_t cols; + size_t validRows; + size_t validCols; + size_t elemSize; +}; + +static const TestCase kCases[] = { + // Case: float, 260x16, no expansion (inplace: single buffer) + {"f32_260x16_noexpand", DataType::F32, + 260, 16, 260, 16, sizeof(float)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + size_t elemCount = tc.rows * tc.cols; + size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (%zux%zu, inplace) ===\n", + tc.name, tc.validRows, tc.validCols); + + std::string caseDir = std::string("./") + tc.name; + + // Single buffer for inplace operation + void *bufHost = nullptr; + void *bufDevice = nullptr; + + aclrtMallocHost(&bufHost, fileSize); + aclrtMalloc(&bufDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + // Load input data into the single buffer + if (!ReadFile((caseDir + "/input.bin").c_str(), fileSize, bufHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + // Copy input to device buffer + aclrtMemcpy(bufDevice, fileSize, bufHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + // Run inplace kernel (src == dst = bufDevice) + // Note: launch wrapper takes two args but inplace kernel uses same physical address + LaunchTFILLPAD_INPLACE_f32_260x16_noexpand((float *)bufDevice, (float *)bufDevice, stream); + + aclrtSynchronizeStream(stream); + // Copy result back (same buffer contains output) + aclrtMemcpy(bufHost, fileSize, bufDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), bufHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (bufDevice != nullptr) + aclrtFree(bufDevice); + if (bufHost != nullptr) + aclrtFreeHost(bufHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_inplace/tfillpad_inplace.pto b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_inplace/tfillpad_inplace.pto new file mode 100644 index 000000000..f4d606df3 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfillpad_inplace/tfillpad_inplace.pto @@ -0,0 +1,71 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tfillpad (inplace mode). +// Matches C++ reference test case: Case 5 +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// +// PadValue encoding: 0=Null, 1=Zero, 2=Max, 3=Min +// Case 5: float, 260x16, valid=260x7, FillPad=Max (pad=2) +// +// Note: PTOAS tstore requires dst size to match src valid_shape. +// For outputting full buffer after inplace fill, we use two tiles: +// - src tile: holds input data (valid=260x7) +// - dst tile: receives filled data (valid=260x16 for output) + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // ========== No expansion: float, 260x16 physical, src_valid == dst_valid ========== + + func.func @TFILLPAD_INPLACE_f32_260x16_noexpand(%tile_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c260 = arith.constant 260 : index + %c4160 = arith.constant 4160 : index // 260*16 (full tile size) + + // Input tensor_view: 260x16 + %src_view = pto.make_tensor_view %tile_ptr, + shape = [%c1, %c1, %c1, %c260, %c16], + strides = [%c4160, %c4160, %c4160, %c16, %c1] + : !pto.tensor_view<1x1x1x260x16xf32> + + // Output tensor_view: 260x16 (same as input) + %dst_view = pto.make_tensor_view %tile_ptr, + shape = [%c1, %c1, %c1, %c260, %c16], + strides = [%c4160, %c4160, %c4160, %c16, %c1] + : !pto.tensor_view<1x1x1x260x16xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260, %c16] + : !pto.tensor_view<1x1x1x260x16xf32> -> !pto.partition_tensor_view<1x1x1x260x16xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260, %c16] + : !pto.tensor_view<1x1x1x260x16xf32> -> !pto.partition_tensor_view<1x1x1x260x16xf32> + + // Single tile buffer in UB space at address 0 + // src_valid = dst_valid = 260x16, so no expansion needed + %tile_buf = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + + // Load full tile (260x16) + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x260x16xf32>) + outs(%tile_buf : !pto.tile_buf) + + // tfillpad_inplace: src_valid == dst_valid, no expansion + pto.tfillpad_inplace ins(%tile_buf : !pto.tile_buf) + outs(%tile_buf : !pto.tile_buf) + + // Store full tile + pto.tstore ins(%tile_buf : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x260x16xf32>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfmod/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tfmod/CMakeLists.txt new file mode 100644 index 000000000..d0810e4ed --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfmod/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tfmod) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfmod/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tfmod/cases.py new file mode 100644 index 000000000..91ebe01d8 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfmod/cases.py @@ -0,0 +1,40 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tfmod ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_16x64", + "dtype": np.float32, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-6, + }, + { + "name": "f32_32x32", + "dtype": np.float32, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-6, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfmod/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tfmod/compare.py new file mode 100644 index 000000000..4eae3bc07 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfmod/compare.py @@ -0,0 +1,48 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfmod/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tfmod/gen_data.py new file mode 100644 index 000000000..b82ea0d15 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfmod/gen_data.py @@ -0,0 +1,32 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + input2 = np.random.randint(1, 10, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + golden[:vr, :vc] = np.fmod(input1[:vr, :vc], input2[:vr, :vc]) + + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfmod/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tfmod/launch.cpp new file mode 100644 index 000000000..85899bfb6 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfmod/launch.cpp @@ -0,0 +1,27 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: f32 16x64 +extern "C" __global__ AICORE void TFMOD_f32_16x64(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTFMOD_f32_16x64(float *a, float *b, float *c, void *stream) { + TFMOD_f32_16x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 1: f32 32x32 +extern "C" __global__ AICORE void TFMOD_f32_32x32(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTFMOD_f32_32x32(float *a, float *b, float *c, void *stream) { + TFMOD_f32_32x32<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfmod/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tfmod/main.cpp new file mode 100644 index 000000000..3bf7f97af --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfmod/main.cpp @@ -0,0 +1,145 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tadd ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTFMOD_f32_16x64(float *a, float *b, float *c, void *stream); +void LaunchTFMOD_f32_32x32(float *a, float *b, float *c, void *stream); + +using LaunchFn = void (*)(float *, float *, float *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_16x64", LaunchTFMOD_f32_16x64, 16, 64, 16, 64, sizeof(float)}, + {"f32_32x32", LaunchTFMOD_f32_32x32, 32, 32, 32, 32, sizeof(float)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t src0FileSize = fileSize; + size_t src1FileSize = fileSize; + + float *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + float *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), fileSize); + aclrtMallocHost((void **)(&src1Host), fileSize); + aclrtMallocHost((void **)(&dstHost), fileSize); + + aclrtMalloc((void **)&src0Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, fileSize, src0Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, fileSize, src1Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tfmod [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfmod/tfmod.pto b/test/tilelang_st/npu/a5/src/st/testcase/tfmod/tfmod.pto new file mode 100644 index 000000000..646a27bdf --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfmod/tfmod.pto @@ -0,0 +1,141 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tadd: tload(a) + tload(b) + tadd(a,b)->c + tstore(c). +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // Case 0: f32 16x64 (1024 elements) + func.func @TFMOD_f32_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tfmod ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + return + } + + // Case 1: f32 32x32 (1024 elements) + func.func @TFMOD_f32_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + outs(%b : !pto.tile_buf) + + pto.tfmod ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfmods/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tfmods/CMakeLists.txt new file mode 100644 index 000000000..0d47eae66 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfmods/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tfmods) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfmods/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tfmods/cases.py new file mode 100644 index 000000000..e28343478 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfmods/cases.py @@ -0,0 +1,56 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tfmods ST test cases. + +tfmods: floating-point modulo, dst = src - trunc(src/scalar) * scalar +Only f32 and f16 types. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_32x64", + "dtype": np.float32, + "shape": (32, 64), + "valid_shape": (32, 64), + "eps": 1e-6, + }, + { + "name": "f16_63x64", + "dtype": np.float16, + "shape": (63, 64), + "valid_shape": (63, 64), + "eps": 1e-3, + }, + { + "name": "f32_7x448", + "dtype": np.float32, + "shape": (7, 448), + "valid_shape": (7, 448), + "eps": 1e-6, + }, + { + "name": "f32_256x16", + "dtype": np.float32, + "shape": (256, 16), + "valid_shape": (256, 16), + "eps": 1e-6, + }, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfmods/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tfmods/compare.py new file mode 100644 index 000000000..18835ae9f --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfmods/compare.py @@ -0,0 +1,56 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfmods/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tfmods/gen_data.py new file mode 100644 index 000000000..8e72da579 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfmods/gen_data.py @@ -0,0 +1,43 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +# Scalar value for floating-point modulo (matches the scalar passed in launch.cpp) +SCALAR = 3.0 + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + scalar_val = dtype(SCALAR) + golden[:vr, :vc] = np.fmod(input1[:vr, :vc], scalar_val).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__} scalar={SCALAR}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfmods/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tfmods/launch.cpp new file mode 100644 index 000000000..5bf003319 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfmods/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Scalar value for floating-point modulo (must match gen_data.py SCALAR) +static constexpr float TFMODS_SCALAR_F32 = 3.0f; + +// Case 0: f32 32x64 +extern "C" __global__ AICORE void TFMODS_f32_32x64(__gm__ float *src, __gm__ float *dst, float scalar); + +void LaunchTFMODS_f32_32x64(float *src, float *dst, void *stream) { + TFMODS_f32_32x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, TFMODS_SCALAR_F32); +} + +// Case 1: f16 63x64 +extern "C" __global__ AICORE void TFMODS_f16_63x64(__gm__ unsigned short *src, __gm__ unsigned short *dst, unsigned short scalar); + +void LaunchTFMODS_f16_63x64(unsigned short *src, unsigned short *dst, void *stream) { + TFMODS_f16_63x64<<<1, nullptr, stream>>>((__gm__ unsigned short *)src, (__gm__ unsigned short *)dst, (unsigned short)0x4200); +} + +// Case 2: f32 7x448 +extern "C" __global__ AICORE void TFMODS_f32_7x448(__gm__ float *src, __gm__ float *dst, float scalar); + +void LaunchTFMODS_f32_7x448(float *src, float *dst, void *stream) { + TFMODS_f32_7x448<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, TFMODS_SCALAR_F32); +} + +// Case 3: f32 256x16 +extern "C" __global__ AICORE void TFMODS_f32_256x16(__gm__ float *src, __gm__ float *dst, float scalar); + +void LaunchTFMODS_f32_256x16(float *src, float *dst, void *stream) { + TFMODS_f32_256x16<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, TFMODS_SCALAR_F32); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfmods/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tfmods/main.cpp new file mode 100644 index 000000000..dff832985 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfmods/main.cpp @@ -0,0 +1,135 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tfmods ST — case-table driven. +// tfmods: dst = src - trunc(src/scalar) * scalar (floating-point modulo). +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTFMODS_f32_32x64(float *src, float *dst, void *stream); +void LaunchTFMODS_f16_63x64(uint16_t *src, uint16_t *dst, void *stream); +void LaunchTFMODS_f32_7x448(float *src, float *dst, void *stream); +void LaunchTFMODS_f32_256x16(float *src, float *dst, void *stream); + +struct TestCase { + const char *name; + void (*launch)(void *, void *, void *); // src, dst, stream + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_32x64", (void (*)(void*,void*,void*))LaunchTFMODS_f32_32x64, 32, 64, 32, 64, sizeof(float)}, + {"f16_63x64", (void (*)(void*,void*,void*))LaunchTFMODS_f16_63x64, 63, 64, 63, 64, sizeof(uint16_t)}, + {"f32_7x448", (void (*)(void*,void*,void*))LaunchTFMODS_f32_7x448, 7, 448, 7, 448, sizeof(float)}, + {"f32_256x16", (void (*)(void*,void*,void*))LaunchTFMODS_f32_256x16, 256, 16, 256, 16, sizeof(float)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSize = fileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&srcHost, fileSize); + aclrtMallocHost(&dstHost, fileSize); + + aclrtMalloc(&srcDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), srcFileSize, srcHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, fileSize, srcHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tfmods [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tfmods/tfmods.pto b/test/tilelang_st/npu/a5/src/st/testcase/tfmods/tfmods.pto new file mode 100644 index 000000000..ee923cec9 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tfmods/tfmods.pto @@ -0,0 +1,177 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tfmods: tload(src) + tfmods(src, scalar)->dst + tstore(dst). +// Floating-point modulo: dst = src - trunc(src/scalar) * scalar. +// Only f32 and f16 types. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + + // Case 0: f32 32x64 (2048 elements) + func.func @TFMODS_f32_32x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) + outs(%src : !pto.tile_buf) + pto.tfmods ins(%src, %scalar : !pto.tile_buf, f32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) + return + } + + // Case 1: f16 63x64 (4032 elements) + func.func @TFMODS_f16_63x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f16) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c4032 = arith.constant 4032 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xf16> -> !pto.partition_tensor_view<1x1x1x63x64xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xf16> -> !pto.partition_tensor_view<1x1x1x63x64xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x64xf16>) + outs(%src : !pto.tile_buf) + pto.tfmods ins(%src, %scalar : !pto.tile_buf, f16) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x64xf16>) + return + } + + // Case 2: f32 7x448 (3136 elements) + func.func @TFMODS_f32_7x448(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c7 = arith.constant 7 : index + %c448 = arith.constant 448 : index + %c3136 = arith.constant 3136 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c7, %c448], + strides = [%c3136, %c3136, %c3136, %c448, %c1] + : !pto.tensor_view<1x1x1x7x448xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c7, %c448], + strides = [%c3136, %c3136, %c3136, %c448, %c1] + : !pto.tensor_view<1x1x1x7x448xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c448] + : !pto.tensor_view<1x1x1x7x448xf32> -> !pto.partition_tensor_view<1x1x1x7x448xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c448] + : !pto.tensor_view<1x1x1x7x448xf32> -> !pto.partition_tensor_view<1x1x1x7x448xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x7x448xf32>) + outs(%src : !pto.tile_buf) + pto.tfmods ins(%src, %scalar : !pto.tile_buf, f32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x7x448xf32>) + return + } + + // Case 3: f32 256x16 (4096 elements) + func.func @TFMODS_f32_256x16(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c16 = arith.constant 16 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c256, %c16], + strides = [%c4096, %c4096, %c4096, %c16, %c1] + : !pto.tensor_view<1x1x1x256x16xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c256, %c16], + strides = [%c4096, %c4096, %c4096, %c16, %c1] + : !pto.tensor_view<1x1x1x256x16xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c16] + : !pto.tensor_view<1x1x1x256x16xf32> -> !pto.partition_tensor_view<1x1x1x256x16xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c16] + : !pto.tensor_view<1x1x1x256x16xf32> -> !pto.partition_tensor_view<1x1x1x256x16xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x256x16xf32>) + outs(%src : !pto.tile_buf) + pto.tfmods ins(%src, %scalar : !pto.tile_buf, f32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x256x16xf32>) + return + } + +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tload/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tload/CMakeLists.txt new file mode 100644 index 000000000..a4ef685a5 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tload/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tload) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tload/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tload/cases.py new file mode 100644 index 000000000..a8a77ba80 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tload/cases.py @@ -0,0 +1,123 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import numpy as np + +CASES = [ + { + "name": "nd_f32_16x64", + "dtype": np.float32, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-6, + }, + { + "name": "dn_f32_16x64", + "dtype": np.float32, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-6, + }, + { + "name": "nz_f32_128x128", + "dtype": np.float32, + "shape": (128, 128), + "valid_shape": (128, 128), + "eps": 1e-6, + }, + { + "name": "nd_pad_zero_f32_16x64", + "dtype": np.float32, + "shape": (16, 64), + "valid_shape": (16, 63), + "eps": 1e-6, + "golden_fill": 0.0, + }, + { + "name": "dn_pad_max_f32_16x64", + "dtype": np.float32, + "shape": (16, 64), + "valid_shape": (15, 64), + "eps": 1e-6, + "golden_fill": np.finfo(np.float32).max, + }, + { + "name": "nz_pad_min_f32_128x128", + "dtype": np.float32, + "shape": (128, 128), + "valid_shape": (64, 128), + "eps": 1e-6, + "golden_fill": np.finfo(np.float32).min, + }, +] + + +def build_expected_output(case, input_arr): + shape = case["shape"] + vr, vc = case["valid_shape"] + dtype = case["dtype"] + + if "golden_fill" in case: + golden = np.full(shape, case["golden_fill"], dtype=dtype) + else: + golden = np.empty(shape, dtype=dtype) + + if case["name"].startswith("dn_pad_"): + flat_in = np.asarray(input_arr, dtype=dtype).reshape(-1) + flat_golden = golden.reshape(-1) + physical_rows = shape[0] + for col in range(vc): + start = physical_rows * col + flat_golden[start : start + vr] = flat_in[start : start + vr] + return golden + + if case["name"].startswith("nz_pad_"): + flat_in = np.asarray(input_arr, dtype=dtype).reshape(-1) + flat_golden = golden.reshape(-1) + block_rows = 8 + block_size = block_rows * shape[1] + num_blocks = shape[0] // block_rows + valid_rows_per_block = vr // num_blocks + for block in range(num_blocks): + base = block * block_size + valid_elems = valid_rows_per_block * shape[1] + flat_golden[base : base + valid_elems] = flat_in[base : base + valid_elems] + return golden + + if "golden_fill" in case: + golden[:vr, :vc] = input_arr[:vr, :vc] + return golden + + return np.asarray(input_arr, dtype=dtype).copy() + + +def select_compared_region(case, arr): + vr, vc = case["valid_shape"] + + if case["name"].startswith("dn_pad_"): + flat = np.asarray(arr).reshape(-1) + physical_rows = case["shape"][0] + pieces = [flat[physical_rows * col : physical_rows * col + vr] for col in range(vc)] + return np.concatenate(pieces) if pieces else flat[:0] + + if case["name"].startswith("nz_pad_"): + flat = np.asarray(arr).reshape(-1) + shape = case["shape"] + block_rows = 8 + block_size = block_rows * shape[1] + num_blocks = shape[0] // block_rows + valid_rows_per_block = vr // num_blocks + pieces = [] + for block in range(num_blocks): + base = block * block_size + valid_elems = valid_rows_per_block * shape[1] + pieces.append(flat[base : base + valid_elems]) + return np.concatenate(pieces) if pieces else flat[:0] + + return np.asarray(arr)[:vr, :vc] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tload/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tload/compare.py new file mode 100644 index 000000000..6adc9c9fe --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tload/compare.py @@ -0,0 +1,50 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import os +import sys +import numpy as np + +from cases import CASES, select_compared_region +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp( + select_compared_region(case, golden), + select_compared_region(case, output), + case["eps"], + ) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tload/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tload/gen_data.py new file mode 100644 index 000000000..449291f26 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tload/gen_data.py @@ -0,0 +1,30 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import numpy as np +from cases import CASES, build_expected_output +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + input_arr = np.random.randint(1, 17, size=shape).astype(dtype) + golden = build_expected_output(case, input_arr) + + save_case_data(case["name"], {"input": input_arr, "golden": golden}) + print( + f"[INFO] gen_data: {case['name']} shape={shape} " + f"valid_shape={(vr, vc)} dtype={dtype.__name__}" + ) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tload/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tload/launch.cpp new file mode 100644 index 000000000..70453d6b9 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tload/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +extern "C" __global__ AICORE void TLOAD_ND_f32_16x64(__gm__ float *src, __gm__ float *dst); +extern "C" __global__ AICORE void TLOAD_DN_f32_16x64(__gm__ float *src, __gm__ float *dst); +extern "C" __global__ AICORE void TLOAD_NZ_f32_128x128(__gm__ float *src, __gm__ float *dst); +extern "C" __global__ AICORE void TLOAD_ND_PAD_ZERO_f32_16x64(__gm__ float *src, __gm__ float *dst); +extern "C" __global__ AICORE void TLOAD_DN_PAD_MAX_f32_16x64(__gm__ float *src, __gm__ float *dst); +extern "C" __global__ AICORE void TLOAD_NZ_PAD_MIN_f32_128x128(__gm__ float *src, __gm__ float *dst); + +void LaunchTLOAD_ND_f32_16x64(float *src, float *dst, void *stream) { + TLOAD_ND_f32_16x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +void LaunchTLOAD_DN_f32_16x64(float *src, float *dst, void *stream) { + TLOAD_DN_f32_16x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +void LaunchTLOAD_NZ_f32_128x128(float *src, float *dst, void *stream) { + TLOAD_NZ_f32_128x128<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +void LaunchTLOAD_ND_PAD_ZERO_f32_16x64(float *src, float *dst, void *stream) { + TLOAD_ND_PAD_ZERO_f32_16x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +void LaunchTLOAD_DN_PAD_MAX_f32_16x64(float *src, float *dst, void *stream) { + TLOAD_DN_PAD_MAX_f32_16x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +void LaunchTLOAD_NZ_PAD_MIN_f32_128x128(float *src, float *dst, void *stream) { + TLOAD_NZ_PAD_MIN_f32_128x128<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tload/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tload/main.cpp new file mode 100644 index 000000000..7c0b66d14 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tload/main.cpp @@ -0,0 +1,145 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tload/tstore ST. +// Each case performs a GM -> Tile -> GM round trip and compare.py checks that +// output.bin matches input.bin exactly for the requested layout. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +void LaunchTLOAD_ND_f32_16x64(float *src, float *dst, void *stream); +void LaunchTLOAD_DN_f32_16x64(float *src, float *dst, void *stream); +void LaunchTLOAD_NZ_f32_128x128(float *src, float *dst, void *stream); +void LaunchTLOAD_ND_PAD_ZERO_f32_16x64(float *src, float *dst, void *stream); +void LaunchTLOAD_DN_PAD_MAX_f32_16x64(float *src, float *dst, void *stream); +void LaunchTLOAD_NZ_PAD_MIN_f32_128x128(float *src, float *dst, void *stream); + +using LaunchFn = void (*)(float *, float *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; + size_t cols; + size_t elemSize; +}; + +static const TestCase kCases[] = { + {"nd_f32_16x64", LaunchTLOAD_ND_f32_16x64, 16, 64, sizeof(float)}, + {"dn_f32_16x64", LaunchTLOAD_DN_f32_16x64, 16, 64, sizeof(float)}, + {"nz_f32_128x128", LaunchTLOAD_NZ_f32_128x128, 128, 128, sizeof(float)}, + {"nd_pad_zero_f32_16x64", LaunchTLOAD_ND_PAD_ZERO_f32_16x64, 16, 64, sizeof(float)}, + {"dn_pad_max_f32_16x64", LaunchTLOAD_DN_PAD_MAX_f32_16x64, 16, 64, sizeof(float)}, + {"nz_pad_min_f32_128x128", LaunchTLOAD_NZ_PAD_MIN_f32_128x128, 128, 128, sizeof(float)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (%zux%zu) ===\n", tc.name, tc.rows, tc.cols); + + std::string caseDir = std::string("./") + tc.name; + size_t inputFileSize = fileSize; + + float *srcHost = nullptr; + float *dstHost = nullptr; + float *srcDevice = nullptr; + float *dstDevice = nullptr; + + aclrtMallocHost((void **)(&srcHost), fileSize); + aclrtMallocHost((void **)(&dstHost), fileSize); + aclrtMalloc((void **)&srcDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input.bin").c_str(), inputFileSize, srcHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, fileSize, srcHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + tc.launch(srcDevice, dstDevice, stream); + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + bool matchedCase = (caseFilter == nullptr); + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + matchedCase = true; + int ret = RunCase(kCases[i], stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (!matchedCase) { + std::fprintf(stderr, "[ERROR] unknown case filter: %s\n", caseFilter); + std::fprintf(stderr, "[ERROR] supported cases:"); + for (size_t i = 0; i < kNumCases; ++i) { + std::fprintf(stderr, " %s", kCases[i].name); + } + std::fprintf(stderr, "\n"); + rc = 1; + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tload/tload.pto b/test/tilelang_st/npu/a5/src/st/testcase/tload/tload.pto new file mode 100644 index 000000000..8b1aeef63 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tload/tload.pto @@ -0,0 +1,228 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tload + pto.tstore round-trip coverage. +// Each kernel only performs GM -> Tile -> GM, so the testcase validates the +// DMA layout path directly for ND, DN, and NZ vector tiles. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @TLOAD_ND_f32_16x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + + %tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%tile : !pto.tile_buf) + pto.tstore ins(%tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + return + } + + func.func @TLOAD_DN_f32_16x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c1, %c16] + : !pto.tensor_view<1x1x1x16x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c1, %c16] + : !pto.tensor_view<1x1x1x16x64xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + + %tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%tile : !pto.tile_buf) + pto.tstore ins(%tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + return + } + + func.func @TLOAD_NZ_f32_128x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c16, %c1, %c128, %c1, %c8], + strides = [%c1024, %c1024, %c8, %c8, %c1] + : !pto.tensor_view<16x1x128x1x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c16, %c1, %c128, %c1, %c8], + strides = [%c1024, %c1024, %c8, %c8, %c1] + : !pto.tensor_view<16x1x128x1x8xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c16, %c1, %c128, %c1, %c8] + : !pto.tensor_view<16x1x128x1x8xf32> -> !pto.partition_tensor_view<16x1x128x1x8xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c16, %c1, %c128, %c1, %c8] + : !pto.tensor_view<16x1x128x1x8xf32> -> !pto.partition_tensor_view<16x1x128x1x8xf32> + + %tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<16x1x128x1x8xf32>) + outs(%tile : !pto.tile_buf) + pto.tstore ins(%tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<16x1x128x1x8xf32>) + return + } + + func.func @TLOAD_ND_PAD_ZERO_f32_16x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c63 = arith.constant 63 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c63], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x63xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c63], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x63xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c63] + : !pto.tensor_view<1x1x1x16x63xf32> -> !pto.partition_tensor_view<1x1x1x16x63xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c63] + : !pto.tensor_view<1x1x1x16x63xf32> -> !pto.partition_tensor_view<1x1x1x16x63xf32> + + %tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x63xf32>) + outs(%tile : !pto.tile_buf) + pto.tstore ins(%tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x63xf32>) + return + } + + func.func @TLOAD_DN_PAD_MAX_f32_16x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c15, %c64], + strides = [%c1024, %c1024, %c1024, %c1, %c16] + : !pto.tensor_view<1x1x1x15x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c15, %c64], + strides = [%c1024, %c1024, %c1024, %c1, %c16] + : !pto.tensor_view<1x1x1x15x64xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c64] + : !pto.tensor_view<1x1x1x15x64xf32> -> !pto.partition_tensor_view<1x1x1x15x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c64] + : !pto.tensor_view<1x1x1x15x64xf32> -> !pto.partition_tensor_view<1x1x1x15x64xf32> + + %tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x64xf32>) + outs(%tile : !pto.tile_buf) + pto.tstore ins(%tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x15x64xf32>) + return + } + + func.func @TLOAD_NZ_PAD_MIN_f32_128x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c16, %c1, %c64, %c1, %c8], + strides = [%c1024, %c1024, %c8, %c8, %c1] + : !pto.tensor_view<16x1x64x1x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c16, %c1, %c64, %c1, %c8], + strides = [%c1024, %c1024, %c8, %c8, %c1] + : !pto.tensor_view<16x1x64x1x8xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c16, %c1, %c64, %c1, %c8] + : !pto.tensor_view<16x1x64x1x8xf32> -> !pto.partition_tensor_view<16x1x64x1x8xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c16, %c1, %c64, %c1, %c8] + : !pto.tensor_view<16x1x64x1x8xf32> -> !pto.partition_tensor_view<16x1x64x1x8xf32> + + %tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<16x1x64x1x8xf32>) + outs(%tile : !pto.tile_buf) + pto.tstore ins(%tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<16x1x64x1x8xf32>) + return + } +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tlog/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tlog/CMakeLists.txt new file mode 100644 index 000000000..f17ca9cf8 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tlog/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tlog) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tlog/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tlog/cases.py new file mode 100644 index 000000000..1162c5d34 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tlog/cases.py @@ -0,0 +1,87 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tlog ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_16x64", + "dtype": np.float32, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-6, + }, + { + "name": "f32_32x32", + "dtype": np.float32, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-6, + }, + { + "name": "f16_16x64", + "dtype": np.float16, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-3, + }, + { + "name": "f16_32x32", + "dtype": np.float16, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-3, + }, + { + "name": "f32_16x64_hp", + "dtype": np.float32, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-5, + "precision_mode": "HIGH_PRECISION", + }, + { + "name": "f32_32x32_hp", + "dtype": np.float32, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-5, + "precision_mode": "HIGH_PRECISION", + }, + { + "name": "f16_16x64_hp", + "dtype": np.float16, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-3, + "precision_mode": "HIGH_PRECISION", + }, + { + "name": "f16_32x32_hp", + "dtype": np.float16, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-3, + "precision_mode": "HIGH_PRECISION", + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tlog/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tlog/compare.py new file mode 100644 index 000000000..428604929 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tlog/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tlog/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tlog/gen_data.py new file mode 100644 index 000000000..459d8fb12 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tlog/gen_data.py @@ -0,0 +1,33 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + # Generate positive random values for log (log requires positive inputs) + input = np.random.uniform(0.1, 10.0, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + golden[:vr, :vc] = np.log(input[:vr, :vc]).astype(dtype, copy=False) + + save_case_data(case["name"], {"input": input, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tlog/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tlog/launch.cpp new file mode 100644 index 000000000..5fbe4521a --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tlog/launch.cpp @@ -0,0 +1,69 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: f32 16x64 +extern "C" __global__ AICORE void TLOG_f32_16x64(__gm__ float *a, __gm__ float *b); + +void LaunchTLOG_f32_16x64(void *a, void *b, void *stream) { + TLOG_f32_16x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b); +} + +// Case 1: f32 32x32 +extern "C" __global__ AICORE void TLOG_f32_32x32(__gm__ float *a, __gm__ float *b); + +void LaunchTLOG_f32_32x32(void *a, void *b, void *stream) { + TLOG_f32_32x32<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b); +} + +// Case 2: f16 16x64 +extern "C" __global__ AICORE void TLOG_f16_16x64(__gm__ uint16_t *a, __gm__ uint16_t *b); + +void LaunchTLOG_f16_16x64(void *a, void *b, void *stream) { + TLOG_f16_16x64<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b); +} + +// Case 3: f16 32x32 +extern "C" __global__ AICORE void TLOG_f16_32x32(__gm__ uint16_t *a, __gm__ uint16_t *b); + +void LaunchTLOG_f16_32x32(void *a, void *b, void *stream) { + TLOG_f16_32x32<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b); +} + +// Case 5: f32 16x64 high precision +extern "C" __global__ AICORE void TLOG_f32_16x64_hp(__gm__ float *a, __gm__ float *b); + +void LaunchTLOG_f32_16x64_hp(void *a, void *b, void *stream) { + TLOG_f32_16x64_hp<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b); +} + +// Case 6: f32 32x32 high precision +extern "C" __global__ AICORE void TLOG_f32_32x32_hp(__gm__ float *a, __gm__ float *b); + +void LaunchTLOG_f32_32x32_hp(void *a, void *b, void *stream) { + TLOG_f32_32x32_hp<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b); +} + +// Case 7: f16 16x64 high precision +extern "C" __global__ AICORE void TLOG_f16_16x64_hp(__gm__ uint16_t *a, __gm__ uint16_t *b); + +void LaunchTLOG_f16_16x64_hp(void *a, void *b, void *stream) { + TLOG_f16_16x64_hp<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b); +} + +// Case 8: f16 32x32 high precision +extern "C" __global__ AICORE void TLOG_f16_32x32_hp(__gm__ uint16_t *a, __gm__ uint16_t *b); + +void LaunchTLOG_f16_32x32_hp(void *a, void *b, void *stream) { + TLOG_f16_32x32_hp<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tlog/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tlog/main.cpp new file mode 100644 index 000000000..b09c0a948 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tlog/main.cpp @@ -0,0 +1,145 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tlog ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTLOG_f32_16x64(void *a, void *b, void *stream); +void LaunchTLOG_f32_32x32(void *a, void *b, void *stream); +void LaunchTLOG_f16_16x64(void *a, void *b, void *stream); +void LaunchTLOG_f16_32x32(void *a, void *b, void *stream); +void LaunchTLOG_f32_16x64_hp(void *a, void *b, void *stream); +void LaunchTLOG_f32_32x32_hp(void *a, void *b, void *stream); +void LaunchTLOG_f16_16x64_hp(void *a, void *b, void *stream); +void LaunchTLOG_f16_32x32_hp(void *a, void *b, void *stream); + +using LaunchFn = void (*)(void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_16x64", LaunchTLOG_f32_16x64, 16, 64, 16, 64, sizeof(float)}, + {"f32_32x32", LaunchTLOG_f32_32x32, 32, 32, 32, 32, sizeof(float)}, + {"f16_16x64", LaunchTLOG_f16_16x64, 16, 64, 16, 64, sizeof(uint16_t)}, + {"f16_32x32", LaunchTLOG_f16_32x32, 32, 32, 32, 32, sizeof(uint16_t)}, + {"f32_16x64_hp", LaunchTLOG_f32_16x64_hp, 16, 64, 16, 64, sizeof(float)}, + {"f32_32x32_hp", LaunchTLOG_f32_32x32_hp, 32, 32, 32, 32, sizeof(float)}, + {"f16_16x64_hp", LaunchTLOG_f16_16x64_hp, 16, 64, 16, 64, sizeof(uint16_t)}, + {"f16_32x32_hp", LaunchTLOG_f16_32x32_hp, 32, 32, 32, 32, sizeof(uint16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSize = fileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&srcHost), fileSize); + aclrtMallocHost((void **)(&dstHost), fileSize); + + aclrtMalloc((void **)&srcDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input.bin").c_str(), srcFileSize, srcHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, fileSize, srcHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tlog [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tlog/tlog.pto b/test/tilelang_st/npu/a5/src/st/testcase/tlog/tlog.pto new file mode 100644 index 000000000..b2929fc29 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tlog/tlog.pto @@ -0,0 +1,350 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tlog: tload(a) + tlog(a)->b + tstore(b). +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // Case 0: f32 16x64 (1024 elements) + func.func @TLOG_f32_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%a : !pto.tile_buf) + + pto.tlog ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + return + } + + // Case 1: f32 32x32 (1024 elements) + func.func @TLOG_f32_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + outs(%a : !pto.tile_buf) + + pto.tlog ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + return + } + + // Case 2: f16 16x64 (1024 elements) + func.func @TLOG_f16_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + outs(%a : !pto.tile_buf) + + pto.tlog ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + return + } + + // Case 3: f16 32x32 (1024 elements) + func.func @TLOG_f16_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf16> -> !pto.partition_tensor_view<1x1x1x32x32xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf16> -> !pto.partition_tensor_view<1x1x1x32x32xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xf16>) + outs(%a : !pto.tile_buf) + + pto.tlog ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x32x32xf16>) + return + } + + // Case 5: f32 16x64 high precision (1024 elements) + func.func @TLOG_f32_16x64_hp(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%a : !pto.tile_buf) + + pto.tlog ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + {precision_mode = #pto} + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + return + } + + // Case 6: f32 32x32 high precision (1024 elements) + func.func @TLOG_f32_32x32_hp(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + outs(%a : !pto.tile_buf) + + pto.tlog ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + {precision_mode = #pto} + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + return + } + + // Case 7: f16 16x64 high precision (1024 elements) + func.func @TLOG_f16_16x64_hp(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + outs(%a : !pto.tile_buf) + + pto.tlog ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + {precision_mode = #pto} + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + return + } + + // Case 8: f16 32x32 high precision (1024 elements) + func.func @TLOG_f16_32x32_hp(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf16> -> !pto.partition_tensor_view<1x1x1x32x32xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf16> -> !pto.partition_tensor_view<1x1x1x32x32xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xf16>) + outs(%a : !pto.tile_buf) + + pto.tlog ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + {precision_mode = #pto} + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x32x32xf16>) + return + } +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tlrelu/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tlrelu/CMakeLists.txt new file mode 100644 index 000000000..7c79f07d5 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tlrelu/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tlrelu) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tlrelu/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tlrelu/cases.py new file mode 100644 index 000000000..a2b897e5d --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tlrelu/cases.py @@ -0,0 +1,65 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tlrelu ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — src tile dimensions (UB allocation). + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - dst_shape: (rows, cols) — dst tile physical dimensions (UB allocation, may have padding). + - dst_valid_shape: (valid_rows, valid_cols) — dst effective region (same as valid_shape). + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_32x64_dst128", + "dtype": np.float32, + "shape": (32, 64), + "valid_shape": (32, 64), + "dst_shape": (32, 128), + "dst_valid_shape": (32, 64), + "eps": 1e-3, + }, + { + "name": "f16_63x64_dst128", + "dtype": np.float16, + "shape": (63, 64), + "valid_shape": (63, 64), + "dst_shape": (63, 128), + "dst_valid_shape": (63, 64), + "eps": 1e-3, + }, + { + "name": "f32_7x448_dst512", + "dtype": np.float32, + "shape": (7, 448), + "valid_shape": (7, 448), + "dst_shape": (7, 512), + "dst_valid_shape": (7, 448), + "eps": 1e-3, + }, + { + "name": "f32_256x16_dst32", + "dtype": np.float32, + "shape": (256, 16), + "valid_shape": (256, 16), + "dst_shape": (256, 32), + "dst_valid_shape": (256, 16), + "eps": 1e-3, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tlrelu/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tlrelu/compare.py new file mode 100644 index 000000000..6af6f6d5c --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tlrelu/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + dst_shape = case["dst_shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(dst_shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(dst_shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tlrelu/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tlrelu/gen_data.py new file mode 100644 index 000000000..22b2b3314 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tlrelu/gen_data.py @@ -0,0 +1,45 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +import struct +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + dst_shape = case["dst_shape"] + valid_shape = case["valid_shape"] + + rows, cols = shape + dst_rows, dst_cols = dst_shape + vr, vc = valid_shape + + input_arr = np.random.uniform(low=-8, high=8, size=(rows, cols)).astype(dtype) + slope = np.random.uniform(low=-8, high=8, size=(1, 1)).astype(np.float32) + golden = np.zeros((dst_rows, dst_cols), dtype=dtype) + + for i in range(vr): + for j in range(vc): + if input_arr[i, j] > 0: + golden[i, j] = input_arr[i, j] + else: + golden[i, j] = dtype(input_arr[i, j] * slope[0, 0]) + + slope_arr = np.array([slope[0, 0]], dtype=np.float32) + + save_case_data(case["name"], {"input": input_arr, "slope": slope_arr, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} dst_shape={dst_shape} valid_shape={valid_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tlrelu/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tlrelu/launch.cpp new file mode 100644 index 000000000..356bd5115 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tlrelu/launch.cpp @@ -0,0 +1,41 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: f32 32x64 -> dst 32x128 (valid 32x64) +extern "C" __global__ AICORE void TLRELU_f32_32x64_dst128(__gm__ float *src, __gm__ float *dst, float slope); + +void LaunchTLRELU_f32_32x64_dst128(float *src, float *dst, float slope, void *stream) { + TLRELU_f32_32x64_dst128<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, slope); +} + +// Case 1: f16 63x64 -> dst 63x128 (valid 63x64) +extern "C" __global__ AICORE void TLRELU_f16_63x64_dst128(__gm__ uint16_t *src, __gm__ uint16_t *dst, float slope); + +void LaunchTLRELU_f16_63x64_dst128(uint16_t *src, uint16_t *dst, float slope, void *stream) { + TLRELU_f16_63x64_dst128<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint16_t *)dst, slope); +} + +// Case 2: f32 7x448 -> dst 7x512 (valid 7x448) +extern "C" __global__ AICORE void TLRELU_f32_7x448_dst512(__gm__ float *src, __gm__ float *dst, float slope); + +void LaunchTLRELU_f32_7x448_dst512(float *src, float *dst, float slope, void *stream) { + TLRELU_f32_7x448_dst512<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, slope); +} + +// Case 3: f32 256x16 -> dst 256x32 (valid 256x16) +extern "C" __global__ AICORE void TLRELU_f32_256x16_dst32(__gm__ float *src, __gm__ float *dst, float slope); + +void LaunchTLRELU_f32_256x16_dst32(float *src, float *dst, float slope, void *stream) { + TLRELU_f32_256x16_dst32<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, slope); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tlrelu/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tlrelu/main.cpp new file mode 100644 index 000000000..3b75edf69 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tlrelu/main.cpp @@ -0,0 +1,154 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tlrelu ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTLRELU_f32_32x64_dst128(float *src, float *dst, float slope, void *stream); +void LaunchTLRELU_f16_63x64_dst128(uint16_t *src, uint16_t *dst, float slope, void *stream); +void LaunchTLRELU_f32_7x448_dst512(float *src, float *dst, float slope, void *stream); +void LaunchTLRELU_f32_256x16_dst32(float *src, float *dst, float slope, void *stream); + +using LaunchFn = void (*)(void *, void *, float, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t srcRows; // src tile rows + size_t srcCols; // src tile cols + size_t dstRows; // dst tile rows (may have padding) + size_t dstCols; // dst tile cols (may have padding) + size_t validRows; // effective computation rows (<= srcRows, dstRows) + size_t validCols; // effective computation cols (<= srcCols, dstCols) + size_t elemSize; // bytes per element + bool isFp16; // true for float16 case +}; + +static const TestCase kCases[] = { + {"f32_32x64_dst128", (LaunchFn)LaunchTLRELU_f32_32x64_dst128, 32, 64, 32, 128, 32, 64, sizeof(float), false}, + {"f16_63x64_dst128", (LaunchFn)LaunchTLRELU_f16_63x64_dst128, 63, 64, 63, 128, 63, 64, sizeof(uint16_t), true}, + {"f32_7x448_dst512", (LaunchFn)LaunchTLRELU_f32_7x448_dst512, 7, 448, 7, 512, 7, 448, sizeof(float), false}, + {"f32_256x16_dst32", (LaunchFn)LaunchTLRELU_f32_256x16_dst32, 256, 16, 256, 32, 256, 16, sizeof(float), false}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + size_t srcFileSize = tc.srcRows * tc.srcCols * tc.elemSize; + size_t dstFileSize = tc.dstRows * tc.dstCols * tc.elemSize; + size_t actualSize = 0; + + std::printf("[INFO] === case: %s (src=%zux%zu, dst=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.srcRows, tc.srcCols, tc.dstRows, tc.dstCols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + float slope = 0.0f; + + aclrtMallocHost(&srcHost, srcFileSize); + aclrtMallocHost(&dstHost, dstFileSize); + + aclrtMalloc(&srcDevice, srcFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile(caseDir + "/input.bin", actualSize, srcHost, srcFileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input.bin\n", caseDir.c_str()); + rc = 1; + } + + // Read slope (4 bytes float) + if (rc == 0) { + std::ifstream slopeFile(caseDir + "/slope.bin", std::ios::binary); + if (!slopeFile) { + std::fprintf(stderr, "[ERROR] failed to open %s/slope.bin\n", caseDir.c_str()); + rc = 1; + } else { + slopeFile.read(reinterpret_cast(&slope), sizeof(float)); + slopeFile.close(); + } + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, srcFileSize, srcHost, srcFileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, slope, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tlrelu [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tlrelu/tlrelu.pto b/test/tilelang_st/npu/a5/src/st/testcase/tlrelu/tlrelu.pto new file mode 100644 index 000000000..bc4490397 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tlrelu/tlrelu.pto @@ -0,0 +1,214 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tlrelu: tload(src) + tlrelu(src, slope)->dst + tstore(dst). +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // Case 0: f32 src 32x64 -> dst 32x128 (valid 32x64) + func.func @TLRELU_f32_32x64_dst128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %slope: f32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c2048 = arith.constant 2048 : index + %c4096 = arith.constant 4096 : index + + // Src GM view: 1x1x1x32x64 + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf32> + + // Dst GM view: shape=valid_shape (32x64), strides based on dst allocation (32x128) + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c4096, %c4096, %c4096, %c128, %c1] + : !pto.tensor_view<1x1x1x32x64xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + // Dst partition: sizes = valid_shape (32x64) + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + + // Src UB tile: 32x64, valid 32x64 + %src_tile = pto.alloc_tile + : !pto.tile_buf + // Dst UB tile: 32x64, valid 32x64 + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) + outs(%src_tile : !pto.tile_buf) + + pto.tlrelu ins(%src_tile, %slope : !pto.tile_buf, f32) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) + return + } + + // Case 1: f16 src 63x64 -> dst 63x128 (valid 63x64) + func.func @TLRELU_f16_63x64_dst128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %slope: f32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c4032 = arith.constant 4032 : index + %c8064 = arith.constant 8064 : index + + // Src GM view: 1x1x1x63x64 + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xf16> + + // Dst GM view: shape=valid_shape (63x64), strides based on dst allocation (63x128) + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c8064, %c8064, %c8064, %c128, %c1] + : !pto.tensor_view<1x1x1x63x64xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xf16> -> !pto.partition_tensor_view<1x1x1x63x64xf16> + // Dst partition: sizes = valid_shape (63x64) + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xf16> -> !pto.partition_tensor_view<1x1x1x63x64xf16> + + // Src UB tile: 63x64, valid 63x64 + %src_tile = pto.alloc_tile + : !pto.tile_buf + // Dst UB tile: 63x64, valid 63x64 + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x64xf16>) + outs(%src_tile : !pto.tile_buf) + + pto.tlrelu ins(%src_tile, %slope : !pto.tile_buf, f32) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x64xf16>) + return + } + + // Case 2: f32 src 7x448 -> dst 7x512 (valid 7x448) + func.func @TLRELU_f32_7x448_dst512(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %slope: f32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c7 = arith.constant 7 : index + %c448 = arith.constant 448 : index + %c512 = arith.constant 512 : index + %c3136 = arith.constant 3136 : index + %c3584 = arith.constant 3584 : index + + // Src GM view: 1x1x1x7x448 + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c7, %c448], + strides = [%c3136, %c3136, %c3136, %c448, %c1] + : !pto.tensor_view<1x1x1x7x448xf32> + + // Dst GM view: shape=valid_shape (7x448), strides based on dst allocation (7x512) + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c7, %c448], + strides = [%c3584, %c3584, %c3584, %c512, %c1] + : !pto.tensor_view<1x1x1x7x448xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c448] + : !pto.tensor_view<1x1x1x7x448xf32> -> !pto.partition_tensor_view<1x1x1x7x448xf32> + // Dst partition: sizes = valid_shape (7x448) + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c448] + : !pto.tensor_view<1x1x1x7x448xf32> -> !pto.partition_tensor_view<1x1x1x7x448xf32> + + // Src UB tile: 7x448, valid 7x448 + %src_tile = pto.alloc_tile + : !pto.tile_buf + // Dst UB tile: 7x448, valid 7x448 + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x7x448xf32>) + outs(%src_tile : !pto.tile_buf) + + pto.tlrelu ins(%src_tile, %slope : !pto.tile_buf, f32) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x7x448xf32>) + return + } + + // Case 3: f32 src 256x16 -> dst 256x32 (valid 256x16) + func.func @TLRELU_f32_256x16_dst32(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %slope: f32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c256 = arith.constant 256 : index + %c4096 = arith.constant 4096 : index + %c8192 = arith.constant 8192 : index + + // Src GM view: 1x1x1x256x16 + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c256, %c16], + strides = [%c4096, %c4096, %c4096, %c16, %c1] + : !pto.tensor_view<1x1x1x256x16xf32> + + // Dst GM view: shape=valid_shape (256x16), strides based on dst allocation (256x32) + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c256, %c16], + strides = [%c8192, %c8192, %c8192, %c32, %c1] + : !pto.tensor_view<1x1x1x256x16xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c16] + : !pto.tensor_view<1x1x1x256x16xf32> -> !pto.partition_tensor_view<1x1x1x256x16xf32> + // Dst partition: sizes = valid_shape (256x16) + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c16] + : !pto.tensor_view<1x1x1x256x16xf32> -> !pto.partition_tensor_view<1x1x1x256x16xf32> + + // Src UB tile: 256x16, valid 256x16 + %src_tile = pto.alloc_tile + : !pto.tile_buf + // Dst UB tile: 256x16, valid 256x16 + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x256x16xf32>) + outs(%src_tile : !pto.tile_buf) + + pto.tlrelu ins(%src_tile, %slope : !pto.tile_buf, f32) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x256x16xf32>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmatmul/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul/CMakeLists.txt new file mode 100644 index 000000000..c109a82c7 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_cube_st(tmatmul PTO_LEVEL level3) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmatmul/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul/cases.py new file mode 100644 index 000000000..cd58fc96a --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul/cases.py @@ -0,0 +1,25 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tmatmul ST test cases.""" + +import numpy as np + + +CASES = [ + { + "name": "f16_16x16x16", + "dtype": np.float16, + "shape_a": (16, 16), + "shape_b": (16, 16), + "shape_c": (16, 16), + "eps": 1e-2, + }, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmatmul/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul/compare.py new file mode 100644 index 000000000..0074a8142 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul/compare.py @@ -0,0 +1,45 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass + + +def main(): + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape_c = case["shape_c"] + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=np.float32).reshape(shape_c) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=np.float32).reshape(shape_c) + + ok = result_cmp(golden, output, case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmatmul/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul/gen_data.py new file mode 100644 index 000000000..6835cda62 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul/gen_data.py @@ -0,0 +1,32 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np + +from cases import CASES +from st_common import setup_case_rng, save_case_data + + +for case in CASES: + setup_case_rng(case) + + shape_a = case["shape_a"] + shape_b = case["shape_b"] + dtype = case["dtype"] + + lhs = np.random.uniform(-1.0, 1.0, size=shape_a).astype(dtype) + rhs = np.random.uniform(-1.0, 1.0, size=shape_b).astype(dtype) + golden = np.matmul(lhs.astype(np.float32), rhs.astype(np.float32)).astype(np.float32) + + save_case_data(case["name"], {"input1": lhs, "input2": rhs, "golden": golden}) + print( + f"[INFO] gen_data: {case['name']} " + f"lhs={shape_a} rhs={shape_b} out={case['shape_c']} dtype={dtype.__name__}" + ) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmatmul/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul/launch.cpp new file mode 100644 index 000000000..ac4b3c48a --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul/launch.cpp @@ -0,0 +1,19 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +extern "C" __global__ AICORE void TMATMUL_f16_16x16x16(__gm__ uint16_t *a, __gm__ uint16_t *b, __gm__ float *c); + +void LaunchTMATMUL_f16_16x16x16(uint16_t *a, uint16_t *b, float *c, void *stream) { + TMATMUL_f16_16x16x16<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b, (__gm__ float *)c); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmatmul/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul/main.cpp new file mode 100644 index 000000000..2b1b50b0b --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul/main.cpp @@ -0,0 +1,158 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +void LaunchTMATMUL_f16_16x16x16(uint16_t *a, uint16_t *b, float *c, void *stream); + +using LaunchFn = void (*)(uint16_t *, uint16_t *, float *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t lhsRows; + size_t lhsCols; + size_t rhsRows; + size_t rhsCols; + size_t outRows; + size_t outCols; +}; + +static const TestCase kCases[] = { + {"f16_16x16x16", LaunchTMATMUL_f16_16x16x16, 16, 16, 16, 16, 16, 16}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + (void)deviceId; + int rc = 0; + const size_t lhsElems = tc.lhsRows * tc.lhsCols; + const size_t rhsElems = tc.rhsRows * tc.rhsCols; + const size_t outElems = tc.outRows * tc.outCols; + const size_t lhsBytes = lhsElems * sizeof(uint16_t); + const size_t rhsBytes = rhsElems * sizeof(uint16_t); + const size_t outBytes = outElems * sizeof(float); + size_t lhsFileSize = lhsBytes; + size_t rhsFileSize = rhsBytes; + + std::printf( + "[INFO] === case: %s (lhs=%zux%zu, rhs=%zux%zu, out=%zux%zu) ===\n", + tc.name, + tc.lhsRows, + tc.lhsCols, + tc.rhsRows, + tc.rhsCols, + tc.outRows, + tc.outCols + ); + + std::string caseDir = std::string("./") + tc.name; + + void *lhsHost = nullptr; + void *rhsHost = nullptr; + void *outHost = nullptr; + void *lhsDevice = nullptr; + void *rhsDevice = nullptr; + void *outDevice = nullptr; + + aclrtMallocHost(&lhsHost, lhsBytes); + aclrtMallocHost(&rhsHost, rhsBytes); + aclrtMallocHost(&outHost, outBytes); + + aclrtMalloc(&lhsDevice, lhsBytes, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&rhsDevice, rhsBytes, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), lhsFileSize, lhsHost, lhsBytes)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), rhsFileSize, rhsHost, rhsBytes)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(lhsDevice, lhsBytes, lhsHost, lhsBytes, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(rhsDevice, rhsBytes, rhsHost, rhsBytes, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch( + static_cast(lhsDevice), + static_cast(rhsDevice), + static_cast(outDevice), + stream + ); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), outHost, outBytes)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (lhsDevice != nullptr) + aclrtFree(lhsDevice); + if (rhsDevice != nullptr) + aclrtFree(rhsDevice); + if (outDevice != nullptr) + aclrtFree(outDevice); + if (lhsHost != nullptr) + aclrtFreeHost(lhsHost); + if (rhsHost != nullptr) + aclrtFreeHost(rhsHost); + if (outHost != nullptr) + aclrtFreeHost(outHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmatmul/tmatmul.pto b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul/tmatmul.pto new file mode 100644 index 000000000..9e0512697 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmatmul/tmatmul.pto @@ -0,0 +1,88 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernel for cube matmul. +// Keep pto.tmatmul on the TileOp expansion path while bridging the boundary +// ops through pto.tile_buf_addr on the level3/manual-address path. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @TMATMUL_f16_16x16x16(%a_gm: !pto.ptr, %b_gm: !pto.ptr, %c_gm: !pto.ptr) attributes {pto.kernel} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + %c512_i64 = arith.constant 512 : i64 + %false = arith.constant false + + %l1_a_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %l1_b_tile = pto.alloc_tile addr = %c512_i64 + : !pto.tile_buf + %l0a_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %l0b_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %l0c_tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + + %l1_a = pto.tile_buf_addr %l1_a_tile + : !pto.tile_buf + -> !pto.ptr + %l1_b = pto.tile_buf_addr %l1_b_tile + : !pto.tile_buf + -> !pto.ptr + %l0a = pto.tile_buf_addr %l0a_tile + : !pto.tile_buf + -> !pto.ptr + %l0b = pto.tile_buf_addr %l0b_tile + : !pto.tile_buf + -> !pto.ptr + %l0c = pto.tile_buf_addr %l0c_tile + : !pto.tile_buf + -> !pto.ptr + pto.mte_gm_l1_frac %a_gm, %l1_a, nd2nz, + shape(%c16_i64, %c16_i64), + src_layout(%c32_i64), + dst_group(%c1_i64, %c1_i64, %c16_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.mte_l1_l0a %l1_a, %l0a, %c16_i64, %c16_i64 + : !pto.ptr, !pto.ptr, i64, i64 + + pto.mte_gm_l1_frac %b_gm, %l1_b, nd2nz, + shape(%c16_i64, %c16_i64), + src_layout(%c32_i64), + dst_group(%c1_i64, %c1_i64, %c16_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID1"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID1"] + pto.mte_l1_l0b %l1_b, %l0b, %c16_i64, %c16_i64 {transpose = true} + : !pto.ptr, !pto.ptr, i64, i64 + + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.tmatmul ins(%l0a_tile, %l0b_tile : !pto.tile_buf, + !pto.tile_buf) + outs(%l0c_tile : !pto.tile_buf) + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.mte_l0c_gm %l0c, %c_gm, %c16_i64, %c16_i64, %c16_i64, %c16_i64, %c0_i64, %c0_i64, + nz2nd + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmax/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tmax/CMakeLists.txt new file mode 100644 index 000000000..8012132e6 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmax/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tmax) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmax/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tmax/cases.py new file mode 100644 index 000000000..69ba77ac4 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmax/cases.py @@ -0,0 +1,40 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tmax ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_16x64", + "dtype": np.float32, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-6, + }, + { + "name": "f32_32x32", + "dtype": np.float32, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-6, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmax/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tmax/compare.py new file mode 100644 index 000000000..4eae3bc07 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmax/compare.py @@ -0,0 +1,48 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmax/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tmax/gen_data.py new file mode 100644 index 000000000..0d1487e44 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmax/gen_data.py @@ -0,0 +1,32 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + input2 = np.random.randint(1, 10, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + golden[:vr, :vc] = np.maximum(input1[:vr, :vc], input2[:vr, :vc]).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmax/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tmax/launch.cpp new file mode 100644 index 000000000..3d47d685c --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmax/launch.cpp @@ -0,0 +1,27 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: f32 16x64 +extern "C" __global__ AICORE void TMAX_f32_16x64(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTMAX_f32_16x64(float *a, float *b, float *c, void *stream) { + TMAX_f32_16x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 1: f32 32x32 +extern "C" __global__ AICORE void TMAX_f32_32x32(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTMAX_f32_32x32(float *a, float *b, float *c, void *stream) { + TMAX_f32_32x32<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmax/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tmax/main.cpp new file mode 100644 index 000000000..3dd9859a5 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmax/main.cpp @@ -0,0 +1,145 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tmax ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTMAX_f32_16x64(float *a, float *b, float *c, void *stream); +void LaunchTMAX_f32_32x32(float *a, float *b, float *c, void *stream); + +using LaunchFn = void (*)(float *, float *, float *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_16x64", LaunchTMAX_f32_16x64, 16, 64, 16, 64, sizeof(float)}, + {"f32_32x32", LaunchTMAX_f32_32x32, 32, 32, 32, 32, sizeof(float)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t src0FileSize = fileSize; + size_t src1FileSize = fileSize; + + float *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + float *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), fileSize); + aclrtMallocHost((void **)(&src1Host), fileSize); + aclrtMallocHost((void **)(&dstHost), fileSize); + + aclrtMalloc((void **)&src0Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, fileSize, src0Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, fileSize, src1Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tmax [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmax/tmax.pto b/test/tilelang_st/npu/a5/src/st/testcase/tmax/tmax.pto new file mode 100644 index 000000000..4a462fc2d --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmax/tmax.pto @@ -0,0 +1,123 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tmax: tload(a) + tload(b) + tmax(a,b)->c + tstore(c). +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // Case 0: f32 16x64 (1024 elements) + func.func @TMAX_f32_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tmax ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + return + } + + // Case 1: f32 32x32 (1024 elements) + func.func @TMAX_f32_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + outs(%b : !pto.tile_buf) + + pto.tmax ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmaxs/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tmaxs/CMakeLists.txt new file mode 100644 index 000000000..a540c4c13 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmaxs/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tmaxs) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmaxs/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tmaxs/cases.py new file mode 100644 index 000000000..d3bab221b --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmaxs/cases.py @@ -0,0 +1,22 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tmaxs ST test cases.""" + +import numpy as np + +CASES = [ + {"name": "f32_32x64", "dtype": np.float32, "shape": (32, 64), "valid_shape": (32, 64), "eps": 1e-6}, + {"name": "f16_63x64", "dtype": np.float16, "shape": (63, 64), "valid_shape": (63, 64), "eps": 1e-3}, + {"name": "i32_31x128", "dtype": np.int32, "shape": (31, 128), "valid_shape": (31, 128), "eps": 0}, + {"name": "i16_15x192", "dtype": np.int16, "shape": (15, 192), "valid_shape": (15, 192), "eps": 0}, + {"name": "f32_7x448", "dtype": np.float32, "shape": (7, 448), "valid_shape": (7, 448), "eps": 1e-6}, + {"name": "f32_256x16", "dtype": np.float32, "shape": (256, 16), "valid_shape": (256, 16), "eps": 1e-6}, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmaxs/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tmaxs/compare.py new file mode 100644 index 000000000..50186777e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmaxs/compare.py @@ -0,0 +1,46 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmaxs/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tmaxs/gen_data.py new file mode 100644 index 000000000..10520c68b --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmaxs/gen_data.py @@ -0,0 +1,35 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +# Scalar value used for element-wise maximum (matches the scalar passed in launch.cpp) +SCALAR = 5.0 + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + scalar_val = dtype(SCALAR) + golden[:vr, :vc] = np.maximum(input1[:vr, :vc], scalar_val).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__} scalar={SCALAR}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmaxs/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tmaxs/launch.cpp new file mode 100644 index 000000000..793db13f1 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmaxs/launch.cpp @@ -0,0 +1,58 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Scalar value used for element-wise maximum (must match gen_data.py SCALAR) +static constexpr float TMAXS_SCALAR_F32 = 5.0f; + +// Case 0: f32 32x64 +extern "C" __global__ AICORE void TMAXS_f32_32x64(__gm__ float *src, __gm__ float *dst, float scalar); + +void LaunchTMAXS_f32_32x64(float *src, float *dst, void *stream) { + TMAXS_f32_32x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, TMAXS_SCALAR_F32); +} + +// Case 1: f16 63x64 +extern "C" __global__ AICORE void TMAXS_f16_63x64(__gm__ unsigned short *src, __gm__ unsigned short *dst, unsigned short scalar); + +void LaunchTMAXS_f16_63x64(unsigned short *src, unsigned short *dst, void *stream) { + TMAXS_f16_63x64<<<1, nullptr, stream>>>((__gm__ unsigned short *)src, (__gm__ unsigned short *)dst, (unsigned short)0x4500); +} + +// Case 2: i32 31x128 +extern "C" __global__ AICORE void TMAXS_i32_31x128(__gm__ int32_t *src, __gm__ int32_t *dst, int32_t scalar); + +void LaunchTMAXS_i32_31x128(int32_t *src, int32_t *dst, void *stream) { + TMAXS_i32_31x128<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst, (int32_t)5); +} + +// Case 3: i16 15x192 +extern "C" __global__ AICORE void TMAXS_i16_15x192(__gm__ int16_t *src, __gm__ int16_t *dst, int16_t scalar); + +void LaunchTMAXS_i16_15x192(int16_t *src, int16_t *dst, void *stream) { + TMAXS_i16_15x192<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst, (int16_t)5); +} + +// Case 4: f32 7x448 +extern "C" __global__ AICORE void TMAXS_f32_7x448(__gm__ float *src, __gm__ float *dst, float scalar); + +void LaunchTMAXS_f32_7x448(float *src, float *dst, void *stream) { + TMAXS_f32_7x448<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, TMAXS_SCALAR_F32); +} + +// Case 5: f32 256x16 +extern "C" __global__ AICORE void TMAXS_f32_256x16(__gm__ float *src, __gm__ float *dst, float scalar); + +void LaunchTMAXS_f32_256x16(float *src, float *dst, void *stream) { + TMAXS_f32_256x16<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, TMAXS_SCALAR_F32); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmaxs/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tmaxs/main.cpp new file mode 100644 index 000000000..7104ff7ad --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmaxs/main.cpp @@ -0,0 +1,139 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tmaxs ST — case-table driven. +// tmaxs: dst = max(src, scalar) (single input + scalar). +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTMAXS_f32_32x64(float *src, float *dst, void *stream); +void LaunchTMAXS_f16_63x64(uint16_t *src, uint16_t *dst, void *stream); +void LaunchTMAXS_i32_31x128(int32_t *src, int32_t *dst, void *stream); +void LaunchTMAXS_i16_15x192(int16_t *src, int16_t *dst, void *stream); +void LaunchTMAXS_f32_7x448(float *src, float *dst, void *stream); +void LaunchTMAXS_f32_256x16(float *src, float *dst, void *stream); + +struct TestCase { + const char *name; + void (*launch)(void *, void *, void *); // src, dst, stream + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_32x64", (void (*)(void*,void*,void*))LaunchTMAXS_f32_32x64, 32, 64, 32, 64, sizeof(float)}, + {"f16_63x64", (void (*)(void*,void*,void*))LaunchTMAXS_f16_63x64, 63, 64, 63, 64, sizeof(uint16_t)}, + {"i32_31x128", (void (*)(void*,void*,void*))LaunchTMAXS_i32_31x128, 31, 128, 31, 128, sizeof(int32_t)}, + {"i16_15x192", (void (*)(void*,void*,void*))LaunchTMAXS_i16_15x192, 15, 192, 15, 192, sizeof(int16_t)}, + {"f32_7x448", (void (*)(void*,void*,void*))LaunchTMAXS_f32_7x448, 7, 448, 7, 448, sizeof(float)}, + {"f32_256x16", (void (*)(void*,void*,void*))LaunchTMAXS_f32_256x16, 256, 16, 256, 16, sizeof(float)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSize = fileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&srcHost, fileSize); + aclrtMallocHost(&dstHost, fileSize); + + aclrtMalloc(&srcDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), srcFileSize, srcHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, fileSize, srcHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tmaxs [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmaxs/tmaxs.pto b/test/tilelang_st/npu/a5/src/st/testcase/tmaxs/tmaxs.pto new file mode 100644 index 000000000..8adf6cb4c --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmaxs/tmaxs.pto @@ -0,0 +1,256 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tmaxs: tload(src) + tmaxs(src, scalar)->dst + tstore(dst). +// Multiple cases with different shapes/dtypes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + + // Case 0: f32 32x64 (2048 elements) + func.func @TMAXS_f32_32x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) + outs(%src : !pto.tile_buf) + pto.tmaxs ins(%src, %scalar : !pto.tile_buf, f32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) + return + } + + // Case 1: f16 63x64 (4032 elements) + func.func @TMAXS_f16_63x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f16) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c4032 = arith.constant 4032 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xf16> -> !pto.partition_tensor_view<1x1x1x63x64xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xf16> -> !pto.partition_tensor_view<1x1x1x63x64xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x64xf16>) + outs(%src : !pto.tile_buf) + pto.tmaxs ins(%src, %scalar : !pto.tile_buf, f16) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x64xf16>) + return + } + + // Case 2: i32 31x128 (3968 elements) + func.func @TMAXS_i32_31x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c31 = arith.constant 31 : index + %c128 = arith.constant 128 : index + %c3968 = arith.constant 3968 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c128] + : !pto.tensor_view<1x1x1x31x128xi32> -> !pto.partition_tensor_view<1x1x1x31x128xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c128] + : !pto.tensor_view<1x1x1x31x128xi32> -> !pto.partition_tensor_view<1x1x1x31x128xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x31x128xi32>) + outs(%src : !pto.tile_buf) + pto.tmaxs ins(%src, %scalar : !pto.tile_buf, i32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x31x128xi32>) + return + } + + // Case 3: i16 15x192 (2880 elements) + func.func @TMAXS_i16_15x192(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i16) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c192 = arith.constant 192 : index + %c2880 = arith.constant 2880 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xi16> -> !pto.partition_tensor_view<1x1x1x15x192xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xi16> -> !pto.partition_tensor_view<1x1x1x15x192xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x192xi16>) + outs(%src : !pto.tile_buf) + pto.tmaxs ins(%src, %scalar : !pto.tile_buf, i16) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x15x192xi16>) + return + } + + // Case 4: f32 7x448 (3136 elements) + func.func @TMAXS_f32_7x448(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c7 = arith.constant 7 : index + %c448 = arith.constant 448 : index + %c3136 = arith.constant 3136 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c7, %c448], + strides = [%c3136, %c3136, %c3136, %c448, %c1] + : !pto.tensor_view<1x1x1x7x448xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c7, %c448], + strides = [%c3136, %c3136, %c3136, %c448, %c1] + : !pto.tensor_view<1x1x1x7x448xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c448] + : !pto.tensor_view<1x1x1x7x448xf32> -> !pto.partition_tensor_view<1x1x1x7x448xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c448] + : !pto.tensor_view<1x1x1x7x448xf32> -> !pto.partition_tensor_view<1x1x1x7x448xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x7x448xf32>) + outs(%src : !pto.tile_buf) + pto.tmaxs ins(%src, %scalar : !pto.tile_buf, f32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x7x448xf32>) + return + } + + // Case 5: f32 256x16 (4096 elements) + func.func @TMAXS_f32_256x16(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c16 = arith.constant 16 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c256, %c16], + strides = [%c4096, %c4096, %c4096, %c16, %c1] + : !pto.tensor_view<1x1x1x256x16xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c256, %c16], + strides = [%c4096, %c4096, %c4096, %c16, %c1] + : !pto.tensor_view<1x1x1x256x16xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c16] + : !pto.tensor_view<1x1x1x256x16xf32> -> !pto.partition_tensor_view<1x1x1x256x16xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c16] + : !pto.tensor_view<1x1x1x256x16xf32> -> !pto.partition_tensor_view<1x1x1x256x16xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x256x16xf32>) + outs(%src : !pto.tile_buf) + pto.tmaxs ins(%src, %scalar : !pto.tile_buf, f32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x256x16xf32>) + return + } + +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmin/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tmin/CMakeLists.txt new file mode 100644 index 000000000..f811b5f04 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmin/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tmin) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmin/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tmin/cases.py new file mode 100644 index 000000000..15bbb58ea --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmin/cases.py @@ -0,0 +1,82 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tmin ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_64x64", + "dtype": np.float32, + "shape": (64, 64), + "valid_shape": (64, 64), + "eps": 1e-6, + }, + { + "name": "i32_64x64", + "dtype": np.int32, + "shape": (64, 64), + "valid_shape": (64, 64), + "eps": 0, + }, + { + "name": "i16_64x64", + "dtype": np.int16, + "shape": (64, 64), + "valid_shape": (64, 64), + "eps": 0, + }, + { + "name": "f16_64x64", + "dtype": np.float16, + "shape": (64, 64), + "valid_shape": (64, 64), + "eps": 1e-3, + }, + { + "name": "f32_64x64_v60x60", + "dtype": np.float32, + "shape": (64, 64), + "valid_shape": (60, 60), + "eps": 1e-6, + }, + { + "name": "i32_64x64_v60x60", + "dtype": np.int32, + "shape": (64, 64), + "valid_shape": (60, 60), + "eps": 0, + }, + { + "name": "f16_2x4096_v1x3600", + "dtype": np.float16, + "shape": (2, 4096), + "valid_shape": (1, 3600), + "eps": 1e-3, + }, + { + "name": "i16_20x512_v16x200", + "dtype": np.int16, + "shape": (20, 512), + "valid_shape": (16, 200), + "eps": 0, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmin/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tmin/compare.py new file mode 100644 index 000000000..4eae3bc07 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmin/compare.py @@ -0,0 +1,48 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmin/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tmin/gen_data.py new file mode 100644 index 000000000..0c72ecbc9 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmin/gen_data.py @@ -0,0 +1,32 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + input2 = np.random.randint(1, 10, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + golden[:vr, :vc] = np.minimum(input1[:vr, :vc], input2[:vr, :vc]).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmin/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tmin/launch.cpp new file mode 100644 index 000000000..95247e512 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmin/launch.cpp @@ -0,0 +1,69 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: f32 64x64 +extern "C" __global__ AICORE void TMIN_f32_64x64(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTMIN_f32_64x64(void *a, void *b, void *c, void *stream) { + TMIN_f32_64x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 1: i32 64x64 +extern "C" __global__ AICORE void TMIN_i32_64x64(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); + +void LaunchTMIN_i32_64x64(void *a, void *b, void *c, void *stream) { + TMIN_i32_64x64<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); +} + +// Case 2: i16 64x64 +extern "C" __global__ AICORE void TMIN_i16_64x64(__gm__ int16_t *a, __gm__ int16_t *b, __gm__ int16_t *c); + +void LaunchTMIN_i16_64x64(void *a, void *b, void *c, void *stream) { + TMIN_i16_64x64<<<1, nullptr, stream>>>((__gm__ int16_t *)a, (__gm__ int16_t *)b, (__gm__ int16_t *)c); +} + +// Case 3: f16 64x64 +extern "C" __global__ AICORE void TMIN_f16_64x64(__gm__ half *a, __gm__ half *b, __gm__ half *c); + +void LaunchTMIN_f16_64x64(void *a, void *b, void *c, void *stream) { + TMIN_f16_64x64<<<1, nullptr, stream>>>((__gm__ half *)a, (__gm__ half *)b, (__gm__ half *)c); +} + +// Case 4: f32 64x64 v60x60 +extern "C" __global__ AICORE void TMIN_f32_64x64_v60x60(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTMIN_f32_64x64_v60x60(void *a, void *b, void *c, void *stream) { + TMIN_f32_64x64_v60x60<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 5: i32 64x64 v60x60 +extern "C" __global__ AICORE void TMIN_i32_64x64_v60x60(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); + +void LaunchTMIN_i32_64x64_v60x60(void *a, void *b, void *c, void *stream) { + TMIN_i32_64x64_v60x60<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); +} + +// Case 6: f16 2x4096 v1x3600 +extern "C" __global__ AICORE void TMIN_f16_2x4096_v1x3600(__gm__ half *a, __gm__ half *b, __gm__ half *c); + +void LaunchTMIN_f16_2x4096_v1x3600(void *a, void *b, void *c, void *stream) { + TMIN_f16_2x4096_v1x3600<<<1, nullptr, stream>>>((__gm__ half *)a, (__gm__ half *)b, (__gm__ half *)c); +} + +// Case 7: i16 20x512 v16x200 +extern "C" __global__ AICORE void TMIN_i16_20x512_v16x200(__gm__ int16_t *a, __gm__ int16_t *b, __gm__ int16_t *c); + +void LaunchTMIN_i16_20x512_v16x200(void *a, void *b, void *c, void *stream) { + TMIN_i16_20x512_v16x200<<<1, nullptr, stream>>>((__gm__ int16_t *)a, (__gm__ int16_t *)b, (__gm__ int16_t *)c); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmin/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tmin/main.cpp new file mode 100644 index 000000000..214a35a22 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmin/main.cpp @@ -0,0 +1,158 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tand ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTMIN_f32_64x64(void *a, void *b, void *c, void *stream); +void LaunchTMIN_i32_64x64(void *a, void *b, void *c, void *stream); +void LaunchTMIN_i16_64x64(void *a, void *b, void *c, void *stream); +void LaunchTMIN_f16_64x64(void *a, void *b, void *c, void *stream); +void LaunchTMIN_f32_64x64_v60x60(void *a, void *b, void *c, void *stream); +void LaunchTMIN_i32_64x64_v60x60(void *a, void *b, void *c, void *stream); +void LaunchTMIN_f16_2x4096_v1x3600(void *a, void *b, void *c, void *stream); +void LaunchTMIN_i16_20x512_v16x200(void *a, void *b, void *c, void *stream); + +using LaunchFn = void (*)(void *, void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_64x64", LaunchTMIN_f32_64x64, 64, 64, 64, 64, sizeof(float)}, + {"i32_64x64", LaunchTMIN_i32_64x64, 64, 64, 64, 64, sizeof(int32_t)}, + {"i16_64x64", LaunchTMIN_i16_64x64, 64, 64, 64, 64, sizeof(int16_t)}, + {"f16_64x64", LaunchTMIN_f16_64x64, 64, 64, 64, 64, sizeof(uint16_t)}, + {"f32_64x64_v60x60", LaunchTMIN_f32_64x64_v60x60, 64, 64, 60, 60, sizeof(float)}, + {"i32_64x64_v60x60", LaunchTMIN_i32_64x64_v60x60, 64, 64, 60, 60, sizeof(int32_t)}, + {"f16_2x4096_v1x3600", LaunchTMIN_f16_2x4096_v1x3600, 2, 4096, 1, 3600, sizeof(uint16_t)}, + {"i16_20x512_v16x200", LaunchTMIN_i16_20x512_v16x200, 20, 512, 16, 200, sizeof(int16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + (void)deviceId; + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t src0FileSize = fileSize; + size_t src1FileSize = fileSize; + + void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&src0Host, fileSize); + aclrtMallocHost(&src1Host, fileSize); + aclrtMallocHost(&dstHost, fileSize); + + aclrtMalloc(&src0Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&src1Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, fileSize, src0Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, fileSize, src1Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tmin [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmin/tmin.pto b/test/tilelang_st/npu/a5/src/st/testcase/tmin/tmin.pto new file mode 100644 index 000000000..e7ba6d9b7 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmin/tmin.pto @@ -0,0 +1,450 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tmin: tload(a) + tload(b) + tmin(a,b)->c + tstore(c). +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // Case 0: f32 64x64 (4096 elements) + func.func @TMIN_f32_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tmin ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + return + } + + // Case 1: i32 64x64 (4096 elements) + func.func @TMIN_i32_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) + outs(%b : !pto.tile_buf) + + pto.tmin ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) + return + } + + // Case 2: i16 64x64 (4096 elements) + func.func @TMIN_i16_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) + outs(%b : !pto.tile_buf) + + pto.tmin ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) + return + } + + // Case 3: f16 64x64 (4096 elements) + func.func @TMIN_f16_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf16> -> !pto.partition_tensor_view<1x1x1x64x64xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf16> -> !pto.partition_tensor_view<1x1x1x64x64xf16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf16> -> !pto.partition_tensor_view<1x1x1x64x64xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf16>) + outs(%b : !pto.tile_buf) + + pto.tmin ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xf16>) + return + } + + // Case 4: f32 64x64 tile with 60x60 valid region (padding with MAX for tmin) + func.func @TMIN_f32_64x64_v60x60(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c60 = arith.constant 60 : index + %c64 = arith.constant 64 : index + %c3600 = arith.constant 3600 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c60, %c60], + strides = [%c3600, %c3600, %c3600, %c64, %c1] + : !pto.tensor_view<1x1x1x60x60xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c60, %c60], + strides = [%c3600, %c3600, %c3600, %c64, %c1] + : !pto.tensor_view<1x1x1x60x60xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c60, %c60], + strides = [%c3600, %c3600, %c3600, %c64, %c1] + : !pto.tensor_view<1x1x1x60x60xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c60, %c60] + : !pto.tensor_view<1x1x1x60x60xf32> -> !pto.partition_tensor_view<1x1x1x60x60xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c60, %c60] + : !pto.tensor_view<1x1x1x60x60xf32> -> !pto.partition_tensor_view<1x1x1x60x60xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c60, %c60] + : !pto.tensor_view<1x1x1x60x60xf32> -> !pto.partition_tensor_view<1x1x1x60x60xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x60x60xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x60x60xf32>) + outs(%b : !pto.tile_buf) + + pto.tmin ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x60x60xf32>) + return + } + + // Case 5: i32 64x64 tile with 60x60 valid region (padding with MAX for tmin) + func.func @TMIN_i32_64x64_v60x60(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c60 = arith.constant 60 : index + %c64 = arith.constant 64 : index + %c3600 = arith.constant 3600 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c60, %c60], + strides = [%c3600, %c3600, %c3600, %c64, %c1] + : !pto.tensor_view<1x1x1x60x60xi32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c60, %c60], + strides = [%c3600, %c3600, %c3600, %c64, %c1] + : !pto.tensor_view<1x1x1x60x60xi32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c60, %c60], + strides = [%c3600, %c3600, %c3600, %c64, %c1] + : !pto.tensor_view<1x1x1x60x60xi32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c60, %c60] + : !pto.tensor_view<1x1x1x60x60xi32> -> !pto.partition_tensor_view<1x1x1x60x60xi32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c60, %c60] + : !pto.tensor_view<1x1x1x60x60xi32> -> !pto.partition_tensor_view<1x1x1x60x60xi32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c60, %c60] + : !pto.tensor_view<1x1x1x60x60xi32> -> !pto.partition_tensor_view<1x1x1x60x60xi32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x60x60xi32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x60x60xi32>) + outs(%b : !pto.tile_buf) + + pto.tmin ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x60x60xi32>) + return + } + + // Case 6: f16 2x4096 tile with 1x3600 valid region (padding with MAX for tmin) + func.func @TMIN_f16_2x4096_v1x3600(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3600 = arith.constant 3600 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c1, %c3600], + strides = [%c3600, %c3600, %c3600, %c4096, %c1] + : !pto.tensor_view<1x1x1x1x3600xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c1, %c3600], + strides = [%c3600, %c3600, %c3600, %c4096, %c1] + : !pto.tensor_view<1x1x1x1x3600xf16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c1, %c3600], + strides = [%c3600, %c3600, %c3600, %c4096, %c1] + : !pto.tensor_view<1x1x1x1x3600xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c3600] + : !pto.tensor_view<1x1x1x1x3600xf16> -> !pto.partition_tensor_view<1x1x1x1x3600xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c3600] + : !pto.tensor_view<1x1x1x1x3600xf16> -> !pto.partition_tensor_view<1x1x1x1x3600xf16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c3600] + : !pto.tensor_view<1x1x1x1x3600xf16> -> !pto.partition_tensor_view<1x1x1x1x3600xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x1x3600xf16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x1x3600xf16>) + outs(%b : !pto.tile_buf) + + pto.tmin ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x1x3600xf16>) + return + } + + // Case 7: i16 20x512 tile with 16x200 valid region (padding with MAX for tmin) + func.func @TMIN_i16_20x512_v16x200(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c200 = arith.constant 200 : index + %c512 = arith.constant 512 : index + %c3200 = arith.constant 3200 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c200], + strides = [%c3200, %c3200, %c3200, %c512, %c1] + : !pto.tensor_view<1x1x1x16x200xi16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c200], + strides = [%c3200, %c3200, %c3200, %c512, %c1] + : !pto.tensor_view<1x1x1x16x200xi16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c16, %c200], + strides = [%c3200, %c3200, %c3200, %c512, %c1] + : !pto.tensor_view<1x1x1x16x200xi16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c200] + : !pto.tensor_view<1x1x1x16x200xi16> -> !pto.partition_tensor_view<1x1x1x16x200xi16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c200] + : !pto.tensor_view<1x1x1x16x200xi16> -> !pto.partition_tensor_view<1x1x1x16x200xi16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c200] + : !pto.tensor_view<1x1x1x16x200xi16> -> !pto.partition_tensor_view<1x1x1x16x200xi16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x200xi16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x200xi16>) + outs(%b : !pto.tile_buf) + + pto.tmin ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x16x200xi16>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmins/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tmins/CMakeLists.txt new file mode 100644 index 000000000..038d4e327 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmins/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tmins) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmins/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tmins/cases.py new file mode 100644 index 000000000..4526c4182 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmins/cases.py @@ -0,0 +1,22 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tmins ST test cases.""" + +import numpy as np + +CASES = [ + {"name": "f32_32x64", "dtype": np.float32, "shape": (32, 64), "valid_shape": (32, 64), "eps": 1e-6}, + {"name": "f16_63x64", "dtype": np.float16, "shape": (63, 64), "valid_shape": (63, 64), "eps": 1e-3}, + {"name": "i32_31x128", "dtype": np.int32, "shape": (31, 128), "valid_shape": (31, 128), "eps": 0}, + {"name": "i16_15x192", "dtype": np.int16, "shape": (15, 192), "valid_shape": (15, 192), "eps": 0}, + {"name": "f32_7x448", "dtype": np.float32, "shape": (7, 448), "valid_shape": (7, 448), "eps": 1e-6}, + {"name": "f32_256x16", "dtype": np.float32, "shape": (256, 16), "valid_shape": (256, 16), "eps": 1e-6}, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmins/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tmins/compare.py new file mode 100644 index 000000000..50186777e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmins/compare.py @@ -0,0 +1,46 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmins/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tmins/gen_data.py new file mode 100644 index 000000000..84da39655 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmins/gen_data.py @@ -0,0 +1,35 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +# Scalar value used for element-wise minimum (matches the scalar passed in launch.cpp) +SCALAR = 5.0 + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + scalar_val = dtype(SCALAR) + golden[:vr, :vc] = np.minimum(input1[:vr, :vc], scalar_val).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__} scalar={SCALAR}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmins/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tmins/launch.cpp new file mode 100644 index 000000000..65d44ffc4 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmins/launch.cpp @@ -0,0 +1,58 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Scalar value used for element-wise minimum (must match gen_data.py SCALAR) +static constexpr float TMINS_SCALAR_F32 = 5.0f; + +// Case 0: f32 32x64 +extern "C" __global__ AICORE void TMINS_f32_32x64(__gm__ float *src, __gm__ float *dst, float scalar); + +void LaunchTMINS_f32_32x64(float *src, float *dst, void *stream) { + TMINS_f32_32x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, TMINS_SCALAR_F32); +} + +// Case 1: f16 63x64 +extern "C" __global__ AICORE void TMINS_f16_63x64(__gm__ unsigned short *src, __gm__ unsigned short *dst, unsigned short scalar); + +void LaunchTMINS_f16_63x64(unsigned short *src, unsigned short *dst, void *stream) { + TMINS_f16_63x64<<<1, nullptr, stream>>>((__gm__ unsigned short *)src, (__gm__ unsigned short *)dst, (unsigned short)0x4500); +} + +// Case 2: i32 31x128 +extern "C" __global__ AICORE void TMINS_i32_31x128(__gm__ int32_t *src, __gm__ int32_t *dst, int32_t scalar); + +void LaunchTMINS_i32_31x128(int32_t *src, int32_t *dst, void *stream) { + TMINS_i32_31x128<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst, (int32_t)5); +} + +// Case 3: i16 15x192 +extern "C" __global__ AICORE void TMINS_i16_15x192(__gm__ int16_t *src, __gm__ int16_t *dst, int16_t scalar); + +void LaunchTMINS_i16_15x192(int16_t *src, int16_t *dst, void *stream) { + TMINS_i16_15x192<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst, (int16_t)5); +} + +// Case 4: f32 7x448 +extern "C" __global__ AICORE void TMINS_f32_7x448(__gm__ float *src, __gm__ float *dst, float scalar); + +void LaunchTMINS_f32_7x448(float *src, float *dst, void *stream) { + TMINS_f32_7x448<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, TMINS_SCALAR_F32); +} + +// Case 5: f32 256x16 +extern "C" __global__ AICORE void TMINS_f32_256x16(__gm__ float *src, __gm__ float *dst, float scalar); + +void LaunchTMINS_f32_256x16(float *src, float *dst, void *stream) { + TMINS_f32_256x16<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, TMINS_SCALAR_F32); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmins/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tmins/main.cpp new file mode 100644 index 000000000..9fd09e48e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmins/main.cpp @@ -0,0 +1,139 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tmins ST — case-table driven. +// tmins: dst = min(src, scalar) (single input + scalar). +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTMINS_f32_32x64(float *src, float *dst, void *stream); +void LaunchTMINS_f16_63x64(uint16_t *src, uint16_t *dst, void *stream); +void LaunchTMINS_i32_31x128(int32_t *src, int32_t *dst, void *stream); +void LaunchTMINS_i16_15x192(int16_t *src, int16_t *dst, void *stream); +void LaunchTMINS_f32_7x448(float *src, float *dst, void *stream); +void LaunchTMINS_f32_256x16(float *src, float *dst, void *stream); + +struct TestCase { + const char *name; + void (*launch)(void *, void *, void *); // src, dst, stream + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_32x64", (void (*)(void*,void*,void*))LaunchTMINS_f32_32x64, 32, 64, 32, 64, sizeof(float)}, + {"f16_63x64", (void (*)(void*,void*,void*))LaunchTMINS_f16_63x64, 63, 64, 63, 64, sizeof(uint16_t)}, + {"i32_31x128", (void (*)(void*,void*,void*))LaunchTMINS_i32_31x128, 31, 128, 31, 128, sizeof(int32_t)}, + {"i16_15x192", (void (*)(void*,void*,void*))LaunchTMINS_i16_15x192, 15, 192, 15, 192, sizeof(int16_t)}, + {"f32_7x448", (void (*)(void*,void*,void*))LaunchTMINS_f32_7x448, 7, 448, 7, 448, sizeof(float)}, + {"f32_256x16", (void (*)(void*,void*,void*))LaunchTMINS_f32_256x16, 256, 16, 256, 16, sizeof(float)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSize = fileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&srcHost, fileSize); + aclrtMallocHost(&dstHost, fileSize); + + aclrtMalloc(&srcDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), srcFileSize, srcHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, fileSize, srcHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tmins [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmins/tmins.pto b/test/tilelang_st/npu/a5/src/st/testcase/tmins/tmins.pto new file mode 100644 index 000000000..bb0f6b850 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmins/tmins.pto @@ -0,0 +1,256 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tmins: tload(src) + tmins(src, scalar)->dst + tstore(dst). +// Multiple cases with different shapes/dtypes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + + // Case 0: f32 32x64 (2048 elements) + func.func @TMINS_f32_32x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) + outs(%src : !pto.tile_buf) + pto.tmins ins(%src, %scalar : !pto.tile_buf, f32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) + return + } + + // Case 1: f16 63x64 (4032 elements) + func.func @TMINS_f16_63x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f16) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c4032 = arith.constant 4032 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xf16> -> !pto.partition_tensor_view<1x1x1x63x64xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xf16> -> !pto.partition_tensor_view<1x1x1x63x64xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x64xf16>) + outs(%src : !pto.tile_buf) + pto.tmins ins(%src, %scalar : !pto.tile_buf, f16) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x64xf16>) + return + } + + // Case 2: i32 31x128 (3968 elements) + func.func @TMINS_i32_31x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c31 = arith.constant 31 : index + %c128 = arith.constant 128 : index + %c3968 = arith.constant 3968 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c128] + : !pto.tensor_view<1x1x1x31x128xi32> -> !pto.partition_tensor_view<1x1x1x31x128xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c128] + : !pto.tensor_view<1x1x1x31x128xi32> -> !pto.partition_tensor_view<1x1x1x31x128xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x31x128xi32>) + outs(%src : !pto.tile_buf) + pto.tmins ins(%src, %scalar : !pto.tile_buf, i32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x31x128xi32>) + return + } + + // Case 3: i16 15x192 (2880 elements) + func.func @TMINS_i16_15x192(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i16) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c192 = arith.constant 192 : index + %c2880 = arith.constant 2880 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xi16> -> !pto.partition_tensor_view<1x1x1x15x192xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xi16> -> !pto.partition_tensor_view<1x1x1x15x192xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x192xi16>) + outs(%src : !pto.tile_buf) + pto.tmins ins(%src, %scalar : !pto.tile_buf, i16) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x15x192xi16>) + return + } + + // Case 4: f32 7x448 (3136 elements) + func.func @TMINS_f32_7x448(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c7 = arith.constant 7 : index + %c448 = arith.constant 448 : index + %c3136 = arith.constant 3136 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c7, %c448], + strides = [%c3136, %c3136, %c3136, %c448, %c1] + : !pto.tensor_view<1x1x1x7x448xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c7, %c448], + strides = [%c3136, %c3136, %c3136, %c448, %c1] + : !pto.tensor_view<1x1x1x7x448xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c448] + : !pto.tensor_view<1x1x1x7x448xf32> -> !pto.partition_tensor_view<1x1x1x7x448xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c448] + : !pto.tensor_view<1x1x1x7x448xf32> -> !pto.partition_tensor_view<1x1x1x7x448xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x7x448xf32>) + outs(%src : !pto.tile_buf) + pto.tmins ins(%src, %scalar : !pto.tile_buf, f32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x7x448xf32>) + return + } + + // Case 5: f32 256x16 (4096 elements) + func.func @TMINS_f32_256x16(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c16 = arith.constant 16 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c256, %c16], + strides = [%c4096, %c4096, %c4096, %c16, %c1] + : !pto.tensor_view<1x1x1x256x16xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c256, %c16], + strides = [%c4096, %c4096, %c4096, %c16, %c1] + : !pto.tensor_view<1x1x1x256x16xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c16] + : !pto.tensor_view<1x1x1x256x16xf32> -> !pto.partition_tensor_view<1x1x1x256x16xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c16] + : !pto.tensor_view<1x1x1x256x16xf32> -> !pto.partition_tensor_view<1x1x1x256x16xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x256x16xf32>) + outs(%src : !pto.tile_buf) + pto.tmins ins(%src, %scalar : !pto.tile_buf, f32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x256x16xf32>) + return + } + +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmov/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tmov/CMakeLists.txt new file mode 100644 index 000000000..019646b0c --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmov/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tmov) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmov/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tmov/cases.py new file mode 100644 index 000000000..3c6f4916c --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmov/cases.py @@ -0,0 +1,98 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tmov ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - eps: tolerance for numpy.allclose (atol and rtol). + +Based on pto-isa tmov_vect test cases: + - float, half, uint8 types + - shapes: 64x64, 32x32, 128x128, 128x32, 128x64 +""" + +import numpy as np + +CASES = [ + { + "name": "f32_64x64", + "dtype": np.float32, + "shape": (64, 64), + "valid_shape": (64, 64), + "eps": 1e-6, + }, + { + "name": "f32_32x32", + "dtype": np.float32, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-6, + }, + { + "name": "f32_128x128", + "dtype": np.float32, + "shape": (128, 128), + "valid_shape": (128, 128), + "eps": 1e-6, + }, + { + "name": "f32_128x32", + "dtype": np.float32, + "shape": (128, 32), + "valid_shape": (128, 32), + "eps": 1e-6, + }, + { + "name": "f32_128x64", + "dtype": np.float32, + "shape": (128, 64), + "valid_shape": (128, 64), + "eps": 1e-6, + }, + { + "name": "f16_64x64", + "dtype": np.float16, + "shape": (64, 64), + "valid_shape": (64, 64), + "eps": 1e-3, + }, + { + "name": "f16_32x32", + "dtype": np.float16, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-3, + }, + { + "name": "f16_128x128", + "dtype": np.float16, + "shape": (128, 128), + "valid_shape": (128, 128), + "eps": 1e-3, + }, + { + "name": "u8_64x64", + "dtype": np.uint8, + "shape": (64, 64), + "valid_shape": (64, 64), + "eps": 0, + }, + { + "name": "u8_128x128", + "dtype": np.uint8, + "shape": (128, 128), + "valid_shape": (128, 128), + "eps": 0, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmov/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tmov/compare.py new file mode 100644 index 000000000..8598a9b30 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmov/compare.py @@ -0,0 +1,50 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Compare device output against golden for tmov ST test cases.""" + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmov/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tmov/gen_data.py new file mode 100644 index 000000000..e145ef579 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmov/gen_data.py @@ -0,0 +1,56 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Generate input and golden data for tmov ST test cases. + +For tmov (Vec-to-Vec data movement): + - input: source tile data + - golden: exact copy of source tile (valid_shape region) +""" + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + # Generate random input data + if dtype == np.uint8: + input_data = np.random.randint(0, 256, size=shape).astype(dtype) + else: + input_data = np.random.rand(*shape).astype(dtype) + + # Golden is exact copy of input (valid_shape region) + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + golden[:vr, :vc] = input_data[:vr, :vc].copy() + + save_case_data(case["name"], {"input": input_data, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmov/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tmov/launch.cpp new file mode 100644 index 000000000..a88b8dbfc --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmov/launch.cpp @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: f32 64x64 +extern "C" __global__ AICORE void TMOV_f32_64x64(__gm__ float *src, __gm__ float *dst); + +void LaunchTMOV_f32_64x64(float *src, float *dst, void *stream) { + TMOV_f32_64x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +// Case 1: f32 32x32 +extern "C" __global__ AICORE void TMOV_f32_32x32(__gm__ float *src, __gm__ float *dst); + +void LaunchTMOV_f32_32x32(float *src, float *dst, void *stream) { + TMOV_f32_32x32<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +// Case 2: f32 128x128 +extern "C" __global__ AICORE void TMOV_f32_128x128(__gm__ float *src, __gm__ float *dst); + +void LaunchTMOV_f32_128x128(float *src, float *dst, void *stream) { + TMOV_f32_128x128<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +// Case 3: f32 128x32 +extern "C" __global__ AICORE void TMOV_f32_128x32(__gm__ float *src, __gm__ float *dst); + +void LaunchTMOV_f32_128x32(float *src, float *dst, void *stream) { + TMOV_f32_128x32<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +// Case 4: f32 128x64 +extern "C" __global__ AICORE void TMOV_f32_128x64(__gm__ float *src, __gm__ float *dst); + +void LaunchTMOV_f32_128x64(float *src, float *dst, void *stream) { + TMOV_f32_128x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +// Case 5: f16 64x64 +extern "C" __global__ AICORE void TMOV_f16_64x64(__gm__ uint16_t *src, __gm__ uint16_t *dst); + +void LaunchTMOV_f16_64x64(uint16_t *src, uint16_t *dst, void *stream) { + TMOV_f16_64x64<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint16_t *)dst); +} + +// Case 6: f16 32x32 +extern "C" __global__ AICORE void TMOV_f16_32x32(__gm__ uint16_t *src, __gm__ uint16_t *dst); + +void LaunchTMOV_f16_32x32(uint16_t *src, uint16_t *dst, void *stream) { + TMOV_f16_32x32<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint16_t *)dst); +} + +// Case 7: f16 128x128 +extern "C" __global__ AICORE void TMOV_f16_128x128(__gm__ uint16_t *src, __gm__ uint16_t *dst); + +void LaunchTMOV_f16_128x128(uint16_t *src, uint16_t *dst, void *stream) { + TMOV_f16_128x128<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint16_t *)dst); +} + +// Case 8: u8 64x64 +extern "C" __global__ AICORE void TMOV_u8_64x64(__gm__ uint8_t *src, __gm__ uint8_t *dst); + +void LaunchTMOV_u8_64x64(uint8_t *src, uint8_t *dst, void *stream) { + TMOV_u8_64x64<<<1, nullptr, stream>>>((__gm__ uint8_t *)src, (__gm__ uint8_t *)dst); +} + +// Case 9: u8 128x128 +extern "C" __global__ AICORE void TMOV_u8_128x128(__gm__ uint8_t *src, __gm__ uint8_t *dst); + +void LaunchTMOV_u8_128x128(uint8_t *src, uint8_t *dst, void *stream) { + TMOV_u8_128x128<<<1, nullptr, stream>>>((__gm__ uint8_t *)src, (__gm__ uint8_t *)dst); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmov/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tmov/main.cpp new file mode 100644 index 000000000..bd1bfcac1 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmov/main.cpp @@ -0,0 +1,147 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tmov ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTMOV_f32_64x64(float *src, float *dst, void *stream); +void LaunchTMOV_f32_32x32(float *src, float *dst, void *stream); +void LaunchTMOV_f32_128x128(float *src, float *dst, void *stream); +void LaunchTMOV_f32_128x32(float *src, float *dst, void *stream); +void LaunchTMOV_f32_128x64(float *src, float *dst, void *stream); +void LaunchTMOV_f16_64x64(uint16_t *src, uint16_t *dst, void *stream); +void LaunchTMOV_f16_32x32(uint16_t *src, uint16_t *dst, void *stream); +void LaunchTMOV_f16_128x128(uint16_t *src, uint16_t *dst, void *stream); +void LaunchTMOV_u8_64x64(uint8_t *src, uint8_t *dst, void *stream); +void LaunchTMOV_u8_128x128(uint8_t *src, uint8_t *dst, void *stream); + +struct TestCase { + const char *name; + void (*launch)(void *src, void *dst, void *stream); + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_64x64", (void(*)(void*,void*,void*))LaunchTMOV_f32_64x64, 64, 64, 64, 64, sizeof(float)}, + {"f32_32x32", (void(*)(void*,void*,void*))LaunchTMOV_f32_32x32, 32, 32, 32, 32, sizeof(float)}, + {"f32_128x128", (void(*)(void*,void*,void*))LaunchTMOV_f32_128x128, 128, 128, 128, 128, sizeof(float)}, + {"f32_128x32", (void(*)(void*,void*,void*))LaunchTMOV_f32_128x32, 128, 32, 128, 32, sizeof(float)}, + {"f32_128x64", (void(*)(void*,void*,void*))LaunchTMOV_f32_128x64, 128, 64, 128, 64, sizeof(float)}, + {"f16_64x64", (void(*)(void*,void*,void*))LaunchTMOV_f16_64x64, 64, 64, 64, 64, sizeof(uint16_t)}, + {"f16_32x32", (void(*)(void*,void*,void*))LaunchTMOV_f16_32x32, 32, 32, 32, 32, sizeof(uint16_t)}, + {"f16_128x128", (void(*)(void*,void*,void*))LaunchTMOV_f16_128x128, 128, 128, 128, 128, sizeof(uint16_t)}, + {"u8_64x64", (void(*)(void*,void*,void*))LaunchTMOV_u8_64x64, 64, 64, 64, 64, sizeof(uint8_t)}, + {"u8_128x128", (void(*)(void*,void*,void*))LaunchTMOV_u8_128x128, 128, 128, 128, 128, sizeof(uint8_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSize = fileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&srcHost), fileSize); + aclrtMallocHost((void **)(&dstHost), fileSize); + + aclrtMalloc((void **)&srcDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input.bin").c_str(), srcFileSize, srcHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, fileSize, srcHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tmov [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmov/tmov.pto b/test/tilelang_st/npu/a5/src/st/testcase/tmov/tmov.pto new file mode 100644 index 000000000..8c91e55d3 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmov/tmov.pto @@ -0,0 +1,407 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tmov: tload(src) + tmov(src,dst) + tstore(dst). +// Multiple cases with different shapes and types in a single module. +// Based on pto-isa tmov_vect test cases. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + + // Case 0: f32 64x64 (4096 elements) + func.func @TMOV_f32_64x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%src : !pto.tile_buf) + + pto.tmov ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + return + } + + // Case 1: f32 32x32 (1024 elements) + func.func @TMOV_f32_32x32(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + outs(%src : !pto.tile_buf) + + pto.tmov ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + return + } + + // Case 2: f32 128x128 (16384 elements) + func.func @TMOV_f32_128x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c16384 = arith.constant 16384 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c128, %c128], + strides = [%c16384, %c16384, %c16384, %c128, %c1] + : !pto.tensor_view<1x1x1x128x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c128, %c128], + strides = [%c16384, %c16384, %c16384, %c128, %c1] + : !pto.tensor_view<1x1x1x128x128xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c128] + : !pto.tensor_view<1x1x1x128x128xf32> -> !pto.partition_tensor_view<1x1x1x128x128xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c128] + : !pto.tensor_view<1x1x1x128x128xf32> -> !pto.partition_tensor_view<1x1x1x128x128xf32> + + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x128x128xf32>) + outs(%src : !pto.tile_buf) + + pto.tmov ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x128x128xf32>) + return + } + + // Case 3: f32 128x32 (4096 elements) + func.func @TMOV_f32_128x32(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c32 = arith.constant 32 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c128, %c32], + strides = [%c4096, %c4096, %c4096, %c32, %c1] + : !pto.tensor_view<1x1x1x128x32xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c128, %c32], + strides = [%c4096, %c4096, %c4096, %c32, %c1] + : !pto.tensor_view<1x1x1x128x32xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c32] + : !pto.tensor_view<1x1x1x128x32xf32> -> !pto.partition_tensor_view<1x1x1x128x32xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c32] + : !pto.tensor_view<1x1x1x128x32xf32> -> !pto.partition_tensor_view<1x1x1x128x32xf32> + + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x128x32xf32>) + outs(%src : !pto.tile_buf) + + pto.tmov ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x128x32xf32>) + return + } + + // Case 4: f32 128x64 (8192 elements) + func.func @TMOV_f32_128x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c64 = arith.constant 64 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c128, %c64], + strides = [%c8192, %c8192, %c8192, %c64, %c1] + : !pto.tensor_view<1x1x1x128x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c128, %c64], + strides = [%c8192, %c8192, %c8192, %c64, %c1] + : !pto.tensor_view<1x1x1x128x64xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c64] + : !pto.tensor_view<1x1x1x128x64xf32> -> !pto.partition_tensor_view<1x1x1x128x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c64] + : !pto.tensor_view<1x1x1x128x64xf32> -> !pto.partition_tensor_view<1x1x1x128x64xf32> + + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x128x64xf32>) + outs(%src : !pto.tile_buf) + + pto.tmov ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x128x64xf32>) + return + } + + // Case 5: f16 64x64 (4096 elements) + func.func @TMOV_f16_64x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf16> -> !pto.partition_tensor_view<1x1x1x64x64xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf16> -> !pto.partition_tensor_view<1x1x1x64x64xf16> + + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x64x64xf16>) + outs(%src : !pto.tile_buf) + + pto.tmov ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x64x64xf16>) + return + } + + // Case 6: f16 32x32 (1024 elements) + func.func @TMOV_f16_32x32(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf16> -> !pto.partition_tensor_view<1x1x1x32x32xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf16> -> !pto.partition_tensor_view<1x1x1x32x32xf16> + + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x32xf16>) + outs(%src : !pto.tile_buf) + + pto.tmov ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x32xf16>) + return + } + + // Case 7: f16 128x128 (16384 elements) + func.func @TMOV_f16_128x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c16384 = arith.constant 16384 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c128, %c128], + strides = [%c16384, %c16384, %c16384, %c128, %c1] + : !pto.tensor_view<1x1x1x128x128xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c128, %c128], + strides = [%c16384, %c16384, %c16384, %c128, %c1] + : !pto.tensor_view<1x1x1x128x128xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c128] + : !pto.tensor_view<1x1x1x128x128xf16> -> !pto.partition_tensor_view<1x1x1x128x128xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c128] + : !pto.tensor_view<1x1x1x128x128xf16> -> !pto.partition_tensor_view<1x1x1x128x128xf16> + + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x128x128xf16>) + outs(%src : !pto.tile_buf) + + pto.tmov ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x128x128xf16>) + return + } + + // Case 8: u8 64x64 (4096 elements) + func.func @TMOV_u8_64x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xui8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xui8> -> !pto.partition_tensor_view<1x1x1x64x64xui8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xui8> -> !pto.partition_tensor_view<1x1x1x64x64xui8> + + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x64x64xui8>) + outs(%src : !pto.tile_buf) + + pto.tmov ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x64x64xui8>) + return + } + + // Case 9: u8 128x128 (16384 elements) + func.func @TMOV_u8_128x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c16384 = arith.constant 16384 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c128, %c128], + strides = [%c16384, %c16384, %c16384, %c128, %c1] + : !pto.tensor_view<1x1x1x128x128xui8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c128, %c128], + strides = [%c16384, %c16384, %c16384, %c128, %c1] + : !pto.tensor_view<1x1x1x128x128xui8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c128] + : !pto.tensor_view<1x1x1x128x128xui8> -> !pto.partition_tensor_view<1x1x1x128x128xui8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c128] + : !pto.tensor_view<1x1x1x128x128xui8> -> !pto.partition_tensor_view<1x1x1x128x128xui8> + + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x128x128xui8>) + outs(%src : !pto.tile_buf) + + pto.tmov ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x128x128xui8>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmrgsort/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tmrgsort/CMakeLists.txt new file mode 100644 index 000000000..caf51ddc9 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmrgsort/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tmrgsort) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmrgsort/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tmrgsort/cases.py new file mode 100644 index 000000000..fd2308b69 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmrgsort/cases.py @@ -0,0 +1,371 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tmrgsort ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32, np.float16). + - format: "single" for Format1 (1-list internal block sorting), + "multi" for Format2-4 (multi-list merge sort). + - src_shape: (rows, cols) - allocated source tile dimensions. + For Format1: single input list. + For multi-list: list of shapes for each input. + - dst_shape: (rows, cols) - allocated destination tile dimensions. + - valid_shape: (valid_rows, valid_cols) - effective computation region. + - block_len: For Format1: block length in elements (must divide src_cols by 4). + - list_num: For multi-list: number of input lists (2, 3, or 4). + - src_cols: For multi-list: list of valid cols for each input list. + - topk: For multi-list: top-k output count. + - exhausted: For multi-list: whether to enable exhausted suspension. + - eps: tolerance for numpy.allclose (atol and rtol). + +tmrgsort semantics: + - Format1 (single list): Sorts 4 internal blocks of src using vmrgsort4. + Each block is sorted independently, then merged. + Output: interleaved (sorted_value, original_index) pairs. + - Format2-4 (multi-list): Merges 2-4 sorted input lists into one sorted output. + Each input list must already be sorted (in descending order). + Output: top-k sorted elements from merged lists. + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + # Format1: single list (internal block sorting) + # Transplanted from pto-isa case_single1: TMrgsortSingle + # Shape uses FLOAT ELEMENT count (matching pto-isa kGCols convention) + # src_cols=256 float elements = 128 (value,index) structures + # block_len=64 float elements = 32 structures/block, 4 blocks total + { + "name": "f32_single_1x256_b64", + "dtype": np.float32, + "format": "single", + "src_shape": (1, 256), # kGCols=256 float elements + "dst_shape": (1, 256), # kGCols=256 float elements + "valid_shape": (1, 256), + "block_len": 64, # float elements (=32 structures) + "eps": 1e-6, + }, + # Transplanted from pto-isa case_single2: TMrgsortSingle + # GCols=320 > TCols=256, global memory has padding, kernel uses TCols + # src_cols=320 float elements (global), valid_cols=256 float elements (tile) + # block_len=64 float elements = 32 structures/block + { + "name": "f32_single_1x320_b64", + "dtype": np.float32, + "format": "single", + "src_shape": (1, 320), # kGCols=320 float elements (global) + "dst_shape": (1, 320), # kGCols=320 float elements (global) + "valid_shape": (1, 256), # kTCols=256 (effective tile region) + "block_len": 64, # float elements (=32 structures) + "eps": 1e-6, + }, + # Transplanted from pto-isa case_single3: TMrgsortSingle + # cols=512 float elements = 256 structures + # block_len=64 float elements = 32 structures/block, 4 blocks total + { + "name": "f32_single_1x512_b64", + "dtype": np.float32, + "format": "single", + "src_shape": (1, 512), # kGCols=512 float elements + "dst_shape": (1, 512), # kGCols=512 float elements + "valid_shape": (1, 512), + "block_len": 64, # float elements (=32 structures) + "eps": 1e-6, + }, + # Transplanted from pto-isa case_single4: TMrgsortSingle + # kGCols=640 > kTCols=512, global memory has padding, kernel uses kTCols + # src_cols=640 float elements (global), valid_cols=512 float elements (tile) + # block_len=64 float elements = 32 structures/block + { + "name": "f32_single_1x640_b64", + "dtype": np.float32, + "format": "single", + "src_shape": (1, 640), # kGCols=640 float elements (global) + "dst_shape": (1, 640), # kGCols=640 float elements (global) + "valid_shape": (1, 512), # kTCols=512 (effective tile region) + "block_len": 64, # float elements (=32 structures) + "eps": 1e-6, + }, + # Transplanted from pto-isa case_single5: TMrgsortSingle + # uint16_t maps to float16 (half) in Ascend C + # TYPE_COEF=2: kGCols*2=512, kTCols*2=512, blockLen*2=128 (kernel internal) + # src_shape uses TYPE_COEF-adjusted counts: 512 f16 elements = 128 structures + # block_len=64 template units → 128 f16 elements in kernel = 32 structures/block + { + "name": "f16_single_1x256_b64", + "dtype": np.float16, + "format": "single", + "src_shape": (1, 512), # kGCols*TYPE_COEF=512 f16 elements = 128 structures + "dst_shape": (1, 512), # kGCols*TYPE_COEF=512 f16 elements + "valid_shape": (1, 512), + "block_len": 128, # block_len*TYPE_COEF=128 f16 elements = 32 structures + "eps": 1e-3, # f16 has lower precision + }, + # Transplanted from pto-isa case_single6: TMrgsortSingle + # TYPE_COEF=2: kGCols*2=640, kTCols*2=512, blockLen*2=128 (kernel internal) + # kGCols=320 > kTCols=256, global memory has padding + # src_shape uses TYPE_COEF-adjusted: 640 f16 elements (global), 512 f16 (valid) + { + "name": "f16_single_1x320_b64", + "dtype": np.float16, + "format": "single", + "src_shape": (1, 640), # kGCols*TYPE_COEF=640 f16 elements (global) + "dst_shape": (1, 640), # kGCols*TYPE_COEF=640 f16 elements (global) + "valid_shape": (1, 512), # kTCols*TYPE_COEF=512 (effective tile region) + "block_len": 128, # block_len*TYPE_COEF=128 f16 elements = 32 structures + "eps": 1e-3, + }, + # Transplanted from pto-isa case_single7: TMrgsortSingle + # TYPE_COEF=2: kGCols*2=1024, kTCols*2=1024, blockLen*2=128 (kernel internal) + # src_shape uses TYPE_COEF-adjusted: 1024 f16 elements = 256 structures + { + "name": "f16_single_1x512_b64", + "dtype": np.float16, + "format": "single", + "src_shape": (1, 1024), # kGCols*TYPE_COEF=1024 f16 elements = 256 structures + "dst_shape": (1, 1024), # kGCols*TYPE_COEF=1024 f16 elements + "valid_shape": (1, 1024), + "block_len": 128, # block_len*TYPE_COEF=128 f16 elements = 32 structures + "eps": 1e-3, + }, + # Transplanted from pto-isa case_single8: TMrgsortSingle + # TYPE_COEF=2: kGCols*2=2048, kTCols*2=2048, blockLen*2=512 (kernel internal) + # src_shape uses TYPE_COEF-adjusted: 2048 f16 elements = 512 structures + { + "name": "f16_single_1x1024_b256", + "dtype": np.float16, + "format": "single", + "src_shape": (1, 2048), # kGCols*TYPE_COEF=2048 f16 elements = 512 structures + "dst_shape": (1, 2048), # kGCols*TYPE_COEF=2048 f16 elements + "valid_shape": (1, 2048), + "block_len": 512, # block_len*TYPE_COEF=512 f16 elements = 128 structures + "eps": 1e-3, + }, + # Format2: multi-list merge (2-list merge) + { + "name": "f32_2list_b64_basic", + "dtype": np.float32, + "format": "multi", + "list_num": 2, + "src_cols": [128, 128], + "src_shape": [(1, 256), (1, 256)], + "dst_shape": (1, 256), + "valid_shape": (1, 256), + "topk": 128, + "exhausted": False, + "eps": 1e-6, + }, + { + "name": "f16_2list_b64_basic", + "dtype": np.float16, + "format": "multi", + "list_num": 2, + "src_cols": [64, 64], # 64 structures per list (match src_shape) + "src_shape": [(1, 256), (1, 256)], # 256 f16 elements = 64 structures + "dst_shape": (1, 256), + "valid_shape": (1, 256), + "topk": 64, # topk should match dst capacity + "exhausted": False, + "eps": 1e-3, + }, + # Format2: exhausted=true cases (aligned with pto-isa case_exhausted1) + # pto-isa template: kGCols_=64 (elements) → 32 structures per list + # TOPK=128 (elements) → 64 structures output + { + "name": "f32_2list_exhausted", + "dtype": np.float32, + "format": "multi", + "list_num": 2, + "src_cols": [32, 32], # 32 structures per list (64 elements / 2) + "src_shape": [(1, 64), (1, 64)], # 64 f32 elements = 32 structures + "dst_shape": (1, 128), # 128 f32 elements = 64 structures (=TOPK) + "valid_shape": (1, 128), # match dst_shape + "topk": 64, # topk in structures (=64 structures) + "exhausted": True, + "eps": 1e-6, + }, + # Format3: 3-list merge sort + { + "name": "f32_3list_b64_basic", + "dtype": np.float32, + "format": "multi", + "list_num": 3, + "src_cols": [64, 64, 64], # 64 structures per list + "src_shape": [(1, 128), (1, 128), (1, 128)], # 128 f32 elements = 64 structures each + "dst_shape": (1, 256), # 256 f32 elements = 128 structures + "valid_shape": (1, 256), + "topk": 128, # topk structures (192 available, output 128) + "exhausted": False, + "eps": 1e-6, + }, + # Format4: 4-list merge sort + { + "name": "f32_4list_b32_basic", + "dtype": np.float32, + "format": "multi", + "list_num": 4, + "src_cols": [64, 64, 64, 64], + "src_shape": [(1, 128), (1, 128), (1, 128), (1, 128)], + "dst_shape": (1, 512), + "valid_shape": (1, 512), + "topk": 256, + "exhausted": False, + "eps": 1e-6, + }, + { + "name": "f16_4list_b64_basic", + "dtype": np.float16, + "format": "multi", + "list_num": 4, + "src_cols": [64, 64, 64, 64], # 64 structures per list + "src_shape": [(1, 256), (1, 256), (1, 256), (1, 256)], # 256 f16 elements = 64 structures each + "dst_shape": (1, 1024), # 1024 f16 elements = 256 structures + "valid_shape": (1, 1024), + "topk": 256, # topk structures (256 available, output 256) + "exhausted": False, + "eps": 1e-3, + }, + # Format3 variants: non-uniform cols + { + "name": "f32_3list_non_uniform", + "dtype": np.float32, + "format": "multi", + "list_num": 3, + "src_cols": [64, 64, 32], # non-uniform: 64,64,32 structures + "src_shape": [(1, 128), (1, 128), (1, 64)], # f32 elements + "dst_shape": (1, 128), # 128 f32 elements = 64 structures + "valid_shape": (1, 128), + "topk": 64, # structures (total=160 available, output topk=64) + "exhausted": False, + "eps": 1e-6, + }, + # Format3 variants: f16 4-list basic + # tmp tile cols=512 can hold max 256 structures for f16 (512/2=256) + # src_cols in STRUCTURES, srcShape in ELEMENTS (f16: 4 elems/struct) + { + "name": "f16_4list_basic", + "dtype": np.float16, + "format": "multi", + "list_num": 4, + "src_cols": [64, 64, 64, 64], + "src_shape": [(1, 256), (1, 256), (1, 256), (1, 256)], + "dst_shape": (1, 1024), + "valid_shape": (1, 1024), + "topk": 256, + "exhausted": False, + "eps": 1e-3, + }, + # Format3 variants: f16 exhausted (aligned with pto-isa case_exhausted2) + # pto-isa template: kGCols_=256 (DataType=float sized), TOPK=768 (float sized) + # In f16 units: 256 float-sized * 4 / 2 = 512 f16 elements per input = 128 structures + # TOPK: 768 float-sized * 4 / 2 = 1536 f16 elements output = 384 structures + { + "name": "f16_3list_exhausted", + "dtype": np.float16, + "format": "multi", + "list_num": 3, + "src_cols": [128, 128, 128], # 128 structures per list (512 f16 elements) + "src_shape": [(1, 512), (1, 512), (1, 512)], # 512 f16 elements = 128 structures + "dst_shape": (1, 1536), # 1536 f16 elements = 384 structures (=TOPK) + "valid_shape": (1, 1536), + "topk": 384, # structures (=384) + "exhausted": True, + "eps": 1e-3, + }, + # Format4 variants: non-uniform cols + { + "name": "f32_4list_non_uniform", + "dtype": np.float32, + "format": "multi", + "list_num": 4, + "src_cols": [64, 64, 64, 32], # non-uniform: 64,64,64,32 structures + "src_shape": [(1, 128), (1, 128), (1, 128), (1, 64)], # f32 elements + "dst_shape": (1, 448), # 448 f32 elements = 224 structures + "valid_shape": (1, 448), + "topk": 224, # structures (total=224, output all) + "exhausted": False, + "eps": 1e-6, + }, + + # Format5: TopK (full sorting with top-k output) + # Following pto-isa case_topk1-6 + # Input: unsorted raw data (value-index interleaved) + # Output: top-k sorted elements + { + "name": "f32_topk_2048_1024", + "dtype": np.float32, + "format": "topk", + "src_shape": (1, 2048), # 2048 f32 elements = 1024 structs (input unsorted) + "dst_shape": (1, 1024), # 1024 f32 elements = 512 structs (output topk) + "valid_shape": (1, 2048), # full input cols + "topk": 512, # output structures count + "block_len": 64, # initial block length in elements + "eps": 1e-6, + }, + { + "name": "f32_topk_2048_2048", + "dtype": np.float32, + "format": "topk", + "src_shape": (1, 2048), # 2048 f32 elements = 1024 structs + "dst_shape": (1, 2048), # 2048 f32 elements = 1024 structs (output all) + "valid_shape": (1, 2048), + "topk": 1024, # output all structures + "block_len": 64, + "eps": 1e-6, + }, + { + "name": "f32_topk_1280_512", + "dtype": np.float32, + "format": "topk", + "src_shape": (1, 1280), # 1280 f32 elements = 640 structs + "dst_shape": (1, 512), # 512 f32 elements = 256 structs + "valid_shape": (1, 1280), + "topk": 256, # output 256 structures + "block_len": 64, + "eps": 1e-6, + }, + { + "name": "f16_topk_2048_1024", + "dtype": np.float16, + "format": "topk", + "src_shape": (1, 2048), # 2048 f16 elements = 512 structs + "dst_shape": (1, 1024), # 1024 f16 elements = 256 structs + "valid_shape": (1, 2048), + "topk": 256, # output 256 structures + "block_len": 64, + "eps": 1e-3, + }, + { + "name": "f16_topk_2048_2048", + "dtype": np.float16, + "format": "topk", + "src_shape": (1, 2048), # 2048 f16 elements = 512 structs + "dst_shape": (1, 2048), # output all + "valid_shape": (1, 2048), + "topk": 512, # output all structures + "block_len": 64, + "eps": 1e-3, + }, + { + "name": "f16_topk_1280_512", + "dtype": np.float16, + "format": "topk", + "src_shape": (1, 1280), # 1280 f16 elements = 320 structs + "dst_shape": (1, 512), # 512 f16 elements = 128 structs + "valid_shape": (1, 1280), + "topk": 128, # output 128 structures + "block_len": 64, + "eps": 1e-3, + } +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmrgsort/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tmrgsort/compare.py new file mode 100644 index 000000000..f0e49169c --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmrgsort/compare.py @@ -0,0 +1,221 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np +import struct + +# Add parent directory to path for st_common import +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from st_common import result_cmp, style_fail, style_pass + +from cases import CASES + + +def read_value_index_pairs(filepath, dtype, count): + """Read interleaved (value, index) pairs from file. + + Format: value followed by index (uint32). + For f16: value (2 bytes) + padding (2 bytes) + index (4 bytes) = 8 bytes per pair. + For f32: value (4 bytes) + index (4 bytes) = 8 bytes per pair. + """ + values = [] + indices = [] + + struct_fmt = 'fI' if dtype == np.float32 else 'e2xI' + struct_size = struct.calcsize(struct_fmt) + + with open(filepath, 'rb') as f: + for _ in range(count): + data = f.read(struct_size) + if not data: + break + unpacked = struct.unpack(struct_fmt, data) + values.append(unpacked[0]) + indices.append(unpacked[1]) + + return np.array(values, dtype=dtype), np.array(indices, dtype=np.uint32) + + +def handle_output_data(golden_vals, golden_idx, output_vals, output_idx): + """Handle exhausted case: zero output values and indices where golden values are 0. + + Following pto-isa HandleOutputData logic: + - Scan from end, find first non-zero golden value + - Zero output values where golden values are 0 + + Also zero output indices where golden indices are 0 (matching gen_data.py behavior). + """ + size = len(golden_vals) + i = size - 1 + while i > 0: + if golden_vals[i] == 0.0: + output_vals[i] = 0.0 + if golden_idx[i] == 0: + output_idx[i] = 0 + i -= 1 + else: + return + + +def compare_multilist(case): + """Compare multi-list merge sort output. + + For multi-list format: + - Read input0.bin, input1.bin, etc. + - Read output.bin + - Compare top-k elements with golden.bin + """ + dtype = case["dtype"] + list_num = case["list_num"] + src_cols = case["src_cols"] + topk = case["topk"] + exhausted = case.get("exhausted", False) + + # Calculate element divisor + if dtype == np.float16: + elem_divisor = 4 + else: + elem_divisor = 2 + + # Total structures to compare + total_structures = sum(src_cols) + + # Read golden output + golden_vals, golden_indices = read_value_index_pairs( + os.path.join(case["name"], "golden.bin"), dtype, total_structures + ) + + # Read actual output + output_vals, output_indices = read_value_index_pairs( + os.path.join(case["name"], "output.bin"), dtype, total_structures + ) + + if exhausted: + handle_output_data(golden_vals, golden_indices, output_vals, output_indices) + + # Compare top-k elements (only compare the valid output) + vals_ok = result_cmp(golden_vals[:topk], output_vals[:topk], case["eps"]) + indices_ok = np.allclose(golden_indices[:topk], output_indices[:topk], atol=0, rtol=0) + + return vals_ok and indices_ok + + +def compare_topk(case): + """Compare TopK output. + + For TopK format: + - Read input0.bin (unsorted raw data) + - Read output.bin (top-k sorted data) + - Compare with golden.bin + """ + dtype = case["dtype"] + valid_shape = case["valid_shape"] + valid_rows, valid_cols = valid_shape + topk = case["topk"] + + # Get element divisor based on dtype + if dtype == np.float16: + elem_divisor = 4 + else: + elem_divisor = 2 + + # Total structures in input + total_structures = valid_cols // elem_divisor + + # Read golden output + golden_vals, golden_indices = read_value_index_pairs( + os.path.join(case["name"], "golden.bin"), dtype, total_structures + ) + + # Read actual output + output_vals, output_indices = read_value_index_pairs( + os.path.join(case["name"], "output.bin"), dtype, topk + ) + + # Compare top-k elements + vals_ok = result_cmp(golden_vals[:topk], output_vals[:topk], case["eps"]) + indices_ok = np.allclose(golden_indices[:topk], output_indices[:topk], atol=0, rtol=0) + + return vals_ok and indices_ok + + +def main(): + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + format_type = case.get("format", "single") + + if format_type == "single": + dtype = case["dtype"] + valid_shape = case["valid_shape"] + valid_rows, valid_cols = valid_shape + block_len = case["block_len"] + + # Get element divisor based on dtype + if dtype == np.float16: + elem_divisor = 4 + else: + elem_divisor = 2 + + cols = valid_cols // elem_divisor + + golden_vals, golden_indices = read_value_index_pairs( + os.path.join(case["name"], "golden.bin"), dtype, cols + ) + output_vals, output_indices = read_value_index_pairs( + os.path.join(case["name"], "output.bin"), dtype, cols + ) + + vals_ok = result_cmp(golden_vals, output_vals, case["eps"]) + indices_ok = np.allclose(golden_indices, output_indices, atol=0, rtol=0) + + if vals_ok and indices_ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + if not vals_ok: + print(style_fail(f"[ERROR] {case['name']}: values mismatch")) + if not indices_ok: + print(style_fail(f"[ERROR] {case['name']}: indices mismatch")) + all_passed = False + + elif format_type == "multi": + ok = compare_multilist(case) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: values or indices mismatch")) + all_passed = False + + elif format_type == "topk": + ok = compare_topk(case) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: values or indices mismatch")) + all_passed = False + + else: + print(style_fail(f"[ERROR] {case['name']}: unsupported format {format_type}")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmrgsort/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tmrgsort/gen_data.py new file mode 100644 index 000000000..beb7ff3ac --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmrgsort/gen_data.py @@ -0,0 +1,418 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +import os +import sys +import struct +import ctypes + +# Add parent directory to path for st_common import +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from st_common import setup_case_rng, save_case_data + +from cases import CASES + +BLOCK_NUM = 4 +STRUCT_SIZE = 8 # bytes per structure (value + index) + + +def find_and_zero(arr, tar): + for item in arr: + if not isinstance(item, (np.floating)): + return -1 + if not all(isinstance(x, (np.floating)) for x in arr): + raise ValueError("The input must be a list of numbers.") + if not isinstance(tar, (np.floating)): + return -1 + + n = len(arr) + for i in range(n - 1, -1, -1): + if arr[i] == tar: + for j in range(i + 1, n): + arr[j] = 0 + return i + return -1 + + +def zero_after_index(arr, i): + if i < 0 or i >= len(arr): + return + for j in range(i + 1, len(arr)): + arr[j] = 0 + + +def handle_exhausted_list(input_num, topk_sorted_output_global, topk_sorted_idx_global, last_data): + for i in range(input_num): + zero_index = find_and_zero(topk_sorted_output_global, last_data[i]) + zero_after_index(topk_sorted_idx_global, zero_index) + + +def _to_tuple(shape): + """Convert shape to tuple if needed.""" + if isinstance(shape, tuple): + return shape + return tuple(shape) + + +def get_elem_divisor(dtype): + """Get element divisor based on dtype. + + A structure is 8 bytes: + - f32 (4 bytes): 8 / 4 = 2 elems per struct + - f16 (2 bytes): 8 / 2 = 4 elems per struct + """ + if dtype == np.float16: + return 4 + return 2 + + +def write_value_index_pair(f, value, index, dtype): + """Write a (value, index) pair to file. + + Format: value followed by index (uint32). + For f16: value (2 bytes) + padding (2 bytes) + index (4 bytes). + For f32: value (4 bytes) + index (4 bytes). + """ + if dtype == np.float32: + packed_data = struct.pack('fI', float(value), ctypes.c_uint32(index).value) + f.write(packed_data) + elif dtype == np.float16: + # f16: directly pack value (np.float16), not float(value) + # Following pto-isa: struct.pack('e2xI', value, ...) + packed_data = struct.pack('e2xI', value, ctypes.c_uint32(index).value) + f.write(packed_data) + + +def gen_golden_single(case): + """Generate golden data for Format1 (single list internal block sorting). + + Following pto-isa gen_data.py logic exactly: + - cols = src_cols // elem_divisor (STRUCTURE count, using full src_cols not valid_cols) + - list_col = block_len // elem_divisor (STRUCTURES per block) + - block_lens = list_col * 4 (STRUCTURES per vmrgsort4 call) + - block_lens_floats = block_len * 4 (FLOATS per vmrgsort4 call) + + Process: + 1. Generate random data (cols structures) + 2. Reshape into blocks (each list_col structures) + 3. Sort each block internally -> input0.bin + 4. Reshape into groups (each block_lens structures) + 5. Globally sort each group -> golden.bin + + For cases where src_cols > valid_cols (padding), generate full src_cols with zeros padding. + """ + dtype = case["dtype"] + src_shape = _to_tuple(case["src_shape"]) + dst_shape = _to_tuple(case["dst_shape"]) + valid_shape = _to_tuple(case["valid_shape"]) + block_len = case["block_len"] + + src_rows, src_cols = src_shape + valid_rows, valid_cols = valid_shape + + # Get element divisor based on dtype (2 for f32, 4 for f16) + elem_divisor = get_elem_divisor(dtype) + + # Use FULL src_cols for file size (matching pto-isa kGCols) + cols = src_cols // elem_divisor # total structures in file + valid_structs = valid_cols // elem_divisor # valid structures for computation + list_col = block_len // elem_divisor # structures per block + block_lens = list_col * 4 # structures per vmrgsort4 call + block_lens_floats = block_len * 4 # floats per vmrgsort4 call + + repeat_times = valid_structs // block_lens # vmrgsort4 call times (use valid_structs) + + # Generate random data only for valid portion (matching pto-isa which uses kTCols for computation) + input_arr = np.random.uniform(low=0.0, high=1.0, size=(1, valid_structs)).astype(dtype) + idx_arr = np.arange(valid_structs, dtype=np.uint32) + + # Step 1: Sort each block internally + # Reshape to (total_blocks, list_col) + input_reshaped = input_arr.reshape(-1, list_col) + idx_reshaped = idx_arr.reshape(-1, list_col) + + # Sort each block descending + sorted_indices = np.argsort(-input_reshaped, kind='stable', axis=1) + sorted_input = np.take_along_axis(input_reshaped, sorted_indices, axis=1) + sorted_idx = np.take_along_axis(idx_reshaped, sorted_indices, axis=1) + + # Flatten back -> input0.bin (needs padding if cols > valid_structs) + flat_input = sorted_input.flatten() + flat_idx = sorted_idx.flatten() + + # Pad input with zeros if needed (for src_cols > valid_cols cases) + if cols > valid_structs: + pad_input = np.zeros(cols - valid_structs, dtype=dtype) + pad_idx = np.zeros(cols - valid_structs, dtype=np.uint32) + flat_input = np.concatenate((flat_input, pad_input)) + flat_idx = np.concatenate((flat_idx, pad_idx)) + + # Step 2: Generate golden (globally sort each group, using valid_structs) + # Take complete groups from valid portion + input_group = flat_input[:valid_structs // block_lens * block_lens] + idx_group = flat_idx[:valid_structs // block_lens * block_lens] + + # Reshape to (repeat_times, block_lens) + single_output_reshape = input_group.reshape(-1, block_lens) + single_idx_reshape = idx_group.reshape(-1, block_lens) + + # Globally sort each group descending + single_sorted_indices = np.argsort(-single_output_reshape, kind='stable', axis=1) + golden_values = np.take_along_axis(single_output_reshape, single_sorted_indices, axis=1).flatten() + golden_indices = np.take_along_axis(single_idx_reshape, single_sorted_indices, axis=1).flatten() + + # Handle remaining elements from valid portion + if valid_structs % block_lens != 0: + zeros_output = np.zeros(valid_structs % block_lens, dtype=golden_values.dtype) + zeros_index = np.zeros(valid_structs % block_lens, dtype=np.uint32) + golden_values = np.concatenate((golden_values, zeros_output)) + golden_indices = np.concatenate((golden_indices, zeros_index)) + + # Pad golden with zeros for full file size (cols > valid_structs) + if cols > valid_structs: + pad_output = np.zeros(cols - valid_structs, dtype=golden_values.dtype) + pad_index = np.zeros(cols - valid_structs, dtype=np.uint32) + golden_values = np.concatenate((golden_values, pad_output)) + golden_indices = np.concatenate((golden_indices, pad_index)) + + os.makedirs(case["name"], exist_ok=True) + with open(os.path.join(case["name"], "input0.bin"), 'wb') as f: + for val, idx in zip(flat_input, flat_idx): + write_value_index_pair(f, val, idx, dtype) + + with open(os.path.join(case["name"], "golden.bin"), 'wb') as f: + for val, idx in zip(golden_values, golden_indices): + write_value_index_pair(f, val, idx, dtype) + + print(f"[INFO] gen_data: {case['name']} src_cols={src_cols} valid_cols={valid_cols} " + f"cols={cols} list_col={list_col} block_lens={block_lens} repeat_times={repeat_times}") + + +def gen_golden_multilist(case): + """Generate golden data for Format2 (multi-list merge sort). + + Following pto-isa gen_data.py logic for multi-list: + 1. Generate sorted data for each input list (descending order) + 2. Concatenate all lists and globally sort (descending) + 3. Take top-k elements + 4. If exhausted=true, handle special termination logic + + Each input list is pre-sorted in descending order. + Output is top-k merged sorted elements. + """ + dtype = case["dtype"] + list_num = case["list_num"] + src_cols = case["src_cols"] # structures per list + topk = case["topk"] + exhausted = case.get("exhausted", False) + + # Calculate actual cols (in elements) per src + # Each structure = (value, index) pair = 8 bytes + # For f32: 2 elements per structure (4 bytes value + 4 bytes index) + # For f16: 4 elements per structure (2 bytes value + 2 bytes padding + 4 bytes index) + elem_divisor = get_elem_divisor(dtype) + + # Generate sorted data for each input list + output_arr_list = [] + output_idx_list = [] + last_data = [] + + total_structures = sum(src_cols) + + for i in range(list_num): + cols_i = src_cols[i] + # Generate random data for this list + input_arr = np.random.uniform(low=0.0, high=1.0, size=(1, cols_i)).astype(dtype) + idx_arr = np.arange(cols_i, dtype=np.uint32).reshape(1, cols_i) # Reshape to match input_arr + + # Sort in descending order + sorted_indices = np.argsort(-input_arr, kind='stable', axis=1) + sorted_input = np.take_along_axis(input_arr, sorted_indices, axis=1) + sorted_idx = np.take_along_axis(idx_arr, sorted_indices, axis=1) + + # Flatten + flat_input_i = sorted_input.flatten() + flat_idx_i = sorted_idx.flatten() + + output_arr_list.append(flat_input_i) + output_idx_list.append(flat_idx_i) + + # Track last element for exhausted case + if cols_i > 0: + last_data.append(flat_input_i[-1]) + else: + last_data.append(0) + + # Concatenate and globally sort (descending) + flat_input_group = np.concatenate(output_arr_list).flatten() + flat_idx_group = np.concatenate(output_idx_list).flatten() + + sorted_indices_global = np.argsort(-flat_input_group, kind='stable') + sorted_output_global = flat_input_group[sorted_indices_global] + sorted_idx_global = flat_idx_group[sorted_indices_global] + + # Take top-k + topk_sorted_output = sorted_output_global[:topk] + topk_sorted_idx = sorted_idx_global[:topk] + + # Pad zeros if needed + zeros_output = np.zeros(total_structures - topk, dtype=topk_sorted_output.dtype) + zeros_index = np.zeros(total_structures - topk, dtype=np.uint32) + topk_sorted_output_global = np.concatenate((topk_sorted_output, zeros_output)) + topk_sorted_idx_global = np.concatenate((topk_sorted_idx, zeros_index)) + + if exhausted: + handle_exhausted_list(list_num, topk_sorted_output_global, topk_sorted_idx_global, last_data) + + # Write input files (input0.bin, input1.bin, etc.) + os.makedirs(case["name"], exist_ok=True) + for i in range(list_num): + input_file = os.path.join(case["name"], f"input{i}.bin") + with open(input_file, 'wb') as f: + for val, idx in zip(output_arr_list[i], output_idx_list[i]): + write_value_index_pair(f, val, idx, dtype) + + # Write golden output file + with open(os.path.join(case["name"], "golden.bin"), 'wb') as f: + for val, idx in zip(topk_sorted_output_global, topk_sorted_idx_global): + write_value_index_pair(f, val, idx, dtype) + + print(f"[INFO] gen_data: {case['name']} list_num={list_num} " + f"src_cols={src_cols} total_structures={total_structures} topk={topk} exhausted={exhausted}") + + +def gen_golden_topk(case): + """Generate golden data for TopK (full iterative merge). + + Following pto-isa RunTMrgsortTopk logic: + 1. Generate unsorted raw data -> input0.bin + 2. Initial: sort each block internally + 3. Iterative merge loop: blockLen *= 4 + - Each iteration: Format1 merge (4 blocks -> 1) + - Copy result back for next iteration + 4. Final: take top-k from globally sorted data -> golden.bin + + This matches the full TopK template implementation. + """ + dtype = case["dtype"] + src_shape = _to_tuple(case["src_shape"]) + dst_shape = _to_tuple(case["dst_shape"]) + valid_shape = _to_tuple(case["valid_shape"]) + topk = case["topk"] # output structures count + block_len = case["block_len"] + + src_rows, src_cols = src_shape + valid_rows, valid_cols = valid_shape + + # Get element divisor based on dtype + elem_divisor = get_elem_divisor(dtype) + + # Structure units (following pto-isa) + cols = valid_cols // elem_divisor # total structures + list_col = block_len // elem_divisor # structures per block + + # Generate unsorted raw data + input_arr = np.random.uniform(low=0.0, high=1.0, size=(1, cols)).astype(dtype) + idx_arr = np.arange(cols, dtype=np.uint32) + + # Step 1: Sort each block internally (Format1 preparation) + input_reshaped = input_arr.reshape(-1, list_col) + idx_reshaped = idx_arr.reshape(-1, list_col) + + sorted_indices = np.argsort(-input_reshaped, kind='stable', axis=1) + sorted_input = np.take_along_axis(input_reshaped, sorted_indices, axis=1) + sorted_idx = np.take_along_axis(idx_reshaped, sorted_indices, axis=1) + + # Flatten -> input0.bin (block-wise sorted) + flat_input = sorted_input.flatten() + flat_idx = sorted_idx.flatten() + + # Step 2: Iterative merge (blockLen *= 4 loop) + current_data = flat_input.copy() + current_idx = flat_idx.copy() + current_block_len = list_col # structures per block + + iteration = 1 + while current_block_len * 4 <= cols: + # Format1 merge at this block length + # Merge groups of 4 blocks into 1 + block_lens = current_block_len * 4 # structures per merge group + num_groups = cols // block_lens + + # Process each group + for g in range(num_groups): + start = g * block_lens + end = start + block_lens + group_vals = current_data[start:end] + group_idx = current_idx[start:end] + + # Sort this group descending + sort_indices = np.argsort(-group_vals, kind='stable') + current_data[start:end] = group_vals[sort_indices] + current_idx[start:end] = group_idx[sort_indices] + + # Update block length for next iteration + current_block_len = current_block_len * 4 + iteration += 1 + + # Step 3: Handle tail blocks (if current_block_len < cols) + # Simplified: just globally sort the remaining data + if current_block_len < cols: + # Global sort for tail handling + sort_indices = np.argsort(-current_data, kind='stable') + current_data = current_data[sort_indices] + current_idx = current_idx[sort_indices] + + # Step 4: Take top-k + golden_values = current_data[:topk] + golden_indices = current_idx[:topk] + + # Write files + os.makedirs(case["name"], exist_ok=True) + with open(os.path.join(case["name"], "input0.bin"), 'wb') as f: + for val, idx in zip(flat_input, flat_idx): + write_value_index_pair(f, val, idx, dtype) + + # Pad zeros if needed (to match dst capacity) + dst_structures = dst_shape[1] // elem_divisor + zeros_values = np.zeros(dst_structures - topk, dtype=golden_values.dtype) + zeros_indices = np.zeros(dst_structures - topk, dtype=np.uint32) + golden_values_padded = np.concatenate((golden_values, zeros_values)) + golden_indices_padded = np.concatenate((golden_indices, zeros_indices)) + + with open(os.path.join(case["name"], "golden.bin"), 'wb') as f: + for val, idx in zip(golden_values_padded, golden_indices_padded): + write_value_index_pair(f, val, idx, dtype) + + print(f"[INFO] gen_data: {case['name']} src_cols={src_cols} valid_cols={valid_cols} " + f"cols={cols} structures topk={topk} structures block_len={block_len} " + f"iterations={iteration}") + + +def gen_golden_data(): + """Generate golden data for all cases.""" + for case in CASES: + setup_case_rng(case) + + format_type = case.get("format", "single") + if format_type == "single": + gen_golden_single(case) + elif format_type == "multi": + gen_golden_multilist(case) + elif format_type == "topk": + gen_golden_topk(case) + else: + print(f"[WARN] Unsupported format: {format_type} for case {case['name']}") + + +if __name__ == "__main__": + gen_golden_data() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmrgsort/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tmrgsort/launch.cpp new file mode 100644 index 000000000..d4e0f6ddc --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmrgsort/launch.cpp @@ -0,0 +1,177 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case: f32_single_1x256_b64 (transplanted from pto-isa case_single1) +extern "C" __global__ AICORE void TMRGSORT_f32_single_1x256_b64(__gm__ float *src, __gm__ float *dst); + +void LaunchTMRGSORT_f32_single_1x256_b64(float *src, float *dst, void *stream) { + TMRGSORT_f32_single_1x256_b64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +// Case: f32_single_1x320_b64 (transplanted from pto-isa case_single2) +extern "C" __global__ AICORE void TMRGSORT_f32_single_1x320_b64(__gm__ float *src, __gm__ float *dst); + +void LaunchTMRGSORT_f32_single_1x320_b64(float *src, float *dst, void *stream) { + TMRGSORT_f32_single_1x320_b64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +// Case: f32_single_1x512_b64 (transplanted from pto-isa case_single3) +extern "C" __global__ AICORE void TMRGSORT_f32_single_1x512_b64(__gm__ float *src, __gm__ float *dst); + +void LaunchTMRGSORT_f32_single_1x512_b64(float *src, float *dst, void *stream) { + TMRGSORT_f32_single_1x512_b64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +// Case: f32_single_1x640_b64 (transplanted from pto-isa case_single4) +extern "C" __global__ AICORE void TMRGSORT_f32_single_1x640_b64(__gm__ float *src, __gm__ float *dst); + +void LaunchTMRGSORT_f32_single_1x640_b64(float *src, float *dst, void *stream) { + TMRGSORT_f32_single_1x640_b64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +// Case: f16_single_1x256_b64 (transplanted from pto-isa case_single5) +extern "C" __global__ AICORE void TMRGSORT_f16_single_1x256_b64(__gm__ half *src, __gm__ half *dst); + +void LaunchTMRGSORT_f16_single_1x256_b64(uint16_t *src, uint16_t *dst, void *stream) { + TMRGSORT_f16_single_1x256_b64<<<1, nullptr, stream>>>((__gm__ half *)src, (__gm__ half *)dst); +} + +// Case: f16_single_1x320_b64 (transplanted from pto-isa case_single6) +extern "C" __global__ AICORE void TMRGSORT_f16_single_1x320_b64(__gm__ half *src, __gm__ half *dst); + +void LaunchTMRGSORT_f16_single_1x320_b64(uint16_t *src, uint16_t *dst, void *stream) { + TMRGSORT_f16_single_1x320_b64<<<1, nullptr, stream>>>((__gm__ half *)src, (__gm__ half *)dst); +} + +// Case: f16_single_1x512_b64 (transplanted from pto-isa case_single7) +extern "C" __global__ AICORE void TMRGSORT_f16_single_1x512_b64(__gm__ half *src, __gm__ half *dst); + +void LaunchTMRGSORT_f16_single_1x512_b64(uint16_t *src, uint16_t *dst, void *stream) { + TMRGSORT_f16_single_1x512_b64<<<1, nullptr, stream>>>((__gm__ half *)src, (__gm__ half *)dst); +} + +// Case: f16_single_1x1024_b256 (transplanted from pto-isa case_single8) +extern "C" __global__ AICORE void TMRGSORT_f16_single_1x1024_b256(__gm__ half *src, __gm__ half *dst); + +void LaunchTMRGSORT_f16_single_1x1024_b256(uint16_t *src, uint16_t *dst, void *stream) { + TMRGSORT_f16_single_1x1024_b256<<<1, nullptr, stream>>>((__gm__ half *)src, (__gm__ half *)dst); +} + +// Multi-list cases +extern "C" __global__ AICORE void TMRGSORT_f32_2list_b64_basic(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTMRGSORT_f32_2list_b64_basic(float *src0, float *src1, float *dst, void *stream) { + TMRGSORT_f32_2list_b64_basic<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TMRGSORT_f16_2list_b64_basic(__gm__ half *src0, __gm__ half *src1, __gm__ half *dst); + +void LaunchTMRGSORT_f16_2list_b64_basic(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream) { + TMRGSORT_f16_2list_b64_basic<<<1, nullptr, stream>>>((__gm__ half *)src0, (__gm__ half *)src1, (__gm__ half *)dst); +} + +extern "C" __global__ AICORE void TMRGSORT_f32_2list_exhausted(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTMRGSORT_f32_2list_exhausted(float *src0, float *src1, float *dst, void *stream) { + TMRGSORT_f32_2list_exhausted<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TMRGSORT_f32_3list_b64_basic(__gm__ float *src0, __gm__ float *src1, __gm__ float *src2, __gm__ float *dst); + +void LaunchTMRGSORT_f32_3list_b64_basic(float *src0, float *src1, float *src2, float *dst, void *stream) { + TMRGSORT_f32_3list_b64_basic<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)src2, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TMRGSORT_f32_4list_b32_basic(__gm__ float *src0, __gm__ float *src1, __gm__ float *src2, __gm__ float *src3, __gm__ float *dst); + +void LaunchTMRGSORT_f32_4list_b32_basic(float *src0, float *src1, float *src2, float *src3, float *dst, void *stream) { + TMRGSORT_f32_4list_b32_basic<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)src2, (__gm__ float *)src3, (__gm__ float *)dst); +} + +// 4-list case: f16_4list_b64_basic +extern "C" __global__ AICORE void TMRGSORT_f16_4list_b64_basic(__gm__ half *src0, __gm__ half *src1, __gm__ half *src2, __gm__ half *src3, __gm__ half *dst); + +void LaunchTMRGSORT_f16_4list_b64_basic(uint16_t *src0, uint16_t *src1, uint16_t *src2, uint16_t *src3, uint16_t *dst, void *stream) { + TMRGSORT_f16_4list_b64_basic<<<1, nullptr, stream>>>((__gm__ half *)src0, (__gm__ half *)src1, (__gm__ half *)src2, (__gm__ half *)src3, (__gm__ half *)dst); +} + +// 4-list case: f16_4list_basic (pto-isa case_multi2) +extern "C" __global__ AICORE void TMRGSORT_f16_4list_basic(__gm__ half *src0, __gm__ half *src1, __gm__ half *src2, __gm__ half *src3, __gm__ half *dst); + +void LaunchTMRGSORT_f16_4list_basic(uint16_t *src0, uint16_t *src1, uint16_t *src2, uint16_t *src3, uint16_t *dst, void *stream) { + TMRGSORT_f16_4list_basic<<<1, nullptr, stream>>>((__gm__ half *)src0, (__gm__ half *)src1, (__gm__ half *)src2, (__gm__ half *)src3, (__gm__ half *)dst); +} + +// 3-list non-uniform: f32_3list_non_uniform +extern "C" __global__ AICORE void TMRGSORT_f32_3list_non_uniform(__gm__ float *src0, __gm__ float *src1, __gm__ float *src2, __gm__ float *dst); + +void LaunchTMRGSORT_f32_3list_non_uniform(float *src0, float *src1, float *src2, float *dst, void *stream) { + TMRGSORT_f32_3list_non_uniform<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)src2, (__gm__ float *)dst); +} + +// 3-list f16 exhausted: f16_3list_exhausted +extern "C" __global__ AICORE void TMRGSORT_f16_3list_exhausted(__gm__ half *src0, __gm__ half *src1, __gm__ half *src2, __gm__ half *dst); + +void LaunchTMRGSORT_f16_3list_exhausted(uint16_t *src0, uint16_t *src1, uint16_t *src2, uint16_t *dst, void *stream) { + TMRGSORT_f16_3list_exhausted<<<1, nullptr, stream>>>((__gm__ half *)src0, (__gm__ half *)src1, (__gm__ half *)src2, (__gm__ half *)dst); +} + +// 4-list non-uniform: f32_4list_non_uniform +extern "C" __global__ AICORE void TMRGSORT_f32_4list_non_uniform(__gm__ float *src0, __gm__ float *src1, __gm__ float *src2, __gm__ float *src3, __gm__ float *dst); + +void LaunchTMRGSORT_f32_4list_non_uniform(float *src0, float *src1, float *src2, float *src3, float *dst, void *stream) { + TMRGSORT_f32_4list_non_uniform<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)src2, (__gm__ float *)src3, (__gm__ float *)dst); +} + +// TopK cases: f32_topk_2048_1024 +extern "C" __global__ AICORE void TMRGSORT_f32_topk_2048_1024(__gm__ float *src, __gm__ float *dst); + +void LaunchTMRGSORT_f32_topk_2048_1024(float *src, float *dst, void *stream) { + TMRGSORT_f32_topk_2048_1024<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +// TopK cases: f32_topk_2048_2048 +extern "C" __global__ AICORE void TMRGSORT_f32_topk_2048_2048(__gm__ float *src, __gm__ float *dst); + +void LaunchTMRGSORT_f32_topk_2048_2048(float *src, float *dst, void *stream) { + TMRGSORT_f32_topk_2048_2048<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +// TopK cases: f32_topk_1280_512 +extern "C" __global__ AICORE void TMRGSORT_f32_topk_1280_512(__gm__ float *src, __gm__ float *dst); + +void LaunchTMRGSORT_f32_topk_1280_512(float *src, float *dst, void *stream) { + TMRGSORT_f32_topk_1280_512<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +// TopK cases: f16_topk_2048_1024 +extern "C" __global__ AICORE void TMRGSORT_f16_topk_2048_1024(__gm__ half *src, __gm__ half *dst); + +void LaunchTMRGSORT_f16_topk_2048_1024(uint16_t *src, uint16_t *dst, void *stream) { + TMRGSORT_f16_topk_2048_1024<<<1, nullptr, stream>>>((__gm__ half *)src, (__gm__ half *)dst); +} + +// TopK cases: f16_topk_2048_2048 +extern "C" __global__ AICORE void TMRGSORT_f16_topk_2048_2048(__gm__ half *src, __gm__ half *dst); + +void LaunchTMRGSORT_f16_topk_2048_2048(uint16_t *src, uint16_t *dst, void *stream) { + TMRGSORT_f16_topk_2048_2048<<<1, nullptr, stream>>>((__gm__ half *)src, (__gm__ half *)dst); +} + +// TopK cases: f16_topk_1280_512 +extern "C" __global__ AICORE void TMRGSORT_f16_topk_1280_512(__gm__ half *src, __gm__ half *dst); + +void LaunchTMRGSORT_f16_topk_1280_512(uint16_t *src, uint16_t *dst, void *stream) { + TMRGSORT_f16_topk_1280_512<<<1, nullptr, stream>>>((__gm__ half *)src, (__gm__ half *)dst); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmrgsort/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tmrgsort/main.cpp new file mode 100644 index 000000000..8f8e2c758 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmrgsort/main.cpp @@ -0,0 +1,432 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tmrgsort ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTMRGSORT_f32_single_1x256_b64(float *src, float *dst, void *stream); +void LaunchTMRGSORT_f32_single_1x320_b64(float *src, float *dst, void *stream); +void LaunchTMRGSORT_f32_single_1x512_b64(float *src, float *dst, void *stream); +void LaunchTMRGSORT_f32_single_1x640_b64(float *src, float *dst, void *stream); +void LaunchTMRGSORT_f16_single_1x256_b64(uint16_t *src, uint16_t *dst, void *stream); +void LaunchTMRGSORT_f16_single_1x320_b64(uint16_t *src, uint16_t *dst, void *stream); +void LaunchTMRGSORT_f16_single_1x512_b64(uint16_t *src, uint16_t *dst, void *stream); +void LaunchTMRGSORT_f16_single_1x1024_b256(uint16_t *src, uint16_t *dst, void *stream); + +// Multi-list launch wrappers +void LaunchTMRGSORT_f32_2list_b64_basic(float *src0, float *src1, float *dst, void *stream); +void LaunchTMRGSORT_f16_2list_b64_basic(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream); +void LaunchTMRGSORT_f32_2list_exhausted(float *src0, float *src1, float *dst, void *stream); +void LaunchTMRGSORT_f32_3list_b64_basic(float *src0, float *src1, float *src2, float *dst, void *stream); +void LaunchTMRGSORT_f32_3list_non_uniform(float *src0, float *src1, float *src2, float *dst, void *stream); +void LaunchTMRGSORT_f16_3list_exhausted(uint16_t *src0, uint16_t *src1, uint16_t *src2, uint16_t *dst, void *stream); +void LaunchTMRGSORT_f32_4list_b32_basic(float *src0, float *src1, float *src2, float *src3, float *dst, void *stream); +void LaunchTMRGSORT_f32_4list_non_uniform(float *src0, float *src1, float *src2, float *src3, float *dst, void *stream); +void LaunchTMRGSORT_f16_4list_b64_basic(uint16_t *src0, uint16_t *src1, uint16_t *src2, uint16_t *src3, uint16_t *dst, void *stream); +void LaunchTMRGSORT_f16_4list_basic(uint16_t *src0, uint16_t *src1, uint16_t *src2, uint16_t *src3, uint16_t *dst, void *stream); + +// TopK launch wrappers +void LaunchTMRGSORT_f32_topk_2048_1024(float *src, float *dst, void *stream); +void LaunchTMRGSORT_f32_topk_2048_2048(float *src, float *dst, void *stream); +void LaunchTMRGSORT_f32_topk_1280_512(float *src, float *dst, void *stream); +void LaunchTMRGSORT_f16_topk_2048_1024(uint16_t *src, uint16_t *dst, void *stream); +void LaunchTMRGSORT_f16_topk_2048_2048(uint16_t *src, uint16_t *dst, void *stream); +void LaunchTMRGSORT_f16_topk_1280_512(uint16_t *src, uint16_t *dst, void *stream); + +using LaunchFn = void (*)(void *, void *, void *); +using LaunchFn2 = void (*)(void *, void *, void *, void *); +using LaunchFn3 = void (*)(void *, void *, void *, void *, void *); +using LaunchFn4 = void (*)(void *, void *, void *, void *, void *, void *); + +struct TestCase { + const char *name; + int listNum; // 1 for single-list, 2/3/4 for multi-list + LaunchFn launch; // for single-list + LaunchFn2 launch2; // for 2-list + LaunchFn3 launch3; // for 3-list + LaunchFn4 launch4; // for 4-list + size_t srcRows; + size_t srcCols; // for single-list: element count + size_t srcCols0; // for multi-list: src0 element count + size_t srcCols1; // for multi-list: src1 element count + size_t srcCols2; // for multi-list: src2 element count (for 3/4-list) + size_t srcCols3; // for multi-list: src3 element count (for 4-list) + size_t dstRows; + size_t dstCols; // element count + size_t elemSize; // bytes per element (4 for f32, 2 for f16) + size_t structSize; // 8 bytes per (value, index) pair + size_t elemsPerStruct; // structSize / elemSize (2 for f32, 4 for f16) +}; + +static const TestCase kCases[] = { + // Single-list cases (Format1) + {"f32_single_1x256_b64", 1, reinterpret_cast(LaunchTMRGSORT_f32_single_1x256_b64), nullptr, nullptr, nullptr, 1, 256, 0, 0, 0, 0, 1, 256, sizeof(float), 8, 2}, + {"f32_single_1x320_b64", 1, reinterpret_cast(LaunchTMRGSORT_f32_single_1x320_b64), nullptr, nullptr, nullptr, 1, 320, 0, 0, 0, 0, 1, 320, sizeof(float), 8, 2}, + {"f32_single_1x512_b64", 1, reinterpret_cast(LaunchTMRGSORT_f32_single_1x512_b64), nullptr, nullptr, nullptr, 1, 512, 0, 0, 0, 0, 1, 512, sizeof(float), 8, 2}, + {"f32_single_1x640_b64", 1, reinterpret_cast(LaunchTMRGSORT_f32_single_1x640_b64), nullptr, nullptr, nullptr, 1, 640, 0, 0, 0, 0, 1, 640, sizeof(float), 8, 2}, + {"f16_single_1x256_b64", 1, reinterpret_cast(LaunchTMRGSORT_f16_single_1x256_b64), nullptr, nullptr, nullptr, 1, 512, 0, 0, 0, 0, 1, 512, sizeof(uint16_t), 8, 4}, + {"f16_single_1x320_b64", 1, reinterpret_cast(LaunchTMRGSORT_f16_single_1x320_b64), nullptr, nullptr, nullptr, 1, 640, 0, 0, 0, 0, 1, 640, sizeof(uint16_t), 8, 4}, + {"f16_single_1x512_b64", 1, reinterpret_cast(LaunchTMRGSORT_f16_single_1x512_b64), nullptr, nullptr, nullptr, 1, 1024, 0, 0, 0, 0, 1, 1024, sizeof(uint16_t), 8, 4}, + {"f16_single_1x1024_b256", 1, reinterpret_cast(LaunchTMRGSORT_f16_single_1x1024_b256), nullptr, nullptr, nullptr, 1, 2048, 0, 0, 0, 0, 1, 2048, sizeof(uint16_t), 8, 4}, + + // Multi-list cases (Format2) + {"f32_2list_b64_basic", 2, nullptr, reinterpret_cast(LaunchTMRGSORT_f32_2list_b64_basic), nullptr, nullptr, 1, 0, 256, 256, 0, 0, 1, 256, sizeof(float), 8, 2}, + {"f16_2list_b64_basic", 2, nullptr, reinterpret_cast(LaunchTMRGSORT_f16_2list_b64_basic), nullptr, nullptr, 1, 0, 256, 256, 0, 0, 1, 256, sizeof(uint16_t), 8, 4}, + + // Exhausted cases (aligned with pto-isa case_exhausted1: kGCols_=64) + {"f32_2list_exhausted", 2, nullptr, reinterpret_cast(LaunchTMRGSORT_f32_2list_exhausted), nullptr, nullptr, 1, 0, 64, 64, 0, 0, 1, 128, sizeof(float), 8, 2}, + + // 3-list and 4-list cases + {"f32_3list_b64_basic", 3, nullptr, nullptr, reinterpret_cast(LaunchTMRGSORT_f32_3list_b64_basic), nullptr, 1, 0, 128, 128, 128, 0, 1, 256, sizeof(float), 8, 2}, + {"f32_3list_non_uniform", 3, nullptr, nullptr, reinterpret_cast(LaunchTMRGSORT_f32_3list_non_uniform), nullptr, 1, 0, 128, 128, 64, 0, 1, 128, sizeof(float), 8, 2}, + {"f16_3list_exhausted", 3, nullptr, nullptr, reinterpret_cast(LaunchTMRGSORT_f16_3list_exhausted), nullptr, 1, 0, 512, 512, 512, 0, 1, 1536, sizeof(uint16_t), 8, 4}, + {"f32_4list_b32_basic", 4, nullptr, nullptr, nullptr, reinterpret_cast(LaunchTMRGSORT_f32_4list_b32_basic), 1, 0, 128, 128, 128, 128, 1, 512, sizeof(float), 8, 2}, + {"f32_4list_non_uniform", 4, nullptr, nullptr, nullptr, reinterpret_cast(LaunchTMRGSORT_f32_4list_non_uniform), 1, 0, 128, 128, 128, 64, 1, 448, sizeof(float), 8, 2}, + {"f16_4list_b64_basic", 4, nullptr, nullptr, nullptr, reinterpret_cast(LaunchTMRGSORT_f16_4list_b64_basic), 1, 0, 256, 256, 256, 256, 1, 1024, sizeof(uint16_t), 8, 4}, + {"f16_4list_basic", 4, nullptr, nullptr, nullptr, reinterpret_cast(LaunchTMRGSORT_f16_4list_basic), 1, 0, 256, 256, 256, 256, 1, 1024, sizeof(uint16_t), 8, 4}, + + // TopK cases (Format5) + {"f32_topk_2048_1024", 1, reinterpret_cast(LaunchTMRGSORT_f32_topk_2048_1024), nullptr, nullptr, nullptr, 1, 2048, 0, 0, 0, 0, 1, 1024, sizeof(float), 8, 2}, + {"f32_topk_2048_2048", 1, reinterpret_cast(LaunchTMRGSORT_f32_topk_2048_2048), nullptr, nullptr, nullptr, 1, 2048, 0, 0, 0, 0, 1, 2048, sizeof(float), 8, 2}, + {"f32_topk_1280_512", 1, reinterpret_cast(LaunchTMRGSORT_f32_topk_1280_512), nullptr, nullptr, nullptr, 1, 1280, 0, 0, 0, 0, 1, 512, sizeof(float), 8, 2}, + {"f16_topk_2048_1024", 1, reinterpret_cast(LaunchTMRGSORT_f16_topk_2048_1024), nullptr, nullptr, nullptr, 1, 2048, 0, 0, 0, 0, 1, 1024, sizeof(uint16_t), 8, 4}, + {"f16_topk_2048_2048", 1, reinterpret_cast(LaunchTMRGSORT_f16_topk_2048_2048), nullptr, nullptr, nullptr, 1, 2048, 0, 0, 0, 0, 1, 2048, sizeof(uint16_t), 8, 4}, + {"f16_topk_1280_512", 1, reinterpret_cast(LaunchTMRGSORT_f16_topk_1280_512), nullptr, nullptr, nullptr, 1, 1280, 0, 0, 0, 0, 1, 512, sizeof(uint16_t), 8, 4}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, aclrtStream stream) { + int rc = 0; + std::string caseDir = std::string("./") + tc.name; + + // Single-list case (Format1) + if (tc.listNum == 1) { + // srcCols/dstCols are in ELEMENTS, need to convert to STRUCTURE count + // elemsPerStruct = structSize / elemSize (2 for f32, 4 for f16) + size_t srcStructs = tc.srcCols / tc.elemsPerStruct; + size_t dstStructs = tc.dstCols / tc.elemsPerStruct; + + // File sizes in bytes + size_t srcFileSize = tc.srcRows * srcStructs * tc.structSize; + size_t dstFileSize = tc.dstRows * dstStructs * tc.structSize; + + std::printf("[INFO] === case: %s (src=%zux%zu, dst=%zux%zu) ===\n", + tc.name, tc.srcRows, tc.srcCols, tc.dstRows, tc.dstCols); + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&srcHost), srcFileSize); + aclrtMallocHost((void **)(&dstHost), dstFileSize); + + aclrtMalloc((void **)&srcDevice, srcFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input0.bin").c_str(), srcFileSize, srcHost, srcFileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input0.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, srcFileSize, srcHost, srcFileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + } + + // Multi-list case (Format2) + else if (tc.listNum == 2) { + // For 2-list: src0, src1, dst + // srcCols0, srcCols1 are in ELEMENTS, dstCols in ELEMENTS + // elemsPerStruct = structSize / elemSize (2 for f32, 4 for f16) + size_t src0Structs = tc.srcCols0 / tc.elemsPerStruct; + size_t src1Structs = tc.srcCols1 / tc.elemsPerStruct; + size_t dstStructs = tc.dstCols / tc.elemsPerStruct; + + size_t src0FileSize = tc.srcRows * src0Structs * tc.structSize; + size_t src1FileSize = tc.srcRows * src1Structs * tc.structSize; + size_t dstFileSize = tc.dstRows * dstStructs * tc.structSize; + + std::printf("[INFO] === case: %s (src0=%zux%zu, src1=%zux%zu, dst=%zux%zu) ===\n", + tc.name, tc.srcRows, tc.srcCols0, tc.srcRows, tc.srcCols1, tc.dstRows, tc.dstCols); + + void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), src0FileSize); + aclrtMallocHost((void **)(&src1Host), src1FileSize); + aclrtMallocHost((void **)(&dstHost), dstFileSize); + + aclrtMalloc((void **)&src0Device, src0FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, src1FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + // Read input0.bin and input1.bin + if (!ReadFile((caseDir + "/input0.bin").c_str(), src0FileSize, src0Host, src0FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input0.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input1.bin").c_str(), src1FileSize, src1Host, src1FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, src0FileSize, src0Host, src0FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, src1FileSize, src1Host, src1FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch2(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) aclrtFree(src0Device); + if (src1Device != nullptr) aclrtFree(src1Device); + if (dstDevice != nullptr) aclrtFree(dstDevice); + if (src0Host != nullptr) aclrtFreeHost(src0Host); + if (src1Host != nullptr) aclrtFreeHost(src1Host); + if (dstHost != nullptr) aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + } + + // 3-list case (Format3) + else if (tc.listNum == 3) { + size_t src0Structs = tc.srcCols0 / tc.elemsPerStruct; + size_t src1Structs = tc.srcCols1 / tc.elemsPerStruct; + size_t src2Structs = tc.srcCols2 / tc.elemsPerStruct; + size_t dstStructs = tc.dstCols / tc.elemsPerStruct; + + size_t src0FileSize = tc.srcRows * src0Structs * tc.structSize; + size_t src1FileSize = tc.srcRows * src1Structs * tc.structSize; + size_t src2FileSize = tc.srcRows * src2Structs * tc.structSize; + size_t dstFileSize = tc.dstRows * dstStructs * tc.structSize; + + std::printf("[INFO] === case: %s (src0=%zux%zu, src1=%zux%zu, src2=%zux%zu, dst=%zux%zu) ===\n", + tc.name, tc.srcRows, tc.srcCols0, tc.srcRows, tc.srcCols1, + tc.srcRows, tc.srcCols2, tc.dstRows, tc.dstCols); + + void *src0Host = nullptr, *src1Host = nullptr, *src2Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *src2Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), src0FileSize); + aclrtMallocHost((void **)(&src1Host), src1FileSize); + aclrtMallocHost((void **)(&src2Host), src2FileSize); + aclrtMallocHost((void **)(&dstHost), dstFileSize); + + aclrtMalloc((void **)&src0Device, src0FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, src1FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src2Device, src2FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input0.bin").c_str(), src0FileSize, src0Host, src0FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input0.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input1.bin").c_str(), src1FileSize, src1Host, src1FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src2FileSize, src2Host, src2FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, src0FileSize, src0Host, src0FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, src1FileSize, src1Host, src1FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src2Device, src2FileSize, src2Host, src2FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch3(src0Device, src1Device, src2Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) aclrtFree(src0Device); + if (src1Device != nullptr) aclrtFree(src1Device); + if (src2Device != nullptr) aclrtFree(src2Device); + if (dstDevice != nullptr) aclrtFree(dstDevice); + if (src0Host != nullptr) aclrtFreeHost(src0Host); + if (src1Host != nullptr) aclrtFreeHost(src1Host); + if (src2Host != nullptr) aclrtFreeHost(src2Host); + if (dstHost != nullptr) aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + } + + // 4-list case (Format4) + else if (tc.listNum == 4) { + size_t src0Structs = tc.srcCols0 / tc.elemsPerStruct; + size_t src1Structs = tc.srcCols1 / tc.elemsPerStruct; + size_t src2Structs = tc.srcCols2 / tc.elemsPerStruct; + size_t src3Structs = tc.srcCols3 / tc.elemsPerStruct; + size_t dstStructs = tc.dstCols / tc.elemsPerStruct; + + size_t src0FileSize = tc.srcRows * src0Structs * tc.structSize; + size_t src1FileSize = tc.srcRows * src1Structs * tc.structSize; + size_t src2FileSize = tc.srcRows * src2Structs * tc.structSize; + size_t src3FileSize = tc.srcRows * src3Structs * tc.structSize; + size_t dstFileSize = tc.dstRows * dstStructs * tc.structSize; + + std::printf("[INFO] === case: %s (src0=%zux%zu, src1=%zux%zu, src2=%zux%zu, src3=%zux%zu, dst=%zux%zu) ===\n", + tc.name, tc.srcRows, tc.srcCols0, tc.srcRows, tc.srcCols1, + tc.srcRows, tc.srcCols2, tc.srcRows, tc.srcCols3, + tc.dstRows, tc.dstCols); + + void *src0Host = nullptr, *src1Host = nullptr, *src2Host = nullptr, *src3Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *src2Device = nullptr, *src3Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), src0FileSize); + aclrtMallocHost((void **)(&src1Host), src1FileSize); + aclrtMallocHost((void **)(&src2Host), src2FileSize); + aclrtMallocHost((void **)(&src3Host), src3FileSize); + aclrtMallocHost((void **)(&dstHost), dstFileSize); + + aclrtMalloc((void **)&src0Device, src0FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, src1FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src2Device, src2FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src3Device, src3FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input0.bin").c_str(), src0FileSize, src0Host, src0FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input0.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input1.bin").c_str(), src1FileSize, src1Host, src1FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src2FileSize, src2Host, src2FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input3.bin").c_str(), src3FileSize, src3Host, src3FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input3.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, src0FileSize, src0Host, src0FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, src1FileSize, src1Host, src1FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src2Device, src2FileSize, src2Host, src2FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src3Device, src3FileSize, src3Host, src3FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch4(src0Device, src1Device, src2Device, src3Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) aclrtFree(src0Device); + if (src1Device != nullptr) aclrtFree(src1Device); + if (src2Device != nullptr) aclrtFree(src2Device); + if (src3Device != nullptr) aclrtFree(src3Device); + if (dstDevice != nullptr) aclrtFree(dstDevice); + if (src0Host != nullptr) aclrtFreeHost(src0Host); + if (src1Host != nullptr) aclrtFreeHost(src1Host); + if (src2Host != nullptr) aclrtFreeHost(src2Host); + if (src3Host != nullptr) aclrtFreeHost(src3Host); + if (dstHost != nullptr) aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + } + + else { + std::fprintf(stderr, "[ERROR] Unsupported listNum=%d for case %s\n", tc.listNum, tc.name); + rc = 1; + } + + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmrgsort/tmrgsort.pto b/test/tilelang_st/npu/a5/src/st/testcase/tmrgsort/tmrgsort.pto new file mode 100644 index 000000000..76eeaaf18 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmrgsort/tmrgsort.pto @@ -0,0 +1,1857 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tmrgsort Format1: single list internal block sorting. +// Input is divided into 4 blocks, each block sorted, then merged. +// Output: interleaved (sorted_value, original_index) pairs. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --vpto-emit-hivm-llvm +// to produce LLVM IR. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // Case: f32_single_1x256_b64 - transplanted from pto-isa case_single1 + // TMrgsortSingle + // cols=256 float elements = 128 structures + // block_len=64 float elements = 32 structures/block + func.func @TMRGSORT_f32_single_1x256_b64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : i32 // block_len (float elements) + %c256 = arith.constant 256 : index // kGCols = total float elements + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c256] + : !pto.tensor_view<1x1x1x1x256xf32> -> !pto.partition_tensor_view<1x1x1x1x256xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c256] + : !pto.tensor_view<1x1x1x1x256xf32> -> !pto.partition_tensor_view<1x1x1x1x256xf32> + + %src_tile = pto.alloc_tile + : !pto.tile_buf + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x256xf32>) + outs(%src_tile : !pto.tile_buf) + + pto.tmrgsort ins(%src_tile, %c64 : !pto.tile_buf, i32) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x256xf32>) + return + } + + // Case: f32_single_1x320_b64 - transplanted from pto-isa case_single2 + // TMrgsortSingle + // kGCols=320 (global memory), kTCols=256 (effective tile region) + // cols=256 float elements = 128 structures + // block_len=64 float elements = 32 structures/block + func.func @TMRGSORT_f32_single_1x320_b64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : i32 // block_len (float elements) + %c320 = arith.constant 320 : index // kGCols (global memory stride) + %c256 = arith.constant 256 : index // kTCols (effective tile cols) + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c320], + strides = [%c320, %c320, %c320, %c320, %c1] + : !pto.tensor_view<1x1x1x1x320xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c320], + strides = [%c320, %c320, %c320, %c320, %c1] + : !pto.tensor_view<1x1x1x1x320xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c256] + : !pto.tensor_view<1x1x1x1x320xf32> -> !pto.partition_tensor_view<1x1x1x1x256xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c256] + : !pto.tensor_view<1x1x1x1x320xf32> -> !pto.partition_tensor_view<1x1x1x1x256xf32> + + %src_tile = pto.alloc_tile + : !pto.tile_buf + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x256xf32>) + outs(%src_tile : !pto.tile_buf) + + pto.tmrgsort ins(%src_tile, %c64 : !pto.tile_buf, i32) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x256xf32>) + return + } + + // Case: f32_single_1x512_b64 - transplanted from pto-isa case_single3 + // TMrgsortSingle + // cols=512 float elements = 256 structures + // block_len=64 float elements = 32 structures/block + func.func @TMRGSORT_f32_single_1x512_b64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : i32 // block_len (float elements) + %c512 = arith.constant 512 : index // kGCols = total float elements + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c512], + strides = [%c512, %c512, %c512, %c512, %c1] + : !pto.tensor_view<1x1x1x1x512xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c512], + strides = [%c512, %c512, %c512, %c512, %c1] + : !pto.tensor_view<1x1x1x1x512xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c512] + : !pto.tensor_view<1x1x1x1x512xf32> -> !pto.partition_tensor_view<1x1x1x1x512xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c512] + : !pto.tensor_view<1x1x1x1x512xf32> -> !pto.partition_tensor_view<1x1x1x1x512xf32> + + %src_tile = pto.alloc_tile + : !pto.tile_buf + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x512xf32>) + outs(%src_tile : !pto.tile_buf) + + pto.tmrgsort ins(%src_tile, %c64 : !pto.tile_buf, i32) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x512xf32>) + return + } + + // Case: f32_single_1x640_b64 - transplanted from pto-isa case_single4 + // TMrgsortSingle + // kGCols=640 (global memory), kTCols=512 (effective tile region) + // cols=512 float elements = 256 structures + // block_len=64 float elements = 32 structures/block + func.func @TMRGSORT_f32_single_1x640_b64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : i32 // block_len (float elements) + %c640 = arith.constant 640 : index // kGCols (global memory stride) + %c512 = arith.constant 512 : index // kTCols (effective tile cols) + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c640], + strides = [%c640, %c640, %c640, %c640, %c1] + : !pto.tensor_view<1x1x1x1x640xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c640], + strides = [%c640, %c640, %c640, %c640, %c1] + : !pto.tensor_view<1x1x1x1x640xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c512] + : !pto.tensor_view<1x1x1x1x640xf32> -> !pto.partition_tensor_view<1x1x1x1x512xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c512] + : !pto.tensor_view<1x1x1x1x640xf32> -> !pto.partition_tensor_view<1x1x1x1x512xf32> + + %src_tile = pto.alloc_tile + : !pto.tile_buf + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x512xf32>) + outs(%src_tile : !pto.tile_buf) + + pto.tmrgsort ins(%src_tile, %c64 : !pto.tile_buf, i32) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x512xf32>) + return + } + + // Case: f16_single_1x256_b64 - transplanted from pto-isa case_single5 + // TMrgsortSingle + // TYPE_COEF = sizeof(float)/sizeof(uint16_t) = 2 + // Kernel params (TYPE_COEF applied): + // kGCols*2=512, kTCols*2=512, blockLen*2=128 (in half elements) + // cols=512 half elements = 128 structures + // block_len=128 half elements = 32 structures/block + func.func @TMRGSORT_f16_single_1x256_b64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : i32 // block_len * TYPE_COEF (half elements) + %c512 = arith.constant 512 : index // kGCols * TYPE_COEF (half elements) + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c512], + strides = [%c512, %c512, %c512, %c512, %c1] + : !pto.tensor_view<1x1x1x1x512xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c512], + strides = [%c512, %c512, %c512, %c512, %c1] + : !pto.tensor_view<1x1x1x1x512xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c512] + : !pto.tensor_view<1x1x1x1x512xf16> -> !pto.partition_tensor_view<1x1x1x1x512xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c512] + : !pto.tensor_view<1x1x1x1x512xf16> -> !pto.partition_tensor_view<1x1x1x1x512xf16> + + %src_tile = pto.alloc_tile + : !pto.tile_buf + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x512xf16>) + outs(%src_tile : !pto.tile_buf) + + pto.tmrgsort ins(%src_tile, %c128 : !pto.tile_buf, i32) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x512xf16>) + return + } + + // Case: f16_single_1x320_b64 - transplanted from pto-isa case_single6 + // TMrgsortSingle + // TYPE_COEF=2: kGCols*2=640, kTCols*2=512, blockLen*2=128 (kernel internal) + // cols=512 half elements = 128 structures + // block_len=128 half elements = 32 structures/block + func.func @TMRGSORT_f16_single_1x320_b64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : i32 // block_len * TYPE_COEF (half elements) + %c640 = arith.constant 640 : index // kGCols * TYPE_COEF (global memory stride) + %c512 = arith.constant 512 : index // kTCols * TYPE_COEF (effective tile cols) + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c640], + strides = [%c640, %c640, %c640, %c640, %c1] + : !pto.tensor_view<1x1x1x1x640xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c640], + strides = [%c640, %c640, %c640, %c640, %c1] + : !pto.tensor_view<1x1x1x1x640xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c512] + : !pto.tensor_view<1x1x1x1x640xf16> -> !pto.partition_tensor_view<1x1x1x1x512xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c512] + : !pto.tensor_view<1x1x1x1x640xf16> -> !pto.partition_tensor_view<1x1x1x1x512xf16> + + %src_tile = pto.alloc_tile + : !pto.tile_buf + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x512xf16>) + outs(%src_tile : !pto.tile_buf) + + pto.tmrgsort ins(%src_tile, %c128 : !pto.tile_buf, i32) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x512xf16>) + return + } + + // Case: f16_single_1x512_b64 - transplanted from pto-isa case_single7 + // TMrgsortSingle + // TYPE_COEF=2: kGCols*2=1024, kTCols*2=1024, blockLen*2=128 + // cols=1024 half elements = 256 structures (at tile limit), repeat_times=2 + func.func @TMRGSORT_f16_single_1x512_b64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : i32 // block_len * TYPE_COEF + %c1024 = arith.constant 1024 : index // kGCols * TYPE_COEF + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c1024], + strides = [%c1024, %c1024, %c1024, %c1024, %c1] + : !pto.tensor_view<1x1x1x1x1024xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c1024], + strides = [%c1024, %c1024, %c1024, %c1024, %c1] + : !pto.tensor_view<1x1x1x1x1024xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c1024] + : !pto.tensor_view<1x1x1x1x1024xf16> -> !pto.partition_tensor_view<1x1x1x1x1024xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c1024] + : !pto.tensor_view<1x1x1x1x1024xf16> -> !pto.partition_tensor_view<1x1x1x1x1024xf16> + + %src_tile = pto.alloc_tile + : !pto.tile_buf + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x1024xf16>) + outs(%src_tile : !pto.tile_buf) + + pto.tmrgsort ins(%src_tile, %c128 : !pto.tile_buf, i32) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x1024xf16>) + return + } + + // Case: f16_single_1x1024_b256 - transplanted from pto-isa case_single8 + // TMrgsortSingle + // TYPE_COEF=2: kGCols*2=2048, kTCols*2=2048, blockLen*2=512 + // cols=2048 half elements = 512 structures, repeat_times=1 (larger block) + func.func @TMRGSORT_f16_single_1x1024_b256(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c512 = arith.constant 512 : i32 // block_len * TYPE_COEF (256*2) + %c2048 = arith.constant 2048 : index // kGCols * TYPE_COEF (1024*2) + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c2048], + strides = [%c2048, %c2048, %c2048, %c2048, %c1] + : !pto.tensor_view<1x1x1x1x2048xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c2048], + strides = [%c2048, %c2048, %c2048, %c2048, %c1] + : !pto.tensor_view<1x1x1x1x2048xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c2048] + : !pto.tensor_view<1x1x1x1x2048xf16> -> !pto.partition_tensor_view<1x1x1x1x2048xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c2048] + : !pto.tensor_view<1x1x1x1x2048xf16> -> !pto.partition_tensor_view<1x1x1x1x2048xf16> + + %src_tile = pto.alloc_tile + : !pto.tile_buf + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x2048xf16>) + outs(%src_tile : !pto.tile_buf) + + pto.tmrgsort ins(%src_tile, %c512 : !pto.tile_buf, i32) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x2048xf16>) + return + } + + // Format2: 2-list merge sort (f32_2list_b64_basic) + func.func @TMRGSORT_f32_2list_b64_basic(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, + %dst_ptr: !pto.ptr, %ex_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %c0_i32 = arith.constant 0 : i32 + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c512, %c512, %c512, %c512, %c1] + : !pto.tensor_view<1x1x1x1x256xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c256] + : !pto.tensor_view<1x1x1x1x256xf32> -> !pto.partition_tensor_view<1x1x1x1x256xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c256] + : !pto.tensor_view<1x1x1x1x256xf32> -> !pto.partition_tensor_view<1x1x1x1x256xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c256] + : !pto.tensor_view<1x1x1x1x256xf32> -> !pto.partition_tensor_view<1x1x1x1x256xf32> + + %src0_tile = pto.alloc_tile + : !pto.tile_buf + %src1_tile = pto.alloc_tile + : !pto.tile_buf + %dst_tile = pto.alloc_tile + : !pto.tile_buf + %tmp_tile = pto.alloc_tile + : !pto.tile_buf + %ex_vec = arith.constant dense<0> : vector<4xi16> + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x1x256xf32>) + outs(%src0_tile : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x256xf32>) + outs(%src1_tile : !pto.tile_buf) + + pto.tmrgsort ins(%src0_tile, %src1_tile, %tmp_tile {exhausted = false} : + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile, %ex_vec : + !pto.tile_buf, + vector<4xi16>) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x256xf32>) + return + } + + // Format2: 2-list merge sort (f16_2list_b64_basic) + func.func @TMRGSORT_f16_2list_b64_basic(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, + %dst_ptr: !pto.ptr, %ex_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %c0_i32 = arith.constant 0 : i32 + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c512, %c512, %c512, %c512, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c256] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x256xf16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c256] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x256xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c256] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x256xf16> + + %src0_tile = pto.alloc_tile + : !pto.tile_buf + %src1_tile = pto.alloc_tile + : !pto.tile_buf + %dst_tile = pto.alloc_tile + : !pto.tile_buf + %tmp_tile = pto.alloc_tile + : !pto.tile_buf + %ex_vec = arith.constant dense<0> : vector<4xi16> + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x1x256xf16>) + outs(%src0_tile : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x256xf16>) + outs(%src1_tile : !pto.tile_buf) + + pto.tmrgsort ins(%src0_tile, %src1_tile, %tmp_tile {exhausted = false} : + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile, %ex_vec : + !pto.tile_buf, + vector<4xi16>) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x256xf16>) + return + } + + // Format2: 2-list merge sort with exhausted=true (f32_2list_exhausted) + // Aligned with pto-isa case_exhausted1: TMrgsortMulti + // kGCols_=64 (elements) → 32 structures per input + // TOPK=128 (elements) → 64 structures output + func.func @TMRGSORT_f32_2list_exhausted(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, + %dst_ptr: !pto.ptr, %ex_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index // kGCols (elements per input) + %c128 = arith.constant 128 : index // TOPK (elements for output) + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c1, %c64], + strides = [%c64, %c64, %c64, %c64, %c1] + : !pto.tensor_view<1x1x1x1x64xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c64], + strides = [%c64, %c64, %c64, %c64, %c1] + : !pto.tensor_view<1x1x1x1x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c64] + : !pto.tensor_view<1x1x1x1x64xf32> -> !pto.partition_tensor_view<1x1x1x1x64xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c64] + : !pto.tensor_view<1x1x1x1x64xf32> -> !pto.partition_tensor_view<1x1x1x1x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x128xf32> + + %src0_tile = pto.alloc_tile + : !pto.tile_buf + %src1_tile = pto.alloc_tile + : !pto.tile_buf + %dst_tile = pto.alloc_tile + : !pto.tile_buf + %tmp_tile = pto.alloc_tile + : !pto.tile_buf + %ex_vec = arith.constant dense<0> : vector<4xi16> + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x1x64xf32>) + outs(%src0_tile : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x64xf32>) + outs(%src1_tile : !pto.tile_buf) + + pto.tmrgsort ins(%src0_tile, %src1_tile, %tmp_tile {exhausted = true} : + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile, %ex_vec : + !pto.tile_buf, + vector<4xi16>) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x128xf32>) + return + } + + // Format3: 3-list merge sort (f32_3list_b64_basic) + // src0: 128 f32 elements = 64 structures + // src1: 128 f32 elements = 64 structures + // src2: 128 f32 elements = 64 structures + // dst: 256 f32 elements = 128 structures + func.func @TMRGSORT_f32_3list_b64_basic(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, + %src2_ptr: !pto.ptr, %dst_ptr: !pto.ptr, + %ex_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + %src2_view = pto.make_tensor_view %src2_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x128xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x128xf32> + %src2_part = pto.partition_view %src2_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x128xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c256] + : !pto.tensor_view<1x1x1x1x256xf32> -> !pto.partition_tensor_view<1x1x1x1x256xf32> + + %src0_tile = pto.alloc_tile + : !pto.tile_buf + %src1_tile = pto.alloc_tile + : !pto.tile_buf + %src2_tile = pto.alloc_tile + : !pto.tile_buf + %dst_tile = pto.alloc_tile + : !pto.tile_buf + %tmp_tile = pto.alloc_tile + : !pto.tile_buf + %ex_vec = arith.constant dense<0> : vector<4xi16> + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x1x128xf32>) + outs(%src0_tile : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x128xf32>) + outs(%src1_tile : !pto.tile_buf) + pto.tload ins(%src2_part : !pto.partition_tensor_view<1x1x1x1x128xf32>) + outs(%src2_tile : !pto.tile_buf) + + pto.tmrgsort ins(%src0_tile, %src1_tile, %src2_tile, %tmp_tile {exhausted = false} : + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile, %ex_vec : + !pto.tile_buf, + vector<4xi16>) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x256xf32>) + return + } + + // Format4: 4-list merge sort (f32_4list_b32_basic - pto-isa case_multi1) + // src0-3: 128 f32 elements = 64 structures each + // dst: 512 f32 elements = 256 structures (topk=256) + func.func @TMRGSORT_f32_4list_b32_basic(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, + %src2_ptr: !pto.ptr, %src3_ptr: !pto.ptr, + %dst_ptr: !pto.ptr, %ex_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + %src2_view = pto.make_tensor_view %src2_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + %src3_view = pto.make_tensor_view %src3_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c512], + strides = [%c512, %c512, %c512, %c512, %c1] + : !pto.tensor_view<1x1x1x1x512xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x128xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x128xf32> + %src2_part = pto.partition_view %src2_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x128xf32> + %src3_part = pto.partition_view %src3_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x128xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c512] + : !pto.tensor_view<1x1x1x1x512xf32> -> !pto.partition_tensor_view<1x1x1x1x512xf32> + + %src0_tile = pto.alloc_tile + : !pto.tile_buf + %src1_tile = pto.alloc_tile + : !pto.tile_buf + %src2_tile = pto.alloc_tile + : !pto.tile_buf + %src3_tile = pto.alloc_tile + : !pto.tile_buf + %dst_tile = pto.alloc_tile + : !pto.tile_buf + %tmp_tile = pto.alloc_tile + : !pto.tile_buf + %ex_vec = arith.constant dense<0> : vector<4xi16> + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x1x128xf32>) + outs(%src0_tile : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x128xf32>) + outs(%src1_tile : !pto.tile_buf) + pto.tload ins(%src2_part : !pto.partition_tensor_view<1x1x1x1x128xf32>) + outs(%src2_tile : !pto.tile_buf) + pto.tload ins(%src3_part : !pto.partition_tensor_view<1x1x1x1x128xf32>) + outs(%src3_tile : !pto.tile_buf) + + pto.tmrgsort ins(%src0_tile, %src1_tile, %src2_tile, %src3_tile, %tmp_tile {exhausted = false} : + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile, %ex_vec : + !pto.tile_buf, + vector<4xi16>) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x512xf32>) + return + } + + // Format4: 4-list merge sort (f16_4list_b64_basic) + // src0: 256 f16 elements = 64 structures + // src1: 256 f16 elements = 64 structures + // src2: 256 f16 elements = 64 structures + // src3: 256 f16 elements = 64 structures + // dst: 1024 f16 elements = 256 structures + func.func @TMRGSORT_f16_4list_b64_basic(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, + %src2_ptr: !pto.ptr, %src3_ptr: !pto.ptr, + %dst_ptr: !pto.ptr, %ex_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %c1024 = arith.constant 1024 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + %src2_view = pto.make_tensor_view %src2_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + %src3_view = pto.make_tensor_view %src3_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c1024], + strides = [%c1024, %c1024, %c1024, %c1024, %c1] + : !pto.tensor_view<1x1x1x1x1024xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c256] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x256xf16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c256] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x256xf16> + %src2_part = pto.partition_view %src2_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c256] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x256xf16> + %src3_part = pto.partition_view %src3_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c256] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x256xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c1024] + : !pto.tensor_view<1x1x1x1x1024xf16> -> !pto.partition_tensor_view<1x1x1x1x1024xf16> + + %src0_tile = pto.alloc_tile + : !pto.tile_buf + %src1_tile = pto.alloc_tile + : !pto.tile_buf + %src2_tile = pto.alloc_tile + : !pto.tile_buf + %src3_tile = pto.alloc_tile + : !pto.tile_buf + %dst_tile = pto.alloc_tile + : !pto.tile_buf + %tmp_tile = pto.alloc_tile + : !pto.tile_buf + %ex_vec = arith.constant dense<0> : vector<4xi16> + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x1x256xf16>) + outs(%src0_tile : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x256xf16>) + outs(%src1_tile : !pto.tile_buf) + pto.tload ins(%src2_part : !pto.partition_tensor_view<1x1x1x1x256xf16>) + outs(%src2_tile : !pto.tile_buf) + pto.tload ins(%src3_part : !pto.partition_tensor_view<1x1x1x1x256xf16>) + outs(%src3_tile : !pto.tile_buf) + + pto.tmrgsort ins(%src0_tile, %src1_tile, %src2_tile, %src3_tile, %tmp_tile {exhausted = false} : + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile, %ex_vec : + !pto.tile_buf, + vector<4xi16>) + +pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x1024xf16>) + return + } + + // Format4: 4-list merge sort (f16_4list_basic - pto-isa case_multi2) + // src0-3: 256 f16 elements = 64 structures each + // dst: 1024 f16 elements = 256 structures (topk=256) + func.func @TMRGSORT_f16_4list_basic(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, + %src2_ptr: !pto.ptr, %src3_ptr: !pto.ptr, + %dst_ptr: !pto.ptr, %ex_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + %src2_view = pto.make_tensor_view %src2_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + %src3_view = pto.make_tensor_view %src3_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c1024], + strides = [%c1024, %c1024, %c1024, %c1024, %c1] + : !pto.tensor_view<1x1x1x1x1024xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c256] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x256xf16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c256] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x256xf16> + %src2_part = pto.partition_view %src2_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c256] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x256xf16> + %src3_part = pto.partition_view %src3_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c256] + : !pto.tensor_view<1x1x1x1x256xf16> -> !pto.partition_tensor_view<1x1x1x1x256xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c1024] + : !pto.tensor_view<1x1x1x1x1024xf16> -> !pto.partition_tensor_view<1x1x1x1x1024xf16> + + %src0_tile = pto.alloc_tile + : !pto.tile_buf + %src1_tile = pto.alloc_tile + : !pto.tile_buf + %src2_tile = pto.alloc_tile + : !pto.tile_buf + %src3_tile = pto.alloc_tile + : !pto.tile_buf + %dst_tile = pto.alloc_tile + : !pto.tile_buf + %tmp_tile = pto.alloc_tile + : !pto.tile_buf + %ex_vec = arith.constant dense<0> : vector<4xi16> + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x1x256xf16>) + outs(%src0_tile : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x256xf16>) + outs(%src1_tile : !pto.tile_buf) + pto.tload ins(%src2_part : !pto.partition_tensor_view<1x1x1x1x256xf16>) + outs(%src2_tile : !pto.tile_buf) + pto.tload ins(%src3_part : !pto.partition_tensor_view<1x1x1x1x256xf16>) + outs(%src3_tile : !pto.tile_buf) + + pto.tmrgsort ins(%src0_tile, %src1_tile, %src2_tile, %src3_tile, %tmp_tile {exhausted = false} : + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile, %ex_vec : + !pto.tile_buf, + vector<4xi16>) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x1024xf16>) + return + } + + // Format3: 3-list non-uniform cols (f32_3list_non_uniform) + // src0: 128 f32 elements = 64 structures + // src1: 128 f32 elements = 64 structures + // src2: 64 f32 elements = 32 structures + // dst: 128 f32 elements = 64 structures (topk) + func.func @TMRGSORT_f32_3list_non_uniform(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, + %src2_ptr: !pto.ptr, %dst_ptr: !pto.ptr, + %ex_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + %src2_view = pto.make_tensor_view %src2_ptr, + shape = [%c1, %c1, %c1, %c1, %c64], + strides = [%c64, %c64, %c64, %c64, %c1] + : !pto.tensor_view<1x1x1x1x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x128xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x128xf32> + %src2_part = pto.partition_view %src2_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c64] + : !pto.tensor_view<1x1x1x1x64xf32> -> !pto.partition_tensor_view<1x1x1x1x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x128xf32> + + %src0_tile = pto.alloc_tile + : !pto.tile_buf + %src1_tile = pto.alloc_tile + : !pto.tile_buf + %src2_tile = pto.alloc_tile + : !pto.tile_buf + %dst_tile = pto.alloc_tile + : !pto.tile_buf + %tmp_tile = pto.alloc_tile + : !pto.tile_buf + %ex_vec = arith.constant dense<0> : vector<4xi16> + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x1x128xf32>) + outs(%src0_tile : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x128xf32>) + outs(%src1_tile : !pto.tile_buf) + pto.tload ins(%src2_part : !pto.partition_tensor_view<1x1x1x1x64xf32>) + outs(%src2_tile : !pto.tile_buf) + + pto.tmrgsort ins(%src0_tile, %src1_tile, %src2_tile, %tmp_tile {exhausted = false} : + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile, %ex_vec : + !pto.tile_buf, + vector<4xi16>) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x128xf32>) + return + } + + + // Format3: 3-list f16 exhausted (pto-isa case_exhausted2) + // src0-2: 512 f16 elements = 128 structures each (total=384) + // dst: 1536 f16 elements = 384 structures (topk=384) + func.func @TMRGSORT_f16_3list_exhausted(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, + %src2_ptr: !pto.ptr, %dst_ptr: !pto.ptr, + %ex_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c512 = arith.constant 512 : index + %c1536 = arith.constant 1536 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c1, %c512], + strides = [%c512, %c512, %c512, %c512, %c1] + : !pto.tensor_view<1x1x1x1x512xf16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c512], + strides = [%c512, %c512, %c512, %c512, %c1] + : !pto.tensor_view<1x1x1x1x512xf16> + %src2_view = pto.make_tensor_view %src2_ptr, + shape = [%c1, %c1, %c1, %c1, %c512], + strides = [%c512, %c512, %c512, %c512, %c1] + : !pto.tensor_view<1x1x1x1x512xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c1536], + strides = [%c1536, %c1536, %c1536, %c1536, %c1] + : !pto.tensor_view<1x1x1x1x1536xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c512] + : !pto.tensor_view<1x1x1x1x512xf16> -> !pto.partition_tensor_view<1x1x1x1x512xf16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c512] + : !pto.tensor_view<1x1x1x1x512xf16> -> !pto.partition_tensor_view<1x1x1x1x512xf16> + %src2_part = pto.partition_view %src2_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c512] + : !pto.tensor_view<1x1x1x1x512xf16> -> !pto.partition_tensor_view<1x1x1x1x512xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c1536] + : !pto.tensor_view<1x1x1x1x1536xf16> -> !pto.partition_tensor_view<1x1x1x1x1536xf16> + + %src0_tile = pto.alloc_tile + : !pto.tile_buf + %src1_tile = pto.alloc_tile + : !pto.tile_buf + %src2_tile = pto.alloc_tile + : !pto.tile_buf + %dst_tile = pto.alloc_tile + : !pto.tile_buf + %tmp_tile = pto.alloc_tile + : !pto.tile_buf + %ex_vec = arith.constant dense<0> : vector<4xi16> + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x1x512xf16>) + outs(%src0_tile : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x512xf16>) + outs(%src1_tile : !pto.tile_buf) + pto.tload ins(%src2_part : !pto.partition_tensor_view<1x1x1x1x512xf16>) + outs(%src2_tile : !pto.tile_buf) + + pto.tmrgsort ins(%src0_tile, %src1_tile, %src2_tile, %tmp_tile {exhausted = true} : + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile, %ex_vec : + !pto.tile_buf, + vector<4xi16>) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x1536xf16>) + return + } + + + // Format4: 4-list non-uniform cols (f32_4list_non_uniform) + // src0-2: 128 f32 elements = 64 structures each + // src3: 64 f32 elements = 32 structures + // dst: 448 f32 elements = 224 structures + func.func @TMRGSORT_f32_4list_non_uniform(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, + %src2_ptr: !pto.ptr, %src3_ptr: !pto.ptr, + %dst_ptr: !pto.ptr, %ex_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c448 = arith.constant 448 : index + %c512 = arith.constant 512 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + %src2_view = pto.make_tensor_view %src2_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + %src3_view = pto.make_tensor_view %src3_ptr, + shape = [%c1, %c1, %c1, %c1, %c64], + strides = [%c64, %c64, %c64, %c64, %c1] + : !pto.tensor_view<1x1x1x1x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c448], + strides = [%c448, %c448, %c448, %c448, %c1] + : !pto.tensor_view<1x1x1x1x448xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x128xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x128xf32> + %src2_part = pto.partition_view %src2_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x128xf32> + %src3_part = pto.partition_view %src3_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c64] + : !pto.tensor_view<1x1x1x1x64xf32> -> !pto.partition_tensor_view<1x1x1x1x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c448] + : !pto.tensor_view<1x1x1x1x448xf32> -> !pto.partition_tensor_view<1x1x1x1x448xf32> + + %src0_tile = pto.alloc_tile + : !pto.tile_buf + %src1_tile = pto.alloc_tile + : !pto.tile_buf + %src2_tile = pto.alloc_tile + : !pto.tile_buf + %src3_tile = pto.alloc_tile + : !pto.tile_buf + %dst_tile = pto.alloc_tile + : !pto.tile_buf + %tmp_tile = pto.alloc_tile + : !pto.tile_buf + %ex_vec = arith.constant dense<0> : vector<4xi16> + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x1x128xf32>) + outs(%src0_tile : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x128xf32>) + outs(%src1_tile : !pto.tile_buf) + pto.tload ins(%src2_part : !pto.partition_tensor_view<1x1x1x1x128xf32>) + outs(%src2_tile : !pto.tile_buf) + pto.tload ins(%src3_part : !pto.partition_tensor_view<1x1x1x1x64xf32>) + outs(%src3_tile : !pto.tile_buf) + + pto.tmrgsort ins(%src0_tile, %src1_tile, %src2_tile, %src3_tile, %tmp_tile {exhausted = false} : + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile, %ex_vec : + !pto.tile_buf, + vector<4xi16>) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x448xf32>) + return + } + +// Format5: TopK (iterative merge sort for top-k selection) +// Case: f32_topk_2048_1024 +// Iteration logic: block_len starts at 32 structures (64 elements) +// Iteration 1: 32*4=128 ≤ 1024 ✓ +// Iteration 2: 128*4=512 ≤ 1024 ✓ +// Iteration 3: 512*4=2048 > 1024 ✗ STOP +// After 2 iterations: 2 blocks of 512 structures each remain +// Use Format2 merge: store to src memory (reuse as intermediate), load into 2 tiles, merge +func.func @TMRGSORT_f32_topk_2048_1024(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : i32 // blockLen iteration 1 + %c256 = arith.constant 256 : i32 // blockLen iteration 2 + %c1024 = arith.constant 1024 : index // half tile cols (offset for block1) + %c1024_idx = arith.constant 1024 : index // dst cols (512 structures * 2 elems) + %c2048 = arith.constant 2048 : index // src cols (1024 structures * 2 elems) + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c2048], + strides = [%c2048, %c2048, %c2048, %c2048, %c1] + : !pto.tensor_view<1x1x1x1x2048xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c1024_idx], + strides = [%c1024_idx, %c1024_idx, %c1024_idx, %c1024_idx, %c1] + : !pto.tensor_view<1x1x1x1x1024xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c2048] + : !pto.tensor_view<1x1x1x1x2048xf32> -> !pto.partition_tensor_view<1x1x1x1x2048xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c1024_idx] + : !pto.tensor_view<1x1x1x1x1024xf32> -> !pto.partition_tensor_view<1x1x1x1x1024xf32> + + // Allocate tiles for iterative merge + %src_tile = pto.alloc_tile + : !pto.tile_buf + %tmp_tile = pto.alloc_tile + : !pto.tile_buf + %dst_tile = pto.alloc_tile + : !pto.tile_buf + // Tiles for Format2 merge (2 blocks of 512 structures each = 1024 elements each) + %block0_tile = pto.alloc_tile + : !pto.tile_buf + %block1_tile = pto.alloc_tile + : !pto.tile_buf + %merge_tmp_tile = pto.alloc_tile + : !pto.tile_buf + %merge_dst_tile = pto.alloc_tile + : !pto.tile_buf + %ex_vec = arith.constant dense<0> : vector<4xi16> + + // Load unsorted data to src_tile + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x2048xf32>) + outs(%src_tile : !pto.tile_buf) + + // Iteration 1: blockLen=64, merge 4 blocks (64*4=256 elements per group) + pto.tmrgsort ins(%src_tile, %c64 : + !pto.tile_buf, + i32) + outs(%tmp_tile : !pto.tile_buf) + + // Copy result back to src_tile for next iteration + pto.tmov ins(%tmp_tile : !pto.tile_buf) + outs(%src_tile : !pto.tile_buf) + + // Iteration 2: blockLen=256, merge 4 blocks (256*4=1024 elements per group) + // After this: 2 sorted blocks of 512 structures each (1024 elems each) + pto.tmrgsort ins(%src_tile, %c256 : + !pto.tile_buf, + i32) + outs(%tmp_tile : !pto.tile_buf) + + // Store tmp_tile back to src memory (reuse src as intermediate buffer) + // src memory now contains: block0 (0-1023), block1 (1024-2047) + pto.tstore ins(%tmp_tile : !pto.tile_buf) + outs(%src_part : !pto.partition_tensor_view<1x1x1x1x2048xf32>) + + // Load first block (0-1023) into block0_tile + %block0_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c1024] + : !pto.tensor_view<1x1x1x1x2048xf32> -> !pto.partition_tensor_view<1x1x1x1x1024xf32> + pto.tload ins(%block0_part : !pto.partition_tensor_view<1x1x1x1x1024xf32>) + outs(%block0_tile : !pto.tile_buf) + + // Load second block (1024-2047) into block1_tile + %block1_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c1024], + sizes = [%c1, %c1, %c1, %c1, %c1024] + : !pto.tensor_view<1x1x1x1x2048xf32> -> !pto.partition_tensor_view<1x1x1x1x1024xf32> + pto.tload ins(%block1_part : !pto.partition_tensor_view<1x1x1x1x1024xf32>) + outs(%block1_tile : !pto.tile_buf) + + // Format2 merge: merge block0 and block1 (512 structures each = 1024 elems each) + // Output: 1024 structures total = 2048 elements, dst takes top-k=512 structures + pto.tmrgsort ins(%block0_tile, %block1_tile, %merge_tmp_tile {exhausted = false} : + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%merge_dst_tile, %ex_vec : + !pto.tile_buf, + vector<4xi16>) + + // Take top-k from merged result (merge_dst_tile already has top-k=512 structures) + pto.tmov ins(%merge_dst_tile : !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + // Store top-k result + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x1024xf32>) + return + } + +// Case: f32_topk_2048_2048 (full global sort) +// Iterative merge: 64→256 (2 iterations), then Format2 merge for remaining 2 blocks +// cols=1024 structures, after 2 iterations: 2 blocks of 512 structures each +func.func @TMRGSORT_f32_topk_2048_2048(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : i32 // blockLen iteration 1 + %c256 = arith.constant 256 : i32 // blockLen iteration 2 + %c1024 = arith.constant 1024 : index // offset for block1 (512 structures * 2 elems) + %c2048 = arith.constant 2048 : index // src/dst cols + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c2048], + strides = [%c2048, %c2048, %c2048, %c2048, %c1] + : !pto.tensor_view<1x1x1x1x2048xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c2048], + strides = [%c2048, %c2048, %c2048, %c2048, %c1] + : !pto.tensor_view<1x1x1x1x2048xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c2048] + : !pto.tensor_view<1x1x1x1x2048xf32> -> !pto.partition_tensor_view<1x1x1x1x2048xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c2048] + : !pto.tensor_view<1x1x1x1x2048xf32> -> !pto.partition_tensor_view<1x1x1x1x2048xf32> + + // Allocate tiles for iterative merge + %src_tile = pto.alloc_tile + : !pto.tile_buf + %tmp_tile = pto.alloc_tile + : !pto.tile_buf + %dst_tile = pto.alloc_tile + : !pto.tile_buf + // Tiles for Format2 merge (2 blocks of 512 structures each = 1024 elements each) + %block0_tile = pto.alloc_tile + : !pto.tile_buf + %block1_tile = pto.alloc_tile + : !pto.tile_buf + %merge_tmp_tile = pto.alloc_tile + : !pto.tile_buf + %merge_dst_tile = pto.alloc_tile + : !pto.tile_buf + %ex_vec = arith.constant dense<0> : vector<4xi16> + + // Load unsorted data to src_tile + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x2048xf32>) + outs(%src_tile : !pto.tile_buf) + + // Iteration 1: blockLen=64, merge 4 blocks (64*4=256 elements per group) + pto.tmrgsort ins(%src_tile, %c64 : + !pto.tile_buf, + i32) + outs(%tmp_tile : !pto.tile_buf) + + // Copy result back to src_tile for next iteration + pto.tmov ins(%tmp_tile : !pto.tile_buf) + outs(%src_tile : !pto.tile_buf) + + // Iteration 2: blockLen=256, merge 4 blocks (256*4=1024 elements per group) + // After this: 2 sorted blocks of 512 structures each (1024 elems each) + pto.tmrgsort ins(%src_tile, %c256 : + !pto.tile_buf, + i32) + outs(%tmp_tile : !pto.tile_buf) + + // Store tmp_tile back to src memory (reuse src as intermediate buffer) + // src memory now contains: block0 (0-1023), block1 (1024-2047) + pto.tstore ins(%tmp_tile : !pto.tile_buf) + outs(%src_part : !pto.partition_tensor_view<1x1x1x1x2048xf32>) + + // Load first block (0-1023) into block0_tile + %block0_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c1024] + : !pto.tensor_view<1x1x1x1x2048xf32> -> !pto.partition_tensor_view<1x1x1x1x1024xf32> + pto.tload ins(%block0_part : !pto.partition_tensor_view<1x1x1x1x1024xf32>) + outs(%block0_tile : !pto.tile_buf) + + // Load second block (1024-2047) into block1_tile + %block1_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c1024], + sizes = [%c1, %c1, %c1, %c1, %c1024] + : !pto.tensor_view<1x1x1x1x2048xf32> -> !pto.partition_tensor_view<1x1x1x1x1024xf32> + pto.tload ins(%block1_part : !pto.partition_tensor_view<1x1x1x1x1024xf32>) + outs(%block1_tile : !pto.tile_buf) + + // Format2 merge: merge block0 and block1 (512 structures each = 1024 elems each) + // Output: 1024 structures total = 2048 elements (full sort) + pto.tmrgsort ins(%block0_tile, %block1_tile, %merge_tmp_tile {exhausted = false} : + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%merge_dst_tile, %ex_vec : + !pto.tile_buf, + vector<4xi16>) + + // Copy merged result to dst_tile (full sort, topk=2048) + pto.tmov ins(%merge_dst_tile : !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + // Store globally sorted result + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x2048xf32>) + return + } + +// Case: f32_topk_1280_512 (iterative merge for global sort, then take topk) +// cols=640 structures, after 2 iterations: 1 block of 512 structures + 1 block of 128 structures +// Need Format2 merge to combine these two sorted blocks +func.func @TMRGSORT_f32_topk_1280_512(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : i32 // blockLen iteration 1 + %c256 = arith.constant 256 : i32 // blockLen iteration 2 + %c256_idx = arith.constant 256 : index // block1 size in elements (128 structures * 2) + %c512 = arith.constant 512 : index // dst cols (topk), also offset for block1 + %c1024 = arith.constant 1024 : index // block0 size (512 structures * 2 elems) + %c1280 = arith.constant 1280 : index // src cols + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c1280], + strides = [%c1280, %c1280, %c1280, %c1280, %c1] + : !pto.tensor_view<1x1x1x1x1280xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c512], + strides = [%c512, %c512, %c512, %c512, %c1] + : !pto.tensor_view<1x1x1x1x512xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c1280] + : !pto.tensor_view<1x1x1x1x1280xf32> -> !pto.partition_tensor_view<1x1x1x1x1280xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c512] + : !pto.tensor_view<1x1x1x1x512xf32> -> !pto.partition_tensor_view<1x1x1x1x512xf32> + + // Allocate tiles for iterative merge + %src_tile = pto.alloc_tile + : !pto.tile_buf + %tmp_tile = pto.alloc_tile + : !pto.tile_buf + %dst_tile = pto.alloc_tile + : !pto.tile_buf + // Tiles for Format2 merge + // block0: 512 structures = 1024 elements + // block1: 128 structures = 256 elements (use cols=1024 tile with v_col=256 for valid portion) + %block0_tile = pto.alloc_tile + : !pto.tile_buf + %block1_tile = pto.alloc_tile + : !pto.tile_buf + %merge_tmp_tile = pto.alloc_tile + : !pto.tile_buf + %merge_dst_tile = pto.alloc_tile + : !pto.tile_buf + %ex_vec = arith.constant dense<0> : vector<4xi16> + + // Load unsorted data to src_tile + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x1280xf32>) + outs(%src_tile : !pto.tile_buf) + + // Iteration 1: blockLen=64, merge 4 blocks (64*4=256 elements per group) + pto.tmrgsort ins(%src_tile, %c64 : + !pto.tile_buf, + i32) + outs(%tmp_tile : !pto.tile_buf) + + // Copy result back to src_tile for next iteration + pto.tmov ins(%tmp_tile : !pto.tile_buf) + outs(%src_tile : !pto.tile_buf) + + // Iteration 2: blockLen=256, merge 4 blocks (256*4=1024 elements per group) + // After this: 1 sorted block of 512 structures (1024 elems) + 1 tail of 128 structures (256 elems) + pto.tmrgsort ins(%src_tile, %c256 : + !pto.tile_buf, + i32) + outs(%tmp_tile : !pto.tile_buf) + + // Store tmp_tile back to src memory (reuse as intermediate buffer) + pto.tstore ins(%tmp_tile : !pto.tile_buf) + outs(%src_part : !pto.partition_tensor_view<1x1x1x1x1280xf32>) + + // Load block0 (0-1023) into block0_tile (512 structures) + %block0_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c1024] + : !pto.tensor_view<1x1x1x1x1280xf32> -> !pto.partition_tensor_view<1x1x1x1x1024xf32> + pto.tload ins(%block0_part : !pto.partition_tensor_view<1x1x1x1x1024xf32>) + outs(%block0_tile : !pto.tile_buf) + + // Load block1 (1024-1279) into block1_tile (128 structures = 256 elems) + %block1_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c1024], + sizes = [%c1, %c1, %c1, %c1, %c256_idx] + : !pto.tensor_view<1x1x1x1x1280xf32> -> !pto.partition_tensor_view<1x1x1x1x256xf32> + pto.tload ins(%block1_part : !pto.partition_tensor_view<1x1x1x1x256xf32>) + outs(%block1_tile : !pto.tile_buf) + + // Format2 merge: merge block0 (512 structures) and block1 (128 structures) + // Output: 640 sorted structures, dst takes topk=256 structures (512 elems) + pto.tmrgsort ins(%block0_tile, %block1_tile, %merge_tmp_tile {exhausted = false} : + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%merge_dst_tile, %ex_vec : + !pto.tile_buf, + vector<4xi16>) + + // Take top-k from merged result (merge_dst_tile already has topk=256 structures) + pto.tmov ins(%merge_dst_tile : !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + // Store top-k result + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x512xf32>) + return + } + +// Case: f16_topk_2048_1024 (iterative merge) +// cols=512 structures (2048 f16 elements), block_len=64 f16 elements = 16 structures +// Iteration 1: 16*4=64 ≤ 512 → 8 blocks of 64 structures +// Iteration 2: 64*4=256 ≤ 512 → 2 blocks of 256 structures +// After 2 iterations: need Format2 merge for 2 blocks of 256 structures each +func.func @TMRGSORT_f16_topk_2048_1024(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : i32 // blockLen iteration 1 (f16 elements) + %c256 = arith.constant 256 : i32 // blockLen iteration 2 (f16 elements) + %c1024 = arith.constant 1024 : index // block0 size in elements (256 structures * 4) + %c1024_dst = arith.constant 1024 : index // dst cols (topk) + %c2048 = arith.constant 2048 : index // src cols + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c2048], + strides = [%c2048, %c2048, %c2048, %c2048, %c1] + : !pto.tensor_view<1x1x1x1x2048xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c1024_dst], + strides = [%c1024_dst, %c1024_dst, %c1024_dst, %c1024_dst, %c1] + : !pto.tensor_view<1x1x1x1x1024xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c2048] + : !pto.tensor_view<1x1x1x1x2048xf16> -> !pto.partition_tensor_view<1x1x1x1x2048xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c1024_dst] + : !pto.tensor_view<1x1x1x1x1024xf16> -> !pto.partition_tensor_view<1x1x1x1x1024xf16> + + // Allocate tiles for iterative merge + %src_tile = pto.alloc_tile + : !pto.tile_buf + %tmp_tile = pto.alloc_tile + : !pto.tile_buf + %dst_tile = pto.alloc_tile + : !pto.tile_buf + // Tiles for Format2 merge (2 blocks of 256 structures each = 1024 elements each) + %block0_tile = pto.alloc_tile + : !pto.tile_buf + %block1_tile = pto.alloc_tile + : !pto.tile_buf + %merge_tmp_tile = pto.alloc_tile + : !pto.tile_buf + %merge_dst_tile = pto.alloc_tile + : !pto.tile_buf + %ex_vec = arith.constant dense<0> : vector<4xi16> + + // Load unsorted data to src_tile + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x2048xf16>) + outs(%src_tile : !pto.tile_buf) + + // Iteration 1: blockLen=64 f16 elements (=16 structures) + // After: 8 blocks of 64 structures each + pto.tmrgsort ins(%src_tile, %c64 : + !pto.tile_buf, + i32) + outs(%tmp_tile : !pto.tile_buf) + + // Copy result back to src_tile for next iteration + pto.tmov ins(%tmp_tile : !pto.tile_buf) + outs(%src_tile : !pto.tile_buf) + + // Iteration 2: blockLen=256 f16 elements (=64 structures) + // After: 2 blocks of 256 structures each + pto.tmrgsort ins(%src_tile, %c256 : + !pto.tile_buf, + i32) + outs(%tmp_tile : !pto.tile_buf) + + // Store tmp_tile back to src memory (reuse as intermediate buffer) + pto.tstore ins(%tmp_tile : !pto.tile_buf) + outs(%src_part : !pto.partition_tensor_view<1x1x1x1x2048xf16>) + + // Load first block (0-1023) into block0_tile (256 structures = 1024 elements) + %block0_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c1024] + : !pto.tensor_view<1x1x1x1x2048xf16> -> !pto.partition_tensor_view<1x1x1x1x1024xf16> + pto.tload ins(%block0_part : !pto.partition_tensor_view<1x1x1x1x1024xf16>) + outs(%block0_tile : !pto.tile_buf) + + // Load second block (1024-2047) into block1_tile (256 structures = 1024 elements) + %block1_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c1024], + sizes = [%c1, %c1, %c1, %c1, %c1024] + : !pto.tensor_view<1x1x1x1x2048xf16> -> !pto.partition_tensor_view<1x1x1x1x1024xf16> + pto.tload ins(%block1_part : !pto.partition_tensor_view<1x1x1x1x1024xf16>) + outs(%block1_tile : !pto.tile_buf) + + // Format2 merge: merge block0 and block1 (256 structures each) + // Output: 512 sorted structures, dst takes topk=256 structures (1024 elems) + pto.tmrgsort ins(%block0_tile, %block1_tile, %merge_tmp_tile {exhausted = false} : + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%merge_dst_tile, %ex_vec : + !pto.tile_buf, + vector<4xi16>) + + // Take top-k=256 structures (1024 elements) from merged result + pto.tmov ins(%merge_dst_tile : !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + // Store top-k result + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x1024xf16>) + return + } + +// Case: f16_topk_2048_2048 (full global sort) +// cols=512 structures (2048 f16 elements), block_len=64 f16 elements = 16 structures +// Iteration 1: 16*4=64 ≤ 512 → 8 blocks of 64 structures +// Iteration 2: 64*4=256 ≤ 512 → 2 blocks of 256 structures +// After 2 iterations: need Format2 merge for 2 blocks, output all 512 structures +func.func @TMRGSORT_f16_topk_2048_2048(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : i32 // blockLen iteration 1 (f16 elements) + %c256 = arith.constant 256 : i32 // blockLen iteration 2 (f16 elements) + %c1024 = arith.constant 1024 : index // block size in elements (256 structures * 4) + %c2048 = arith.constant 2048 : index // src/dst cols + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c2048], + strides = [%c2048, %c2048, %c2048, %c2048, %c1] + : !pto.tensor_view<1x1x1x1x2048xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c2048], + strides = [%c2048, %c2048, %c2048, %c2048, %c1] + : !pto.tensor_view<1x1x1x1x2048xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c2048] + : !pto.tensor_view<1x1x1x1x2048xf16> -> !pto.partition_tensor_view<1x1x1x1x2048xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c2048] + : !pto.tensor_view<1x1x1x1x2048xf16> -> !pto.partition_tensor_view<1x1x1x1x2048xf16> + + // Allocate tiles for iterative merge + %src_tile = pto.alloc_tile + : !pto.tile_buf + %tmp_tile = pto.alloc_tile + : !pto.tile_buf + %dst_tile = pto.alloc_tile + : !pto.tile_buf + // Tiles for Format2 merge (2 blocks of 256 structures each = 1024 elements each) + %block0_tile = pto.alloc_tile + : !pto.tile_buf + %block1_tile = pto.alloc_tile + : !pto.tile_buf + %merge_tmp_tile = pto.alloc_tile + : !pto.tile_buf + %merge_dst_tile = pto.alloc_tile + : !pto.tile_buf + %ex_vec = arith.constant dense<0> : vector<4xi16> + + // Load unsorted data to src_tile + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x2048xf16>) + outs(%src_tile : !pto.tile_buf) + + // Iteration 1: blockLen=64 f16 elements (=16 structures) + pto.tmrgsort ins(%src_tile, %c64 : + !pto.tile_buf, + i32) + outs(%tmp_tile : !pto.tile_buf) + + // Copy result back to src_tile for next iteration + pto.tmov ins(%tmp_tile : !pto.tile_buf) + outs(%src_tile : !pto.tile_buf) + + // Iteration 2: blockLen=256 f16 elements (=64 structures) + // After: 2 blocks of 256 structures each + pto.tmrgsort ins(%src_tile, %c256 : + !pto.tile_buf, + i32) + outs(%tmp_tile : !pto.tile_buf) + + // Store tmp_tile back to src memory (reuse as intermediate buffer) + pto.tstore ins(%tmp_tile : !pto.tile_buf) + outs(%src_part : !pto.partition_tensor_view<1x1x1x1x2048xf16>) + + // Load first block (0-1023) into block0_tile (256 structures = 1024 elements) + %block0_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c1024] + : !pto.tensor_view<1x1x1x1x2048xf16> -> !pto.partition_tensor_view<1x1x1x1x1024xf16> + pto.tload ins(%block0_part : !pto.partition_tensor_view<1x1x1x1x1024xf16>) + outs(%block0_tile : !pto.tile_buf) + + // Load second block (1024-2047) into block1_tile (256 structures = 1024 elements) + %block1_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c1024], + sizes = [%c1, %c1, %c1, %c1, %c1024] + : !pto.tensor_view<1x1x1x1x2048xf16> -> !pto.partition_tensor_view<1x1x1x1x1024xf16> + pto.tload ins(%block1_part : !pto.partition_tensor_view<1x1x1x1x1024xf16>) + outs(%block1_tile : !pto.tile_buf) + + // Format2 merge: merge block0 and block1 (256 structures each) + // Output: 512 sorted structures (all) + pto.tmrgsort ins(%block0_tile, %block1_tile, %merge_tmp_tile {exhausted = false} : + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%merge_dst_tile, %ex_vec : + !pto.tile_buf, + vector<4xi16>) + + // Copy all merged data to dst_tile (topk=512 structures = all) + pto.tmov ins(%merge_dst_tile : !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + // Store globally sorted result + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x2048xf16>) + return + } + +// Case: f16_topk_1280_512 +// cols=320 structures (1280 f16 elements), block_len=64 f16 elements = 16 structures +// Iteration 1: 16*4=64 ≤ 320 → 5 blocks of 64 structures +// Iteration 2: 64*4=256 ≤ 320 → 1 block of 256 + 1 tail of 64 structures +// After 2 iterations: Format2 merge for 256 + 64 structures +func.func @TMRGSORT_f16_topk_1280_512(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : i32 // blockLen iteration 1 (f16 elements) + %c256 = arith.constant 256 : i32 // blockLen iteration 2 (f16 elements) + %c256_idx = arith.constant 256 : index // block1 size in elements (64 structures * 4) + %c512 = arith.constant 512 : index // block0 size (256 structures * 4), also dst cols (topk) + %c768 = arith.constant 768 : index // block1 offset (512 + 256) + %c1024 = arith.constant 1024 : index // block0 end offset (for partition) + %c1280 = arith.constant 1280 : index // src cols + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c1280], + strides = [%c1280, %c1280, %c1280, %c1280, %c1] + : !pto.tensor_view<1x1x1x1x1280xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c512], + strides = [%c512, %c512, %c512, %c512, %c1] + : !pto.tensor_view<1x1x1x1x512xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c1280] + : !pto.tensor_view<1x1x1x1x1280xf16> -> !pto.partition_tensor_view<1x1x1x1x1280xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c512] + : !pto.tensor_view<1x1x1x1x512xf16> -> !pto.partition_tensor_view<1x1x1x1x512xf16> + + // Allocate tiles for iterative merge + %src_tile = pto.alloc_tile + : !pto.tile_buf + %tmp_tile = pto.alloc_tile + : !pto.tile_buf + %dst_tile = pto.alloc_tile + : !pto.tile_buf + // Tiles for Format2 merge + // block0: 256 structures = 1024 f16 elements + // block1: 64 structures = 256 f16 elements + %block0_tile = pto.alloc_tile + : !pto.tile_buf + %block1_tile = pto.alloc_tile + : !pto.tile_buf + %merge_tmp_tile = pto.alloc_tile + : !pto.tile_buf + %merge_dst_tile = pto.alloc_tile + : !pto.tile_buf + %ex_vec = arith.constant dense<0> : vector<4xi16> + + // Load unsorted data to src_tile + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x1280xf16>) + outs(%src_tile : !pto.tile_buf) + + // Iteration 1: blockLen=64 f16 elements (=16 structures) + // After: 5 blocks of 64 structures each + pto.tmrgsort ins(%src_tile, %c64 : + !pto.tile_buf, + i32) + outs(%tmp_tile : !pto.tile_buf) + + // Copy result back to src_tile for next iteration + pto.tmov ins(%tmp_tile : !pto.tile_buf) + outs(%src_tile : !pto.tile_buf) + + // Iteration 2: blockLen=256 f16 elements (=64 structures) + // After: 1 block of 256 structures + 1 tail of 64 structures + pto.tmrgsort ins(%src_tile, %c256 : + !pto.tile_buf, + i32) + outs(%tmp_tile : !pto.tile_buf) + + // Store tmp_tile back to src memory (reuse as intermediate buffer) + pto.tstore ins(%tmp_tile : !pto.tile_buf) + outs(%src_part : !pto.partition_tensor_view<1x1x1x1x1280xf16>) + + // Load block0 (0-1023) into block0_tile (256 structures = 1024 f16 elements) + %block0_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c1024] + : !pto.tensor_view<1x1x1x1x1280xf16> -> !pto.partition_tensor_view<1x1x1x1x1024xf16> + pto.tload ins(%block0_part : !pto.partition_tensor_view<1x1x1x1x1024xf16>) + outs(%block0_tile : !pto.tile_buf) + + // Load block1 (1024-1279) into block1_tile (64 structures = 256 f16 elements) + %block1_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c1024], + sizes = [%c1, %c1, %c1, %c1, %c256_idx] + : !pto.tensor_view<1x1x1x1x1280xf16> -> !pto.partition_tensor_view<1x1x1x1x256xf16> + pto.tload ins(%block1_part : !pto.partition_tensor_view<1x1x1x1x256xf16>) + outs(%block1_tile : !pto.tile_buf) + + // Format2 merge: merge block0 (256 structures) and block1 (64 structures) + // Output: 320 sorted structures, dst takes topk=128 structures (512 f16 elems) + pto.tmrgsort ins(%block0_tile, %block1_tile, %merge_tmp_tile {exhausted = false} : + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%merge_dst_tile, %ex_vec : + !pto.tile_buf, + vector<4xi16>) + + // Take top-k from merged result (merge_dst_tile already has topk=128 structures) + pto.tmov ins(%merge_dst_tile : !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + // Store top-k result + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x512xf16>) + return + } + +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmul/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tmul/CMakeLists.txt new file mode 100644 index 000000000..4134fa993 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmul/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tmul) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmul/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tmul/cases.py new file mode 100644 index 000000000..2d3a70ce8 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmul/cases.py @@ -0,0 +1,40 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tmul ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_16x64", + "dtype": np.float32, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-6, + }, + { + "name": "f32_32x32", + "dtype": np.float32, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-6, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmul/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tmul/compare.py new file mode 100644 index 000000000..4eae3bc07 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmul/compare.py @@ -0,0 +1,48 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmul/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tmul/gen_data.py new file mode 100644 index 000000000..0cf58f73b --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmul/gen_data.py @@ -0,0 +1,32 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + input2 = np.random.randint(1, 10, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + golden[:vr, :vc] = (input1[:vr, :vc] * input2[:vr, :vc]).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmul/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tmul/launch.cpp new file mode 100644 index 000000000..1debfe140 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmul/launch.cpp @@ -0,0 +1,27 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: f32 16x64 +extern "C" __global__ AICORE void TMUL_f32_16x64(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTMUL_f32_16x64(float *a, float *b, float *c, void *stream) { + TMUL_f32_16x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 1: f32 32x32 +extern "C" __global__ AICORE void TMUL_f32_32x32(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTMUL_f32_32x32(float *a, float *b, float *c, void *stream) { + TMUL_f32_32x32<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmul/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tmul/main.cpp new file mode 100644 index 000000000..6e294af40 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmul/main.cpp @@ -0,0 +1,145 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tmul ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTMUL_f32_16x64(float *a, float *b, float *c, void *stream); +void LaunchTMUL_f32_32x32(float *a, float *b, float *c, void *stream); + +using LaunchFn = void (*)(float *, float *, float *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_16x64", LaunchTMUL_f32_16x64, 16, 64, 16, 64, sizeof(float)}, + {"f32_32x32", LaunchTMUL_f32_32x32, 32, 32, 32, 32, sizeof(float)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t src0FileSize = fileSize; + size_t src1FileSize = fileSize; + + float *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + float *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), fileSize); + aclrtMallocHost((void **)(&src1Host), fileSize); + aclrtMallocHost((void **)(&dstHost), fileSize); + + aclrtMalloc((void **)&src0Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, fileSize, src0Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, fileSize, src1Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tmul [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmul/tmul.pto b/test/tilelang_st/npu/a5/src/st/testcase/tmul/tmul.pto new file mode 100644 index 000000000..7ee883cc9 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmul/tmul.pto @@ -0,0 +1,123 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tmul: tload(a) + tload(b) + tmul(a,b)->c + tstore(c). +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // Case 0: f32 16x64 (1024 elements) + func.func @TMUL_f32_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tmul ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + return + } + + // Case 1: f32 32x32 (1024 elements) + func.func @TMUL_f32_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + outs(%b : !pto.tile_buf) + + pto.tmul ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmuls/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tmuls/CMakeLists.txt new file mode 100644 index 000000000..49ba8cd84 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmuls/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tmuls) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmuls/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tmuls/cases.py new file mode 100644 index 000000000..d12724e5f --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmuls/cases.py @@ -0,0 +1,69 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tmuls ST test cases. + +Shapes and dtype match testcase/tadds (C++ GTest suite): + case1: float, 32x64, valid 32x64 + case2: float16, 63x64, valid 63x64 + case3: int32, 31x128, valid 31x128 + case4: int16, 15x192, valid 15x192 + case5: float, 7x448, valid 7x448 + case6: float, 256x16, valid 256x16 + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_32x64", + "dtype": np.float32, + "shape": (32, 64), + "valid_shape": (32, 64), + "eps": 1e-6, + }, + { + "name": "f16_63x64", + "dtype": np.float16, + "shape": (63, 64), + "valid_shape": (63, 64), + "eps": 1e-3, + }, + { + "name": "i32_31x128", + "dtype": np.int32, + "shape": (31, 128), + "valid_shape": (31, 128), + "eps": 0, + }, + { + "name": "i16_15x192", + "dtype": np.int16, + "shape": (15, 192), + "valid_shape": (15, 192), + "eps": 0, + }, + { + "name": "f32_7x448", + "dtype": np.float32, + "shape": (7, 448), + "valid_shape": (7, 448), + "eps": 1e-6, + }, + { + "name": "f32_256x16", + "dtype": np.float32, + "shape": (256, 16), + "valid_shape": (256, 16), + "eps": 1e-6, + }, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmuls/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tmuls/compare.py new file mode 100644 index 000000000..50186777e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmuls/compare.py @@ -0,0 +1,46 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmuls/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tmuls/gen_data.py new file mode 100644 index 000000000..a98114643 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmuls/gen_data.py @@ -0,0 +1,35 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +# Scalar value multiplied into every element (matches the scalar passed in launch.cpp) +SCALAR = 3.0 + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + scalar_val = dtype(SCALAR) + golden[:vr, :vc] = (input1[:vr, :vc] * scalar_val).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__} scalar={SCALAR}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmuls/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tmuls/launch.cpp new file mode 100644 index 000000000..fdd67d596 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmuls/launch.cpp @@ -0,0 +1,58 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Scalar value multiplied into every element (must match gen_data.py SCALAR) +static constexpr float TMULS_SCALAR_F32 = 3.0f; + +// Case 0: f32 32x64 +extern "C" __global__ AICORE void TMULS_f32_32x64(__gm__ float *src, __gm__ float *dst, float scalar); + +void LaunchTMULS_f32_32x64(float *src, float *dst, void *stream) { + TMULS_f32_32x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, TMULS_SCALAR_F32); +} + +// Case 1: f16 63x64 +extern "C" __global__ AICORE void TMULS_f16_63x64(__gm__ unsigned short *src, __gm__ unsigned short *dst, unsigned short scalar); + +void LaunchTMULS_f16_63x64(unsigned short *src, unsigned short *dst, void *stream) { + TMULS_f16_63x64<<<1, nullptr, stream>>>((__gm__ unsigned short *)src, (__gm__ unsigned short *)dst, (unsigned short)0x4200); +} + +// Case 2: i32 31x128 +extern "C" __global__ AICORE void TMULS_i32_31x128(__gm__ int32_t *src, __gm__ int32_t *dst, int32_t scalar); + +void LaunchTMULS_i32_31x128(int32_t *src, int32_t *dst, void *stream) { + TMULS_i32_31x128<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst, (int32_t)3); +} + +// Case 3: i16 15x192 +extern "C" __global__ AICORE void TMULS_i16_15x192(__gm__ int16_t *src, __gm__ int16_t *dst, int16_t scalar); + +void LaunchTMULS_i16_15x192(int16_t *src, int16_t *dst, void *stream) { + TMULS_i16_15x192<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst, (int16_t)3); +} + +// Case 4: f32 7x448 +extern "C" __global__ AICORE void TMULS_f32_7x448(__gm__ float *src, __gm__ float *dst, float scalar); + +void LaunchTMULS_f32_7x448(float *src, float *dst, void *stream) { + TMULS_f32_7x448<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, TMULS_SCALAR_F32); +} + +// Case 5: f32 256x16 +extern "C" __global__ AICORE void TMULS_f32_256x16(__gm__ float *src, __gm__ float *dst, float scalar); + +void LaunchTMULS_f32_256x16(float *src, float *dst, void *stream) { + TMULS_f32_256x16<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, TMULS_SCALAR_F32); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmuls/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tmuls/main.cpp new file mode 100644 index 000000000..a5372cccc --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmuls/main.cpp @@ -0,0 +1,139 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tmuls ST — case-table driven. +// tmuls: dst = src * scalar (single input + scalar). +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTMULS_f32_32x64(float *src, float *dst, void *stream); +void LaunchTMULS_f16_63x64(uint16_t *src, uint16_t *dst, void *stream); +void LaunchTMULS_i32_31x128(int32_t *src, int32_t *dst, void *stream); +void LaunchTMULS_i16_15x192(int16_t *src, int16_t *dst, void *stream); +void LaunchTMULS_f32_7x448(float *src, float *dst, void *stream); +void LaunchTMULS_f32_256x16(float *src, float *dst, void *stream); + +struct TestCase { + const char *name; + void (*launch)(void *, void *, void *); // src, dst, stream + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_32x64", (void (*)(void*,void*,void*))LaunchTMULS_f32_32x64, 32, 64, 32, 64, sizeof(float)}, + {"f16_63x64", (void (*)(void*,void*,void*))LaunchTMULS_f16_63x64, 63, 64, 63, 64, sizeof(uint16_t)}, + {"i32_31x128", (void (*)(void*,void*,void*))LaunchTMULS_i32_31x128, 31, 128, 31, 128, sizeof(int32_t)}, + {"i16_15x192", (void (*)(void*,void*,void*))LaunchTMULS_i16_15x192, 15, 192, 15, 192, sizeof(int16_t)}, + {"f32_7x448", (void (*)(void*,void*,void*))LaunchTMULS_f32_7x448, 7, 448, 7, 448, sizeof(float)}, + {"f32_256x16", (void (*)(void*,void*,void*))LaunchTMULS_f32_256x16, 256, 16, 256, 16, sizeof(float)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSize = fileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&srcHost, fileSize); + aclrtMallocHost(&dstHost, fileSize); + + aclrtMalloc(&srcDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), srcFileSize, srcHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, fileSize, srcHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tmuls [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tmuls/tmuls.pto b/test/tilelang_st/npu/a5/src/st/testcase/tmuls/tmuls.pto new file mode 100644 index 000000000..a6fcc5cd9 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tmuls/tmuls.pto @@ -0,0 +1,256 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tmuls: tload(src) + tmuls(src, scalar)->dst + tstore(dst). +// Multiple cases with different shapes/dtypes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + + // Case 0: f32 32x64 (2048 elements) + func.func @TMULS_f32_32x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) + outs(%src : !pto.tile_buf) + pto.tmuls ins(%src, %scalar : !pto.tile_buf, f32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) + return + } + + // Case 1: f16 63x64 (4032 elements) + func.func @TMULS_f16_63x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f16) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c4032 = arith.constant 4032 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xf16> -> !pto.partition_tensor_view<1x1x1x63x64xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xf16> -> !pto.partition_tensor_view<1x1x1x63x64xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x64xf16>) + outs(%src : !pto.tile_buf) + pto.tmuls ins(%src, %scalar : !pto.tile_buf, f16) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x64xf16>) + return + } + + // Case 2: i32 31x128 (3968 elements) + func.func @TMULS_i32_31x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c31 = arith.constant 31 : index + %c128 = arith.constant 128 : index + %c3968 = arith.constant 3968 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c128] + : !pto.tensor_view<1x1x1x31x128xi32> -> !pto.partition_tensor_view<1x1x1x31x128xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c128] + : !pto.tensor_view<1x1x1x31x128xi32> -> !pto.partition_tensor_view<1x1x1x31x128xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x31x128xi32>) + outs(%src : !pto.tile_buf) + pto.tmuls ins(%src, %scalar : !pto.tile_buf, i32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x31x128xi32>) + return + } + + // Case 3: i16 15x192 (2880 elements) + func.func @TMULS_i16_15x192(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i16) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c192 = arith.constant 192 : index + %c2880 = arith.constant 2880 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xi16> -> !pto.partition_tensor_view<1x1x1x15x192xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xi16> -> !pto.partition_tensor_view<1x1x1x15x192xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x192xi16>) + outs(%src : !pto.tile_buf) + pto.tmuls ins(%src, %scalar : !pto.tile_buf, i16) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x15x192xi16>) + return + } + + // Case 4: f32 7x448 (3136 elements) + func.func @TMULS_f32_7x448(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c7 = arith.constant 7 : index + %c448 = arith.constant 448 : index + %c3136 = arith.constant 3136 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c7, %c448], + strides = [%c3136, %c3136, %c3136, %c448, %c1] + : !pto.tensor_view<1x1x1x7x448xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c7, %c448], + strides = [%c3136, %c3136, %c3136, %c448, %c1] + : !pto.tensor_view<1x1x1x7x448xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c448] + : !pto.tensor_view<1x1x1x7x448xf32> -> !pto.partition_tensor_view<1x1x1x7x448xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c448] + : !pto.tensor_view<1x1x1x7x448xf32> -> !pto.partition_tensor_view<1x1x1x7x448xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x7x448xf32>) + outs(%src : !pto.tile_buf) + pto.tmuls ins(%src, %scalar : !pto.tile_buf, f32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x7x448xf32>) + return + } + + // Case 5: f32 256x16 (4096 elements) + func.func @TMULS_f32_256x16(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c16 = arith.constant 16 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c256, %c16], + strides = [%c4096, %c4096, %c4096, %c16, %c1] + : !pto.tensor_view<1x1x1x256x16xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c256, %c16], + strides = [%c4096, %c4096, %c4096, %c16, %c1] + : !pto.tensor_view<1x1x1x256x16xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c16] + : !pto.tensor_view<1x1x1x256x16xf32> -> !pto.partition_tensor_view<1x1x1x256x16xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c16] + : !pto.tensor_view<1x1x1x256x16xf32> -> !pto.partition_tensor_view<1x1x1x256x16xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x256x16xf32>) + outs(%src : !pto.tile_buf) + pto.tmuls ins(%src, %scalar : !pto.tile_buf, f32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x256x16xf32>) + return + } + +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tneg/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tneg/CMakeLists.txt new file mode 100644 index 000000000..02a068e9e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tneg/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tneg) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tneg/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tneg/cases.py new file mode 100644 index 000000000..f5251d28a --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tneg/cases.py @@ -0,0 +1,69 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tneg ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_16x64", + "dtype": np.float32, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-6, + }, + { + "name": "f32_32x32", + "dtype": np.float32, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-6, + }, + { + "name": "f16_16x64", + "dtype": np.float16, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-3, + }, + { + "name": "f16_32x32", + "dtype": np.float16, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-3, + }, + { + "name": "i32_32x32", + "dtype": np.int32, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 0, + }, + { + "name": "i16_64x16", + "dtype": np.int16, + "shape": (64, 16), + "valid_shape": (64, 16), + "eps": 0, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tneg/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tneg/compare.py new file mode 100644 index 000000000..428604929 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tneg/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tneg/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tneg/gen_data.py new file mode 100644 index 000000000..0c88055b7 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tneg/gen_data.py @@ -0,0 +1,33 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + # Random values (no constraints for neg) + input = np.random.randn(*shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + golden[:vr, :vc] = np.negative(input[:vr, :vc]).astype(dtype, copy=False) + + save_case_data(case["name"], {"input": input, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tneg/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tneg/launch.cpp new file mode 100644 index 000000000..ef6121f48 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tneg/launch.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: f32 16x64 +extern "C" __global__ AICORE void TNEG_f32_16x64(__gm__ float *a, __gm__ float *b); + +void LaunchTNEG_f32_16x64(void *a, void *b, void *stream) { + TNEG_f32_16x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b); +} + +// Case 1: f32 32x32 +extern "C" __global__ AICORE void TNEG_f32_32x32(__gm__ float *a, __gm__ float *b); + +void LaunchTNEG_f32_32x32(void *a, void *b, void *stream) { + TNEG_f32_32x32<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b); +} + +// Case 2: f16 16x64 +extern "C" __global__ AICORE void TNEG_f16_16x64(__gm__ uint16_t *a, __gm__ uint16_t *b); + +void LaunchTNEG_f16_16x64(void *a, void *b, void *stream) { + TNEG_f16_16x64<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b); +} + +// Case 3: f16 32x32 +extern "C" __global__ AICORE void TNEG_f16_32x32(__gm__ uint16_t *a, __gm__ uint16_t *b); + +void LaunchTNEG_f16_32x32(void *a, void *b, void *stream) { + TNEG_f16_32x32<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b); +} + +// Case 4: i32 32x32 +extern "C" __global__ AICORE void TNEG_i32_32x32(__gm__ int32_t *a, __gm__ int32_t *b); + +void LaunchTNEG_i32_32x32(void *a, void *b, void *stream) { + TNEG_i32_32x32<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b); +} + +// Case 5: i16 64x16 +extern "C" __global__ AICORE void TNEG_i16_64x16(__gm__ int16_t *a, __gm__ int16_t *b); + +void LaunchTNEG_i16_64x16(void *a, void *b, void *stream) { + TNEG_i16_64x16<<<1, nullptr, stream>>>((__gm__ int16_t *)a, (__gm__ int16_t *)b); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tneg/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tneg/main.cpp new file mode 100644 index 000000000..6a86073fc --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tneg/main.cpp @@ -0,0 +1,141 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tneg ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTNEG_f32_16x64(void *a, void *b, void *stream); +void LaunchTNEG_f32_32x32(void *a, void *b, void *stream); +void LaunchTNEG_f16_16x64(void *a, void *b, void *stream); +void LaunchTNEG_f16_32x32(void *a, void *b, void *stream); +void LaunchTNEG_i32_32x32(void *a, void *b, void *stream); +void LaunchTNEG_i16_64x16(void *a, void *b, void *stream); + +using LaunchFn = void (*)(void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_16x64", LaunchTNEG_f32_16x64, 16, 64, 16, 64, sizeof(float)}, + {"f32_32x32", LaunchTNEG_f32_32x32, 32, 32, 32, 32, sizeof(float)}, + {"f16_16x64", LaunchTNEG_f16_16x64, 16, 64, 16, 64, sizeof(uint16_t)}, + {"f16_32x32", LaunchTNEG_f16_32x32, 32, 32, 32, 32, sizeof(uint16_t)}, + {"i32_32x32", LaunchTNEG_i32_32x32, 32, 32, 32, 32, sizeof(int32_t)}, + {"i16_64x16", LaunchTNEG_i16_64x16, 64, 16, 64, 16, sizeof(int16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSize = fileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&srcHost), fileSize); + aclrtMallocHost((void **)(&dstHost), fileSize); + + aclrtMalloc((void **)&srcDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input.bin").c_str(), srcFileSize, srcHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, fileSize, srcHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tneg [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tneg/tneg.pto b/test/tilelang_st/npu/a5/src/st/testcase/tneg/tneg.pto new file mode 100644 index 000000000..f67ec4f19 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tneg/tneg.pto @@ -0,0 +1,263 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tneg: tload(a) + tneg(a)->b + tstore(b). +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // Case 0: f32 16x64 (1024 elements) + func.func @TNEG_f32_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%a : !pto.tile_buf) + + pto.tneg ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + return + } + + // Case 1: f32 32x32 (1024 elements) + func.func @TNEG_f32_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + outs(%a : !pto.tile_buf) + + pto.tneg ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + return + } + + // Case 2: f16 16x64 (1024 elements) + func.func @TNEG_f16_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + outs(%a : !pto.tile_buf) + + pto.tneg ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + return + } + + // Case 3: f16 32x32 (1024 elements) + func.func @TNEG_f16_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf16> -> !pto.partition_tensor_view<1x1x1x32x32xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf16> -> !pto.partition_tensor_view<1x1x1x32x32xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xf16>) + outs(%a : !pto.tile_buf) + + pto.tneg ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x32x32xf16>) + return + } + + // Case 4: i32 32x32 (1024 elements) + func.func @TNEG_i32_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xi32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xi32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xi32> -> !pto.partition_tensor_view<1x1x1x32x32xi32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xi32> -> !pto.partition_tensor_view<1x1x1x32x32xi32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xi32>) + outs(%a : !pto.tile_buf) + + pto.tneg ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x32x32xi32>) + return + } + + // Case 5: i16 64x16 (1024 elements) + func.func @TNEG_i16_64x16(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c16], + strides = [%c1024, %c1024, %c1024, %c16, %c1] + : !pto.tensor_view<1x1x1x64x16xi16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c16], + strides = [%c1024, %c1024, %c1024, %c16, %c1] + : !pto.tensor_view<1x1x1x64x16xi16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c16] + : !pto.tensor_view<1x1x1x64x16xi16> -> !pto.partition_tensor_view<1x1x1x64x16xi16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c16] + : !pto.tensor_view<1x1x1x64x16xi16> -> !pto.partition_tensor_view<1x1x1x64x16xi16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x16xi16>) + outs(%a : !pto.tile_buf) + + pto.tneg ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x64x16xi16>) + return + } +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tnot/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tnot/CMakeLists.txt new file mode 100644 index 000000000..ee5525ac2 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tnot/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tnot) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tnot/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tnot/cases.py new file mode 100644 index 000000000..b6612d63f --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tnot/cases.py @@ -0,0 +1,69 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tnot ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.int8, np.int16, np.int32). + - shape: (rows, cols) — allocated tile dimensions. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - eps: tolerance for numpy.allclose (atol and rtol), 0 for exact match. + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "int8_64x64", + "dtype": np.int8, + "shape": (64, 64), + "valid_shape": (64, 64), + "eps": 0 + }, + { + "name": "uint8_60x60", + "dtype": np.uint8, + "shape": (64, 64), + "valid_shape": (60, 60), + "eps": 0 + }, + { + "name": "int16_64x64", + "dtype": np.int16, + "shape": (64, 64), + "valid_shape": (64, 64), + "eps": 0 + }, + { + "name": "uint16_60x60", + "dtype": np.uint16, + "shape": (64, 64), + "valid_shape": (60, 60), + "eps": 0 + }, + { + "name": "int32_64x64", + "dtype": np.int32, + "shape": (64, 64), + "valid_shape": (64, 64), + "eps": 0 + }, + { + "name": "uint32_60x60", + "dtype": np.uint32, + "shape": (64, 64), + "valid_shape": (60, 60), + "eps": 0 + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tnot/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tnot/compare.py new file mode 100644 index 000000000..428604929 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tnot/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tnot/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tnot/gen_data.py new file mode 100644 index 000000000..62de58386 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tnot/gen_data.py @@ -0,0 +1,30 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + dtype_info = np.iinfo(dtype) + input = np.random.randint(dtype_info.min, dtype_info.max, size=shape, dtype=dtype) + golden = np.bitwise_not(input).astype(dtype, copy=False) + + save_case_data(case["name"], {"input": input, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tnot/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tnot/launch.cpp new file mode 100644 index 000000000..858f6d181 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tnot/launch.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: int8 64x64 +extern "C" __global__ AICORE void TNOT_int8_64x64(__gm__ int8_t *a, __gm__ int8_t *b); + +void LaunchTNOT_int8_64x64(void *a, void *b, void *stream) { + TNOT_int8_64x64<<<1, nullptr, stream>>>((__gm__ int8_t *)a, (__gm__ int8_t *)b); +} + +// Case 1: uint8 60x60 +extern "C" __global__ AICORE void TNOT_uint8_60x60(__gm__ uint8_t *a, __gm__ uint8_t *b); + +void LaunchTNOT_uint8_60x60(void *a, void *b, void *stream) { + TNOT_uint8_60x60<<<1, nullptr, stream>>>((__gm__ uint8_t *)a, (__gm__ uint8_t *)b); +} + +// Case 2: int16 64x64 +extern "C" __global__ AICORE void TNOT_int16_64x64(__gm__ int16_t *a, __gm__ int16_t *b); + +void LaunchTNOT_int16_64x64(void *a, void *b, void *stream) { + TNOT_int16_64x64<<<1, nullptr, stream>>>((__gm__ int16_t *)a, (__gm__ int16_t *)b); +} + +// Case 3: uint16 60x60 +extern "C" __global__ AICORE void TNOT_uint16_60x60(__gm__ uint16_t *a, __gm__ uint16_t *b); + +void LaunchTNOT_uint16_60x60(void *a, void *b, void *stream) { + TNOT_uint16_60x60<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b); +} + +// Case 4: int32 64x64 +extern "C" __global__ AICORE void TNOT_int32_64x64(__gm__ int32_t *a, __gm__ int32_t *b); + +void LaunchTNOT_int32_64x64(void *a, void *b, void *stream) { + TNOT_int32_64x64<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b); +} + +// Case 5: uint32 60x60 +extern "C" __global__ AICORE void TNOT_uint32_60x60(__gm__ uint32_t *a, __gm__ uint32_t *b); + +void LaunchTNOT_uint32_60x60(void *a, void *b, void *stream) { + TNOT_uint32_60x60<<<1, nullptr, stream>>>((__gm__ uint32_t *)a, (__gm__ uint32_t *)b); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tnot/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tnot/main.cpp new file mode 100644 index 000000000..55a823be7 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tnot/main.cpp @@ -0,0 +1,141 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tnot ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTNOT_int8_64x64(void *a, void *b, void *stream); +void LaunchTNOT_uint8_60x60(void *a, void *b, void *stream); +void LaunchTNOT_int16_64x64(void *a, void *b, void *stream); +void LaunchTNOT_uint16_60x60(void *a, void *b, void *stream); +void LaunchTNOT_int32_64x64(void *a, void *b, void *stream); +void LaunchTNOT_uint32_60x60(void *a, void *b, void *stream); + +using LaunchFn = void (*)(void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"int8_64x64", LaunchTNOT_int8_64x64, 64, 64, 64, 64, sizeof(int8_t)}, + {"uint8_60x60", LaunchTNOT_uint8_60x60, 64, 64, 60, 60, sizeof(uint8_t)}, + {"int16_64x64", LaunchTNOT_int16_64x64, 64, 64, 64, 64, sizeof(int16_t)}, + {"uint16_60x60", LaunchTNOT_uint16_60x60, 64, 64, 60, 60, sizeof(uint16_t)}, + {"int32_64x64", LaunchTNOT_int32_64x64, 64, 64, 64, 64, sizeof(int32_t)}, + {"uint32_60x60", LaunchTNOT_uint32_60x60, 64, 64, 60, 60, sizeof(uint32_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSize = fileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&srcHost), fileSize); + aclrtMallocHost((void **)(&dstHost), fileSize); + + aclrtMalloc((void **)&srcDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input.bin").c_str(), srcFileSize, srcHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, fileSize, srcHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tnot [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tnot/tnot.pto b/test/tilelang_st/npu/a5/src/st/testcase/tnot/tnot.pto new file mode 100644 index 000000000..bbd028835 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tnot/tnot.pto @@ -0,0 +1,263 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tnot: tload(a) + tnot(a)->b + tstore(b). +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // Case 0: int8 64x64 (valid 64x64) + func.func @TNOT_int8_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi8> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi8> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi8> -> !pto.partition_tensor_view<1x1x1x64x64xi8> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi8> -> !pto.partition_tensor_view<1x1x1x64x64xi8> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xi8>) + outs(%a : !pto.tile_buf) + + pto.tnot ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x64x64xi8>) + return + } + + // Case 1: uint8 64x64 (valid 60x60) - partition_view sizes = valid_shape + func.func @TNOT_uint8_60x60(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c60 = arith.constant 60 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xui8> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xui8> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c60, %c60] + : !pto.tensor_view<1x1x1x64x64xui8> -> !pto.partition_tensor_view<1x1x1x60x60xui8> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c60, %c60] + : !pto.tensor_view<1x1x1x64x64xui8> -> !pto.partition_tensor_view<1x1x1x60x60xui8> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x60x60xui8>) + outs(%a : !pto.tile_buf) + + pto.tnot ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x60x60xui8>) + return + } + + // Case 2: int16 64x64 (valid 64x64) + func.func @TNOT_int16_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) + outs(%a : !pto.tile_buf) + + pto.tnot ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) + return + } + + // Case 3: uint16 64x64 (valid 60x60) - partition_view sizes = valid_shape + func.func @TNOT_uint16_60x60(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c60 = arith.constant 60 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xui16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xui16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c60, %c60] + : !pto.tensor_view<1x1x1x64x64xui16> -> !pto.partition_tensor_view<1x1x1x60x60xui16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c60, %c60] + : !pto.tensor_view<1x1x1x64x64xui16> -> !pto.partition_tensor_view<1x1x1x60x60xui16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x60x60xui16>) + outs(%a : !pto.tile_buf) + + pto.tnot ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x60x60xui16>) + return + } + + // Case 4: int32 64x64 (valid 64x64) + func.func @TNOT_int32_64x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) + outs(%a : !pto.tile_buf) + + pto.tnot ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) + return + } + + // Case 5: uint32 64x64 (valid 60x60) - partition_view sizes = valid_shape + func.func @TNOT_uint32_60x60(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c60 = arith.constant 60 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xui32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xui32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c60, %c60] + : !pto.tensor_view<1x1x1x64x64xui32> -> !pto.partition_tensor_view<1x1x1x60x60xui32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c60, %c60] + : !pto.tensor_view<1x1x1x64x64xui32> -> !pto.partition_tensor_view<1x1x1x60x60xui32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x60x60xui32>) + outs(%a : !pto.tile_buf) + + pto.tnot ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x60x60xui32>) + return + } +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tor/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tor/CMakeLists.txt new file mode 100644 index 000000000..4d7414cdb --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tor/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tor) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tor/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tor/cases.py new file mode 100644 index 000000000..736a5ff8f --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tor/cases.py @@ -0,0 +1,40 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tor ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.int32). + - shape: (rows, cols) — allocated tile dimensions. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "i32_16x64", + "dtype": np.int32, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 0, + }, + { + "name": "i32_32x32", + "dtype": np.int32, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 0, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tor/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tor/compare.py new file mode 100644 index 000000000..4eae3bc07 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tor/compare.py @@ -0,0 +1,48 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tor/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tor/gen_data.py new file mode 100644 index 000000000..c822c0be3 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tor/gen_data.py @@ -0,0 +1,32 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + input1 = np.random.randint(0, 100, size=shape).astype(dtype) + input2 = np.random.randint(0, 100, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + golden[:vr, :vc] = (input1[:vr, :vc] | input2[:vr, :vc]).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tor/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tor/launch.cpp new file mode 100644 index 000000000..1cb9c1454 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tor/launch.cpp @@ -0,0 +1,27 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: i32 16x64 +extern "C" __global__ AICORE void TOR_i32_16x64(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); + +void LaunchTOR_i32_16x64(int32_t *a, int32_t *b, int32_t *c, void *stream) { + TOR_i32_16x64<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); +} + +// Case 1: i32 32x32 +extern "C" __global__ AICORE void TOR_i32_32x32(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); + +void LaunchTOR_i32_32x32(int32_t *a, int32_t *b, int32_t *c, void *stream) { + TOR_i32_32x32<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tor/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tor/main.cpp new file mode 100644 index 000000000..21d82eeea --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tor/main.cpp @@ -0,0 +1,145 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tor ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTOR_i32_16x64(int32_t *a, int32_t *b, int32_t *c, void *stream); +void LaunchTOR_i32_32x32(int32_t *a, int32_t *b, int32_t *c, void *stream); + +using LaunchFn = void (*)(int32_t *, int32_t *, int32_t *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"i32_16x64", LaunchTOR_i32_16x64, 16, 64, 16, 64, sizeof(int32_t)}, + {"i32_32x32", LaunchTOR_i32_32x32, 32, 32, 32, 32, sizeof(int32_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t src0FileSize = fileSize; + size_t src1FileSize = fileSize; + + int32_t *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + int32_t *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), fileSize); + aclrtMallocHost((void **)(&src1Host), fileSize); + aclrtMallocHost((void **)(&dstHost), fileSize); + + aclrtMalloc((void **)&src0Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, fileSize, src0Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, fileSize, src1Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tor [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tor/tor.pto b/test/tilelang_st/npu/a5/src/st/testcase/tor/tor.pto new file mode 100644 index 000000000..0d6f7d9b3 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tor/tor.pto @@ -0,0 +1,123 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tor: tload(a) + tload(b) + tor(a,b)->c + tstore(c). +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // Case 0: i32 16x64 (1024 elements) + func.func @TOR_i32_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi32> -> !pto.partition_tensor_view<1x1x1x16x64xi32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi32> -> !pto.partition_tensor_view<1x1x1x16x64xi32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi32> -> !pto.partition_tensor_view<1x1x1x16x64xi32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xi32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x64xi32>) + outs(%b : !pto.tile_buf) + + pto.tor ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x16x64xi32>) + return + } + + // Case 1: i32 32x32 (1024 elements) + func.func @TOR_i32_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xi32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xi32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xi32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xi32> -> !pto.partition_tensor_view<1x1x1x32x32xi32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xi32> -> !pto.partition_tensor_view<1x1x1x32x32xi32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xi32> -> !pto.partition_tensor_view<1x1x1x32x32xi32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xi32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x32x32xi32>) + outs(%b : !pto.tile_buf) + + pto.tor ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x32x32xi32>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tors/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tors/CMakeLists.txt new file mode 100644 index 000000000..5decd02d7 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tors/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tors) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tors/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tors/cases.py new file mode 100644 index 000000000..18cc99178 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tors/cases.py @@ -0,0 +1,42 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np + +CASES = [ + { + "name": "i32_32x64", + "dtype": np.int32, + "shape": (32, 64), + "valid_shape": (32, 64), + "eps": 0, + }, + { + "name": "i16_63x64", + "dtype": np.int16, + "shape": (63, 64), + "valid_shape": (63, 64), + "eps": 0, + }, + { + "name": "i32_31x128", + "dtype": np.int32, + "shape": (31, 128), + "valid_shape": (31, 128), + "eps": 0, + }, + { + "name": "i16_15x192", + "dtype": np.int16, + "shape": (15, 192), + "valid_shape": (15, 192), + "eps": 0, + }, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tors/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tors/compare.py new file mode 100644 index 000000000..50186777e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tors/compare.py @@ -0,0 +1,46 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tors/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tors/gen_data.py new file mode 100644 index 000000000..c4c879dcd --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tors/gen_data.py @@ -0,0 +1,35 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +# Scalar value for bitwise OR (must match launch.cpp) +SCALAR = 3 + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + scalar_val = dtype(SCALAR) + golden[:vr, :vc] = (input1[:vr, :vc] | scalar_val).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__} scalar={SCALAR}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tors/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tors/launch.cpp new file mode 100644 index 000000000..4495ff38c --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tors/launch.cpp @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Scalar value for bitwise OR (must match gen_data.py SCALAR) +static constexpr int32_t TORS_SCALAR_I32 = 3; +static constexpr int16_t TORS_SCALAR_I16 = 3; + +// Case 0: i32 32x64 +extern "C" __global__ AICORE void TORS_i32_32x64(__gm__ int32_t *src, __gm__ int32_t *dst, int32_t scalar); + +void LaunchTORS_i32_32x64(int32_t *src, int32_t *dst, void *stream) { + TORS_i32_32x64<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst, TORS_SCALAR_I32); +} + +// Case 1: i16 63x64 +extern "C" __global__ AICORE void TORS_i16_63x64(__gm__ int16_t *src, __gm__ int16_t *dst, int16_t scalar); + +void LaunchTORS_i16_63x64(int16_t *src, int16_t *dst, void *stream) { + TORS_i16_63x64<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst, TORS_SCALAR_I16); +} + +// Case 2: i32 31x128 +extern "C" __global__ AICORE void TORS_i32_31x128(__gm__ int32_t *src, __gm__ int32_t *dst, int32_t scalar); + +void LaunchTORS_i32_31x128(int32_t *src, int32_t *dst, void *stream) { + TORS_i32_31x128<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst, TORS_SCALAR_I32); +} + +// Case 3: i16 15x192 +extern "C" __global__ AICORE void TORS_i16_15x192(__gm__ int16_t *src, __gm__ int16_t *dst, int16_t scalar); + +void LaunchTORS_i16_15x192(int16_t *src, int16_t *dst, void *stream) { + TORS_i16_15x192<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst, TORS_SCALAR_I16); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tors/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tors/main.cpp new file mode 100644 index 000000000..b67da6f06 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tors/main.cpp @@ -0,0 +1,135 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tors ST — case-table driven. +// tors: dst = src | scalar (single input + scalar, bitwise OR). +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTORS_i32_32x64(int32_t *src, int32_t *dst, void *stream); +void LaunchTORS_i16_63x64(int16_t *src, int16_t *dst, void *stream); +void LaunchTORS_i32_31x128(int32_t *src, int32_t *dst, void *stream); +void LaunchTORS_i16_15x192(int16_t *src, int16_t *dst, void *stream); + +struct TestCase { + const char *name; + void (*launch)(void *, void *, void *); // src, dst, stream + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"i32_32x64", (void (*)(void*,void*,void*))LaunchTORS_i32_32x64, 32, 64, 32, 64, sizeof(int32_t)}, + {"i16_63x64", (void (*)(void*,void*,void*))LaunchTORS_i16_63x64, 63, 64, 63, 64, sizeof(int16_t)}, + {"i32_31x128", (void (*)(void*,void*,void*))LaunchTORS_i32_31x128, 31, 128, 31, 128, sizeof(int32_t)}, + {"i16_15x192", (void (*)(void*,void*,void*))LaunchTORS_i16_15x192, 15, 192, 15, 192, sizeof(int16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSize = fileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&srcHost, fileSize); + aclrtMallocHost(&dstHost, fileSize); + + aclrtMalloc(&srcDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), srcFileSize, srcHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, fileSize, srcHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tors [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tors/tors.pto b/test/tilelang_st/npu/a5/src/st/testcase/tors/tors.pto new file mode 100644 index 000000000..36124ff38 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tors/tors.pto @@ -0,0 +1,176 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tors: tload(src) + tors(src, scalar)->dst + tstore(dst). +// Multiple cases with different shapes/dtypes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + + // Case 0: i32 32x64 (2048 elements) + func.func @TORS_i32_32x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xi32> -> !pto.partition_tensor_view<1x1x1x32x64xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xi32> -> !pto.partition_tensor_view<1x1x1x32x64xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x64xi32>) + outs(%src : !pto.tile_buf) + pto.tors ins(%src, %scalar : !pto.tile_buf, i32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x64xi32>) + return + } + + // Case 1: i16 63x64 (4032 elements) + func.func @TORS_i16_63x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i16) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c4032 = arith.constant 4032 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xi16> -> !pto.partition_tensor_view<1x1x1x63x64xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xi16> -> !pto.partition_tensor_view<1x1x1x63x64xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x64xi16>) + outs(%src : !pto.tile_buf) + pto.tors ins(%src, %scalar : !pto.tile_buf, i16) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x64xi16>) + return + } + + // Case 2: i32 31x128 (3968 elements) + func.func @TORS_i32_31x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c31 = arith.constant 31 : index + %c128 = arith.constant 128 : index + %c3968 = arith.constant 3968 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c128] + : !pto.tensor_view<1x1x1x31x128xi32> -> !pto.partition_tensor_view<1x1x1x31x128xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c128] + : !pto.tensor_view<1x1x1x31x128xi32> -> !pto.partition_tensor_view<1x1x1x31x128xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x31x128xi32>) + outs(%src : !pto.tile_buf) + pto.tors ins(%src, %scalar : !pto.tile_buf, i32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x31x128xi32>) + return + } + + // Case 3: i16 15x192 (2880 elements) + func.func @TORS_i16_15x192(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i16) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c192 = arith.constant 192 : index + %c2880 = arith.constant 2880 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xi16> -> !pto.partition_tensor_view<1x1x1x15x192xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xi16> -> !pto.partition_tensor_view<1x1x1x15x192xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x192xi16>) + outs(%src : !pto.tile_buf) + pto.tors ins(%src, %scalar : !pto.tile_buf, i16) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x15x192xi16>) + return + } + +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartadd/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tpartadd/CMakeLists.txt new file mode 100644 index 000000000..4eb1affcf --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartadd/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tpartadd) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartadd/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tpartadd/cases.py new file mode 100644 index 000000000..6ec74d95d --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartadd/cases.py @@ -0,0 +1,122 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tpartadd ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions (same for src0/src1/dst). + - valid_shape: (valid_rows, valid_cols) — src0 valid region (src0_eq_dst scenario). + - src1_vshape: (src1_valid_rows, src1_valid_cols) — src1 valid region. + May be smaller than dst valid region for partial add cases. + - dst_vshape: (dst_valid_rows, dst_valid_cols) — dst valid region. + - eps: tolerance for numpy.allclose (atol and rtol). + +tpartadd semantics: + - If src0_valid == dst_valid: dst[:src1_rows,:src1_cols] = src0[:src1_rows,:src1_cols] + src1[:src1_rows,:src1_cols] + dst[src1_rows:,:] = src0[src1_rows:,:] (copy remaining rows) + OR (for col_less) dst[:,:src1_cols] = src0[:,:src1_cols] + src1[:,:src1_cols] + dst[:,src1_cols:] = src0[:,src1_cols:] (copy remaining cols) + - If src1_valid == dst_valid: similar logic with src1 as the full operand. + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + # float32 cases + { + "name": "f32_64x64_full", + "dtype": np.float32, + "shape": (64, 64), + "valid_shape": (64, 64), # src0 valid region + "src1_vshape": (64, 64), # src1 valid region (same as dst) + "dst_vshape": (64, 64), # dst valid region + "eps": 1e-6, + }, + { + "name": "f32_64x64_src0_row_less", + "dtype": np.float32, + "shape": (64, 64), + "valid_shape": (8, 64), # src0 valid region (row_less) + "src1_vshape": (64, 64), # src1 valid region (equals dst) + "dst_vshape": (64, 64), # dst valid region + "eps": 1e-6, + }, + { + "name": "f32_64x64_src0_col_less", + "dtype": np.float32, + "shape": (64, 64), + "valid_shape": (64, 8), # src0 valid region (col_less) + "src1_vshape": (64, 64), # src1 valid region (equals dst) + "dst_vshape": (64, 64), # dst valid region + "eps": 1e-6, + }, + { + "name": "f32_64x64_src1_row_less", + "dtype": np.float32, + "shape": (64, 64), + "valid_shape": (64, 64), # src0 valid region (equals dst) + "src1_vshape": (8, 64), # src1 valid region (row_less) + "dst_vshape": (64, 64), # dst valid region + "eps": 1e-6, + }, + { + "name": "f32_64x64_src1_col_less", + "dtype": np.float32, + "shape": (64, 64), + "valid_shape": (64, 64), # src0 valid region (equals dst) + "src1_vshape": (64, 8), # src1 valid region (col_less) + "dst_vshape": (64, 64), # dst valid region + "eps": 1e-6, + }, + # float16 cases + { + "name": "f16_8x48_src0_col_less", + "dtype": np.float16, + "shape": (8, 48), + "valid_shape": (8, 16), # src0 valid region (col_less) + "src1_vshape": (8, 48), # src1 valid region (equals dst) + "dst_vshape": (8, 48), # dst valid region + "eps": 1e-3, + }, + { + "name": "f16_8x768_src0_col_less", + "dtype": np.float16, + "shape": (8, 768), + "valid_shape": (8, 512), # src0 valid region (col_less) + "src1_vshape": (8, 768), # src1 valid region (equals dst) + "dst_vshape": (8, 768), # dst valid region + "eps": 1e-3, + }, + # int16 cases + { + "name": "i16_8x48_src1_col_less", + "dtype": np.int16, + "shape": (8, 48), + "valid_shape": (8, 48), # src0 valid region (equals dst) + "src1_vshape": (8, 16), # src1 valid region (col_less) + "dst_vshape": (8, 48), # dst valid region + "eps": 0, # exact match for int + }, + # int32 cases + { + "name": "i32_64x64_src0_row_less", + "dtype": np.int32, + "shape": (64, 64), + "valid_shape": (8, 64), # src0 valid region (row_less) + "src1_vshape": (64, 64), # src1 valid region (equals dst) + "dst_vshape": (64, 64), # dst valid region + "eps": 0, # exact match for int + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartadd/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tpartadd/compare.py new file mode 100644 index 000000000..283ee788a --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartadd/compare.py @@ -0,0 +1,53 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +# Add parent directory to path for st_common import +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from st_common import result_cmp, style_fail, style_pass + +from cases import CASES + + +def main(): + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + dtype = case["dtype"] + dst_vr, dst_vc = case["dst_vshape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=dtype).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=dtype).reshape(shape) + + # Compare only the dst valid region + ok = result_cmp(golden[:dst_vr, :dst_vc], output[:dst_vr, :dst_vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartadd/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tpartadd/gen_data.py new file mode 100644 index 000000000..9ecaf30fa --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartadd/gen_data.py @@ -0,0 +1,96 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +import os +import sys + +# Add parent directory to path for st_common import +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from st_common import setup_case_rng, save_case_data + +from cases import CASES + + +def _to_tuple(shape): + """Convert shape to tuple if needed.""" + if isinstance(shape, tuple): + return shape + return tuple(shape) + + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = _to_tuple(case["shape"]) + src0_valid = _to_tuple(case["valid_shape"]) + src1_valid = _to_tuple(case["src1_vshape"]) + dst_valid = _to_tuple(case["dst_vshape"]) + + rows, cols = shape + src0_vr, src0_vc = src0_valid + src1_vr, src1_vc = src1_valid + dst_vr, dst_vc = dst_valid + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + input2 = np.random.randint(1, 10, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + + # Compute golden according to tpartadd semantics from template: + # If src0_valid == dst_valid: use tpart_op with src0 as full operand + # - If src1 row_less: add for src1 region, copy src0 for remaining rows + # - If src1 col_less: copy src0 full, then add for overlapping region + # If src1_valid == dst_valid: use tpart_op with src1 as full operand (swap src0/src1) + + src0_eq_dst = (src0_vr == dst_vr and src0_vc == dst_vc) + src1_eq_dst = (src1_vr == dst_vr and src1_vc == dst_vc) + + if src0_eq_dst: + # src0 is the full operand matching dst + src1_row_lt_dst = (src1_vr < dst_vr and src1_vc == dst_vc) + src1_col_lt_dst = (src1_vr <= dst_vr and src1_vc < dst_vc) + + if src1_eq_dst: + # Full add: dst[:] = src0[:] + src1[:] + golden[:dst_vr, :dst_vc] = (input1[:dst_vr, :dst_vc] + input2[:dst_vr, :dst_vc]).astype(dtype, copy=False) + elif src1_col_lt_dst: + # Col_less: first copy src0, then add in overlapping region + golden[:dst_vr, :dst_vc] = input1[:dst_vr, :dst_vc].copy() + if src1_vc > 0: + golden[:src1_vr, :src1_vc] = (input1[:src1_vr, :src1_vc] + input2[:src1_vr, :src1_vc]).astype(dtype, copy=False) + elif src1_row_lt_dst: + # Row_less: add for src1 region, copy src0 for remaining rows + if src1_vc > 0: + golden[:src1_vr, :src1_vc] = (input1[:src1_vr, :src1_vc] + input2[:src1_vr, :src1_vc]).astype(dtype, copy=False) + golden[src1_vr:dst_vr, :dst_vc] = input1[src1_vr:dst_vr, :dst_vc].copy() + elif src1_eq_dst: + # src1 is the full operand matching dst, swap src0/src1 in the logic + src0_row_lt_dst = (src0_vr < dst_vr and src0_vc == dst_vc) + src0_col_lt_dst = (src0_vr <= dst_vr and src0_vc < dst_vc) + + if src0_eq_dst: + # Full add: dst[:] = src0[:] + src1[:] + golden[:dst_vr, :dst_vc] = (input1[:dst_vr, :dst_vc] + input2[:dst_vr, :dst_vc]).astype(dtype, copy=False) + elif src0_col_lt_dst: + # Col_less: first copy src1, then add in overlapping region + golden[:dst_vr, :dst_vc] = input2[:dst_vr, :dst_vc].copy() + if src0_vc > 0: + golden[:src0_vr, :src0_vc] = (input1[:src0_vr, :src0_vc] + input2[:src0_vr, :src0_vc]).astype(dtype, copy=False) + elif src0_row_lt_dst: + # Row_less: add for src0 region, copy src1 for remaining rows + if src0_vc > 0: + golden[:src0_vr, :src0_vc] = (input1[:src0_vr, :src0_vc] + input2[:src0_vr, :src0_vc]).astype(dtype, copy=False) + golden[src0_vr:dst_vr, :dst_vc] = input2[src0_vr:dst_vr, :dst_vc].copy() + + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} src0_valid={src0_valid} src1_valid={src1_valid} dst_valid={dst_valid} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartadd/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tpartadd/launch.cpp new file mode 100644 index 000000000..02d725199 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartadd/launch.cpp @@ -0,0 +1,76 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: f32 64x64 full +extern "C" __global__ AICORE void TPARTADD_f32_64x64_full(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTPARTADD_f32_64x64_full(float *a, float *b, float *c, void *stream) { + TPARTADD_f32_64x64_full<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 1: f32 64x64 src0 row less +extern "C" __global__ AICORE void TPARTADD_f32_64x64_src0_row_less(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTPARTADD_f32_64x64_src0_row_less(float *a, float *b, float *c, void *stream) { + TPARTADD_f32_64x64_src0_row_less<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 2: f32 64x64 src0 col less +extern "C" __global__ AICORE void TPARTADD_f32_64x64_src0_col_less(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTPARTADD_f32_64x64_src0_col_less(float *a, float *b, float *c, void *stream) { + TPARTADD_f32_64x64_src0_col_less<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 3: f32 64x64 src1 row less +extern "C" __global__ AICORE void TPARTADD_f32_64x64_src1_row_less(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTPARTADD_f32_64x64_src1_row_less(float *a, float *b, float *c, void *stream) { + TPARTADD_f32_64x64_src1_row_less<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 4: f32 64x64 src1 col less +extern "C" __global__ AICORE void TPARTADD_f32_64x64_src1_col_less(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTPARTADD_f32_64x64_src1_col_less(float *a, float *b, float *c, void *stream) { + TPARTADD_f32_64x64_src1_col_less<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 5: f16 8x48 src0 col less +extern "C" __global__ AICORE void TPARTADD_f16_8x48_src0_col_less(__gm__ uint16_t *a, __gm__ uint16_t *b, __gm__ uint16_t *c); + +void LaunchTPARTADD_f16_8x48_src0_col_less(uint16_t *a, uint16_t *b, uint16_t *c, void *stream) { + TPARTADD_f16_8x48_src0_col_less<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b, (__gm__ uint16_t *)c); +} + +// Case 6: f16 8x768 src0 col less +extern "C" __global__ AICORE void TPARTADD_f16_8x768_src0_col_less(__gm__ uint16_t *a, __gm__ uint16_t *b, __gm__ uint16_t *c); + +void LaunchTPARTADD_f16_8x768_src0_col_less(uint16_t *a, uint16_t *b, uint16_t *c, void *stream) { + TPARTADD_f16_8x768_src0_col_less<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b, (__gm__ uint16_t *)c); +} + +// Case 7: i16 8x48 src1 col less +extern "C" __global__ AICORE void TPARTADD_i16_8x48_src1_col_less(__gm__ int16_t *a, __gm__ int16_t *b, __gm__ int16_t *c); + +void LaunchTPARTADD_i16_8x48_src1_col_less(int16_t *a, int16_t *b, int16_t *c, void *stream) { + TPARTADD_i16_8x48_src1_col_less<<<1, nullptr, stream>>>((__gm__ int16_t *)a, (__gm__ int16_t *)b, (__gm__ int16_t *)c); +} + +// Case 8: i32 64x64 src0 row less +extern "C" __global__ AICORE void TPARTADD_i32_64x64_src0_row_less(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); + +void LaunchTPARTADD_i32_64x64_src0_row_less(int32_t *a, int32_t *b, int32_t *c, void *stream) { + TPARTADD_i32_64x64_src0_row_less<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartadd/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tpartadd/main.cpp new file mode 100644 index 000000000..34e013b77 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartadd/main.cpp @@ -0,0 +1,164 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tpartadd ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTPARTADD_f32_64x64_full(float *a, float *b, float *c, void *stream); +void LaunchTPARTADD_f32_64x64_src0_row_less(float *a, float *b, float *c, void *stream); +void LaunchTPARTADD_f32_64x64_src0_col_less(float *a, float *b, float *c, void *stream); +void LaunchTPARTADD_f32_64x64_src1_row_less(float *a, float *b, float *c, void *stream); +void LaunchTPARTADD_f32_64x64_src1_col_less(float *a, float *b, float *c, void *stream); +void LaunchTPARTADD_f16_8x48_src0_col_less(uint16_t *a, uint16_t *b, uint16_t *c, void *stream); +void LaunchTPARTADD_f16_8x768_src0_col_less(uint16_t *a, uint16_t *b, uint16_t *c, void *stream); +void LaunchTPARTADD_i16_8x48_src1_col_less(int16_t *a, int16_t *b, int16_t *c, void *stream); +void LaunchTPARTADD_i32_64x64_src0_row_less(int32_t *a, int32_t *b, int32_t *c, void *stream); + +using LaunchFn = void (*)(void *, void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t src0ValidRows; // src0 effective rows + size_t src0ValidCols; // src0 effective cols + size_t src1ValidRows; // src1 effective rows + size_t src1ValidCols; // src1 effective cols + size_t dstValidRows; // dst effective rows + size_t dstValidCols; // dst effective cols + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_64x64_full", reinterpret_cast(LaunchTPARTADD_f32_64x64_full), 64, 64, 64, 64, 64, 64, 64, 64, sizeof(float)}, + {"f32_64x64_src0_row_less", reinterpret_cast(LaunchTPARTADD_f32_64x64_src0_row_less), 64, 64, 8, 64, 64, 64, 64, 64, sizeof(float)}, + {"f32_64x64_src0_col_less", reinterpret_cast(LaunchTPARTADD_f32_64x64_src0_col_less), 64, 64, 64, 8, 64, 64, 64, 64, sizeof(float)}, + {"f32_64x64_src1_row_less", reinterpret_cast(LaunchTPARTADD_f32_64x64_src1_row_less), 64, 64, 64, 64, 8, 64, 64, 64, sizeof(float)}, + {"f32_64x64_src1_col_less", reinterpret_cast(LaunchTPARTADD_f32_64x64_src1_col_less), 64, 64, 64, 64, 64, 8, 64, 64, sizeof(float)}, + {"f16_8x48_src0_col_less", reinterpret_cast(LaunchTPARTADD_f16_8x48_src0_col_less), 8, 48, 8, 16, 8, 48, 8, 48, sizeof(uint16_t)}, + {"f16_8x768_src0_col_less", reinterpret_cast(LaunchTPARTADD_f16_8x768_src0_col_less), 8,768, 8,512, 8,768, 8,768, sizeof(uint16_t)}, + {"i16_8x48_src1_col_less", reinterpret_cast(LaunchTPARTADD_i16_8x48_src1_col_less), 8, 48, 8, 48, 8, 16, 8, 48, sizeof(int16_t)}, + {"i32_64x64_src0_row_less", reinterpret_cast(LaunchTPARTADD_i32_64x64_src0_row_less), 64, 64, 8, 64, 64, 64, 64, 64, sizeof(int32_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, src0_valid=%zux%zu, src1_valid=%zux%zu, dst_valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.src0ValidRows, tc.src0ValidCols, + tc.src1ValidRows, tc.src1ValidCols, tc.dstValidRows, tc.dstValidCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t src0FileSize = fileSize; + size_t src1FileSize = fileSize; + + void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), fileSize); + aclrtMallocHost((void **)(&src1Host), fileSize); + aclrtMallocHost((void **)(&dstHost), fileSize); + + aclrtMalloc((void **)&src0Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, fileSize, src0Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, fileSize, src1Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tpartadd [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartadd/tpartadd.pto b/test/tilelang_st/npu/a5/src/st/testcase/tpartadd/tpartadd.pto new file mode 100644 index 000000000..f7a196284 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartadd/tpartadd.pto @@ -0,0 +1,535 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use the file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tpartadd: partial elementwise add with valid region handling. +// Multiple cases with different valid_shape combinations in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // Case 0: f32 64x64 full (src0/src1/dst all have same valid_shape 64x64) + func.func @TPARTADD_f32_64x64_full(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tpartadd ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + return + } + + // Case 1: f32 64x64 src0 row less (src0 valid 8x64, src1/dst valid 64x64) + func.func @TPARTADD_f32_64x64_src0_row_less(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + + // src0: partial valid region (8,64) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: full valid region (64,64) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: full valid region (64,64) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tpartadd ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + return + } + + // Case 2: f32 64x64 src0 col less (src0 valid 64x8, src1/dst valid 64x64) + func.func @TPARTADD_f32_64x64_src0_col_less(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + + // src0: partial valid region (64,8) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: full valid region (64,64) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: full valid region (64,64) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tpartadd ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + return + } + + // Case 3: f32 64x64 src1 row less (src0/dst valid 64x64, src1 valid 8x64) + func.func @TPARTADD_f32_64x64_src1_row_less(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + + // src0: full valid region (64,64) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: partial valid region (8,64) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: full valid region (64,64) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tpartadd ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + return + } + + // Case 4: f32 64x64 src1 col less (src0/dst valid 64x64, src1 valid 64x8) + func.func @TPARTADD_f32_64x64_src1_col_less(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + + // src0: full valid region (64,64) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: partial valid region (64,8) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: full valid region (64,64) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tpartadd ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + return + } + + // Case 5: f16 8x48 src0 col less (src0 valid 8x16, src1/dst valid 8x48) + func.func @TPARTADD_f16_8x48_src0_col_less(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c48 = arith.constant 48 : index + %c384 = arith.constant 384 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c8, %c48], + strides = [%c384, %c384, %c384, %c48, %c1] + : !pto.tensor_view<1x1x1x8x48xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c8, %c48], + strides = [%c384, %c384, %c384, %c48, %c1] + : !pto.tensor_view<1x1x1x8x48xf16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c8, %c48], + strides = [%c384, %c384, %c384, %c48, %c1] + : !pto.tensor_view<1x1x1x8x48xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c48] + : !pto.tensor_view<1x1x1x8x48xf16> -> !pto.partition_tensor_view<1x1x1x8x48xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c48] + : !pto.tensor_view<1x1x1x8x48xf16> -> !pto.partition_tensor_view<1x1x1x8x48xf16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c48] + : !pto.tensor_view<1x1x1x8x48xf16> -> !pto.partition_tensor_view<1x1x1x8x48xf16> + + // src0: partial valid region (8,16) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: full valid region (8,48) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: full valid region (8,48) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x8x48xf16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x8x48xf16>) + outs(%b : !pto.tile_buf) + + pto.tpartadd ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x8x48xf16>) + return + } + + // Case 6: f16 8x768 src0 col less (src0 valid 8x512, src1/dst valid 8x768) + func.func @TPARTADD_f16_8x768_src0_col_less(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c512 = arith.constant 512 : index + %c768 = arith.constant 768 : index + %c6144 = arith.constant 6144 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c8, %c768], + strides = [%c6144, %c6144, %c6144, %c768, %c1] + : !pto.tensor_view<1x1x1x8x768xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c8, %c768], + strides = [%c6144, %c6144, %c6144, %c768, %c1] + : !pto.tensor_view<1x1x1x8x768xf16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c8, %c768], + strides = [%c6144, %c6144, %c6144, %c768, %c1] + : !pto.tensor_view<1x1x1x8x768xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c768] + : !pto.tensor_view<1x1x1x8x768xf16> -> !pto.partition_tensor_view<1x1x1x8x768xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c768] + : !pto.tensor_view<1x1x1x8x768xf16> -> !pto.partition_tensor_view<1x1x1x8x768xf16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c768] + : !pto.tensor_view<1x1x1x8x768xf16> -> !pto.partition_tensor_view<1x1x1x8x768xf16> + + // src0: partial valid region (8,512) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: full valid region (8,768) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: full valid region (8,768) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x8x768xf16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x8x768xf16>) + outs(%b : !pto.tile_buf) + + pto.tpartadd ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x8x768xf16>) + return + } + + // Case 7: i16 8x48 src1 col less (src0/dst valid 8x48, src1 valid 8x16) + func.func @TPARTADD_i16_8x48_src1_col_less(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c48 = arith.constant 48 : index + %c384 = arith.constant 384 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c8, %c48], + strides = [%c384, %c384, %c384, %c48, %c1] + : !pto.tensor_view<1x1x1x8x48xi16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c8, %c48], + strides = [%c384, %c384, %c384, %c48, %c1] + : !pto.tensor_view<1x1x1x8x48xi16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c8, %c48], + strides = [%c384, %c384, %c384, %c48, %c1] + : !pto.tensor_view<1x1x1x8x48xi16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c48] + : !pto.tensor_view<1x1x1x8x48xi16> -> !pto.partition_tensor_view<1x1x1x8x48xi16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c48] + : !pto.tensor_view<1x1x1x8x48xi16> -> !pto.partition_tensor_view<1x1x1x8x48xi16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c48] + : !pto.tensor_view<1x1x1x8x48xi16> -> !pto.partition_tensor_view<1x1x1x8x48xi16> + + // src0: full valid region (8,48) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: partial valid region (8,16) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: full valid region (8,48) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x8x48xi16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x8x48xi16>) + outs(%b : !pto.tile_buf) + + pto.tpartadd ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x8x48xi16>) + return + } + + // Case 8: i32 64x64 src0 row less (src0 valid 8x64, src1/dst valid 64x64) + func.func @TPARTADD_i32_64x64_src0_row_less(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> + + // src0: partial valid region (8,64) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: full valid region (64,64) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: full valid region (64,64) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) + outs(%b : !pto.tile_buf) + + pto.tpartadd ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartmax/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tpartmax/CMakeLists.txt new file mode 100644 index 000000000..04f947e55 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartmax/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tpartmax) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartmax/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tpartmax/cases.py new file mode 100644 index 000000000..e6e6dbfe5 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartmax/cases.py @@ -0,0 +1,153 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tpartmax ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions (same for src0/src1/dst). + - valid_shape: (valid_rows, valid_cols) — src0 valid region (src0_eq_dst scenario). + - src1_vshape: (src1_valid_rows, src1_valid_cols) — src1 valid region. + May be smaller than dst valid region for partial max cases. + - dst_vshape: (dst_valid_rows, dst_valid_cols) — dst valid region. + - eps: tolerance for numpy.allclose (atol and rtol). + +tpartmax semantics: + - If src0_valid == dst_valid: dst[:src1_rows,:src1_cols] = max(src0[:src1_rows,:src1_cols], src1[:src1_rows,:src1_cols]) + dst[src1_rows:,:] = src0[src1_rows:,:] (copy remaining rows) + OR (for col_less) dst[:,:src1_cols] = max(src0[:,:src1_cols], src1[:,:src1_cols]) + dst[:,src1_cols:] = src0[:,src1_cols:] (copy remaining cols) + - If src1_valid == dst_valid: similar logic with src1 as the full operand. + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + # float32 cases from pto-isa + { + "name": "f32_64x64_full", + "dtype": np.float32, + "shape": (64, 64), + "valid_shape": (64, 64), # src0 valid region + "src1_vshape": (64, 64), # src1 valid region (same as dst) + "dst_vshape": (64, 64), # dst valid region + "eps": 1e-6, + }, + { + "name": "f32_2x24_src1_col_less", + "dtype": np.float32, + "shape": (2, 24), + "valid_shape": (2, 24), # src0 valid region (equals dst) + "src1_vshape": (2, 8), # src1 valid region (col_less) + "dst_vshape": (2, 24), # dst valid region + "eps": 1e-6, + }, + { + "name": "f32_128x64_src1_row_less", + "dtype": np.float32, + "shape": (128, 64), + "valid_shape": (128, 64), # src0 valid region (equals dst) + "src1_vshape": (96, 64), # src1 valid region (row_less) + "dst_vshape": (128, 64), # dst valid region + "eps": 1e-6, + }, + { + "name": "f32_95x95_full", + "dtype": np.float32, + "shape": (95, 95), + "valid_shape": (95, 95), # src0 valid region + "src1_vshape": (95, 95), # src1 valid region (same as dst) + "dst_vshape": (95, 95), # dst valid region + "eps": 1e-6, + }, + { + "name": "f32_122x123_complex", + "dtype": np.float32, + "shape": (122, 123), + "valid_shape": (104, 123), # src0 valid region + "src1_vshape": (122, 110), # src1 valid region + "dst_vshape": (122, 123), # dst valid region (src1 rows, src0 cols) + "eps": 1e-6, + }, + # float16 cases from pto-isa + { + "name": "f16_122x123_complex", + "dtype": np.float16, + "shape": (122, 123), + "valid_shape": (104, 123), # src0 valid region + "src1_vshape": (122, 110), # src1 valid region + "dst_vshape": (122, 123), # dst valid region + "eps": 1e-3, + }, + # int16 cases from pto-isa + { + "name": "i16_122x123_complex", + "dtype": np.int16, + "shape": (122, 123), + "valid_shape": (104, 123), # src0 valid region + "src1_vshape": (122, 110), # src1 valid region + "dst_vshape": (122, 123), # dst valid region + "eps": 0, + }, + # int32 cases from pto-isa + { + "name": "i32_122x123_complex", + "dtype": np.int32, + "shape": (122, 123), + "valid_shape": (104, 123), # src0 valid region + "src1_vshape": (122, 110), # src1 valid region + "dst_vshape": (122, 123), # dst valid region + "eps": 0, + }, + # uint16 cases from pto-isa + { + "name": "u16_122x123_complex", + "dtype": np.uint16, + "shape": (122, 123), + "valid_shape": (104, 123), # src0 valid region + "src1_vshape": (122, 110), # src1 valid region + "dst_vshape": (122, 123), # dst valid region + "eps": 0, + }, + # uint32 cases from pto-isa + { + "name": "u32_122x123_complex", + "dtype": np.uint32, + "shape": (122, 123), + "valid_shape": (104, 123), # src0 valid region + "src1_vshape": (122, 110), # src1 valid region + "dst_vshape": (122, 123), # dst valid region + "eps": 0, + }, + # int8 cases from pto-isa + { + "name": "i8_122x123_complex", + "dtype": np.int8, + "shape": (122, 123), + "valid_shape": (104, 123), # src0 valid region + "src1_vshape": (122, 110), # src1 valid region + "dst_vshape": (122, 123), # dst valid region + "eps": 0, + }, + # uint8 cases from pto-isa + { + "name": "u8_122x123_complex", + "dtype": np.uint8, + "shape": (122, 123), + "valid_shape": (104, 123), # src0 valid region + "src1_vshape": (122, 110), # src1 valid region + "dst_vshape": (122, 123), # dst valid region + "eps": 0, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartmax/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tpartmax/compare.py new file mode 100644 index 000000000..283ee788a --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartmax/compare.py @@ -0,0 +1,53 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +# Add parent directory to path for st_common import +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from st_common import result_cmp, style_fail, style_pass + +from cases import CASES + + +def main(): + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + dtype = case["dtype"] + dst_vr, dst_vc = case["dst_vshape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=dtype).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=dtype).reshape(shape) + + # Compare only the dst valid region + ok = result_cmp(golden[:dst_vr, :dst_vc], output[:dst_vr, :dst_vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartmax/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tpartmax/gen_data.py new file mode 100644 index 000000000..700de5895 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartmax/gen_data.py @@ -0,0 +1,127 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +import os +import sys + +# Add parent directory to path for st_common import +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from st_common import setup_case_rng, save_case_data + +from cases import CASES + + +def _to_tuple(shape): + """Convert shape to tuple if needed.""" + if isinstance(shape, tuple): + return shape + return tuple(shape) + + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = _to_tuple(case["shape"]) + src0_valid = _to_tuple(case["valid_shape"]) + src1_valid = _to_tuple(case["src1_vshape"]) + dst_valid = _to_tuple(case["dst_vshape"]) + + rows, cols = shape + src0_vr, src0_vc = src0_valid + src1_vr, src1_vc = src1_valid + dst_vr, dst_vc = dst_valid + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + input2 = np.random.randint(1, 10, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + + # tpartmax semantics (based on pto-isa TPartBinOps.hpp TCopyPadOp): + # Algorithm: + # 1. dst[:] = Min (padding for max operation) + # 2. dst[0:src0_vr, 0:src0_vc] = src0[0:src0_vr, 0:src0_vc] (copy src0 to dst) + # 3. dst[0:src1_vr, 0:src1_vc] = max(dst[0:src1_vr, 0:src1_vc], src1[0:src1_vr, 0:src1_vc]) + # (apply max in src1 valid region) + + src0_eq_dst = (src0_vr == dst_vr and src0_vc == dst_vc) + src1_eq_dst = (src1_vr == dst_vr and src1_vc == dst_vc) + + if src0_eq_dst and src1_eq_dst: + # Full max: both src0 and src1 cover entire dst + golden[:dst_vr, :dst_vc] = np.maximum(input1[:dst_vr, :dst_vc], input2[:dst_vr, :dst_vc]).astype(dtype, copy=False) + elif src0_eq_dst: + # src0 covers dst, src1 is partial + # dst = src0 (copy), then max(dst, src1) in src1 region = max(src0, src1) in src1 region, src0 in rest + golden[:src1_vr, :src1_vc] = np.maximum(input1[:src1_vr, :src1_vc], input2[:src1_vr, :src1_vc]).astype(dtype, copy=False) + if src1_vc < dst_vc: + golden[:src1_vr, src1_vc:dst_vc] = input1[:src1_vr, src1_vc:dst_vc].copy() + if src1_vr < dst_vr: + golden[src1_vr:dst_vr, :dst_vc] = input1[src1_vr:dst_vr, :dst_vc].copy() + elif src1_eq_dst: + # src1 covers dst, src0 is partial + # dst = Min, then copy src0 in src0 region, then max(dst, src1) in src1 region + golden[:src0_vr, :src0_vc] = np.maximum(input1[:src0_vr, :src0_vc], input2[:src0_vr, :src0_vc]).astype(dtype, copy=False) + if src0_vc < dst_vc: + golden[:src0_vr, src0_vc:dst_vc] = input2[:src0_vr, src0_vc:dst_vc].copy() + if src0_vr < dst_vr: + golden[src0_vr:dst_vr, :dst_vc] = input2[src0_vr:dst_vr, :dst_vc].copy() + else: + min_vr = min(src0_vr, src1_vr) + min_vc = min(src0_vc, src1_vc) + + # Region 1: [0:min_vr, 0:min_vc] - overlapping region (both src0 and src1 valid) + golden[:min_vr, :min_vc] = np.maximum(input1[:min_vr, :min_vc], input2[:min_vr, :min_vc]).astype(dtype, copy=False) + + # Region 2: [0:src0_vr, min_vc:src0_vc] if src0_vc > min_vc + if src0_vc > min_vc: + golden[:src0_vr, min_vc:src0_vc] = input1[:src0_vr, min_vc:src0_vc].copy() + + # Region 3: [min_vr:src1_vr, 0:min_vc] if src1_vr > min_vr + if src1_vr > min_vr: + golden[min_vr:src1_vr, :min_vc] = input2[min_vr:src1_vr, :min_vc].copy() + + # Region 4: [min_vr:src1_vr, min_vc:src1_vc] if src1_vr > min_vr AND src1_vc > min_vc + if src1_vr > min_vr and src1_vc > min_vc: + golden[min_vr:src1_vr, min_vc:src1_vc] = input2[min_vr:src1_vr, min_vc:src1_vc].copy() + + # Region 5: [0:min_vr, src1_vc:src0_vc] if src0_vc > src1_vc + if src0_vc > src1_vc and min_vr > 0: + # Already handled in Region 2 if rows are [0:src0_vr] + pass # Region 2 covers this + + if src1_vr > src0_vr and src0_vc > src1_vc: + # Region [src0_vr:src1_vr, src1_vc:src0_vc] = Min (neither covers) + # This is correct for tpartmax - padding value is Min + # For floats, we use -np.inf. For integers, use dtype min. + if dtype == np.float32: + min_val = np.finfo(np.float32).min + elif dtype == np.float16: + min_val = np.finfo(np.float16).min + elif dtype == np.int8: + min_val = np.iinfo(np.int8).min + elif dtype == np.uint8: + min_val = np.iinfo(np.uint8).min + elif dtype == np.int16: + min_val = np.iinfo(np.int16).min + elif dtype == np.uint16: + min_val = np.iinfo(np.uint16).min + elif dtype == np.int32: + min_val = np.iinfo(np.int32).min + elif dtype == np.uint32: + min_val = np.iinfo(np.uint32).min + else: + min_val = np.iinfo(dtype).min + golden[src0_vr:src1_vr, src1_vc:src0_vc] = min_val + + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} src0_valid={src0_valid} src1_valid={src1_valid} dst_valid={dst_valid} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartmax/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tpartmax/launch.cpp new file mode 100644 index 000000000..98a1a76d7 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartmax/launch.cpp @@ -0,0 +1,97 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case: f32 64x64 full +extern "C" __global__ AICORE void TPARTMAX_f32_64x64_full(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTPARTMAX_f32_64x64_full(float *a, float *b, float *c, void *stream) { + TPARTMAX_f32_64x64_full<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case: f32 2x24 src1 col less +extern "C" __global__ AICORE void TPARTMAX_f32_2x24_src1_col_less(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTPARTMAX_f32_2x24_src1_col_less(float *a, float *b, float *c, void *stream) { + TPARTMAX_f32_2x24_src1_col_less<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case: f32 128x64 src1 row less +extern "C" __global__ AICORE void TPARTMAX_f32_128x64_src1_row_less(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTPARTMAX_f32_128x64_src1_row_less(float *a, float *b, float *c, void *stream) { + TPARTMAX_f32_128x64_src1_row_less<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case: f32 95x95 full +extern "C" __global__ AICORE void TPARTMAX_f32_95x95_full(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTPARTMAX_f32_95x95_full(float *a, float *b, float *c, void *stream) { + TPARTMAX_f32_95x95_full<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case: f32 122x123 complex +extern "C" __global__ AICORE void TPARTMAX_f32_122x123_complex(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTPARTMAX_f32_122x123_complex(float *a, float *b, float *c, void *stream) { + TPARTMAX_f32_122x123_complex<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case: f16 122x123 complex +extern "C" __global__ AICORE void TPARTMAX_f16_122x123_complex(__gm__ uint16_t *a, __gm__ uint16_t *b, __gm__ uint16_t *c); + +void LaunchTPARTMAX_f16_122x123_complex(uint16_t *a, uint16_t *b, uint16_t *c, void *stream) { + TPARTMAX_f16_122x123_complex<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b, (__gm__ uint16_t *)c); +} + +// Case: i16 122x123 complex +extern "C" __global__ AICORE void TPARTMAX_i16_122x123_complex(__gm__ int16_t *a, __gm__ int16_t *b, __gm__ int16_t *c); + +void LaunchTPARTMAX_i16_122x123_complex(int16_t *a, int16_t *b, int16_t *c, void *stream) { + TPARTMAX_i16_122x123_complex<<<1, nullptr, stream>>>((__gm__ int16_t *)a, (__gm__ int16_t *)b, (__gm__ int16_t *)c); +} + +// Case: i32 122x123 complex +extern "C" __global__ AICORE void TPARTMAX_i32_122x123_complex(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); + +void LaunchTPARTMAX_i32_122x123_complex(int32_t *a, int32_t *b, int32_t *c, void *stream) { + TPARTMAX_i32_122x123_complex<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); +} + +// Case: u16 122x123 complex +extern "C" __global__ AICORE void TPARTMAX_u16_122x123_complex(__gm__ uint16_t *a, __gm__ uint16_t *b, __gm__ uint16_t *c); + +void LaunchTPARTMAX_u16_122x123_complex(uint16_t *a, uint16_t *b, uint16_t *c, void *stream) { + TPARTMAX_u16_122x123_complex<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b, (__gm__ uint16_t *)c); +} + +// Case: u32 122x123 complex +extern "C" __global__ AICORE void TPARTMAX_u32_122x123_complex(__gm__ uint32_t *a, __gm__ uint32_t *b, __gm__ uint32_t *c); + +void LaunchTPARTMAX_u32_122x123_complex(uint32_t *a, uint32_t *b, uint32_t *c, void *stream) { + TPARTMAX_u32_122x123_complex<<<1, nullptr, stream>>>((__gm__ uint32_t *)a, (__gm__ uint32_t *)b, (__gm__ uint32_t *)c); +} + +// Case: i8 122x123 complex +extern "C" __global__ AICORE void TPARTMAX_i8_122x123_complex(__gm__ int8_t *a, __gm__ int8_t *b, __gm__ int8_t *c); + +void LaunchTPARTMAX_i8_122x123_complex(int8_t *a, int8_t *b, int8_t *c, void *stream) { + TPARTMAX_i8_122x123_complex<<<1, nullptr, stream>>>((__gm__ int8_t *)a, (__gm__ int8_t *)b, (__gm__ int8_t *)c); +} + +// Case: u8 122x123 complex +extern "C" __global__ AICORE void TPARTMAX_u8_122x123_complex(__gm__ uint8_t *a, __gm__ uint8_t *b, __gm__ uint8_t *c); + +void LaunchTPARTMAX_u8_122x123_complex(uint8_t *a, uint8_t *b, uint8_t *c, void *stream) { + TPARTMAX_u8_122x123_complex<<<1, nullptr, stream>>>((__gm__ uint8_t *)a, (__gm__ uint8_t *)b, (__gm__ uint8_t *)c); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartmax/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tpartmax/main.cpp new file mode 100644 index 000000000..c81ab0e62 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartmax/main.cpp @@ -0,0 +1,230 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tpartmax ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTPARTMAX_f32_64x64_full(float *a, float *b, float *c, void *stream); +void LaunchTPARTMAX_f32_2x24_src1_col_less(float *a, float *b, float *c, void *stream); +void LaunchTPARTMAX_f32_128x64_src1_row_less(float *a, float *b, float *c, void *stream); +void LaunchTPARTMAX_f32_95x95_full(float *a, float *b, float *c, void *stream); +void LaunchTPARTMAX_f32_122x123_complex(float *a, float *b, float *c, void *stream); +void LaunchTPARTMAX_f16_122x123_complex(uint16_t *a, uint16_t *b, uint16_t *c, void *stream); +void LaunchTPARTMAX_i16_122x123_complex(int16_t *a, int16_t *b, int16_t *c, void *stream); +void LaunchTPARTMAX_i32_122x123_complex(int32_t *a, int32_t *b, int32_t *c, void *stream); +void LaunchTPARTMAX_u16_122x123_complex(uint16_t *a, uint16_t *b, uint16_t *c, void *stream); +void LaunchTPARTMAX_u32_122x123_complex(uint32_t *a, uint32_t *b, uint32_t *c, void *stream); +void LaunchTPARTMAX_i8_122x123_complex(int8_t *a, int8_t *b, int8_t *c, void *stream); +void LaunchTPARTMAX_u8_122x123_complex(uint8_t *a, uint8_t *b, uint8_t *c, void *stream); + +using LaunchFn = void (*)(void *, void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols (valid cols) + size_t src0ValidRows; // src0 effective rows + size_t src0ValidCols; // src0 effective cols + size_t src1ValidRows; // src1 effective rows + size_t src1ValidCols; // src1 effective cols + size_t dstValidRows; // dst effective rows + size_t dstValidCols; // dst effective cols + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_64x64_full", reinterpret_cast(LaunchTPARTMAX_f32_64x64_full), 64, 64, 64, 64, 64, 64, 64, 64, sizeof(float)}, + {"f32_2x24_src1_col_less", reinterpret_cast(LaunchTPARTMAX_f32_2x24_src1_col_less), 2, 24, 2, 24, 2, 8, 2, 24, sizeof(float)}, + {"f32_128x64_src1_row_less", reinterpret_cast(LaunchTPARTMAX_f32_128x64_src1_row_less), 128, 64,128, 64, 96, 64,128, 64, sizeof(float)}, + {"f32_95x95_full", reinterpret_cast(LaunchTPARTMAX_f32_95x95_full), 95, 95, 95, 95, 95, 95, 95, 95, sizeof(float)}, + {"f32_122x123_complex", reinterpret_cast(LaunchTPARTMAX_f32_122x123_complex), 122,123,104,123,122,110,122,123, sizeof(float)}, + {"f16_122x123_complex", reinterpret_cast(LaunchTPARTMAX_f16_122x123_complex), 122,123,104,123,122,110,122,123, sizeof(uint16_t)}, + {"i16_122x123_complex", reinterpret_cast(LaunchTPARTMAX_i16_122x123_complex), 122,123,104,123,122,110,122,123, sizeof(int16_t)}, + {"i32_122x123_complex", reinterpret_cast(LaunchTPARTMAX_i32_122x123_complex), 122,123,104,123,122,110,122,123, sizeof(int32_t)}, + {"u16_122x123_complex", reinterpret_cast(LaunchTPARTMAX_u16_122x123_complex), 122,123,104,123,122,110,122,123, sizeof(uint16_t)}, + {"u32_122x123_complex", reinterpret_cast(LaunchTPARTMAX_u32_122x123_complex), 122,123,104,123,122,110,122,123, sizeof(uint32_t)}, + {"i8_122x123_complex", reinterpret_cast(LaunchTPARTMAX_i8_122x123_complex), 122,123,104,123,122,110,122,123, sizeof(int8_t)}, + {"u8_122x123_complex", reinterpret_cast(LaunchTPARTMAX_u8_122x123_complex), 122,123,104,123,122,110,122,123, sizeof(uint8_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +// Calculate aligned cols for 32-byte alignment +static size_t CalcAlignedCols(size_t cols, size_t elemSize) { + size_t totalBytes = cols * elemSize; + size_t alignedBytes = ((totalBytes + 31) / 32) * 32; + return alignedBytes / elemSize; +} + +// Helper to pad data with stride +static void PadDataWithStride(const void *src, void *dst, size_t rows, size_t cols, + size_t alignedCols, size_t elemSize) { + const char *srcPtr = static_cast(src); + char *dstPtr = static_cast(dst); + for (size_t r = 0; r < rows; ++r) { + memcpy(dstPtr + r * alignedCols * elemSize, + srcPtr + r * cols * elemSize, + cols * elemSize); + // Zero-fill padding region (optional, data will be overwritten by kernel) + memset(dstPtr + r * alignedCols * elemSize + cols * elemSize, + 0, + (alignedCols - cols) * elemSize); + } +} + +// Helper to unpad data (extract valid cols) +static void UnpadDataWithStride(const void *src, void *dst, size_t rows, size_t cols, + size_t alignedCols, size_t elemSize) { + const char *srcPtr = static_cast(src); + char *dstPtr = static_cast(dst); + for (size_t r = 0; r < rows; ++r) { + memcpy(dstPtr + r * cols * elemSize, + srcPtr + r * alignedCols * elemSize, + cols * elemSize); + } +} + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + const size_t alignedCols = CalcAlignedCols(tc.cols, tc.elemSize); + const size_t paddedSize = tc.rows * alignedCols * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, src0_valid=%zux%zu, src1_valid=%zux%zu, dst_valid=%zux%zu, alignedCols=%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.src0ValidRows, tc.src0ValidCols, + tc.src1ValidRows, tc.src1ValidCols, tc.dstValidRows, tc.dstValidCols, alignedCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + + void *src0HostOrig = nullptr, *src1HostOrig = nullptr, *dstHostOrig = nullptr; + void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + // Allocate host buffers for original data (contiguous) + aclrtMallocHost((void **)(&src0HostOrig), fileSize); + aclrtMallocHost((void **)(&src1HostOrig), fileSize); + aclrtMallocHost((void **)(&dstHostOrig), fileSize); + + // Allocate host buffers for padded data + aclrtMallocHost((void **)(&src0Host), paddedSize); + aclrtMallocHost((void **)(&src1Host), paddedSize); + aclrtMallocHost((void **)(&dstHost), paddedSize); + + // Allocate device buffers with padded size + aclrtMalloc((void **)&src0Device, paddedSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, paddedSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, paddedSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (rc == 0) { + size_t src0FileSize = fileSize; + size_t src1FileSize = fileSize; + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0HostOrig, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1HostOrig, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + } + + if (rc == 0) { + // Pad input data with stride + PadDataWithStride(src0HostOrig, src0Host, tc.rows, tc.cols, alignedCols, tc.elemSize); + PadDataWithStride(src1HostOrig, src1Host, tc.rows, tc.cols, alignedCols, tc.elemSize); + + aclrtMemcpy(src0Device, paddedSize, src0Host, paddedSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, paddedSize, src1Host, paddedSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, paddedSize, dstDevice, paddedSize, ACL_MEMCPY_DEVICE_TO_HOST); + + // Unpad output data + UnpadDataWithStride(dstHost, dstHostOrig, tc.rows, tc.cols, alignedCols, tc.elemSize); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHostOrig, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + if (src0HostOrig != nullptr) + aclrtFreeHost(src0HostOrig); + if (src1HostOrig != nullptr) + aclrtFreeHost(src1HostOrig); + if (dstHostOrig != nullptr) + aclrtFreeHost(dstHostOrig); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tpartmax [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartmax/tpartmax.pto b/test/tilelang_st/npu/a5/src/st/testcase/tpartmax/tpartmax.pto new file mode 100644 index 000000000..9dfeab2fe --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartmax/tpartmax.pto @@ -0,0 +1,717 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tpartmax: partial elementwise max with valid region handling. +// Multiple cases with different valid_shape combinations in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // Case: f32_64x64_full (src0 valid 64x64, src1 valid 64x64, dst valid 64x64) + func.func @TPARTMAX_f32_64x64_full(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + + // src0: valid region (64,64) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: valid region (64,64) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: valid region (64,64) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tpartmax ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + return + } + + // Case: f32_2x24_src1_col_less (src0 valid 2x24, src1 valid 2x8, dst valid 2x24) + func.func @TPARTMAX_f32_2x24_src1_col_less(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c24 = arith.constant 24 : index + %c48 = arith.constant 48 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c2, %c24], + strides = [%c48, %c48, %c48, %c24, %c1] + : !pto.tensor_view<1x1x1x2x24xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c2, %c24], + strides = [%c48, %c48, %c48, %c24, %c1] + : !pto.tensor_view<1x1x1x2x24xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c2, %c24], + strides = [%c48, %c48, %c48, %c24, %c1] + : !pto.tensor_view<1x1x1x2x24xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c24] + : !pto.tensor_view<1x1x1x2x24xf32> -> !pto.partition_tensor_view<1x1x1x2x24xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c24] + : !pto.tensor_view<1x1x1x2x24xf32> -> !pto.partition_tensor_view<1x1x1x2x24xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c24] + : !pto.tensor_view<1x1x1x2x24xf32> -> !pto.partition_tensor_view<1x1x1x2x24xf32> + + // src0: valid region (2,24) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: valid region (2,8) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: valid region (2,24) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x2x24xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x2x24xf32>) + outs(%b : !pto.tile_buf) + + pto.tpartmax ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x2x24xf32>) + return + } + + // Case: f32_128x64_src1_row_less (src0 valid 128x64, src1 valid 96x64, dst valid 128x64) + func.func @TPARTMAX_f32_128x64_src1_row_less(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c64 = arith.constant 64 : index + %c8192 = arith.constant 8192 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c128, %c64], + strides = [%c8192, %c8192, %c8192, %c64, %c1] + : !pto.tensor_view<1x1x1x128x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c128, %c64], + strides = [%c8192, %c8192, %c8192, %c64, %c1] + : !pto.tensor_view<1x1x1x128x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c128, %c64], + strides = [%c8192, %c8192, %c8192, %c64, %c1] + : !pto.tensor_view<1x1x1x128x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c64] + : !pto.tensor_view<1x1x1x128x64xf32> -> !pto.partition_tensor_view<1x1x1x128x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c64] + : !pto.tensor_view<1x1x1x128x64xf32> -> !pto.partition_tensor_view<1x1x1x128x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c64] + : !pto.tensor_view<1x1x1x128x64xf32> -> !pto.partition_tensor_view<1x1x1x128x64xf32> + + // src0: valid region (128,64) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: valid region (96,64) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: valid region (128,64) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x128x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x128x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tpartmax ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x128x64xf32>) + return + } + + // Case: f32_95x95_full (src0 valid 95x95, src1 valid 95x95, dst valid 95x95) + func.func @TPARTMAX_f32_95x95_full(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c95 = arith.constant 95 : index + %c96 = arith.constant 96 : index + %c9120 = arith.constant 9120 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c95, %c96], + strides = [%c9120, %c9120, %c9120, %c96, %c1] + : !pto.tensor_view<1x1x1x95x96xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c95, %c96], + strides = [%c9120, %c9120, %c9120, %c96, %c1] + : !pto.tensor_view<1x1x1x95x96xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c95, %c96], + strides = [%c9120, %c9120, %c9120, %c96, %c1] + : !pto.tensor_view<1x1x1x95x96xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c95, %c96] + : !pto.tensor_view<1x1x1x95x96xf32> -> !pto.partition_tensor_view<1x1x1x95x96xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c95, %c96] + : !pto.tensor_view<1x1x1x95x96xf32> -> !pto.partition_tensor_view<1x1x1x95x96xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c95, %c95] + : !pto.tensor_view<1x1x1x95x96xf32> -> !pto.partition_tensor_view<1x1x1x95x95xf32> + + // src0: valid region (95,95) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: valid region (95,95) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: valid region (95,95) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x95x96xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x95x96xf32>) + outs(%b : !pto.tile_buf) + + pto.tpartmax ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x95x95xf32>) + return + } + + // Case: f32_122x123_complex (src0 valid 104x123, src1 valid 122x110, dst valid 122x123) + func.func @TPARTMAX_f32_122x123_complex(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c122 = arith.constant 122 : index + %c128 = arith.constant 128 : index + %c15616 = arith.constant 15616 : index + %c123 = arith.constant 123 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xf32> -> !pto.partition_tensor_view<1x1x1x122x128xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xf32> -> !pto.partition_tensor_view<1x1x1x122x128xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c123] + : !pto.tensor_view<1x1x1x122x128xf32> -> !pto.partition_tensor_view<1x1x1x122x123xf32> + + // src0: valid region (104,123) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: valid region (122,110) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: valid region (122,123) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x122x128xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x122x128xf32>) + outs(%b : !pto.tile_buf) + + pto.tpartmax ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x122x123xf32>) + return + } + + // Case: f16_122x123_complex (src0 valid 104x123, src1 valid 122x110, dst valid 122x123) + func.func @TPARTMAX_f16_122x123_complex(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c122 = arith.constant 122 : index + %c128 = arith.constant 128 : index + %c15616 = arith.constant 15616 : index + %c123 = arith.constant 123 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xf16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xf16> -> !pto.partition_tensor_view<1x1x1x122x128xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xf16> -> !pto.partition_tensor_view<1x1x1x122x128xf16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c123] + : !pto.tensor_view<1x1x1x122x128xf16> -> !pto.partition_tensor_view<1x1x1x122x123xf16> + + // src0: valid region (104,123) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: valid region (122,110) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: valid region (122,123) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x122x128xf16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x122x128xf16>) + outs(%b : !pto.tile_buf) + + pto.tpartmax ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x122x123xf16>) + return + } + + // Case: i16_122x123_complex (src0 valid 104x123, src1 valid 122x110, dst valid 122x123) + func.func @TPARTMAX_i16_122x123_complex(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c122 = arith.constant 122 : index + %c128 = arith.constant 128 : index + %c15616 = arith.constant 15616 : index + %c123 = arith.constant 123 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xi16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xi16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xi16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xi16> -> !pto.partition_tensor_view<1x1x1x122x128xi16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xi16> -> !pto.partition_tensor_view<1x1x1x122x128xi16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c123] + : !pto.tensor_view<1x1x1x122x128xi16> -> !pto.partition_tensor_view<1x1x1x122x123xi16> + + // src0: valid region (104,123) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: valid region (122,110) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: valid region (122,123) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x122x128xi16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x122x128xi16>) + outs(%b : !pto.tile_buf) + + pto.tpartmax ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x122x123xi16>) + return + } + + // Case: i32_122x123_complex (src0 valid 104x123, src1 valid 122x110, dst valid 122x123) + func.func @TPARTMAX_i32_122x123_complex(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c122 = arith.constant 122 : index + %c128 = arith.constant 128 : index + %c15616 = arith.constant 15616 : index + %c123 = arith.constant 123 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xi32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xi32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xi32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xi32> -> !pto.partition_tensor_view<1x1x1x122x128xi32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xi32> -> !pto.partition_tensor_view<1x1x1x122x128xi32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c123] + : !pto.tensor_view<1x1x1x122x128xi32> -> !pto.partition_tensor_view<1x1x1x122x123xi32> + + // src0: valid region (104,123) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: valid region (122,110) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: valid region (122,123) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x122x128xi32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x122x128xi32>) + outs(%b : !pto.tile_buf) + + pto.tpartmax ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x122x123xi32>) + return + } + + // Case: u16_122x123_complex (src0 valid 104x123, src1 valid 122x110, dst valid 122x123) + func.func @TPARTMAX_u16_122x123_complex(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c122 = arith.constant 122 : index + %c128 = arith.constant 128 : index + %c15616 = arith.constant 15616 : index + %c123 = arith.constant 123 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xui16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xui16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xui16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xui16> -> !pto.partition_tensor_view<1x1x1x122x128xui16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xui16> -> !pto.partition_tensor_view<1x1x1x122x128xui16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c123] + : !pto.tensor_view<1x1x1x122x128xui16> -> !pto.partition_tensor_view<1x1x1x122x123xui16> + + // src0: valid region (104,123) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: valid region (122,110) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: valid region (122,123) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x122x128xui16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x122x128xui16>) + outs(%b : !pto.tile_buf) + + pto.tpartmax ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x122x123xui16>) + return + } + + // Case: u32_122x123_complex (src0 valid 104x123, src1 valid 122x110, dst valid 122x123) + func.func @TPARTMAX_u32_122x123_complex(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c122 = arith.constant 122 : index + %c128 = arith.constant 128 : index + %c15616 = arith.constant 15616 : index + %c123 = arith.constant 123 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xui32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xui32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xui32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xui32> -> !pto.partition_tensor_view<1x1x1x122x128xui32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xui32> -> !pto.partition_tensor_view<1x1x1x122x128xui32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c123] + : !pto.tensor_view<1x1x1x122x128xui32> -> !pto.partition_tensor_view<1x1x1x122x123xui32> + + // src0: valid region (104,123) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: valid region (122,110) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: valid region (122,123) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x122x128xui32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x122x128xui32>) + outs(%b : !pto.tile_buf) + + pto.tpartmax ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x122x123xui32>) + return + } + + // Case: i8_122x123_complex (src0 valid 104x123, src1 valid 122x110, dst valid 122x123) + func.func @TPARTMAX_i8_122x123_complex(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c122 = arith.constant 122 : index + %c128 = arith.constant 128 : index + %c15616 = arith.constant 15616 : index + %c123 = arith.constant 123 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xi8> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xi8> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xi8> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xi8> -> !pto.partition_tensor_view<1x1x1x122x128xi8> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xi8> -> !pto.partition_tensor_view<1x1x1x122x128xi8> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c123] + : !pto.tensor_view<1x1x1x122x128xi8> -> !pto.partition_tensor_view<1x1x1x122x123xi8> + + // src0: valid region (104,123) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: valid region (122,110) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: valid region (122,123) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x122x128xi8>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x122x128xi8>) + outs(%b : !pto.tile_buf) + + pto.tpartmax ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x122x123xi8>) + return + } + + // Case: u8_122x123_complex (src0 valid 104x123, src1 valid 122x110, dst valid 122x123) + func.func @TPARTMAX_u8_122x123_complex(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c122 = arith.constant 122 : index + %c128 = arith.constant 128 : index + %c15616 = arith.constant 15616 : index + %c123 = arith.constant 123 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xui8> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xui8> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xui8> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xui8> -> !pto.partition_tensor_view<1x1x1x122x128xui8> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xui8> -> !pto.partition_tensor_view<1x1x1x122x128xui8> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c123] + : !pto.tensor_view<1x1x1x122x128xui8> -> !pto.partition_tensor_view<1x1x1x122x123xui8> + + // src0: valid region (104,123) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: valid region (122,110) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: valid region (122,123) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x122x128xui8>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x122x128xui8>) + outs(%b : !pto.tile_buf) + + pto.tpartmax ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x122x123xui8>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartmin/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tpartmin/CMakeLists.txt new file mode 100644 index 000000000..cfb480147 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartmin/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tpartmin) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartmin/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tpartmin/cases.py new file mode 100644 index 000000000..50976fbd1 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartmin/cases.py @@ -0,0 +1,153 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tpartmin ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions (same for src0/src1/dst). + - valid_shape: (valid_rows, valid_cols) — src0 valid region (src0_eq_dst scenario). + - src1_vshape: (src1_valid_rows, src1_valid_cols) — src1 valid region. + May be smaller than dst valid region for partial min cases. + - dst_vshape: (dst_valid_rows, dst_valid_cols) — dst valid region. + - eps: tolerance for numpy.allclose (atol and rtol). + +tpartmin semantics: + - If src0_valid == dst_valid: dst[:src1_rows,:src1_cols] = min(src0[:src1_rows,:src1_cols], src1[:src1_rows,:src1_cols]) + dst[src1_rows:,:] = src0[src1_rows:,:] (copy remaining rows) + OR (for col_less) dst[:,:src1_cols] = min(src0[:,:src1_cols], src1[:,:src1_cols]) + dst[:,src1_cols:] = src0[:,src1_cols:] (copy remaining cols) + - If src1_valid == dst_valid: similar logic with src1 as the full operand. + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + # float32 cases from pto-isa + { + "name": "f32_64x64_full", + "dtype": np.float32, + "shape": (64, 64), + "valid_shape": (64, 64), # src0 valid region + "src1_vshape": (64, 64), # src1 valid region (same as dst) + "dst_vshape": (64, 64), # dst valid region + "eps": 1e-6, + }, + { + "name": "f32_2x24_src1_col_less", + "dtype": np.float32, + "shape": (2, 24), + "valid_shape": (2, 24), # src0 valid region (equals dst) + "src1_vshape": (2, 8), # src1 valid region (col_less) + "dst_vshape": (2, 24), # dst valid region + "eps": 1e-6, + }, + { + "name": "f32_128x64_src1_row_less", + "dtype": np.float32, + "shape": (128, 64), + "valid_shape": (128, 64), # src0 valid region (equals dst) + "src1_vshape": (96, 64), # src1 valid region (row_less) + "dst_vshape": (128, 64), # dst valid region + "eps": 1e-6, + }, + { + "name": "f32_95x95_full", + "dtype": np.float32, + "shape": (95, 95), + "valid_shape": (95, 95), # src0 valid region + "src1_vshape": (95, 95), # src1 valid region (same as dst) + "dst_vshape": (95, 95), # dst valid region + "eps": 1e-6, + }, + { + "name": "f32_122x123_complex", + "dtype": np.float32, + "shape": (122, 123), + "valid_shape": (104, 123), # src0 valid region + "src1_vshape": (122, 110), # src1 valid region + "dst_vshape": (122, 123), # dst valid region (src1 rows, src0 cols) + "eps": 1e-6, + }, + # float16 cases from pto-isa + { + "name": "f16_122x123_complex", + "dtype": np.float16, + "shape": (122, 123), + "valid_shape": (104, 123), # src0 valid region + "src1_vshape": (122, 110), # src1 valid region + "dst_vshape": (122, 123), # dst valid region + "eps": 1e-3, + }, + # int16 cases from pto-isa + { + "name": "i16_122x123_complex", + "dtype": np.int16, + "shape": (122, 123), + "valid_shape": (104, 123), # src0 valid region + "src1_vshape": (122, 110), # src1 valid region + "dst_vshape": (122, 123), # dst valid region + "eps": 0, + }, + # int32 cases from pto-isa + { + "name": "i32_122x123_complex", + "dtype": np.int32, + "shape": (122, 123), + "valid_shape": (104, 123), # src0 valid region + "src1_vshape": (122, 110), # src1 valid region + "dst_vshape": (122, 123), # dst valid region + "eps": 0, + }, + # uint16 cases from pto-isa + { + "name": "u16_122x123_complex", + "dtype": np.uint16, + "shape": (122, 123), + "valid_shape": (104, 123), # src0 valid region + "src1_vshape": (122, 110), # src1 valid region + "dst_vshape": (122, 123), # dst valid region + "eps": 0, + }, + # uint32 cases from pto-isa + { + "name": "u32_122x123_complex", + "dtype": np.uint32, + "shape": (122, 123), + "valid_shape": (104, 123), # src0 valid region + "src1_vshape": (122, 110), # src1 valid region + "dst_vshape": (122, 123), # dst valid region + "eps": 0, + }, + # int8 cases from pto-isa + { + "name": "i8_122x123_complex", + "dtype": np.int8, + "shape": (122, 123), + "valid_shape": (104, 123), # src0 valid region + "src1_vshape": (122, 110), # src1 valid region + "dst_vshape": (122, 123), # dst valid region + "eps": 0, + }, + # uint8 cases from pto-isa + { + "name": "u8_122x123_complex", + "dtype": np.uint8, + "shape": (122, 123), + "valid_shape": (104, 123), # src0 valid region + "src1_vshape": (122, 110), # src1 valid region + "dst_vshape": (122, 123), # dst valid region + "eps": 0, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartmin/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tpartmin/compare.py new file mode 100644 index 000000000..283ee788a --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartmin/compare.py @@ -0,0 +1,53 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +# Add parent directory to path for st_common import +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from st_common import result_cmp, style_fail, style_pass + +from cases import CASES + + +def main(): + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + dtype = case["dtype"] + dst_vr, dst_vc = case["dst_vshape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=dtype).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=dtype).reshape(shape) + + # Compare only the dst valid region + ok = result_cmp(golden[:dst_vr, :dst_vc], output[:dst_vr, :dst_vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartmin/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tpartmin/gen_data.py new file mode 100644 index 000000000..fb3766f42 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartmin/gen_data.py @@ -0,0 +1,127 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +import os +import sys + +# Add parent directory to path for st_common import +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from st_common import setup_case_rng, save_case_data + +from cases import CASES + + +def _to_tuple(shape): + """Convert shape to tuple if needed.""" + if isinstance(shape, tuple): + return shape + return tuple(shape) + + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = _to_tuple(case["shape"]) + src0_valid = _to_tuple(case["valid_shape"]) + src1_valid = _to_tuple(case["src1_vshape"]) + dst_valid = _to_tuple(case["dst_vshape"]) + + rows, cols = shape + src0_vr, src0_vc = src0_valid + src1_vr, src1_vc = src1_valid + dst_vr, dst_vc = dst_valid + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + input2 = np.random.randint(1, 10, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + + # tpartmin semantics (based on pto-isa TPartBinOps.hpp TCopyPadOp): + # Algorithm: + # 1. dst[:] = Max (padding for min operation) + # 2. dst[0:src0_vr, 0:src0_vc] = src0[0:src0_vr, 0:src0_vc] (copy src0 to dst) + # 3. dst[0:src1_vr, 0:src1_vc] = min(dst[0:src1_vr, 0:src1_vc], src1[0:src1_vr, 0:src1_vc]) + # (apply min in src1 valid region) + + src0_eq_dst = (src0_vr == dst_vr and src0_vc == dst_vc) + src1_eq_dst = (src1_vr == dst_vr and src1_vc == dst_vc) + + if src0_eq_dst and src1_eq_dst: + # Full min: both src0 and src1 cover entire dst + golden[:dst_vr, :dst_vc] = np.minimum(input1[:dst_vr, :dst_vc], input2[:dst_vr, :dst_vc]).astype(dtype, copy=False) + elif src0_eq_dst: + # src0 covers dst, src1 is partial + # dst = src0 (copy), then min(dst, src1) in src1 region = min(src0, src1) in src1 region, src0 in rest + golden[:src1_vr, :src1_vc] = np.minimum(input1[:src1_vr, :src1_vc], input2[:src1_vr, :src1_vc]).astype(dtype, copy=False) + if src1_vc < dst_vc: + golden[:src1_vr, src1_vc:dst_vc] = input1[:src1_vr, src1_vc:dst_vc].copy() + if src1_vr < dst_vr: + golden[src1_vr:dst_vr, :dst_vc] = input1[src1_vr:dst_vr, :dst_vc].copy() + elif src1_eq_dst: + # src1 covers dst, src0 is partial + # dst = Max, then copy src0 in src0 region, then min(dst, src1) in src1 region + golden[:src0_vr, :src0_vc] = np.minimum(input1[:src0_vr, :src0_vc], input2[:src0_vr, :src0_vc]).astype(dtype, copy=False) + if src0_vc < dst_vc: + golden[:src0_vr, src0_vc:dst_vc] = input2[:src0_vr, src0_vc:dst_vc].copy() + if src0_vr < dst_vr: + golden[src0_vr:dst_vr, :dst_vc] = input2[src0_vr:dst_vr, :dst_vc].copy() + else: + min_vr = min(src0_vr, src1_vr) + min_vc = min(src0_vc, src1_vc) + + # Region 1: [0:min_vr, 0:min_vc] - overlapping region (both src0 and src1 valid) + golden[:min_vr, :min_vc] = np.minimum(input1[:min_vr, :min_vc], input2[:min_vr, :min_vc]).astype(dtype, copy=False) + + # Region 2: [0:src0_vr, min_vc:src0_vc] if src0_vc > min_vc + if src0_vc > min_vc: + golden[:src0_vr, min_vc:src0_vc] = input1[:src0_vr, min_vc:src0_vc].copy() + + # Region 3: [min_vr:src1_vr, 0:min_vc] if src1_vr > min_vr + if src1_vr > min_vr: + golden[min_vr:src1_vr, :min_vc] = input2[min_vr:src1_vr, :min_vc].copy() + + # Region 4: [min_vr:src1_vr, min_vc:src1_vc] if src1_vr > min_vr AND src1_vc > min_vc + if src1_vr > min_vr and src1_vc > min_vc: + golden[min_vr:src1_vr, min_vc:src1_vc] = input2[min_vr:src1_vr, min_vc:src1_vc].copy() + + # Region 5: [0:min_vr, src1_vc:src0_vc] if src0_vc > src1_vc + if src0_vc > src1_vc and min_vr > 0: + # Already handled in Region 2 if rows are [0:src0_vr] + pass # Region 2 covers this + + if src1_vr > src0_vr and src0_vc > src1_vc: + # Region [src0_vr:src1_vr, src1_vc:src0_vc] = Max (neither covers) + # This is correct for tpartmin - padding value is Max + # For floats, we use np.inf. For integers, use dtype max. + if dtype == np.float32: + max_val = np.finfo(np.float32).max + elif dtype == np.float16: + max_val = np.finfo(np.float16).max + elif dtype == np.int8: + max_val = np.iinfo(np.int8).max + elif dtype == np.uint8: + max_val = np.iinfo(np.uint8).max + elif dtype == np.int16: + max_val = np.iinfo(np.int16).max + elif dtype == np.uint16: + max_val = np.iinfo(np.uint16).max + elif dtype == np.int32: + max_val = np.iinfo(np.int32).max + elif dtype == np.uint32: + max_val = np.iinfo(np.uint32).max + else: + max_val = np.iinfo(dtype).max + golden[src0_vr:src1_vr, src1_vc:src0_vc] = max_val + + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} src0_valid={src0_valid} src1_valid={src1_valid} dst_valid={dst_valid} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartmin/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tpartmin/launch.cpp new file mode 100644 index 000000000..4fdee00b6 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartmin/launch.cpp @@ -0,0 +1,97 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case: f32 64x64 full +extern "C" __global__ AICORE void TPARTMIN_f32_64x64_full(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTPARTMIN_f32_64x64_full(float *a, float *b, float *c, void *stream) { + TPARTMIN_f32_64x64_full<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case: f32 2x24 src1 col less +extern "C" __global__ AICORE void TPARTMIN_f32_2x24_src1_col_less(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTPARTMIN_f32_2x24_src1_col_less(float *a, float *b, float *c, void *stream) { + TPARTMIN_f32_2x24_src1_col_less<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case: f32 128x64 src1 row less +extern "C" __global__ AICORE void TPARTMIN_f32_128x64_src1_row_less(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTPARTMIN_f32_128x64_src1_row_less(float *a, float *b, float *c, void *stream) { + TPARTMIN_f32_128x64_src1_row_less<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case: f32 95x95 full +extern "C" __global__ AICORE void TPARTMIN_f32_95x95_full(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTPARTMIN_f32_95x95_full(float *a, float *b, float *c, void *stream) { + TPARTMIN_f32_95x95_full<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case: f32 122x123 complex +extern "C" __global__ AICORE void TPARTMIN_f32_122x123_complex(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTPARTMIN_f32_122x123_complex(float *a, float *b, float *c, void *stream) { + TPARTMIN_f32_122x123_complex<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case: f16 122x123 complex +extern "C" __global__ AICORE void TPARTMIN_f16_122x123_complex(__gm__ uint16_t *a, __gm__ uint16_t *b, __gm__ uint16_t *c); + +void LaunchTPARTMIN_f16_122x123_complex(uint16_t *a, uint16_t *b, uint16_t *c, void *stream) { + TPARTMIN_f16_122x123_complex<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b, (__gm__ uint16_t *)c); +} + +// Case: i16 122x123 complex +extern "C" __global__ AICORE void TPARTMIN_i16_122x123_complex(__gm__ int16_t *a, __gm__ int16_t *b, __gm__ int16_t *c); + +void LaunchTPARTMIN_i16_122x123_complex(int16_t *a, int16_t *b, int16_t *c, void *stream) { + TPARTMIN_i16_122x123_complex<<<1, nullptr, stream>>>((__gm__ int16_t *)a, (__gm__ int16_t *)b, (__gm__ int16_t *)c); +} + +// Case: i32 122x123 complex +extern "C" __global__ AICORE void TPARTMIN_i32_122x123_complex(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); + +void LaunchTPARTMIN_i32_122x123_complex(int32_t *a, int32_t *b, int32_t *c, void *stream) { + TPARTMIN_i32_122x123_complex<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); +} + +// Case: u16 122x123 complex +extern "C" __global__ AICORE void TPARTMIN_u16_122x123_complex(__gm__ uint16_t *a, __gm__ uint16_t *b, __gm__ uint16_t *c); + +void LaunchTPARTMIN_u16_122x123_complex(uint16_t *a, uint16_t *b, uint16_t *c, void *stream) { + TPARTMIN_u16_122x123_complex<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b, (__gm__ uint16_t *)c); +} + +// Case: u32 122x123 complex +extern "C" __global__ AICORE void TPARTMIN_u32_122x123_complex(__gm__ uint32_t *a, __gm__ uint32_t *b, __gm__ uint32_t *c); + +void LaunchTPARTMIN_u32_122x123_complex(uint32_t *a, uint32_t *b, uint32_t *c, void *stream) { + TPARTMIN_u32_122x123_complex<<<1, nullptr, stream>>>((__gm__ uint32_t *)a, (__gm__ uint32_t *)b, (__gm__ uint32_t *)c); +} + +// Case: i8 122x123 complex +extern "C" __global__ AICORE void TPARTMIN_i8_122x123_complex(__gm__ int8_t *a, __gm__ int8_t *b, __gm__ int8_t *c); + +void LaunchTPARTMIN_i8_122x123_complex(int8_t *a, int8_t *b, int8_t *c, void *stream) { + TPARTMIN_i8_122x123_complex<<<1, nullptr, stream>>>((__gm__ int8_t *)a, (__gm__ int8_t *)b, (__gm__ int8_t *)c); +} + +// Case: u8 122x123 complex +extern "C" __global__ AICORE void TPARTMIN_u8_122x123_complex(__gm__ uint8_t *a, __gm__ uint8_t *b, __gm__ uint8_t *c); + +void LaunchTPARTMIN_u8_122x123_complex(uint8_t *a, uint8_t *b, uint8_t *c, void *stream) { + TPARTMIN_u8_122x123_complex<<<1, nullptr, stream>>>((__gm__ uint8_t *)a, (__gm__ uint8_t *)b, (__gm__ uint8_t *)c); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartmin/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tpartmin/main.cpp new file mode 100644 index 000000000..5251f0149 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartmin/main.cpp @@ -0,0 +1,230 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tpartmin ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTPARTMIN_f32_64x64_full(float *a, float *b, float *c, void *stream); +void LaunchTPARTMIN_f32_2x24_src1_col_less(float *a, float *b, float *c, void *stream); +void LaunchTPARTMIN_f32_128x64_src1_row_less(float *a, float *b, float *c, void *stream); +void LaunchTPARTMIN_f32_95x95_full(float *a, float *b, float *c, void *stream); +void LaunchTPARTMIN_f32_122x123_complex(float *a, float *b, float *c, void *stream); +void LaunchTPARTMIN_f16_122x123_complex(uint16_t *a, uint16_t *b, uint16_t *c, void *stream); +void LaunchTPARTMIN_i16_122x123_complex(int16_t *a, int16_t *b, int16_t *c, void *stream); +void LaunchTPARTMIN_i32_122x123_complex(int32_t *a, int32_t *b, int32_t *c, void *stream); +void LaunchTPARTMIN_u16_122x123_complex(uint16_t *a, uint16_t *b, uint16_t *c, void *stream); +void LaunchTPARTMIN_u32_122x123_complex(uint32_t *a, uint32_t *b, uint32_t *c, void *stream); +void LaunchTPARTMIN_i8_122x123_complex(int8_t *a, int8_t *b, int8_t *c, void *stream); +void LaunchTPARTMIN_u8_122x123_complex(uint8_t *a, uint8_t *b, uint8_t *c, void *stream); + +using LaunchFn = void (*)(void *, void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols (valid cols) + size_t src0ValidRows; // src0 effective rows + size_t src0ValidCols; // src0 effective cols + size_t src1ValidRows; // src1 effective rows + size_t src1ValidCols; // src1 effective cols + size_t dstValidRows; // dst effective rows + size_t dstValidCols; // dst effective cols + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_64x64_full", reinterpret_cast(LaunchTPARTMIN_f32_64x64_full), 64, 64, 64, 64, 64, 64, 64, 64, sizeof(float)}, + {"f32_2x24_src1_col_less", reinterpret_cast(LaunchTPARTMIN_f32_2x24_src1_col_less), 2, 24, 2, 24, 2, 8, 2, 24, sizeof(float)}, + {"f32_128x64_src1_row_less", reinterpret_cast(LaunchTPARTMIN_f32_128x64_src1_row_less), 128, 64,128, 64, 96, 64,128, 64, sizeof(float)}, + {"f32_95x95_full", reinterpret_cast(LaunchTPARTMIN_f32_95x95_full), 95, 95, 95, 95, 95, 95, 95, 95, sizeof(float)}, + {"f32_122x123_complex", reinterpret_cast(LaunchTPARTMIN_f32_122x123_complex), 122,123,104,123,122,110,122,123, sizeof(float)}, + {"f16_122x123_complex", reinterpret_cast(LaunchTPARTMIN_f16_122x123_complex), 122,123,104,123,122,110,122,123, sizeof(uint16_t)}, + {"i16_122x123_complex", reinterpret_cast(LaunchTPARTMIN_i16_122x123_complex), 122,123,104,123,122,110,122,123, sizeof(int16_t)}, + {"i32_122x123_complex", reinterpret_cast(LaunchTPARTMIN_i32_122x123_complex), 122,123,104,123,122,110,122,123, sizeof(int32_t)}, + {"u16_122x123_complex", reinterpret_cast(LaunchTPARTMIN_u16_122x123_complex), 122,123,104,123,122,110,122,123, sizeof(uint16_t)}, + {"u32_122x123_complex", reinterpret_cast(LaunchTPARTMIN_u32_122x123_complex), 122,123,104,123,122,110,122,123, sizeof(uint32_t)}, + {"i8_122x123_complex", reinterpret_cast(LaunchTPARTMIN_i8_122x123_complex), 122,123,104,123,122,110,122,123, sizeof(int8_t)}, + {"u8_122x123_complex", reinterpret_cast(LaunchTPARTMIN_u8_122x123_complex), 122,123,104,123,122,110,122,123, sizeof(uint8_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +// Calculate aligned cols for 32-byte alignment +static size_t CalcAlignedCols(size_t cols, size_t elemSize) { + size_t totalBytes = cols * elemSize; + size_t alignedBytes = ((totalBytes + 31) / 32) * 32; + return alignedBytes / elemSize; +} + +// Helper to pad data with stride +static void PadDataWithStride(const void *src, void *dst, size_t rows, size_t cols, + size_t alignedCols, size_t elemSize) { + const char *srcPtr = static_cast(src); + char *dstPtr = static_cast(dst); + for (size_t r = 0; r < rows; ++r) { + memcpy(dstPtr + r * alignedCols * elemSize, + srcPtr + r * cols * elemSize, + cols * elemSize); + // Zero-fill padding region (optional, data will be overwritten by kernel) + memset(dstPtr + r * alignedCols * elemSize + cols * elemSize, + 0, + (alignedCols - cols) * elemSize); + } +} + +// Helper to unpad data (extract valid cols) +static void UnpadDataWithStride(const void *src, void *dst, size_t rows, size_t cols, + size_t alignedCols, size_t elemSize) { + const char *srcPtr = static_cast(src); + char *dstPtr = static_cast(dst); + for (size_t r = 0; r < rows; ++r) { + memcpy(dstPtr + r * cols * elemSize, + srcPtr + r * alignedCols * elemSize, + cols * elemSize); + } +} + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + const size_t alignedCols = CalcAlignedCols(tc.cols, tc.elemSize); + const size_t paddedSize = tc.rows * alignedCols * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, src0_valid=%zux%zu, src1_valid=%zux%zu, dst_valid=%zux%zu, alignedCols=%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.src0ValidRows, tc.src0ValidCols, + tc.src1ValidRows, tc.src1ValidCols, tc.dstValidRows, tc.dstValidCols, alignedCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + + void *src0HostOrig = nullptr, *src1HostOrig = nullptr, *dstHostOrig = nullptr; + void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + // Allocate host buffers for original data (contiguous) + aclrtMallocHost((void **)(&src0HostOrig), fileSize); + aclrtMallocHost((void **)(&src1HostOrig), fileSize); + aclrtMallocHost((void **)(&dstHostOrig), fileSize); + + // Allocate host buffers for padded data + aclrtMallocHost((void **)(&src0Host), paddedSize); + aclrtMallocHost((void **)(&src1Host), paddedSize); + aclrtMallocHost((void **)(&dstHost), paddedSize); + + // Allocate device buffers with padded size + aclrtMalloc((void **)&src0Device, paddedSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, paddedSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, paddedSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (rc == 0) { + size_t src0FileSize = fileSize; + size_t src1FileSize = fileSize; + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0HostOrig, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1HostOrig, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + } + + if (rc == 0) { + // Pad input data with stride + PadDataWithStride(src0HostOrig, src0Host, tc.rows, tc.cols, alignedCols, tc.elemSize); + PadDataWithStride(src1HostOrig, src1Host, tc.rows, tc.cols, alignedCols, tc.elemSize); + + aclrtMemcpy(src0Device, paddedSize, src0Host, paddedSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, paddedSize, src1Host, paddedSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, paddedSize, dstDevice, paddedSize, ACL_MEMCPY_DEVICE_TO_HOST); + + // Unpad output data + UnpadDataWithStride(dstHost, dstHostOrig, tc.rows, tc.cols, alignedCols, tc.elemSize); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHostOrig, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + if (src0HostOrig != nullptr) + aclrtFreeHost(src0HostOrig); + if (src1HostOrig != nullptr) + aclrtFreeHost(src1HostOrig); + if (dstHostOrig != nullptr) + aclrtFreeHost(dstHostOrig); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tpartmin [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartmin/tpartmin.pto b/test/tilelang_st/npu/a5/src/st/testcase/tpartmin/tpartmin.pto new file mode 100644 index 000000000..3dd1d34cd --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartmin/tpartmin.pto @@ -0,0 +1,718 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use the file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tpartmin: partial elementwise min with valid region handling. +// Multiple cases with different valid_shape combinations in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // Case: f32_64x64_full (src0 valid 64x64, src1 valid 64x64, dst valid 64x64) + func.func @TPARTMIN_f32_64x64_full(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + + // src0: valid region (64,64) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: valid region (64,64) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: valid region (64,64) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tpartmin ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + return + } + + // Case: f32_2x24_src1_col_less (src0 valid 2x24, src1 valid 2x8, dst valid 2x24) + func.func @TPARTMIN_f32_2x24_src1_col_less(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c24 = arith.constant 24 : index + %c48 = arith.constant 48 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c2, %c24], + strides = [%c48, %c48, %c48, %c24, %c1] + : !pto.tensor_view<1x1x1x2x24xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c2, %c24], + strides = [%c48, %c48, %c48, %c24, %c1] + : !pto.tensor_view<1x1x1x2x24xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c2, %c24], + strides = [%c48, %c48, %c48, %c24, %c1] + : !pto.tensor_view<1x1x1x2x24xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c24] + : !pto.tensor_view<1x1x1x2x24xf32> -> !pto.partition_tensor_view<1x1x1x2x24xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c24] + : !pto.tensor_view<1x1x1x2x24xf32> -> !pto.partition_tensor_view<1x1x1x2x24xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c24] + : !pto.tensor_view<1x1x1x2x24xf32> -> !pto.partition_tensor_view<1x1x1x2x24xf32> + + // src0: valid region (2,24) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: valid region (2,8) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: valid region (2,24) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x2x24xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x2x24xf32>) + outs(%b : !pto.tile_buf) + + pto.tpartmin ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x2x24xf32>) + return + } + + // Case: f32_128x64_src1_row_less (src0 valid 128x64, src1 valid 96x64, dst valid 128x64) + func.func @TPARTMIN_f32_128x64_src1_row_less(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c64 = arith.constant 64 : index + %c8192 = arith.constant 8192 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c128, %c64], + strides = [%c8192, %c8192, %c8192, %c64, %c1] + : !pto.tensor_view<1x1x1x128x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c128, %c64], + strides = [%c8192, %c8192, %c8192, %c64, %c1] + : !pto.tensor_view<1x1x1x128x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c128, %c64], + strides = [%c8192, %c8192, %c8192, %c64, %c1] + : !pto.tensor_view<1x1x1x128x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c64] + : !pto.tensor_view<1x1x1x128x64xf32> -> !pto.partition_tensor_view<1x1x1x128x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c64] + : !pto.tensor_view<1x1x1x128x64xf32> -> !pto.partition_tensor_view<1x1x1x128x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c64] + : !pto.tensor_view<1x1x1x128x64xf32> -> !pto.partition_tensor_view<1x1x1x128x64xf32> + + // src0: valid region (128,64) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: valid region (96,64) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: valid region (128,64) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x128x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x128x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tpartmin ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x128x64xf32>) + return + } + + // Case: f32_95x95_full (src0 valid 95x95, src1 valid 95x95, dst valid 95x95) + func.func @TPARTMIN_f32_95x95_full(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c95 = arith.constant 95 : index + %c96 = arith.constant 96 : index + %c9120 = arith.constant 9120 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c95, %c96], + strides = [%c9120, %c9120, %c9120, %c96, %c1] + : !pto.tensor_view<1x1x1x95x96xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c95, %c96], + strides = [%c9120, %c9120, %c9120, %c96, %c1] + : !pto.tensor_view<1x1x1x95x96xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c95, %c96], + strides = [%c9120, %c9120, %c9120, %c96, %c1] + : !pto.tensor_view<1x1x1x95x96xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c95, %c96] + : !pto.tensor_view<1x1x1x95x96xf32> -> !pto.partition_tensor_view<1x1x1x95x96xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c95, %c96] + : !pto.tensor_view<1x1x1x95x96xf32> -> !pto.partition_tensor_view<1x1x1x95x96xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c95, %c95] + : !pto.tensor_view<1x1x1x95x96xf32> -> !pto.partition_tensor_view<1x1x1x95x95xf32> + + // src0: valid region (95,95) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: valid region (95,95) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: valid region (95,95) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x95x96xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x95x96xf32>) + outs(%b : !pto.tile_buf) + + pto.tpartmin ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x95x95xf32>) + return + } + + // Case: f32_122x123_complex (src0 valid 104x123, src1 valid 122x110, dst valid 122x123) + func.func @TPARTMIN_f32_122x123_complex(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c122 = arith.constant 122 : index + %c128 = arith.constant 128 : index + %c15616 = arith.constant 15616 : index + %c123 = arith.constant 123 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xf32> -> !pto.partition_tensor_view<1x1x1x122x128xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xf32> -> !pto.partition_tensor_view<1x1x1x122x128xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c123] + : !pto.tensor_view<1x1x1x122x128xf32> -> !pto.partition_tensor_view<1x1x1x122x123xf32> + + // src0: valid region (104,123) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: valid region (122,110) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: valid region (122,123) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x122x128xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x122x128xf32>) + outs(%b : !pto.tile_buf) + + pto.tpartmin ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x122x123xf32>) + return + } + + // Case: f16_122x123_complex (src0 valid 104x123, src1 valid 122x110, dst valid 122x123) + func.func @TPARTMIN_f16_122x123_complex(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c122 = arith.constant 122 : index + %c128 = arith.constant 128 : index + %c15616 = arith.constant 15616 : index + %c123 = arith.constant 123 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xf16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xf16> -> !pto.partition_tensor_view<1x1x1x122x128xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xf16> -> !pto.partition_tensor_view<1x1x1x122x128xf16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c123] + : !pto.tensor_view<1x1x1x122x128xf16> -> !pto.partition_tensor_view<1x1x1x122x123xf16> + + // src0: valid region (104,123) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: valid region (122,110) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: valid region (122,123) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x122x128xf16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x122x128xf16>) + outs(%b : !pto.tile_buf) + + pto.tpartmin ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x122x123xf16>) + return + } + + // Case: i16_122x123_complex (src0 valid 104x123, src1 valid 122x110, dst valid 122x123) + func.func @TPARTMIN_i16_122x123_complex(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c122 = arith.constant 122 : index + %c128 = arith.constant 128 : index + %c15616 = arith.constant 15616 : index + %c123 = arith.constant 123 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xi16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xi16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xi16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xi16> -> !pto.partition_tensor_view<1x1x1x122x128xi16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xi16> -> !pto.partition_tensor_view<1x1x1x122x128xi16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c123] + : !pto.tensor_view<1x1x1x122x128xi16> -> !pto.partition_tensor_view<1x1x1x122x123xi16> + + // src0: valid region (104,123) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: valid region (122,110) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: valid region (122,123) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x122x128xi16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x122x128xi16>) + outs(%b : !pto.tile_buf) + + pto.tpartmin ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x122x123xi16>) + return + } + + // Case: i32_122x123_complex (src0 valid 104x123, src1 valid 122x110, dst valid 122x123) + func.func @TPARTMIN_i32_122x123_complex(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c122 = arith.constant 122 : index + %c128 = arith.constant 128 : index + %c15616 = arith.constant 15616 : index + %c123 = arith.constant 123 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xi32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xi32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xi32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xi32> -> !pto.partition_tensor_view<1x1x1x122x128xi32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xi32> -> !pto.partition_tensor_view<1x1x1x122x128xi32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c123] + : !pto.tensor_view<1x1x1x122x128xi32> -> !pto.partition_tensor_view<1x1x1x122x123xi32> + + // src0: valid region (104,123) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: valid region (122,110) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: valid region (122,123) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x122x128xi32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x122x128xi32>) + outs(%b : !pto.tile_buf) + + pto.tpartmin ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x122x123xi32>) + return + } + + // Case: u16_122x123_complex (src0 valid 104x123, src1 valid 122x110, dst valid 122x123) + func.func @TPARTMIN_u16_122x123_complex(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c122 = arith.constant 122 : index + %c128 = arith.constant 128 : index + %c15616 = arith.constant 15616 : index + %c123 = arith.constant 123 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xui16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xui16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xui16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xui16> -> !pto.partition_tensor_view<1x1x1x122x128xui16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xui16> -> !pto.partition_tensor_view<1x1x1x122x128xui16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c123] + : !pto.tensor_view<1x1x1x122x128xui16> -> !pto.partition_tensor_view<1x1x1x122x123xui16> + + // src0: valid region (104,123) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: valid region (122,110) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: valid region (122,123) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x122x128xui16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x122x128xui16>) + outs(%b : !pto.tile_buf) + + pto.tpartmin ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x122x123xui16>) + return + } + + // Case: u32_122x123_complex (src0 valid 104x123, src1 valid 122x110, dst valid 122x123) + func.func @TPARTMIN_u32_122x123_complex(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c122 = arith.constant 122 : index + %c128 = arith.constant 128 : index + %c15616 = arith.constant 15616 : index + %c123 = arith.constant 123 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xui32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xui32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xui32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xui32> -> !pto.partition_tensor_view<1x1x1x122x128xui32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xui32> -> !pto.partition_tensor_view<1x1x1x122x128xui32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c123] + : !pto.tensor_view<1x1x1x122x128xui32> -> !pto.partition_tensor_view<1x1x1x122x123xui32> + + // src0: valid region (104,123) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: valid region (122,110) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: valid region (122,123) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x122x128xui32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x122x128xui32>) + outs(%b : !pto.tile_buf) + + pto.tpartmin ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x122x123xui32>) + return + } + + // Case: i8_122x123_complex (src0 valid 104x123, src1 valid 122x110, dst valid 122x123) + func.func @TPARTMIN_i8_122x123_complex(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c122 = arith.constant 122 : index + %c128 = arith.constant 128 : index + %c15616 = arith.constant 15616 : index + %c123 = arith.constant 123 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xi8> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xi8> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xi8> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xi8> -> !pto.partition_tensor_view<1x1x1x122x128xi8> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xi8> -> !pto.partition_tensor_view<1x1x1x122x128xi8> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c123] + : !pto.tensor_view<1x1x1x122x128xi8> -> !pto.partition_tensor_view<1x1x1x122x123xi8> + + // src0: valid region (104,123) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: valid region (122,110) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: valid region (122,123) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x122x128xi8>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x122x128xi8>) + outs(%b : !pto.tile_buf) + + pto.tpartmin ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x122x123xi8>) + return + } + + // Case: u8_122x123_complex (src0 valid 104x123, src1 valid 122x110, dst valid 122x123) + func.func @TPARTMIN_u8_122x123_complex(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c122 = arith.constant 122 : index + %c128 = arith.constant 128 : index + %c15616 = arith.constant 15616 : index + %c123 = arith.constant 123 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xui8> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xui8> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c122, %c128], + strides = [%c15616, %c15616, %c15616, %c128, %c1] + : !pto.tensor_view<1x1x1x122x128xui8> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xui8> -> !pto.partition_tensor_view<1x1x1x122x128xui8> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c128] + : !pto.tensor_view<1x1x1x122x128xui8> -> !pto.partition_tensor_view<1x1x1x122x128xui8> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c122, %c123] + : !pto.tensor_view<1x1x1x122x128xui8> -> !pto.partition_tensor_view<1x1x1x122x123xui8> + + // src0: valid region (104,123) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: valid region (122,110) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: valid region (122,123) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x122x128xui8>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x122x128xui8>) + outs(%b : !pto.tile_buf) + + pto.tpartmin ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x122x123xui8>) + return + } +} + diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartmul/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tpartmul/CMakeLists.txt new file mode 100644 index 000000000..190439e25 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartmul/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tpartmul) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartmul/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tpartmul/cases.py new file mode 100644 index 000000000..ad892fa2b --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartmul/cases.py @@ -0,0 +1,122 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tpartmul ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions (same for src0/src1/dst). + - valid_shape: (valid_rows, valid_cols) — src0 valid region (src0_eq_dst scenario). + - src1_vshape: (src1_valid_rows, src1_valid_cols) — src1 valid region. + May be smaller than dst valid region for partial mul cases. + - dst_vshape: (dst_valid_rows, dst_valid_cols) — dst valid region. + - eps: tolerance for numpy.allclose (atol and rtol). + +tpartmul semantics: + - If src0_valid == dst_valid: dst[:src1_rows,:src1_cols] = src0[:src1_rows,:src1_cols] * src1[:src1_rows,:src1_cols] + dst[src1_rows:,:] = src0[src1_rows:,:] (copy remaining rows) + OR (for col_less) dst[:,:src1_cols] = src0[:,:src1_cols] * src1[:,:src1_cols] + dst[:,src1_cols:] = src0[:,src1_cols:] (copy remaining cols) + - If src1_valid == dst_valid: similar logic with src1 as the full operand. + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + # float32 cases + { + "name": "f32_64x64_full", + "dtype": np.float32, + "shape": (64, 64), + "valid_shape": (64, 64), # src0 valid region + "src1_vshape": (64, 64), # src1 valid region (same as dst) + "dst_vshape": (64, 64), # dst valid region + "eps": 1e-6, + }, + { + "name": "f32_64x64_src0_row_less", + "dtype": np.float32, + "shape": (64, 64), + "valid_shape": (8, 64), # src0 valid region (row_less) + "src1_vshape": (64, 64), # src1 valid region (equals dst) + "dst_vshape": (64, 64), # dst valid region + "eps": 1e-6, + }, + { + "name": "f32_64x64_src0_col_less", + "dtype": np.float32, + "shape": (64, 64), + "valid_shape": (64, 8), # src0 valid region (col_less) + "src1_vshape": (64, 64), # src1 valid region (equals dst) + "dst_vshape": (64, 64), # dst valid region + "eps": 1e-6, + }, + { + "name": "f32_64x64_src1_row_less", + "dtype": np.float32, + "shape": (64, 64), + "valid_shape": (64, 64), # src0 valid region (equals dst) + "src1_vshape": (8, 64), # src1 valid region (row_less) + "dst_vshape": (64, 64), # dst valid region + "eps": 1e-6, + }, + { + "name": "f32_64x64_src1_col_less", + "dtype": np.float32, + "shape": (64, 64), + "valid_shape": (64, 64), # src0 valid region (equals dst) + "src1_vshape": (64, 8), # src1 valid region (col_less) + "dst_vshape": (64, 64), # dst valid region + "eps": 1e-6, + }, + # float16 cases + { + "name": "f16_8x48_src0_col_less", + "dtype": np.float16, + "shape": (8, 48), + "valid_shape": (8, 16), # src0 valid region (col_less) + "src1_vshape": (8, 48), # src1 valid region (equals dst) + "dst_vshape": (8, 48), # dst valid region + "eps": 1e-3, + }, + { + "name": "f16_8x768_src0_col_less", + "dtype": np.float16, + "shape": (8, 768), + "valid_shape": (8, 512), # src0 valid region (col_less) + "src1_vshape": (8, 768), # src1 valid region (equals dst) + "dst_vshape": (8, 768), # dst valid region + "eps": 1e-3, + }, + # int16 cases + { + "name": "i16_8x48_src1_col_less", + "dtype": np.int16, + "shape": (8, 48), + "valid_shape": (8, 48), # src0 valid region (equals dst) + "src1_vshape": (8, 16), # src1 valid region (col_less) + "dst_vshape": (8, 48), # dst valid region + "eps": 0, # exact match for int + }, + # int32 cases + { + "name": "i32_64x64_src0_row_less", + "dtype": np.int32, + "shape": (64, 64), + "valid_shape": (8, 64), # src0 valid region (row_less) + "src1_vshape": (64, 64), # src1 valid region (equals dst) + "dst_vshape": (64, 64), # dst valid region + "eps": 0, # exact match for int + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartmul/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tpartmul/compare.py new file mode 100644 index 000000000..283ee788a --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartmul/compare.py @@ -0,0 +1,53 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +# Add parent directory to path for st_common import +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from st_common import result_cmp, style_fail, style_pass + +from cases import CASES + + +def main(): + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + dtype = case["dtype"] + dst_vr, dst_vc = case["dst_vshape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=dtype).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=dtype).reshape(shape) + + # Compare only the dst valid region + ok = result_cmp(golden[:dst_vr, :dst_vc], output[:dst_vr, :dst_vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartmul/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tpartmul/gen_data.py new file mode 100644 index 000000000..5ca965d0e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartmul/gen_data.py @@ -0,0 +1,96 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +import os +import sys + +# Add parent directory to path for st_common import +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from st_common import setup_case_rng, save_case_data + +from cases import CASES + + +def _to_tuple(shape): + """Convert shape to tuple if needed.""" + if isinstance(shape, tuple): + return shape + return tuple(shape) + + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = _to_tuple(case["shape"]) + src0_valid = _to_tuple(case["valid_shape"]) + src1_valid = _to_tuple(case["src1_vshape"]) + dst_valid = _to_tuple(case["dst_vshape"]) + + rows, cols = shape + src0_vr, src0_vc = src0_valid + src1_vr, src1_vc = src1_valid + dst_vr, dst_vc = dst_valid + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + input2 = np.random.randint(1, 10, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + + # Compute golden according to tpartmul semantics from template: + # If src0_valid == dst_valid: use tpart_op with src0 as full operand + # - If src1 row_less: mul for src1 region, copy src0 for remaining rows + # - If src1 col_less: copy src0 full, then mul for overlapping region + # If src1_valid == dst_valid: use tpart_op with src1 as full operand (swap src0/src1) + + src0_eq_dst = (src0_vr == dst_vr and src0_vc == dst_vc) + src1_eq_dst = (src1_vr == dst_vr and src1_vc == dst_vc) + + if src0_eq_dst: + # src0 is the full operand matching dst + src1_row_lt_dst = (src1_vr < dst_vr and src1_vc == dst_vc) + src1_col_lt_dst = (src1_vr <= dst_vr and src1_vc < dst_vc) + + if src1_eq_dst: + # Full mul: dst[:] = src0[:] * src1[:] + golden[:dst_vr, :dst_vc] = (input1[:dst_vr, :dst_vc] * input2[:dst_vr, :dst_vc]).astype(dtype, copy=False) + elif src1_col_lt_dst: + # Col_less: first copy src0, then mul in overlapping region + golden[:dst_vr, :dst_vc] = input1[:dst_vr, :dst_vc].copy() + if src1_vc > 0: + golden[:src1_vr, :src1_vc] = (input1[:src1_vr, :src1_vc] * input2[:src1_vr, :src1_vc]).astype(dtype, copy=False) + elif src1_row_lt_dst: + # Row_less: mul for src1 region, copy src0 for remaining rows + if src1_vc > 0: + golden[:src1_vr, :src1_vc] = (input1[:src1_vr, :src1_vc] * input2[:src1_vr, :src1_vc]).astype(dtype, copy=False) + golden[src1_vr:dst_vr, :dst_vc] = input1[src1_vr:dst_vr, :dst_vc].copy() + elif src1_eq_dst: + # src1 is the full operand matching dst, swap src0/src1 in the logic + src0_row_lt_dst = (src0_vr < dst_vr and src0_vc == dst_vc) + src0_col_lt_dst = (src0_vr <= dst_vr and src0_vc < dst_vc) + + if src0_eq_dst: + # Full mul: dst[:] = src0[:] * src1[:] + golden[:dst_vr, :dst_vc] = (input1[:dst_vr, :dst_vc] * input2[:dst_vr, :dst_vc]).astype(dtype, copy=False) + elif src0_col_lt_dst: + # Col_less: first copy src1, then mul in overlapping region + golden[:dst_vr, :dst_vc] = input2[:dst_vr, :dst_vc].copy() + if src0_vc > 0: + golden[:src0_vr, :src0_vc] = (input1[:src0_vr, :src0_vc] * input2[:src0_vr, :src0_vc]).astype(dtype, copy=False) + elif src0_row_lt_dst: + # Row_less: mul for src0 region, copy src1 for remaining rows + if src0_vc > 0: + golden[:src0_vr, :src0_vc] = (input1[:src0_vr, :src0_vc] * input2[:src0_vr, :src0_vc]).astype(dtype, copy=False) + golden[src0_vr:dst_vr, :dst_vc] = input2[src0_vr:dst_vr, :dst_vc].copy() + + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} src0_valid={src0_valid} src1_valid={src1_valid} dst_valid={dst_valid} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartmul/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tpartmul/launch.cpp new file mode 100644 index 000000000..fb00bb99f --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartmul/launch.cpp @@ -0,0 +1,76 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: f32 64x64 full +extern "C" __global__ AICORE void TPARTMUL_f32_64x64_full(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTPARTMUL_f32_64x64_full(float *a, float *b, float *c, void *stream) { + TPARTMUL_f32_64x64_full<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 1: f32 64x64 src0 row less +extern "C" __global__ AICORE void TPARTMUL_f32_64x64_src0_row_less(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTPARTMUL_f32_64x64_src0_row_less(float *a, float *b, float *c, void *stream) { + TPARTMUL_f32_64x64_src0_row_less<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 2: f32 64x64 src0 col less +extern "C" __global__ AICORE void TPARTMUL_f32_64x64_src0_col_less(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTPARTMUL_f32_64x64_src0_col_less(float *a, float *b, float *c, void *stream) { + TPARTMUL_f32_64x64_src0_col_less<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 3: f32 64x64 src1 row less +extern "C" __global__ AICORE void TPARTMUL_f32_64x64_src1_row_less(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTPARTMUL_f32_64x64_src1_row_less(float *a, float *b, float *c, void *stream) { + TPARTMUL_f32_64x64_src1_row_less<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 4: f32 64x64 src1 col less +extern "C" __global__ AICORE void TPARTMUL_f32_64x64_src1_col_less(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTPARTMUL_f32_64x64_src1_col_less(float *a, float *b, float *c, void *stream) { + TPARTMUL_f32_64x64_src1_col_less<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 5: f16 8x48 src0 col less +extern "C" __global__ AICORE void TPARTMUL_f16_8x48_src0_col_less(__gm__ uint16_t *a, __gm__ uint16_t *b, __gm__ uint16_t *c); + +void LaunchTPARTMUL_f16_8x48_src0_col_less(uint16_t *a, uint16_t *b, uint16_t *c, void *stream) { + TPARTMUL_f16_8x48_src0_col_less<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b, (__gm__ uint16_t *)c); +} + +// Case 6: f16 8x768 src0 col less +extern "C" __global__ AICORE void TPARTMUL_f16_8x768_src0_col_less(__gm__ uint16_t *a, __gm__ uint16_t *b, __gm__ uint16_t *c); + +void LaunchTPARTMUL_f16_8x768_src0_col_less(uint16_t *a, uint16_t *b, uint16_t *c, void *stream) { + TPARTMUL_f16_8x768_src0_col_less<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b, (__gm__ uint16_t *)c); +} + +// Case 7: i16 8x48 src1 col less +extern "C" __global__ AICORE void TPARTMUL_i16_8x48_src1_col_less(__gm__ int16_t *a, __gm__ int16_t *b, __gm__ int16_t *c); + +void LaunchTPARTMUL_i16_8x48_src1_col_less(int16_t *a, int16_t *b, int16_t *c, void *stream) { + TPARTMUL_i16_8x48_src1_col_less<<<1, nullptr, stream>>>((__gm__ int16_t *)a, (__gm__ int16_t *)b, (__gm__ int16_t *)c); +} + +// Case 8: i32 64x64 src0 row less +extern "C" __global__ AICORE void TPARTMUL_i32_64x64_src0_row_less(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); + +void LaunchTPARTMUL_i32_64x64_src0_row_less(int32_t *a, int32_t *b, int32_t *c, void *stream) { + TPARTMUL_i32_64x64_src0_row_less<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartmul/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tpartmul/main.cpp new file mode 100644 index 000000000..d281d8710 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartmul/main.cpp @@ -0,0 +1,164 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tpartmul ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTPARTMUL_f32_64x64_full(float *a, float *b, float *c, void *stream); +void LaunchTPARTMUL_f32_64x64_src0_row_less(float *a, float *b, float *c, void *stream); +void LaunchTPARTMUL_f32_64x64_src0_col_less(float *a, float *b, float *c, void *stream); +void LaunchTPARTMUL_f32_64x64_src1_row_less(float *a, float *b, float *c, void *stream); +void LaunchTPARTMUL_f32_64x64_src1_col_less(float *a, float *b, float *c, void *stream); +void LaunchTPARTMUL_f16_8x48_src0_col_less(uint16_t *a, uint16_t *b, uint16_t *c, void *stream); +void LaunchTPARTMUL_f16_8x768_src0_col_less(uint16_t *a, uint16_t *b, uint16_t *c, void *stream); +void LaunchTPARTMUL_i16_8x48_src1_col_less(int16_t *a, int16_t *b, int16_t *c, void *stream); +void LaunchTPARTMUL_i32_64x64_src0_row_less(int32_t *a, int32_t *b, int32_t *c, void *stream); + +using LaunchFn = void (*)(void *, void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t src0ValidRows; // src0 effective rows + size_t src0ValidCols; // src0 effective cols + size_t src1ValidRows; // src1 effective rows + size_t src1ValidCols; // src1 effective cols + size_t dstValidRows; // dst effective rows + size_t dstValidCols; // dst effective cols + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_64x64_full", reinterpret_cast(LaunchTPARTMUL_f32_64x64_full), 64, 64, 64, 64, 64, 64, 64, 64, sizeof(float)}, + {"f32_64x64_src0_row_less", reinterpret_cast(LaunchTPARTMUL_f32_64x64_src0_row_less), 64, 64, 8, 64, 64, 64, 64, 64, sizeof(float)}, + {"f32_64x64_src0_col_less", reinterpret_cast(LaunchTPARTMUL_f32_64x64_src0_col_less), 64, 64, 64, 8, 64, 64, 64, 64, sizeof(float)}, + {"f32_64x64_src1_row_less", reinterpret_cast(LaunchTPARTMUL_f32_64x64_src1_row_less), 64, 64, 64, 64, 8, 64, 64, 64, sizeof(float)}, + {"f32_64x64_src1_col_less", reinterpret_cast(LaunchTPARTMUL_f32_64x64_src1_col_less), 64, 64, 64, 64, 64, 8, 64, 64, sizeof(float)}, + {"f16_8x48_src0_col_less", reinterpret_cast(LaunchTPARTMUL_f16_8x48_src0_col_less), 8, 48, 8, 16, 8, 48, 8, 48, sizeof(uint16_t)}, + {"f16_8x768_src0_col_less", reinterpret_cast(LaunchTPARTMUL_f16_8x768_src0_col_less), 8,768, 8,512, 8,768, 8,768, sizeof(uint16_t)}, + {"i16_8x48_src1_col_less", reinterpret_cast(LaunchTPARTMUL_i16_8x48_src1_col_less), 8, 48, 8, 48, 8, 16, 8, 48, sizeof(int16_t)}, + {"i32_64x64_src0_row_less", reinterpret_cast(LaunchTPARTMUL_i32_64x64_src0_row_less), 64, 64, 8, 64, 64, 64, 64, 64, sizeof(int32_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, src0_valid=%zux%zu, src1_valid=%zux%zu, dst_valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.src0ValidRows, tc.src0ValidCols, + tc.src1ValidRows, tc.src1ValidCols, tc.dstValidRows, tc.dstValidCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t src0FileSize = fileSize; + size_t src1FileSize = fileSize; + + void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), fileSize); + aclrtMallocHost((void **)(&src1Host), fileSize); + aclrtMallocHost((void **)(&dstHost), fileSize); + + aclrtMalloc((void **)&src0Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, fileSize, src0Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, fileSize, src1Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tpartmul [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tpartmul/tpartmul.pto b/test/tilelang_st/npu/a5/src/st/testcase/tpartmul/tpartmul.pto new file mode 100644 index 000000000..d537030a5 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tpartmul/tpartmul.pto @@ -0,0 +1,535 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tpartmul: partial elementwise mul with valid region handling. +// Multiple cases with different valid_shape combinations in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // Case 0: f32 64x64 full (src0/src1/dst all have same valid_shape 64x64) + func.func @TPARTMUL_f32_64x64_full(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tpartmul ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + return + } + + // Case 1: f32 64x64 src0 row less (src0 valid 8x64, src1/dst valid 64x64) + func.func @TPARTMUL_f32_64x64_src0_row_less(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + + // src0: partial valid region (8,64) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: full valid region (64,64) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: full valid region (64,64) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tpartmul ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + return + } + + // Case 2: f32 64x64 src0 col less (src0 valid 64x8, src1/dst valid 64x64) + func.func @TPARTMUL_f32_64x64_src0_col_less(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + + // src0: partial valid region (64,8) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: full valid region (64,64) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: full valid region (64,64) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tpartmul ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + return + } + + // Case 3: f32 64x64 src1 row less (src0/dst valid 64x64, src1 valid 8x64) + func.func @TPARTMUL_f32_64x64_src1_row_less(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + + // src0: full valid region (64,64) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: partial valid region (8,64) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: full valid region (64,64) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tpartmul ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + return + } + + // Case 4: f32 64x64 src1 col less (src0/dst valid 64x64, src1 valid 64x8) + func.func @TPARTMUL_f32_64x64_src1_col_less(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + + // src0: full valid region (64,64) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: partial valid region (64,8) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: full valid region (64,64) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tpartmul ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + return + } + + // Case 5: f16 8x48 src0 col less (src0 valid 8x16, src1/dst valid 8x48) + func.func @TPARTMUL_f16_8x48_src0_col_less(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c48 = arith.constant 48 : index + %c384 = arith.constant 384 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c8, %c48], + strides = [%c384, %c384, %c384, %c48, %c1] + : !pto.tensor_view<1x1x1x8x48xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c8, %c48], + strides = [%c384, %c384, %c384, %c48, %c1] + : !pto.tensor_view<1x1x1x8x48xf16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c8, %c48], + strides = [%c384, %c384, %c384, %c48, %c1] + : !pto.tensor_view<1x1x1x8x48xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c48] + : !pto.tensor_view<1x1x1x8x48xf16> -> !pto.partition_tensor_view<1x1x1x8x48xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c48] + : !pto.tensor_view<1x1x1x8x48xf16> -> !pto.partition_tensor_view<1x1x1x8x48xf16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c48] + : !pto.tensor_view<1x1x1x8x48xf16> -> !pto.partition_tensor_view<1x1x1x8x48xf16> + + // src0: partial valid region (8,16) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: full valid region (8,48) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: full valid region (8,48) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x8x48xf16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x8x48xf16>) + outs(%b : !pto.tile_buf) + + pto.tpartmul ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x8x48xf16>) + return + } + + // Case 6: f16 8x768 src0 col less (src0 valid 8x512, src1/dst valid 8x768) + func.func @TPARTMUL_f16_8x768_src0_col_less(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c512 = arith.constant 512 : index + %c768 = arith.constant 768 : index + %c6144 = arith.constant 6144 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c8, %c768], + strides = [%c6144, %c6144, %c6144, %c768, %c1] + : !pto.tensor_view<1x1x1x8x768xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c8, %c768], + strides = [%c6144, %c6144, %c6144, %c768, %c1] + : !pto.tensor_view<1x1x1x8x768xf16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c8, %c768], + strides = [%c6144, %c6144, %c6144, %c768, %c1] + : !pto.tensor_view<1x1x1x8x768xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c768] + : !pto.tensor_view<1x1x1x8x768xf16> -> !pto.partition_tensor_view<1x1x1x8x768xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c768] + : !pto.tensor_view<1x1x1x8x768xf16> -> !pto.partition_tensor_view<1x1x1x8x768xf16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c768] + : !pto.tensor_view<1x1x1x8x768xf16> -> !pto.partition_tensor_view<1x1x1x8x768xf16> + + // src0: partial valid region (8,512) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: full valid region (8,768) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: full valid region (8,768) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x8x768xf16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x8x768xf16>) + outs(%b : !pto.tile_buf) + + pto.tpartmul ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x8x768xf16>) + return + } + + // Case 7: i16 8x48 src1 col less (src0/dst valid 8x48, src1 valid 8x16) + func.func @TPARTMUL_i16_8x48_src1_col_less(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c48 = arith.constant 48 : index + %c384 = arith.constant 384 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c8, %c48], + strides = [%c384, %c384, %c384, %c48, %c1] + : !pto.tensor_view<1x1x1x8x48xi16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c8, %c48], + strides = [%c384, %c384, %c384, %c48, %c1] + : !pto.tensor_view<1x1x1x8x48xi16> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c8, %c48], + strides = [%c384, %c384, %c384, %c48, %c1] + : !pto.tensor_view<1x1x1x8x48xi16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c48] + : !pto.tensor_view<1x1x1x8x48xi16> -> !pto.partition_tensor_view<1x1x1x8x48xi16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c48] + : !pto.tensor_view<1x1x1x8x48xi16> -> !pto.partition_tensor_view<1x1x1x8x48xi16> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c48] + : !pto.tensor_view<1x1x1x8x48xi16> -> !pto.partition_tensor_view<1x1x1x8x48xi16> + + // src0: full valid region (8,48) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: partial valid region (8,16) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: full valid region (8,48) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x8x48xi16>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x8x48xi16>) + outs(%b : !pto.tile_buf) + + pto.tpartmul ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x8x48xi16>) + return + } + + // Case 8: i32 64x64 src0 row less (src0 valid 8x64, src1/dst valid 64x64) + func.func @TPARTMUL_i32_64x64_src0_row_less(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> + + // src0: partial valid region (8,64) + %a = pto.alloc_tile + : !pto.tile_buf + // src1: full valid region (64,64) + %b = pto.alloc_tile + : !pto.tile_buf + // dst: full valid region (64,64) + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) + outs(%b : !pto.tile_buf) + + pto.tpartmul ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tprelu/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tprelu/CMakeLists.txt new file mode 100644 index 000000000..6d73cccfa --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tprelu/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tprelu) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tprelu/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tprelu/cases.py new file mode 100644 index 000000000..bafcc28ef --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tprelu/cases.py @@ -0,0 +1,83 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tprelu ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float16, np.float32). + - shape: (rows, cols) — allocated tile dimensions. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "f16_64x64", + "dtype": np.float16, + "shape": (64, 64), + "valid_shape": (64, 64), + "eps": 1e-3, + }, + { + "name": "f16_63x63", + "dtype": np.float16, + "shape": (64, 64), + "valid_shape": (63, 63), + "eps": 1e-3, + }, + { + "name": "f16_1x16384", + "dtype": np.float16, + "shape": (1, 16384), + "valid_shape": (1, 16384), + "eps": 1e-3, + }, + { + "name": "f16_2048x16", + "dtype": np.float16, + "shape": (2048, 16), + "valid_shape": (2048, 16), + "eps": 1e-3, + }, + { + "name": "f32_64x64", + "dtype": np.float32, + "shape": (64, 64), + "valid_shape": (64, 64), + "eps": 1e-6, + }, + { + "name": "f32_63x63", + "dtype": np.float32, + "shape": (64, 64), + "valid_shape": (63, 63), + "eps": 1e-6, + }, + { + "name": "f32_1x16384", + "dtype": np.float32, + "shape": (1, 16384), + "valid_shape": (1, 16384), + "eps": 1e-6, + }, + { + "name": "f32_2048x8", + "dtype": np.float32, + "shape": (2048, 8), + "valid_shape": (2048, 8), + "eps": 1e-6, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tprelu/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tprelu/compare.py new file mode 100644 index 000000000..b975718e6 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tprelu/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tprelu/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tprelu/gen_data.py new file mode 100644 index 000000000..b8ec70593 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tprelu/gen_data.py @@ -0,0 +1,40 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + rows, cols = shape + vr, vc = valid_shape + + input0 = np.random.uniform(-8, high=8, size=(rows, cols)).astype(dtype) + input1 = np.random.uniform(-8, high=8, size=(rows, cols)).astype(dtype) + + golden = np.zeros((rows, cols), dtype=dtype) + for i in range(vr): + for j in range(vc): + if input0[i, j] > 0: + golden[i, j] = input0[i, j] + else: + golden[i, j] = dtype(input0[i, j] * input1[i, j]) + + save_case_data(case["name"], {"input0": input0, "input1": input1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tprelu/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tprelu/launch.cpp new file mode 100644 index 000000000..4e15c68ab --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tprelu/launch.cpp @@ -0,0 +1,69 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: f16 64x64 +extern "C" __global__ AICORE void TPRELU_f16_64x64(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTPRELU_f16_64x64(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream) { + TPRELU_f16_64x64<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} + +// Case 1: f16 63x63 +extern "C" __global__ AICORE void TPRELU_f16_63x63(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTPRELU_f16_63x63(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream) { + TPRELU_f16_63x63<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} + +// Case 2: f16 1x16384 +extern "C" __global__ AICORE void TPRELU_f16_1x16384(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTPRELU_f16_1x16384(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream) { + TPRELU_f16_1x16384<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} + +// Case 3: f16 2048x16 +extern "C" __global__ AICORE void TPRELU_f16_2048x16(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTPRELU_f16_2048x16(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream) { + TPRELU_f16_2048x16<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} + +// Case 4: f32 64x64 +extern "C" __global__ AICORE void TPRELU_f32_64x64(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTPRELU_f32_64x64(float *src0, float *src1, float *dst, void *stream) { + TPRELU_f32_64x64<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// Case 5: f32 63x63 +extern "C" __global__ AICORE void TPRELU_f32_63x63(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTPRELU_f32_63x63(float *src0, float *src1, float *dst, void *stream) { + TPRELU_f32_63x63<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// Case 6: f32 1x16384 +extern "C" __global__ AICORE void TPRELU_f32_1x16384(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTPRELU_f32_1x16384(float *src0, float *src1, float *dst, void *stream) { + TPRELU_f32_1x16384<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// Case 7: f32 2048x8 +extern "C" __global__ AICORE void TPRELU_f32_2048x8(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTPRELU_f32_2048x8(float *src0, float *src1, float *dst, void *stream) { + TPRELU_f32_2048x8<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tprelu/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tprelu/main.cpp new file mode 100644 index 000000000..a55be7945 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tprelu/main.cpp @@ -0,0 +1,198 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tprelu ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTPRELU_f16_64x64(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream); +void LaunchTPRELU_f16_63x63(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream); +void LaunchTPRELU_f16_1x16384(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream); +void LaunchTPRELU_f16_2048x16(uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream); +void LaunchTPRELU_f32_64x64(float *src0, float *src1, float *dst, void *stream); +void LaunchTPRELU_f32_63x63(float *src0, float *src1, float *dst, void *stream); +void LaunchTPRELU_f32_1x16384(float *src0, float *src1, float *dst, void *stream); +void LaunchTPRELU_f32_2048x8(float *src0, float *src1, float *dst, void *stream); + +enum DataType { F16, F32 }; + +struct TestCase { + const char *name; + DataType dtype; + void * launch; + size_t rows; + size_t cols; + size_t validRows; + size_t validCols; +}; + +static const TestCase kCases[] = { + {"f16_64x64", F16, (void*)LaunchTPRELU_f16_64x64, 64, 64, 64, 64}, + {"f16_63x63", F16, (void*)LaunchTPRELU_f16_63x63, 64, 64, 63, 63}, + {"f16_1x16384", F16, (void*)LaunchTPRELU_f16_1x16384, 1, 16384, 1, 16384}, + {"f16_2048x16", F16, (void*)LaunchTPRELU_f16_2048x16, 2048, 16, 2048, 16}, + {"f32_64x64", F32, (void*)LaunchTPRELU_f32_64x64, 64, 64, 64, 64}, + {"f32_63x63", F32, (void*)LaunchTPRELU_f32_63x63, 64, 64, 63, 63}, + {"f32_1x16384", F32, (void*)LaunchTPRELU_f32_1x16384, 1, 16384, 1, 16384}, + {"f32_2048x8", F32, (void*)LaunchTPRELU_f32_2048x8, 2048, 8, 2048, 8}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +template +using LaunchFn = void (*)(T *, T *, T *, void *); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t elemSize = (tc.dtype == F16) ? sizeof(uint16_t) : sizeof(float); + size_t fileSize = elemCount * elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu, dtype=%s) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols, + (tc.dtype == F16) ? "f16" : "f32"); + + std::string caseDir = std::string("./") + tc.name; + + if (tc.dtype == F16) { + uint16_t *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + uint16_t *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), fileSize); + aclrtMallocHost((void **)(&src1Host), fileSize); + aclrtMallocHost((void **)(&dstHost), fileSize); + + aclrtMalloc((void **)&src0Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input0.bin").c_str(), fileSize, src0Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input0.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input1.bin").c_str(), fileSize, src1Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, fileSize, src0Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, fileSize, src1Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + LaunchFn launch = (LaunchFn)tc.launch; + launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) aclrtFree(src0Device); + if (src1Device != nullptr) aclrtFree(src1Device); + if (dstDevice != nullptr) aclrtFree(dstDevice); + if (src0Host != nullptr) aclrtFreeHost(src0Host); + if (src1Host != nullptr) aclrtFreeHost(src1Host); + if (dstHost != nullptr) aclrtFreeHost(dstHost); + } else { + float *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + float *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), fileSize); + aclrtMallocHost((void **)(&src1Host), fileSize); + aclrtMallocHost((void **)(&dstHost), fileSize); + + aclrtMalloc((void **)&src0Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input0.bin").c_str(), fileSize, src0Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input0.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input1.bin").c_str(), fileSize, src1Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, fileSize, src0Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, fileSize, src1Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + LaunchFn launch = (LaunchFn)tc.launch; + launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) aclrtFree(src0Device); + if (src1Device != nullptr) aclrtFree(src1Device); + if (dstDevice != nullptr) aclrtFree(dstDevice); + if (src0Host != nullptr) aclrtFreeHost(src0Host); + if (src1Host != nullptr) aclrtFreeHost(src1Host); + if (dstHost != nullptr) aclrtFreeHost(dstHost); + } + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tprelu/tprelu.pto b/test/tilelang_st/npu/a5/src/st/testcase/tprelu/tprelu.pto new file mode 100644 index 000000000..d72ed41a4 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tprelu/tprelu.pto @@ -0,0 +1,494 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tprelu: tload(src0) + tload(src1) + tprelu(src0,src1,tmp)->dst + tstore(dst). +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // Case 0: f16 64x64 (4096 elements) + func.func @TPRELU_f16_64x64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf16> -> !pto.partition_tensor_view<1x1x1x64x64xf16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf16> -> !pto.partition_tensor_view<1x1x1x64x64xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf16> -> !pto.partition_tensor_view<1x1x1x64x64xf16> + + %src0_tile = pto.alloc_tile + : !pto.tile_buf + %src1_tile = pto.alloc_tile + : !pto.tile_buf + %tmp_tile = pto.alloc_tile + : !pto.tile_buf + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x64x64xf16>) + outs(%src0_tile : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x64x64xf16>) + outs(%src1_tile : !pto.tile_buf) + + pto.tprelu ins(%src0_tile, %src1_tile, %tmp_tile : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x64x64xf16>) + return + } + + // Case 1: f16 63x63 (partial valid_shape) + func.func @TPRELU_f16_63x63(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c63] + : !pto.tensor_view<1x1x1x64x64xf16> -> !pto.partition_tensor_view<1x1x1x63x63xf16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c63] + : !pto.tensor_view<1x1x1x64x64xf16> -> !pto.partition_tensor_view<1x1x1x63x63xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c63] + : !pto.tensor_view<1x1x1x64x64xf16> -> !pto.partition_tensor_view<1x1x1x63x63xf16> + + %src0_tile = pto.alloc_tile + : !pto.tile_buf + %src1_tile = pto.alloc_tile + : !pto.tile_buf + %tmp_tile = pto.alloc_tile + : !pto.tile_buf + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x63x63xf16>) + outs(%src0_tile : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x63x63xf16>) + outs(%src1_tile : !pto.tile_buf) + + pto.tprelu ins(%src0_tile, %src1_tile, %tmp_tile : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x63xf16>) + return + } + + // Case 2: f16 1x16384 + func.func @TPRELU_f16_1x16384(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16384 = arith.constant 16384 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c1, %c16384], + strides = [%c16384, %c16384, %c16384, %c16384, %c1] + : !pto.tensor_view<1x1x1x1x16384xf16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c16384], + strides = [%c16384, %c16384, %c16384, %c16384, %c1] + : !pto.tensor_view<1x1x1x1x16384xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c16384], + strides = [%c16384, %c16384, %c16384, %c16384, %c1] + : !pto.tensor_view<1x1x1x1x16384xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c16384] + : !pto.tensor_view<1x1x1x1x16384xf16> -> !pto.partition_tensor_view<1x1x1x1x16384xf16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c16384] + : !pto.tensor_view<1x1x1x1x16384xf16> -> !pto.partition_tensor_view<1x1x1x1x16384xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c16384] + : !pto.tensor_view<1x1x1x1x16384xf16> -> !pto.partition_tensor_view<1x1x1x1x16384xf16> + + %src0_tile = pto.alloc_tile + : !pto.tile_buf + %src1_tile = pto.alloc_tile + : !pto.tile_buf + %tmp_tile = pto.alloc_tile + : !pto.tile_buf + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x1x16384xf16>) + outs(%src0_tile : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x16384xf16>) + outs(%src1_tile : !pto.tile_buf) + + pto.tprelu ins(%src0_tile, %src1_tile, %tmp_tile : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x16384xf16>) + return + } + + // Case 3: f16 2048x16 + func.func @TPRELU_f16_2048x16(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c2048 = arith.constant 2048 : index + %c32768 = arith.constant 32768 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c2048, %c16], + strides = [%c32768, %c32768, %c32768, %c16, %c1] + : !pto.tensor_view<1x1x1x2048x16xf16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c2048, %c16], + strides = [%c32768, %c32768, %c32768, %c16, %c1] + : !pto.tensor_view<1x1x1x2048x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2048, %c16], + strides = [%c32768, %c32768, %c32768, %c16, %c1] + : !pto.tensor_view<1x1x1x2048x16xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2048, %c16] + : !pto.tensor_view<1x1x1x2048x16xf16> -> !pto.partition_tensor_view<1x1x1x2048x16xf16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2048, %c16] + : !pto.tensor_view<1x1x1x2048x16xf16> -> !pto.partition_tensor_view<1x1x1x2048x16xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2048, %c16] + : !pto.tensor_view<1x1x1x2048x16xf16> -> !pto.partition_tensor_view<1x1x1x2048x16xf16> + + %src0_tile = pto.alloc_tile + : !pto.tile_buf + %src1_tile = pto.alloc_tile + : !pto.tile_buf + %tmp_tile = pto.alloc_tile + : !pto.tile_buf + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x2048x16xf16>) + outs(%src0_tile : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x2048x16xf16>) + outs(%src1_tile : !pto.tile_buf) + + pto.tprelu ins(%src0_tile, %src1_tile, %tmp_tile : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2048x16xf16>) + return + } + + // Case 4: f32 64x64 + func.func @TPRELU_f32_64x64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + + %src0_tile = pto.alloc_tile + : !pto.tile_buf + %src1_tile = pto.alloc_tile + : !pto.tile_buf + %tmp_tile = pto.alloc_tile + : !pto.tile_buf + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%src0_tile : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%src1_tile : !pto.tile_buf) + + pto.tprelu ins(%src0_tile, %src1_tile, %tmp_tile : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + return + } + + // Case 5: f32 63x63 (partial valid_shape) + func.func @TPRELU_f32_63x63(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c63] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x63x63xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c63] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x63x63xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c63] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x63x63xf32> + + %src0_tile = pto.alloc_tile + : !pto.tile_buf + %src1_tile = pto.alloc_tile + : !pto.tile_buf + %tmp_tile = pto.alloc_tile + : !pto.tile_buf + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x63x63xf32>) + outs(%src0_tile : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x63x63xf32>) + outs(%src1_tile : !pto.tile_buf) + + pto.tprelu ins(%src0_tile, %src1_tile, %tmp_tile : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x63xf32>) + return + } + + // Case 6: f32 1x16384 + func.func @TPRELU_f32_1x16384(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16384 = arith.constant 16384 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c1, %c16384], + strides = [%c16384, %c16384, %c16384, %c16384, %c1] + : !pto.tensor_view<1x1x1x1x16384xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c1, %c16384], + strides = [%c16384, %c16384, %c16384, %c16384, %c1] + : !pto.tensor_view<1x1x1x1x16384xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c16384], + strides = [%c16384, %c16384, %c16384, %c16384, %c1] + : !pto.tensor_view<1x1x1x1x16384xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c16384] + : !pto.tensor_view<1x1x1x1x16384xf32> -> !pto.partition_tensor_view<1x1x1x1x16384xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c16384] + : !pto.tensor_view<1x1x1x1x16384xf32> -> !pto.partition_tensor_view<1x1x1x1x16384xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c16384] + : !pto.tensor_view<1x1x1x1x16384xf32> -> !pto.partition_tensor_view<1x1x1x1x16384xf32> + + %src0_tile = pto.alloc_tile + : !pto.tile_buf + %src1_tile = pto.alloc_tile + : !pto.tile_buf + %tmp_tile = pto.alloc_tile + : !pto.tile_buf + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x1x16384xf32>) + outs(%src0_tile : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x1x16384xf32>) + outs(%src1_tile : !pto.tile_buf) + + pto.tprelu ins(%src0_tile, %src1_tile, %tmp_tile : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x16384xf32>) + return + } + + // Case 7: f32 2048x8 + func.func @TPRELU_f32_2048x8(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c2048 = arith.constant 2048 : index + %c16384 = arith.constant 16384 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c2048, %c8], + strides = [%c16384, %c16384, %c16384, %c8, %c1] + : !pto.tensor_view<1x1x1x2048x8xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c2048, %c8], + strides = [%c16384, %c16384, %c16384, %c8, %c1] + : !pto.tensor_view<1x1x1x2048x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2048, %c8], + strides = [%c16384, %c16384, %c16384, %c8, %c1] + : !pto.tensor_view<1x1x1x2048x8xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2048, %c8] + : !pto.tensor_view<1x1x1x2048x8xf32> -> !pto.partition_tensor_view<1x1x1x2048x8xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2048, %c8] + : !pto.tensor_view<1x1x1x2048x8xf32> -> !pto.partition_tensor_view<1x1x1x2048x8xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2048, %c8] + : !pto.tensor_view<1x1x1x2048x8xf32> -> !pto.partition_tensor_view<1x1x1x2048x8xf32> + + %src0_tile = pto.alloc_tile + : !pto.tile_buf + %src1_tile = pto.alloc_tile + : !pto.tile_buf + %tmp_tile = pto.alloc_tile + : !pto.tile_buf + %dst_tile = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x2048x8xf32>) + outs(%src0_tile : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x2048x8xf32>) + outs(%src1_tile : !pto.tile_buf) + + pto.tprelu ins(%src0_tile, %src1_tile, %tmp_tile : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2048x8xf32>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trandom/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/trandom/CMakeLists.txt new file mode 100644 index 000000000..5d7e644a2 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trandom/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(trandom) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trandom/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/trandom/cases.py new file mode 100644 index 000000000..d3cf0b681 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trandom/cases.py @@ -0,0 +1,36 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for trandom ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (np.int32 or np.uint32). + - shape: (rows, cols) — allocated tile dimensions. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - rounds: Philox rounds (7 or 10). + - eps: tolerance for comparison (0 for exact match). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "int32_4x256", + "dtype": np.int32, + "shape": (4, 256), + "valid_shape": (4, 256), + "rounds": 10, + "eps": 0, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trandom/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/trandom/compare.py new file mode 100644 index 000000000..89acf6221 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trandom/compare.py @@ -0,0 +1,82 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Compare golden reference with NPU output for trandom test cases.""" + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases +from gen_data import trandom_generate + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + valid_shape = case["valid_shape"] + vr, vc = valid_shape + dtype = case["dtype"] + eps = case["eps"] + + golden_file = os.path.join(case_dir, "golden.bin") + output_file = os.path.join(case_dir, "output.bin") + key_file = os.path.join(case_dir, "key.bin") + counter_file = os.path.join(case_dir, "counter.bin") + + if not os.path.exists(golden_file): + if os.path.exists(key_file) and os.path.exists(counter_file) and os.path.exists(output_file): + key = np.fromfile(key_file, dtype=dtype) + counter = np.fromfile(counter_file, dtype=dtype) + rounds = case.get("rounds", 10) + golden = trandom_generate(key.view(np.uint32), counter.view(np.uint32), + vr, vc, dtype=dtype, rounds=rounds) + golden.astype(dtype).tofile(golden_file) + print(f"[INFO] {case['name']}: generated golden.bin") + else: + print(style_fail(f"[ERROR] {case['name']}: golden.bin not found and cannot generate")) + all_passed = False + continue + + if not os.path.exists(output_file): + print(style_fail(f"[ERROR] {case['name']}: output.bin not found")) + all_passed = False + continue + + golden = np.fromfile(golden_file, dtype=dtype).reshape(shape) + output = np.fromfile(output_file, dtype=dtype).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], eps) + if ok: + unique_count = len(np.unique(output[:vr, :vc])) + total_count = vr * vc + print(style_pass(f"[INFO] {case['name']}: compare passed " + f"(unique={unique_count}/{total_count})")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trandom/debug_trandom.py b/test/tilelang_st/npu/a5/src/st/testcase/trandom/debug_trandom.py new file mode 100644 index 000000000..ec063b076 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trandom/debug_trandom.py @@ -0,0 +1,112 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""Debug script to trace trandom computation step by step.""" + +import numpy as np + +TRANDOM_CONST_0 = 0xD2511F53 +TRANDOM_CONST_1 = 0xCD9E8D57 +TRANDOM_CONST_KEY_ADD_0 = 0x9E3779B9 +TRANDOM_CONST_KEY_ADD_1 = 0xBB67AE85 + +def add_with_128bits_debug(ctr0, ctr1, ctr2, ctr3, value): + """Simulate 128-bit addition with carry propagation.""" + ctr0_new = ctr0.astype(np.uint64) + value.astype(np.uint64) + carry0 = (ctr0_new > 0xFFFFFFFF).astype(np.uint32) + ctr0_new = ctr0_new.astype(np.uint32) + + ctr1_new = ctr1.astype(np.uint64) + carry0.astype(np.uint64) + carry1 = (ctr1_new > 0xFFFFFFFF).astype(np.uint32) + ctr1_new = ctr1_new.astype(np.uint32) + + ctr2_new = ctr2.astype(np.uint64) + carry1.astype(np.uint64) + carry2 = (ctr2_new > 0xFFFFFFFF).astype(np.uint32) + ctr2_new = ctr2_new.astype(np.uint32) + + ctr3_new = ctr3.astype(np.uint64) + carry2.astype(np.uint64) + ctr3_new = ctr3_new.astype(np.uint32) + + return ctr0_new, ctr1_new, ctr2_new, ctr3_new + +def trandom_kernel_debug(ctr0, ctr1, ctr2, ctr3, key0_val, key1_val, rounds=10): + """Philox kernel with detailed logging.""" + lanes = len(ctr0) + key0 = np.full(lanes, np.uint32(key0_val), dtype=np.uint32) + key1 = np.full(lanes, np.uint32(key1_val), dtype=np.uint32) + + print(f"Initial counters: ctr0[0:5]={ctr0[0:5]}, ctr1[0:5]={ctr1[0:5]}") + print(f"Initial keys: key0={key0[0]}, key1={key1[0]}") + + for round_idx in range(rounds): + print(f"\n=== Round {round_idx} ===") + print(f"Before: ctr0[0]={ctr0[0]}, ctr1[0]={ctr1[0]}, ctr2[0]={ctr2[0]}, ctr3[0]={ctr3[0]}") + print(f"Before: key0={key0[0]}, key1={key1[0]}") + + prod0 = ctr0.astype(np.uint64) * np.uint64(TRANDOM_CONST_0) + prod1 = ctr2.astype(np.uint64) * np.uint64(TRANDOM_CONST_1) + + L0 = prod0.astype(np.uint32) + H0 = (prod0 >> 32).astype(np.uint32) + L1 = prod1.astype(np.uint32) + H1 = (prod1 >> 32).astype(np.uint32) + + print(f"prod0[0]={prod0[0]}, L0[0]={L0[0]}, H0[0]={H0[0]}") + print(f"prod1[0]={prod1[0]}, L1[0]={L1[0]}, H1[0]={H1[0]}") + + ctr0 = (H1 ^ ctr1) ^ key0 + ctr2 = (H0 ^ ctr3) ^ key1 + + print(f"ctr0[0] = (H1[0] ^ ctr1[0]) ^ key0[0] = ({H1[0]} ^ {ctr1[0]}) ^ {key0[0]} = {ctr0[0]}") + print(f"ctr2[0] = (H0[0] ^ ctr3[0]) ^ key1[0] = ({H0[0]} ^ {ctr3[0]}) ^ {key1[0]} = {ctr2[0]}") + + key0 = (key0.astype(np.uint32) + np.uint32(TRANDOM_CONST_KEY_ADD_0)) & np.uint32(0xFFFFFFFF) + key1 = (key1.astype(np.uint32) + np.uint32(TRANDOM_CONST_KEY_ADD_1)) & np.uint32(0xFFFFFFFF) + + print(f"key0={key0[0]}, key1={key1[0]} (after update)") + + ctr1 = L1 + ctr3 = L0 + + print(f"After: ctr0[0]={ctr0[0]}, ctr1[0]={ctr1[0]}, ctr2[0]={ctr2[0]}, ctr3[0]={ctr3[0]}") + + return ctr0, ctr1, ctr2, ctr3 + +key = np.array([-792737938, 2139558336], dtype=np.int32) +counter = np.array([-1759534764, -1881674653, 640338625, 1381573024], dtype=np.int32) + +key_uint = key.view(np.uint32) +counter_uint = counter.view(np.uint32) + +lanes = 64 +ctr0 = np.full(lanes, counter_uint[0], dtype=np.uint32) +ctr1 = np.full(lanes, counter_uint[1], dtype=np.uint32) +ctr2 = np.full(lanes, counter_uint[2], dtype=np.uint32) +ctr3 = np.full(lanes, counter_uint[3], dtype=np.uint32) + +print("=== Initial counter values ===") +print(f"ctr0[0]={ctr0[0]}, ctr1[0]={ctr1[0]}, ctr2[0]={ctr2[0]}, ctr3[0]={ctr3[0]}") + +inc_idx = np.arange(lanes, dtype=np.uint32) +ctr0, ctr1, ctr2, ctr3 = add_with_128bits_debug(ctr0, ctr1, ctr2, ctr3, inc_idx) + +print("\n=== After adding index ===") +print(f"ctr0[0:5]={ctr0[0:5]}") +print(f"ctr1[0:5]={ctr1[0:5]}") +print(f"ctr2[0:5]={ctr2[0:5]}") +print(f"ctr3[0:5]={ctr3[0:5]}") + +result = trandom_kernel_debug(ctr0.copy(), ctr1.copy(), ctr2.copy(), ctr3.copy(), + key_uint[0], key_uint[1], rounds=10) + +print("\n=== Final result ===") +print(f"ctr0[0:5]={result[0][0:5]}") +print(f"ctr1[0:5]={result[1][0:5]}") +print(f"ctr2[0:5]={result[2][0:5]}") +print(f"ctr3[0:5]={result[3][0:5]}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trandom/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/trandom/gen_data.py new file mode 100644 index 000000000..727bf8f34 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trandom/gen_data.py @@ -0,0 +1,235 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Generate input data and golden output for trandom test cases. + +Implements the Philox-based TRandom algorithm in pure Python/NumPy +to generate reference golden data for comparison with NPU output. + +Flow: + - First run (no output.bin): generate key/counter inputs only + - Second run (with output.bin): read saved key/counter, compute golden +""" + +import os +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +TRANDOM_ONCE_REPEAT = 4 +TRANDOM_CONST_0 = 0xD2511F53 +TRANDOM_CONST_1 = 0xCD9E8D57 +TRANDOM_CONST_KEY_ADD_0 = 0x9E3779B9 +TRANDOM_CONST_KEY_ADD_1 = 0xBB67AE85 + + +def add_with_128bits(ctr0, ctr1, ctr2, ctr3, value): + """Simulate 128-bit addition with carry propagation.""" + ctr0_new = ctr0.astype(np.uint64) + value.astype(np.uint64) + carry0 = (ctr0_new > 0xFFFFFFFF).astype(np.uint32) + ctr0_new = ctr0_new.astype(np.uint32) + + ctr1_new = ctr1.astype(np.uint64) + carry0.astype(np.uint64) + carry1 = (ctr1_new > 0xFFFFFFFF).astype(np.uint32) + ctr1_new = ctr1_new.astype(np.uint32) + + ctr2_new = ctr2.astype(np.uint64) + carry1.astype(np.uint64) + carry2 = (ctr2_new > 0xFFFFFFFF).astype(np.uint32) + ctr2_new = ctr2_new.astype(np.uint32) + + ctr3_new = ctr3.astype(np.uint64) + carry2.astype(np.uint64) + ctr3_new = ctr3_new.astype(np.uint32) + + return ctr0_new, ctr1_new, ctr2_new, ctr3_new + + +def trandom_kernel(ctr0, ctr1, ctr2, ctr3, key0_val, key1_val, rounds=10): + """Philox-based random number generation kernel. + + Uses unsigned multiply to match C++ TRandomKernel (RegTensor, vmull.v64u32). + """ + key0 = np.full(len(ctr0), key0_val, dtype=np.uint32) + key1 = np.full(len(ctr0), key1_val, dtype=np.uint32) + + for _ in range(rounds): + prod0 = ctr0.astype(np.uint64) * np.uint64(TRANDOM_CONST_0) + prod1 = ctr2.astype(np.uint64) * np.uint64(TRANDOM_CONST_1) + + L0 = prod0.astype(np.uint32) + H0 = (prod0 >> 32).astype(np.uint32) + L1 = prod1.astype(np.uint32) + H1 = (prod1 >> 32).astype(np.uint32) + + ctr0 = (H1 ^ ctr1) ^ key0 + ctr2 = (H0 ^ ctr3) ^ key1 + ctr1 = L1 + ctr3 = L0 + + key0 = (key0 + TRANDOM_CONST_KEY_ADD_0) & np.uint32(0xFFFFFFFF) + key1 = (key1 + TRANDOM_CONST_KEY_ADD_1) & np.uint32(0xFFFFFFFF) + + return ctr0, ctr1, ctr2, ctr3 + + +def interleave_values(ctr0, ctr1, ctr2, ctr3): + """Simulate vintlv: interleave values to reorder random numbers. + + vintlv semantics (N=64, half=32): + - low[2*i] = src0[i], low[2*i+1] = src1[i] for i in 0..31 (interleave first half) + - high[2*i] = src0[i+32], high[2*i+1] = src1[i+32] for i in 0..31 (interleave second half) + + TRandom uses: + 1. vintlv(tmpL0, tmpH0, ctr0, ctr2) + 2. vintlv(tmpL1, tmpH1, ctr1, ctr3) + 3. vintlv(ctr0, ctr1, tmpL0, tmpL1) + 4. vintlv(ctr2, ctr3, tmpH0, tmpH1) + """ + n = len(ctr0) + half = n // 2 + + tmpL0 = np.empty(n, dtype=np.uint32) + tmpH0 = np.empty(n, dtype=np.uint32) + tmpL1 = np.empty(n, dtype=np.uint32) + tmpH1 = np.empty(n, dtype=np.uint32) + + for i in range(half): + tmpL0[2*i] = ctr0[i] + tmpL0[2*i+1] = ctr2[i] + tmpH0[2*i] = ctr0[i + half] + tmpH0[2*i+1] = ctr2[i + half] + + tmpL1[2*i] = ctr1[i] + tmpL1[2*i+1] = ctr3[i] + tmpH1[2*i] = ctr1[i + half] + tmpH1[2*i+1] = ctr3[i + half] + + result0 = np.empty(n, dtype=np.uint32) + result1 = np.empty(n, dtype=np.uint32) + result2 = np.empty(n, dtype=np.uint32) + result3 = np.empty(n, dtype=np.uint32) + + for i in range(half): + result0[2*i] = tmpL0[i] + result0[2*i+1] = tmpL1[i] + result1[2*i] = tmpL0[i + half] + result1[2*i+1] = tmpL1[i + half] + + result2[2*i] = tmpH0[i] + result2[2*i+1] = tmpH1[i] + result3[2*i] = tmpH0[i + half] + result3[2*i+1] = tmpH1[i + half] + + return result0, result1, result2, result3 + + +def trandom_generate(key, counter, valid_rows, valid_cols, dtype=np.int32, rounds=10): + """Generate random numbers using TRandom algorithm. + + Args: + key: 2-element array (key0, key1) - scalar values, broadcast to all lanes + counter: 4-element array (counter0-3) - 128-bit counter base value + valid_rows: number of rows to generate + valid_cols: number of columns to generate + dtype: output dtype (int32 or uint32) + rounds: number of Philox rounds (7 or 10) + + Returns: + output: (valid_rows, valid_cols) array of random numbers + """ + lanes = 64 + n_loop = (valid_cols + TRANDOM_ONCE_REPEAT * lanes - 1) // (TRANDOM_ONCE_REPEAT * lanes) + + output = np.zeros((valid_rows, valid_cols), dtype=np.uint32) + + key0_val = np.uint32(key[0]) + key1_val = np.uint32(key[1]) + + ctr0 = np.full(lanes, np.uint32(counter[0]), dtype=np.uint32) + ctr1 = np.full(lanes, np.uint32(counter[1]), dtype=np.uint32) + ctr2 = np.full(lanes, np.uint32(counter[2]), dtype=np.uint32) + ctr3 = np.full(lanes, np.uint32(counter[3]), dtype=np.uint32) + + inc_idx = np.arange(lanes, dtype=np.uint32) + ctr0, ctr1, ctr2, ctr3 = add_with_128bits(ctr0, ctr1, ctr2, ctr3, inc_idx) + + for i in range(valid_rows): + s_reg = valid_cols + counter_add_val = lanes + + for j in range(n_loop): + tmp_ctr0 = ctr0.copy() + tmp_ctr1 = ctr1.copy() + tmp_ctr2 = ctr2.copy() + tmp_ctr3 = ctr3.copy() + + tmp_ctr0, tmp_ctr1, tmp_ctr2, tmp_ctr3 = trandom_kernel( + tmp_ctr0, tmp_ctr1, tmp_ctr2, tmp_ctr3, key0_val, key1_val, rounds=rounds + ) + + # Apply interleave to match vintlv semantics in trandom_template.py + # This produces element-wise interleaved order: [ctr0[0], ctr1[0], ctr2[0], ctr3[0], ...] + tmp_ctr0, tmp_ctr1, tmp_ctr2, tmp_ctr3 = interleave_values( + tmp_ctr0, tmp_ctr1, tmp_ctr2, tmp_ctr3 + ) + + for k in range(TRANDOM_ONCE_REPEAT): + start_col = TRANDOM_ONCE_REPEAT * j * lanes + k * lanes + end_col = min(start_col + lanes, valid_cols) + num_valid = end_col - start_col + + if num_valid > 0: + vals = [tmp_ctr0, tmp_ctr1, tmp_ctr2, tmp_ctr3][k] + output[i, start_col:end_col] = vals[:num_valid] + + if s_reg >= TRANDOM_ONCE_REPEAT * lanes: + s_reg = s_reg - TRANDOM_ONCE_REPEAT * lanes + else: + s_reg = 0 + + counter_add_val = lanes if j != n_loop - 1 else ((valid_cols - 1) % lanes + 1) + v_ele_stride = np.full(lanes, np.uint32(counter_add_val), dtype=np.uint32) + ctr0, ctr1, ctr2, ctr3 = add_with_128bits(ctr0, ctr1, ctr2, ctr3, v_ele_stride) + + return output.view(dtype) + + +validate_cases(CASES) + +for case in CASES: + case_dir = case["name"] + key_file = os.path.join(case_dir, "key.bin") + counter_file = os.path.join(case_dir, "counter.bin") + output_file = os.path.join(case_dir, "output.bin") + + dtype = case["dtype"] + valid_rows, valid_cols = case["valid_shape"] + rounds = case.get("rounds", 10) + + if os.path.exists(key_file) and os.path.exists(counter_file): + key = np.fromfile(key_file, dtype=dtype) + counter = np.fromfile(counter_file, dtype=dtype) + print(f"[INFO] gen_data: {case['name']} loaded existing key/counter") + else: + setup_case_rng(case) + value_max = np.iinfo(dtype).max + value_min = np.iinfo(dtype).min + key = np.random.randint(value_min, value_max + 1, size=2, dtype=dtype) + counter = np.random.randint(value_min, value_max + 1, size=4, dtype=dtype) + print(f"[INFO] gen_data: {case['name']} generated new key={key.tolist()} counter={counter.tolist()}") + + if os.path.exists(output_file): + golden = trandom_generate(key.view(np.uint32), counter.view(np.uint32), + valid_rows, valid_cols, dtype=dtype, rounds=rounds) + save_case_data(case["name"], {"key": key, "counter": counter, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} generated golden shape={case['shape']}") + else: + save_case_data(case["name"], {"key": key, "counter": counter}) + print(f"[INFO] gen_data: {case['name']} saved inputs (waiting for output)") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trandom/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trandom/launch.cpp new file mode 100644 index 000000000..2b5cff3ed --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trandom/launch.cpp @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: ui32 4x256 +extern "C" __global__ AICORE void TRANDOM_int32_4x256(__gm__ uint32_t *key, __gm__ uint32_t *counter, __gm__ uint32_t *output); + +void LaunchTRANDOM_int32_4x256(uint32_t *key, uint32_t *counter, uint32_t *output, void *stream) { + TRANDOM_int32_4x256<<<1, nullptr, stream>>>((__gm__ uint32_t *)key, (__gm__ uint32_t *)counter, (__gm__ uint32_t *)output); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trandom/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trandom/main.cpp new file mode 100644 index 000000000..9197e55f7 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trandom/main.cpp @@ -0,0 +1,137 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang trandom ST — case-table driven. +// Each case launches a different kernel variant, reads key/counter and writes output. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTRANDOM_int32_4x256(uint32_t *key, uint32_t *counter, uint32_t *output, void *stream); + +struct TestCase { + const char *name; + void (*launch)(uint32_t *, uint32_t *, uint32_t *, void *); + size_t rows; + size_t cols; +}; + +static const TestCase kCases[] = { + {"int32_4x256", LaunchTRANDOM_int32_4x256, 4, 256}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t outputSize = elemCount * sizeof(uint32_t); + size_t keySize = 2 * sizeof(uint32_t); + size_t counterSize = 4 * sizeof(uint32_t); + + std::printf("[INFO] === case: %s (shape=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols); + + std::string caseDir = std::string("./") + tc.name; + + void *keyHost = nullptr, *counterHost = nullptr, *outputHost = nullptr; + void *keyDevice = nullptr, *counterDevice = nullptr, *outputDevice = nullptr; + + aclrtMallocHost(&keyHost, keySize); + aclrtMallocHost(&counterHost, counterSize); + aclrtMallocHost(&outputHost, outputSize); + + aclrtMalloc(&keyDevice, keySize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&counterDevice, counterSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&outputDevice, outputSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/key.bin").c_str(), keySize, keyHost, keySize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/key.bin\n", caseDir.c_str()); + rc = 1; + } + + if (!ReadFile((caseDir + "/counter.bin").c_str(), counterSize, counterHost, counterSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/counter.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(keyDevice, keySize, keyHost, keySize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(counterDevice, counterSize, counterHost, counterSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch((uint32_t *)keyDevice, (uint32_t *)counterDevice, (uint32_t *)outputDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(outputHost, outputSize, outputDevice, outputSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), outputHost, outputSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (keyDevice != nullptr) + aclrtFree(keyDevice); + if (counterDevice != nullptr) + aclrtFree(counterDevice); + if (outputDevice != nullptr) + aclrtFree(outputDevice); + if (keyHost != nullptr) + aclrtFreeHost(keyHost); + if (counterHost != nullptr) + aclrtFreeHost(counterHost); + if (outputHost != nullptr) + aclrtFreeHost(outputHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trandom/trandom.pto b/test/tilelang_st/npu/a5/src/st/testcase/trandom/trandom.pto new file mode 100644 index 000000000..85d8355bf --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trandom/trandom.pto @@ -0,0 +1,62 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.trandom: generate random numbers using key and counter. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // Case 0: ui32 4x256 (1024 elements, valid=4x256) + // Key and counter passed as ui32 arrays, converted to i32 for pto.trandom (which requires signless) + func.func @TRANDOM_int32_4x256(%key_ptr: !pto.ptr, %counter_ptr: !pto.ptr, %output_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + // Load key and counter values as ui32 + %key0_ui32 = pto.load_scalar %key_ptr[%c0] : !pto.ptr -> ui32 + %key1_ui32 = pto.load_scalar %key_ptr[%c1] : !pto.ptr -> ui32 + %counter0_ui32 = pto.load_scalar %counter_ptr[%c0] : !pto.ptr -> ui32 + %counter1_ui32 = pto.load_scalar %counter_ptr[%c1] : !pto.ptr -> ui32 + %counter2_ui32 = pto.load_scalar %counter_ptr[%c2] : !pto.ptr -> ui32 + %counter3_ui32 = pto.load_scalar %counter_ptr[%c3] : !pto.ptr -> ui32 + + // Convert ui32 to i32 (signless) before passing to pto.trandom + %key0 = builtin.unrealized_conversion_cast %key0_ui32 : ui32 to i32 + %key1 = builtin.unrealized_conversion_cast %key1_ui32 : ui32 to i32 + %counter0 = builtin.unrealized_conversion_cast %counter0_ui32 : ui32 to i32 + %counter1 = builtin.unrealized_conversion_cast %counter1_ui32 : ui32 to i32 + %counter2 = builtin.unrealized_conversion_cast %counter2_ui32 : ui32 to i32 + %counter3 = builtin.unrealized_conversion_cast %counter3_ui32 : ui32 to i32 + + %output_view = pto.make_tensor_view %output_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xui32> + + %output_part = pto.partition_view %output_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c256] + : !pto.tensor_view<1x1x1x4x256xui32> -> !pto.partition_tensor_view<1x1x1x4x256xui32> + + %dst = pto.alloc_tile + : !pto.tile_buf + + // Input 6 scalars are i32 (signless), output tile is ui32 + pto.trandom ins(%key0, %key1, %counter0, %counter1, %counter2, %counter3 : i32, i32, i32, i32, i32, i32) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%output_part : !pto.partition_tensor_view<1x1x1x4x256xui32>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trecip/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/trecip/CMakeLists.txt new file mode 100644 index 000000000..9ec69bc60 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trecip/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(trecip) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trecip/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/trecip/cases.py new file mode 100644 index 000000000..b1c2012e2 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trecip/cases.py @@ -0,0 +1,69 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for trecip ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_16x64", + "dtype": np.float32, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-6, + }, + { + "name": "f32_32x32", + "dtype": np.float32, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-6, + }, + { + "name": "f16_16x64", + "dtype": np.float16, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-3, + }, + { + "name": "f16_32x32", + "dtype": np.float16, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-3, + }, + { + "name": "f32_64x64_pad", + "dtype": np.float32, + "shape": (66, 72), + "valid_shape": (64, 64), + "eps": 1e-6, + }, + { + "name": "f32_58x70", + "dtype": np.float32, + "shape": (66, 72), + "valid_shape": (58, 70), + "eps": 1e-6, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trecip/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/trecip/compare.py new file mode 100644 index 000000000..428604929 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trecip/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trecip/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/trecip/gen_data.py new file mode 100644 index 000000000..81e052958 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trecip/gen_data.py @@ -0,0 +1,31 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + # Avoid 0 for reciprocal, use range [0.1, 10.0] + input = np.random.uniform(0.1, 10.0, size=shape).astype(dtype) + + # reciprocal = 1/x + golden = np.reciprocal(input).astype(dtype, copy=False) + + save_case_data(case["name"], {"input": input, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trecip/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trecip/launch.cpp new file mode 100644 index 000000000..3cf95d119 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trecip/launch.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: f32 16x64 +extern "C" __global__ AICORE void TRECIP_f32_16x64(__gm__ float *a, __gm__ float *b); + +void LaunchTRECIP_f32_16x64(void *a, void *b, void *stream) { + TRECIP_f32_16x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b); +} + +// Case 1: f32 32x32 +extern "C" __global__ AICORE void TRECIP_f32_32x32(__gm__ float *a, __gm__ float *b); + +void LaunchTRECIP_f32_32x32(void *a, void *b, void *stream) { + TRECIP_f32_32x32<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b); +} + +// Case 2: f16 16x64 +extern "C" __global__ AICORE void TRECIP_f16_16x64(__gm__ uint16_t *a, __gm__ uint16_t *b); + +void LaunchTRECIP_f16_16x64(void *a, void *b, void *stream) { + TRECIP_f16_16x64<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b); +} + +// Case 3: f16 32x32 +extern "C" __global__ AICORE void TRECIP_f16_32x32(__gm__ uint16_t *a, __gm__ uint16_t *b); + +void LaunchTRECIP_f16_32x32(void *a, void *b, void *stream) { + TRECIP_f16_32x32<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b); +} + +// Case 4: f32 66x72, valid 64x64 (pad) +extern "C" __global__ AICORE void TRECIP_f32_64x64_pad(__gm__ float *a, __gm__ float *b); + +void LaunchTRECIP_f32_64x64_pad(void *a, void *b, void *stream) { + TRECIP_f32_64x64_pad<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b); +} + +// Case 5: f32 66x72, valid 58x70 (non-square valid) +extern "C" __global__ AICORE void TRECIP_f32_58x70(__gm__ float *a, __gm__ float *b); + +void LaunchTRECIP_f32_58x70(void *a, void *b, void *stream) { + TRECIP_f32_58x70<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trecip/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trecip/main.cpp new file mode 100644 index 000000000..4a8400e6a --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trecip/main.cpp @@ -0,0 +1,141 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang trecip ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTRECIP_f32_16x64(void *a, void *b, void *stream); +void LaunchTRECIP_f32_32x32(void *a, void *b, void *stream); +void LaunchTRECIP_f16_16x64(void *a, void *b, void *stream); +void LaunchTRECIP_f16_32x32(void *a, void *b, void *stream); +void LaunchTRECIP_f32_64x64_pad(void *a, void *b, void *stream); +void LaunchTRECIP_f32_58x70(void *a, void *b, void *stream); + +using LaunchFn = void (*)(void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_16x64", LaunchTRECIP_f32_16x64, 16, 64, 16, 64, sizeof(float)}, + {"f32_32x32", LaunchTRECIP_f32_32x32, 32, 32, 32, 32, sizeof(float)}, + {"f16_16x64", LaunchTRECIP_f16_16x64, 16, 64, 16, 64, sizeof(uint16_t)}, + {"f16_32x32", LaunchTRECIP_f16_32x32, 32, 32, 32, 32, sizeof(uint16_t)}, + {"f32_64x64_pad", LaunchTRECIP_f32_64x64_pad, 66, 72, 64, 64, sizeof(float)}, + {"f32_58x70", LaunchTRECIP_f32_58x70, 66, 72, 58, 70, sizeof(float)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSize = fileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&srcHost, fileSize); + aclrtMallocHost(&dstHost, fileSize); + + aclrtMalloc(&srcDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input.bin").c_str(), srcFileSize, srcHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, fileSize, srcHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./trecip [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trecip/trecip.pto b/test/tilelang_st/npu/a5/src/st/testcase/trecip/trecip.pto new file mode 100644 index 000000000..a2854aefa --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trecip/trecip.pto @@ -0,0 +1,268 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.trecip: 1/x (reciprocal) +// trecip = vdiv(1.0, x) +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // Case 0: f32 16x64 + func.func @TRECIP_f32_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%a : !pto.tile_buf) + + pto.trecip ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + return + } + + // Case 1: f32 32x32 + func.func @TRECIP_f32_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + outs(%a : !pto.tile_buf) + + pto.trecip ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + return + } + + // Case 2: f16 16x64 + func.func @TRECIP_f16_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + outs(%a : !pto.tile_buf) + + pto.trecip ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + return + } + + // Case 3: f16 32x32 + func.func @TRECIP_f16_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf16> -> !pto.partition_tensor_view<1x1x1x32x32xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf16> -> !pto.partition_tensor_view<1x1x1x32x32xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xf16>) + outs(%a : !pto.tile_buf) + + pto.trecip ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x32x32xf16>) + return + } + + // Case 4: f32 66x72, valid 64x64 (pad case) + func.func @TRECIP_f32_64x64_pad(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c66 = arith.constant 66 : index + %c72 = arith.constant 72 : index + %c4752 = arith.constant 4752 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c66, %c72], + strides = [%c4752, %c4752, %c4752, %c72, %c1] + : !pto.tensor_view<1x1x1x66x72xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c66, %c72], + strides = [%c4752, %c4752, %c4752, %c72, %c1] + : !pto.tensor_view<1x1x1x66x72xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x66x72xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x66x72xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%a : !pto.tile_buf) + + pto.trecip ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + return + } + + // Case 5: f32 66x72, valid 58x70 (non-square valid) + func.func @TRECIP_f32_58x70(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c58 = arith.constant 58 : index + %c70 = arith.constant 70 : index + %c66 = arith.constant 66 : index + %c72 = arith.constant 72 : index + %c4752 = arith.constant 4752 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c66, %c72], + strides = [%c4752, %c4752, %c4752, %c72, %c1] + : !pto.tensor_view<1x1x1x66x72xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c66, %c72], + strides = [%c4752, %c4752, %c4752, %c72, %c1] + : !pto.tensor_view<1x1x1x66x72xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c58, %c70] + : !pto.tensor_view<1x1x1x66x72xf32> -> !pto.partition_tensor_view<1x1x1x58x70xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c58, %c70] + : !pto.tensor_view<1x1x1x66x72xf32> -> !pto.partition_tensor_view<1x1x1x58x70xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x58x70xf32>) + outs(%a : !pto.tile_buf) + + pto.trecip ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x58x70xf32>) + return + } +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trelu/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/trelu/CMakeLists.txt new file mode 100644 index 000000000..5b01f92c5 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trelu/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(trelu) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trelu/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/trelu/cases.py new file mode 100644 index 000000000..85823525c --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trelu/cases.py @@ -0,0 +1,52 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for trelu ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — global data dimensions (input/output size). + - tile_shape: (tile_rows, tile_cols) — allocated tile buffer dimensions. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "int32_64x64", + "dtype": np.int32, + "shape": (64, 64), + "tile_shape": (64, 64), + "valid_shape": (64, 64), + "eps": 1e-6, + }, + { + "name": "f16_64x64_valid_60x60", + "dtype": np.float16, + "shape": (60, 60), + "tile_shape": (64, 64), + "valid_shape": (60, 60), + "eps": 1e-3, + }, + { + "name": "f32_64x64_valid_60x60", + "dtype": np.float32, + "shape": (60, 60), + "tile_shape": (64, 64), + "valid_shape": (60, 60), + "eps": 1e-6, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trelu/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/trelu/compare.py new file mode 100644 index 000000000..4409f6261 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trelu/compare.py @@ -0,0 +1,48 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden, output, case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trelu/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/trelu/gen_data.py new file mode 100644 index 000000000..8911a7cb4 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trelu/gen_data.py @@ -0,0 +1,32 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + + if dtype == np.int32: + input1 = np.random.randint(-3_000_000, 3_000_000, size=shape).astype(dtype) + else: + input1 = np.random.uniform(-10, 10, size=shape).astype(dtype) + + golden = np.maximum(input1, 0) + + save_case_data(case["name"], {"input": input1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={case['valid_shape']} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trelu/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trelu/launch.cpp new file mode 100644 index 000000000..94e256ff4 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trelu/launch.cpp @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: int32 64x64 +extern "C" __global__ AICORE void TRELU_int32_64x64(__gm__ int32_t *input, __gm__ int32_t *output); + +void LaunchTRELU_int32_64x64(int32_t *input, int32_t *output, void *stream) { + TRELU_int32_64x64<<<1, nullptr, stream>>>((__gm__ int32_t *)input, (__gm__ int32_t *)output); +} + +// Case 1: f16 64x64 valid 60x60 +extern "C" __global__ AICORE void TRELU_f16_64x64_v60x60(__gm__ uint16_t *input, __gm__ uint16_t *output); + +void LaunchTRELU_f16_64x64_v60x60(uint16_t *input, uint16_t *output, void *stream) { + TRELU_f16_64x64_v60x60<<<1, nullptr, stream>>>((__gm__ uint16_t *)input, (__gm__ uint16_t *)output); +} + +// Case 2: f32 64x64 valid 60x60 +extern "C" __global__ AICORE void TRELU_f32_64x64_v60x60(__gm__ float *input, __gm__ float *output); + +void LaunchTRELU_f32_64x64_v60x60(float *input, float *output, void *stream) { + TRELU_f32_64x64_v60x60<<<1, nullptr, stream>>>((__gm__ float *)input, (__gm__ float *)output); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trelu/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trelu/main.cpp new file mode 100644 index 000000000..b19a7fd95 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trelu/main.cpp @@ -0,0 +1,128 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang trelu ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTRELU_int32_64x64(int32_t *input, int32_t *output, void *stream); +void LaunchTRELU_f16_64x64_v60x60(uint16_t *input, uint16_t *output, void *stream); +void LaunchTRELU_f32_64x64_v60x60(float *input, float *output, void *stream); + +struct TestCase { + const char *name; + void (*launch)(void *, void *, void *); + size_t rows; + size_t cols; + size_t elemSize; +}; + +static const TestCase kCases[] = { + {"int32_64x64", (void (*)(void*, void*, void*))LaunchTRELU_int32_64x64, 64, 64, sizeof(int32_t)}, + {"f16_64x64_valid_60x60", (void (*)(void*, void*, void*))LaunchTRELU_f16_64x64_v60x60, 60, 60, sizeof(uint16_t)}, + {"f32_64x64_valid_60x60", (void (*)(void*, void*, void*))LaunchTRELU_f32_64x64_v60x60, 60, 60, sizeof(float)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols); + + std::string caseDir = std::string("./") + tc.name; + + void *inputHost = nullptr, *outputHost = nullptr; + void *inputDevice = nullptr, *outputDevice = nullptr; + + aclrtMallocHost(&inputHost, fileSize); + aclrtMallocHost(&outputHost, fileSize); + + aclrtMalloc(&inputDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&outputDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input.bin").c_str(), fileSize, inputHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(inputDevice, fileSize, inputHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(inputDevice, outputDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(outputHost, fileSize, outputDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), outputHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (inputDevice != nullptr) + aclrtFree(inputDevice); + if (outputDevice != nullptr) + aclrtFree(outputDevice); + if (inputHost != nullptr) + aclrtFreeHost(inputHost); + if (outputHost != nullptr) + aclrtFreeHost(outputHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trelu/trelu.pto b/test/tilelang_st/npu/a5/src/st/testcase/trelu/trelu.pto new file mode 100644 index 000000000..99c550b1d --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trelu/trelu.pto @@ -0,0 +1,139 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.trelu: tload(input) + trelu(input)->output + tstore(output). +// Multiple cases with different shapes and dtypes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // Case 0: int32 64x64 (4096 elements, valid=64x64) + func.func @TRELU_int32_64x64(%input_ptr: !pto.ptr, %output_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %input_view = pto.make_tensor_view %input_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi32> + %output_view = pto.make_tensor_view %output_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi32> + + %input_part = pto.partition_view %input_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> + %output_part = pto.partition_view %output_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi32> -> !pto.partition_tensor_view<1x1x1x64x64xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%input_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) + outs(%src : !pto.tile_buf) + + pto.trelu ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%output_part : !pto.partition_tensor_view<1x1x1x64x64xi32>) + return + } + + // Case 1: f16 64x64 (4096 elements, valid=60x60) + func.func @TRELU_f16_64x64_v60x60(%input_ptr: !pto.ptr, %output_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c60 = arith.constant 60 : index + %c64 = arith.constant 64 : index + %c3600 = arith.constant 3600 : index + + %input_view = pto.make_tensor_view %input_ptr, + shape = [%c1, %c1, %c1, %c60, %c60], + strides = [%c3600, %c3600, %c3600, %c60, %c1] + : !pto.tensor_view<1x1x1x60x60xf16> + %output_view = pto.make_tensor_view %output_ptr, + shape = [%c1, %c1, %c1, %c60, %c60], + strides = [%c3600, %c3600, %c3600, %c60, %c1] + : !pto.tensor_view<1x1x1x60x60xf16> + + %input_part = pto.partition_view %input_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c60, %c60] + : !pto.tensor_view<1x1x1x60x60xf16> -> !pto.partition_tensor_view<1x1x1x60x60xf16> + %output_part = pto.partition_view %output_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c60, %c60] + : !pto.tensor_view<1x1x1x60x60xf16> -> !pto.partition_tensor_view<1x1x1x60x60xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%input_part : !pto.partition_tensor_view<1x1x1x60x60xf16>) + outs(%src : !pto.tile_buf) + + pto.trelu ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%output_part : !pto.partition_tensor_view<1x1x1x60x60xf16>) + return + } + + // Case 2: f32 64x64 (4096 elements, valid=60x60) + func.func @TRELU_f32_64x64_v60x60(%input_ptr: !pto.ptr, %output_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c60 = arith.constant 60 : index + %c64 = arith.constant 64 : index + %c3600 = arith.constant 3600 : index + + %input_view = pto.make_tensor_view %input_ptr, + shape = [%c1, %c1, %c1, %c60, %c60], + strides = [%c3600, %c3600, %c3600, %c60, %c1] + : !pto.tensor_view<1x1x1x60x60xf32> + %output_view = pto.make_tensor_view %output_ptr, + shape = [%c1, %c1, %c1, %c60, %c60], + strides = [%c3600, %c3600, %c3600, %c60, %c1] + : !pto.tensor_view<1x1x1x60x60xf32> + + %input_part = pto.partition_view %input_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c60, %c60] + : !pto.tensor_view<1x1x1x60x60xf32> -> !pto.partition_tensor_view<1x1x1x60x60xf32> + %output_part = pto.partition_view %output_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c60, %c60] + : !pto.tensor_view<1x1x1x60x60xf32> -> !pto.partition_tensor_view<1x1x1x60x60xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%input_part : !pto.partition_tensor_view<1x1x1x60x60xf32>) + outs(%src : !pto.tile_buf) + + pto.trelu ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%output_part : !pto.partition_tensor_view<1x1x1x60x60xf32>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trem/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/trem/CMakeLists.txt new file mode 100644 index 000000000..eb913e626 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trem/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(trem) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trem/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/trem/cases.py new file mode 100644 index 000000000..544675c0b --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trem/cases.py @@ -0,0 +1,40 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for trem ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_16x64", + "dtype": np.float32, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-6, + }, + { + "name": "f32_32x32", + "dtype": np.float32, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-6, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trem/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/trem/compare.py new file mode 100644 index 000000000..4eae3bc07 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trem/compare.py @@ -0,0 +1,48 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trem/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/trem/gen_data.py new file mode 100644 index 000000000..4cd825fa9 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trem/gen_data.py @@ -0,0 +1,32 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + input2 = np.random.randint(1, 10, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + golden[:vr, :vc] = np.remainder(input1[:vr, :vc], input2[:vr, :vc]) + + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trem/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trem/launch.cpp new file mode 100644 index 000000000..6dbbc1685 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trem/launch.cpp @@ -0,0 +1,27 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: f32 16x64 +extern "C" __global__ AICORE void TREM_f32_16x64(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTREM_f32_16x64(float *a, float *b, float *c, void *stream) { + TREM_f32_16x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 1: f32 32x32 +extern "C" __global__ AICORE void TREM_f32_32x32(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTREM_f32_32x32(float *a, float *b, float *c, void *stream) { + TREM_f32_32x32<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trem/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trem/main.cpp new file mode 100644 index 000000000..e616ef131 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trem/main.cpp @@ -0,0 +1,145 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tadd ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTREM_f32_16x64(float *a, float *b, float *c, void *stream); +void LaunchTREM_f32_32x32(float *a, float *b, float *c, void *stream); + +using LaunchFn = void (*)(float *, float *, float *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_16x64", LaunchTREM_f32_16x64, 16, 64, 16, 64, sizeof(float)}, + {"f32_32x32", LaunchTREM_f32_32x32, 32, 32, 32, 32, sizeof(float)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t src0FileSize = fileSize; + size_t src1FileSize = fileSize; + + float *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + float *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), fileSize); + aclrtMallocHost((void **)(&src1Host), fileSize); + aclrtMallocHost((void **)(&dstHost), fileSize); + + aclrtMalloc((void **)&src0Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, fileSize, src0Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, fileSize, src1Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./trem [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trem/trem.pto b/test/tilelang_st/npu/a5/src/st/testcase/trem/trem.pto new file mode 100644 index 000000000..3551aa9e0 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trem/trem.pto @@ -0,0 +1,151 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tadd: tload(a) + tload(b) + tadd(a,b)->c + tstore(c). +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // Case 0: f32 16x64 (1024 elements) + func.func @TREM_f32_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%b : !pto.tile_buf) + + pto.trem ins(%a, %b, %tmp : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + return + } + + // Case 1: f32 32x32 (1024 elements) + func.func @TREM_f32_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + outs(%b : !pto.tile_buf) + + pto.trem ins(%a, %b, %tmp : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trems/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/trems/CMakeLists.txt new file mode 100644 index 000000000..3a21d2c4c --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trems/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(trems) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trems/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/trems/cases.py new file mode 100644 index 000000000..b87b0b9b0 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trems/cases.py @@ -0,0 +1,56 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for trems ST test cases. + +trems: integer remainder via vdiv, dst = src - trunc(src/scalar) * scalar. +All types: f32, f16, i32, i16. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_32x64", + "dtype": np.float32, + "shape": (32, 64), + "valid_shape": (32, 64), + "eps": 1e-6, + }, + { + "name": "f16_63x64", + "dtype": np.float16, + "shape": (63, 64), + "valid_shape": (63, 64), + "eps": 1e-3, + }, + { + "name": "f32_7x448", + "dtype": np.float32, + "shape": (7, 448), + "valid_shape": (7, 448), + "eps": 1e-6, + }, + { + "name": "f32_256x16", + "dtype": np.float32, + "shape": (256, 16), + "valid_shape": (256, 16), + "eps": 1e-6, + }, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trems/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/trems/compare.py new file mode 100644 index 000000000..18835ae9f --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trems/compare.py @@ -0,0 +1,56 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trems/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/trems/gen_data.py new file mode 100644 index 000000000..02fcf6165 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trems/gen_data.py @@ -0,0 +1,46 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +# Scalar value for remainder (matches the scalar passed in launch.cpp) +SCALAR = 3.0 + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + scalar_val = dtype(SCALAR) + if np.issubdtype(dtype, np.floating): + golden[:vr, :vc] = (input1[:vr, :vc] - np.trunc(input1[:vr, :vc] / scalar_val) * scalar_val).astype(dtype, copy=False) + else: + golden[:vr, :vc] = (input1[:vr, :vc] % scalar_val).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__} scalar={SCALAR}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trems/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trems/launch.cpp new file mode 100644 index 000000000..95f4d0aea --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trems/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Scalar value for remainder (must match gen_data.py SCALAR) +static constexpr float TREMS_SCALAR_F32 = 3.0f; + +// Case 0: f32 32x64 +extern "C" __global__ AICORE void TREMS_f32_32x64(__gm__ float *src, __gm__ float *dst, float scalar); + +void LaunchTREMS_f32_32x64(float *src, float *dst, void *stream) { + TREMS_f32_32x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, TREMS_SCALAR_F32); +} + +// Case 1: f16 63x64 +extern "C" __global__ AICORE void TREMS_f16_63x64(__gm__ unsigned short *src, __gm__ unsigned short *dst, unsigned short scalar); + +void LaunchTREMS_f16_63x64(unsigned short *src, unsigned short *dst, void *stream) { + TREMS_f16_63x64<<<1, nullptr, stream>>>((__gm__ unsigned short *)src, (__gm__ unsigned short *)dst, (unsigned short)0x4200); +} + +// Case 2: f32 7x448 +extern "C" __global__ AICORE void TREMS_f32_7x448(__gm__ float *src, __gm__ float *dst, float scalar); + +void LaunchTREMS_f32_7x448(float *src, float *dst, void *stream) { + TREMS_f32_7x448<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, TREMS_SCALAR_F32); +} + +// Case 3: f32 256x16 +extern "C" __global__ AICORE void TREMS_f32_256x16(__gm__ float *src, __gm__ float *dst, float scalar); + +void LaunchTREMS_f32_256x16(float *src, float *dst, void *stream) { + TREMS_f32_256x16<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, TREMS_SCALAR_F32); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trems/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trems/main.cpp new file mode 100644 index 000000000..810ab7c53 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trems/main.cpp @@ -0,0 +1,135 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang trems ST — case-table driven. +// trems: dst = src - trunc(src/scalar) * scalar (integer remainder via vdiv). +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTREMS_f32_32x64(float *src, float *dst, void *stream); +void LaunchTREMS_f16_63x64(uint16_t *src, uint16_t *dst, void *stream); +void LaunchTREMS_f32_7x448(float *src, float *dst, void *stream); +void LaunchTREMS_f32_256x16(float *src, float *dst, void *stream); + +struct TestCase { + const char *name; + void (*launch)(void *, void *, void *); // src, dst, stream + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_32x64", (void (*)(void*,void*,void*))LaunchTREMS_f32_32x64, 32, 64, 32, 64, sizeof(float)}, + {"f16_63x64", (void (*)(void*,void*,void*))LaunchTREMS_f16_63x64, 63, 64, 63, 64, sizeof(uint16_t)}, + {"f32_7x448", (void (*)(void*,void*,void*))LaunchTREMS_f32_7x448, 7, 448, 7, 448, sizeof(float)}, + {"f32_256x16", (void (*)(void*,void*,void*))LaunchTREMS_f32_256x16, 256, 16, 256, 16, sizeof(float)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSize = fileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&srcHost, fileSize); + aclrtMallocHost(&dstHost, fileSize); + + aclrtMalloc(&srcDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), srcFileSize, srcHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, fileSize, srcHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./trems [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trems/trems.pto b/test/tilelang_st/npu/a5/src/st/testcase/trems/trems.pto new file mode 100644 index 000000000..bc686b938 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trems/trems.pto @@ -0,0 +1,192 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.trems: tload(src) + trems(src, scalar, tmp)->dst + tstore(dst). +// Integer remainder via vdiv: dst = src - trunc(src/scalar) * scalar. +// All types: f32, f16, i32, i16. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + + // Case 0: f32 32x64 (2048 elements) + func.func @TREMS_f32_32x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) + outs(%src : !pto.tile_buf) + pto.trems ins(%src, %scalar, %tmp : !pto.tile_buf, f32, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) + return + } + + // Case 1: f16 63x64 (4032 elements) + func.func @TREMS_f16_63x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f16) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c4032 = arith.constant 4032 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xf16> -> !pto.partition_tensor_view<1x1x1x63x64xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xf16> -> !pto.partition_tensor_view<1x1x1x63x64xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x64xf16>) + outs(%src : !pto.tile_buf) + pto.trems ins(%src, %scalar, %tmp : !pto.tile_buf, f16, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x64xf16>) + return + } + + // Case 2: i32 31x128 (3968 elements) - SKIPPED: vdiv does not support integer types on A5 hardware + // Case 3: i16 15x192 (2880 elements) - SKIPPED: vdiv does not support integer types on A5 hardware + + // Case 4: f32 7x448 (3136 elements) + func.func @TREMS_f32_7x448(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c7 = arith.constant 7 : index + %c448 = arith.constant 448 : index + %c3136 = arith.constant 3136 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c7, %c448], + strides = [%c3136, %c3136, %c3136, %c448, %c1] + : !pto.tensor_view<1x1x1x7x448xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c7, %c448], + strides = [%c3136, %c3136, %c3136, %c448, %c1] + : !pto.tensor_view<1x1x1x7x448xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c448] + : !pto.tensor_view<1x1x1x7x448xf32> -> !pto.partition_tensor_view<1x1x1x7x448xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c448] + : !pto.tensor_view<1x1x1x7x448xf32> -> !pto.partition_tensor_view<1x1x1x7x448xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x7x448xf32>) + outs(%src : !pto.tile_buf) + pto.trems ins(%src, %scalar, %tmp : !pto.tile_buf, f32, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x7x448xf32>) + return + } + + // Case 5: f32 256x16 (4096 elements) + func.func @TREMS_f32_256x16(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c16 = arith.constant 16 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c256, %c16], + strides = [%c4096, %c4096, %c4096, %c16, %c1] + : !pto.tensor_view<1x1x1x256x16xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c256, %c16], + strides = [%c4096, %c4096, %c4096, %c16, %c1] + : !pto.tensor_view<1x1x1x256x16xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c16] + : !pto.tensor_view<1x1x1x256x16xf32> -> !pto.partition_tensor_view<1x1x1x256x16xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c16] + : !pto.tensor_view<1x1x1x256x16xf32> -> !pto.partition_tensor_view<1x1x1x256x16xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x256x16xf32>) + outs(%src : !pto.tile_buf) + pto.trems ins(%src, %scalar, %tmp : !pto.tile_buf, f32, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x256x16xf32>) + return + } + +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowargmax/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/trowargmax/CMakeLists.txt new file mode 100644 index 000000000..42aec9129 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowargmax/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(trowargmax) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowargmax/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/trowargmax/cases.py new file mode 100644 index 000000000..c67114409 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowargmax/cases.py @@ -0,0 +1,215 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for trowargmax ST test cases — aligned with pto-isa.""" + +import numpy as np + +CASES = [ + # uint32_dst + float32_src + { + "name": "uint32_float_8x1_8x8_8x8", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (8, 8), + "valid_shape": (8, 8), + "eps": 0, + }, + { + "name": "uint32_float_1024x1_1024x8_1024x8", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (1024, 8), + "valid_shape": (1024, 8), + "eps": 0, + }, + { + "name": "uint32_float_16x1_13x16_13x13", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (13, 16), + "valid_shape": (13, 13), + "eps": 0, + }, + { + "name": "uint32_float_1024x1_1023x24_1023x17", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (1023, 24), + "valid_shape": (1023, 17), + "eps": 0, + }, + { + "name": "uint32_float_8x1_8x64_8x64", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (8, 64), + "valid_shape": (8, 64), + "eps": 0, + }, + { + "name": "uint32_float_264x1_260x64_260x64", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (260, 64), + "valid_shape": (260, 64), + "eps": 0, + }, + { + "name": "uint32_float_8x1_1x128_1x128", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (1, 128), + "valid_shape": (1, 128), + "eps": 0, + }, + { + "name": "uint32_float_64x1_32x128_32x128", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (32, 128), + "valid_shape": (32, 128), + "eps": 0, + }, + { + "name": "uint32_float_8x1_3x4096_3x4095", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (3, 4096), + "valid_shape": (3, 4095), + "eps": 0, + }, + { + "name": "uint32_float_8x1_2x16384_2x16381", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (2, 16384), + "valid_shape": (2, 16381), + "eps": 0, + }, + # uint32_dst + float16_src + { + "name": "uint32_half_16x1_2x16_2x16", + "dtype": np.float16, + "dst_dtype": np.uint32, + "shape": (2, 16), + "valid_shape": (2, 16), + "eps": 0, + }, + { + "name": "uint32_half_16x1_13x16_13x13", + "dtype": np.float16, + "dst_dtype": np.uint32, + "shape": (13, 16), + "valid_shape": (13, 13), + "eps": 0, + }, + { + "name": "uint32_half_272x1_260x64_260x64", + "dtype": np.float16, + "dst_dtype": np.uint32, + "shape": (260, 64), + "valid_shape": (260, 64), + "eps": 0, + }, + { + "name": "uint32_half_16x1_3x8192_3x8191", + "dtype": np.float16, + "dst_dtype": np.uint32, + "shape": (3, 8192), + "valid_shape": (3, 8191), + "eps": 0, + }, + { + "name": "uint32_half_16x1_1x16384_1x16381", + "dtype": np.float16, + "dst_dtype": np.uint32, + "shape": (1, 16384), + "valid_shape": (1, 16381), + "eps": 0, + }, + { + "name": "uint32_half_16x1_1x32768_1x32761", + "dtype": np.float16, + "dst_dtype": np.uint32, + "shape": (1, 32768), + "valid_shape": (1, 32761), + "eps": 0, + }, + # int32_dst + float32_src + { + "name": "int32_float_16x1_13x16_13x13", + "dtype": np.float32, + "dst_dtype": np.int32, + "shape": (13, 16), + "valid_shape": (13, 13), + "eps": 0, + }, + # int32_dst + float16_src + { + "name": "int32_half_16x1_13x16_13x13", + "dtype": np.float16, + "dst_dtype": np.int32, + "shape": (13, 16), + "valid_shape": (13, 13), + "eps": 0, + }, + # uint32_dst + float32_src (dst col > 1) + { + "name": "uint32_float_3x8_3x3480_3x3473", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (3, 3480), + "valid_shape": (3, 3473), + "eps": 0, + }, + { + "name": "uint32_float_260x8_260x64_260x64", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (260, 64), + "valid_shape": (260, 64), + "eps": 0, + }, + { + "name": "uint32_float_1023x8_1023x24_1023x17", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (1023, 24), + "valid_shape": (1023, 17), + "eps": 0, + }, + # uint32_dst + float16_src (dst col > 1) + { + "name": "uint32_half_3x16_3x3488_3x3473", + "dtype": np.float16, + "dst_dtype": np.uint32, + "shape": (3, 3488), + "valid_shape": (3, 3473), + "eps": 0, + }, + { + "name": "uint32_half_260x16_260x64_260x64", + "dtype": np.float16, + "dst_dtype": np.uint32, + "shape": (260, 64), + "valid_shape": (260, 64), + "eps": 0, + }, + { + "name": "uint32_half_1023x16_1023x32_1023x17", + "dtype": np.float16, + "dst_dtype": np.uint32, + "shape": (1023, 32), + "valid_shape": (1023, 17), + "eps": 0, + }, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowargmax/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/trowargmax/compare.py new file mode 100644 index 000000000..4cd015fd3 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowargmax/compare.py @@ -0,0 +1,52 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + vr, vc = case["valid_shape"] + out_shape = (vr, 1) + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dst_dtype"], count=np.prod(out_shape)).reshape(out_shape) + + output_full = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dst_dtype"]) + dst_cols = len(output_full) // vr + output = output_full.reshape(vr, dst_cols)[:, 0:1] + + ok = result_cmp(golden, output, case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowargmax/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/trowargmax/gen_data.py new file mode 100644 index 000000000..3016b948f --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowargmax/gen_data.py @@ -0,0 +1,38 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + dst_dtype = case["dst_dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + if dtype in (np.int8, np.uint8, np.int16, np.uint16, np.int32, np.uint32): + dtype_info = np.iinfo(dtype) + input1 = np.random.randint(dtype_info.min, dtype_info.max, size=shape).astype(dtype) + else: + dtype_info = np.finfo(dtype) + input1 = np.random.uniform(low=dtype_info.min, high=dtype_info.max, size=shape).astype(dtype) + + out_shape = (valid_shape[0], 1) + golden = np.zeros(out_shape, dtype=dst_dtype) + golden[:, 0:1] = np.argmax(input1[:, :valid_shape[1]], axis=1, keepdims=True).astype(dst_dtype) + + save_case_data(case["name"], {"input1": input1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowargmax/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowargmax/launch.cpp new file mode 100644 index 000000000..1da9eee23 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowargmax/launch.cpp @@ -0,0 +1,133 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +extern "C" __global__ AICORE void TROWARGMAX_uint32_float_8x1_8x8_8x8(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMAX_uint32_float_8x1_8x8_8x8(float *src, uint32_t *dst, void *stream) { + TROWARGMAX_uint32_float_8x1_8x8_8x8<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMAX_uint32_float_1024x1_1024x8_1024x8(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMAX_uint32_float_1024x1_1024x8_1024x8(float *src, uint32_t *dst, void *stream) { + TROWARGMAX_uint32_float_1024x1_1024x8_1024x8<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMAX_uint32_float_16x1_13x16_13x13(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMAX_uint32_float_16x1_13x16_13x13(float *src, uint32_t *dst, void *stream) { + TROWARGMAX_uint32_float_16x1_13x16_13x13<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMAX_uint32_float_1024x1_1023x24_1023x17(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMAX_uint32_float_1024x1_1023x24_1023x17(float *src, uint32_t *dst, void *stream) { + TROWARGMAX_uint32_float_1024x1_1023x24_1023x17<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMAX_uint32_float_8x1_8x64_8x64(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMAX_uint32_float_8x1_8x64_8x64(float *src, uint32_t *dst, void *stream) { + TROWARGMAX_uint32_float_8x1_8x64_8x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMAX_uint32_float_264x1_260x64_260x64(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMAX_uint32_float_264x1_260x64_260x64(float *src, uint32_t *dst, void *stream) { + TROWARGMAX_uint32_float_264x1_260x64_260x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMAX_uint32_float_8x1_1x128_1x128(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMAX_uint32_float_8x1_1x128_1x128(float *src, uint32_t *dst, void *stream) { + TROWARGMAX_uint32_float_8x1_1x128_1x128<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMAX_uint32_float_64x1_32x128_32x128(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMAX_uint32_float_64x1_32x128_32x128(float *src, uint32_t *dst, void *stream) { + TROWARGMAX_uint32_float_64x1_32x128_32x128<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMAX_uint32_float_8x1_3x4096_3x4095(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMAX_uint32_float_8x1_3x4096_3x4095(float *src, uint32_t *dst, void *stream) { + TROWARGMAX_uint32_float_8x1_3x4096_3x4095<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMAX_uint32_float_8x1_2x16384_2x16381(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMAX_uint32_float_8x1_2x16384_2x16381(float *src, uint32_t *dst, void *stream) { + TROWARGMAX_uint32_float_8x1_2x16384_2x16381<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMAX_uint32_half_16x1_2x16_2x16(__gm__ uint16_t *src, __gm__ uint32_t *dst); +void LaunchTROWARGMAX_uint32_half_16x1_2x16_2x16(uint16_t *src, uint32_t *dst, void *stream) { + TROWARGMAX_uint32_half_16x1_2x16_2x16<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMAX_uint32_half_16x1_13x16_13x13(__gm__ uint16_t *src, __gm__ uint32_t *dst); +void LaunchTROWARGMAX_uint32_half_16x1_13x16_13x13(uint16_t *src, uint32_t *dst, void *stream) { + TROWARGMAX_uint32_half_16x1_13x16_13x13<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMAX_uint32_half_272x1_260x64_260x64(__gm__ uint16_t *src, __gm__ uint32_t *dst); +void LaunchTROWARGMAX_uint32_half_272x1_260x64_260x64(uint16_t *src, uint32_t *dst, void *stream) { + TROWARGMAX_uint32_half_272x1_260x64_260x64<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMAX_uint32_half_16x1_3x8192_3x8191(__gm__ uint16_t *src, __gm__ uint32_t *dst); +void LaunchTROWARGMAX_uint32_half_16x1_3x8192_3x8191(uint16_t *src, uint32_t *dst, void *stream) { + TROWARGMAX_uint32_half_16x1_3x8192_3x8191<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMAX_uint32_half_16x1_1x16384_1x16381(__gm__ uint16_t *src, __gm__ uint32_t *dst); +void LaunchTROWARGMAX_uint32_half_16x1_1x16384_1x16381(uint16_t *src, uint32_t *dst, void *stream) { + TROWARGMAX_uint32_half_16x1_1x16384_1x16381<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMAX_uint32_half_16x1_1x32768_1x32761(__gm__ uint16_t *src, __gm__ uint32_t *dst); +void LaunchTROWARGMAX_uint32_half_16x1_1x32768_1x32761(uint16_t *src, uint32_t *dst, void *stream) { + TROWARGMAX_uint32_half_16x1_1x32768_1x32761<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMAX_int32_float_16x1_13x16_13x13(__gm__ float *src, __gm__ int32_t *dst); +void LaunchTROWARGMAX_int32_float_16x1_13x16_13x13(float *src, int32_t *dst, void *stream) { + TROWARGMAX_int32_float_16x1_13x16_13x13<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ int32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMAX_int32_half_16x1_13x16_13x13(__gm__ uint16_t *src, __gm__ int32_t *dst); +void LaunchTROWARGMAX_int32_half_16x1_13x16_13x13(uint16_t *src, int32_t *dst, void *stream) { + TROWARGMAX_int32_half_16x1_13x16_13x13<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ int32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMAX_uint32_float_3x8_3x3480_3x3473(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMAX_uint32_float_3x8_3x3480_3x3473(float *src, uint32_t *dst, void *stream) { + TROWARGMAX_uint32_float_3x8_3x3480_3x3473<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMAX_uint32_float_260x8_260x64_260x64(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMAX_uint32_float_260x8_260x64_260x64(float *src, uint32_t *dst, void *stream) { + TROWARGMAX_uint32_float_260x8_260x64_260x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMAX_uint32_float_1023x8_1023x24_1023x17(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMAX_uint32_float_1023x8_1023x24_1023x17(float *src, uint32_t *dst, void *stream) { + TROWARGMAX_uint32_float_1023x8_1023x24_1023x17<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMAX_uint32_half_3x16_3x3488_3x3473(__gm__ uint16_t *src, __gm__ uint32_t *dst); +void LaunchTROWARGMAX_uint32_half_3x16_3x3488_3x3473(uint16_t *src, uint32_t *dst, void *stream) { + TROWARGMAX_uint32_half_3x16_3x3488_3x3473<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMAX_uint32_half_260x16_260x64_260x64(__gm__ uint16_t *src, __gm__ uint32_t *dst); +void LaunchTROWARGMAX_uint32_half_260x16_260x64_260x64(uint16_t *src, uint32_t *dst, void *stream) { + TROWARGMAX_uint32_half_260x16_260x64_260x64<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMAX_uint32_half_1023x16_1023x32_1023x17(__gm__ uint16_t *src, __gm__ uint32_t *dst); +void LaunchTROWARGMAX_uint32_half_1023x16_1023x32_1023x17(uint16_t *src, uint32_t *dst, void *stream) { + TROWARGMAX_uint32_half_1023x16_1023x32_1023x17<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint32_t *)dst); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowargmax/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowargmax/main.cpp new file mode 100644 index 000000000..908e57820 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowargmax/main.cpp @@ -0,0 +1,210 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang trowargmax ST — case-table driven. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTROWARGMAX_uint32_float_8x1_8x8_8x8(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMAX_uint32_float_1024x1_1024x8_1024x8(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMAX_uint32_float_16x1_13x16_13x13(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMAX_uint32_float_1024x1_1023x24_1023x17(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMAX_uint32_float_8x1_8x64_8x64(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMAX_uint32_float_264x1_260x64_260x64(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMAX_uint32_float_8x1_1x128_1x128(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMAX_uint32_float_64x1_32x128_32x128(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMAX_uint32_float_8x1_3x4096_3x4095(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMAX_uint32_float_8x1_2x16384_2x16381(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMAX_uint32_half_16x1_2x16_2x16(uint16_t *src, uint32_t *dst, void *stream); +void LaunchTROWARGMAX_uint32_half_16x1_13x16_13x13(uint16_t *src, uint32_t *dst, void *stream); +void LaunchTROWARGMAX_uint32_half_272x1_260x64_260x64(uint16_t *src, uint32_t *dst, void *stream); +void LaunchTROWARGMAX_uint32_half_16x1_3x8192_3x8191(uint16_t *src, uint32_t *dst, void *stream); +void LaunchTROWARGMAX_uint32_half_16x1_1x16384_1x16381(uint16_t *src, uint32_t *dst, void *stream); +void LaunchTROWARGMAX_uint32_half_16x1_1x32768_1x32761(uint16_t *src, uint32_t *dst, void *stream); +void LaunchTROWARGMAX_int32_float_16x1_13x16_13x13(float *src, int32_t *dst, void *stream); +void LaunchTROWARGMAX_int32_half_16x1_13x16_13x13(uint16_t *src, int32_t *dst, void *stream); +void LaunchTROWARGMAX_uint32_float_3x8_3x3480_3x3473(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMAX_uint32_float_260x8_260x64_260x64(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMAX_uint32_float_1023x8_1023x24_1023x17(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMAX_uint32_half_3x16_3x3488_3x3473(uint16_t *src, uint32_t *dst, void *stream); +void LaunchTROWARGMAX_uint32_half_260x16_260x64_260x64(uint16_t *src, uint32_t *dst, void *stream); +void LaunchTROWARGMAX_uint32_half_1023x16_1023x32_1023x17(uint16_t *src, uint32_t *dst, void *stream); + +using LaunchFnF32U32 = void (*)(float *, uint32_t *, void *); +using LaunchFnF16U32 = void (*)(uint16_t *, uint32_t *, void *); +using LaunchFnF32S32 = void (*)(float *, int32_t *, void *); +using LaunchFnF16S32 = void (*)(uint16_t *, int32_t *, void *); + +enum class DType { F32U32, F16U32, F32S32, F16S32 }; + +struct TestCase { + const char *name; + DType dtype; + union { + LaunchFnF32U32 launchF32U32; + LaunchFnF16U32 launchF16U32; + LaunchFnF32S32 launchF32S32; + LaunchFnF16S32 launchF16S32; + }; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t srcElemSize; // bytes per src element + size_t dstElemSize; // bytes per dst element + size_t dstCols; // dst tile cols +}; + +static const TestCase kCases[] = { + {"uint32_float_8x1_8x8_8x8", DType::F32U32, .launchF32U32 = LaunchTROWARGMAX_uint32_float_8x1_8x8_8x8, 8, 8, 8, 8, 4, 4, 1}, + {"uint32_float_1024x1_1024x8_1024x8", DType::F32U32, .launchF32U32 = LaunchTROWARGMAX_uint32_float_1024x1_1024x8_1024x8, 1024, 8, 1024, 8, 4, 4, 1}, + {"uint32_float_16x1_13x16_13x13", DType::F32U32, .launchF32U32 = LaunchTROWARGMAX_uint32_float_16x1_13x16_13x13, 13, 16, 13, 13, 4, 4, 1}, + {"uint32_float_1024x1_1023x24_1023x17", DType::F32U32, .launchF32U32 = LaunchTROWARGMAX_uint32_float_1024x1_1023x24_1023x17, 1023, 24, 1023, 17, 4, 4, 1}, + {"uint32_float_8x1_8x64_8x64", DType::F32U32, .launchF32U32 = LaunchTROWARGMAX_uint32_float_8x1_8x64_8x64, 8, 64, 8, 64, 4, 4, 1}, + {"uint32_float_264x1_260x64_260x64", DType::F32U32, .launchF32U32 = LaunchTROWARGMAX_uint32_float_264x1_260x64_260x64, 260, 64, 260, 64, 4, 4, 1}, + {"uint32_float_8x1_1x128_1x128", DType::F32U32, .launchF32U32 = LaunchTROWARGMAX_uint32_float_8x1_1x128_1x128, 1, 128, 1, 128, 4, 4, 1}, + {"uint32_float_64x1_32x128_32x128", DType::F32U32, .launchF32U32 = LaunchTROWARGMAX_uint32_float_64x1_32x128_32x128, 32, 128, 32, 128, 4, 4, 1}, + {"uint32_float_8x1_3x4096_3x4095", DType::F32U32, .launchF32U32 = LaunchTROWARGMAX_uint32_float_8x1_3x4096_3x4095, 3, 4096, 3, 4095, 4, 4, 1}, + {"uint32_float_8x1_2x16384_2x16381", DType::F32U32, .launchF32U32 = LaunchTROWARGMAX_uint32_float_8x1_2x16384_2x16381, 2, 16384, 2, 16381, 4, 4, 1}, + {"uint32_half_16x1_2x16_2x16", DType::F16U32, .launchF16U32 = LaunchTROWARGMAX_uint32_half_16x1_2x16_2x16, 2, 16, 2, 16, 2, 4, 1}, + {"uint32_half_16x1_13x16_13x13", DType::F16U32, .launchF16U32 = LaunchTROWARGMAX_uint32_half_16x1_13x16_13x13, 13, 16, 13, 13, 2, 4, 1}, + {"uint32_half_272x1_260x64_260x64", DType::F16U32, .launchF16U32 = LaunchTROWARGMAX_uint32_half_272x1_260x64_260x64, 260, 64, 260, 64, 2, 4, 1}, + {"uint32_half_16x1_3x8192_3x8191", DType::F16U32, .launchF16U32 = LaunchTROWARGMAX_uint32_half_16x1_3x8192_3x8191, 3, 8192, 3, 8191, 2, 4, 1}, + {"uint32_half_16x1_1x16384_1x16381", DType::F16U32, .launchF16U32 = LaunchTROWARGMAX_uint32_half_16x1_1x16384_1x16381, 1, 16384, 1, 16381, 2, 4, 1}, + {"uint32_half_16x1_1x32768_1x32761", DType::F16U32, .launchF16U32 = LaunchTROWARGMAX_uint32_half_16x1_1x32768_1x32761, 1, 32768, 1, 32761, 2, 4, 1}, + {"int32_float_16x1_13x16_13x13", DType::F32S32, .launchF32S32 = LaunchTROWARGMAX_int32_float_16x1_13x16_13x13, 13, 16, 13, 13, 4, 4, 1}, + {"int32_half_16x1_13x16_13x13", DType::F16S32, .launchF16S32 = LaunchTROWARGMAX_int32_half_16x1_13x16_13x13, 13, 16, 13, 13, 2, 4, 1}, + {"uint32_float_3x8_3x3480_3x3473", DType::F32U32, .launchF32U32 = LaunchTROWARGMAX_uint32_float_3x8_3x3480_3x3473, 3, 3480, 3, 3473, 4, 4, 8}, + {"uint32_float_260x8_260x64_260x64", DType::F32U32, .launchF32U32 = LaunchTROWARGMAX_uint32_float_260x8_260x64_260x64, 260, 64, 260, 64, 4, 4, 8}, + {"uint32_float_1023x8_1023x24_1023x17", DType::F32U32, .launchF32U32 = LaunchTROWARGMAX_uint32_float_1023x8_1023x24_1023x17, 1023, 24, 1023, 17, 4, 4, 8}, + {"uint32_half_3x16_3x3488_3x3473", DType::F16U32, .launchF16U32 = LaunchTROWARGMAX_uint32_half_3x16_3x3488_3x3473, 3, 3488, 3, 3473, 2, 4, 16}, + {"uint32_half_260x16_260x64_260x64", DType::F16U32, .launchF16U32 = LaunchTROWARGMAX_uint32_half_260x16_260x64_260x64, 260, 64, 260, 64, 2, 4, 16}, + {"uint32_half_1023x16_1023x32_1023x17", DType::F16U32, .launchF16U32 = LaunchTROWARGMAX_uint32_half_1023x16_1023x32_1023x17, 1023, 32, 1023, 17, 2, 4, 16}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t srcElemCount = tc.rows * tc.cols; + const size_t srcFileSize = srcElemCount * tc.srcElemSize; + const size_t dstElemCount = tc.validRows * tc.dstCols; + const size_t dstFileSize = dstElemCount * tc.dstElemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t src0FileSize = srcFileSize; + + void *src0Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&src0Host, srcFileSize); + aclrtMallocHost(&dstHost, dstFileSize); + + aclrtMalloc(&src0Device, srcFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (rc == 0) { + aclrtMemset(dstDevice, dstFileSize, 0, dstFileSize); + } + + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, srcFileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, srcFileSize, src0Host, srcFileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + switch (tc.dtype) { + case DType::F32U32: + tc.launchF32U32((float *)src0Device, (uint32_t *)dstDevice, stream); + break; + case DType::F16U32: + tc.launchF16U32((uint16_t *)src0Device, (uint32_t *)dstDevice, stream); + break; + case DType::F32S32: + tc.launchF32S32((float *)src0Device, (int32_t *)dstDevice, stream); + break; + case DType::F16S32: + tc.launchF16S32((uint16_t *)src0Device, (int32_t *)dstDevice, stream); + break; + } + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0) { + mkdir(caseDir.c_str(), 0755); + if (!WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./trowargmax [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowargmax/trowargmax.pto b/test/tilelang_st/npu/a5/src/st/testcase/trowargmax/trowargmax.pto new file mode 100644 index 000000000..d6459ad07 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowargmax/trowargmax.pto @@ -0,0 +1,1013 @@ +// Auto-generated trowargmax ST testcases + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + + func.func @TROWARGMAX_uint32_float_8x1_8x8_8x8(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8_r = arith.constant 8 : index + %c8_c = arith.constant 8 : index + %c64_se = arith.constant 64 : index + %c8_de = arith.constant 8 : index + %c1_dc = arith.constant 1 : index + %c8_vr = arith.constant 8 : index + %c8_vc = arith.constant 8 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c8_r, %c8_c], + strides = [%c64_se, %c64_se, %c64_se, %c8_c, %c1] + : !pto.tensor_view<1x1x1x8x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c8_r, %c1_dc], + strides = [%c8_de, %c8_de, %c8_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x8x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8_vr, %c8_vc] + : !pto.tensor_view<1x1x1x8x8xf32> -> !pto.partition_tensor_view<1x1x1x8x8xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8_vr, %c1] + : !pto.tensor_view<1x1x1x8x1xui32> -> !pto.partition_tensor_view<1x1x1x8x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x8x8xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x8x1xui32>) + return + } + + func.func @TROWARGMAX_uint32_float_1024x1_1024x8_1024x8(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1024_r = arith.constant 1024 : index + %c8_c = arith.constant 8 : index + %c8192_se = arith.constant 8192 : index + %c1024_de = arith.constant 1024 : index + %c1_dc = arith.constant 1 : index + %c1024_vr = arith.constant 1024 : index + %c8_vc = arith.constant 8 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1024_r, %c8_c], + strides = [%c8192_se, %c8192_se, %c8192_se, %c8_c, %c1] + : !pto.tensor_view<1x1x1x1024x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1024_r, %c1_dc], + strides = [%c1024_de, %c1024_de, %c1024_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x1024x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1024_vr, %c8_vc] + : !pto.tensor_view<1x1x1x1024x8xf32> -> !pto.partition_tensor_view<1x1x1x1024x8xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1024_vr, %c1] + : !pto.tensor_view<1x1x1x1024x1xui32> -> !pto.partition_tensor_view<1x1x1x1024x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1024x8xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1024x1xui32>) + return + } + + func.func @TROWARGMAX_uint32_float_16x1_13x16_13x13(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c13_r = arith.constant 13 : index + %c16_c = arith.constant 16 : index + %c208_se = arith.constant 208 : index + %c13_de = arith.constant 13 : index + %c1_dc = arith.constant 1 : index + %c13_vr = arith.constant 13 : index + %c13_vc = arith.constant 13 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c13_r, %c16_c], + strides = [%c208_se, %c208_se, %c208_se, %c16_c, %c1] + : !pto.tensor_view<1x1x1x13x16xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c13_r, %c1_dc], + strides = [%c13_de, %c13_de, %c13_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x13x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c13_vr, %c13_vc] + : !pto.tensor_view<1x1x1x13x16xf32> -> !pto.partition_tensor_view<1x1x1x13x13xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c13_vr, %c1] + : !pto.tensor_view<1x1x1x13x1xui32> -> !pto.partition_tensor_view<1x1x1x13x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x13x13xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x13x1xui32>) + return + } + + func.func @TROWARGMAX_uint32_float_1024x1_1023x24_1023x17(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1023_r = arith.constant 1023 : index + %c24_c = arith.constant 24 : index + %c24552_se = arith.constant 24552 : index + %c1023_de = arith.constant 1023 : index + %c1_dc = arith.constant 1 : index + %c1023_vr = arith.constant 1023 : index + %c17_vc = arith.constant 17 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1023_r, %c24_c], + strides = [%c24552_se, %c24552_se, %c24552_se, %c24_c, %c1] + : !pto.tensor_view<1x1x1x1023x24xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1023_r, %c1_dc], + strides = [%c1023_de, %c1023_de, %c1023_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x1023x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1023_vr, %c17_vc] + : !pto.tensor_view<1x1x1x1023x24xf32> -> !pto.partition_tensor_view<1x1x1x1023x17xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1023_vr, %c1] + : !pto.tensor_view<1x1x1x1023x1xui32> -> !pto.partition_tensor_view<1x1x1x1023x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1023x17xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1023x1xui32>) + return + } + + func.func @TROWARGMAX_uint32_float_8x1_8x64_8x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8_r = arith.constant 8 : index + %c64_c = arith.constant 64 : index + %c512_se = arith.constant 512 : index + %c8_de = arith.constant 8 : index + %c1_dc = arith.constant 1 : index + %c8_vr = arith.constant 8 : index + %c64_vc = arith.constant 64 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c8_r, %c64_c], + strides = [%c512_se, %c512_se, %c512_se, %c64_c, %c1] + : !pto.tensor_view<1x1x1x8x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c8_r, %c1_dc], + strides = [%c8_de, %c8_de, %c8_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x8x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8_vr, %c64_vc] + : !pto.tensor_view<1x1x1x8x64xf32> -> !pto.partition_tensor_view<1x1x1x8x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8_vr, %c1] + : !pto.tensor_view<1x1x1x8x1xui32> -> !pto.partition_tensor_view<1x1x1x8x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x8x64xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x8x1xui32>) + return + } + + func.func @TROWARGMAX_uint32_float_264x1_260x64_260x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c260_r = arith.constant 260 : index + %c64_c = arith.constant 64 : index + %c16640_se = arith.constant 16640 : index + %c260_de = arith.constant 260 : index + %c1_dc = arith.constant 1 : index + %c260_vr = arith.constant 260 : index + %c64_vc = arith.constant 64 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c260_r, %c64_c], + strides = [%c16640_se, %c16640_se, %c16640_se, %c64_c, %c1] + : !pto.tensor_view<1x1x1x260x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c260_r, %c1_dc], + strides = [%c260_de, %c260_de, %c260_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x260x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260_vr, %c64_vc] + : !pto.tensor_view<1x1x1x260x64xf32> -> !pto.partition_tensor_view<1x1x1x260x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260_vr, %c1] + : !pto.tensor_view<1x1x1x260x1xui32> -> !pto.partition_tensor_view<1x1x1x260x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x260x64xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x260x1xui32>) + return + } + + func.func @TROWARGMAX_uint32_float_8x1_1x128_1x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1_r = arith.constant 1 : index + %c128_c = arith.constant 128 : index + %c128_se = arith.constant 128 : index + %c1_de = arith.constant 1 : index + %c1_dc = arith.constant 1 : index + %c1_vr = arith.constant 1 : index + %c128_vc = arith.constant 128 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1_r, %c128_c], + strides = [%c128_se, %c128_se, %c128_se, %c128_c, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1_r, %c1_dc], + strides = [%c1_de, %c1_de, %c1_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x1x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1_vr, %c128_vc] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x128xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1_vr, %c1] + : !pto.tensor_view<1x1x1x1x1xui32> -> !pto.partition_tensor_view<1x1x1x1x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x128xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x1xui32>) + return + } + + func.func @TROWARGMAX_uint32_float_64x1_32x128_32x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32_r = arith.constant 32 : index + %c128_c = arith.constant 128 : index + %c4096_se = arith.constant 4096 : index + %c32_de = arith.constant 32 : index + %c1_dc = arith.constant 1 : index + %c32_vr = arith.constant 32 : index + %c128_vc = arith.constant 128 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c32_r, %c128_c], + strides = [%c4096_se, %c4096_se, %c4096_se, %c128_c, %c1] + : !pto.tensor_view<1x1x1x32x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32_r, %c1_dc], + strides = [%c32_de, %c32_de, %c32_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x32x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32_vr, %c128_vc] + : !pto.tensor_view<1x1x1x32x128xf32> -> !pto.partition_tensor_view<1x1x1x32x128xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32_vr, %c1] + : !pto.tensor_view<1x1x1x32x1xui32> -> !pto.partition_tensor_view<1x1x1x32x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x128xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x1xui32>) + return + } + + func.func @TROWARGMAX_uint32_float_8x1_3x4096_3x4095(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3_r = arith.constant 3 : index + %c4096_c = arith.constant 4096 : index + %c12288_se = arith.constant 12288 : index + %c3_de = arith.constant 3 : index + %c1_dc = arith.constant 1 : index + %c3_vr = arith.constant 3 : index + %c4095_vc = arith.constant 4095 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c3_r, %c4096_c], + strides = [%c12288_se, %c12288_se, %c12288_se, %c4096_c, %c1] + : !pto.tensor_view<1x1x1x3x4096xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c3_r, %c1_dc], + strides = [%c3_de, %c3_de, %c3_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x3x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c3_vr, %c4095_vc] + : !pto.tensor_view<1x1x1x3x4096xf32> -> !pto.partition_tensor_view<1x1x1x3x4095xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c3_vr, %c1] + : !pto.tensor_view<1x1x1x3x1xui32> -> !pto.partition_tensor_view<1x1x1x3x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x3x4095xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x3x1xui32>) + return + } + + func.func @TROWARGMAX_uint32_float_8x1_2x16384_2x16381(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2_r = arith.constant 2 : index + %c16384_c = arith.constant 16384 : index + %c32768_se = arith.constant 32768 : index + %c2_de = arith.constant 2 : index + %c1_dc = arith.constant 1 : index + %c2_vr = arith.constant 2 : index + %c16381_vc = arith.constant 16381 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2_r, %c16384_c], + strides = [%c32768_se, %c32768_se, %c32768_se, %c16384_c, %c1] + : !pto.tensor_view<1x1x1x2x16384xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2_r, %c1_dc], + strides = [%c2_de, %c2_de, %c2_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x2x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2_vr, %c16381_vc] + : !pto.tensor_view<1x1x1x2x16384xf32> -> !pto.partition_tensor_view<1x1x1x2x16381xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2_vr, %c1] + : !pto.tensor_view<1x1x1x2x1xui32> -> !pto.partition_tensor_view<1x1x1x2x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x16381xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x1xui32>) + return + } + + func.func @TROWARGMAX_uint32_half_16x1_2x16_2x16(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2_r = arith.constant 2 : index + %c16_c = arith.constant 16 : index + %c32_se = arith.constant 32 : index + %c2_de = arith.constant 2 : index + %c1_dc = arith.constant 1 : index + %c2_vr = arith.constant 2 : index + %c16_vc = arith.constant 16 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2_r, %c16_c], + strides = [%c32_se, %c32_se, %c32_se, %c16_c, %c1] + : !pto.tensor_view<1x1x1x2x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2_r, %c1_dc], + strides = [%c2_de, %c2_de, %c2_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x2x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2_vr, %c16_vc] + : !pto.tensor_view<1x1x1x2x16xf16> -> !pto.partition_tensor_view<1x1x1x2x16xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2_vr, %c1] + : !pto.tensor_view<1x1x1x2x1xui32> -> !pto.partition_tensor_view<1x1x1x2x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x16xf16>) + outs(%src : !pto.tile_buf) + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x1xui32>) + return + } + + func.func @TROWARGMAX_uint32_half_16x1_13x16_13x13(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c13_r = arith.constant 13 : index + %c16_c = arith.constant 16 : index + %c208_se = arith.constant 208 : index + %c13_de = arith.constant 13 : index + %c1_dc = arith.constant 1 : index + %c13_vr = arith.constant 13 : index + %c13_vc = arith.constant 13 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c13_r, %c16_c], + strides = [%c208_se, %c208_se, %c208_se, %c16_c, %c1] + : !pto.tensor_view<1x1x1x13x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c13_r, %c1_dc], + strides = [%c13_de, %c13_de, %c13_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x13x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c13_vr, %c13_vc] + : !pto.tensor_view<1x1x1x13x16xf16> -> !pto.partition_tensor_view<1x1x1x13x13xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c13_vr, %c1] + : !pto.tensor_view<1x1x1x13x1xui32> -> !pto.partition_tensor_view<1x1x1x13x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x13x13xf16>) + outs(%src : !pto.tile_buf) + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x13x1xui32>) + return + } + + func.func @TROWARGMAX_uint32_half_272x1_260x64_260x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c260_r = arith.constant 260 : index + %c64_c = arith.constant 64 : index + %c16640_se = arith.constant 16640 : index + %c260_de = arith.constant 260 : index + %c1_dc = arith.constant 1 : index + %c260_vr = arith.constant 260 : index + %c64_vc = arith.constant 64 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c260_r, %c64_c], + strides = [%c16640_se, %c16640_se, %c16640_se, %c64_c, %c1] + : !pto.tensor_view<1x1x1x260x64xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c260_r, %c1_dc], + strides = [%c260_de, %c260_de, %c260_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x260x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260_vr, %c64_vc] + : !pto.tensor_view<1x1x1x260x64xf16> -> !pto.partition_tensor_view<1x1x1x260x64xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260_vr, %c1] + : !pto.tensor_view<1x1x1x260x1xui32> -> !pto.partition_tensor_view<1x1x1x260x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x260x64xf16>) + outs(%src : !pto.tile_buf) + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x260x1xui32>) + return + } + + func.func @TROWARGMAX_uint32_half_16x1_3x8192_3x8191(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3_r = arith.constant 3 : index + %c8192_c = arith.constant 8192 : index + %c24576_se = arith.constant 24576 : index + %c3_de = arith.constant 3 : index + %c1_dc = arith.constant 1 : index + %c3_vr = arith.constant 3 : index + %c8191_vc = arith.constant 8191 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c3_r, %c8192_c], + strides = [%c24576_se, %c24576_se, %c24576_se, %c8192_c, %c1] + : !pto.tensor_view<1x1x1x3x8192xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c3_r, %c1_dc], + strides = [%c3_de, %c3_de, %c3_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x3x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c3_vr, %c8191_vc] + : !pto.tensor_view<1x1x1x3x8192xf16> -> !pto.partition_tensor_view<1x1x1x3x8191xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c3_vr, %c1] + : !pto.tensor_view<1x1x1x3x1xui32> -> !pto.partition_tensor_view<1x1x1x3x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x3x8191xf16>) + outs(%src : !pto.tile_buf) + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x3x1xui32>) + return + } + + func.func @TROWARGMAX_uint32_half_16x1_1x16384_1x16381(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1_r = arith.constant 1 : index + %c16384_c = arith.constant 16384 : index + %c16384_se = arith.constant 16384 : index + %c1_de = arith.constant 1 : index + %c1_dc = arith.constant 1 : index + %c1_vr = arith.constant 1 : index + %c16381_vc = arith.constant 16381 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1_r, %c16384_c], + strides = [%c16384_se, %c16384_se, %c16384_se, %c16384_c, %c1] + : !pto.tensor_view<1x1x1x1x16384xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1_r, %c1_dc], + strides = [%c1_de, %c1_de, %c1_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x1x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1_vr, %c16381_vc] + : !pto.tensor_view<1x1x1x1x16384xf16> -> !pto.partition_tensor_view<1x1x1x1x16381xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1_vr, %c1] + : !pto.tensor_view<1x1x1x1x1xui32> -> !pto.partition_tensor_view<1x1x1x1x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x16381xf16>) + outs(%src : !pto.tile_buf) + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x1xui32>) + return + } + + func.func @TROWARGMAX_uint32_half_16x1_1x32768_1x32761(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1_r = arith.constant 1 : index + %c32768_c = arith.constant 32768 : index + %c32768_se = arith.constant 32768 : index + %c1_de = arith.constant 1 : index + %c1_dc = arith.constant 1 : index + %c1_vr = arith.constant 1 : index + %c32761_vc = arith.constant 32761 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1_r, %c32768_c], + strides = [%c32768_se, %c32768_se, %c32768_se, %c32768_c, %c1] + : !pto.tensor_view<1x1x1x1x32768xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1_r, %c1_dc], + strides = [%c1_de, %c1_de, %c1_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x1x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1_vr, %c32761_vc] + : !pto.tensor_view<1x1x1x1x32768xf16> -> !pto.partition_tensor_view<1x1x1x1x32761xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1_vr, %c1] + : !pto.tensor_view<1x1x1x1x1xui32> -> !pto.partition_tensor_view<1x1x1x1x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x32761xf16>) + outs(%src : !pto.tile_buf) + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x1xui32>) + return + } + + func.func @TROWARGMAX_int32_float_16x1_13x16_13x13(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c13_r = arith.constant 13 : index + %c16_c = arith.constant 16 : index + %c208_se = arith.constant 208 : index + %c13_de = arith.constant 13 : index + %c1_dc = arith.constant 1 : index + %c13_vr = arith.constant 13 : index + %c13_vc = arith.constant 13 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c13_r, %c16_c], + strides = [%c208_se, %c208_se, %c208_se, %c16_c, %c1] + : !pto.tensor_view<1x1x1x13x16xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c13_r, %c1_dc], + strides = [%c13_de, %c13_de, %c13_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x13x1xi32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c13_vr, %c13_vc] + : !pto.tensor_view<1x1x1x13x16xf32> -> !pto.partition_tensor_view<1x1x1x13x13xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c13_vr, %c1] + : !pto.tensor_view<1x1x1x13x1xi32> -> !pto.partition_tensor_view<1x1x1x13x1xi32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x13x13xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x13x1xi32>) + return + } + + func.func @TROWARGMAX_int32_half_16x1_13x16_13x13(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c13_r = arith.constant 13 : index + %c16_c = arith.constant 16 : index + %c208_se = arith.constant 208 : index + %c13_de = arith.constant 13 : index + %c1_dc = arith.constant 1 : index + %c13_vr = arith.constant 13 : index + %c13_vc = arith.constant 13 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c13_r, %c16_c], + strides = [%c208_se, %c208_se, %c208_se, %c16_c, %c1] + : !pto.tensor_view<1x1x1x13x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c13_r, %c1_dc], + strides = [%c13_de, %c13_de, %c13_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x13x1xi32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c13_vr, %c13_vc] + : !pto.tensor_view<1x1x1x13x16xf16> -> !pto.partition_tensor_view<1x1x1x13x13xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c13_vr, %c1] + : !pto.tensor_view<1x1x1x13x1xi32> -> !pto.partition_tensor_view<1x1x1x13x1xi32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x13x13xf16>) + outs(%src : !pto.tile_buf) + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x13x1xi32>) + return + } + + func.func @TROWARGMAX_uint32_float_3x8_3x3480_3x3473(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3_r = arith.constant 3 : index + %c3480_c = arith.constant 3480 : index + %c10440_se = arith.constant 10440 : index + %c24_de = arith.constant 24 : index + %c8_dc = arith.constant 8 : index + %c3_vr = arith.constant 3 : index + %c3473_vc = arith.constant 3473 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c3_r, %c3480_c], + strides = [%c10440_se, %c10440_se, %c10440_se, %c3480_c, %c1] + : !pto.tensor_view<1x1x1x3x3480xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c3_r, %c8_dc], + strides = [%c24_de, %c24_de, %c24_de, %c8_dc, %c1] + : !pto.tensor_view<1x1x1x3x8xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c3_vr, %c3473_vc] + : !pto.tensor_view<1x1x1x3x3480xf32> -> !pto.partition_tensor_view<1x1x1x3x3473xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c3_vr, %c1] + : !pto.tensor_view<1x1x1x3x8xui32> -> !pto.partition_tensor_view<1x1x1x3x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x3x3473xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x3x1xui32>) + return + } + + func.func @TROWARGMAX_uint32_float_260x8_260x64_260x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c260_r = arith.constant 260 : index + %c64_c = arith.constant 64 : index + %c16640_se = arith.constant 16640 : index + %c2080_de = arith.constant 2080 : index + %c8_dc = arith.constant 8 : index + %c260_vr = arith.constant 260 : index + %c64_vc = arith.constant 64 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c260_r, %c64_c], + strides = [%c16640_se, %c16640_se, %c16640_se, %c64_c, %c1] + : !pto.tensor_view<1x1x1x260x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c260_r, %c8_dc], + strides = [%c2080_de, %c2080_de, %c2080_de, %c8_dc, %c1] + : !pto.tensor_view<1x1x1x260x8xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260_vr, %c64_vc] + : !pto.tensor_view<1x1x1x260x64xf32> -> !pto.partition_tensor_view<1x1x1x260x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260_vr, %c1] + : !pto.tensor_view<1x1x1x260x8xui32> -> !pto.partition_tensor_view<1x1x1x260x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x260x64xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x260x1xui32>) + return + } + + func.func @TROWARGMAX_uint32_float_1023x8_1023x24_1023x17(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1023_r = arith.constant 1023 : index + %c24_c = arith.constant 24 : index + %c24552_se = arith.constant 24552 : index + %c8184_de = arith.constant 8184 : index + %c8_dc = arith.constant 8 : index + %c1023_vr = arith.constant 1023 : index + %c17_vc = arith.constant 17 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1023_r, %c24_c], + strides = [%c24552_se, %c24552_se, %c24552_se, %c24_c, %c1] + : !pto.tensor_view<1x1x1x1023x24xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1023_r, %c8_dc], + strides = [%c8184_de, %c8184_de, %c8184_de, %c8_dc, %c1] + : !pto.tensor_view<1x1x1x1023x8xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1023_vr, %c17_vc] + : !pto.tensor_view<1x1x1x1023x24xf32> -> !pto.partition_tensor_view<1x1x1x1023x17xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1023_vr, %c1] + : !pto.tensor_view<1x1x1x1023x8xui32> -> !pto.partition_tensor_view<1x1x1x1023x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1023x17xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1023x1xui32>) + return + } + + func.func @TROWARGMAX_uint32_half_3x16_3x3488_3x3473(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3_r = arith.constant 3 : index + %c3488_c = arith.constant 3488 : index + %c10464_se = arith.constant 10464 : index + %c48_de = arith.constant 48 : index + %c16_dc = arith.constant 16 : index + %c3_vr = arith.constant 3 : index + %c3473_vc = arith.constant 3473 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c3_r, %c3488_c], + strides = [%c10464_se, %c10464_se, %c10464_se, %c3488_c, %c1] + : !pto.tensor_view<1x1x1x3x3488xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c3_r, %c16_dc], + strides = [%c48_de, %c48_de, %c48_de, %c16_dc, %c1] + : !pto.tensor_view<1x1x1x3x16xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c3_vr, %c3473_vc] + : !pto.tensor_view<1x1x1x3x3488xf16> -> !pto.partition_tensor_view<1x1x1x3x3473xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c3_vr, %c1] + : !pto.tensor_view<1x1x1x3x16xui32> -> !pto.partition_tensor_view<1x1x1x3x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x3x3473xf16>) + outs(%src : !pto.tile_buf) + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x3x1xui32>) + return + } + + func.func @TROWARGMAX_uint32_half_260x16_260x64_260x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c260_r = arith.constant 260 : index + %c64_c = arith.constant 64 : index + %c16640_se = arith.constant 16640 : index + %c4160_de = arith.constant 4160 : index + %c16_dc = arith.constant 16 : index + %c260_vr = arith.constant 260 : index + %c64_vc = arith.constant 64 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c260_r, %c64_c], + strides = [%c16640_se, %c16640_se, %c16640_se, %c64_c, %c1] + : !pto.tensor_view<1x1x1x260x64xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c260_r, %c16_dc], + strides = [%c4160_de, %c4160_de, %c4160_de, %c16_dc, %c1] + : !pto.tensor_view<1x1x1x260x16xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260_vr, %c64_vc] + : !pto.tensor_view<1x1x1x260x64xf16> -> !pto.partition_tensor_view<1x1x1x260x64xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260_vr, %c1] + : !pto.tensor_view<1x1x1x260x16xui32> -> !pto.partition_tensor_view<1x1x1x260x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x260x64xf16>) + outs(%src : !pto.tile_buf) + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x260x1xui32>) + return + } + + func.func @TROWARGMAX_uint32_half_1023x16_1023x32_1023x17(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1023_r = arith.constant 1023 : index + %c32_c = arith.constant 32 : index + %c32736_se = arith.constant 32736 : index + %c16368_de = arith.constant 16368 : index + %c16_dc = arith.constant 16 : index + %c1023_vr = arith.constant 1023 : index + %c17_vc = arith.constant 17 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1023_r, %c32_c], + strides = [%c32736_se, %c32736_se, %c32736_se, %c32_c, %c1] + : !pto.tensor_view<1x1x1x1023x32xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1023_r, %c16_dc], + strides = [%c16368_de, %c16368_de, %c16368_de, %c16_dc, %c1] + : !pto.tensor_view<1x1x1x1023x16xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1023_vr, %c17_vc] + : !pto.tensor_view<1x1x1x1023x32xf16> -> !pto.partition_tensor_view<1x1x1x1023x17xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1023_vr, %c1] + : !pto.tensor_view<1x1x1x1023x16xui32> -> !pto.partition_tensor_view<1x1x1x1023x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1023x17xf16>) + outs(%src : !pto.tile_buf) + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1023x1xui32>) + return + } + +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowargmin/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/trowargmin/CMakeLists.txt new file mode 100644 index 000000000..a6a8925b5 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowargmin/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(trowargmin) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowargmin/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/trowargmin/cases.py new file mode 100644 index 000000000..2614b130e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowargmin/cases.py @@ -0,0 +1,215 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for trowargmin ST test cases — aligned with pto-isa.""" + +import numpy as np + +CASES = [ + # uint32_dst + float32_src + { + "name": "uint32_float_8x1_8x8_8x8", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (8, 8), + "valid_shape": (8, 8), + "eps": 0, + }, + { + "name": "uint32_float_1024x1_1024x8_1024x8", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (1024, 8), + "valid_shape": (1024, 8), + "eps": 0, + }, + { + "name": "uint32_float_16x1_13x16_13x13", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (13, 16), + "valid_shape": (13, 13), + "eps": 0, + }, + { + "name": "uint32_float_1024x1_1023x24_1023x17", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (1023, 24), + "valid_shape": (1023, 17), + "eps": 0, + }, + { + "name": "uint32_float_8x1_8x64_8x64", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (8, 64), + "valid_shape": (8, 64), + "eps": 0, + }, + { + "name": "uint32_float_264x1_260x64_260x64", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (260, 64), + "valid_shape": (260, 64), + "eps": 0, + }, + { + "name": "uint32_float_8x1_1x128_1x128", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (1, 128), + "valid_shape": (1, 128), + "eps": 0, + }, + { + "name": "uint32_float_64x1_32x128_32x128", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (32, 128), + "valid_shape": (32, 128), + "eps": 0, + }, + { + "name": "uint32_float_8x1_3x4096_3x4095", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (3, 4096), + "valid_shape": (3, 4095), + "eps": 0, + }, + { + "name": "uint32_float_8x1_2x16384_2x16381", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (2, 16384), + "valid_shape": (2, 16381), + "eps": 0, + }, + # uint32_dst + float16_src + { + "name": "uint32_half_16x1_2x16_2x16", + "dtype": np.float16, + "dst_dtype": np.uint32, + "shape": (2, 16), + "valid_shape": (2, 16), + "eps": 0, + }, + { + "name": "uint32_half_16x1_13x16_13x13", + "dtype": np.float16, + "dst_dtype": np.uint32, + "shape": (13, 16), + "valid_shape": (13, 13), + "eps": 0, + }, + { + "name": "uint32_half_272x1_260x64_260x64", + "dtype": np.float16, + "dst_dtype": np.uint32, + "shape": (260, 64), + "valid_shape": (260, 64), + "eps": 0, + }, + { + "name": "uint32_half_16x1_3x8192_3x8191", + "dtype": np.float16, + "dst_dtype": np.uint32, + "shape": (3, 8192), + "valid_shape": (3, 8191), + "eps": 0, + }, + { + "name": "uint32_half_16x1_1x16384_1x16381", + "dtype": np.float16, + "dst_dtype": np.uint32, + "shape": (1, 16384), + "valid_shape": (1, 16381), + "eps": 0, + }, + { + "name": "uint32_half_16x1_1x32768_1x32761", + "dtype": np.float16, + "dst_dtype": np.uint32, + "shape": (1, 32768), + "valid_shape": (1, 32761), + "eps": 0, + }, + # int32_dst + float32_src + { + "name": "int32_float_16x1_13x16_13x13", + "dtype": np.float32, + "dst_dtype": np.int32, + "shape": (13, 16), + "valid_shape": (13, 13), + "eps": 0, + }, + # int32_dst + float16_src + { + "name": "int32_half_16x1_13x16_13x13", + "dtype": np.float16, + "dst_dtype": np.int32, + "shape": (13, 16), + "valid_shape": (13, 13), + "eps": 0, + }, + # uint32_dst + float32_src (dst col > 1) + { + "name": "uint32_float_3x8_3x3480_3x3473", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (3, 3480), + "valid_shape": (3, 3473), + "eps": 0, + }, + { + "name": "uint32_float_260x8_260x64_260x64", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (260, 64), + "valid_shape": (260, 64), + "eps": 0, + }, + { + "name": "uint32_float_1023x8_1023x24_1023x17", + "dtype": np.float32, + "dst_dtype": np.uint32, + "shape": (1023, 24), + "valid_shape": (1023, 17), + "eps": 0, + }, + # uint32_dst + float16_src (dst col > 1) + { + "name": "uint32_half_3x16_3x3488_3x3473", + "dtype": np.float16, + "dst_dtype": np.uint32, + "shape": (3, 3488), + "valid_shape": (3, 3473), + "eps": 0, + }, + { + "name": "uint32_half_260x16_260x64_260x64", + "dtype": np.float16, + "dst_dtype": np.uint32, + "shape": (260, 64), + "valid_shape": (260, 64), + "eps": 0, + }, + { + "name": "uint32_half_1023x16_1023x32_1023x17", + "dtype": np.float16, + "dst_dtype": np.uint32, + "shape": (1023, 32), + "valid_shape": (1023, 17), + "eps": 0, + }, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowargmin/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/trowargmin/compare.py new file mode 100644 index 000000000..4cd015fd3 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowargmin/compare.py @@ -0,0 +1,52 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + vr, vc = case["valid_shape"] + out_shape = (vr, 1) + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dst_dtype"], count=np.prod(out_shape)).reshape(out_shape) + + output_full = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dst_dtype"]) + dst_cols = len(output_full) // vr + output = output_full.reshape(vr, dst_cols)[:, 0:1] + + ok = result_cmp(golden, output, case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowargmin/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/trowargmin/gen_data.py new file mode 100644 index 000000000..6c103094c --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowargmin/gen_data.py @@ -0,0 +1,38 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + dst_dtype = case["dst_dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + if dtype in (np.int8, np.uint8, np.int16, np.uint16, np.int32, np.uint32): + dtype_info = np.iinfo(dtype) + input1 = np.random.randint(dtype_info.min, dtype_info.max, size=shape).astype(dtype) + else: + dtype_info = np.finfo(dtype) + input1 = np.random.uniform(low=dtype_info.min, high=dtype_info.max, size=shape).astype(dtype) + + out_shape = (valid_shape[0], 1) + golden = np.zeros(out_shape, dtype=dst_dtype) + golden[:, 0:1] = np.argmin(input1[:, :valid_shape[1]], axis=1, keepdims=True).astype(dst_dtype) + + save_case_data(case["name"], {"input1": input1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowargmin/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowargmin/launch.cpp new file mode 100644 index 000000000..d87134237 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowargmin/launch.cpp @@ -0,0 +1,133 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +extern "C" __global__ AICORE void TROWARGMIN_uint32_float_8x1_8x8_8x8(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMIN_uint32_float_8x1_8x8_8x8(float *src, uint32_t *dst, void *stream) { + TROWARGMIN_uint32_float_8x1_8x8_8x8<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMIN_uint32_float_1024x1_1024x8_1024x8(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMIN_uint32_float_1024x1_1024x8_1024x8(float *src, uint32_t *dst, void *stream) { + TROWARGMIN_uint32_float_1024x1_1024x8_1024x8<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMIN_uint32_float_16x1_13x16_13x13(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMIN_uint32_float_16x1_13x16_13x13(float *src, uint32_t *dst, void *stream) { + TROWARGMIN_uint32_float_16x1_13x16_13x13<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMIN_uint32_float_1024x1_1023x24_1023x17(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMIN_uint32_float_1024x1_1023x24_1023x17(float *src, uint32_t *dst, void *stream) { + TROWARGMIN_uint32_float_1024x1_1023x24_1023x17<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMIN_uint32_float_8x1_8x64_8x64(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMIN_uint32_float_8x1_8x64_8x64(float *src, uint32_t *dst, void *stream) { + TROWARGMIN_uint32_float_8x1_8x64_8x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMIN_uint32_float_264x1_260x64_260x64(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMIN_uint32_float_264x1_260x64_260x64(float *src, uint32_t *dst, void *stream) { + TROWARGMIN_uint32_float_264x1_260x64_260x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMIN_uint32_float_8x1_1x128_1x128(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMIN_uint32_float_8x1_1x128_1x128(float *src, uint32_t *dst, void *stream) { + TROWARGMIN_uint32_float_8x1_1x128_1x128<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMIN_uint32_float_64x1_32x128_32x128(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMIN_uint32_float_64x1_32x128_32x128(float *src, uint32_t *dst, void *stream) { + TROWARGMIN_uint32_float_64x1_32x128_32x128<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMIN_uint32_float_8x1_3x4096_3x4095(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMIN_uint32_float_8x1_3x4096_3x4095(float *src, uint32_t *dst, void *stream) { + TROWARGMIN_uint32_float_8x1_3x4096_3x4095<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMIN_uint32_float_8x1_2x16384_2x16381(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMIN_uint32_float_8x1_2x16384_2x16381(float *src, uint32_t *dst, void *stream) { + TROWARGMIN_uint32_float_8x1_2x16384_2x16381<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMIN_uint32_half_16x1_2x16_2x16(__gm__ uint16_t *src, __gm__ uint32_t *dst); +void LaunchTROWARGMIN_uint32_half_16x1_2x16_2x16(uint16_t *src, uint32_t *dst, void *stream) { + TROWARGMIN_uint32_half_16x1_2x16_2x16<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMIN_uint32_half_16x1_13x16_13x13(__gm__ uint16_t *src, __gm__ uint32_t *dst); +void LaunchTROWARGMIN_uint32_half_16x1_13x16_13x13(uint16_t *src, uint32_t *dst, void *stream) { + TROWARGMIN_uint32_half_16x1_13x16_13x13<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMIN_uint32_half_272x1_260x64_260x64(__gm__ uint16_t *src, __gm__ uint32_t *dst); +void LaunchTROWARGMIN_uint32_half_272x1_260x64_260x64(uint16_t *src, uint32_t *dst, void *stream) { + TROWARGMIN_uint32_half_272x1_260x64_260x64<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMIN_uint32_half_16x1_3x8192_3x8191(__gm__ uint16_t *src, __gm__ uint32_t *dst); +void LaunchTROWARGMIN_uint32_half_16x1_3x8192_3x8191(uint16_t *src, uint32_t *dst, void *stream) { + TROWARGMIN_uint32_half_16x1_3x8192_3x8191<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMIN_uint32_half_16x1_1x16384_1x16381(__gm__ uint16_t *src, __gm__ uint32_t *dst); +void LaunchTROWARGMIN_uint32_half_16x1_1x16384_1x16381(uint16_t *src, uint32_t *dst, void *stream) { + TROWARGMIN_uint32_half_16x1_1x16384_1x16381<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMIN_uint32_half_16x1_1x32768_1x32761(__gm__ uint16_t *src, __gm__ uint32_t *dst); +void LaunchTROWARGMIN_uint32_half_16x1_1x32768_1x32761(uint16_t *src, uint32_t *dst, void *stream) { + TROWARGMIN_uint32_half_16x1_1x32768_1x32761<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMIN_int32_float_16x1_13x16_13x13(__gm__ float *src, __gm__ int32_t *dst); +void LaunchTROWARGMIN_int32_float_16x1_13x16_13x13(float *src, int32_t *dst, void *stream) { + TROWARGMIN_int32_float_16x1_13x16_13x13<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ int32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMIN_int32_half_16x1_13x16_13x13(__gm__ uint16_t *src, __gm__ int32_t *dst); +void LaunchTROWARGMIN_int32_half_16x1_13x16_13x13(uint16_t *src, int32_t *dst, void *stream) { + TROWARGMIN_int32_half_16x1_13x16_13x13<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ int32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMIN_uint32_float_3x8_3x3480_3x3473(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMIN_uint32_float_3x8_3x3480_3x3473(float *src, uint32_t *dst, void *stream) { + TROWARGMIN_uint32_float_3x8_3x3480_3x3473<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMIN_uint32_float_260x8_260x64_260x64(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMIN_uint32_float_260x8_260x64_260x64(float *src, uint32_t *dst, void *stream) { + TROWARGMIN_uint32_float_260x8_260x64_260x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMIN_uint32_float_1023x8_1023x24_1023x17(__gm__ float *src, __gm__ uint32_t *dst); +void LaunchTROWARGMIN_uint32_float_1023x8_1023x24_1023x17(float *src, uint32_t *dst, void *stream) { + TROWARGMIN_uint32_float_1023x8_1023x24_1023x17<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMIN_uint32_half_3x16_3x3488_3x3473(__gm__ uint16_t *src, __gm__ uint32_t *dst); +void LaunchTROWARGMIN_uint32_half_3x16_3x3488_3x3473(uint16_t *src, uint32_t *dst, void *stream) { + TROWARGMIN_uint32_half_3x16_3x3488_3x3473<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMIN_uint32_half_260x16_260x64_260x64(__gm__ uint16_t *src, __gm__ uint32_t *dst); +void LaunchTROWARGMIN_uint32_half_260x16_260x64_260x64(uint16_t *src, uint32_t *dst, void *stream) { + TROWARGMIN_uint32_half_260x16_260x64_260x64<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint32_t *)dst); +} + +extern "C" __global__ AICORE void TROWARGMIN_uint32_half_1023x16_1023x32_1023x17(__gm__ uint16_t *src, __gm__ uint32_t *dst); +void LaunchTROWARGMIN_uint32_half_1023x16_1023x32_1023x17(uint16_t *src, uint32_t *dst, void *stream) { + TROWARGMIN_uint32_half_1023x16_1023x32_1023x17<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint32_t *)dst); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowargmin/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowargmin/main.cpp new file mode 100644 index 000000000..997db22a2 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowargmin/main.cpp @@ -0,0 +1,206 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang trowargmin ST — case-table driven. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTROWARGMIN_uint32_float_8x1_8x8_8x8(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMIN_uint32_float_1024x1_1024x8_1024x8(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMIN_uint32_float_16x1_13x16_13x13(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMIN_uint32_float_1024x1_1023x24_1023x17(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMIN_uint32_float_8x1_8x64_8x64(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMIN_uint32_float_264x1_260x64_260x64(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMIN_uint32_float_8x1_1x128_1x128(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMIN_uint32_float_64x1_32x128_32x128(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMIN_uint32_float_8x1_3x4096_3x4095(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMIN_uint32_float_8x1_2x16384_2x16381(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMIN_uint32_half_16x1_2x16_2x16(uint16_t *src, uint32_t *dst, void *stream); +void LaunchTROWARGMIN_uint32_half_16x1_13x16_13x13(uint16_t *src, uint32_t *dst, void *stream); +void LaunchTROWARGMIN_uint32_half_272x1_260x64_260x64(uint16_t *src, uint32_t *dst, void *stream); +void LaunchTROWARGMIN_uint32_half_16x1_3x8192_3x8191(uint16_t *src, uint32_t *dst, void *stream); +void LaunchTROWARGMIN_uint32_half_16x1_1x16384_1x16381(uint16_t *src, uint32_t *dst, void *stream); +void LaunchTROWARGMIN_uint32_half_16x1_1x32768_1x32761(uint16_t *src, uint32_t *dst, void *stream); +void LaunchTROWARGMIN_int32_float_16x1_13x16_13x13(float *src, int32_t *dst, void *stream); +void LaunchTROWARGMIN_int32_half_16x1_13x16_13x13(uint16_t *src, int32_t *dst, void *stream); +void LaunchTROWARGMIN_uint32_float_3x8_3x3480_3x3473(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMIN_uint32_float_260x8_260x64_260x64(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMIN_uint32_float_1023x8_1023x24_1023x17(float *src, uint32_t *dst, void *stream); +void LaunchTROWARGMIN_uint32_half_3x16_3x3488_3x3473(uint16_t *src, uint32_t *dst, void *stream); +void LaunchTROWARGMIN_uint32_half_260x16_260x64_260x64(uint16_t *src, uint32_t *dst, void *stream); +void LaunchTROWARGMIN_uint32_half_1023x16_1023x32_1023x17(uint16_t *src, uint32_t *dst, void *stream); + +using LaunchFnF32U32 = void (*)(float *, uint32_t *, void *); +using LaunchFnF16U32 = void (*)(uint16_t *, uint32_t *, void *); +using LaunchFnF32S32 = void (*)(float *, int32_t *, void *); +using LaunchFnF16S32 = void (*)(uint16_t *, int32_t *, void *); + +enum class DType { F32U32, F16U32, F32S32, F16S32 }; + +struct TestCase { + const char *name; + DType dtype; + union { + LaunchFnF32U32 launchF32U32; + LaunchFnF16U32 launchF16U32; + LaunchFnF32S32 launchF32S32; + LaunchFnF16S32 launchF16S32; + }; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t srcElemSize; // bytes per src element + size_t dstElemSize; // bytes per dst element + size_t dstCols; // dst tile cols +}; + +static const TestCase kCases[] = { + {"uint32_float_8x1_8x8_8x8", DType::F32U32, .launchF32U32 = LaunchTROWARGMIN_uint32_float_8x1_8x8_8x8, 8, 8, 8, 8, 4, 4, 1}, + {"uint32_float_1024x1_1024x8_1024x8", DType::F32U32, .launchF32U32 = LaunchTROWARGMIN_uint32_float_1024x1_1024x8_1024x8, 1024, 8, 1024, 8, 4, 4, 1}, + {"uint32_float_16x1_13x16_13x13", DType::F32U32, .launchF32U32 = LaunchTROWARGMIN_uint32_float_16x1_13x16_13x13, 13, 16, 13, 13, 4, 4, 1}, + {"uint32_float_1024x1_1023x24_1023x17", DType::F32U32, .launchF32U32 = LaunchTROWARGMIN_uint32_float_1024x1_1023x24_1023x17, 1023, 24, 1023, 17, 4, 4, 1}, + {"uint32_float_8x1_8x64_8x64", DType::F32U32, .launchF32U32 = LaunchTROWARGMIN_uint32_float_8x1_8x64_8x64, 8, 64, 8, 64, 4, 4, 1}, + {"uint32_float_264x1_260x64_260x64", DType::F32U32, .launchF32U32 = LaunchTROWARGMIN_uint32_float_264x1_260x64_260x64, 260, 64, 260, 64, 4, 4, 1}, + {"uint32_float_8x1_1x128_1x128", DType::F32U32, .launchF32U32 = LaunchTROWARGMIN_uint32_float_8x1_1x128_1x128, 1, 128, 1, 128, 4, 4, 1}, + {"uint32_float_64x1_32x128_32x128", DType::F32U32, .launchF32U32 = LaunchTROWARGMIN_uint32_float_64x1_32x128_32x128, 32, 128, 32, 128, 4, 4, 1}, + {"uint32_float_8x1_3x4096_3x4095", DType::F32U32, .launchF32U32 = LaunchTROWARGMIN_uint32_float_8x1_3x4096_3x4095, 3, 4096, 3, 4095, 4, 4, 1}, + {"uint32_float_8x1_2x16384_2x16381", DType::F32U32, .launchF32U32 = LaunchTROWARGMIN_uint32_float_8x1_2x16384_2x16381, 2, 16384, 2, 16381, 4, 4, 1}, + {"uint32_half_16x1_2x16_2x16", DType::F16U32, .launchF16U32 = LaunchTROWARGMIN_uint32_half_16x1_2x16_2x16, 2, 16, 2, 16, 2, 4, 1}, + {"uint32_half_16x1_13x16_13x13", DType::F16U32, .launchF16U32 = LaunchTROWARGMIN_uint32_half_16x1_13x16_13x13, 13, 16, 13, 13, 2, 4, 1}, + {"uint32_half_272x1_260x64_260x64", DType::F16U32, .launchF16U32 = LaunchTROWARGMIN_uint32_half_272x1_260x64_260x64, 260, 64, 260, 64, 2, 4, 1}, + {"uint32_half_16x1_3x8192_3x8191", DType::F16U32, .launchF16U32 = LaunchTROWARGMIN_uint32_half_16x1_3x8192_3x8191, 3, 8192, 3, 8191, 2, 4, 1}, + {"uint32_half_16x1_1x16384_1x16381", DType::F16U32, .launchF16U32 = LaunchTROWARGMIN_uint32_half_16x1_1x16384_1x16381, 1, 16384, 1, 16381, 2, 4, 1}, + {"uint32_half_16x1_1x32768_1x32761", DType::F16U32, .launchF16U32 = LaunchTROWARGMIN_uint32_half_16x1_1x32768_1x32761, 1, 32768, 1, 32761, 2, 4, 1}, + {"int32_float_16x1_13x16_13x13", DType::F32S32, .launchF32S32 = LaunchTROWARGMIN_int32_float_16x1_13x16_13x13, 13, 16, 13, 13, 4, 4, 1}, + {"int32_half_16x1_13x16_13x13", DType::F16S32, .launchF16S32 = LaunchTROWARGMIN_int32_half_16x1_13x16_13x13, 13, 16, 13, 13, 2, 4, 1}, + {"uint32_float_3x8_3x3480_3x3473", DType::F32U32, .launchF32U32 = LaunchTROWARGMIN_uint32_float_3x8_3x3480_3x3473, 3, 3480, 3, 3473, 4, 4, 8}, + {"uint32_float_260x8_260x64_260x64", DType::F32U32, .launchF32U32 = LaunchTROWARGMIN_uint32_float_260x8_260x64_260x64, 260, 64, 260, 64, 4, 4, 8}, + {"uint32_float_1023x8_1023x24_1023x17", DType::F32U32, .launchF32U32 = LaunchTROWARGMIN_uint32_float_1023x8_1023x24_1023x17, 1023, 24, 1023, 17, 4, 4, 8}, + {"uint32_half_3x16_3x3488_3x3473", DType::F16U32, .launchF16U32 = LaunchTROWARGMIN_uint32_half_3x16_3x3488_3x3473, 3, 3488, 3, 3473, 2, 4, 16}, + {"uint32_half_260x16_260x64_260x64", DType::F16U32, .launchF16U32 = LaunchTROWARGMIN_uint32_half_260x16_260x64_260x64, 260, 64, 260, 64, 2, 4, 16}, + {"uint32_half_1023x16_1023x32_1023x17", DType::F16U32, .launchF16U32 = LaunchTROWARGMIN_uint32_half_1023x16_1023x32_1023x17, 1023, 32, 1023, 17, 2, 4, 16}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t srcElemCount = tc.rows * tc.cols; + const size_t srcFileSize = srcElemCount * tc.srcElemSize; + const size_t dstElemCount = tc.validRows * tc.dstCols; + const size_t dstFileSize = dstElemCount * tc.dstElemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t src0FileSize = srcFileSize; + + void *src0Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&src0Host, srcFileSize); + aclrtMallocHost(&dstHost, dstFileSize); + + aclrtMalloc(&src0Device, srcFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, srcFileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, srcFileSize, src0Host, srcFileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + switch (tc.dtype) { + case DType::F32U32: + tc.launchF32U32((float *)src0Device, (uint32_t *)dstDevice, stream); + break; + case DType::F16U32: + tc.launchF16U32((uint16_t *)src0Device, (uint32_t *)dstDevice, stream); + break; + case DType::F32S32: + tc.launchF32S32((float *)src0Device, (int32_t *)dstDevice, stream); + break; + case DType::F16S32: + tc.launchF16S32((uint16_t *)src0Device, (int32_t *)dstDevice, stream); + break; + } + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0) { + mkdir(caseDir.c_str(), 0755); + if (!WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./trowargmin [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowargmin/trowargmin.pto b/test/tilelang_st/npu/a5/src/st/testcase/trowargmin/trowargmin.pto new file mode 100644 index 000000000..97e9432d7 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowargmin/trowargmin.pto @@ -0,0 +1,1013 @@ +// Auto-generated trowargmin ST testcases + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + + func.func @TROWARGMIN_uint32_float_8x1_8x8_8x8(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8_r = arith.constant 8 : index + %c8_c = arith.constant 8 : index + %c64_se = arith.constant 64 : index + %c8_de = arith.constant 8 : index + %c1_dc = arith.constant 1 : index + %c8_vr = arith.constant 8 : index + %c8_vc = arith.constant 8 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c8_r, %c8_c], + strides = [%c64_se, %c64_se, %c64_se, %c8_c, %c1] + : !pto.tensor_view<1x1x1x8x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c8_r, %c1_dc], + strides = [%c8_de, %c8_de, %c8_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x8x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8_vr, %c8_vc] + : !pto.tensor_view<1x1x1x8x8xf32> -> !pto.partition_tensor_view<1x1x1x8x8xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8_vr, %c1] + : !pto.tensor_view<1x1x1x8x1xui32> -> !pto.partition_tensor_view<1x1x1x8x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x8x8xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x8x1xui32>) + return + } + + func.func @TROWARGMIN_uint32_float_1024x1_1024x8_1024x8(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1024_r = arith.constant 1024 : index + %c8_c = arith.constant 8 : index + %c8192_se = arith.constant 8192 : index + %c1024_de = arith.constant 1024 : index + %c1_dc = arith.constant 1 : index + %c1024_vr = arith.constant 1024 : index + %c8_vc = arith.constant 8 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1024_r, %c8_c], + strides = [%c8192_se, %c8192_se, %c8192_se, %c8_c, %c1] + : !pto.tensor_view<1x1x1x1024x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1024_r, %c1_dc], + strides = [%c1024_de, %c1024_de, %c1024_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x1024x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1024_vr, %c8_vc] + : !pto.tensor_view<1x1x1x1024x8xf32> -> !pto.partition_tensor_view<1x1x1x1024x8xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1024_vr, %c1] + : !pto.tensor_view<1x1x1x1024x1xui32> -> !pto.partition_tensor_view<1x1x1x1024x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1024x8xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1024x1xui32>) + return + } + + func.func @TROWARGMIN_uint32_float_16x1_13x16_13x13(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c13_r = arith.constant 13 : index + %c16_c = arith.constant 16 : index + %c208_se = arith.constant 208 : index + %c13_de = arith.constant 13 : index + %c1_dc = arith.constant 1 : index + %c13_vr = arith.constant 13 : index + %c13_vc = arith.constant 13 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c13_r, %c16_c], + strides = [%c208_se, %c208_se, %c208_se, %c16_c, %c1] + : !pto.tensor_view<1x1x1x13x16xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c13_r, %c1_dc], + strides = [%c13_de, %c13_de, %c13_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x13x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c13_vr, %c13_vc] + : !pto.tensor_view<1x1x1x13x16xf32> -> !pto.partition_tensor_view<1x1x1x13x13xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c13_vr, %c1] + : !pto.tensor_view<1x1x1x13x1xui32> -> !pto.partition_tensor_view<1x1x1x13x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x13x13xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x13x1xui32>) + return + } + + func.func @TROWARGMIN_uint32_float_1024x1_1023x24_1023x17(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1023_r = arith.constant 1023 : index + %c24_c = arith.constant 24 : index + %c24552_se = arith.constant 24552 : index + %c1023_de = arith.constant 1023 : index + %c1_dc = arith.constant 1 : index + %c1023_vr = arith.constant 1023 : index + %c17_vc = arith.constant 17 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1023_r, %c24_c], + strides = [%c24552_se, %c24552_se, %c24552_se, %c24_c, %c1] + : !pto.tensor_view<1x1x1x1023x24xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1023_r, %c1_dc], + strides = [%c1023_de, %c1023_de, %c1023_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x1023x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1023_vr, %c17_vc] + : !pto.tensor_view<1x1x1x1023x24xf32> -> !pto.partition_tensor_view<1x1x1x1023x17xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1023_vr, %c1] + : !pto.tensor_view<1x1x1x1023x1xui32> -> !pto.partition_tensor_view<1x1x1x1023x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1023x17xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1023x1xui32>) + return + } + + func.func @TROWARGMIN_uint32_float_8x1_8x64_8x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8_r = arith.constant 8 : index + %c64_c = arith.constant 64 : index + %c512_se = arith.constant 512 : index + %c8_de = arith.constant 8 : index + %c1_dc = arith.constant 1 : index + %c8_vr = arith.constant 8 : index + %c64_vc = arith.constant 64 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c8_r, %c64_c], + strides = [%c512_se, %c512_se, %c512_se, %c64_c, %c1] + : !pto.tensor_view<1x1x1x8x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c8_r, %c1_dc], + strides = [%c8_de, %c8_de, %c8_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x8x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8_vr, %c64_vc] + : !pto.tensor_view<1x1x1x8x64xf32> -> !pto.partition_tensor_view<1x1x1x8x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8_vr, %c1] + : !pto.tensor_view<1x1x1x8x1xui32> -> !pto.partition_tensor_view<1x1x1x8x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x8x64xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x8x1xui32>) + return + } + + func.func @TROWARGMIN_uint32_float_264x1_260x64_260x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c260_r = arith.constant 260 : index + %c64_c = arith.constant 64 : index + %c16640_se = arith.constant 16640 : index + %c260_de = arith.constant 260 : index + %c1_dc = arith.constant 1 : index + %c260_vr = arith.constant 260 : index + %c64_vc = arith.constant 64 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c260_r, %c64_c], + strides = [%c16640_se, %c16640_se, %c16640_se, %c64_c, %c1] + : !pto.tensor_view<1x1x1x260x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c260_r, %c1_dc], + strides = [%c260_de, %c260_de, %c260_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x260x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260_vr, %c64_vc] + : !pto.tensor_view<1x1x1x260x64xf32> -> !pto.partition_tensor_view<1x1x1x260x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260_vr, %c1] + : !pto.tensor_view<1x1x1x260x1xui32> -> !pto.partition_tensor_view<1x1x1x260x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x260x64xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x260x1xui32>) + return + } + + func.func @TROWARGMIN_uint32_float_8x1_1x128_1x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1_r = arith.constant 1 : index + %c128_c = arith.constant 128 : index + %c128_se = arith.constant 128 : index + %c1_de = arith.constant 1 : index + %c1_dc = arith.constant 1 : index + %c1_vr = arith.constant 1 : index + %c128_vc = arith.constant 128 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1_r, %c128_c], + strides = [%c128_se, %c128_se, %c128_se, %c128_c, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1_r, %c1_dc], + strides = [%c1_de, %c1_de, %c1_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x1x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1_vr, %c128_vc] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x128xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1_vr, %c1] + : !pto.tensor_view<1x1x1x1x1xui32> -> !pto.partition_tensor_view<1x1x1x1x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x128xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x1xui32>) + return + } + + func.func @TROWARGMIN_uint32_float_64x1_32x128_32x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32_r = arith.constant 32 : index + %c128_c = arith.constant 128 : index + %c4096_se = arith.constant 4096 : index + %c32_de = arith.constant 32 : index + %c1_dc = arith.constant 1 : index + %c32_vr = arith.constant 32 : index + %c128_vc = arith.constant 128 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c32_r, %c128_c], + strides = [%c4096_se, %c4096_se, %c4096_se, %c128_c, %c1] + : !pto.tensor_view<1x1x1x32x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32_r, %c1_dc], + strides = [%c32_de, %c32_de, %c32_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x32x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32_vr, %c128_vc] + : !pto.tensor_view<1x1x1x32x128xf32> -> !pto.partition_tensor_view<1x1x1x32x128xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32_vr, %c1] + : !pto.tensor_view<1x1x1x32x1xui32> -> !pto.partition_tensor_view<1x1x1x32x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x128xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x1xui32>) + return + } + + func.func @TROWARGMIN_uint32_float_8x1_3x4096_3x4095(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3_r = arith.constant 3 : index + %c4096_c = arith.constant 4096 : index + %c12288_se = arith.constant 12288 : index + %c3_de = arith.constant 3 : index + %c1_dc = arith.constant 1 : index + %c3_vr = arith.constant 3 : index + %c4095_vc = arith.constant 4095 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c3_r, %c4096_c], + strides = [%c12288_se, %c12288_se, %c12288_se, %c4096_c, %c1] + : !pto.tensor_view<1x1x1x3x4096xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c3_r, %c1_dc], + strides = [%c3_de, %c3_de, %c3_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x3x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c3_vr, %c4095_vc] + : !pto.tensor_view<1x1x1x3x4096xf32> -> !pto.partition_tensor_view<1x1x1x3x4095xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c3_vr, %c1] + : !pto.tensor_view<1x1x1x3x1xui32> -> !pto.partition_tensor_view<1x1x1x3x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x3x4095xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x3x1xui32>) + return + } + + func.func @TROWARGMIN_uint32_float_8x1_2x16384_2x16381(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2_r = arith.constant 2 : index + %c16384_c = arith.constant 16384 : index + %c32768_se = arith.constant 32768 : index + %c2_de = arith.constant 2 : index + %c1_dc = arith.constant 1 : index + %c2_vr = arith.constant 2 : index + %c16381_vc = arith.constant 16381 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2_r, %c16384_c], + strides = [%c32768_se, %c32768_se, %c32768_se, %c16384_c, %c1] + : !pto.tensor_view<1x1x1x2x16384xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2_r, %c1_dc], + strides = [%c2_de, %c2_de, %c2_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x2x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2_vr, %c16381_vc] + : !pto.tensor_view<1x1x1x2x16384xf32> -> !pto.partition_tensor_view<1x1x1x2x16381xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2_vr, %c1] + : !pto.tensor_view<1x1x1x2x1xui32> -> !pto.partition_tensor_view<1x1x1x2x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x16381xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x1xui32>) + return + } + + func.func @TROWARGMIN_uint32_half_16x1_2x16_2x16(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2_r = arith.constant 2 : index + %c16_c = arith.constant 16 : index + %c32_se = arith.constant 32 : index + %c2_de = arith.constant 2 : index + %c1_dc = arith.constant 1 : index + %c2_vr = arith.constant 2 : index + %c16_vc = arith.constant 16 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2_r, %c16_c], + strides = [%c32_se, %c32_se, %c32_se, %c16_c, %c1] + : !pto.tensor_view<1x1x1x2x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2_r, %c1_dc], + strides = [%c2_de, %c2_de, %c2_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x2x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2_vr, %c16_vc] + : !pto.tensor_view<1x1x1x2x16xf16> -> !pto.partition_tensor_view<1x1x1x2x16xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2_vr, %c1] + : !pto.tensor_view<1x1x1x2x1xui32> -> !pto.partition_tensor_view<1x1x1x2x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x16xf16>) + outs(%src : !pto.tile_buf) + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x1xui32>) + return + } + + func.func @TROWARGMIN_uint32_half_16x1_13x16_13x13(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c13_r = arith.constant 13 : index + %c16_c = arith.constant 16 : index + %c208_se = arith.constant 208 : index + %c13_de = arith.constant 13 : index + %c1_dc = arith.constant 1 : index + %c13_vr = arith.constant 13 : index + %c13_vc = arith.constant 13 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c13_r, %c16_c], + strides = [%c208_se, %c208_se, %c208_se, %c16_c, %c1] + : !pto.tensor_view<1x1x1x13x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c13_r, %c1_dc], + strides = [%c13_de, %c13_de, %c13_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x13x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c13_vr, %c13_vc] + : !pto.tensor_view<1x1x1x13x16xf16> -> !pto.partition_tensor_view<1x1x1x13x13xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c13_vr, %c1] + : !pto.tensor_view<1x1x1x13x1xui32> -> !pto.partition_tensor_view<1x1x1x13x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x13x13xf16>) + outs(%src : !pto.tile_buf) + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x13x1xui32>) + return + } + + func.func @TROWARGMIN_uint32_half_272x1_260x64_260x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c260_r = arith.constant 260 : index + %c64_c = arith.constant 64 : index + %c16640_se = arith.constant 16640 : index + %c260_de = arith.constant 260 : index + %c1_dc = arith.constant 1 : index + %c260_vr = arith.constant 260 : index + %c64_vc = arith.constant 64 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c260_r, %c64_c], + strides = [%c16640_se, %c16640_se, %c16640_se, %c64_c, %c1] + : !pto.tensor_view<1x1x1x260x64xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c260_r, %c1_dc], + strides = [%c260_de, %c260_de, %c260_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x260x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260_vr, %c64_vc] + : !pto.tensor_view<1x1x1x260x64xf16> -> !pto.partition_tensor_view<1x1x1x260x64xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260_vr, %c1] + : !pto.tensor_view<1x1x1x260x1xui32> -> !pto.partition_tensor_view<1x1x1x260x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x260x64xf16>) + outs(%src : !pto.tile_buf) + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x260x1xui32>) + return + } + + func.func @TROWARGMIN_uint32_half_16x1_3x8192_3x8191(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3_r = arith.constant 3 : index + %c8192_c = arith.constant 8192 : index + %c24576_se = arith.constant 24576 : index + %c3_de = arith.constant 3 : index + %c1_dc = arith.constant 1 : index + %c3_vr = arith.constant 3 : index + %c8191_vc = arith.constant 8191 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c3_r, %c8192_c], + strides = [%c24576_se, %c24576_se, %c24576_se, %c8192_c, %c1] + : !pto.tensor_view<1x1x1x3x8192xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c3_r, %c1_dc], + strides = [%c3_de, %c3_de, %c3_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x3x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c3_vr, %c8191_vc] + : !pto.tensor_view<1x1x1x3x8192xf16> -> !pto.partition_tensor_view<1x1x1x3x8191xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c3_vr, %c1] + : !pto.tensor_view<1x1x1x3x1xui32> -> !pto.partition_tensor_view<1x1x1x3x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x3x8191xf16>) + outs(%src : !pto.tile_buf) + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x3x1xui32>) + return + } + + func.func @TROWARGMIN_uint32_half_16x1_1x16384_1x16381(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1_r = arith.constant 1 : index + %c16384_c = arith.constant 16384 : index + %c16384_se = arith.constant 16384 : index + %c1_de = arith.constant 1 : index + %c1_dc = arith.constant 1 : index + %c1_vr = arith.constant 1 : index + %c16381_vc = arith.constant 16381 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1_r, %c16384_c], + strides = [%c16384_se, %c16384_se, %c16384_se, %c16384_c, %c1] + : !pto.tensor_view<1x1x1x1x16384xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1_r, %c1_dc], + strides = [%c1_de, %c1_de, %c1_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x1x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1_vr, %c16381_vc] + : !pto.tensor_view<1x1x1x1x16384xf16> -> !pto.partition_tensor_view<1x1x1x1x16381xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1_vr, %c1] + : !pto.tensor_view<1x1x1x1x1xui32> -> !pto.partition_tensor_view<1x1x1x1x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x16381xf16>) + outs(%src : !pto.tile_buf) + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x1xui32>) + return + } + + func.func @TROWARGMIN_uint32_half_16x1_1x32768_1x32761(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1_r = arith.constant 1 : index + %c32768_c = arith.constant 32768 : index + %c32768_se = arith.constant 32768 : index + %c1_de = arith.constant 1 : index + %c1_dc = arith.constant 1 : index + %c1_vr = arith.constant 1 : index + %c32761_vc = arith.constant 32761 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1_r, %c32768_c], + strides = [%c32768_se, %c32768_se, %c32768_se, %c32768_c, %c1] + : !pto.tensor_view<1x1x1x1x32768xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1_r, %c1_dc], + strides = [%c1_de, %c1_de, %c1_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x1x1xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1_vr, %c32761_vc] + : !pto.tensor_view<1x1x1x1x32768xf16> -> !pto.partition_tensor_view<1x1x1x1x32761xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1_vr, %c1] + : !pto.tensor_view<1x1x1x1x1xui32> -> !pto.partition_tensor_view<1x1x1x1x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x32761xf16>) + outs(%src : !pto.tile_buf) + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x1xui32>) + return + } + + func.func @TROWARGMIN_int32_float_16x1_13x16_13x13(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c13_r = arith.constant 13 : index + %c16_c = arith.constant 16 : index + %c208_se = arith.constant 208 : index + %c13_de = arith.constant 13 : index + %c1_dc = arith.constant 1 : index + %c13_vr = arith.constant 13 : index + %c13_vc = arith.constant 13 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c13_r, %c16_c], + strides = [%c208_se, %c208_se, %c208_se, %c16_c, %c1] + : !pto.tensor_view<1x1x1x13x16xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c13_r, %c1_dc], + strides = [%c13_de, %c13_de, %c13_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x13x1xi32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c13_vr, %c13_vc] + : !pto.tensor_view<1x1x1x13x16xf32> -> !pto.partition_tensor_view<1x1x1x13x13xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c13_vr, %c1] + : !pto.tensor_view<1x1x1x13x1xi32> -> !pto.partition_tensor_view<1x1x1x13x1xi32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x13x13xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x13x1xi32>) + return + } + + func.func @TROWARGMIN_int32_half_16x1_13x16_13x13(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c13_r = arith.constant 13 : index + %c16_c = arith.constant 16 : index + %c208_se = arith.constant 208 : index + %c13_de = arith.constant 13 : index + %c1_dc = arith.constant 1 : index + %c13_vr = arith.constant 13 : index + %c13_vc = arith.constant 13 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c13_r, %c16_c], + strides = [%c208_se, %c208_se, %c208_se, %c16_c, %c1] + : !pto.tensor_view<1x1x1x13x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c13_r, %c1_dc], + strides = [%c13_de, %c13_de, %c13_de, %c1_dc, %c1] + : !pto.tensor_view<1x1x1x13x1xi32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c13_vr, %c13_vc] + : !pto.tensor_view<1x1x1x13x16xf16> -> !pto.partition_tensor_view<1x1x1x13x13xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c13_vr, %c1] + : !pto.tensor_view<1x1x1x13x1xi32> -> !pto.partition_tensor_view<1x1x1x13x1xi32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x13x13xf16>) + outs(%src : !pto.tile_buf) + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x13x1xi32>) + return + } + + func.func @TROWARGMIN_uint32_float_3x8_3x3480_3x3473(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3_r = arith.constant 3 : index + %c3480_c = arith.constant 3480 : index + %c10440_se = arith.constant 10440 : index + %c24_de = arith.constant 24 : index + %c8_dc = arith.constant 8 : index + %c3_vr = arith.constant 3 : index + %c3473_vc = arith.constant 3473 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c3_r, %c3480_c], + strides = [%c10440_se, %c10440_se, %c10440_se, %c3480_c, %c1] + : !pto.tensor_view<1x1x1x3x3480xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c3_r, %c8_dc], + strides = [%c24_de, %c24_de, %c24_de, %c8_dc, %c1] + : !pto.tensor_view<1x1x1x3x8xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c3_vr, %c3473_vc] + : !pto.tensor_view<1x1x1x3x3480xf32> -> !pto.partition_tensor_view<1x1x1x3x3473xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c3_vr, %c1] + : !pto.tensor_view<1x1x1x3x8xui32> -> !pto.partition_tensor_view<1x1x1x3x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x3x3473xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x3x1xui32>) + return + } + + func.func @TROWARGMIN_uint32_float_260x8_260x64_260x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c260_r = arith.constant 260 : index + %c64_c = arith.constant 64 : index + %c16640_se = arith.constant 16640 : index + %c2080_de = arith.constant 2080 : index + %c8_dc = arith.constant 8 : index + %c260_vr = arith.constant 260 : index + %c64_vc = arith.constant 64 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c260_r, %c64_c], + strides = [%c16640_se, %c16640_se, %c16640_se, %c64_c, %c1] + : !pto.tensor_view<1x1x1x260x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c260_r, %c8_dc], + strides = [%c2080_de, %c2080_de, %c2080_de, %c8_dc, %c1] + : !pto.tensor_view<1x1x1x260x8xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260_vr, %c64_vc] + : !pto.tensor_view<1x1x1x260x64xf32> -> !pto.partition_tensor_view<1x1x1x260x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260_vr, %c1] + : !pto.tensor_view<1x1x1x260x8xui32> -> !pto.partition_tensor_view<1x1x1x260x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x260x64xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x260x1xui32>) + return + } + + func.func @TROWARGMIN_uint32_float_1023x8_1023x24_1023x17(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1023_r = arith.constant 1023 : index + %c24_c = arith.constant 24 : index + %c24552_se = arith.constant 24552 : index + %c8184_de = arith.constant 8184 : index + %c8_dc = arith.constant 8 : index + %c1023_vr = arith.constant 1023 : index + %c17_vc = arith.constant 17 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1023_r, %c24_c], + strides = [%c24552_se, %c24552_se, %c24552_se, %c24_c, %c1] + : !pto.tensor_view<1x1x1x1023x24xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1023_r, %c8_dc], + strides = [%c8184_de, %c8184_de, %c8184_de, %c8_dc, %c1] + : !pto.tensor_view<1x1x1x1023x8xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1023_vr, %c17_vc] + : !pto.tensor_view<1x1x1x1023x24xf32> -> !pto.partition_tensor_view<1x1x1x1023x17xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1023_vr, %c1] + : !pto.tensor_view<1x1x1x1023x8xui32> -> !pto.partition_tensor_view<1x1x1x1023x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1023x17xf32>) + outs(%src : !pto.tile_buf) + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1023x1xui32>) + return + } + + func.func @TROWARGMIN_uint32_half_3x16_3x3488_3x3473(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3_r = arith.constant 3 : index + %c3488_c = arith.constant 3488 : index + %c10464_se = arith.constant 10464 : index + %c48_de = arith.constant 48 : index + %c16_dc = arith.constant 16 : index + %c3_vr = arith.constant 3 : index + %c3473_vc = arith.constant 3473 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c3_r, %c3488_c], + strides = [%c10464_se, %c10464_se, %c10464_se, %c3488_c, %c1] + : !pto.tensor_view<1x1x1x3x3488xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c3_r, %c16_dc], + strides = [%c48_de, %c48_de, %c48_de, %c16_dc, %c1] + : !pto.tensor_view<1x1x1x3x16xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c3_vr, %c3473_vc] + : !pto.tensor_view<1x1x1x3x3488xf16> -> !pto.partition_tensor_view<1x1x1x3x3473xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c3_vr, %c1] + : !pto.tensor_view<1x1x1x3x16xui32> -> !pto.partition_tensor_view<1x1x1x3x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x3x3473xf16>) + outs(%src : !pto.tile_buf) + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x3x1xui32>) + return + } + + func.func @TROWARGMIN_uint32_half_260x16_260x64_260x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c260_r = arith.constant 260 : index + %c64_c = arith.constant 64 : index + %c16640_se = arith.constant 16640 : index + %c4160_de = arith.constant 4160 : index + %c16_dc = arith.constant 16 : index + %c260_vr = arith.constant 260 : index + %c64_vc = arith.constant 64 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c260_r, %c64_c], + strides = [%c16640_se, %c16640_se, %c16640_se, %c64_c, %c1] + : !pto.tensor_view<1x1x1x260x64xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c260_r, %c16_dc], + strides = [%c4160_de, %c4160_de, %c4160_de, %c16_dc, %c1] + : !pto.tensor_view<1x1x1x260x16xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260_vr, %c64_vc] + : !pto.tensor_view<1x1x1x260x64xf16> -> !pto.partition_tensor_view<1x1x1x260x64xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c260_vr, %c1] + : !pto.tensor_view<1x1x1x260x16xui32> -> !pto.partition_tensor_view<1x1x1x260x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x260x64xf16>) + outs(%src : !pto.tile_buf) + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x260x1xui32>) + return + } + + func.func @TROWARGMIN_uint32_half_1023x16_1023x32_1023x17(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1023_r = arith.constant 1023 : index + %c32_c = arith.constant 32 : index + %c32736_se = arith.constant 32736 : index + %c16368_de = arith.constant 16368 : index + %c16_dc = arith.constant 16 : index + %c1023_vr = arith.constant 1023 : index + %c17_vc = arith.constant 17 : index + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1023_r, %c32_c], + strides = [%c32736_se, %c32736_se, %c32736_se, %c32_c, %c1] + : !pto.tensor_view<1x1x1x1023x32xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1023_r, %c16_dc], + strides = [%c16368_de, %c16368_de, %c16368_de, %c16_dc, %c1] + : !pto.tensor_view<1x1x1x1023x16xui32> + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1023_vr, %c17_vc] + : !pto.tensor_view<1x1x1x1023x32xf16> -> !pto.partition_tensor_view<1x1x1x1023x17xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1023_vr, %c1] + : !pto.tensor_view<1x1x1x1023x16xui32> -> !pto.partition_tensor_view<1x1x1x1023x1xui32> + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1023x17xf16>) + outs(%src : !pto.tile_buf) + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1023x1xui32>) + return + } + +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpand/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/trowexpand/CMakeLists.txt new file mode 100644 index 000000000..254ce36e5 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpand/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(trowexpand) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpand/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/trowexpand/cases.py new file mode 100644 index 000000000..44effd8a7 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpand/cases.py @@ -0,0 +1,90 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for trowexpand ST test cases. + +trowexpand is a row broadcast operation: expands a scalar per row to the entire row. +- Input shape: (rows, srcCols) - physical layout for NPU alignment +- srcCols = 32/sizeof(dtype) for 32-byte alignment +- Output shape: (rows, dstCols) - broadcast each scalar across the row +- dstValidCols may be less than dstCols for partial valid region + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32, np.float16, np.int8). + - src0_shape: (rows, srcCols) — physical input tile dimensions. + - src0_valid_shape: (valid_rows, 1) — effective input region. + - dst_shape: (rows, dstCols) — output tile dimensions. + - dst_valid_shape: (valid_rows, valid_cols) — effective output region. + - eps: tolerance for numpy.allclose (atol and rtol). +""" + +import numpy as np + +CASES = [ + # f32 cases (srcCols=8 for 32-byte alignment) + { + "name": "f32_16x128", + "dtype": np.float32, + "src0_shape": (16, 8), + "src0_valid_shape": (16, 1), + "dst_shape": (16, 128), + "dst_valid_shape": (16, 128), + "eps": 1e-6, + }, + { + "name": "f32_16x127", + "dtype": np.float32, + "src0_shape": (16, 8), + "src0_valid_shape": (16, 1), + "dst_shape": (16, 128), + "dst_valid_shape": (16, 127), # partial valid region + "eps": 1e-6, + }, + # f16 cases (srcCols=16 for 32-byte alignment) + { + "name": "f16_16x512", + "dtype": np.float16, + "src0_shape": (16, 16), + "src0_valid_shape": (16, 1), + "dst_shape": (16, 512), + "dst_valid_shape": (16, 512), + "eps": 1e-3, + }, + { + "name": "f16_16x511", + "dtype": np.float16, + "src0_shape": (16, 16), + "src0_valid_shape": (16, 1), + "dst_shape": (16, 512), + "dst_valid_shape": (16, 511), # partial valid region + "eps": 1e-3, + }, + # i8 cases (srcCols=32 for 32-byte alignment) + { + "name": "i8_16x256", + "dtype": np.int8, + "src0_shape": (16, 32), + "src0_valid_shape": (16, 1), + "dst_shape": (16, 256), + "dst_valid_shape": (16, 256), + "eps": 0, # exact match for integers + }, + { + "name": "i8_16x255", + "dtype": np.int8, + "src0_shape": (16, 32), + "src0_valid_shape": (16, 1), + "dst_shape": (16, 256), + "dst_valid_shape": (16, 255), # partial valid region + "eps": 0, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpand/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/trowexpand/compare.py new file mode 100644 index 000000000..bf00d9b08 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpand/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Compare golden and output for trowexpand ST test cases. + +trowexpand: row broadcast operation. +Compare output (rows, cols) against golden (rows, cols). +""" + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass + +# Inline validation for multi-input format (trowexpand uses src0/dst only) +REQUIRED_KEYS = {"name", "dtype", "src0_shape", "src0_valid_shape", "dst_shape", "dst_valid_shape"} +for i, case in enumerate(CASES): + missing = REQUIRED_KEYS - case.keys() + if missing: + raise ValueError(f"cases[{i}] ({case.get('name', '?')}) missing keys: {missing}") + + +def main(): + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + dtype = case["dtype"] + + vr, vc = dst_valid_shape + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=dtype).reshape(dst_shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=dtype).reshape(dst_shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpand/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/trowexpand/gen_data.py new file mode 100644 index 000000000..8eec79d9e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpand/gen_data.py @@ -0,0 +1,52 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Generate input and golden data for trowexpand ST test cases. + +trowexpand: row broadcast operation. +- Input: (rows, 1) - one scalar per row +- Output: (rows, cols) - broadcast each scalar across the entire row +""" + +import numpy as np +from cases import CASES +from st_common import setup_case_rng, save_case_data + +# Inline validation for multi-input format (trowexpand uses src0/dst only) +REQUIRED_KEYS = {"name", "dtype", "src0_shape", "src0_valid_shape", "dst_shape", "dst_valid_shape"} +for i, case in enumerate(CASES): + missing = REQUIRED_KEYS - case.keys() + if missing: + raise ValueError(f"cases[{i}] ({case.get('name', '?')}) missing keys: {missing}") + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + src0_shape = case["src0_shape"] # Physical shape (rows, 8) + src0_valid_shape = case["src0_valid_shape"] # Valid shape (rows, 1) + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + + # Generate input: random values for each row's scalar, padded to 8 columns + # Physical layout: (rows, 8), but only column 0 is valid data + input_data = np.zeros(src0_shape, dtype=dtype) + src_vr = src0_valid_shape[0] + input_data[:src_vr, 0] = np.random.randint(1, 10, size=src_vr).astype(dtype) + + # Generate golden: broadcast each row's scalar across columns + # dst[i, :] = src[i, 0] for all columns + golden = np.zeros(dst_shape, dtype=dtype) + dst_vr, dst_vc = dst_valid_shape + golden[:dst_vr, :dst_vc] = np.broadcast_to(input_data[:src_vr, 0:1], (dst_vr, dst_vc)).astype(dtype, copy=False) + + save_case_data(case["name"], {"input": input_data, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} src0_shape={src0_shape} dst_shape={dst_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpand/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowexpand/launch.cpp new file mode 100644 index 000000000..ab55c0382 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpand/launch.cpp @@ -0,0 +1,46 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// f32 kernels +extern "C" __global__ AICORE void TROWEXPAND_f32_16x128(__gm__ float *src, __gm__ float *dst); +extern "C" __global__ AICORE void TROWEXPAND_f32_16x127(__gm__ float *src, __gm__ float *dst); + +void LaunchTROWEXPAND_f32_16x128(float *src, float *dst, void *stream) { + TROWEXPAND_f32_16x128<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} +void LaunchTROWEXPAND_f32_16x127(float *src, float *dst, void *stream) { + TROWEXPAND_f32_16x127<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +// f16 kernels (use uint16_t for aclFloat16) +extern "C" __global__ AICORE void TROWEXPAND_f16_16x512(__gm__ uint16_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TROWEXPAND_f16_16x511(__gm__ uint16_t *src, __gm__ uint16_t *dst); + +void LaunchTROWEXPAND_f16_16x512(void *src, void *dst, void *stream) { + TROWEXPAND_f16_16x512<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint16_t *)dst); +} +void LaunchTROWEXPAND_f16_16x511(void *src, void *dst, void *stream) { + TROWEXPAND_f16_16x511<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint16_t *)dst); +} + +// i8 kernels +extern "C" __global__ AICORE void TROWEXPAND_i8_16x256(__gm__ int8_t *src, __gm__ int8_t *dst); +extern "C" __global__ AICORE void TROWEXPAND_i8_16x255(__gm__ int8_t *src, __gm__ int8_t *dst); + +void LaunchTROWEXPAND_i8_16x256(void *src, void *dst, void *stream) { + TROWEXPAND_i8_16x256<<<1, nullptr, stream>>>((__gm__ int8_t *)src, (__gm__ int8_t *)dst); +} +void LaunchTROWEXPAND_i8_16x255(void *src, void *dst, void *stream) { + TROWEXPAND_i8_16x255<<<1, nullptr, stream>>>((__gm__ int8_t *)src, (__gm__ int8_t *)dst); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpand/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowexpand/main.cpp new file mode 100644 index 000000000..60413d4e4 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpand/main.cpp @@ -0,0 +1,145 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang trowexpand ST — row broadcast operation. +// Supports multiple data types: f32, f16, i8 + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +// f32 +void LaunchTROWEXPAND_f32_16x128(float *src, float *dst, void *stream); +void LaunchTROWEXPAND_f32_16x127(float *src, float *dst, void *stream); +// f16 +void LaunchTROWEXPAND_f16_16x512(void *src, void *dst, void *stream); +void LaunchTROWEXPAND_f16_16x511(void *src, void *dst, void *stream); +// i8 +void LaunchTROWEXPAND_i8_16x256(void *src, void *dst, void *stream); +void LaunchTROWEXPAND_i8_16x255(void *src, void *dst, void *stream); + +// Generic launch function type +using LaunchFn = void (*)(void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t srcRows; + size_t srcCols; // srcCols = 32/sizeof(dtype) for alignment + size_t dstRows; + size_t dstCols; + size_t dstValidCols; // effective output columns + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + // f32: srcCols=8 (32/4), dstCols=128, dstValidCols=128 or 127 + {"f32_16x128", (LaunchFn)LaunchTROWEXPAND_f32_16x128, 16, 8, 16, 128, 128, sizeof(float)}, + {"f32_16x127", (LaunchFn)LaunchTROWEXPAND_f32_16x127, 16, 8, 16, 128, 127, sizeof(float)}, + // f16: srcCols=16 (32/2), dstCols=512, dstValidCols=512 or 511 + {"f16_16x512", LaunchTROWEXPAND_f16_16x512, 16, 16, 16, 512, 512, sizeof(uint16_t)}, + {"f16_16x511", LaunchTROWEXPAND_f16_16x511, 16, 16, 16, 512, 511, sizeof(uint16_t)}, + // i8: srcCols=32 (32/1), dstCols=256, dstValidCols=256 or 255 + {"i8_16x256", LaunchTROWEXPAND_i8_16x256, 16, 32, 16, 256, 256, sizeof(int8_t)}, + {"i8_16x255", LaunchTROWEXPAND_i8_16x255, 16, 32, 16, 256, 255, sizeof(int8_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + size_t srcFileSize = tc.srcRows * tc.srcCols * tc.elemSize; + const size_t dstFileSize = tc.dstRows * tc.dstCols * tc.elemSize; + + std::printf("[INFO] === case: %s (src=%zux%zu, dst=%zux%zu, valid_cols=%zu) ===\n", + tc.name, tc.srcRows, tc.srcCols, tc.dstRows, tc.dstCols, tc.dstValidCols); + + std::string caseDir = std::string("./") + tc.name; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&srcHost), srcFileSize); + aclrtMallocHost((void **)(&dstHost), dstFileSize); + + aclrtMalloc((void **)&srcDevice, srcFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input.bin").c_str(), srcFileSize, srcHost, srcFileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, srcFileSize, srcHost, srcFileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpand/trowexpand.pto b/test/tilelang_st/npu/a5/src/st/testcase/trowexpand/trowexpand.pto new file mode 100644 index 000000000..993a11e02 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpand/trowexpand.pto @@ -0,0 +1,276 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.trowexpand: row broadcast operation. +// dst[row, col] = src[row, 0] (broadcast scalar per row) +// srcCols = 32/sizeof(dtype) for NPU 32-byte alignment + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // f32_16x128: rows=16, srcCols=8, dstValidCols=128, dstCols=128 + func.func @TROWEXPAND_f32_16x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c8], + strides = [%c128, %c128, %c128, %c8, %c1] + : !pto.tensor_view<1x1x1x16x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c8] + : !pto.tensor_view<1x1x1x16x8xf32> -> !pto.partition_tensor_view<1x1x1x16x8xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x128xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x8xf32>) + outs(%src : !pto.tile_buf) + + pto.trowexpand ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x128xf32>) + return + } + + // f32_16x127: rows=16, srcCols=8, dstValidCols=127, dstCols=128 + func.func @TROWEXPAND_f32_16x127(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c127 = arith.constant 127 : index + %c128 = arith.constant 128 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c8], + strides = [%c128, %c128, %c128, %c8, %c1] + : !pto.tensor_view<1x1x1x16x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c8] + : !pto.tensor_view<1x1x1x16x8xf32> -> !pto.partition_tensor_view<1x1x1x16x8xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c127] + : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x127xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x8xf32>) + outs(%src : !pto.tile_buf) + + pto.trowexpand ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x127xf32>) + return + } + + // f16_16x512: rows=16, srcCols=16, dstValidCols=512, dstCols=512 + func.func @TROWEXPAND_f16_16x512(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c16], + strides = [%c256, %c256, %c256, %c16, %c1] + : !pto.tensor_view<1x1x1x16x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c512], + strides = [%c8192, %c8192, %c8192, %c512, %c1] + : !pto.tensor_view<1x1x1x16x512xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c16] + : !pto.tensor_view<1x1x1x16x16xf16> -> !pto.partition_tensor_view<1x1x1x16x16xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c512] + : !pto.tensor_view<1x1x1x16x512xf16> -> !pto.partition_tensor_view<1x1x1x16x512xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x16xf16>) + outs(%src : !pto.tile_buf) + + pto.trowexpand ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x512xf16>) + return + } + + // f16_16x511: rows=16, srcCols=16, dstValidCols=511, dstCols=512 + func.func @TROWEXPAND_f16_16x511(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c511 = arith.constant 511 : index + %c512 = arith.constant 512 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c16], + strides = [%c256, %c256, %c256, %c16, %c1] + : !pto.tensor_view<1x1x1x16x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c512], + strides = [%c8192, %c8192, %c8192, %c512, %c1] + : !pto.tensor_view<1x1x1x16x512xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c16] + : !pto.tensor_view<1x1x1x16x16xf16> -> !pto.partition_tensor_view<1x1x1x16x16xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c511] + : !pto.tensor_view<1x1x1x16x512xf16> -> !pto.partition_tensor_view<1x1x1x16x511xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x16xf16>) + outs(%src : !pto.tile_buf) + + pto.trowexpand ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x511xf16>) + return + } + + // i8_16x256: rows=16, srcCols=32, dstValidCols=256, dstCols=256 + func.func @TROWEXPAND_i8_16x256(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xi8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xi8> -> !pto.partition_tensor_view<1x1x1x16x32xi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c256] + : !pto.tensor_view<1x1x1x16x256xi8> -> !pto.partition_tensor_view<1x1x1x16x256xi8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x32xi8>) + outs(%src : !pto.tile_buf) + + pto.trowexpand ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x256xi8>) + return + } + + // i8_16x255: rows=16, srcCols=32, dstValidCols=255, dstCols=256 + func.func @TROWEXPAND_i8_16x255(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c255 = arith.constant 255 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xi8> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xi8> -> !pto.partition_tensor_view<1x1x1x16x32xi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c255] + : !pto.tensor_view<1x1x1x16x256xi8> -> !pto.partition_tensor_view<1x1x1x16x255xi8> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x32xi8>) + outs(%src : !pto.tile_buf) + + pto.trowexpand ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x255xi8>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandadd/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandadd/CMakeLists.txt new file mode 100644 index 000000000..47f7afb3f --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandadd/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(trowexpandadd) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandadd/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandadd/cases.py new file mode 100644 index 000000000..cf5e55bb0 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandadd/cases.py @@ -0,0 +1,116 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for trowexpandadd ST test cases. + +trowexpandadd: dst = src0 + broadcast(src1) across columns. +- src1Col determines how src1 is broadcast: + - src1Col=1: only first column is valid, broadcast to dstCols + - src1Col=8 (for f32): 8 columns are valid, no broadcast needed +- src1Cols (physical) = 32/sizeof(dtype) for NPU alignment + +Template parameters: + - dstRow, dstCol: dst shape + - src1Row, src1Col: src1 shape (src1Col is valid columns, not physical) + - src0eqdst: true means src0 shape equals dst, false means different +""" + +import numpy as np + +CASES = [ + # launchTRowExpandAdd + { + "name": "f32_16x32", + "dtype": np.float32, + "src0_shape": (16, 32), # src0eqdst=true, same as dst + "src0_valid_shape": (16, 32), + "src1_shape": (16, 8), # physical: 32/sizeof(f32)=8 + "src1_valid_shape": (16, 1), # src1Col=1, only first column valid + "dst_shape": (16, 32), + "dst_valid_shape": (16, 32), + "eps": 1e-6, + }, + # launchTRowExpandAdd + { + "name": "f32_56x128", + "dtype": np.float32, + "src0_shape": (56, 128), # src0eqdst=true + "src0_valid_shape": (56, 128), + "src1_shape": (56, 8), # physical: 8 + "src1_valid_shape": (56, 1), # src1Col=1 + "dst_shape": (56, 128), + "dst_valid_shape": (56, 128), + "eps": 1e-6, + }, + # launchTRowExpandAdd + { + "name": "f16_48x64", + "dtype": np.float16, + "src0_shape": (48, 64), # src0eqdst=true + "src0_valid_shape": (48, 64), + "src1_shape": (48, 16), # physical: 32/sizeof(f16)=16 + "src1_valid_shape": (48, 1), # src1Col=1 + "dst_shape": (48, 64), + "dst_valid_shape": (48, 64), + "eps": 1e-3, + }, + # launchTRowExpandAdd + { + "name": "f16_16x128", + "dtype": np.float16, + "src0_shape": (16, 128), # src0eqdst=true + "src0_valid_shape": (16, 128), + "src1_shape": (16, 16), # physical: 16 + "src1_valid_shape": (16, 1), # src1Col=1 + "dst_shape": (16, 128), + "dst_valid_shape": (16, 128), + "eps": 1e-3, + }, + # Note: launchTRowExpandAdd2 with src1Col=8 has different semantics - TBD + # launchTRowExpandAdd2 - needs investigation + # launchTRowExpandAdd2 - needs investigation + # launchTRowExpandAdd + { + "name": "f16_32x64_noeq", + "dtype": np.float16, + "src0_shape": (32, 64), # src0eqdst=false, but src0 shape still matches dst + "src0_valid_shape": (32, 64), + "src1_shape": (32, 16), # physical: 16 + "src1_valid_shape": (32, 1), # src1Col=1 + "dst_shape": (32, 64), + "dst_valid_shape": (32, 64), + "eps": 1e-3, + }, + # launchTRowExpandAdd + { + "name": "i32_16x32", + "dtype": np.int32, + "src0_shape": (16, 32), # src0eqdst=true + "src0_valid_shape": (16, 32), + "src1_shape": (16, 8), # physical: 32/sizeof(i32)=8 + "src1_valid_shape": (16, 1), # src1Col=1 + "dst_shape": (16, 32), + "dst_valid_shape": (16, 32), + "eps": 0, + }, + # launchTRowExpandAdd + { + "name": "i16_16x64", + "dtype": np.int16, + "src0_shape": (16, 64), # src0eqdst=true + "src0_valid_shape": (16, 64), + "src1_shape": (16, 16), # physical: 32/sizeof(i16)=16 + "src1_valid_shape": (16, 1), # src1Col=1 + "dst_shape": (16, 64), + "dst_valid_shape": (16, 64), + "eps": 0, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandadd/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandadd/compare.py new file mode 100644 index 000000000..c88279ea7 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandadd/compare.py @@ -0,0 +1,61 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Compare golden and output for trowexpandadd ST test cases.""" + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass + +# Inline validation for multi-input format (trowexpandadd uses src0/src1/dst) +REQUIRED_KEYS = {"name", "dtype", "src0_shape", "src0_valid_shape", "src1_shape", + "src1_valid_shape", "dst_shape", "dst_valid_shape"} +for i, case in enumerate(CASES): + missing = REQUIRED_KEYS - case.keys() + if missing: + raise ValueError(f"cases[{i}] ({case.get('name', '?')}) missing keys: {missing}") + + +def main(): + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + dtype = case["dtype"] + + vr, vc = dst_valid_shape + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=dtype).reshape(dst_shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=dtype).reshape(dst_shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandadd/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandadd/gen_data.py new file mode 100644 index 000000000..b13261332 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandadd/gen_data.py @@ -0,0 +1,69 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Generate input and golden data for trowexpandadd ST test cases. + +trowexpandadd: dst = src0 + broadcast(src1) across columns. +- src1Col=1: only first column valid, broadcast to all dst columns +- src1Col>1: each src1 column maps to a block of dst columns (dstCol/src1Col) +""" + +import numpy as np +from cases import CASES +from st_common import setup_case_rng, save_case_data + +# Inline validation for multi-input format (trowexpandadd uses src0/src1/dst) +REQUIRED_KEYS = {"name", "dtype", "src0_shape", "src0_valid_shape", "src1_shape", + "src1_valid_shape", "dst_shape", "dst_valid_shape"} +for i, case in enumerate(CASES): + missing = REQUIRED_KEYS - case.keys() + if missing: + raise ValueError(f"cases[{i}] ({case.get('name', '?')}) missing keys: {missing}") + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + src0_shape = case["src0_shape"] + src0_valid_shape = case["src0_valid_shape"] + src1_shape = case["src1_shape"] + src1_valid_shape = case["src1_valid_shape"] + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + + # Generate inputs + input1 = np.random.randint(1, 10, size=src0_shape).astype(dtype) # src0 matrix + input2 = np.random.randint(1, 10, size=src1_shape).astype(dtype) # src1 row vectors + + # Generate golden: dst = src0 + broadcast(src1) + golden = np.zeros(dst_shape, dtype=dtype) + dst_vr, dst_vc = dst_valid_shape + src0_vr, src0_vc = src0_valid_shape + src1_vr, src1_vc = src1_valid_shape + + if src1_vc == 1: + # src1Col=1: broadcast first column to all dst columns + golden[:dst_vr, :dst_vc] = ( + input1[:src0_vr, :src0_vc] + input2[:src1_vr, 0:1] + ).astype(dtype, copy=False) + else: + # src1Col>1: each src1 column maps to dstCol/src1_vc columns + # dst[:, block*repeat:(block+1)*repeat] = src0 + src1[:, block:block+1] + repeat = dst_vc // src1_vc + for block in range(src1_vc): + start_col = block * repeat + end_col = min((block + 1) * repeat, dst_vc) + golden[:dst_vr, start_col:end_col] = ( + input1[:src0_vr, start_col:end_col] + input2[:src1_vr, block:block+1] + ).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} src0={src0_shape} src1={src1_shape} dst={dst_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandadd/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandadd/launch.cpp new file mode 100644 index 000000000..5bde96197 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandadd/launch.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// f32 kernels +extern "C" __global__ AICORE void TROWEXPANDADD_f32_16x32(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); +extern "C" __global__ AICORE void TROWEXPANDADD_f32_56x128(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTROWEXPANDADD_f32_16x32(float *src0, float *src1, float *dst, void *stream) { + TROWEXPANDADD_f32_16x32<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} +void LaunchTROWEXPANDADD_f32_56x128(float *src0, float *src1, float *dst, void *stream) { + TROWEXPANDADD_f32_56x128<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// Note: launchTRowExpandAdd2 with src1Col=8 has different semantics - TBD + +// f16 kernels (use uint16_t for aclFloat16) +extern "C" __global__ AICORE void TROWEXPANDADD_f16_48x64(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TROWEXPANDADD_f16_16x128(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TROWEXPANDADD_f16_32x64(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTROWEXPANDADD_f16_48x64(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDADD_f16_48x64<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} +void LaunchTROWEXPANDADD_f16_16x128(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDADD_f16_16x128<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} +void LaunchTROWEXPANDADD_f16_32x64(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDADD_f16_32x64<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} + +// i32 kernels +extern "C" __global__ AICORE void TROWEXPANDADD_i32_16x32(__gm__ int32_t *src0, __gm__ int32_t *src1, __gm__ int32_t *dst); + +void LaunchTROWEXPANDADD_i32_16x32(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDADD_i32_16x32<<<1, nullptr, stream>>>((__gm__ int32_t *)src0, (__gm__ int32_t *)src1, (__gm__ int32_t *)dst); +} + +// i16 kernels +extern "C" __global__ AICORE void TROWEXPANDADD_i16_16x64(__gm__ int16_t *src0, __gm__ int16_t *src1, __gm__ int16_t *dst); + +void LaunchTROWEXPANDADD_i16_16x64(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDADD_i16_16x64<<<1, nullptr, stream>>>((__gm__ int16_t *)src0, (__gm__ int16_t *)src1, (__gm__ int16_t *)dst); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandadd/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandadd/main.cpp new file mode 100644 index 000000000..7e6b1dce6 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandadd/main.cpp @@ -0,0 +1,166 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang trowexpandadd ST — row-wise broadcast addition. +// Supports multiple data types: f32, f16, i32, i16 + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +// f32 +void LaunchTROWEXPANDADD_f32_16x32(float *src0, float *src1, float *dst, void *stream); +void LaunchTROWEXPANDADD_f32_56x128(float *src0, float *src1, float *dst, void *stream); +// f16 (use void* for aclFloat16) +void LaunchTROWEXPANDADD_f16_48x64(void *src0, void *src1, void *dst, void *stream); +void LaunchTROWEXPANDADD_f16_16x128(void *src0, void *src1, void *dst, void *stream); +void LaunchTROWEXPANDADD_f16_32x64(void *src0, void *src1, void *dst, void *stream); +// i32 +void LaunchTROWEXPANDADD_i32_16x32(void *src0, void *src1, void *dst, void *stream); +// i16 +void LaunchTROWEXPANDADD_i16_16x64(void *src0, void *src1, void *dst, void *stream); + +// Note: launchTRowExpandAdd2 with src1Col=8 has different semantics - TBD + +// Generic launch function type +using LaunchFn = void (*)(void *, void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t src0Rows; + size_t src0Cols; + size_t src1Rows; + size_t src1Cols; // physical src1 cols = 32/sizeof(dtype) + size_t dstRows; + size_t dstCols; + size_t dstValidCols; // effective dst cols + size_t elemSize; +}; + +static const TestCase kCases[] = { + // f32 cases + {"f32_16x32", (LaunchFn)LaunchTROWEXPANDADD_f32_16x32, 16, 32, 16, 8, 16, 32, 32, sizeof(float)}, + {"f32_56x128", (LaunchFn)LaunchTROWEXPANDADD_f32_56x128, 56, 128, 56, 8, 56, 128, 128, sizeof(float)}, + // Note: f32_24x64_v2 and f32_20x64_v2_noeq have different semantics - TBD + // f16 cases + {"f16_48x64", LaunchTROWEXPANDADD_f16_48x64, 48, 64, 48, 16, 48, 64, 64, sizeof(uint16_t)}, + {"f16_16x128", LaunchTROWEXPANDADD_f16_16x128, 16, 128, 16, 16, 16, 128, 128, sizeof(uint16_t)}, + {"f16_32x64_noeq", LaunchTROWEXPANDADD_f16_32x64, 32, 64, 32, 16, 32, 64, 64, sizeof(uint16_t)}, + // i32 cases + {"i32_16x32", LaunchTROWEXPANDADD_i32_16x32, 16, 32, 16, 8, 16, 32, 32, sizeof(int32_t)}, + // i16 cases + {"i16_16x64", LaunchTROWEXPANDADD_i16_16x64, 16, 64, 16, 16, 16, 64, 64, sizeof(int16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + size_t src0FileSize = tc.src0Rows * tc.src0Cols * tc.elemSize; + size_t src1FileSize = tc.src1Rows * tc.src1Cols * tc.elemSize; + size_t dstFileSize = tc.dstRows * tc.dstCols * tc.elemSize; + + std::printf("[INFO] === case: %s (src0=%zux%zu, src1=%zux%zu, dst=%zux%zu) ===\n", + tc.name, tc.src0Rows, tc.src0Cols, tc.src1Rows, tc.src1Cols, tc.dstRows, tc.dstCols); + + std::string caseDir = std::string("./") + tc.name; + + void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), src0FileSize); + aclrtMallocHost((void **)(&src1Host), src1FileSize); + aclrtMallocHost((void **)(&dstHost), dstFileSize); + + aclrtMalloc((void **)&src0Device, src0FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, src1FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, src0FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1Host, src1FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, src0FileSize, src0Host, src0FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, src1FileSize, src1Host, src1FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandadd/trowexpandadd.pto b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandadd/trowexpandadd.pto new file mode 100644 index 000000000..5208aea51 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandadd/trowexpandadd.pto @@ -0,0 +1,413 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.trowexpandadd: row-wise broadcast addition. +// dst = src0 + broadcast(src1) where src1 is expanded across columns. +// src1 physical cols = 32/sizeof(dtype) for NPU alignment +// src1 v_col = src1Col from template (1 or 8) + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // f32_16x32: dstRow=16, dstCol=32, src1Row=16, src1Col=1 + func.func @TROWEXPANDADD_f32_16x32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c16, %c8], + strides = [%c128, %c128, %c128, %c8, %c1] + : !pto.tensor_view<1x1x1x16x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xf32> -> !pto.partition_tensor_view<1x1x1x16x32xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c8] + : !pto.tensor_view<1x1x1x16x8xf32> -> !pto.partition_tensor_view<1x1x1x16x8xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xf32> -> !pto.partition_tensor_view<1x1x1x16x32xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x32xf32>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x8xf32>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandadd ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x32xf32>) + return + } + + // f32_56x128: dstRow=56, dstCol=128, src1Row=56, src1Col=1 + func.func @TROWEXPANDADD_f32_56x128(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c56 = arith.constant 56 : index + %c128 = arith.constant 128 : index + %c448 = arith.constant 448 : index + %c7168 = arith.constant 7168 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c56, %c128], + strides = [%c7168, %c7168, %c7168, %c128, %c1] + : !pto.tensor_view<1x1x1x56x128xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c56, %c8], + strides = [%c448, %c448, %c448, %c8, %c1] + : !pto.tensor_view<1x1x1x56x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c56, %c128], + strides = [%c7168, %c7168, %c7168, %c128, %c1] + : !pto.tensor_view<1x1x1x56x128xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c56, %c128] + : !pto.tensor_view<1x1x1x56x128xf32> -> !pto.partition_tensor_view<1x1x1x56x128xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c56, %c8] + : !pto.tensor_view<1x1x1x56x8xf32> -> !pto.partition_tensor_view<1x1x1x56x8xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c56, %c128] + : !pto.tensor_view<1x1x1x56x128xf32> -> !pto.partition_tensor_view<1x1x1x56x128xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x56x128xf32>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x56x8xf32>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandadd ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x56x128xf32>) + return + } + + // Note: launchTRowExpandAdd2 with src1Col=8 has different semantics - TBD + + // f16_48x64: dstRow=48, dstCol=64, src1Row=48, src1Col=1 + func.func @TROWEXPANDADD_f16_48x64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c48 = arith.constant 48 : index + %c64 = arith.constant 64 : index + %c768 = arith.constant 768 : index + %c3072 = arith.constant 3072 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c48, %c64], + strides = [%c3072, %c3072, %c3072, %c64, %c1] + : !pto.tensor_view<1x1x1x48x64xf16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c48, %c16], + strides = [%c768, %c768, %c768, %c16, %c1] + : !pto.tensor_view<1x1x1x48x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c48, %c64], + strides = [%c3072, %c3072, %c3072, %c64, %c1] + : !pto.tensor_view<1x1x1x48x64xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c48, %c64] + : !pto.tensor_view<1x1x1x48x64xf16> -> !pto.partition_tensor_view<1x1x1x48x64xf16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c48, %c16] + : !pto.tensor_view<1x1x1x48x16xf16> -> !pto.partition_tensor_view<1x1x1x48x16xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c48, %c64] + : !pto.tensor_view<1x1x1x48x64xf16> -> !pto.partition_tensor_view<1x1x1x48x64xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x48x64xf16>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x48x16xf16>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandadd ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x48x64xf16>) + return + } + + // f16_16x128: dstRow=16, dstCol=128, src1Row=16, src1Col=1 + func.func @TROWEXPANDADD_f16_16x128(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c2048 = arith.constant 2048 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c16, %c16], + strides = [%c256, %c256, %c256, %c16, %c1] + : !pto.tensor_view<1x1x1x16x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf16> -> !pto.partition_tensor_view<1x1x1x16x128xf16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c16] + : !pto.tensor_view<1x1x1x16x16xf16> -> !pto.partition_tensor_view<1x1x1x16x16xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf16> -> !pto.partition_tensor_view<1x1x1x16x128xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x128xf16>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x16xf16>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandadd ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x128xf16>) + return + } + + // f16_32x64_noeq: dstRow=32, dstCol=64, src1Row=32, src1Col=1 (src0eqdst=false) + func.func @TROWEXPANDADD_f16_32x64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c512 = arith.constant 512 : index + %c2048 = arith.constant 2048 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c32, %c16], + strides = [%c512, %c512, %c512, %c16, %c1] + : !pto.tensor_view<1x1x1x32x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf16> -> !pto.partition_tensor_view<1x1x1x32x64xf16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c16] + : !pto.tensor_view<1x1x1x32x16xf16> -> !pto.partition_tensor_view<1x1x1x32x16xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf16> -> !pto.partition_tensor_view<1x1x1x32x64xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x32x64xf16>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x32x16xf16>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandadd ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x64xf16>) + return + } + + // i32_16x32: dstRow=16, dstCol=32, src1Row=16, src1Col=1 + func.func @TROWEXPANDADD_i32_16x32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c16, %c8], + strides = [%c128, %c128, %c128, %c8, %c1] + : !pto.tensor_view<1x1x1x16x8xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c8] + : !pto.tensor_view<1x1x1x16x8xi32> -> !pto.partition_tensor_view<1x1x1x16x8xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x8xi32>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandadd ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) + return + } + + // i16_16x64: dstRow=16, dstCol=64, src1Row=16, src1Col=1 + func.func @TROWEXPANDADD_i16_16x64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c16, %c16], + strides = [%c256, %c256, %c256, %c16, %c1] + : !pto.tensor_view<1x1x1x16x16xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi16> -> !pto.partition_tensor_view<1x1x1x16x64xi16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c16] + : !pto.tensor_view<1x1x1x16x16xi16> -> !pto.partition_tensor_view<1x1x1x16x16xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi16> -> !pto.partition_tensor_view<1x1x1x16x64xi16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x64xi16>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x16xi16>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandadd ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xi16>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpanddiv/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/trowexpanddiv/CMakeLists.txt new file mode 100644 index 000000000..2cbab13c9 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpanddiv/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(trowexpanddiv) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpanddiv/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/trowexpanddiv/cases.py new file mode 100644 index 000000000..ad20e8203 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpanddiv/cases.py @@ -0,0 +1,130 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for trowexpanddiv ST test cases. + +trowexpanddiv: dst = src0 / broadcast(src1) across columns. +- src1Col determines how src1 is broadcast: + - src1Col=1: only first column is valid, broadcast to dstCols + - src1Col>1: each src1 column maps to a block of dst columns (dstCol/src1Col columns per src1 value) +- src1 physical cols = 32/sizeof(dtype) for NPU alignment +- highPrecision: use high precision mode for computation +""" + +import numpy as np + +CASES = [ + # launchTRowExpandDiv + { + "name": "f32_40x64", + "dtype": np.float32, + "src0_shape": (40, 64), # src0eqdst=true + "src0_valid_shape": (40, 64), + "src1_shape": (40, 8), # physical: 32/sizeof(f32)=8 + "src1_valid_shape": (40, 1), # src1Col=1 + "dst_shape": (40, 64), + "dst_valid_shape": (40, 64), + "eps": 1e-6, + "high_precision": False, + }, + # launchTRowExpandDiv + { + "name": "f32_16x256", + "dtype": np.float32, + "src0_shape": (16, 256), + "src0_valid_shape": (16, 256), + "src1_shape": (16, 8), + "src1_valid_shape": (16, 1), + "dst_shape": (16, 256), + "dst_valid_shape": (16, 256), + "eps": 1e-6, + "high_precision": False, + }, + # launchTRowExpandDiv + { + "name": "f16_16x32", + "dtype": np.float16, + "src0_shape": (16, 32), + "src0_valid_shape": (16, 32), + "src1_shape": (16, 16), # physical: 32/sizeof(f16)=16 + "src1_valid_shape": (16, 1), + "dst_shape": (16, 32), + "dst_valid_shape": (16, 32), + "eps": 1e-3, + "high_precision": False, + }, + # launchTRowExpandDiv + { + "name": "f16_32x512", + "dtype": np.float16, + "src0_shape": (32, 512), + "src0_valid_shape": (32, 512), + "src1_shape": (32, 16), + "src1_valid_shape": (32, 1), + "dst_shape": (32, 512), + "dst_valid_shape": (32, 512), + "eps": 1e-3, + "high_precision": False, + }, + # launchTRowExpandDiv + { + "name": "f32_16x128_noeq", + "dtype": np.float32, + "src0_shape": (16, 128), # src0eqdst=false + "src0_valid_shape": (16, 128), + "src1_shape": (16, 8), + "src1_valid_shape": (16, 1), + "dst_shape": (16, 128), + "dst_valid_shape": (16, 128), + "eps": 1e-6, + "high_precision": False, + }, + # launchTRowExpandDiv + { + "name": "f16_32x64_noeq", + "dtype": np.float16, + "src0_shape": (32, 64), + "src0_valid_shape": (32, 64), + "src1_shape": (32, 16), + "src1_valid_shape": (32, 1), + "dst_shape": (32, 64), + "dst_valid_shape": (32, 64), + "eps": 1e-3, + "high_precision": False, + }, + # launchTRowExpandDiv + { + "name": "f32_40x32_hp", + "dtype": np.float32, + "src0_shape": (40, 32), + "src0_valid_shape": (40, 32), + "src1_shape": (40, 8), + "src1_valid_shape": (40, 1), + "dst_shape": (40, 32), + "dst_valid_shape": (40, 32), + "eps": 1e-6, + "high_precision": True, + }, + # launchTRowExpandDiv + { + "name": "f16_16x128_hp", + "dtype": np.float16, + "src0_shape": (16, 128), + "src0_valid_shape": (16, 128), + "src1_shape": (16, 16), + "src1_valid_shape": (16, 1), + "dst_shape": (16, 128), + "dst_valid_shape": (16, 128), + "eps": 1e-3, + "high_precision": True, + }, + # Note: launchTRowExpandDiv2 with src1Col>1 has different semantics - TBD +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpanddiv/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/trowexpanddiv/compare.py new file mode 100644 index 000000000..c6dfc114c --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpanddiv/compare.py @@ -0,0 +1,61 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Compare golden and output for trowexpanddiv ST test cases.""" + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass + +# Inline validation for multi-input format (trowexpanddiv uses src0/src1/dst) +REQUIRED_KEYS = {"name", "dtype", "src0_shape", "src0_valid_shape", "src1_shape", + "src1_valid_shape", "dst_shape", "dst_valid_shape"} +for i, case in enumerate(CASES): + missing = REQUIRED_KEYS - case.keys() + if missing: + raise ValueError(f"cases[{i}] ({case.get('name', '?')}) missing keys: {missing}") + + +def main(): + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + dtype = case["dtype"] + + vr, vc = dst_valid_shape + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=dtype).reshape(dst_shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=dtype).reshape(dst_shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpanddiv/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/trowexpanddiv/gen_data.py new file mode 100644 index 000000000..a260bf68e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpanddiv/gen_data.py @@ -0,0 +1,77 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Generate input and golden data for trowexpanddiv ST test cases. + +trowexpanddiv: dst = src0 / broadcast(src1) +""" + +import numpy as np +from cases import CASES +from st_common import setup_case_rng, save_case_data + +# Inline validation for multi-input format (trowexpanddiv uses src0/src1/dst) +REQUIRED_KEYS = {"name", "dtype", "src0_shape", "src0_valid_shape", "src1_shape", + "src1_valid_shape", "dst_shape", "dst_valid_shape"} +for i, case in enumerate(CASES): + missing = REQUIRED_KEYS - case.keys() + if missing: + raise ValueError(f"cases[{i}] ({case.get('name', '?')}) missing keys: {missing}") + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + src0_shape = case["src0_shape"] + src0_valid_shape = case["src0_valid_shape"] + src1_shape = case["src1_shape"] + src1_valid_shape = case["src1_valid_shape"] + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + + input1 = np.random.randint(1, 10, size=src0_shape).astype(dtype) + input2 = np.random.randint(1, 10, size=src1_shape).astype(dtype) + + golden = np.zeros(dst_shape, dtype=dtype) + dst_vr, dst_vc = dst_valid_shape + src0_vr, src0_vc = src0_valid_shape + src1_vr, src1_vc = src1_valid_shape + + # Compute golden based on src1Col semantics + # src1Col=1: broadcast single column to all dst columns + # src1Col>1: each src1 column broadcasts to dst_vc/src1_vc columns + if dtype in (np.int8, np.int16, np.int32): + if src1_vc == 1: + golden[:dst_vr, :dst_vc] = ( + input1[:src0_vr, :src0_vc] // input2[:src1_vr, 0:1] + ).astype(dtype, copy=False) + else: + # src1Col > 1: each src1 column broadcasts to dst_vc/src1_vc dst columns + block_size = dst_vc // src1_vc + for c in range(src1_vc): + golden[:dst_vr, c*block_size:(c+1)*block_size] = ( + input1[:src0_vr, c*block_size:(c+1)*block_size] // input2[:src1_vr, c:c+1] + ).astype(dtype, copy=False) + else: + if src1_vc == 1: + golden[:dst_vr, :dst_vc] = ( + input1[:src0_vr, :src0_vc] / input2[:src1_vr, 0:1] + ).astype(dtype, copy=False) + else: + # src1Col > 1: each src1 column broadcasts to dst_vc/src1_vc dst columns + block_size = dst_vc // src1_vc + for c in range(src1_vc): + golden[:dst_vr, c*block_size:(c+1)*block_size] = ( + input1[:src0_vr, c*block_size:(c+1)*block_size] / input2[:src1_vr, c:c+1] + ).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpanddiv/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowexpanddiv/launch.cpp new file mode 100644 index 000000000..e028030ce --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpanddiv/launch.cpp @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// f32 kernels +extern "C" __global__ AICORE void TROWEXPANDDIV_f32_40x64(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); +extern "C" __global__ AICORE void TROWEXPANDDIV_f32_16x256(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); +extern "C" __global__ AICORE void TROWEXPANDDIV_f32_16x128_noeq(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); +extern "C" __global__ AICORE void TROWEXPANDDIV_f32_40x32_hp(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTROWEXPANDDIV_f32_40x64(float *src0, float *src1, float *dst, void *stream) { + TROWEXPANDDIV_f32_40x64<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} +void LaunchTROWEXPANDDIV_f32_16x256(float *src0, float *src1, float *dst, void *stream) { + TROWEXPANDDIV_f32_16x256<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} +void LaunchTROWEXPANDDIV_f32_16x128_noeq(float *src0, float *src1, float *dst, void *stream) { + TROWEXPANDDIV_f32_16x128_noeq<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} +void LaunchTROWEXPANDDIV_f32_40x32_hp(float *src0, float *src1, float *dst, void *stream) { + TROWEXPANDDIV_f32_40x32_hp<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// f16 kernels (use uint16_t for aclFloat16) +extern "C" __global__ AICORE void TROWEXPANDDIV_f16_16x32(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TROWEXPANDDIV_f16_32x512(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TROWEXPANDDIV_f16_32x64_noeq(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TROWEXPANDDIV_f16_16x128_hp(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTROWEXPANDDIV_f16_16x32(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDDIV_f16_16x32<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} +void LaunchTROWEXPANDDIV_f16_32x512(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDDIV_f16_32x512<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} +void LaunchTROWEXPANDDIV_f16_32x64_noeq(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDDIV_f16_32x64_noeq<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} +void LaunchTROWEXPANDDIV_f16_16x128_hp(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDDIV_f16_16x128_hp<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpanddiv/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowexpanddiv/main.cpp new file mode 100644 index 000000000..6d2fcded5 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpanddiv/main.cpp @@ -0,0 +1,131 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang trowexpanddiv ST — row-wise broadcast division. +// Supports f32, f16 +// Div variants: src1Col=1 (broadcast single value) or src1Col>1 (block broadcast) + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// f32 kernels +void LaunchTROWEXPANDDIV_f32_40x64(float *src0, float *src1, float *dst, void *stream); +void LaunchTROWEXPANDDIV_f32_16x256(float *src0, float *src1, float *dst, void *stream); +void LaunchTROWEXPANDDIV_f32_16x128_noeq(float *src0, float *src1, float *dst, void *stream); +void LaunchTROWEXPANDDIV_f32_40x32_hp(float *src0, float *src1, float *dst, void *stream); +// f16 kernels (use void* for aclFloat16) +void LaunchTROWEXPANDDIV_f16_16x32(void *src0, void *src1, void *dst, void *stream); +void LaunchTROWEXPANDDIV_f16_32x512(void *src0, void *src1, void *dst, void *stream); +void LaunchTROWEXPANDDIV_f16_32x64_noeq(void *src0, void *src1, void *dst, void *stream); +void LaunchTROWEXPANDDIV_f16_16x128_hp(void *src0, void *src1, void *dst, void *stream); + +using LaunchFn = void (*)(void *, void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t src0Rows, src0Cols, src1Rows, src1Cols, dstRows, dstCols; + size_t dstValidRows, dstValidCols; + size_t elemSize; +}; + +static const TestCase kCases[] = { + // f32 cases + {"f32_40x64", (LaunchFn)LaunchTROWEXPANDDIV_f32_40x64, 40, 64, 40, 8, 40, 64, 40, 64, sizeof(float)}, + {"f32_16x256", (LaunchFn)LaunchTROWEXPANDDIV_f32_16x256, 16, 256, 16, 8, 16, 256, 16, 256, sizeof(float)}, + {"f32_16x128_noeq", (LaunchFn)LaunchTROWEXPANDDIV_f32_16x128_noeq, 16, 128, 16, 8, 16, 128, 16, 128, sizeof(float)}, + {"f32_40x32_hp", (LaunchFn)LaunchTROWEXPANDDIV_f32_40x32_hp, 40, 32, 40, 8, 40, 32, 40, 32, sizeof(float)}, + // f16 cases + {"f16_16x32", LaunchTROWEXPANDDIV_f16_16x32, 16, 32, 16, 16, 16, 32, 16, 32, sizeof(uint16_t)}, + {"f16_32x512", LaunchTROWEXPANDDIV_f16_32x512, 32, 512, 32, 16, 32, 512, 32, 512, sizeof(uint16_t)}, + {"f16_32x64_noeq", LaunchTROWEXPANDDIV_f16_32x64_noeq, 32, 64, 32, 16, 32, 64, 32, 64, sizeof(uint16_t)}, + {"f16_16x128_hp", LaunchTROWEXPANDDIV_f16_16x128_hp, 16, 128, 16, 16, 16, 128, 16, 128, sizeof(uint16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + size_t src0FileSize = tc.src0Rows * tc.src0Cols * tc.elemSize; + size_t src1FileSize = tc.src1Rows * tc.src1Cols * tc.elemSize; + const size_t dstFileSize = tc.dstRows * tc.dstCols * tc.elemSize; + + std::printf("[INFO] === case: %s (src0=%zux%zu, src1=%zux%zu, dst=%zux%zu) ===\n", + tc.name, tc.src0Rows, tc.src0Cols, tc.src1Rows, tc.src1Cols, tc.dstRows, tc.dstCols); + + std::string caseDir = std::string("./") + tc.name; + + void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), src0FileSize); + aclrtMallocHost((void **)(&src1Host), src1FileSize); + aclrtMallocHost((void **)(&dstHost), dstFileSize); + aclrtMalloc((void **)&src0Device, src0FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, src1FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, src0FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1Host, src1FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, src0FileSize, src0Host, src0FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, src1FileSize, src1Host, src1FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + tc.launch(src0Device, src1Device, dstDevice, stream); + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device) aclrtFree(src0Device); + if (src1Device) aclrtFree(src1Device); + if (dstDevice) aclrtFree(dstDevice); + if (src0Host) aclrtFreeHost(src0Host); + if (src1Host) aclrtFreeHost(src1Host); + if (dstHost) aclrtFreeHost(dstHost); + + if (rc == 0) std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + int rc = 0, deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) deviceId = std::atoi(envDevice); + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter && std::strcmp(kCases[i].name, caseFilter) != 0) continue; + if (RunCase(kCases[i], deviceId, stream) != 0) { rc = 1; break; } + } + + if (stream) aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpanddiv/trowexpanddiv.pto b/test/tilelang_st/npu/a5/src/st/testcase/trowexpanddiv/trowexpanddiv.pto new file mode 100644 index 000000000..940081c2e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpanddiv/trowexpanddiv.pto @@ -0,0 +1,776 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.trowexpanddiv: row-wise broadcast division. +// Supports f32, f16 types. +// src1Col=1: broadcast single column value to all dst columns +// src1Col>1: each src1 column broadcasts to dstCol/src1Col columns + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // f32_40x64: launchTRowExpandDiv + func.func @TROWEXPANDDIV_f32_40x64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c40 = arith.constant 40 : index + %c64 = arith.constant 64 : index + %c320 = arith.constant 320 : index + %c2560 = arith.constant 2560 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c40, %c64], + strides = [%c2560, %c2560, %c2560, %c64, %c1] + : !pto.tensor_view<1x1x1x40x64xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c40, %c8], + strides = [%c320, %c320, %c320, %c8, %c1] + : !pto.tensor_view<1x1x1x40x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c40, %c64], + strides = [%c2560, %c2560, %c2560, %c64, %c1] + : !pto.tensor_view<1x1x1x40x64xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c40, %c64] + : !pto.tensor_view<1x1x1x40x64xf32> -> !pto.partition_tensor_view<1x1x1x40x64xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c40, %c8] + : !pto.tensor_view<1x1x1x40x8xf32> -> !pto.partition_tensor_view<1x1x1x40x8xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c40, %c64] + : !pto.tensor_view<1x1x1x40x64xf32> -> !pto.partition_tensor_view<1x1x1x40x64xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x40x64xf32>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x40x8xf32>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpanddiv ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true, highPrecision = false} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x40x64xf32>) + return + } + + // f32_16x256: launchTRowExpandDiv + func.func @TROWEXPANDDIV_f32_16x256(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c4096 = arith.constant 4096 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c16, %c8], + strides = [%c128, %c128, %c128, %c8, %c1] + : !pto.tensor_view<1x1x1x16x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c256] + : !pto.tensor_view<1x1x1x16x256xf32> -> !pto.partition_tensor_view<1x1x1x16x256xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c8] + : !pto.tensor_view<1x1x1x16x8xf32> -> !pto.partition_tensor_view<1x1x1x16x8xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c256] + : !pto.tensor_view<1x1x1x16x256xf32> -> !pto.partition_tensor_view<1x1x1x16x256xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x256xf32>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x8xf32>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpanddiv ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true, highPrecision = false} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x256xf32>) + return + } + + // f16_16x32: launchTRowExpandDiv + func.func @TROWEXPANDDIV_f16_16x32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xf16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c16, %c16], + strides = [%c256, %c256, %c256, %c16, %c1] + : !pto.tensor_view<1x1x1x16x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xf16> -> !pto.partition_tensor_view<1x1x1x16x32xf16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c16] + : !pto.tensor_view<1x1x1x16x16xf16> -> !pto.partition_tensor_view<1x1x1x16x16xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xf16> -> !pto.partition_tensor_view<1x1x1x16x32xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x32xf16>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x16xf16>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpanddiv ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true, highPrecision = false} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x32xf16>) + return + } + + // f16_32x512: launchTRowExpandDiv + func.func @TROWEXPANDDIV_f16_32x512(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c512 = arith.constant 512 : index + %c512_ = arith.constant 512 : index + %c16384 = arith.constant 16384 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c32, %c512], + strides = [%c16384, %c16384, %c16384, %c512, %c1] + : !pto.tensor_view<1x1x1x32x512xf16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c32, %c16], + strides = [%c512_, %c512_, %c512_, %c16, %c1] + : !pto.tensor_view<1x1x1x32x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c512], + strides = [%c16384, %c16384, %c16384, %c512, %c1] + : !pto.tensor_view<1x1x1x32x512xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c512] + : !pto.tensor_view<1x1x1x32x512xf16> -> !pto.partition_tensor_view<1x1x1x32x512xf16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c16] + : !pto.tensor_view<1x1x1x32x16xf16> -> !pto.partition_tensor_view<1x1x1x32x16xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c512] + : !pto.tensor_view<1x1x1x32x512xf16> -> !pto.partition_tensor_view<1x1x1x32x512xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x32x512xf16>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x32x16xf16>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpanddiv ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true, highPrecision = false} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x512xf16>) + return + } + + // f32_16x128_noeq: launchTRowExpandDiv + func.func @TROWEXPANDDIV_f32_16x128_noeq(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c128_ = arith.constant 128 : index + %c2048 = arith.constant 2048 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c16, %c8], + strides = [%c128_, %c128_, %c128_, %c8, %c1] + : !pto.tensor_view<1x1x1x16x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x128xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c8] + : !pto.tensor_view<1x1x1x16x8xf32> -> !pto.partition_tensor_view<1x1x1x16x8xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x128xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x128xf32>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x8xf32>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpanddiv ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = false, highPrecision = false} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x128xf32>) + return + } + + // f16_32x64_noeq: launchTRowExpandDiv + func.func @TROWEXPANDDIV_f16_32x64_noeq(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c512 = arith.constant 512 : index + %c2048 = arith.constant 2048 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c32, %c16], + strides = [%c512, %c512, %c512, %c16, %c1] + : !pto.tensor_view<1x1x1x32x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf16> -> !pto.partition_tensor_view<1x1x1x32x64xf16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c16] + : !pto.tensor_view<1x1x1x32x16xf16> -> !pto.partition_tensor_view<1x1x1x32x16xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf16> -> !pto.partition_tensor_view<1x1x1x32x64xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x32x64xf16>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x32x16xf16>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpanddiv ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = false, highPrecision = false} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x64xf16>) + return + } + + // f32_40x32_hp: launchTRowExpandDiv (highPrecision) + func.func @TROWEXPANDDIV_f32_40x32_hp(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c32 = arith.constant 32 : index + %c40 = arith.constant 40 : index + %c320 = arith.constant 320 : index + %c1280 = arith.constant 1280 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c40, %c32], + strides = [%c1280, %c1280, %c1280, %c32, %c1] + : !pto.tensor_view<1x1x1x40x32xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c40, %c8], + strides = [%c320, %c320, %c320, %c8, %c1] + : !pto.tensor_view<1x1x1x40x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c40, %c32], + strides = [%c1280, %c1280, %c1280, %c32, %c1] + : !pto.tensor_view<1x1x1x40x32xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c40, %c32] + : !pto.tensor_view<1x1x1x40x32xf32> -> !pto.partition_tensor_view<1x1x1x40x32xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c40, %c8] + : !pto.tensor_view<1x1x1x40x8xf32> -> !pto.partition_tensor_view<1x1x1x40x8xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c40, %c32] + : !pto.tensor_view<1x1x1x40x32xf32> -> !pto.partition_tensor_view<1x1x1x40x32xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x40x32xf32>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x40x8xf32>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpanddiv ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true, highPrecision = true} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x40x32xf32>) + return + } + + // f16_16x128_hp: launchTRowExpandDiv (highPrecision) + func.func @TROWEXPANDDIV_f16_16x128_hp(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c2048 = arith.constant 2048 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c16, %c16], + strides = [%c256, %c256, %c256, %c16, %c1] + : !pto.tensor_view<1x1x1x16x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf16> -> !pto.partition_tensor_view<1x1x1x16x128xf16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c16] + : !pto.tensor_view<1x1x1x16x16xf16> -> !pto.partition_tensor_view<1x1x1x16x16xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf16> -> !pto.partition_tensor_view<1x1x1x16x128xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x128xf16>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x16xf16>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpanddiv ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true, highPrecision = true} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x128xf16>) + return + } + + // f32_24x64_v2: launchTRowExpandDiv2 + func.func @TROWEXPANDDIV_f32_24x64_v2(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c24 = arith.constant 24 : index + %c64 = arith.constant 64 : index + %c192 = arith.constant 192 : index + %c1536 = arith.constant 1536 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c24, %c64], + strides = [%c1536, %c1536, %c1536, %c64, %c1] + : !pto.tensor_view<1x1x1x24x64xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c24, %c8], + strides = [%c192, %c192, %c192, %c8, %c1] + : !pto.tensor_view<1x1x1x24x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c24, %c64], + strides = [%c1536, %c1536, %c1536, %c64, %c1] + : !pto.tensor_view<1x1x1x24x64xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c24, %c64] + : !pto.tensor_view<1x1x1x24x64xf32> -> !pto.partition_tensor_view<1x1x1x24x64xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c24, %c8] + : !pto.tensor_view<1x1x1x24x8xf32> -> !pto.partition_tensor_view<1x1x1x24x8xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c24, %c64] + : !pto.tensor_view<1x1x1x24x64xf32> -> !pto.partition_tensor_view<1x1x1x24x64xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x24x64xf32>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x24x8xf32>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpanddiv ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 8 : i64, src0eqdst = true, highPrecision = false} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x24x64xf32>) + return + } + + // f16_32x32_v2: launchTRowExpandDiv2 + func.func @TROWEXPANDDIV_f16_32x32_v2(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c512 = arith.constant 512 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c32, %c16], + strides = [%c512, %c512, %c512, %c16, %c1] + : !pto.tensor_view<1x1x1x32x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf16> -> !pto.partition_tensor_view<1x1x1x32x32xf16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c16] + : !pto.tensor_view<1x1x1x32x16xf16> -> !pto.partition_tensor_view<1x1x1x32x16xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf16> -> !pto.partition_tensor_view<1x1x1x32x32xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x32x32xf16>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x32x16xf16>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpanddiv ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 16 : i64, src0eqdst = true, highPrecision = false} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x32xf16>) + return + } + + // f32_20x64_v2_noeq: launchTRowExpandDiv2 + func.func @TROWEXPANDDIV_f32_20x64_v2_noeq(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c20 = arith.constant 20 : index + %c64 = arith.constant 64 : index + %c160 = arith.constant 160 : index + %c1280 = arith.constant 1280 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c20, %c64], + strides = [%c1280, %c1280, %c1280, %c64, %c1] + : !pto.tensor_view<1x1x1x20x64xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c20, %c8], + strides = [%c160, %c160, %c160, %c8, %c1] + : !pto.tensor_view<1x1x1x20x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c20, %c64], + strides = [%c1280, %c1280, %c1280, %c64, %c1] + : !pto.tensor_view<1x1x1x20x64xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c20, %c64] + : !pto.tensor_view<1x1x1x20x64xf32> -> !pto.partition_tensor_view<1x1x1x20x64xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c20, %c8] + : !pto.tensor_view<1x1x1x20x8xf32> -> !pto.partition_tensor_view<1x1x1x20x8xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c20, %c64] + : !pto.tensor_view<1x1x1x20x64xf32> -> !pto.partition_tensor_view<1x1x1x20x64xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x20x64xf32>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x20x8xf32>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpanddiv ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 8 : i64, src0eqdst = false, highPrecision = false} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x20x64xf32>) + return + } + + // f16_16x64_v2_noeq: launchTRowExpandDiv2 + func.func @TROWEXPANDDIV_f16_16x64_v2_noeq(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c16, %c16], + strides = [%c256, %c256, %c256, %c16, %c1] + : !pto.tensor_view<1x1x1x16x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c16] + : !pto.tensor_view<1x1x1x16x16xf16> -> !pto.partition_tensor_view<1x1x1x16x16xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x16xf16>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpanddiv ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 16 : i64, src0eqdst = false, highPrecision = false} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + return + } + + // f32_8x32_v2_hp: launchTRowExpandDiv2 + func.func @TROWEXPANDDIV_f32_8x32_v2_hp(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c32 = arith.constant 32 : index + %c256 = arith.constant 256 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c8, %c32], + strides = [%c256, %c256, %c256, %c32, %c1] + : !pto.tensor_view<1x1x1x8x32xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c8, %c8], + strides = [%c256, %c256, %c256, %c8, %c1] + : !pto.tensor_view<1x1x1x8x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c8, %c32], + strides = [%c256, %c256, %c256, %c32, %c1] + : !pto.tensor_view<1x1x1x8x32xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c32] + : !pto.tensor_view<1x1x1x8x32xf32> -> !pto.partition_tensor_view<1x1x1x8x32xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c8] + : !pto.tensor_view<1x1x1x8x8xf32> -> !pto.partition_tensor_view<1x1x1x8x8xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c32] + : !pto.tensor_view<1x1x1x8x32xf32> -> !pto.partition_tensor_view<1x1x1x8x32xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x8x32xf32>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x8x8xf32>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpanddiv ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 8 : i64, src0eqdst = true, highPrecision = true} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x8x32xf32>) + return + } + + // f16_8x128_v2_hp: launchTRowExpandDiv2 + func.func @TROWEXPANDDIV_f16_8x128_v2_hp(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c8, %c128], + strides = [%c1024, %c1024, %c1024, %c128, %c1] + : !pto.tensor_view<1x1x1x8x128xf16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c8, %c16], + strides = [%c128, %c128, %c128, %c16, %c1] + : !pto.tensor_view<1x1x1x8x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c8, %c128], + strides = [%c1024, %c1024, %c1024, %c128, %c1] + : !pto.tensor_view<1x1x1x8x128xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c128] + : !pto.tensor_view<1x1x1x8x128xf16> -> !pto.partition_tensor_view<1x1x1x8x128xf16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c16] + : !pto.tensor_view<1x1x1x8x16xf16> -> !pto.partition_tensor_view<1x1x1x8x16xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c128] + : !pto.tensor_view<1x1x1x8x128xf16> -> !pto.partition_tensor_view<1x1x1x8x128xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x8x128xf16>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x8x16xf16>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpanddiv ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 16 : i64, src0eqdst = true, highPrecision = true} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x8x128xf16>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandexpdif/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandexpdif/CMakeLists.txt new file mode 100644 index 000000000..fd0640efc --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandexpdif/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(trowexpandexpdif) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandexpdif/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandexpdif/cases.py new file mode 100644 index 000000000..111391868 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandexpdif/cases.py @@ -0,0 +1,87 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for trowexpandexpdif ST test cases. + +trowexpandexpdif: dst = exp(src0 - broadcast(src1)) +- src1Col=1: only first column of src1 is valid, broadcast to dstCols +- src1Col>1: launchTRowExpandExpdif2 with different semantics (TBD) +- src1 physical cols = 32/sizeof(dtype) for NPU alignment +- src0eqdst: true means src0 shape equals dst shape +""" + +import numpy as np + +CASES = [ + # launchTRowExpandExpdif + { + "name": "f32_32x64", + "dtype": np.float32, + "src0_shape": (32, 64), + "src0_valid_shape": (32, 64), + "src1_shape": (32, 8), # physical: 32/sizeof(f32)=8 + "src1_valid_shape": (32, 1), # src1Col=1 + "dst_shape": (32, 64), + "dst_valid_shape": (32, 64), + "eps": 1e-5, + }, + # launchTRowExpandExpdif + { + "name": "f32_16x32", + "dtype": np.float32, + "src0_shape": (16, 32), + "src0_valid_shape": (16, 32), + "src1_shape": (16, 8), + "src1_valid_shape": (16, 1), + "dst_shape": (16, 32), + "dst_valid_shape": (16, 32), + "eps": 1e-5, + }, + # launchTRowExpandExpdif + { + "name": "f16_16x32", + "dtype": np.float16, + "src0_shape": (16, 32), + "src0_valid_shape": (16, 32), + "src1_shape": (16, 16), # physical: 32/sizeof(f16)=16 + "src1_valid_shape": (16, 1), + "dst_shape": (16, 32), + "dst_valid_shape": (16, 32), + "eps": 1e-3, + }, + # launchTRowExpandExpdif + { + "name": "f16_48x64", + "dtype": np.float16, + "src0_shape": (48, 64), + "src0_valid_shape": (48, 64), + "src1_shape": (48, 16), + "src1_valid_shape": (48, 1), + "dst_shape": (48, 64), + "dst_valid_shape": (48, 64), + "eps": 1e-3, + }, + # launchTRowExpandExpdif + { + "name": "f32_16x128_noeq", + "dtype": np.float32, + "src0_shape": (16, 128), # src0eqdst=false + "src0_valid_shape": (16, 128), + "src1_shape": (16, 8), + "src1_valid_shape": (16, 1), + "dst_shape": (16, 128), + "dst_valid_shape": (16, 128), + "eps": 1e-5, + }, + # Note: launchTRowExpandExpdif2 with src1Col>1 has different semantics - TBD + # - float, 24, 64, 24, 8, true (src1Col=8) + # - aclFloat16, 16, 64, 16, 16, false (src1Col=16) +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandexpdif/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandexpdif/compare.py new file mode 100644 index 000000000..98aff7854 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandexpdif/compare.py @@ -0,0 +1,61 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Compare golden and output for trowexpandexpdif ST test cases.""" + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass + +# Inline validation for multi-input format (trowexpandexpdif uses src0/src1/dst) +REQUIRED_KEYS = {"name", "dtype", "src0_shape", "src0_valid_shape", "src1_shape", + "src1_valid_shape", "dst_shape", "dst_valid_shape"} +for i, case in enumerate(CASES): + missing = REQUIRED_KEYS - case.keys() + if missing: + raise ValueError(f"cases[{i}] ({case.get('name', '?')}) missing keys: {missing}") + + +def main(): + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + dtype = case["dtype"] + + vr, vc = dst_valid_shape + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=dtype).reshape(dst_shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=dtype).reshape(dst_shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandexpdif/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandexpdif/gen_data.py new file mode 100644 index 000000000..8b6e09814 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandexpdif/gen_data.py @@ -0,0 +1,54 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Generate input and golden data for trowexpandexpdif ST test cases. + +trowexpandexpdif: dst = exp(src0 - broadcast(src1)) +""" + +import numpy as np +from cases import CASES +from st_common import setup_case_rng, save_case_data + +# Inline validation for multi-input format (trowexpandexpdif uses src0/src1/dst) +REQUIRED_KEYS = {"name", "dtype", "src0_shape", "src0_valid_shape", "src1_shape", + "src1_valid_shape", "dst_shape", "dst_valid_shape"} +for i, case in enumerate(CASES): + missing = REQUIRED_KEYS - case.keys() + if missing: + raise ValueError(f"cases[{i}] ({case.get('name', '?')}) missing keys: {missing}") + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + src0_shape = case["src0_shape"] + src0_valid_shape = case["src0_valid_shape"] + src1_shape = case["src1_shape"] + src1_valid_shape = case["src1_valid_shape"] + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + + # Use small values to avoid overflow in exp + input1 = np.random.randint(1, 5, size=src0_shape).astype(dtype) + input2 = np.random.randint(1, 5, size=src1_shape).astype(dtype) + + golden = np.zeros(dst_shape, dtype=dtype) + dst_vr, dst_vc = dst_valid_shape + src0_vr, src0_vc = src0_valid_shape + src1_vr = src1_valid_shape[0] + + # exp(src0 - src1_scalar) + diff = input1[:src0_vr, :src0_vc] - input2[:src1_vr, 0:1] + golden[:dst_vr, :dst_vc] = np.exp(diff).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandexpdif/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandexpdif/launch.cpp new file mode 100644 index 000000000..6b65978f7 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandexpdif/launch.cpp @@ -0,0 +1,41 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// f32 kernels +extern "C" __global__ AICORE void TROWEXPANDEXPDIF_f32_32x64(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); +extern "C" __global__ AICORE void TROWEXPANDEXPDIF_f32_16x32(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); +extern "C" __global__ AICORE void TROWEXPANDEXPDIF_f32_16x128_noeq(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTROWEXPANDEXPDIF_f32_32x64(float *src0, float *src1, float *dst, void *stream) { + TROWEXPANDEXPDIF_f32_32x64<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} +void LaunchTROWEXPANDEXPDIF_f32_16x32(float *src0, float *src1, float *dst, void *stream) { + TROWEXPANDEXPDIF_f32_16x32<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} +void LaunchTROWEXPANDEXPDIF_f32_16x128_noeq(float *src0, float *src1, float *dst, void *stream) { + TROWEXPANDEXPDIF_f32_16x128_noeq<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// f16 kernels (use uint16_t for aclFloat16) +extern "C" __global__ AICORE void TROWEXPANDEXPDIF_f16_16x32(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TROWEXPANDEXPDIF_f16_48x64(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTROWEXPANDEXPDIF_f16_16x32(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDEXPDIF_f16_16x32<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} +void LaunchTROWEXPANDEXPDIF_f16_48x64(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDEXPDIF_f16_48x64<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} + +// Note: launchTRowExpandExpdif2 with src1Col>1 has different semantics - TBD \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandexpdif/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandexpdif/main.cpp new file mode 100644 index 000000000..92c7ba517 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandexpdif/main.cpp @@ -0,0 +1,126 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang trowexpandexpdif ST — row-wise broadcast exponential difference. +// Supports f32, f16 + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// f32 kernels +void LaunchTROWEXPANDEXPDIF_f32_32x64(float *src0, float *src1, float *dst, void *stream); +void LaunchTROWEXPANDEXPDIF_f32_16x32(float *src0, float *src1, float *dst, void *stream); +void LaunchTROWEXPANDEXPDIF_f32_16x128_noeq(float *src0, float *src1, float *dst, void *stream); +// f16 kernels (use void* for aclFloat16) +void LaunchTROWEXPANDEXPDIF_f16_16x32(void *src0, void *src1, void *dst, void *stream); +void LaunchTROWEXPANDEXPDIF_f16_48x64(void *src0, void *src1, void *dst, void *stream); + +// Note: launchTRowExpandExpdif2 with src1Col>1 has different semantics - TBD + +using LaunchFn = void (*)(void *, void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t src0Rows, src0Cols, src1Rows, src1Cols, dstRows, dstCols; + size_t dstValidRows, dstValidCols; + size_t elemSize; +}; + +static const TestCase kCases[] = { + // f32 cases + {"f32_32x64", (LaunchFn)LaunchTROWEXPANDEXPDIF_f32_32x64, 32, 64, 32, 8, 32, 64, 32, 64, sizeof(float)}, + {"f32_16x32", (LaunchFn)LaunchTROWEXPANDEXPDIF_f32_16x32, 16, 32, 16, 8, 16, 32, 16, 32, sizeof(float)}, + {"f32_16x128_noeq", (LaunchFn)LaunchTROWEXPANDEXPDIF_f32_16x128_noeq, 16, 128, 16, 8, 16, 128, 16, 128, sizeof(float)}, + // f16 cases + {"f16_16x32", LaunchTROWEXPANDEXPDIF_f16_16x32, 16, 32, 16, 16, 16, 32, 16, 32, sizeof(uint16_t)}, + {"f16_48x64", LaunchTROWEXPANDEXPDIF_f16_48x64, 48, 64, 48, 16, 48, 64, 48, 64, sizeof(uint16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + size_t src0FileSize = tc.src0Rows * tc.src0Cols * tc.elemSize; + size_t src1FileSize = tc.src1Rows * tc.src1Cols * tc.elemSize; + const size_t dstFileSize = tc.dstRows * tc.dstCols * tc.elemSize; + + std::printf("[INFO] === case: %s (src0=%zux%zu, src1=%zux%zu, dst=%zux%zu) ===\n", + tc.name, tc.src0Rows, tc.src0Cols, tc.src1Rows, tc.src1Cols, tc.dstRows, tc.dstCols); + + std::string caseDir = std::string("./") + tc.name; + + void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), src0FileSize); + aclrtMallocHost((void **)(&src1Host), src1FileSize); + aclrtMallocHost((void **)(&dstHost), dstFileSize); + aclrtMalloc((void **)&src0Device, src0FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, src1FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, src0FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1Host, src1FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, src0FileSize, src0Host, src0FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, src1FileSize, src1Host, src1FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + tc.launch(src0Device, src1Device, dstDevice, stream); + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device) aclrtFree(src0Device); + if (src1Device) aclrtFree(src1Device); + if (dstDevice) aclrtFree(dstDevice); + if (src0Host) aclrtFreeHost(src0Host); + if (src1Host) aclrtFreeHost(src1Host); + if (dstHost) aclrtFreeHost(dstHost); + + if (rc == 0) std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + int rc = 0, deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) deviceId = std::atoi(envDevice); + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter && std::strcmp(kCases[i].name, caseFilter) != 0) continue; + if (RunCase(kCases[i], deviceId, stream) != 0) { rc = 1; break; } + } + + if (stream) aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandexpdif/trowexpandexpdif.pto b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandexpdif/trowexpandexpdif.pto new file mode 100644 index 000000000..6ef72c7d5 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandexpdif/trowexpandexpdif.pto @@ -0,0 +1,287 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.trowexpandexpdif: row-wise broadcast exponential difference. +// dst = exp(src0 - broadcast(src1)) +// Supports f32, f16 + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // f32_32x64: launchTRowExpandExpdif + func.func @TROWEXPANDEXPDIF_f32_32x64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c256 = arith.constant 256 : index + %c2048 = arith.constant 2048 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c32, %c8], + strides = [%c256, %c256, %c256, %c8, %c1] + : !pto.tensor_view<1x1x1x32x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c8] + : !pto.tensor_view<1x1x1x32x8xf32> -> !pto.partition_tensor_view<1x1x1x32x8xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x32x8xf32>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandexpdif ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) + return + } + + // f32_16x32: launchTRowExpandExpdif + func.func @TROWEXPANDEXPDIF_f32_16x32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c16, %c8], + strides = [%c128, %c128, %c128, %c8, %c1] + : !pto.tensor_view<1x1x1x16x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xf32> -> !pto.partition_tensor_view<1x1x1x16x32xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c8] + : !pto.tensor_view<1x1x1x16x8xf32> -> !pto.partition_tensor_view<1x1x1x16x8xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xf32> -> !pto.partition_tensor_view<1x1x1x16x32xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x32xf32>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x8xf32>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandexpdif ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x32xf32>) + return + } + + // f16_16x32: launchTRowExpandExpdif + func.func @TROWEXPANDEXPDIF_f16_16x32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xf16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c16, %c16], + strides = [%c256, %c256, %c256, %c16, %c1] + : !pto.tensor_view<1x1x1x16x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xf16> -> !pto.partition_tensor_view<1x1x1x16x32xf16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c16] + : !pto.tensor_view<1x1x1x16x16xf16> -> !pto.partition_tensor_view<1x1x1x16x16xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xf16> -> !pto.partition_tensor_view<1x1x1x16x32xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x32xf16>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x16xf16>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandexpdif ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x32xf16>) + return + } + + // f16_48x64: launchTRowExpandExpdif + func.func @TROWEXPANDEXPDIF_f16_48x64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c48 = arith.constant 48 : index + %c64 = arith.constant 64 : index + %c768 = arith.constant 768 : index + %c3072 = arith.constant 3072 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c48, %c64], + strides = [%c3072, %c3072, %c3072, %c64, %c1] + : !pto.tensor_view<1x1x1x48x64xf16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c48, %c16], + strides = [%c768, %c768, %c768, %c16, %c1] + : !pto.tensor_view<1x1x1x48x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c48, %c64], + strides = [%c3072, %c3072, %c3072, %c64, %c1] + : !pto.tensor_view<1x1x1x48x64xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c48, %c64] + : !pto.tensor_view<1x1x1x48x64xf16> -> !pto.partition_tensor_view<1x1x1x48x64xf16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c48, %c16] + : !pto.tensor_view<1x1x1x48x16xf16> -> !pto.partition_tensor_view<1x1x1x48x16xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c48, %c64] + : !pto.tensor_view<1x1x1x48x64xf16> -> !pto.partition_tensor_view<1x1x1x48x64xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x48x64xf16>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x48x16xf16>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandexpdif ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x48x64xf16>) + return + } + + // f32_16x128_noeq: launchTRowExpandExpdif + func.func @TROWEXPANDEXPDIF_f32_16x128_noeq(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c128_ = arith.constant 128 : index + %c2048 = arith.constant 2048 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c16, %c8], + strides = [%c128_, %c128_, %c128_, %c8, %c1] + : !pto.tensor_view<1x1x1x16x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x128xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c8] + : !pto.tensor_view<1x1x1x16x8xf32> -> !pto.partition_tensor_view<1x1x1x16x8xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x128xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x128xf32>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x8xf32>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandexpdif ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = false} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x128xf32>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmax/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmax/CMakeLists.txt new file mode 100644 index 000000000..7f6c82ffe --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmax/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(trowexpandmax) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmax/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmax/cases.py new file mode 100644 index 000000000..be9646b0f --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmax/cases.py @@ -0,0 +1,111 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for trowexpandmax ST test cases. + +trowexpandmax: row-wise broadcast maximum. +- src1Col=1: only first column of src1 is valid, broadcast to dstCols +- src1Col>1: launchTRowExpandMax2 with different semantics (TBD) +- src1 physical cols = 32/sizeof(dtype) for NPU alignment +- src0eqdst: true means src0 shape equals dst shape +""" + +import numpy as np + +CASES = [ + # launchTRowExpandMax + { + "name": "f32_16x32", + "dtype": np.float32, + "src0_shape": (16, 32), + "src0_valid_shape": (16, 32), + "src1_shape": (16, 8), # physical: 32/sizeof(f32)=8 + "src1_valid_shape": (16, 1), # src1Col=1 + "dst_shape": (16, 32), + "dst_valid_shape": (16, 32), + "eps": 1e-6, + }, + # launchTRowExpandMax + { + "name": "f32_56x128", + "dtype": np.float32, + "src0_shape": (56, 128), + "src0_valid_shape": (56, 128), + "src1_shape": (56, 8), + "src1_valid_shape": (56, 1), + "dst_shape": (56, 128), + "dst_valid_shape": (56, 128), + "eps": 1e-6, + }, + # launchTRowExpandMax + { + "name": "f16_48x64", + "dtype": np.float16, + "src0_shape": (48, 64), + "src0_valid_shape": (48, 64), + "src1_shape": (48, 16), # physical: 32/sizeof(f16)=16 + "src1_valid_shape": (48, 1), + "dst_shape": (48, 64), + "dst_valid_shape": (48, 64), + "eps": 1e-3, + }, + # launchTRowExpandMax + { + "name": "f16_16x128", + "dtype": np.float16, + "src0_shape": (16, 128), + "src0_valid_shape": (16, 128), + "src1_shape": (16, 16), + "src1_valid_shape": (16, 1), + "dst_shape": (16, 128), + "dst_valid_shape": (16, 128), + "eps": 1e-3, + }, + # launchTRowExpandMax + { + "name": "f16_32x64_noeq", + "dtype": np.float16, + "src0_shape": (32, 64), # src0eqdst=false + "src0_valid_shape": (32, 64), + "src1_shape": (32, 16), + "src1_valid_shape": (32, 1), + "dst_shape": (32, 64), + "dst_valid_shape": (32, 64), + "eps": 1e-3, + }, + # launchTRowExpandMax + { + "name": "i32_16x32", + "dtype": np.int32, + "src0_shape": (16, 32), + "src0_valid_shape": (16, 32), + "src1_shape": (16, 8), # physical: 32/sizeof(i32)=8 + "src1_valid_shape": (16, 1), + "dst_shape": (16, 32), + "dst_valid_shape": (16, 32), + "eps": 0, + }, + # launchTRowExpandMax + { + "name": "i16_16x64", + "dtype": np.int16, + "src0_shape": (16, 64), + "src0_valid_shape": (16, 64), + "src1_shape": (16, 16), # physical: 32/sizeof(i16)=16 + "src1_valid_shape": (16, 1), + "dst_shape": (16, 64), + "dst_valid_shape": (16, 64), + "eps": 0, + }, + # Note: launchTRowExpandMax2 with src1Col>1 has different semantics - TBD + # - float, 24, 64, 24, 8, true (src1Col=8) + # - float, 20, 64, 20, 8, false (src1Col=8) +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmax/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmax/compare.py new file mode 100644 index 000000000..d02ad1760 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmax/compare.py @@ -0,0 +1,61 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Compare golden and output for trowexpandmax ST test cases.""" + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass + +# Inline validation for multi-input format (trowexpandmax uses src0/src1/dst) +REQUIRED_KEYS = {"name", "dtype", "src0_shape", "src0_valid_shape", "src1_shape", + "src1_valid_shape", "dst_shape", "dst_valid_shape"} +for i, case in enumerate(CASES): + missing = REQUIRED_KEYS - case.keys() + if missing: + raise ValueError(f"cases[{i}] ({case.get('name', '?')}) missing keys: {missing}") + + +def main(): + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + dtype = case["dtype"] + + vr, vc = dst_valid_shape + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=dtype).reshape(dst_shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=dtype).reshape(dst_shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmax/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmax/gen_data.py new file mode 100644 index 000000000..ba1a8a2c4 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmax/gen_data.py @@ -0,0 +1,53 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Generate input and golden data for trowexpandmax ST test cases. + +trowexpandmax: dst = max(src0, broadcast(src1)) +""" + +import numpy as np +from cases import CASES +from st_common import setup_case_rng, save_case_data + +# Inline validation for multi-input format (trowexpandmax uses src0/src1/dst) +REQUIRED_KEYS = {"name", "dtype", "src0_shape", "src0_valid_shape", "src1_shape", + "src1_valid_shape", "dst_shape", "dst_valid_shape"} +for i, case in enumerate(CASES): + missing = REQUIRED_KEYS - case.keys() + if missing: + raise ValueError(f"cases[{i}] ({case.get('name', '?')}) missing keys: {missing}") + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + src0_shape = case["src0_shape"] + src0_valid_shape = case["src0_valid_shape"] + src1_shape = case["src1_shape"] + src1_valid_shape = case["src1_valid_shape"] + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + + input1 = np.random.randint(1, 10, size=src0_shape).astype(dtype) + input2 = np.random.randint(1, 10, size=src1_shape).astype(dtype) + + golden = np.zeros(dst_shape, dtype=dtype) + dst_vr, dst_vc = dst_valid_shape + src0_vr, src0_vc = src0_valid_shape + src1_vr = src1_valid_shape[0] + + golden[:dst_vr, :dst_vc] = np.maximum( + input1[:src0_vr, :src0_vc], np.broadcast_to(input2[:src1_vr, 0:1], (dst_vr, dst_vc)) + ).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmax/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmax/launch.cpp new file mode 100644 index 000000000..8a16fd19a --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmax/launch.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// f32 kernels +extern "C" __global__ AICORE void TROWEXPANDMAX_f32_16x32(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); +extern "C" __global__ AICORE void TROWEXPANDMAX_f32_56x128(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTROWEXPANDMAX_f32_16x32(float *src0, float *src1, float *dst, void *stream) { + TROWEXPANDMAX_f32_16x32<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} +void LaunchTROWEXPANDMAX_f32_56x128(float *src0, float *src1, float *dst, void *stream) { + TROWEXPANDMAX_f32_56x128<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// f16 kernels (use uint16_t for aclFloat16) +extern "C" __global__ AICORE void TROWEXPANDMAX_f16_48x64(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TROWEXPANDMAX_f16_16x128(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TROWEXPANDMAX_f16_32x64_noeq(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTROWEXPANDMAX_f16_48x64(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDMAX_f16_48x64<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} +void LaunchTROWEXPANDMAX_f16_16x128(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDMAX_f16_16x128<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} +void LaunchTROWEXPANDMAX_f16_32x64_noeq(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDMAX_f16_32x64_noeq<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} + +// i32 kernels +extern "C" __global__ AICORE void TROWEXPANDMAX_i32_16x32(__gm__ int32_t *src0, __gm__ int32_t *src1, __gm__ int32_t *dst); + +void LaunchTROWEXPANDMAX_i32_16x32(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDMAX_i32_16x32<<<1, nullptr, stream>>>((__gm__ int32_t *)src0, (__gm__ int32_t *)src1, (__gm__ int32_t *)dst); +} + +// i16 kernels +extern "C" __global__ AICORE void TROWEXPANDMAX_i16_16x64(__gm__ int16_t *src0, __gm__ int16_t *src1, __gm__ int16_t *dst); + +void LaunchTROWEXPANDMAX_i16_16x64(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDMAX_i16_16x64<<<1, nullptr, stream>>>((__gm__ int16_t *)src0, (__gm__ int16_t *)src1, (__gm__ int16_t *)dst); +} + +// Note: launchTRowExpandMax2 with src1Col>1 has different semantics - TBD \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmax/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmax/main.cpp new file mode 100644 index 000000000..cf2b51036 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmax/main.cpp @@ -0,0 +1,134 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang trowexpandmax ST — row-wise broadcast maximum. +// Supports f32, f16, i32, i16 + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// f32 kernels +void LaunchTROWEXPANDMAX_f32_16x32(float *src0, float *src1, float *dst, void *stream); +void LaunchTROWEXPANDMAX_f32_56x128(float *src0, float *src1, float *dst, void *stream); +// f16 kernels (use void* for aclFloat16) +void LaunchTROWEXPANDMAX_f16_48x64(void *src0, void *src1, void *dst, void *stream); +void LaunchTROWEXPANDMAX_f16_16x128(void *src0, void *src1, void *dst, void *stream); +void LaunchTROWEXPANDMAX_f16_32x64_noeq(void *src0, void *src1, void *dst, void *stream); +// i32 kernels +void LaunchTROWEXPANDMAX_i32_16x32(void *src0, void *src1, void *dst, void *stream); +// i16 kernels +void LaunchTROWEXPANDMAX_i16_16x64(void *src0, void *src1, void *dst, void *stream); + +// Note: launchTRowExpandMax2 with src1Col>1 has different semantics - TBD + +using LaunchFn = void (*)(void *, void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t src0Rows, src0Cols, src1Rows, src1Cols, dstRows, dstCols; + size_t dstValidRows, dstValidCols; + size_t elemSize; +}; + +static const TestCase kCases[] = { + // f32 cases + {"f32_16x32", (LaunchFn)LaunchTROWEXPANDMAX_f32_16x32, 16, 32, 16, 8, 16, 32, 16, 32, sizeof(float)}, + {"f32_56x128", (LaunchFn)LaunchTROWEXPANDMAX_f32_56x128, 56, 128, 56, 8, 56, 128, 56, 128, sizeof(float)}, + // f16 cases + {"f16_48x64", LaunchTROWEXPANDMAX_f16_48x64, 48, 64, 48, 16, 48, 64, 48, 64, sizeof(uint16_t)}, + {"f16_16x128", LaunchTROWEXPANDMAX_f16_16x128, 16, 128, 16, 16, 16, 128, 16, 128, sizeof(uint16_t)}, + {"f16_32x64_noeq", LaunchTROWEXPANDMAX_f16_32x64_noeq, 32, 64, 32, 16, 32, 64, 32, 64, sizeof(uint16_t)}, + // i32 cases + {"i32_16x32", LaunchTROWEXPANDMAX_i32_16x32, 16, 32, 16, 8, 16, 32, 16, 32, sizeof(int32_t)}, + // i16 cases + {"i16_16x64", LaunchTROWEXPANDMAX_i16_16x64, 16, 64, 16, 16, 16, 64, 16, 64, sizeof(int16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + size_t src0FileSize = tc.src0Rows * tc.src0Cols * tc.elemSize; + size_t src1FileSize = tc.src1Rows * tc.src1Cols * tc.elemSize; + const size_t dstFileSize = tc.dstRows * tc.dstCols * tc.elemSize; + + std::printf("[INFO] === case: %s (src0=%zux%zu, src1=%zux%zu, dst=%zux%zu) ===\n", + tc.name, tc.src0Rows, tc.src0Cols, tc.src1Rows, tc.src1Cols, tc.dstRows, tc.dstCols); + + std::string caseDir = std::string("./") + tc.name; + + void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), src0FileSize); + aclrtMallocHost((void **)(&src1Host), src1FileSize); + aclrtMallocHost((void **)(&dstHost), dstFileSize); + aclrtMalloc((void **)&src0Device, src0FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, src1FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, src0FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1Host, src1FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, src0FileSize, src0Host, src0FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, src1FileSize, src1Host, src1FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + tc.launch(src0Device, src1Device, dstDevice, stream); + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device) aclrtFree(src0Device); + if (src1Device) aclrtFree(src1Device); + if (dstDevice) aclrtFree(dstDevice); + if (src0Host) aclrtFreeHost(src0Host); + if (src1Host) aclrtFreeHost(src1Host); + if (dstHost) aclrtFreeHost(dstHost); + + if (rc == 0) std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + int rc = 0, deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) deviceId = std::atoi(envDevice); + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter && std::strcmp(kCases[i].name, caseFilter) != 0) continue; + if (RunCase(kCases[i], deviceId, stream) != 0) { rc = 1; break; } + } + + if (stream) aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmax/trowexpandmax.pto b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmax/trowexpandmax.pto new file mode 100644 index 000000000..44b52decc --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmax/trowexpandmax.pto @@ -0,0 +1,395 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.trowexpandmax: row-wise broadcast maximum. +// Supports f32, f16, i32, i16 + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // f32_16x32: launchTRowExpandMax + func.func @TROWEXPANDMAX_f32_16x32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c16, %c8], + strides = [%c128, %c128, %c128, %c8, %c1] + : !pto.tensor_view<1x1x1x16x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xf32> -> !pto.partition_tensor_view<1x1x1x16x32xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c8] + : !pto.tensor_view<1x1x1x16x8xf32> -> !pto.partition_tensor_view<1x1x1x16x8xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xf32> -> !pto.partition_tensor_view<1x1x1x16x32xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x32xf32>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x8xf32>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandmax ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x32xf32>) + return + } + + // f32_56x128: launchTRowExpandMax + func.func @TROWEXPANDMAX_f32_56x128(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c56 = arith.constant 56 : index + %c128 = arith.constant 128 : index + %c448 = arith.constant 448 : index + %c7168 = arith.constant 7168 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c56, %c128], + strides = [%c7168, %c7168, %c7168, %c128, %c1] + : !pto.tensor_view<1x1x1x56x128xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c56, %c8], + strides = [%c448, %c448, %c448, %c8, %c1] + : !pto.tensor_view<1x1x1x56x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c56, %c128], + strides = [%c7168, %c7168, %c7168, %c128, %c1] + : !pto.tensor_view<1x1x1x56x128xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c56, %c128] + : !pto.tensor_view<1x1x1x56x128xf32> -> !pto.partition_tensor_view<1x1x1x56x128xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c56, %c8] + : !pto.tensor_view<1x1x1x56x8xf32> -> !pto.partition_tensor_view<1x1x1x56x8xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c56, %c128] + : !pto.tensor_view<1x1x1x56x128xf32> -> !pto.partition_tensor_view<1x1x1x56x128xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x56x128xf32>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x56x8xf32>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandmax ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x56x128xf32>) + return + } + + // f16_48x64: launchTRowExpandMax + func.func @TROWEXPANDMAX_f16_48x64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c48 = arith.constant 48 : index + %c64 = arith.constant 64 : index + %c768 = arith.constant 768 : index + %c3072 = arith.constant 3072 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c48, %c64], + strides = [%c3072, %c3072, %c3072, %c64, %c1] + : !pto.tensor_view<1x1x1x48x64xf16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c48, %c16], + strides = [%c768, %c768, %c768, %c16, %c1] + : !pto.tensor_view<1x1x1x48x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c48, %c64], + strides = [%c3072, %c3072, %c3072, %c64, %c1] + : !pto.tensor_view<1x1x1x48x64xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c48, %c64] + : !pto.tensor_view<1x1x1x48x64xf16> -> !pto.partition_tensor_view<1x1x1x48x64xf16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c48, %c16] + : !pto.tensor_view<1x1x1x48x16xf16> -> !pto.partition_tensor_view<1x1x1x48x16xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c48, %c64] + : !pto.tensor_view<1x1x1x48x64xf16> -> !pto.partition_tensor_view<1x1x1x48x64xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x48x64xf16>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x48x16xf16>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandmax ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x48x64xf16>) + return + } + + // f16_16x128: launchTRowExpandMax + func.func @TROWEXPANDMAX_f16_16x128(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c2048 = arith.constant 2048 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c16, %c16], + strides = [%c256, %c256, %c256, %c16, %c1] + : !pto.tensor_view<1x1x1x16x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf16> -> !pto.partition_tensor_view<1x1x1x16x128xf16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c16] + : !pto.tensor_view<1x1x1x16x16xf16> -> !pto.partition_tensor_view<1x1x1x16x16xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf16> -> !pto.partition_tensor_view<1x1x1x16x128xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x128xf16>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x16xf16>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandmax ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x128xf16>) + return + } + + // f16_32x64_noeq: launchTRowExpandMax + func.func @TROWEXPANDMAX_f16_32x64_noeq(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c512 = arith.constant 512 : index + %c2048 = arith.constant 2048 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c32, %c16], + strides = [%c512, %c512, %c512, %c16, %c1] + : !pto.tensor_view<1x1x1x32x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf16> -> !pto.partition_tensor_view<1x1x1x32x64xf16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c16] + : !pto.tensor_view<1x1x1x32x16xf16> -> !pto.partition_tensor_view<1x1x1x32x16xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf16> -> !pto.partition_tensor_view<1x1x1x32x64xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x32x64xf16>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x32x16xf16>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandmax ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = false} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x64xf16>) + return + } + + // i32_16x32: launchTRowExpandMax + func.func @TROWEXPANDMAX_i32_16x32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c16, %c8], + strides = [%c128, %c128, %c128, %c8, %c1] + : !pto.tensor_view<1x1x1x16x8xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c8] + : !pto.tensor_view<1x1x1x16x8xi32> -> !pto.partition_tensor_view<1x1x1x16x8xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x8xi32>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandmax ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) + return + } + + // i16_16x64: launchTRowExpandMax + func.func @TROWEXPANDMAX_i16_16x64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c16, %c16], + strides = [%c256, %c256, %c256, %c16, %c1] + : !pto.tensor_view<1x1x1x16x16xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi16> -> !pto.partition_tensor_view<1x1x1x16x64xi16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c16] + : !pto.tensor_view<1x1x1x16x16xi16> -> !pto.partition_tensor_view<1x1x1x16x16xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi16> -> !pto.partition_tensor_view<1x1x1x16x64xi16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x64xi16>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x16xi16>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandmax ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xi16>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmin/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmin/CMakeLists.txt new file mode 100644 index 000000000..2d154b940 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmin/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(trowexpandmin) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmin/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmin/cases.py new file mode 100644 index 000000000..97443ca5b --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmin/cases.py @@ -0,0 +1,111 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for trowexpandmin ST test cases. + +trowexpandmin: row-wise broadcast minimum. +- src1Col=1: only first column of src1 is valid, broadcast to dstCols +- src1Col>1: launchTRowExpandMin2 with different semantics (TBD) +- src1 physical cols = 32/sizeof(dtype) for NPU alignment +- src0eqdst: true means src0 shape equals dst shape +""" + +import numpy as np + +CASES = [ + # launchTRowExpandMin + { + "name": "f32_16x32", + "dtype": np.float32, + "src0_shape": (16, 32), + "src0_valid_shape": (16, 32), + "src1_shape": (16, 8), # physical: 32/sizeof(f32)=8 + "src1_valid_shape": (16, 1), # src1Col=1 + "dst_shape": (16, 32), + "dst_valid_shape": (16, 32), + "eps": 1e-6, + }, + # launchTRowExpandMin + { + "name": "f32_56x128", + "dtype": np.float32, + "src0_shape": (56, 128), + "src0_valid_shape": (56, 128), + "src1_shape": (56, 8), + "src1_valid_shape": (56, 1), + "dst_shape": (56, 128), + "dst_valid_shape": (56, 128), + "eps": 1e-6, + }, + # launchTRowExpandMin + { + "name": "f16_48x64", + "dtype": np.float16, + "src0_shape": (48, 64), + "src0_valid_shape": (48, 64), + "src1_shape": (48, 16), # physical: 32/sizeof(f16)=16 + "src1_valid_shape": (48, 1), + "dst_shape": (48, 64), + "dst_valid_shape": (48, 64), + "eps": 1e-3, + }, + # launchTRowExpandMin + { + "name": "f16_16x128", + "dtype": np.float16, + "src0_shape": (16, 128), + "src0_valid_shape": (16, 128), + "src1_shape": (16, 16), + "src1_valid_shape": (16, 1), + "dst_shape": (16, 128), + "dst_valid_shape": (16, 128), + "eps": 1e-3, + }, + # launchTRowExpandMin + { + "name": "f16_32x64_noeq", + "dtype": np.float16, + "src0_shape": (32, 64), # src0eqdst=false + "src0_valid_shape": (32, 64), + "src1_shape": (32, 16), + "src1_valid_shape": (32, 1), + "dst_shape": (32, 64), + "dst_valid_shape": (32, 64), + "eps": 1e-3, + }, + # launchTRowExpandMin + { + "name": "i32_16x32", + "dtype": np.int32, + "src0_shape": (16, 32), + "src0_valid_shape": (16, 32), + "src1_shape": (16, 8), # physical: 32/sizeof(i32)=8 + "src1_valid_shape": (16, 1), + "dst_shape": (16, 32), + "dst_valid_shape": (16, 32), + "eps": 0, + }, + # launchTRowExpandMin + { + "name": "i16_16x64", + "dtype": np.int16, + "src0_shape": (16, 64), + "src0_valid_shape": (16, 64), + "src1_shape": (16, 16), # physical: 32/sizeof(i16)=16 + "src1_valid_shape": (16, 1), + "dst_shape": (16, 64), + "dst_valid_shape": (16, 64), + "eps": 0, + }, + # Note: launchTRowExpandMin2 with src1Col>1 has different semantics - TBD + # - float, 24, 64, 24, 8, true (src1Col=8) + # - float, 20, 64, 20, 8, false (src1Col=8) +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmin/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmin/compare.py new file mode 100644 index 000000000..7f637101d --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmin/compare.py @@ -0,0 +1,61 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Compare golden and output for trowexpandmin ST test cases.""" + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass + +# Inline validation for multi-input format (trowexpandmin uses src0/src1/dst) +REQUIRED_KEYS = {"name", "dtype", "src0_shape", "src0_valid_shape", "src1_shape", + "src1_valid_shape", "dst_shape", "dst_valid_shape"} +for i, case in enumerate(CASES): + missing = REQUIRED_KEYS - case.keys() + if missing: + raise ValueError(f"cases[{i}] ({case.get('name', '?')}) missing keys: {missing}") + + +def main(): + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + dtype = case["dtype"] + + vr, vc = dst_valid_shape + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=dtype).reshape(dst_shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=dtype).reshape(dst_shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmin/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmin/gen_data.py new file mode 100644 index 000000000..1c88e0eef --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmin/gen_data.py @@ -0,0 +1,53 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Generate input and golden data for trowexpandmin ST test cases. + +trowexpandmin: dst = min(src0, broadcast(src1)) +""" + +import numpy as np +from cases import CASES +from st_common import setup_case_rng, save_case_data + +# Inline validation for multi-input format (trowexpandmin uses src0/src1/dst) +REQUIRED_KEYS = {"name", "dtype", "src0_shape", "src0_valid_shape", "src1_shape", + "src1_valid_shape", "dst_shape", "dst_valid_shape"} +for i, case in enumerate(CASES): + missing = REQUIRED_KEYS - case.keys() + if missing: + raise ValueError(f"cases[{i}] ({case.get('name', '?')}) missing keys: {missing}") + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + src0_shape = case["src0_shape"] + src0_valid_shape = case["src0_valid_shape"] + src1_shape = case["src1_shape"] + src1_valid_shape = case["src1_valid_shape"] + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + + input1 = np.random.randint(1, 10, size=src0_shape).astype(dtype) + input2 = np.random.randint(1, 10, size=src1_shape).astype(dtype) + + golden = np.zeros(dst_shape, dtype=dtype) + dst_vr, dst_vc = dst_valid_shape + src0_vr, src0_vc = src0_valid_shape + src1_vr = src1_valid_shape[0] + + golden[:dst_vr, :dst_vc] = np.minimum( + input1[:src0_vr, :src0_vc], np.broadcast_to(input2[:src1_vr, 0:1], (dst_vr, dst_vc)) + ).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmin/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmin/launch.cpp new file mode 100644 index 000000000..ba11c02ff --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmin/launch.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// f32 kernels +extern "C" __global__ AICORE void TROWEXPANDMIN_f32_16x32(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); +extern "C" __global__ AICORE void TROWEXPANDMIN_f32_56x128(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTROWEXPANDMIN_f32_16x32(float *src0, float *src1, float *dst, void *stream) { + TROWEXPANDMIN_f32_16x32<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} +void LaunchTROWEXPANDMIN_f32_56x128(float *src0, float *src1, float *dst, void *stream) { + TROWEXPANDMIN_f32_56x128<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// f16 kernels (use uint16_t for aclFloat16) +extern "C" __global__ AICORE void TROWEXPANDMIN_f16_48x64(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TROWEXPANDMIN_f16_16x128(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TROWEXPANDMIN_f16_32x64_noeq(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTROWEXPANDMIN_f16_48x64(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDMIN_f16_48x64<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} +void LaunchTROWEXPANDMIN_f16_16x128(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDMIN_f16_16x128<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} +void LaunchTROWEXPANDMIN_f16_32x64_noeq(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDMIN_f16_32x64_noeq<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} + +// i32 kernels +extern "C" __global__ AICORE void TROWEXPANDMIN_i32_16x32(__gm__ int32_t *src0, __gm__ int32_t *src1, __gm__ int32_t *dst); + +void LaunchTROWEXPANDMIN_i32_16x32(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDMIN_i32_16x32<<<1, nullptr, stream>>>((__gm__ int32_t *)src0, (__gm__ int32_t *)src1, (__gm__ int32_t *)dst); +} + +// i16 kernels +extern "C" __global__ AICORE void TROWEXPANDMIN_i16_16x64(__gm__ int16_t *src0, __gm__ int16_t *src1, __gm__ int16_t *dst); + +void LaunchTROWEXPANDMIN_i16_16x64(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDMIN_i16_16x64<<<1, nullptr, stream>>>((__gm__ int16_t *)src0, (__gm__ int16_t *)src1, (__gm__ int16_t *)dst); +} + +// Note: launchTRowExpandMin2 with src1Col>1 has different semantics - TBD \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmin/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmin/main.cpp new file mode 100644 index 000000000..53e40102a --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmin/main.cpp @@ -0,0 +1,134 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang trowexpandmin ST — row-wise broadcast minimum. +// Supports f32, f16, i32, i16 + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// f32 kernels +void LaunchTROWEXPANDMIN_f32_16x32(float *src0, float *src1, float *dst, void *stream); +void LaunchTROWEXPANDMIN_f32_56x128(float *src0, float *src1, float *dst, void *stream); +// f16 kernels (use void* for aclFloat16) +void LaunchTROWEXPANDMIN_f16_48x64(void *src0, void *src1, void *dst, void *stream); +void LaunchTROWEXPANDMIN_f16_16x128(void *src0, void *src1, void *dst, void *stream); +void LaunchTROWEXPANDMIN_f16_32x64_noeq(void *src0, void *src1, void *dst, void *stream); +// i32 kernels +void LaunchTROWEXPANDMIN_i32_16x32(void *src0, void *src1, void *dst, void *stream); +// i16 kernels +void LaunchTROWEXPANDMIN_i16_16x64(void *src0, void *src1, void *dst, void *stream); + +// Note: launchTRowExpandMin2 with src1Col>1 has different semantics - TBD + +using LaunchFn = void (*)(void *, void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t src0Rows, src0Cols, src1Rows, src1Cols, dstRows, dstCols; + size_t dstValidRows, dstValidCols; + size_t elemSize; +}; + +static const TestCase kCases[] = { + // f32 cases + {"f32_16x32", (LaunchFn)LaunchTROWEXPANDMIN_f32_16x32, 16, 32, 16, 8, 16, 32, 16, 32, sizeof(float)}, + {"f32_56x128", (LaunchFn)LaunchTROWEXPANDMIN_f32_56x128, 56, 128, 56, 8, 56, 128, 56, 128, sizeof(float)}, + // f16 cases + {"f16_48x64", LaunchTROWEXPANDMIN_f16_48x64, 48, 64, 48, 16, 48, 64, 48, 64, sizeof(uint16_t)}, + {"f16_16x128", LaunchTROWEXPANDMIN_f16_16x128, 16, 128, 16, 16, 16, 128, 16, 128, sizeof(uint16_t)}, + {"f16_32x64_noeq", LaunchTROWEXPANDMIN_f16_32x64_noeq, 32, 64, 32, 16, 32, 64, 32, 64, sizeof(uint16_t)}, + // i32 cases + {"i32_16x32", LaunchTROWEXPANDMIN_i32_16x32, 16, 32, 16, 8, 16, 32, 16, 32, sizeof(int32_t)}, + // i16 cases + {"i16_16x64", LaunchTROWEXPANDMIN_i16_16x64, 16, 64, 16, 16, 16, 64, 16, 64, sizeof(int16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + size_t src0FileSize = tc.src0Rows * tc.src0Cols * tc.elemSize; + size_t src1FileSize = tc.src1Rows * tc.src1Cols * tc.elemSize; + const size_t dstFileSize = tc.dstRows * tc.dstCols * tc.elemSize; + + std::printf("[INFO] === case: %s (src0=%zux%zu, src1=%zux%zu, dst=%zux%zu) ===\n", + tc.name, tc.src0Rows, tc.src0Cols, tc.src1Rows, tc.src1Cols, tc.dstRows, tc.dstCols); + + std::string caseDir = std::string("./") + tc.name; + + void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), src0FileSize); + aclrtMallocHost((void **)(&src1Host), src1FileSize); + aclrtMallocHost((void **)(&dstHost), dstFileSize); + aclrtMalloc((void **)&src0Device, src0FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, src1FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, src0FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1Host, src1FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, src0FileSize, src0Host, src0FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, src1FileSize, src1Host, src1FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + tc.launch(src0Device, src1Device, dstDevice, stream); + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device) aclrtFree(src0Device); + if (src1Device) aclrtFree(src1Device); + if (dstDevice) aclrtFree(dstDevice); + if (src0Host) aclrtFreeHost(src0Host); + if (src1Host) aclrtFreeHost(src1Host); + if (dstHost) aclrtFreeHost(dstHost); + + if (rc == 0) std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + int rc = 0, deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) deviceId = std::atoi(envDevice); + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter && std::strcmp(kCases[i].name, caseFilter) != 0) continue; + if (RunCase(kCases[i], deviceId, stream) != 0) { rc = 1; break; } + } + + if (stream) aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmin/trowexpandmin.pto b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmin/trowexpandmin.pto new file mode 100644 index 000000000..c9f60ad5d --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmin/trowexpandmin.pto @@ -0,0 +1,395 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.trowexpandmin: row-wise broadcast minimum. +// Supports f32, f16, i32, i16 + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // f32_16x32: launchTRowExpandMin + func.func @TROWEXPANDMIN_f32_16x32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c16, %c8], + strides = [%c128, %c128, %c128, %c8, %c1] + : !pto.tensor_view<1x1x1x16x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xf32> -> !pto.partition_tensor_view<1x1x1x16x32xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c8] + : !pto.tensor_view<1x1x1x16x8xf32> -> !pto.partition_tensor_view<1x1x1x16x8xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xf32> -> !pto.partition_tensor_view<1x1x1x16x32xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x32xf32>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x8xf32>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandmin ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x32xf32>) + return + } + + // f32_56x128: launchTRowExpandMin + func.func @TROWEXPANDMIN_f32_56x128(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c56 = arith.constant 56 : index + %c128 = arith.constant 128 : index + %c448 = arith.constant 448 : index + %c7168 = arith.constant 7168 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c56, %c128], + strides = [%c7168, %c7168, %c7168, %c128, %c1] + : !pto.tensor_view<1x1x1x56x128xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c56, %c8], + strides = [%c448, %c448, %c448, %c8, %c1] + : !pto.tensor_view<1x1x1x56x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c56, %c128], + strides = [%c7168, %c7168, %c7168, %c128, %c1] + : !pto.tensor_view<1x1x1x56x128xf32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c56, %c128] + : !pto.tensor_view<1x1x1x56x128xf32> -> !pto.partition_tensor_view<1x1x1x56x128xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c56, %c8] + : !pto.tensor_view<1x1x1x56x8xf32> -> !pto.partition_tensor_view<1x1x1x56x8xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c56, %c128] + : !pto.tensor_view<1x1x1x56x128xf32> -> !pto.partition_tensor_view<1x1x1x56x128xf32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x56x128xf32>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x56x8xf32>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandmin ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x56x128xf32>) + return + } + + // f16_48x64: launchTRowExpandMin + func.func @TROWEXPANDMIN_f16_48x64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c48 = arith.constant 48 : index + %c64 = arith.constant 64 : index + %c768 = arith.constant 768 : index + %c3072 = arith.constant 3072 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c48, %c64], + strides = [%c3072, %c3072, %c3072, %c64, %c1] + : !pto.tensor_view<1x1x1x48x64xf16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c48, %c16], + strides = [%c768, %c768, %c768, %c16, %c1] + : !pto.tensor_view<1x1x1x48x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c48, %c64], + strides = [%c3072, %c3072, %c3072, %c64, %c1] + : !pto.tensor_view<1x1x1x48x64xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c48, %c64] + : !pto.tensor_view<1x1x1x48x64xf16> -> !pto.partition_tensor_view<1x1x1x48x64xf16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c48, %c16] + : !pto.tensor_view<1x1x1x48x16xf16> -> !pto.partition_tensor_view<1x1x1x48x16xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c48, %c64] + : !pto.tensor_view<1x1x1x48x64xf16> -> !pto.partition_tensor_view<1x1x1x48x64xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x48x64xf16>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x48x16xf16>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandmin ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x48x64xf16>) + return + } + + // f16_16x128: launchTRowExpandMin + func.func @TROWEXPANDMIN_f16_16x128(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c2048 = arith.constant 2048 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c16, %c16], + strides = [%c256, %c256, %c256, %c16, %c1] + : !pto.tensor_view<1x1x1x16x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf16> -> !pto.partition_tensor_view<1x1x1x16x128xf16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c16] + : !pto.tensor_view<1x1x1x16x16xf16> -> !pto.partition_tensor_view<1x1x1x16x16xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf16> -> !pto.partition_tensor_view<1x1x1x16x128xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x128xf16>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x16xf16>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandmin ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x128xf16>) + return + } + + // f16_32x64_noeq: launchTRowExpandMin + func.func @TROWEXPANDMIN_f16_32x64_noeq(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c512 = arith.constant 512 : index + %c2048 = arith.constant 2048 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c32, %c16], + strides = [%c512, %c512, %c512, %c16, %c1] + : !pto.tensor_view<1x1x1x32x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf16> -> !pto.partition_tensor_view<1x1x1x32x64xf16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c16] + : !pto.tensor_view<1x1x1x32x16xf16> -> !pto.partition_tensor_view<1x1x1x32x16xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf16> -> !pto.partition_tensor_view<1x1x1x32x64xf16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x32x64xf16>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x32x16xf16>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandmin ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = false} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x64xf16>) + return + } + + // i32_16x32: launchTRowExpandMin + func.func @TROWEXPANDMIN_i32_16x32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c16, %c8], + strides = [%c128, %c128, %c128, %c8, %c1] + : !pto.tensor_view<1x1x1x16x8xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xi32> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c8] + : !pto.tensor_view<1x1x1x16x8xi32> -> !pto.partition_tensor_view<1x1x1x16x8xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x8xi32>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandmin ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) + return + } + + // i16_16x64: launchTRowExpandMin + func.func @TROWEXPANDMIN_i16_16x64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c16, %c16], + strides = [%c256, %c256, %c256, %c16, %c1] + : !pto.tensor_view<1x1x1x16x16xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi16> + + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi16> -> !pto.partition_tensor_view<1x1x1x16x64xi16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c16] + : !pto.tensor_view<1x1x1x16x16xi16> -> !pto.partition_tensor_view<1x1x1x16x16xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi16> -> !pto.partition_tensor_view<1x1x1x16x64xi16> + + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x64xi16>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x16xi16>) + outs(%src1 : !pto.tile_buf) + + pto.trowexpandmin ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true} + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xi16>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmul/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmul/CMakeLists.txt new file mode 100644 index 000000000..5a71ec723 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmul/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(trowexpandmul) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmul/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmul/cases.py new file mode 100644 index 000000000..883563cb9 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmul/cases.py @@ -0,0 +1,111 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for trowexpandmul ST test cases. + +trowexpandmul: row-wise broadcast multiplication. +- src1Col=1: only first column of src1 is valid, broadcast to dstCols +- src1Col>1: launchTRowExpandMul2 with different semantics (TBD) +- src1 physical cols = 32/sizeof(dtype) for NPU alignment +- src0eqdst: true means src0 shape equals dst shape +""" + +import numpy as np + +CASES = [ + # launchTRowExpandMul + { + "name": "f32_16x32", + "dtype": np.float32, + "src0_shape": (16, 32), + "src0_valid_shape": (16, 32), + "src1_shape": (16, 8), # physical: 32/sizeof(f32)=8 + "src1_valid_shape": (16, 1), # src1Col=1 + "dst_shape": (16, 32), + "dst_valid_shape": (16, 32), + "eps": 1e-6, + }, + # launchTRowExpandMul + { + "name": "f32_56x128", + "dtype": np.float32, + "src0_shape": (56, 128), + "src0_valid_shape": (56, 128), + "src1_shape": (56, 8), + "src1_valid_shape": (56, 1), + "dst_shape": (56, 128), + "dst_valid_shape": (56, 128), + "eps": 1e-6, + }, + # launchTRowExpandMul + { + "name": "f16_48x64", + "dtype": np.float16, + "src0_shape": (48, 64), + "src0_valid_shape": (48, 64), + "src1_shape": (48, 16), # physical: 32/sizeof(f16)=16 + "src1_valid_shape": (48, 1), + "dst_shape": (48, 64), + "dst_valid_shape": (48, 64), + "eps": 1e-3, + }, + # launchTRowExpandMul + { + "name": "f16_16x128", + "dtype": np.float16, + "src0_shape": (16, 128), + "src0_valid_shape": (16, 128), + "src1_shape": (16, 16), + "src1_valid_shape": (16, 1), + "dst_shape": (16, 128), + "dst_valid_shape": (16, 128), + "eps": 1e-3, + }, + # launchTRowExpandMul + { + "name": "f16_32x64_noeq", + "dtype": np.float16, + "src0_shape": (32, 64), # src0eqdst=false + "src0_valid_shape": (32, 64), + "src1_shape": (32, 16), + "src1_valid_shape": (32, 1), + "dst_shape": (32, 64), + "dst_valid_shape": (32, 64), + "eps": 1e-3, + }, + # launchTRowExpandMul + { + "name": "i32_16x32", + "dtype": np.int32, + "src0_shape": (16, 32), + "src0_valid_shape": (16, 32), + "src1_shape": (16, 8), # physical: 32/sizeof(i32)=8 + "src1_valid_shape": (16, 1), + "dst_shape": (16, 32), + "dst_valid_shape": (16, 32), + "eps": 0, + }, + # launchTRowExpandMul + { + "name": "i16_16x64", + "dtype": np.int16, + "src0_shape": (16, 64), + "src0_valid_shape": (16, 64), + "src1_shape": (16, 16), # physical: 32/sizeof(i16)=16 + "src1_valid_shape": (16, 1), + "dst_shape": (16, 64), + "dst_valid_shape": (16, 64), + "eps": 0, + }, + # Note: launchTRowExpandMul2 with src1Col>1 has different semantics - TBD + # - float, 24, 64, 24, 8, true (src1Col=8) + # - float, 20, 64, 20, 8, false (src1Col=8) +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmul/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmul/compare.py new file mode 100644 index 000000000..5dea2f271 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmul/compare.py @@ -0,0 +1,61 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Compare golden and output for trowexpandmul ST test cases.""" + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass + +# Inline validation for multi-input format (trowexpandmul uses src0/src1/dst) +REQUIRED_KEYS = {"name", "dtype", "src0_shape", "src0_valid_shape", "src1_shape", + "src1_valid_shape", "dst_shape", "dst_valid_shape"} +for i, case in enumerate(CASES): + missing = REQUIRED_KEYS - case.keys() + if missing: + raise ValueError(f"cases[{i}] ({case.get('name', '?')}) missing keys: {missing}") + + +def main(): + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + dtype = case["dtype"] + + vr, vc = dst_valid_shape + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=dtype).reshape(dst_shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=dtype).reshape(dst_shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmul/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmul/gen_data.py new file mode 100644 index 000000000..4d0dde3b6 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmul/gen_data.py @@ -0,0 +1,53 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Generate input and golden data for trowexpandmul ST test cases. + +trowexpandmul: dst = src0 * broadcast(src1) +""" + +import numpy as np +from cases import CASES +from st_common import setup_case_rng, save_case_data + +# Inline validation for multi-input format (trowexpandmul uses src0/src1/dst) +REQUIRED_KEYS = {"name", "dtype", "src0_shape", "src0_valid_shape", "src1_shape", + "src1_valid_shape", "dst_shape", "dst_valid_shape"} +for i, case in enumerate(CASES): + missing = REQUIRED_KEYS - case.keys() + if missing: + raise ValueError(f"cases[{i}] ({case.get('name', '?')}) missing keys: {missing}") + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + src0_shape = case["src0_shape"] + src0_valid_shape = case["src0_valid_shape"] + src1_shape = case["src1_shape"] + src1_valid_shape = case["src1_valid_shape"] + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + + input1 = np.random.randint(1, 10, size=src0_shape).astype(dtype) + input2 = np.random.randint(1, 10, size=src1_shape).astype(dtype) + + golden = np.zeros(dst_shape, dtype=dtype) + dst_vr, dst_vc = dst_valid_shape + src0_vr, src0_vc = src0_valid_shape + src1_vr = src1_valid_shape[0] + + golden[:dst_vr, :dst_vc] = ( + input1[:src0_vr, :src0_vc] * input2[:src1_vr, 0:1] + ).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmul/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmul/launch.cpp new file mode 100644 index 000000000..3bccc18d7 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmul/launch.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// f32 kernels +extern "C" __global__ AICORE void TROWEXPANDMUL_f32_16x32(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); +extern "C" __global__ AICORE void TROWEXPANDMUL_f32_56x128(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTROWEXPANDMUL_f32_16x32(float *src0, float *src1, float *dst, void *stream) { + TROWEXPANDMUL_f32_16x32<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} +void LaunchTROWEXPANDMUL_f32_56x128(float *src0, float *src1, float *dst, void *stream) { + TROWEXPANDMUL_f32_56x128<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// f16 kernels (use uint16_t for aclFloat16) +extern "C" __global__ AICORE void TROWEXPANDMUL_f16_48x64(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TROWEXPANDMUL_f16_16x128(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TROWEXPANDMUL_f16_32x64_noeq(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTROWEXPANDMUL_f16_48x64(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDMUL_f16_48x64<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} +void LaunchTROWEXPANDMUL_f16_16x128(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDMUL_f16_16x128<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} +void LaunchTROWEXPANDMUL_f16_32x64_noeq(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDMUL_f16_32x64_noeq<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} + +// i32 kernels +extern "C" __global__ AICORE void TROWEXPANDMUL_i32_16x32(__gm__ int32_t *src0, __gm__ int32_t *src1, __gm__ int32_t *dst); + +void LaunchTROWEXPANDMUL_i32_16x32(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDMUL_i32_16x32<<<1, nullptr, stream>>>((__gm__ int32_t *)src0, (__gm__ int32_t *)src1, (__gm__ int32_t *)dst); +} + +// i16 kernels +extern "C" __global__ AICORE void TROWEXPANDMUL_i16_16x64(__gm__ int16_t *src0, __gm__ int16_t *src1, __gm__ int16_t *dst); + +void LaunchTROWEXPANDMUL_i16_16x64(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDMUL_i16_16x64<<<1, nullptr, stream>>>((__gm__ int16_t *)src0, (__gm__ int16_t *)src1, (__gm__ int16_t *)dst); +} + +// Note: launchTRowExpandMul2 with src1Col>1 has different semantics - TBD \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmul/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmul/main.cpp new file mode 100644 index 000000000..4c4cd0310 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmul/main.cpp @@ -0,0 +1,134 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang trowexpandmul ST — row-wise broadcast multiplication. +// Supports f32, f16, i32, i16 + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// f32 kernels +void LaunchTROWEXPANDMUL_f32_16x32(float *src0, float *src1, float *dst, void *stream); +void LaunchTROWEXPANDMUL_f32_56x128(float *src0, float *src1, float *dst, void *stream); +// f16 kernels (use void* for aclFloat16) +void LaunchTROWEXPANDMUL_f16_48x64(void *src0, void *src1, void *dst, void *stream); +void LaunchTROWEXPANDMUL_f16_16x128(void *src0, void *src1, void *dst, void *stream); +void LaunchTROWEXPANDMUL_f16_32x64_noeq(void *src0, void *src1, void *dst, void *stream); +// i32 kernels +void LaunchTROWEXPANDMUL_i32_16x32(void *src0, void *src1, void *dst, void *stream); +// i16 kernels +void LaunchTROWEXPANDMUL_i16_16x64(void *src0, void *src1, void *dst, void *stream); + +// Note: launchTRowExpandMul2 with src1Col>1 has different semantics - TBD + +using LaunchFn = void (*)(void *, void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t src0Rows, src0Cols, src1Rows, src1Cols, dstRows, dstCols; + size_t dstValidRows, dstValidCols; + size_t elemSize; +}; + +static const TestCase kCases[] = { + // f32 cases + {"f32_16x32", (LaunchFn)LaunchTROWEXPANDMUL_f32_16x32, 16, 32, 16, 8, 16, 32, 16, 32, sizeof(float)}, + {"f32_56x128", (LaunchFn)LaunchTROWEXPANDMUL_f32_56x128, 56, 128, 56, 8, 56, 128, 56, 128, sizeof(float)}, + // f16 cases + {"f16_48x64", LaunchTROWEXPANDMUL_f16_48x64, 48, 64, 48, 16, 48, 64, 48, 64, sizeof(uint16_t)}, + {"f16_16x128", LaunchTROWEXPANDMUL_f16_16x128, 16, 128, 16, 16, 16, 128, 16, 128, sizeof(uint16_t)}, + {"f16_32x64_noeq", LaunchTROWEXPANDMUL_f16_32x64_noeq, 32, 64, 32, 16, 32, 64, 32, 64, sizeof(uint16_t)}, + // i32 cases + {"i32_16x32", LaunchTROWEXPANDMUL_i32_16x32, 16, 32, 16, 8, 16, 32, 16, 32, sizeof(int32_t)}, + // i16 cases + {"i16_16x64", LaunchTROWEXPANDMUL_i16_16x64, 16, 64, 16, 16, 16, 64, 16, 64, sizeof(int16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + size_t src0FileSize = tc.src0Rows * tc.src0Cols * tc.elemSize; + size_t src1FileSize = tc.src1Rows * tc.src1Cols * tc.elemSize; + const size_t dstFileSize = tc.dstRows * tc.dstCols * tc.elemSize; + + std::printf("[INFO] === case: %s (src0=%zux%zu, src1=%zux%zu, dst=%zux%zu) ===\n", + tc.name, tc.src0Rows, tc.src0Cols, tc.src1Rows, tc.src1Cols, tc.dstRows, tc.dstCols); + + std::string caseDir = std::string("./") + tc.name; + + void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), src0FileSize); + aclrtMallocHost((void **)(&src1Host), src1FileSize); + aclrtMallocHost((void **)(&dstHost), dstFileSize); + aclrtMalloc((void **)&src0Device, src0FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, src1FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, src0FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1Host, src1FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, src0FileSize, src0Host, src0FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, src1FileSize, src1Host, src1FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + tc.launch(src0Device, src1Device, dstDevice, stream); + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device) aclrtFree(src0Device); + if (src1Device) aclrtFree(src1Device); + if (dstDevice) aclrtFree(dstDevice); + if (src0Host) aclrtFreeHost(src0Host); + if (src1Host) aclrtFreeHost(src1Host); + if (dstHost) aclrtFreeHost(dstHost); + + if (rc == 0) std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + int rc = 0, deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) deviceId = std::atoi(envDevice); + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter && std::strcmp(kCases[i].name, caseFilter) != 0) continue; + if (RunCase(kCases[i], deviceId, stream) != 0) { rc = 1; break; } + } + + if (stream) aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmul/trowexpandmul.pto b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmul/trowexpandmul.pto new file mode 100644 index 000000000..82e18b922 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandmul/trowexpandmul.pto @@ -0,0 +1,212 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.trowexpandmul: row-wise broadcast multiplication. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // f32_16x32: dstRow=16, dstCol=32, src1Row=16, src1Col=1, src0eqdst=true + func.func @TROWEXPANDMUL_f32_16x32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src0_view = pto.make_tensor_view %src0_ptr, shape = [%c1, %c1, %c1, %c16, %c32], strides = [%c512, %c512, %c512, %c32, %c1] : !pto.tensor_view<1x1x1x16x32xf32> + %src1_view = pto.make_tensor_view %src1_ptr, shape = [%c1, %c1, %c1, %c16, %c8], strides = [%c128, %c128, %c128, %c8, %c1] : !pto.tensor_view<1x1x1x16x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c16, %c32], strides = [%c512, %c512, %c512, %c32, %c1] : !pto.tensor_view<1x1x1x16x32xf32> + + %src0_part = pto.partition_view %src0_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c32] : !pto.tensor_view<1x1x1x16x32xf32> -> !pto.partition_tensor_view<1x1x1x16x32xf32> + %src1_part = pto.partition_view %src1_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c8] : !pto.tensor_view<1x1x1x16x8xf32> -> !pto.partition_tensor_view<1x1x1x16x8xf32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c32] : !pto.tensor_view<1x1x1x16x32xf32> -> !pto.partition_tensor_view<1x1x1x16x32xf32> + + %src0 = pto.alloc_tile : !pto.tile_buf + %src1 = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x32xf32>) outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x8xf32>) outs(%src1 : !pto.tile_buf) + pto.trowexpandmul ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true, highPrecision = false} + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x32xf32>) + return + } + + // f32_56x128: dstRow=56, dstCol=128, src1Row=56, src1Col=1, src0eqdst=true + func.func @TROWEXPANDMUL_f32_56x128(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c56 = arith.constant 56 : index + %c128 = arith.constant 128 : index + %c448 = arith.constant 448 : index + %c7168 = arith.constant 7168 : index + + %src0_view = pto.make_tensor_view %src0_ptr, shape = [%c1, %c1, %c1, %c56, %c128], strides = [%c7168, %c7168, %c7168, %c128, %c1] : !pto.tensor_view<1x1x1x56x128xf32> + %src1_view = pto.make_tensor_view %src1_ptr, shape = [%c1, %c1, %c1, %c56, %c8], strides = [%c448, %c448, %c448, %c8, %c1] : !pto.tensor_view<1x1x1x56x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c56, %c128], strides = [%c7168, %c7168, %c7168, %c128, %c1] : !pto.tensor_view<1x1x1x56x128xf32> + + %src0_part = pto.partition_view %src0_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c56, %c128] : !pto.tensor_view<1x1x1x56x128xf32> -> !pto.partition_tensor_view<1x1x1x56x128xf32> + %src1_part = pto.partition_view %src1_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c56, %c8] : !pto.tensor_view<1x1x1x56x8xf32> -> !pto.partition_tensor_view<1x1x1x56x8xf32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c56, %c128] : !pto.tensor_view<1x1x1x56x128xf32> -> !pto.partition_tensor_view<1x1x1x56x128xf32> + + %src0 = pto.alloc_tile : !pto.tile_buf + %src1 = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x56x128xf32>) outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x56x8xf32>) outs(%src1 : !pto.tile_buf) + pto.trowexpandmul ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true, highPrecision = false} + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x56x128xf32>) + return + } + + // f16_48x64: dstRow=48, dstCol=64, src1Row=48, src1Col=1, src0eqdst=true + func.func @TROWEXPANDMUL_f16_48x64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c48 = arith.constant 48 : index + %c64 = arith.constant 64 : index + %c768 = arith.constant 768 : index + %c3072 = arith.constant 3072 : index + + %src0_view = pto.make_tensor_view %src0_ptr, shape = [%c1, %c1, %c1, %c48, %c64], strides = [%c3072, %c3072, %c3072, %c64, %c1] : !pto.tensor_view<1x1x1x48x64xf16> + %src1_view = pto.make_tensor_view %src1_ptr, shape = [%c1, %c1, %c1, %c48, %c16], strides = [%c768, %c768, %c768, %c16, %c1] : !pto.tensor_view<1x1x1x48x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c48, %c64], strides = [%c3072, %c3072, %c3072, %c64, %c1] : !pto.tensor_view<1x1x1x48x64xf16> + + %src0_part = pto.partition_view %src0_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c48, %c64] : !pto.tensor_view<1x1x1x48x64xf16> -> !pto.partition_tensor_view<1x1x1x48x64xf16> + %src1_part = pto.partition_view %src1_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c48, %c16] : !pto.tensor_view<1x1x1x48x16xf16> -> !pto.partition_tensor_view<1x1x1x48x16xf16> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c48, %c64] : !pto.tensor_view<1x1x1x48x64xf16> -> !pto.partition_tensor_view<1x1x1x48x64xf16> + + %src0 = pto.alloc_tile : !pto.tile_buf + %src1 = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x48x64xf16>) outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x48x16xf16>) outs(%src1 : !pto.tile_buf) + pto.trowexpandmul ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true, highPrecision = false} + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x48x64xf16>) + return + } + + // f16_16x128: dstRow=16, dstCol=128, src1Row=16, src1Col=1, src0eqdst=true + func.func @TROWEXPANDMUL_f16_16x128(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c2048 = arith.constant 2048 : index + + %src0_view = pto.make_tensor_view %src0_ptr, shape = [%c1, %c1, %c1, %c16, %c128], strides = [%c2048, %c2048, %c2048, %c128, %c1] : !pto.tensor_view<1x1x1x16x128xf16> + %src1_view = pto.make_tensor_view %src1_ptr, shape = [%c1, %c1, %c1, %c16, %c16], strides = [%c256, %c256, %c256, %c16, %c1] : !pto.tensor_view<1x1x1x16x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c16, %c128], strides = [%c2048, %c2048, %c2048, %c128, %c1] : !pto.tensor_view<1x1x1x16x128xf16> + + %src0_part = pto.partition_view %src0_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c128] : !pto.tensor_view<1x1x1x16x128xf16> -> !pto.partition_tensor_view<1x1x1x16x128xf16> + %src1_part = pto.partition_view %src1_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c16] : !pto.tensor_view<1x1x1x16x16xf16> -> !pto.partition_tensor_view<1x1x1x16x16xf16> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c128] : !pto.tensor_view<1x1x1x16x128xf16> -> !pto.partition_tensor_view<1x1x1x16x128xf16> + + %src0 = pto.alloc_tile : !pto.tile_buf + %src1 = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x128xf16>) outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x16xf16>) outs(%src1 : !pto.tile_buf) + pto.trowexpandmul ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true, highPrecision = false} + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x128xf16>) + return + } + + // f16_32x64_noeq: dstRow=32, dstCol=64, src1Row=32, src1Col=1, src0eqdst=false + func.func @TROWEXPANDMUL_f16_32x64_noeq(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c512 = arith.constant 512 : index + %c2048 = arith.constant 2048 : index + + %src0_view = pto.make_tensor_view %src0_ptr, shape = [%c1, %c1, %c1, %c32, %c64], strides = [%c2048, %c2048, %c2048, %c64, %c1] : !pto.tensor_view<1x1x1x32x64xf16> + %src1_view = pto.make_tensor_view %src1_ptr, shape = [%c1, %c1, %c1, %c32, %c16], strides = [%c512, %c512, %c512, %c16, %c1] : !pto.tensor_view<1x1x1x32x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c32, %c64], strides = [%c2048, %c2048, %c2048, %c64, %c1] : !pto.tensor_view<1x1x1x32x64xf16> + + %src0_part = pto.partition_view %src0_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c32, %c64] : !pto.tensor_view<1x1x1x32x64xf16> -> !pto.partition_tensor_view<1x1x1x32x64xf16> + %src1_part = pto.partition_view %src1_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c32, %c16] : !pto.tensor_view<1x1x1x32x16xf16> -> !pto.partition_tensor_view<1x1x1x32x16xf16> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c32, %c64] : !pto.tensor_view<1x1x1x32x64xf16> -> !pto.partition_tensor_view<1x1x1x32x64xf16> + + %src0 = pto.alloc_tile : !pto.tile_buf + %src1 = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x32x64xf16>) outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x32x16xf16>) outs(%src1 : !pto.tile_buf) + pto.trowexpandmul ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = false, highPrecision = false} + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x64xf16>) + return + } + + // i32_16x32: dstRow=16, dstCol=32, src1Row=16, src1Col=1, src0eqdst=true + func.func @TROWEXPANDMUL_i32_16x32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src0_view = pto.make_tensor_view %src0_ptr, shape = [%c1, %c1, %c1, %c16, %c32], strides = [%c512, %c512, %c512, %c32, %c1] : !pto.tensor_view<1x1x1x16x32xi32> + %src1_view = pto.make_tensor_view %src1_ptr, shape = [%c1, %c1, %c1, %c16, %c8], strides = [%c128, %c128, %c128, %c8, %c1] : !pto.tensor_view<1x1x1x16x8xi32> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c16, %c32], strides = [%c512, %c512, %c512, %c32, %c1] : !pto.tensor_view<1x1x1x16x32xi32> + + %src0_part = pto.partition_view %src0_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c32] : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + %src1_part = pto.partition_view %src1_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c8] : !pto.tensor_view<1x1x1x16x8xi32> -> !pto.partition_tensor_view<1x1x1x16x8xi32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c32] : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + + %src0 = pto.alloc_tile : !pto.tile_buf + %src1 = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x8xi32>) outs(%src1 : !pto.tile_buf) + pto.trowexpandmul ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true, highPrecision = false} + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) + return + } + + // i16_16x64: dstRow=16, dstCol=64, src1Row=16, src1Col=1, src0eqdst=true + func.func @TROWEXPANDMUL_i16_16x64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src0_view = pto.make_tensor_view %src0_ptr, shape = [%c1, %c1, %c1, %c16, %c64], strides = [%c1024, %c1024, %c1024, %c64, %c1] : !pto.tensor_view<1x1x1x16x64xi16> + %src1_view = pto.make_tensor_view %src1_ptr, shape = [%c1, %c1, %c1, %c16, %c16], strides = [%c256, %c256, %c256, %c16, %c1] : !pto.tensor_view<1x1x1x16x16xi16> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c16, %c64], strides = [%c1024, %c1024, %c1024, %c64, %c1] : !pto.tensor_view<1x1x1x16x64xi16> + + %src0_part = pto.partition_view %src0_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c64] : !pto.tensor_view<1x1x1x16x64xi16> -> !pto.partition_tensor_view<1x1x1x16x64xi16> + %src1_part = pto.partition_view %src1_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c16] : !pto.tensor_view<1x1x1x16x16xi16> -> !pto.partition_tensor_view<1x1x1x16x16xi16> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c64] : !pto.tensor_view<1x1x1x16x64xi16> -> !pto.partition_tensor_view<1x1x1x16x64xi16> + + %src0 = pto.alloc_tile : !pto.tile_buf + %src1 = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x64xi16>) outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x16xi16>) outs(%src1 : !pto.tile_buf) + pto.trowexpandmul ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true, highPrecision = false} + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xi16>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandsub/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandsub/CMakeLists.txt new file mode 100644 index 000000000..fe69e4770 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandsub/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(trowexpandsub) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandsub/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandsub/cases.py new file mode 100644 index 000000000..c9ac76001 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandsub/cases.py @@ -0,0 +1,111 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for trowexpandsub ST test cases. + +trowexpandsub: row-wise broadcast subtraction. +- src1Col=1: only first column of src1 is valid, broadcast to dstCols +- src1Col>1: launchTRowExpandSub2 with different semantics (TBD) +- src1 physical cols = 32/sizeof(dtype) for NPU alignment +- src0eqdst: true means src0 shape equals dst shape +""" + +import numpy as np + +CASES = [ + # launchTRowExpandSub + { + "name": "f32_8x128", + "dtype": np.float32, + "src0_shape": (8, 128), + "src0_valid_shape": (8, 128), + "src1_shape": (8, 8), # physical: 32/sizeof(f32)=8 + "src1_valid_shape": (8, 1), # src1Col=1 + "dst_shape": (8, 128), + "dst_valid_shape": (8, 128), + "eps": 1e-6, + }, + # launchTRowExpandSub + { + "name": "f32_24x32", + "dtype": np.float32, + "src0_shape": (24, 32), + "src0_valid_shape": (24, 32), + "src1_shape": (24, 8), + "src1_valid_shape": (24, 1), + "dst_shape": (24, 32), + "dst_valid_shape": (24, 32), + "eps": 1e-6, + }, + # launchTRowExpandSub + { + "name": "f16_16x256", + "dtype": np.float16, + "src0_shape": (16, 256), + "src0_valid_shape": (16, 256), + "src1_shape": (16, 16), # physical: 32/sizeof(f16)=16 + "src1_valid_shape": (16, 1), + "dst_shape": (16, 256), + "dst_valid_shape": (16, 256), + "eps": 1e-3, + }, + # launchTRowExpandSub + { + "name": "f16_32x64", + "dtype": np.float16, + "src0_shape": (32, 64), + "src0_valid_shape": (32, 64), + "src1_shape": (32, 16), + "src1_valid_shape": (32, 1), + "dst_shape": (32, 64), + "dst_valid_shape": (32, 64), + "eps": 1e-3, + }, + # launchTRowExpandSub + { + "name": "f32_16x128_noeq", + "dtype": np.float32, + "src0_shape": (16, 128), # src0eqdst=false + "src0_valid_shape": (16, 128), + "src1_shape": (16, 8), + "src1_valid_shape": (16, 1), + "dst_shape": (16, 128), + "dst_valid_shape": (16, 128), + "eps": 1e-6, + }, + # launchTRowExpandSub + { + "name": "i32_16x32", + "dtype": np.int32, + "src0_shape": (16, 32), + "src0_valid_shape": (16, 32), + "src1_shape": (16, 8), # physical: 32/sizeof(i32)=8 + "src1_valid_shape": (16, 1), + "dst_shape": (16, 32), + "dst_valid_shape": (16, 32), + "eps": 0, + }, + # launchTRowExpandSub + { + "name": "i16_16x64", + "dtype": np.int16, + "src0_shape": (16, 64), + "src0_valid_shape": (16, 64), + "src1_shape": (16, 16), # physical: 32/sizeof(i16)=16 + "src1_valid_shape": (16, 1), + "dst_shape": (16, 64), + "dst_valid_shape": (16, 64), + "eps": 0, + }, + # Note: launchTRowExpandSub2 with src1Col>1 has different semantics - TBD + # - float, 24, 64, 24, 8, true (src1Col=8) + # - aclFloat16, 16, 64, 16, 16, false (src1Col=16) +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandsub/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandsub/compare.py new file mode 100644 index 000000000..3105c4da6 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandsub/compare.py @@ -0,0 +1,61 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Compare golden and output for trowexpandsub ST test cases.""" + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass + +# Inline validation for multi-input format (trowexpandsub uses src0/src1/dst) +REQUIRED_KEYS = {"name", "dtype", "src0_shape", "src0_valid_shape", "src1_shape", + "src1_valid_shape", "dst_shape", "dst_valid_shape"} +for i, case in enumerate(CASES): + missing = REQUIRED_KEYS - case.keys() + if missing: + raise ValueError(f"cases[{i}] ({case.get('name', '?')}) missing keys: {missing}") + + +def main(): + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + dtype = case["dtype"] + + vr, vc = dst_valid_shape + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=dtype).reshape(dst_shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=dtype).reshape(dst_shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandsub/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandsub/gen_data.py new file mode 100644 index 000000000..64b40f040 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandsub/gen_data.py @@ -0,0 +1,54 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Generate input and golden data for trowexpandsub ST test cases. + +trowexpandsub: dst = src0 - broadcast(src1) +""" + +import numpy as np +from cases import CASES +from st_common import setup_case_rng, save_case_data + +# Inline validation for multi-input format (trowexpandsub uses src0/src1/dst) +REQUIRED_KEYS = {"name", "dtype", "src0_shape", "src0_valid_shape", "src1_shape", + "src1_valid_shape", "dst_shape", "dst_valid_shape"} +for i, case in enumerate(CASES): + missing = REQUIRED_KEYS - case.keys() + if missing: + raise ValueError(f"cases[{i}] ({case.get('name', '?')}) missing keys: {missing}") + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + src0_shape = case["src0_shape"] + src0_valid_shape = case["src0_valid_shape"] + src1_shape = case["src1_shape"] + src1_valid_shape = case["src1_valid_shape"] + dst_shape = case["dst_shape"] + dst_valid_shape = case["dst_valid_shape"] + + input1 = np.random.randint(1, 10, size=src0_shape).astype(dtype) + input2 = np.random.randint(1, 10, size=src1_shape).astype(dtype) + + golden = np.zeros(dst_shape, dtype=dtype) + dst_vr, dst_vc = dst_valid_shape + src0_vr, src0_vc = src0_valid_shape + src1_vr = src1_valid_shape[0] + + # dst = src0 - src1_scalar (broadcasted) + golden[:dst_vr, :dst_vc] = ( + input1[:src0_vr, :src0_vc] - input2[:src1_vr, 0:1] + ).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} src0={src0_shape} src1={src1_shape} dst={dst_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandsub/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandsub/launch.cpp new file mode 100644 index 000000000..4692348e2 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandsub/launch.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// f32 kernels +extern "C" __global__ AICORE void TROWEXPANDSUB_f32_8x128(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); +extern "C" __global__ AICORE void TROWEXPANDSUB_f32_24x32(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); +extern "C" __global__ AICORE void TROWEXPANDSUB_f32_16x128_noeq(__gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTROWEXPANDSUB_f32_8x128(float *src0, float *src1, float *dst, void *stream) { + TROWEXPANDSUB_f32_8x128<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} +void LaunchTROWEXPANDSUB_f32_24x32(float *src0, float *src1, float *dst, void *stream) { + TROWEXPANDSUB_f32_24x32<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} +void LaunchTROWEXPANDSUB_f32_16x128_noeq(float *src0, float *src1, float *dst, void *stream) { + TROWEXPANDSUB_f32_16x128_noeq<<<1, nullptr, stream>>>((__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// f16 kernels (use uint16_t for aclFloat16) +extern "C" __global__ AICORE void TROWEXPANDSUB_f16_16x256(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TROWEXPANDSUB_f16_32x64(__gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTROWEXPANDSUB_f16_16x256(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDSUB_f16_16x256<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} +void LaunchTROWEXPANDSUB_f16_32x64(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDSUB_f16_32x64<<<1, nullptr, stream>>>((__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} + +// i32 kernels +extern "C" __global__ AICORE void TROWEXPANDSUB_i32_16x32(__gm__ int32_t *src0, __gm__ int32_t *src1, __gm__ int32_t *dst); + +void LaunchTROWEXPANDSUB_i32_16x32(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDSUB_i32_16x32<<<1, nullptr, stream>>>((__gm__ int32_t *)src0, (__gm__ int32_t *)src1, (__gm__ int32_t *)dst); +} + +// i16 kernels +extern "C" __global__ AICORE void TROWEXPANDSUB_i16_16x64(__gm__ int16_t *src0, __gm__ int16_t *src1, __gm__ int16_t *dst); + +void LaunchTROWEXPANDSUB_i16_16x64(void *src0, void *src1, void *dst, void *stream) { + TROWEXPANDSUB_i16_16x64<<<1, nullptr, stream>>>((__gm__ int16_t *)src0, (__gm__ int16_t *)src1, (__gm__ int16_t *)dst); +} + +// Note: launchTRowExpandSub2 with src1Col>1 has different semantics - TBD \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandsub/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandsub/main.cpp new file mode 100644 index 000000000..4943e2cc7 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandsub/main.cpp @@ -0,0 +1,134 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang trowexpandsub ST — row-wise broadcast subtraction. +// Supports f32, f16, i32, i16 + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// f32 kernels +void LaunchTROWEXPANDSUB_f32_8x128(float *src0, float *src1, float *dst, void *stream); +void LaunchTROWEXPANDSUB_f32_24x32(float *src0, float *src1, float *dst, void *stream); +void LaunchTROWEXPANDSUB_f32_16x128_noeq(float *src0, float *src1, float *dst, void *stream); +// f16 kernels (use void* for aclFloat16) +void LaunchTROWEXPANDSUB_f16_16x256(void *src0, void *src1, void *dst, void *stream); +void LaunchTROWEXPANDSUB_f16_32x64(void *src0, void *src1, void *dst, void *stream); +// i32 kernels +void LaunchTROWEXPANDSUB_i32_16x32(void *src0, void *src1, void *dst, void *stream); +// i16 kernels +void LaunchTROWEXPANDSUB_i16_16x64(void *src0, void *src1, void *dst, void *stream); + +// Note: launchTRowExpandSub2 with src1Col>1 has different semantics - TBD + +using LaunchFn = void (*)(void *, void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t src0Rows, src0Cols, src1Rows, src1Cols, dstRows, dstCols; + size_t dstValidRows, dstValidCols; + size_t elemSize; +}; + +static const TestCase kCases[] = { + // f32 cases + {"f32_8x128", (LaunchFn)LaunchTROWEXPANDSUB_f32_8x128, 8, 128, 8, 8, 8, 128, 8, 128, sizeof(float)}, + {"f32_24x32", (LaunchFn)LaunchTROWEXPANDSUB_f32_24x32, 24, 32, 24, 8, 24, 32, 24, 32, sizeof(float)}, + {"f32_16x128_noeq", (LaunchFn)LaunchTROWEXPANDSUB_f32_16x128_noeq, 16, 128, 16, 8, 16, 128, 16, 128, sizeof(float)}, + // f16 cases + {"f16_16x256", LaunchTROWEXPANDSUB_f16_16x256, 16, 256, 16, 16, 16, 256, 16, 256, sizeof(uint16_t)}, + {"f16_32x64", LaunchTROWEXPANDSUB_f16_32x64, 32, 64, 32, 16, 32, 64, 32, 64, sizeof(uint16_t)}, + // i32 cases + {"i32_16x32", LaunchTROWEXPANDSUB_i32_16x32, 16, 32, 16, 8, 16, 32, 16, 32, sizeof(int32_t)}, + // i16 cases + {"i16_16x64", LaunchTROWEXPANDSUB_i16_16x64, 16, 64, 16, 16, 16, 64, 16, 64, sizeof(int16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + size_t src0FileSize = tc.src0Rows * tc.src0Cols * tc.elemSize; + size_t src1FileSize = tc.src1Rows * tc.src1Cols * tc.elemSize; + const size_t dstFileSize = tc.dstRows * tc.dstCols * tc.elemSize; + + std::printf("[INFO] === case: %s (src0=%zux%zu, src1=%zux%zu, dst=%zux%zu) ===\n", + tc.name, tc.src0Rows, tc.src0Cols, tc.src1Rows, tc.src1Cols, tc.dstRows, tc.dstCols); + + std::string caseDir = std::string("./") + tc.name; + + void *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), src0FileSize); + aclrtMallocHost((void **)(&src1Host), src1FileSize); + aclrtMallocHost((void **)(&dstHost), dstFileSize); + aclrtMalloc((void **)&src0Device, src0FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, src1FileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, src0FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1Host, src1FileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, src0FileSize, src0Host, src0FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, src1FileSize, src1Host, src1FileSize, ACL_MEMCPY_HOST_TO_DEVICE); + tc.launch(src0Device, src1Device, dstDevice, stream); + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device) aclrtFree(src0Device); + if (src1Device) aclrtFree(src1Device); + if (dstDevice) aclrtFree(dstDevice); + if (src0Host) aclrtFreeHost(src0Host); + if (src1Host) aclrtFreeHost(src1Host); + if (dstHost) aclrtFreeHost(dstHost); + + if (rc == 0) std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + int rc = 0, deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) deviceId = std::atoi(envDevice); + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter && std::strcmp(kCases[i].name, caseFilter) != 0) continue; + if (RunCase(kCases[i], deviceId, stream) != 0) { rc = 1; break; } + } + + if (stream) aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowexpandsub/trowexpandsub.pto b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandsub/trowexpandsub.pto new file mode 100644 index 000000000..a28568fcf --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowexpandsub/trowexpandsub.pto @@ -0,0 +1,211 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.trowexpandsub: row-wise broadcast subtraction. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // f32_8x128: dstRow=8, dstCol=128, src1Row=8, src1Col=1, src0eqdst=true + func.func @TROWEXPANDSUB_f32_8x128(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + + %src0_view = pto.make_tensor_view %src0_ptr, shape = [%c1, %c1, %c1, %c8, %c128], strides = [%c1024, %c1024, %c1024, %c128, %c1] : !pto.tensor_view<1x1x1x8x128xf32> + %src1_view = pto.make_tensor_view %src1_ptr, shape = [%c1, %c1, %c1, %c8, %c8], strides = [%c64, %c64, %c64, %c8, %c1] : !pto.tensor_view<1x1x1x8x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c8, %c128], strides = [%c1024, %c1024, %c1024, %c128, %c1] : !pto.tensor_view<1x1x1x8x128xf32> + + %src0_part = pto.partition_view %src0_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c8, %c128] : !pto.tensor_view<1x1x1x8x128xf32> -> !pto.partition_tensor_view<1x1x1x8x128xf32> + %src1_part = pto.partition_view %src1_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c8, %c8] : !pto.tensor_view<1x1x1x8x8xf32> -> !pto.partition_tensor_view<1x1x1x8x8xf32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c8, %c128] : !pto.tensor_view<1x1x1x8x128xf32> -> !pto.partition_tensor_view<1x1x1x8x128xf32> + + %src0 = pto.alloc_tile : !pto.tile_buf + %src1 = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x8x128xf32>) outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x8x8xf32>) outs(%src1 : !pto.tile_buf) + pto.trowexpandsub ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true, highPrecision = false} + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x8x128xf32>) + return + } + + // f32_24x32: dstRow=24, dstCol=32, src1Row=24, src1Col=1, src0eqdst=true + func.func @TROWEXPANDSUB_f32_24x32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c24 = arith.constant 24 : index + %c32 = arith.constant 32 : index + %c192 = arith.constant 192 : index + %c768 = arith.constant 768 : index + + %src0_view = pto.make_tensor_view %src0_ptr, shape = [%c1, %c1, %c1, %c24, %c32], strides = [%c768, %c768, %c768, %c32, %c1] : !pto.tensor_view<1x1x1x24x32xf32> + %src1_view = pto.make_tensor_view %src1_ptr, shape = [%c1, %c1, %c1, %c24, %c8], strides = [%c192, %c192, %c192, %c8, %c1] : !pto.tensor_view<1x1x1x24x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c24, %c32], strides = [%c768, %c768, %c768, %c32, %c1] : !pto.tensor_view<1x1x1x24x32xf32> + + %src0_part = pto.partition_view %src0_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c24, %c32] : !pto.tensor_view<1x1x1x24x32xf32> -> !pto.partition_tensor_view<1x1x1x24x32xf32> + %src1_part = pto.partition_view %src1_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c24, %c8] : !pto.tensor_view<1x1x1x24x8xf32> -> !pto.partition_tensor_view<1x1x1x24x8xf32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c24, %c32] : !pto.tensor_view<1x1x1x24x32xf32> -> !pto.partition_tensor_view<1x1x1x24x32xf32> + + %src0 = pto.alloc_tile : !pto.tile_buf + %src1 = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x24x32xf32>) outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x24x8xf32>) outs(%src1 : !pto.tile_buf) + pto.trowexpandsub ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true, highPrecision = false} + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x24x32xf32>) + return + } + + // f16_16x256: dstRow=16, dstCol=256, src1Row=16, src1Col=1, src0eqdst=true + func.func @TROWEXPANDSUB_f16_16x256(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c256_2 = arith.constant 256 : index + %c4096 = arith.constant 4096 : index + + %src0_view = pto.make_tensor_view %src0_ptr, shape = [%c1, %c1, %c1, %c16, %c256], strides = [%c4096, %c4096, %c4096, %c256, %c1] : !pto.tensor_view<1x1x1x16x256xf16> + %src1_view = pto.make_tensor_view %src1_ptr, shape = [%c1, %c1, %c1, %c16, %c16], strides = [%c256_2, %c256_2, %c256_2, %c16, %c1] : !pto.tensor_view<1x1x1x16x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c16, %c256], strides = [%c4096, %c4096, %c4096, %c256, %c1] : !pto.tensor_view<1x1x1x16x256xf16> + + %src0_part = pto.partition_view %src0_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c256] : !pto.tensor_view<1x1x1x16x256xf16> -> !pto.partition_tensor_view<1x1x1x16x256xf16> + %src1_part = pto.partition_view %src1_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c16] : !pto.tensor_view<1x1x1x16x16xf16> -> !pto.partition_tensor_view<1x1x1x16x16xf16> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c256] : !pto.tensor_view<1x1x1x16x256xf16> -> !pto.partition_tensor_view<1x1x1x16x256xf16> + + %src0 = pto.alloc_tile : !pto.tile_buf + %src1 = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x256xf16>) outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x16xf16>) outs(%src1 : !pto.tile_buf) + pto.trowexpandsub ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true, highPrecision = false} + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x256xf16>) + return + } + + // f16_32x64: dstRow=32, dstCol=64, src1Row=32, src1Col=1, src0eqdst=true + func.func @TROWEXPANDSUB_f16_32x64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c512 = arith.constant 512 : index + %c2048 = arith.constant 2048 : index + + %src0_view = pto.make_tensor_view %src0_ptr, shape = [%c1, %c1, %c1, %c32, %c64], strides = [%c2048, %c2048, %c2048, %c64, %c1] : !pto.tensor_view<1x1x1x32x64xf16> + %src1_view = pto.make_tensor_view %src1_ptr, shape = [%c1, %c1, %c1, %c32, %c16], strides = [%c512, %c512, %c512, %c16, %c1] : !pto.tensor_view<1x1x1x32x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c32, %c64], strides = [%c2048, %c2048, %c2048, %c64, %c1] : !pto.tensor_view<1x1x1x32x64xf16> + + %src0_part = pto.partition_view %src0_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c32, %c64] : !pto.tensor_view<1x1x1x32x64xf16> -> !pto.partition_tensor_view<1x1x1x32x64xf16> + %src1_part = pto.partition_view %src1_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c32, %c16] : !pto.tensor_view<1x1x1x32x16xf16> -> !pto.partition_tensor_view<1x1x1x32x16xf16> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c32, %c64] : !pto.tensor_view<1x1x1x32x64xf16> -> !pto.partition_tensor_view<1x1x1x32x64xf16> + + %src0 = pto.alloc_tile : !pto.tile_buf + %src1 = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x32x64xf16>) outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x32x16xf16>) outs(%src1 : !pto.tile_buf) + pto.trowexpandsub ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true, highPrecision = false} + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x64xf16>) + return + } + + // f32_16x128_noeq: dstRow=16, dstCol=128, src1Row=16, src1Col=1, src0eqdst=false + func.func @TROWEXPANDSUB_f32_16x128_noeq(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c128_2 = arith.constant 128 : index + %c2048 = arith.constant 2048 : index + + %src0_view = pto.make_tensor_view %src0_ptr, shape = [%c1, %c1, %c1, %c16, %c128], strides = [%c2048, %c2048, %c2048, %c128, %c1] : !pto.tensor_view<1x1x1x16x128xf32> + %src1_view = pto.make_tensor_view %src1_ptr, shape = [%c1, %c1, %c1, %c16, %c8], strides = [%c128_2, %c128_2, %c128_2, %c8, %c1] : !pto.tensor_view<1x1x1x16x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c16, %c128], strides = [%c2048, %c2048, %c2048, %c128, %c1] : !pto.tensor_view<1x1x1x16x128xf32> + + %src0_part = pto.partition_view %src0_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c128] : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x128xf32> + %src1_part = pto.partition_view %src1_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c8] : !pto.tensor_view<1x1x1x16x8xf32> -> !pto.partition_tensor_view<1x1x1x16x8xf32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c128] : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x128xf32> + + %src0 = pto.alloc_tile : !pto.tile_buf + %src1 = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x128xf32>) outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x8xf32>) outs(%src1 : !pto.tile_buf) + pto.trowexpandsub ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = false, highPrecision = false} + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x128xf32>) + return + } + + // i32_16x32: dstRow=16, dstCol=32, src1Row=16, src1Col=1, src0eqdst=true + func.func @TROWEXPANDSUB_i32_16x32(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %src0_view = pto.make_tensor_view %src0_ptr, shape = [%c1, %c1, %c1, %c16, %c32], strides = [%c512, %c512, %c512, %c32, %c1] : !pto.tensor_view<1x1x1x16x32xi32> + %src1_view = pto.make_tensor_view %src1_ptr, shape = [%c1, %c1, %c1, %c16, %c8], strides = [%c128, %c128, %c128, %c8, %c1] : !pto.tensor_view<1x1x1x16x8xi32> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c16, %c32], strides = [%c512, %c512, %c512, %c32, %c1] : !pto.tensor_view<1x1x1x16x32xi32> + + %src0_part = pto.partition_view %src0_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c32] : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + %src1_part = pto.partition_view %src1_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c8] : !pto.tensor_view<1x1x1x16x8xi32> -> !pto.partition_tensor_view<1x1x1x16x8xi32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c32] : !pto.tensor_view<1x1x1x16x32xi32> -> !pto.partition_tensor_view<1x1x1x16x32xi32> + + %src0 = pto.alloc_tile : !pto.tile_buf + %src1 = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x8xi32>) outs(%src1 : !pto.tile_buf) + pto.trowexpandsub ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true, highPrecision = false} + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x32xi32>) + return + } + + // i16_16x64: dstRow=16, dstCol=64, src1Row=16, src1Col=1, src0eqdst=true + func.func @TROWEXPANDSUB_i16_16x64(%src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src0_view = pto.make_tensor_view %src0_ptr, shape = [%c1, %c1, %c1, %c16, %c64], strides = [%c1024, %c1024, %c1024, %c64, %c1] : !pto.tensor_view<1x1x1x16x64xi16> + %src1_view = pto.make_tensor_view %src1_ptr, shape = [%c1, %c1, %c1, %c16, %c16], strides = [%c256, %c256, %c256, %c16, %c1] : !pto.tensor_view<1x1x1x16x16xi16> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c16, %c64], strides = [%c1024, %c1024, %c1024, %c64, %c1] : !pto.tensor_view<1x1x1x16x64xi16> + + %src0_part = pto.partition_view %src0_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c64] : !pto.tensor_view<1x1x1x16x64xi16> -> !pto.partition_tensor_view<1x1x1x16x64xi16> + %src1_part = pto.partition_view %src1_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c16] : !pto.tensor_view<1x1x1x16x16xi16> -> !pto.partition_tensor_view<1x1x1x16x16xi16> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c16, %c64] : !pto.tensor_view<1x1x1x16x64xi16> -> !pto.partition_tensor_view<1x1x1x16x64xi16> + + %src0 = pto.alloc_tile : !pto.tile_buf + %src1 = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x16x64xi16>) outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x16x16xi16>) outs(%src1 : !pto.tile_buf) + pto.trowexpandsub ins(%src0, %src1 : !pto.tile_buf, !pto.tile_buf) outs(%dst : !pto.tile_buf) {src1Col = 1 : i64, src0eqdst = true, highPrecision = false} + pto.tstore ins(%dst : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x64xi16>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowmax/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/trowmax/CMakeLists.txt new file mode 100644 index 000000000..62291cfb6 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowmax/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(trowmax) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowmax/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/trowmax/cases.py new file mode 100644 index 000000000..f6db377f2 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowmax/cases.py @@ -0,0 +1,224 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for trowmax ST test cases. + +Aligned with pto-isa tests/npu/a2a3/src/st/testcase/trowmax (28 cases). +""" + +import numpy as np + +CASES = [ + # f32 cases (case1-case5 from pto-isa) + { + "name": "f32_127x64_valid127x63", + "dtype": np.float32, + "shape": (127, 64), + "valid_shape": (127, 63), + "eps": 1e-5, + }, + { + "name": "f32_63x64", + "dtype": np.float32, + "shape": (63, 64), + "valid_shape": (63, 64), + "eps": 1e-5, + }, + { + "name": "f32_31x128_valid31x127", + "dtype": np.float32, + "shape": (31, 128), + "valid_shape": (31, 127), + "eps": 1e-5, + }, + { + "name": "f32_15x192", + "dtype": np.float32, + "shape": (15, 192), + "valid_shape": (15, 192), + "eps": 1e-5, + }, + { + "name": "f32_7x448_valid7x447", + "dtype": np.float32, + "shape": (7, 448), + "valid_shape": (7, 447), + "eps": 1e-5, + }, + # f16 case (case6 from pto-isa) + { + "name": "f16_256x16_valid256x15", + "dtype": np.float16, + "shape": (256, 16), + "valid_shape": (256, 15), + "eps": 1e-2, + }, + # f32 more cases (case7-case14 from pto-isa) + { + "name": "f32_30x216", + "dtype": np.float32, + "shape": (30, 216), + "valid_shape": (30, 216), + "eps": 1e-5, + }, + { + "name": "f32_30x216_valid30x24", + "dtype": np.float32, + "shape": (30, 216), + "valid_shape": (30, 24), + "eps": 1e-5, + }, + { + "name": "f32_30x216_valid11x216", + "dtype": np.float32, + "shape": (30, 216), + "valid_shape": (11, 216), + "eps": 1e-5, + }, + { + "name": "f32_30x216_valid11x24", + "dtype": np.float32, + "shape": (30, 216), + "valid_shape": (11, 24), + "eps": 1e-5, + }, + { + "name": "f32_238x40", + "dtype": np.float32, + "shape": (238, 40), + "valid_shape": (238, 40), + "eps": 1e-5, + }, + { + "name": "f32_238x40_valid238x16", + "dtype": np.float32, + "shape": (238, 40), + "valid_shape": (238, 16), + "eps": 1e-5, + }, + { + "name": "f32_238x40_valid121x40", + "dtype": np.float32, + "shape": (238, 40), + "valid_shape": (121, 40), + "eps": 1e-5, + }, + { + "name": "f32_238x40_valid121x16", + "dtype": np.float32, + "shape": (238, 40), + "valid_shape": (121, 16), + "eps": 1e-5, + }, + # f32 DN dst cases (case15-case18 from pto-isa) + { + "name": "f32_64x128", + "dtype": np.float32, + "shape": (64, 128), + "valid_shape": (64, 128), + "eps": 1e-5, + }, + { + "name": "f32_32x256", + "dtype": np.float32, + "shape": (32, 256), + "valid_shape": (32, 256), + "eps": 1e-5, + }, + { + "name": "f32_16x512", + "dtype": np.float32, + "shape": (16, 512), + "valid_shape": (16, 512), + "eps": 1e-5, + }, + { + "name": "f32_8x1024", + "dtype": np.float32, + "shape": (8, 1024), + "valid_shape": (8, 1024), + "eps": 1e-5, + }, + + # int32 cases (case19-case23 from pto-isa) + { + "name": "i32_127x64_valid127x63", + "dtype": np.int32, + "shape": (127, 64), + "valid_shape": (127, 63), + "eps": 0, + }, + { + "name": "i32_63x64", + "dtype": np.int32, + "shape": (63, 64), + "valid_shape": (63, 64), + "eps": 0, + }, + { + "name": "i32_31x128_valid31x127", + "dtype": np.int32, + "shape": (31, 128), + "valid_shape": (31, 127), + "eps": 0, + }, + { + "name": "i32_15x192", + "dtype": np.int32, + "shape": (15, 192), + "valid_shape": (15, 192), + "eps": 0, + }, + { + "name": "i32_7x448_valid7x447", + "dtype": np.int32, + "shape": (7, 448), + "valid_shape": (7, 447), + "eps": 0, + }, + + # int16 cases (case24-case28 from pto-isa) + { + "name": "i16_128x64", + "dtype": np.int16, + "shape": (128, 64), + "valid_shape": (128, 64), + "eps": 0, + }, + { + "name": "i16_64x64", + "dtype": np.int16, + "shape": (64, 64), + "valid_shape": (64, 64), + "eps": 0, + }, + { + "name": "i16_32x128", + "dtype": np.int16, + "shape": (32, 128), + "valid_shape": (32, 128), + "eps": 0, + }, + { + "name": "i16_16x192", + "dtype": np.int16, + "shape": (16, 192), + "valid_shape": (16, 192), + "eps": 0, + }, + { + "name": "i16_8x448", + "dtype": np.int16, + "shape": (8, 448), + "valid_shape": (8, 448), + "eps": 0, + }, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowmax/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/trowmax/compare.py new file mode 100644 index 000000000..12d4207bd --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowmax/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + vr, vc = case["valid_shape"] + out_shape = (vr,) + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"], count=np.prod(out_shape)).reshape(out_shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"], count=np.prod(out_shape)).reshape(out_shape) + + ok = result_cmp(golden, output, case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowmax/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/trowmax/gen_data.py new file mode 100644 index 000000000..97495c982 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowmax/gen_data.py @@ -0,0 +1,41 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + if np.issubdtype(dtype, np.integer): + if dtype == np.int32: + input1 = np.random.randint(low=-100, high=100, size=shape).astype(dtype) + else: + input1 = np.random.randint(low=-50, high=50, size=shape).astype(dtype) + else: + input1 = np.random.uniform(low=-16, high=16, size=shape).astype(dtype) + + out_shape = (valid_shape[0],) + golden = np.zeros(out_shape, dtype=dtype) + vr, vc = valid_shape + for i in range(vr): + golden[i] = np.max(input1[i, :vc]) + + golden = golden.astype(dtype, copy=False) + save_case_data(case["name"], {"input1": input1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowmax/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowmax/launch.cpp new file mode 100644 index 000000000..a5d840da9 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowmax/launch.cpp @@ -0,0 +1,183 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +extern "C" __global__ AICORE void TROWMAX_f32_127x64_valid127x63(__gm__ float *src, __gm__ float *dst); + +void LaunchTROWMAX_f32_127x64_valid127x63(float *src, float *dst, void *stream) { + TROWMAX_f32_127x64_valid127x63<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMAX_f32_63x64(__gm__ float *src, __gm__ float *dst); + +void LaunchTROWMAX_f32_63x64(float *src, float *dst, void *stream) { + TROWMAX_f32_63x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMAX_f32_31x128_valid31x127(__gm__ float *src, __gm__ float *dst); + +void LaunchTROWMAX_f32_31x128_valid31x127(float *src, float *dst, void *stream) { + TROWMAX_f32_31x128_valid31x127<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMAX_f32_15x192(__gm__ float *src, __gm__ float *dst); + +void LaunchTROWMAX_f32_15x192(float *src, float *dst, void *stream) { + TROWMAX_f32_15x192<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMAX_f32_7x448_valid7x447(__gm__ float *src, __gm__ float *dst); + +void LaunchTROWMAX_f32_7x448_valid7x447(float *src, float *dst, void *stream) { + TROWMAX_f32_7x448_valid7x447<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMAX_f16_256x16_valid256x15(__gm__ uint16_t *src, __gm__ uint16_t *dst); + +void LaunchTROWMAX_f16_256x16_valid256x15(uint16_t *src, uint16_t *dst, void *stream) { + TROWMAX_f16_256x16_valid256x15<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint16_t *)dst); +} + +extern "C" __global__ AICORE void TROWMAX_f32_30x216(__gm__ float *src, __gm__ float *dst); + +void LaunchTROWMAX_f32_30x216(float *src, float *dst, void *stream) { + TROWMAX_f32_30x216<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMAX_f32_30x216_valid30x24(__gm__ float *src, __gm__ float *dst); + +void LaunchTROWMAX_f32_30x216_valid30x24(float *src, float *dst, void *stream) { + TROWMAX_f32_30x216_valid30x24<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMAX_f32_30x216_valid11x216(__gm__ float *src, __gm__ float *dst); + +void LaunchTROWMAX_f32_30x216_valid11x216(float *src, float *dst, void *stream) { + TROWMAX_f32_30x216_valid11x216<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMAX_f32_30x216_valid11x24(__gm__ float *src, __gm__ float *dst); + +void LaunchTROWMAX_f32_30x216_valid11x24(float *src, float *dst, void *stream) { + TROWMAX_f32_30x216_valid11x24<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMAX_f32_238x40(__gm__ float *src, __gm__ float *dst); + +void LaunchTROWMAX_f32_238x40(float *src, float *dst, void *stream) { + TROWMAX_f32_238x40<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMAX_f32_238x40_valid238x16(__gm__ float *src, __gm__ float *dst); + +void LaunchTROWMAX_f32_238x40_valid238x16(float *src, float *dst, void *stream) { + TROWMAX_f32_238x40_valid238x16<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMAX_f32_238x40_valid121x40(__gm__ float *src, __gm__ float *dst); + +void LaunchTROWMAX_f32_238x40_valid121x40(float *src, float *dst, void *stream) { + TROWMAX_f32_238x40_valid121x40<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMAX_f32_238x40_valid121x16(__gm__ float *src, __gm__ float *dst); + +void LaunchTROWMAX_f32_238x40_valid121x16(float *src, float *dst, void *stream) { + TROWMAX_f32_238x40_valid121x16<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMAX_f32_64x128(__gm__ float *src, __gm__ float *dst); + +void LaunchTROWMAX_f32_64x128(float *src, float *dst, void *stream) { + TROWMAX_f32_64x128<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMAX_f32_32x256(__gm__ float *src, __gm__ float *dst); + +void LaunchTROWMAX_f32_32x256(float *src, float *dst, void *stream) { + TROWMAX_f32_32x256<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMAX_f32_16x512(__gm__ float *src, __gm__ float *dst); + +void LaunchTROWMAX_f32_16x512(float *src, float *dst, void *stream) { + TROWMAX_f32_16x512<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMAX_f32_8x1024(__gm__ float *src, __gm__ float *dst); + +void LaunchTROWMAX_f32_8x1024(float *src, float *dst, void *stream) { + TROWMAX_f32_8x1024<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +// int32 cases +extern "C" __global__ AICORE void TROWMAX_i32_127x64_valid127x63(__gm__ int32_t *src, __gm__ int32_t *dst); + +void LaunchTROWMAX_i32_127x64_valid127x63(int32_t *src, int32_t *dst, void *stream) { + TROWMAX_i32_127x64_valid127x63<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst); +} + +extern "C" __global__ AICORE void TROWMAX_i32_63x64(__gm__ int32_t *src, __gm__ int32_t *dst); + +void LaunchTROWMAX_i32_63x64(int32_t *src, int32_t *dst, void *stream) { + TROWMAX_i32_63x64<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst); +} + +extern "C" __global__ AICORE void TROWMAX_i32_31x128_valid31x127(__gm__ int32_t *src, __gm__ int32_t *dst); + +void LaunchTROWMAX_i32_31x128_valid31x127(int32_t *src, int32_t *dst, void *stream) { + TROWMAX_i32_31x128_valid31x127<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst); +} + +extern "C" __global__ AICORE void TROWMAX_i32_15x192(__gm__ int32_t *src, __gm__ int32_t *dst); + +void LaunchTROWMAX_i32_15x192(int32_t *src, int32_t *dst, void *stream) { + TROWMAX_i32_15x192<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst); +} + +extern "C" __global__ AICORE void TROWMAX_i32_7x448_valid7x447(__gm__ int32_t *src, __gm__ int32_t *dst); + +void LaunchTROWMAX_i32_7x448_valid7x447(int32_t *src, int32_t *dst, void *stream) { + TROWMAX_i32_7x448_valid7x447<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst); +} + +// int16 cases +extern "C" __global__ AICORE void TROWMAX_i16_128x64(__gm__ int16_t *src, __gm__ int16_t *dst); + +void LaunchTROWMAX_i16_128x64(int16_t *src, int16_t *dst, void *stream) { + TROWMAX_i16_128x64<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst); +} + +extern "C" __global__ AICORE void TROWMAX_i16_64x64(__gm__ int16_t *src, __gm__ int16_t *dst); + +void LaunchTROWMAX_i16_64x64(int16_t *src, int16_t *dst, void *stream) { + TROWMAX_i16_64x64<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst); +} + +extern "C" __global__ AICORE void TROWMAX_i16_32x128(__gm__ int16_t *src, __gm__ int16_t *dst); + +void LaunchTROWMAX_i16_32x128(int16_t *src, int16_t *dst, void *stream) { + TROWMAX_i16_32x128<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst); +} + +extern "C" __global__ AICORE void TROWMAX_i16_16x192(__gm__ int16_t *src, __gm__ int16_t *dst); + +void LaunchTROWMAX_i16_16x192(int16_t *src, int16_t *dst, void *stream) { + TROWMAX_i16_16x192<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst); +} + +extern "C" __global__ AICORE void TROWMAX_i16_8x448(__gm__ int16_t *src, __gm__ int16_t *dst); + +void LaunchTROWMAX_i16_8x448(int16_t *src, int16_t *dst, void *stream) { + TROWMAX_i16_8x448<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowmax/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowmax/main.cpp new file mode 100644 index 000000000..b89132b41 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowmax/main.cpp @@ -0,0 +1,207 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang trowmax ST — case-table driven. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTROWMAX_f32_127x64_valid127x63(float *src, float *dst, void *stream); +void LaunchTROWMAX_f32_63x64(float *src, float *dst, void *stream); +void LaunchTROWMAX_f32_31x128_valid31x127(float *src, float *dst, void *stream); +void LaunchTROWMAX_f32_15x192(float *src, float *dst, void *stream); +void LaunchTROWMAX_f32_7x448_valid7x447(float *src, float *dst, void *stream); +void LaunchTROWMAX_f16_256x16_valid256x15(uint16_t *src, uint16_t *dst, void *stream); +void LaunchTROWMAX_f32_30x216(float *src, float *dst, void *stream); +void LaunchTROWMAX_f32_30x216_valid30x24(float *src, float *dst, void *stream); +void LaunchTROWMAX_f32_30x216_valid11x216(float *src, float *dst, void *stream); +void LaunchTROWMAX_f32_30x216_valid11x24(float *src, float *dst, void *stream); +void LaunchTROWMAX_f32_238x40(float *src, float *dst, void *stream); +void LaunchTROWMAX_f32_238x40_valid238x16(float *src, float *dst, void *stream); +void LaunchTROWMAX_f32_238x40_valid121x40(float *src, float *dst, void *stream); +void LaunchTROWMAX_f32_238x40_valid121x16(float *src, float *dst, void *stream); +void LaunchTROWMAX_f32_64x128(float *src, float *dst, void *stream); +void LaunchTROWMAX_f32_32x256(float *src, float *dst, void *stream); +void LaunchTROWMAX_f32_16x512(float *src, float *dst, void *stream); +void LaunchTROWMAX_f32_8x1024(float *src, float *dst, void *stream); +void LaunchTROWMAX_i32_127x64_valid127x63(int32_t *src, int32_t *dst, void *stream); +void LaunchTROWMAX_i32_63x64(int32_t *src, int32_t *dst, void *stream); +void LaunchTROWMAX_i32_31x128_valid31x127(int32_t *src, int32_t *dst, void *stream); +void LaunchTROWMAX_i32_15x192(int32_t *src, int32_t *dst, void *stream); +void LaunchTROWMAX_i32_7x448_valid7x447(int32_t *src, int32_t *dst, void *stream); +void LaunchTROWMAX_i16_128x64(int16_t *src, int16_t *dst, void *stream); +void LaunchTROWMAX_i16_64x64(int16_t *src, int16_t *dst, void *stream); +void LaunchTROWMAX_i16_32x128(int16_t *src, int16_t *dst, void *stream); +void LaunchTROWMAX_i16_16x192(int16_t *src, int16_t *dst, void *stream); +void LaunchTROWMAX_i16_8x448(int16_t *src, int16_t *dst, void *stream); + +using LaunchFnF32 = void (*)(float *, float *, void *); +using LaunchFnF16 = void (*)(uint16_t *, uint16_t *, void *); +using LaunchFnI32 = void (*)(int32_t *, int32_t *, void *); +using LaunchFnI16 = void (*)(int16_t *, int16_t *, void *); + +enum class DType { F32, F16, I32, I16 }; + +struct TestCase { + const char *name; + DType dtype; + union { + LaunchFnF32 launchF32; + LaunchFnF16 launchF16; + LaunchFnI32 launchI32; + LaunchFnI16 launchI16; + }; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + // f32 cases + {"f32_127x64_valid127x63", DType::F32, .launchF32 = LaunchTROWMAX_f32_127x64_valid127x63, 127, 64, 127, 63, 4}, + {"f32_63x64", DType::F32, .launchF32 = LaunchTROWMAX_f32_63x64, 63, 64, 63, 64, 4}, + {"f32_31x128_valid31x127", DType::F32, .launchF32 = LaunchTROWMAX_f32_31x128_valid31x127, 31, 128, 31, 127, 4}, + {"f32_15x192", DType::F32, .launchF32 = LaunchTROWMAX_f32_15x192, 15, 192, 15, 192, 4}, + {"f32_7x448_valid7x447", DType::F32, .launchF32 = LaunchTROWMAX_f32_7x448_valid7x447, 7, 448, 7, 447, 4}, + // f16 case + {"f16_256x16_valid256x15", DType::F16, .launchF16 = LaunchTROWMAX_f16_256x16_valid256x15, 256, 16, 256, 15, 2}, + // f32 more cases + {"f32_30x216", DType::F32, .launchF32 = LaunchTROWMAX_f32_30x216, 30, 216, 30, 216, 4}, + {"f32_30x216_valid30x24", DType::F32, .launchF32 = LaunchTROWMAX_f32_30x216_valid30x24, 30, 216, 30, 24, 4}, + {"f32_30x216_valid11x216", DType::F32, .launchF32 = LaunchTROWMAX_f32_30x216_valid11x216, 30, 216, 11, 216, 4}, + {"f32_30x216_valid11x24", DType::F32, .launchF32 = LaunchTROWMAX_f32_30x216_valid11x24, 30, 216, 11, 24, 4}, + {"f32_238x40", DType::F32, .launchF32 = LaunchTROWMAX_f32_238x40, 238, 40, 238, 40, 4}, + {"f32_238x40_valid238x16", DType::F32, .launchF32 = LaunchTROWMAX_f32_238x40_valid238x16, 238, 40, 238, 16, 4}, + {"f32_238x40_valid121x40", DType::F32, .launchF32 = LaunchTROWMAX_f32_238x40_valid121x40, 238, 40, 121, 40, 4}, + {"f32_238x40_valid121x16", DType::F32, .launchF32 = LaunchTROWMAX_f32_238x40_valid121x16, 238, 40, 121, 16, 4}, + // f32 DN dst cases + {"f32_64x128", DType::F32, .launchF32 = LaunchTROWMAX_f32_64x128, 64, 128, 64, 128, 4}, + {"f32_32x256", DType::F32, .launchF32 = LaunchTROWMAX_f32_32x256, 32, 256, 32, 256, 4}, + {"f32_16x512", DType::F32, .launchF32 = LaunchTROWMAX_f32_16x512, 16, 512, 16, 512, 4}, + {"f32_8x1024", DType::F32, .launchF32 = LaunchTROWMAX_f32_8x1024, 8, 1024,8, 1024,4}, + // int32 cases + {"i32_127x64_valid127x63", DType::I32, .launchI32 = LaunchTROWMAX_i32_127x64_valid127x63, 127, 64, 127, 63, 4}, + {"i32_63x64", DType::I32, .launchI32 = LaunchTROWMAX_i32_63x64, 63, 64, 63, 64, 4}, + {"i32_31x128_valid31x127", DType::I32, .launchI32 = LaunchTROWMAX_i32_31x128_valid31x127, 31, 128, 31, 127, 4}, + {"i32_15x192", DType::I32, .launchI32 = LaunchTROWMAX_i32_15x192, 15, 192, 15, 192, 4}, + {"i32_7x448_valid7x447", DType::I32, .launchI32 = LaunchTROWMAX_i32_7x448_valid7x447, 7, 448, 7, 447, 4}, + // int16 cases + {"i16_128x64", DType::I16, .launchI16 = LaunchTROWMAX_i16_128x64, 128, 64, 128, 64, 2}, + {"i16_64x64", DType::I16, .launchI16 = LaunchTROWMAX_i16_64x64, 64, 64, 64, 64, 2}, + {"i16_32x128", DType::I16, .launchI16 = LaunchTROWMAX_i16_32x128, 32, 128, 32, 128, 2}, + {"i16_16x192", DType::I16, .launchI16 = LaunchTROWMAX_i16_16x192, 16, 192, 16, 192, 2}, + {"i16_8x448", DType::I16, .launchI16 = LaunchTROWMAX_i16_8x448, 8, 448, 8, 448, 2}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t srcElemCount = tc.rows * tc.cols; + const size_t srcFileSize = srcElemCount * tc.elemSize; + const size_t dstElemCount = tc.validRows * 1; + const size_t dstFileSize = dstElemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t src0FileSize = srcFileSize; + + void *src0Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&src0Host, srcFileSize); + aclrtMallocHost(&dstHost, dstFileSize); + + aclrtMalloc(&src0Device, srcFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, srcFileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, srcFileSize, src0Host, srcFileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + switch (tc.dtype) { + case DType::F32: tc.launchF32((float *)src0Device, (float *)dstDevice, stream); break; + case DType::F16: tc.launchF16((uint16_t *)src0Device, (uint16_t *)dstDevice, stream); break; + case DType::I32: tc.launchI32((int32_t *)src0Device, (int32_t *)dstDevice, stream); break; + case DType::I16: tc.launchI16((int16_t *)src0Device, (int16_t *)dstDevice, stream); break; + } + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./trowmax [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowmax/trowmax.pto b/test/tilelang_st/npu/a5/src/st/testcase/trowmax/trowmax.pto new file mode 100644 index 000000000..055fdd8e1 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowmax/trowmax.pto @@ -0,0 +1,1321 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.trowmax: tload(src) + trowmax(src, tmp)->dst + tstore(dst). + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + + // Case 0: f32 127x64 (valid=127x63) + func.func @TROWMAX_f32_127x64_valid127x63(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c127 = arith.constant 127 : index + %c8128 = arith.constant 8128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c127, %c64], + strides = [%c8128, %c8128, %c8128, %c64, %c1] + : !pto.tensor_view<1x1x1x127x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c127, %c1], + strides = [%c127, %c127, %c127, %c1, %c1] + : !pto.tensor_view<1x1x1x127x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c127, %c63] + : !pto.tensor_view<1x1x1x127x64xf32> -> !pto.partition_tensor_view<1x1x1x127x63xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c127, %c1] + : !pto.tensor_view<1x1x1x127x1xf32> -> !pto.partition_tensor_view<1x1x1x127x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x127x63xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x127x1xf32>) + return + } + + // Case 1: f32 63x64 (valid=63x64) + func.func @TROWMAX_f32_63x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c4032 = arith.constant 4032 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c63, %c1], + strides = [%c63, %c63, %c63, %c1, %c1] + : !pto.tensor_view<1x1x1x63x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xf32> -> !pto.partition_tensor_view<1x1x1x63x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c1] + : !pto.tensor_view<1x1x1x63x1xf32> -> !pto.partition_tensor_view<1x1x1x63x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x64xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x1xf32>) + return + } + + // Case 2: f32 31x128 (valid=31x127) + func.func @TROWMAX_f32_31x128_valid31x127(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c31 = arith.constant 31 : index + %c127 = arith.constant 127 : index + %c128 = arith.constant 128 : index + %c3968 = arith.constant 3968 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c31, %c1], + strides = [%c31, %c31, %c31, %c1, %c1] + : !pto.tensor_view<1x1x1x31x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c127] + : !pto.tensor_view<1x1x1x31x128xf32> -> !pto.partition_tensor_view<1x1x1x31x127xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c1] + : !pto.tensor_view<1x1x1x31x1xf32> -> !pto.partition_tensor_view<1x1x1x31x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x31x127xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x31x1xf32>) + return + } + + // Case 3: f32 15x192 (valid=15x192) + func.func @TROWMAX_f32_15x192(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c15 = arith.constant 15 : index + %c192 = arith.constant 192 : index + %c2880 = arith.constant 2880 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c15, %c1], + strides = [%c15, %c15, %c15, %c1, %c1] + : !pto.tensor_view<1x1x1x15x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xf32> -> !pto.partition_tensor_view<1x1x1x15x192xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c1] + : !pto.tensor_view<1x1x1x15x1xf32> -> !pto.partition_tensor_view<1x1x1x15x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x192xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x15x1xf32>) + return + } + + // Case 4: f32 7x448 (valid=7x447) + func.func @TROWMAX_f32_7x448_valid7x447(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c7 = arith.constant 7 : index + %c8 = arith.constant 8 : index + %c447 = arith.constant 447 : index + %c448 = arith.constant 448 : index + %c3136 = arith.constant 3136 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c7, %c448], + strides = [%c3136, %c3136, %c3136, %c448, %c1] + : !pto.tensor_view<1x1x1x7x448xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c7, %c1], + strides = [%c7, %c7, %c7, %c1, %c1] + : !pto.tensor_view<1x1x1x7x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c447] + : !pto.tensor_view<1x1x1x7x448xf32> -> !pto.partition_tensor_view<1x1x1x7x447xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c1] + : !pto.tensor_view<1x1x1x7x1xf32> -> !pto.partition_tensor_view<1x1x1x7x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x7x447xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x7x1xf32>) + return + } + + // Case 5: f16 256x16 (valid=256x15) + func.func @TROWMAX_f16_256x16_valid256x15(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c256, %c16], + strides = [%c4096, %c4096, %c4096, %c16, %c1] + : !pto.tensor_view<1x1x1x256x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c256, %c1], + strides = [%c256, %c256, %c256, %c1, %c1] + : !pto.tensor_view<1x1x1x256x1xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c15] + : !pto.tensor_view<1x1x1x256x16xf16> -> !pto.partition_tensor_view<1x1x1x256x15xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c1] + : !pto.tensor_view<1x1x1x256x1xf16> -> !pto.partition_tensor_view<1x1x1x256x1xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x256x15xf16>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x256x1xf16>) + return + } + + // Case 6: f32 30x216 (valid=30x216) + func.func @TROWMAX_f32_30x216(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c30 = arith.constant 30 : index + %c216 = arith.constant 216 : index + %c6480 = arith.constant 6480 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c30, %c216], + strides = [%c6480, %c6480, %c6480, %c216, %c1] + : !pto.tensor_view<1x1x1x30x216xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c30, %c1], + strides = [%c30, %c30, %c30, %c1, %c1] + : !pto.tensor_view<1x1x1x30x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c30, %c216] + : !pto.tensor_view<1x1x1x30x216xf32> -> !pto.partition_tensor_view<1x1x1x30x216xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c30, %c1] + : !pto.tensor_view<1x1x1x30x1xf32> -> !pto.partition_tensor_view<1x1x1x30x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x30x216xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x30x1xf32>) + return + } + + // Case 7: f32 30x216 (valid=30x24) + func.func @TROWMAX_f32_30x216_valid30x24(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c24 = arith.constant 24 : index + %c30 = arith.constant 30 : index + %c216 = arith.constant 216 : index + %c6480 = arith.constant 6480 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c30, %c216], + strides = [%c6480, %c6480, %c6480, %c216, %c1] + : !pto.tensor_view<1x1x1x30x216xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c30, %c1], + strides = [%c30, %c30, %c30, %c1, %c1] + : !pto.tensor_view<1x1x1x30x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c30, %c24] + : !pto.tensor_view<1x1x1x30x216xf32> -> !pto.partition_tensor_view<1x1x1x30x24xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c30, %c1] + : !pto.tensor_view<1x1x1x30x1xf32> -> !pto.partition_tensor_view<1x1x1x30x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x30x24xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x30x1xf32>) + return + } + + // Case 8: f32 30x216 (valid=11x216) + func.func @TROWMAX_f32_30x216_valid11x216(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c11 = arith.constant 11 : index + %c30 = arith.constant 30 : index + %c216 = arith.constant 216 : index + %c6480 = arith.constant 6480 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c30, %c216], + strides = [%c6480, %c6480, %c6480, %c216, %c1] + : !pto.tensor_view<1x1x1x30x216xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c11, %c1], + strides = [%c11, %c11, %c11, %c1, %c1] + : !pto.tensor_view<1x1x1x11x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c11, %c216] + : !pto.tensor_view<1x1x1x30x216xf32> -> !pto.partition_tensor_view<1x1x1x11x216xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c11, %c1] + : !pto.tensor_view<1x1x1x11x1xf32> -> !pto.partition_tensor_view<1x1x1x11x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x11x216xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x11x1xf32>) + return + } + + // Case 9: f32 30x216 (valid=11x24) + func.func @TROWMAX_f32_30x216_valid11x24(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c11 = arith.constant 11 : index + %c24 = arith.constant 24 : index + %c30 = arith.constant 30 : index + %c216 = arith.constant 216 : index + %c6480 = arith.constant 6480 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c30, %c216], + strides = [%c6480, %c6480, %c6480, %c216, %c1] + : !pto.tensor_view<1x1x1x30x216xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c11, %c1], + strides = [%c11, %c11, %c11, %c1, %c1] + : !pto.tensor_view<1x1x1x11x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c11, %c24] + : !pto.tensor_view<1x1x1x30x216xf32> -> !pto.partition_tensor_view<1x1x1x11x24xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c11, %c1] + : !pto.tensor_view<1x1x1x11x1xf32> -> !pto.partition_tensor_view<1x1x1x11x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x11x24xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x11x1xf32>) + return + } + + // Case 10: f32 238x40 (valid=238x40) + func.func @TROWMAX_f32_238x40(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c40 = arith.constant 40 : index + %c238 = arith.constant 238 : index + %c9520 = arith.constant 9520 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c238, %c40], + strides = [%c9520, %c9520, %c9520, %c40, %c1] + : !pto.tensor_view<1x1x1x238x40xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c238, %c1], + strides = [%c238, %c238, %c238, %c1, %c1] + : !pto.tensor_view<1x1x1x238x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c238, %c40] + : !pto.tensor_view<1x1x1x238x40xf32> -> !pto.partition_tensor_view<1x1x1x238x40xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c238, %c1] + : !pto.tensor_view<1x1x1x238x1xf32> -> !pto.partition_tensor_view<1x1x1x238x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x238x40xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x238x1xf32>) + return + } + + // Case 11: f32 238x40 (valid=238x16) + func.func @TROWMAX_f32_238x40_valid238x16(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c40 = arith.constant 40 : index + %c238 = arith.constant 238 : index + %c9520 = arith.constant 9520 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c238, %c40], + strides = [%c9520, %c9520, %c9520, %c40, %c1] + : !pto.tensor_view<1x1x1x238x40xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c238, %c1], + strides = [%c238, %c238, %c238, %c1, %c1] + : !pto.tensor_view<1x1x1x238x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c238, %c16] + : !pto.tensor_view<1x1x1x238x40xf32> -> !pto.partition_tensor_view<1x1x1x238x16xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c238, %c1] + : !pto.tensor_view<1x1x1x238x1xf32> -> !pto.partition_tensor_view<1x1x1x238x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x238x16xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x238x1xf32>) + return + } + + // Case 12: f32 238x40 (valid=121x40) + func.func @TROWMAX_f32_238x40_valid121x40(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c40 = arith.constant 40 : index + %c121 = arith.constant 121 : index + %c238 = arith.constant 238 : index + %c9520 = arith.constant 9520 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c238, %c40], + strides = [%c9520, %c9520, %c9520, %c40, %c1] + : !pto.tensor_view<1x1x1x238x40xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c121, %c1], + strides = [%c121, %c121, %c121, %c1, %c1] + : !pto.tensor_view<1x1x1x121x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c121, %c40] + : !pto.tensor_view<1x1x1x238x40xf32> -> !pto.partition_tensor_view<1x1x1x121x40xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c121, %c1] + : !pto.tensor_view<1x1x1x121x1xf32> -> !pto.partition_tensor_view<1x1x1x121x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x121x40xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x121x1xf32>) + return + } + + // Case 13: f32 238x40 (valid=121x16) + func.func @TROWMAX_f32_238x40_valid121x16(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c40 = arith.constant 40 : index + %c121 = arith.constant 121 : index + %c238 = arith.constant 238 : index + %c9520 = arith.constant 9520 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c238, %c40], + strides = [%c9520, %c9520, %c9520, %c40, %c1] + : !pto.tensor_view<1x1x1x238x40xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c121, %c1], + strides = [%c121, %c121, %c121, %c1, %c1] + : !pto.tensor_view<1x1x1x121x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c121, %c16] + : !pto.tensor_view<1x1x1x238x40xf32> -> !pto.partition_tensor_view<1x1x1x121x16xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c121, %c1] + : !pto.tensor_view<1x1x1x121x1xf32> -> !pto.partition_tensor_view<1x1x1x121x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x121x16xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x121x1xf32>) + return + } + + // Case 14: f32 64x128 (valid=64x128) + func.func @TROWMAX_f32_64x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c64, %c128], + strides = [%c8192, %c8192, %c8192, %c128, %c1] + : !pto.tensor_view<1x1x1x64x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c64, %c1], + strides = [%c64, %c64, %c64, %c1, %c1] + : !pto.tensor_view<1x1x1x64x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c128] + : !pto.tensor_view<1x1x1x64x128xf32> -> !pto.partition_tensor_view<1x1x1x64x128xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c1] + : !pto.tensor_view<1x1x1x64x1xf32> -> !pto.partition_tensor_view<1x1x1x64x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x64x128xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x64x1xf32>) + return + } + + // Case 15: f32 32x256 (valid=32x256) + func.func @TROWMAX_f32_32x256(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c32 = arith.constant 32 : index + %c256 = arith.constant 256 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c32, %c256], + strides = [%c8192, %c8192, %c8192, %c256, %c1] + : !pto.tensor_view<1x1x1x32x256xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c1], + strides = [%c32, %c32, %c32, %c1, %c1] + : !pto.tensor_view<1x1x1x32x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c256] + : !pto.tensor_view<1x1x1x32x256xf32> -> !pto.partition_tensor_view<1x1x1x32x256xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c1] + : !pto.tensor_view<1x1x1x32x1xf32> -> !pto.partition_tensor_view<1x1x1x32x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x256xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x1xf32>) + return + } + + // Case 16: f32 16x512 (valid=16x512) + func.func @TROWMAX_f32_16x512(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c512 = arith.constant 512 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c512], + strides = [%c8192, %c8192, %c8192, %c512, %c1] + : !pto.tensor_view<1x1x1x16x512xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c1], + strides = [%c16, %c16, %c16, %c1, %c1] + : !pto.tensor_view<1x1x1x16x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c512] + : !pto.tensor_view<1x1x1x16x512xf32> -> !pto.partition_tensor_view<1x1x1x16x512xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c1] + : !pto.tensor_view<1x1x1x16x1xf32> -> !pto.partition_tensor_view<1x1x1x16x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x512xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x1xf32>) + return + } + + // Case 17: f32 8x1024 (valid=8x1024) + func.func @TROWMAX_f32_8x1024(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c1024 = arith.constant 1024 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c8, %c1024], + strides = [%c8192, %c8192, %c8192, %c1024, %c1] + : !pto.tensor_view<1x1x1x8x1024xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c8, %c1], + strides = [%c8, %c8, %c8, %c1, %c1] + : !pto.tensor_view<1x1x1x8x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c1024] + : !pto.tensor_view<1x1x1x8x1024xf32> -> !pto.partition_tensor_view<1x1x1x8x1024xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c1] + : !pto.tensor_view<1x1x1x8x1xf32> -> !pto.partition_tensor_view<1x1x1x8x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x8x1024xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x8x1xf32>) + return + } + + // ======================================================================== + // int32 cases (case19-case23) + // ======================================================================== + + // case19: i32 127x64 valid=127x63 + func.func @TROWMAX_i32_127x64_valid127x63(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c127 = arith.constant 127 : index + %c8128 = arith.constant 8128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c127, %c64], + strides = [%c8128, %c8128, %c8128, %c64, %c1] + : !pto.tensor_view<1x1x1x127x64xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c127, %c1], + strides = [%c127, %c127, %c127, %c1, %c1] + : !pto.tensor_view<1x1x1x127x1xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c127, %c63] + : !pto.tensor_view<1x1x1x127x64xi32> -> !pto.partition_tensor_view<1x1x1x127x63xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c127, %c1] + : !pto.tensor_view<1x1x1x127x1xi32> -> !pto.partition_tensor_view<1x1x1x127x1xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x127x63xi32>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x127x1xi32>) + return + } + + // case20: i32 63x64 valid=63x64 + func.func @TROWMAX_i32_63x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c4032 = arith.constant 4032 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c63, %c1], + strides = [%c63, %c63, %c63, %c1, %c1] + : !pto.tensor_view<1x1x1x63x1xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xi32> -> !pto.partition_tensor_view<1x1x1x63x64xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c1] + : !pto.tensor_view<1x1x1x63x1xi32> -> !pto.partition_tensor_view<1x1x1x63x1xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x64xi32>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x1xi32>) + return + } + + // case21: i32 31x128 valid=31x127 + func.func @TROWMAX_i32_31x128_valid31x127(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c31 = arith.constant 31 : index + %c127 = arith.constant 127 : index + %c128 = arith.constant 128 : index + %c3968 = arith.constant 3968 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c31, %c1], + strides = [%c31, %c31, %c31, %c1, %c1] + : !pto.tensor_view<1x1x1x31x1xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c127] + : !pto.tensor_view<1x1x1x31x128xi32> -> !pto.partition_tensor_view<1x1x1x31x127xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c1] + : !pto.tensor_view<1x1x1x31x1xi32> -> !pto.partition_tensor_view<1x1x1x31x1xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x31x127xi32>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x31x1xi32>) + return + } + + // case22: i32 15x192 valid=15x192 + func.func @TROWMAX_i32_15x192(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c15 = arith.constant 15 : index + %c192 = arith.constant 192 : index + %c2880 = arith.constant 2880 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c15, %c1], + strides = [%c15, %c15, %c15, %c1, %c1] + : !pto.tensor_view<1x1x1x15x1xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xi32> -> !pto.partition_tensor_view<1x1x1x15x192xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c1] + : !pto.tensor_view<1x1x1x15x1xi32> -> !pto.partition_tensor_view<1x1x1x15x1xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x192xi32>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x15x1xi32>) + return + } + + // case23: i32 7x448 valid=7x447 + func.func @TROWMAX_i32_7x448_valid7x447(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c7 = arith.constant 7 : index + %c447 = arith.constant 447 : index + %c448 = arith.constant 448 : index + %c3136 = arith.constant 3136 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c7, %c448], + strides = [%c3136, %c3136, %c3136, %c448, %c1] + : !pto.tensor_view<1x1x1x7x448xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c7, %c1], + strides = [%c7, %c7, %c7, %c1, %c1] + : !pto.tensor_view<1x1x1x7x1xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c447] + : !pto.tensor_view<1x1x1x7x448xi32> -> !pto.partition_tensor_view<1x1x1x7x447xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c1] + : !pto.tensor_view<1x1x1x7x1xi32> -> !pto.partition_tensor_view<1x1x1x7x1xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x7x447xi32>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x7x1xi32>) + return + } + + // ======================================================================== + // int16 cases (case24-case28) + // ======================================================================== + + // case24: i16 128x64 valid=128x64 + func.func @TROWMAX_i16_128x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c128, %c64], + strides = [%c8192, %c8192, %c8192, %c64, %c1] + : !pto.tensor_view<1x1x1x128x64xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c128, %c1], + strides = [%c128, %c128, %c128, %c1, %c1] + : !pto.tensor_view<1x1x1x128x1xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c64] + : !pto.tensor_view<1x1x1x128x64xi16> -> !pto.partition_tensor_view<1x1x1x128x64xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c1] + : !pto.tensor_view<1x1x1x128x1xi16> -> !pto.partition_tensor_view<1x1x1x128x1xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x128x64xi16>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x128x1xi16>) + return + } + + // case25: i16 64x64 valid=64x64 + func.func @TROWMAX_i16_64x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c64, %c1], + strides = [%c64, %c64, %c64, %c1, %c1] + : !pto.tensor_view<1x1x1x64x1xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c1] + : !pto.tensor_view<1x1x1x64x1xi16> -> !pto.partition_tensor_view<1x1x1x64x1xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x64x1xi16>) + return + } + + // case26: i16 32x128 valid=32x128 + func.func @TROWMAX_i16_32x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c32, %c128], + strides = [%c4096, %c4096, %c4096, %c128, %c1] + : !pto.tensor_view<1x1x1x32x128xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c1], + strides = [%c32, %c32, %c32, %c1, %c1] + : !pto.tensor_view<1x1x1x32x1xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c128] + : !pto.tensor_view<1x1x1x32x128xi16> -> !pto.partition_tensor_view<1x1x1x32x128xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c1] + : !pto.tensor_view<1x1x1x32x1xi16> -> !pto.partition_tensor_view<1x1x1x32x1xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x128xi16>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x1xi16>) + return + } + + // case27: i16 16x192 valid=16x192 + func.func @TROWMAX_i16_16x192(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c16r = arith.constant 16 : index + %c192 = arith.constant 192 : index + %c3072 = arith.constant 3072 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16r, %c192], + strides = [%c3072, %c3072, %c3072, %c192, %c1] + : !pto.tensor_view<1x1x1x16x192xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16r, %c1], + strides = [%c16, %c16, %c16, %c1, %c1] + : !pto.tensor_view<1x1x1x16x1xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16r, %c192] + : !pto.tensor_view<1x1x1x16x192xi16> -> !pto.partition_tensor_view<1x1x1x16x192xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16r, %c1] + : !pto.tensor_view<1x1x1x16x1xi16> -> !pto.partition_tensor_view<1x1x1x16x1xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x192xi16>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x1xi16>) + return + } + + // case28: i16 8x448 valid=8x448 + func.func @TROWMAX_i16_8x448(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c8 = arith.constant 8 : index + %c448 = arith.constant 448 : index + %c3584 = arith.constant 3584 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c8, %c448], + strides = [%c3584, %c3584, %c3584, %c448, %c1] + : !pto.tensor_view<1x1x1x8x448xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c8, %c1], + strides = [%c8, %c8, %c8, %c1, %c1] + : !pto.tensor_view<1x1x1x8x1xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c448] + : !pto.tensor_view<1x1x1x8x448xi16> -> !pto.partition_tensor_view<1x1x1x8x448xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c1] + : !pto.tensor_view<1x1x1x8x1xi16> -> !pto.partition_tensor_view<1x1x1x8x1xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x8x448xi16>) + outs(%src : !pto.tile_buf) + + pto.trowmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x8x1xi16>) + return + } + +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowmin/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/trowmin/CMakeLists.txt new file mode 100644 index 000000000..e88611a82 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowmin/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(trowmin) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowmin/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/trowmin/cases.py new file mode 100644 index 000000000..903509084 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowmin/cases.py @@ -0,0 +1,224 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for trowmin ST test cases. + +Aligned with pto-isa tests/npu/a2a3/src/st/testcase/trowmin (28 cases). +""" + +import numpy as np + +CASES = [ + # f32 cases (case1-case5 from pto-isa) + { + "name": "f32_127x64_valid127x63", + "dtype": np.float32, + "shape": (127, 64), + "valid_shape": (127, 63), + "eps": 1e-5, + }, + { + "name": "f32_63x64", + "dtype": np.float32, + "shape": (63, 64), + "valid_shape": (63, 64), + "eps": 1e-5, + }, + { + "name": "f32_31x128_valid31x127", + "dtype": np.float32, + "shape": (31, 128), + "valid_shape": (31, 127), + "eps": 1e-5, + }, + { + "name": "f32_15x192", + "dtype": np.float32, + "shape": (15, 192), + "valid_shape": (15, 192), + "eps": 1e-5, + }, + { + "name": "f32_7x448_valid7x447", + "dtype": np.float32, + "shape": (7, 448), + "valid_shape": (7, 447), + "eps": 1e-5, + }, + # f16 case (case6 from pto-isa) + { + "name": "f16_256x16_valid256x15", + "dtype": np.float16, + "shape": (256, 16), + "valid_shape": (256, 15), + "eps": 1e-2, + }, + # f32 more cases (case7-case14 from pto-isa) + { + "name": "f32_30x216", + "dtype": np.float32, + "shape": (30, 216), + "valid_shape": (30, 216), + "eps": 1e-5, + }, + { + "name": "f32_30x216_valid30x24", + "dtype": np.float32, + "shape": (30, 216), + "valid_shape": (30, 24), + "eps": 1e-5, + }, + { + "name": "f32_30x216_valid11x216", + "dtype": np.float32, + "shape": (30, 216), + "valid_shape": (11, 216), + "eps": 1e-5, + }, + { + "name": "f32_30x216_valid11x24", + "dtype": np.float32, + "shape": (30, 216), + "valid_shape": (11, 24), + "eps": 1e-5, + }, + { + "name": "f32_238x40", + "dtype": np.float32, + "shape": (238, 40), + "valid_shape": (238, 40), + "eps": 1e-5, + }, + { + "name": "f32_238x40_valid238x16", + "dtype": np.float32, + "shape": (238, 40), + "valid_shape": (238, 16), + "eps": 1e-5, + }, + { + "name": "f32_238x40_valid121x40", + "dtype": np.float32, + "shape": (238, 40), + "valid_shape": (121, 40), + "eps": 1e-5, + }, + { + "name": "f32_238x40_valid121x16", + "dtype": np.float32, + "shape": (238, 40), + "valid_shape": (121, 16), + "eps": 1e-5, + }, + # f32 DN dst cases (case15-case18 from pto-isa) + { + "name": "f32_64x128", + "dtype": np.float32, + "shape": (64, 128), + "valid_shape": (64, 128), + "eps": 1e-5, + }, + { + "name": "f32_32x256", + "dtype": np.float32, + "shape": (32, 256), + "valid_shape": (32, 256), + "eps": 1e-5, + }, + { + "name": "f32_16x512", + "dtype": np.float32, + "shape": (16, 512), + "valid_shape": (16, 512), + "eps": 1e-5, + }, + { + "name": "f32_8x1024", + "dtype": np.float32, + "shape": (8, 1024), + "valid_shape": (8, 1024), + "eps": 1e-5, + }, + + # int32 cases (case19-case23 from pto-isa) + { + "name": "i32_127x64_valid127x63", + "dtype": np.int32, + "shape": (127, 64), + "valid_shape": (127, 63), + "eps": 0, + }, + { + "name": "i32_63x64", + "dtype": np.int32, + "shape": (63, 64), + "valid_shape": (63, 64), + "eps": 0, + }, + { + "name": "i32_31x128_valid31x127", + "dtype": np.int32, + "shape": (31, 128), + "valid_shape": (31, 127), + "eps": 0, + }, + { + "name": "i32_15x192", + "dtype": np.int32, + "shape": (15, 192), + "valid_shape": (15, 192), + "eps": 0, + }, + { + "name": "i32_7x448_valid7x447", + "dtype": np.int32, + "shape": (7, 448), + "valid_shape": (7, 447), + "eps": 0, + }, + + # int16 cases (case24-case28 from pto-isa) + { + "name": "i16_128x64", + "dtype": np.int16, + "shape": (128, 64), + "valid_shape": (128, 64), + "eps": 0, + }, + { + "name": "i16_64x64", + "dtype": np.int16, + "shape": (64, 64), + "valid_shape": (64, 64), + "eps": 0, + }, + { + "name": "i16_32x128", + "dtype": np.int16, + "shape": (32, 128), + "valid_shape": (32, 128), + "eps": 0, + }, + { + "name": "i16_16x192", + "dtype": np.int16, + "shape": (16, 192), + "valid_shape": (16, 192), + "eps": 0, + }, + { + "name": "i16_8x448", + "dtype": np.int16, + "shape": (8, 448), + "valid_shape": (8, 448), + "eps": 0, + }, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowmin/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/trowmin/compare.py new file mode 100644 index 000000000..12d4207bd --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowmin/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + vr, vc = case["valid_shape"] + out_shape = (vr,) + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"], count=np.prod(out_shape)).reshape(out_shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"], count=np.prod(out_shape)).reshape(out_shape) + + ok = result_cmp(golden, output, case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowmin/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/trowmin/gen_data.py new file mode 100644 index 000000000..cf1bed8ac --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowmin/gen_data.py @@ -0,0 +1,41 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + if np.issubdtype(dtype, np.integer): + if dtype == np.int32: + input1 = np.random.randint(low=-100, high=100, size=shape).astype(dtype) + else: + input1 = np.random.randint(low=-50, high=50, size=shape).astype(dtype) + else: + input1 = np.random.uniform(low=-16, high=16, size=shape).astype(dtype) + + out_shape = (valid_shape[0],) + golden = np.zeros(out_shape, dtype=dtype) + vr, vc = valid_shape + for i in range(vr): + golden[i] = np.min(input1[i, :vc]) + + golden = golden.astype(dtype, copy=False) + save_case_data(case["name"], {"input1": input1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowmin/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowmin/launch.cpp new file mode 100644 index 000000000..e4a8f8bde --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowmin/launch.cpp @@ -0,0 +1,155 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +extern "C" __global__ AICORE void TROWMIN_f32_127x64_valid127x63(__gm__ float *src, __gm__ float *dst); +void LaunchTROWMIN_f32_127x64_valid127x63(float *src, float *dst, void *stream) { + TROWMIN_f32_127x64_valid127x63<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMIN_f32_63x64(__gm__ float *src, __gm__ float *dst); +void LaunchTROWMIN_f32_63x64(float *src, float *dst, void *stream) { + TROWMIN_f32_63x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMIN_f32_31x128_valid31x127(__gm__ float *src, __gm__ float *dst); +void LaunchTROWMIN_f32_31x128_valid31x127(float *src, float *dst, void *stream) { + TROWMIN_f32_31x128_valid31x127<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMIN_f32_15x192(__gm__ float *src, __gm__ float *dst); +void LaunchTROWMIN_f32_15x192(float *src, float *dst, void *stream) { + TROWMIN_f32_15x192<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMIN_f32_7x448_valid7x447(__gm__ float *src, __gm__ float *dst); +void LaunchTROWMIN_f32_7x448_valid7x447(float *src, float *dst, void *stream) { + TROWMIN_f32_7x448_valid7x447<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMIN_f16_256x16_valid256x15(__gm__ uint16_t *src, __gm__ uint16_t *dst); +void LaunchTROWMIN_f16_256x16_valid256x15(uint16_t *src, uint16_t *dst, void *stream) { + TROWMIN_f16_256x16_valid256x15<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint16_t *)dst); +} + +extern "C" __global__ AICORE void TROWMIN_f32_30x216(__gm__ float *src, __gm__ float *dst); +void LaunchTROWMIN_f32_30x216(float *src, float *dst, void *stream) { + TROWMIN_f32_30x216<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMIN_f32_30x216_valid30x24(__gm__ float *src, __gm__ float *dst); +void LaunchTROWMIN_f32_30x216_valid30x24(float *src, float *dst, void *stream) { + TROWMIN_f32_30x216_valid30x24<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMIN_f32_30x216_valid11x216(__gm__ float *src, __gm__ float *dst); +void LaunchTROWMIN_f32_30x216_valid11x216(float *src, float *dst, void *stream) { + TROWMIN_f32_30x216_valid11x216<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMIN_f32_30x216_valid11x24(__gm__ float *src, __gm__ float *dst); +void LaunchTROWMIN_f32_30x216_valid11x24(float *src, float *dst, void *stream) { + TROWMIN_f32_30x216_valid11x24<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMIN_f32_238x40(__gm__ float *src, __gm__ float *dst); +void LaunchTROWMIN_f32_238x40(float *src, float *dst, void *stream) { + TROWMIN_f32_238x40<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMIN_f32_238x40_valid238x16(__gm__ float *src, __gm__ float *dst); +void LaunchTROWMIN_f32_238x40_valid238x16(float *src, float *dst, void *stream) { + TROWMIN_f32_238x40_valid238x16<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMIN_f32_238x40_valid121x40(__gm__ float *src, __gm__ float *dst); +void LaunchTROWMIN_f32_238x40_valid121x40(float *src, float *dst, void *stream) { + TROWMIN_f32_238x40_valid121x40<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMIN_f32_238x40_valid121x16(__gm__ float *src, __gm__ float *dst); +void LaunchTROWMIN_f32_238x40_valid121x16(float *src, float *dst, void *stream) { + TROWMIN_f32_238x40_valid121x16<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMIN_f32_64x128(__gm__ float *src, __gm__ float *dst); +void LaunchTROWMIN_f32_64x128(float *src, float *dst, void *stream) { + TROWMIN_f32_64x128<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMIN_f32_32x256(__gm__ float *src, __gm__ float *dst); +void LaunchTROWMIN_f32_32x256(float *src, float *dst, void *stream) { + TROWMIN_f32_32x256<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMIN_f32_16x512(__gm__ float *src, __gm__ float *dst); +void LaunchTROWMIN_f32_16x512(float *src, float *dst, void *stream) { + TROWMIN_f32_16x512<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWMIN_f32_8x1024(__gm__ float *src, __gm__ float *dst); +void LaunchTROWMIN_f32_8x1024(float *src, float *dst, void *stream) { + TROWMIN_f32_8x1024<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +// int32 cases +extern "C" __global__ AICORE void TROWMIN_i32_127x64_valid127x63(__gm__ int32_t *src, __gm__ int32_t *dst); +void LaunchTROWMIN_i32_127x64_valid127x63(int32_t *src, int32_t *dst, void *stream) { + TROWMIN_i32_127x64_valid127x63<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst); +} + +extern "C" __global__ AICORE void TROWMIN_i32_63x64(__gm__ int32_t *src, __gm__ int32_t *dst); +void LaunchTROWMIN_i32_63x64(int32_t *src, int32_t *dst, void *stream) { + TROWMIN_i32_63x64<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst); +} + +extern "C" __global__ AICORE void TROWMIN_i32_31x128_valid31x127(__gm__ int32_t *src, __gm__ int32_t *dst); +void LaunchTROWMIN_i32_31x128_valid31x127(int32_t *src, int32_t *dst, void *stream) { + TROWMIN_i32_31x128_valid31x127<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst); +} + +extern "C" __global__ AICORE void TROWMIN_i32_15x192(__gm__ int32_t *src, __gm__ int32_t *dst); +void LaunchTROWMIN_i32_15x192(int32_t *src, int32_t *dst, void *stream) { + TROWMIN_i32_15x192<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst); +} + +extern "C" __global__ AICORE void TROWMIN_i32_7x448_valid7x447(__gm__ int32_t *src, __gm__ int32_t *dst); +void LaunchTROWMIN_i32_7x448_valid7x447(int32_t *src, int32_t *dst, void *stream) { + TROWMIN_i32_7x448_valid7x447<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst); +} + +// int16 cases +extern "C" __global__ AICORE void TROWMIN_i16_128x64(__gm__ int16_t *src, __gm__ int16_t *dst); +void LaunchTROWMIN_i16_128x64(int16_t *src, int16_t *dst, void *stream) { + TROWMIN_i16_128x64<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst); +} + +extern "C" __global__ AICORE void TROWMIN_i16_64x64(__gm__ int16_t *src, __gm__ int16_t *dst); +void LaunchTROWMIN_i16_64x64(int16_t *src, int16_t *dst, void *stream) { + TROWMIN_i16_64x64<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst); +} + +extern "C" __global__ AICORE void TROWMIN_i16_32x128(__gm__ int16_t *src, __gm__ int16_t *dst); +void LaunchTROWMIN_i16_32x128(int16_t *src, int16_t *dst, void *stream) { + TROWMIN_i16_32x128<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst); +} + +extern "C" __global__ AICORE void TROWMIN_i16_16x192(__gm__ int16_t *src, __gm__ int16_t *dst); +void LaunchTROWMIN_i16_16x192(int16_t *src, int16_t *dst, void *stream) { + TROWMIN_i16_16x192<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst); +} + +extern "C" __global__ AICORE void TROWMIN_i16_8x448(__gm__ int16_t *src, __gm__ int16_t *dst); +void LaunchTROWMIN_i16_8x448(int16_t *src, int16_t *dst, void *stream) { + TROWMIN_i16_8x448<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowmin/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowmin/main.cpp new file mode 100644 index 000000000..f0b9f0025 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowmin/main.cpp @@ -0,0 +1,207 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang trowmin ST — case-table driven. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTROWMIN_f32_127x64_valid127x63(float *src, float *dst, void *stream); +void LaunchTROWMIN_f32_63x64(float *src, float *dst, void *stream); +void LaunchTROWMIN_f32_31x128_valid31x127(float *src, float *dst, void *stream); +void LaunchTROWMIN_f32_15x192(float *src, float *dst, void *stream); +void LaunchTROWMIN_f32_7x448_valid7x447(float *src, float *dst, void *stream); +void LaunchTROWMIN_f16_256x16_valid256x15(uint16_t *src, uint16_t *dst, void *stream); +void LaunchTROWMIN_f32_30x216(float *src, float *dst, void *stream); +void LaunchTROWMIN_f32_30x216_valid30x24(float *src, float *dst, void *stream); +void LaunchTROWMIN_f32_30x216_valid11x216(float *src, float *dst, void *stream); +void LaunchTROWMIN_f32_30x216_valid11x24(float *src, float *dst, void *stream); +void LaunchTROWMIN_f32_238x40(float *src, float *dst, void *stream); +void LaunchTROWMIN_f32_238x40_valid238x16(float *src, float *dst, void *stream); +void LaunchTROWMIN_f32_238x40_valid121x40(float *src, float *dst, void *stream); +void LaunchTROWMIN_f32_238x40_valid121x16(float *src, float *dst, void *stream); +void LaunchTROWMIN_f32_64x128(float *src, float *dst, void *stream); +void LaunchTROWMIN_f32_32x256(float *src, float *dst, void *stream); +void LaunchTROWMIN_f32_16x512(float *src, float *dst, void *stream); +void LaunchTROWMIN_f32_8x1024(float *src, float *dst, void *stream); +void LaunchTROWMIN_i32_127x64_valid127x63(int32_t *src, int32_t *dst, void *stream); +void LaunchTROWMIN_i32_63x64(int32_t *src, int32_t *dst, void *stream); +void LaunchTROWMIN_i32_31x128_valid31x127(int32_t *src, int32_t *dst, void *stream); +void LaunchTROWMIN_i32_15x192(int32_t *src, int32_t *dst, void *stream); +void LaunchTROWMIN_i32_7x448_valid7x447(int32_t *src, int32_t *dst, void *stream); +void LaunchTROWMIN_i16_128x64(int16_t *src, int16_t *dst, void *stream); +void LaunchTROWMIN_i16_64x64(int16_t *src, int16_t *dst, void *stream); +void LaunchTROWMIN_i16_32x128(int16_t *src, int16_t *dst, void *stream); +void LaunchTROWMIN_i16_16x192(int16_t *src, int16_t *dst, void *stream); +void LaunchTROWMIN_i16_8x448(int16_t *src, int16_t *dst, void *stream); + +using LaunchFnF32 = void (*)(float *, float *, void *); +using LaunchFnF16 = void (*)(uint16_t *, uint16_t *, void *); +using LaunchFnI32 = void (*)(int32_t *, int32_t *, void *); +using LaunchFnI16 = void (*)(int16_t *, int16_t *, void *); + +enum class DType { F32, F16, I32, I16 }; + +struct TestCase { + const char *name; + DType dtype; + union { + LaunchFnF32 launchF32; + LaunchFnF16 launchF16; + LaunchFnI32 launchI32; + LaunchFnI16 launchI16; + }; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + // f32 cases + {"f32_127x64_valid127x63", DType::F32, .launchF32 = LaunchTROWMIN_f32_127x64_valid127x63, 127, 64, 127, 63, 4}, + {"f32_63x64", DType::F32, .launchF32 = LaunchTROWMIN_f32_63x64, 63, 64, 63, 64, 4}, + {"f32_31x128_valid31x127", DType::F32, .launchF32 = LaunchTROWMIN_f32_31x128_valid31x127, 31, 128, 31, 127, 4}, + {"f32_15x192", DType::F32, .launchF32 = LaunchTROWMIN_f32_15x192, 15, 192, 15, 192, 4}, + {"f32_7x448_valid7x447", DType::F32, .launchF32 = LaunchTROWMIN_f32_7x448_valid7x447, 7, 448, 7, 447, 4}, + // f16 case + {"f16_256x16_valid256x15", DType::F16, .launchF16 = LaunchTROWMIN_f16_256x16_valid256x15, 256, 16, 256, 15, 2}, + // f32 more cases + {"f32_30x216", DType::F32, .launchF32 = LaunchTROWMIN_f32_30x216, 30, 216, 30, 216, 4}, + {"f32_30x216_valid30x24", DType::F32, .launchF32 = LaunchTROWMIN_f32_30x216_valid30x24, 30, 216, 30, 24, 4}, + {"f32_30x216_valid11x216", DType::F32, .launchF32 = LaunchTROWMIN_f32_30x216_valid11x216, 30, 216, 11, 216, 4}, + {"f32_30x216_valid11x24", DType::F32, .launchF32 = LaunchTROWMIN_f32_30x216_valid11x24, 30, 216, 11, 24, 4}, + {"f32_238x40", DType::F32, .launchF32 = LaunchTROWMIN_f32_238x40, 238, 40, 238, 40, 4}, + {"f32_238x40_valid238x16", DType::F32, .launchF32 = LaunchTROWMIN_f32_238x40_valid238x16, 238, 40, 238, 16, 4}, + {"f32_238x40_valid121x40", DType::F32, .launchF32 = LaunchTROWMIN_f32_238x40_valid121x40, 238, 40, 121, 40, 4}, + {"f32_238x40_valid121x16", DType::F32, .launchF32 = LaunchTROWMIN_f32_238x40_valid121x16, 238, 40, 121, 16, 4}, + // f32 DN dst cases + {"f32_64x128", DType::F32, .launchF32 = LaunchTROWMIN_f32_64x128, 64, 128, 64, 128, 4}, + {"f32_32x256", DType::F32, .launchF32 = LaunchTROWMIN_f32_32x256, 32, 256, 32, 256, 4}, + {"f32_16x512", DType::F32, .launchF32 = LaunchTROWMIN_f32_16x512, 16, 512, 16, 512, 4}, + {"f32_8x1024", DType::F32, .launchF32 = LaunchTROWMIN_f32_8x1024, 8, 1024,8, 1024,4}, + // int32 cases + {"i32_127x64_valid127x63", DType::I32, .launchI32 = LaunchTROWMIN_i32_127x64_valid127x63, 127, 64, 127, 63, 4}, + {"i32_63x64", DType::I32, .launchI32 = LaunchTROWMIN_i32_63x64, 63, 64, 63, 64, 4}, + {"i32_31x128_valid31x127", DType::I32, .launchI32 = LaunchTROWMIN_i32_31x128_valid31x127, 31, 128, 31, 127, 4}, + {"i32_15x192", DType::I32, .launchI32 = LaunchTROWMIN_i32_15x192, 15, 192, 15, 192, 4}, + {"i32_7x448_valid7x447", DType::I32, .launchI32 = LaunchTROWMIN_i32_7x448_valid7x447, 7, 448, 7, 447, 4}, + // int16 cases + {"i16_128x64", DType::I16, .launchI16 = LaunchTROWMIN_i16_128x64, 128, 64, 128, 64, 2}, + {"i16_64x64", DType::I16, .launchI16 = LaunchTROWMIN_i16_64x64, 64, 64, 64, 64, 2}, + {"i16_32x128", DType::I16, .launchI16 = LaunchTROWMIN_i16_32x128, 32, 128, 32, 128, 2}, + {"i16_16x192", DType::I16, .launchI16 = LaunchTROWMIN_i16_16x192, 16, 192, 16, 192, 2}, + {"i16_8x448", DType::I16, .launchI16 = LaunchTROWMIN_i16_8x448, 8, 448, 8, 448, 2}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t srcElemCount = tc.rows * tc.cols; + const size_t srcFileSize = srcElemCount * tc.elemSize; + const size_t dstElemCount = tc.validRows * 1; + const size_t dstFileSize = dstElemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t src0FileSize = srcFileSize; + + void *src0Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&src0Host, srcFileSize); + aclrtMallocHost(&dstHost, dstFileSize); + + aclrtMalloc(&src0Device, srcFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, srcFileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, srcFileSize, src0Host, srcFileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + switch (tc.dtype) { + case DType::F32: tc.launchF32((float *)src0Device, (float *)dstDevice, stream); break; + case DType::F16: tc.launchF16((uint16_t *)src0Device, (uint16_t *)dstDevice, stream); break; + case DType::I32: tc.launchI32((int32_t *)src0Device, (int32_t *)dstDevice, stream); break; + case DType::I16: tc.launchI16((int16_t *)src0Device, (int16_t *)dstDevice, stream); break; + } + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./trowmin [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowmin/trowmin.pto b/test/tilelang_st/npu/a5/src/st/testcase/trowmin/trowmin.pto new file mode 100644 index 000000000..d3f7a680e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowmin/trowmin.pto @@ -0,0 +1,1321 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.trowmin: tload(src) + trowmin(src, tmp)->dst + tstore(dst). + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + + // Case 0: f32 127x64 (valid=127x63) + func.func @TROWMIN_f32_127x64_valid127x63(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c127 = arith.constant 127 : index + %c8128 = arith.constant 8128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c127, %c64], + strides = [%c8128, %c8128, %c8128, %c64, %c1] + : !pto.tensor_view<1x1x1x127x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c127, %c1], + strides = [%c127, %c127, %c127, %c1, %c1] + : !pto.tensor_view<1x1x1x127x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c127, %c63] + : !pto.tensor_view<1x1x1x127x64xf32> -> !pto.partition_tensor_view<1x1x1x127x63xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c127, %c1] + : !pto.tensor_view<1x1x1x127x1xf32> -> !pto.partition_tensor_view<1x1x1x127x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x127x63xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x127x1xf32>) + return + } + + // Case 1: f32 63x64 (valid=63x64) + func.func @TROWMIN_f32_63x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c4032 = arith.constant 4032 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c63, %c1], + strides = [%c63, %c63, %c63, %c1, %c1] + : !pto.tensor_view<1x1x1x63x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xf32> -> !pto.partition_tensor_view<1x1x1x63x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c1] + : !pto.tensor_view<1x1x1x63x1xf32> -> !pto.partition_tensor_view<1x1x1x63x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x64xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x1xf32>) + return + } + + // Case 2: f32 31x128 (valid=31x127) + func.func @TROWMIN_f32_31x128_valid31x127(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c31 = arith.constant 31 : index + %c127 = arith.constant 127 : index + %c128 = arith.constant 128 : index + %c3968 = arith.constant 3968 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c31, %c1], + strides = [%c31, %c31, %c31, %c1, %c1] + : !pto.tensor_view<1x1x1x31x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c127] + : !pto.tensor_view<1x1x1x31x128xf32> -> !pto.partition_tensor_view<1x1x1x31x127xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c1] + : !pto.tensor_view<1x1x1x31x1xf32> -> !pto.partition_tensor_view<1x1x1x31x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x31x127xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x31x1xf32>) + return + } + + // Case 3: f32 15x192 (valid=15x192) + func.func @TROWMIN_f32_15x192(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c15 = arith.constant 15 : index + %c192 = arith.constant 192 : index + %c2880 = arith.constant 2880 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c15, %c1], + strides = [%c15, %c15, %c15, %c1, %c1] + : !pto.tensor_view<1x1x1x15x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xf32> -> !pto.partition_tensor_view<1x1x1x15x192xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c1] + : !pto.tensor_view<1x1x1x15x1xf32> -> !pto.partition_tensor_view<1x1x1x15x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x192xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x15x1xf32>) + return + } + + // Case 4: f32 7x448 (valid=7x447) + func.func @TROWMIN_f32_7x448_valid7x447(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c7 = arith.constant 7 : index + %c8 = arith.constant 8 : index + %c447 = arith.constant 447 : index + %c448 = arith.constant 448 : index + %c3136 = arith.constant 3136 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c7, %c448], + strides = [%c3136, %c3136, %c3136, %c448, %c1] + : !pto.tensor_view<1x1x1x7x448xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c7, %c1], + strides = [%c7, %c7, %c7, %c1, %c1] + : !pto.tensor_view<1x1x1x7x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c447] + : !pto.tensor_view<1x1x1x7x448xf32> -> !pto.partition_tensor_view<1x1x1x7x447xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c1] + : !pto.tensor_view<1x1x1x7x1xf32> -> !pto.partition_tensor_view<1x1x1x7x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x7x447xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x7x1xf32>) + return + } + + // Case 5: f16 256x16 (valid=256x15) + func.func @TROWMIN_f16_256x16_valid256x15(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c256, %c16], + strides = [%c4096, %c4096, %c4096, %c16, %c1] + : !pto.tensor_view<1x1x1x256x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c256, %c1], + strides = [%c256, %c256, %c256, %c1, %c1] + : !pto.tensor_view<1x1x1x256x1xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c15] + : !pto.tensor_view<1x1x1x256x16xf16> -> !pto.partition_tensor_view<1x1x1x256x15xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c1] + : !pto.tensor_view<1x1x1x256x1xf16> -> !pto.partition_tensor_view<1x1x1x256x1xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x256x15xf16>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x256x1xf16>) + return + } + + // Case 6: f32 30x216 (valid=30x216) + func.func @TROWMIN_f32_30x216(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c30 = arith.constant 30 : index + %c216 = arith.constant 216 : index + %c6480 = arith.constant 6480 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c30, %c216], + strides = [%c6480, %c6480, %c6480, %c216, %c1] + : !pto.tensor_view<1x1x1x30x216xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c30, %c1], + strides = [%c30, %c30, %c30, %c1, %c1] + : !pto.tensor_view<1x1x1x30x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c30, %c216] + : !pto.tensor_view<1x1x1x30x216xf32> -> !pto.partition_tensor_view<1x1x1x30x216xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c30, %c1] + : !pto.tensor_view<1x1x1x30x1xf32> -> !pto.partition_tensor_view<1x1x1x30x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x30x216xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x30x1xf32>) + return + } + + // Case 7: f32 30x216 (valid=30x24) + func.func @TROWMIN_f32_30x216_valid30x24(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c24 = arith.constant 24 : index + %c30 = arith.constant 30 : index + %c216 = arith.constant 216 : index + %c6480 = arith.constant 6480 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c30, %c216], + strides = [%c6480, %c6480, %c6480, %c216, %c1] + : !pto.tensor_view<1x1x1x30x216xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c30, %c1], + strides = [%c30, %c30, %c30, %c1, %c1] + : !pto.tensor_view<1x1x1x30x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c30, %c24] + : !pto.tensor_view<1x1x1x30x216xf32> -> !pto.partition_tensor_view<1x1x1x30x24xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c30, %c1] + : !pto.tensor_view<1x1x1x30x1xf32> -> !pto.partition_tensor_view<1x1x1x30x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x30x24xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x30x1xf32>) + return + } + + // Case 8: f32 30x216 (valid=11x216) + func.func @TROWMIN_f32_30x216_valid11x216(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c11 = arith.constant 11 : index + %c30 = arith.constant 30 : index + %c216 = arith.constant 216 : index + %c6480 = arith.constant 6480 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c30, %c216], + strides = [%c6480, %c6480, %c6480, %c216, %c1] + : !pto.tensor_view<1x1x1x30x216xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c11, %c1], + strides = [%c11, %c11, %c11, %c1, %c1] + : !pto.tensor_view<1x1x1x11x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c11, %c216] + : !pto.tensor_view<1x1x1x30x216xf32> -> !pto.partition_tensor_view<1x1x1x11x216xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c11, %c1] + : !pto.tensor_view<1x1x1x11x1xf32> -> !pto.partition_tensor_view<1x1x1x11x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x11x216xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x11x1xf32>) + return + } + + // Case 9: f32 30x216 (valid=11x24) + func.func @TROWMIN_f32_30x216_valid11x24(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c11 = arith.constant 11 : index + %c24 = arith.constant 24 : index + %c30 = arith.constant 30 : index + %c216 = arith.constant 216 : index + %c6480 = arith.constant 6480 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c30, %c216], + strides = [%c6480, %c6480, %c6480, %c216, %c1] + : !pto.tensor_view<1x1x1x30x216xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c11, %c1], + strides = [%c11, %c11, %c11, %c1, %c1] + : !pto.tensor_view<1x1x1x11x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c11, %c24] + : !pto.tensor_view<1x1x1x30x216xf32> -> !pto.partition_tensor_view<1x1x1x11x24xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c11, %c1] + : !pto.tensor_view<1x1x1x11x1xf32> -> !pto.partition_tensor_view<1x1x1x11x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x11x24xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x11x1xf32>) + return + } + + // Case 10: f32 238x40 (valid=238x40) + func.func @TROWMIN_f32_238x40(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c40 = arith.constant 40 : index + %c238 = arith.constant 238 : index + %c9520 = arith.constant 9520 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c238, %c40], + strides = [%c9520, %c9520, %c9520, %c40, %c1] + : !pto.tensor_view<1x1x1x238x40xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c238, %c1], + strides = [%c238, %c238, %c238, %c1, %c1] + : !pto.tensor_view<1x1x1x238x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c238, %c40] + : !pto.tensor_view<1x1x1x238x40xf32> -> !pto.partition_tensor_view<1x1x1x238x40xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c238, %c1] + : !pto.tensor_view<1x1x1x238x1xf32> -> !pto.partition_tensor_view<1x1x1x238x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x238x40xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x238x1xf32>) + return + } + + // Case 11: f32 238x40 (valid=238x16) + func.func @TROWMIN_f32_238x40_valid238x16(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c40 = arith.constant 40 : index + %c238 = arith.constant 238 : index + %c9520 = arith.constant 9520 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c238, %c40], + strides = [%c9520, %c9520, %c9520, %c40, %c1] + : !pto.tensor_view<1x1x1x238x40xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c238, %c1], + strides = [%c238, %c238, %c238, %c1, %c1] + : !pto.tensor_view<1x1x1x238x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c238, %c16] + : !pto.tensor_view<1x1x1x238x40xf32> -> !pto.partition_tensor_view<1x1x1x238x16xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c238, %c1] + : !pto.tensor_view<1x1x1x238x1xf32> -> !pto.partition_tensor_view<1x1x1x238x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x238x16xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x238x1xf32>) + return + } + + // Case 12: f32 238x40 (valid=121x40) + func.func @TROWMIN_f32_238x40_valid121x40(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c40 = arith.constant 40 : index + %c121 = arith.constant 121 : index + %c238 = arith.constant 238 : index + %c9520 = arith.constant 9520 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c238, %c40], + strides = [%c9520, %c9520, %c9520, %c40, %c1] + : !pto.tensor_view<1x1x1x238x40xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c121, %c1], + strides = [%c121, %c121, %c121, %c1, %c1] + : !pto.tensor_view<1x1x1x121x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c121, %c40] + : !pto.tensor_view<1x1x1x238x40xf32> -> !pto.partition_tensor_view<1x1x1x121x40xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c121, %c1] + : !pto.tensor_view<1x1x1x121x1xf32> -> !pto.partition_tensor_view<1x1x1x121x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x121x40xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x121x1xf32>) + return + } + + // Case 13: f32 238x40 (valid=121x16) + func.func @TROWMIN_f32_238x40_valid121x16(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c40 = arith.constant 40 : index + %c121 = arith.constant 121 : index + %c238 = arith.constant 238 : index + %c9520 = arith.constant 9520 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c238, %c40], + strides = [%c9520, %c9520, %c9520, %c40, %c1] + : !pto.tensor_view<1x1x1x238x40xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c121, %c1], + strides = [%c121, %c121, %c121, %c1, %c1] + : !pto.tensor_view<1x1x1x121x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c121, %c16] + : !pto.tensor_view<1x1x1x238x40xf32> -> !pto.partition_tensor_view<1x1x1x121x16xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c121, %c1] + : !pto.tensor_view<1x1x1x121x1xf32> -> !pto.partition_tensor_view<1x1x1x121x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x121x16xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x121x1xf32>) + return + } + + // Case 14: f32 64x128 (valid=64x128) + func.func @TROWMIN_f32_64x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c64, %c128], + strides = [%c8192, %c8192, %c8192, %c128, %c1] + : !pto.tensor_view<1x1x1x64x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c64, %c1], + strides = [%c64, %c64, %c64, %c1, %c1] + : !pto.tensor_view<1x1x1x64x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c128] + : !pto.tensor_view<1x1x1x64x128xf32> -> !pto.partition_tensor_view<1x1x1x64x128xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c1] + : !pto.tensor_view<1x1x1x64x1xf32> -> !pto.partition_tensor_view<1x1x1x64x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x64x128xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x64x1xf32>) + return + } + + // Case 15: f32 32x256 (valid=32x256) + func.func @TROWMIN_f32_32x256(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c32 = arith.constant 32 : index + %c256 = arith.constant 256 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c32, %c256], + strides = [%c8192, %c8192, %c8192, %c256, %c1] + : !pto.tensor_view<1x1x1x32x256xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c1], + strides = [%c32, %c32, %c32, %c1, %c1] + : !pto.tensor_view<1x1x1x32x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c256] + : !pto.tensor_view<1x1x1x32x256xf32> -> !pto.partition_tensor_view<1x1x1x32x256xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c1] + : !pto.tensor_view<1x1x1x32x1xf32> -> !pto.partition_tensor_view<1x1x1x32x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x256xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x1xf32>) + return + } + + // Case 16: f32 16x512 (valid=16x512) + func.func @TROWMIN_f32_16x512(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c512 = arith.constant 512 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c512], + strides = [%c8192, %c8192, %c8192, %c512, %c1] + : !pto.tensor_view<1x1x1x16x512xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c1], + strides = [%c16, %c16, %c16, %c1, %c1] + : !pto.tensor_view<1x1x1x16x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c512] + : !pto.tensor_view<1x1x1x16x512xf32> -> !pto.partition_tensor_view<1x1x1x16x512xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c1] + : !pto.tensor_view<1x1x1x16x1xf32> -> !pto.partition_tensor_view<1x1x1x16x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x512xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x1xf32>) + return + } + + // Case 17: f32 8x1024 (valid=8x1024) + func.func @TROWMIN_f32_8x1024(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c1024 = arith.constant 1024 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c8, %c1024], + strides = [%c8192, %c8192, %c8192, %c1024, %c1] + : !pto.tensor_view<1x1x1x8x1024xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c8, %c1], + strides = [%c8, %c8, %c8, %c1, %c1] + : !pto.tensor_view<1x1x1x8x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c1024] + : !pto.tensor_view<1x1x1x8x1024xf32> -> !pto.partition_tensor_view<1x1x1x8x1024xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c1] + : !pto.tensor_view<1x1x1x8x1xf32> -> !pto.partition_tensor_view<1x1x1x8x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x8x1024xf32>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x8x1xf32>) + return + } + + // ======================================================================== + // int32 cases (case19-case23) + // ======================================================================== + + // case19: i32 127x64 valid=127x63 + func.func @TROWMIN_i32_127x64_valid127x63(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c127 = arith.constant 127 : index + %c8128 = arith.constant 8128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c127, %c64], + strides = [%c8128, %c8128, %c8128, %c64, %c1] + : !pto.tensor_view<1x1x1x127x64xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c127, %c1], + strides = [%c127, %c127, %c127, %c1, %c1] + : !pto.tensor_view<1x1x1x127x1xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c127, %c63] + : !pto.tensor_view<1x1x1x127x64xi32> -> !pto.partition_tensor_view<1x1x1x127x63xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c127, %c1] + : !pto.tensor_view<1x1x1x127x1xi32> -> !pto.partition_tensor_view<1x1x1x127x1xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x127x63xi32>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x127x1xi32>) + return + } + + // case20: i32 63x64 valid=63x64 + func.func @TROWMIN_i32_63x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c4032 = arith.constant 4032 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c63, %c1], + strides = [%c63, %c63, %c63, %c1, %c1] + : !pto.tensor_view<1x1x1x63x1xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xi32> -> !pto.partition_tensor_view<1x1x1x63x64xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c1] + : !pto.tensor_view<1x1x1x63x1xi32> -> !pto.partition_tensor_view<1x1x1x63x1xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x64xi32>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x1xi32>) + return + } + + // case21: i32 31x128 valid=31x127 + func.func @TROWMIN_i32_31x128_valid31x127(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c31 = arith.constant 31 : index + %c127 = arith.constant 127 : index + %c128 = arith.constant 128 : index + %c3968 = arith.constant 3968 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c31, %c1], + strides = [%c31, %c31, %c31, %c1, %c1] + : !pto.tensor_view<1x1x1x31x1xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c127] + : !pto.tensor_view<1x1x1x31x128xi32> -> !pto.partition_tensor_view<1x1x1x31x127xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c1] + : !pto.tensor_view<1x1x1x31x1xi32> -> !pto.partition_tensor_view<1x1x1x31x1xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x31x127xi32>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x31x1xi32>) + return + } + + // case22: i32 15x192 valid=15x192 + func.func @TROWMIN_i32_15x192(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c15 = arith.constant 15 : index + %c192 = arith.constant 192 : index + %c2880 = arith.constant 2880 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c15, %c1], + strides = [%c15, %c15, %c15, %c1, %c1] + : !pto.tensor_view<1x1x1x15x1xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xi32> -> !pto.partition_tensor_view<1x1x1x15x192xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c1] + : !pto.tensor_view<1x1x1x15x1xi32> -> !pto.partition_tensor_view<1x1x1x15x1xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x192xi32>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x15x1xi32>) + return + } + + // case23: i32 7x448 valid=7x447 + func.func @TROWMIN_i32_7x448_valid7x447(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c7 = arith.constant 7 : index + %c447 = arith.constant 447 : index + %c448 = arith.constant 448 : index + %c3136 = arith.constant 3136 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c7, %c448], + strides = [%c3136, %c3136, %c3136, %c448, %c1] + : !pto.tensor_view<1x1x1x7x448xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c7, %c1], + strides = [%c7, %c7, %c7, %c1, %c1] + : !pto.tensor_view<1x1x1x7x1xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c447] + : !pto.tensor_view<1x1x1x7x448xi32> -> !pto.partition_tensor_view<1x1x1x7x447xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c1] + : !pto.tensor_view<1x1x1x7x1xi32> -> !pto.partition_tensor_view<1x1x1x7x1xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x7x447xi32>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x7x1xi32>) + return + } + + // ======================================================================== + // int16 cases (case24-case28) + // ======================================================================== + + // case24: i16 128x64 valid=128x64 + func.func @TROWMIN_i16_128x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c128, %c64], + strides = [%c8192, %c8192, %c8192, %c64, %c1] + : !pto.tensor_view<1x1x1x128x64xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c128, %c1], + strides = [%c128, %c128, %c128, %c1, %c1] + : !pto.tensor_view<1x1x1x128x1xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c64] + : !pto.tensor_view<1x1x1x128x64xi16> -> !pto.partition_tensor_view<1x1x1x128x64xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c1] + : !pto.tensor_view<1x1x1x128x1xi16> -> !pto.partition_tensor_view<1x1x1x128x1xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x128x64xi16>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x128x1xi16>) + return + } + + // case25: i16 64x64 valid=64x64 + func.func @TROWMIN_i16_64x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c64, %c1], + strides = [%c64, %c64, %c64, %c1, %c1] + : !pto.tensor_view<1x1x1x64x1xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c1] + : !pto.tensor_view<1x1x1x64x1xi16> -> !pto.partition_tensor_view<1x1x1x64x1xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x64x1xi16>) + return + } + + // case26: i16 32x128 valid=32x128 + func.func @TROWMIN_i16_32x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c32, %c128], + strides = [%c4096, %c4096, %c4096, %c128, %c1] + : !pto.tensor_view<1x1x1x32x128xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c1], + strides = [%c32, %c32, %c32, %c1, %c1] + : !pto.tensor_view<1x1x1x32x1xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c128] + : !pto.tensor_view<1x1x1x32x128xi16> -> !pto.partition_tensor_view<1x1x1x32x128xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c1] + : !pto.tensor_view<1x1x1x32x1xi16> -> !pto.partition_tensor_view<1x1x1x32x1xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x128xi16>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x1xi16>) + return + } + + // case27: i16 16x192 valid=16x192 + func.func @TROWMIN_i16_16x192(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c16r = arith.constant 16 : index + %c192 = arith.constant 192 : index + %c3072 = arith.constant 3072 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16r, %c192], + strides = [%c3072, %c3072, %c3072, %c192, %c1] + : !pto.tensor_view<1x1x1x16x192xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16r, %c1], + strides = [%c16, %c16, %c16, %c1, %c1] + : !pto.tensor_view<1x1x1x16x1xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16r, %c192] + : !pto.tensor_view<1x1x1x16x192xi16> -> !pto.partition_tensor_view<1x1x1x16x192xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16r, %c1] + : !pto.tensor_view<1x1x1x16x1xi16> -> !pto.partition_tensor_view<1x1x1x16x1xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x192xi16>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x1xi16>) + return + } + + // case28: i16 8x448 valid=8x448 + func.func @TROWMIN_i16_8x448(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c8 = arith.constant 8 : index + %c448 = arith.constant 448 : index + %c3584 = arith.constant 3584 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c8, %c448], + strides = [%c3584, %c3584, %c3584, %c448, %c1] + : !pto.tensor_view<1x1x1x8x448xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c8, %c1], + strides = [%c8, %c8, %c8, %c1, %c1] + : !pto.tensor_view<1x1x1x8x1xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c448] + : !pto.tensor_view<1x1x1x8x448xi16> -> !pto.partition_tensor_view<1x1x1x8x448xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c1] + : !pto.tensor_view<1x1x1x8x1xi16> -> !pto.partition_tensor_view<1x1x1x8x1xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x8x448xi16>) + outs(%src : !pto.tile_buf) + + pto.trowmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x8x1xi16>) + return + } + +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowprod/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/trowprod/CMakeLists.txt new file mode 100644 index 000000000..6a30d1293 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowprod/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(trowprod) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowprod/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/trowprod/cases.py new file mode 100644 index 000000000..66f7176a9 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowprod/cases.py @@ -0,0 +1,153 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for trowprod ST test cases. + +Aligned with pto-isa tests/npu/a5/src/st/testcase/trowprod (18 cases). +""" + +import numpy as np + +CASES = [ + # f32 cases (case1-case5 from pto-isa) + { + "name": "f32_127x64_valid127x63", + "dtype": np.float32, + "shape": (127, 64), + "valid_shape": (127, 63), + "eps": 1e-3, + }, + { + "name": "f32_63x64", + "dtype": np.float32, + "shape": (63, 64), + "valid_shape": (63, 64), + "eps": 1e-3, + }, + { + "name": "f32_31x128_valid31x127", + "dtype": np.float32, + "shape": (31, 128), + "valid_shape": (31, 127), + "eps": 1e-3, + }, + { + "name": "f32_15x192", + "dtype": np.float32, + "shape": (15, 192), + "valid_shape": (15, 192), + "eps": 1e-3, + }, + { + "name": "f32_7x448_valid7x447", + "dtype": np.float32, + "shape": (7, 448), + "valid_shape": (7, 447), + "eps": 1e-3, + }, + # f16 case (case6 from pto-isa) + { + "name": "f16_256x16_valid256x15", + "dtype": np.float16, + "shape": (256, 16), + "valid_shape": (256, 15), + "eps": 1e-1, + }, + # f32 DN dst cases (case7-case10 from pto-isa) + { + "name": "f32_64x128", + "dtype": np.float32, + "shape": (64, 128), + "valid_shape": (64, 128), + "eps": 1e-3, + }, + { + "name": "f32_32x256", + "dtype": np.float32, + "shape": (32, 256), + "valid_shape": (32, 256), + "eps": 1e-3, + }, + { + "name": "f32_16x512", + "dtype": np.float32, + "shape": (16, 512), + "valid_shape": (16, 512), + "eps": 1e-3, + }, + { + "name": "f32_8x1024", + "dtype": np.float32, + "shape": (8, 1024), + "valid_shape": (8, 1024), + "eps": 1e-3, + }, + + # int32 cases (case11-case15 from pto-isa) + { + "name": "i32_127x64_valid127x63", + "dtype": np.int32, + "shape": (127, 64), + "valid_shape": (127, 63), + "eps": 0, + }, + { + "name": "i32_63x64", + "dtype": np.int32, + "shape": (63, 64), + "valid_shape": (63, 64), + "eps": 0, + }, + { + "name": "i32_31x128_valid31x127", + "dtype": np.int32, + "shape": (31, 128), + "valid_shape": (31, 127), + "eps": 0, + }, + { + "name": "i32_15x192", + "dtype": np.int32, + "shape": (15, 192), + "valid_shape": (15, 192), + "eps": 0, + }, + { + "name": "i32_7x448_valid7x447", + "dtype": np.int32, + "shape": (7, 448), + "valid_shape": (7, 447), + "eps": 0, + }, + + # int16 cases (case16-case18 from pto-isa) + { + "name": "i16_256x16_valid256x15", + "dtype": np.int16, + "shape": (256, 16), + "valid_shape": (256, 15), + "eps": 0, + }, + { + "name": "i16_63x64", + "dtype": np.int16, + "shape": (63, 64), + "valid_shape": (63, 64), + "eps": 0, + }, + { + "name": "i16_31x128_valid31x127", + "dtype": np.int16, + "shape": (31, 128), + "valid_shape": (31, 127), + "eps": 0, + }, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowprod/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/trowprod/compare.py new file mode 100644 index 000000000..12d4207bd --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowprod/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + vr, vc = case["valid_shape"] + out_shape = (vr,) + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"], count=np.prod(out_shape)).reshape(out_shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"], count=np.prod(out_shape)).reshape(out_shape) + + ok = result_cmp(golden, output, case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowprod/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/trowprod/gen_data.py new file mode 100644 index 000000000..b1f6092af --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowprod/gen_data.py @@ -0,0 +1,42 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + if np.issubdtype(dtype, np.integer): + if dtype == np.int32: + input1 = np.random.randint(low=-3, high=4, size=shape).astype(dtype) + else: + input1 = np.random.randint(low=-2, high=3, size=shape).astype(dtype) + else: + input1 = np.random.uniform(low=0.9, high=1.1, size=shape).astype(dtype) + + out_shape = (valid_shape[0],) + golden = np.ones(out_shape, dtype=dtype) + vr, vc = valid_shape + for i in range(vr): + for j in range(vc): + golden[i] *= input1[i, j] + + golden = golden.astype(dtype, copy=False) + save_case_data(case["name"], {"input1": input1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowprod/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowprod/launch.cpp new file mode 100644 index 000000000..0533066cd --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowprod/launch.cpp @@ -0,0 +1,105 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +extern "C" __global__ AICORE void TROWPROD_f32_127x64_valid127x63(__gm__ float *src, __gm__ float *dst); +void LaunchTROWPROD_f32_127x64_valid127x63(float *src, float *dst, void *stream) { + TROWPROD_f32_127x64_valid127x63<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWPROD_f32_63x64(__gm__ float *src, __gm__ float *dst); +void LaunchTROWPROD_f32_63x64(float *src, float *dst, void *stream) { + TROWPROD_f32_63x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWPROD_f32_31x128_valid31x127(__gm__ float *src, __gm__ float *dst); +void LaunchTROWPROD_f32_31x128_valid31x127(float *src, float *dst, void *stream) { + TROWPROD_f32_31x128_valid31x127<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWPROD_f32_15x192(__gm__ float *src, __gm__ float *dst); +void LaunchTROWPROD_f32_15x192(float *src, float *dst, void *stream) { + TROWPROD_f32_15x192<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWPROD_f32_7x448_valid7x447(__gm__ float *src, __gm__ float *dst); +void LaunchTROWPROD_f32_7x448_valid7x447(float *src, float *dst, void *stream) { + TROWPROD_f32_7x448_valid7x447<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWPROD_f16_256x16_valid256x15(__gm__ uint16_t *src, __gm__ uint16_t *dst); +void LaunchTROWPROD_f16_256x16_valid256x15(uint16_t *src, uint16_t *dst, void *stream) { + TROWPROD_f16_256x16_valid256x15<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint16_t *)dst); +} + +extern "C" __global__ AICORE void TROWPROD_f32_64x128(__gm__ float *src, __gm__ float *dst); +void LaunchTROWPROD_f32_64x128(float *src, float *dst, void *stream) { + TROWPROD_f32_64x128<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWPROD_f32_32x256(__gm__ float *src, __gm__ float *dst); +void LaunchTROWPROD_f32_32x256(float *src, float *dst, void *stream) { + TROWPROD_f32_32x256<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWPROD_f32_16x512(__gm__ float *src, __gm__ float *dst); +void LaunchTROWPROD_f32_16x512(float *src, float *dst, void *stream) { + TROWPROD_f32_16x512<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +extern "C" __global__ AICORE void TROWPROD_f32_8x1024(__gm__ float *src, __gm__ float *dst); +void LaunchTROWPROD_f32_8x1024(float *src, float *dst, void *stream) { + TROWPROD_f32_8x1024<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +// int32 cases +extern "C" __global__ AICORE void TROWPROD_i32_127x64_valid127x63(__gm__ int32_t *src, __gm__ int32_t *dst); +void LaunchTROWPROD_i32_127x64_valid127x63(int32_t *src, int32_t *dst, void *stream) { + TROWPROD_i32_127x64_valid127x63<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst); +} + +extern "C" __global__ AICORE void TROWPROD_i32_63x64(__gm__ int32_t *src, __gm__ int32_t *dst); +void LaunchTROWPROD_i32_63x64(int32_t *src, int32_t *dst, void *stream) { + TROWPROD_i32_63x64<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst); +} + +extern "C" __global__ AICORE void TROWPROD_i32_31x128_valid31x127(__gm__ int32_t *src, __gm__ int32_t *dst); +void LaunchTROWPROD_i32_31x128_valid31x127(int32_t *src, int32_t *dst, void *stream) { + TROWPROD_i32_31x128_valid31x127<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst); +} + +extern "C" __global__ AICORE void TROWPROD_i32_15x192(__gm__ int32_t *src, __gm__ int32_t *dst); +void LaunchTROWPROD_i32_15x192(int32_t *src, int32_t *dst, void *stream) { + TROWPROD_i32_15x192<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst); +} + +extern "C" __global__ AICORE void TROWPROD_i32_7x448_valid7x447(__gm__ int32_t *src, __gm__ int32_t *dst); +void LaunchTROWPROD_i32_7x448_valid7x447(int32_t *src, int32_t *dst, void *stream) { + TROWPROD_i32_7x448_valid7x447<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst); +} + +// int16 cases +extern "C" __global__ AICORE void TROWPROD_i16_256x16_valid256x15(__gm__ int16_t *src, __gm__ int16_t *dst); +void LaunchTROWPROD_i16_256x16_valid256x15(int16_t *src, int16_t *dst, void *stream) { + TROWPROD_i16_256x16_valid256x15<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst); +} + +extern "C" __global__ AICORE void TROWPROD_i16_63x64(__gm__ int16_t *src, __gm__ int16_t *dst); +void LaunchTROWPROD_i16_63x64(int16_t *src, int16_t *dst, void *stream) { + TROWPROD_i16_63x64<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst); +} + +extern "C" __global__ AICORE void TROWPROD_i16_31x128_valid31x127(__gm__ int16_t *src, __gm__ int16_t *dst); +void LaunchTROWPROD_i16_31x128_valid31x127(int16_t *src, int16_t *dst, void *stream) { + TROWPROD_i16_31x128_valid31x127<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowprod/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowprod/main.cpp new file mode 100644 index 000000000..32981566a --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowprod/main.cpp @@ -0,0 +1,186 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang trowprod ST — case-table driven. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTROWPROD_f32_127x64_valid127x63(float *src, float *dst, void *stream); +void LaunchTROWPROD_f32_63x64(float *src, float *dst, void *stream); +void LaunchTROWPROD_f32_31x128_valid31x127(float *src, float *dst, void *stream); +void LaunchTROWPROD_f32_15x192(float *src, float *dst, void *stream); +void LaunchTROWPROD_f32_7x448_valid7x447(float *src, float *dst, void *stream); +void LaunchTROWPROD_f16_256x16_valid256x15(uint16_t *src, uint16_t *dst, void *stream); +void LaunchTROWPROD_f32_64x128(float *src, float *dst, void *stream); +void LaunchTROWPROD_f32_32x256(float *src, float *dst, void *stream); +void LaunchTROWPROD_f32_16x512(float *src, float *dst, void *stream); +void LaunchTROWPROD_f32_8x1024(float *src, float *dst, void *stream); +void LaunchTROWPROD_i32_127x64_valid127x63(int32_t *src, int32_t *dst, void *stream); +void LaunchTROWPROD_i32_63x64(int32_t *src, int32_t *dst, void *stream); +void LaunchTROWPROD_i32_31x128_valid31x127(int32_t *src, int32_t *dst, void *stream); +void LaunchTROWPROD_i32_15x192(int32_t *src, int32_t *dst, void *stream); +void LaunchTROWPROD_i32_7x448_valid7x447(int32_t *src, int32_t *dst, void *stream); +void LaunchTROWPROD_i16_256x16_valid256x15(int16_t *src, int16_t *dst, void *stream); +void LaunchTROWPROD_i16_63x64(int16_t *src, int16_t *dst, void *stream); +void LaunchTROWPROD_i16_31x128_valid31x127(int16_t *src, int16_t *dst, void *stream); + +using LaunchFnF32 = void (*)(float *, float *, void *); +using LaunchFnF16 = void (*)(uint16_t *, uint16_t *, void *); +using LaunchFnI32 = void (*)(int32_t *, int32_t *, void *); +using LaunchFnI16 = void (*)(int16_t *, int16_t *, void *); + +enum class DType { F32, F16, I32, I16 }; + +struct TestCase { + const char *name; + DType dtype; + union { + LaunchFnF32 launchF32; + LaunchFnF16 launchF16; + LaunchFnI32 launchI32; + LaunchFnI16 launchI16; + }; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + // f32 cases + {"f32_127x64_valid127x63", DType::F32, .launchF32 = LaunchTROWPROD_f32_127x64_valid127x63, 127, 64, 127, 63, 4}, + {"f32_63x64", DType::F32, .launchF32 = LaunchTROWPROD_f32_63x64, 63, 64, 63, 64, 4}, + {"f32_31x128_valid31x127", DType::F32, .launchF32 = LaunchTROWPROD_f32_31x128_valid31x127, 31, 128, 31, 127, 4}, + {"f32_15x192", DType::F32, .launchF32 = LaunchTROWPROD_f32_15x192, 15, 192, 15, 192, 4}, + {"f32_7x448_valid7x447", DType::F32, .launchF32 = LaunchTROWPROD_f32_7x448_valid7x447, 7, 448, 7, 447, 4}, + // f16 case + {"f16_256x16_valid256x15", DType::F16, .launchF16 = LaunchTROWPROD_f16_256x16_valid256x15, 256, 16, 256, 15, 2}, + // f32 DN dst cases + {"f32_64x128", DType::F32, .launchF32 = LaunchTROWPROD_f32_64x128, 64, 128, 64, 128, 4}, + {"f32_32x256", DType::F32, .launchF32 = LaunchTROWPROD_f32_32x256, 32, 256, 32, 256, 4}, + {"f32_16x512", DType::F32, .launchF32 = LaunchTROWPROD_f32_16x512, 16, 512, 16, 512, 4}, + {"f32_8x1024", DType::F32, .launchF32 = LaunchTROWPROD_f32_8x1024, 8, 1024,8, 1024,4}, + // int32 cases + {"i32_127x64_valid127x63", DType::I32, .launchI32 = LaunchTROWPROD_i32_127x64_valid127x63, 127, 64, 127, 63, 4}, + {"i32_63x64", DType::I32, .launchI32 = LaunchTROWPROD_i32_63x64, 63, 64, 63, 64, 4}, + {"i32_31x128_valid31x127", DType::I32, .launchI32 = LaunchTROWPROD_i32_31x128_valid31x127, 31, 128, 31, 127, 4}, + {"i32_15x192", DType::I32, .launchI32 = LaunchTROWPROD_i32_15x192, 15, 192, 15, 192, 4}, + {"i32_7x448_valid7x447", DType::I32, .launchI32 = LaunchTROWPROD_i32_7x448_valid7x447, 7, 448, 7, 447, 4}, + // int16 cases + {"i16_256x16_valid256x15", DType::I16, .launchI16 = LaunchTROWPROD_i16_256x16_valid256x15, 256, 16, 256, 15, 2}, + {"i16_63x64", DType::I16, .launchI16 = LaunchTROWPROD_i16_63x64, 63, 64, 63, 64, 2}, + {"i16_31x128_valid31x127", DType::I16, .launchI16 = LaunchTROWPROD_i16_31x128_valid31x127, 31, 128, 31, 127, 2}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t srcElemCount = tc.rows * tc.cols; + const size_t srcFileSize = srcElemCount * tc.elemSize; + const size_t dstElemCount = tc.validRows * 1; + const size_t dstFileSize = dstElemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t src0FileSize = srcFileSize; + + void *src0Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&src0Host, srcFileSize); + aclrtMallocHost(&dstHost, dstFileSize); + + aclrtMalloc(&src0Device, srcFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, srcFileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, srcFileSize, src0Host, srcFileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + switch (tc.dtype) { + case DType::F32: tc.launchF32((float *)src0Device, (float *)dstDevice, stream); break; + case DType::F16: tc.launchF16((uint16_t *)src0Device, (uint16_t *)dstDevice, stream); break; + case DType::I32: tc.launchI32((int32_t *)src0Device, (int32_t *)dstDevice, stream); break; + case DType::I16: tc.launchI16((int16_t *)src0Device, (int16_t *)dstDevice, stream); break; + } + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./trowprod [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowprod/trowprod.pto b/test/tilelang_st/npu/a5/src/st/testcase/trowprod/trowprod.pto new file mode 100644 index 000000000..b51762fd8 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowprod/trowprod.pto @@ -0,0 +1,855 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.trowprod: tload(src) + trowprod(src, tmp)->dst + tstore(dst). + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + + // Case 0: f32 127x64 (valid=127x63) + func.func @TROWPROD_f32_127x64_valid127x63(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c127 = arith.constant 127 : index + %c8128 = arith.constant 8128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c127, %c64], + strides = [%c8128, %c8128, %c8128, %c64, %c1] + : !pto.tensor_view<1x1x1x127x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c127, %c1], + strides = [%c127, %c127, %c127, %c1, %c1] + : !pto.tensor_view<1x1x1x127x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c127, %c63] + : !pto.tensor_view<1x1x1x127x64xf32> -> !pto.partition_tensor_view<1x1x1x127x63xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c127, %c1] + : !pto.tensor_view<1x1x1x127x1xf32> -> !pto.partition_tensor_view<1x1x1x127x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x127x63xf32>) + outs(%src : !pto.tile_buf) + + pto.trowprod ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x127x1xf32>) + return + } + + // Case 1: f32 63x64 (valid=63x64) + func.func @TROWPROD_f32_63x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c4032 = arith.constant 4032 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c63, %c1], + strides = [%c63, %c63, %c63, %c1, %c1] + : !pto.tensor_view<1x1x1x63x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xf32> -> !pto.partition_tensor_view<1x1x1x63x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c1] + : !pto.tensor_view<1x1x1x63x1xf32> -> !pto.partition_tensor_view<1x1x1x63x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x64xf32>) + outs(%src : !pto.tile_buf) + + pto.trowprod ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x1xf32>) + return + } + + // Case 2: f32 31x128 (valid=31x127) + func.func @TROWPROD_f32_31x128_valid31x127(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c31 = arith.constant 31 : index + %c127 = arith.constant 127 : index + %c128 = arith.constant 128 : index + %c3968 = arith.constant 3968 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c31, %c1], + strides = [%c31, %c31, %c31, %c1, %c1] + : !pto.tensor_view<1x1x1x31x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c127] + : !pto.tensor_view<1x1x1x31x128xf32> -> !pto.partition_tensor_view<1x1x1x31x127xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c1] + : !pto.tensor_view<1x1x1x31x1xf32> -> !pto.partition_tensor_view<1x1x1x31x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x31x127xf32>) + outs(%src : !pto.tile_buf) + + pto.trowprod ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x31x1xf32>) + return + } + + // Case 3: f32 15x192 (valid=15x192) + func.func @TROWPROD_f32_15x192(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c15 = arith.constant 15 : index + %c192 = arith.constant 192 : index + %c2880 = arith.constant 2880 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c15, %c1], + strides = [%c15, %c15, %c15, %c1, %c1] + : !pto.tensor_view<1x1x1x15x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xf32> -> !pto.partition_tensor_view<1x1x1x15x192xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c1] + : !pto.tensor_view<1x1x1x15x1xf32> -> !pto.partition_tensor_view<1x1x1x15x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x192xf32>) + outs(%src : !pto.tile_buf) + + pto.trowprod ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x15x1xf32>) + return + } + + // Case 4: f32 7x448 (valid=7x447) + func.func @TROWPROD_f32_7x448_valid7x447(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c7 = arith.constant 7 : index + %c8 = arith.constant 8 : index + %c447 = arith.constant 447 : index + %c448 = arith.constant 448 : index + %c3136 = arith.constant 3136 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c7, %c448], + strides = [%c3136, %c3136, %c3136, %c448, %c1] + : !pto.tensor_view<1x1x1x7x448xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c7, %c1], + strides = [%c7, %c7, %c7, %c1, %c1] + : !pto.tensor_view<1x1x1x7x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c447] + : !pto.tensor_view<1x1x1x7x448xf32> -> !pto.partition_tensor_view<1x1x1x7x447xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c1] + : !pto.tensor_view<1x1x1x7x1xf32> -> !pto.partition_tensor_view<1x1x1x7x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x7x447xf32>) + outs(%src : !pto.tile_buf) + + pto.trowprod ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x7x1xf32>) + return + } + + // Case 5: f16 256x16 (valid=256x15) + func.func @TROWPROD_f16_256x16_valid256x15(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c256, %c16], + strides = [%c4096, %c4096, %c4096, %c16, %c1] + : !pto.tensor_view<1x1x1x256x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c256, %c1], + strides = [%c256, %c256, %c256, %c1, %c1] + : !pto.tensor_view<1x1x1x256x1xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c15] + : !pto.tensor_view<1x1x1x256x16xf16> -> !pto.partition_tensor_view<1x1x1x256x15xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c1] + : !pto.tensor_view<1x1x1x256x1xf16> -> !pto.partition_tensor_view<1x1x1x256x1xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x256x15xf16>) + outs(%src : !pto.tile_buf) + + pto.trowprod ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x256x1xf16>) + return + } + + // Case 6: f32 64x128 (valid=64x128) + func.func @TROWPROD_f32_64x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c64, %c128], + strides = [%c8192, %c8192, %c8192, %c128, %c1] + : !pto.tensor_view<1x1x1x64x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c64, %c1], + strides = [%c64, %c64, %c64, %c1, %c1] + : !pto.tensor_view<1x1x1x64x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c128] + : !pto.tensor_view<1x1x1x64x128xf32> -> !pto.partition_tensor_view<1x1x1x64x128xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c1] + : !pto.tensor_view<1x1x1x64x1xf32> -> !pto.partition_tensor_view<1x1x1x64x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x64x128xf32>) + outs(%src : !pto.tile_buf) + + pto.trowprod ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x64x1xf32>) + return + } + + // Case 7: f32 32x256 (valid=32x256) + func.func @TROWPROD_f32_32x256(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c32 = arith.constant 32 : index + %c256 = arith.constant 256 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c32, %c256], + strides = [%c8192, %c8192, %c8192, %c256, %c1] + : !pto.tensor_view<1x1x1x32x256xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c1], + strides = [%c32, %c32, %c32, %c1, %c1] + : !pto.tensor_view<1x1x1x32x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c256] + : !pto.tensor_view<1x1x1x32x256xf32> -> !pto.partition_tensor_view<1x1x1x32x256xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c1] + : !pto.tensor_view<1x1x1x32x1xf32> -> !pto.partition_tensor_view<1x1x1x32x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x256xf32>) + outs(%src : !pto.tile_buf) + + pto.trowprod ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x1xf32>) + return + } + + // Case 8: f32 16x512 (valid=16x512) + func.func @TROWPROD_f32_16x512(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c512 = arith.constant 512 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c512], + strides = [%c8192, %c8192, %c8192, %c512, %c1] + : !pto.tensor_view<1x1x1x16x512xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c1], + strides = [%c16, %c16, %c16, %c1, %c1] + : !pto.tensor_view<1x1x1x16x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c512] + : !pto.tensor_view<1x1x1x16x512xf32> -> !pto.partition_tensor_view<1x1x1x16x512xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c1] + : !pto.tensor_view<1x1x1x16x1xf32> -> !pto.partition_tensor_view<1x1x1x16x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x512xf32>) + outs(%src : !pto.tile_buf) + + pto.trowprod ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x1xf32>) + return + } + + // Case 9: f32 8x1024 (valid=8x1024) + func.func @TROWPROD_f32_8x1024(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c1024 = arith.constant 1024 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c8, %c1024], + strides = [%c8192, %c8192, %c8192, %c1024, %c1] + : !pto.tensor_view<1x1x1x8x1024xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c8, %c1], + strides = [%c8, %c8, %c8, %c1, %c1] + : !pto.tensor_view<1x1x1x8x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c1024] + : !pto.tensor_view<1x1x1x8x1024xf32> -> !pto.partition_tensor_view<1x1x1x8x1024xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c1] + : !pto.tensor_view<1x1x1x8x1xf32> -> !pto.partition_tensor_view<1x1x1x8x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x8x1024xf32>) + outs(%src : !pto.tile_buf) + + pto.trowprod ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x8x1xf32>) + return + } + + // ======================================================================== + // int32 cases (case11-case15) + // ======================================================================== + + // case11: i32 127x64 valid=127x63 + func.func @TROWPROD_i32_127x64_valid127x63(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c127 = arith.constant 127 : index + %c8128 = arith.constant 8128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c127, %c64], + strides = [%c8128, %c8128, %c8128, %c64, %c1] + : !pto.tensor_view<1x1x1x127x64xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c127, %c1], + strides = [%c127, %c127, %c127, %c1, %c1] + : !pto.tensor_view<1x1x1x127x1xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c127, %c63] + : !pto.tensor_view<1x1x1x127x64xi32> -> !pto.partition_tensor_view<1x1x1x127x63xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c127, %c1] + : !pto.tensor_view<1x1x1x127x1xi32> -> !pto.partition_tensor_view<1x1x1x127x1xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x127x63xi32>) + outs(%src : !pto.tile_buf) + + pto.trowprod ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x127x1xi32>) + return + } + + // case12: i32 63x64 valid=63x64 + func.func @TROWPROD_i32_63x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c4032 = arith.constant 4032 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c63, %c1], + strides = [%c63, %c63, %c63, %c1, %c1] + : !pto.tensor_view<1x1x1x63x1xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xi32> -> !pto.partition_tensor_view<1x1x1x63x64xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c1] + : !pto.tensor_view<1x1x1x63x1xi32> -> !pto.partition_tensor_view<1x1x1x63x1xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x64xi32>) + outs(%src : !pto.tile_buf) + + pto.trowprod ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x1xi32>) + return + } + + // case13: i32 31x128 valid=31x127 + func.func @TROWPROD_i32_31x128_valid31x127(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c31 = arith.constant 31 : index + %c127 = arith.constant 127 : index + %c128 = arith.constant 128 : index + %c3968 = arith.constant 3968 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c31, %c1], + strides = [%c31, %c31, %c31, %c1, %c1] + : !pto.tensor_view<1x1x1x31x1xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c127] + : !pto.tensor_view<1x1x1x31x128xi32> -> !pto.partition_tensor_view<1x1x1x31x127xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c1] + : !pto.tensor_view<1x1x1x31x1xi32> -> !pto.partition_tensor_view<1x1x1x31x1xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x31x127xi32>) + outs(%src : !pto.tile_buf) + + pto.trowprod ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x31x1xi32>) + return + } + + // case14: i32 15x192 valid=15x192 + func.func @TROWPROD_i32_15x192(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c15 = arith.constant 15 : index + %c192 = arith.constant 192 : index + %c2880 = arith.constant 2880 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c15, %c1], + strides = [%c15, %c15, %c15, %c1, %c1] + : !pto.tensor_view<1x1x1x15x1xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xi32> -> !pto.partition_tensor_view<1x1x1x15x192xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c1] + : !pto.tensor_view<1x1x1x15x1xi32> -> !pto.partition_tensor_view<1x1x1x15x1xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x192xi32>) + outs(%src : !pto.tile_buf) + + pto.trowprod ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x15x1xi32>) + return + } + + // case15: i32 7x448 valid=7x447 + func.func @TROWPROD_i32_7x448_valid7x447(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c7 = arith.constant 7 : index + %c447 = arith.constant 447 : index + %c448 = arith.constant 448 : index + %c3136 = arith.constant 3136 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c7, %c448], + strides = [%c3136, %c3136, %c3136, %c448, %c1] + : !pto.tensor_view<1x1x1x7x448xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c7, %c1], + strides = [%c7, %c7, %c7, %c1, %c1] + : !pto.tensor_view<1x1x1x7x1xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c447] + : !pto.tensor_view<1x1x1x7x448xi32> -> !pto.partition_tensor_view<1x1x1x7x447xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c1] + : !pto.tensor_view<1x1x1x7x1xi32> -> !pto.partition_tensor_view<1x1x1x7x1xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x7x447xi32>) + outs(%src : !pto.tile_buf) + + pto.trowprod ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x7x1xi32>) + return + } + + // ======================================================================== + // int16 cases (case16-case18) + // ======================================================================== + + // case16: i16 256x16 valid=256x15 + func.func @TROWPROD_i16_256x16_valid256x15(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c15 = arith.constant 15 : index + %c256 = arith.constant 256 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c256, %c16], + strides = [%c4096, %c4096, %c4096, %c16, %c1] + : !pto.tensor_view<1x1x1x256x16xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c256, %c1], + strides = [%c256, %c256, %c256, %c1, %c1] + : !pto.tensor_view<1x1x1x256x1xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c15] + : !pto.tensor_view<1x1x1x256x16xi16> -> !pto.partition_tensor_view<1x1x1x256x15xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c1] + : !pto.tensor_view<1x1x1x256x1xi16> -> !pto.partition_tensor_view<1x1x1x256x1xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x256x15xi16>) + outs(%src : !pto.tile_buf) + + pto.trowprod ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x256x1xi16>) + return + } + + // case17: i16 63x64 valid=63x64 + func.func @TROWPROD_i16_63x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c4032 = arith.constant 4032 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c63, %c1], + strides = [%c63, %c63, %c63, %c1, %c1] + : !pto.tensor_view<1x1x1x63x1xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xi16> -> !pto.partition_tensor_view<1x1x1x63x64xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c1] + : !pto.tensor_view<1x1x1x63x1xi16> -> !pto.partition_tensor_view<1x1x1x63x1xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x64xi16>) + outs(%src : !pto.tile_buf) + + pto.trowprod ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x1xi16>) + return + } + + // case18: i16 31x128 valid=31x127 + func.func @TROWPROD_i16_31x128_valid31x127(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c31 = arith.constant 31 : index + %c127 = arith.constant 127 : index + %c128 = arith.constant 128 : index + %c3968 = arith.constant 3968 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c31, %c1], + strides = [%c31, %c31, %c31, %c1, %c1] + : !pto.tensor_view<1x1x1x31x1xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c127] + : !pto.tensor_view<1x1x1x31x128xi16> -> !pto.partition_tensor_view<1x1x1x31x127xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c1] + : !pto.tensor_view<1x1x1x31x1xi16> -> !pto.partition_tensor_view<1x1x1x31x1xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x31x127xi16>) + outs(%src : !pto.tile_buf) + + pto.trowprod ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x31x1xi16>) + return + } + +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowsum/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/trowsum/CMakeLists.txt new file mode 100644 index 000000000..bcb316bcc --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowsum/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(trowsum) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowsum/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/trowsum/cases.py new file mode 100644 index 000000000..c3a410199 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowsum/cases.py @@ -0,0 +1,174 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for trowsum ST test cases. + +Aligned with pto-isa tests/npu/a5/src/st/testcase/trowsum (20 cases). +""" + +import numpy as np + +CASES = [ + # f32 cases (case1-case10 from pto-isa) + { + "name": "f32_127x64_valid127x63", + "dtype": np.float32, + "shape": (127, 64), + "valid_shape": (127, 63), + "eps": 1e-3, + }, + { + "name": "f32_63x64", + "dtype": np.float32, + "shape": (63, 64), + "valid_shape": (63, 64), + "eps": 1e-3, + }, + { + "name": "f32_31x128_valid31x127", + "dtype": np.float32, + "shape": (31, 128), + "valid_shape": (31, 127), + "eps": 1e-3, + }, + { + "name": "f32_15x192", + "dtype": np.float32, + "shape": (15, 192), + "valid_shape": (15, 192), + "eps": 1e-3, + }, + { + "name": "f32_7x448_valid7x447", + "dtype": np.float32, + "shape": (7, 448), + "valid_shape": (7, 447), + "eps": 1e-3, + }, + { + "name": "f16_256x16_valid256x15", + "dtype": np.float16, + "shape": (256, 16), + "valid_shape": (256, 15), + "eps": 5e-3, + }, + { + "name": "f32_64x128", + "dtype": np.float32, + "shape": (64, 128), + "valid_shape": (64, 128), + "eps": 1e-3, + }, + { + "name": "f32_32x256", + "dtype": np.float32, + "shape": (32, 256), + "valid_shape": (32, 256), + "eps": 1e-3, + }, + { + "name": "f32_16x512", + "dtype": np.float32, + "shape": (16, 512), + "valid_shape": (16, 512), + "eps": 1e-3, + }, + { + "name": "f32_8x1024", + "dtype": np.float32, + "shape": (8, 1024), + "valid_shape": (8, 1024), + "eps": 1e-3, + }, + + # int32 cases (case11-case15 from pto-isa) + { + "name": "i32_127x64_valid127x63", + "dtype": np.int32, + "shape": (127, 64), + "valid_shape": (127, 63), + "eps": 0, + }, + { + "name": "i32_63x64", + "dtype": np.int32, + "shape": (63, 64), + "valid_shape": (63, 64), + "eps": 0, + }, + { + "name": "i32_31x128_valid31x127", + "dtype": np.int32, + "shape": (31, 128), + "valid_shape": (31, 127), + "eps": 0, + }, + { + "name": "i32_15x192", + "dtype": np.int32, + "shape": (15, 192), + "valid_shape": (15, 192), + "eps": 0, + }, + { + "name": "i32_7x448_valid7x447", + "dtype": np.int32, + "shape": (7, 448), + "valid_shape": (7, 447), + "eps": 0, + }, + + # int16 cases (case16-case20 from pto-isa) + { + "name": "i16_128x64", + "dtype": np.int16, + "shape": (128, 64), + "valid_shape": (128, 64), + "eps": 0, + }, + { + "name": "i16_64x64", + "dtype": np.int16, + "shape": (64, 64), + "valid_shape": (64, 64), + "eps": 0, + }, + { + "name": "i16_32x128", + "dtype": np.int16, + "shape": (32, 128), + "valid_shape": (32, 128), + "eps": 0, + }, + { + "name": "i16_16x192", + "dtype": np.int16, + "shape": (16, 192), + "valid_shape": (16, 192), + "eps": 0, + }, + { + "name": "i16_8x448", + "dtype": np.int16, + "shape": (8, 448), + "valid_shape": (8, 448), + "eps": 0, + }, + # i16 overflow case to test vcvt NOSAT behavior + { + "name": "i16_1x64_overflow", + "dtype": np.int16, + "shape": (1, 64), + "valid_shape": (1, 64), + "eps": 0, + "overflow": True, + }, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowsum/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/trowsum/compare.py new file mode 100644 index 000000000..b80e2549b --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowsum/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + vr, vc = case["valid_shape"] + out_shape = (vr, 1) + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"], count=np.prod(out_shape)).reshape(out_shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"], count=np.prod(out_shape)).reshape(out_shape) + + ok = result_cmp(golden, output, case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowsum/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/trowsum/gen_data.py new file mode 100644 index 000000000..0a7041c34 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowsum/gen_data.py @@ -0,0 +1,50 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import numpy as np +from cases import CASES +from st_common import validate_cases, save_case_data + +validate_cases(CASES) + +np.random.seed(42) + +for case in CASES: + dtype = case["dtype"] + row = case["shape"][0] + valid_row = case["valid_shape"][0] + col = case["shape"][1] + valid_col = case["valid_shape"][1] + + if np.issubdtype(dtype, np.integer): + if dtype == np.int32: + input_arr = np.random.randint(low=-100, high=100, size=(row, col)).astype(dtype) + elif dtype == np.int16: + if case.get("overflow"): + # Generate values that cause overflow when summed to test NOSAT behavior + # 1000 * 64 = 64000 > 32767, wraps to -1536 in int16 + input_arr = np.full((row, col), 1000, dtype=dtype) + else: + input_arr = np.random.randint(low=-50, high=50, size=(row, col)).astype(dtype) + else: + input_arr = np.random.randint(low=-10, high=10, size=(row, col)).astype(dtype) + else: + input_arr = np.random.uniform(low=-1, high=1, size=(row, col)).astype(dtype) + + output_arr = np.zeros((row,), dtype=np.int64 if np.issubdtype(dtype, np.integer) else np.float64) + for i in range(valid_row): + for j in range(valid_col): + output_arr[i] += int(input_arr[i, j]) if np.issubdtype(dtype, np.integer) else input_arr[i, j] + output_arr = output_arr.astype(dtype) + + save_case_data(case["name"], {"input": input_arr, "golden": output_arr}) + print(f"[INFO] gen_data: {case['name']} shape=({row},{col}) valid=({valid_row},{valid_col}) dtype={dtype.__name__}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowsum/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowsum/launch.cpp new file mode 100644 index 000000000..9209d568f --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowsum/launch.cpp @@ -0,0 +1,115 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// ======================================================================== +// f32 kernels +// ======================================================================== + +extern "C" __global__ AICORE void TROWSUM_f32_127x64_valid127x63(__gm__ float *src, __gm__ float *dst); +extern "C" __global__ AICORE void TROWSUM_f32_63x64(__gm__ float *src, __gm__ float *dst); +extern "C" __global__ AICORE void TROWSUM_f32_31x128_valid31x127(__gm__ float *src, __gm__ float *dst); +extern "C" __global__ AICORE void TROWSUM_f32_15x192(__gm__ float *src, __gm__ float *dst); +extern "C" __global__ AICORE void TROWSUM_f32_7x448_valid7x447(__gm__ float *src, __gm__ float *dst); +extern "C" __global__ AICORE void TROWSUM_f16_256x16_valid256x15(__gm__ uint16_t *src, __gm__ uint16_t *dst); +extern "C" __global__ AICORE void TROWSUM_f32_64x128(__gm__ float *src, __gm__ float *dst); +extern "C" __global__ AICORE void TROWSUM_f32_32x256(__gm__ float *src, __gm__ float *dst); +extern "C" __global__ AICORE void TROWSUM_f32_16x512(__gm__ float *src, __gm__ float *dst); +extern "C" __global__ AICORE void TROWSUM_f32_8x1024(__gm__ float *src, __gm__ float *dst); + +void LaunchTROWSUM_f32_127x64_valid127x63(float *src, float *dst, void *stream) { + TROWSUM_f32_127x64_valid127x63<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} +void LaunchTROWSUM_f32_63x64(float *src, float *dst, void *stream) { + TROWSUM_f32_63x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} +void LaunchTROWSUM_f32_31x128_valid31x127(float *src, float *dst, void *stream) { + TROWSUM_f32_31x128_valid31x127<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} +void LaunchTROWSUM_f32_15x192(float *src, float *dst, void *stream) { + TROWSUM_f32_15x192<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} +void LaunchTROWSUM_f32_7x448_valid7x447(float *src, float *dst, void *stream) { + TROWSUM_f32_7x448_valid7x447<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} +void LaunchTROWSUM_f16_256x16_valid256x15(uint16_t *src, uint16_t *dst, void *stream) { + TROWSUM_f16_256x16_valid256x15<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint16_t *)dst); +} +void LaunchTROWSUM_f32_64x128(float *src, float *dst, void *stream) { + TROWSUM_f32_64x128<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} +void LaunchTROWSUM_f32_32x256(float *src, float *dst, void *stream) { + TROWSUM_f32_32x256<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} +void LaunchTROWSUM_f32_16x512(float *src, float *dst, void *stream) { + TROWSUM_f32_16x512<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} +void LaunchTROWSUM_f32_8x1024(float *src, float *dst, void *stream) { + TROWSUM_f32_8x1024<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst); +} + +// ======================================================================== +// i32 kernels +// ======================================================================== + +extern "C" __global__ AICORE void TROWSUM_i32_127x64_valid127x63(__gm__ int32_t *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TROWSUM_i32_63x64(__gm__ int32_t *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TROWSUM_i32_31x128_valid31x127(__gm__ int32_t *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TROWSUM_i32_15x192(__gm__ int32_t *src, __gm__ int32_t *dst); +extern "C" __global__ AICORE void TROWSUM_i32_7x448_valid7x447(__gm__ int32_t *src, __gm__ int32_t *dst); + +void LaunchTROWSUM_i32_127x64_valid127x63(int32_t *src, int32_t *dst, void *stream) { + TROWSUM_i32_127x64_valid127x63<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst); +} +void LaunchTROWSUM_i32_63x64(int32_t *src, int32_t *dst, void *stream) { + TROWSUM_i32_63x64<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst); +} +void LaunchTROWSUM_i32_31x128_valid31x127(int32_t *src, int32_t *dst, void *stream) { + TROWSUM_i32_31x128_valid31x127<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst); +} +void LaunchTROWSUM_i32_15x192(int32_t *src, int32_t *dst, void *stream) { + TROWSUM_i32_15x192<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst); +} +void LaunchTROWSUM_i32_7x448_valid7x447(int32_t *src, int32_t *dst, void *stream) { + TROWSUM_i32_7x448_valid7x447<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst); +} + +// ======================================================================== +// i16 kernels +// ======================================================================== + +extern "C" __global__ AICORE void TROWSUM_i16_128x64(__gm__ int16_t *src, __gm__ int16_t *dst); +extern "C" __global__ AICORE void TROWSUM_i16_64x64(__gm__ int16_t *src, __gm__ int16_t *dst); +extern "C" __global__ AICORE void TROWSUM_i16_32x128(__gm__ int16_t *src, __gm__ int16_t *dst); +extern "C" __global__ AICORE void TROWSUM_i16_16x192(__gm__ int16_t *src, __gm__ int16_t *dst); +extern "C" __global__ AICORE void TROWSUM_i16_8x448(__gm__ int16_t *src, __gm__ int16_t *dst); +extern "C" __global__ AICORE void TROWSUM_i16_1x64_overflow(__gm__ int16_t *src, __gm__ int16_t *dst); + +void LaunchTROWSUM_i16_128x64(int16_t *src, int16_t *dst, void *stream) { + TROWSUM_i16_128x64<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst); +} +void LaunchTROWSUM_i16_64x64(int16_t *src, int16_t *dst, void *stream) { + TROWSUM_i16_64x64<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst); +} +void LaunchTROWSUM_i16_32x128(int16_t *src, int16_t *dst, void *stream) { + TROWSUM_i16_32x128<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst); +} +void LaunchTROWSUM_i16_16x192(int16_t *src, int16_t *dst, void *stream) { + TROWSUM_i16_16x192<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst); +} +void LaunchTROWSUM_i16_8x448(int16_t *src, int16_t *dst, void *stream) { + TROWSUM_i16_8x448<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst); +} +void LaunchTROWSUM_i16_1x64_overflow(int16_t *src, int16_t *dst, void *stream) { + TROWSUM_i16_1x64_overflow<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowsum/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trowsum/main.cpp new file mode 100644 index 000000000..5d0950e84 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowsum/main.cpp @@ -0,0 +1,193 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang trowsum ST — case-table driven. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTROWSUM_f32_127x64_valid127x63(float *src, float *dst, void *stream); +void LaunchTROWSUM_f32_63x64(float *src, float *dst, void *stream); +void LaunchTROWSUM_f32_31x128_valid31x127(float *src, float *dst, void *stream); +void LaunchTROWSUM_f32_15x192(float *src, float *dst, void *stream); +void LaunchTROWSUM_f32_7x448_valid7x447(float *src, float *dst, void *stream); +void LaunchTROWSUM_f16_256x16_valid256x15(uint16_t *src, uint16_t *dst, void *stream); +void LaunchTROWSUM_f32_64x128(float *src, float *dst, void *stream); +void LaunchTROWSUM_f32_32x256(float *src, float *dst, void *stream); +void LaunchTROWSUM_f32_16x512(float *src, float *dst, void *stream); +void LaunchTROWSUM_f32_8x1024(float *src, float *dst, void *stream); +void LaunchTROWSUM_i32_127x64_valid127x63(int32_t *src, int32_t *dst, void *stream); +void LaunchTROWSUM_i32_63x64(int32_t *src, int32_t *dst, void *stream); +void LaunchTROWSUM_i32_31x128_valid31x127(int32_t *src, int32_t *dst, void *stream); +void LaunchTROWSUM_i32_15x192(int32_t *src, int32_t *dst, void *stream); +void LaunchTROWSUM_i32_7x448_valid7x447(int32_t *src, int32_t *dst, void *stream); +void LaunchTROWSUM_i16_128x64(int16_t *src, int16_t *dst, void *stream); +void LaunchTROWSUM_i16_64x64(int16_t *src, int16_t *dst, void *stream); +void LaunchTROWSUM_i16_32x128(int16_t *src, int16_t *dst, void *stream); +void LaunchTROWSUM_i16_16x192(int16_t *src, int16_t *dst, void *stream); +void LaunchTROWSUM_i16_8x448(int16_t *src, int16_t *dst, void *stream); +void LaunchTROWSUM_i16_1x64_overflow(int16_t *src, int16_t *dst, void *stream); + +using LaunchFnF32 = void (*)(float *, float *, void *); +using LaunchFnF16 = void (*)(uint16_t *, uint16_t *, void *); +using LaunchFnI32 = void (*)(int32_t *, int32_t *, void *); +using LaunchFnI16 = void (*)(int16_t *, int16_t *, void *); + +enum class DType { F32, F16, I32, I16 }; + +struct TestCase { + const char *name; + DType dtype; + union { + LaunchFnF32 launchF32; + LaunchFnF16 launchF16; + LaunchFnI32 launchI32; + LaunchFnI16 launchI16; + }; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + // f32 cases + {"f32_127x64_valid127x63", DType::F32, .launchF32 = LaunchTROWSUM_f32_127x64_valid127x63, 127, 64, 127, 63, 4}, + {"f32_63x64", DType::F32, .launchF32 = LaunchTROWSUM_f32_63x64, 63, 64, 63, 64, 4}, + {"f32_31x128_valid31x127", DType::F32, .launchF32 = LaunchTROWSUM_f32_31x128_valid31x127, 31, 128, 31, 127, 4}, + {"f32_15x192", DType::F32, .launchF32 = LaunchTROWSUM_f32_15x192, 15, 192, 15, 192, 4}, + {"f32_7x448_valid7x447", DType::F32, .launchF32 = LaunchTROWSUM_f32_7x448_valid7x447, 7, 448, 7, 447, 4}, + // f16 case + {"f16_256x16_valid256x15", DType::F16, .launchF16 = LaunchTROWSUM_f16_256x16_valid256x15, 256, 16, 256, 15, 2}, + // f32 DN dst cases + {"f32_64x128", DType::F32, .launchF32 = LaunchTROWSUM_f32_64x128, 64, 128, 64, 128, 4}, + {"f32_32x256", DType::F32, .launchF32 = LaunchTROWSUM_f32_32x256, 32, 256, 32, 256, 4}, + {"f32_16x512", DType::F32, .launchF32 = LaunchTROWSUM_f32_16x512, 16, 512, 16, 512, 4}, + {"f32_8x1024", DType::F32, .launchF32 = LaunchTROWSUM_f32_8x1024, 8, 1024,8, 1024,4}, + // int32 cases + {"i32_127x64_valid127x63", DType::I32, .launchI32 = LaunchTROWSUM_i32_127x64_valid127x63, 127, 64, 127, 63, 4}, + {"i32_63x64", DType::I32, .launchI32 = LaunchTROWSUM_i32_63x64, 63, 64, 63, 64, 4}, + {"i32_31x128_valid31x127", DType::I32, .launchI32 = LaunchTROWSUM_i32_31x128_valid31x127, 31, 128, 31, 127, 4}, + {"i32_15x192", DType::I32, .launchI32 = LaunchTROWSUM_i32_15x192, 15, 192, 15, 192, 4}, + {"i32_7x448_valid7x447", DType::I32, .launchI32 = LaunchTROWSUM_i32_7x448_valid7x447, 7, 448, 7, 447, 4}, + // int16 cases + {"i16_128x64", DType::I16, .launchI16 = LaunchTROWSUM_i16_128x64, 128, 64, 128, 64, 2}, + {"i16_64x64", DType::I16, .launchI16 = LaunchTROWSUM_i16_64x64, 64, 64, 64, 64, 2}, + {"i16_32x128", DType::I16, .launchI16 = LaunchTROWSUM_i16_32x128, 32, 128, 32, 128, 2}, + {"i16_16x192", DType::I16, .launchI16 = LaunchTROWSUM_i16_16x192, 16, 192, 16, 192, 2}, + {"i16_8x448", DType::I16, .launchI16 = LaunchTROWSUM_i16_8x448, 8, 448, 8, 448, 2}, + // i16 overflow case + {"i16_1x64_overflow", DType::I16, .launchI16 = LaunchTROWSUM_i16_1x64_overflow, 1, 64, 1, 64, 2}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t srcElemCount = tc.rows * tc.cols; + const size_t srcFileSize = srcElemCount * tc.elemSize; + const size_t dstElemCount = tc.validRows * 1; + const size_t dstFileSize = dstElemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t src0FileSize = srcFileSize; + + void *src0Host = nullptr, *dstHost = nullptr; + void *src0Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&src0Host, srcFileSize); + aclrtMallocHost(&dstHost, dstFileSize); + + aclrtMalloc(&src0Device, srcFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input.bin").c_str(), src0FileSize, src0Host, srcFileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, srcFileSize, src0Host, srcFileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + switch (tc.dtype) { + case DType::F32: tc.launchF32((float *)src0Device, (float *)dstDevice, stream); break; + case DType::F16: tc.launchF16((uint16_t *)src0Device, (uint16_t *)dstDevice, stream); break; + case DType::I32: tc.launchI32((int32_t *)src0Device, (int32_t *)dstDevice, stream); break; + case DType::I16: tc.launchI16((int16_t *)src0Device, (int16_t *)dstDevice, stream); break; + } + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./trowsum [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trowsum/trowsum.pto b/test/tilelang_st/npu/a5/src/st/testcase/trowsum/trowsum.pto new file mode 100644 index 000000000..57b3ca62a --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trowsum/trowsum.pto @@ -0,0 +1,983 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.trowsum: tload(src) + trowsum(src, tmp)->dst + tstore(dst). +// Aligned with pto-isa tests/npu/a5/src/st/testcase/trowsum (20 cases) + 1 overflow case. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + + // ======================================================================== + // f32 cases + // ======================================================================== + + // f32_127x64_valid127x63 + func.func @TROWSUM_f32_127x64_valid127x63(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c127 = arith.constant 127 : index + %c8128 = arith.constant 8128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c127, %c64], + strides = [%c8128, %c8128, %c8128, %c64, %c1] + : !pto.tensor_view<1x1x1x127x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c127, %c1], + strides = [%c127, %c127, %c127, %c1, %c1] + : !pto.tensor_view<1x1x1x127x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c127, %c63] + : !pto.tensor_view<1x1x1x127x64xf32> -> !pto.partition_tensor_view<1x1x1x127x63xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c127, %c1] + : !pto.tensor_view<1x1x1x127x1xf32> -> !pto.partition_tensor_view<1x1x1x127x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x127x63xf32>) + outs(%src : !pto.tile_buf) + + pto.trowsum ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x127x1xf32>) + return + } + + // f32_63x64 + func.func @TROWSUM_f32_63x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c4032 = arith.constant 4032 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c63, %c1], + strides = [%c63, %c63, %c63, %c1, %c1] + : !pto.tensor_view<1x1x1x63x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xf32> -> !pto.partition_tensor_view<1x1x1x63x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c1] + : !pto.tensor_view<1x1x1x63x1xf32> -> !pto.partition_tensor_view<1x1x1x63x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x64xf32>) + outs(%src : !pto.tile_buf) + + pto.trowsum ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x1xf32>) + return + } + + // f32_31x128_valid31x127 + func.func @TROWSUM_f32_31x128_valid31x127(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c31 = arith.constant 31 : index + %c127 = arith.constant 127 : index + %c128 = arith.constant 128 : index + %c3968 = arith.constant 3968 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c31, %c1], + strides = [%c31, %c31, %c31, %c1, %c1] + : !pto.tensor_view<1x1x1x31x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c127] + : !pto.tensor_view<1x1x1x31x128xf32> -> !pto.partition_tensor_view<1x1x1x31x127xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c1] + : !pto.tensor_view<1x1x1x31x1xf32> -> !pto.partition_tensor_view<1x1x1x31x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x31x127xf32>) + outs(%src : !pto.tile_buf) + + pto.trowsum ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x31x1xf32>) + return + } + + // f32_15x192 + func.func @TROWSUM_f32_15x192(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c192 = arith.constant 192 : index + %c2880 = arith.constant 2880 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c15, %c1], + strides = [%c15, %c15, %c15, %c1, %c1] + : !pto.tensor_view<1x1x1x15x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xf32> -> !pto.partition_tensor_view<1x1x1x15x192xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c1] + : !pto.tensor_view<1x1x1x15x1xf32> -> !pto.partition_tensor_view<1x1x1x15x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x192xf32>) + outs(%src : !pto.tile_buf) + + pto.trowsum ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x15x1xf32>) + return + } + + // f32_7x448_valid7x447 + func.func @TROWSUM_f32_7x448_valid7x447(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c7 = arith.constant 7 : index + %c447 = arith.constant 447 : index + %c448 = arith.constant 448 : index + %c3136 = arith.constant 3136 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c7, %c448], + strides = [%c3136, %c3136, %c3136, %c448, %c1] + : !pto.tensor_view<1x1x1x7x448xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c7, %c1], + strides = [%c7, %c7, %c7, %c1, %c1] + : !pto.tensor_view<1x1x1x7x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c447] + : !pto.tensor_view<1x1x1x7x448xf32> -> !pto.partition_tensor_view<1x1x1x7x447xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c1] + : !pto.tensor_view<1x1x1x7x1xf32> -> !pto.partition_tensor_view<1x1x1x7x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x7x447xf32>) + outs(%src : !pto.tile_buf) + + pto.trowsum ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x7x1xf32>) + return + } + + // f16_256x16_valid256x15 + func.func @TROWSUM_f16_256x16_valid256x15(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c256, %c16], + strides = [%c4096, %c4096, %c4096, %c16, %c1] + : !pto.tensor_view<1x1x1x256x16xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c256, %c1], + strides = [%c256, %c256, %c256, %c1, %c1] + : !pto.tensor_view<1x1x1x256x1xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c15] + : !pto.tensor_view<1x1x1x256x16xf16> -> !pto.partition_tensor_view<1x1x1x256x15xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c1] + : !pto.tensor_view<1x1x1x256x1xf16> -> !pto.partition_tensor_view<1x1x1x256x1xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x256x15xf16>) + outs(%src : !pto.tile_buf) + + pto.trowsum ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x256x1xf16>) + return + } + + // f32_64x128 + func.func @TROWSUM_f32_64x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c64, %c128], + strides = [%c8192, %c8192, %c8192, %c128, %c1] + : !pto.tensor_view<1x1x1x64x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c64, %c1], + strides = [%c64, %c64, %c64, %c1, %c1] + : !pto.tensor_view<1x1x1x64x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c128] + : !pto.tensor_view<1x1x1x64x128xf32> -> !pto.partition_tensor_view<1x1x1x64x128xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c1] + : !pto.tensor_view<1x1x1x64x1xf32> -> !pto.partition_tensor_view<1x1x1x64x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x64x128xf32>) + outs(%src : !pto.tile_buf) + + pto.trowsum ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x64x1xf32>) + return + } + + // f32_32x256 + func.func @TROWSUM_f32_32x256(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c256 = arith.constant 256 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c32, %c256], + strides = [%c8192, %c8192, %c8192, %c256, %c1] + : !pto.tensor_view<1x1x1x32x256xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c1], + strides = [%c32, %c32, %c32, %c1, %c1] + : !pto.tensor_view<1x1x1x32x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c256] + : !pto.tensor_view<1x1x1x32x256xf32> -> !pto.partition_tensor_view<1x1x1x32x256xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c1] + : !pto.tensor_view<1x1x1x32x1xf32> -> !pto.partition_tensor_view<1x1x1x32x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x256xf32>) + outs(%src : !pto.tile_buf) + + pto.trowsum ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x1xf32>) + return + } + + // f32_16x512 + func.func @TROWSUM_f32_16x512(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c512 = arith.constant 512 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c512], + strides = [%c8192, %c8192, %c8192, %c512, %c1] + : !pto.tensor_view<1x1x1x16x512xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c1], + strides = [%c16, %c16, %c16, %c1, %c1] + : !pto.tensor_view<1x1x1x16x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c512] + : !pto.tensor_view<1x1x1x16x512xf32> -> !pto.partition_tensor_view<1x1x1x16x512xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c1] + : !pto.tensor_view<1x1x1x16x1xf32> -> !pto.partition_tensor_view<1x1x1x16x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x512xf32>) + outs(%src : !pto.tile_buf) + + pto.trowsum ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x1xf32>) + return + } + + // f32_8x1024 + func.func @TROWSUM_f32_8x1024(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c1024 = arith.constant 1024 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c8, %c1024], + strides = [%c8192, %c8192, %c8192, %c1024, %c1] + : !pto.tensor_view<1x1x1x8x1024xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c8, %c1], + strides = [%c8, %c8, %c8, %c1, %c1] + : !pto.tensor_view<1x1x1x8x1xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c1024] + : !pto.tensor_view<1x1x1x8x1024xf32> -> !pto.partition_tensor_view<1x1x1x8x1024xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c1] + : !pto.tensor_view<1x1x1x8x1xf32> -> !pto.partition_tensor_view<1x1x1x8x1xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x8x1024xf32>) + outs(%src : !pto.tile_buf) + + pto.trowsum ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x8x1xf32>) + return + } + + // ======================================================================== + // i32 cases + // ======================================================================== + + // i32_127x64_valid127x63 + func.func @TROWSUM_i32_127x64_valid127x63(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c127 = arith.constant 127 : index + %c8128 = arith.constant 8128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c127, %c64], + strides = [%c8128, %c8128, %c8128, %c64, %c1] + : !pto.tensor_view<1x1x1x127x64xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c127, %c1], + strides = [%c127, %c127, %c127, %c1, %c1] + : !pto.tensor_view<1x1x1x127x1xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c127, %c63] + : !pto.tensor_view<1x1x1x127x64xi32> -> !pto.partition_tensor_view<1x1x1x127x63xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c127, %c1] + : !pto.tensor_view<1x1x1x127x1xi32> -> !pto.partition_tensor_view<1x1x1x127x1xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x127x63xi32>) + outs(%src : !pto.tile_buf) + + pto.trowsum ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x127x1xi32>) + return + } + + // i32_63x64 + func.func @TROWSUM_i32_63x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c4032 = arith.constant 4032 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c63, %c1], + strides = [%c63, %c63, %c63, %c1, %c1] + : !pto.tensor_view<1x1x1x63x1xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xi32> -> !pto.partition_tensor_view<1x1x1x63x64xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c1] + : !pto.tensor_view<1x1x1x63x1xi32> -> !pto.partition_tensor_view<1x1x1x63x1xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x64xi32>) + outs(%src : !pto.tile_buf) + + pto.trowsum ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x1xi32>) + return + } + + // i32_31x128_valid31x127 + func.func @TROWSUM_i32_31x128_valid31x127(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c31 = arith.constant 31 : index + %c127 = arith.constant 127 : index + %c128 = arith.constant 128 : index + %c3968 = arith.constant 3968 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c31, %c1], + strides = [%c31, %c31, %c31, %c1, %c1] + : !pto.tensor_view<1x1x1x31x1xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c127] + : !pto.tensor_view<1x1x1x31x128xi32> -> !pto.partition_tensor_view<1x1x1x31x127xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c1] + : !pto.tensor_view<1x1x1x31x1xi32> -> !pto.partition_tensor_view<1x1x1x31x1xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x31x127xi32>) + outs(%src : !pto.tile_buf) + + pto.trowsum ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x31x1xi32>) + return + } + + // i32_15x192 + func.func @TROWSUM_i32_15x192(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c192 = arith.constant 192 : index + %c2880 = arith.constant 2880 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c15, %c1], + strides = [%c15, %c15, %c15, %c1, %c1] + : !pto.tensor_view<1x1x1x15x1xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xi32> -> !pto.partition_tensor_view<1x1x1x15x192xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c1] + : !pto.tensor_view<1x1x1x15x1xi32> -> !pto.partition_tensor_view<1x1x1x15x1xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x192xi32>) + outs(%src : !pto.tile_buf) + + pto.trowsum ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x15x1xi32>) + return + } + + // i32_7x448_valid7x447 + func.func @TROWSUM_i32_7x448_valid7x447(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c7 = arith.constant 7 : index + %c447 = arith.constant 447 : index + %c448 = arith.constant 448 : index + %c3136 = arith.constant 3136 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c7, %c448], + strides = [%c3136, %c3136, %c3136, %c448, %c1] + : !pto.tensor_view<1x1x1x7x448xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c7, %c1], + strides = [%c7, %c7, %c7, %c1, %c1] + : !pto.tensor_view<1x1x1x7x1xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c447] + : !pto.tensor_view<1x1x1x7x448xi32> -> !pto.partition_tensor_view<1x1x1x7x447xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c1] + : !pto.tensor_view<1x1x1x7x1xi32> -> !pto.partition_tensor_view<1x1x1x7x1xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x7x447xi32>) + outs(%src : !pto.tile_buf) + + pto.trowsum ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x7x1xi32>) + return + } + + // ======================================================================== + // i16 cases + // ======================================================================== + + // i16_128x64 + func.func @TROWSUM_i16_128x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c8192 = arith.constant 8192 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c128, %c64], + strides = [%c8192, %c8192, %c8192, %c64, %c1] + : !pto.tensor_view<1x1x1x128x64xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c128, %c1], + strides = [%c128, %c128, %c128, %c1, %c1] + : !pto.tensor_view<1x1x1x128x1xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c64] + : !pto.tensor_view<1x1x1x128x64xi16> -> !pto.partition_tensor_view<1x1x1x128x64xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c128, %c1] + : !pto.tensor_view<1x1x1x128x1xi16> -> !pto.partition_tensor_view<1x1x1x128x1xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x128x64xi16>) + outs(%src : !pto.tile_buf) + + pto.trowsum ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x128x1xi16>) + return + } + + // i16_64x64 + func.func @TROWSUM_i16_64x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c64, %c1], + strides = [%c64, %c64, %c64, %c1, %c1] + : !pto.tensor_view<1x1x1x64x1xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xi16> -> !pto.partition_tensor_view<1x1x1x64x64xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c1] + : !pto.tensor_view<1x1x1x64x1xi16> -> !pto.partition_tensor_view<1x1x1x64x1xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x64x64xi16>) + outs(%src : !pto.tile_buf) + + pto.trowsum ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x64x1xi16>) + return + } + + // i16_32x128 + func.func @TROWSUM_i16_32x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c32, %c128], + strides = [%c4096, %c4096, %c4096, %c128, %c1] + : !pto.tensor_view<1x1x1x32x128xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c1], + strides = [%c32, %c32, %c32, %c1, %c1] + : !pto.tensor_view<1x1x1x32x1xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c128] + : !pto.tensor_view<1x1x1x32x128xi16> -> !pto.partition_tensor_view<1x1x1x32x128xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c1] + : !pto.tensor_view<1x1x1x32x1xi16> -> !pto.partition_tensor_view<1x1x1x32x1xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x128xi16>) + outs(%src : !pto.tile_buf) + + pto.trowsum ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x1xi16>) + return + } + + // i16_16x192 + func.func @TROWSUM_i16_16x192(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c192 = arith.constant 192 : index + %c3072 = arith.constant 3072 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c192], + strides = [%c3072, %c3072, %c3072, %c192, %c1] + : !pto.tensor_view<1x1x1x16x192xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c1], + strides = [%c16, %c16, %c16, %c1, %c1] + : !pto.tensor_view<1x1x1x16x1xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c192] + : !pto.tensor_view<1x1x1x16x192xi16> -> !pto.partition_tensor_view<1x1x1x16x192xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c1] + : !pto.tensor_view<1x1x1x16x1xi16> -> !pto.partition_tensor_view<1x1x1x16x1xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x192xi16>) + outs(%src : !pto.tile_buf) + + pto.trowsum ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x1xi16>) + return + } + + // i16_8x448 + func.func @TROWSUM_i16_8x448(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c448 = arith.constant 448 : index + %c3584 = arith.constant 3584 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c8, %c448], + strides = [%c3584, %c3584, %c3584, %c448, %c1] + : !pto.tensor_view<1x1x1x8x448xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c8, %c1], + strides = [%c8, %c8, %c8, %c1, %c1] + : !pto.tensor_view<1x1x1x8x1xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c448] + : !pto.tensor_view<1x1x1x8x448xi16> -> !pto.partition_tensor_view<1x1x1x8x448xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c8, %c1] + : !pto.tensor_view<1x1x1x8x1xi16> -> !pto.partition_tensor_view<1x1x1x8x1xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x8x448xi16>) + outs(%src : !pto.tile_buf) + + pto.trowsum ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x8x1xi16>) + return + } + + // i16_1x64_overflow: test vcvt NOSAT behavior (1000*64=64000 wraps to -1536 in i16) + func.func @TROWSUM_i16_1x64_overflow(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes { pto.entry , pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c64], + strides = [%c64, %c64, %c64, %c64, %c1] + : !pto.tensor_view<1x1x1x1x64xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c1], + strides = [%c1, %c1, %c1, %c1, %c1] + : !pto.tensor_view<1x1x1x1x1xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c64] + : !pto.tensor_view<1x1x1x1x64xi16> -> !pto.partition_tensor_view<1x1x1x1x64xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c1] + : !pto.tensor_view<1x1x1x1x1xi16> -> !pto.partition_tensor_view<1x1x1x1x1xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x64xi16>) + outs(%src : !pto.tile_buf) + + pto.trowsum ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x1xi16>) + return + } + +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trsqrt/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/trsqrt/CMakeLists.txt new file mode 100644 index 000000000..7209977f8 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trsqrt/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(trsqrt) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trsqrt/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/trsqrt/cases.py new file mode 100644 index 000000000..cb8b6a48d --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trsqrt/cases.py @@ -0,0 +1,55 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for trsqrt ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_16x64", + "dtype": np.float32, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-6, + }, + { + "name": "f32_32x32", + "dtype": np.float32, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-6, + }, + { + "name": "f16_16x64", + "dtype": np.float16, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-3, + }, + { + "name": "f16_32x32", + "dtype": np.float16, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-3, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trsqrt/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/trsqrt/compare.py new file mode 100644 index 000000000..428604929 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trsqrt/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trsqrt/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/trsqrt/gen_data.py new file mode 100644 index 000000000..9ca63c976 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trsqrt/gen_data.py @@ -0,0 +1,34 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + # Positive values for rsqrt (1/sqrt(x) requires sqrt(x) > 0) + input = np.random.uniform(0.1, 100.0, size=shape).astype(dtype) + + # rsqrt = 1 / sqrt(x) + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + golden[:vr, :vc] = np.reciprocal(np.sqrt(input[:vr, :vc])).astype(dtype, copy=False) + + save_case_data(case["name"], {"input": input, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trsqrt/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trsqrt/launch.cpp new file mode 100644 index 000000000..65a35f3bd --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trsqrt/launch.cpp @@ -0,0 +1,41 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: f32 16x64 +extern "C" __global__ AICORE void TRSQRT_f32_16x64(__gm__ float *a, __gm__ float *b); + +void LaunchTRSQRT_f32_16x64(void *a, void *b, void *stream) { + TRSQRT_f32_16x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b); +} + +// Case 1: f32 32x32 +extern "C" __global__ AICORE void TRSQRT_f32_32x32(__gm__ float *a, __gm__ float *b); + +void LaunchTRSQRT_f32_32x32(void *a, void *b, void *stream) { + TRSQRT_f32_32x32<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b); +} + +// Case 2: f16 16x64 +extern "C" __global__ AICORE void TRSQRT_f16_16x64(__gm__ uint16_t *a, __gm__ uint16_t *b); + +void LaunchTRSQRT_f16_16x64(void *a, void *b, void *stream) { + TRSQRT_f16_16x64<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b); +} + +// Case 3: f16 32x32 +extern "C" __global__ AICORE void TRSQRT_f16_32x32(__gm__ uint16_t *a, __gm__ uint16_t *b); + +void LaunchTRSQRT_f16_32x32(void *a, void *b, void *stream) { + TRSQRT_f16_32x32<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trsqrt/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/trsqrt/main.cpp new file mode 100644 index 000000000..20c955070 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trsqrt/main.cpp @@ -0,0 +1,137 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang trsqrt ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTRSQRT_f32_16x64(void *a, void *b, void *stream); +void LaunchTRSQRT_f32_32x32(void *a, void *b, void *stream); +void LaunchTRSQRT_f16_16x64(void *a, void *b, void *stream); +void LaunchTRSQRT_f16_32x32(void *a, void *b, void *stream); + +using LaunchFn = void (*)(void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_16x64", LaunchTRSQRT_f32_16x64, 16, 64, 16, 64, sizeof(float)}, + {"f32_32x32", LaunchTRSQRT_f32_32x32, 32, 32, 32, 32, sizeof(float)}, + {"f16_16x64", LaunchTRSQRT_f16_16x64, 16, 64, 16, 64, sizeof(uint16_t)}, + {"f16_32x32", LaunchTRSQRT_f16_32x32, 32, 32, 32, 32, sizeof(uint16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSize = fileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&srcHost), fileSize); + aclrtMallocHost((void **)(&dstHost), fileSize); + + aclrtMalloc((void **)&srcDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input.bin").c_str(), srcFileSize, srcHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, fileSize, srcHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./trsqrt [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/trsqrt/trsqrt.pto b/test/tilelang_st/npu/a5/src/st/testcase/trsqrt/trsqrt.pto new file mode 100644 index 000000000..7881bb3ea --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/trsqrt/trsqrt.pto @@ -0,0 +1,181 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.trsqrt: 1/sqrt(x) +// trsqrt = vsqrt(x) -> vdiv(1.0, sqrt_result) +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // Case 0: f32 16x64 (1024 elements) + func.func @TRSQRT_f32_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%a : !pto.tile_buf) + + pto.trsqrt ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + return + } + + // Case 1: f32 32x32 (1024 elements) + func.func @TRSQRT_f32_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + outs(%a : !pto.tile_buf) + + pto.trsqrt ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + return + } + + // Case 2: f16 16x64 (1024 elements) + func.func @TRSQRT_f16_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + outs(%a : !pto.tile_buf) + + pto.trsqrt ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + return + } + + // Case 3: f16 32x32 (1024 elements) + func.func @TRSQRT_f16_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf16> -> !pto.partition_tensor_view<1x1x1x32x32xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf16> -> !pto.partition_tensor_view<1x1x1x32x32xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xf16>) + outs(%a : !pto.tile_buf) + + pto.trsqrt ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x32x32xf16>) + return + } +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsel/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tsel/CMakeLists.txt new file mode 100644 index 000000000..73a77806b --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsel/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tsel) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsel/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tsel/cases.py new file mode 100644 index 000000000..3432ab7e9 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsel/cases.py @@ -0,0 +1,97 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tsel ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_2x128", + "dtype": np.float32, + "shape": (2, 128), + "valid_shape": (2, 128), + "eps": 1e-6, + }, + { + "name": "f32_2x32", + "dtype": np.float32, + "shape": (2, 32), + "valid_shape": (2, 32), + "eps": 1e-6, + }, + { + "name": "f32_2x160", + "dtype": np.float32, + "shape": (2, 160), + "valid_shape": (2, 160), + "eps": 1e-6, + }, + { + "name": "f32_2x512", + "dtype": np.float32, + "shape": (2, 512), + "valid_shape": (2, 512), + "eps": 1e-6, + }, + { + "name": "f16_2x128", + "dtype": np.float16, + "shape": (2, 128), + "valid_shape": (2, 128), + "eps": 1e-3, + }, + { + "name": "f16_2x32", + "dtype": np.float16, + "shape": (2, 32), + "valid_shape": (2, 32), + "eps": 1e-3, + }, + { + "name": "f16_2x160", + "dtype": np.float16, + "shape": (2, 160), + "valid_shape": (2, 160), + "eps": 1e-3, + }, + { + "name": "i8_2x128", + "dtype": np.int8, + "shape": (2, 128), + "valid_shape": (2, 128), + "eps": 0, + }, + { + "name": "i8_2x32", + "dtype": np.int8, + "shape": (2, 32), + "valid_shape": (2, 32), + "eps": 0, + }, + { + "name": "i8_2x160", + "dtype": np.int8, + "shape": (2, 160), + "valid_shape": (2, 160), + "eps": 0, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsel/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tsel/compare.py new file mode 100644 index 000000000..b975718e6 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsel/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsel/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tsel/gen_data.py new file mode 100644 index 000000000..7308ac94d --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsel/gen_data.py @@ -0,0 +1,44 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + vr, vc = valid_shape + mask_cols = (vc + 7) // 8 + + src0 = np.random.randint(1, 10, size=shape).astype(dtype) + src1 = np.random.randint(1, 10, size=shape).astype(dtype) + mask = np.random.randint(0, 256, size=(vr, mask_cols), dtype=np.uint8) + + golden = np.zeros(shape, dtype=dtype) + src0_valid = src0[:vr, :vc] + src1_valid = src1[:vr, :vc] + for row in range(vr): + for packed_col in range(mask_cols): + byte = int(mask[row, packed_col]) + for bit in range(8): + col = packed_col * 8 + bit + if col >= vc: + break + golden[row, col] = src0_valid[row, col] if ((byte >> bit) & 1) else src1_valid[row, col] + + save_case_data(case["name"], {"input1": src0, "input2": src1, "input3": mask, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsel/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tsel/launch.cpp new file mode 100644 index 000000000..dbd0edbe1 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsel/launch.cpp @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: f32 2x128 +extern "C" __global__ AICORE void TSEL_f32_2x128(__gm__ uint8_t *mask, __gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTSEL_f32_2x128(uint8_t *mask, float *src0, float *src1, float *dst, void *stream) { + TSEL_f32_2x128<<<1, nullptr, stream>>>((__gm__ uint8_t *)mask, (__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// Case 1: f32 2x32 +extern "C" __global__ AICORE void TSEL_f32_2x32(__gm__ uint8_t *mask, __gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTSEL_f32_2x32(uint8_t *mask, float *src0, float *src1, float *dst, void *stream) { + TSEL_f32_2x32<<<1, nullptr, stream>>>((__gm__ uint8_t *)mask, (__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// Case 2: f32 2x160 +extern "C" __global__ AICORE void TSEL_f32_2x160(__gm__ uint8_t *mask, __gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTSEL_f32_2x160(uint8_t *mask, float *src0, float *src1, float *dst, void *stream) { + TSEL_f32_2x160<<<1, nullptr, stream>>>((__gm__ uint8_t *)mask, (__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// Case 3: f32 2x512 +extern "C" __global__ AICORE void TSEL_f32_2x512(__gm__ uint8_t *mask, __gm__ float *src0, __gm__ float *src1, __gm__ float *dst); + +void LaunchTSEL_f32_2x512(uint8_t *mask, float *src0, float *src1, float *dst, void *stream) { + TSEL_f32_2x512<<<1, nullptr, stream>>>((__gm__ uint8_t *)mask, (__gm__ float *)src0, (__gm__ float *)src1, (__gm__ float *)dst); +} + +// Case 4: f16 2x128 +extern "C" __global__ AICORE void TSEL_f16_2x128(__gm__ uint8_t *mask, __gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTSEL_f16_2x128(uint8_t *mask, uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream) { + TSEL_f16_2x128<<<1, nullptr, stream>>>((__gm__ uint8_t *)mask, (__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} + +// Case 5: f16 2x32 +extern "C" __global__ AICORE void TSEL_f16_2x32(__gm__ uint8_t *mask, __gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTSEL_f16_2x32(uint8_t *mask, uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream) { + TSEL_f16_2x32<<<1, nullptr, stream>>>((__gm__ uint8_t *)mask, (__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} + +// Case 6: f16 2x160 +extern "C" __global__ AICORE void TSEL_f16_2x160(__gm__ uint8_t *mask, __gm__ uint16_t *src0, __gm__ uint16_t *src1, __gm__ uint16_t *dst); + +void LaunchTSEL_f16_2x160(uint8_t *mask, uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream) { + TSEL_f16_2x160<<<1, nullptr, stream>>>((__gm__ uint8_t *)mask, (__gm__ uint16_t *)src0, (__gm__ uint16_t *)src1, (__gm__ uint16_t *)dst); +} + +// Case 7: i8 2x128 +extern "C" __global__ AICORE void TSEL_i8_2x128(__gm__ uint8_t *mask, __gm__ int8_t *src0, __gm__ int8_t *src1, __gm__ int8_t *dst); + +void LaunchTSEL_i8_2x128(uint8_t *mask, int8_t *src0, int8_t *src1, int8_t *dst, void *stream) { + TSEL_i8_2x128<<<1, nullptr, stream>>>((__gm__ uint8_t *)mask, (__gm__ int8_t *)src0, (__gm__ int8_t *)src1, (__gm__ int8_t *)dst); +} + +// Case 8: i8 2x32 +extern "C" __global__ AICORE void TSEL_i8_2x32(__gm__ uint8_t *mask, __gm__ int8_t *src0, __gm__ int8_t *src1, __gm__ int8_t *dst); + +void LaunchTSEL_i8_2x32(uint8_t *mask, int8_t *src0, int8_t *src1, int8_t *dst, void *stream) { + TSEL_i8_2x32<<<1, nullptr, stream>>>((__gm__ uint8_t *)mask, (__gm__ int8_t *)src0, (__gm__ int8_t *)src1, (__gm__ int8_t *)dst); +} + +// Case 9: i8 2x160 +extern "C" __global__ AICORE void TSEL_i8_2x160(__gm__ uint8_t *mask, __gm__ int8_t *src0, __gm__ int8_t *src1, __gm__ int8_t *dst); + +void LaunchTSEL_i8_2x160(uint8_t *mask, int8_t *src0, int8_t *src1, int8_t *dst, void *stream) { + TSEL_i8_2x160<<<1, nullptr, stream>>>((__gm__ uint8_t *)mask, (__gm__ int8_t *)src0, (__gm__ int8_t *)src1, (__gm__ int8_t *)dst); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsel/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tsel/main.cpp new file mode 100644 index 000000000..4bf41b7bd --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsel/main.cpp @@ -0,0 +1,312 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tsel ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTSEL_f32_2x128(uint8_t *mask, float *src0, float *src1, float *dst, void *stream); +void LaunchTSEL_f32_2x32(uint8_t *mask, float *src0, float *src1, float *dst, void *stream); +void LaunchTSEL_f32_2x160(uint8_t *mask, float *src0, float *src1, float *dst, void *stream); +void LaunchTSEL_f32_2x512(uint8_t *mask, float *src0, float *src1, float *dst, void *stream); +void LaunchTSEL_f16_2x128(uint8_t *mask, uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream); +void LaunchTSEL_f16_2x32(uint8_t *mask, uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream); +void LaunchTSEL_f16_2x160(uint8_t *mask, uint16_t *src0, uint16_t *src1, uint16_t *dst, void *stream); +void LaunchTSEL_i8_2x128(uint8_t *mask, int8_t *src0, int8_t *src1, int8_t *dst, void *stream); +void LaunchTSEL_i8_2x32(uint8_t *mask, int8_t *src0, int8_t *src1, int8_t *dst, void *stream); +void LaunchTSEL_i8_2x160(uint8_t *mask, int8_t *src0, int8_t *src1, int8_t *dst, void *stream); + +enum DataType { DT_F32, DT_F16, DT_I8 }; + +using LaunchFnF32 = void (*)(uint8_t *, float *, float *, float *, void *); +using LaunchFnF16 = void (*)(uint8_t *, uint16_t *, uint16_t *, uint16_t *, void *); +using LaunchFnI8 = void (*)(uint8_t *, int8_t *, int8_t *, int8_t *, void *); + +struct TestCase { + const char *name; + DataType dtype; + LaunchFnF32 launchF32; + LaunchFnF16 launchF16; + LaunchFnI8 launchI8; + size_t rows; + size_t cols; + size_t validRows; + size_t validCols; + size_t elemSize; +}; + +static const TestCase kCases[] = { + {"f32_2x128", DT_F32, LaunchTSEL_f32_2x128, nullptr, nullptr, 2, 128, 2, 128, sizeof(float)}, + {"f32_2x32", DT_F32, LaunchTSEL_f32_2x32, nullptr, nullptr, 2, 32, 2, 32, sizeof(float)}, + {"f32_2x160", DT_F32, LaunchTSEL_f32_2x160, nullptr, nullptr, 2, 160, 2, 160, sizeof(float)}, + {"f32_2x512", DT_F32, LaunchTSEL_f32_2x512, nullptr, nullptr, 2, 512, 2, 512, sizeof(float)}, + {"f16_2x128", DT_F16, nullptr, LaunchTSEL_f16_2x128, nullptr, 2, 128, 2, 128, sizeof(uint16_t)}, + {"f16_2x32", DT_F16, nullptr, LaunchTSEL_f16_2x32, nullptr, 2, 32, 2, 32, sizeof(uint16_t)}, + {"f16_2x160", DT_F16, nullptr, LaunchTSEL_f16_2x160, nullptr, 2, 160, 2, 160, sizeof(uint16_t)}, + {"i8_2x128", DT_I8, nullptr, nullptr, LaunchTSEL_i8_2x128, 2, 128, 2, 128, sizeof(int8_t)}, + {"i8_2x32", DT_I8, nullptr, nullptr, LaunchTSEL_i8_2x32, 2, 32, 2, 32, sizeof(int8_t)}, + {"i8_2x160", DT_I8, nullptr, nullptr, LaunchTSEL_i8_2x160, 2, 160, 2, 160, sizeof(int8_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSizeConst = elemCount * tc.elemSize; + const size_t maskCols = (tc.validCols + 7) / 8; + const size_t maskFileSizeConst = tc.validRows * maskCols * sizeof(uint8_t); + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + std::string caseDir = std::string("./") + tc.name; + + if (tc.dtype == DT_F32) { + uint8_t *maskHost = nullptr; + float *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + uint8_t *maskDevice = nullptr; + float *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&maskHost), maskFileSizeConst); + aclrtMallocHost((void **)(&src0Host), fileSizeConst); + aclrtMallocHost((void **)(&src1Host), fileSizeConst); + aclrtMallocHost((void **)(&dstHost), fileSizeConst); + + aclrtMalloc((void **)&maskDevice, maskFileSizeConst, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src0Device, fileSizeConst, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, fileSizeConst, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSizeConst, ACL_MEM_MALLOC_HUGE_FIRST); + + size_t fileSize = fileSizeConst; + if (!ReadFile((caseDir + "/input1.bin").c_str(), fileSize, src0Host, fileSizeConst)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + fileSize = fileSizeConst; + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), fileSize, src1Host, fileSizeConst)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + size_t maskFileSize = maskFileSizeConst; + if (rc == 0 && !ReadFile((caseDir + "/input3.bin").c_str(), maskFileSize, maskHost, maskFileSizeConst)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input3.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(maskDevice, maskFileSizeConst, maskHost, maskFileSizeConst, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src0Device, fileSizeConst, src0Host, fileSizeConst, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, fileSizeConst, src1Host, fileSizeConst, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launchF32(maskDevice, src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSizeConst, dstDevice, fileSizeConst, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSizeConst)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (maskDevice != nullptr) + aclrtFree(maskDevice); + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (maskHost != nullptr) + aclrtFreeHost(maskHost); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + } else if (tc.dtype == DT_F16) { + uint8_t *maskHost = nullptr; + uint16_t *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + uint8_t *maskDevice = nullptr; + uint16_t *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&maskHost), maskFileSizeConst); + aclrtMallocHost((void **)(&src0Host), fileSizeConst); + aclrtMallocHost((void **)(&src1Host), fileSizeConst); + aclrtMallocHost((void **)(&dstHost), fileSizeConst); + + aclrtMalloc((void **)&maskDevice, maskFileSizeConst, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src0Device, fileSizeConst, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, fileSizeConst, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSizeConst, ACL_MEM_MALLOC_HUGE_FIRST); + + size_t fileSize = fileSizeConst; + if (!ReadFile((caseDir + "/input1.bin").c_str(), fileSize, src0Host, fileSizeConst)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + fileSize = fileSizeConst; + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), fileSize, src1Host, fileSizeConst)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + size_t maskFileSize = maskFileSizeConst; + if (rc == 0 && !ReadFile((caseDir + "/input3.bin").c_str(), maskFileSize, maskHost, maskFileSizeConst)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input3.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(maskDevice, maskFileSizeConst, maskHost, maskFileSizeConst, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src0Device, fileSizeConst, src0Host, fileSizeConst, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, fileSizeConst, src1Host, fileSizeConst, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launchF16(maskDevice, src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSizeConst, dstDevice, fileSizeConst, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSizeConst)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (maskDevice != nullptr) + aclrtFree(maskDevice); + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (maskHost != nullptr) + aclrtFreeHost(maskHost); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + } else { + uint8_t *maskHost = nullptr; + int8_t *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + uint8_t *maskDevice = nullptr; + int8_t *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&maskHost), maskFileSizeConst); + aclrtMallocHost((void **)(&src0Host), fileSizeConst); + aclrtMallocHost((void **)(&src1Host), fileSizeConst); + aclrtMallocHost((void **)(&dstHost), fileSizeConst); + + aclrtMalloc((void **)&maskDevice, maskFileSizeConst, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src0Device, fileSizeConst, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, fileSizeConst, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSizeConst, ACL_MEM_MALLOC_HUGE_FIRST); + + size_t fileSize = fileSizeConst; + if (!ReadFile((caseDir + "/input1.bin").c_str(), fileSize, src0Host, fileSizeConst)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + fileSize = fileSizeConst; + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), fileSize, src1Host, fileSizeConst)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + size_t maskFileSize = maskFileSizeConst; + if (rc == 0 && !ReadFile((caseDir + "/input3.bin").c_str(), maskFileSize, maskHost, maskFileSizeConst)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input3.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(maskDevice, maskFileSizeConst, maskHost, maskFileSizeConst, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src0Device, fileSizeConst, src0Host, fileSizeConst, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, fileSizeConst, src1Host, fileSizeConst, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launchI8(maskDevice, src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSizeConst, dstDevice, fileSizeConst, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSizeConst)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (maskDevice != nullptr) + aclrtFree(maskDevice); + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (maskHost != nullptr) + aclrtFreeHost(maskHost); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + } + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsel/tsel.pto b/test/tilelang_st/npu/a5/src/st/testcase/tsel/tsel.pto new file mode 100644 index 000000000..8023c12db --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsel/tsel.pto @@ -0,0 +1,744 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tsel: packed mask tload + tload(src0) + tload(src1) + tsel(mask,src0,src1,tmp,dst) + tstore(dst). +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // Case: f32 2x128 + func.func @TSEL_f32_2x128(%mask_ptr: !pto.ptr, %src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + + %mask_view = pto.make_tensor_view %mask_ptr, + shape = [%c1, %c1, %c1, %c2, %c16], + strides = [%c32, %c32, %c32, %c16, %c1] + : !pto.tensor_view<1x1x1x2x16xi8> + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xf32> + + %mask_part = pto.partition_view %mask_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c16] + : !pto.tensor_view<1x1x1x2x16xi8> -> !pto.partition_tensor_view<1x1x1x2x16xi8> + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xf32> -> !pto.partition_tensor_view<1x1x1x2x128xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xf32> -> !pto.partition_tensor_view<1x1x1x2x128xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xf32> -> !pto.partition_tensor_view<1x1x1x2x128xf32> + + %mask = pto.alloc_tile + : !pto.tile_buf + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x16xi8>) + outs(%mask : !pto.tile_buf) + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x2x128xf32>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x2x128xf32>) + outs(%src1 : !pto.tile_buf) + + pto.tsel ins(%mask, %src0, %src1, %tmp : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x128xf32>) + return + } + + // Case: f32 2x32 + func.func @TSEL_f32_2x32(%mask_ptr: !pto.ptr, %src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %c8 = arith.constant 8 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + + %mask_view = pto.make_tensor_view %mask_ptr, + shape = [%c1, %c1, %c1, %c2, %c4], + strides = [%c8, %c8, %c8, %c4, %c1] + : !pto.tensor_view<1x1x1x2x4xi8> + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c2, %c32], + strides = [%c64, %c64, %c64, %c32, %c1] + : !pto.tensor_view<1x1x1x2x32xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c2, %c32], + strides = [%c64, %c64, %c64, %c32, %c1] + : !pto.tensor_view<1x1x1x2x32xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c32], + strides = [%c64, %c64, %c64, %c32, %c1] + : !pto.tensor_view<1x1x1x2x32xf32> + + %mask_part = pto.partition_view %mask_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c4] + : !pto.tensor_view<1x1x1x2x4xi8> -> !pto.partition_tensor_view<1x1x1x2x4xi8> + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c32] + : !pto.tensor_view<1x1x1x2x32xf32> -> !pto.partition_tensor_view<1x1x1x2x32xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c32] + : !pto.tensor_view<1x1x1x2x32xf32> -> !pto.partition_tensor_view<1x1x1x2x32xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c32] + : !pto.tensor_view<1x1x1x2x32xf32> -> !pto.partition_tensor_view<1x1x1x2x32xf32> + + %mask = pto.alloc_tile + : !pto.tile_buf + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x4xi8>) + outs(%mask : !pto.tile_buf) + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x2x32xf32>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x2x32xf32>) + outs(%src1 : !pto.tile_buf) + + pto.tsel ins(%mask, %src0, %src1, %tmp : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x32xf32>) + return + } + + // Case: f32 2x160 + func.func @TSEL_f32_2x160(%mask_ptr: !pto.ptr, %src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c20 = arith.constant 20 : index + %c40 = arith.constant 40 : index + %c160 = arith.constant 160 : index + %c320 = arith.constant 320 : index + + %mask_view = pto.make_tensor_view %mask_ptr, + shape = [%c1, %c1, %c1, %c2, %c20], + strides = [%c40, %c40, %c40, %c20, %c1] + : !pto.tensor_view<1x1x1x2x20xi8> + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c2, %c160], + strides = [%c320, %c320, %c320, %c160, %c1] + : !pto.tensor_view<1x1x1x2x160xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c2, %c160], + strides = [%c320, %c320, %c320, %c160, %c1] + : !pto.tensor_view<1x1x1x2x160xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c160], + strides = [%c320, %c320, %c320, %c160, %c1] + : !pto.tensor_view<1x1x1x2x160xf32> + + %mask_part = pto.partition_view %mask_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c20] + : !pto.tensor_view<1x1x1x2x20xi8> -> !pto.partition_tensor_view<1x1x1x2x20xi8> + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c160] + : !pto.tensor_view<1x1x1x2x160xf32> -> !pto.partition_tensor_view<1x1x1x2x160xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c160] + : !pto.tensor_view<1x1x1x2x160xf32> -> !pto.partition_tensor_view<1x1x1x2x160xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c160] + : !pto.tensor_view<1x1x1x2x160xf32> -> !pto.partition_tensor_view<1x1x1x2x160xf32> + + %mask = pto.alloc_tile + : !pto.tile_buf + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x20xi8>) + outs(%mask : !pto.tile_buf) + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x2x160xf32>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x2x160xf32>) + outs(%src1 : !pto.tile_buf) + + pto.tsel ins(%mask, %src0, %src1, %tmp : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x160xf32>) + return + } + + // Case: f32 2x512 + func.func @TSEL_f32_2x512(%mask_ptr: !pto.ptr, %src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + %c1024 = arith.constant 1024 : index + + %mask_view = pto.make_tensor_view %mask_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xi8> + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c2, %c512], + strides = [%c1024, %c1024, %c1024, %c512, %c1] + : !pto.tensor_view<1x1x1x2x512xf32> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c2, %c512], + strides = [%c1024, %c1024, %c1024, %c512, %c1] + : !pto.tensor_view<1x1x1x2x512xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c512], + strides = [%c1024, %c1024, %c1024, %c512, %c1] + : !pto.tensor_view<1x1x1x2x512xf32> + + %mask_part = pto.partition_view %mask_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xi8> -> !pto.partition_tensor_view<1x1x1x2x64xi8> + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c512] + : !pto.tensor_view<1x1x1x2x512xf32> -> !pto.partition_tensor_view<1x1x1x2x512xf32> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c512] + : !pto.tensor_view<1x1x1x2x512xf32> -> !pto.partition_tensor_view<1x1x1x2x512xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c512] + : !pto.tensor_view<1x1x1x2x512xf32> -> !pto.partition_tensor_view<1x1x1x2x512xf32> + + %mask = pto.alloc_tile + : !pto.tile_buf + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x64xi8>) + outs(%mask : !pto.tile_buf) + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x2x512xf32>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x2x512xf32>) + outs(%src1 : !pto.tile_buf) + + pto.tsel ins(%mask, %src0, %src1, %tmp : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x512xf32>) + return + } + + // Case: f16 2x128 + func.func @TSEL_f16_2x128(%mask_ptr: !pto.ptr, %src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + + %mask_view = pto.make_tensor_view %mask_ptr, + shape = [%c1, %c1, %c1, %c2, %c16], + strides = [%c32, %c32, %c32, %c16, %c1] + : !pto.tensor_view<1x1x1x2x16xi8> + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xf16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xf16> + + %mask_part = pto.partition_view %mask_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c16] + : !pto.tensor_view<1x1x1x2x16xi8> -> !pto.partition_tensor_view<1x1x1x2x16xi8> + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xf16> -> !pto.partition_tensor_view<1x1x1x2x128xf16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xf16> -> !pto.partition_tensor_view<1x1x1x2x128xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xf16> -> !pto.partition_tensor_view<1x1x1x2x128xf16> + + %mask = pto.alloc_tile + : !pto.tile_buf + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x16xi8>) + outs(%mask : !pto.tile_buf) + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x2x128xf16>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x2x128xf16>) + outs(%src1 : !pto.tile_buf) + + pto.tsel ins(%mask, %src0, %src1, %tmp : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x128xf16>) + return + } + + // Case: f16 2x32 + func.func @TSEL_f16_2x32(%mask_ptr: !pto.ptr, %src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %c8 = arith.constant 8 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + + %mask_view = pto.make_tensor_view %mask_ptr, + shape = [%c1, %c1, %c1, %c2, %c4], + strides = [%c8, %c8, %c8, %c4, %c1] + : !pto.tensor_view<1x1x1x2x4xi8> + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c2, %c32], + strides = [%c64, %c64, %c64, %c32, %c1] + : !pto.tensor_view<1x1x1x2x32xf16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c2, %c32], + strides = [%c64, %c64, %c64, %c32, %c1] + : !pto.tensor_view<1x1x1x2x32xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c32], + strides = [%c64, %c64, %c64, %c32, %c1] + : !pto.tensor_view<1x1x1x2x32xf16> + + %mask_part = pto.partition_view %mask_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c4] + : !pto.tensor_view<1x1x1x2x4xi8> -> !pto.partition_tensor_view<1x1x1x2x4xi8> + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c32] + : !pto.tensor_view<1x1x1x2x32xf16> -> !pto.partition_tensor_view<1x1x1x2x32xf16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c32] + : !pto.tensor_view<1x1x1x2x32xf16> -> !pto.partition_tensor_view<1x1x1x2x32xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c32] + : !pto.tensor_view<1x1x1x2x32xf16> -> !pto.partition_tensor_view<1x1x1x2x32xf16> + + %mask = pto.alloc_tile + : !pto.tile_buf + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x4xi8>) + outs(%mask : !pto.tile_buf) + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x2x32xf16>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x2x32xf16>) + outs(%src1 : !pto.tile_buf) + + pto.tsel ins(%mask, %src0, %src1, %tmp : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x32xf16>) + return + } + + // Case: f16 2x160 + func.func @TSEL_f16_2x160(%mask_ptr: !pto.ptr, %src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c20 = arith.constant 20 : index + %c40 = arith.constant 40 : index + %c160 = arith.constant 160 : index + %c320 = arith.constant 320 : index + + %mask_view = pto.make_tensor_view %mask_ptr, + shape = [%c1, %c1, %c1, %c2, %c20], + strides = [%c40, %c40, %c40, %c20, %c1] + : !pto.tensor_view<1x1x1x2x20xi8> + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c2, %c160], + strides = [%c320, %c320, %c320, %c160, %c1] + : !pto.tensor_view<1x1x1x2x160xf16> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c2, %c160], + strides = [%c320, %c320, %c320, %c160, %c1] + : !pto.tensor_view<1x1x1x2x160xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c160], + strides = [%c320, %c320, %c320, %c160, %c1] + : !pto.tensor_view<1x1x1x2x160xf16> + + %mask_part = pto.partition_view %mask_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c20] + : !pto.tensor_view<1x1x1x2x20xi8> -> !pto.partition_tensor_view<1x1x1x2x20xi8> + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c160] + : !pto.tensor_view<1x1x1x2x160xf16> -> !pto.partition_tensor_view<1x1x1x2x160xf16> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c160] + : !pto.tensor_view<1x1x1x2x160xf16> -> !pto.partition_tensor_view<1x1x1x2x160xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c160] + : !pto.tensor_view<1x1x1x2x160xf16> -> !pto.partition_tensor_view<1x1x1x2x160xf16> + + %mask = pto.alloc_tile + : !pto.tile_buf + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x20xi8>) + outs(%mask : !pto.tile_buf) + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x2x160xf16>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x2x160xf16>) + outs(%src1 : !pto.tile_buf) + + pto.tsel ins(%mask, %src0, %src1, %tmp : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x160xf16>) + return + } + + // Case: i8 2x128 + func.func @TSEL_i8_2x128(%mask_ptr: !pto.ptr, %src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + + %mask_view = pto.make_tensor_view %mask_ptr, + shape = [%c1, %c1, %c1, %c2, %c16], + strides = [%c32, %c32, %c32, %c16, %c1] + : !pto.tensor_view<1x1x1x2x16xi8> + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xi8> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xi8> + + %mask_part = pto.partition_view %mask_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c16] + : !pto.tensor_view<1x1x1x2x16xi8> -> !pto.partition_tensor_view<1x1x1x2x16xi8> + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xi8> -> !pto.partition_tensor_view<1x1x1x2x128xi8> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xi8> -> !pto.partition_tensor_view<1x1x1x2x128xi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xi8> -> !pto.partition_tensor_view<1x1x1x2x128xi8> + + %mask = pto.alloc_tile + : !pto.tile_buf + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x16xi8>) + outs(%mask : !pto.tile_buf) + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x2x128xi8>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x2x128xi8>) + outs(%src1 : !pto.tile_buf) + + pto.tsel ins(%mask, %src0, %src1, %tmp : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x128xi8>) + return + } + + // Case: i8 2x32 + func.func @TSEL_i8_2x32(%mask_ptr: !pto.ptr, %src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %c8 = arith.constant 8 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + + %mask_view = pto.make_tensor_view %mask_ptr, + shape = [%c1, %c1, %c1, %c2, %c4], + strides = [%c8, %c8, %c8, %c4, %c1] + : !pto.tensor_view<1x1x1x2x4xi8> + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c2, %c32], + strides = [%c64, %c64, %c64, %c32, %c1] + : !pto.tensor_view<1x1x1x2x32xi8> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c2, %c32], + strides = [%c64, %c64, %c64, %c32, %c1] + : !pto.tensor_view<1x1x1x2x32xi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c32], + strides = [%c64, %c64, %c64, %c32, %c1] + : !pto.tensor_view<1x1x1x2x32xi8> + + %mask_part = pto.partition_view %mask_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c4] + : !pto.tensor_view<1x1x1x2x4xi8> -> !pto.partition_tensor_view<1x1x1x2x4xi8> + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c32] + : !pto.tensor_view<1x1x1x2x32xi8> -> !pto.partition_tensor_view<1x1x1x2x32xi8> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c32] + : !pto.tensor_view<1x1x1x2x32xi8> -> !pto.partition_tensor_view<1x1x1x2x32xi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c32] + : !pto.tensor_view<1x1x1x2x32xi8> -> !pto.partition_tensor_view<1x1x1x2x32xi8> + + %mask = pto.alloc_tile + : !pto.tile_buf + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x4xi8>) + outs(%mask : !pto.tile_buf) + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x2x32xi8>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x2x32xi8>) + outs(%src1 : !pto.tile_buf) + + pto.tsel ins(%mask, %src0, %src1, %tmp : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x32xi8>) + return + } + + // Case: i8 2x160 + func.func @TSEL_i8_2x160(%mask_ptr: !pto.ptr, %src0_ptr: !pto.ptr, %src1_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c20 = arith.constant 20 : index + %c40 = arith.constant 40 : index + %c160 = arith.constant 160 : index + %c320 = arith.constant 320 : index + + %mask_view = pto.make_tensor_view %mask_ptr, + shape = [%c1, %c1, %c1, %c2, %c20], + strides = [%c40, %c40, %c40, %c20, %c1] + : !pto.tensor_view<1x1x1x2x20xi8> + %src0_view = pto.make_tensor_view %src0_ptr, + shape = [%c1, %c1, %c1, %c2, %c160], + strides = [%c320, %c320, %c320, %c160, %c1] + : !pto.tensor_view<1x1x1x2x160xi8> + %src1_view = pto.make_tensor_view %src1_ptr, + shape = [%c1, %c1, %c1, %c2, %c160], + strides = [%c320, %c320, %c320, %c160, %c1] + : !pto.tensor_view<1x1x1x2x160xi8> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c160], + strides = [%c320, %c320, %c320, %c160, %c1] + : !pto.tensor_view<1x1x1x2x160xi8> + + %mask_part = pto.partition_view %mask_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c20] + : !pto.tensor_view<1x1x1x2x20xi8> -> !pto.partition_tensor_view<1x1x1x2x20xi8> + %src0_part = pto.partition_view %src0_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c160] + : !pto.tensor_view<1x1x1x2x160xi8> -> !pto.partition_tensor_view<1x1x1x2x160xi8> + %src1_part = pto.partition_view %src1_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c160] + : !pto.tensor_view<1x1x1x2x160xi8> -> !pto.partition_tensor_view<1x1x1x2x160xi8> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c160] + : !pto.tensor_view<1x1x1x2x160xi8> -> !pto.partition_tensor_view<1x1x1x2x160xi8> + + %mask = pto.alloc_tile + : !pto.tile_buf + %src0 = pto.alloc_tile + : !pto.tile_buf + %src1 = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x20xi8>) + outs(%mask : !pto.tile_buf) + pto.tload ins(%src0_part : !pto.partition_tensor_view<1x1x1x2x160xi8>) + outs(%src0 : !pto.tile_buf) + pto.tload ins(%src1_part : !pto.partition_tensor_view<1x1x1x2x160xi8>) + outs(%src1 : !pto.tile_buf) + + pto.tsel ins(%mask, %src0, %src1, %tmp : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x160xi8>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsels/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tsels/CMakeLists.txt new file mode 100644 index 000000000..d699a3c35 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsels/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tsels) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsels/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tsels/cases.py new file mode 100644 index 000000000..496b37c96 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsels/cases.py @@ -0,0 +1,50 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tsels ST test cases. + +Each case defines: + - name: case identifier + - dtype: numpy dtype for data (src/dst) + - dtype_mask: numpy dtype for mask + - dst_shape: (dst_rows, dst_cols) — allocated dst tile dimensions + - mask_shape: (mask_rows, mask_cols) — allocated mask tile dimensions + - src_shape: (src_rows, src_cols) — allocated src tile dimensions + - valid_shape: (valid_rows, valid_cols) — effective computation region + - eps: tolerance for numpy.allclose (atol and rtol) +""" + +import numpy as np + +CASES = [ + {"name": "uint8_uint8_2x32_2x32_2x32_2x32", "dtype": np.uint8, "dtype_mask": np.uint8, "shape": (2, 32), "dst_shape": (2, 32), "dst_valid_shape": (2, 32), "mask_shape": (2, 32), "src_shape": (2, 32), "valid_shape": (2, 32), "eps": 0}, + {"name": "uint8_uint16_2x32_2x16_2x32_2x32", "dtype": np.uint8, "dtype_mask": np.uint16, "shape": (2, 32), "dst_shape": (2, 32), "dst_valid_shape": (2, 32), "mask_shape": (2, 16), "src_shape": (2, 32), "valid_shape": (2, 32), "eps": 0}, + {"name": "uint8_uint32_2x32_2x8_2x32_2x32", "dtype": np.uint8, "dtype_mask": np.uint32, "shape": (2, 32), "dst_shape": (2, 32), "dst_valid_shape": (2, 32), "mask_shape": (2, 8), "src_shape": (2, 32), "valid_shape": (2, 32), "eps": 0}, + {"name": "uint16_uint8_2x16_2x32_2x16_2x16", "dtype": np.uint16, "dtype_mask": np.uint8, "shape": (2, 16), "dst_shape": (2, 16), "dst_valid_shape": (2, 16), "mask_shape": (2, 32), "src_shape": (2, 16), "valid_shape": (2, 16), "eps": 0}, + {"name": "uint16_uint16_2x16_2x16_2x16_2x16", "dtype": np.uint16, "dtype_mask": np.uint16, "shape": (2, 16), "dst_shape": (2, 16), "dst_valid_shape": (2, 16), "mask_shape": (2, 16), "src_shape": (2, 16), "valid_shape": (2, 16), "eps": 0}, + {"name": "uint16_uint32_2x16_2x8_2x16_2x16", "dtype": np.uint16, "dtype_mask": np.uint32, "shape": (2, 16), "dst_shape": (2, 16), "dst_valid_shape": (2, 16), "mask_shape": (2, 8), "src_shape": (2, 16), "valid_shape": (2, 16), "eps": 0}, + {"name": "uint32_uint8_2x8_2x32_2x8_2x8", "dtype": np.uint32, "dtype_mask": np.uint8, "shape": (2, 8), "dst_shape": (2, 8), "dst_valid_shape": (2, 8), "mask_shape": (2, 32), "src_shape": (2, 8), "valid_shape": (2, 8), "eps": 0}, + {"name": "uint32_uint16_2x8_2x16_2x8_2x8", "dtype": np.uint32, "dtype_mask": np.uint16, "shape": (2, 8), "dst_shape": (2, 8), "dst_valid_shape": (2, 8), "mask_shape": (2, 16), "src_shape": (2, 8), "valid_shape": (2, 8), "eps": 0}, + {"name": "uint32_uint32_2x8_2x8_2x8_2x8", "dtype": np.uint32, "dtype_mask": np.uint32, "shape": (2, 8), "dst_shape": (2, 8), "dst_valid_shape": (2, 8), "mask_shape": (2, 8), "src_shape": (2, 8), "valid_shape": (2, 8), "eps": 0}, + {"name": "f16_uint8_2x16_2x32_2x16_2x16", "dtype": np.float16, "dtype_mask": np.uint8, "shape": (2, 16), "dst_shape": (2, 16), "dst_valid_shape": (2, 16), "mask_shape": (2, 32), "src_shape": (2, 16), "valid_shape": (2, 16), "eps": 1e-3}, + {"name": "f16_uint16_2x16_2x16_2x16_2x16", "dtype": np.float16, "dtype_mask": np.uint16, "shape": (2, 16), "dst_shape": (2, 16), "dst_valid_shape": (2, 16), "mask_shape": (2, 16), "src_shape": (2, 16), "valid_shape": (2, 16), "eps": 1e-3}, + {"name": "f16_uint32_2x16_2x8_2x16_2x16", "dtype": np.float16, "dtype_mask": np.uint32, "shape": (2, 16), "dst_shape": (2, 16), "dst_valid_shape": (2, 16), "mask_shape": (2, 8), "src_shape": (2, 16), "valid_shape": (2, 16), "eps": 1e-3}, + {"name": "f32_uint8_2x8_2x32_2x8_2x8", "dtype": np.float32, "dtype_mask": np.uint8, "shape": (2, 8), "dst_shape": (2, 8), "dst_valid_shape": (2, 8), "mask_shape": (2, 32), "src_shape": (2, 8), "valid_shape": (2, 8), "eps": 1e-6}, + {"name": "f32_uint16_2x8_2x16_2x8_2x8", "dtype": np.float32, "dtype_mask": np.uint16, "shape": (2, 8), "dst_shape": (2, 8), "dst_valid_shape": (2, 8), "mask_shape": (2, 16), "src_shape": (2, 8), "valid_shape": (2, 8), "eps": 1e-6}, + {"name": "f32_uint32_2x8_2x8_2x8_2x8", "dtype": np.float32, "dtype_mask": np.uint32, "shape": (2, 8), "dst_shape": (2, 8), "dst_valid_shape": (2, 8), "mask_shape": (2, 8), "src_shape": (2, 8), "valid_shape": (2, 8), "eps": 1e-6}, + {"name": "uint8_uint8_2x32_2x64_2x128_2x31", "dtype": np.uint8, "dtype_mask": np.uint8, "shape": (2, 32), "dst_shape": (2, 32), "dst_valid_shape": (2, 31), "mask_shape": (2, 64), "src_shape": (2, 128), "valid_shape": (2, 31), "eps": 0}, + {"name": "uint16_uint8_2x32_2x64_2x128_2x31", "dtype": np.uint16, "dtype_mask": np.uint8, "shape": (2, 32), "dst_shape": (2, 32), "dst_valid_shape": (2, 31), "mask_shape": (2, 64), "src_shape": (2, 128), "valid_shape": (2, 31), "eps": 0}, + {"name": "f32_uint8_2x32_2x64_2x128_2x31", "dtype": np.float32, "dtype_mask": np.uint8, "shape": (2, 32), "dst_shape": (2, 32), "dst_valid_shape": (2, 31), "mask_shape": (2, 64), "src_shape": (2, 128), "valid_shape": (2, 31), "eps": 1e-6}, + {"name": "uint8_uint8_32x672_32x96_32x672_32x666", "dtype": np.uint8, "dtype_mask": np.uint8, "shape": (32, 672), "dst_shape": (32, 672), "dst_valid_shape": (32, 666), "mask_shape": (32, 96), "src_shape": (32, 672), "valid_shape": (32, 666), "eps": 0}, + {"name": "f16_uint8_32x672_32x96_32x672_32x666", "dtype": np.float16, "dtype_mask": np.uint8, "shape": (32, 672), "dst_shape": (32, 672), "dst_valid_shape": (32, 666), "mask_shape": (32, 96), "src_shape": (32, 672), "valid_shape": (32, 666), "eps": 1e-3}, + {"name": "f32_uint8_32x672_32x96_32x672_32x666", "dtype": np.float32, "dtype_mask": np.uint8, "shape": (32, 672), "dst_shape": (32, 672), "dst_valid_shape": (32, 666), "mask_shape": (32, 96), "src_shape": (32, 672), "valid_shape": (32, 666), "eps": 1e-6}, + {"name": "f32_uint8_1x8192_1x4096_1x8192_1x8192", "dtype": np.float32, "dtype_mask": np.uint8, "shape": (1, 8192), "dst_shape": (1, 8192), "dst_valid_shape": (1, 8192), "mask_shape": (1, 4096), "src_shape": (1, 8192), "valid_shape": (1, 8192), "eps": 1e-6}, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsels/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tsels/compare.py new file mode 100644 index 000000000..6af6f6d5c --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsels/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + dst_shape = case["dst_shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(dst_shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(dst_shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsels/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tsels/gen_data.py new file mode 100644 index 000000000..b462425c6 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsels/gen_data.py @@ -0,0 +1,49 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + dtype_mask = case["dtype_mask"] + dst_shape = case["dst_shape"] + mask_shape = case["mask_shape"] + src_shape = case["src_shape"] + valid_shape = case["valid_shape"] + height, width = valid_shape + + if dtype in (np.int8, np.uint8, np.int16, np.uint16, np.int32, np.uint32): + dtype_info = np.iinfo(dtype) + input1 = np.random.randint(dtype_info.min, dtype_info.max, size=src_shape).astype(dtype) + input2 = np.random.randint(dtype_info.min, dtype_info.max, size=[1]).astype(dtype) + else: + dtype_info = np.finfo(dtype) + input1 = np.random.uniform(low=dtype_info.min, high=dtype_info.max, size=src_shape).astype(dtype) + input2 = np.random.uniform(low=dtype_info.min, high=dtype_info.max, size=[1]).astype(dtype) + + mask_dtype_info = np.iinfo(dtype_mask) + mask = np.random.randint(mask_dtype_info.min, mask_dtype_info.max, size=mask_shape).astype(dtype_mask) + mask_u8view = mask.view(np.uint8).reshape(mask_shape[0], -1) + golden = np.zeros(dst_shape, dtype=dtype) + + for y in range(height): + for x in range(width): + do_select = (1 << (x & 7)) & mask_u8view[y, x >> 3] + golden[y, x] = input1[y, x] if do_select != 0 else input2[0] + + save_case_data(case["name"], {"mask": mask, "input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} dst={dst_shape} mask={mask_shape} src={src_shape} valid={valid_shape} dtype={dtype.__name__} mask_dtype={dtype_mask.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsels/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tsels/launch.cpp new file mode 100644 index 000000000..30372dbb8 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsels/launch.cpp @@ -0,0 +1,168 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +extern "C" __global__ AICORE void TSELS_uint8_uint8_2x32_2x32_2x32_2x32(__gm__ uint8_t *mask, __gm__ uint8_t *src, __gm__ uint8_t *dst, uint8_t scalar); +void LaunchTSELS_uint8_uint8_2x32_2x32_2x32_2x32(uint8_t *mask, uint8_t *src, uint8_t *dst, void *scalar_ptr, void *stream) { + uint8_t scalar; + std::memcpy(&scalar, scalar_ptr, sizeof(uint8_t)); + TSELS_uint8_uint8_2x32_2x32_2x32_2x32<<<1, nullptr, stream>>>((__gm__ uint8_t *)mask, (__gm__ uint8_t *)src, (__gm__ uint8_t *)dst, scalar); +} + +extern "C" __global__ AICORE void TSELS_uint8_uint16_2x32_2x16_2x32_2x32(__gm__ uint16_t *mask, __gm__ uint8_t *src, __gm__ uint8_t *dst, uint8_t scalar); +void LaunchTSELS_uint8_uint16_2x32_2x16_2x32_2x32(uint16_t *mask, uint8_t *src, uint8_t *dst, void *scalar_ptr, void *stream) { + uint8_t scalar; + std::memcpy(&scalar, scalar_ptr, sizeof(uint8_t)); + TSELS_uint8_uint16_2x32_2x16_2x32_2x32<<<1, nullptr, stream>>>((__gm__ uint16_t *)mask, (__gm__ uint8_t *)src, (__gm__ uint8_t *)dst, scalar); +} + +extern "C" __global__ AICORE void TSELS_uint8_uint32_2x32_2x8_2x32_2x32(__gm__ uint32_t *mask, __gm__ uint8_t *src, __gm__ uint8_t *dst, uint8_t scalar); +void LaunchTSELS_uint8_uint32_2x32_2x8_2x32_2x32(uint32_t *mask, uint8_t *src, uint8_t *dst, void *scalar_ptr, void *stream) { + uint8_t scalar; + std::memcpy(&scalar, scalar_ptr, sizeof(uint8_t)); + TSELS_uint8_uint32_2x32_2x8_2x32_2x32<<<1, nullptr, stream>>>((__gm__ uint32_t *)mask, (__gm__ uint8_t *)src, (__gm__ uint8_t *)dst, scalar); +} + +extern "C" __global__ AICORE void TSELS_uint16_uint8_2x16_2x32_2x16_2x16(__gm__ uint8_t *mask, __gm__ uint16_t *src, __gm__ uint16_t *dst, uint16_t scalar); +void LaunchTSELS_uint16_uint8_2x16_2x32_2x16_2x16(uint8_t *mask, uint16_t *src, uint16_t *dst, void *scalar_ptr, void *stream) { + uint16_t scalar; + std::memcpy(&scalar, scalar_ptr, sizeof(uint16_t)); + TSELS_uint16_uint8_2x16_2x32_2x16_2x16<<<1, nullptr, stream>>>((__gm__ uint8_t *)mask, (__gm__ uint16_t *)src, (__gm__ uint16_t *)dst, scalar); +} + +extern "C" __global__ AICORE void TSELS_uint16_uint16_2x16_2x16_2x16_2x16(__gm__ uint16_t *mask, __gm__ uint16_t *src, __gm__ uint16_t *dst, uint16_t scalar); +void LaunchTSELS_uint16_uint16_2x16_2x16_2x16_2x16(uint16_t *mask, uint16_t *src, uint16_t *dst, void *scalar_ptr, void *stream) { + uint16_t scalar; + std::memcpy(&scalar, scalar_ptr, sizeof(uint16_t)); + TSELS_uint16_uint16_2x16_2x16_2x16_2x16<<<1, nullptr, stream>>>((__gm__ uint16_t *)mask, (__gm__ uint16_t *)src, (__gm__ uint16_t *)dst, scalar); +} + +extern "C" __global__ AICORE void TSELS_uint16_uint32_2x16_2x8_2x16_2x16(__gm__ uint32_t *mask, __gm__ uint16_t *src, __gm__ uint16_t *dst, uint16_t scalar); +void LaunchTSELS_uint16_uint32_2x16_2x8_2x16_2x16(uint32_t *mask, uint16_t *src, uint16_t *dst, void *scalar_ptr, void *stream) { + uint16_t scalar; + std::memcpy(&scalar, scalar_ptr, sizeof(uint16_t)); + TSELS_uint16_uint32_2x16_2x8_2x16_2x16<<<1, nullptr, stream>>>((__gm__ uint32_t *)mask, (__gm__ uint16_t *)src, (__gm__ uint16_t *)dst, scalar); +} + +extern "C" __global__ AICORE void TSELS_uint32_uint8_2x8_2x32_2x8_2x8(__gm__ uint8_t *mask, __gm__ uint32_t *src, __gm__ uint32_t *dst, uint32_t scalar); +void LaunchTSELS_uint32_uint8_2x8_2x32_2x8_2x8(uint8_t *mask, uint32_t *src, uint32_t *dst, void *scalar_ptr, void *stream) { + uint32_t scalar; + std::memcpy(&scalar, scalar_ptr, sizeof(uint32_t)); + TSELS_uint32_uint8_2x8_2x32_2x8_2x8<<<1, nullptr, stream>>>((__gm__ uint8_t *)mask, (__gm__ uint32_t *)src, (__gm__ uint32_t *)dst, scalar); +} + +extern "C" __global__ AICORE void TSELS_uint32_uint16_2x8_2x16_2x8_2x8(__gm__ uint16_t *mask, __gm__ uint32_t *src, __gm__ uint32_t *dst, uint32_t scalar); +void LaunchTSELS_uint32_uint16_2x8_2x16_2x8_2x8(uint16_t *mask, uint32_t *src, uint32_t *dst, void *scalar_ptr, void *stream) { + uint32_t scalar; + std::memcpy(&scalar, scalar_ptr, sizeof(uint32_t)); + TSELS_uint32_uint16_2x8_2x16_2x8_2x8<<<1, nullptr, stream>>>((__gm__ uint16_t *)mask, (__gm__ uint32_t *)src, (__gm__ uint32_t *)dst, scalar); +} + +extern "C" __global__ AICORE void TSELS_uint32_uint32_2x8_2x8_2x8_2x8(__gm__ uint32_t *mask, __gm__ uint32_t *src, __gm__ uint32_t *dst, uint32_t scalar); +void LaunchTSELS_uint32_uint32_2x8_2x8_2x8_2x8(uint32_t *mask, uint32_t *src, uint32_t *dst, void *scalar_ptr, void *stream) { + uint32_t scalar; + std::memcpy(&scalar, scalar_ptr, sizeof(uint32_t)); + TSELS_uint32_uint32_2x8_2x8_2x8_2x8<<<1, nullptr, stream>>>((__gm__ uint32_t *)mask, (__gm__ uint32_t *)src, (__gm__ uint32_t *)dst, scalar); +} + +extern "C" __global__ AICORE void TSELS_f16_uint8_2x16_2x32_2x16_2x16(__gm__ uint8_t *mask, __gm__ uint16_t *src, __gm__ uint16_t *dst, uint16_t scalar); +void LaunchTSELS_f16_uint8_2x16_2x32_2x16_2x16(uint8_t *mask, uint16_t *src, uint16_t *dst, void *scalar_ptr, void *stream) { + uint16_t scalar; + std::memcpy(&scalar, scalar_ptr, sizeof(uint16_t)); + TSELS_f16_uint8_2x16_2x32_2x16_2x16<<<1, nullptr, stream>>>((__gm__ uint8_t *)mask, (__gm__ uint16_t *)src, (__gm__ uint16_t *)dst, scalar); +} + +extern "C" __global__ AICORE void TSELS_f16_uint16_2x16_2x16_2x16_2x16(__gm__ uint16_t *mask, __gm__ uint16_t *src, __gm__ uint16_t *dst, uint16_t scalar); +void LaunchTSELS_f16_uint16_2x16_2x16_2x16_2x16(uint16_t *mask, uint16_t *src, uint16_t *dst, void *scalar_ptr, void *stream) { + uint16_t scalar; + std::memcpy(&scalar, scalar_ptr, sizeof(uint16_t)); + TSELS_f16_uint16_2x16_2x16_2x16_2x16<<<1, nullptr, stream>>>((__gm__ uint16_t *)mask, (__gm__ uint16_t *)src, (__gm__ uint16_t *)dst, scalar); +} + +extern "C" __global__ AICORE void TSELS_f16_uint32_2x16_2x8_2x16_2x16(__gm__ uint32_t *mask, __gm__ uint16_t *src, __gm__ uint16_t *dst, uint16_t scalar); +void LaunchTSELS_f16_uint32_2x16_2x8_2x16_2x16(uint32_t *mask, uint16_t *src, uint16_t *dst, void *scalar_ptr, void *stream) { + uint16_t scalar; + std::memcpy(&scalar, scalar_ptr, sizeof(uint16_t)); + TSELS_f16_uint32_2x16_2x8_2x16_2x16<<<1, nullptr, stream>>>((__gm__ uint32_t *)mask, (__gm__ uint16_t *)src, (__gm__ uint16_t *)dst, scalar); +} + +extern "C" __global__ AICORE void TSELS_f32_uint8_2x8_2x32_2x8_2x8(__gm__ uint8_t *mask, __gm__ float *src, __gm__ float *dst, float scalar); +void LaunchTSELS_f32_uint8_2x8_2x32_2x8_2x8(uint8_t *mask, float *src, float *dst, void *scalar_ptr, void *stream) { + float scalar; + std::memcpy(&scalar, scalar_ptr, sizeof(float)); + TSELS_f32_uint8_2x8_2x32_2x8_2x8<<<1, nullptr, stream>>>((__gm__ uint8_t *)mask, (__gm__ float *)src, (__gm__ float *)dst, scalar); +} + +extern "C" __global__ AICORE void TSELS_f32_uint16_2x8_2x16_2x8_2x8(__gm__ uint16_t *mask, __gm__ float *src, __gm__ float *dst, float scalar); +void LaunchTSELS_f32_uint16_2x8_2x16_2x8_2x8(uint16_t *mask, float *src, float *dst, void *scalar_ptr, void *stream) { + float scalar; + std::memcpy(&scalar, scalar_ptr, sizeof(float)); + TSELS_f32_uint16_2x8_2x16_2x8_2x8<<<1, nullptr, stream>>>((__gm__ uint16_t *)mask, (__gm__ float *)src, (__gm__ float *)dst, scalar); +} + +extern "C" __global__ AICORE void TSELS_f32_uint32_2x8_2x8_2x8_2x8(__gm__ uint32_t *mask, __gm__ float *src, __gm__ float *dst, float scalar); +void LaunchTSELS_f32_uint32_2x8_2x8_2x8_2x8(uint32_t *mask, float *src, float *dst, void *scalar_ptr, void *stream) { + float scalar; + std::memcpy(&scalar, scalar_ptr, sizeof(float)); + TSELS_f32_uint32_2x8_2x8_2x8_2x8<<<1, nullptr, stream>>>((__gm__ uint32_t *)mask, (__gm__ float *)src, (__gm__ float *)dst, scalar); +} + +extern "C" __global__ AICORE void TSELS_uint8_uint8_2x32_2x64_2x128_2x31(__gm__ uint8_t *mask, __gm__ uint8_t *src, __gm__ uint8_t *dst, uint8_t scalar); +void LaunchTSELS_uint8_uint8_2x32_2x64_2x128_2x31(uint8_t *mask, uint8_t *src, uint8_t *dst, void *scalar_ptr, void *stream) { + uint8_t scalar; + std::memcpy(&scalar, scalar_ptr, sizeof(uint8_t)); + TSELS_uint8_uint8_2x32_2x64_2x128_2x31<<<1, nullptr, stream>>>((__gm__ uint8_t *)mask, (__gm__ uint8_t *)src, (__gm__ uint8_t *)dst, scalar); +} + +extern "C" __global__ AICORE void TSELS_uint16_uint8_2x32_2x64_2x128_2x31(__gm__ uint8_t *mask, __gm__ uint16_t *src, __gm__ uint16_t *dst, uint16_t scalar); +void LaunchTSELS_uint16_uint8_2x32_2x64_2x128_2x31(uint8_t *mask, uint16_t *src, uint16_t *dst, void *scalar_ptr, void *stream) { + uint16_t scalar; + std::memcpy(&scalar, scalar_ptr, sizeof(uint16_t)); + TSELS_uint16_uint8_2x32_2x64_2x128_2x31<<<1, nullptr, stream>>>((__gm__ uint8_t *)mask, (__gm__ uint16_t *)src, (__gm__ uint16_t *)dst, scalar); +} + +extern "C" __global__ AICORE void TSELS_f32_uint8_2x32_2x64_2x128_2x31(__gm__ uint8_t *mask, __gm__ float *src, __gm__ float *dst, float scalar); +void LaunchTSELS_f32_uint8_2x32_2x64_2x128_2x31(uint8_t *mask, float *src, float *dst, void *scalar_ptr, void *stream) { + float scalar; + std::memcpy(&scalar, scalar_ptr, sizeof(float)); + TSELS_f32_uint8_2x32_2x64_2x128_2x31<<<1, nullptr, stream>>>((__gm__ uint8_t *)mask, (__gm__ float *)src, (__gm__ float *)dst, scalar); +} + +extern "C" __global__ AICORE void TSELS_uint8_uint8_32x672_32x96_32x672_32x666(__gm__ uint8_t *mask, __gm__ uint8_t *src, __gm__ uint8_t *dst, uint8_t scalar); +void LaunchTSELS_uint8_uint8_32x672_32x96_32x672_32x666(uint8_t *mask, uint8_t *src, uint8_t *dst, void *scalar_ptr, void *stream) { + uint8_t scalar; + std::memcpy(&scalar, scalar_ptr, sizeof(uint8_t)); + TSELS_uint8_uint8_32x672_32x96_32x672_32x666<<<1, nullptr, stream>>>((__gm__ uint8_t *)mask, (__gm__ uint8_t *)src, (__gm__ uint8_t *)dst, scalar); +} + +extern "C" __global__ AICORE void TSELS_f16_uint8_32x672_32x96_32x672_32x666(__gm__ uint8_t *mask, __gm__ uint16_t *src, __gm__ uint16_t *dst, uint16_t scalar); +void LaunchTSELS_f16_uint8_32x672_32x96_32x672_32x666(uint8_t *mask, uint16_t *src, uint16_t *dst, void *scalar_ptr, void *stream) { + uint16_t scalar; + std::memcpy(&scalar, scalar_ptr, sizeof(uint16_t)); + TSELS_f16_uint8_32x672_32x96_32x672_32x666<<<1, nullptr, stream>>>((__gm__ uint8_t *)mask, (__gm__ uint16_t *)src, (__gm__ uint16_t *)dst, scalar); +} + +extern "C" __global__ AICORE void TSELS_f32_uint8_32x672_32x96_32x672_32x666(__gm__ uint8_t *mask, __gm__ float *src, __gm__ float *dst, float scalar); +void LaunchTSELS_f32_uint8_32x672_32x96_32x672_32x666(uint8_t *mask, float *src, float *dst, void *scalar_ptr, void *stream) { + float scalar; + std::memcpy(&scalar, scalar_ptr, sizeof(float)); + TSELS_f32_uint8_32x672_32x96_32x672_32x666<<<1, nullptr, stream>>>((__gm__ uint8_t *)mask, (__gm__ float *)src, (__gm__ float *)dst, scalar); +} + +extern "C" __global__ AICORE void TSELS_f32_uint8_1x8192_1x4096_1x8192_1x8192(__gm__ uint8_t *mask, __gm__ float *src, __gm__ float *dst, float scalar); +void LaunchTSELS_f32_uint8_1x8192_1x4096_1x8192_1x8192(uint8_t *mask, float *src, float *dst, void *scalar_ptr, void *stream) { + float scalar; + std::memcpy(&scalar, scalar_ptr, sizeof(float)); + TSELS_f32_uint8_1x8192_1x4096_1x8192_1x8192<<<1, nullptr, stream>>>((__gm__ uint8_t *)mask, (__gm__ float *)src, (__gm__ float *)dst, scalar); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsels/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tsels/main.cpp new file mode 100644 index 000000000..351e822cc --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsels/main.cpp @@ -0,0 +1,181 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +void LaunchTSELS_uint8_uint8_2x32_2x32_2x32_2x32(uint8_t *mask, uint8_t *src, uint8_t *dst, void *scalar_ptr, void *stream); +void LaunchTSELS_uint8_uint16_2x32_2x16_2x32_2x32(uint16_t *mask, uint8_t *src, uint8_t *dst, void *scalar_ptr, void *stream); +void LaunchTSELS_uint8_uint32_2x32_2x8_2x32_2x32(uint32_t *mask, uint8_t *src, uint8_t *dst, void *scalar_ptr, void *stream); +void LaunchTSELS_uint16_uint8_2x16_2x32_2x16_2x16(uint8_t *mask, uint16_t *src, uint16_t *dst, void *scalar_ptr, void *stream); +void LaunchTSELS_uint16_uint16_2x16_2x16_2x16_2x16(uint16_t *mask, uint16_t *src, uint16_t *dst, void *scalar_ptr, void *stream); +void LaunchTSELS_uint16_uint32_2x16_2x8_2x16_2x16(uint32_t *mask, uint16_t *src, uint16_t *dst, void *scalar_ptr, void *stream); +void LaunchTSELS_uint32_uint8_2x8_2x32_2x8_2x8(uint8_t *mask, uint32_t *src, uint32_t *dst, void *scalar_ptr, void *stream); +void LaunchTSELS_uint32_uint16_2x8_2x16_2x8_2x8(uint16_t *mask, uint32_t *src, uint32_t *dst, void *scalar_ptr, void *stream); +void LaunchTSELS_uint32_uint32_2x8_2x8_2x8_2x8(uint32_t *mask, uint32_t *src, uint32_t *dst, void *scalar_ptr, void *stream); +void LaunchTSELS_f16_uint8_2x16_2x32_2x16_2x16(uint8_t *mask, uint16_t *src, uint16_t *dst, void *scalar_ptr, void *stream); +void LaunchTSELS_f16_uint16_2x16_2x16_2x16_2x16(uint16_t *mask, uint16_t *src, uint16_t *dst, void *scalar_ptr, void *stream); +void LaunchTSELS_f16_uint32_2x16_2x8_2x16_2x16(uint32_t *mask, uint16_t *src, uint16_t *dst, void *scalar_ptr, void *stream); +void LaunchTSELS_f32_uint8_2x8_2x32_2x8_2x8(uint8_t *mask, float *src, float *dst, void *scalar_ptr, void *stream); +void LaunchTSELS_f32_uint16_2x8_2x16_2x8_2x8(uint16_t *mask, float *src, float *dst, void *scalar_ptr, void *stream); +void LaunchTSELS_f32_uint32_2x8_2x8_2x8_2x8(uint32_t *mask, float *src, float *dst, void *scalar_ptr, void *stream); +void LaunchTSELS_uint8_uint8_2x32_2x64_2x128_2x31(uint8_t *mask, uint8_t *src, uint8_t *dst, void *scalar_ptr, void *stream); +void LaunchTSELS_uint16_uint8_2x32_2x64_2x128_2x31(uint8_t *mask, uint16_t *src, uint16_t *dst, void *scalar_ptr, void *stream); +void LaunchTSELS_f32_uint8_2x32_2x64_2x128_2x31(uint8_t *mask, float *src, float *dst, void *scalar_ptr, void *stream); +void LaunchTSELS_uint8_uint8_32x672_32x96_32x672_32x666(uint8_t *mask, uint8_t *src, uint8_t *dst, void *scalar_ptr, void *stream); +void LaunchTSELS_f16_uint8_32x672_32x96_32x672_32x666(uint8_t *mask, uint16_t *src, uint16_t *dst, void *scalar_ptr, void *stream); +void LaunchTSELS_f32_uint8_32x672_32x96_32x672_32x666(uint8_t *mask, float *src, float *dst, void *scalar_ptr, void *stream); +void LaunchTSELS_f32_uint8_1x8192_1x4096_1x8192_1x8192(uint8_t *mask, float *src, float *dst, void *scalar_ptr, void *stream); + +struct TestCase { + const char *name; + void (*launch)(void*, void*, void*, void*, void*); + size_t dstRows, dstCols; + size_t maskRows, maskCols; + size_t srcRows, srcCols; + size_t validRows, validCols; + size_t dstElemSize; + size_t maskElemSize; + size_t srcElemSize; +}; + +static const TestCase kCases[] = { + {"uint8_uint8_2x32_2x32_2x32_2x32", (void(*)(void*,void*,void*,void*,void*))LaunchTSELS_uint8_uint8_2x32_2x32_2x32_2x32, 2, 32, 2, 32, 2, 32, 2, 32, 1, 1, 1}, + {"uint8_uint16_2x32_2x16_2x32_2x32", (void(*)(void*,void*,void*,void*,void*))LaunchTSELS_uint8_uint16_2x32_2x16_2x32_2x32, 2, 32, 2, 16, 2, 32, 2, 32, 1, 2, 1}, + {"uint8_uint32_2x32_2x8_2x32_2x32", (void(*)(void*,void*,void*,void*,void*))LaunchTSELS_uint8_uint32_2x32_2x8_2x32_2x32, 2, 32, 2, 8, 2, 32, 2, 32, 1, 4, 1}, + {"uint16_uint8_2x16_2x32_2x16_2x16", (void(*)(void*,void*,void*,void*,void*))LaunchTSELS_uint16_uint8_2x16_2x32_2x16_2x16, 2, 16, 2, 32, 2, 16, 2, 16, 2, 1, 2}, + {"uint16_uint16_2x16_2x16_2x16_2x16", (void(*)(void*,void*,void*,void*,void*))LaunchTSELS_uint16_uint16_2x16_2x16_2x16_2x16, 2, 16, 2, 16, 2, 16, 2, 16, 2, 2, 2}, + {"uint16_uint32_2x16_2x8_2x16_2x16", (void(*)(void*,void*,void*,void*,void*))LaunchTSELS_uint16_uint32_2x16_2x8_2x16_2x16, 2, 16, 2, 8, 2, 16, 2, 16, 2, 4, 2}, + {"uint32_uint8_2x8_2x32_2x8_2x8", (void(*)(void*,void*,void*,void*,void*))LaunchTSELS_uint32_uint8_2x8_2x32_2x8_2x8, 2, 8, 2, 32, 2, 8, 2, 8, 4, 1, 4}, + {"uint32_uint16_2x8_2x16_2x8_2x8", (void(*)(void*,void*,void*,void*,void*))LaunchTSELS_uint32_uint16_2x8_2x16_2x8_2x8, 2, 8, 2, 16, 2, 8, 2, 8, 4, 2, 4}, + {"uint32_uint32_2x8_2x8_2x8_2x8", (void(*)(void*,void*,void*,void*,void*))LaunchTSELS_uint32_uint32_2x8_2x8_2x8_2x8, 2, 8, 2, 8, 2, 8, 2, 8, 4, 4, 4}, + {"f16_uint8_2x16_2x32_2x16_2x16", (void(*)(void*,void*,void*,void*,void*))LaunchTSELS_f16_uint8_2x16_2x32_2x16_2x16, 2, 16, 2, 32, 2, 16, 2, 16, 2, 1, 2}, + {"f16_uint16_2x16_2x16_2x16_2x16", (void(*)(void*,void*,void*,void*,void*))LaunchTSELS_f16_uint16_2x16_2x16_2x16_2x16, 2, 16, 2, 16, 2, 16, 2, 16, 2, 2, 2}, + {"f16_uint32_2x16_2x8_2x16_2x16", (void(*)(void*,void*,void*,void*,void*))LaunchTSELS_f16_uint32_2x16_2x8_2x16_2x16, 2, 16, 2, 8, 2, 16, 2, 16, 2, 4, 2}, + {"f32_uint8_2x8_2x32_2x8_2x8", (void(*)(void*,void*,void*,void*,void*))LaunchTSELS_f32_uint8_2x8_2x32_2x8_2x8, 2, 8, 2, 32, 2, 8, 2, 8, 4, 1, 4}, + {"f32_uint16_2x8_2x16_2x8_2x8", (void(*)(void*,void*,void*,void*,void*))LaunchTSELS_f32_uint16_2x8_2x16_2x8_2x8, 2, 8, 2, 16, 2, 8, 2, 8, 4, 2, 4}, + {"f32_uint32_2x8_2x8_2x8_2x8", (void(*)(void*,void*,void*,void*,void*))LaunchTSELS_f32_uint32_2x8_2x8_2x8_2x8, 2, 8, 2, 8, 2, 8, 2, 8, 4, 4, 4}, + {"uint8_uint8_2x32_2x64_2x128_2x31", (void(*)(void*,void*,void*,void*,void*))LaunchTSELS_uint8_uint8_2x32_2x64_2x128_2x31, 2, 32, 2, 64, 2, 128, 2, 31, 1, 1, 1}, + {"uint16_uint8_2x32_2x64_2x128_2x31", (void(*)(void*,void*,void*,void*,void*))LaunchTSELS_uint16_uint8_2x32_2x64_2x128_2x31, 2, 32, 2, 64, 2, 128, 2, 31, 2, 1, 2}, + {"f32_uint8_2x32_2x64_2x128_2x31", (void(*)(void*,void*,void*,void*,void*))LaunchTSELS_f32_uint8_2x32_2x64_2x128_2x31, 2, 32, 2, 64, 2, 128, 2, 31, 4, 1, 4}, + {"uint8_uint8_32x672_32x96_32x672_32x666", (void(*)(void*,void*,void*,void*,void*))LaunchTSELS_uint8_uint8_32x672_32x96_32x672_32x666, 32, 672, 32, 96, 32, 672, 32, 666, 1, 1, 1}, + {"f16_uint8_32x672_32x96_32x672_32x666", (void(*)(void*,void*,void*,void*,void*))LaunchTSELS_f16_uint8_32x672_32x96_32x672_32x666, 32, 672, 32, 96, 32, 672, 32, 666, 2, 1, 2}, + {"f32_uint8_32x672_32x96_32x672_32x666", (void(*)(void*,void*,void*,void*,void*))LaunchTSELS_f32_uint8_32x672_32x96_32x672_32x666, 32, 672, 32, 96, 32, 672, 32, 666, 4, 1, 4}, + {"f32_uint8_1x8192_1x4096_1x8192_1x8192", (void(*)(void*,void*,void*,void*,void*))LaunchTSELS_f32_uint8_1x8192_1x4096_1x8192_1x8192, 1, 8192, 1, 4096, 1, 8192, 1, 8192, 4, 1, 4}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + size_t dstFileSize = tc.dstRows * tc.dstCols * tc.dstElemSize; + size_t maskFileSize = tc.maskRows * tc.maskCols * tc.maskElemSize; + size_t srcFileSize = tc.srcRows * tc.srcCols * tc.srcElemSize; + size_t scalarFileSize = tc.dstElemSize; + + std::printf("[INFO] === case: %s (dst=%zux%zu, mask=%zux%zu, src=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.dstRows, tc.dstCols, tc.maskRows, tc.maskCols, tc.srcRows, tc.srcCols, tc.validRows, tc.validCols); + + std::string caseDir = std::string("./") + tc.name; + const size_t maskFileSizeBuf = maskFileSize; + const size_t srcFileSizeBuf = srcFileSize; + const size_t scalarFileSizeBuf = scalarFileSize; + + void *maskHost = nullptr, *srcHost = nullptr, *dstHost = nullptr, *scalarHost = nullptr; + void *maskDevice = nullptr, *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&maskHost, maskFileSize); + aclrtMallocHost(&srcHost, srcFileSize); + aclrtMallocHost(&dstHost, dstFileSize); + aclrtMallocHost(&scalarHost, scalarFileSize); + + aclrtMalloc(&maskDevice, maskFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&srcDevice, srcFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + memset(dstHost, 0, dstFileSize); + + if (!ReadFile(caseDir + "/mask.bin", maskFileSize, maskHost, maskFileSizeBuf)) { + std::fprintf(stderr, "[ERROR] failed to read %s/mask.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile(caseDir + "/input1.bin", srcFileSize, srcHost, srcFileSizeBuf)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile(caseDir + "/input2.bin", scalarFileSize, scalarHost, scalarFileSizeBuf)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(maskDevice, maskFileSize, maskHost, maskFileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(srcDevice, srcFileSize, srcHost, srcFileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(maskDevice, srcDevice, dstDevice, scalarHost, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (maskDevice != nullptr) aclrtFree(maskDevice); + if (srcDevice != nullptr) aclrtFree(srcDevice); + if (dstDevice != nullptr) aclrtFree(dstDevice); + if (maskHost != nullptr) aclrtFreeHost(maskHost); + if (srcHost != nullptr) aclrtFreeHost(srcHost); + if (dstHost != nullptr) aclrtFreeHost(dstHost); + if (scalarHost != nullptr) aclrtFreeHost(scalarHost); + + if (rc == 0) std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsels/tsels.pto b/test/tilelang_st/npu/a5/src/st/testcase/tsels/tsels.pto new file mode 100644 index 000000000..5ffbb6ce6 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsels/tsels.pto @@ -0,0 +1,654 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tsels: tload(mask) + tload(src) + tsels(mask,src,tmp,scalar)->dst + tstore(dst) +// 22 cases from pto-isa tests. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @TSELS_uint8_uint8_2x32_2x32_2x32_2x32(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + + %mask_view = pto.make_tensor_view %mask_ptr, shape = [%c1, %c1, %c1, %c2, %c32], strides = [%c64, %c64, %c64, %c32, %c1] : !pto.tensor_view<1x1x1x2x32xi8> + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c2, %c32], strides = [%c64, %c64, %c64, %c32, %c1] : !pto.tensor_view<1x1x1x2x32xi8> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c2, %c32], strides = [%c64, %c64, %c64, %c32, %c1] : !pto.tensor_view<1x1x1x2x32xi8> + + %mask_part = pto.partition_view %mask_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c32] : !pto.tensor_view<1x1x1x2x32xi8> -> !pto.partition_tensor_view<1x1x1x2x32xi8> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c32] : !pto.tensor_view<1x1x1x2x32xi8> -> !pto.partition_tensor_view<1x1x1x2x32xi8> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c32] : !pto.tensor_view<1x1x1x2x32xi8> -> !pto.partition_tensor_view<1x1x1x2x32xi8> + + %mask_tile = pto.alloc_tile : !pto.tile_buf + %src_tile = pto.alloc_tile : !pto.tile_buf + %tmp_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x32xi8>) outs(%mask_tile : !pto.tile_buf) + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x32xi8>) outs(%src_tile : !pto.tile_buf) + %scalar_i8 = arith.trunci %scalar : i32 to i8 + pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar_i8 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, i8) outs(%dst_tile : !pto.tile_buf) + pto.tstore ins(%dst_tile : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x32xi8>) + return + } + + func.func @TSELS_uint8_uint16_2x32_2x16_2x32_2x32(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + + %mask_view = pto.make_tensor_view %mask_ptr, shape = [%c1, %c1, %c1, %c2, %c16], strides = [%c32, %c32, %c32, %c16, %c1] : !pto.tensor_view<1x1x1x2x16xi16> + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c2, %c32], strides = [%c64, %c64, %c64, %c32, %c1] : !pto.tensor_view<1x1x1x2x32xi8> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c2, %c32], strides = [%c64, %c64, %c64, %c32, %c1] : !pto.tensor_view<1x1x1x2x32xi8> + + %mask_part = pto.partition_view %mask_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c16] : !pto.tensor_view<1x1x1x2x16xi16> -> !pto.partition_tensor_view<1x1x1x2x16xi16> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c32] : !pto.tensor_view<1x1x1x2x32xi8> -> !pto.partition_tensor_view<1x1x1x2x32xi8> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c32] : !pto.tensor_view<1x1x1x2x32xi8> -> !pto.partition_tensor_view<1x1x1x2x32xi8> + + %mask_tile = pto.alloc_tile : !pto.tile_buf + %src_tile = pto.alloc_tile : !pto.tile_buf + %tmp_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x16xi16>) outs(%mask_tile : !pto.tile_buf) + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x32xi8>) outs(%src_tile : !pto.tile_buf) + %scalar_i8 = arith.trunci %scalar : i32 to i8 + pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar_i8 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, i8) outs(%dst_tile : !pto.tile_buf) + pto.tstore ins(%dst_tile : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x32xi8>) + return + } + + func.func @TSELS_uint8_uint32_2x32_2x8_2x32_2x32(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + + %mask_view = pto.make_tensor_view %mask_ptr, shape = [%c1, %c1, %c1, %c2, %c8], strides = [%c16, %c16, %c16, %c8, %c1] : !pto.tensor_view<1x1x1x2x8xi32> + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c2, %c32], strides = [%c64, %c64, %c64, %c32, %c1] : !pto.tensor_view<1x1x1x2x32xi8> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c2, %c32], strides = [%c64, %c64, %c64, %c32, %c1] : !pto.tensor_view<1x1x1x2x32xi8> + + %mask_part = pto.partition_view %mask_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c8] : !pto.tensor_view<1x1x1x2x8xi32> -> !pto.partition_tensor_view<1x1x1x2x8xi32> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c32] : !pto.tensor_view<1x1x1x2x32xi8> -> !pto.partition_tensor_view<1x1x1x2x32xi8> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c32] : !pto.tensor_view<1x1x1x2x32xi8> -> !pto.partition_tensor_view<1x1x1x2x32xi8> + + %mask_tile = pto.alloc_tile : !pto.tile_buf + %src_tile = pto.alloc_tile : !pto.tile_buf + %tmp_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x8xi32>) outs(%mask_tile : !pto.tile_buf) + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x32xi8>) outs(%src_tile : !pto.tile_buf) + %scalar_i8 = arith.trunci %scalar : i32 to i8 + pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar_i8 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, i8) outs(%dst_tile : !pto.tile_buf) + pto.tstore ins(%dst_tile : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x32xi8>) + return + } + + func.func @TSELS_uint16_uint8_2x16_2x32_2x16_2x16(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + + %mask_view = pto.make_tensor_view %mask_ptr, shape = [%c1, %c1, %c1, %c2, %c32], strides = [%c64, %c64, %c64, %c32, %c1] : !pto.tensor_view<1x1x1x2x32xi8> + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c2, %c16], strides = [%c32, %c32, %c32, %c16, %c1] : !pto.tensor_view<1x1x1x2x16xi16> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c2, %c16], strides = [%c32, %c32, %c32, %c16, %c1] : !pto.tensor_view<1x1x1x2x16xi16> + + %mask_part = pto.partition_view %mask_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c32] : !pto.tensor_view<1x1x1x2x32xi8> -> !pto.partition_tensor_view<1x1x1x2x32xi8> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c16] : !pto.tensor_view<1x1x1x2x16xi16> -> !pto.partition_tensor_view<1x1x1x2x16xi16> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c16] : !pto.tensor_view<1x1x1x2x16xi16> -> !pto.partition_tensor_view<1x1x1x2x16xi16> + + %mask_tile = pto.alloc_tile : !pto.tile_buf + %src_tile = pto.alloc_tile : !pto.tile_buf + %tmp_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x32xi8>) outs(%mask_tile : !pto.tile_buf) + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x16xi16>) outs(%src_tile : !pto.tile_buf) + %scalar_i16 = arith.trunci %scalar : i32 to i16 + pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar_i16 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, i16) outs(%dst_tile : !pto.tile_buf) + pto.tstore ins(%dst_tile : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x16xi16>) + return + } + + func.func @TSELS_uint16_uint16_2x16_2x16_2x16_2x16(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + + %mask_view = pto.make_tensor_view %mask_ptr, shape = [%c1, %c1, %c1, %c2, %c16], strides = [%c32, %c32, %c32, %c16, %c1] : !pto.tensor_view<1x1x1x2x16xi16> + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c2, %c16], strides = [%c32, %c32, %c32, %c16, %c1] : !pto.tensor_view<1x1x1x2x16xi16> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c2, %c16], strides = [%c32, %c32, %c32, %c16, %c1] : !pto.tensor_view<1x1x1x2x16xi16> + + %mask_part = pto.partition_view %mask_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c16] : !pto.tensor_view<1x1x1x2x16xi16> -> !pto.partition_tensor_view<1x1x1x2x16xi16> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c16] : !pto.tensor_view<1x1x1x2x16xi16> -> !pto.partition_tensor_view<1x1x1x2x16xi16> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c16] : !pto.tensor_view<1x1x1x2x16xi16> -> !pto.partition_tensor_view<1x1x1x2x16xi16> + + %mask_tile = pto.alloc_tile : !pto.tile_buf + %src_tile = pto.alloc_tile : !pto.tile_buf + %tmp_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x16xi16>) outs(%mask_tile : !pto.tile_buf) + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x16xi16>) outs(%src_tile : !pto.tile_buf) + %scalar_i16 = arith.trunci %scalar : i32 to i16 + pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar_i16 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, i16) outs(%dst_tile : !pto.tile_buf) + pto.tstore ins(%dst_tile : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x16xi16>) + return + } + + func.func @TSELS_uint16_uint32_2x16_2x8_2x16_2x16(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + + %mask_view = pto.make_tensor_view %mask_ptr, shape = [%c1, %c1, %c1, %c2, %c8], strides = [%c16, %c16, %c16, %c8, %c1] : !pto.tensor_view<1x1x1x2x8xi32> + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c2, %c16], strides = [%c32, %c32, %c32, %c16, %c1] : !pto.tensor_view<1x1x1x2x16xi16> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c2, %c16], strides = [%c32, %c32, %c32, %c16, %c1] : !pto.tensor_view<1x1x1x2x16xi16> + + %mask_part = pto.partition_view %mask_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c8] : !pto.tensor_view<1x1x1x2x8xi32> -> !pto.partition_tensor_view<1x1x1x2x8xi32> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c16] : !pto.tensor_view<1x1x1x2x16xi16> -> !pto.partition_tensor_view<1x1x1x2x16xi16> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c16] : !pto.tensor_view<1x1x1x2x16xi16> -> !pto.partition_tensor_view<1x1x1x2x16xi16> + + %mask_tile = pto.alloc_tile : !pto.tile_buf + %src_tile = pto.alloc_tile : !pto.tile_buf + %tmp_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x8xi32>) outs(%mask_tile : !pto.tile_buf) + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x16xi16>) outs(%src_tile : !pto.tile_buf) + %scalar_i16 = arith.trunci %scalar : i32 to i16 + pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar_i16 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, i16) outs(%dst_tile : !pto.tile_buf) + pto.tstore ins(%dst_tile : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x16xi16>) + return + } + + func.func @TSELS_uint32_uint8_2x8_2x32_2x8_2x8(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + + %mask_view = pto.make_tensor_view %mask_ptr, shape = [%c1, %c1, %c1, %c2, %c32], strides = [%c64, %c64, %c64, %c32, %c1] : !pto.tensor_view<1x1x1x2x32xi8> + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c2, %c8], strides = [%c16, %c16, %c16, %c8, %c1] : !pto.tensor_view<1x1x1x2x8xi32> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c2, %c8], strides = [%c16, %c16, %c16, %c8, %c1] : !pto.tensor_view<1x1x1x2x8xi32> + + %mask_part = pto.partition_view %mask_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c32] : !pto.tensor_view<1x1x1x2x32xi8> -> !pto.partition_tensor_view<1x1x1x2x32xi8> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c8] : !pto.tensor_view<1x1x1x2x8xi32> -> !pto.partition_tensor_view<1x1x1x2x8xi32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c8] : !pto.tensor_view<1x1x1x2x8xi32> -> !pto.partition_tensor_view<1x1x1x2x8xi32> + + %mask_tile = pto.alloc_tile : !pto.tile_buf + %src_tile = pto.alloc_tile : !pto.tile_buf + %tmp_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x32xi8>) outs(%mask_tile : !pto.tile_buf) + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x8xi32>) outs(%src_tile : !pto.tile_buf) + pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, i32) outs(%dst_tile : !pto.tile_buf) + pto.tstore ins(%dst_tile : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x8xi32>) + return + } + + func.func @TSELS_uint32_uint16_2x8_2x16_2x8_2x8(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + + %mask_view = pto.make_tensor_view %mask_ptr, shape = [%c1, %c1, %c1, %c2, %c16], strides = [%c32, %c32, %c32, %c16, %c1] : !pto.tensor_view<1x1x1x2x16xi16> + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c2, %c8], strides = [%c16, %c16, %c16, %c8, %c1] : !pto.tensor_view<1x1x1x2x8xi32> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c2, %c8], strides = [%c16, %c16, %c16, %c8, %c1] : !pto.tensor_view<1x1x1x2x8xi32> + + %mask_part = pto.partition_view %mask_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c16] : !pto.tensor_view<1x1x1x2x16xi16> -> !pto.partition_tensor_view<1x1x1x2x16xi16> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c8] : !pto.tensor_view<1x1x1x2x8xi32> -> !pto.partition_tensor_view<1x1x1x2x8xi32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c8] : !pto.tensor_view<1x1x1x2x8xi32> -> !pto.partition_tensor_view<1x1x1x2x8xi32> + + %mask_tile = pto.alloc_tile : !pto.tile_buf + %src_tile = pto.alloc_tile : !pto.tile_buf + %tmp_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x16xi16>) outs(%mask_tile : !pto.tile_buf) + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x8xi32>) outs(%src_tile : !pto.tile_buf) + pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, i32) outs(%dst_tile : !pto.tile_buf) + pto.tstore ins(%dst_tile : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x8xi32>) + return + } + + func.func @TSELS_uint32_uint32_2x8_2x8_2x8_2x8(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + + %mask_view = pto.make_tensor_view %mask_ptr, shape = [%c1, %c1, %c1, %c2, %c8], strides = [%c16, %c16, %c16, %c8, %c1] : !pto.tensor_view<1x1x1x2x8xi32> + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c2, %c8], strides = [%c16, %c16, %c16, %c8, %c1] : !pto.tensor_view<1x1x1x2x8xi32> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c2, %c8], strides = [%c16, %c16, %c16, %c8, %c1] : !pto.tensor_view<1x1x1x2x8xi32> + + %mask_part = pto.partition_view %mask_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c8] : !pto.tensor_view<1x1x1x2x8xi32> -> !pto.partition_tensor_view<1x1x1x2x8xi32> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c8] : !pto.tensor_view<1x1x1x2x8xi32> -> !pto.partition_tensor_view<1x1x1x2x8xi32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c8] : !pto.tensor_view<1x1x1x2x8xi32> -> !pto.partition_tensor_view<1x1x1x2x8xi32> + + %mask_tile = pto.alloc_tile : !pto.tile_buf + %src_tile = pto.alloc_tile : !pto.tile_buf + %tmp_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x8xi32>) outs(%mask_tile : !pto.tile_buf) + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x8xi32>) outs(%src_tile : !pto.tile_buf) + pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, i32) outs(%dst_tile : !pto.tile_buf) + pto.tstore ins(%dst_tile : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x8xi32>) + return + } + + func.func @TSELS_f16_uint8_2x16_2x32_2x16_2x16(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + + %mask_view = pto.make_tensor_view %mask_ptr, shape = [%c1, %c1, %c1, %c2, %c32], strides = [%c64, %c64, %c64, %c32, %c1] : !pto.tensor_view<1x1x1x2x32xi8> + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c2, %c16], strides = [%c32, %c32, %c32, %c16, %c1] : !pto.tensor_view<1x1x1x2x16xi16> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c2, %c16], strides = [%c32, %c32, %c32, %c16, %c1] : !pto.tensor_view<1x1x1x2x16xi16> + + %mask_part = pto.partition_view %mask_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c32] : !pto.tensor_view<1x1x1x2x32xi8> -> !pto.partition_tensor_view<1x1x1x2x32xi8> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c16] : !pto.tensor_view<1x1x1x2x16xi16> -> !pto.partition_tensor_view<1x1x1x2x16xi16> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c16] : !pto.tensor_view<1x1x1x2x16xi16> -> !pto.partition_tensor_view<1x1x1x2x16xi16> + + %mask_tile = pto.alloc_tile : !pto.tile_buf + %src_tile = pto.alloc_tile : !pto.tile_buf + %tmp_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x32xi8>) outs(%mask_tile : !pto.tile_buf) + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x16xi16>) outs(%src_tile : !pto.tile_buf) + %scalar_i16 = arith.trunci %scalar : i32 to i16 + pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar_i16 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, i16) outs(%dst_tile : !pto.tile_buf) + pto.tstore ins(%dst_tile : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x16xi16>) + return + } + + func.func @TSELS_f16_uint16_2x16_2x16_2x16_2x16(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + + %mask_view = pto.make_tensor_view %mask_ptr, shape = [%c1, %c1, %c1, %c2, %c16], strides = [%c32, %c32, %c32, %c16, %c1] : !pto.tensor_view<1x1x1x2x16xi16> + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c2, %c16], strides = [%c32, %c32, %c32, %c16, %c1] : !pto.tensor_view<1x1x1x2x16xi16> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c2, %c16], strides = [%c32, %c32, %c32, %c16, %c1] : !pto.tensor_view<1x1x1x2x16xi16> + + %mask_part = pto.partition_view %mask_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c16] : !pto.tensor_view<1x1x1x2x16xi16> -> !pto.partition_tensor_view<1x1x1x2x16xi16> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c16] : !pto.tensor_view<1x1x1x2x16xi16> -> !pto.partition_tensor_view<1x1x1x2x16xi16> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c16] : !pto.tensor_view<1x1x1x2x16xi16> -> !pto.partition_tensor_view<1x1x1x2x16xi16> + + %mask_tile = pto.alloc_tile : !pto.tile_buf + %src_tile = pto.alloc_tile : !pto.tile_buf + %tmp_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x16xi16>) outs(%mask_tile : !pto.tile_buf) + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x16xi16>) outs(%src_tile : !pto.tile_buf) + %scalar_i16 = arith.trunci %scalar : i32 to i16 + pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar_i16 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, i16) outs(%dst_tile : !pto.tile_buf) + pto.tstore ins(%dst_tile : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x16xi16>) + return + } + + func.func @TSELS_f16_uint32_2x16_2x8_2x16_2x16(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + + %mask_view = pto.make_tensor_view %mask_ptr, shape = [%c1, %c1, %c1, %c2, %c8], strides = [%c16, %c16, %c16, %c8, %c1] : !pto.tensor_view<1x1x1x2x8xi32> + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c2, %c16], strides = [%c32, %c32, %c32, %c16, %c1] : !pto.tensor_view<1x1x1x2x16xi16> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c2, %c16], strides = [%c32, %c32, %c32, %c16, %c1] : !pto.tensor_view<1x1x1x2x16xi16> + + %mask_part = pto.partition_view %mask_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c8] : !pto.tensor_view<1x1x1x2x8xi32> -> !pto.partition_tensor_view<1x1x1x2x8xi32> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c16] : !pto.tensor_view<1x1x1x2x16xi16> -> !pto.partition_tensor_view<1x1x1x2x16xi16> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c16] : !pto.tensor_view<1x1x1x2x16xi16> -> !pto.partition_tensor_view<1x1x1x2x16xi16> + + %mask_tile = pto.alloc_tile : !pto.tile_buf + %src_tile = pto.alloc_tile : !pto.tile_buf + %tmp_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x8xi32>) outs(%mask_tile : !pto.tile_buf) + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x16xi16>) outs(%src_tile : !pto.tile_buf) + %scalar_i16 = arith.trunci %scalar : i32 to i16 + pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar_i16 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, i16) outs(%dst_tile : !pto.tile_buf) + pto.tstore ins(%dst_tile : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x16xi16>) + return + } + + func.func @TSELS_f32_uint8_2x8_2x32_2x8_2x8(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + + %mask_view = pto.make_tensor_view %mask_ptr, shape = [%c1, %c1, %c1, %c2, %c32], strides = [%c64, %c64, %c64, %c32, %c1] : !pto.tensor_view<1x1x1x2x32xi8> + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c2, %c8], strides = [%c16, %c16, %c16, %c8, %c1] : !pto.tensor_view<1x1x1x2x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c2, %c8], strides = [%c16, %c16, %c16, %c8, %c1] : !pto.tensor_view<1x1x1x2x8xf32> + + %mask_part = pto.partition_view %mask_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c32] : !pto.tensor_view<1x1x1x2x32xi8> -> !pto.partition_tensor_view<1x1x1x2x32xi8> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c8] : !pto.tensor_view<1x1x1x2x8xf32> -> !pto.partition_tensor_view<1x1x1x2x8xf32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c8] : !pto.tensor_view<1x1x1x2x8xf32> -> !pto.partition_tensor_view<1x1x1x2x8xf32> + + %mask_tile = pto.alloc_tile : !pto.tile_buf + %src_tile = pto.alloc_tile : !pto.tile_buf + %tmp_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x32xi8>) outs(%mask_tile : !pto.tile_buf) + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x8xf32>) outs(%src_tile : !pto.tile_buf) + %scalar_f32 = arith.bitcast %scalar : i32 to f32 + pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar_f32 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, f32) outs(%dst_tile : !pto.tile_buf) + pto.tstore ins(%dst_tile : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x8xf32>) + return + } + + func.func @TSELS_f32_uint16_2x8_2x16_2x8_2x8(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + + %mask_view = pto.make_tensor_view %mask_ptr, shape = [%c1, %c1, %c1, %c2, %c16], strides = [%c32, %c32, %c32, %c16, %c1] : !pto.tensor_view<1x1x1x2x16xi16> + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c2, %c8], strides = [%c16, %c16, %c16, %c8, %c1] : !pto.tensor_view<1x1x1x2x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c2, %c8], strides = [%c16, %c16, %c16, %c8, %c1] : !pto.tensor_view<1x1x1x2x8xf32> + + %mask_part = pto.partition_view %mask_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c16] : !pto.tensor_view<1x1x1x2x16xi16> -> !pto.partition_tensor_view<1x1x1x2x16xi16> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c8] : !pto.tensor_view<1x1x1x2x8xf32> -> !pto.partition_tensor_view<1x1x1x2x8xf32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c8] : !pto.tensor_view<1x1x1x2x8xf32> -> !pto.partition_tensor_view<1x1x1x2x8xf32> + + %mask_tile = pto.alloc_tile : !pto.tile_buf + %src_tile = pto.alloc_tile : !pto.tile_buf + %tmp_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x16xi16>) outs(%mask_tile : !pto.tile_buf) + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x8xf32>) outs(%src_tile : !pto.tile_buf) + %scalar_f32 = arith.bitcast %scalar : i32 to f32 + pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar_f32 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, f32) outs(%dst_tile : !pto.tile_buf) + pto.tstore ins(%dst_tile : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x8xf32>) + return + } + + func.func @TSELS_f32_uint32_2x8_2x8_2x8_2x8(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + + %mask_view = pto.make_tensor_view %mask_ptr, shape = [%c1, %c1, %c1, %c2, %c8], strides = [%c16, %c16, %c16, %c8, %c1] : !pto.tensor_view<1x1x1x2x8xi32> + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c2, %c8], strides = [%c16, %c16, %c16, %c8, %c1] : !pto.tensor_view<1x1x1x2x8xf32> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c2, %c8], strides = [%c16, %c16, %c16, %c8, %c1] : !pto.tensor_view<1x1x1x2x8xf32> + + %mask_part = pto.partition_view %mask_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c8] : !pto.tensor_view<1x1x1x2x8xi32> -> !pto.partition_tensor_view<1x1x1x2x8xi32> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c8] : !pto.tensor_view<1x1x1x2x8xf32> -> !pto.partition_tensor_view<1x1x1x2x8xf32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c8] : !pto.tensor_view<1x1x1x2x8xf32> -> !pto.partition_tensor_view<1x1x1x2x8xf32> + + %mask_tile = pto.alloc_tile : !pto.tile_buf + %src_tile = pto.alloc_tile : !pto.tile_buf + %tmp_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x8xi32>) outs(%mask_tile : !pto.tile_buf) + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x8xf32>) outs(%src_tile : !pto.tile_buf) + %scalar_f32 = arith.bitcast %scalar : i32 to f32 + pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar_f32 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, f32) outs(%dst_tile : !pto.tile_buf) + pto.tstore ins(%dst_tile : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x8xf32>) + return + } + + func.func @TSELS_uint8_uint8_2x32_2x64_2x128_2x31(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c31 = arith.constant 31 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + + %mask_view = pto.make_tensor_view %mask_ptr, shape = [%c1, %c1, %c1, %c2, %c64], strides = [%c64, %c64, %c64, %c64, %c1] : !pto.tensor_view<1x1x1x2x64xi8> + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c2, %c128], strides = [%c128, %c128, %c128, %c128, %c1] : !pto.tensor_view<1x1x1x2x128xi8> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c2, %c32], strides = [%c32, %c32, %c32, %c32, %c1] : !pto.tensor_view<1x1x1x2x32xi8> + + %mask_part = pto.partition_view %mask_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c64] : !pto.tensor_view<1x1x1x2x64xi8> -> !pto.partition_tensor_view<1x1x1x2x64xi8> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c128] : !pto.tensor_view<1x1x1x2x128xi8> -> !pto.partition_tensor_view<1x1x1x2x128xi8> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c31] : !pto.tensor_view<1x1x1x2x32xi8> -> !pto.partition_tensor_view<1x1x1x2x31xi8> + + %mask_tile = pto.alloc_tile : !pto.tile_buf + %src_tile = pto.alloc_tile : !pto.tile_buf + %tmp_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x64xi8>) outs(%mask_tile : !pto.tile_buf) + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x128xi8>) outs(%src_tile : !pto.tile_buf) + %scalar_i8 = arith.trunci %scalar : i32 to i8 + pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar_i8 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, i8) outs(%dst_tile : !pto.tile_buf) + pto.tstore ins(%dst_tile : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x31xi8>) + return + } + + func.func @TSELS_uint16_uint8_2x32_2x64_2x128_2x31(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c31 = arith.constant 31 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + + %mask_view = pto.make_tensor_view %mask_ptr, shape = [%c1, %c1, %c1, %c2, %c64], strides = [%c64, %c64, %c64, %c64, %c1] : !pto.tensor_view<1x1x1x2x64xi8> + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c2, %c128], strides = [%c256, %c256, %c256, %c128, %c1] : !pto.tensor_view<1x1x1x2x128xi16> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c2, %c32], strides = [%c64, %c64, %c64, %c32, %c1] : !pto.tensor_view<1x1x1x2x32xi16> + + %mask_part = pto.partition_view %mask_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c64] : !pto.tensor_view<1x1x1x2x64xi8> -> !pto.partition_tensor_view<1x1x1x2x64xi8> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c128] : !pto.tensor_view<1x1x1x2x128xi16> -> !pto.partition_tensor_view<1x1x1x2x128xi16> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c31] : !pto.tensor_view<1x1x1x2x32xi16> -> !pto.partition_tensor_view<1x1x1x2x31xi16> + + %mask_tile = pto.alloc_tile : !pto.tile_buf + %src_tile = pto.alloc_tile : !pto.tile_buf + %tmp_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x64xi8>) outs(%mask_tile : !pto.tile_buf) + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x128xi16>) outs(%src_tile : !pto.tile_buf) + %scalar_i16 = arith.trunci %scalar : i32 to i16 + pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar_i16 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, i16) outs(%dst_tile : !pto.tile_buf) + pto.tstore ins(%dst_tile : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x31xi16>) + return + } + + func.func @TSELS_f32_uint8_2x32_2x64_2x128_2x31(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %c31 = arith.constant 31 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + + %mask_view = pto.make_tensor_view %mask_ptr, shape = [%c1, %c1, %c1, %c2, %c64], strides = [%c64, %c64, %c64, %c64, %c1] : !pto.tensor_view<1x1x1x2x64xi8> + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c2, %c128], strides = [%c512, %c512, %c512, %c128, %c1] : !pto.tensor_view<1x1x1x2x128xf32> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c2, %c32], strides = [%c128, %c128, %c128, %c32, %c1] : !pto.tensor_view<1x1x1x2x32xf32> + + %mask_part = pto.partition_view %mask_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c64] : !pto.tensor_view<1x1x1x2x64xi8> -> !pto.partition_tensor_view<1x1x1x2x64xi8> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c128] : !pto.tensor_view<1x1x1x2x128xf32> -> !pto.partition_tensor_view<1x1x1x2x128xf32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c2, %c31] : !pto.tensor_view<1x1x1x2x32xf32> -> !pto.partition_tensor_view<1x1x1x2x31xf32> + + %mask_tile = pto.alloc_tile : !pto.tile_buf + %src_tile = pto.alloc_tile : !pto.tile_buf + %tmp_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x2x64xi8>) outs(%mask_tile : !pto.tile_buf) + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x128xf32>) outs(%src_tile : !pto.tile_buf) + %scalar_f32 = arith.bitcast %scalar : i32 to f32 + pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar_f32 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, f32) outs(%dst_tile : !pto.tile_buf) + pto.tstore ins(%dst_tile : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x31xf32>) + return + } + + func.func @TSELS_uint8_uint8_32x672_32x96_32x672_32x666(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c96 = arith.constant 96 : index + %c672 = arith.constant 672 : index + %c666 = arith.constant 666 : index + + %mask_view = pto.make_tensor_view %mask_ptr, shape = [%c1, %c1, %c1, %c32, %c96], strides = [%c96, %c96, %c96, %c96, %c1] : !pto.tensor_view<1x1x1x32x96xi8> + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c32, %c672], strides = [%c672, %c672, %c672, %c672, %c1] : !pto.tensor_view<1x1x1x32x672xi8> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c32, %c672], strides = [%c672, %c672, %c672, %c672, %c1] : !pto.tensor_view<1x1x1x32x672xi8> + + %mask_part = pto.partition_view %mask_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c32, %c96] : !pto.tensor_view<1x1x1x32x96xi8> -> !pto.partition_tensor_view<1x1x1x32x96xi8> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c32, %c672] : !pto.tensor_view<1x1x1x32x672xi8> -> !pto.partition_tensor_view<1x1x1x32x672xi8> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c32, %c666] : !pto.tensor_view<1x1x1x32x672xi8> -> !pto.partition_tensor_view<1x1x1x32x666xi8> + + %mask_tile = pto.alloc_tile : !pto.tile_buf + %src_tile = pto.alloc_tile : !pto.tile_buf + %tmp_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x32x96xi8>) outs(%mask_tile : !pto.tile_buf) + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x672xi8>) outs(%src_tile : !pto.tile_buf) + %scalar_i8 = arith.trunci %scalar : i32 to i8 + pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar_i8 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, i8) outs(%dst_tile : !pto.tile_buf) + pto.tstore ins(%dst_tile : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x666xi8>) + return + } + + func.func @TSELS_f16_uint8_32x672_32x96_32x672_32x666(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c32 = arith.constant 32 : index + %c96 = arith.constant 96 : index + %c672 = arith.constant 672 : index + %c666 = arith.constant 666 : index + %c1344 = arith.constant 1344 : index + + %mask_view = pto.make_tensor_view %mask_ptr, shape = [%c1, %c1, %c1, %c32, %c96], strides = [%c96, %c96, %c96, %c96, %c1] : !pto.tensor_view<1x1x1x32x96xi8> + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c32, %c672], strides = [%c1344, %c1344, %c1344, %c672, %c1] : !pto.tensor_view<1x1x1x32x672xi16> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c32, %c672], strides = [%c1344, %c1344, %c1344, %c672, %c1] : !pto.tensor_view<1x1x1x32x672xi16> + + %mask_part = pto.partition_view %mask_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c32, %c96] : !pto.tensor_view<1x1x1x32x96xi8> -> !pto.partition_tensor_view<1x1x1x32x96xi8> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c32, %c672] : !pto.tensor_view<1x1x1x32x672xi16> -> !pto.partition_tensor_view<1x1x1x32x672xi16> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c32, %c666] : !pto.tensor_view<1x1x1x32x672xi16> -> !pto.partition_tensor_view<1x1x1x32x666xi16> + + %mask_tile = pto.alloc_tile : !pto.tile_buf + %src_tile = pto.alloc_tile : !pto.tile_buf + %tmp_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x32x96xi8>) outs(%mask_tile : !pto.tile_buf) + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x672xi16>) outs(%src_tile : !pto.tile_buf) + %scalar_i16 = arith.trunci %scalar : i32 to i16 + pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar_i16 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, i16) outs(%dst_tile : !pto.tile_buf) + pto.tstore ins(%dst_tile : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x666xi16>) + return + } + + func.func @TSELS_f32_uint8_32x672_32x96_32x672_32x666(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c32 = arith.constant 32 : index + %c96 = arith.constant 96 : index + %c672 = arith.constant 672 : index + %c666 = arith.constant 666 : index + %c2688 = arith.constant 2688 : index + + %mask_view = pto.make_tensor_view %mask_ptr, shape = [%c1, %c1, %c1, %c32, %c96], strides = [%c96, %c96, %c96, %c96, %c1] : !pto.tensor_view<1x1x1x32x96xi8> + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c32, %c672], strides = [%c2688, %c2688, %c2688, %c672, %c1] : !pto.tensor_view<1x1x1x32x672xf32> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c32, %c672], strides = [%c2688, %c2688, %c2688, %c672, %c1] : !pto.tensor_view<1x1x1x32x672xf32> + + %mask_part = pto.partition_view %mask_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c32, %c96] : !pto.tensor_view<1x1x1x32x96xi8> -> !pto.partition_tensor_view<1x1x1x32x96xi8> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c32, %c672] : !pto.tensor_view<1x1x1x32x672xf32> -> !pto.partition_tensor_view<1x1x1x32x672xf32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c32, %c666] : !pto.tensor_view<1x1x1x32x672xf32> -> !pto.partition_tensor_view<1x1x1x32x666xf32> + + %mask_tile = pto.alloc_tile : !pto.tile_buf + %src_tile = pto.alloc_tile : !pto.tile_buf + %tmp_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x32x96xi8>) outs(%mask_tile : !pto.tile_buf) + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x672xf32>) outs(%src_tile : !pto.tile_buf) + %scalar_f32 = arith.bitcast %scalar : i32 to f32 + pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar_f32 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, f32) outs(%dst_tile : !pto.tile_buf) + pto.tstore ins(%dst_tile : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x666xf32>) + return + } + + func.func @TSELS_f32_uint8_1x8192_1x4096_1x8192_1x8192(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c4096 = arith.constant 4096 : index + %c8192 = arith.constant 8192 : index + + %mask_view = pto.make_tensor_view %mask_ptr, shape = [%c1, %c1, %c1, %c1, %c4096], strides = [%c4096, %c4096, %c4096, %c4096, %c1] : !pto.tensor_view<1x1x1x1x4096xi8> + %src_view = pto.make_tensor_view %src_ptr, shape = [%c1, %c1, %c1, %c1, %c8192], strides = [%c8192, %c8192, %c8192, %c8192, %c1] : !pto.tensor_view<1x1x1x1x8192xf32> + %dst_view = pto.make_tensor_view %dst_ptr, shape = [%c1, %c1, %c1, %c1, %c8192], strides = [%c8192, %c8192, %c8192, %c8192, %c1] : !pto.tensor_view<1x1x1x1x8192xf32> + + %mask_part = pto.partition_view %mask_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c1, %c4096] : !pto.tensor_view<1x1x1x1x4096xi8> -> !pto.partition_tensor_view<1x1x1x1x4096xi8> + %src_part = pto.partition_view %src_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c1, %c8192] : !pto.tensor_view<1x1x1x1x8192xf32> -> !pto.partition_tensor_view<1x1x1x1x8192xf32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0, %c0, %c0, %c0, %c0], sizes = [%c1, %c1, %c1, %c1, %c8192] : !pto.tensor_view<1x1x1x1x8192xf32> -> !pto.partition_tensor_view<1x1x1x1x8192xf32> + + %mask_tile = pto.alloc_tile : !pto.tile_buf + %src_tile = pto.alloc_tile : !pto.tile_buf + %tmp_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%mask_part : !pto.partition_tensor_view<1x1x1x1x4096xi8>) outs(%mask_tile : !pto.tile_buf) + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x8192xf32>) outs(%src_tile : !pto.tile_buf) + %scalar_f32 = arith.bitcast %scalar : i32 to f32 + pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar_f32 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, f32) outs(%dst_tile : !pto.tile_buf) + pto.tstore ins(%dst_tile : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x8192xf32>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshl/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tshl/CMakeLists.txt new file mode 100644 index 000000000..42b3fa0bd --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshl/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tshl) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshl/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tshl/cases.py new file mode 100644 index 000000000..4bc308400 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshl/cases.py @@ -0,0 +1,40 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tshl ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.int32). + - shape: (rows, cols) — allocated tile dimensions. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "i32_16x64", + "dtype": np.int32, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 0, + }, + { + "name": "i32_32x32", + "dtype": np.int32, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 0, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshl/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tshl/compare.py new file mode 100644 index 000000000..4eae3bc07 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshl/compare.py @@ -0,0 +1,48 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshl/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tshl/gen_data.py new file mode 100644 index 000000000..811ffc995 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshl/gen_data.py @@ -0,0 +1,32 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + input1 = np.random.randint(1, 100, size=shape).astype(dtype) + input2 = np.random.randint(0, 8, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + golden[:vr, :vc] = (input1[:vr, :vc] << input2[:vr, :vc]).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshl/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tshl/launch.cpp new file mode 100644 index 000000000..d58d324a4 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshl/launch.cpp @@ -0,0 +1,27 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: i32 16x64 +extern "C" __global__ AICORE void TSHL_i32_16x64(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); + +void LaunchTSHL_i32_16x64(int32_t *a, int32_t *b, int32_t *c, void *stream) { + TSHL_i32_16x64<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); +} + +// Case 1: i32 32x32 +extern "C" __global__ AICORE void TSHL_i32_32x32(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); + +void LaunchTSHL_i32_32x32(int32_t *a, int32_t *b, int32_t *c, void *stream) { + TSHL_i32_32x32<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshl/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tshl/main.cpp new file mode 100644 index 000000000..cb35c3a31 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshl/main.cpp @@ -0,0 +1,145 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tshl ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTSHL_i32_16x64(int32_t *a, int32_t *b, int32_t *c, void *stream); +void LaunchTSHL_i32_32x32(int32_t *a, int32_t *b, int32_t *c, void *stream); + +using LaunchFn = void (*)(int32_t *, int32_t *, int32_t *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"i32_16x64", LaunchTSHL_i32_16x64, 16, 64, 16, 64, sizeof(int32_t)}, + {"i32_32x32", LaunchTSHL_i32_32x32, 32, 32, 32, 32, sizeof(int32_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t src0FileSize = fileSize; + size_t src1FileSize = fileSize; + + int32_t *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + int32_t *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), fileSize); + aclrtMallocHost((void **)(&src1Host), fileSize); + aclrtMallocHost((void **)(&dstHost), fileSize); + + aclrtMalloc((void **)&src0Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, fileSize, src0Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, fileSize, src1Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tshl [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshl/tshl.pto b/test/tilelang_st/npu/a5/src/st/testcase/tshl/tshl.pto new file mode 100644 index 000000000..c002895d9 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshl/tshl.pto @@ -0,0 +1,123 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tshl: tload(a) + tload(b) + tshl(a,b)->c + tstore(c). +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // Case 0: i32 16x64 (1024 elements) + func.func @TSHL_i32_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi32> -> !pto.partition_tensor_view<1x1x1x16x64xi32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi32> -> !pto.partition_tensor_view<1x1x1x16x64xi32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi32> -> !pto.partition_tensor_view<1x1x1x16x64xi32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xi32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x64xi32>) + outs(%b : !pto.tile_buf) + + pto.tshl ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x16x64xi32>) + return + } + + // Case 1: i32 32x32 (1024 elements) + func.func @TSHL_i32_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xi32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xi32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xi32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xi32> -> !pto.partition_tensor_view<1x1x1x32x32xi32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xi32> -> !pto.partition_tensor_view<1x1x1x32x32xi32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xi32> -> !pto.partition_tensor_view<1x1x1x32x32xi32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xi32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x32x32xi32>) + outs(%b : !pto.tile_buf) + + pto.tshl ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x32x32xi32>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshls/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tshls/CMakeLists.txt new file mode 100644 index 000000000..ae8289e40 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshls/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tshls) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshls/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tshls/cases.py new file mode 100644 index 000000000..18cc99178 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshls/cases.py @@ -0,0 +1,42 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np + +CASES = [ + { + "name": "i32_32x64", + "dtype": np.int32, + "shape": (32, 64), + "valid_shape": (32, 64), + "eps": 0, + }, + { + "name": "i16_63x64", + "dtype": np.int16, + "shape": (63, 64), + "valid_shape": (63, 64), + "eps": 0, + }, + { + "name": "i32_31x128", + "dtype": np.int32, + "shape": (31, 128), + "valid_shape": (31, 128), + "eps": 0, + }, + { + "name": "i16_15x192", + "dtype": np.int16, + "shape": (15, 192), + "valid_shape": (15, 192), + "eps": 0, + }, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshls/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tshls/compare.py new file mode 100644 index 000000000..50186777e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshls/compare.py @@ -0,0 +1,46 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshls/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tshls/gen_data.py new file mode 100644 index 000000000..9b4624bfc --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshls/gen_data.py @@ -0,0 +1,34 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +# Scalar value for left shift (must match launch.cpp) +SCALAR = 2 + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + golden[:vr, :vc] = (input1[:vr, :vc] << SCALAR).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__} scalar={SCALAR}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshls/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tshls/launch.cpp new file mode 100644 index 000000000..5e7343071 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshls/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Scalar value for left shift (must match gen_data.py SCALAR) +static constexpr int16_t TSHLS_SCALAR = 2; + +// Case 0: i32 32x64 +extern "C" __global__ AICORE void TSHLS_i32_32x64(__gm__ int32_t *src, __gm__ int32_t *dst, int16_t scalar); + +void LaunchTSHLS_i32_32x64(int32_t *src, int32_t *dst, void *stream) { + TSHLS_i32_32x64<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst, TSHLS_SCALAR); +} + +// Case 1: i16 63x64 +extern "C" __global__ AICORE void TSHLS_i16_63x64(__gm__ int16_t *src, __gm__ int16_t *dst, int16_t scalar); + +void LaunchTSHLS_i16_63x64(int16_t *src, int16_t *dst, void *stream) { + TSHLS_i16_63x64<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst, TSHLS_SCALAR); +} + +// Case 2: i32 31x128 +extern "C" __global__ AICORE void TSHLS_i32_31x128(__gm__ int32_t *src, __gm__ int32_t *dst, int16_t scalar); + +void LaunchTSHLS_i32_31x128(int32_t *src, int32_t *dst, void *stream) { + TSHLS_i32_31x128<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst, TSHLS_SCALAR); +} + +// Case 3: i16 15x192 +extern "C" __global__ AICORE void TSHLS_i16_15x192(__gm__ int16_t *src, __gm__ int16_t *dst, int16_t scalar); + +void LaunchTSHLS_i16_15x192(int16_t *src, int16_t *dst, void *stream) { + TSHLS_i16_15x192<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst, TSHLS_SCALAR); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshls/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tshls/main.cpp new file mode 100644 index 000000000..ca0e4d73e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshls/main.cpp @@ -0,0 +1,135 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tshls ST — case-table driven. +// tshls: dst = src << scalar (single input + scalar, left shift). +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTSHLS_i32_32x64(int32_t *src, int32_t *dst, void *stream); +void LaunchTSHLS_i16_63x64(int16_t *src, int16_t *dst, void *stream); +void LaunchTSHLS_i32_31x128(int32_t *src, int32_t *dst, void *stream); +void LaunchTSHLS_i16_15x192(int16_t *src, int16_t *dst, void *stream); + +struct TestCase { + const char *name; + void (*launch)(void *, void *, void *); // src, dst, stream + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"i32_32x64", (void (*)(void*,void*,void*))LaunchTSHLS_i32_32x64, 32, 64, 32, 64, sizeof(int32_t)}, + {"i16_63x64", (void (*)(void*,void*,void*))LaunchTSHLS_i16_63x64, 63, 64, 63, 64, sizeof(int16_t)}, + {"i32_31x128", (void (*)(void*,void*,void*))LaunchTSHLS_i32_31x128, 31, 128, 31, 128, sizeof(int32_t)}, + {"i16_15x192", (void (*)(void*,void*,void*))LaunchTSHLS_i16_15x192, 15, 192, 15, 192, sizeof(int16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSize = fileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&srcHost, fileSize); + aclrtMallocHost(&dstHost, fileSize); + + aclrtMalloc(&srcDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), srcFileSize, srcHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, fileSize, srcHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tshls [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshls/tshls.pto b/test/tilelang_st/npu/a5/src/st/testcase/tshls/tshls.pto new file mode 100644 index 000000000..6bd6e4fd3 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshls/tshls.pto @@ -0,0 +1,176 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tshls: tload(src) + tshls(src, scalar)->dst + tstore(dst). +// Multiple cases with different shapes/dtypes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + + // Case 0: i32 32x64 (2048 elements) + func.func @TSHLS_i32_32x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i16) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xi32> -> !pto.partition_tensor_view<1x1x1x32x64xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xi32> -> !pto.partition_tensor_view<1x1x1x32x64xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x64xi32>) + outs(%src : !pto.tile_buf) + pto.tshls ins(%src, %scalar : !pto.tile_buf, i16) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x64xi32>) + return + } + + // Case 1: i16 63x64 (4032 elements) + func.func @TSHLS_i16_63x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i16) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c4032 = arith.constant 4032 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xi16> -> !pto.partition_tensor_view<1x1x1x63x64xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xi16> -> !pto.partition_tensor_view<1x1x1x63x64xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x64xi16>) + outs(%src : !pto.tile_buf) + pto.tshls ins(%src, %scalar : !pto.tile_buf, i16) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x64xi16>) + return + } + + // Case 2: i32 31x128 (3968 elements) + func.func @TSHLS_i32_31x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i16) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c31 = arith.constant 31 : index + %c128 = arith.constant 128 : index + %c3968 = arith.constant 3968 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c128] + : !pto.tensor_view<1x1x1x31x128xi32> -> !pto.partition_tensor_view<1x1x1x31x128xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c128] + : !pto.tensor_view<1x1x1x31x128xi32> -> !pto.partition_tensor_view<1x1x1x31x128xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x31x128xi32>) + outs(%src : !pto.tile_buf) + pto.tshls ins(%src, %scalar : !pto.tile_buf, i16) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x31x128xi32>) + return + } + + // Case 3: i16 15x192 (2880 elements) + func.func @TSHLS_i16_15x192(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i16) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c192 = arith.constant 192 : index + %c2880 = arith.constant 2880 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xi16> -> !pto.partition_tensor_view<1x1x1x15x192xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xi16> -> !pto.partition_tensor_view<1x1x1x15x192xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x192xi16>) + outs(%src : !pto.tile_buf) + pto.tshls ins(%src, %scalar : !pto.tile_buf, i16) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x15x192xi16>) + return + } + +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshr/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tshr/CMakeLists.txt new file mode 100644 index 000000000..ed8592e59 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshr/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tshr) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshr/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tshr/cases.py new file mode 100644 index 000000000..36075525b --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshr/cases.py @@ -0,0 +1,40 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tshr ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.int32). + - shape: (rows, cols) — allocated tile dimensions. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "i32_16x64", + "dtype": np.int32, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 0, + }, + { + "name": "i32_32x32", + "dtype": np.int32, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 0, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshr/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tshr/compare.py new file mode 100644 index 000000000..4eae3bc07 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshr/compare.py @@ -0,0 +1,48 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshr/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tshr/gen_data.py new file mode 100644 index 000000000..5737627f7 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshr/gen_data.py @@ -0,0 +1,32 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + input1 = np.random.randint(1, 100, size=shape).astype(dtype) + input2 = np.random.randint(0, 8, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + golden[:vr, :vc] = (input1[:vr, :vc] >> input2[:vr, :vc]).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshr/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tshr/launch.cpp new file mode 100644 index 000000000..1d4f9cd1d --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshr/launch.cpp @@ -0,0 +1,27 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: i32 16x64 +extern "C" __global__ AICORE void TSHR_i32_16x64(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); + +void LaunchTSHR_i32_16x64(int32_t *a, int32_t *b, int32_t *c, void *stream) { + TSHR_i32_16x64<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); +} + +// Case 1: i32 32x32 +extern "C" __global__ AICORE void TSHR_i32_32x32(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); + +void LaunchTSHR_i32_32x32(int32_t *a, int32_t *b, int32_t *c, void *stream) { + TSHR_i32_32x32<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshr/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tshr/main.cpp new file mode 100644 index 000000000..e99634eae --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshr/main.cpp @@ -0,0 +1,145 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tshr ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTSHR_i32_16x64(int32_t *a, int32_t *b, int32_t *c, void *stream); +void LaunchTSHR_i32_32x32(int32_t *a, int32_t *b, int32_t *c, void *stream); + +using LaunchFn = void (*)(int32_t *, int32_t *, int32_t *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"i32_16x64", LaunchTSHR_i32_16x64, 16, 64, 16, 64, sizeof(int32_t)}, + {"i32_32x32", LaunchTSHR_i32_32x32, 32, 32, 32, 32, sizeof(int32_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t src0FileSize = fileSize; + size_t src1FileSize = fileSize; + + int32_t *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + int32_t *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), fileSize); + aclrtMallocHost((void **)(&src1Host), fileSize); + aclrtMallocHost((void **)(&dstHost), fileSize); + + aclrtMalloc((void **)&src0Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, fileSize, src0Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, fileSize, src1Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tshr [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshr/tshr.pto b/test/tilelang_st/npu/a5/src/st/testcase/tshr/tshr.pto new file mode 100644 index 000000000..ab6a9f36f --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshr/tshr.pto @@ -0,0 +1,123 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tshr: tload(a) + tload(b) + tshr(a,b)->c + tstore(c). +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // Case 0: i32 16x64 (1024 elements) + func.func @TSHR_i32_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi32> -> !pto.partition_tensor_view<1x1x1x16x64xi32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi32> -> !pto.partition_tensor_view<1x1x1x16x64xi32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi32> -> !pto.partition_tensor_view<1x1x1x16x64xi32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xi32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x64xi32>) + outs(%b : !pto.tile_buf) + + pto.tshr ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x16x64xi32>) + return + } + + // Case 1: i32 32x32 (1024 elements) + func.func @TSHR_i32_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xi32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xi32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xi32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xi32> -> !pto.partition_tensor_view<1x1x1x32x32xi32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xi32> -> !pto.partition_tensor_view<1x1x1x32x32xi32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xi32> -> !pto.partition_tensor_view<1x1x1x32x32xi32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xi32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x32x32xi32>) + outs(%b : !pto.tile_buf) + + pto.tshr ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x32x32xi32>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshrs/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tshrs/CMakeLists.txt new file mode 100644 index 000000000..c8e37c793 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshrs/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tshrs) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshrs/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tshrs/cases.py new file mode 100644 index 000000000..18cc99178 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshrs/cases.py @@ -0,0 +1,42 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np + +CASES = [ + { + "name": "i32_32x64", + "dtype": np.int32, + "shape": (32, 64), + "valid_shape": (32, 64), + "eps": 0, + }, + { + "name": "i16_63x64", + "dtype": np.int16, + "shape": (63, 64), + "valid_shape": (63, 64), + "eps": 0, + }, + { + "name": "i32_31x128", + "dtype": np.int32, + "shape": (31, 128), + "valid_shape": (31, 128), + "eps": 0, + }, + { + "name": "i16_15x192", + "dtype": np.int16, + "shape": (15, 192), + "valid_shape": (15, 192), + "eps": 0, + }, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshrs/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tshrs/compare.py new file mode 100644 index 000000000..50186777e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshrs/compare.py @@ -0,0 +1,46 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshrs/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tshrs/gen_data.py new file mode 100644 index 000000000..6f269f96a --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshrs/gen_data.py @@ -0,0 +1,34 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +# Scalar value for right shift (must match launch.cpp) +SCALAR = 2 + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + golden[:vr, :vc] = (input1[:vr, :vc] >> SCALAR).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__} scalar={SCALAR}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshrs/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tshrs/launch.cpp new file mode 100644 index 000000000..e80c03e16 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshrs/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Scalar value for right shift (must match gen_data.py SCALAR) +static constexpr int16_t TSHRS_SCALAR = 2; + +// Case 0: i32 32x64 +extern "C" __global__ AICORE void TSHRS_i32_32x64(__gm__ int32_t *src, __gm__ int32_t *dst, int16_t scalar); + +void LaunchTSHRS_i32_32x64(int32_t *src, int32_t *dst, void *stream) { + TSHRS_i32_32x64<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst, TSHRS_SCALAR); +} + +// Case 1: i16 63x64 +extern "C" __global__ AICORE void TSHRS_i16_63x64(__gm__ int16_t *src, __gm__ int16_t *dst, int16_t scalar); + +void LaunchTSHRS_i16_63x64(int16_t *src, int16_t *dst, void *stream) { + TSHRS_i16_63x64<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst, TSHRS_SCALAR); +} + +// Case 2: i32 31x128 +extern "C" __global__ AICORE void TSHRS_i32_31x128(__gm__ int32_t *src, __gm__ int32_t *dst, int16_t scalar); + +void LaunchTSHRS_i32_31x128(int32_t *src, int32_t *dst, void *stream) { + TSHRS_i32_31x128<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst, TSHRS_SCALAR); +} + +// Case 3: i16 15x192 +extern "C" __global__ AICORE void TSHRS_i16_15x192(__gm__ int16_t *src, __gm__ int16_t *dst, int16_t scalar); + +void LaunchTSHRS_i16_15x192(int16_t *src, int16_t *dst, void *stream) { + TSHRS_i16_15x192<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst, TSHRS_SCALAR); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshrs/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tshrs/main.cpp new file mode 100644 index 000000000..3afd710f9 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshrs/main.cpp @@ -0,0 +1,135 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tshrs ST — case-table driven. +// tshrs: dst = src >> scalar (single input + scalar, right shift). +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTSHRS_i32_32x64(int32_t *src, int32_t *dst, void *stream); +void LaunchTSHRS_i16_63x64(int16_t *src, int16_t *dst, void *stream); +void LaunchTSHRS_i32_31x128(int32_t *src, int32_t *dst, void *stream); +void LaunchTSHRS_i16_15x192(int16_t *src, int16_t *dst, void *stream); + +struct TestCase { + const char *name; + void (*launch)(void *, void *, void *); // src, dst, stream + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"i32_32x64", (void (*)(void*,void*,void*))LaunchTSHRS_i32_32x64, 32, 64, 32, 64, sizeof(int32_t)}, + {"i16_63x64", (void (*)(void*,void*,void*))LaunchTSHRS_i16_63x64, 63, 64, 63, 64, sizeof(int16_t)}, + {"i32_31x128", (void (*)(void*,void*,void*))LaunchTSHRS_i32_31x128, 31, 128, 31, 128, sizeof(int32_t)}, + {"i16_15x192", (void (*)(void*,void*,void*))LaunchTSHRS_i16_15x192, 15, 192, 15, 192, sizeof(int16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSize = fileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&srcHost, fileSize); + aclrtMallocHost(&dstHost, fileSize); + + aclrtMalloc(&srcDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), srcFileSize, srcHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, fileSize, srcHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tshrs [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tshrs/tshrs.pto b/test/tilelang_st/npu/a5/src/st/testcase/tshrs/tshrs.pto new file mode 100644 index 000000000..94fe2c4f4 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tshrs/tshrs.pto @@ -0,0 +1,176 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tshrs: tload(src) + tshrs(src, scalar)->dst + tstore(dst). +// Multiple cases with different shapes/dtypes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + + // Case 0: i32 32x64 (2048 elements) + func.func @TSHRS_i32_32x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i16) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xi32> -> !pto.partition_tensor_view<1x1x1x32x64xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xi32> -> !pto.partition_tensor_view<1x1x1x32x64xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x64xi32>) + outs(%src : !pto.tile_buf) + pto.tshrs ins(%src, %scalar : !pto.tile_buf, i16) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x64xi32>) + return + } + + // Case 1: i16 63x64 (4032 elements) + func.func @TSHRS_i16_63x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i16) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c4032 = arith.constant 4032 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xi16> -> !pto.partition_tensor_view<1x1x1x63x64xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xi16> -> !pto.partition_tensor_view<1x1x1x63x64xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x64xi16>) + outs(%src : !pto.tile_buf) + pto.tshrs ins(%src, %scalar : !pto.tile_buf, i16) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x64xi16>) + return + } + + // Case 2: i32 31x128 (3968 elements) + func.func @TSHRS_i32_31x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i16) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c31 = arith.constant 31 : index + %c128 = arith.constant 128 : index + %c3968 = arith.constant 3968 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c128] + : !pto.tensor_view<1x1x1x31x128xi32> -> !pto.partition_tensor_view<1x1x1x31x128xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c128] + : !pto.tensor_view<1x1x1x31x128xi32> -> !pto.partition_tensor_view<1x1x1x31x128xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x31x128xi32>) + outs(%src : !pto.tile_buf) + pto.tshrs ins(%src, %scalar : !pto.tile_buf, i16) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x31x128xi32>) + return + } + + // Case 3: i16 15x192 (2880 elements) + func.func @TSHRS_i16_15x192(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i16) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c192 = arith.constant 192 : index + %c2880 = arith.constant 2880 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xi16> -> !pto.partition_tensor_view<1x1x1x15x192xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xi16> -> !pto.partition_tensor_view<1x1x1x15x192xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x192xi16>) + outs(%src : !pto.tile_buf) + pto.tshrs ins(%src, %scalar : !pto.tile_buf, i16) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x15x192xi16>) + return + } + +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsort32/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tsort32/CMakeLists.txt new file mode 100644 index 000000000..5feb4848a --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsort32/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tsort32) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsort32/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tsort32/cases.py new file mode 100644 index 000000000..d8062ff18 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsort32/cases.py @@ -0,0 +1,199 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tsort32 ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - src_shape: (rows, cols) — allocated source tile dimensions. + - idx_shape: (rows, cols) — allocated index tile dimensions (can be 1 x cols for shared idx). + - tmp_shape: (rows, cols) — allocated tmp tile dimensions (optional, only for unaligned cases). + None for aligned cases (valid_cols % 32 == 0). + For unaligned cases: tmp_rows = 1, tmp_cols = ceil(valid_cols, 32). + - dst_shape: (rows, cols) — allocated destination tile dimensions. + For f32: dst_cols = src_cols * 4 (buffer allocation, but valid region is src_cols * 2). + For f16: dst_cols = src_cols * 2. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + For aligned cases: valid_cols must be multiple of 32 (BLOCK_SIZE). + For unaligned cases: valid_cols can be any value (requires tmp). + - idx_vshape: (idx_valid_rows, idx_valid_cols) — idx valid region. + If idx_valid_rows == 1, same idx is used for all rows. + - dst_vshape: (dst_valid_rows, dst_valid_cols) — dst valid region. + For f32: dst_vcols = src_vcols * 2 (stride coef = 2, interleaved value+index). + - eps: tolerance for numpy.allclose (atol and rtol). + +tsort32 semantics: + - Sorts data in 32-element blocks using vbitsort. + - Output format: interleaved (sorted_value, original_index) pairs with stride coef = 2. + - For each 32-element block, the output contains sorted values and their original indices. + - Each pair occupies 2 element positions: [value0, idx0, value1, idx1, ...] + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + # f32 cases - basic shapes (aligned, no tmp needed) + { + "name": "f32_1x32", + "dtype": np.float32, + "src_shape": (1, 32), + "idx_shape": (1, 32), + "tmp_shape": None, # aligned: valid_cols % 32 == 0, no tmp + "dst_shape": (1, 128), # buffer allocation (src_cols * 4) + "valid_shape": (1, 32), + "idx_vshape": (1, 32), + "dst_vshape": (1, 64), # actual valid output: src_cols * stride_coef = 32 * 2 + "eps": 1e-6, + }, + { + "name": "f32_1x64", + "dtype": np.float32, + "src_shape": (1, 64), + "idx_shape": (1, 64), + "tmp_shape": None, # aligned: valid_cols % 32 == 0, no tmp + "dst_shape": (1, 256), # buffer allocation (src_cols * 4) + "valid_shape": (1, 64), + "idx_vshape": (1, 64), + "dst_vshape": (1, 128), # actual valid output: src_cols * stride_coef = 64 * 2 + "eps": 1e-6, + }, + # f32 cases - multiple rows (aligned, no tmp needed) + { + "name": "f32_2x32", + "dtype": np.float32, + "src_shape": (2, 32), + "idx_shape": (2, 32), + "tmp_shape": None, # aligned: valid_cols % 32 == 0, no tmp + "dst_shape": (2, 128), # buffer allocation (src_cols * 4) + "valid_shape": (2, 32), + "idx_vshape": (2, 32), + "dst_vshape": (2, 64), # actual valid output: src_cols * stride_coef = 32 * 2 + "eps": 1e-6, + }, + { + "name": "f32_16x32", + "dtype": np.float32, + "src_shape": (16, 32), + "idx_shape": (16, 32), + "tmp_shape": None, # aligned: valid_cols % 32 == 0, no tmp + "dst_shape": (16, 128), # buffer allocation (src_cols * 4) + "valid_shape": (16, 32), + "idx_vshape": (16, 32), + "dst_vshape": (16, 64), # actual valid output: src_cols * stride_coef = 32 * 2 + "eps": 1e-6, + }, + # f32 cases - shared idx (aligned, no tmp needed) + { + "name": "f32_2x64_shared_idx", + "dtype": np.float32, + "src_shape": (2, 64), + "idx_shape": (1, 64), # shared idx for all rows + "tmp_shape": None, # aligned: valid_cols % 32 == 0, no tmp + "dst_shape": (2, 256), # buffer allocation (src_cols * 4) + "valid_shape": (2, 64), + "idx_vshape": (1, 64), # idx_valid_rows = 1 means shared idx + "dst_vshape": (2, 128), # actual valid output: src_cols * stride_coef = 64 * 2 + "eps": 1e-6, + }, + { + "name": "f32_16x64_shared_idx", + "dtype": np.float32, + "src_shape": (16, 64), + "idx_shape": (1, 64), # shared idx for all rows + "tmp_shape": None, # aligned: valid_cols % 32 == 0, no tmp + "dst_shape": (16, 256), # buffer allocation (src_cols * 4) + "valid_shape": (16, 64), + "idx_vshape": (1, 64), # idx_valid_rows = 1 means shared idx + "dst_vshape": (16, 128), # actual valid output: src_cols * stride_coef = 64 * 2 + "eps": 1e-6, + }, + # f32 cases - large shape (multiple vbitsort calls, aligned, no tmp needed) + { + "name": "f32_1x8192", + "dtype": np.float32, + "src_shape": (1, 8192), # 256 * 32, requires loop_num > 1 + "idx_shape": (1, 8192), + "tmp_shape": None, # aligned: valid_cols % 32 == 0, no tmp + "dst_shape": (1, 32768), # buffer allocation (src_cols * 4) + "valid_shape": (1, 8192), + "idx_vshape": (1, 8192), + "dst_vshape": (1, 16384), # actual valid output: src_cols * stride_coef = 8192 * 2 + "eps": 1e-6, + }, + # f32 cases - non-32-aligned (requires tmp buffer for padding) + # Case 4 from C++: VALID_C=13, requires padding to 32-element block + { + "name": "f32_2x13", + "dtype": np.float32, + "src_shape": (2, 16), # ALIGN_C = ceil(13*4, 32) / 4 = 16 + "idx_shape": (2, 16), + "tmp_shape": (1, 16), # unaligned: tmp_cols = ceil(13, 32) = 16 + "dst_shape": (2, 64), # 4 * ALIGN_C = 64 + "valid_shape": (2, 13), # non-32-aligned + "idx_vshape": (2, 13), + "dst_vshape": (2, 26), # VALID_C * stride_coef = 13 * 2 + "eps": 1e-6, + }, + # Case 5 from C++: VALID_C=4164, large non-aligned shape + { + "name": "f32_1x4164", + "dtype": np.float32, + "src_shape": (1, 8192), # ALIGN_C = 8192 (from C++ hardcoded) + "idx_shape": (1, 8192), + "tmp_shape": (1, 4168), # unaligned: tmp_cols = ceil(4164, 32) = 4168 + "dst_shape": (1, 32768), # 4 * ALIGN_C = 32768 + "valid_shape": (1, 4164), # non-32-aligned + "idx_vshape": (1, 4164), + "dst_vshape": (1, 8328), # VALID_C * stride_coef = 4164 * 2 + "eps": 1e-6, + }, + # Case 6 from C++: VALID_C=2084, multi-row non-aligned shape + { + "name": "f32_2x2084", + "dtype": np.float32, + "src_shape": (2, 3072), # ALIGN_C = 3072 (from C++ hardcoded) + "idx_shape": (2, 3072), + "tmp_shape": (1, 2088), # unaligned: tmp_cols = ceil(2084, 32) = 2088 + "dst_shape": (2, 12288), # 4 * ALIGN_C = 12288 + "valid_shape": (2, 2084), # non-32-aligned + "idx_vshape": (2, 2084), + "dst_vshape": (2, 4168), # VALID_C * stride_coef = 2084 * 2 + "eps": 1e-6, + }, + # f16 cases - basic shapes (aligned, no tmp needed) + { + "name": "f16_1x32", + "dtype": np.float16, + "src_shape": (1, 32), + "idx_shape": (1, 32), + "tmp_shape": None, # aligned: valid_cols % 32 == 0, no tmp + "dst_shape": (1, 128), # buffer allocation (src_cols * 4 for f16) + "valid_shape": (1, 32), + "idx_vshape": (1, 32), + "dst_vshape": (1, 128), # actual valid output: src_cols * stride_coef = 32 * 4 + "eps": 1e-3, + }, + { + "name": "f16_4x64", + "dtype": np.float16, + "src_shape": (4, 64), + "idx_shape": (4, 64), + "tmp_shape": None, # aligned: valid_cols % 32 == 0, no tmp + "dst_shape": (4, 256), # buffer allocation (src_cols * 4 for f16) + "valid_shape": (4, 64), + "idx_vshape": (4, 64), + "dst_vshape": (4, 256), # actual valid output: src_cols * stride_coef = 64 * 4 + "eps": 1e-3, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsort32/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tsort32/compare.py new file mode 100644 index 000000000..e4184b9f4 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsort32/compare.py @@ -0,0 +1,54 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +# Add parent directory to path for st_common import +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from st_common import result_cmp, style_fail, style_pass + +from cases import CASES + + +def main(): + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + dtype = case["dtype"] + src_shape = case["src_shape"] + dst_shape = case["dst_shape"] + dst_vr, dst_vc = case["dst_vshape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=dtype).reshape(dst_shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=dtype).reshape(dst_shape) + + # Compare only the dst valid region + ok = result_cmp(golden[:dst_vr, :dst_vc], output[:dst_vr, :dst_vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsort32/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tsort32/gen_data.py new file mode 100644 index 000000000..415cbc40b --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsort32/gen_data.py @@ -0,0 +1,140 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +import os +import sys + +# Add parent directory to path for st_common import +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from st_common import setup_case_rng, save_case_data + +from cases import CASES + +BLOCK_SIZE = 32 +FLOAT_DST_STRIDE_COEF = 2 # for f32 +HALF_DST_STRIDE_COEF = 4 # for f16 + + +def _to_tuple(shape): + """Convert shape to tuple if needed.""" + if isinstance(shape, tuple): + return shape + return tuple(shape) + + +def get_stride_coef(dtype): + """Get stride coefficient based on dtype.""" + if dtype == np.float16: + return HALF_DST_STRIDE_COEF + return FLOAT_DST_STRIDE_COEF + + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + src_shape = _to_tuple(case["src_shape"]) + idx_shape = _to_tuple(case["idx_shape"]) + dst_shape = _to_tuple(case["dst_shape"]) + src_valid = _to_tuple(case["valid_shape"]) + idx_valid = _to_tuple(case["idx_vshape"]) + + src_rows, src_cols = src_shape + src_vr, src_vc = src_valid + idx_vr, idx_vc = idx_valid + + # Generate random input data + input_data = np.random.randint(1, 100, size=src_shape).astype(dtype) + + # Generate index data (0, 1, 2, ... for each row) + # If idx_valid_rows == 1, same index is used for all rows + if idx_vr == 1: + idx_data = np.arange(src_cols, dtype=np.int32).reshape(1, src_cols) + else: + idx_data = np.arange(src_cols, dtype=np.int32).reshape(1, src_cols) + idx_data = np.tile(idx_data, (src_rows, 1)) + + # Compute golden: for each 32-element block, sort and output interleaved (value, index) + # Output stride coef depends on dtype: + # - f32 uses stride_coef=2 (value+index pair occupies 2 f32 elements) + # - f16 uses stride_coef=4 (value occupies 1 f16, index stored as ui32 = 4 f16 positions) + stride_coef = get_stride_coef(dtype) + golden = np.zeros(dst_shape, dtype=dtype) + + for row in range(src_vr): + for block_start in range(0, src_vc, BLOCK_SIZE): + block_end = min(block_start + BLOCK_SIZE, src_vc) + block_size = block_end - block_start + + block_data = input_data[row, block_start:block_end].copy() + block_idx = idx_data[0 if idx_vr == 1 else row, block_start:block_end].astype(np.int32) + + # For partial blocks, pad with NaN (negative NaN = max value) to make 32 elements + if block_size < BLOCK_SIZE: + # Use the same padding value as in tsort32_template.py + # f16: 0x7C00 (+inf), bf16: 0x7FC0, f32: 0x7FC00000 (negative NaN) + if dtype == np.float16: + pad_val = np.float16(0xFC00) # +inf for f16 + elif hasattr(np, 'bfloat16') and dtype == np.bfloat16: + pad_val = np.bfloat16(0xFF80) + else: + pad_val = np.float32(0xFF800000) # negative NaN for f32 + + # Pad block to 32 elements with +inf (will be sorted to end) + padded_data = np.full(BLOCK_SIZE, pad_val, dtype=dtype) + padded_data[:block_size] = block_data + + # Pad indices to 32 elements (indices for padding elements don't matter) + padded_idx = np.zeros(BLOCK_SIZE, dtype=np.int32) + padded_idx[:block_size] = block_idx + + # Sort the padded 32-element block in descending order + # +inf values will be at the end after sorting + sorted_indices = np.argsort(-padded_data) + sorted_values = padded_data[sorted_indices] + sorted_original_idx = padded_idx[sorted_indices] + + # Output interleaved (value, index) pairs for the full 32-element block + # but only the first block_size elements are valid (padding elements at the end) + dst_offset = block_start * stride_coef + for i in range(BLOCK_SIZE): + golden[row, dst_offset + i * stride_coef] = sorted_values[i] + # Store index as int32 bit pattern + idx_u32 = np.array(sorted_original_idx[i], dtype=np.uint32) + if dtype == np.float16: + idx_bytes = idx_u32.tobytes() + golden[row, dst_offset + i * stride_coef + 1] = np.frombuffer(idx_bytes[:2], dtype=np.float16)[0] + golden[row, dst_offset + i * stride_coef + 2] = np.frombuffer(idx_bytes[2:], dtype=np.float16)[0] + else: + golden[row, dst_offset + i * stride_coef + 1] = idx_u32.view(np.float32) + else: + # Full 32-element block + # Sort by value in descending order (largest to smallest) + sorted_indices = np.argsort(-block_data) + sorted_values = block_data[sorted_indices] + sorted_original_idx = block_idx[sorted_indices] + + # Output interleaved (value, index) pairs with stride_coef + dst_offset = block_start * stride_coef + for i in range(BLOCK_SIZE): + golden[row, dst_offset + i * stride_coef] = sorted_values[i] + # Store index as int32 bit pattern + idx_u32 = np.array(sorted_original_idx[i], dtype=np.uint32) + if dtype == np.float16: + idx_bytes = idx_u32.tobytes() + golden[row, dst_offset + i * stride_coef + 1] = np.frombuffer(idx_bytes[:2], dtype=np.float16)[0] + golden[row, dst_offset + i * stride_coef + 2] = np.frombuffer(idx_bytes[2:], dtype=np.float16)[0] + else: + golden[row, dst_offset + i * stride_coef + 1] = idx_u32.view(np.float32) + + save_case_data(case["name"], {"input": input_data, "idx": idx_data.astype(np.uint32), "golden": golden}) + print(f"[INFO] gen_data: {case['name']} src_shape={src_shape} idx_shape={idx_shape} dst_shape={dst_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsort32/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tsort32/launch.cpp new file mode 100644 index 000000000..f1e212c10 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsort32/launch.cpp @@ -0,0 +1,97 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case: f32 1x32 +extern "C" __global__ AICORE void TSORT32_f32_1x32(__gm__ float *src, __gm__ uint32_t *idx, __gm__ float *dst); + +void LaunchTSORT32_f32_1x32(float *src, uint32_t *idx, float *dst, void *stream) { + TSORT32_f32_1x32<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)idx, (__gm__ float *)dst); +} + +// Case: f32 1x64 +extern "C" __global__ AICORE void TSORT32_f32_1x64(__gm__ float *src, __gm__ uint32_t *idx, __gm__ float *dst); + +void LaunchTSORT32_f32_1x64(float *src, uint32_t *idx, float *dst, void *stream) { + TSORT32_f32_1x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)idx, (__gm__ float *)dst); +} + +// Case: f32 2x32 +extern "C" __global__ AICORE void TSORT32_f32_2x32(__gm__ float *src, __gm__ uint32_t *idx, __gm__ float *dst); + +void LaunchTSORT32_f32_2x32(float *src, uint32_t *idx, float *dst, void *stream) { + TSORT32_f32_2x32<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)idx, (__gm__ float *)dst); +} + +// Case: f32 16x32 +extern "C" __global__ AICORE void TSORT32_f32_16x32(__gm__ float *src, __gm__ uint32_t *idx, __gm__ float *dst); + +void LaunchTSORT32_f32_16x32(float *src, uint32_t *idx, float *dst, void *stream) { + TSORT32_f32_16x32<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)idx, (__gm__ float *)dst); +} + +// Case: f32 2x64 shared_idx +extern "C" __global__ AICORE void TSORT32_f32_2x64_shared_idx(__gm__ float *src, __gm__ uint32_t *idx, __gm__ float *dst); + +void LaunchTSORT32_f32_2x64_shared_idx(float *src, uint32_t *idx, float *dst, void *stream) { + TSORT32_f32_2x64_shared_idx<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)idx, (__gm__ float *)dst); +} + +// Case: f32 16x64 shared_idx +extern "C" __global__ AICORE void TSORT32_f32_16x64_shared_idx(__gm__ float *src, __gm__ uint32_t *idx, __gm__ float *dst); + +void LaunchTSORT32_f32_16x64_shared_idx(float *src, uint32_t *idx, float *dst, void *stream) { + TSORT32_f32_16x64_shared_idx<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)idx, (__gm__ float *)dst); +} + +// Case: f32 1x8192 (large shape) +extern "C" __global__ AICORE void TSORT32_f32_1x8192(__gm__ float *src, __gm__ uint32_t *idx, __gm__ float *dst); + +void LaunchTSORT32_f32_1x8192(float *src, uint32_t *idx, float *dst, void *stream) { + TSORT32_f32_1x8192<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)idx, (__gm__ float *)dst); +} + +// Case: f16 1x32 +extern "C" __global__ AICORE void TSORT32_f16_1x32(__gm__ uint16_t *src, __gm__ uint32_t *idx, __gm__ uint16_t *dst); + +void LaunchTSORT32_f16_1x32(uint16_t *src, uint32_t *idx, uint16_t *dst, void *stream) { + TSORT32_f16_1x32<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint32_t *)idx, (__gm__ uint16_t *)dst); +} + +// Case: f16 4x64 +extern "C" __global__ AICORE void TSORT32_f16_4x64(__gm__ uint16_t *src, __gm__ uint32_t *idx, __gm__ uint16_t *dst); + +void LaunchTSORT32_f16_4x64(uint16_t *src, uint32_t *idx, uint16_t *dst, void *stream) { + TSORT32_f16_4x64<<<1, nullptr, stream>>>((__gm__ uint16_t *)src, (__gm__ uint32_t *)idx, (__gm__ uint16_t *)dst); +} + +// Case: f32 2x13 (non-32-aligned, requires tmp buffer for padding) +extern "C" __global__ AICORE void TSORT32_f32_2x13(__gm__ float *src, __gm__ uint32_t *idx, __gm__ float *dst); + +void LaunchTSORT32_f32_2x13(float *src, uint32_t *idx, float *dst, void *stream) { + TSORT32_f32_2x13<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)idx, (__gm__ float *)dst); +} + +// Case: f32 1x4164 (non-32-aligned large shape) +extern "C" __global__ AICORE void TSORT32_f32_1x4164(__gm__ float *src, __gm__ uint32_t *idx, __gm__ float *dst); + +void LaunchTSORT32_f32_1x4164(float *src, uint32_t *idx, float *dst, void *stream) { + TSORT32_f32_1x4164<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)idx, (__gm__ float *)dst); +} + +// Case: f32 2x2084 (non-32-aligned multi-row shape) +extern "C" __global__ AICORE void TSORT32_f32_2x2084(__gm__ float *src, __gm__ uint32_t *idx, __gm__ float *dst); + +void LaunchTSORT32_f32_2x2084(float *src, uint32_t *idx, float *dst, void *stream) { + TSORT32_f32_2x2084<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ uint32_t *)idx, (__gm__ float *)dst); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsort32/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tsort32/main.cpp new file mode 100644 index 000000000..e03400591 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsort32/main.cpp @@ -0,0 +1,166 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tsort32 ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTSORT32_f32_1x32(float *src, uint32_t *idx, float *dst, void *stream); +void LaunchTSORT32_f32_1x64(float *src, uint32_t *idx, float *dst, void *stream); +void LaunchTSORT32_f32_2x32(float *src, uint32_t *idx, float *dst, void *stream); +void LaunchTSORT32_f32_16x32(float *src, uint32_t *idx, float *dst, void *stream); +void LaunchTSORT32_f32_2x64_shared_idx(float *src, uint32_t *idx, float *dst, void *stream); +void LaunchTSORT32_f32_16x64_shared_idx(float *src, uint32_t *idx, float *dst, void *stream); +void LaunchTSORT32_f32_1x8192(float *src, uint32_t *idx, float *dst, void *stream); +void LaunchTSORT32_f16_1x32(uint16_t *src, uint32_t *idx, uint16_t *dst, void *stream); +void LaunchTSORT32_f16_4x64(uint16_t *src, uint32_t *idx, uint16_t *dst, void *stream); +void LaunchTSORT32_f32_2x13(float *src, uint32_t *idx, float *dst, void *stream); +void LaunchTSORT32_f32_1x4164(float *src, uint32_t *idx, float *dst, void *stream); +void LaunchTSORT32_f32_2x2084(float *src, uint32_t *idx, float *dst, void *stream); + +using LaunchFn = void (*)(void *, uint32_t *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t srcRows; + size_t srcCols; + size_t idxRows; + size_t idxCols; + size_t dstRows; + size_t dstCols; + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_1x32", reinterpret_cast(LaunchTSORT32_f32_1x32), 1, 32, 1, 32, 1, 128, sizeof(float)}, + {"f32_1x64", reinterpret_cast(LaunchTSORT32_f32_1x64), 1, 64, 1, 64, 1, 256, sizeof(float)}, + {"f32_2x32", reinterpret_cast(LaunchTSORT32_f32_2x32), 2, 32, 2, 32, 2, 128, sizeof(float)}, + {"f32_16x32", reinterpret_cast(LaunchTSORT32_f32_16x32), 16, 32, 16, 32, 16, 128, sizeof(float)}, + {"f32_2x64_shared_idx", reinterpret_cast(LaunchTSORT32_f32_2x64_shared_idx), 2, 64, 1, 64, 2, 256, sizeof(float)}, + {"f32_16x64_shared_idx", reinterpret_cast(LaunchTSORT32_f32_16x64_shared_idx), 16, 64, 1, 64, 16, 256, sizeof(float)}, + {"f32_1x8192", reinterpret_cast(LaunchTSORT32_f32_1x8192), 1, 8192, 1, 8192, 1, 32768, sizeof(float)}, + {"f16_1x32", reinterpret_cast(LaunchTSORT32_f16_1x32), 1, 32, 1, 32, 1, 128, sizeof(uint16_t)}, + {"f16_4x64", reinterpret_cast(LaunchTSORT32_f16_4x64), 4, 64, 4, 64, 4, 256, sizeof(uint16_t)}, + {"f32_2x13", reinterpret_cast(LaunchTSORT32_f32_2x13), 2, 16, 2, 16, 2, 64, sizeof(float)}, + {"f32_1x4164", reinterpret_cast(LaunchTSORT32_f32_1x4164), 1, 8192, 1, 8192, 1, 32768, sizeof(float)}, + {"f32_2x2084", reinterpret_cast(LaunchTSORT32_f32_2x2084), 2, 3072, 2, 3072, 2, 12288, sizeof(float)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, aclrtStream stream) { + int rc = 0; + size_t srcFileSize = tc.srcRows * tc.srcCols * tc.elemSize; + size_t idxFileSize = tc.idxRows * tc.idxCols * sizeof(uint32_t); + size_t dstFileSize = tc.dstRows * tc.dstCols * tc.elemSize; + + std::printf("[INFO] === case: %s (src=%zux%zu, idx=%zux%zu, dst=%zux%zu) ===\n", + tc.name, tc.srcRows, tc.srcCols, tc.idxRows, tc.idxCols, tc.dstRows, tc.dstCols); + + std::string caseDir = std::string("./") + tc.name; + + void *srcHost = nullptr, *dstHost = nullptr; + uint32_t *idxHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + uint32_t *idxDevice = nullptr; + + aclrtMallocHost((void **)(&srcHost), srcFileSize); + aclrtMallocHost((void **)(&idxHost), idxFileSize); + aclrtMallocHost((void **)(&dstHost), dstFileSize); + + aclrtMalloc((void **)&srcDevice, srcFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&idxDevice, idxFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, dstFileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input.bin").c_str(), srcFileSize, srcHost, srcFileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/idx.bin").c_str(), idxFileSize, idxHost, idxFileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/idx.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, srcFileSize, srcHost, srcFileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(idxDevice, idxFileSize, idxHost, idxFileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, idxDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, dstFileSize, dstDevice, dstFileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, dstFileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (idxDevice != nullptr) + aclrtFree(idxDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (idxHost != nullptr) + aclrtFreeHost(idxHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsort32/tsort32.pto b/test/tilelang_st/npu/a5/src/st/testcase/tsort32/tsort32.pto new file mode 100644 index 000000000..406df69ec --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsort32/tsort32.pto @@ -0,0 +1,663 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You can not use the file except of compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tsort32: sort 32-element blocks with interleaved output. +// Multiple cases with different shapes and shared/broadcast index patterns. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // Case 0: f32 1x32 - single row, one 32-element block + func.func @TSORT32_f32_1x32(%src_ptr: !pto.ptr, %idx_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c32], + strides = [%c32, %c32, %c32, %c32, %c1] + : !pto.tensor_view<1x1x1x1x32xf32> + %idx_view = pto.make_tensor_view %idx_ptr, + shape = [%c1, %c1, %c1, %c1, %c32], + strides = [%c32, %c32, %c32, %c32, %c1] + : !pto.tensor_view<1x1x1x1x32xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c32] + : !pto.tensor_view<1x1x1x1x32xf32> -> !pto.partition_tensor_view<1x1x1x1x32xf32> + %idx_part = pto.partition_view %idx_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c32] + : !pto.tensor_view<1x1x1x1x32xui32> -> !pto.partition_tensor_view<1x1x1x1x32xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf32> -> !pto.partition_tensor_view<1x1x1x1x128xf32> + + %src_tile = pto.alloc_tile : !pto.tile_buf + %idx_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x32xf32>) + outs(%src_tile : !pto.tile_buf) + pto.tload ins(%idx_part : !pto.partition_tensor_view<1x1x1x1x32xui32>) + outs(%idx_tile : !pto.tile_buf) + + pto.tsort32 ins(%src_tile, %idx_tile : !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x128xf32>) + return + } + + // Case 1: f32 1x64 - single row, two 32-element blocks + func.func @TSORT32_f32_1x64(%src_ptr: !pto.ptr, %idx_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c64], + strides = [%c64, %c64, %c64, %c64, %c1] + : !pto.tensor_view<1x1x1x1x64xf32> + %idx_view = pto.make_tensor_view %idx_ptr, + shape = [%c1, %c1, %c1, %c1, %c64], + strides = [%c64, %c64, %c64, %c64, %c1] + : !pto.tensor_view<1x1x1x1x64xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c256], + strides = [%c256, %c256, %c256, %c256, %c1] + : !pto.tensor_view<1x1x1x1x256xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c64] + : !pto.tensor_view<1x1x1x1x64xf32> -> !pto.partition_tensor_view<1x1x1x1x64xf32> + %idx_part = pto.partition_view %idx_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c64] + : !pto.tensor_view<1x1x1x1x64xui32> -> !pto.partition_tensor_view<1x1x1x1x64xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c256] + : !pto.tensor_view<1x1x1x1x256xf32> -> !pto.partition_tensor_view<1x1x1x1x256xf32> + + %src_tile = pto.alloc_tile : !pto.tile_buf + %idx_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x64xf32>) + outs(%src_tile : !pto.tile_buf) + pto.tload ins(%idx_part : !pto.partition_tensor_view<1x1x1x1x64xui32>) + outs(%idx_tile : !pto.tile_buf) + + pto.tsort32 ins(%src_tile, %idx_tile : !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x256xf32>) + return + } + + // Case 2: f32 16x32 - multiple rows, one 32-element block per row + func.func @TSORT32_f32_16x32(%src_ptr: !pto.ptr, %idx_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xf32> + %idx_view = pto.make_tensor_view %idx_ptr, + shape = [%c1, %c1, %c1, %c16, %c32], + strides = [%c512, %c512, %c512, %c32, %c1] + : !pto.tensor_view<1x1x1x16x32xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c128], + strides = [%c2048, %c2048, %c2048, %c128, %c1] + : !pto.tensor_view<1x1x1x16x128xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xf32> -> !pto.partition_tensor_view<1x1x1x16x32xf32> + %idx_part = pto.partition_view %idx_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c32] + : !pto.tensor_view<1x1x1x16x32xui32> -> !pto.partition_tensor_view<1x1x1x16x32xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c128] + : !pto.tensor_view<1x1x1x16x128xf32> -> !pto.partition_tensor_view<1x1x1x16x128xf32> + + %src_tile = pto.alloc_tile : !pto.tile_buf + %idx_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x32xf32>) + outs(%src_tile : !pto.tile_buf) + pto.tload ins(%idx_part : !pto.partition_tensor_view<1x1x1x16x32xui32>) + outs(%idx_tile : !pto.tile_buf) + + pto.tsort32 ins(%src_tile, %idx_tile : !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x128xf32>) + return + } + + // Case 3: f32 16x64 shared_idx - multiple rows with shared index (idx rows=1) + func.func @TSORT32_f32_16x64_shared_idx(%src_ptr: !pto.ptr, %idx_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %idx_view = pto.make_tensor_view %idx_ptr, + shape = [%c1, %c1, %c1, %c1, %c64], + strides = [%c64, %c64, %c64, %c64, %c1] + : !pto.tensor_view<1x1x1x1x64xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c16, %c256], + strides = [%c4096, %c4096, %c4096, %c256, %c1] + : !pto.tensor_view<1x1x1x16x256xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %idx_part = pto.partition_view %idx_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c64] + : !pto.tensor_view<1x1x1x1x64xui32> -> !pto.partition_tensor_view<1x1x1x1x64xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c256] + : !pto.tensor_view<1x1x1x16x256xf32> -> !pto.partition_tensor_view<1x1x1x16x256xf32> + + %src_tile = pto.alloc_tile : !pto.tile_buf + %idx_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%src_tile : !pto.tile_buf) + pto.tload ins(%idx_part : !pto.partition_tensor_view<1x1x1x1x64xui32>) + outs(%idx_tile : !pto.tile_buf) + + pto.tsort32 ins(%src_tile, %idx_tile : !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x16x256xf32>) + return + } + + // Case: f32 2x32 - 2 rows, one 32-element block per row + func.func @TSORT32_f32_2x32(%src_ptr: !pto.ptr, %idx_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c32], + strides = [%c64, %c64, %c64, %c32, %c1] + : !pto.tensor_view<1x1x1x2x32xf32> + %idx_view = pto.make_tensor_view %idx_ptr, + shape = [%c1, %c1, %c1, %c2, %c32], + strides = [%c64, %c64, %c64, %c32, %c1] + : !pto.tensor_view<1x1x1x2x32xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c128], + strides = [%c256, %c256, %c256, %c128, %c1] + : !pto.tensor_view<1x1x1x2x128xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c32] + : !pto.tensor_view<1x1x1x2x32xf32> -> !pto.partition_tensor_view<1x1x1x2x32xf32> + %idx_part = pto.partition_view %idx_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c32] + : !pto.tensor_view<1x1x1x2x32xui32> -> !pto.partition_tensor_view<1x1x1x2x32xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c128] + : !pto.tensor_view<1x1x1x2x128xf32> -> !pto.partition_tensor_view<1x1x1x2x128xf32> + + %src_tile = pto.alloc_tile : !pto.tile_buf + %idx_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x32xf32>) + outs(%src_tile : !pto.tile_buf) + pto.tload ins(%idx_part : !pto.partition_tensor_view<1x1x1x2x32xui32>) + outs(%idx_tile : !pto.tile_buf) + + pto.tsort32 ins(%src_tile, %idx_tile : !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x128xf32>) + return + } + + // Case: f32 2x64 shared_idx - 2 rows with shared index + func.func @TSORT32_f32_2x64_shared_idx(%src_ptr: !pto.ptr, %idx_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xf32> + %idx_view = pto.make_tensor_view %idx_ptr, + shape = [%c1, %c1, %c1, %c1, %c64], + strides = [%c64, %c64, %c64, %c64, %c1] + : !pto.tensor_view<1x1x1x1x64xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c256], + strides = [%c512, %c512, %c512, %c256, %c1] + : !pto.tensor_view<1x1x1x2x256xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c64] + : !pto.tensor_view<1x1x1x2x64xf32> -> !pto.partition_tensor_view<1x1x1x2x64xf32> + %idx_part = pto.partition_view %idx_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c64] + : !pto.tensor_view<1x1x1x1x64xui32> -> !pto.partition_tensor_view<1x1x1x1x64xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c256] + : !pto.tensor_view<1x1x1x2x256xf32> -> !pto.partition_tensor_view<1x1x1x2x256xf32> + + %src_tile = pto.alloc_tile : !pto.tile_buf + %idx_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x64xf32>) + outs(%src_tile : !pto.tile_buf) + pto.tload ins(%idx_part : !pto.partition_tensor_view<1x1x1x1x64xui32>) + outs(%idx_tile : !pto.tile_buf) + + pto.tsort32 ins(%src_tile, %idx_tile : !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x256xf32>) + return + } + + // Case: f32 1x8192 - large shape (256 * 32), requires multiple vbitsort calls + func.func @TSORT32_f32_1x8192(%src_ptr: !pto.ptr, %idx_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8192 = arith.constant 8192 : index + %c32768 = arith.constant 32768 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c8192], + strides = [%c8192, %c8192, %c8192, %c8192, %c1] + : !pto.tensor_view<1x1x1x1x8192xf32> + %idx_view = pto.make_tensor_view %idx_ptr, + shape = [%c1, %c1, %c1, %c1, %c8192], + strides = [%c8192, %c8192, %c8192, %c8192, %c1] + : !pto.tensor_view<1x1x1x1x8192xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c32768], + strides = [%c32768, %c32768, %c32768, %c32768, %c1] + : !pto.tensor_view<1x1x1x1x32768xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c8192] + : !pto.tensor_view<1x1x1x1x8192xf32> -> !pto.partition_tensor_view<1x1x1x1x8192xf32> + %idx_part = pto.partition_view %idx_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c8192] + : !pto.tensor_view<1x1x1x1x8192xui32> -> !pto.partition_tensor_view<1x1x1x1x8192xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c32768] + : !pto.tensor_view<1x1x1x1x32768xf32> -> !pto.partition_tensor_view<1x1x1x1x32768xf32> + + %src_tile = pto.alloc_tile : !pto.tile_buf + %idx_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x8192xf32>) + outs(%src_tile : !pto.tile_buf) + pto.tload ins(%idx_part : !pto.partition_tensor_view<1x1x1x1x8192xui32>) + outs(%idx_tile : !pto.tile_buf) + + pto.tsort32 ins(%src_tile, %idx_tile : !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x32768xf32>) + return + } + + // Case: f16 1x32 - f16 dtype, single row, one 32-element block + func.func @TSORT32_f16_1x32(%src_ptr: !pto.ptr, %idx_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c32], + strides = [%c32, %c32, %c32, %c32, %c1] + : !pto.tensor_view<1x1x1x1x32xf16> + %idx_view = pto.make_tensor_view %idx_ptr, + shape = [%c1, %c1, %c1, %c1, %c32], + strides = [%c32, %c32, %c32, %c32, %c1] + : !pto.tensor_view<1x1x1x1x32xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c128], + strides = [%c128, %c128, %c128, %c128, %c1] + : !pto.tensor_view<1x1x1x1x128xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c32] + : !pto.tensor_view<1x1x1x1x32xf16> -> !pto.partition_tensor_view<1x1x1x1x32xf16> + %idx_part = pto.partition_view %idx_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c32] + : !pto.tensor_view<1x1x1x1x32xui32> -> !pto.partition_tensor_view<1x1x1x1x32xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c128] + : !pto.tensor_view<1x1x1x1x128xf16> -> !pto.partition_tensor_view<1x1x1x1x128xf16> + + %src_tile = pto.alloc_tile : !pto.tile_buf + %idx_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x32xf16>) + outs(%src_tile : !pto.tile_buf) + pto.tload ins(%idx_part : !pto.partition_tensor_view<1x1x1x1x32xui32>) + outs(%idx_tile : !pto.tile_buf) + + pto.tsort32 ins(%src_tile, %idx_tile : !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x128xf16>) + return + } + + // Case: f16 4x64 - f16 dtype, 4 rows, two 32-element blocks per row + func.func @TSORT32_f16_4x64(%src_ptr: !pto.ptr, %idx_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c64 = arith.constant 64 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c4, %c64], + strides = [%c256, %c256, %c256, %c64, %c1] + : !pto.tensor_view<1x1x1x4x64xf16> + %idx_view = pto.make_tensor_view %idx_ptr, + shape = [%c1, %c1, %c1, %c4, %c64], + strides = [%c256, %c256, %c256, %c64, %c1] + : !pto.tensor_view<1x1x1x4x64xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c4, %c256], + strides = [%c1024, %c1024, %c1024, %c256, %c1] + : !pto.tensor_view<1x1x1x4x256xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c64] + : !pto.tensor_view<1x1x1x4x64xf16> -> !pto.partition_tensor_view<1x1x1x4x64xf16> + %idx_part = pto.partition_view %idx_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c64] + : !pto.tensor_view<1x1x1x4x64xui32> -> !pto.partition_tensor_view<1x1x1x4x64xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c4, %c256] + : !pto.tensor_view<1x1x1x4x256xf16> -> !pto.partition_tensor_view<1x1x1x4x256xf16> + + %src_tile = pto.alloc_tile : !pto.tile_buf + %idx_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x4x64xf16>) + outs(%src_tile : !pto.tile_buf) + pto.tload ins(%idx_part : !pto.partition_tensor_view<1x1x1x4x64xui32>) + outs(%idx_tile : !pto.tile_buf) + + pto.tsort32 ins(%src_tile, %idx_tile : !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x4x256xf16>) + return + } + + // Case: f32 2x13 - non-32-aligned, requires tmp buffer for padding + // VALID_C=13, tmp_cols=ceil(13,8)=16 (32-byte aligned) + func.func @TSORT32_f32_2x13(%src_ptr: !pto.ptr, %idx_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c13 = arith.constant 13 : index + %c16 = arith.constant 16 : index + %c26 = arith.constant 26 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c16], + strides = [%c32, %c32, %c32, %c16, %c1] + : !pto.tensor_view<1x1x1x2x16xf32> + %idx_view = pto.make_tensor_view %idx_ptr, + shape = [%c1, %c1, %c1, %c2, %c16], + strides = [%c32, %c32, %c32, %c16, %c1] + : !pto.tensor_view<1x1x1x2x16xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c64], + strides = [%c128, %c128, %c128, %c64, %c1] + : !pto.tensor_view<1x1x1x2x64xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c13] + : !pto.tensor_view<1x1x1x2x16xf32> -> !pto.partition_tensor_view<1x1x1x2x13xf32> + %idx_part = pto.partition_view %idx_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c13] + : !pto.tensor_view<1x1x1x2x16xui32> -> !pto.partition_tensor_view<1x1x1x2x13xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c26] + : !pto.tensor_view<1x1x1x2x64xf32> -> !pto.partition_tensor_view<1x1x1x2x26xf32> + + %src_tile = pto.alloc_tile : !pto.tile_buf + %idx_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + %tmp_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x13xf32>) + outs(%src_tile : !pto.tile_buf) + pto.tload ins(%idx_part : !pto.partition_tensor_view<1x1x1x2x13xui32>) + outs(%idx_tile : !pto.tile_buf) + + pto.tsort32 ins(%src_tile, %idx_tile, %tmp_tile : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x26xf32>) + return + } + + // Case: f32 1x4164 - non-32-aligned large shape + // VALID_C=4164, tmp_cols=ceil(4164,8)=4168 (32-byte aligned) + func.func @TSORT32_f32_1x4164(%src_ptr: !pto.ptr, %idx_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4164 = arith.constant 4164 : index + %c4168 = arith.constant 4168 : index + %c8192 = arith.constant 8192 : index + %c8328 = arith.constant 8328 : index + %c32768 = arith.constant 32768 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c1, %c8192], + strides = [%c8192, %c8192, %c8192, %c8192, %c1] + : !pto.tensor_view<1x1x1x1x8192xf32> + %idx_view = pto.make_tensor_view %idx_ptr, + shape = [%c1, %c1, %c1, %c1, %c8192], + strides = [%c8192, %c8192, %c8192, %c8192, %c1] + : !pto.tensor_view<1x1x1x1x8192xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c1, %c32768], + strides = [%c32768, %c32768, %c32768, %c32768, %c1] + : !pto.tensor_view<1x1x1x1x32768xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c4164] + : !pto.tensor_view<1x1x1x1x8192xf32> -> !pto.partition_tensor_view<1x1x1x1x4164xf32> + %idx_part = pto.partition_view %idx_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c4164] + : !pto.tensor_view<1x1x1x1x8192xui32> -> !pto.partition_tensor_view<1x1x1x1x4164xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c1, %c8328] + : !pto.tensor_view<1x1x1x1x32768xf32> -> !pto.partition_tensor_view<1x1x1x1x8328xf32> + + %src_tile = pto.alloc_tile : !pto.tile_buf + %idx_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + %tmp_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x4164xf32>) + outs(%src_tile : !pto.tile_buf) + pto.tload ins(%idx_part : !pto.partition_tensor_view<1x1x1x1x4164xui32>) + outs(%idx_tile : !pto.tile_buf) + + pto.tsort32 ins(%src_tile, %idx_tile, %tmp_tile : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x8328xf32>) + return + } + + // Case: f32 2x2084 - non-32-aligned multi-row shape + // VALID_C=2084, tmp_cols=ceil(2084,8)=2088 (32-byte aligned) + func.func @TSORT32_f32_2x2084(%src_ptr: !pto.ptr, %idx_ptr: !pto.ptr, %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c2084 = arith.constant 2084 : index + %c2088 = arith.constant 2088 : index + %c3072 = arith.constant 3072 : index + %c4168 = arith.constant 4168 : index + %c6144 = arith.constant 6144 : index + %c12288 = arith.constant 12288 : index + %c24576 = arith.constant 24576 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c2, %c3072], + strides = [%c6144, %c6144, %c6144, %c3072, %c1] + : !pto.tensor_view<1x1x1x2x3072xf32> + %idx_view = pto.make_tensor_view %idx_ptr, + shape = [%c1, %c1, %c1, %c2, %c3072], + strides = [%c6144, %c6144, %c6144, %c3072, %c1] + : !pto.tensor_view<1x1x1x2x3072xui32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c2, %c12288], + strides = [%c24576, %c24576, %c24576, %c12288, %c1] + : !pto.tensor_view<1x1x1x2x12288xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c2084] + : !pto.tensor_view<1x1x1x2x3072xf32> -> !pto.partition_tensor_view<1x1x1x2x2084xf32> + %idx_part = pto.partition_view %idx_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c2084] + : !pto.tensor_view<1x1x1x2x3072xui32> -> !pto.partition_tensor_view<1x1x1x2x2084xui32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c2, %c4168] + : !pto.tensor_view<1x1x1x2x12288xf32> -> !pto.partition_tensor_view<1x1x1x2x4168xf32> + + %src_tile = pto.alloc_tile : !pto.tile_buf + %idx_tile = pto.alloc_tile : !pto.tile_buf + %dst_tile = pto.alloc_tile : !pto.tile_buf + %tmp_tile = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x2084xf32>) + outs(%src_tile : !pto.tile_buf) + pto.tload ins(%idx_part : !pto.partition_tensor_view<1x1x1x2x2084xui32>) + outs(%idx_tile : !pto.tile_buf) + + pto.tsort32 ins(%src_tile, %idx_tile, %tmp_tile : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%dst_tile : !pto.tile_buf) + + pto.tstore ins(%dst_tile : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x4168xf32>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsqrt/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tsqrt/CMakeLists.txt new file mode 100644 index 000000000..83de2cda8 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsqrt/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tsqrt) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsqrt/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tsqrt/cases.py new file mode 100644 index 000000000..24fb10786 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsqrt/cases.py @@ -0,0 +1,75 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tsqrt ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_16x64", + "dtype": np.float32, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-6, + "high_precision": False, + }, + { + "name": "f32_32x32", + "dtype": np.float32, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-6, + "high_precision": False, + }, + { + "name": "f16_16x64", + "dtype": np.float16, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-3, + "high_precision": False, + }, + { + "name": "f16_32x32", + "dtype": np.float16, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-3, + "high_precision": False, + }, + { + "name": "f32_64x64_hp1", + "dtype": np.float32, + "shape": (64, 64), + "valid_shape": (64, 64), + "eps": 1e-7, + "high_precision": True, + }, + { + "name": "f16_64x64_hp2", + "dtype": np.float16, + "shape": (64, 64), + "valid_shape": (64, 64), + "eps": 1e-7, + "high_precision": True, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsqrt/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tsqrt/compare.py new file mode 100644 index 000000000..428604929 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsqrt/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsqrt/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tsqrt/gen_data.py new file mode 100644 index 000000000..a5744d36d --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsqrt/gen_data.py @@ -0,0 +1,36 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + high_precision = case.get("high_precision", False) + + if high_precision: + input = np.random.uniform(0.001, 1.0, size=shape).astype(dtype) + else: + input = np.random.uniform(0.1, 100.0, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + golden[:vr, :vc] = np.sqrt(input[:vr, :vc]).astype(dtype, copy=False) + + save_case_data(case["name"], {"input": input, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__} high_precision={high_precision}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsqrt/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tsqrt/launch.cpp new file mode 100644 index 000000000..cb1d3cb3d --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsqrt/launch.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: f32 16x64 +extern "C" __global__ AICORE void TSQRT_f32_16x64(__gm__ float *a, __gm__ float *b); + +void LaunchTSQRT_f32_16x64(void *a, void *b, void *stream) { + TSQRT_f32_16x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b); +} + +// Case 1: f32 32x32 +extern "C" __global__ AICORE void TSQRT_f32_32x32(__gm__ float *a, __gm__ float *b); + +void LaunchTSQRT_f32_32x32(void *a, void *b, void *stream) { + TSQRT_f32_32x32<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b); +} + +// Case 2: f16 16x64 +extern "C" __global__ AICORE void TSQRT_f16_16x64(__gm__ uint16_t *a, __gm__ uint16_t *b); + +void LaunchTSQRT_f16_16x64(void *a, void *b, void *stream) { + TSQRT_f16_16x64<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b); +} + +// Case 3: f16 32x32 +extern "C" __global__ AICORE void TSQRT_f16_32x32(__gm__ uint16_t *a, __gm__ uint16_t *b); + +void LaunchTSQRT_f16_32x32(void *a, void *b, void *stream) { + TSQRT_f16_32x32<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b); +} + +// Case 4: f32 64x64 hp1 +extern "C" __global__ AICORE void TSQRT_f32_64x64_hp1(__gm__ float *a, __gm__ float *b); + +void LaunchTSQRT_f32_64x64_hp1(void *a, void *b, void *stream) { + TSQRT_f32_64x64_hp1<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b); +} + +// Case 5: f16 64x64 hp2 +extern "C" __global__ AICORE void TSQRT_f16_64x64_hp2(__gm__ uint16_t *a, __gm__ uint16_t *b); + +void LaunchTSQRT_f16_64x64_hp2(void *a, void *b, void *stream) { + TSQRT_f16_64x64_hp2<<<1, nullptr, stream>>>((__gm__ uint16_t *)a, (__gm__ uint16_t *)b); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsqrt/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tsqrt/main.cpp new file mode 100644 index 000000000..d7be5f315 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsqrt/main.cpp @@ -0,0 +1,141 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tsqrt ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTSQRT_f32_16x64(void *a, void *b, void *stream); +void LaunchTSQRT_f32_32x32(void *a, void *b, void *stream); +void LaunchTSQRT_f16_16x64(void *a, void *b, void *stream); +void LaunchTSQRT_f16_32x32(void *a, void *b, void *stream); +void LaunchTSQRT_f32_64x64_hp1(void *a, void *b, void *stream); +void LaunchTSQRT_f16_64x64_hp2(void *a, void *b, void *stream); + +using LaunchFn = void (*)(void *, void *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_16x64", LaunchTSQRT_f32_16x64, 16, 64, 16, 64, sizeof(float)}, + {"f32_32x32", LaunchTSQRT_f32_32x32, 32, 32, 32, 32, sizeof(float)}, + {"f16_16x64", LaunchTSQRT_f16_16x64, 16, 64, 16, 64, sizeof(uint16_t)}, + {"f16_32x32", LaunchTSQRT_f16_32x32, 32, 32, 32, 32, sizeof(uint16_t)}, + {"f32_64x64_hp1", LaunchTSQRT_f32_64x64_hp1, 64, 64, 64, 64, sizeof(float)}, + {"f16_64x64_hp2", LaunchTSQRT_f16_64x64_hp2, 64, 64, 64, 64, sizeof(uint16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSize = fileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&srcHost), fileSize); + aclrtMallocHost((void **)(&dstHost), fileSize); + + aclrtMalloc((void **)&srcDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input.bin").c_str(), srcFileSize, srcHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, fileSize, srcHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tsqrt [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsqrt/tsqrt.pto b/test/tilelang_st/npu/a5/src/st/testcase/tsqrt/tsqrt.pto new file mode 100644 index 000000000..f7e49af9f --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsqrt/tsqrt.pto @@ -0,0 +1,264 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tsqrt: tload(a) + tsqrt(a)->b + tstore(b). +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // Case 0: f32 16x64 (1024 elements) + func.func @TSQRT_f32_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%a : !pto.tile_buf) + + pto.tsqrt ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + return + } + + // Case 1: f32 32x32 (1024 elements) + func.func @TSQRT_f32_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + outs(%a : !pto.tile_buf) + + pto.tsqrt ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + +pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + return + } + + // Case 2: f16 16x64 (1024 elements) + func.func @TSQRT_f16_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf16> -> !pto.partition_tensor_view<1x1x1x16x64xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + outs(%a : !pto.tile_buf) + + pto.tsqrt ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf16>) + return + } + + // Case 3: f16 32x32 (1024 elements) + func.func @TSQRT_f16_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf16> -> !pto.partition_tensor_view<1x1x1x32x32xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf16> -> !pto.partition_tensor_view<1x1x1x32x32xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xf16>) + outs(%a : !pto.tile_buf) + + pto.tsqrt ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x32x32xf16>) + return + } + + // Case 4: f32 64x64 hp1 (4096 elements) + func.func @TSQRT_f32_64x64_hp1(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf32> -> !pto.partition_tensor_view<1x1x1x64x64xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + outs(%a : !pto.tile_buf) + + pto.tsqrt ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + {precision_mode = #pto} + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf32>) + return + } + + // Case 5: f16 64x64 hp2 (4096 elements) + func.func @TSQRT_f16_64x64_hp2(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c4096 = arith.constant 4096 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf16> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c64, %c64], + strides = [%c4096, %c4096, %c4096, %c64, %c1] + : !pto.tensor_view<1x1x1x64x64xf16> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf16> -> !pto.partition_tensor_view<1x1x1x64x64xf16> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c64, %c64] + : !pto.tensor_view<1x1x1x64x64xf16> -> !pto.partition_tensor_view<1x1x1x64x64xf16> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x64x64xf16>) + outs(%a : !pto.tile_buf) + + pto.tsqrt ins(%a : !pto.tile_buf) + outs(%b : !pto.tile_buf) + {precision_mode = #pto} + + pto.tstore ins(%b : !pto.tile_buf) + outs(%b_part : !pto.partition_tensor_view<1x1x1x64x64xf16>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsub/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tsub/CMakeLists.txt new file mode 100644 index 000000000..da60b1f64 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsub/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tsub) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsub/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tsub/cases.py new file mode 100644 index 000000000..b71da2e9b --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsub/cases.py @@ -0,0 +1,40 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tsub ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.float32). + - shape: (rows, cols) — allocated tile dimensions. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "f32_16x64", + "dtype": np.float32, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 1e-6, + }, + { + "name": "f32_32x32", + "dtype": np.float32, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 1e-6, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsub/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tsub/compare.py new file mode 100644 index 000000000..4eae3bc07 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsub/compare.py @@ -0,0 +1,48 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsub/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tsub/gen_data.py new file mode 100644 index 000000000..95cccfd2a --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsub/gen_data.py @@ -0,0 +1,32 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + input2 = np.random.randint(1, 10, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + golden[:vr, :vc] = (input1[:vr, :vc] - input2[:vr, :vc]).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsub/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tsub/launch.cpp new file mode 100644 index 000000000..256c0ed07 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsub/launch.cpp @@ -0,0 +1,27 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: f32 16x64 +extern "C" __global__ AICORE void TSUB_f32_16x64(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTSUB_f32_16x64(float *a, float *b, float *c, void *stream) { + TSUB_f32_16x64<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} + +// Case 1: f32 32x32 +extern "C" __global__ AICORE void TSUB_f32_32x32(__gm__ float *a, __gm__ float *b, __gm__ float *c); + +void LaunchTSUB_f32_32x32(float *a, float *b, float *c, void *stream) { + TSUB_f32_32x32<<<1, nullptr, stream>>>((__gm__ float *)a, (__gm__ float *)b, (__gm__ float *)c); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsub/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tsub/main.cpp new file mode 100644 index 000000000..b5e338d4b --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsub/main.cpp @@ -0,0 +1,145 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tsub ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTSUB_f32_16x64(float *a, float *b, float *c, void *stream); +void LaunchTSUB_f32_32x32(float *a, float *b, float *c, void *stream); + +using LaunchFn = void (*)(float *, float *, float *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_16x64", LaunchTSUB_f32_16x64, 16, 64, 16, 64, sizeof(float)}, + {"f32_32x32", LaunchTSUB_f32_32x32, 32, 32, 32, 32, sizeof(float)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t src0FileSize = fileSize; + size_t src1FileSize = fileSize; + + float *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + float *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), fileSize); + aclrtMallocHost((void **)(&src1Host), fileSize); + aclrtMallocHost((void **)(&dstHost), fileSize); + + aclrtMalloc((void **)&src0Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, fileSize, src0Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, fileSize, src1Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tsub [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsub/tsub.pto b/test/tilelang_st/npu/a5/src/st/testcase/tsub/tsub.pto new file mode 100644 index 000000000..3246f3bc1 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsub/tsub.pto @@ -0,0 +1,123 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tsub: tload(a) + tload(b) + tsub(a,b)->c + tstore(c). +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // Case 0: f32 16x64 (1024 elements) + func.func @TSUB_f32_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xf32> -> !pto.partition_tensor_view<1x1x1x16x64xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + outs(%b : !pto.tile_buf) + + pto.tsub ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x16x64xf32>) + return + } + + // Case 1: f32 32x32 (1024 elements) + func.func @TSUB_f32_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xf32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xf32> -> !pto.partition_tensor_view<1x1x1x32x32xf32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + outs(%b : !pto.tile_buf) + + pto.tsub ins(%a, %b : !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x32x32xf32>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsubs/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/tsubs/CMakeLists.txt new file mode 100644 index 000000000..3ccdb0fb1 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsubs/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(tsubs) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsubs/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/tsubs/cases.py new file mode 100644 index 000000000..af6b6b425 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsubs/cases.py @@ -0,0 +1,22 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for tsubs ST test cases.""" + +import numpy as np + +CASES = [ + {"name": "f32_32x64", "dtype": np.float32, "shape": (32, 64), "valid_shape": (32, 64), "eps": 1e-6}, + {"name": "f16_63x64", "dtype": np.float16, "shape": (63, 64), "valid_shape": (63, 64), "eps": 1e-3}, + {"name": "i32_31x128", "dtype": np.int32, "shape": (31, 128), "valid_shape": (31, 128), "eps": 0}, + {"name": "i16_15x192", "dtype": np.int16, "shape": (15, 192), "valid_shape": (15, 192), "eps": 0}, + {"name": "f32_7x448", "dtype": np.float32, "shape": (7, 448), "valid_shape": (7, 448), "eps": 1e-6}, + {"name": "f32_256x16", "dtype": np.float32, "shape": (256, 16), "valid_shape": (256, 16), "eps": 1e-6}, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsubs/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/tsubs/compare.py new file mode 100644 index 000000000..50186777e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsubs/compare.py @@ -0,0 +1,46 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsubs/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/tsubs/gen_data.py new file mode 100644 index 000000000..20d55e1d6 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsubs/gen_data.py @@ -0,0 +1,35 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +# Scalar value subtracted from every element (matches the scalar passed in launch.cpp) +SCALAR = 3.0 + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + scalar_val = dtype(SCALAR) + golden[:vr, :vc] = (input1[:vr, :vc] - scalar_val).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__} scalar={SCALAR}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsubs/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tsubs/launch.cpp new file mode 100644 index 000000000..d511cf09d --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsubs/launch.cpp @@ -0,0 +1,58 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Scalar value subtracted from every element (must match gen_data.py SCALAR) +static constexpr float TSUBS_SCALAR_F32 = 3.0f; + +// Case 0: f32 32x64 +extern "C" __global__ AICORE void TSUBS_f32_32x64(__gm__ float *src, __gm__ float *dst, float scalar); + +void LaunchTSUBS_f32_32x64(float *src, float *dst, void *stream) { + TSUBS_f32_32x64<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, TSUBS_SCALAR_F32); +} + +// Case 1: f16 63x64 +extern "C" __global__ AICORE void TSUBS_f16_63x64(__gm__ unsigned short *src, __gm__ unsigned short *dst, unsigned short scalar); + +void LaunchTSUBS_f16_63x64(unsigned short *src, unsigned short *dst, void *stream) { + TSUBS_f16_63x64<<<1, nullptr, stream>>>((__gm__ unsigned short *)src, (__gm__ unsigned short *)dst, (unsigned short)0x4200); +} + +// Case 2: i32 31x128 +extern "C" __global__ AICORE void TSUBS_i32_31x128(__gm__ int32_t *src, __gm__ int32_t *dst, int32_t scalar); + +void LaunchTSUBS_i32_31x128(int32_t *src, int32_t *dst, void *stream) { + TSUBS_i32_31x128<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst, (int32_t)3); +} + +// Case 3: i16 15x192 +extern "C" __global__ AICORE void TSUBS_i16_15x192(__gm__ int16_t *src, __gm__ int16_t *dst, int16_t scalar); + +void LaunchTSUBS_i16_15x192(int16_t *src, int16_t *dst, void *stream) { + TSUBS_i16_15x192<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst, (int16_t)3); +} + +// Case 4: f32 7x448 +extern "C" __global__ AICORE void TSUBS_f32_7x448(__gm__ float *src, __gm__ float *dst, float scalar); + +void LaunchTSUBS_f32_7x448(float *src, float *dst, void *stream) { + TSUBS_f32_7x448<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, TSUBS_SCALAR_F32); +} + +// Case 5: f32 256x16 +extern "C" __global__ AICORE void TSUBS_f32_256x16(__gm__ float *src, __gm__ float *dst, float scalar); + +void LaunchTSUBS_f32_256x16(float *src, float *dst, void *stream) { + TSUBS_f32_256x16<<<1, nullptr, stream>>>((__gm__ float *)src, (__gm__ float *)dst, TSUBS_SCALAR_F32); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsubs/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/tsubs/main.cpp new file mode 100644 index 000000000..40509f578 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsubs/main.cpp @@ -0,0 +1,139 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang tsubs ST — case-table driven. +// tsubs: dst = src - scalar (single input + scalar). +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTSUBS_f32_32x64(float *src, float *dst, void *stream); +void LaunchTSUBS_f16_63x64(uint16_t *src, uint16_t *dst, void *stream); +void LaunchTSUBS_i32_31x128(int32_t *src, int32_t *dst, void *stream); +void LaunchTSUBS_i16_15x192(int16_t *src, int16_t *dst, void *stream); +void LaunchTSUBS_f32_7x448(float *src, float *dst, void *stream); +void LaunchTSUBS_f32_256x16(float *src, float *dst, void *stream); + +struct TestCase { + const char *name; + void (*launch)(void *, void *, void *); // src, dst, stream + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"f32_32x64", (void (*)(void*,void*,void*))LaunchTSUBS_f32_32x64, 32, 64, 32, 64, sizeof(float)}, + {"f16_63x64", (void (*)(void*,void*,void*))LaunchTSUBS_f16_63x64, 63, 64, 63, 64, sizeof(uint16_t)}, + {"i32_31x128", (void (*)(void*,void*,void*))LaunchTSUBS_i32_31x128, 31, 128, 31, 128, sizeof(int32_t)}, + {"i16_15x192", (void (*)(void*,void*,void*))LaunchTSUBS_i16_15x192, 15, 192, 15, 192, sizeof(int16_t)}, + {"f32_7x448", (void (*)(void*,void*,void*))LaunchTSUBS_f32_7x448, 7, 448, 7, 448, sizeof(float)}, + {"f32_256x16", (void (*)(void*,void*,void*))LaunchTSUBS_f32_256x16, 256, 16, 256, 16, sizeof(float)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSize = fileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&srcHost, fileSize); + aclrtMallocHost(&dstHost, fileSize); + + aclrtMalloc(&srcDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), srcFileSize, srcHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, fileSize, srcHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./tsubs [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsubs/tsubs.pto b/test/tilelang_st/npu/a5/src/st/testcase/tsubs/tsubs.pto new file mode 100644 index 000000000..4e9dcaef1 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsubs/tsubs.pto @@ -0,0 +1,256 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.tsubs: tload(src) + tsubs(src, scalar)->dst + tstore(dst). +// Multiple cases with different shapes/dtypes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + + // Case 0: f32 32x64 (2048 elements) + func.func @TSUBS_f32_32x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xf32> -> !pto.partition_tensor_view<1x1x1x32x64xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) + outs(%src : !pto.tile_buf) + pto.tsubs ins(%src, %scalar : !pto.tile_buf, f32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x64xf32>) + return + } + + // Case 1: f16 63x64 (4032 elements) + func.func @TSUBS_f16_63x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f16) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c4032 = arith.constant 4032 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xf16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xf16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xf16> -> !pto.partition_tensor_view<1x1x1x63x64xf16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xf16> -> !pto.partition_tensor_view<1x1x1x63x64xf16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x64xf16>) + outs(%src : !pto.tile_buf) + pto.tsubs ins(%src, %scalar : !pto.tile_buf, f16) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x64xf16>) + return + } + + // Case 2: i32 31x128 (3968 elements) + func.func @TSUBS_i32_31x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c31 = arith.constant 31 : index + %c128 = arith.constant 128 : index + %c3968 = arith.constant 3968 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c128] + : !pto.tensor_view<1x1x1x31x128xi32> -> !pto.partition_tensor_view<1x1x1x31x128xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c128] + : !pto.tensor_view<1x1x1x31x128xi32> -> !pto.partition_tensor_view<1x1x1x31x128xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x31x128xi32>) + outs(%src : !pto.tile_buf) + pto.tsubs ins(%src, %scalar : !pto.tile_buf, i32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x31x128xi32>) + return + } + + // Case 3: i16 15x192 (2880 elements) + func.func @TSUBS_i16_15x192(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i16) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c192 = arith.constant 192 : index + %c2880 = arith.constant 2880 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xi16> -> !pto.partition_tensor_view<1x1x1x15x192xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xi16> -> !pto.partition_tensor_view<1x1x1x15x192xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x192xi16>) + outs(%src : !pto.tile_buf) + pto.tsubs ins(%src, %scalar : !pto.tile_buf, i16) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x15x192xi16>) + return + } + + // Case 4: f32 7x448 (3136 elements) + func.func @TSUBS_f32_7x448(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c7 = arith.constant 7 : index + %c448 = arith.constant 448 : index + %c3136 = arith.constant 3136 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c7, %c448], + strides = [%c3136, %c3136, %c3136, %c448, %c1] + : !pto.tensor_view<1x1x1x7x448xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c7, %c448], + strides = [%c3136, %c3136, %c3136, %c448, %c1] + : !pto.tensor_view<1x1x1x7x448xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c448] + : !pto.tensor_view<1x1x1x7x448xf32> -> !pto.partition_tensor_view<1x1x1x7x448xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c7, %c448] + : !pto.tensor_view<1x1x1x7x448xf32> -> !pto.partition_tensor_view<1x1x1x7x448xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x7x448xf32>) + outs(%src : !pto.tile_buf) + pto.tsubs ins(%src, %scalar : !pto.tile_buf, f32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x7x448xf32>) + return + } + + // Case 5: f32 256x16 (4096 elements) + func.func @TSUBS_f32_256x16(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c16 = arith.constant 16 : index + %c4096 = arith.constant 4096 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c256, %c16], + strides = [%c4096, %c4096, %c4096, %c16, %c1] + : !pto.tensor_view<1x1x1x256x16xf32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c256, %c16], + strides = [%c4096, %c4096, %c4096, %c16, %c1] + : !pto.tensor_view<1x1x1x256x16xf32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c16] + : !pto.tensor_view<1x1x1x256x16xf32> -> !pto.partition_tensor_view<1x1x1x256x16xf32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c256, %c16] + : !pto.tensor_view<1x1x1x256x16xf32> -> !pto.partition_tensor_view<1x1x1x256x16xf32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x256x16xf32>) + outs(%src : !pto.tile_buf) + pto.tsubs ins(%src, %scalar : !pto.tile_buf, f32) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x256x16xf32>) + return + } + +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/txor/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/txor/CMakeLists.txt new file mode 100644 index 000000000..e54d015e1 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/txor/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(txor) \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/txor/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/txor/cases.py new file mode 100644 index 000000000..c710ea612 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/txor/cases.py @@ -0,0 +1,40 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for txor ST test cases. + +Each case defines: + - name: case identifier, used as subdirectory name and by main.cpp kCases[]. + - dtype: numpy dtype (e.g. np.int32). + - shape: (rows, cols) — allocated tile dimensions. + - valid_shape: (valid_rows, valid_cols) — effective computation region. + - eps: tolerance for numpy.allclose (atol and rtol). + +gen_data.py and compare.py both import this list to avoid redundant definitions. +""" + +import numpy as np + +CASES = [ + { + "name": "i32_16x64", + "dtype": np.int32, + "shape": (16, 64), + "valid_shape": (16, 64), + "eps": 0, + }, + { + "name": "i32_32x32", + "dtype": np.int32, + "shape": (32, 32), + "valid_shape": (32, 32), + "eps": 0, + }, +] \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/txor/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/txor/compare.py new file mode 100644 index 000000000..4eae3bc07 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/txor/compare.py @@ -0,0 +1,48 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/txor/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/txor/gen_data.py new file mode 100644 index 000000000..2d2fbe7b6 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/txor/gen_data.py @@ -0,0 +1,32 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + input1 = np.random.randint(0, 100, size=shape).astype(dtype) + input2 = np.random.randint(0, 100, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + golden[:vr, :vc] = (input1[:vr, :vc] ^ input2[:vr, :vc]).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "input2": input2, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__}") \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/txor/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/txor/launch.cpp new file mode 100644 index 000000000..90fd20459 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/txor/launch.cpp @@ -0,0 +1,27 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: i32 16x64 +extern "C" __global__ AICORE void TXOR_i32_16x64(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); + +void LaunchTXOR_i32_16x64(int32_t *a, int32_t *b, int32_t *c, void *stream) { + TXOR_i32_16x64<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); +} + +// Case 1: i32 32x32 +extern "C" __global__ AICORE void TXOR_i32_32x32(__gm__ int32_t *a, __gm__ int32_t *b, __gm__ int32_t *c); + +void LaunchTXOR_i32_32x32(int32_t *a, int32_t *b, int32_t *c, void *stream) { + TXOR_i32_32x32<<<1, nullptr, stream>>>((__gm__ int32_t *)a, (__gm__ int32_t *)b, (__gm__ int32_t *)c); +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/txor/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/txor/main.cpp new file mode 100644 index 000000000..838ff0de1 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/txor/main.cpp @@ -0,0 +1,145 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang txor ST — case-table driven. +// Each case launches a different kernel variant, reads/writes from per-case subdirectory. +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTXOR_i32_16x64(int32_t *a, int32_t *b, int32_t *c, void *stream); +void LaunchTXOR_i32_32x32(int32_t *a, int32_t *b, int32_t *c, void *stream); + +using LaunchFn = void (*)(int32_t *, int32_t *, int32_t *, void *); + +struct TestCase { + const char *name; + LaunchFn launch; + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"i32_16x64", LaunchTXOR_i32_16x64, 16, 64, 16, 64, sizeof(int32_t)}, + {"i32_32x32", LaunchTXOR_i32_32x32, 32, 32, 32, 32, sizeof(int32_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t src0FileSize = fileSize; + size_t src1FileSize = fileSize; + + int32_t *src0Host = nullptr, *src1Host = nullptr, *dstHost = nullptr; + int32_t *src0Device = nullptr, *src1Device = nullptr, *dstDevice = nullptr; + + aclrtMallocHost((void **)(&src0Host), fileSize); + aclrtMallocHost((void **)(&src1Host), fileSize); + aclrtMallocHost((void **)(&dstHost), fileSize); + + aclrtMalloc((void **)&src0Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&src1Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc((void **)&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), src0FileSize, src0Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + if (rc == 0 && !ReadFile((caseDir + "/input2.bin").c_str(), src1FileSize, src1Host, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input2.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(src0Device, fileSize, src0Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + aclrtMemcpy(src1Device, fileSize, src1Host, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(src0Device, src1Device, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (src0Device != nullptr) + aclrtFree(src0Device); + if (src1Device != nullptr) + aclrtFree(src1Device); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (src0Host != nullptr) + aclrtFreeHost(src0Host); + if (src1Host != nullptr) + aclrtFreeHost(src1Host); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./txor [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/txor/txor.pto b/test/tilelang_st/npu/a5/src/st/testcase/txor/txor.pto new file mode 100644 index 000000000..41de710fa --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/txor/txor.pto @@ -0,0 +1,125 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.txor: tload(a) + tload(b) + txor(a,b)->c + tstore(c). +// Multiple cases with different shapes in a single module. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + // Case 0: i32 16x64 (1024 elements) + func.func @TXOR_i32_16x64(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c16, %c64], + strides = [%c1024, %c1024, %c1024, %c64, %c1] + : !pto.tensor_view<1x1x1x16x64xi32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi32> -> !pto.partition_tensor_view<1x1x1x16x64xi32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi32> -> !pto.partition_tensor_view<1x1x1x16x64xi32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c16, %c64] + : !pto.tensor_view<1x1x1x16x64xi32> -> !pto.partition_tensor_view<1x1x1x16x64xi32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x16x64xi32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x16x64xi32>) + outs(%b : !pto.tile_buf) + + pto.txor ins(%a, %b, %c : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x16x64xi32>) + return + } + + // Case 1: i32 32x32 (1024 elements) + func.func @TXOR_i32_32x32(%a_ptr: !pto.ptr, %b_ptr: !pto.ptr, %c_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + + %a_view = pto.make_tensor_view %a_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xi32> + %b_view = pto.make_tensor_view %b_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xi32> + %c_view = pto.make_tensor_view %c_ptr, + shape = [%c1, %c1, %c1, %c32, %c32], + strides = [%c1024, %c1024, %c1024, %c32, %c1] + : !pto.tensor_view<1x1x1x32x32xi32> + + %a_part = pto.partition_view %a_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xi32> -> !pto.partition_tensor_view<1x1x1x32x32xi32> + %b_part = pto.partition_view %b_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xi32> -> !pto.partition_tensor_view<1x1x1x32x32xi32> + %c_part = pto.partition_view %c_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c32] + : !pto.tensor_view<1x1x1x32x32xi32> -> !pto.partition_tensor_view<1x1x1x32x32xi32> + + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %c = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<1x1x1x32x32xi32>) + outs(%a : !pto.tile_buf) + pto.tload ins(%b_part : !pto.partition_tensor_view<1x1x1x32x32xi32>) + outs(%b : !pto.tile_buf) + + pto.txor ins(%a, %b, %c : !pto.tile_buf, + !pto.tile_buf, + !pto.tile_buf) + outs(%c : !pto.tile_buf) + + pto.tstore ins(%c : !pto.tile_buf) + outs(%c_part : !pto.partition_tensor_view<1x1x1x32x32xi32>) + return + } +} \ No newline at end of file diff --git a/test/tilelang_st/npu/a5/src/st/testcase/txors/CMakeLists.txt b/test/tilelang_st/npu/a5/src/st/testcase/txors/CMakeLists.txt new file mode 100644 index 000000000..1bcd9e681 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/txors/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +pto_tilelang_vec_st(txors) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/txors/cases.py b/test/tilelang_st/npu/a5/src/st/testcase/txors/cases.py new file mode 100644 index 000000000..9b652056d --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/txors/cases.py @@ -0,0 +1,48 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +"""Single source of truth for txors ST test cases. + +txors: bitwise XOR with scalar, dst = src ^ scalar. +Integer only: i32, i16. +""" + +import numpy as np + +CASES = [ + { + "name": "i32_32x64", + "dtype": np.int32, + "shape": (32, 64), + "valid_shape": (32, 64), + "eps": 0, + }, + { + "name": "i16_63x64", + "dtype": np.int16, + "shape": (63, 64), + "valid_shape": (63, 64), + "eps": 0, + }, + { + "name": "i32_31x128", + "dtype": np.int32, + "shape": (31, 128), + "valid_shape": (31, 128), + "eps": 0, + }, + { + "name": "i16_15x192", + "dtype": np.int16, + "shape": (15, 192), + "valid_shape": (15, 192), + "eps": 0, + }, +] diff --git a/test/tilelang_st/npu/a5/src/st/testcase/txors/compare.py b/test/tilelang_st/npu/a5/src/st/testcase/txors/compare.py new file mode 100644 index 000000000..50186777e --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/txors/compare.py @@ -0,0 +1,46 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + +from cases import CASES +from st_common import result_cmp, style_fail, style_pass, validate_cases + +def main(): + validate_cases(CASES) + case_filter = sys.argv[1] if len(sys.argv) > 1 else None + + all_passed = True + for case in CASES: + if case_filter is not None and case["name"] != case_filter: + continue + + case_dir = case["name"] + shape = case["shape"] + vr, vc = case["valid_shape"] + + golden = np.fromfile(os.path.join(case_dir, "golden.bin"), dtype=case["dtype"]).reshape(shape) + output = np.fromfile(os.path.join(case_dir, "output.bin"), dtype=case["dtype"]).reshape(shape) + + ok = result_cmp(golden[:vr, :vc], output[:vr, :vc], case["eps"]) + if ok: + print(style_pass(f"[INFO] {case['name']}: compare passed")) + else: + print(style_fail(f"[ERROR] {case['name']}: compare failed")) + all_passed = False + + if not all_passed: + sys.exit(2) + print(style_pass("[INFO] all cases passed")) + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/npu/a5/src/st/testcase/txors/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/txors/gen_data.py new file mode 100644 index 000000000..5c12edd5a --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/txors/gen_data.py @@ -0,0 +1,35 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import numpy as np +from cases import CASES +from st_common import validate_cases, setup_case_rng, save_case_data + +# Scalar value for bitwise XOR (matches the scalar passed in launch.cpp) +SCALAR = 3 + +validate_cases(CASES) + +for case in CASES: + setup_case_rng(case) + + dtype = case["dtype"] + shape = case["shape"] + valid_shape = case["valid_shape"] + + input1 = np.random.randint(1, 10, size=shape).astype(dtype) + + golden = np.zeros(shape, dtype=dtype) + vr, vc = valid_shape + scalar_val = dtype(SCALAR) + golden[:vr, :vc] = (input1[:vr, :vc] ^ scalar_val).astype(dtype, copy=False) + + save_case_data(case["name"], {"input1": input1, "golden": golden}) + print(f"[INFO] gen_data: {case['name']} shape={shape} valid_shape={valid_shape} dtype={dtype.__name__} scalar={SCALAR}") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/txors/launch.cpp b/test/tilelang_st/npu/a5/src/st/testcase/txors/launch.cpp new file mode 100644 index 000000000..f61619d9f --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/txors/launch.cpp @@ -0,0 +1,41 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef AICORE +#define AICORE [aicore] +#endif + +// Case 0: i32 32x64 +extern "C" __global__ AICORE void TXORS_i32_32x64(__gm__ int32_t *src, __gm__ int32_t *dst, int32_t scalar); + +void LaunchTXORS_i32_32x64(int32_t *src, int32_t *dst, void *stream) { + TXORS_i32_32x64<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst, (int32_t)3); +} + +// Case 1: i16 63x64 +extern "C" __global__ AICORE void TXORS_i16_63x64(__gm__ int16_t *src, __gm__ int16_t *dst, int16_t scalar); + +void LaunchTXORS_i16_63x64(int16_t *src, int16_t *dst, void *stream) { + TXORS_i16_63x64<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst, (int16_t)3); +} + +// Case 2: i32 31x128 +extern "C" __global__ AICORE void TXORS_i32_31x128(__gm__ int32_t *src, __gm__ int32_t *dst, int32_t scalar); + +void LaunchTXORS_i32_31x128(int32_t *src, int32_t *dst, void *stream) { + TXORS_i32_31x128<<<1, nullptr, stream>>>((__gm__ int32_t *)src, (__gm__ int32_t *)dst, (int32_t)3); +} + +// Case 3: i16 15x192 +extern "C" __global__ AICORE void TXORS_i16_15x192(__gm__ int16_t *src, __gm__ int16_t *dst, int16_t scalar); + +void LaunchTXORS_i16_15x192(int16_t *src, int16_t *dst, void *stream) { + TXORS_i16_15x192<<<1, nullptr, stream>>>((__gm__ int16_t *)src, (__gm__ int16_t *)dst, (int16_t)3); +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/txors/main.cpp b/test/tilelang_st/npu/a5/src/st/testcase/txors/main.cpp new file mode 100644 index 000000000..f46282f01 --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/txors/main.cpp @@ -0,0 +1,135 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Host driver for TileLang txors ST — case-table driven. +// txors: dst = src ^ scalar (bitwise XOR with scalar). +// Numerical comparison is done externally by compare.py. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include +#include +#include +#include + +using namespace PtoTestCommon; + +// Kernel launch wrappers (defined in launch.cpp) +void LaunchTXORS_i32_32x64(int32_t *src, int32_t *dst, void *stream); +void LaunchTXORS_i16_63x64(int16_t *src, int16_t *dst, void *stream); +void LaunchTXORS_i32_31x128(int32_t *src, int32_t *dst, void *stream); +void LaunchTXORS_i16_15x192(int16_t *src, int16_t *dst, void *stream); + +struct TestCase { + const char *name; + void (*launch)(void *, void *, void *); // src, dst, stream + size_t rows; // allocated tile rows + size_t cols; // allocated tile cols + size_t validRows; // effective computation rows (<= rows) + size_t validCols; // effective computation cols (<= cols) + size_t elemSize; // bytes per element +}; + +static const TestCase kCases[] = { + {"i32_32x64", (void (*)(void*,void*,void*))LaunchTXORS_i32_32x64, 32, 64, 32, 64, sizeof(int32_t)}, + {"i16_63x64", (void (*)(void*,void*,void*))LaunchTXORS_i16_63x64, 63, 64, 63, 64, sizeof(int16_t)}, + {"i32_31x128", (void (*)(void*,void*,void*))LaunchTXORS_i32_31x128, 31, 128, 31, 128, sizeof(int32_t)}, + {"i16_15x192", (void (*)(void*,void*,void*))LaunchTXORS_i16_15x192, 15, 192, 15, 192, sizeof(int16_t)}, +}; +static constexpr size_t kNumCases = sizeof(kCases) / sizeof(kCases[0]); + +static int RunCase(const TestCase &tc, int deviceId, aclrtStream stream) { + int rc = 0; + const size_t elemCount = tc.rows * tc.cols; + const size_t fileSize = elemCount * tc.elemSize; + + std::printf("[INFO] === case: %s (shape=%zux%zu, valid=%zux%zu) ===\n", + tc.name, tc.rows, tc.cols, tc.validRows, tc.validCols); + + // Per-case data directory + std::string caseDir = std::string("./") + tc.name; + size_t srcFileSize = fileSize; + + void *srcHost = nullptr, *dstHost = nullptr; + void *srcDevice = nullptr, *dstDevice = nullptr; + + aclrtMallocHost(&srcHost, fileSize); + aclrtMallocHost(&dstHost, fileSize); + + aclrtMalloc(&srcDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + aclrtMalloc(&dstDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST); + + if (!ReadFile((caseDir + "/input1.bin").c_str(), srcFileSize, srcHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to read %s/input1.bin\n", caseDir.c_str()); + rc = 1; + } + + if (rc == 0) { + aclrtMemcpy(srcDevice, fileSize, srcHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE); + + tc.launch(srcDevice, dstDevice, stream); + + aclrtSynchronizeStream(stream); + aclrtMemcpy(dstHost, fileSize, dstDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST); + } + + if (rc == 0 && !WriteFile((caseDir + "/output.bin").c_str(), dstHost, fileSize)) { + std::fprintf(stderr, "[ERROR] failed to write %s/output.bin\n", caseDir.c_str()); + rc = 1; + } + + if (srcDevice != nullptr) + aclrtFree(srcDevice); + if (dstDevice != nullptr) + aclrtFree(dstDevice); + if (srcHost != nullptr) + aclrtFreeHost(srcHost); + if (dstHost != nullptr) + aclrtFreeHost(dstHost); + + if (rc == 0) + std::printf("[INFO] case %s done\n", tc.name); + return rc; +} + +int main(int argc, char *argv[]) { + // Optional case filter: ./txors [case_name] + const char *caseFilter = (argc > 1) ? argv[1] : nullptr; + + int rc = 0; + int deviceId = 0; + aclrtStream stream = nullptr; + + aclInit(nullptr); + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + aclrtSetDevice(deviceId); + aclrtCreateStream(&stream); + + for (size_t i = 0; i < kNumCases; ++i) { + if (caseFilter != nullptr && std::strcmp(kCases[i].name, caseFilter) != 0) { + continue; + } + int ret = RunCase(kCases[i], deviceId, stream); + if (ret != 0) { + std::fprintf(stderr, "[ERROR] case %s failed\n", kCases[i].name); + rc = 1; + break; + } + } + + if (stream != nullptr) + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + aclFinalize(); + + return rc; +} diff --git a/test/tilelang_st/npu/a5/src/st/testcase/txors/txors.pto b/test/tilelang_st/npu/a5/src/st/testcase/txors/txors.pto new file mode 100644 index 000000000..17ec4041f --- /dev/null +++ b/test/tilelang_st/npu/a5/src/st/testcase/txors/txors.pto @@ -0,0 +1,189 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// TileLang ST kernels for pto.txors: tload(src) + txors(src, scalar, tmp)->dst + tstore(dst). +// Bitwise XOR with scalar: dst = src ^ scalar. +// Integer only: i32, i16. +// Compiled by ptoas --enable-insert-sync --enable-tile-op-expand --pto-backend=vpto +// to produce a fatobj object. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + + // Case 0: i32 32x64 (2048 elements) + func.func @TXORS_i32_32x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c2048 = arith.constant 2048 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c32, %c64], + strides = [%c2048, %c2048, %c2048, %c64, %c1] + : !pto.tensor_view<1x1x1x32x64xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xi32> -> !pto.partition_tensor_view<1x1x1x32x64xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c32, %c64] + : !pto.tensor_view<1x1x1x32x64xi32> -> !pto.partition_tensor_view<1x1x1x32x64xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x64xi32>) + outs(%src : !pto.tile_buf) + pto.txors ins(%src, %scalar, %tmp : !pto.tile_buf, i32, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x64xi32>) + return + } + + // Case 1: i16 63x64 (4032 elements) + func.func @TXORS_i16_63x64(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i16) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c63 = arith.constant 63 : index + %c64 = arith.constant 64 : index + %c4032 = arith.constant 4032 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c63, %c64], + strides = [%c4032, %c4032, %c4032, %c64, %c1] + : !pto.tensor_view<1x1x1x63x64xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xi16> -> !pto.partition_tensor_view<1x1x1x63x64xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c63, %c64] + : !pto.tensor_view<1x1x1x63x64xi16> -> !pto.partition_tensor_view<1x1x1x63x64xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x63x64xi16>) + outs(%src : !pto.tile_buf) + pto.txors ins(%src, %scalar, %tmp : !pto.tile_buf, i16, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x63x64xi16>) + return + } + + // Case 2: i32 31x128 (3968 elements) + func.func @TXORS_i32_31x128(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c31 = arith.constant 31 : index + %c128 = arith.constant 128 : index + %c3968 = arith.constant 3968 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi32> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c31, %c128], + strides = [%c3968, %c3968, %c3968, %c128, %c1] + : !pto.tensor_view<1x1x1x31x128xi32> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c128] + : !pto.tensor_view<1x1x1x31x128xi32> -> !pto.partition_tensor_view<1x1x1x31x128xi32> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c31, %c128] + : !pto.tensor_view<1x1x1x31x128xi32> -> !pto.partition_tensor_view<1x1x1x31x128xi32> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x31x128xi32>) + outs(%src : !pto.tile_buf) + pto.txors ins(%src, %scalar, %tmp : !pto.tile_buf, i32, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x31x128xi32>) + return + } + + // Case 3: i16 15x192 (2880 elements) + func.func @TXORS_i16_15x192(%src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i16) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c15 = arith.constant 15 : index + %c192 = arith.constant 192 : index + %c2880 = arith.constant 2880 : index + + %src_view = pto.make_tensor_view %src_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xi16> + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c1, %c1, %c1, %c15, %c192], + strides = [%c2880, %c2880, %c2880, %c192, %c1] + : !pto.tensor_view<1x1x1x15x192xi16> + + %src_part = pto.partition_view %src_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xi16> -> !pto.partition_tensor_view<1x1x1x15x192xi16> + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0, %c0, %c0, %c0], + sizes = [%c1, %c1, %c1, %c15, %c192] + : !pto.tensor_view<1x1x1x15x192xi16> -> !pto.partition_tensor_view<1x1x1x15x192xi16> + + %src = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + %tmp = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x15x192xi16>) + outs(%src : !pto.tile_buf) + pto.txors ins(%src, %scalar, %tmp : !pto.tile_buf, i16, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<1x1x1x15x192xi16>) + return + } + +} diff --git a/test/tilelang_st/script/run_all_st.py b/test/tilelang_st/script/run_all_st.py new file mode 100755 index 000000000..e1b47c54f --- /dev/null +++ b/test/tilelang_st/script/run_all_st.py @@ -0,0 +1,297 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""Batch runner for TileLang ST, suitable for CI/self-hosted runner usage.""" + +import argparse +import concurrent.futures +import importlib.util +import os +import subprocess +import sys +import traceback + +import run_st + + +SOC_VERSION_MAP = { + "a5": "Ascend950PR_9599", +} + +SMOKE_CASE_LIMIT = 1 + + +def discover_testcases(testcase_root): + testcases = [] + for entry in sorted(os.listdir(testcase_root)): + testcase_dir = os.path.join(testcase_root, entry) + if not os.path.isdir(testcase_dir): + continue + pto_file = os.path.join(testcase_dir, f"{entry}.pto") + if os.path.isfile(pto_file): + testcases.append(entry) + return testcases + + +def load_case_names(testcase_root, testcase): + cases_path = os.path.join(testcase_root, testcase, "cases.py") + if not os.path.isfile(cases_path): + raise FileNotFoundError(f"cases.py not found: {cases_path}") + + spec = importlib.util.spec_from_file_location(f"_tilelang_{testcase}_cases", cases_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return [case["name"] for case in module.CASES] + + +def resolve_case_filters(testcase_root, testcase, smoke_mode): + if not smoke_mode: + return [] + case_names = load_case_names(testcase_root, testcase) + if not case_names: + raise ValueError(f"no cases found for smoke testcase: {testcase}") + return case_names[:SMOKE_CASE_LIMIT] + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Run all TileLang ST testcases for CI or local batch validation." + ) + parser.add_argument( + "-r", "--run-mode", default="sim", + help="Run mode: sim or npu (default: sim)", + ) + parser.add_argument( + "-v", "--soc-version", default="a5", + help="SoC version: a5 (default: a5)", + ) + parser.add_argument( + "-p", "--ptoas-bin", default=None, + help="Path to ptoas binary (auto-detected if omitted)", + ) + parser.add_argument( + "-t", "--testcase", action="append", default=[], + help="Run only selected testcase(s). Can be passed multiple times.", + ) + parser.add_argument( + "-w", "--without-build", action="store_true", + help="Skip build and reuse the existing build directory.", + ) + parser.add_argument( + "--fail-fast", action="store_true", + help="Stop immediately after the first failed testcase.", + ) + parser.add_argument( + "--list", action="store_true", + help="List discovered testcases and exit.", + ) + parser.add_argument( + "-j", "--jobs", type=int, default=1, + help="Number of testcases to run in parallel after the shared build (default: 1).", + ) + parser.add_argument( + "--smoke", action="store_true", + help="Run only a representative smoke subset of cases for each testcase.", + ) + return parser.parse_args() + + +def resolve_selected_testcases(all_testcases, requested): + if not requested: + return all_testcases + + requested_set = [] + seen = set() + for testcase in requested: + if testcase not in seen: + requested_set.append(testcase) + seen.add(testcase) + + missing = [testcase for testcase in requested_set if testcase not in all_testcases] + if missing: + raise ValueError( + f"Unsupported testcase(s): {', '.join(missing)}; " + f"supported: {', '.join(all_testcases)}" + ) + return requested_set + + +def run_testcase_subprocess(run_st_script_path, run_mode, soc_version, ptoas_bin, testcase, case_filters=None): + command = [ + sys.executable, + run_st_script_path, + "-r", run_mode, + "-v", soc_version, + "-t", testcase, + "-p", ptoas_bin, + "-w", + ] + for case_filter in case_filters or []: + command.extend(["-c", case_filter]) + env = os.environ.copy() + result = subprocess.run( + command, + check=False, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + env=env, + ) + return testcase, result.returncode, result.stdout + + +def main(): + args = parse_args() + + if args.soc_version not in SOC_VERSION_MAP: + print( + f"[ERROR] Unsupported soc-version: {args.soc_version}, " + f"supported: {', '.join(sorted(SOC_VERSION_MAP))}", + file=sys.stderr, + ) + sys.exit(1) + if args.jobs < 1: + print("[ERROR] --jobs must be >= 1", file=sys.stderr) + sys.exit(1) + + batch_script_path = os.path.abspath(__file__) + run_st_script_path = os.path.abspath(run_st.__file__) + tilelang_st_root = os.path.dirname(os.path.dirname(batch_script_path)) + testcase_root = os.path.join( + tilelang_st_root, "npu", args.soc_version, "src", "st", "testcase" + ) + target_dir = os.path.dirname(testcase_root) + + if not os.path.isdir(testcase_root): + print(f"[ERROR] Testcase root not found: {testcase_root}", file=sys.stderr) + sys.exit(1) + + all_testcases = discover_testcases(testcase_root) + if not all_testcases: + print(f"[ERROR] No testcases found in: {testcase_root}", file=sys.stderr) + sys.exit(1) + + if args.list: + for testcase in all_testcases: + print(testcase) + return + + try: + selected_testcases = resolve_selected_testcases(all_testcases, args.testcase) + except ValueError as exc: + print(f"[ERROR] {exc}", file=sys.stderr) + sys.exit(1) + + ptoas_bin = args.ptoas_bin or run_st.find_ptoas_bin() + if not ptoas_bin: + print( + "[ERROR] Cannot find ptoas binary. Set PTOAS_BIN env or use -p flag.", + file=sys.stderr, + ) + sys.exit(1) + ptoas_bin = os.path.abspath(ptoas_bin) + + default_soc_version = SOC_VERSION_MAP[args.soc_version] + print(f"[INFO] run_mode={args.run_mode}") + print(f"[INFO] soc_version={args.soc_version} ({default_soc_version})") + print(f"[INFO] ptoas={ptoas_bin}") + print(f"[INFO] target_dir={target_dir}") + print(f"[INFO] selected_testcases={', '.join(selected_testcases)}") + print(f"[INFO] smoke={args.smoke}") + print(f"[INFO] jobs={args.jobs}") + + original_dir = os.getcwd() + failures = [] + try: + os.chdir(target_dir) + run_st.set_env_variables(args.run_mode, default_soc_version) + + if not args.without_build: + build_target = "all" if selected_testcases == all_testcases else ";".join(selected_testcases) + print(f"[INFO] build requested for {build_target}") + run_st.build_project(args.run_mode, default_soc_version, "all", ptoas_bin) + + total = len(selected_testcases) + if args.jobs == 1: + for index, testcase in enumerate(selected_testcases, start=1): + case_filters = resolve_case_filters(testcase_root, testcase, args.smoke) + print(f"[INFO] [{index}/{total}] running testcase: {testcase}") + if case_filters: + print(f"[INFO] smoke cases: {', '.join(case_filters)}") + try: + run_st.run_gen_data(testcase, case_filters) + run_st.run_binary(testcase, case_filters) + run_st.run_compare(testcase, case_filters) + except Exception as exc: # pragma: no cover - CI-side aggregation path + failures.append((testcase, str(exc))) + print(f"[ERROR] testcase failed: {testcase}") + traceback.print_exc() + if args.fail_fast: + break + else: + print(f"[INFO] running testcases in parallel with jobs={args.jobs}") + max_workers = min(args.jobs, total) + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + future_to_testcase = {} + for index, testcase in enumerate(selected_testcases, start=1): + case_filters = resolve_case_filters(testcase_root, testcase, args.smoke) + print(f"[INFO] [{index}/{total}] queue testcase: {testcase}") + if case_filters: + print(f"[INFO] smoke cases: {', '.join(case_filters)}") + future = executor.submit( + run_testcase_subprocess, + run_st_script_path, + args.run_mode, + args.soc_version, + ptoas_bin, + testcase, + case_filters, + ) + future_to_testcase[future] = testcase + + for future in concurrent.futures.as_completed(future_to_testcase): + testcase = future_to_testcase[future] + try: + _, returncode, output = future.result() + except Exception as exc: # pragma: no cover - executor/host failure + failures.append((testcase, str(exc))) + print(f"[ERROR] testcase runner crashed: {testcase}") + traceback.print_exc() + if args.fail_fast: + break + continue + + print(f"[INFO] ===== testcase {testcase} output begin =====") + if output: + print(output, end="" if output.endswith("\n") else "\n") + print(f"[INFO] ===== testcase {testcase} output end =====") + + if returncode != 0: + failures.append((testcase, f"subprocess exited with {returncode}")) + print(f"[ERROR] testcase failed: {testcase}") + if args.fail_fast: + break + + except Exception as exc: + print(f"[ERROR] batch run failed: {exc}", file=sys.stderr) + sys.exit(1) + finally: + os.chdir(original_dir) + + passed = len(selected_testcases) - len(failures) + print("[INFO] TileLang ST summary") + print(f"[INFO] passed={passed} failed={len(failures)} total={len(selected_testcases)}") + if failures: + for testcase, reason in failures: + print(f"[INFO] failed testcase: {testcase} ({reason})") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/test/tilelang_st/script/run_ci.sh b/test/tilelang_st/script/run_ci.sh new file mode 100755 index 000000000..385a7bda9 --- /dev/null +++ b/test/tilelang_st/script/run_ci.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +set -euo pipefail + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd -- "${SCRIPT_DIR}/../../.." && pwd)" + +if [[ -f "${REPO_ROOT}/scripts/ptoas_env.sh" ]]; then + # shellcheck source=/dev/null + source "${REPO_ROOT}/scripts/ptoas_env.sh" +fi + +export PYTHONUNBUFFERED=1 + +exec python3 "${SCRIPT_DIR}/run_all_st.py" "$@" diff --git a/test/tilelang_st/script/run_st.py b/test/tilelang_st/script/run_st.py new file mode 100755 index 000000000..44bf191ac --- /dev/null +++ b/test/tilelang_st/script/run_st.py @@ -0,0 +1,322 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +""" +TileLang ST runner — validates TileLang DSL template library on NPU / simulator. + +Usage: + python3 test/tilelang_st/script/run_st.py -r npu -v a5 -t tadd + python3 test/tilelang_st/script/run_st.py -r sim -v a5 -t tadd +""" + +import os +import sys +import subprocess +import shutil +import argparse + + +def run_command(command, cwd=None, check=True): + try: + print(f"run command: {' '.join(command)}") + subprocess.run(command, cwd=cwd, check=check, stdout=None, stderr=None, text=True) + except subprocess.CalledProcessError as e: + print(f"run command failed with return code {e.returncode}") + raise + + +def _normalize_case_filters(case_filters): + if not case_filters: + return [] + if isinstance(case_filters, str): + return [case_filters] + normalized = [] + seen = set() + for case_filter in case_filters: + if not case_filter or case_filter in seen: + continue + normalized.append(case_filter) + seen.add(case_filter) + return normalized + + +def find_ptoas_bin(): + """Locate the ptoas binary by walking up from this script to the repo root.""" + env_bin = os.environ.get("PTOAS_BIN") + if env_bin and os.path.isfile(env_bin): + return os.path.abspath(env_bin) + + search_dir = os.path.dirname(os.path.abspath(__file__)) + for _ in range(8): + candidate = os.path.join(search_dir, "build", "tools", "ptoas", "ptoas") + if os.path.isfile(candidate): + return os.path.abspath(candidate) + parent = os.path.dirname(search_dir) + if parent == search_dir: + break + search_dir = parent + return None + + +def set_env_variables(run_mode, soc_version): + if run_mode == "sim": + ld_lib_path = os.environ.get("LD_LIBRARY_PATH", "") + if ld_lib_path: + filtered_paths = [ + path for path in ld_lib_path.split(":") + if "/runtime/lib64" not in path + ] + os.environ["LD_LIBRARY_PATH"] = ":".join(filtered_paths) + + ascend_home = os.environ.get("ASCEND_HOME_PATH") + if not ascend_home: + raise EnvironmentError("ASCEND_HOME_PATH is not set") + + os.environ["LD_LIBRARY_PATH"] = ( + f"{ascend_home}/runtime/lib64/stub:{os.environ.get('LD_LIBRARY_PATH', '')}" + ) + setenv_path = os.path.join(ascend_home, "bin", "setenv.bash") + if os.path.exists(setenv_path): + print(f"run env shell: {setenv_path}") + result = subprocess.run( + f"source {setenv_path} && env", + shell=True, + executable=shutil.which("bash") or "bash", + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + for line in result.stdout.splitlines(): + if "=" in line: + key, value = line.split("=", 1) + os.environ[key] = value + else: + print(f"warning: not found {setenv_path}") + + simulator_lib_path = os.path.join( + ascend_home, "tools", "simulator", soc_version, "lib" + ) + os.environ["LD_LIBRARY_PATH"] = ( + f"{simulator_lib_path}:{os.environ.get('LD_LIBRARY_PATH', '')}" + ) + + +def get_testcase_work_dir(testcase): + return os.path.join("build", "testcase", testcase) + + +def build_project(run_mode, soc_version, testcase, ptoas_bin): + build_dir = "build" + if os.path.exists(build_dir): + print(f"clean build: {build_dir}") + shutil.rmtree(build_dir) + os.makedirs(build_dir, exist_ok=True) + + try: + cmake_cmd = [ + "cmake", + f"-DRUN_MODE={run_mode}", + f"-DSOC_VERSION={soc_version}", + f"-DTEST_CASE={testcase}", + f"-DPTOAS_BIN={ptoas_bin}", + "..", + ] + subprocess.run( + cmake_cmd, + cwd=build_dir, + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + ) + + cpu_count = os.cpu_count() or 4 + make_cmd = ["make", "VERBOSE=1", "-j", str(cpu_count)] + result = subprocess.run( + make_cmd, + cwd=build_dir, + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + ) + print("compile process:\n", result.stdout) + except subprocess.CalledProcessError as e: + print(f"build failed: {e.stdout}") + raise + + +def _write_filtered_cases_module(src_path, dst_path, selected_case_names): + selected_case_names = _normalize_case_filters(selected_case_names) + wrapper = f"""#!/usr/bin/python3 +# Auto-generated by run_st.py; do not edit. + +import importlib.util as _importlib_util + +_SOURCE_PATH = {src_path!r} +_SELECTED_CASE_NAMES = {selected_case_names!r} + +_SPEC = _importlib_util.spec_from_file_location("_tilelang_st_source_cases", _SOURCE_PATH) +_MODULE = _importlib_util.module_from_spec(_SPEC) +_SPEC.loader.exec_module(_MODULE) + +for _name, _value in vars(_MODULE).items(): + if not _name.startswith("_"): + globals()[_name] = _value + +if _SELECTED_CASE_NAMES: + _selected = set(_SELECTED_CASE_NAMES) + _all_case_names = [case.get("name") for case in CASES] + _missing = sorted(name for name in _selected if name not in _all_case_names) + if _missing: + raise RuntimeError("unknown case filter(s): " + ", ".join(_missing)) + CASES = [case for case in CASES if case.get("name") in _selected] +""" + with open(dst_path, "w", encoding="utf-8") as handle: + handle.write(wrapper) + + +def _copy_testcase_scripts(testcase, case_filters=None): + """Copy shared and per-testcase Python scripts into the build work directory.""" + work_dir = get_testcase_work_dir(testcase) + os.makedirs(work_dir, exist_ok=True) + # Shared scripts (testcase/ level). + for name in ("st_common.py",): + src = os.path.join("testcase", name) + if os.path.isfile(src): + run_command(["cp", src, os.path.join(work_dir, name)]) + # Per-testcase scripts. + testcase_src = f"testcase/{testcase}" + for name in ("cases.py", "gen_data.py", "compare.py"): + src = os.path.join(testcase_src, name) + if os.path.isfile(src): + dst = os.path.join(work_dir, name) + if name == "cases.py": + selected_case_names = _normalize_case_filters(case_filters) + if selected_case_names: + _write_filtered_cases_module(os.path.abspath(src), dst, selected_case_names) + continue + run_command(["cp", src, dst]) + + +def run_gen_data(testcase, case_filters=None): + original_dir = os.getcwd() + try: + work_dir = get_testcase_work_dir(testcase) + _copy_testcase_scripts(testcase, case_filters) + os.chdir(work_dir) + run_command([sys.executable, "gen_data.py"]) + except Exception as e: + print(f"gen golden failed: {e}") + raise + finally: + os.chdir(original_dir) + + +def run_binary(testcase, case_filters=None): + original_dir = os.getcwd() + try: + os.chdir(get_testcase_work_dir(testcase)) + binary = os.path.join("..", "..", "bin", testcase) + normalized_case_filters = _normalize_case_filters(case_filters) + if not normalized_case_filters: + run_command([binary]) + else: + for case_filter in normalized_case_filters: + run_command([binary, case_filter]) + except Exception as e: + print(f"run binary failed: {e}") + raise + finally: + os.chdir(original_dir) + + +def run_compare(testcase, case_filters=None): + original_dir = os.getcwd() + try: + work_dir = get_testcase_work_dir(testcase) + _copy_testcase_scripts(testcase, case_filters) + os.chdir(work_dir) + cmd = [sys.executable, "compare.py"] + run_command(cmd) + except Exception as e: + print(f"compare failed: {e}") + raise + finally: + os.chdir(original_dir) + + +def main(): + parser = argparse.ArgumentParser(description="TileLang ST runner") + parser.add_argument("-r", "--run-mode", required=True, + help="Run mode: sim or npu") + parser.add_argument("-v", "--soc-version", required=True, + help="SoC version: a5") + parser.add_argument("-t", "--testcase", required=True, + help="Test case name (e.g. tadd)") + parser.add_argument("-p", "--ptoas-bin", required=False, + help="Path to ptoas binary (auto-detected if omitted)") + parser.add_argument("-c", "--case", action="append", default=[], + help="Run one or more specific cases within the testcase. Can be passed multiple times.") + parser.add_argument("-w", "--without-build", action="store_true", + help="Skip build (requires prior build)") + + args = parser.parse_args() + + if args.soc_version == "a5": + default_soc_version = "Ascend950PR_9599" + else: + print(f"[ERROR] Unsupported soc-version: {args.soc_version}, only a5 is supported", + file=sys.stderr) + sys.exit(1) + + testcase = args.testcase + + ptoas_bin = args.ptoas_bin or find_ptoas_bin() + if not ptoas_bin: + print("[ERROR] Cannot find ptoas binary. " + "Set PTOAS_BIN env or use -p flag.", file=sys.stderr) + sys.exit(1) + ptoas_bin = os.path.abspath(ptoas_bin) + print(f"[INFO] ptoas: {ptoas_bin}") + + original_dir = os.getcwd() + try: + script_path = os.path.abspath(__file__) + tilelang_st_root = os.path.dirname(os.path.dirname(script_path)) + target_dir = os.path.join(tilelang_st_root, "npu", args.soc_version, "src", "st") + + if not os.path.isdir(target_dir): + print(f"[ERROR] Target dir not found: {target_dir}", file=sys.stderr) + sys.exit(1) + + print(f"target_dir: {target_dir}") + os.chdir(target_dir) + + set_env_variables(args.run_mode, default_soc_version) + + if not args.without_build: + build_project(args.run_mode, default_soc_version, testcase, ptoas_bin) + + # gen golden → run binary → compare + run_gen_data(testcase, args.case) + run_binary(testcase, args.case) + run_compare(testcase, args.case) + + except Exception as e: + print(f"run failed: {str(e)}", file=sys.stderr) + sys.exit(1) + finally: + os.chdir(original_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/kernels/online-softmax-update/compare.py b/test/vpto/cases/kernels/online-softmax-update/compare.py new file mode 100644 index 000000000..0130f47d0 --- /dev/null +++ b/test/vpto/cases/kernels/online-softmax-update/compare.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: kernels/online-softmax-update +# family: kernels +# target_ops: pto.mte_gm_ub, pto.mte_ub_gm, pto.vlds, pto.vcmax, pto.vdup, pto.vmax, pto.vexpdif, pto.vcadd, pto.vadd, pto.vmul, pto.vdiv, pto.vsts +# scenarios: online-softmax-update, 16x128-f32, oldmax-oldsum-qk-to-newmax-newsum-expmax-out + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden.shape} vs {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + abs_diff = np.abs(golden.astype(np.float64) - output.astype(np.float64)) + idx = int(np.argmax(abs_diff)) + print( + f"[ERROR] Mismatch: max diff={float(abs_diff[idx])} at idx={idx} " + f"(golden={float(golden[idx])}, out={float(output[idx])}, dtype={dtype_np})" + ) + return False + return True + + +def compare_matrix_valid(golden_path, output_path, rows, cols, valid_cols, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + expected_elems = rows * cols + if golden.size != expected_elems or output.size != expected_elems: + print( + f"[ERROR] Shape mismatch: expected elems={expected_elems}, " + f"golden={golden.size}, out={output.size}" + ) + return False + golden = golden.reshape(rows, cols) + output = output.reshape(rows, cols) + if not np.allclose( + golden[:, :valid_cols], + output[:, :valid_cols], + atol=eps, + rtol=eps, + equal_nan=True, + ): + abs_diff = np.abs( + golden[:, :valid_cols].astype(np.float64) + - output[:, :valid_cols].astype(np.float64) + ) + flat_idx = int(np.argmax(abs_diff)) + row, col = divmod(flat_idx, valid_cols) + print( + f"[ERROR] Mismatch in valid region: max diff={float(abs_diff[row, col])} " + f"at row={row}, col={col} " + f"(golden={float(golden[row, col])}, out={float(output[row, col])}, dtype={dtype_np})" + ) + return False + return True + + +def main(): + rows = int(np.fromfile("v9.bin", dtype=np.int32)[0]) + seq = int(np.fromfile("v8.bin", dtype=np.int32)[0]) + ok = True + ok = compare_bin("golden_v4.bin", "v4.bin", np.float32, 1e-4) and ok + ok = compare_bin("golden_v5.bin", "v5.bin", np.float32, 1e-4) and ok + ok = compare_bin("golden_v6.bin", "v6.bin", np.float32, 1e-4) and ok + ok = compare_matrix_valid( + "golden_v7.bin", "v7.bin", rows, 128, seq, np.float32, 1e-4 + ) and ok + if not ok: + print("[ERROR] compare failed") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/kernels/online-softmax-update/golden.py b/test/vpto/cases/kernels/online-softmax-update/golden.py new file mode 100644 index 000000000..47cdddad6 --- /dev/null +++ b/test/vpto/cases/kernels/online-softmax-update/golden.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: kernels/online-softmax-update +# family: kernels +# target_ops: pto.get_block_idx, pto.mte_gm_ub, pto.mte_ub_gm, pto.vlds, pto.vcmax, pto.vdup, pto.vmax, pto.vexpdif, pto.vcadd, pto.vadd, pto.vmul, pto.vdiv, pto.vsts +# scenarios: online-softmax-update, dynamic-rows-and-seq, max-seq-128, block-rows-8, oldmax-oldsum-qk-to-newmax-newsum-expmax-out + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 24 +COLS = 128 +SEED = 19 +SEQ = 73 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + seq = SEQ + oldmax = rng.uniform(-3.0, 1.5, size=(ROWS,)).astype(np.float32) + oldsum = rng.uniform(0.5, 4.0, size=(ROWS,)).astype(np.float32) + qk = rng.normal(loc=0.0, scale=1.5, size=(ROWS, COLS)).astype(np.float32) + + qk_active = qk[:, :seq] + qk_rowmax = np.max(qk_active, axis=1) + newmax = np.maximum(qk_rowmax, oldmax) + tmp_active = np.exp(qk_active - newmax[:, None], dtype=np.float32) + cursum = np.sum(tmp_active, axis=1, dtype=np.float32) + raw_expmax = np.exp(oldmax - newmax, dtype=np.float32) + newsum = raw_expmax * oldsum + cursum + expmax = (raw_expmax * oldsum) / newsum + out = np.zeros((ROWS, COLS), dtype=np.float32) + out[:, :seq] = tmp_active / newsum[:, None] + + zeros_state = np.zeros((ROWS,), dtype=np.float32) + zeros_out = np.zeros((ROWS, COLS), dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + oldmax.tofile(output_dir / "v1.bin") + oldsum.tofile(output_dir / "v2.bin") + qk.reshape(-1).tofile(output_dir / "v3.bin") + zeros_state.tofile(output_dir / "v4.bin") + zeros_state.tofile(output_dir / "v5.bin") + zeros_state.tofile(output_dir / "v6.bin") + zeros_out.reshape(-1).tofile(output_dir / "v7.bin") + np.array([seq], dtype=np.int32).tofile(output_dir / "v8.bin") + np.array([ROWS], dtype=np.int32).tofile(output_dir / "v9.bin") + newmax.tofile(output_dir / "golden_v4.bin") + newsum.tofile(output_dir / "golden_v5.bin") + expmax.tofile(output_dir / "golden_v6.bin") + out.astype(np.float32, copy=False).reshape(-1).tofile(output_dir / "golden_v7.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/kernels/online-softmax-update/kernel.pto b/test/vpto/cases/kernels/online-softmax-update/kernel.pto new file mode 100644 index 000000000..96c889643 --- /dev/null +++ b/test/vpto/cases/kernels/online-softmax-update/kernel.pto @@ -0,0 +1,164 @@ +// ----------------------------------------------------------------------------- +// case: kernels/online-softmax-update +// family: kernels +// target_ops: pto.get_block_idx, pto.mte_gm_ub, pto.mte_ub_gm, pto.vlds, pto.vcmax, pto.vdup, pto.vmax, pto.vexpdif, pto.vcadd, pto.vadd, pto.vmul, pto.vdiv, pto.vsts +// scenarios: online-softmax-update, dynamic-rows-and-seq, max-seq-128, block-rows-8, oldmax-oldsum-qk-to-newmax-newsum-expmax-out +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @online_softmax_update_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr, + %arg2: !pto.ptr, + %arg3: !pto.ptr, + %arg4: !pto.ptr, + %arg5: !pto.ptr, + %arg6: !pto.ptr, + %arg7: i32, + %arg8: i32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c8_i64 = arith.constant 8 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %c8448_i64 = arith.constant 8448 : i64 + %c16640_i64 = arith.constant 16640 : i64 + %c16768_i64 = arith.constant 16768 : i64 + %c16896_i64 = arith.constant 16896 : i64 + + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %c64_i32 = arith.constant 64 : i32 + %c0_i32 = arith.constant 0 : i32 + %block = pto.get_block_idx + %block_idx = arith.index_cast %block : i64 to index + %row_base = arith.muli %block_idx, %c8 : index + %qk_base = arith.muli %row_base, %c128 : index + %block_rows_i32 = arith.index_cast %c8 : index to i32 + %row_base_i32 = arith.index_cast %row_base : index to i32 + %remaining_rows = arith.subi %arg8, %row_base_i32 : i32 + %has_rows = arith.cmpi sgt, %remaining_rows, %c0_i32 : i32 + %too_many_rows = arith.cmpi sgt, %remaining_rows, %c8_i32 : i32 + %row_count_i32 = arith.select %too_many_rows, %c8_i32, %remaining_rows : i32 + %row_count = arith.index_cast %row_count_i32 : i32 to index + %row_count_i64 = arith.extui %row_count_i32 : i32 to i64 + %gm_oldmax = pto.addptr %arg0, %row_base : !pto.ptr -> !pto.ptr + %gm_oldsum = pto.addptr %arg1, %row_base : !pto.ptr -> !pto.ptr + %gm_qk = pto.addptr %arg2, %qk_base : !pto.ptr -> !pto.ptr + %gm_qk_hi = pto.addptr %gm_qk, %c64 : !pto.ptr -> !pto.ptr + %gm_newmax = pto.addptr %arg3, %row_base : !pto.ptr -> !pto.ptr + %gm_newsum = pto.addptr %arg4, %row_base : !pto.ptr -> !pto.ptr + %gm_expmax = pto.addptr %arg5, %row_base : !pto.ptr -> !pto.ptr + %gm_out = pto.addptr %arg6, %qk_base : !pto.ptr -> !pto.ptr + + %ub_oldmax = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_oldsum = pto.castptr %c128_i64 : i64 -> !pto.ptr + %ub_qk = pto.castptr %c256_i64 : i64 -> !pto.ptr + %ub_qk_hi = pto.addptr %ub_qk, %c64 : !pto.ptr -> !pto.ptr + %ub_out = pto.castptr %c8448_i64 : i64 -> !pto.ptr + %ub_newmax = pto.castptr %c16640_i64 : i64 -> !pto.ptr + %ub_newsum = pto.castptr %c16768_i64 : i64 -> !pto.ptr + %ub_expmax = pto.castptr %c16896_i64 : i64 -> !pto.ptr + + scf.if %has_rows { + pto.mte_gm_ub %gm_oldmax, %ub_oldmax, %c0_i64, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %gm_oldsum, %ub_oldsum, %c0_i64, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %gm_qk, %ub_qk, %c0_i64, %c256_i64 + nburst(%row_count_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %gm_qk_hi, %ub_qk_hi, %c0_i64, %c256_i64 + nburst(%row_count_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %active = pto.pset_b32 "PAT_ALL" : !pto.mask + %one_mask, %one_remaining = pto.plt_b32 %c1_i32 : i32 -> !pto.mask, i32 + scf.for %row = %c0 to %row_count step %c1 { + %row_qk = arith.muli %row, %c128 : index + %oldmax_bc = pto.vlds %ub_oldmax[%row] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> + %oldsum_bc = pto.vlds %ub_oldsum[%row] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> + + %final_max, %final_sum = scf.for %chunk = %c0 to %c128 step %c64 + iter_args(%running_max = %oldmax_bc, %running_sum = %oldsum_bc) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %chunk_i32 = arith.index_cast %chunk : index to i32 + %remaining_cols = arith.subi %arg7, %chunk_i32 : i32 + %has_chunk = arith.cmpi sgt, %remaining_cols, %c0_i32 : i32 + %next_max, %next_sum = scf.if %has_chunk -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %chunk_mask, %chunk_rest = pto.plt_b32 %remaining_cols : i32 -> !pto.mask, i32 + %chunk_base = arith.addi %row_qk, %chunk : index + %vec = pto.vlds %ub_qk[%chunk_base] : !pto.ptr -> !pto.vreg<64xf32> + %chunk_max = pto.vcmax %vec, %chunk_mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %chunk_max_bc = pto.vdup %chunk_max, %active {position = "LOWEST"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %merged_max = pto.vmax %running_max, %chunk_max_bc, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %scaled_running = pto.vexpdif %running_max, %merged_max, %active, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %running_sum_scaled = pto.vmul %scaled_running, %running_sum, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %chunk_exp = pto.vexpdif %vec, %merged_max, %chunk_mask, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %chunk_sum = pto.vcadd %chunk_exp, %chunk_mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %chunk_sum_bc = pto.vdup %chunk_sum, %active {position = "LOWEST"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %merged_sum = pto.vadd %running_sum_scaled, %chunk_sum_bc, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + scf.yield %merged_max, %merged_sum : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } else { + scf.yield %running_max, %running_sum : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + scf.yield %next_max, %next_sum : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + %raw_expmax = pto.vexpdif %oldmax_bc, %final_max, %active, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %scaled_oldsum = pto.vmul %raw_expmax, %oldsum_bc, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %expmax = pto.vdiv %scaled_oldsum, %final_sum, %active : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %final_max, %ub_newmax[%row], %one_mask {dist = "1PT_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + pto.vsts %final_sum, %ub_newsum[%row], %one_mask {dist = "1PT_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + pto.vsts %expmax, %ub_expmax[%row], %one_mask {dist = "1PT_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + + scf.for %chunk = %c0 to %c128 step %c64 { + %chunk_i32 = arith.index_cast %chunk : index to i32 + %remaining_cols = arith.subi %arg7, %chunk_i32 : i32 + %has_chunk = arith.cmpi sgt, %remaining_cols, %c0_i32 : i32 + scf.if %has_chunk { + %chunk_mask, %chunk_rest = pto.plt_b32 %remaining_cols : i32 -> !pto.mask, i32 + %chunk_base = arith.addi %row_qk, %chunk : index + %vec = pto.vlds %ub_qk[%chunk_base] : !pto.ptr -> !pto.vreg<64xf32> + %exp = pto.vexpdif %vec, %final_max, %chunk_mask, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %out = pto.vdiv %exp, %final_sum, %chunk_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%chunk_base], %chunk_mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + } + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.mte_ub_gm %ub_newmax, %gm_newmax, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_newsum, %gm_newsum, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_expmax, %gm_expmax, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_out, %gm_out, %c512_i64 + nburst(%row_count_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + } + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/kernels/online-softmax-update/launch.cpp b/test/vpto/cases/kernels/online-softmax-update/launch.cpp new file mode 100644 index 000000000..836937ba3 --- /dev/null +++ b/test/vpto/cases/kernels/online-softmax-update/launch.cpp @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: kernels/online-softmax-update +// family: kernels +// target_ops: pto.get_block_idx, pto.mte_gm_ub, pto.mte_ub_gm, pto.vlds, pto.vcmax, pto.vdup, pto.vmax, pto.vexpdif, pto.vcadd, pto.vadd, pto.vmul, pto.vdiv, pto.vsts +// scenarios: online-softmax-update, dynamic-rows-and-seq, max-seq-128, block-rows-8, oldmax-oldsum-qk-to-newmax-newsum-expmax-out +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void online_softmax_update_kernel_2d( + __gm__ float *v1, __gm__ float *v2, __gm__ float *v3, + __gm__ float *v4, __gm__ float *v5, __gm__ float *v6, + __gm__ float *v7, int32_t v8, int32_t v9); + +void LaunchOnline_softmax_update_kernel_2d(float *v1, float *v2, float *v3, + float *v4, float *v5, float *v6, + float *v7, int32_t v8, int32_t v9, + void *stream) { + const int32_t blockRows = 8; + const int32_t blocks = (v9 + blockRows - 1) / blockRows; + online_softmax_update_kernel_2d<<>>( + (__gm__ float *)v1, (__gm__ float *)v2, (__gm__ float *)v3, + (__gm__ float *)v4, (__gm__ float *)v5, (__gm__ float *)v6, + (__gm__ float *)v7, v8, v9); +} diff --git a/test/vpto/cases/kernels/online-softmax-update/main.cpp b/test/vpto/cases/kernels/online-softmax-update/main.cpp new file mode 100644 index 000000000..00cf74b2c --- /dev/null +++ b/test/vpto/cases/kernels/online-softmax-update/main.cpp @@ -0,0 +1,150 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: kernels/online-softmax-update +// family: kernels +// target_ops: pto.get_block_idx, pto.mte_gm_ub, pto.mte_ub_gm, pto.vlds, pto.vcmax, pto.vdup, pto.vmax, pto.vexpdif, pto.vcadd, pto.vadd, pto.vmul, pto.vdiv, pto.vsts +// scenarios: online-softmax-update, dynamic-rows-and-seq, max-seq-128, block-rows-8, oldmax-oldsum-qk-to-newmax-newsum-expmax-out +// ----------------------------------------------------------------------------- +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchOnline_softmax_update_kernel_2d(float *v1, float *v2, float *v3, + float *v4, float *v5, float *v6, + float *v7, int32_t v8, int32_t v9, + void *stream); + +int main() { + constexpr size_t elemCountSeq = 1; + constexpr size_t elemCountRows = 1; + size_t fileSizeSeq = elemCountSeq * sizeof(int32_t); + size_t fileSizeRows = elemCountRows * sizeof(int32_t); + size_t elemCountState = 0; + size_t elemCountOut = 0; + size_t fileSizeState = 0; + size_t fileSizeOut = 0; + float *v1Host = nullptr, *v2Host = nullptr, *v3Host = nullptr; + float *v4Host = nullptr, *v5Host = nullptr, *v6Host = nullptr; + float *v7Host = nullptr; + float *v1Device = nullptr, *v2Device = nullptr, *v3Device = nullptr; + float *v4Device = nullptr, *v5Device = nullptr, *v6Device = nullptr; + float *v7Device = nullptr; + int32_t v8Host = 0, v9Host = 0; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ReadFile("./v8.bin", fileSizeSeq, &v8Host, fileSizeSeq); + ReadFile("./v9.bin", fileSizeRows, &v9Host, fileSizeRows); + + elemCountState = static_cast(v9Host); + elemCountOut = static_cast(v9Host) * 128; + fileSizeState = elemCountState * sizeof(float); + fileSizeOut = elemCountOut * sizeof(float); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSizeState)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSizeState)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSizeOut)); + ACL_CHECK(aclrtMallocHost((void **)(&v4Host), fileSizeState)); + ACL_CHECK(aclrtMallocHost((void **)(&v5Host), fileSizeState)); + ACL_CHECK(aclrtMallocHost((void **)(&v6Host), fileSizeState)); + ACL_CHECK(aclrtMallocHost((void **)(&v7Host), fileSizeOut)); + + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSizeState, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSizeState, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSizeOut, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v4Device, fileSizeState, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v5Device, fileSizeState, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v6Device, fileSizeState, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v7Device, fileSizeOut, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSizeState, v1Host, fileSizeState); + ReadFile("./v2.bin", fileSizeState, v2Host, fileSizeState); + ReadFile("./v3.bin", fileSizeOut, v3Host, fileSizeOut); + ReadFile("./v4.bin", fileSizeState, v4Host, fileSizeState); + ReadFile("./v5.bin", fileSizeState, v5Host, fileSizeState); + ReadFile("./v6.bin", fileSizeState, v6Host, fileSizeState); + ReadFile("./v7.bin", fileSizeOut, v7Host, fileSizeOut); + + ACL_CHECK(aclrtMemcpy(v1Device, fileSizeState, v1Host, fileSizeState, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSizeState, v2Host, fileSizeState, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSizeOut, v3Host, fileSizeOut, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v4Device, fileSizeState, v4Host, fileSizeState, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v5Device, fileSizeState, v5Host, fileSizeState, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v6Device, fileSizeState, v6Host, fileSizeState, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v7Device, fileSizeOut, v7Host, fileSizeOut, ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchOnline_softmax_update_kernel_2d(v1Device, v2Device, v3Device, + v4Device, v5Device, v6Device, + v7Device, v8Host, v9Host, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v4Host, fileSizeState, v4Device, fileSizeState, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(v5Host, fileSizeState, v5Device, fileSizeState, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(v6Host, fileSizeState, v6Device, fileSizeState, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(v7Host, fileSizeOut, v7Device, fileSizeOut, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v4.bin", v4Host, fileSizeState); + WriteFile("./v5.bin", v5Host, fileSizeState); + WriteFile("./v6.bin", v6Host, fileSizeState); + WriteFile("./v7.bin", v7Host, fileSizeOut); + +cleanup: + aclrtFree(v1Device); aclrtFree(v2Device); aclrtFree(v3Device); + aclrtFree(v4Device); aclrtFree(v5Device); aclrtFree(v6Device); aclrtFree(v7Device); + aclrtFreeHost(v1Host); aclrtFreeHost(v2Host); aclrtFreeHost(v3Host); + aclrtFreeHost(v4Host); aclrtFreeHost(v5Host); aclrtFreeHost(v6Host); aclrtFreeHost(v7Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-bf16/compare.py b/test/vpto/cases/micro-op/binary-vector/vadd-bf16/compare.py new file mode 100755 index 000000000..68ceff820 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-bf16/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vadd-bf16 +# family: binary-vector +# target_ops: pto.vadd +# scenarios: core-bf16, full-mask +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.uint16, 0, 1024) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-bf16/golden.py b/test/vpto/cases/micro-op/binary-vector/vadd-bf16/golden.py new file mode 100755 index 000000000..c1d417ba0 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-bf16/golden.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vadd-bf16 +# family: binary-vector +# target_ops: pto.vadd +# scenarios: core-bf16, full-mask +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def f32_to_bf16_bits(values: np.ndarray) -> np.ndarray: + wide = values.astype(np.float32, copy=False).view(np.uint32) + rounding = np.uint32(0x7FFF) + ((wide >> 16) & np.uint32(1)) + return ((wide + rounding) >> 16).astype(np.uint16) + + +def bf16_bits_to_f32(bits: np.ndarray) -> np.ndarray: + return (bits.astype(np.uint32) << 16).view(np.float32) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1_f32 = rng.uniform(-4.0, 4.0, size=ELEMS).astype(np.float32) + v2_f32 = rng.uniform(-4.0, 4.0, size=ELEMS).astype(np.float32) + v1 = f32_to_bf16_bits(v1_f32) + v2 = f32_to_bf16_bits(v2_f32) + v3 = np.zeros(ELEMS, dtype=np.uint16) + golden_v3 = f32_to_bf16_bits(bf16_bits_to_f32(v1) + bf16_bits_to_f32(v2)) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + v3.tofile(output_dir / "v3.bin") + golden_v3.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-bf16/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vadd-bf16/kernel.pto new file mode 100644 index 000000000..653d327fd --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-bf16/kernel.pto @@ -0,0 +1,53 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vadd-bf16 +// family: binary-vector +// target_ops: pto.vadd +// scenarios: core-bf16, full-mask +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vadd_bf16_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<128xbf16> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<128xbf16> + %sum = pto.vadd %lhs, %rhs, %mask : !pto.vreg<128xbf16>, !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<128xbf16> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<128xbf16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-bf16/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vadd-bf16/launch.cpp new file mode 100644 index 000000000..13e50fe0b --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-bf16/launch.cpp @@ -0,0 +1,53 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vadd-bf16 +// family: binary-vector +// target_ops: pto.vadd +// scenarios: core-bf16, full-mask +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vadd_bf16_kernel(__gm__ bfloat16_t *v1, + __gm__ bfloat16_t *v2, + __gm__ bfloat16_t *v3); + +void LaunchVadd_bf16_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream) { + vadd_bf16_kernel<<<1, nullptr, stream>>>((__gm__ bfloat16_t *)v1, + (__gm__ bfloat16_t *)v2, + (__gm__ bfloat16_t *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-bf16/main.cpp b/test/vpto/cases/micro-op/binary-vector/vadd-bf16/main.cpp new file mode 100644 index 000000000..e130fecc0 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-bf16/main.cpp @@ -0,0 +1,102 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vadd-bf16 +// family: binary-vector +// target_ops: pto.vadd +// scenarios: core-bf16, full-mask +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVadd_bf16_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint16_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + uint16_t *v3Host = nullptr; + uint16_t *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVadd_bf16_kernel(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-f16/compare.py b/test/vpto/cases/micro-op/binary-vector/vadd-f16/compare.py new file mode 100755 index 000000000..1254044fb --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-f16/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vadd-f16 +# family: binary-vector +# target_ops: pto.vadd +# scenarios: core-f16, full-mask +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.float16, 5e-3, 1024) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-f16/golden.py b/test/vpto/cases/micro-op/binary-vector/vadd-f16/golden.py new file mode 100755 index 000000000..442cc35e7 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-f16/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vadd-f16 +# family: binary-vector +# target_ops: pto.vadd +# scenarios: core-f16, full-mask +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +LOGICAL_ELEMS = 1000 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float16) + v2 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float16) + v3 = np.zeros((ROWS, COLS), dtype=np.float16) + golden_v3 = (v1.astype(np.float32) + v2.astype(np.float32)).astype(np.float16) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-f16/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vadd-f16/kernel.pto new file mode 100644 index 000000000..705cb81ef --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-f16/kernel.pto @@ -0,0 +1,53 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vadd-f16 +// family: binary-vector +// target_ops: pto.vadd +// scenarios: core-f16, full-mask +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vadd_f16_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<128xf16> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<128xf16> + %sum = pto.vadd %lhs, %rhs, %mask : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<128xf16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-f16/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vadd-f16/launch.cpp new file mode 100644 index 000000000..8beb1a003 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-f16/launch.cpp @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vadd-f16 +// family: binary-vector +// target_ops: pto.vadd +// scenarios: core-f16, full-mask +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vadd_f16_kernel(__gm__ half *v1, + __gm__ half *v2, + __gm__ half *v3); + +void LaunchVadd_f16_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream) { + vadd_f16_kernel<<<1, nullptr, stream>>>((__gm__ half *)v1, (__gm__ half *)v2, + (__gm__ half *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-f16/main.cpp b/test/vpto/cases/micro-op/binary-vector/vadd-f16/main.cpp new file mode 100644 index 000000000..621cf398a --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-f16/main.cpp @@ -0,0 +1,101 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vadd-f16 +// family: binary-vector +// target_ops: pto.vadd +// scenarios: core-f16, full-mask +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVadd_f16_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint16_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + uint16_t *v3Host = nullptr; + uint16_t *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVadd_f16_kernel(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-f32-exceptional/compare.py b/test/vpto/cases/micro-op/binary-vector/vadd-f32-exceptional/compare.py new file mode 100644 index 000000000..a5f14dabc --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-f32-exceptional/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-f32-exceptional/golden.py b/test/vpto/cases/micro-op/binary-vector/vadd-f32-exceptional/golden.py new file mode 100644 index 000000000..802880fdc --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-f32-exceptional/golden.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + del seed + specials_a = np.array( + [-np.inf, -7.5, -0.0, 0.0, 1.0, np.inf, np.nan, 3.5], + dtype=np.float32, + ) + specials_b = np.array( + [np.inf, 2.5, 0.0, -0.0, -1.0, -np.inf, 1.0, np.nan], + dtype=np.float32, + ) + v1 = np.resize(specials_a, ROWS * COLS).reshape(ROWS, COLS).astype(np.float32) + v2 = np.resize(specials_b, ROWS * COLS).reshape(ROWS, COLS).astype(np.float32) + v3 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v3 = (v1 + v2).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-f32-exceptional/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vadd-f32-exceptional/kernel.pto new file mode 100644 index 000000000..a3fbb8ef7 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-f32-exceptional/kernel.pto @@ -0,0 +1,49 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vadd_f32_exceptional_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %sum = pto.vadd %lhs, %rhs, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-f32-exceptional/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vadd-f32-exceptional/launch.cpp new file mode 100644 index 000000000..fbb0031f2 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-f32-exceptional/launch.cpp @@ -0,0 +1,46 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vadd_f32_exceptional_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3); + +void LaunchVadd_f32_exceptional_kernel_2d(float *v1, float *v2, float *v3, void *stream) { + vadd_f32_exceptional_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-f32-exceptional/main.cpp b/test/vpto/cases/micro-op/binary-vector/vadd-f32-exceptional/main.cpp new file mode 100644 index 000000000..781e0d000 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-f32-exceptional/main.cpp @@ -0,0 +1,95 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVadd_f32_exceptional_kernel_2d(float *v1, float *v2, float *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVadd_f32_exceptional_kernel_2d(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed-overflow/compare.py b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed-overflow/compare.py new file mode 100644 index 000000000..fe6bc69c3 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed-overflow/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vadd-i16-signed-overflow +# family: binary-vector +# target_ops: pto.vadd +# scenarios: core-i16-signed, full-mask, integer-overflow + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str, dtype) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin", np.int16) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed-overflow/golden.py b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed-overflow/golden.py new file mode 100644 index 000000000..960e6d163 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed-overflow/golden.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vadd-i16-signed-overflow +# family: binary-vector +# target_ops: pto.vadd +# scenarios: core-i16-signed, full-mask, integer-overflow + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def wrap_add_i16(lhs: np.ndarray, rhs: np.ndarray) -> np.ndarray: + bits = lhs.view(np.uint16).astype(np.uint32) + rhs.view(np.uint16).astype(np.uint32) + return (bits & 0xFFFF).astype(np.uint16).view(np.int16) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + lhs_pattern = np.array( + [32767, 32760, -32768, -32760, 1000, -1000, 12345, -12345], + dtype=np.int16, + ) + rhs_pattern = np.array( + [1, 100, -1, -100, 30000, -30000, 23456, -23456], + dtype=np.int16, + ) + repeats = ELEMS // lhs_pattern.size + v1 = np.tile(lhs_pattern, repeats) + v2 = np.tile(rhs_pattern, repeats) + v3 = np.zeros(ELEMS, dtype=np.int16) + golden_v3 = wrap_add_i16(v1, v2) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + v3.tofile(output_dir / "v3.bin") + golden_v3.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed-overflow/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed-overflow/kernel.pto new file mode 100644 index 000000000..dafac0a22 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed-overflow/kernel.pto @@ -0,0 +1,56 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vadd-i16-signed-overflow +// family: binary-vector +// target_ops: pto.vadd +// scenarios: core-i16-signed, full-mask, integer-overflow +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vadd_i16_signed_overflow_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<128xi16> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<128xi16> + %sum = pto.vadd %lhs, %rhs, %mask : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<128xi16>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed-overflow/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed-overflow/launch.cpp new file mode 100644 index 000000000..7e3c8bb76 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed-overflow/launch.cpp @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vadd_i16_signed_overflow_kernel(__gm__ int16_t *v1, + __gm__ int16_t *v2, + __gm__ int16_t *v3); + +void LaunchVadd_i16_signed_overflow_kernel(int16_t *v1, int16_t *v2, int16_t *v3, + void *stream) { + vadd_i16_signed_overflow_kernel<<<1, nullptr, stream>>>((__gm__ int16_t *)v1, + (__gm__ int16_t *)v2, + (__gm__ int16_t *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed-overflow/main.cpp b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed-overflow/main.cpp new file mode 100644 index 000000000..26f3895f8 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed-overflow/main.cpp @@ -0,0 +1,92 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVadd_i16_signed_overflow_kernel(int16_t *v1, int16_t *v2, int16_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(int16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(int16_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(int16_t); + int16_t *v1Host = nullptr; + int16_t *v1Device = nullptr; + int16_t *v2Host = nullptr; + int16_t *v2Device = nullptr; + int16_t *v3Host = nullptr; + int16_t *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVadd_i16_signed_overflow_kernel(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed/compare.py b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed/compare.py new file mode 100755 index 000000000..2f49f90a6 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vadd-i16-signed +# family: binary-vector +# target_ops: pto.vadd +# scenarios: core-i16-signed, full-mask +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.int16, 0, 1024) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed/golden.py b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed/golden.py new file mode 100755 index 000000000..38079c47f --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed/golden.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vadd-i16-signed +# family: binary-vector +# target_ops: pto.vadd +# scenarios: core-i16-signed, full-mask +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(-1000, 1001, size=(ROWS, COLS), dtype=np.int16) + v2 = rng.integers(-1000, 1001, size=(ROWS, COLS), dtype=np.int16) + v3 = np.zeros((ROWS, COLS), dtype=np.int16) + golden_v3 = (v1.astype(np.int32) + v2.astype(np.int32)).astype(np.int16) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed/kernel.pto new file mode 100644 index 000000000..83ecf0cfa --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed/kernel.pto @@ -0,0 +1,53 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vadd-i16-signed +// family: binary-vector +// target_ops: pto.vadd +// scenarios: core-i16-signed, full-mask +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vadd_i16_signed_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1024 = arith.constant 1024 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<128xi16> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<128xi16> + %sum = pto.vadd %lhs, %rhs, %mask : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<128xi16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed/launch.cpp new file mode 100644 index 000000000..5f2e4f059 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed/launch.cpp @@ -0,0 +1,53 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vadd-i16-signed +// family: binary-vector +// target_ops: pto.vadd +// scenarios: core-i16-signed, full-mask +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vadd_i16_signed_kernel(__gm__ int16_t *v1, + __gm__ int16_t *v2, + __gm__ int16_t *v3); + +void LaunchVadd_i16_signed_kernel(int16_t *v1, int16_t *v2, int16_t *v3, + void *stream) { + vadd_i16_signed_kernel<<<1, nullptr, stream>>>((__gm__ int16_t *)v1, + (__gm__ int16_t *)v2, + (__gm__ int16_t *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed/main.cpp b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed/main.cpp new file mode 100644 index 000000000..4d22c989d --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-signed/main.cpp @@ -0,0 +1,102 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vadd-i16-signed +// family: binary-vector +// target_ops: pto.vadd +// scenarios: core-i16-signed, full-mask +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVadd_i16_signed_kernel(int16_t *v1, int16_t *v2, int16_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(int16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(int16_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(int16_t); + int16_t *v1Host = nullptr; + int16_t *v1Device = nullptr; + int16_t *v2Host = nullptr; + int16_t *v2Device = nullptr; + int16_t *v3Host = nullptr; + int16_t *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVadd_i16_signed_kernel(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned-overflow/compare.py b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned-overflow/compare.py new file mode 100644 index 000000000..4e992a275 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned-overflow/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vadd-i16-unsigned-overflow +# family: binary-vector +# target_ops: pto.vadd +# scenarios: core-i16-unsigned, full-mask, integer-overflow + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str, dtype) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin", np.uint16) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned-overflow/golden.py b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned-overflow/golden.py new file mode 100644 index 000000000..4673e2501 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned-overflow/golden.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vadd-i16-unsigned-overflow +# family: binary-vector +# target_ops: pto.vadd +# scenarios: core-i16-unsigned, full-mask, integer-overflow + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def wrap_add_u16(lhs: np.ndarray, rhs: np.ndarray) -> np.ndarray: + wide = lhs.astype(np.uint32) + rhs.astype(np.uint32) + return (wide & 0xFFFF).astype(np.uint16) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + lhs_pattern = np.array( + [65535, 65530, 65500, 60000, 100, 0, 32768, 12345], + dtype=np.uint16, + ) + rhs_pattern = np.array( + [1, 10, 1000, 10000, 65535, 5, 40000, 60000], + dtype=np.uint16, + ) + repeats = ELEMS // lhs_pattern.size + v1 = np.tile(lhs_pattern, repeats) + v2 = np.tile(rhs_pattern, repeats) + v3 = np.zeros(ELEMS, dtype=np.uint16) + golden_v3 = wrap_add_u16(v1, v2) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + v3.tofile(output_dir / "v3.bin") + golden_v3.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned-overflow/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned-overflow/kernel.pto new file mode 100644 index 000000000..9888837f9 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned-overflow/kernel.pto @@ -0,0 +1,56 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vadd-i16-unsigned-overflow +// family: binary-vector +// target_ops: pto.vadd +// scenarios: core-i16-unsigned, full-mask, integer-overflow +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vadd_i16_unsigned_overflow_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %sum = pto.vadd %lhs, %rhs, %mask : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<128xui16>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned-overflow/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned-overflow/launch.cpp new file mode 100644 index 000000000..bfd0fbe37 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned-overflow/launch.cpp @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vadd_i16_unsigned_overflow_kernel(__gm__ uint16_t *v1, + __gm__ uint16_t *v2, + __gm__ uint16_t *v3); + +void LaunchVadd_i16_unsigned_overflow_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream) { + vadd_i16_unsigned_overflow_kernel<<<1, nullptr, stream>>>((__gm__ uint16_t *)v1, + (__gm__ uint16_t *)v2, + (__gm__ uint16_t *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned-overflow/main.cpp b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned-overflow/main.cpp new file mode 100644 index 000000000..fb6fa53b2 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned-overflow/main.cpp @@ -0,0 +1,92 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVadd_i16_unsigned_overflow_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint16_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + uint16_t *v3Host = nullptr; + uint16_t *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVadd_i16_unsigned_overflow_kernel(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned/compare.py b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned/compare.py new file mode 100755 index 000000000..29c833e93 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vadd-i16-unsigned +# family: binary-vector +# target_ops: pto.vadd +# scenarios: core-i16-unsigned, full-mask +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.uint16, 0, 1024) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned/golden.py b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned/golden.py new file mode 100755 index 000000000..fa3e8e0c1 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned/golden.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vadd-i16-unsigned +# family: binary-vector +# target_ops: pto.vadd +# scenarios: core-i16-unsigned, full-mask +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(0, 2001, size=(ROWS, COLS), dtype=np.uint16) + v2 = rng.integers(0, 2001, size=(ROWS, COLS), dtype=np.uint16) + v3 = np.zeros((ROWS, COLS), dtype=np.uint16) + golden_v3 = (v1.astype(np.uint32) + v2.astype(np.uint32)).astype(np.uint16) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned/kernel.pto new file mode 100644 index 000000000..ddbeecdf3 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned/kernel.pto @@ -0,0 +1,53 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vadd-i16-unsigned +// family: binary-vector +// target_ops: pto.vadd +// scenarios: core-i16-unsigned, full-mask +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vadd_i16_unsigned_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1024 = arith.constant 1024 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %sum = pto.vadd %lhs, %rhs, %mask : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<128xui16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned/launch.cpp new file mode 100644 index 000000000..c4198a017 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned/launch.cpp @@ -0,0 +1,53 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vadd-i16-unsigned +// family: binary-vector +// target_ops: pto.vadd +// scenarios: core-i16-unsigned, full-mask +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vadd_i16_unsigned_kernel(__gm__ uint16_t *v1, + __gm__ uint16_t *v2, + __gm__ uint16_t *v3); + +void LaunchVadd_i16_unsigned_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream) { + vadd_i16_unsigned_kernel<<<1, nullptr, stream>>>((__gm__ uint16_t *)v1, + (__gm__ uint16_t *)v2, + (__gm__ uint16_t *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned/main.cpp b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned/main.cpp new file mode 100644 index 000000000..dd05d5051 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-i16-unsigned/main.cpp @@ -0,0 +1,102 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vadd-i16-unsigned +// family: binary-vector +// target_ops: pto.vadd +// scenarios: core-i16-unsigned, full-mask +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVadd_i16_unsigned_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint16_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + uint16_t *v3Host = nullptr; + uint16_t *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVadd_i16_unsigned_kernel(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-tail/compare.py b/test/vpto/cases/micro-op/binary-vector/vadd-tail/compare.py new file mode 100644 index 000000000..c95419953 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-tail/compare.py @@ -0,0 +1,39 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.float32, 1e-4, 1000) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-tail/golden.py b/test/vpto/cases/micro-op/binary-vector/vadd-tail/golden.py new file mode 100644 index 000000000..e967b1153 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-tail/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +LOGICAL_ELEMS = 1000 +OUT_SENTINEL = np.float32(-123.25) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.random((ROWS, COLS), dtype=np.float32) + v2 = rng.random((ROWS, COLS), dtype=np.float32) + v3 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + golden_v3 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + golden_v3.reshape(-1)[:LOGICAL_ELEMS] = ( + v1.reshape(-1)[:LOGICAL_ELEMS] + v2.reshape(-1)[:LOGICAL_ELEMS] + ).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-tail/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vadd-tail/kernel.pto new file mode 100644 index 000000000..599f6da84 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-tail/kernel.pto @@ -0,0 +1,49 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vadd_tail_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c1000_i32 = arith.constant 1000 : i32 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1000_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %sum = pto.vadd %lhs, %rhs, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-tail/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vadd-tail/launch.cpp new file mode 100644 index 000000000..3d1578331 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-tail/launch.cpp @@ -0,0 +1,46 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vadd_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3); + +void LaunchVadd_tail_kernel_2d(float *v1, float *v2, float *v3, void *stream) { + vadd_tail_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd-tail/main.cpp b/test/vpto/cases/micro-op/binary-vector/vadd-tail/main.cpp new file mode 100644 index 000000000..40a9881d6 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd-tail/main.cpp @@ -0,0 +1,94 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVadd_tail_kernel_2d(float *v1, float *v2, float *v3, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVadd_tail_kernel_2d(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd/compare.py b/test/vpto/cases/micro-op/binary-vector/vadd/compare.py new file mode 100644 index 000000000..a5f14dabc --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vadd/golden.py b/test/vpto/cases/micro-op/binary-vector/vadd/golden.py new file mode 100644 index 000000000..fbf37245e --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd/golden.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + golden_v3 = (v1 + v2).astype(np.float32, copy=False) + v3 = np.zeros((ROWS, COLS), dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vadd/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vadd/kernel.pto new file mode 100644 index 000000000..2aa6dbfe0 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd/kernel.pto @@ -0,0 +1,58 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @add_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c2_i64 = arith.constant 2 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.get_buf "PIPE_MTE2", 0, 0 + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.rls_buf "PIPE_MTE2", 0, 0 + pto.get_buf "PIPE_MTE2", 1, 0 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.rls_buf "PIPE_MTE2", 1, 0 + pto.get_buf "PIPE_V", 0, 0 + pto.get_buf "PIPE_V", 1, 0 + pto.get_buf "PIPE_V", 2, 0 + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %sum = pto.vadd %lhs, %rhs, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + pto.rls_buf "PIPE_V", 0, 0 + pto.rls_buf "PIPE_V", 1, 0 + pto.rls_buf "PIPE_V", 2, 0 + + pto.get_buf "PIPE_MTE3", 2, 0 + pto.mte_ub_gm %ub_out, %arg2, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.rls_buf "PIPE_MTE3", 2, 0 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vadd/launch.cpp new file mode 100644 index 000000000..7e1cfc9e0 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd/launch.cpp @@ -0,0 +1,46 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void add_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3); + +void LaunchAdd_kernel_2d(float *v1, float *v2, float *v3, void *stream) { + add_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vadd/main.cpp b/test/vpto/cases/micro-op/binary-vector/vadd/main.cpp new file mode 100644 index 000000000..78517d1a3 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vadd/main.cpp @@ -0,0 +1,94 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchAdd_kernel_2d(float *v1, float *v2, float *v3, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchAdd_kernel_2d(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vaddc-carry-boundary/compare.py b/test/vpto/cases/micro-op/binary-vector/vaddc-carry-boundary/compare.py new file mode 100755 index 000000000..df15d65e4 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vaddc-carry-boundary/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vaddc-carry-boundary +# family: binary-vector +# target_ops: pto.vaddc +# scenarios: core-u32-unsigned, full-mask, carry-chain + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 +LOGICAL_ELEMS = 64 +SRC_ELEM_BYTES = 4 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + repeat_elems = REPEAT_BYTES // src_elem_bytes + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + +def compare_result(): + golden = np.fromfile("golden_v3.bin", dtype=np.uint32, count=64) + output = np.fromfile("v3.bin", dtype=np.uint32, count=64) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def compare_carry(): + prefix_bytes = _packed_pred_storage_bytes(LOGICAL_ELEMS, SRC_ELEM_BYTES) + golden = np.fromfile("golden_v4.bin", dtype=np.uint8) + output = np.fromfile("v4.bin", dtype=np.uint8) + if golden.size < prefix_bytes or output.size < prefix_bytes: + return False + return np.array_equal(golden[:prefix_bytes], output[:prefix_bytes]) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_result() and compare_carry() + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vaddc-carry-boundary/golden.py b/test/vpto/cases/micro-op/binary-vector/vaddc-carry-boundary/golden.py new file mode 100644 index 000000000..253c44d2c --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vaddc-carry-boundary/golden.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vaddc-carry-boundary +# family: binary-vector +# target_ops: pto.vaddc +# scenarios: core-u32-unsigned, full-mask, carry-chain + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 64 +SEED = 19 + + +def pack_mask_nibbles(bits): + out = np.zeros(256, dtype=np.uint8) + for idx, bit in enumerate(bits): + if not bit: + continue + byte = idx // 2 + if idx % 2 == 0: + out[byte] |= np.uint8(0x1) + else: + out[byte] |= np.uint8(0x10) + return out + + +def generate(output_dir: Path, seed: int) -> None: + del seed + v1 = np.zeros(LANES, dtype=np.uint32) + v2 = np.zeros(LANES, dtype=np.uint32) + pattern_lhs = np.array([0xFFFFFFFF, 0xFFFFFFFE, 0x80000000, 0x7FFFFFFF], dtype=np.uint32) + pattern_rhs = np.array([0x00000001, 0x00000002, 0x80000000, 0x00000001], dtype=np.uint32) + reps = LANES // pattern_lhs.size + v1[:] = np.tile(pattern_lhs, reps) + v2[:] = np.tile(pattern_rhs, reps) + total = v1.astype(np.uint64) + v2.astype(np.uint64) + result = (total & np.uint64(0xFFFFFFFF)).astype(np.uint32) + carry = (total >> np.uint64(32)) != 0 + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + np.zeros(LANES, dtype=np.uint32).tofile(output_dir / "v3.bin") + np.zeros(256, dtype=np.uint8).tofile(output_dir / "v4.bin") + result.tofile(output_dir / "golden_v3.bin") + pack_mask_nibbles(carry).tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vaddc-carry-boundary/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vaddc-carry-boundary/kernel.pto new file mode 100644 index 000000000..13933a0ba --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vaddc-carry-boundary/kernel.pto @@ -0,0 +1,57 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vaddc-carry-boundary +// family: binary-vector +// target_ops: pto.vaddc +// scenarios: core-u32-unsigned, full-mask, carry-chain +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vaddc_carry_boundary_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr, %arg3: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c128_i64 = arith.constant 128 : i64 + %c256_i64 = arith.constant 256 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c12288_i64 = arith.constant 12288 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_carry = pto.castptr %c12288_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %lhs = pto.vlds %ub_lhs[%c0] : !pto.ptr -> !pto.vreg<64xui32> + %rhs = pto.vlds %ub_rhs[%c0] : !pto.ptr -> !pto.vreg<64xui32> + %sum, %carry = pto.vaddc %lhs, %rhs, %mask : !pto.vreg<64xui32>, !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<64xui32>, !pto.mask + pto.vsts %sum, %ub_out[%c0], %mask : !pto.vreg<64xui32>, !pto.ptr, !pto.mask + pto.psti %carry, %ub_carry[%c0], "NORM" : !pto.mask, !pto.ptr, index + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_carry, %arg3, %c128_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vaddc-carry-boundary/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vaddc-carry-boundary/launch.cpp new file mode 100644 index 000000000..9c3f6d2c9 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vaddc-carry-boundary/launch.cpp @@ -0,0 +1,56 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vaddc-carry-boundary +// family: binary-vector +// target_ops: pto.vaddc +// scenarios: core-u32-unsigned, full-mask, carry-chain +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vaddc_carry_boundary_kernel_2d(__gm__ uint32_t *v1, __gm__ uint32_t *v2, + __gm__ uint32_t *v3, __gm__ uint8_t *v4); + +void LaunchVaddc_carry_boundary_kernel_2d(uint32_t *v1, uint32_t *v2, + uint32_t *v3, uint8_t *v4, + void *stream) { + vaddc_carry_boundary_kernel_2d<<<1, nullptr, stream>>>( + (__gm__ uint32_t *)v1, (__gm__ uint32_t *)v2, (__gm__ uint32_t *)v3, + (__gm__ uint8_t *)v4); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vaddc-carry-boundary/main.cpp b/test/vpto/cases/micro-op/binary-vector/vaddc-carry-boundary/main.cpp new file mode 100644 index 000000000..486a7314a --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vaddc-carry-boundary/main.cpp @@ -0,0 +1,119 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vaddc-carry-boundary +// family: binary-vector +// target_ops: pto.vaddc +// scenarios: core-u32-unsigned, full-mask, carry-chain +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVaddc_carry_boundary_kernel_2d(uint32_t *v1, uint32_t *v2, + uint32_t *v3, uint8_t *v4, + void *stream); + +int main() { + size_t elemCount_v1 = 64; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + size_t elemCount_v2 = 64; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint32_t); + size_t elemCount_v3 = 64; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint32_t); + size_t elemCount_v4 = 256; + size_t fileSize_v4 = elemCount_v4 * sizeof(uint8_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + uint32_t *v2Host = nullptr; + uint32_t *v2Device = nullptr; + uint32_t *v3Host = nullptr; + uint32_t *v3Device = nullptr; + uint8_t *v4Host = nullptr; + uint8_t *v4Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMallocHost((void **)(&v4Host), fileSize_v4)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v4Device, fileSize_v4, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ReadFile("./v4.bin", fileSize_v4, v4Host, fileSize_v4); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v4Device, fileSize_v4, v4Host, fileSize_v4, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVaddc_carry_boundary_kernel_2d(v1Device, v2Device, v3Device, v4Device, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(v4Host, fileSize_v4, v4Device, fileSize_v4, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + WriteFile("./v4.bin", v4Host, fileSize_v4); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFree(v4Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + aclrtFreeHost(v4Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vaddc/compare.py b/test/vpto/cases/micro-op/binary-vector/vaddc/compare.py new file mode 100755 index 000000000..af24f7730 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vaddc/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vaddc +# family: binary-vector +# target_ops: pto.vaddc +# scenarios: core-u32-unsigned, full-mask, carry-chain + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 +LOGICAL_ELEMS = 64 +SRC_ELEM_BYTES = 4 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + repeat_elems = REPEAT_BYTES // src_elem_bytes + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + +def compare_result(): + golden = np.fromfile("golden_v3.bin", dtype=np.uint32, count=64) + output = np.fromfile("v3.bin", dtype=np.uint32, count=64) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def compare_carry(): + prefix_bytes = _packed_pred_storage_bytes(LOGICAL_ELEMS, SRC_ELEM_BYTES) + golden = np.fromfile("golden_v4.bin", dtype=np.uint8) + output = np.fromfile("v4.bin", dtype=np.uint8) + if golden.size < prefix_bytes or output.size < prefix_bytes: + return False + return np.array_equal(golden[:prefix_bytes], output[:prefix_bytes]) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_result() and compare_carry() + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vaddc/golden.py b/test/vpto/cases/micro-op/binary-vector/vaddc/golden.py new file mode 100644 index 000000000..3fff82cbc --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vaddc/golden.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vaddc +# family: binary-vector +# target_ops: pto.vaddc +# scenarios: core-u32-unsigned, full-mask, carry-chain + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 64 +SEED = 19 + + +def pack_mask_nibbles(bits): + out = np.zeros(256, dtype=np.uint8) + for idx, bit in enumerate(bits): + if not bit: + continue + byte = idx // 2 + if idx % 2 == 0: + out[byte] |= np.uint8(0x1) + else: + out[byte] |= np.uint8(0x10) + return out + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(0, 0xFFFFFFFF, size=LANES, dtype=np.uint32) + v2 = rng.integers(0, 0xFFFFFFFF, size=LANES, dtype=np.uint32) + total = v1.astype(np.uint64) + v2.astype(np.uint64) + result = (total & np.uint64(0xFFFFFFFF)).astype(np.uint32) + carry = (total >> np.uint64(32)) != 0 + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + np.zeros(LANES, dtype=np.uint32).tofile(output_dir / "v3.bin") + np.zeros(256, dtype=np.uint8).tofile(output_dir / "v4.bin") + result.tofile(output_dir / "golden_v3.bin") + pack_mask_nibbles(carry).tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vaddc/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vaddc/kernel.pto new file mode 100644 index 000000000..74554c755 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vaddc/kernel.pto @@ -0,0 +1,56 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vaddc +// family: binary-vector +// target_ops: pto.vaddc +// scenarios: core-u32-unsigned, full-mask, carry-chain +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vaddc_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr, %arg3: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c256_i64 = arith.constant 256 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c12288_i64 = arith.constant 12288 : i64 + %c128_i64 = arith.constant 128 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_carry = pto.castptr %c12288_i64 : i64 -> !pto.ptr + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %lhs = pto.vlds %ub_lhs[%c0] : !pto.ptr -> !pto.vreg<64xui32> + %rhs = pto.vlds %ub_rhs[%c0] : !pto.ptr -> !pto.vreg<64xui32> + %sum, %carry = pto.vaddc %lhs, %rhs, %mask : !pto.vreg<64xui32>, !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<64xui32>, !pto.mask + pto.vsts %sum, %ub_out[%c0], %mask : !pto.vreg<64xui32>, !pto.ptr, !pto.mask + pto.psti %carry, %ub_carry[%c0], "NORM" : !pto.mask, !pto.ptr, index + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_carry, %arg3, %c128_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vaddc/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vaddc/launch.cpp new file mode 100644 index 000000000..21640d20f --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vaddc/launch.cpp @@ -0,0 +1,57 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vaddc +// family: binary-vector +// target_ops: pto.vaddc +// scenarios: core-u32-unsigned, full-mask, carry-chain +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vaddc_kernel_2d(__gm__ uint32_t *v1, + __gm__ uint32_t *v2, + __gm__ uint32_t *v3, + __gm__ uint8_t *v4); + +void LaunchVaddc_kernel_2d(uint32_t *v1, uint32_t *v2, uint32_t *v3, + uint8_t *v4, void *stream) { + vaddc_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1, + (__gm__ uint32_t *)v2, + (__gm__ uint32_t *)v3, + (__gm__ uint8_t *)v4); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vaddc/main.cpp b/test/vpto/cases/micro-op/binary-vector/vaddc/main.cpp new file mode 100644 index 000000000..8cf91e047 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vaddc/main.cpp @@ -0,0 +1,117 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vaddc +// family: binary-vector +// target_ops: pto.vaddc +// scenarios: core-u32-unsigned, full-mask, carry-chain +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVaddc_kernel_2d(uint32_t *v1, uint32_t *v2, uint32_t *v3, + uint8_t *v4, void *stream); + +int main() { + size_t elemCount_v1 = 64; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + size_t elemCount_v2 = 64; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint32_t); + size_t elemCount_v3 = 64; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint32_t); + size_t elemCount_v4 = 256; + size_t fileSize_v4 = elemCount_v4 * sizeof(uint8_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + uint32_t *v2Host = nullptr; + uint32_t *v2Device = nullptr; + uint32_t *v3Host = nullptr; + uint32_t *v3Device = nullptr; + uint8_t *v4Host = nullptr; + uint8_t *v4Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMallocHost((void **)(&v4Host), fileSize_v4)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v4Device, fileSize_v4, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ReadFile("./v4.bin", fileSize_v4, v4Host, fileSize_v4); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v4Device, fileSize_v4, v4Host, fileSize_v4, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVaddc_kernel_2d(v1Device, v2Device, v3Device, v4Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(v4Host, fileSize_v4, v4Device, fileSize_v4, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + WriteFile("./v4.bin", v4Host, fileSize_v4); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFree(v4Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + aclrtFreeHost(v4Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vand-mask-edge/compare.py b/test/vpto/cases/micro-op/binary-vector/vand-mask-edge/compare.py new file mode 100755 index 000000000..f42233bb4 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vand-mask-edge/compare.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vand-mask-edge +# family: binary-vector +# target_ops: pto.vand +# scenarios: core-i16-unsigned, full-mask +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.uint16, 0, 1024) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vand-mask-edge/golden.py b/test/vpto/cases/micro-op/binary-vector/vand-mask-edge/golden.py new file mode 100755 index 000000000..27a700901 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vand-mask-edge/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vand-mask-edge +# family: binary-vector +# target_ops: pto.vand +# scenarios: core-i16-unsigned, full-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + idx = np.arange(ELEMS, dtype=np.uint16) + v1 = np.where((idx & 1) == 0, np.uint16(0xAAAA), np.uint16(0x0F0F)).astype(np.uint16, copy=False) + v2 = np.where((idx & 2) == 0, np.uint16(0x5555), np.uint16(0x3333)).astype(np.uint16, copy=False) + v3 = np.zeros(ELEMS, dtype=np.uint16) + golden_v3 = np.bitwise_and(v1, v2).astype(np.uint16, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + v3.tofile(output_dir / "v3.bin") + golden_v3.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vand-mask-edge/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vand-mask-edge/kernel.pto new file mode 100644 index 000000000..908bcb94c --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vand-mask-edge/kernel.pto @@ -0,0 +1,55 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vand-mask-edge +// family: binary-vector +// target_ops: pto.vand +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vand_mask_edge_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c1024 = arith.constant 1024 : index + %c128 = arith.constant 128 : index + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %out = pto.vand %lhs, %rhs, %mask : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<128xui16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vand-mask-edge/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vand-mask-edge/launch.cpp new file mode 100644 index 000000000..3924e63d3 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vand-mask-edge/launch.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vand-mask-edge +// family: binary-vector +// target_ops: pto.vand +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vand_mask_edge_kernel(__gm__ uint16_t *v1, + __gm__ uint16_t *v2, + __gm__ uint16_t *v3); + +void LaunchVand_mask_edge_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream) { + vand_mask_edge_kernel<<<1, nullptr, stream>>>((__gm__ uint16_t *)v1, + (__gm__ uint16_t *)v2, + (__gm__ uint16_t *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vand-mask-edge/main.cpp b/test/vpto/cases/micro-op/binary-vector/vand-mask-edge/main.cpp new file mode 100644 index 000000000..eae6df992 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vand-mask-edge/main.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vand-mask-edge +// family: binary-vector +// target_ops: pto.vand +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVand_mask_edge_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint16_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + uint16_t *v3Host = nullptr; + uint16_t *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVand_mask_edge_kernel(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vand/compare.py b/test/vpto/cases/micro-op/binary-vector/vand/compare.py new file mode 100755 index 000000000..28c2a232c --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vand/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vand +# family: binary-vector +# target_ops: pto.vand +# scenarios: core-i16-unsigned, full-mask +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.uint16, 0, 1024) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vand/golden.py b/test/vpto/cases/micro-op/binary-vector/vand/golden.py new file mode 100755 index 000000000..a67709b57 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vand/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vand +# family: binary-vector +# target_ops: pto.vand +# scenarios: core-i16-unsigned, full-mask +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(0, 0x10000, size=ELEMS, dtype=np.uint16) + v2 = rng.integers(0, 0x10000, size=ELEMS, dtype=np.uint16) + v3 = np.zeros(ELEMS, dtype=np.uint16) + golden_v3 = np.bitwise_and(v1, v2).astype(np.uint16, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + v3.tofile(output_dir / "v3.bin") + golden_v3.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vand/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vand/kernel.pto new file mode 100644 index 000000000..8042c9550 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vand/kernel.pto @@ -0,0 +1,53 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vand +// family: binary-vector +// target_ops: pto.vand +// scenarios: core-i16-unsigned, full-mask +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vand_i16_unsigned_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1024 = arith.constant 1024 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %out = pto.vand %lhs, %rhs, %mask : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<128xui16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vand/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vand/launch.cpp new file mode 100644 index 000000000..3008d46bc --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vand/launch.cpp @@ -0,0 +1,53 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vand +// family: binary-vector +// target_ops: pto.vand +// scenarios: core-i16-unsigned, full-mask +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vand_i16_unsigned_kernel(__gm__ uint16_t *v1, + __gm__ uint16_t *v2, + __gm__ uint16_t *v3); + +void LaunchVand_i16_unsigned_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream) { + vand_i16_unsigned_kernel<<<1, nullptr, stream>>>((__gm__ uint16_t *)v1, + (__gm__ uint16_t *)v2, + (__gm__ uint16_t *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vand/main.cpp b/test/vpto/cases/micro-op/binary-vector/vand/main.cpp new file mode 100644 index 000000000..958b4422e --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vand/main.cpp @@ -0,0 +1,101 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vand +// family: binary-vector +// target_ops: pto.vand +// scenarios: core-i16-unsigned, full-mask +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVand_i16_unsigned_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint16_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + uint16_t *v3Host = nullptr; + uint16_t *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVand_i16_unsigned_kernel(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv-f16/compare.py b/test/vpto/cases/micro-op/binary-vector/vdiv-f16/compare.py new file mode 100755 index 000000000..1de4f17b7 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vdiv-f16/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vdiv-f16 +# family: binary-vector +# target_ops: pto.vdiv +# scenarios: core-f16, full-mask +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.float16, 5e-3, 1000) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv-f16/golden.py b/test/vpto/cases/micro-op/binary-vector/vdiv-f16/golden.py new file mode 100755 index 000000000..627221d7a --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vdiv-f16/golden.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vdiv-f16 +# family: binary-vector +# target_ops: pto.vdiv +# scenarios: core-f16, full-mask +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +LOGICAL_ELEMS = 1000 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float16) + v2_mag = rng.uniform(0.5, 4.0, size=(ROWS, COLS)).astype(np.float32) + v2_sign = np.where(rng.integers(0, 2, size=(ROWS, COLS), dtype=np.int32) == 0, + np.float32(-1.0), np.float32(1.0)) + v2 = (v2_mag * v2_sign).astype(np.float16) + v3 = np.zeros((ROWS, COLS), dtype=np.float16) + golden_v3 = np.zeros((ROWS, COLS), dtype=np.float16) + golden_v3.reshape(-1)[:LOGICAL_ELEMS] = ( + v1.reshape(-1)[:LOGICAL_ELEMS].astype(np.float32) + / v2.reshape(-1)[:LOGICAL_ELEMS].astype(np.float32) + ).astype(np.float16) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv-f16/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vdiv-f16/kernel.pto new file mode 100644 index 000000000..ffca6d1e3 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vdiv-f16/kernel.pto @@ -0,0 +1,55 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vdiv-f16 +// family: binary-vector +// target_ops: pto.vdiv +// scenarios: core-f16, full-mask +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vdiv_f16_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<128xf16> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<128xf16> + %quot = pto.vdiv %lhs, %rhs, %mask : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> + pto.vsts %quot, %ub_out[%offset], %mask : !pto.vreg<128xf16>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv-f16/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vdiv-f16/launch.cpp new file mode 100644 index 000000000..1abdc6f6d --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vdiv-f16/launch.cpp @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vdiv-f16 +// family: binary-vector +// target_ops: pto.vdiv +// scenarios: core-f16, full-mask +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vdiv_f16_kernel(__gm__ half *v1, + __gm__ half *v2, + __gm__ half *v3); + +void LaunchVdiv_f16_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream) { + vdiv_f16_kernel<<<1, nullptr, stream>>>((__gm__ half *)v1, (__gm__ half *)v2, + (__gm__ half *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv-f16/main.cpp b/test/vpto/cases/micro-op/binary-vector/vdiv-f16/main.cpp new file mode 100644 index 000000000..d0b9cff9a --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vdiv-f16/main.cpp @@ -0,0 +1,101 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vdiv-f16 +// family: binary-vector +// target_ops: pto.vdiv +// scenarios: core-f16, full-mask +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVdiv_f16_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint16_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + uint16_t *v3Host = nullptr; + uint16_t *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVdiv_f16_kernel(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv-f32-exceptional/compare.py b/test/vpto/cases/micro-op/binary-vector/vdiv-f32-exceptional/compare.py new file mode 100644 index 000000000..a5f14dabc --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vdiv-f32-exceptional/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv-f32-exceptional/golden.py b/test/vpto/cases/micro-op/binary-vector/vdiv-f32-exceptional/golden.py new file mode 100644 index 000000000..9caa514c7 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vdiv-f32-exceptional/golden.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + del seed + numer = np.array( + [-np.inf, -7.5, -0.0, 0.0, 1.0, np.inf, np.nan, 3.5], + dtype=np.float32, + ) + denom = np.array( + [2.0, -2.0, 0.0, -0.0, np.inf, 1.0, 1.0, np.nan], + dtype=np.float32, + ) + v1 = np.resize(numer, ROWS * COLS).reshape(ROWS, COLS).astype(np.float32) + v2 = np.resize(denom, ROWS * COLS).reshape(ROWS, COLS).astype(np.float32) + v3 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v3 = np.divide(v1, v2).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv-f32-exceptional/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vdiv-f32-exceptional/kernel.pto new file mode 100644 index 000000000..ddb95efd3 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vdiv-f32-exceptional/kernel.pto @@ -0,0 +1,49 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vdiv_f32_exceptional_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %quot = pto.vdiv %lhs, %rhs, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %quot, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv-f32-exceptional/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vdiv-f32-exceptional/launch.cpp new file mode 100644 index 000000000..a68d4e95b --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vdiv-f32-exceptional/launch.cpp @@ -0,0 +1,46 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vdiv_f32_exceptional_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3); + +void LaunchVdiv_f32_exceptional_kernel_2d(float *v1, float *v2, float *v3, void *stream) { + vdiv_f32_exceptional_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv-f32-exceptional/main.cpp b/test/vpto/cases/micro-op/binary-vector/vdiv-f32-exceptional/main.cpp new file mode 100644 index 000000000..d048e2faf --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vdiv-f32-exceptional/main.cpp @@ -0,0 +1,95 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVdiv_f32_exceptional_kernel_2d(float *v1, float *v2, float *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVdiv_f32_exceptional_kernel_2d(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv-tail/compare.py b/test/vpto/cases/micro-op/binary-vector/vdiv-tail/compare.py new file mode 100644 index 000000000..c95419953 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vdiv-tail/compare.py @@ -0,0 +1,39 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.float32, 1e-4, 1000) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv-tail/golden.py b/test/vpto/cases/micro-op/binary-vector/vdiv-tail/golden.py new file mode 100644 index 000000000..c010ada1f --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vdiv-tail/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +LOGICAL_ELEMS = 1000 +OUT_SENTINEL = np.float32(-123.25) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.random((ROWS, COLS), dtype=np.float32) + np.float32(0.5) + v2 = rng.random((ROWS, COLS), dtype=np.float32) + np.float32(0.5) + v3 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + golden_v3 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + golden_v3.reshape(-1)[:LOGICAL_ELEMS] = ( + v1.reshape(-1)[:LOGICAL_ELEMS] / v2.reshape(-1)[:LOGICAL_ELEMS] + ).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv-tail/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vdiv-tail/kernel.pto new file mode 100644 index 000000000..774b7f31f --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vdiv-tail/kernel.pto @@ -0,0 +1,49 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vdiv_tail_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c1000_i32 = arith.constant 1000 : i32 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1000_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %quot = pto.vdiv %lhs, %rhs, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %quot, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv-tail/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vdiv-tail/launch.cpp new file mode 100644 index 000000000..85826b5f8 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vdiv-tail/launch.cpp @@ -0,0 +1,46 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vdiv_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3); + +void LaunchVdiv_tail_kernel_2d(float *v1, float *v2, float *v3, void *stream) { + vdiv_tail_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv-tail/main.cpp b/test/vpto/cases/micro-op/binary-vector/vdiv-tail/main.cpp new file mode 100644 index 000000000..3f3dcf515 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vdiv-tail/main.cpp @@ -0,0 +1,94 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVdiv_tail_kernel_2d(float *v1, float *v2, float *v3, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVdiv_tail_kernel_2d(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv/compare.py b/test/vpto/cases/micro-op/binary-vector/vdiv/compare.py new file mode 100644 index 000000000..a5f14dabc --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vdiv/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv/golden.py b/test/vpto/cases/micro-op/binary-vector/vdiv/golden.py new file mode 100644 index 000000000..ea43aa613 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vdiv/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2_mag = rng.uniform(0.5, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2_sign = np.where(rng.integers(0, 2, size=(ROWS, COLS), dtype=np.int32) == 0, + np.float32(-1.0), np.float32(1.0)) + v2 = (v2_mag * v2_sign).astype(np.float32, copy=False) + golden_v3 = (v1 / v2).astype(np.float32, copy=False) + v3 = np.zeros((ROWS, COLS), dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vdiv/kernel.pto new file mode 100644 index 000000000..2b6849a8b --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vdiv/kernel.pto @@ -0,0 +1,50 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @div_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %quot = pto.vdiv %lhs, %rhs, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %quot, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vdiv/launch.cpp new file mode 100644 index 000000000..fd82d2921 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vdiv/launch.cpp @@ -0,0 +1,46 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void div_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3); + +void LaunchDiv_kernel_2d(float *v1, float *v2, float *v3, void *stream) { + div_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vdiv/main.cpp b/test/vpto/cases/micro-op/binary-vector/vdiv/main.cpp new file mode 100644 index 000000000..3972f99b3 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vdiv/main.cpp @@ -0,0 +1,94 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchDiv_kernel_2d(float *v1, float *v2, float *v3, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchDiv_kernel_2d(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmax-tail/compare.py b/test/vpto/cases/micro-op/binary-vector/vmax-tail/compare.py new file mode 100644 index 000000000..c95419953 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmax-tail/compare.py @@ -0,0 +1,39 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.float32, 1e-4, 1000) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vmax-tail/golden.py b/test/vpto/cases/micro-op/binary-vector/vmax-tail/golden.py new file mode 100644 index 000000000..82d3beb41 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmax-tail/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +LOGICAL_ELEMS = 1000 +OUT_SENTINEL = np.float32(-123.25) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.random((ROWS, COLS), dtype=np.float32) + v2 = rng.random((ROWS, COLS), dtype=np.float32) + v3 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + golden_v3 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + golden_v3.reshape(-1)[:LOGICAL_ELEMS] = np.maximum( + v1.reshape(-1)[:LOGICAL_ELEMS], v2.reshape(-1)[:LOGICAL_ELEMS] + ).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vmax-tail/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vmax-tail/kernel.pto new file mode 100644 index 000000000..8bfae2fd0 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmax-tail/kernel.pto @@ -0,0 +1,49 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmax_tail_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c1000_i32 = arith.constant 1000 : i32 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1000_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %maxv = pto.vmax %lhs, %rhs, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %maxv, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmax-tail/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vmax-tail/launch.cpp new file mode 100644 index 000000000..bab607bfa --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmax-tail/launch.cpp @@ -0,0 +1,46 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vmax_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3); + +void LaunchVadd_tail_kernel_2d(float *v1, float *v2, float *v3, void *stream) { + vmax_tail_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmax-tail/main.cpp b/test/vpto/cases/micro-op/binary-vector/vmax-tail/main.cpp new file mode 100644 index 000000000..40a9881d6 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmax-tail/main.cpp @@ -0,0 +1,94 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVadd_tail_kernel_2d(float *v1, float *v2, float *v3, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVadd_tail_kernel_2d(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmax/compare.py b/test/vpto/cases/micro-op/binary-vector/vmax/compare.py new file mode 100644 index 000000000..a5f14dabc --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmax/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vmax/golden.py b/test/vpto/cases/micro-op/binary-vector/vmax/golden.py new file mode 100644 index 000000000..aca780439 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmax/golden.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + golden_v3 = np.maximum(v1, v2).astype(np.float32, copy=False) + v3 = np.zeros((ROWS, COLS), dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vmax/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vmax/kernel.pto new file mode 100644 index 000000000..8c7835c64 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmax/kernel.pto @@ -0,0 +1,50 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @max_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %maxv = pto.vmax %lhs, %rhs, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %maxv, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmax/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vmax/launch.cpp new file mode 100644 index 000000000..44e917951 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmax/launch.cpp @@ -0,0 +1,46 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void max_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3); + +void LaunchMax_kernel_2d(float *v1, float *v2, float *v3, void *stream) { + max_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmax/main.cpp b/test/vpto/cases/micro-op/binary-vector/vmax/main.cpp new file mode 100644 index 000000000..9713fd509 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmax/main.cpp @@ -0,0 +1,94 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchMax_kernel_2d(float *v1, float *v2, float *v3, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchMax_kernel_2d(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-bf16/compare.py b/test/vpto/cases/micro-op/binary-vector/vmin-bf16/compare.py new file mode 100755 index 000000000..8e84eda9e --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-bf16/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vmin-bf16 +# family: binary-vector +# target_ops: pto.vmin +# scenarios: core-bf16, full-mask +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.uint16, 0, 1024) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-bf16/golden.py b/test/vpto/cases/micro-op/binary-vector/vmin-bf16/golden.py new file mode 100755 index 000000000..a399eeb9d --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-bf16/golden.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vmin-bf16 +# family: binary-vector +# target_ops: pto.vmin +# scenarios: core-bf16, full-mask +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def f32_to_bf16_bits(values: np.ndarray) -> np.ndarray: + wide = values.astype(np.float32, copy=False).view(np.uint32) + rounding = np.uint32(0x7FFF) + ((wide >> 16) & np.uint32(1)) + return ((wide + rounding) >> 16).astype(np.uint16) + + +def bf16_bits_to_f32(bits: np.ndarray) -> np.ndarray: + return (bits.astype(np.uint32) << 16).view(np.float32) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1_f32 = rng.uniform(-4.0, 4.0, size=ELEMS).astype(np.float32) + v2_f32 = rng.uniform(-4.0, 4.0, size=ELEMS).astype(np.float32) + v1 = f32_to_bf16_bits(v1_f32) + v2 = f32_to_bf16_bits(v2_f32) + v3 = np.zeros(ELEMS, dtype=np.uint16) + golden_v3 = f32_to_bf16_bits(np.minimum(bf16_bits_to_f32(v1), bf16_bits_to_f32(v2))) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + v3.tofile(output_dir / "v3.bin") + golden_v3.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-bf16/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vmin-bf16/kernel.pto new file mode 100644 index 000000000..046c88d23 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-bf16/kernel.pto @@ -0,0 +1,53 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vmin-bf16 +// family: binary-vector +// target_ops: pto.vmin +// scenarios: core-bf16, full-mask +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmin_bf16_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1024 = arith.constant 1024 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<128xbf16> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<128xbf16> + %out = pto.vmin %lhs, %rhs, %mask : !pto.vreg<128xbf16>, !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<128xbf16> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<128xbf16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-bf16/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vmin-bf16/launch.cpp new file mode 100644 index 000000000..1f374fb36 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-bf16/launch.cpp @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vmin-bf16 +// family: binary-vector +// target_ops: pto.vmin +// scenarios: core-bf16, full-mask +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vmin_bf16_kernel(__gm__ bfloat16_t *v1, + __gm__ bfloat16_t *v2, + __gm__ bfloat16_t *v3); + +void LaunchVmin_bf16_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, void *stream) { + vmin_bf16_kernel<<<1, nullptr, stream>>>((__gm__ bfloat16_t *)v1, + (__gm__ bfloat16_t *)v2, + (__gm__ bfloat16_t *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-bf16/main.cpp b/test/vpto/cases/micro-op/binary-vector/vmin-bf16/main.cpp new file mode 100644 index 000000000..01fb803f3 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-bf16/main.cpp @@ -0,0 +1,102 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vmin-bf16 +// family: binary-vector +// target_ops: pto.vmin +// scenarios: core-bf16, full-mask +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmin_bf16_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint16_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + uint16_t *v3Host = nullptr; + uint16_t *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmin_bf16_kernel(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-f16/compare.py b/test/vpto/cases/micro-op/binary-vector/vmin-f16/compare.py new file mode 100755 index 000000000..d4fe300db --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-f16/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vmin-f16 +# family: binary-vector +# target_ops: pto.vmin +# scenarios: core-f16, full-mask +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.float16, 5e-3, 1024) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-f16/golden.py b/test/vpto/cases/micro-op/binary-vector/vmin-f16/golden.py new file mode 100755 index 000000000..4ee39a4b3 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-f16/golden.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vmin-f16 +# family: binary-vector +# target_ops: pto.vmin +# scenarios: core-f16, full-mask +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float16) + v2 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float16) + v3 = np.zeros((ROWS, COLS), dtype=np.float16) + golden_v3 = np.minimum(v1.astype(np.float32), v2.astype(np.float32)).astype(np.float16) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-f16/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vmin-f16/kernel.pto new file mode 100644 index 000000000..05368522d --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-f16/kernel.pto @@ -0,0 +1,53 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vmin-f16 +// family: binary-vector +// target_ops: pto.vmin +// scenarios: core-f16, full-mask +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmin_f16_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1024 = arith.constant 1024 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<128xf16> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<128xf16> + %out = pto.vmin %lhs, %rhs, %mask : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<128xf16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-f16/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vmin-f16/launch.cpp new file mode 100644 index 000000000..5151be743 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-f16/launch.cpp @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vmin-f16 +// family: binary-vector +// target_ops: pto.vmin +// scenarios: core-f16, full-mask +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vmin_f16_kernel(__gm__ half *v1, + __gm__ half *v2, + __gm__ half *v3); + +void LaunchVmin_f16_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, void *stream) { + vmin_f16_kernel<<<1, nullptr, stream>>>((__gm__ half *)v1, + (__gm__ half *)v2, + (__gm__ half *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-f16/main.cpp b/test/vpto/cases/micro-op/binary-vector/vmin-f16/main.cpp new file mode 100644 index 000000000..374d8a322 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-f16/main.cpp @@ -0,0 +1,102 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vmin-f16 +// family: binary-vector +// target_ops: pto.vmin +// scenarios: core-f16, full-mask +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmin_f16_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint16_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + uint16_t *v3Host = nullptr; + uint16_t *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmin_f16_kernel(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/compare.py b/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/compare.py new file mode 100644 index 000000000..a5f14dabc --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/golden.py b/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/golden.py new file mode 100644 index 000000000..4d8d2f34a --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/golden.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + del seed + lhs = np.array( + [-np.inf, -7.5, -0.0, 0.0, 1.0, np.inf, np.nan, 3.5], + dtype=np.float32, + ) + rhs = np.array( + [np.inf, -2.5, 0.0, -0.0, -1.0, 1.0, 1.0, np.nan], + dtype=np.float32, + ) + v1 = np.resize(lhs, ROWS * COLS).reshape(ROWS, COLS).astype(np.float32) + v2 = np.resize(rhs, ROWS * COLS).reshape(ROWS, COLS).astype(np.float32) + v3 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v3 = np.minimum(v1, v2).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/kernel.pto new file mode 100644 index 000000000..cfbd2a17e --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/kernel.pto @@ -0,0 +1,50 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @min_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vmin %lhs, %rhs, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/launch.cpp new file mode 100644 index 000000000..f2c64c6a6 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/launch.cpp @@ -0,0 +1,46 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void min_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3); + +void LaunchMin_kernel_2d(float *v1, float *v2, float *v3, void *stream) { + min_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/main.cpp b/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/main.cpp new file mode 100644 index 000000000..b952b76a0 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-f32-exceptional/main.cpp @@ -0,0 +1,94 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchMin_kernel_2d(float *v1, float *v2, float *v3, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchMin_kernel_2d(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-i16-signed/compare.py b/test/vpto/cases/micro-op/binary-vector/vmin-i16-signed/compare.py new file mode 100755 index 000000000..2afc3f8ec --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-i16-signed/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vmin-i16-signed +# family: binary-vector +# target_ops: pto.vmin +# scenarios: core-i16-signed, full-mask +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.int16, 0, 1024) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-i16-signed/golden.py b/test/vpto/cases/micro-op/binary-vector/vmin-i16-signed/golden.py new file mode 100755 index 000000000..48ce71042 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-i16-signed/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vmin-i16-signed +# family: binary-vector +# target_ops: pto.vmin +# scenarios: core-i16-signed, full-mask +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(-1000, 1001, size=ELEMS, dtype=np.int16) + v2 = rng.integers(-1000, 1001, size=ELEMS, dtype=np.int16) + v3 = np.zeros(ELEMS, dtype=np.int16) + golden_v3 = np.minimum(v1, v2) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + v3.tofile(output_dir / "v3.bin") + golden_v3.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-i16-signed/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vmin-i16-signed/kernel.pto new file mode 100644 index 000000000..3a37e8baa --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-i16-signed/kernel.pto @@ -0,0 +1,53 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vmin-i16-signed +// family: binary-vector +// target_ops: pto.vmin +// scenarios: core-i16-signed, full-mask +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmin_i16_signed_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1024 = arith.constant 1024 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<128xi16> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<128xi16> + %out = pto.vmin %lhs, %rhs, %mask : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<128xi16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-i16-signed/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vmin-i16-signed/launch.cpp new file mode 100644 index 000000000..923e415d4 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-i16-signed/launch.cpp @@ -0,0 +1,53 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vmin-i16-signed +// family: binary-vector +// target_ops: pto.vmin +// scenarios: core-i16-signed, full-mask +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vmin_i16_signed_kernel(__gm__ int16_t *v1, + __gm__ int16_t *v2, + __gm__ int16_t *v3); + +void LaunchVmin_i16_signed_kernel(int16_t *v1, int16_t *v2, int16_t *v3, + void *stream) { + vmin_i16_signed_kernel<<<1, nullptr, stream>>>((__gm__ int16_t *)v1, + (__gm__ int16_t *)v2, + (__gm__ int16_t *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-i16-signed/main.cpp b/test/vpto/cases/micro-op/binary-vector/vmin-i16-signed/main.cpp new file mode 100644 index 000000000..029455a99 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-i16-signed/main.cpp @@ -0,0 +1,101 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vmin-i16-signed +// family: binary-vector +// target_ops: pto.vmin +// scenarios: core-i16-signed, full-mask +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmin_i16_signed_kernel(int16_t *v1, int16_t *v2, int16_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(int16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(int16_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(int16_t); + int16_t *v1Host = nullptr; + int16_t *v1Device = nullptr; + int16_t *v2Host = nullptr; + int16_t *v2Device = nullptr; + int16_t *v3Host = nullptr; + int16_t *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmin_i16_signed_kernel(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-i16-unsigned/compare.py b/test/vpto/cases/micro-op/binary-vector/vmin-i16-unsigned/compare.py new file mode 100755 index 000000000..f87d0f17d --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-i16-unsigned/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vmin-i16-unsigned +# family: binary-vector +# target_ops: pto.vmin +# scenarios: core-i16-unsigned, full-mask +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.uint16, 0, 1024) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-i16-unsigned/golden.py b/test/vpto/cases/micro-op/binary-vector/vmin-i16-unsigned/golden.py new file mode 100755 index 000000000..7ac5b68a6 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-i16-unsigned/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vmin-i16-unsigned +# family: binary-vector +# target_ops: pto.vmin +# scenarios: core-i16-unsigned, full-mask +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(0, 2001, size=ELEMS, dtype=np.uint16) + v2 = rng.integers(0, 2001, size=ELEMS, dtype=np.uint16) + v3 = np.zeros(ELEMS, dtype=np.uint16) + golden_v3 = np.minimum(v1, v2) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + v3.tofile(output_dir / "v3.bin") + golden_v3.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-i16-unsigned/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vmin-i16-unsigned/kernel.pto new file mode 100644 index 000000000..0dce887df --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-i16-unsigned/kernel.pto @@ -0,0 +1,53 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vmin-i16-unsigned +// family: binary-vector +// target_ops: pto.vmin +// scenarios: core-i16-unsigned, full-mask +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmin_i16_unsigned_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1024 = arith.constant 1024 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %out = pto.vmin %lhs, %rhs, %mask : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<128xui16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-i16-unsigned/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vmin-i16-unsigned/launch.cpp new file mode 100644 index 000000000..6cc3d692c --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-i16-unsigned/launch.cpp @@ -0,0 +1,53 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vmin-i16-unsigned +// family: binary-vector +// target_ops: pto.vmin +// scenarios: core-i16-unsigned, full-mask +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vmin_i16_unsigned_kernel(__gm__ uint16_t *v1, + __gm__ uint16_t *v2, + __gm__ uint16_t *v3); + +void LaunchVmin_i16_unsigned_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream) { + vmin_i16_unsigned_kernel<<<1, nullptr, stream>>>((__gm__ uint16_t *)v1, + (__gm__ uint16_t *)v2, + (__gm__ uint16_t *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-i16-unsigned/main.cpp b/test/vpto/cases/micro-op/binary-vector/vmin-i16-unsigned/main.cpp new file mode 100644 index 000000000..885dea67a --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-i16-unsigned/main.cpp @@ -0,0 +1,101 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vmin-i16-unsigned +// family: binary-vector +// target_ops: pto.vmin +// scenarios: core-i16-unsigned, full-mask +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmin_i16_unsigned_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint16_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + uint16_t *v3Host = nullptr; + uint16_t *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmin_i16_unsigned_kernel(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-tail/compare.py b/test/vpto/cases/micro-op/binary-vector/vmin-tail/compare.py new file mode 100644 index 000000000..c95419953 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-tail/compare.py @@ -0,0 +1,39 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.float32, 1e-4, 1000) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-tail/golden.py b/test/vpto/cases/micro-op/binary-vector/vmin-tail/golden.py new file mode 100644 index 000000000..29bbdcd28 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-tail/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +LOGICAL_ELEMS = 1000 +OUT_SENTINEL = np.float32(-123.25) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v3 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + golden_v3 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + golden_v3.reshape(-1)[:LOGICAL_ELEMS] = np.minimum( + v1.reshape(-1)[:LOGICAL_ELEMS], v2.reshape(-1)[:LOGICAL_ELEMS] + ).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-tail/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vmin-tail/kernel.pto new file mode 100644 index 000000000..e73d73821 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-tail/kernel.pto @@ -0,0 +1,49 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmin_tail_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c1000_i32 = arith.constant 1000 : i32 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1000_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vmin %lhs, %rhs, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-tail/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vmin-tail/launch.cpp new file mode 100644 index 000000000..f2a890c47 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-tail/launch.cpp @@ -0,0 +1,46 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vmin_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3); + +void LaunchVmin_tail_kernel_2d(float *v1, float *v2, float *v3, void *stream) { + vmin_tail_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin-tail/main.cpp b/test/vpto/cases/micro-op/binary-vector/vmin-tail/main.cpp new file mode 100644 index 000000000..5a418b3da --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin-tail/main.cpp @@ -0,0 +1,94 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmin_tail_kernel_2d(float *v1, float *v2, float *v3, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmin_tail_kernel_2d(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin/compare.py b/test/vpto/cases/micro-op/binary-vector/vmin/compare.py new file mode 100644 index 000000000..a5f14dabc --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vmin/golden.py b/test/vpto/cases/micro-op/binary-vector/vmin/golden.py new file mode 100644 index 000000000..6d18ab792 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin/golden.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.random((ROWS, COLS), dtype=np.float32) + v2 = rng.random((ROWS, COLS), dtype=np.float32) + golden_v3 = np.minimum(v1, v2) + v3 = np.zeros((ROWS, COLS), dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.astype(np.float32, copy=False).reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vmin/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vmin/kernel.pto new file mode 100644 index 000000000..4d1789609 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin/kernel.pto @@ -0,0 +1,50 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @min_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %minv = pto.vmin %lhs, %rhs, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %minv, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vmin/launch.cpp new file mode 100644 index 000000000..f2c64c6a6 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin/launch.cpp @@ -0,0 +1,46 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void min_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3); + +void LaunchMin_kernel_2d(float *v1, float *v2, float *v3, void *stream) { + min_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmin/main.cpp b/test/vpto/cases/micro-op/binary-vector/vmin/main.cpp new file mode 100644 index 000000000..b952b76a0 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmin/main.cpp @@ -0,0 +1,94 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchMin_kernel_2d(float *v1, float *v2, float *v3, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchMin_kernel_2d(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmul-tail/compare.py b/test/vpto/cases/micro-op/binary-vector/vmul-tail/compare.py new file mode 100644 index 000000000..c95419953 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmul-tail/compare.py @@ -0,0 +1,39 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.float32, 1e-4, 1000) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vmul-tail/golden.py b/test/vpto/cases/micro-op/binary-vector/vmul-tail/golden.py new file mode 100644 index 000000000..553faae15 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmul-tail/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +LOGICAL_ELEMS = 1000 +OUT_SENTINEL = np.float32(-123.25) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.random((ROWS, COLS), dtype=np.float32) + v2 = rng.random((ROWS, COLS), dtype=np.float32) + v3 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + golden_v3 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + golden_v3.reshape(-1)[:LOGICAL_ELEMS] = ( + v1.reshape(-1)[:LOGICAL_ELEMS] * v2.reshape(-1)[:LOGICAL_ELEMS] + ).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vmul-tail/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vmul-tail/kernel.pto new file mode 100644 index 000000000..68675fb11 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmul-tail/kernel.pto @@ -0,0 +1,49 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmul_tail_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c1000_i32 = arith.constant 1000 : i32 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1000_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %prod = pto.vmul %lhs, %rhs, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %prod, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmul-tail/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vmul-tail/launch.cpp new file mode 100644 index 000000000..6e1ab54ae --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmul-tail/launch.cpp @@ -0,0 +1,46 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vmul_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3); + +void LaunchVadd_tail_kernel_2d(float *v1, float *v2, float *v3, void *stream) { + vmul_tail_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmul-tail/main.cpp b/test/vpto/cases/micro-op/binary-vector/vmul-tail/main.cpp new file mode 100644 index 000000000..40a9881d6 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmul-tail/main.cpp @@ -0,0 +1,94 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVadd_tail_kernel_2d(float *v1, float *v2, float *v3, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVadd_tail_kernel_2d(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmul/compare.py b/test/vpto/cases/micro-op/binary-vector/vmul/compare.py new file mode 100644 index 000000000..a5f14dabc --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmul/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vmul/golden.py b/test/vpto/cases/micro-op/binary-vector/vmul/golden.py new file mode 100644 index 000000000..23e4731e7 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmul/golden.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-4.0, 4.0, size=(ROWS, COLS)).astype(np.float32) + v2 = rng.uniform(-4.0, 4.0, size=(ROWS, COLS)).astype(np.float32) + golden_v3 = (v1 * v2).astype(np.float32, copy=False) + v3 = np.zeros((ROWS, COLS), dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vmul/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vmul/kernel.pto new file mode 100644 index 000000000..f42d8a759 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmul/kernel.pto @@ -0,0 +1,50 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @mul_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %prod = pto.vmul %lhs, %rhs, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %prod, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmul/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vmul/launch.cpp new file mode 100644 index 000000000..21ee4384c --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmul/launch.cpp @@ -0,0 +1,46 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void mul_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3); + +void LaunchMul_kernel_2d(float *v1, float *v2, float *v3, void *stream) { + mul_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vmul/main.cpp b/test/vpto/cases/micro-op/binary-vector/vmul/main.cpp new file mode 100644 index 000000000..711269796 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vmul/main.cpp @@ -0,0 +1,94 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchMul_kernel_2d(float *v1, float *v2, float *v3, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchMul_kernel_2d(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vor-f16/compare.py b/test/vpto/cases/micro-op/binary-vector/vor-f16/compare.py new file mode 100755 index 000000000..78bd43ef6 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vor-f16/compare.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vor-f16 +# family: binary-vector +# target_ops: pto.vor +# scenarios: core-f16, full-mask +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.uint16, 0, 1024) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vor-f16/golden.py b/test/vpto/cases/micro-op/binary-vector/vor-f16/golden.py new file mode 100755 index 000000000..471da5094 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vor-f16/golden.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vor-f16 +# family: binary-vector +# target_ops: pto.vor +# scenarios: core-f16, full-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + bits1 = rng.integers(0, 0x10000, size=ELEMS, dtype=np.uint16) + bits2 = rng.integers(0, 0x10000, size=ELEMS, dtype=np.uint16) + bits1[:8] = np.array( + [0x0000, 0x8000, 0x3c00, 0xbc00, 0x7c00, 0xfc00, 0x7e00, 0x3555], + dtype=np.uint16, + ) + bits2[:8] = np.array( + [0x0001, 0x0001, 0x4000, 0x2000, 0x0001, 0x0001, 0x0100, 0x0aaa], + dtype=np.uint16, + ) + v1 = bits1.view(np.float16) + v2 = bits2.view(np.float16) + v3 = np.zeros(ELEMS, dtype=np.float16) + golden_v3 = np.bitwise_or(bits1, bits2).view(np.float16) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + v3.tofile(output_dir / "v3.bin") + golden_v3.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vor-f16/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vor-f16/kernel.pto new file mode 100644 index 000000000..d5b6351d9 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vor-f16/kernel.pto @@ -0,0 +1,55 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vor-f16 +// family: binary-vector +// target_ops: pto.vor +// scenarios: core-f16, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vor_f16_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1024 = arith.constant 1024 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<128xf16> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<128xf16> + %out = pto.vor %lhs, %rhs, %mask : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<128xf16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vor-f16/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vor-f16/launch.cpp new file mode 100644 index 000000000..45b41406e --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vor-f16/launch.cpp @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vor-f16 +// family: binary-vector +// target_ops: pto.vor +// scenarios: core-f16, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vor_f16_kernel(__gm__ half *v1, + __gm__ half *v2, + __gm__ half *v3); + +void LaunchVor_f16_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream) { + vor_f16_kernel<<<1, nullptr, stream>>>((__gm__ half *)v1, (__gm__ half *)v2, + (__gm__ half *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vor-f16/main.cpp b/test/vpto/cases/micro-op/binary-vector/vor-f16/main.cpp new file mode 100644 index 000000000..826735922 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vor-f16/main.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vor-f16 +// family: binary-vector +// target_ops: pto.vor +// scenarios: core-f16, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVor_f16_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint16_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + uint16_t *v3Host = nullptr; + uint16_t *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVor_f16_kernel(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vor-mask-edge/compare.py b/test/vpto/cases/micro-op/binary-vector/vor-mask-edge/compare.py new file mode 100755 index 000000000..58ac20c66 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vor-mask-edge/compare.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vor-mask-edge +# family: binary-vector +# target_ops: pto.vor +# scenarios: core-i16-unsigned, full-mask +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.uint16, 0, 1024) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vor-mask-edge/golden.py b/test/vpto/cases/micro-op/binary-vector/vor-mask-edge/golden.py new file mode 100755 index 000000000..3c28fa036 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vor-mask-edge/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vor-mask-edge +# family: binary-vector +# target_ops: pto.vor +# scenarios: core-i16-unsigned, full-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + idx = np.arange(ELEMS, dtype=np.uint16) + v1 = np.where((idx & 1) == 0, np.uint16(0xAAAA), np.uint16(0x0F0F)).astype(np.uint16, copy=False) + v2 = np.where((idx & 2) == 0, np.uint16(0x5555), np.uint16(0x3333)).astype(np.uint16, copy=False) + v3 = np.zeros(ELEMS, dtype=np.uint16) + golden_v3 = np.bitwise_or(v1, v2).astype(np.uint16, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + v3.tofile(output_dir / "v3.bin") + golden_v3.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vor-mask-edge/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vor-mask-edge/kernel.pto new file mode 100644 index 000000000..c964b2940 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vor-mask-edge/kernel.pto @@ -0,0 +1,55 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vor-mask-edge +// family: binary-vector +// target_ops: pto.vor +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vor_mask_edge_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c1024 = arith.constant 1024 : index + %c128 = arith.constant 128 : index + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %out = pto.vor %lhs, %rhs, %mask : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<128xui16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vor-mask-edge/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vor-mask-edge/launch.cpp new file mode 100644 index 000000000..c9e9411c6 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vor-mask-edge/launch.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vor-mask-edge +// family: binary-vector +// target_ops: pto.vor +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vor_mask_edge_kernel(__gm__ uint16_t *v1, + __gm__ uint16_t *v2, + __gm__ uint16_t *v3); + +void LaunchVor_mask_edge_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream) { + vor_mask_edge_kernel<<<1, nullptr, stream>>>((__gm__ uint16_t *)v1, + (__gm__ uint16_t *)v2, + (__gm__ uint16_t *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vor-mask-edge/main.cpp b/test/vpto/cases/micro-op/binary-vector/vor-mask-edge/main.cpp new file mode 100644 index 000000000..634ad7664 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vor-mask-edge/main.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vor-mask-edge +// family: binary-vector +// target_ops: pto.vor +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVor_mask_edge_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint16_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + uint16_t *v3Host = nullptr; + uint16_t *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVor_mask_edge_kernel(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vor/compare.py b/test/vpto/cases/micro-op/binary-vector/vor/compare.py new file mode 100755 index 000000000..8b38a30b8 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vor/compare.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vor +# family: binary-vector +# target_ops: pto.vor +# scenarios: core-i16-unsigned, full-mask +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.uint16, 0, 1024) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vor/golden.py b/test/vpto/cases/micro-op/binary-vector/vor/golden.py new file mode 100755 index 000000000..c0d7ce117 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vor/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vor +# family: binary-vector +# target_ops: pto.vor +# scenarios: core-i16-unsigned, full-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(0, 0x10000, size=ELEMS, dtype=np.uint16) + v2 = rng.integers(0, 0x10000, size=ELEMS, dtype=np.uint16) + v3 = np.zeros(ELEMS, dtype=np.uint16) + golden_v3 = np.bitwise_or(v1, v2).astype(np.uint16, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + v3.tofile(output_dir / "v3.bin") + golden_v3.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vor/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vor/kernel.pto new file mode 100644 index 000000000..3ce40ada7 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vor/kernel.pto @@ -0,0 +1,55 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vor +// family: binary-vector +// target_ops: pto.vor +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vor_i16_unsigned_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1024 = arith.constant 1024 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %out = pto.vor %lhs, %rhs, %mask : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<128xui16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vor/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vor/launch.cpp new file mode 100644 index 000000000..416e5354f --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vor/launch.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vor +// family: binary-vector +// target_ops: pto.vor +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vor_i16_unsigned_kernel(__gm__ uint16_t *v1, + __gm__ uint16_t *v2, + __gm__ uint16_t *v3); + +void LaunchVor_i16_unsigned_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream) { + vor_i16_unsigned_kernel<<<1, nullptr, stream>>>((__gm__ uint16_t *)v1, + (__gm__ uint16_t *)v2, + (__gm__ uint16_t *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vor/main.cpp b/test/vpto/cases/micro-op/binary-vector/vor/main.cpp new file mode 100644 index 000000000..0ebb0d781 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vor/main.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vor +// family: binary-vector +// target_ops: pto.vor +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVor_i16_unsigned_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint16_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + uint16_t *v3Host = nullptr; + uint16_t *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVor_i16_unsigned_kernel(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vshl-i32-unsigned/compare.py b/test/vpto/cases/micro-op/binary-vector/vshl-i32-unsigned/compare.py new file mode 100755 index 000000000..fcf304f6f --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshl-i32-unsigned/compare.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vshl-i32-unsigned +# family: binary-vector +# target_ops: pto.vshl +# scenarios: core-i32-unsigned, full-mask +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.uint32, 0, 1024) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vshl-i32-unsigned/golden.py b/test/vpto/cases/micro-op/binary-vector/vshl-i32-unsigned/golden.py new file mode 100755 index 000000000..cefd36ee1 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshl-i32-unsigned/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vshl-i32-unsigned +# family: binary-vector +# target_ops: pto.vshl +# scenarios: core-i32-unsigned, full-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(0, 1 << 32, size=ELEMS, dtype=np.uint32) + v2 = rng.integers(0, 32, size=ELEMS, dtype=np.uint32) + v3 = np.zeros(ELEMS, dtype=np.uint32) + golden_v3 = np.left_shift(v1, v2 & np.uint32(31)).astype(np.uint32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + v3.tofile(output_dir / "v3.bin") + golden_v3.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vshl-i32-unsigned/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vshl-i32-unsigned/kernel.pto new file mode 100644 index 000000000..50ef19cc1 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshl-i32-unsigned/kernel.pto @@ -0,0 +1,57 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vshl-i32-unsigned +// family: binary-vector +// target_ops: pto.vshl +// scenarios: core-i32-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vshl_i32_unsigned_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c128 = arith.constant 128 : index + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c64 { + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<64xui32> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<64xui32> + %out = pto.vshl %lhs, %rhs, %mask : !pto.vreg<64xui32>, !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<64xui32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xui32>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vshl-i32-unsigned/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vshl-i32-unsigned/launch.cpp new file mode 100644 index 000000000..ba6ef9dca --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshl-i32-unsigned/launch.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vshl-i32-unsigned +// family: binary-vector +// target_ops: pto.vshl +// scenarios: core-i32-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vshl_i32_unsigned_kernel(__gm__ uint32_t *v1, + __gm__ uint32_t *v2, + __gm__ uint32_t *v3); + +void LaunchVshl_i32_unsigned_kernel(uint32_t *v1, uint32_t *v2, uint32_t *v3, + void *stream) { + vshl_i32_unsigned_kernel<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1, + (__gm__ uint32_t *)v2, + (__gm__ uint32_t *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vshl-i32-unsigned/main.cpp b/test/vpto/cases/micro-op/binary-vector/vshl-i32-unsigned/main.cpp new file mode 100644 index 000000000..df4bc2e9a --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshl-i32-unsigned/main.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vshl-i32-unsigned +// family: binary-vector +// target_ops: pto.vshl +// scenarios: core-i32-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVshl_i32_unsigned_kernel(uint32_t *v1, uint32_t *v2, uint32_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint32_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + uint32_t *v2Host = nullptr; + uint32_t *v2Device = nullptr; + uint32_t *v3Host = nullptr; + uint32_t *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVshl_i32_unsigned_kernel(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vshl-shift-boundary/compare.py b/test/vpto/cases/micro-op/binary-vector/vshl-shift-boundary/compare.py new file mode 100755 index 000000000..2ef28c2cf --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshl-shift-boundary/compare.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vshl-shift-boundary +# family: binary-vector +# target_ops: pto.vshl +# scenarios: core-i16-unsigned, full-mask +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.uint16, 0, 1024) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vshl-shift-boundary/golden.py b/test/vpto/cases/micro-op/binary-vector/vshl-shift-boundary/golden.py new file mode 100755 index 000000000..15261e271 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshl-shift-boundary/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vshl-shift-boundary +# family: binary-vector +# target_ops: pto.vshl +# scenarios: core-i16-unsigned, full-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(0, 1 << 16, size=ELEMS, dtype=np.uint16) + shift_cycle = np.array([0, 1, 14, 15, 15, 14, 1, 0], dtype=np.uint16) + v2 = np.resize(shift_cycle, ELEMS).astype(np.uint16, copy=False) + v3 = np.zeros(ELEMS, dtype=np.uint16) + golden_v3 = np.left_shift(v1, v2 & np.uint16(15)).astype(np.uint16, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + v3.tofile(output_dir / "v3.bin") + golden_v3.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vshl-shift-boundary/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vshl-shift-boundary/kernel.pto new file mode 100644 index 000000000..d10c6a05b --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshl-shift-boundary/kernel.pto @@ -0,0 +1,55 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vshl-shift-boundary +// family: binary-vector +// target_ops: pto.vshl +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vshl_shift_boundary_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c1024 = arith.constant 1024 : index + %c128 = arith.constant 128 : index + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %out = pto.vshl %lhs, %rhs, %mask : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<128xui16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vshl-shift-boundary/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vshl-shift-boundary/launch.cpp new file mode 100644 index 000000000..bbf9c75f5 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshl-shift-boundary/launch.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vshl-shift-boundary +// family: binary-vector +// target_ops: pto.vshl +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vshl_shift_boundary_kernel(__gm__ uint16_t *v1, + __gm__ uint16_t *v2, + __gm__ uint16_t *v3); + +void LaunchVshl_shift_boundary_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream) { + vshl_shift_boundary_kernel<<<1, nullptr, stream>>>((__gm__ uint16_t *)v1, + (__gm__ uint16_t *)v2, + (__gm__ uint16_t *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vshl-shift-boundary/main.cpp b/test/vpto/cases/micro-op/binary-vector/vshl-shift-boundary/main.cpp new file mode 100644 index 000000000..13d114f38 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshl-shift-boundary/main.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vshl-shift-boundary +// family: binary-vector +// target_ops: pto.vshl +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVshl_shift_boundary_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint16_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + uint16_t *v3Host = nullptr; + uint16_t *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVshl_shift_boundary_kernel(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vshl/compare.py b/test/vpto/cases/micro-op/binary-vector/vshl/compare.py new file mode 100755 index 000000000..7006bca77 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshl/compare.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vshl +# family: binary-vector +# target_ops: pto.vshl +# scenarios: core-i16-unsigned, full-mask +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.uint16, 0, 1024) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vshl/golden.py b/test/vpto/cases/micro-op/binary-vector/vshl/golden.py new file mode 100755 index 000000000..ed6ca93bc --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshl/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vshl +# family: binary-vector +# target_ops: pto.vshl +# scenarios: core-i16-unsigned, full-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(0, 0x10000, size=ELEMS, dtype=np.uint16) + v2 = rng.integers(0, 16, size=ELEMS, dtype=np.uint16) + v3 = np.zeros(ELEMS, dtype=np.uint16) + golden_v3 = np.left_shift(v1, v2 & np.uint16(15)).astype(np.uint16, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + v3.tofile(output_dir / "v3.bin") + golden_v3.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vshl/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vshl/kernel.pto new file mode 100644 index 000000000..ebe4448bd --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshl/kernel.pto @@ -0,0 +1,55 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vshl +// family: binary-vector +// target_ops: pto.vshl +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vshl_i16_unsigned_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c1024 = arith.constant 1024 : index + %c128 = arith.constant 128 : index + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %out = pto.vshl %lhs, %rhs, %mask : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<128xui16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vshl/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vshl/launch.cpp new file mode 100644 index 000000000..3abb0a4bf --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshl/launch.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vshl +// family: binary-vector +// target_ops: pto.vshl +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vshl_i16_unsigned_kernel(__gm__ uint16_t *v1, + __gm__ uint16_t *v2, + __gm__ uint16_t *v3); + +void LaunchVshl_i16_unsigned_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream) { + vshl_i16_unsigned_kernel<<<1, nullptr, stream>>>((__gm__ uint16_t *)v1, + (__gm__ uint16_t *)v2, + (__gm__ uint16_t *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vshl/main.cpp b/test/vpto/cases/micro-op/binary-vector/vshl/main.cpp new file mode 100644 index 000000000..ce45e4cf1 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshl/main.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vshl +// family: binary-vector +// target_ops: pto.vshl +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVshl_i16_unsigned_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint16_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + uint16_t *v3Host = nullptr; + uint16_t *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVshl_i16_unsigned_kernel(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vshr-i16-signed/compare.py b/test/vpto/cases/micro-op/binary-vector/vshr-i16-signed/compare.py new file mode 100755 index 000000000..a8ee34d1b --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshr-i16-signed/compare.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vshr-i16-signed +# family: binary-vector +# target_ops: pto.vshr +# scenarios: core-i16-signed, full-mask +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.int16, 0, 1024) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vshr-i16-signed/golden.py b/test/vpto/cases/micro-op/binary-vector/vshr-i16-signed/golden.py new file mode 100755 index 000000000..4262e419f --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshr-i16-signed/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vshr-i16-signed +# family: binary-vector +# target_ops: pto.vshr +# scenarios: core-i16-signed, full-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(-0x8000, 0x8000, size=ELEMS, dtype=np.int16) + v2 = rng.integers(0, 16, size=ELEMS, dtype=np.int16) + v3 = np.zeros(ELEMS, dtype=np.int16) + golden_v3 = np.right_shift(v1, v2 & np.int16(15)).astype(np.int16, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + v3.tofile(output_dir / "v3.bin") + golden_v3.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vshr-i16-signed/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vshr-i16-signed/kernel.pto new file mode 100644 index 000000000..d4e53b05f --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshr-i16-signed/kernel.pto @@ -0,0 +1,55 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vshr-i16-signed +// family: binary-vector +// target_ops: pto.vshr +// scenarios: core-i16-signed, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vshr_i16_signed_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1024 = arith.constant 1024 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<128xsi16> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<128xsi16> + %out = pto.vshr %lhs, %rhs, %mask : !pto.vreg<128xsi16>, !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<128xsi16> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<128xsi16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vshr-i16-signed/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vshr-i16-signed/launch.cpp new file mode 100644 index 000000000..6b62ae139 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshr-i16-signed/launch.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vshr-i16-signed +// family: binary-vector +// target_ops: pto.vshr +// scenarios: core-i16-signed, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vshr_i16_signed_kernel(__gm__ int16_t *v1, + __gm__ int16_t *v2, + __gm__ int16_t *v3); + +void LaunchVshr_i16_signed_kernel(int16_t *v1, int16_t *v2, int16_t *v3, + void *stream) { + vshr_i16_signed_kernel<<<1, nullptr, stream>>>((__gm__ int16_t *)v1, + (__gm__ int16_t *)v2, + (__gm__ int16_t *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vshr-i16-signed/main.cpp b/test/vpto/cases/micro-op/binary-vector/vshr-i16-signed/main.cpp new file mode 100644 index 000000000..d1e3520e4 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshr-i16-signed/main.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vshr-i16-signed +// family: binary-vector +// target_ops: pto.vshr +// scenarios: core-i16-signed, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVshr_i16_signed_kernel(int16_t *v1, int16_t *v2, int16_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(int16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(int16_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(int16_t); + int16_t *v1Host = nullptr; + int16_t *v1Device = nullptr; + int16_t *v2Host = nullptr; + int16_t *v2Device = nullptr; + int16_t *v3Host = nullptr; + int16_t *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVshr_i16_signed_kernel(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vshr-shift-boundary/compare.py b/test/vpto/cases/micro-op/binary-vector/vshr-shift-boundary/compare.py new file mode 100755 index 000000000..f5e74791e --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshr-shift-boundary/compare.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vshr-shift-boundary +# family: binary-vector +# target_ops: pto.vshr +# scenarios: core-i16-unsigned, full-mask +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.uint16, 0, 1024) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vshr-shift-boundary/golden.py b/test/vpto/cases/micro-op/binary-vector/vshr-shift-boundary/golden.py new file mode 100755 index 000000000..6c70bedba --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshr-shift-boundary/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vshr-shift-boundary +# family: binary-vector +# target_ops: pto.vshr +# scenarios: core-i16-unsigned, full-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(0, 1 << 16, size=ELEMS, dtype=np.uint16) + shift_cycle = np.array([0, 1, 14, 15, 15, 14, 1, 0], dtype=np.uint16) + v2 = np.resize(shift_cycle, ELEMS).astype(np.uint16, copy=False) + v3 = np.zeros(ELEMS, dtype=np.uint16) + golden_v3 = np.right_shift(v1, v2 & np.uint16(15)).astype(np.uint16, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + v3.tofile(output_dir / "v3.bin") + golden_v3.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vshr-shift-boundary/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vshr-shift-boundary/kernel.pto new file mode 100644 index 000000000..baa1e93c1 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshr-shift-boundary/kernel.pto @@ -0,0 +1,55 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vshr-shift-boundary +// family: binary-vector +// target_ops: pto.vshr +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vshr_shift_boundary_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c1024 = arith.constant 1024 : index + %c128 = arith.constant 128 : index + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %out = pto.vshr %lhs, %rhs, %mask : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<128xui16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vshr-shift-boundary/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vshr-shift-boundary/launch.cpp new file mode 100644 index 000000000..a24d8e08c --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshr-shift-boundary/launch.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vshr-shift-boundary +// family: binary-vector +// target_ops: pto.vshr +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vshr_shift_boundary_kernel(__gm__ uint16_t *v1, + __gm__ uint16_t *v2, + __gm__ uint16_t *v3); + +void LaunchVshr_shift_boundary_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream) { + vshr_shift_boundary_kernel<<<1, nullptr, stream>>>((__gm__ uint16_t *)v1, + (__gm__ uint16_t *)v2, + (__gm__ uint16_t *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vshr-shift-boundary/main.cpp b/test/vpto/cases/micro-op/binary-vector/vshr-shift-boundary/main.cpp new file mode 100644 index 000000000..e8d1e1459 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshr-shift-boundary/main.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vshr-shift-boundary +// family: binary-vector +// target_ops: pto.vshr +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVshr_shift_boundary_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint16_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + uint16_t *v3Host = nullptr; + uint16_t *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVshr_shift_boundary_kernel(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vshr/compare.py b/test/vpto/cases/micro-op/binary-vector/vshr/compare.py new file mode 100755 index 000000000..a2429ec9b --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshr/compare.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vshr +# family: binary-vector +# target_ops: pto.vshr +# scenarios: core-i16-unsigned, full-mask +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.uint16, 0, 1024) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vshr/golden.py b/test/vpto/cases/micro-op/binary-vector/vshr/golden.py new file mode 100755 index 000000000..bd0cda8b9 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshr/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vshr +# family: binary-vector +# target_ops: pto.vshr +# scenarios: core-i16-unsigned, full-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(0, 0x10000, size=ELEMS, dtype=np.uint16) + v2 = rng.integers(0, 16, size=ELEMS, dtype=np.uint16) + v3 = np.zeros(ELEMS, dtype=np.uint16) + golden_v3 = np.right_shift(v1, v2 & np.uint16(15)).astype(np.uint16, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + v3.tofile(output_dir / "v3.bin") + golden_v3.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vshr/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vshr/kernel.pto new file mode 100644 index 000000000..cd7db0019 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshr/kernel.pto @@ -0,0 +1,55 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vshr +// family: binary-vector +// target_ops: pto.vshr +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vshr_i16_unsigned_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1024 = arith.constant 1024 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %out = pto.vshr %lhs, %rhs, %mask : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<128xui16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vshr/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vshr/launch.cpp new file mode 100644 index 000000000..08208c24c --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshr/launch.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vshr +// family: binary-vector +// target_ops: pto.vshr +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vshr_i16_unsigned_kernel(__gm__ uint16_t *v1, + __gm__ uint16_t *v2, + __gm__ uint16_t *v3); + +void LaunchVshr_i16_unsigned_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream) { + vshr_i16_unsigned_kernel<<<1, nullptr, stream>>>((__gm__ uint16_t *)v1, + (__gm__ uint16_t *)v2, + (__gm__ uint16_t *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vshr/main.cpp b/test/vpto/cases/micro-op/binary-vector/vshr/main.cpp new file mode 100644 index 000000000..fcdf7cbc0 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vshr/main.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vshr +// family: binary-vector +// target_ops: pto.vshr +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVshr_i16_unsigned_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint16_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + uint16_t *v3Host = nullptr; + uint16_t *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVshr_i16_unsigned_kernel(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vsub-tail/compare.py b/test/vpto/cases/micro-op/binary-vector/vsub-tail/compare.py new file mode 100644 index 000000000..c95419953 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vsub-tail/compare.py @@ -0,0 +1,39 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.float32, 1e-4, 1000) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vsub-tail/golden.py b/test/vpto/cases/micro-op/binary-vector/vsub-tail/golden.py new file mode 100644 index 000000000..954e00c9b --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vsub-tail/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +LOGICAL_ELEMS = 1000 +OUT_SENTINEL = np.float32(-123.25) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.random((ROWS, COLS), dtype=np.float32) + v2 = rng.random((ROWS, COLS), dtype=np.float32) + v3 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + golden_v3 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + golden_v3.reshape(-1)[:LOGICAL_ELEMS] = ( + v1.reshape(-1)[:LOGICAL_ELEMS] - v2.reshape(-1)[:LOGICAL_ELEMS] + ).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vsub-tail/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vsub-tail/kernel.pto new file mode 100644 index 000000000..6d839ec58 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vsub-tail/kernel.pto @@ -0,0 +1,49 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vsub_tail_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c1000_i32 = arith.constant 1000 : i32 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1000_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %diff = pto.vsub %lhs, %rhs, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %diff, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vsub-tail/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vsub-tail/launch.cpp new file mode 100644 index 000000000..01f113a97 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vsub-tail/launch.cpp @@ -0,0 +1,46 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vsub_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3); + +void LaunchVadd_tail_kernel_2d(float *v1, float *v2, float *v3, void *stream) { + vsub_tail_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vsub-tail/main.cpp b/test/vpto/cases/micro-op/binary-vector/vsub-tail/main.cpp new file mode 100644 index 000000000..40a9881d6 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vsub-tail/main.cpp @@ -0,0 +1,94 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVadd_tail_kernel_2d(float *v1, float *v2, float *v3, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVadd_tail_kernel_2d(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vsub/compare.py b/test/vpto/cases/micro-op/binary-vector/vsub/compare.py new file mode 100644 index 000000000..a5f14dabc --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vsub/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vsub/golden.py b/test/vpto/cases/micro-op/binary-vector/vsub/golden.py new file mode 100644 index 000000000..2f3f82fe6 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vsub/golden.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + golden_v3 = (v1 - v2).astype(np.float32, copy=False) + v3 = np.zeros((ROWS, COLS), dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vsub/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vsub/kernel.pto new file mode 100644 index 000000000..2d3847fc0 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vsub/kernel.pto @@ -0,0 +1,50 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @sub_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %diff = pto.vsub %lhs, %rhs, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %diff, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vsub/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vsub/launch.cpp new file mode 100644 index 000000000..daeaeb5de --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vsub/launch.cpp @@ -0,0 +1,46 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void sub_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3); + +void LaunchSub_kernel_2d(float *v1, float *v2, float *v3, void *stream) { + sub_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vsub/main.cpp b/test/vpto/cases/micro-op/binary-vector/vsub/main.cpp new file mode 100644 index 000000000..0c7c8359a --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vsub/main.cpp @@ -0,0 +1,94 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchSub_kernel_2d(float *v1, float *v2, float *v3, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchSub_kernel_2d(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vsubc-borrow-boundary/compare.py b/test/vpto/cases/micro-op/binary-vector/vsubc-borrow-boundary/compare.py new file mode 100755 index 000000000..67df8a750 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vsubc-borrow-boundary/compare.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vsubc-borrow-boundary +# family: binary-vector +# target_ops: pto.vsubc +# scenarios: core-u32-unsigned, full-mask, carry-chain +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 +LOGICAL_ELEMS = 64 +SRC_ELEM_BYTES = 4 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + repeat_elems = REPEAT_BYTES // src_elem_bytes + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + +def compare_result(): + golden = np.fromfile("golden_v3.bin", dtype=np.uint32, count=64) + output = np.fromfile("v3.bin", dtype=np.uint32, count=64) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def compare_borrow(): + prefix_bytes = _packed_pred_storage_bytes(LOGICAL_ELEMS, SRC_ELEM_BYTES) + golden = np.fromfile("golden_v4.bin", dtype=np.uint8) + output = np.fromfile("v4.bin", dtype=np.uint8) + if golden.size < prefix_bytes or output.size < prefix_bytes: + return False + return np.array_equal(golden[:prefix_bytes], output[:prefix_bytes]) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_result() and compare_borrow() + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vsubc-borrow-boundary/golden.py b/test/vpto/cases/micro-op/binary-vector/vsubc-borrow-boundary/golden.py new file mode 100755 index 000000000..cf6db0014 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vsubc-borrow-boundary/golden.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vsubc-borrow-boundary +# family: binary-vector +# target_ops: pto.vsubc +# scenarios: core-u32-unsigned, full-mask, carry-chain +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 64 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + del seed + v1 = np.zeros(LANES, dtype=np.uint32) + v2 = np.zeros(LANES, dtype=np.uint32) + pattern_lhs = np.array([0x00000000, 0x00000001, 0x7FFFFFFF, 0x80000000], dtype=np.uint32) + pattern_rhs = np.array([0x00000001, 0x00000002, 0x80000000, 0xFFFFFFFF], dtype=np.uint32) + reps = LANES // pattern_lhs.size + v1[:] = np.tile(pattern_lhs, reps) + v2[:] = np.tile(pattern_rhs, reps) + no_borrow = v1 >= v2 + result = (v1 - v2).astype(np.uint32, copy=False) + packed = np.zeros(256, dtype=np.uint8) + for idx, bit in enumerate(no_borrow): + if not bit: + continue + byte = idx // 2 + if idx % 2 == 0: + packed[byte] |= np.uint8(0x1) + else: + packed[byte] |= np.uint8(0x10) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + np.zeros(LANES, dtype=np.uint32).tofile(output_dir / "v3.bin") + np.zeros(256, dtype=np.uint8).tofile(output_dir / "v4.bin") + result.tofile(output_dir / "golden_v3.bin") + packed.tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vsubc-borrow-boundary/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vsubc-borrow-boundary/kernel.pto new file mode 100644 index 000000000..6c5586aa9 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vsubc-borrow-boundary/kernel.pto @@ -0,0 +1,57 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vsubc-borrow-boundary +// family: binary-vector +// target_ops: pto.vsubc +// scenarios: core-u32-unsigned, full-mask, carry-chain +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vsubc_borrow_boundary_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr, %arg3: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c128_i64 = arith.constant 128 : i64 + %c256_i64 = arith.constant 256 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c12288_i64 = arith.constant 12288 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_borrow = pto.castptr %c12288_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %lhs = pto.vlds %ub_lhs[%c0] : !pto.ptr -> !pto.vreg<64xui32> + %rhs = pto.vlds %ub_rhs[%c0] : !pto.ptr -> !pto.vreg<64xui32> + %diff, %borrow = pto.vsubc %lhs, %rhs, %mask : !pto.vreg<64xui32>, !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<64xui32>, !pto.mask + pto.vsts %diff, %ub_out[%c0], %mask : !pto.vreg<64xui32>, !pto.ptr, !pto.mask + pto.psti %borrow, %ub_borrow[%c0], "NORM" : !pto.mask, !pto.ptr, index + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_borrow, %arg3, %c128_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vsubc-borrow-boundary/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vsubc-borrow-boundary/launch.cpp new file mode 100644 index 000000000..972a52f2d --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vsubc-borrow-boundary/launch.cpp @@ -0,0 +1,56 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vsubc-borrow-boundary +// family: binary-vector +// target_ops: pto.vsubc +// scenarios: core-u32-unsigned, full-mask, carry-chain +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vsubc_borrow_boundary_kernel_2d(__gm__ uint32_t *v1, __gm__ uint32_t *v2, + __gm__ uint32_t *v3, __gm__ uint8_t *v4); + +void LaunchVsubc_borrow_boundary_kernel_2d(uint32_t *v1, uint32_t *v2, + uint32_t *v3, uint8_t *v4, + void *stream) { + vsubc_borrow_boundary_kernel_2d<<<1, nullptr, stream>>>( + (__gm__ uint32_t *)v1, (__gm__ uint32_t *)v2, (__gm__ uint32_t *)v3, + (__gm__ uint8_t *)v4); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vsubc-borrow-boundary/main.cpp b/test/vpto/cases/micro-op/binary-vector/vsubc-borrow-boundary/main.cpp new file mode 100644 index 000000000..43b2eb292 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vsubc-borrow-boundary/main.cpp @@ -0,0 +1,119 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vsubc-borrow-boundary +// family: binary-vector +// target_ops: pto.vsubc +// scenarios: core-u32-unsigned, full-mask, carry-chain +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVsubc_borrow_boundary_kernel_2d(uint32_t *v1, uint32_t *v2, + uint32_t *v3, uint8_t *v4, + void *stream); + +int main() { + size_t elemCount_v1 = 64; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + size_t elemCount_v2 = 64; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint32_t); + size_t elemCount_v3 = 64; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint32_t); + size_t elemCount_v4 = 256; + size_t fileSize_v4 = elemCount_v4 * sizeof(uint8_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + uint32_t *v2Host = nullptr; + uint32_t *v2Device = nullptr; + uint32_t *v3Host = nullptr; + uint32_t *v3Device = nullptr; + uint8_t *v4Host = nullptr; + uint8_t *v4Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMallocHost((void **)(&v4Host), fileSize_v4)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v4Device, fileSize_v4, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ReadFile("./v4.bin", fileSize_v4, v4Host, fileSize_v4); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v4Device, fileSize_v4, v4Host, fileSize_v4, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVsubc_borrow_boundary_kernel_2d(v1Device, v2Device, v3Device, v4Device, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(v4Host, fileSize_v4, v4Device, fileSize_v4, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + WriteFile("./v4.bin", v4Host, fileSize_v4); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFree(v4Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + aclrtFreeHost(v4Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vsubc/compare.py b/test/vpto/cases/micro-op/binary-vector/vsubc/compare.py new file mode 100755 index 000000000..f68c8267e --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vsubc/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vsubc +# family: binary-vector +# target_ops: pto.vsubc +# scenarios: core-u32-unsigned, full-mask, carry-chain + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 +LOGICAL_ELEMS = 64 +SRC_ELEM_BYTES = 4 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + repeat_elems = REPEAT_BYTES // src_elem_bytes + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + +def compare_result(): + golden = np.fromfile("golden_v3.bin", dtype=np.uint32, count=64) + output = np.fromfile("v3.bin", dtype=np.uint32, count=64) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def compare_borrow(): + prefix_bytes = _packed_pred_storage_bytes(LOGICAL_ELEMS, SRC_ELEM_BYTES) + golden = np.fromfile("golden_v4.bin", dtype=np.uint8) + output = np.fromfile("v4.bin", dtype=np.uint8) + if golden.size < prefix_bytes or output.size < prefix_bytes: + return False + return np.array_equal(golden[:prefix_bytes], output[:prefix_bytes]) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_result() and compare_borrow() + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vsubc/golden.py b/test/vpto/cases/micro-op/binary-vector/vsubc/golden.py new file mode 100755 index 000000000..1b647ac6c --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vsubc/golden.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vsubc +# family: binary-vector +# target_ops: pto.vsubc +# scenarios: core-u32-unsigned, full-mask, carry-chain + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 64 +SEED = 19 + + +def pack_mask_nibbles(bits): + out = np.zeros(256, dtype=np.uint8) + for idx, bit in enumerate(bits): + if not bit: + continue + byte = idx // 2 + if idx % 2 == 0: + out[byte] |= np.uint8(0x1) + else: + out[byte] |= np.uint8(0x10) + return out + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(0, 0xFFFFFFFF, size=LANES, dtype=np.uint32) + v2 = rng.integers(0, 0xFFFFFFFF, size=LANES, dtype=np.uint32) + diff = (v1 - v2).astype(np.uint32, copy=False) + no_borrow = v1 >= v2 + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + np.zeros(LANES, dtype=np.uint32).tofile(output_dir / "v3.bin") + np.zeros(256, dtype=np.uint8).tofile(output_dir / "v4.bin") + diff.tofile(output_dir / "golden_v3.bin") + pack_mask_nibbles(no_borrow).tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vsubc/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vsubc/kernel.pto new file mode 100644 index 000000000..d9eedc19b --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vsubc/kernel.pto @@ -0,0 +1,57 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vsubc +// family: binary-vector +// target_ops: pto.vsubc +// scenarios: core-i16-unsigned, full-mask, carry-chain +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vsubc_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr, %arg3: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c128_i64 = arith.constant 128 : i64 + %c256_i64 = arith.constant 256 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c12288_i64 = arith.constant 12288 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_borrow = pto.castptr %c12288_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %lhs = pto.vlds %ub_lhs[%c0] : !pto.ptr -> !pto.vreg<64xui32> + %rhs = pto.vlds %ub_rhs[%c0] : !pto.ptr -> !pto.vreg<64xui32> + %diff, %borrow = pto.vsubc %lhs, %rhs, %mask : !pto.vreg<64xui32>, !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<64xui32>, !pto.mask + pto.vsts %diff, %ub_out[%c0], %mask : !pto.vreg<64xui32>, !pto.ptr, !pto.mask + pto.psti %borrow, %ub_borrow[%c0], "NORM" : !pto.mask, !pto.ptr, index + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_borrow, %arg3, %c128_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vsubc/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vsubc/launch.cpp new file mode 100644 index 000000000..4f47cec25 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vsubc/launch.cpp @@ -0,0 +1,57 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vsubc +// family: binary-vector +// target_ops: pto.vsubc +// scenarios: core-u32-unsigned, full-mask, carry-chain +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vsubc_kernel_2d(__gm__ uint32_t *v1, + __gm__ uint32_t *v2, + __gm__ uint32_t *v3, + __gm__ uint8_t *v4); + +void LaunchVsubc_kernel_2d(uint32_t *v1, uint32_t *v2, uint32_t *v3, + uint8_t *v4, void *stream) { + vsubc_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1, + (__gm__ uint32_t *)v2, + (__gm__ uint32_t *)v3, + (__gm__ uint8_t *)v4); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vsubc/main.cpp b/test/vpto/cases/micro-op/binary-vector/vsubc/main.cpp new file mode 100644 index 000000000..a553603b2 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vsubc/main.cpp @@ -0,0 +1,117 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vsubc +// family: binary-vector +// target_ops: pto.vsubc +// scenarios: core-u32-unsigned, full-mask, carry-chain +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVsubc_kernel_2d(uint32_t *v1, uint32_t *v2, uint32_t *v3, + uint8_t *v4, void *stream); + +int main() { + size_t elemCount_v1 = 64; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + size_t elemCount_v2 = 64; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint32_t); + size_t elemCount_v3 = 64; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint32_t); + size_t elemCount_v4 = 256; + size_t fileSize_v4 = elemCount_v4 * sizeof(uint8_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + uint32_t *v2Host = nullptr; + uint32_t *v2Device = nullptr; + uint32_t *v3Host = nullptr; + uint32_t *v3Device = nullptr; + uint8_t *v4Host = nullptr; + uint8_t *v4Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMallocHost((void **)(&v4Host), fileSize_v4)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v4Device, fileSize_v4, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ReadFile("./v4.bin", fileSize_v4, v4Host, fileSize_v4); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v4Device, fileSize_v4, v4Host, fileSize_v4, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVsubc_kernel_2d(v1Device, v2Device, v3Device, v4Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(v4Host, fileSize_v4, v4Device, fileSize_v4, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + WriteFile("./v4.bin", v4Host, fileSize_v4); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFree(v4Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + aclrtFreeHost(v4Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vxor-mask-edge/compare.py b/test/vpto/cases/micro-op/binary-vector/vxor-mask-edge/compare.py new file mode 100755 index 000000000..e8a187bb2 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vxor-mask-edge/compare.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vxor-mask-edge +# family: binary-vector +# target_ops: pto.vxor +# scenarios: core-i16-unsigned, full-mask +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.uint16, 0, 1024) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vxor-mask-edge/golden.py b/test/vpto/cases/micro-op/binary-vector/vxor-mask-edge/golden.py new file mode 100755 index 000000000..0da3cc44d --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vxor-mask-edge/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vxor-mask-edge +# family: binary-vector +# target_ops: pto.vxor +# scenarios: core-i16-unsigned, full-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + idx = np.arange(ELEMS, dtype=np.uint16) + v1 = np.where((idx & 1) == 0, np.uint16(0xAAAA), np.uint16(0x0F0F)).astype(np.uint16, copy=False) + v2 = np.where((idx & 2) == 0, np.uint16(0x5555), np.uint16(0x3333)).astype(np.uint16, copy=False) + v3 = np.zeros(ELEMS, dtype=np.uint16) + golden_v3 = np.bitwise_xor(v1, v2).astype(np.uint16, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + v3.tofile(output_dir / "v3.bin") + golden_v3.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vxor-mask-edge/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vxor-mask-edge/kernel.pto new file mode 100644 index 000000000..4fbed53a0 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vxor-mask-edge/kernel.pto @@ -0,0 +1,55 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vxor-mask-edge +// family: binary-vector +// target_ops: pto.vxor +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vxor_mask_edge_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c1024 = arith.constant 1024 : index + %c128 = arith.constant 128 : index + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %out = pto.vxor %lhs, %rhs, %mask : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<128xui16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vxor-mask-edge/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vxor-mask-edge/launch.cpp new file mode 100644 index 000000000..309646298 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vxor-mask-edge/launch.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vxor-mask-edge +// family: binary-vector +// target_ops: pto.vxor +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vxor_mask_edge_kernel(__gm__ uint16_t *v1, + __gm__ uint16_t *v2, + __gm__ uint16_t *v3); + +void LaunchVxor_mask_edge_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream) { + vxor_mask_edge_kernel<<<1, nullptr, stream>>>((__gm__ uint16_t *)v1, + (__gm__ uint16_t *)v2, + (__gm__ uint16_t *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vxor-mask-edge/main.cpp b/test/vpto/cases/micro-op/binary-vector/vxor-mask-edge/main.cpp new file mode 100644 index 000000000..95d478f5f --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vxor-mask-edge/main.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vxor-mask-edge +// family: binary-vector +// target_ops: pto.vxor +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVxor_mask_edge_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint16_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + uint16_t *v3Host = nullptr; + uint16_t *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVxor_mask_edge_kernel(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/binary-vector/vxor/compare.py b/test/vpto/cases/micro-op/binary-vector/vxor/compare.py new file mode 100755 index 000000000..cd8e1ce3e --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vxor/compare.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vxor +# family: binary-vector +# target_ops: pto.vxor +# scenarios: core-i16-unsigned, full-mask +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.uint16, 0, 1024) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vxor/golden.py b/test/vpto/cases/micro-op/binary-vector/vxor/golden.py new file mode 100755 index 000000000..f5e328e08 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vxor/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/binary-vector/vxor +# family: binary-vector +# target_ops: pto.vxor +# scenarios: core-i16-unsigned, full-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(0, 0x10000, size=ELEMS, dtype=np.uint16) + v2 = rng.integers(0, 0x10000, size=ELEMS, dtype=np.uint16) + v3 = np.zeros(ELEMS, dtype=np.uint16) + golden_v3 = np.bitwise_xor(v1, v2).astype(np.uint16, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + v3.tofile(output_dir / "v3.bin") + golden_v3.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/binary-vector/vxor/kernel.pto b/test/vpto/cases/micro-op/binary-vector/vxor/kernel.pto new file mode 100644 index 000000000..6cd29199d --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vxor/kernel.pto @@ -0,0 +1,55 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vxor +// family: binary-vector +// target_ops: pto.vxor +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vxor_i16_unsigned_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1024 = arith.constant 1024 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %lhs = pto.vlds %ub_lhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %rhs = pto.vlds %ub_rhs[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %out = pto.vxor %lhs, %rhs, %mask : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<128xui16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/binary-vector/vxor/launch.cpp b/test/vpto/cases/micro-op/binary-vector/vxor/launch.cpp new file mode 100644 index 000000000..59d1c049b --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vxor/launch.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vxor +// family: binary-vector +// target_ops: pto.vxor +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vxor_i16_unsigned_kernel(__gm__ uint16_t *v1, + __gm__ uint16_t *v2, + __gm__ uint16_t *v3); + +void LaunchVxor_i16_unsigned_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream) { + vxor_i16_unsigned_kernel<<<1, nullptr, stream>>>((__gm__ uint16_t *)v1, + (__gm__ uint16_t *)v2, + (__gm__ uint16_t *)v3); +} diff --git a/test/vpto/cases/micro-op/binary-vector/vxor/main.cpp b/test/vpto/cases/micro-op/binary-vector/vxor/main.cpp new file mode 100644 index 000000000..99f4a9d98 --- /dev/null +++ b/test/vpto/cases/micro-op/binary-vector/vxor/main.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/binary-vector/vxor +// family: binary-vector +// target_ops: pto.vxor +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVxor_i16_unsigned_kernel(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint16_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + uint16_t *v3Host = nullptr; + uint16_t *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVxor_i16_unsigned_kernel(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-eq/compare.py b/test/vpto/cases/micro-op/compare-select/vcmp-eq/compare.py new file mode 100644 index 000000000..a872552e3 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-eq/compare.py @@ -0,0 +1,35 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + golden = np.fromfile("golden_v3.bin", dtype=np.uint8) + output = np.fromfile("v3.bin", dtype=np.uint8) + ok = golden.size >= 32 and output.size >= 32 and np.array_equal(golden[:32], output[:32]) + if not ok: + if golden.size and output.size: + diff = np.nonzero(golden[:32] != output[:32])[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] Mismatch: idx={idx} golden={int(golden[idx])} out={int(output[idx])}") + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-eq/golden.py b/test/vpto/cases/micro-op/compare-select/vcmp-eq/golden.py new file mode 100644 index 000000000..4d075f357 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-eq/golden.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 64 +SEED = 19 +OUTPUT_BYTES = 32 + + +def encode_b32_mask(mask: np.ndarray) -> np.ndarray: + out = np.zeros((OUTPUT_BYTES,), dtype=np.uint8) + for i, bit in enumerate(mask.astype(np.uint8, copy=False)): + if bit: + byte_index = i // 2 + nibble_shift = 4 * (i % 2) + out[byte_index] |= np.uint8(1 << nibble_shift) + return out + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-3.0, 3.0, size=(LANES,)).astype(np.float32) + v2 = v1.copy() + mismatch = (np.arange(LANES, dtype=np.int32) % 3) == 1 + v2[mismatch] = (v2[mismatch] + np.float32(1.25)).astype(np.float32) + + mask = np.equal(v1, v2) + golden = encode_b32_mask(mask) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + np.zeros((OUTPUT_BYTES,), dtype=np.uint8).tofile(output_dir / "v3.bin") + golden.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate inputs/golden for VPTO vcmp-eq.") + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-eq/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmp-eq/kernel.pto new file mode 100644 index 000000000..b92a55f8d --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-eq/kernel.pto @@ -0,0 +1,51 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vcmp_eq_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i32 = arith.constant 64 : i32 + %c64_i64 = arith.constant 64 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %false = arith.constant false + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c256_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c512_i64 : i64 -> !pto.ptr + pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %iv = %c0 to %c1 step %c1 iter_args(%remaining = %c64_i32) -> (i32) { + %active, %next = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %rhs = pto.vlds %ub_rhs[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %pred = pto.vcmp %lhs, %rhs, %active, "eq" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + pto.psts %pred, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + scf.yield %next : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.mte_ub_gm %ub_out, %arg2, %c32_i64 + nburst(%c32_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-eq/launch.cpp b/test/vpto/cases/micro-op/compare-select/vcmp-eq/launch.cpp new file mode 100644 index 000000000..aedecd7b5 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-eq/launch.cpp @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcmp_eq_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ unsigned char *v3); + +void LaunchVcmp_eq_kernel_2d(float *v1, float *v2, unsigned char *v3, + void *stream) { + vcmp_eq_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ unsigned char *)v3); +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-eq/main.cpp b/test/vpto/cases/micro-op/compare-select/vcmp-eq/main.cpp new file mode 100644 index 000000000..15cae7173 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-eq/main.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcmp_eq_kernel_2d(float *v1, float *v2, unsigned char *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 64; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 64; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 32; + size_t fileSize_v3 = elemCount_v3 * sizeof(unsigned char); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + unsigned char *v3Host = nullptr; + unsigned char *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVcmp_eq_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-f32-exceptional/compare.py b/test/vpto/cases/micro-op/compare-select/vcmp-f32-exceptional/compare.py new file mode 100644 index 000000000..a872552e3 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-f32-exceptional/compare.py @@ -0,0 +1,35 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + golden = np.fromfile("golden_v3.bin", dtype=np.uint8) + output = np.fromfile("v3.bin", dtype=np.uint8) + ok = golden.size >= 32 and output.size >= 32 and np.array_equal(golden[:32], output[:32]) + if not ok: + if golden.size and output.size: + diff = np.nonzero(golden[:32] != output[:32])[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] Mismatch: idx={idx} golden={int(golden[idx])} out={int(output[idx])}") + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-f32-exceptional/golden.py b/test/vpto/cases/micro-op/compare-select/vcmp-f32-exceptional/golden.py new file mode 100644 index 000000000..3d3dbb4d8 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-f32-exceptional/golden.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 64 +SEED = 19 +OUTPUT_BYTES = 32 + + +def encode_b32_mask(mask: np.ndarray) -> np.ndarray: + out = np.zeros((OUTPUT_BYTES,), dtype=np.uint8) + for i, bit in enumerate(mask.astype(np.uint8, copy=False)): + if bit: + byte_index = i // 2 + nibble_shift = 4 * (i % 2) + out[byte_index] |= np.uint8(1 << nibble_shift) + return out + + +def generate(output_dir: Path, seed: int) -> None: + del seed + lhs = np.array( + [-np.inf, -3.0, -0.0, 0.0, 0.5, np.inf, np.nan, 1.0], + dtype=np.float32, + ) + rhs = np.array( + [np.inf, -2.0, 0.0, -0.0, 0.5, np.nan, 1.0, -np.inf], + dtype=np.float32, + ) + v1 = np.resize(lhs, LANES).astype(np.float32) + v2 = np.resize(rhs, LANES).astype(np.float32) + mask = np.less(v1, v2) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + np.zeros((OUTPUT_BYTES,), dtype=np.uint8).tofile(output_dir / "v3.bin") + encode_b32_mask(mask).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate inputs/golden for VPTO vcmp-f32-exceptional.") + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-f32-exceptional/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmp-f32-exceptional/kernel.pto new file mode 100644 index 000000000..6f31f7cb8 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-f32-exceptional/kernel.pto @@ -0,0 +1,51 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vcmp_f32_exceptional_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i32 = arith.constant 64 : i32 + %c64_i64 = arith.constant 64 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %false = arith.constant false + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c256_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c512_i64 : i64 -> !pto.ptr + pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %iv = %c0 to %c1 step %c1 iter_args(%remaining = %c64_i32) -> (i32) { + %active, %next = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %rhs = pto.vlds %ub_rhs[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %pred = pto.vcmp %lhs, %rhs, %active, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + pto.psts %pred, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + scf.yield %next : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.mte_ub_gm %ub_out, %arg2, %c32_i64 + nburst(%c32_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-f32-exceptional/launch.cpp b/test/vpto/cases/micro-op/compare-select/vcmp-f32-exceptional/launch.cpp new file mode 100644 index 000000000..97d79d2fd --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-f32-exceptional/launch.cpp @@ -0,0 +1,50 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcmp_f32_exceptional_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ unsigned char *v3); + +void LaunchVcmp_f32_exceptional_kernel_2d(float *v1, float *v2, + unsigned char *v3, void *stream) { + vcmp_f32_exceptional_kernel_2d<<<1, nullptr, stream>>>( + (__gm__ float *)v1, (__gm__ float *)v2, (__gm__ unsigned char *)v3); +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-f32-exceptional/main.cpp b/test/vpto/cases/micro-op/compare-select/vcmp-f32-exceptional/main.cpp new file mode 100644 index 000000000..b8b92375f --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-f32-exceptional/main.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcmp_f32_exceptional_kernel_2d(float *v1, float *v2, + unsigned char *v3, void *stream); + +int main() { + size_t elemCount_v1 = 64; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 64; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 32; + size_t fileSize_v3 = elemCount_v3 * sizeof(unsigned char); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + unsigned char *v3Host = nullptr; + unsigned char *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVcmp_f32_exceptional_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-i16-signed/compare.py b/test/vpto/cases/micro-op/compare-select/vcmp-i16-signed/compare.py new file mode 100644 index 000000000..a872552e3 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-i16-signed/compare.py @@ -0,0 +1,35 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + golden = np.fromfile("golden_v3.bin", dtype=np.uint8) + output = np.fromfile("v3.bin", dtype=np.uint8) + ok = golden.size >= 32 and output.size >= 32 and np.array_equal(golden[:32], output[:32]) + if not ok: + if golden.size and output.size: + diff = np.nonzero(golden[:32] != output[:32])[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] Mismatch: idx={idx} golden={int(golden[idx])} out={int(output[idx])}") + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-i16-signed/golden.py b/test/vpto/cases/micro-op/compare-select/vcmp-i16-signed/golden.py new file mode 100644 index 000000000..08c48389e --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-i16-signed/golden.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 128 +SEED = 19 +OUTPUT_BYTES = 32 + + +def encode_b16_mask(mask: np.ndarray) -> np.ndarray: + out = np.zeros((OUTPUT_BYTES,), dtype=np.uint8) + for i, bit in enumerate(mask.astype(np.uint8, copy=False)): + if bit: + byte_index = i // 4 + pair_shift = 2 * (i % 4) + out[byte_index] |= np.uint8(1 << pair_shift) + return out + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(-1200, 1200, size=(LANES,), dtype=np.int16) + v2 = v1.copy() + mismatch = (np.arange(LANES, dtype=np.int32) % 3) == 1 + v2[mismatch] = (v2[mismatch] + np.int16(7)).astype(np.int16) + mask = np.equal(v1, v2) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + np.zeros((OUTPUT_BYTES,), dtype=np.uint8).tofile(output_dir / "v3.bin") + encode_b16_mask(mask).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate inputs/golden for VPTO vcmp-i16-signed.") + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-i16-signed/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmp-i16-signed/kernel.pto new file mode 100644 index 000000000..7d10bdad3 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-i16-signed/kernel.pto @@ -0,0 +1,52 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vcmp_eq_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i32 = arith.constant 128 : i32 + %c32_i64_data = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %false = arith.constant false + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c256_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c512_i64 : i64 -> !pto.ptr + pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %iv = %c0 to %c1 step %c1 iter_args(%remaining = %c128_i32) -> (i32) { + %active, %next = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%c0] : !pto.ptr -> !pto.vreg<128xi16> + %rhs = pto.vlds %ub_rhs[%c0] : !pto.ptr -> !pto.vreg<128xi16> + %pred = pto.vcmp %lhs, %rhs, %active, "eq" : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.mask + pto.psts %pred, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + scf.yield %next : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.mte_ub_gm %ub_out, %arg2, %c32_i64_data + nburst(%c32_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-i16-signed/launch.cpp b/test/vpto/cases/micro-op/compare-select/vcmp-i16-signed/launch.cpp new file mode 100644 index 000000000..0af8b16de --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-i16-signed/launch.cpp @@ -0,0 +1,59 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/compare-select/vcmp-i16-signed +// family: compare-select +// target_ops: pto.vcmp +// scenarios: core-i16-signed, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcmp_eq_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ unsigned char *v3); + +void LaunchVcmp_eq_kernel_2d(float *v1, float *v2, unsigned char *v3, + void *stream) { + vcmp_eq_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ unsigned char *)v3); +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-i16-signed/main.cpp b/test/vpto/cases/micro-op/compare-select/vcmp-i16-signed/main.cpp new file mode 100644 index 000000000..0857d2e9e --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-i16-signed/main.cpp @@ -0,0 +1,112 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/compare-select/vcmp-i16-signed +// family: compare-select +// target_ops: pto.vcmp +// scenarios: core-i16-signed, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcmp_eq_kernel_2d(float *v1, float *v2, unsigned char *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 128; + size_t fileSize_v1 = elemCount_v1 * sizeof(short); + size_t elemCount_v2 = 128; + size_t fileSize_v2 = elemCount_v2 * sizeof(short); + size_t elemCount_v3 = 32; + size_t fileSize_v3 = elemCount_v3 * sizeof(unsigned char); + short *v1Host = nullptr; + short *v1Device = nullptr; + short *v2Host = nullptr; + short *v2Device = nullptr; + unsigned char *v3Host = nullptr; + unsigned char *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVcmp_eq_kernel_2d(reinterpret_cast(v1Device), + reinterpret_cast(v2Device), v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-i16-unsigned/compare.py b/test/vpto/cases/micro-op/compare-select/vcmp-i16-unsigned/compare.py new file mode 100644 index 000000000..a872552e3 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-i16-unsigned/compare.py @@ -0,0 +1,35 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + golden = np.fromfile("golden_v3.bin", dtype=np.uint8) + output = np.fromfile("v3.bin", dtype=np.uint8) + ok = golden.size >= 32 and output.size >= 32 and np.array_equal(golden[:32], output[:32]) + if not ok: + if golden.size and output.size: + diff = np.nonzero(golden[:32] != output[:32])[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] Mismatch: idx={idx} golden={int(golden[idx])} out={int(output[idx])}") + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-i16-unsigned/golden.py b/test/vpto/cases/micro-op/compare-select/vcmp-i16-unsigned/golden.py new file mode 100644 index 000000000..1742febf1 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-i16-unsigned/golden.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 128 +SEED = 19 +OUTPUT_BYTES = 32 + + +def encode_b16_mask(mask: np.ndarray) -> np.ndarray: + out = np.zeros((OUTPUT_BYTES,), dtype=np.uint8) + for i, bit in enumerate(mask.astype(np.uint8, copy=False)): + if bit: + byte_index = i // 4 + pair_shift = 2 * (i % 4) + out[byte_index] |= np.uint8(1 << pair_shift) + return out + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(0, 60000, size=(LANES,), dtype=np.uint16) + v2 = v1.copy() + mismatch = (np.arange(LANES, dtype=np.int32) % 3) == 1 + v2[mismatch] = (v2[mismatch] + np.uint16(7)).astype(np.uint16) + mask = np.equal(v1, v2) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + np.zeros((OUTPUT_BYTES,), dtype=np.uint8).tofile(output_dir / "v3.bin") + encode_b16_mask(mask).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate inputs/golden for VPTO vcmp-i16-unsigned.") + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-i16-unsigned/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmp-i16-unsigned/kernel.pto new file mode 100644 index 000000000..264ba844b --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-i16-unsigned/kernel.pto @@ -0,0 +1,52 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vcmp_eq_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i32 = arith.constant 128 : i32 + %c32_i64_data = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %false = arith.constant false + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c256_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c512_i64 : i64 -> !pto.ptr + pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %iv = %c0 to %c1 step %c1 iter_args(%remaining = %c128_i32) -> (i32) { + %active, %next = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%c0] : !pto.ptr -> !pto.vreg<128xui16> + %rhs = pto.vlds %ub_rhs[%c0] : !pto.ptr -> !pto.vreg<128xui16> + %pred = pto.vcmp %lhs, %rhs, %active, "eq" : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.mask + pto.psts %pred, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + scf.yield %next : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.mte_ub_gm %ub_out, %arg2, %c32_i64_data + nburst(%c32_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-i16-unsigned/launch.cpp b/test/vpto/cases/micro-op/compare-select/vcmp-i16-unsigned/launch.cpp new file mode 100644 index 000000000..e1199bd24 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-i16-unsigned/launch.cpp @@ -0,0 +1,59 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/compare-select/vcmp-i16-unsigned +// family: compare-select +// target_ops: pto.vcmp +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcmp_eq_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ unsigned char *v3); + +void LaunchVcmp_eq_kernel_2d(float *v1, float *v2, unsigned char *v3, + void *stream) { + vcmp_eq_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ unsigned char *)v3); +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-i16-unsigned/main.cpp b/test/vpto/cases/micro-op/compare-select/vcmp-i16-unsigned/main.cpp new file mode 100644 index 000000000..8e5b3e1d4 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-i16-unsigned/main.cpp @@ -0,0 +1,112 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/compare-select/vcmp-i16-unsigned +// family: compare-select +// target_ops: pto.vcmp +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcmp_eq_kernel_2d(float *v1, float *v2, unsigned char *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 128; + size_t fileSize_v1 = elemCount_v1 * sizeof(unsigned short); + size_t elemCount_v2 = 128; + size_t fileSize_v2 = elemCount_v2 * sizeof(unsigned short); + size_t elemCount_v3 = 32; + size_t fileSize_v3 = elemCount_v3 * sizeof(unsigned char); + unsigned short *v1Host = nullptr; + unsigned short *v1Device = nullptr; + unsigned short *v2Host = nullptr; + unsigned short *v2Device = nullptr; + unsigned char *v3Host = nullptr; + unsigned char *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVcmp_eq_kernel_2d(reinterpret_cast(v1Device), + reinterpret_cast(v2Device), v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-lt/compare.py b/test/vpto/cases/micro-op/compare-select/vcmp-lt/compare.py new file mode 100644 index 000000000..a872552e3 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-lt/compare.py @@ -0,0 +1,35 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + golden = np.fromfile("golden_v3.bin", dtype=np.uint8) + output = np.fromfile("v3.bin", dtype=np.uint8) + ok = golden.size >= 32 and output.size >= 32 and np.array_equal(golden[:32], output[:32]) + if not ok: + if golden.size and output.size: + diff = np.nonzero(golden[:32] != output[:32])[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] Mismatch: idx={idx} golden={int(golden[idx])} out={int(output[idx])}") + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-lt/golden.py b/test/vpto/cases/micro-op/compare-select/vcmp-lt/golden.py new file mode 100644 index 000000000..6feb1da41 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-lt/golden.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 64 +SEED = 19 +OUTPUT_BYTES = 32 + + +def encode_b32_mask(mask: np.ndarray) -> np.ndarray: + out = np.zeros((OUTPUT_BYTES,), dtype=np.uint8) + for i, bit in enumerate(mask.astype(np.uint8, copy=False)): + if bit: + byte_index = i // 2 + nibble_shift = 4 * (i % 2) + out[byte_index] |= np.uint8(1 << nibble_shift) + return out + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-3.0, 3.0, size=(LANES,)).astype(np.float32) + delta = rng.uniform(0.25, 1.25, size=(LANES,)).astype(np.float32) + choose_less = (np.arange(LANES, dtype=np.int32) % 2) == 0 + v2 = np.where(choose_less, v1 + delta, v1 - delta).astype(np.float32) + mask = np.less(v1, v2) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + np.zeros((OUTPUT_BYTES,), dtype=np.uint8).tofile(output_dir / "v3.bin") + encode_b32_mask(mask).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate inputs/golden for VPTO vcmp-lt.") + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-lt/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmp-lt/kernel.pto new file mode 100644 index 000000000..266c88ce4 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-lt/kernel.pto @@ -0,0 +1,51 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vcmp_lt_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i32 = arith.constant 64 : i32 + %c64_i64 = arith.constant 64 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %false = arith.constant false + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c256_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c512_i64 : i64 -> !pto.ptr + pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %iv = %c0 to %c1 step %c1 iter_args(%remaining = %c64_i32) -> (i32) { + %active, %next = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %rhs = pto.vlds %ub_rhs[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %pred = pto.vcmp %lhs, %rhs, %active, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + pto.psts %pred, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + scf.yield %next : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.mte_ub_gm %ub_out, %arg2, %c32_i64 + nburst(%c32_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-lt/launch.cpp b/test/vpto/cases/micro-op/compare-select/vcmp-lt/launch.cpp new file mode 100644 index 000000000..2762499e2 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-lt/launch.cpp @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcmp_lt_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ unsigned char *v3); + +void LaunchVcmp_lt_kernel_2d(float *v1, float *v2, unsigned char *v3, + void *stream) { + vcmp_lt_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ unsigned char *)v3); +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-lt/main.cpp b/test/vpto/cases/micro-op/compare-select/vcmp-lt/main.cpp new file mode 100644 index 000000000..fa06a715f --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-lt/main.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcmp_lt_kernel_2d(float *v1, float *v2, unsigned char *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 64; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 64; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 32; + size_t fileSize_v3 = elemCount_v3 * sizeof(unsigned char); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + unsigned char *v3Host = nullptr; + unsigned char *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVcmp_lt_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-tail/compare.py b/test/vpto/cases/micro-op/compare-select/vcmp-tail/compare.py new file mode 100644 index 000000000..a872552e3 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-tail/compare.py @@ -0,0 +1,35 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + golden = np.fromfile("golden_v3.bin", dtype=np.uint8) + output = np.fromfile("v3.bin", dtype=np.uint8) + ok = golden.size >= 32 and output.size >= 32 and np.array_equal(golden[:32], output[:32]) + if not ok: + if golden.size and output.size: + diff = np.nonzero(golden[:32] != output[:32])[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] Mismatch: idx={idx} golden={int(golden[idx])} out={int(output[idx])}") + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-tail/golden.py b/test/vpto/cases/micro-op/compare-select/vcmp-tail/golden.py new file mode 100644 index 000000000..f59aded57 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-tail/golden.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 64 +LOGICAL_ELEMS = 53 +SEED = 19 +OUTPUT_BYTES = 32 + + +def encode_b32_mask(mask: np.ndarray) -> np.ndarray: + out = np.zeros((OUTPUT_BYTES,), dtype=np.uint8) + for i, bit in enumerate(mask.astype(np.uint8, copy=False)): + if bit: + byte_index = i // 2 + nibble_shift = 4 * (i % 2) + out[byte_index] |= np.uint8(1 << nibble_shift) + return out + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-6.0, 6.0, size=(LANES,)).astype(np.float32) + delta = rng.uniform(0.1, 2.0, size=(LANES,)).astype(np.float32) + mode = np.arange(LANES, dtype=np.int32) % 5 + + v2 = np.empty((LANES,), dtype=np.float32) + v2[mode == 0] = v1[mode == 0] + delta[mode == 0] + v2[mode == 1] = v1[mode == 1] - delta[mode == 1] + v2[mode == 2] = v1[mode == 2] + v2[mode == 3] = np.nextafter(v1[mode == 3], np.float32(np.inf)) + v2[mode == 4] = np.nextafter(v1[mode == 4], np.float32(-np.inf)) + + v1[:10] = np.array([-3.0, -1.0, -0.0, 0.0, 0.25, 1.0, 2.0, 4.0, -4.0, 6.0], dtype=np.float32) + v2[:10] = np.array([ + -2.0, + -2.0, + 0.0, + np.nextafter(np.float32(0.0), np.float32(np.inf)), + 0.25, + np.nextafter(np.float32(1.0), np.float32(-np.inf)), + 3.0, + 3.0, + np.nextafter(np.float32(-4.0), np.float32(np.inf)), + 6.0, + ], dtype=np.float32) + + mask = np.less(v1, v2) + mask[LOGICAL_ELEMS:] = False + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + np.zeros((OUTPUT_BYTES,), dtype=np.uint8).tofile(output_dir / "v3.bin") + encode_b32_mask(mask).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate inputs/golden for VPTO vcmp-tail.") + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-tail/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmp-tail/kernel.pto new file mode 100644 index 000000000..fe2d00b22 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-tail/kernel.pto @@ -0,0 +1,51 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vcmp_tail_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c53_i32 = arith.constant 53 : i32 + %c64_i64 = arith.constant 64 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %false = arith.constant false + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c256_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c512_i64 : i64 -> !pto.ptr + pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %iv = %c0 to %c1 step %c1 iter_args(%remaining = %c53_i32) -> (i32) { + %active, %next = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %rhs = pto.vlds %ub_rhs[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %pred = pto.vcmp %lhs, %rhs, %active, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + pto.psts %pred, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + scf.yield %next : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.mte_ub_gm %ub_out, %arg2, %c32_i64 + nburst(%c32_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-tail/launch.cpp b/test/vpto/cases/micro-op/compare-select/vcmp-tail/launch.cpp new file mode 100644 index 000000000..c57830b58 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-tail/launch.cpp @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcmp_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ unsigned char *v3); + +void LaunchVcmp_tail_kernel_2d(float *v1, float *v2, unsigned char *v3, + void *stream) { + vcmp_tail_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ unsigned char *)v3); +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmp-tail/main.cpp b/test/vpto/cases/micro-op/compare-select/vcmp-tail/main.cpp new file mode 100644 index 000000000..ee8661a62 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmp-tail/main.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcmp_tail_kernel_2d(float *v1, float *v2, unsigned char *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 64; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 64; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 32; + size_t fileSize_v3 = elemCount_v3 * sizeof(unsigned char); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + unsigned char *v3Host = nullptr; + unsigned char *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVcmp_tail_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-f32-exceptional/compare.py b/test/vpto/cases/micro-op/compare-select/vcmps-f32-exceptional/compare.py new file mode 100644 index 000000000..bc2a4827f --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-f32-exceptional/compare.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_mask(golden_path, output_path): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden.shape} vs {output.shape}") + return False + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] Mismatch (packed mask): idx={idx} golden={int(golden[idx])} out={int(output[idx])}") + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_mask("golden_v2.bin", "v2.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-f32-exceptional/golden.py b/test/vpto/cases/micro-op/compare-select/vcmps-f32-exceptional/golden.py new file mode 100644 index 000000000..d2ca06dc2 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-f32-exceptional/golden.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 64 +SEED = 19 +THRESHOLD = np.float32(0.5) +OUTPUT_BYTES = 32 + + +def encode_b32_mask(mask: np.ndarray) -> np.ndarray: + out = np.zeros((OUTPUT_BYTES,), dtype=np.uint8) + for i, bit in enumerate(mask.astype(np.uint8, copy=False)): + if bit: + byte_index = i // 2 + nibble_shift = 4 * (i % 2) + out[byte_index] |= np.uint8(1 << nibble_shift) + return out + + +def generate(output_dir: Path, seed: int) -> None: + del seed + specials = np.array( + [-np.inf, -1.0, -0.0, 0.0, 0.5, 0.75, np.inf, np.nan], + dtype=np.float32, + ) + v1 = np.resize(specials, LANES).astype(np.float32) + mask = np.greater(v1, THRESHOLD) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + np.zeros((OUTPUT_BYTES,), dtype=np.uint8).tofile(output_dir / "v2.bin") + encode_b32_mask(mask).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate inputs/golden for VPTO vcmps-f32-exceptional.") + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-f32-exceptional/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmps-f32-exceptional/kernel.pto new file mode 100644 index 000000000..f61d05eb9 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-f32-exceptional/kernel.pto @@ -0,0 +1,46 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vcmps_f32_exceptional_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i32 = arith.constant 64 : i32 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c256_i64 = arith.constant 256 : i64 + %threshold = arith.constant 5.000000e-01 : f32 + %false = arith.constant false + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c256_i64 : i64 -> !pto.ptr + pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.mte_gm_ub %arg0, %ub_src, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %iv = %c0 to %c1 step %c1 iter_args(%remaining = %c64_i32) -> (i32) { + %active, %next = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %src = pto.vlds %ub_src[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %pred = pto.vcmps %src, %threshold, %active, "gt" : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask + pto.psts %pred, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + scf.yield %next : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.mte_ub_gm %ub_out, %arg1, %c32_i64 + nburst(%c32_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-f32-exceptional/launch.cpp b/test/vpto/cases/micro-op/compare-select/vcmps-f32-exceptional/launch.cpp new file mode 100644 index 000000000..e96d87fec --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-f32-exceptional/launch.cpp @@ -0,0 +1,49 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcmps_f32_exceptional_kernel_2d(__gm__ float *v1, + __gm__ unsigned char *v2); + +void LaunchVcmps_f32_exceptional_kernel_2d(float *v1, unsigned char *v2, + void *stream) { + vcmps_f32_exceptional_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ unsigned char *)v2); +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-f32-exceptional/main.cpp b/test/vpto/cases/micro-op/compare-select/vcmps-f32-exceptional/main.cpp new file mode 100644 index 000000000..d8d7a33b6 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-f32-exceptional/main.cpp @@ -0,0 +1,92 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcmps_f32_exceptional_kernel_2d(float *v1, unsigned char *v2, + void *stream); + +int main() { + size_t elemCount_v1 = 64; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 32; + size_t fileSize_v2 = elemCount_v2 * sizeof(unsigned char); + float *v1Host = nullptr; + float *v1Device = nullptr; + unsigned char *v2Host = nullptr; + unsigned char *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVcmps_f32_exceptional_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-f32/compare.py b/test/vpto/cases/micro-op/compare-select/vcmps-f32/compare.py new file mode 100644 index 000000000..bc2a4827f --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-f32/compare.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_mask(golden_path, output_path): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden.shape} vs {output.shape}") + return False + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] Mismatch (packed mask): idx={idx} golden={int(golden[idx])} out={int(output[idx])}") + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_mask("golden_v2.bin", "v2.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-f32/golden.py b/test/vpto/cases/micro-op/compare-select/vcmps-f32/golden.py new file mode 100644 index 000000000..7224594ef --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-f32/golden.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 64 +SEED = 19 +THRESHOLD = np.float32(0.5) +OUTPUT_BYTES = 32 + + +def encode_b32_mask(mask: np.ndarray) -> np.ndarray: + out = np.zeros((OUTPUT_BYTES,), dtype=np.uint8) + for i, bit in enumerate(mask.astype(np.uint8, copy=False)): + if bit: + byte_index = i // 2 + nibble_shift = 4 * (i % 2) + out[byte_index] |= np.uint8(1 << nibble_shift) + return out + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-2.0, 2.0, size=(LANES,)).astype(np.float32) + v1[:8] = np.array([0.5, 0.5001, 0.4999, -0.5, 1.0, -1.0, 0.0, 2.0], + dtype=np.float32) + mask = np.greater(v1, THRESHOLD) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + np.zeros((OUTPUT_BYTES,), dtype=np.uint8).tofile(output_dir / "v2.bin") + encode_b32_mask(mask).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate inputs/golden for VPTO vcmps-f32.") + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-f32/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmps-f32/kernel.pto new file mode 100644 index 000000000..448ad347e --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-f32/kernel.pto @@ -0,0 +1,46 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vcmps_f32_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i32 = arith.constant 64 : i32 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c256_i64 = arith.constant 256 : i64 + %threshold = arith.constant 5.000000e-01 : f32 + %false = arith.constant false + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c256_i64 : i64 -> !pto.ptr + pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.mte_gm_ub %arg0, %ub_src, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %iv = %c0 to %c1 step %c1 iter_args(%remaining = %c64_i32) -> (i32) { + %active, %next = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %src = pto.vlds %ub_src[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %pred = pto.vcmps %src, %threshold, %active, "gt" : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask + pto.psts %pred, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + scf.yield %next : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.mte_ub_gm %ub_out, %arg1, %c32_i64 + nburst(%c32_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-f32/launch.cpp b/test/vpto/cases/micro-op/compare-select/vcmps-f32/launch.cpp new file mode 100644 index 000000000..49dfceec7 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-f32/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcmps_f32_kernel_2d(__gm__ float *v1, + __gm__ unsigned char *v2); + +void LaunchVcmps_f32_kernel_2d(float *v1, unsigned char *v2, void *stream) { + vcmps_f32_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ unsigned char *)v2); +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-f32/main.cpp b/test/vpto/cases/micro-op/compare-select/vcmps-f32/main.cpp new file mode 100644 index 000000000..e9c28d290 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-f32/main.cpp @@ -0,0 +1,91 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcmps_f32_kernel_2d(float *v1, unsigned char *v2, void *stream); + +int main() { + size_t elemCount_v1 = 64; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 32; + size_t fileSize_v2 = elemCount_v2 * sizeof(unsigned char); + float *v1Host = nullptr; + float *v1Device = nullptr; + unsigned char *v2Host = nullptr; + unsigned char *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVcmps_f32_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i16-signed/compare.py b/test/vpto/cases/micro-op/compare-select/vcmps-i16-signed/compare.py new file mode 100644 index 000000000..bc2a4827f --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i16-signed/compare.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_mask(golden_path, output_path): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden.shape} vs {output.shape}") + return False + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] Mismatch (packed mask): idx={idx} golden={int(golden[idx])} out={int(output[idx])}") + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_mask("golden_v2.bin", "v2.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i16-signed/golden.py b/test/vpto/cases/micro-op/compare-select/vcmps-i16-signed/golden.py new file mode 100644 index 000000000..1a4865835 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i16-signed/golden.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 128 +SEED = 19 +THRESHOLD = np.int16(5) +OUTPUT_BYTES = 32 + + +def encode_b16_mask(mask: np.ndarray) -> np.ndarray: + out = np.zeros((OUTPUT_BYTES,), dtype=np.uint8) + for i, bit in enumerate(mask.astype(np.uint8, copy=False)): + if bit: + byte_index = i // 4 + bit_shift = 2 * (i % 4) + out[byte_index] |= np.uint8(1 << bit_shift) + return out + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(-32768, 32767, size=(LANES,), dtype=np.int16) + mask = np.greater(v1, THRESHOLD) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + np.zeros((OUTPUT_BYTES,), dtype=np.uint8).tofile(output_dir / "v2.bin") + encode_b16_mask(mask).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate inputs/golden for VPTO vcmps-i16-signed.") + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i16-signed/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmps-i16-signed/kernel.pto new file mode 100644 index 000000000..58f8628a3 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i16-signed/kernel.pto @@ -0,0 +1,45 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vcmps_i16_signed_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c256_i64 = arith.constant 256 : i64 + %c128_i32 = arith.constant 128 : i32 + %threshold = arith.constant 5 : i16 + %false = arith.constant false + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c256_i64 : i64 -> !pto.ptr + pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.mte_gm_ub %arg0, %ub_src, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %iv = %c0 to %c1 step %c1 iter_args(%remaining = %c128_i32) -> (i32) { + %active, %next = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %src = pto.vlds %ub_src[%c0] : !pto.ptr -> !pto.vreg<128xi16> + %pred = pto.vcmps %src, %threshold, %active, "gt" : !pto.vreg<128xi16>, i16, !pto.mask -> !pto.mask + pto.psts %pred, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + scf.yield %next : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.mte_ub_gm %ub_out, %arg1, %c32_i64 + nburst(%c32_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i16-signed/launch.cpp b/test/vpto/cases/micro-op/compare-select/vcmps-i16-signed/launch.cpp new file mode 100644 index 000000000..6fd05c86e --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i16-signed/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcmps_i16_signed_kernel_2d(__gm__ int16_t *v1, + __gm__ unsigned char *v2); + +void LaunchVcmps_i16_signed_kernel_2d(int16_t *v1, unsigned char *v2, void *stream) { + vcmps_i16_signed_kernel_2d<<<1, nullptr, stream>>>((__gm__ int16_t *)v1, + (__gm__ unsigned char *)v2); +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i16-signed/main.cpp b/test/vpto/cases/micro-op/compare-select/vcmps-i16-signed/main.cpp new file mode 100644 index 000000000..2d318f173 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i16-signed/main.cpp @@ -0,0 +1,92 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcmps_i16_signed_kernel_2d(int16_t *v1, unsigned char *v2, void *stream); + +int main() { + size_t elemCount_v1 = 128; + size_t fileSize_v1 = elemCount_v1 * sizeof(int16_t); + size_t elemCount_v2 = 32; + size_t fileSize_v2 = elemCount_v2 * sizeof(unsigned char); + int16_t *v1Host = nullptr; + int16_t *v1Device = nullptr; + unsigned char *v2Host = nullptr; + unsigned char *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVcmps_i16_signed_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/compare.py b/test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/compare.py new file mode 100644 index 000000000..bc2a4827f --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/compare.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_mask(golden_path, output_path): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden.shape} vs {output.shape}") + return False + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] Mismatch (packed mask): idx={idx} golden={int(golden[idx])} out={int(output[idx])}") + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_mask("golden_v2.bin", "v2.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/golden.py b/test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/golden.py new file mode 100644 index 000000000..b7318cc9f --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/golden.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 128 +SEED = 19 +THRESHOLD = np.uint16(513) +OUTPUT_BYTES = 32 + + +def encode_b16_mask(mask: np.ndarray) -> np.ndarray: + out = np.zeros((OUTPUT_BYTES,), dtype=np.uint8) + for i, bit in enumerate(mask.astype(np.uint8, copy=False)): + if bit: + byte_index = i // 4 + bit_shift = 2 * (i % 4) + out[byte_index] |= np.uint8(1 << bit_shift) + return out + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(0, 65535, size=(LANES,), dtype=np.uint16) + mask = np.greater(v1, THRESHOLD) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + np.zeros((OUTPUT_BYTES,), dtype=np.uint8).tofile(output_dir / "v2.bin") + encode_b16_mask(mask).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate inputs/golden for VPTO vcmps-i16-unsigned.") + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/kernel.pto new file mode 100644 index 000000000..b7ae23f2b --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/kernel.pto @@ -0,0 +1,45 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vcmps_i16_unsigned_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c256_i64 = arith.constant 256 : i64 + %c128_i32 = arith.constant 128 : i32 + %threshold = arith.constant 513 : i16 + %false = arith.constant false + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c256_i64 : i64 -> !pto.ptr + pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.mte_gm_ub %arg0, %ub_src, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %iv = %c0 to %c1 step %c1 iter_args(%remaining = %c128_i32) -> (i32) { + %active, %next = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %src = pto.vlds %ub_src[%c0] : !pto.ptr -> !pto.vreg<128xui16> + %pred = pto.vcmps %src, %threshold, %active, "gt" : !pto.vreg<128xui16>, i16, !pto.mask -> !pto.mask + pto.psts %pred, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + scf.yield %next : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.mte_ub_gm %ub_out, %arg1, %c32_i64 + nburst(%c32_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/launch.cpp b/test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/launch.cpp new file mode 100644 index 000000000..2178b15fa --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcmps_i16_unsigned_kernel_2d(__gm__ uint16_t *v1, + __gm__ unsigned char *v2); + +void LaunchVcmps_i16_unsigned_kernel_2d(uint16_t *v1, unsigned char *v2, void *stream) { + vcmps_i16_unsigned_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint16_t *)v1, + (__gm__ unsigned char *)v2); +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/main.cpp b/test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/main.cpp new file mode 100644 index 000000000..717dc9332 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i16-unsigned/main.cpp @@ -0,0 +1,92 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcmps_i16_unsigned_kernel_2d(uint16_t *v1, unsigned char *v2, void *stream); + +int main() { + size_t elemCount_v1 = 128; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 32; + size_t fileSize_v2 = elemCount_v2 * sizeof(unsigned char); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + unsigned char *v2Host = nullptr; + unsigned char *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVcmps_i16_unsigned_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i8-signed/compare.py b/test/vpto/cases/micro-op/compare-select/vcmps-i8-signed/compare.py new file mode 100644 index 000000000..bc2a4827f --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i8-signed/compare.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_mask(golden_path, output_path): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden.shape} vs {output.shape}") + return False + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] Mismatch (packed mask): idx={idx} golden={int(golden[idx])} out={int(output[idx])}") + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_mask("golden_v2.bin", "v2.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i8-signed/golden.py b/test/vpto/cases/micro-op/compare-select/vcmps-i8-signed/golden.py new file mode 100644 index 000000000..9b26d1c00 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i8-signed/golden.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 256 +SEED = 19 +THRESHOLD = np.int8(5) +OUTPUT_BYTES = 32 + + +def encode_b8_mask(mask: np.ndarray) -> np.ndarray: + out = np.zeros((OUTPUT_BYTES,), dtype=np.uint8) + for i, bit in enumerate(mask.astype(np.uint8, copy=False)): + if bit: + byte_index = i // 8 + bit_shift = i % 8 + out[byte_index] |= np.uint8(1 << bit_shift) + return out + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(-128, 127, size=(LANES,), dtype=np.int8) + mask = np.greater(v1, THRESHOLD) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + np.zeros((OUTPUT_BYTES,), dtype=np.uint8).tofile(output_dir / "v2.bin") + encode_b8_mask(mask).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate inputs/golden for VPTO vcmps-i8-signed.") + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i8-signed/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmps-i8-signed/kernel.pto new file mode 100644 index 000000000..4cb82ed72 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i8-signed/kernel.pto @@ -0,0 +1,45 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vcmps_i8_signed_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c256_i64 = arith.constant 256 : i64 + %c256_i32 = arith.constant 256 : i32 + %threshold = arith.constant 5 : i8 + %false = arith.constant false + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c256_i64 : i64 -> !pto.ptr + pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.mte_gm_ub %arg0, %ub_src, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %iv = %c0 to %c1 step %c1 iter_args(%remaining = %c256_i32) -> (i32) { + %active, %next = pto.plt_b8 %remaining : i32 -> !pto.mask, i32 + %src = pto.vlds %ub_src[%c0] : !pto.ptr -> !pto.vreg<256xsi8> + %pred = pto.vcmps %src, %threshold, %active, "gt" : !pto.vreg<256xsi8>, i8, !pto.mask -> !pto.mask + pto.psts %pred, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + scf.yield %next : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.mte_ub_gm %ub_out, %arg1, %c32_i64 + nburst(%c32_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i8-signed/launch.cpp b/test/vpto/cases/micro-op/compare-select/vcmps-i8-signed/launch.cpp new file mode 100644 index 000000000..3423328de --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i8-signed/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcmps_i8_signed_kernel_2d(__gm__ int8_t *v1, + __gm__ unsigned char *v2); + +void LaunchVcmps_i8_signed_kernel_2d(int8_t *v1, unsigned char *v2, void *stream) { + vcmps_i8_signed_kernel_2d<<<1, nullptr, stream>>>((__gm__ int8_t *)v1, + (__gm__ unsigned char *)v2); +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i8-signed/main.cpp b/test/vpto/cases/micro-op/compare-select/vcmps-i8-signed/main.cpp new file mode 100644 index 000000000..bb2a47b15 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i8-signed/main.cpp @@ -0,0 +1,92 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcmps_i8_signed_kernel_2d(int8_t *v1, unsigned char *v2, void *stream); + +int main() { + size_t elemCount_v1 = 256; + size_t fileSize_v1 = elemCount_v1 * sizeof(int8_t); + size_t elemCount_v2 = 32; + size_t fileSize_v2 = elemCount_v2 * sizeof(unsigned char); + int8_t *v1Host = nullptr; + int8_t *v1Device = nullptr; + unsigned char *v2Host = nullptr; + unsigned char *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVcmps_i8_signed_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i8-unsigned/compare.py b/test/vpto/cases/micro-op/compare-select/vcmps-i8-unsigned/compare.py new file mode 100644 index 000000000..bc2a4827f --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i8-unsigned/compare.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_mask(golden_path, output_path): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden.shape} vs {output.shape}") + return False + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] Mismatch (packed mask): idx={idx} golden={int(golden[idx])} out={int(output[idx])}") + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_mask("golden_v2.bin", "v2.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i8-unsigned/golden.py b/test/vpto/cases/micro-op/compare-select/vcmps-i8-unsigned/golden.py new file mode 100644 index 000000000..334f772c7 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i8-unsigned/golden.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 256 +SEED = 19 +THRESHOLD = np.uint8(129) +OUTPUT_BYTES = 32 + + +def encode_b8_mask(mask: np.ndarray) -> np.ndarray: + out = np.zeros((OUTPUT_BYTES,), dtype=np.uint8) + for i, bit in enumerate(mask.astype(np.uint8, copy=False)): + if bit: + byte_index = i // 8 + bit_shift = i % 8 + out[byte_index] |= np.uint8(1 << bit_shift) + return out + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(0, 255, size=(LANES,), dtype=np.uint8) + mask = np.greater(v1, THRESHOLD) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + np.zeros((OUTPUT_BYTES,), dtype=np.uint8).tofile(output_dir / "v2.bin") + encode_b8_mask(mask).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate inputs/golden for VPTO vcmps-i8-unsigned.") + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i8-unsigned/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmps-i8-unsigned/kernel.pto new file mode 100644 index 000000000..a826fbc47 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i8-unsigned/kernel.pto @@ -0,0 +1,45 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vcmps_i8_unsigned_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c256_i64 = arith.constant 256 : i64 + %c256_i32 = arith.constant 256 : i32 + %threshold = arith.constant -127 : i8 + %false = arith.constant false + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c256_i64 : i64 -> !pto.ptr + pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.mte_gm_ub %arg0, %ub_src, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %iv = %c0 to %c1 step %c1 iter_args(%remaining = %c256_i32) -> (i32) { + %active, %next = pto.plt_b8 %remaining : i32 -> !pto.mask, i32 + %src = pto.vlds %ub_src[%c0] : !pto.ptr -> !pto.vreg<256xui8> + %pred = pto.vcmps %src, %threshold, %active, "gt" : !pto.vreg<256xui8>, i8, !pto.mask -> !pto.mask + pto.psts %pred, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + scf.yield %next : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.mte_ub_gm %ub_out, %arg1, %c32_i64 + nburst(%c32_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i8-unsigned/launch.cpp b/test/vpto/cases/micro-op/compare-select/vcmps-i8-unsigned/launch.cpp new file mode 100644 index 000000000..ce3b2eca9 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i8-unsigned/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcmps_i8_unsigned_kernel_2d(__gm__ uint8_t *v1, + __gm__ unsigned char *v2); + +void LaunchVcmps_i8_unsigned_kernel_2d(uint8_t *v1, unsigned char *v2, void *stream) { + vcmps_i8_unsigned_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint8_t *)v1, + (__gm__ unsigned char *)v2); +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-i8-unsigned/main.cpp b/test/vpto/cases/micro-op/compare-select/vcmps-i8-unsigned/main.cpp new file mode 100644 index 000000000..f24020b81 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-i8-unsigned/main.cpp @@ -0,0 +1,92 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcmps_i8_unsigned_kernel_2d(uint8_t *v1, unsigned char *v2, void *stream); + +int main() { + size_t elemCount_v1 = 256; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint8_t); + size_t elemCount_v2 = 32; + size_t fileSize_v2 = elemCount_v2 * sizeof(unsigned char); + uint8_t *v1Host = nullptr; + uint8_t *v1Device = nullptr; + unsigned char *v2Host = nullptr; + unsigned char *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVcmps_i8_unsigned_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-tail/compare.py b/test/vpto/cases/micro-op/compare-select/vcmps-tail/compare.py new file mode 100644 index 000000000..bc2a4827f --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-tail/compare.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_mask(golden_path, output_path): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden.shape} vs {output.shape}") + return False + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] Mismatch (packed mask): idx={idx} golden={int(golden[idx])} out={int(output[idx])}") + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_mask("golden_v2.bin", "v2.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-tail/golden.py b/test/vpto/cases/micro-op/compare-select/vcmps-tail/golden.py new file mode 100644 index 000000000..e36631b9a --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-tail/golden.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 64 +LOGICAL_ELEMS = 40 +SEED = 19 +THRESHOLD = np.float32(0.5) +OUTPUT_BYTES = 32 + + +def encode_b32_mask(mask: np.ndarray) -> np.ndarray: + out = np.zeros((OUTPUT_BYTES,), dtype=np.uint8) + for i, bit in enumerate(mask.astype(np.uint8, copy=False)): + if bit: + byte_index = i // 2 + nibble_shift = 4 * (i % 2) + out[byte_index] |= np.uint8(1 << nibble_shift) + return out + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-2.0, 2.0, size=(LANES,)).astype(np.float32) + + v1[:12] = np.array([ + THRESHOLD, + np.nextafter(THRESHOLD, np.float32(np.inf)), + np.nextafter(THRESHOLD, np.float32(-np.inf)), + 0.0, + -0.0, + -1.0, + 1.0, + 2.0, + -2.0, + THRESHOLD + np.float32(0.25), + THRESHOLD - np.float32(0.25), + THRESHOLD, + ], dtype=np.float32) + + mask = np.greater(v1, THRESHOLD) + mask[LOGICAL_ELEMS:] = False + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + np.zeros((OUTPUT_BYTES,), dtype=np.uint8).tofile(output_dir / "v2.bin") + encode_b32_mask(mask).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate inputs/golden for VPTO vcmps-tail.") + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-tail/kernel.pto b/test/vpto/cases/micro-op/compare-select/vcmps-tail/kernel.pto new file mode 100644 index 000000000..7ed9ffeed --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-tail/kernel.pto @@ -0,0 +1,45 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vcmps_tail_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c40_i32 = arith.constant 40 : i32 + %c64_i64 = arith.constant 64 : i64 + %c256_i64 = arith.constant 256 : i64 + %threshold = arith.constant 5.000000e-01 : f32 + %false = arith.constant false + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c256_i64 : i64 -> !pto.ptr + pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.mte_gm_ub %arg0, %ub_src, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %iv = %c0 to %c1 step %c1 iter_args(%remaining = %c40_i32) -> (i32) { + %active, %next = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %src = pto.vlds %ub_src[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %pred = pto.vcmps %src, %threshold, %active, "gt" : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask + pto.psts %pred, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + scf.yield %next : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.mte_ub_gm %ub_out, %arg1, %c32_i64 + nburst(%c32_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-tail/launch.cpp b/test/vpto/cases/micro-op/compare-select/vcmps-tail/launch.cpp new file mode 100644 index 000000000..a210fc2fa --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-tail/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcmps_tail_kernel_2d(__gm__ float *v1, + __gm__ unsigned char *v2); + +void LaunchVcmps_tail_kernel_2d(float *v1, unsigned char *v2, void *stream) { + vcmps_tail_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ unsigned char *)v2); +} diff --git a/test/vpto/cases/micro-op/compare-select/vcmps-tail/main.cpp b/test/vpto/cases/micro-op/compare-select/vcmps-tail/main.cpp new file mode 100644 index 000000000..941741c4b --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vcmps-tail/main.cpp @@ -0,0 +1,91 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcmps_tail_kernel_2d(float *v1, unsigned char *v2, void *stream); + +int main() { + size_t elemCount_v1 = 64; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 32; + size_t fileSize_v2 = elemCount_v2 * sizeof(unsigned char); + float *v1Host = nullptr; + float *v1Device = nullptr; + unsigned char *v2Host = nullptr; + unsigned char *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVcmps_tail_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/compare-select/vsel-f32-plds-us-pintlv-pbitcast/compare.py b/test/vpto/cases/micro-op/compare-select/vsel-f32-plds-us-pintlv-pbitcast/compare.py new file mode 100644 index 000000000..0823d7be6 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel-f32-plds-us-pintlv-pbitcast/compare.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden.shape} vs {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + diff = np.abs(golden.astype(np.float64) - output.astype(np.float64)) + idx = int(np.argmax(diff)) + print(f"[ERROR] Mismatch: idx={idx} golden={golden[idx]} out={output[idx]}") + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v4.bin", "v4.bin", np.float32, 1e-6) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vsel-f32-plds-us-pintlv-pbitcast/golden.py b/test/vpto/cases/micro-op/compare-select/vsel-f32-plds-us-pintlv-pbitcast/golden.py new file mode 100644 index 000000000..c5f29198b --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel-f32-plds-us-pintlv-pbitcast/golden.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 128 +MASK_BYTES = 32 +SEED = 19 + + +def load_us_source_bits(packed: np.ndarray) -> np.ndarray: + # `plds ..., "US"` on this path consumes only the first VL/16 bytes. + bits = np.unpackbits(packed[:16], bitorder="little") + return bits.astype(np.bool_, copy=False) + + +def plds_us_to_mask_b8(src_bits: np.ndarray) -> np.ndarray: + # "US": duplicate each loaded bit once. + return np.repeat(src_bits, 2).astype(np.bool_, copy=False) + + +def pbitcast_b8_to_b16(mask_b8: np.ndarray) -> np.ndarray: + # Reinterpret the same predicate image at b16 granularity. + # For the duplicated "US" image, each b16 lane observes the first bit of + # the corresponding 2-bit pair, which reconstructs the original 128 bits. + return mask_b8[::2].astype(np.bool_, copy=False) + + +def pintlv_b16_with_all(mask_b16: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + # Interleave `%mask_b16` with `all_b16`, then split into low/high outputs. + interleaved = np.empty((256,), dtype=np.bool_) + interleaved[0::2] = mask_b16 + interleaved[1::2] = True + return interleaved[:128], interleaved[128:] + + +def pbitcast_b16_to_b32(mask_b16_image: np.ndarray) -> np.ndarray: + # Reinterpret the same predicate image at b32 granularity. + # The b32 lanes read back the even-positioned b16 lanes. + return mask_b16_image[0::2][:64].astype(np.bool_, copy=False) + + +def build_vsel_lanes_from_mask_pipeline(packed: np.ndarray) -> np.ndarray: + src_bits = load_us_source_bits(packed) + mask_b8 = plds_us_to_mask_b8(src_bits) + mask_b16 = pbitcast_b8_to_b16(mask_b8) + mask0_b16, mask1_b16 = pintlv_b16_with_all(mask_b16) + mask0_b32 = pbitcast_b16_to_b32(mask0_b16) + mask1_b32 = pbitcast_b16_to_b32(mask1_b16) + return np.concatenate([mask0_b32, mask1_b32]).astype(np.bool_, copy=False) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-3.0, 3.0, size=(LANES,)).astype(np.float32) + v2 = rng.uniform(-3.0, 3.0, size=(LANES,)).astype(np.float32) + packed = rng.integers(0, 256, size=(MASK_BYTES,), dtype=np.uint8) + lanes = build_vsel_lanes_from_mask_pipeline(packed) + v4 = np.zeros((LANES,), dtype=np.float32) + golden_v4 = np.where(lanes, v1, v2).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + packed.tofile(output_dir / "v3.bin") + v4.tofile(output_dir / "v4.bin") + golden_v4.tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate inputs/golden for VPTO vsel-f32-plds-us-pintlv-pbitcast." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vsel-f32-plds-us-pintlv-pbitcast/kernel.pto b/test/vpto/cases/micro-op/compare-select/vsel-f32-plds-us-pintlv-pbitcast/kernel.pto new file mode 100644 index 000000000..567987b24 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel-f32-plds-us-pintlv-pbitcast/kernel.pto @@ -0,0 +1,77 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/compare-select/vsel-f32-plds-us-pintlv-pbitcast +// family: compare-select +// target_ops: pto.plds, pto.pintlv_b16, pto.pbitcast, pto.vsel +// scenarios: packed-us-mask-expand-to-b32, f32-select-from-compressed-mask +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vsel_f32_plds_us_pintlv_pbitcast_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr, + %arg2: !pto.ptr, + %arg3: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c0_i64 = arith.constant 0 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c1056_i64 = arith.constant 1056 : i64 + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c512_i64 : i64 -> !pto.ptr + %ub_mask = pto.castptr %c1024_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c1056_i64 : i64 -> !pto.ptr + + pto.set_loop1_stride_outtoub %c128_i64, %c128_i64 : i64, i64 + pto.set_loop2_stride_outtoub %c128_i64, %c128_i64 : i64, i64 + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg2, %ub_mask, %c0_i64, %c32_i64 + nburst(%c32_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask_b8 = pto.plds %ub_mask[%c0], "US" : !pto.ptr, index -> !pto.mask + %mask_b16 = pto.pbitcast %mask_b8 : !pto.mask -> !pto.mask + %all_b16 = pto.pset_b16 "PAT_ALL" : !pto.mask + %all_b32 = pto.pset_b32 "PAT_ALL" : !pto.mask + %mask0_b16, %mask1_b16 = pto.pintlv_b16 %mask_b16, %all_b16 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask + %mask0_b32 = pto.pbitcast %mask0_b16 : !pto.mask -> !pto.mask + %mask1_b32 = pto.pbitcast %mask1_b16 : !pto.mask -> !pto.mask + %lhs0 = pto.vlds %ub_lhs[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %rhs0 = pto.vlds %ub_rhs[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %lhs1 = pto.vlds %ub_lhs[%c64] : !pto.ptr -> !pto.vreg<64xf32> + %rhs1 = pto.vlds %ub_rhs[%c64] : !pto.ptr -> !pto.vreg<64xf32> + %out0 = pto.vsel %lhs0, %rhs0, %mask0_b32 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %out1 = pto.vsel %lhs1, %rhs1, %mask1_b32 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out0, %ub_out[%c0], %all_b32 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + pto.vsts %out1, %ub_out[%c64], %all_b32 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop1_stride_ubtoout %c128_i64, %c128_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c128_i64, %c128_i64 : i64, i64 + pto.mte_ub_gm %ub_out, %arg3, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/compare-select/vsel-f32-plds-us-pintlv-pbitcast/launch.cpp b/test/vpto/cases/micro-op/compare-select/vsel-f32-plds-us-pintlv-pbitcast/launch.cpp new file mode 100644 index 000000000..73a736714 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel-f32-plds-us-pintlv-pbitcast/launch.cpp @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vsel_f32_plds_us_pintlv_pbitcast_kernel_2d(__gm__ float *v1, __gm__ float *v2, + __gm__ unsigned char *v3, + __gm__ float *v4); + +void LaunchVsel_f32_plds_us_pintlv_pbitcast_kernel_2d(float *v1, float *v2, + unsigned char *v3, + float *v4, + void *stream) { + vsel_f32_plds_us_pintlv_pbitcast_kernel_2d<<<1, nullptr, stream>>>( + (__gm__ float *)v1, (__gm__ float *)v2, (__gm__ unsigned char *)v3, + (__gm__ float *)v4); +} diff --git a/test/vpto/cases/micro-op/compare-select/vsel-f32-plds-us-pintlv-pbitcast/main.cpp b/test/vpto/cases/micro-op/compare-select/vsel-f32-plds-us-pintlv-pbitcast/main.cpp new file mode 100644 index 000000000..4f36cd466 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel-f32-plds-us-pintlv-pbitcast/main.cpp @@ -0,0 +1,116 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/compare-select/vsel-f32-plds-us-pintlv-pbitcast +// ----------------------------------------------------------------------------- +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVsel_f32_plds_us_pintlv_pbitcast_kernel_2d(float *v1, float *v2, + unsigned char *v3, + float *v4, + void *stream); + +int main() { + size_t elemCount_v1 = 128; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 128; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 32; + size_t fileSize_v3 = elemCount_v3 * sizeof(unsigned char); + size_t elemCount_v4 = 128; + size_t fileSize_v4 = elemCount_v4 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + unsigned char *v3Host = nullptr; + unsigned char *v3Device = nullptr; + float *v4Host = nullptr; + float *v4Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMallocHost((void **)(&v4Host), fileSize_v4)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v4Device, fileSize_v4, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ReadFile("./v4.bin", fileSize_v4, v4Host, fileSize_v4); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v4Device, fileSize_v4, v4Host, fileSize_v4, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVsel_f32_plds_us_pintlv_pbitcast_kernel_2d(v1Device, v2Device, v3Device, + v4Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v4Host, fileSize_v4, v4Device, fileSize_v4, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v4.bin", v4Host, fileSize_v4); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFree(v4Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + aclrtFreeHost(v4Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/compare-select/vsel-i16/compare.py b/test/vpto/cases/micro-op/compare-select/vsel-i16/compare.py new file mode 100644 index 000000000..cb78833f5 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel-i16/compare.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden.shape} vs {output.shape}") + return False + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] Mismatch: idx={idx} golden={golden[idx]} out={output[idx]}") + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin", np.int16) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vsel-i16/golden.py b/test/vpto/cases/micro-op/compare-select/vsel-i16/golden.py new file mode 100644 index 000000000..9a5b86185 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel-i16/golden.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 128 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(-200, 200, size=(LANES,), dtype=np.int16) + v2 = rng.integers(-200, 200, size=(LANES,), dtype=np.int16) + golden_v3 = np.where(v1 > v2, v1, v2).astype(np.int16, copy=False) + v3 = np.zeros((LANES,), dtype=np.int16) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + v3.tofile(output_dir / "v3.bin") + golden_v3.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate inputs/golden for VPTO vsel-i16.") + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vsel-i16/kernel.pto b/test/vpto/cases/micro-op/compare-select/vsel-i16/kernel.pto new file mode 100644 index 000000000..670b5a2bb --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel-i16/kernel.pto @@ -0,0 +1,51 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vsel_i16_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i32 = arith.constant 128 : i32 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %false = arith.constant false + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c256_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c512_i64 : i64 -> !pto.ptr + pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %iv = %c0 to %c1 step %c1 iter_args(%remaining = %c128_i32) -> (i32) { + %active, %next = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%c0] : !pto.ptr -> !pto.vreg<128xi16> + %rhs = pto.vlds %ub_rhs[%c0] : !pto.ptr -> !pto.vreg<128xi16> + %pred = pto.vcmp %lhs, %rhs, %active, "gt" : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.mask + %out = pto.vsel %lhs, %rhs, %pred : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + pto.vsts %out, %ub_out[%c0], %active : !pto.vreg<128xi16>, !pto.ptr, !pto.mask + scf.yield %next : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop1_stride_ubtoout %c256_i64, %c256_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c256_i64, %c256_i64 : i64, i64 + pto.mte_ub_gm %ub_out, %arg2, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/compare-select/vsel-i16/launch.cpp b/test/vpto/cases/micro-op/compare-select/vsel-i16/launch.cpp new file mode 100644 index 000000000..8ca125427 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel-i16/launch.cpp @@ -0,0 +1,50 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vsel_i16_kernel_2d(__gm__ int16_t *v1, + __gm__ int16_t *v2, + __gm__ int16_t *v3); + +void LaunchVsel_i16_kernel_2d(int16_t *v1, int16_t *v2, int16_t *v3, void *stream) { + vsel_i16_kernel_2d<<<1, nullptr, stream>>>((__gm__ int16_t *)v1, + (__gm__ int16_t *)v2, + (__gm__ int16_t *)v3); +} diff --git a/test/vpto/cases/micro-op/compare-select/vsel-i16/main.cpp b/test/vpto/cases/micro-op/compare-select/vsel-i16/main.cpp new file mode 100644 index 000000000..656e4dfec --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel-i16/main.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVsel_i16_kernel_2d(int16_t *v1, int16_t *v2, int16_t *v3, void *stream); + +int main() { + size_t elemCount_v1 = 128; + size_t fileSize_v1 = elemCount_v1 * sizeof(int16_t); + size_t elemCount_v2 = 128; + size_t fileSize_v2 = elemCount_v2 * sizeof(int16_t); + size_t elemCount_v3 = 128; + size_t fileSize_v3 = elemCount_v3 * sizeof(int16_t); + int16_t *v1Host = nullptr; + int16_t *v1Device = nullptr; + int16_t *v2Host = nullptr; + int16_t *v2Device = nullptr; + int16_t *v3Host = nullptr; + int16_t *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVsel_i16_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/compare-select/vsel-predicate-edge/compare.py b/test/vpto/cases/micro-op/compare-select/vsel-predicate-edge/compare.py new file mode 100644 index 000000000..a861864de --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel-predicate-edge/compare.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden.shape} vs {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + diff = np.abs(golden.astype(np.float64) - output.astype(np.float64)) + idx = int(np.argmax(diff)) + print(f"[ERROR] Mismatch: idx={idx} golden={golden[idx]} out={output[idx]}") + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 1e-6) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vsel-predicate-edge/golden.py b/test/vpto/cases/micro-op/compare-select/vsel-predicate-edge/golden.py new file mode 100644 index 000000000..e757d1089 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel-predicate-edge/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 64 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + lhs = rng.uniform(-8.0, 8.0, size=(LANES,)).astype(np.float32) + rhs = lhs.copy() + + lane_ids = np.arange(LANES, dtype=np.int32) + edge_mask = ((lane_ids < 4) | (lane_ids >= 60) | ((lane_ids % 17) == 0)) + rhs[edge_mask] = (rhs[edge_mask] + np.float32(3.5)).astype(np.float32) + rhs[~edge_mask] = (rhs[~edge_mask] - np.float32(2.0)).astype(np.float32) + + out = np.zeros((LANES,), dtype=np.float32) + golden = np.where(lhs > rhs, lhs, rhs).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + lhs.tofile(output_dir / "v1.bin") + rhs.tofile(output_dir / "v2.bin") + out.tofile(output_dir / "v3.bin") + golden.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate inputs/golden for VPTO vsel-predicate-edge.") + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vsel-predicate-edge/kernel.pto b/test/vpto/cases/micro-op/compare-select/vsel-predicate-edge/kernel.pto new file mode 100644 index 000000000..7d73c5bfa --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel-predicate-edge/kernel.pto @@ -0,0 +1,52 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vsel_predicate_edge_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i32 = arith.constant 64 : i32 + %c64_i64 = arith.constant 64 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %false = arith.constant false + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c256_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c512_i64 : i64 -> !pto.ptr + pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %iv = %c0 to %c1 step %c1 iter_args(%remaining = %c64_i32) -> (i32) { + %active, %next = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %rhs = pto.vlds %ub_rhs[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %pred = pto.vcmp %lhs, %rhs, %active, "gt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + %out = pto.vsel %lhs, %rhs, %pred : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%c0], %active : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop1_stride_ubtoout %c256_i64, %c256_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c256_i64, %c256_i64 : i64, i64 + pto.mte_ub_gm %ub_out, %arg2, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/compare-select/vsel-predicate-edge/launch.cpp b/test/vpto/cases/micro-op/compare-select/vsel-predicate-edge/launch.cpp new file mode 100644 index 000000000..c23a21ddd --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel-predicate-edge/launch.cpp @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vsel_predicate_edge_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3); + +void LaunchVsel_predicate_edge_kernel_2d(float *v1, float *v2, float *v3, + void *stream) { + vsel_predicate_edge_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/compare-select/vsel-predicate-edge/main.cpp b/test/vpto/cases/micro-op/compare-select/vsel-predicate-edge/main.cpp new file mode 100644 index 000000000..6fea6f0e8 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel-predicate-edge/main.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVsel_predicate_edge_kernel_2d(float *v1, float *v2, float *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 64; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 64; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 64; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVsel_predicate_edge_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/compare-select/vsel-tail/compare.py b/test/vpto/cases/micro-op/compare-select/vsel-tail/compare.py new file mode 100755 index 000000000..9965cdb63 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel-tail/compare.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/compare-select/vsel-tail +# family: compare-select +# target_ops: pto.vsel +# scenarios: core-f32, tail-mask +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + +LOGICAL_ELEMS = 40 + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden.shape} vs {output.shape}") + return False + golden = golden[:LOGICAL_ELEMS] + output = output[:LOGICAL_ELEMS] + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + diff = np.abs(golden.astype(np.float64) - output.astype(np.float64)) + idx = int(np.argmax(diff)) + print(f"[ERROR] Mismatch: idx={idx} golden={golden[idx]} out={output[idx]}") + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 1e-6) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vsel-tail/golden.py b/test/vpto/cases/micro-op/compare-select/vsel-tail/golden.py new file mode 100644 index 000000000..a2d6807fa --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel-tail/golden.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 64 +SEED = 19 +LOGICAL_ELEMS = 40 +OUT_SENTINEL = np.float32(-123.25) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-3.0, 3.0, size=(LANES,)).astype(np.float32) + v2 = rng.uniform(-3.0, 3.0, size=(LANES,)).astype(np.float32) + golden_v3 = np.full((LANES,), OUT_SENTINEL, dtype=np.float32) + flat = np.where(v1 > v2, v1, v2).astype(np.float32, copy=False) + golden_v3[:LOGICAL_ELEMS] = flat[:LOGICAL_ELEMS] + v3 = np.full((LANES,), OUT_SENTINEL, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + v3.tofile(output_dir / "v3.bin") + golden_v3.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate inputs/golden for VPTO vsel-tail.") + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vsel-tail/kernel.pto b/test/vpto/cases/micro-op/compare-select/vsel-tail/kernel.pto new file mode 100644 index 000000000..2237e5305 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel-tail/kernel.pto @@ -0,0 +1,54 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vsel_tail_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c40_i32 = arith.constant 40 : i32 + %c64_i64 = arith.constant 64 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %false = arith.constant false + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c256_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c512_i64 : i64 -> !pto.ptr + pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg2, %ub_out, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %iv = %c0 to %c1 step %c1 iter_args(%remaining = %c40_i32) -> (i32) { + %active, %next = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %rhs = pto.vlds %ub_rhs[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %pred = pto.vcmp %lhs, %rhs, %active, "gt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + %out = pto.vsel %lhs, %rhs, %pred : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%c0], %active : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop1_stride_ubtoout %c256_i64, %c256_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c256_i64, %c256_i64 : i64, i64 + pto.mte_ub_gm %ub_out, %arg2, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/compare-select/vsel-tail/launch.cpp b/test/vpto/cases/micro-op/compare-select/vsel-tail/launch.cpp new file mode 100644 index 000000000..b4e0598e0 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel-tail/launch.cpp @@ -0,0 +1,50 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vsel_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3); + +void LaunchVsel_tail_kernel_2d(float *v1, float *v2, float *v3, void *stream) { + vsel_tail_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/compare-select/vsel-tail/main.cpp b/test/vpto/cases/micro-op/compare-select/vsel-tail/main.cpp new file mode 100644 index 000000000..323131056 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel-tail/main.cpp @@ -0,0 +1,102 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVsel_tail_kernel_2d(float *v1, float *v2, float *v3, void *stream); + +int main() { + size_t elemCount_v1 = 64; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 64; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 64; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVsel_tail_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/compare-select/vsel/compare.py b/test/vpto/cases/micro-op/compare-select/vsel/compare.py new file mode 100755 index 000000000..6cd01c922 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel/compare.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/compare-select/vsel +# family: compare-select +# target_ops: pto.vsel +# scenarios: core-f32, full-mask +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden.shape} vs {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + diff = np.abs(golden.astype(np.float64) - output.astype(np.float64)) + idx = int(np.argmax(diff)) + print(f"[ERROR] Mismatch: idx={idx} golden={golden[idx]} out={output[idx]}") + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 1e-6) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vsel/golden.py b/test/vpto/cases/micro-op/compare-select/vsel/golden.py new file mode 100644 index 000000000..fef5f7d27 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel/golden.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 64 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-3.0, 3.0, size=(LANES,)).astype(np.float32) + v2 = rng.uniform(-3.0, 3.0, size=(LANES,)).astype(np.float32) + golden_v3 = np.where(v1 > v2, v1, v2).astype(np.float32, copy=False) + v3 = np.zeros((LANES,), dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + v3.tofile(output_dir / "v3.bin") + golden_v3.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate inputs/golden for VPTO vsel.") + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vsel/kernel.pto b/test/vpto/cases/micro-op/compare-select/vsel/kernel.pto new file mode 100644 index 000000000..c8ee34b29 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel/kernel.pto @@ -0,0 +1,51 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vsel_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i32 = arith.constant 64 : i32 + %c64_i64 = arith.constant 64 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %false = arith.constant false + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c256_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c512_i64 : i64 -> !pto.ptr + pto.set_loop1_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.set_loop2_stride_outtoub %c256_i64, %c256_i64 : i64, i64 + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %iv = %c0 to %c1 step %c1 iter_args(%remaining = %c64_i32) -> (i32) { + %active, %next = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %lhs = pto.vlds %ub_lhs[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %rhs = pto.vlds %ub_rhs[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %pred = pto.vcmp %lhs, %rhs, %active, "gt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + %out = pto.vsel %lhs, %rhs, %pred : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%c0], %active : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop1_stride_ubtoout %c256_i64, %c256_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c256_i64, %c256_i64 : i64, i64 + pto.mte_ub_gm %ub_out, %arg2, %c256_i64 + nburst(%c64_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/compare-select/vsel/launch.cpp b/test/vpto/cases/micro-op/compare-select/vsel/launch.cpp new file mode 100644 index 000000000..269405dee --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel/launch.cpp @@ -0,0 +1,50 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vsel_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3); + +void LaunchVsel_kernel_2d(float *v1, float *v2, float *v3, void *stream) { + vsel_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/compare-select/vsel/main.cpp b/test/vpto/cases/micro-op/compare-select/vsel/main.cpp new file mode 100644 index 000000000..cf71eb295 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vsel/main.cpp @@ -0,0 +1,102 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVsel_kernel_2d(float *v1, float *v2, float *v3, void *stream); + +int main() { + size_t elemCount_v1 = 64; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 64; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 64; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVsel_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/compare-select/vselr-f16/compare.py b/test/vpto/cases/micro-op/compare-select/vselr-f16/compare.py new file mode 100644 index 000000000..b961a3713 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vselr-f16/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/compare-select/vselr-f16 +# family: compare-select +# target_ops: pto.vselr +# scenarios: core-f16, full-mask, explicit-lane-index + +import os +import sys + +import numpy as np + + +def compare_tensor(golden_path: str, output_path: str) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.float16) + output = np.fromfile(output_path, dtype=np.float16) + if golden.shape != output.shape: + return False + if not np.allclose(golden, output, rtol=0.0, atol=0.0, equal_nan=True): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] Mismatch: idx={idx} golden={golden[idx]} out={output[idx]}") + return False + return True + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_tensor("golden_v3.bin", "v3.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vselr-f16/golden.py b/test/vpto/cases/micro-op/compare-select/vselr-f16/golden.py new file mode 100644 index 000000000..ae0513ecf --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vselr-f16/golden.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/compare-select/vselr-f16 +# family: compare-select +# target_ops: pto.vselr +# scenarios: core-f16, full-mask, explicit-lane-index + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 8 +COLS = 128 +SEED = 23 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + src = rng.uniform(-6.0, 6.0, size=(ROWS, COLS)).astype(np.float16, copy=False) + lane_ids = np.arange(COLS, dtype=np.uint16) + idx = np.empty((ROWS, COLS), dtype=np.uint16) + for row in range(ROWS): + idx[row] = (lane_ids[::-1] + row * 11 + (lane_ids % 7) * 3) % COLS + golden = np.take_along_axis(src, idx.astype(np.int64, copy=False), axis=1).astype(np.float16, copy=False) + out = np.zeros((ROWS, COLS), dtype=np.float16) + + output_dir.mkdir(parents=True, exist_ok=True) + src.view(np.uint16).reshape(-1).tofile(output_dir / "v1.bin") + idx.reshape(-1).tofile(output_dir / "v2.bin") + out.view(np.uint16).reshape(-1).tofile(output_dir / "v3.bin") + golden.view(np.uint16).reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate inputs/golden for vselr-f16.") + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vselr-f16/kernel.pto b/test/vpto/cases/micro-op/compare-select/vselr-f16/kernel.pto new file mode 100644 index 000000000..0c9bf3d51 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vselr-f16/kernel.pto @@ -0,0 +1,54 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/compare-select/vselr-f16 +// family: compare-select +// target_ops: pto.vselr +// scenarios: core-f16, full-mask, explicit-lane-index +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vselr_f16_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c8_i64 = arith.constant 8 : i64 + %c256_i64 = arith.constant 256 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + %false = arith.constant false + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_idx = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + pto.mte_gm_ub %arg0, %ub_src, %c0_i64, %c256_i64 + nburst(%c8_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_idx, %c0_i64, %c256_i64 + nburst(%c8_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %src = pto.vlds %ub_src[%offset] : !pto.ptr -> !pto.vreg<128xf16> + %idx = pto.vlds %ub_idx[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %out = pto.vselr %src, %idx : !pto.vreg<128xf16>, !pto.vreg<128xui16> -> !pto.vreg<128xf16> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<128xf16>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c256_i64 + nburst(%c8_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/compare-select/vselr-f16/launch.cpp b/test/vpto/cases/micro-op/compare-select/vselr-f16/launch.cpp new file mode 100644 index 000000000..f00e5672d --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vselr-f16/launch.cpp @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/compare-select/vselr-f16 +// family: compare-select +// target_ops: pto.vselr +// scenarios: core-f16, full-mask, explicit-lane-index +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vselr_f16_kernel_2d(__gm__ half *v1, + __gm__ uint16_t *v2, + __gm__ half *v3); + +void LaunchVselr_f16_kernel_2d(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream) { + vselr_f16_kernel_2d<<<1, nullptr, stream>>>((__gm__ half *)v1, + (__gm__ uint16_t *)v2, + (__gm__ half *)v3); +} diff --git a/test/vpto/cases/micro-op/compare-select/vselr-f16/main.cpp b/test/vpto/cases/micro-op/compare-select/vselr-f16/main.cpp new file mode 100644 index 000000000..2002d9ee5 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vselr-f16/main.cpp @@ -0,0 +1,106 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/compare-select/vselr-f16 +// family: compare-select +// target_ops: pto.vselr +// scenarios: core-f16, full-mask, explicit-lane-index +// ----------------------------------------------------------------------------- +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVselr_f16_kernel_2d(uint16_t *v1, uint16_t *v2, uint16_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint16_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + uint16_t *v3Host = nullptr; + uint16_t *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVselr_f16_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/compare-select/vselr-u8/compare.py b/test/vpto/cases/micro-op/compare-select/vselr-u8/compare.py new file mode 100644 index 000000000..d48e1a42e --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vselr-u8/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/compare-select/vselr-u8 +# family: compare-select +# target_ops: pto.vselr +# scenarios: core-u8, full-mask, explicit-lane-index + +import os +import sys + +import numpy as np + + +def compare_tensor(golden_path: str, output_path: str) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + if golden.shape != output.shape: + return False + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] Mismatch: idx={idx} golden={int(golden[idx])} out={int(output[idx])}") + return False + return True + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_tensor("golden_v3.bin", "v3.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vselr-u8/golden.py b/test/vpto/cases/micro-op/compare-select/vselr-u8/golden.py new file mode 100644 index 000000000..2cb03404a --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vselr-u8/golden.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/compare-select/vselr-u8 +# family: compare-select +# target_ops: pto.vselr +# scenarios: core-u8, full-mask, explicit-lane-index + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 4 +COLS = 256 +SEED = 29 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + src = rng.integers(0, 256, size=(ROWS, COLS), dtype=np.uint8) + lane_ids = np.arange(COLS, dtype=np.uint16) + idx = np.empty((ROWS, COLS), dtype=np.uint8) + for row in range(ROWS): + row_idx = (lane_ids[::-1] + row * 19 + (lane_ids % 13) * 5) % COLS + idx[row] = row_idx.astype(np.uint8, copy=False) + golden = np.take_along_axis(src, idx.astype(np.int64, copy=False), axis=1).astype(np.uint8, copy=False) + out = np.zeros((ROWS, COLS), dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + idx.reshape(-1).tofile(output_dir / "v2.bin") + out.reshape(-1).tofile(output_dir / "v3.bin") + golden.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate inputs/golden for vselr-u8.") + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vselr-u8/kernel.pto b/test/vpto/cases/micro-op/compare-select/vselr-u8/kernel.pto new file mode 100644 index 000000000..3b33248ac --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vselr-u8/kernel.pto @@ -0,0 +1,54 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/compare-select/vselr-u8 +// family: compare-select +// target_ops: pto.vselr +// scenarios: core-u8, full-mask, explicit-lane-index +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vselr_u8_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c4_i64 = arith.constant 4 : i64 + %c256_i64 = arith.constant 256 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c1024_i32 = arith.constant 1024 : i32 + %false = arith.constant false + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_idx = pto.castptr %c1024_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + pto.mte_gm_ub %arg0, %ub_src, %c0_i64, %c256_i64 + nburst(%c4_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_idx, %c0_i64, %c256_i64 + nburst(%c4_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c256 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b8 %remaining : i32 -> !pto.mask, i32 + %src = pto.vlds %ub_src[%offset] : !pto.ptr -> !pto.vreg<256xui8> + %idx = pto.vlds %ub_idx[%offset] : !pto.ptr -> !pto.vreg<256xui8> + %out = pto.vselr %src, %idx : !pto.vreg<256xui8>, !pto.vreg<256xui8> -> !pto.vreg<256xui8> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<256xui8>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c256_i64 + nburst(%c4_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/compare-select/vselr-u8/launch.cpp b/test/vpto/cases/micro-op/compare-select/vselr-u8/launch.cpp new file mode 100644 index 000000000..a8d38e8ea --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vselr-u8/launch.cpp @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/compare-select/vselr-u8 +// family: compare-select +// target_ops: pto.vselr +// scenarios: core-u8, full-mask, explicit-lane-index +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vselr_u8_kernel_2d(__gm__ uint8_t *v1, + __gm__ uint8_t *v2, + __gm__ uint8_t *v3); + +void LaunchVselr_u8_kernel_2d(uint8_t *v1, uint8_t *v2, uint8_t *v3, + void *stream) { + vselr_u8_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint8_t *)v1, + (__gm__ uint8_t *)v2, + (__gm__ uint8_t *)v3); +} diff --git a/test/vpto/cases/micro-op/compare-select/vselr-u8/main.cpp b/test/vpto/cases/micro-op/compare-select/vselr-u8/main.cpp new file mode 100644 index 000000000..78f0a8d16 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vselr-u8/main.cpp @@ -0,0 +1,106 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/compare-select/vselr-u8 +// family: compare-select +// target_ops: pto.vselr +// scenarios: core-u8, full-mask, explicit-lane-index +// ----------------------------------------------------------------------------- +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVselr_u8_kernel_2d(uint8_t *v1, uint8_t *v2, uint8_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint8_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint8_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint8_t); + uint8_t *v1Host = nullptr; + uint8_t *v1Device = nullptr; + uint8_t *v2Host = nullptr; + uint8_t *v2Device = nullptr; + uint8_t *v3Host = nullptr; + uint8_t *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVselr_u8_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/compare-select/vselr/compare.py b/test/vpto/cases/micro-op/compare-select/vselr/compare.py new file mode 100755 index 000000000..b0da91196 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vselr/compare.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/compare-select/vselr +# family: compare-select +# target_ops: pto.vselr +# scenarios: core-f32, full-mask, explicit-lane-index + +import os +import sys +import numpy as np + +def compare_tensor(golden_path, output_path): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.float32) + output = np.fromfile(output_path, dtype=np.float32) + if golden.shape != output.shape: + return False + if not np.allclose(golden, output, rtol=0.0, atol=0.0): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] Mismatch: idx={idx} golden={float(golden[idx])} out={float(output[idx])}") + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_tensor("golden_v3.bin", "v3.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vselr/golden.py b/test/vpto/cases/micro-op/compare-select/vselr/golden.py new file mode 100755 index 000000000..6362369bd --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vselr/golden.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/compare-select/vselr +# family: compare-select +# target_ops: pto.vselr +# scenarios: core-f32, full-mask, explicit-lane-index +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-3.0, 3.0, size=(ROWS, COLS)).astype(np.float32, copy=False) + src = v1.reshape(16, 64) + lane_ids = np.arange(64, dtype=np.int32) + idx = np.empty((16, 64), dtype=np.int32) + for row in range(16): + idx[row] = (lane_ids[::-1] + row * 3 + (lane_ids // 8) * 3) % 64 + golden_v3 = np.take_along_axis(src, idx, axis=1).astype(np.float32, copy=False).reshape(ROWS, COLS) + v3 = np.zeros((ROWS, COLS), dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + idx.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vselr validation." + ) + parser.add_argument( + "--output-dir", + type=Path, + default=Path("."), + help="Directory where v1.bin/v2.bin/v3.bin/golden_v3.bin are written.", + ) + parser.add_argument( + "--seed", + type=int, + default=SEED, + help="Numpy random seed.", + ) + args = parser.parse_args() + + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/compare-select/vselr/kernel.pto b/test/vpto/cases/micro-op/compare-select/vselr/kernel.pto new file mode 100644 index 000000000..fb8b47ff0 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vselr/kernel.pto @@ -0,0 +1,74 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/compare-select/vselr +// family: compare-select +// target_ops: pto.vselr +// scenarios: core-f32, full-mask, explicit-lane-index +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vselr_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: !pto.ptr) attributes {pto.kernel} { + %c8192_i64 = arith.constant 8192 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c0_i64 = arith.constant 0 : i64 + %c32 = arith.constant 32 : index + %0 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %1 = arith.index_castui %c32 : index to i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c4_i64 = arith.constant 4 : i64 + %2 = arith.muli %1, %c4_i64 : i64 + %c128_i64 = arith.constant 128 : i64 + %3 = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + %4 = arith.index_castui %c0_i64 : i64 to index + %5 = pto.addptr %3, %4 : -> + pto.set_loop2_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 + pto.set_loop1_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 + %6 = pto.castptr %5 : !pto.ptr -> !pto.ptr + %false = arith.constant false + pto.mte_gm_ub %6, %0, %c0_i64, %2 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + %7 = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %8 = pto.castptr %arg1 : !pto.ptr -> !pto.ptr + %9 = pto.addptr %8, %4 : -> + pto.set_loop2_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 + pto.set_loop1_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 + %10 = pto.castptr %9 : !pto.ptr -> !pto.ptr + pto.mte_gm_ub %10, %7, %c0_i64, %2 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + %11 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c1024_i32 = arith.constant 1024 : i32 + pto.vecscope { + %16 = scf.for %arg4 = %c0 to %c16 step %c1 iter_args(%arg5 = %c1024_i32) -> (i32) { + %17 = arith.muli %arg4, %c64 : index + %mask, %scalar_out = pto.plt_b32 %arg5 : i32 -> !pto.mask, i32 + %25 = pto.vlds %0[%17] : !pto.ptr -> !pto.vreg<64xf32> + %26 = pto.vlds %7[%17] : !pto.ptr -> !pto.vreg<64xi32> + %27 = pto.vselr %25, %26 : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> + pto.vsts %27, %11[%17], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %scalar_out : i32 + } + } + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + %c1024_i64 = arith.constant 1024 : i64 + %12 = arith.muli %1, %c4_i64 : i64 + %13 = pto.castptr %arg2 : !pto.ptr -> !pto.ptr + %14 = pto.addptr %13, %4 : -> + pto.set_loop1_stride_ubtoout %c4096_i64, %c4096_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c4096_i64, %c4096_i64 : i64, i64 + %15 = pto.castptr %14 : !pto.ptr -> !pto.ptr + pto.mte_ub_gm %11, %15, %12 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/compare-select/vselr/launch.cpp b/test/vpto/cases/micro-op/compare-select/vselr/launch.cpp new file mode 100644 index 000000000..68e4c6169 --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vselr/launch.cpp @@ -0,0 +1,57 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/compare-select/vselr +// family: compare-select +// target_ops: pto.vselr +// scenarios: core-f32, full-mask, explicit-lane-index +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vselr_kernel_2d(__gm__ float *v1, + __gm__ int *v2, + __gm__ float *v3); + +void LaunchVselr_kernel_2d(float *v1, int *v2, float *v3, + void *stream) { + vselr_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ int *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/compare-select/vselr/main.cpp b/test/vpto/cases/micro-op/compare-select/vselr/main.cpp new file mode 100644 index 000000000..62fd4bebf --- /dev/null +++ b/test/vpto/cases/micro-op/compare-select/vselr/main.cpp @@ -0,0 +1,109 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/compare-select/vselr +// family: compare-select +// target_ops: pto.vselr +// scenarios: core-f32, full-mask, explicit-lane-index +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVselr_kernel_2d(float *v1, int *v2, float *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(int); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + int *v2Host = nullptr; + int *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVselr_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-special/compare.py b/test/vpto/cases/micro-op/conversion/vcvt-f16-special/compare.py new file mode 100644 index 000000000..751000b6f --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-special/compare.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-special/golden.py b/test/vpto/cases/micro-op/conversion/vcvt-f16-special/golden.py new file mode 100755 index 000000000..66921951a --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-special/golden.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/conversion/vcvt-f16-special +# family: conversion +# target_ops: pto.vcvt +# scenarios: f16-to-f32, exceptional-values +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + del seed + special = np.array( + [ + np.float16(0.0), + np.float16(-0.0), + np.float16(1.0), + np.float16(-1.0), + np.float16(np.inf), + np.float16(-np.inf), + np.float16(np.nan), + np.float16(65504.0), + np.float16(-65504.0), + np.float16(6.1035e-05), + np.float16(-6.1035e-05), + np.float16(5.9605e-08), + np.float16(-5.9605e-08), + np.float16(123.75), + np.float16(-123.75), + np.float16(0.33325), + ], + dtype=np.float16, + ) + v1 = np.resize(special, ROWS * COLS).reshape(ROWS, COLS) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = v1.astype(np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vcvt-f16-special validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-special/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-f16-special/kernel.pto new file mode 100644 index 000000000..f147b0a99 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-special/kernel.pto @@ -0,0 +1,43 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vcvt_f16_special_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %full_mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %cvt_mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c64 { + %loaded = pto.vlds %ub_in[%offset] {dist = "UNPK_B16"} : !pto.ptr -> !pto.vreg<128xf16> + %out = pto.vcvt %loaded, %cvt_mask {part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %full_mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-special/launch.cpp b/test/vpto/cases/micro-op/conversion/vcvt-f16-special/launch.cpp new file mode 100644 index 000000000..214c9b5c8 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-special/launch.cpp @@ -0,0 +1,71 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-f16-special +// family: conversion +// target_ops: pto.vcvt +// scenarios: f16-to-f32, exceptional-values +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcvt_f16_special_kernel_2d(__gm__ half *v1, + __gm__ float *v2); + +void LaunchVcvt_f16_special_kernel_2d(uint16_t *v1, float *v2, void *stream) { + vcvt_f16_special_kernel_2d<<<1, nullptr, stream>>>((__gm__ half *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-special/main.cpp b/test/vpto/cases/micro-op/conversion/vcvt-f16-special/main.cpp new file mode 100644 index 000000000..8f83d1edd --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-special/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-f16-special +// family: conversion +// target_ops: pto.vcvt +// scenarios: f16-to-f32, exceptional-values +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcvt_f16_special_kernel_2d(uint16_t *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVcvt_f16_special_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/compare.py b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/compare.py new file mode 100755 index 000000000..751000b6f --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/compare.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/golden.py b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/golden.py new file mode 100644 index 000000000..e071074e7 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/golden.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/conversion/vcvt-f16-to-f32-part-even +# family: conversion +# target_ops: pto.vcvt +# scenarios: f16-to-f32, full-mask, part-even + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=ELEMS).astype(np.float16) + # Kernel writes 8 chunks (offset 0..448, step 64), each chunk converts the + # lower 16-bit half (PART_EVEN) from packed f16 pairs in a 128-lane load. + out_elems = 512 + v2 = np.zeros(out_elems, dtype=np.float32) + golden_v2 = np.empty(out_elems, dtype=np.float32) + for block in range(0, out_elems, 64): + src = v1[block : block + 128 : 2].astype(np.float32, copy=False) + golden_v2[block : block + 64] = src + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vcvt-f16-to-f32 part-even validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/kernel.pto new file mode 100644 index 000000000..05742cdee --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/kernel.pto @@ -0,0 +1,47 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vcvt_f16_to_f32_part_even_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c512 = arith.constant 512 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %full_mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %cvt_mask = pto.pset_b16 "PAT_ALL" : !pto.mask + // Use packed f16 load (no UNPK): PART_EVEN selects the lower 16-bit + // element from each f16 pair inside a b32 lane. + scf.for %offset = %c0 to %c512 step %c64 { + %loaded = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<128xf16> + %out = pto.vcvt %loaded, %cvt_mask {part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %full_mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c16_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/launch.cpp b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/launch.cpp new file mode 100644 index 000000000..8e321d1a5 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/launch.cpp @@ -0,0 +1,71 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-f16-to-f32-part-even +// family: conversion +// target_ops: pto.vcvt +// scenarios: f16-to-f32, full-mask, part-even +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcvt_f16_to_f32_part_even_kernel_2d(__gm__ half *v1, + __gm__ float *v2); + +void LaunchVcvt_f16_to_f32_part_even_kernel_2d(uint16_t *v1, float *v2, void *stream) { + vcvt_f16_to_f32_part_even_kernel_2d<<<1, nullptr, stream>>>((__gm__ half *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/main.cpp b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/main.cpp new file mode 100644 index 000000000..1925124c1 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-even/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-f16-to-f32-part-even +// family: conversion +// target_ops: pto.vcvt +// scenarios: f16-to-f32, full-mask, part-even +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcvt_f16_to_f32_part_even_kernel_2d(uint16_t *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 512; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVcvt_f16_to_f32_part_even_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/compare.py b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/compare.py new file mode 100755 index 000000000..751000b6f --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/compare.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/golden.py b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/golden.py new file mode 100644 index 000000000..2b7822bc8 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/golden.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/conversion/vcvt-f16-to-f32-part-odd +# family: conversion +# target_ops: pto.vcvt +# scenarios: f16-to-f32, full-mask, part-odd + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=ELEMS).astype(np.float16) + # Kernel writes 8 chunks (offset 0..448, step 64), each chunk converts the + # upper 16-bit half (PART_ODD) from packed f16 pairs in a 128-lane load. + out_elems = 512 + v2 = np.zeros(out_elems, dtype=np.float32) + golden_v2 = np.empty(out_elems, dtype=np.float32) + for block in range(0, out_elems, 64): + src = v1[block + 1 : block + 128 : 2].astype(np.float32, copy=False) + golden_v2[block : block + 64] = src + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vcvt-f16-to-f32 part-odd validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/kernel.pto new file mode 100644 index 000000000..29b2d7065 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/kernel.pto @@ -0,0 +1,47 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vcvt_f16_to_f32_part_odd_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c512 = arith.constant 512 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %full_mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %cvt_mask = pto.pset_b16 "PAT_ALL" : !pto.mask + // Use packed f16 load (no UNPK): PART_ODD then selects the upper 16-bit + // element from each f16 pair inside a b32 lane. + scf.for %offset = %c0 to %c512 step %c64 { + %loaded = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<128xf16> + %out = pto.vcvt %loaded, %cvt_mask {part = "ODD"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %full_mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c16_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/launch.cpp b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/launch.cpp new file mode 100644 index 000000000..db23cbbf4 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/launch.cpp @@ -0,0 +1,71 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-f16-to-f32-part-odd +// family: conversion +// target_ops: pto.vcvt +// scenarios: f16-to-f32, full-mask, part-odd +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcvt_f16_to_f32_part_odd_kernel_2d(__gm__ half *v1, + __gm__ float *v2); + +void LaunchVcvt_f16_to_f32_part_odd_kernel_2d(uint16_t *v1, float *v2, void *stream) { + vcvt_f16_to_f32_part_odd_kernel_2d<<<1, nullptr, stream>>>((__gm__ half *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/main.cpp b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/main.cpp new file mode 100644 index 000000000..567aafa0a --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32-part-odd/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-f16-to-f32-part-odd +// family: conversion +// target_ops: pto.vcvt +// scenarios: f16-to-f32, full-mask, part-odd +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcvt_f16_to_f32_part_odd_kernel_2d(uint16_t *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 512; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVcvt_f16_to_f32_part_odd_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/compare.py b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/compare.py new file mode 100755 index 000000000..751000b6f --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/compare.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/golden.py b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/golden.py new file mode 100755 index 000000000..903c2385d --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/golden.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/conversion/vcvt-f16-to-f32 +# family: conversion +# target_ops: pto.vcvt +# scenarios: f16-to-f32, full-mask + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=ELEMS).astype(np.float16) + v2 = np.zeros(ELEMS, dtype=np.float32) + golden_v2 = v1.astype(np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vcvt-f16-to-f32 validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/kernel.pto new file mode 100644 index 000000000..5d57a648e --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/kernel.pto @@ -0,0 +1,43 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vcvt_f16_to_f32_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %full_mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %cvt_mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c64 { + %loaded = pto.vlds %ub_in[%offset] {dist = "UNPK_B16"} : !pto.ptr -> !pto.vreg<128xf16> + %out = pto.vcvt %loaded, %cvt_mask {part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %full_mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/launch.cpp b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/launch.cpp new file mode 100644 index 000000000..4998ce110 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/launch.cpp @@ -0,0 +1,71 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-f16-to-f32 +// family: conversion +// target_ops: pto.vcvt +// scenarios: f16-to-f32, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcvt_f16_to_f32_kernel_2d(__gm__ half *v1, + __gm__ float *v2); + +void LaunchVcvt_f16_to_f32_kernel_2d(uint16_t *v1, float *v2, void *stream) { + vcvt_f16_to_f32_kernel_2d<<<1, nullptr, stream>>>((__gm__ half *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/main.cpp b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/main.cpp new file mode 100644 index 000000000..17f92c862 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f16-to-f32/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-f16-to-f32 +// family: conversion +// target_ops: pto.vcvt +// scenarios: f16-to-f32, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcvt_f16_to_f32_kernel_2d(uint16_t *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVcvt_f16_to_f32_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f32-special/compare.py b/test/vpto/cases/micro-op/conversion/vcvt-f32-special/compare.py new file mode 100644 index 000000000..d2d022505 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f32-special/compare.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float16, 1e-3) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f32-special/golden.py b/test/vpto/cases/micro-op/conversion/vcvt-f32-special/golden.py new file mode 100755 index 000000000..413875307 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f32-special/golden.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/conversion/vcvt-f32-special +# family: conversion +# target_ops: pto.vcvt +# scenarios: f32-to-f16, exceptional-values +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +F16_MAX_FINITE = np.float32(65504.0) + + +def sat_cast_f32_to_f16(values: np.ndarray) -> np.ndarray: + values = np.where(np.isnan(values), np.float32(0.0), values) + values = np.clip(values, -F16_MAX_FINITE, F16_MAX_FINITE) + return values.astype(np.float16) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + special = np.array( + [ + 0.0, + -0.0, + 1.0, + -1.0, + np.inf, + -np.inf, + np.nan, + 65504.0, + -65504.0, + 1.0e-8, + -1.0e-8, + 1.0e-4, + -1.0e-4, + 123.75, + -123.75, + 0.33333334, + ], + dtype=np.float32, + ) + flat = np.resize(special, ROWS * COLS).astype(np.float32) + v1 = flat.reshape(ROWS, COLS) + v2 = np.zeros((ROWS, COLS), dtype=np.float16) + golden_flat = np.zeros(ROWS * COLS, dtype=np.float16) + + for offset in range(0, ROWS * COLS, 128): + lower = sat_cast_f32_to_f16(flat[offset : offset + 64]) + upper = sat_cast_f32_to_f16(flat[offset + 64 : offset + 128]) + merged = np.empty(128, dtype=np.float16) + merged[0::2] = lower + merged[1::2] = upper + golden_flat[offset : offset + 128] = merged + + golden_v2 = golden_flat.reshape(ROWS, COLS) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vcvt-f32-special validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f32-special/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-f32-special/kernel.pto new file mode 100644 index 000000000..583b208c1 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f32-special/kernel.pto @@ -0,0 +1,52 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vcvt_f32_special_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %lower_mask, %upper_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %upper_mask, %_next_b32 = pto.plt_b32 %upper_remaining : i32 -> !pto.mask, i32 + %upper_offset = arith.addi %offset, %c64 : index + %lower = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %upper = pto.vlds %ub_in[%upper_offset] : !pto.ptr -> !pto.vreg<64xf32> + %even = pto.vcvt %lower, %lower_mask {rnd = "R", sat = "SAT", part = "EVEN"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> + %odd = pto.vcvt %upper, %upper_mask {rnd = "R", sat = "SAT", part = "ODD"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> + %merged = pto.vor %even, %odd, %mask : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> + pto.vsts %merged, %ub_out[%offset], %mask : !pto.vreg<128xf16>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f32-special/launch.cpp b/test/vpto/cases/micro-op/conversion/vcvt-f32-special/launch.cpp new file mode 100644 index 000000000..64c50ea0d --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f32-special/launch.cpp @@ -0,0 +1,71 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-f32-special +// family: conversion +// target_ops: pto.vcvt +// scenarios: f32-to-f16, exceptional-values +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcvt_f32_special_kernel_2d(__gm__ float *v1, + __gm__ half *v2); + +void LaunchVcvt_f32_special_kernel_2d(float *v1, uint16_t *v2, void *stream) { + vcvt_f32_special_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ half *)v2); +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f32-special/main.cpp b/test/vpto/cases/micro-op/conversion/vcvt-f32-special/main.cpp new file mode 100644 index 000000000..73f29e4ca --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f32-special/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-f32-special +// family: conversion +// target_ops: pto.vcvt +// scenarios: f32-to-f16, exceptional-values +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcvt_f32_special_kernel_2d(float *v1, uint16_t *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + float *v1Host = nullptr; + float *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVcvt_f32_special_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/compare.py b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/compare.py new file mode 100644 index 000000000..d2d022505 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/compare.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float16, 1e-3) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/golden.py b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/golden.py new file mode 100644 index 000000000..ee8fd3890 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/golden.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/conversion/vcvt-f32-to-f16-pk-b32 +# family: conversion +# target_ops: pto.vcvt, pto.vsts +# scenarios: f32-to-f16, pk-b32-store, full-mask + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=ELEMS).astype(np.float32) + v2 = np.zeros(ELEMS, dtype=np.float16) + golden_v2 = v1.astype(np.float16) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vcvt-f32-to-f16-pk-b32 validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/kernel.pto new file mode 100644 index 000000000..2e12d41be --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/kernel.pto @@ -0,0 +1,48 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-f32-to-f16-pk-b32 +// family: conversion +// target_ops: pto.vcvt, pto.vsts +// scenarios: f32-to-f16, pk-b32-store, full-mask +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vcvt_f32_to_f16_pk_b32_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c64 { + %loaded = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %converted = pto.vcvt %loaded, %mask {rnd = "R", sat = "SAT", part = "EVEN"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> + pto.vsts %converted, %ub_out[%offset], %mask {dist = "PK_B32"} : !pto.vreg<128xf16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/launch.cpp b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/launch.cpp new file mode 100644 index 000000000..77055836f --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcvt_f32_to_f16_pk_b32_kernel(__gm__ float *v1, + __gm__ half *v2); + +void LaunchVcvt_f32_to_f16_pk_b32_kernel(float *v1, aclFloat16 *v2, void *stream) { + vcvt_f32_to_f16_pk_b32_kernel<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ half *)v2); +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/main.cpp b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/main.cpp new file mode 100644 index 000000000..8b7886671 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16-pk-b32/main.cpp @@ -0,0 +1,122 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcvt_f32_to_f16_pk_b32_kernel(float *v1, aclFloat16 *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(aclFloat16); + float *v1Host = nullptr; + float *v1Device = nullptr; + aclFloat16 *v2Host = nullptr; + aclFloat16 *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVcvt_f32_to_f16_pk_b32_kernel(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/compare.py b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/compare.py new file mode 100644 index 000000000..d2d022505 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/compare.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float16, 1e-3) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/golden.py b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/golden.py new file mode 100755 index 000000000..55a924e97 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/golden.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/conversion/vcvt-f32-to-f16 +# family: conversion +# target_ops: pto.vcvt +# scenarios: f32-to-f16, full-mask + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=ELEMS).astype(np.float32) + v2 = np.zeros(ELEMS, dtype=np.float16) + golden_v2 = np.zeros(ELEMS, dtype=np.float16) + + # Width-changing f32->f16 lowering uses two 64-lane f32 vectors, converts + # them into EVEN/ODD halves, then merges them into one 128-lane f16 vector. + for offset in range(0, ELEMS, 128): + lower = v1[offset : offset + 64].astype(np.float16) + upper = v1[offset + 64 : offset + 128].astype(np.float16) + merged = np.empty(128, dtype=np.float16) + merged[0::2] = lower + merged[1::2] = upper + golden_v2[offset : offset + 128] = merged + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vcvt-f32-to-f16 validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/kernel.pto new file mode 100644 index 000000000..01ab9c588 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/kernel.pto @@ -0,0 +1,52 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vcvt_f32_to_f16_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %lower_mask, %upper_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %upper_mask, %_next_b32 = pto.plt_b32 %upper_remaining : i32 -> !pto.mask, i32 + %upper_offset = arith.addi %offset, %c64 : index + %lower = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %upper = pto.vlds %ub_in[%upper_offset] : !pto.ptr -> !pto.vreg<64xf32> + %even = pto.vcvt %lower, %lower_mask {rnd = "R", sat = "SAT", part = "EVEN"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> + %odd = pto.vcvt %upper, %upper_mask {rnd = "R", sat = "SAT", part = "ODD"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> + %merged = pto.vor %even, %odd, %mask : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> + pto.vsts %merged, %ub_out[%offset], %mask : !pto.vreg<128xf16>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/launch.cpp b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/launch.cpp new file mode 100644 index 000000000..8dcc00348 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/launch.cpp @@ -0,0 +1,71 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-f32-to-f16 +// family: conversion +// target_ops: pto.vcvt +// scenarios: f32-to-f16, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcvt_f32_to_f16_kernel_2d(__gm__ float *v1, + __gm__ half *v2); + +void LaunchVcvt_f32_to_f16_kernel_2d(float *v1, uint16_t *v2, void *stream) { + vcvt_f32_to_f16_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ half *)v2); +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/main.cpp b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/main.cpp new file mode 100644 index 000000000..cf7a6c2de --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-f32-to-f16/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-f32-to-f16 +// family: conversion +// target_ops: pto.vcvt +// scenarios: f32-to-f16, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcvt_f32_to_f16_kernel_2d(float *v1, uint16_t *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + float *v1Host = nullptr; + float *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVcvt_f32_to_f16_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/compare.py b/test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/compare.py new file mode 100644 index 000000000..fe3cc3abc --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/conversion/vcvt-i32-to-i16-overflow +# family: conversion +# target_ops: pto.vcvt +# scenarios: i32-to-i16, integer-overflow + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.int16) + output = np.fromfile(output_path, dtype=np.int16) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/golden.py b/test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/golden.py new file mode 100644 index 000000000..6fbfc4834 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/golden.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/conversion/vcvt-i32-to-i16-overflow +# family: conversion +# target_ops: pto.vcvt +# scenarios: i32-to-i16, integer-overflow + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 +I16_MIN = np.iinfo(np.int16).min +I16_MAX = np.iinfo(np.int16).max + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + data = rng.integers(-200000, 200000, size=ELEMS, dtype=np.int32) + edge = np.array([ + -40000, -32769, -32768, -32767, -1, 0, 1, 32766, + 32767, 32768, 40000, 70000, -70000, 65535, -65535, 123456, + ], dtype=np.int32) + data[:edge.size] = edge + clipped = np.clip(data, I16_MIN, I16_MAX).astype(np.int16) + golden = np.zeros(ELEMS, dtype=np.int16) + for offset in range(0, ELEMS, 128): + lower = clipped[offset : offset + 64] + upper = clipped[offset + 64 : offset + 128] + merged = np.empty(128, dtype=np.int16) + merged[0::2] = lower + merged[1::2] = upper + golden[offset : offset + 128] = merged + + output_dir.mkdir(parents=True, exist_ok=True) + data.tofile(output_dir / "v1.bin") + np.zeros(ELEMS, dtype=np.int16).tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/kernel.pto new file mode 100644 index 000000000..b4ef9032c --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/kernel.pto @@ -0,0 +1,57 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-i32-to-i16-overflow +// family: conversion +// target_ops: pto.vcvt +// scenarios: i32-to-i16, integer-overflow +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vcvt_i32_to_i16_overflow_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %lower_mask, %upper_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %upper_mask, %_next_b32 = pto.plt_b32 %upper_remaining : i32 -> !pto.mask, i32 + %upper_offset = arith.addi %offset, %c64 : index + %lower = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xi32> + %upper = pto.vlds %ub_in[%upper_offset] : !pto.ptr -> !pto.vreg<64xi32> + %even = pto.vcvt %lower, %lower_mask {sat = "SAT", part = "EVEN"} : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<128xi16> + %odd = pto.vcvt %upper, %upper_mask {sat = "SAT", part = "ODD"} : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<128xi16> + %merged = pto.vor %even, %odd, %mask : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + pto.vsts %merged, %ub_out[%offset], %mask : !pto.vreg<128xi16>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/launch.cpp b/test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/launch.cpp new file mode 100644 index 000000000..0ac6bc67d --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/launch.cpp @@ -0,0 +1,49 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcvt_i32_to_i16_overflow_kernel( + __gm__ int32_t *v1, __gm__ int16_t *v2); + +void LaunchVcvt_i32_to_i16_overflow_kernel(int32_t *v1, int16_t *v2, + void *stream) { + vcvt_i32_to_i16_overflow_kernel<<<1, nullptr, stream>>>( + (__gm__ int32_t *)v1, (__gm__ int16_t *)v2); +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/main.cpp b/test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/main.cpp new file mode 100644 index 000000000..ab8500906 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-i32-to-i16-overflow/main.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-i32-to-i16-overflow +// family: conversion +// target_ops: pto.vcvt +// scenarios: i32-to-i16, integer-overflow +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcvt_i32_to_i16_overflow_kernel(int32_t *v1, int16_t *v2, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(int32_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(int16_t); + int32_t *v1Host = nullptr; + int32_t *v1Device = nullptr; + int16_t *v2Host = nullptr; + int16_t *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVcvt_i32_to_i16_overflow_kernel(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-i64-to-f32/compare.py b/test/vpto/cases/micro-op/conversion/vcvt-i64-to-f32/compare.py new file mode 100755 index 000000000..696722bfe --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-i64-to-f32/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# case: micro-op/conversion/vcvt-i64-to-f32 +# family: conversion +# target_ops: pto.mte_gm_ub, pto.mte_ub_gm, pto.vcvt, pto.vsts +# scenarios: i64-dma-roundtrip, i64-to-f32, signed-input, rounded, part-even-low-half + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str, dtype) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32) + ok = ok and compare_bin("golden_v3.bin", "v3.bin", np.int64) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vcvt-i64-to-f32/golden.py b/test/vpto/cases/micro-op/conversion/vcvt-i64-to-f32/golden.py new file mode 100755 index 000000000..d05262692 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-i64-to-f32/golden.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/conversion/vcvt-i64-to-f32 +# family: conversion +# target_ops: pto.mte_gm_ub, pto.mte_ub_gm, pto.vcvt, pto.vsts +# scenarios: i64-dma-roundtrip, i64-to-f32, signed-input, rounded, part-even-low-half + +import argparse +from pathlib import Path + +import numpy as np + + +INPUT_ELEMS = 1024 +OUTPUT_ELEMS = 512 +ROUNDTRIP_ELEMS = 1024 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + edge = np.array( + [ + -(1 << 31), + -(1 << 24) - 3, + -(1 << 24) - 1, + -(1 << 24), + -(1 << 24) + 1, + -65537, + -32769, + -32768, + -1, + 0, + 1, + 32767, + 32768, + 65537, + (1 << 24) - 1, + 1 << 24, + (1 << 24) + 1, + (1 << 24) + 3, + (1 << 31) - 2, + (1 << 31) - 1, + ], + dtype=np.int32, + ) + base = rng.integers(np.iinfo(np.int32).min, np.iinfo(np.int32).max, + size=INPUT_ELEMS, dtype=np.int32) + base[: edge.size] = edge + v1 = base.astype(np.int64) + v2 = np.zeros(OUTPUT_ELEMS, dtype=np.float32) + v3 = np.zeros(ROUNDTRIP_ELEMS, dtype=np.int64) + golden_v2 = np.concatenate( + [base[offset : offset + 16] for offset in range(0, INPUT_ELEMS, 32)] + ).astype(np.float32) + golden_v3 = v1.copy() + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + v3.tofile(output_dir / "v3.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + golden_v3.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vcvt-i64-to-f32 validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vcvt-i64-to-f32/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-i64-to-f32/kernel.pto new file mode 100644 index 000000000..87af44adb --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-i64-to-f32/kernel.pto @@ -0,0 +1,55 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-i64-to-f32 +// family: conversion +// target_ops: pto.mte_gm_ub, pto.mte_ub_gm, pto.vcvt, pto.vsts +// scenarios: i64-dma-roundtrip, i64-to-f32, signed-input, rounded, part-even-low-half +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vcvt_i64_to_f32_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c512 = arith.constant 512 : index + %c0_i64 = arith.constant 0 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c256_i64 = arith.constant 256 : i64 + %c8192_i64 = arith.constant 8192 : i64 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c256_i64 + nburst(%c32_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.set_flag["PIPE_MTE2", "PIPE_MTE3", "EVENT_ID1"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE3", "EVENT_ID1"] + pto.mte_ub_gm %ub_in, %arg2, %c256_i64 + nburst(%c32_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + scf.for %store_offset = %c0 to %c512 step %c16 { + %input_offset = arith.muli %store_offset, %c2 : index + %loaded = pto.vlds %ub_in[%input_offset] : !pto.ptr -> !pto.vreg<32xsi64> + %converted = pto.vcvt %loaded, %mask {rnd = "R", part = "EVEN"} : !pto.vreg<32xsi64>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %converted, %ub_out[%store_offset], %mask {dist = "PK_B64"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-i64-to-f32/launch.cpp b/test/vpto/cases/micro-op/conversion/vcvt-i64-to-f32/launch.cpp new file mode 100644 index 000000000..deff20af5 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-i64-to-f32/launch.cpp @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-i64-to-f32 +// family: conversion +// target_ops: pto.mte_gm_ub, pto.mte_ub_gm, pto.vcvt, pto.vsts +// scenarios: i64-dma-roundtrip, i64-to-f32, signed-input, rounded, part-even-low-half +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcvt_i64_to_f32_kernel( + __gm__ int64_t *v1, __gm__ float *v2, __gm__ int64_t *v3); + +void LaunchVcvt_i64_to_f32_kernel(int64_t *v1, float *v2, int64_t *v3, + void *stream) { + vcvt_i64_to_f32_kernel<<<1, nullptr, stream>>>( + (__gm__ int64_t *)v1, (__gm__ float *)v2, (__gm__ int64_t *)v3); +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-i64-to-f32/main.cpp b/test/vpto/cases/micro-op/conversion/vcvt-i64-to-f32/main.cpp new file mode 100644 index 000000000..d51d24e6e --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-i64-to-f32/main.cpp @@ -0,0 +1,132 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-i64-to-f32 +// family: conversion +// target_ops: pto.mte_gm_ub, pto.mte_ub_gm, pto.vcvt, pto.vsts +// scenarios: i64-dma-roundtrip, i64-to-f32, signed-input, rounded, part-even-low-half +// ----------------------------------------------------------------------------- + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcvt_i64_to_f32_kernel(int64_t *v1, float *v2, int64_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(int64_t); + size_t elemCount_v2 = 512; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(int64_t); + int64_t *v1Host = nullptr; + int64_t *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int64_t *v3Host = nullptr; + int64_t *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVcvt_i64_to_f32_kernel(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-tail-special/compare.py b/test/vpto/cases/micro-op/conversion/vcvt-tail-special/compare.py new file mode 100644 index 000000000..166196a8e --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-tail-special/compare.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys + +import numpy as np + + +LOGICAL_ELEMS = 1000 + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v2.bin", "v2.bin", np.float16, 1e-3, LOGICAL_ELEMS) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vcvt-tail-special/golden.py b/test/vpto/cases/micro-op/conversion/vcvt-tail-special/golden.py new file mode 100755 index 000000000..0b4417baf --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-tail-special/golden.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/conversion/vcvt-tail-special +# family: conversion +# target_ops: pto.vcvt +# scenarios: f32-to-f16, tail-mask, exceptional-values +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +LOGICAL_ELEMS = 1000 +F16_MAX_FINITE = np.float32(65504.0) + + +def sat_cast_f32_to_f16(values: np.ndarray) -> np.ndarray: + values = np.where(np.isnan(values), np.float32(0.0), values) + values = np.clip(values, -F16_MAX_FINITE, F16_MAX_FINITE) + return values.astype(np.float16) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + special = np.array( + [ + 0.0, + -0.0, + 1.0, + -1.0, + np.inf, + -np.inf, + np.nan, + 65504.0, + -65504.0, + 1.0e-8, + -1.0e-8, + 1.0e-4, + -1.0e-4, + 123.75, + -123.75, + 0.33333334, + ], + dtype=np.float32, + ) + flat = np.resize(special, ROWS * COLS).astype(np.float32) + flat[LOGICAL_ELEMS:] = 0.0 + v1 = flat.reshape(ROWS, COLS) + v2 = np.zeros((ROWS, COLS), dtype=np.float16) + golden_flat = np.zeros(ROWS * COLS, dtype=np.float16) + + remaining = LOGICAL_ELEMS + for offset in range(0, ROWS * COLS, 128): + lower = sat_cast_f32_to_f16(flat[offset : offset + 64]) + upper = sat_cast_f32_to_f16(flat[offset + 64 : offset + 128]) + merged = np.empty(128, dtype=np.float16) + merged[0::2] = lower + merged[1::2] = upper + active = min(remaining, 128) + golden_flat[offset : offset + active] = merged[:active] + remaining = max(remaining - 128, 0) + + golden_v2 = golden_flat.reshape(ROWS, COLS) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vcvt-tail-special validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vcvt-tail-special/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-tail-special/kernel.pto new file mode 100644 index 000000000..2b5cebf1b --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-tail-special/kernel.pto @@ -0,0 +1,52 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vcvt_tail_special_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1000_i32 = arith.constant 1000 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1000_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %lower_mask, %upper_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %upper_mask, %_next_b32 = pto.plt_b32 %upper_remaining : i32 -> !pto.mask, i32 + %upper_offset = arith.addi %offset, %c64 : index + %lower = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %upper = pto.vlds %ub_in[%upper_offset] : !pto.ptr -> !pto.vreg<64xf32> + %even = pto.vcvt %lower, %lower_mask {rnd = "R", sat = "SAT", part = "EVEN"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> + %odd = pto.vcvt %upper, %upper_mask {rnd = "R", sat = "SAT", part = "ODD"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> + %merged = pto.vor %even, %odd, %mask : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> + pto.vsts %merged, %ub_out[%offset], %mask : !pto.vreg<128xf16>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-tail-special/launch.cpp b/test/vpto/cases/micro-op/conversion/vcvt-tail-special/launch.cpp new file mode 100644 index 000000000..128254d29 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-tail-special/launch.cpp @@ -0,0 +1,71 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-tail-special +// family: conversion +// target_ops: pto.vcvt +// scenarios: f32-to-f16, tail-mask, exceptional-values +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcvt_tail_special_kernel_2d(__gm__ float *v1, + __gm__ half *v2); + +void LaunchVcvt_tail_special_kernel_2d(float *v1, uint16_t *v2, void *stream) { + vcvt_tail_special_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ half *)v2); +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-tail-special/main.cpp b/test/vpto/cases/micro-op/conversion/vcvt-tail-special/main.cpp new file mode 100644 index 000000000..155e88b98 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-tail-special/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-tail-special +// family: conversion +// target_ops: pto.vcvt +// scenarios: f32-to-f16, tail-mask, exceptional-values +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcvt_tail_special_kernel_2d(float *v1, uint16_t *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + float *v1Host = nullptr; + float *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVcvt_tail_special_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-tail/compare.py b/test/vpto/cases/micro-op/conversion/vcvt-tail/compare.py new file mode 100644 index 000000000..166196a8e --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-tail/compare.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys + +import numpy as np + + +LOGICAL_ELEMS = 1000 + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v2.bin", "v2.bin", np.float16, 1e-3, LOGICAL_ELEMS) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vcvt-tail/golden.py b/test/vpto/cases/micro-op/conversion/vcvt-tail/golden.py new file mode 100755 index 000000000..b121f1e29 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-tail/golden.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/conversion/vcvt-tail +# family: conversion +# target_ops: pto.vcvt +# scenarios: f32-to-f16, tail-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +LOGICAL_ELEMS = 1000 +OUT_SENTINEL = np.float16(-123.25) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + flat = rng.uniform(-8.0, 8.0, size=ROWS * COLS).astype(np.float32) + flat[LOGICAL_ELEMS:] = 0.0 + v1 = flat.reshape(ROWS, COLS) + v2 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float16) + golden_flat = np.full(ROWS * COLS, OUT_SENTINEL, dtype=np.float16) + + remaining = LOGICAL_ELEMS + for offset in range(0, ROWS * COLS, 128): + lower = flat[offset : offset + 64].astype(np.float16) + upper = flat[offset + 64 : offset + 128].astype(np.float16) + merged = np.empty(128, dtype=np.float16) + merged[0::2] = lower + merged[1::2] = upper + active = min(remaining, 128) + golden_flat[offset : offset + active] = merged[:active] + remaining = max(remaining - 128, 0) + + golden_v2 = golden_flat.reshape(ROWS, COLS) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vcvt-tail validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vcvt-tail/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-tail/kernel.pto new file mode 100644 index 000000000..e724ee2b1 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-tail/kernel.pto @@ -0,0 +1,52 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vcvt_tail_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1000_i32 = arith.constant 1000 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1000_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %lower_mask, %upper_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %upper_mask, %_next_b32 = pto.plt_b32 %upper_remaining : i32 -> !pto.mask, i32 + %upper_offset = arith.addi %offset, %c64 : index + %lower = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %upper = pto.vlds %ub_in[%upper_offset] : !pto.ptr -> !pto.vreg<64xf32> + %even = pto.vcvt %lower, %lower_mask {rnd = "R", sat = "SAT", part = "EVEN"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> + %odd = pto.vcvt %upper, %upper_mask {rnd = "R", sat = "SAT", part = "ODD"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> + %merged = pto.vor %even, %odd, %mask : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> + pto.vsts %merged, %ub_out[%offset], %mask : !pto.vreg<128xf16>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-tail/launch.cpp b/test/vpto/cases/micro-op/conversion/vcvt-tail/launch.cpp new file mode 100644 index 000000000..5773d5044 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-tail/launch.cpp @@ -0,0 +1,71 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-tail +// family: conversion +// target_ops: pto.vcvt +// scenarios: f32-to-f16, tail-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcvt_tail_kernel_2d(__gm__ float *v1, + __gm__ half *v2); + +void LaunchVcvt_tail_kernel_2d(float *v1, uint16_t *v2, void *stream) { + vcvt_tail_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ half *)v2); +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-tail/main.cpp b/test/vpto/cases/micro-op/conversion/vcvt-tail/main.cpp new file mode 100644 index 000000000..9a0abf5cb --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-tail/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-tail +// family: conversion +// target_ops: pto.vcvt +// scenarios: f32-to-f16, tail-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcvt_tail_kernel_2d(float *v1, uint16_t *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + float *v1Host = nullptr; + float *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVcvt_tail_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/compare.py b/test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/compare.py new file mode 100644 index 000000000..918a4b0bc --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/golden.py b/test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/golden.py new file mode 100644 index 000000000..487f52bfb --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/golden.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/conversion/vcvt-u32-to-u8-part-p0123 +# family: conversion +# target_ops: pto.vcvt +# scenarios: u32-to-u8, sat, part-p0123 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 23 +CHUNK = 256 +SUBCHUNK = 64 +U8_MAX = np.iinfo(np.uint8).max + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + data = rng.integers(0, 2000, size=ELEMS, dtype=np.uint32) + edge = np.array( + [ + 0, + 1, + 2, + 3, + 4, + 7, + 15, + 31, + 63, + 127, + 128, + 129, + 254, + 255, + 256, + 257, + 511, + 512, + 1023, + 65535, + 0xFFFFFFFF, + ], + dtype=np.uint32, + ) + data[: edge.size] = edge + + clipped = np.clip(data, 0, U8_MAX).astype(np.uint8) + golden = np.empty(ELEMS, dtype=np.uint8) + for offset in range(0, ELEMS, CHUNK): + p0 = clipped[offset : offset + SUBCHUNK] + p1 = clipped[offset + SUBCHUNK : offset + 2 * SUBCHUNK] + p2 = clipped[offset + 2 * SUBCHUNK : offset + 3 * SUBCHUNK] + p3 = clipped[offset + 3 * SUBCHUNK : offset + 4 * SUBCHUNK] + merged = np.empty(CHUNK, dtype=np.uint8) + merged[0::4] = p0 + merged[1::4] = p1 + merged[2::4] = p2 + merged[3::4] = p3 + golden[offset : offset + CHUNK] = merged + + output_dir.mkdir(parents=True, exist_ok=True) + data.tofile(output_dir / "v1.bin") + np.zeros(ELEMS, dtype=np.uint8).tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/kernel.pto b/test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/kernel.pto new file mode 100644 index 000000000..fbeeb81bd --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/kernel.pto @@ -0,0 +1,64 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-u32-to-u8-part-p0123 +// family: conversion +// target_ops: pto.vcvt +// scenarios: u32-to-u8, sat, part-p0123 +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vcvt_u32_to_u8_part_p0123_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c192 = arith.constant 192 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c4_i64 = arith.constant 4 : i64 + %c8_i64 = arith.constant 8 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c256_i64 = arith.constant 256 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c128_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %full_mask = pto.pset_b8 "PAT_ALL" : !pto.mask + %cvt_mask = pto.pset_b32 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c256 { + %offset_p1 = arith.addi %offset, %c64 : index + %offset_p2 = arith.addi %offset, %c128 : index + %offset_p3 = arith.addi %offset, %c192 : index + %src_p0 = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xui32> + %src_p1 = pto.vlds %ub_in[%offset_p1] : !pto.ptr -> !pto.vreg<64xui32> + %src_p2 = pto.vlds %ub_in[%offset_p2] : !pto.ptr -> !pto.vreg<64xui32> + %src_p3 = pto.vlds %ub_in[%offset_p3] : !pto.ptr -> !pto.vreg<64xui32> + %part_p0 = pto.vcvt %src_p0, %cvt_mask {sat = "SAT", part = "P0"} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<256xui8> + %part_p1 = pto.vcvt %src_p1, %cvt_mask {sat = "SAT", part = "P1"} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<256xui8> + %part_p2 = pto.vcvt %src_p2, %cvt_mask {sat = "SAT", part = "P2"} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<256xui8> + %part_p3 = pto.vcvt %src_p3, %cvt_mask {sat = "SAT", part = "P3"} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<256xui8> + %merged01 = pto.vor %part_p0, %part_p1, %full_mask : !pto.vreg<256xui8>, !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<256xui8> + %merged23 = pto.vor %part_p2, %part_p3, %full_mask : !pto.vreg<256xui8>, !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<256xui8> + %merged = pto.vor %merged01, %merged23, %full_mask : !pto.vreg<256xui8>, !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<256xui8> + pto.vsts %merged, %ub_out[%offset], %full_mask : !pto.vreg<256xui8>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c4_i64, %c256_i64, %c0_i64, %c256_i64, %c256_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/launch.cpp b/test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/launch.cpp new file mode 100644 index 000000000..a3b66083b --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/launch.cpp @@ -0,0 +1,49 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcvt_u32_to_u8_part_p0123_kernel( + __gm__ uint32_t *v1, __gm__ uint8_t *v2); + +void LaunchVcvt_u32_to_u8_part_p0123_kernel(uint32_t *v1, uint8_t *v2, + void *stream) { + vcvt_u32_to_u8_part_p0123_kernel<<<1, nullptr, stream>>>( + (__gm__ uint32_t *)v1, (__gm__ uint8_t *)v2); +} diff --git a/test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/main.cpp b/test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/main.cpp new file mode 100644 index 000000000..ae417b2d7 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vcvt-u32-to-u8-part-p0123/main.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/conversion/vcvt-u32-to-u8-part-p0123 +// family: conversion +// target_ops: pto.vcvt +// scenarios: u32-to-u8, sat, part-p0123 +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcvt_u32_to_u8_part_p0123_kernel(uint32_t *v1, uint8_t *v2, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint8_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + uint8_t *v2Host = nullptr; + uint8_t *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVcvt_u32_to_u8_part_p0123_kernel(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/conversion/vtrc-f16-rounding/compare.py b/test/vpto/cases/micro-op/conversion/vtrc-f16-rounding/compare.py new file mode 100644 index 000000000..2aec0a573 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vtrc-f16-rounding/compare.py @@ -0,0 +1,39 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float16, 1e-3) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vtrc-f16-rounding/golden.py b/test/vpto/cases/micro-op/conversion/vtrc-f16-rounding/golden.py new file mode 100644 index 000000000..62578c84a --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vtrc-f16-rounding/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 23 + + +def generate(output_dir: Path, seed: int) -> None: + del seed + values = np.array( + [-7.5, -3.25, -0.5, -0.0, 0.0, 0.5, 1.5, 6.75], + dtype=np.float16, + ) + v1 = np.resize(values, ROWS * COLS).reshape(ROWS, COLS).astype(np.float16) + v2 = np.zeros((ROWS, COLS), dtype=np.float16) + golden_v2 = np.trunc(v1.astype(np.float32)).astype(np.float16) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vtrc-f16-rounding validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vtrc-f16-rounding/kernel.pto b/test/vpto/cases/micro-op/conversion/vtrc-f16-rounding/kernel.pto new file mode 100644 index 000000000..1edef85ea --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vtrc-f16-rounding/kernel.pto @@ -0,0 +1,43 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vtrc_f16_rounding_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<128xf16> + %out = pto.vtrc %vec, %mask, "Z" : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<128xf16>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/conversion/vtrc-f16-rounding/launch.cpp b/test/vpto/cases/micro-op/conversion/vtrc-f16-rounding/launch.cpp new file mode 100644 index 000000000..ad3a8682a --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vtrc-f16-rounding/launch.cpp @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vtrc_f16_rounding_kernel_2d(__gm__ half *v1, + __gm__ half *v2); + +void LaunchVtrc_f16_rounding_kernel_2d(void *v1, void *v2, void *stream) { + vtrc_f16_rounding_kernel_2d<<<1, nullptr, stream>>>((__gm__ half *)v1, + (__gm__ half *)v2); +} diff --git a/test/vpto/cases/micro-op/conversion/vtrc-f16-rounding/main.cpp b/test/vpto/cases/micro-op/conversion/vtrc-f16-rounding/main.cpp new file mode 100644 index 000000000..604ec1198 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vtrc-f16-rounding/main.cpp @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVtrc_f16_rounding_kernel_2d(void *v1, void *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + void *v1Host = nullptr; + void *v1Device = nullptr; + void *v2Host = nullptr; + void *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost(&v1Host, fileSize_v1)); + ACL_CHECK(aclrtMallocHost(&v2Host, fileSize_v2)); + ACL_CHECK(aclrtMalloc(&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc(&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVtrc_f16_rounding_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/conversion/vtrc-f32-rounding/compare.py b/test/vpto/cases/micro-op/conversion/vtrc-f32-rounding/compare.py new file mode 100644 index 000000000..848571069 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vtrc-f32-rounding/compare.py @@ -0,0 +1,206 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 0.0001) and ok + ok = compare_bin("golden_v4.bin", "v4.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vtrc-f32-rounding/golden.py b/test/vpto/cases/micro-op/conversion/vtrc-f32-rounding/golden.py new file mode 100644 index 000000000..64260e9d6 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vtrc-f32-rounding/golden.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + v3 = np.zeros((ROWS, COLS), dtype=np.float32) + v4 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.rint(v1).astype(np.float32, copy=False) + golden_v3 = np.trunc(v1).astype(np.float32, copy=False) + golden_v4 = np.floor(v1).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + v4.reshape(-1).tofile(output_dir / "v4.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + golden_v4.reshape(-1).tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vtrc-f32-rounding validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vtrc-f32-rounding/kernel.pto b/test/vpto/cases/micro-op/conversion/vtrc-f32-rounding/kernel.pto new file mode 100644 index 000000000..e1756ace6 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vtrc-f32-rounding/kernel.pto @@ -0,0 +1,79 @@ +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vtrc_f32_rounding_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr, + %arg2: !pto.ptr, + %arg3: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_r = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %c8192_i64 = arith.constant 8192 : i64 + %c12288_i64 = arith.constant 12288 : i64 + %ub_z = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_f = pto.castptr %c12288_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out_r = pto.vtrc %vec, %mask, "R" : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %out_z = pto.vtrc %vec, %mask, "Z" : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %out_f = pto.vtrc %vec, %mask, "F" : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out_r, %ub_r[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + pto.vsts %out_z, %ub_z[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + pto.vsts %out_f, %ub_f[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_r, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_z, %arg2, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_f, %arg3, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/conversion/vtrc-f32-rounding/launch.cpp b/test/vpto/cases/micro-op/conversion/vtrc-f32-rounding/launch.cpp new file mode 100644 index 000000000..6e4f1a142 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vtrc-f32-rounding/launch.cpp @@ -0,0 +1,68 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vtrc_f32_rounding_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3, + __gm__ float *v4); + +void LaunchVtrc_f32_rounding_kernel_2d(float *v1, float *v2, float *v3, + float *v4, void *stream) { + vtrc_f32_rounding_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ float *)v3, + (__gm__ float *)v4); +} diff --git a/test/vpto/cases/micro-op/conversion/vtrc-f32-rounding/main.cpp b/test/vpto/cases/micro-op/conversion/vtrc-f32-rounding/main.cpp new file mode 100644 index 000000000..b86de567c --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vtrc-f32-rounding/main.cpp @@ -0,0 +1,147 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVtrc_f32_rounding_kernel_2d(float *v1, float *v2, float *v3, + float *v4, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + size_t elemCount_v4 = 1024; + size_t fileSize_v4 = elemCount_v4 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + float *v4Host = nullptr; + float *v4Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMallocHost((void **)(&v4Host), fileSize_v4)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v4Device, fileSize_v4, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ReadFile("./v4.bin", fileSize_v4, v4Host, fileSize_v4); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v4Device, fileSize_v4, v4Host, fileSize_v4, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVtrc_f32_rounding_kernel_2d(v1Device, v2Device, v3Device, v4Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(v4Host, fileSize_v4, v4Device, fileSize_v4, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + WriteFile("./v3.bin", v3Host, fileSize_v3); + WriteFile("./v4.bin", v4Host, fileSize_v4); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFree(v4Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + aclrtFreeHost(v4Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/conversion/vtrc-f32-special/compare.py b/test/vpto/cases/micro-op/conversion/vtrc-f32-special/compare.py new file mode 100644 index 000000000..38d1deb75 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vtrc-f32-special/compare.py @@ -0,0 +1,39 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vtrc-f32-special/golden.py b/test/vpto/cases/micro-op/conversion/vtrc-f32-special/golden.py new file mode 100644 index 000000000..f6251171d --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vtrc-f32-special/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + del seed + specials = np.array( + [-np.inf, -7.5, -0.0, 0.0, 1.0, np.inf, np.nan, 3.5], + dtype=np.float32, + ) + v1 = np.resize(specials, ROWS * COLS).reshape(ROWS, COLS).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.trunc(v1).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vtrc-f32-special validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vtrc-f32-special/kernel.pto b/test/vpto/cases/micro-op/conversion/vtrc-f32-special/kernel.pto new file mode 100644 index 000000000..3d834d945 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vtrc-f32-special/kernel.pto @@ -0,0 +1,42 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vtrc_f32_special_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vtrc %vec, %mask, "Z" : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/conversion/vtrc-f32-special/launch.cpp b/test/vpto/cases/micro-op/conversion/vtrc-f32-special/launch.cpp new file mode 100644 index 000000000..4d1ad9527 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vtrc-f32-special/launch.cpp @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vtrc_f32_special_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVtrc_f32_special_kernel_2d(float *v1, float *v2, void *stream) { + vtrc_f32_special_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/conversion/vtrc-f32-special/main.cpp b/test/vpto/cases/micro-op/conversion/vtrc-f32-special/main.cpp new file mode 100644 index 000000000..40f3aa5ae --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vtrc-f32-special/main.cpp @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVtrc_f32_special_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVtrc_f32_special_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/conversion/vtrc-rounding-boundary/compare.py b/test/vpto/cases/micro-op/conversion/vtrc-rounding-boundary/compare.py new file mode 100644 index 000000000..848571069 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vtrc-rounding-boundary/compare.py @@ -0,0 +1,206 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 0.0001) and ok + ok = compare_bin("golden_v4.bin", "v4.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vtrc-rounding-boundary/golden.py b/test/vpto/cases/micro-op/conversion/vtrc-rounding-boundary/golden.py new file mode 100644 index 000000000..a39eaa122 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vtrc-rounding-boundary/golden.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + del seed + boundary = np.array( + [-3.5, -3.0, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5], + dtype=np.float32, + ) + v1 = np.resize(boundary, ROWS * COLS).reshape(ROWS, COLS).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + v3 = np.zeros((ROWS, COLS), dtype=np.float32) + v4 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.rint(v1).astype(np.float32, copy=False) + golden_v3 = np.trunc(v1).astype(np.float32, copy=False) + golden_v4 = np.floor(v1).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + v4.reshape(-1).tofile(output_dir / "v4.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + golden_v4.reshape(-1).tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vtrc-f32-rounding validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/conversion/vtrc-rounding-boundary/kernel.pto b/test/vpto/cases/micro-op/conversion/vtrc-rounding-boundary/kernel.pto new file mode 100644 index 000000000..e1756ace6 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vtrc-rounding-boundary/kernel.pto @@ -0,0 +1,79 @@ +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vtrc_f32_rounding_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr, + %arg2: !pto.ptr, + %arg3: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_r = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %c8192_i64 = arith.constant 8192 : i64 + %c12288_i64 = arith.constant 12288 : i64 + %ub_z = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_f = pto.castptr %c12288_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out_r = pto.vtrc %vec, %mask, "R" : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %out_z = pto.vtrc %vec, %mask, "Z" : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %out_f = pto.vtrc %vec, %mask, "F" : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out_r, %ub_r[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + pto.vsts %out_z, %ub_z[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + pto.vsts %out_f, %ub_f[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_r, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_z, %arg2, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_f, %arg3, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/conversion/vtrc-rounding-boundary/launch.cpp b/test/vpto/cases/micro-op/conversion/vtrc-rounding-boundary/launch.cpp new file mode 100644 index 000000000..6e4f1a142 --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vtrc-rounding-boundary/launch.cpp @@ -0,0 +1,68 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vtrc_f32_rounding_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3, + __gm__ float *v4); + +void LaunchVtrc_f32_rounding_kernel_2d(float *v1, float *v2, float *v3, + float *v4, void *stream) { + vtrc_f32_rounding_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ float *)v3, + (__gm__ float *)v4); +} diff --git a/test/vpto/cases/micro-op/conversion/vtrc-rounding-boundary/main.cpp b/test/vpto/cases/micro-op/conversion/vtrc-rounding-boundary/main.cpp new file mode 100644 index 000000000..b86de567c --- /dev/null +++ b/test/vpto/cases/micro-op/conversion/vtrc-rounding-boundary/main.cpp @@ -0,0 +1,147 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVtrc_f32_rounding_kernel_2d(float *v1, float *v2, float *v3, + float *v4, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + size_t elemCount_v4 = 1024; + size_t fileSize_v4 = elemCount_v4 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + float *v4Host = nullptr; + float *v4Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMallocHost((void **)(&v4Host), fileSize_v4)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v4Device, fileSize_v4, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ReadFile("./v4.bin", fileSize_v4, v4Host, fileSize_v4); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v4Device, fileSize_v4, v4Host, fileSize_v4, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVtrc_f32_rounding_kernel_2d(v1Device, v2Device, v3Device, v4Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(v4Host, fileSize_v4, v4Device, fileSize_v4, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + WriteFile("./v3.bin", v3Host, fileSize_v3); + WriteFile("./v4.bin", v4Host, fileSize_v4); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFree(v4Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + aclrtFreeHost(v4Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/cube-matmul/cube-bridge-matmul/compare.py b/test/vpto/cases/micro-op/cube-matmul/cube-bridge-matmul/compare.py new file mode 100644 index 000000000..c1391455e --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/cube-bridge-matmul/compare.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.float32) + output = np.fromfile(output_path, dtype=np.float32) + if golden.shape != output.shape: + print(f"[ERROR] shape mismatch: {golden.shape} vs {output.shape}") + return False + if np.allclose(golden, output, atol=1e-2, rtol=1e-2): + return True + diff = np.where(np.abs(golden - output) > (1e-2 + 1e-2 * np.abs(golden)))[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] first mismatch at idx={idx}: golden={float(golden[idx])}, out={float(output[idx])}") + return False + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cube-matmul/cube-bridge-matmul/golden.py b/test/vpto/cases/micro-op/cube-matmul/cube-bridge-matmul/golden.py new file mode 100644 index 000000000..26c2a9a4f --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/cube-bridge-matmul/golden.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import argparse +from pathlib import Path +import numpy as np + +M = 16 +N = 16 +K = 16 + + +def generate(output_dir: Path) -> None: + a = np.eye(M, K, dtype=np.float16) + b = np.arange(K * N, dtype=np.float16).reshape(K, N) + c = np.zeros((M, N), dtype=np.float32) + a_f32 = a.astype(np.float32, copy=False) + b_f32 = b.astype(np.float32, copy=False) + golden_c = a_f32 @ b_f32 + + output_dir.mkdir(parents=True, exist_ok=True) + a.reshape(-1).tofile(output_dir / "v1.bin") + b.reshape(-1).tofile(output_dir / "v2.bin") + c.reshape(-1).tofile(output_dir / "v3.bin") + golden_c.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cube-matmul/cube-bridge-matmul/kernel.pto b/test/vpto/cases/micro-op/cube-matmul/cube-bridge-matmul/kernel.pto new file mode 100644 index 000000000..2a2e6812a --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/cube-bridge-matmul/kernel.pto @@ -0,0 +1,61 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/cube-matmul/cube-bridge-matmul +// family: micro-op/cube-matmul +// target_ops: pto.mte_gm_l1, pto.mte_l1_l0a, pto.mte_l1_l0b, pto.mte_l0c_gm +// scenarios: f16xf16->f32 cube bridge wrappers +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @cube_bridge_matmul_kernel(%a_gm: !pto.ptr, + %b_gm: !pto.ptr, + %c_gm: !pto.ptr) attributes {pto.kernel} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c16_i64 = arith.constant 16 : i64 + %c512_i64 = arith.constant 512 : i64 + + %l1_a = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l1_b = pto.castptr %c512_i64 : i64 -> !pto.ptr + %l0a = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0b = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0c = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.mte_gm_l1 %a_gm, %l1_a, %c512_i64 + nburst(%c1_i64, %c0_i64, %c0_i64) + loop(%c1_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, loop i64, i64, i64 + pto.mte_gm_l1 %b_gm, %l1_b, %c512_i64 + nburst(%c1_i64, %c0_i64, %c0_i64) + loop(%c1_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, loop i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + + pto.mte_l1_l0a %l1_a, %l0a, %c16_i64, %c16_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.mte_l1_l0b %l1_b, %l0b, %c16_i64, %c16_i64 {transpose = true} + : !pto.ptr, !pto.ptr, i64, i64 + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + + pto.mad %l0a, %l0b, %l0c, %c16_i64, %c16_i64, %c16_i64 + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + + pto.mte_l0c_gm %l0c, %c_gm, %c16_i64, %c16_i64, %c16_i64, %c16_i64, %c0_i64, %c0_i64, + nz2nd + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/cube-matmul/cube-bridge-matmul/launch.cpp b/test/vpto/cases/micro-op/cube-matmul/cube-bridge-matmul/launch.cpp new file mode 100644 index 000000000..0fc1ed2b1 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/cube-bridge-matmul/launch.cpp @@ -0,0 +1,49 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif + +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void cube_bridge_matmul_kernel(__gm__ __fp16 *a, + __gm__ __fp16 *b, + __gm__ float *c); + +void LaunchCube_bridge_matmul_kernel(__fp16 *a, __fp16 *b, float *c, void *stream) { + cube_bridge_matmul_kernel<<<1, nullptr, stream>>>((__gm__ __fp16 *)a, + (__gm__ __fp16 *)b, + (__gm__ float *)c); +} diff --git a/test/vpto/cases/micro-op/cube-matmul/cube-bridge-matmul/main.cpp b/test/vpto/cases/micro-op/cube-matmul/cube-bridge-matmul/main.cpp new file mode 100644 index 000000000..ba054595d --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/cube-bridge-matmul/main.cpp @@ -0,0 +1,127 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +#define FILE_CHECK(expr, path) \ + do { \ + if (!(expr)) { \ + std::fprintf(stderr, "[ERROR] file operation failed: %s (%s:%d)\n", \ + path, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchCube_bridge_matmul_kernel(__fp16 *a, __fp16 *b, float *c, void *stream); + +int main() { + constexpr size_t kM = 16; + constexpr size_t kN = 16; + constexpr size_t kK = 16; + constexpr size_t aElem = kM * kK; + constexpr size_t bElem = kK * kN; + constexpr size_t cElem = kM * kN; + + constexpr size_t aSize = aElem * sizeof(__fp16); + constexpr size_t bSize = bElem * sizeof(__fp16); + constexpr size_t cSize = cElem * sizeof(float); + + __fp16 *aHost = nullptr; + __fp16 *bHost = nullptr; + float *cHost = nullptr; + __fp16 *aDevice = nullptr; + __fp16 *bDevice = nullptr; + float *cDevice = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + size_t inputSize = 0; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&aHost), aSize)); + ACL_CHECK(aclrtMallocHost((void **)(&bHost), bSize)); + ACL_CHECK(aclrtMallocHost((void **)(&cHost), cSize)); + ACL_CHECK(aclrtMalloc((void **)&aDevice, aSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&bDevice, bSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&cDevice, cSize, ACL_MEM_MALLOC_HUGE_FIRST)); + + inputSize = aSize; + FILE_CHECK(ReadFile("./v1.bin", inputSize, aHost, aSize) && inputSize == aSize, + "./v1.bin"); + inputSize = bSize; + FILE_CHECK(ReadFile("./v2.bin", inputSize, bHost, bSize) && inputSize == bSize, + "./v2.bin"); + inputSize = cSize; + FILE_CHECK(ReadFile("./v3.bin", inputSize, cHost, cSize) && inputSize == cSize, + "./v3.bin"); + + ACL_CHECK(aclrtMemcpy(aDevice, aSize, aHost, aSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(bDevice, bSize, bHost, bSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(cDevice, cSize, cHost, cSize, ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchCube_bridge_matmul_kernel(aDevice, bDevice, cDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + + ACL_CHECK(aclrtMemcpy(cHost, cSize, cDevice, cSize, ACL_MEMCPY_DEVICE_TO_HOST)); + FILE_CHECK(WriteFile("./v3.bin", cHost, cSize), "./v3.bin"); + +cleanup: + aclrtFree(aDevice); + aclrtFree(bDevice); + aclrtFree(cDevice); + aclrtFreeHost(aHost); + aclrtFreeHost(bHost); + aclrtFreeHost(cHost); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/cube-matmul/cube-bridge-store-nz2dn-ncdhw/compare.py b/test/vpto/cases/micro-op/cube-matmul/cube-bridge-store-nz2dn-ncdhw/compare.py new file mode 100644 index 000000000..7643407a2 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/cube-bridge-store-nz2dn-ncdhw/compare.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.float32) + output = np.fromfile(output_path, dtype=np.float32) + if golden.shape != output.shape: + print(f"[ERROR] shape mismatch: {golden.shape} vs {output.shape}") + return False + if np.allclose(golden, output, atol=1e-2, rtol=1e-2): + return True + diff = np.where(np.abs(golden - output) > (1e-2 + 1e-2 * np.abs(golden)))[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] first mismatch at idx={idx}: golden={float(golden[idx])}, out={float(output[idx])}") + return False + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cube-matmul/cube-bridge-store-nz2dn-ncdhw/golden.py b/test/vpto/cases/micro-op/cube-matmul/cube-bridge-store-nz2dn-ncdhw/golden.py new file mode 100644 index 000000000..a139ef92c --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/cube-bridge-store-nz2dn-ncdhw/golden.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path +import numpy as np + +M = 16 +K = 16 +N = 16 +DST_STRIDE = 48 + + +def generate(output_dir: Path) -> None: + a = np.eye(M, K, dtype=np.float16) + b = np.arange(K * N, dtype=np.float16).reshape(K, N) + c = np.zeros((M, DST_STRIDE), dtype=np.float32) + golden_c = np.zeros((M, DST_STRIDE), dtype=np.float32) + a_f32 = a.astype(np.float32, copy=False) + b_f32 = b.astype(np.float32, copy=False) + golden_c[:, :N] = (a_f32 @ b_f32).T + + output_dir.mkdir(parents=True, exist_ok=True) + a.reshape(-1).tofile(output_dir / "v1.bin") + b.reshape(-1).tofile(output_dir / "v2.bin") + c.reshape(-1).tofile(output_dir / "v3.bin") + golden_c.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cube-matmul/cube-bridge-store-nz2dn-ncdhw/kernel.pto b/test/vpto/cases/micro-op/cube-matmul/cube-bridge-store-nz2dn-ncdhw/kernel.pto new file mode 100644 index 000000000..87616427e --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/cube-bridge-store-nz2dn-ncdhw/kernel.pto @@ -0,0 +1,62 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/cube-matmul/cube-bridge-store-nz2dn-ncdhw +// family: micro-op/cube-matmul +// target_ops: pto.mte_gm_l1, pto.mte_l1_l0a, pto.mte_l1_l0b, pto.mte_l0c_gm +// scenarios: canonical nz2dn write-back configured as an NCDHW-style store +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @cube_bridge_store_nz2dn_ncdhw_kernel(%a_gm: !pto.ptr, + %b_gm: !pto.ptr, + %c_gm: !pto.ptr) attributes {pto.kernel} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c16_i64 = arith.constant 16 : i64 + %c48_i64 = arith.constant 48 : i64 + %c512_i64 = arith.constant 512 : i64 + + %l1_a = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l1_b = pto.castptr %c512_i64 : i64 -> !pto.ptr + %l0a = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0b = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0c = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.mte_gm_l1 %a_gm, %l1_a, %c512_i64 + nburst(%c1_i64, %c0_i64, %c0_i64) + loop(%c1_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, loop i64, i64, i64 + pto.mte_gm_l1 %b_gm, %l1_b, %c512_i64 + nburst(%c1_i64, %c0_i64, %c0_i64) + loop(%c1_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, loop i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + + pto.mte_l1_l0a %l1_a, %l0a, %c16_i64, %c16_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.mte_l1_l0b %l1_b, %l0b, %c16_i64, %c16_i64 {transpose = true} + : !pto.ptr, !pto.ptr, i64, i64 + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + + pto.mad %l0a, %l0b, %l0c, %c16_i64, %c16_i64, %c16_i64 + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + + pto.mte_l0c_gm %l0c, %c_gm, %c16_i64, %c16_i64, %c16_i64, %c48_i64, %c0_i64, %c0_i64, + nz2dn(%c1_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/cube-matmul/cube-bridge-store-nz2dn-ncdhw/launch.cpp b/test/vpto/cases/micro-op/cube-matmul/cube-bridge-store-nz2dn-ncdhw/launch.cpp new file mode 100644 index 000000000..abf97258d --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/cube-bridge-store-nz2dn-ncdhw/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif + +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void cube_bridge_store_nz2dn_ncdhw_kernel( + __gm__ __fp16 *a, __gm__ __fp16 *b, __gm__ float *c); + +void LaunchCube_bridge_store_nz2dn_ncdhw_kernel(__fp16 *a, __fp16 *b, float *c, + void *stream) { + cube_bridge_store_nz2dn_ncdhw_kernel<<<1, nullptr, stream>>>( + (__gm__ __fp16 *)a, (__gm__ __fp16 *)b, (__gm__ float *)c); +} diff --git a/test/vpto/cases/micro-op/cube-matmul/cube-bridge-store-nz2dn-ncdhw/main.cpp b/test/vpto/cases/micro-op/cube-matmul/cube-bridge-store-nz2dn-ncdhw/main.cpp new file mode 100644 index 000000000..49f2579a7 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/cube-bridge-store-nz2dn-ncdhw/main.cpp @@ -0,0 +1,129 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +#define FILE_CHECK(expr, path) \ + do { \ + if (!(expr)) { \ + std::fprintf(stderr, "[ERROR] file operation failed: %s (%s:%d)\n", \ + path, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchCube_bridge_store_nz2dn_ncdhw_kernel(__fp16 *a, __fp16 *b, float *c, + void *stream); + +int main() { + constexpr size_t kM = 16; + constexpr size_t kN = 16; + constexpr size_t kK = 16; + constexpr size_t kDstStride = 48; + constexpr size_t aElem = kM * kK; + constexpr size_t bElem = kK * kN; + constexpr size_t cElem = kM * kDstStride; + + constexpr size_t aSize = aElem * sizeof(__fp16); + constexpr size_t bSize = bElem * sizeof(__fp16); + constexpr size_t cSize = cElem * sizeof(float); + + __fp16 *aHost = nullptr; + __fp16 *bHost = nullptr; + float *cHost = nullptr; + __fp16 *aDevice = nullptr; + __fp16 *bDevice = nullptr; + float *cDevice = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + size_t inputSize = 0; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&aHost), aSize)); + ACL_CHECK(aclrtMallocHost((void **)(&bHost), bSize)); + ACL_CHECK(aclrtMallocHost((void **)(&cHost), cSize)); + ACL_CHECK(aclrtMalloc((void **)&aDevice, aSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&bDevice, bSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&cDevice, cSize, ACL_MEM_MALLOC_HUGE_FIRST)); + + inputSize = aSize; + FILE_CHECK(ReadFile("./v1.bin", inputSize, aHost, aSize) && inputSize == aSize, + "./v1.bin"); + inputSize = bSize; + FILE_CHECK(ReadFile("./v2.bin", inputSize, bHost, bSize) && inputSize == bSize, + "./v2.bin"); + inputSize = cSize; + FILE_CHECK(ReadFile("./v3.bin", inputSize, cHost, cSize) && inputSize == cSize, + "./v3.bin"); + + ACL_CHECK(aclrtMemcpy(aDevice, aSize, aHost, aSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(bDevice, bSize, bHost, bSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(cDevice, cSize, cHost, cSize, ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchCube_bridge_store_nz2dn_ncdhw_kernel(aDevice, bDevice, cDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + + ACL_CHECK(aclrtMemcpy(cHost, cSize, cDevice, cSize, ACL_MEMCPY_DEVICE_TO_HOST)); + FILE_CHECK(WriteFile("./v3.bin", cHost, cSize), "./v3.bin"); + +cleanup: + aclrtFree(aDevice); + aclrtFree(bDevice); + aclrtFree(cDevice); + aclrtFreeHost(aHost); + aclrtFreeHost(bHost); + aclrtFreeHost(cHost); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/cube-matmul/cube-bridge-store-nz2dn-nchw/compare.py b/test/vpto/cases/micro-op/cube-matmul/cube-bridge-store-nz2dn-nchw/compare.py new file mode 100644 index 000000000..7643407a2 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/cube-bridge-store-nz2dn-nchw/compare.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.float32) + output = np.fromfile(output_path, dtype=np.float32) + if golden.shape != output.shape: + print(f"[ERROR] shape mismatch: {golden.shape} vs {output.shape}") + return False + if np.allclose(golden, output, atol=1e-2, rtol=1e-2): + return True + diff = np.where(np.abs(golden - output) > (1e-2 + 1e-2 * np.abs(golden)))[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] first mismatch at idx={idx}: golden={float(golden[idx])}, out={float(output[idx])}") + return False + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cube-matmul/cube-bridge-store-nz2dn-nchw/golden.py b/test/vpto/cases/micro-op/cube-matmul/cube-bridge-store-nz2dn-nchw/golden.py new file mode 100644 index 000000000..b804f7ae5 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/cube-bridge-store-nz2dn-nchw/golden.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path +import numpy as np + +M = 16 +K = 16 +N = 16 +DST_STRIDE = 32 + + +def generate(output_dir: Path) -> None: + a = np.eye(M, K, dtype=np.float16) + b = np.arange(K * N, dtype=np.float16).reshape(K, N) + c = np.zeros((M, DST_STRIDE), dtype=np.float32) + golden_c = np.zeros((M, DST_STRIDE), dtype=np.float32) + a_f32 = a.astype(np.float32, copy=False) + b_f32 = b.astype(np.float32, copy=False) + golden_c[:, :N] = (a_f32 @ b_f32).T + + output_dir.mkdir(parents=True, exist_ok=True) + a.reshape(-1).tofile(output_dir / "v1.bin") + b.reshape(-1).tofile(output_dir / "v2.bin") + c.reshape(-1).tofile(output_dir / "v3.bin") + golden_c.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cube-matmul/cube-bridge-store-nz2dn-nchw/kernel.pto b/test/vpto/cases/micro-op/cube-matmul/cube-bridge-store-nz2dn-nchw/kernel.pto new file mode 100644 index 000000000..8ce576474 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/cube-bridge-store-nz2dn-nchw/kernel.pto @@ -0,0 +1,62 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/cube-matmul/cube-bridge-store-nz2dn-nchw +// family: micro-op/cube-matmul +// target_ops: pto.mte_gm_l1, pto.mte_l1_l0a, pto.mte_l1_l0b, pto.mte_l0c_gm +// scenarios: canonical nz2dn write-back configured as an NCHW-style store +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @cube_bridge_store_nz2dn_nchw_kernel(%a_gm: !pto.ptr, + %b_gm: !pto.ptr, + %c_gm: !pto.ptr) attributes {pto.kernel} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + %c512_i64 = arith.constant 512 : i64 + + %l1_a = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l1_b = pto.castptr %c512_i64 : i64 -> !pto.ptr + %l0a = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0b = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0c = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.mte_gm_l1 %a_gm, %l1_a, %c512_i64 + nburst(%c1_i64, %c0_i64, %c0_i64) + loop(%c1_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, loop i64, i64, i64 + pto.mte_gm_l1 %b_gm, %l1_b, %c512_i64 + nburst(%c1_i64, %c0_i64, %c0_i64) + loop(%c1_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, loop i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + + pto.mte_l1_l0a %l1_a, %l0a, %c16_i64, %c16_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.mte_l1_l0b %l1_b, %l0b, %c16_i64, %c16_i64 {transpose = true} + : !pto.ptr, !pto.ptr, i64, i64 + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + + pto.mad %l0a, %l0b, %l0c, %c16_i64, %c16_i64, %c16_i64 + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + + pto.mte_l0c_gm %l0c, %c_gm, %c16_i64, %c16_i64, %c16_i64, %c32_i64, %c0_i64, %c0_i64, + nz2dn(%c1_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/cube-matmul/cube-bridge-store-nz2dn-nchw/launch.cpp b/test/vpto/cases/micro-op/cube-matmul/cube-bridge-store-nz2dn-nchw/launch.cpp new file mode 100644 index 000000000..991cdc024 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/cube-bridge-store-nz2dn-nchw/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif + +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void cube_bridge_store_nz2dn_nchw_kernel( + __gm__ __fp16 *a, __gm__ __fp16 *b, __gm__ float *c); + +void LaunchCube_bridge_store_nz2dn_nchw_kernel(__fp16 *a, __fp16 *b, float *c, + void *stream) { + cube_bridge_store_nz2dn_nchw_kernel<<<1, nullptr, stream>>>( + (__gm__ __fp16 *)a, (__gm__ __fp16 *)b, (__gm__ float *)c); +} diff --git a/test/vpto/cases/micro-op/cube-matmul/cube-bridge-store-nz2dn-nchw/main.cpp b/test/vpto/cases/micro-op/cube-matmul/cube-bridge-store-nz2dn-nchw/main.cpp new file mode 100644 index 000000000..a959f8136 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/cube-bridge-store-nz2dn-nchw/main.cpp @@ -0,0 +1,129 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +#define FILE_CHECK(expr, path) \ + do { \ + if (!(expr)) { \ + std::fprintf(stderr, "[ERROR] file operation failed: %s (%s:%d)\n", \ + path, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchCube_bridge_store_nz2dn_nchw_kernel(__fp16 *a, __fp16 *b, float *c, + void *stream); + +int main() { + constexpr size_t kM = 16; + constexpr size_t kN = 16; + constexpr size_t kK = 16; + constexpr size_t kDstStride = 32; + constexpr size_t aElem = kM * kK; + constexpr size_t bElem = kK * kN; + constexpr size_t cElem = kM * kDstStride; + + constexpr size_t aSize = aElem * sizeof(__fp16); + constexpr size_t bSize = bElem * sizeof(__fp16); + constexpr size_t cSize = cElem * sizeof(float); + + __fp16 *aHost = nullptr; + __fp16 *bHost = nullptr; + float *cHost = nullptr; + __fp16 *aDevice = nullptr; + __fp16 *bDevice = nullptr; + float *cDevice = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + size_t inputSize = 0; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&aHost), aSize)); + ACL_CHECK(aclrtMallocHost((void **)(&bHost), bSize)); + ACL_CHECK(aclrtMallocHost((void **)(&cHost), cSize)); + ACL_CHECK(aclrtMalloc((void **)&aDevice, aSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&bDevice, bSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&cDevice, cSize, ACL_MEM_MALLOC_HUGE_FIRST)); + + inputSize = aSize; + FILE_CHECK(ReadFile("./v1.bin", inputSize, aHost, aSize) && inputSize == aSize, + "./v1.bin"); + inputSize = bSize; + FILE_CHECK(ReadFile("./v2.bin", inputSize, bHost, bSize) && inputSize == bSize, + "./v2.bin"); + inputSize = cSize; + FILE_CHECK(ReadFile("./v3.bin", inputSize, cHost, cSize) && inputSize == cSize, + "./v3.bin"); + + ACL_CHECK(aclrtMemcpy(aDevice, aSize, aHost, aSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(bDevice, bSize, bHost, bSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(cDevice, cSize, cHost, cSize, ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchCube_bridge_store_nz2dn_nchw_kernel(aDevice, bDevice, cDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + + ACL_CHECK(aclrtMemcpy(cHost, cSize, cDevice, cSize, ACL_MEMCPY_DEVICE_TO_HOST)); + FILE_CHECK(WriteFile("./v3.bin", cHost, cSize), "./v3.bin"); + +cleanup: + aclrtFree(aDevice); + aclrtFree(bDevice); + aclrtFree(cDevice); + aclrtFreeHost(aHost); + aclrtFreeHost(bHost); + aclrtFreeHost(cHost); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/cube-matmul/cube-load-frac-layouts/compare.py b/test/vpto/cases/micro-op/cube-matmul/cube-load-frac-layouts/compare.py new file mode 100644 index 000000000..ae57f3060 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/cube-load-frac-layouts/compare.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import os +import sys + +import numpy as np + + +CASES = ( + ("golden_nd2nz_case1.bin", "out_nd2nz_case1.bin"), +) + + +def compare_one(golden_path: str, output_path: str) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + print(f"[ERROR] missing file: {golden_path} or {output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.float32) + output = np.fromfile(output_path, dtype=np.float32) + if golden.shape != output.shape: + print(f"[ERROR] shape mismatch for {output_path}: {golden.shape} vs {output.shape}") + return False + if np.allclose(golden, output, atol=1e-3, rtol=1e-3): + return True + diff = np.nonzero(~np.isclose(golden, output, atol=1e-3, rtol=1e-3))[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] {output_path} mismatch at idx={idx}: golden={golden[idx]}, out={output[idx]}") + return False + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = all(compare_one(golden_path, output_path) for golden_path, output_path in CASES) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cube-matmul/cube-load-frac-layouts/golden.py b/test/vpto/cases/micro-op/cube-matmul/cube-load-frac-layouts/golden.py new file mode 100644 index 000000000..aa764fb3d --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/cube-load-frac-layouts/golden.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + + +def nchw_to_nc1hwc0(nchw_tensor: np.ndarray, c0: int = 16) -> np.ndarray: + n, c, h, w = nchw_tensor.shape + c1 = (c + c0 - 1) // c0 + padded = np.pad(nchw_tensor, ((0, 0), (0, c1 * c0 - c), (0, 0), (0, 0))) + return np.transpose(padded.reshape(n, c1, c0, h, w), (0, 1, 3, 4, 2)) + + +def nchw_to_c1hw_n16_16_c0(nchw_tensor: np.ndarray, c0: int = 16) -> np.ndarray: + n, c, h, w = nchw_tensor.shape + n_pad = ((n + 15) // 16) * 16 + c_pad = ((c + c0 - 1) // c0) * c0 + c1 = c_pad // c0 + padded = np.pad(nchw_tensor, ((0, n_pad - n), (0, c_pad - c), (0, 0), (0, 0))) + nc1c0hw = padded.reshape(n_pad, c1, c0, h, w) + n16 = nc1c0hw.reshape(n_pad // 16, 16, c1, c0, h, w) + return np.transpose(n16, (2, 4, 5, 0, 1, 3)).reshape(c1 * h * w, n_pad // 16, 16, c0) + + +def ncdhw_to_ndc1hwc0(ncdhw_tensor: np.ndarray, c0: int = 16) -> np.ndarray: + n, c, d, h, w = ncdhw_tensor.shape + c1 = (c + c0 - 1) // c0 + padded = np.pad(ncdhw_tensor, ((0, 0), (0, c1 * c0 - c), (0, 0), (0, 0), (0, 0))) + nc1c0dhw = padded.reshape(n, c1, c0, d, h, w) + return np.transpose(nc1c0dhw, (0, 3, 1, 4, 5, 2)) + + +def ncdhw_to_c1dhw_n16_16_c0(ncdhw_tensor: np.ndarray, c0: int = 16) -> np.ndarray: + n, c, d, h, w = ncdhw_tensor.shape + n_pad = ((n + 15) // 16) * 16 + c_pad = ((c + c0 - 1) // c0) * c0 + c1 = c_pad // c0 + padded = np.pad(ncdhw_tensor, ((0, n_pad - n), (0, c_pad - c), (0, 0), (0, 0), (0, 0))) + nc1c0dhw = padded.reshape(n_pad, c1, c0, d, h, w) + n16 = nc1c0dhw.reshape(n_pad // 16, 16, c1, c0, d, h, w) + return np.transpose(n16, (2, 4, 5, 6, 0, 1, 3)).reshape(c1 * d * h * w, n_pad // 16, 16, c0) + + +def write(path: Path, array: np.ndarray) -> None: + array.reshape(-1).tofile(path) + + +def generate(output_dir: Path) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + + lhs_nd2nz_case1 = (np.arange(40 * 50, dtype=np.float16).reshape(40, 50) * np.float16(0.5) + + np.float16(17)).astype(np.float16) + nd2nz_case1 = (np.arange(50 * 60, dtype=np.float16).reshape(50, 60) * np.float16(0.25) + + np.float16(3)).astype(np.float16) + lhs = (np.arange(16 * 16, dtype=np.float16).reshape(16, 16) + 2001).astype(np.float16) + dn2nz = (np.arange(16 * 16, dtype=np.float16).reshape(16, 16) + 301).astype(np.float16) + nchw = (np.arange(1 * 10 * 1 * 16, dtype=np.float16).reshape(1, 10, 1, 16) + 601).astype(np.float16) + nchw_fz4d = (np.arange(5 * 16, dtype=np.float16).reshape(5, 16, 1, 1) + 901).astype(np.float16) + ncdhw = (np.arange(1 * 7 * 1 * 1 * 16, dtype=np.float16).reshape(1, 7, 1, 1, 16) + 1201).astype(np.float16) + ncdhw_fz3d = (np.arange(3 * 16, dtype=np.float16).reshape(3, 16, 1, 1, 1) + 1501).astype(np.float16) + + lhs_nd2nz_case1_f32 = lhs_nd2nz_case1.astype(np.float32) + lhs_f32 = lhs.astype(np.float32) + golden_nd2nz = lhs_nd2nz_case1_f32 @ nd2nz_case1.astype(np.float32) + golden_dn2nz = lhs_f32 @ dn2nz.astype(np.float32) + golden_nchw_nc1hwc0 = lhs_f32 @ nchw_to_nc1hwc0(nchw).reshape(16, 16).astype(np.float32) + golden_nchw_fz4d = lhs_f32 @ nchw_to_c1hw_n16_16_c0(nchw_fz4d).reshape(16, 16).astype(np.float32) + golden_ncdhw_ndc1hwc0 = lhs_f32 @ ncdhw_to_ndc1hwc0(ncdhw).reshape(16, 16).astype(np.float32) + golden_ncdhw_fz3d = lhs_f32 @ ncdhw_to_c1dhw_n16_16_c0(ncdhw_fz3d).reshape(16, 16).astype(np.float32) + + zeros_nd2nz = np.zeros((40, 60), dtype=np.float32) + zeros = np.zeros((16, 16), dtype=np.float32) + + write(output_dir / "lhs_nd2nz_case1.bin", lhs_nd2nz_case1) + write(output_dir / "src_nd2nz_case1.bin", nd2nz_case1) + write(output_dir / "identity.bin", lhs) + write(output_dir / "src_dn2nz.bin", dn2nz) + write(output_dir / "src_nchw_nc1hwc0.bin", nchw) + write(output_dir / "src_nchw_fz4d.bin", nchw_fz4d) + write(output_dir / "src_ncdhw_ndc1hwc0.bin", ncdhw) + write(output_dir / "src_ncdhw_fz3d.bin", ncdhw_fz3d) + + write(output_dir / "out_nd2nz_case1.bin", zeros_nd2nz) + write(output_dir / "out_dn2nz.bin", zeros) + write(output_dir / "out_nchw_nc1hwc0.bin", zeros) + write(output_dir / "out_nchw_fz4d.bin", zeros) + write(output_dir / "out_ncdhw_ndc1hwc0.bin", zeros) + write(output_dir / "out_ncdhw_fz3d.bin", zeros) + + write(output_dir / "golden_nd2nz_case1.bin", golden_nd2nz) + write(output_dir / "golden_dn2nz.bin", golden_dn2nz) + write(output_dir / "golden_nchw_nc1hwc0.bin", golden_nchw_nc1hwc0) + write(output_dir / "golden_nchw_fz4d.bin", golden_nchw_fz4d) + write(output_dir / "golden_ncdhw_ndc1hwc0.bin", golden_ncdhw_ndc1hwc0) + write(output_dir / "golden_ncdhw_fz3d.bin", golden_ncdhw_fz3d) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cube-matmul/cube-load-frac-layouts/kernel.pto b/test/vpto/cases/micro-op/cube-matmul/cube-load-frac-layouts/kernel.pto new file mode 100644 index 000000000..0a3d16ee9 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/cube-load-frac-layouts/kernel.pto @@ -0,0 +1,337 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/cube-matmul/cube-load-frac-layouts +// family: micro-op/cube-matmul +// target_ops: pto.mte_gm_l1_frac, pto.copy_gm_to_cbuf, pto.mte_l1_l0a, +// pto.mte_l1_l0b, pto.mad, pto.mte_l0c_gm +// scenarios: nd2nz, dn2nz, nchw->nc1hwc0, nchw->fractalz4d, +// ncdhw->ndc1hwc0, ncdhw->fractalz3d +// ----------------------------------------------------------------------------- + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @cube_load_frac_nd2nz_kernel(%src_gm: !pto.ptr, + %id_gm: !pto.ptr, + %out_gm: !pto.ptr) attributes {pto.kernel} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c40_i64 = arith.constant 40 : i64 + %c48_i64 = arith.constant 48 : i64 + %c50_i64 = arith.constant 50 : i64 + %c60_i64 = arith.constant 60 : i64 + %c64_i64 = arith.constant 64 : i64 + %c100_i64 = arith.constant 100 : i64 + %c120_i64 = arith.constant 120 : i64 + %c2400_i64 = arith.constant 2400 : i64 + %c46080_i64 = arith.constant 46080 : i64 + %c65536_i64 = arith.constant 65536 : i64 + %false = arith.constant false + + %mat_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %mat_id = pto.castptr %c65536_i64 : i64 -> !pto.ptr + %l0a = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0b = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0c = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.mte_gm_l1_frac %id_gm, %mat_src, nd2nz, + shape(%c40_i64, %c50_i64), + src_layout(%c100_i64), + dst_group(%c1_i64, %c1_i64, %c48_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1_frac %src_gm, %mat_id, nd2nz, + shape(%c50_i64, %c60_i64), + src_layout(%c120_i64), + dst_group(%c1_i64, %c1_i64, %c64_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + + pto.mte_l1_l0a %mat_src, %l0a, %c40_i64, %c50_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.mte_l1_l0b %mat_id, %l0b, %c50_i64, %c60_i64 {transpose = true} + : !pto.ptr, !pto.ptr, i64, i64 + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.mad %l0a, %l0b, %l0c, %c40_i64, %c60_i64, %c50_i64 + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + + pto.mte_l0c_gm %l0c, %out_gm, %c40_i64, %c60_i64, %c48_i64, %c60_i64, %c0_i64, %c0_i64, + nz2nd, + loop3(%c1_i64, %c46080_i64, %c2400_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } + + func.func @cube_load_frac_dn2nz_kernel(%id_gm: !pto.ptr, + %src_gm: !pto.ptr, + %out_gm: !pto.ptr) attributes {pto.kernel} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c2_i64 = arith.constant 2 : i64 + %c8_i64 = arith.constant 8 : i64 + %c16_i64 = arith.constant 16 : i64 + %c512_i64 = arith.constant 512 : i64 + %false = arith.constant false + + %mat_id = pto.castptr %c0_i64 : i64 -> !pto.ptr + %mat_src = pto.castptr %c512_i64 : i64 -> !pto.ptr + %l0a = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0b = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0c = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.copy_gm_to_cbuf %id_gm, %mat_src, %c16_i64, %c16_i64, %c0_i64, %c0_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_gm_l1_frac %src_gm, %mat_id, dn2nz, + shape(%c8_i64, %c16_i64), + src_layout(%c2_i64), + dst_group(%c1_i64, %c1_i64, %c8_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, dn2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + + pto.mte_l1_l0a %mat_src, %l0a, %c16_i64, %c16_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.mte_l1_l0b %mat_id, %l0b, %c16_i64, %c16_i64 {transpose = true} + : !pto.ptr, !pto.ptr, i64, i64 + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.mad %l0a, %l0b, %l0c, %c16_i64, %c16_i64, %c16_i64 + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + + pto.mte_l0c_gm %l0c, %out_gm, %c16_i64, %c16_i64, %c16_i64, %c16_i64, %c0_i64, %c0_i64, + nz2nd + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } + + func.func @cube_load_frac_nchw_nc1hwc0_kernel(%id_gm: !pto.ptr, + %src_gm: !pto.ptr, + %out_gm: !pto.ptr) attributes {pto.kernel} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c10_i64 = arith.constant 10 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + %c512_i64 = arith.constant 512 : i64 + %false = arith.constant false + + %mat_id = pto.castptr %c0_i64 : i64 -> !pto.ptr + %mat_src = pto.castptr %c512_i64 : i64 -> !pto.ptr + %l0a = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0b = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0c = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.copy_gm_to_cbuf %id_gm, %mat_src, %c16_i64, %c16_i64, %c0_i64, %c0_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_gm_l1_frac %src_gm, %mat_id, dn2nz, + shape(%c16_i64, %c10_i64), + src_layout(%c32_i64), + dst_group(%c1_i64, %c1_i64, %c16_i64, %c16_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, dn2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + + pto.mte_l1_l0a %mat_src, %l0a, %c16_i64, %c16_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.mte_l1_l0b %mat_id, %l0b, %c16_i64, %c10_i64 {transpose = true} + : !pto.ptr, !pto.ptr, i64, i64 + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.mad %l0a, %l0b, %l0c, %c16_i64, %c16_i64, %c16_i64 + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + + pto.mte_l0c_gm %l0c, %out_gm, %c16_i64, %c16_i64, %c16_i64, %c16_i64, %c0_i64, %c0_i64, + nz2nd + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } + + func.func @cube_load_frac_nchw_fz4d_kernel(%id_gm: !pto.ptr, + %src_gm: !pto.ptr, + %out_gm: !pto.ptr) attributes {pto.kernel} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c2_i64 = arith.constant 2 : i64 + %c5_i64 = arith.constant 5 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + %c80_i64 = arith.constant 80 : i64 + %c160_i64 = arith.constant 160 : i64 + %c512_i64 = arith.constant 512 : i64 + %false = arith.constant false + + %mat_id = pto.castptr %c0_i64 : i64 -> !pto.ptr + %mat_src = pto.castptr %c512_i64 : i64 -> !pto.ptr + %l0a = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0b = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0c = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.copy_gm_to_cbuf %id_gm, %mat_src, %c16_i64, %c16_i64, %c0_i64, %c0_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_gm_l1_frac %src_gm, %mat_id, dn2nz, + shape(%c16_i64, %c5_i64), + src_layout(%c32_i64, %c160_i64), + dst_group(%c1_i64, %c16_i64, %c80_i64, %c1_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, dn2nz, + shape i64, i64, src_layout(i64, i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + + pto.mte_l1_l0a %mat_src, %l0a, %c16_i64, %c16_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.mte_l1_l0b %mat_id, %l0b, %c16_i64, %c16_i64 {transpose = true} + : !pto.ptr, !pto.ptr, i64, i64 + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.mad %l0a, %l0b, %l0c, %c16_i64, %c16_i64, %c16_i64 + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + + pto.mte_l0c_gm %l0c, %out_gm, %c16_i64, %c16_i64, %c16_i64, %c16_i64, %c0_i64, %c0_i64, + nz2nd + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } + + func.func @cube_load_frac_ncdhw_ndc1hwc0_kernel(%id_gm: !pto.ptr, + %src_gm: !pto.ptr, + %out_gm: !pto.ptr) attributes {pto.kernel} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c7_i64 = arith.constant 7 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + %c512_i64 = arith.constant 512 : i64 + %false = arith.constant false + + %mat_id = pto.castptr %c0_i64 : i64 -> !pto.ptr + %mat_src = pto.castptr %c512_i64 : i64 -> !pto.ptr + %l0a = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0b = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0c = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.copy_gm_to_cbuf %id_gm, %mat_src, %c16_i64, %c16_i64, %c0_i64, %c0_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_gm_l1_frac %src_gm, %mat_id, dn2nz, + shape(%c16_i64, %c7_i64), + src_layout(%c32_i64), + dst_group(%c1_i64, %c1_i64, %c16_i64, %c16_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, dn2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + + pto.mte_l1_l0a %mat_src, %l0a, %c16_i64, %c16_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.mte_l1_l0b %mat_id, %l0b, %c16_i64, %c7_i64 {transpose = true} + : !pto.ptr, !pto.ptr, i64, i64 + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.mad %l0a, %l0b, %l0c, %c16_i64, %c16_i64, %c16_i64 + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + + pto.mte_l0c_gm %l0c, %out_gm, %c16_i64, %c16_i64, %c16_i64, %c16_i64, %c0_i64, %c0_i64, + nz2nd + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } + + func.func @cube_load_frac_ncdhw_fz3d_kernel(%id_gm: !pto.ptr, + %src_gm: !pto.ptr, + %out_gm: !pto.ptr) attributes {pto.kernel} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c2_i64 = arith.constant 2 : i64 + %c3_i64 = arith.constant 3 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + %c512_i64 = arith.constant 512 : i64 + %c68720525568_i64 = arith.constant 68720525568 : i64 + %false = arith.constant false + + %mat_id = pto.castptr %c0_i64 : i64 -> !pto.ptr + %mat_src = pto.castptr %c512_i64 : i64 -> !pto.ptr + %l0a = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0b = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0c = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.copy_gm_to_cbuf %id_gm, %mat_src, %c16_i64, %c16_i64, %c0_i64, %c0_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_gm_l1_frac %src_gm, %mat_id, dn2nz, + shape(%c1_i64, %c16_i64), + src_layout(%c2_i64, %c32_i64), + dst_group(%c3_i64, %c16_i64, %c16_i64, %c1_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, dn2nz, + shape i64, i64, src_layout(i64, i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + + pto.mte_l1_l0a %mat_src, %l0a, %c16_i64, %c16_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.mte_l1_l0b %mat_id, %l0b, %c16_i64, %c16_i64 {transpose = true} + : !pto.ptr, !pto.ptr, i64, i64 + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.mad %l0a, %l0b, %l0c, %c16_i64, %c16_i64, %c16_i64 + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + + pto.mte_l0c_gm %l0c, %out_gm, %c16_i64, %c16_i64, %c16_i64, %c16_i64, %c0_i64, %c0_i64, + nz2nd + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/cube-matmul/cube-load-frac-layouts/launch.cpp b/test/vpto/cases/micro-op/cube-matmul/cube-load-frac-layouts/launch.cpp new file mode 100644 index 000000000..5af5bfd7a --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/cube-load-frac-layouts/launch.cpp @@ -0,0 +1,78 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void cube_load_frac_nd2nz_kernel(__gm__ __fp16 *src, __gm__ __fp16 *id, + __gm__ float *out); +extern "C" __global__ [aicore] void cube_load_frac_dn2nz_kernel(__gm__ __fp16 *id, __gm__ __fp16 *src, + __gm__ float *out); +extern "C" __global__ [aicore] void cube_load_frac_nchw_nc1hwc0_kernel(__gm__ __fp16 *id, __gm__ __fp16 *src, + __gm__ float *out); +extern "C" __global__ [aicore] void cube_load_frac_nchw_fz4d_kernel(__gm__ __fp16 *id, __gm__ __fp16 *src, + __gm__ float *out); +extern "C" __global__ [aicore] void cube_load_frac_ncdhw_ndc1hwc0_kernel(__gm__ __fp16 *id, __gm__ __fp16 *src, + __gm__ float *out); +extern "C" __global__ [aicore] void cube_load_frac_ncdhw_fz3d_kernel(__gm__ __fp16 *id, __gm__ __fp16 *src, + __gm__ float *out); + +void LaunchCube_load_frac_nd2nz_kernel(__fp16 *src, __fp16 *id, float *out, void *stream) { + cube_load_frac_nd2nz_kernel<<<1, nullptr, stream>>>((__gm__ __fp16 *)src, (__gm__ __fp16 *)id, (__gm__ float *)out); +} + +void LaunchCube_load_frac_dn2nz_kernel(__fp16 *id, __fp16 *src, float *out, void *stream) { + cube_load_frac_dn2nz_kernel<<<1, nullptr, stream>>>((__gm__ __fp16 *)id, (__gm__ __fp16 *)src, (__gm__ float *)out); +} + +void LaunchCube_load_frac_nchw_nc1hwc0_kernel(__fp16 *id, __fp16 *src, float *out, void *stream) { + cube_load_frac_nchw_nc1hwc0_kernel<<<1, nullptr, stream>>>((__gm__ __fp16 *)id, (__gm__ __fp16 *)src, + (__gm__ float *)out); +} + +void LaunchCube_load_frac_nchw_fz4d_kernel(__fp16 *id, __fp16 *src, float *out, void *stream) { + cube_load_frac_nchw_fz4d_kernel<<<1, nullptr, stream>>>((__gm__ __fp16 *)id, (__gm__ __fp16 *)src, + (__gm__ float *)out); +} + +void LaunchCube_load_frac_ncdhw_ndc1hwc0_kernel(__fp16 *id, __fp16 *src, float *out, void *stream) { + cube_load_frac_ncdhw_ndc1hwc0_kernel<<<1, nullptr, stream>>>((__gm__ __fp16 *)id, (__gm__ __fp16 *)src, + (__gm__ float *)out); +} + +void LaunchCube_load_frac_ncdhw_fz3d_kernel(__fp16 *id, __fp16 *src, float *out, void *stream) { + cube_load_frac_ncdhw_fz3d_kernel<<<1, nullptr, stream>>>((__gm__ __fp16 *)id, (__gm__ __fp16 *)src, + (__gm__ float *)out); +} diff --git a/test/vpto/cases/micro-op/cube-matmul/cube-load-frac-layouts/main.cpp b/test/vpto/cases/micro-op/cube-matmul/cube-load-frac-layouts/main.cpp new file mode 100644 index 000000000..301c4d053 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/cube-load-frac-layouts/main.cpp @@ -0,0 +1,141 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +#define FILE_CHECK(expr, path) \ + do { \ + if (!(expr)) { \ + std::fprintf(stderr, "[ERROR] file operation failed: %s (%s:%d)\n", \ + path, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchCube_load_frac_nd2nz_kernel(__fp16 *src, __fp16 *id, float *out, void *stream); + +static bool readExact(const char *path, void *dst, size_t size) { + size_t inputSize = size; + return ReadFile(path, inputSize, dst, size) && inputSize == size; +} + +static bool writeExact(const char *path, void *src, size_t size) { + return WriteFile(path, src, size); +} + +int main() { + constexpr size_t kNd2NzCase1LhsElem = 40 * 50; + constexpr size_t kNd2NzCase1RhsElem = 50 * 60; + constexpr size_t kNd2NzCase1OutElem = 40 * 60; + + constexpr size_t kNd2NzCase1LhsSize = kNd2NzCase1LhsElem * sizeof(__fp16); + constexpr size_t kNd2NzCase1RhsSize = kNd2NzCase1RhsElem * sizeof(__fp16); + constexpr size_t kNd2NzCase1OutSize = kNd2NzCase1OutElem * sizeof(float); + + __fp16 *nd2nzCase1LhsHost = nullptr; + __fp16 *nd2nzCase1RhsHost = nullptr; + float *outNd2nzCase1Host = nullptr; + + __fp16 *nd2nzCase1LhsDevice = nullptr; + __fp16 *nd2nzCase1RhsDevice = nullptr; + float *outNd2nzCase1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&nd2nzCase1LhsHost), kNd2NzCase1LhsSize)); + ACL_CHECK(aclrtMallocHost((void **)(&nd2nzCase1RhsHost), kNd2NzCase1RhsSize)); + ACL_CHECK(aclrtMallocHost((void **)(&outNd2nzCase1Host), kNd2NzCase1OutSize)); + + ACL_CHECK(aclrtMalloc((void **)&nd2nzCase1LhsDevice, kNd2NzCase1LhsSize, + ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&nd2nzCase1RhsDevice, kNd2NzCase1RhsSize, + ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outNd2nzCase1Device, kNd2NzCase1OutSize, + ACL_MEM_MALLOC_HUGE_FIRST)); + + FILE_CHECK(readExact("./lhs_nd2nz_case1.bin", nd2nzCase1LhsHost, kNd2NzCase1LhsSize), + "./lhs_nd2nz_case1.bin"); + FILE_CHECK(readExact("./src_nd2nz_case1.bin", nd2nzCase1RhsHost, kNd2NzCase1RhsSize), + "./src_nd2nz_case1.bin"); + FILE_CHECK(readExact("./out_nd2nz_case1.bin", outNd2nzCase1Host, kNd2NzCase1OutSize), + "./out_nd2nz_case1.bin"); + + ACL_CHECK(aclrtMemcpy(nd2nzCase1LhsDevice, kNd2NzCase1LhsSize, nd2nzCase1LhsHost, + kNd2NzCase1LhsSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(nd2nzCase1RhsDevice, kNd2NzCase1RhsSize, nd2nzCase1RhsHost, + kNd2NzCase1RhsSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outNd2nzCase1Device, kNd2NzCase1OutSize, outNd2nzCase1Host, + kNd2NzCase1OutSize, ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchCube_load_frac_nd2nz_kernel(nd2nzCase1RhsDevice, nd2nzCase1LhsDevice, + outNd2nzCase1Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + + ACL_CHECK(aclrtMemcpy(outNd2nzCase1Host, kNd2NzCase1OutSize, outNd2nzCase1Device, + kNd2NzCase1OutSize, ACL_MEMCPY_DEVICE_TO_HOST)); + + FILE_CHECK(writeExact("./out_nd2nz_case1.bin", outNd2nzCase1Host, kNd2NzCase1OutSize), + "./out_nd2nz_case1.bin"); + +cleanup: + aclrtFree(nd2nzCase1LhsDevice); + aclrtFree(nd2nzCase1RhsDevice); + aclrtFree(outNd2nzCase1Device); + aclrtFreeHost(nd2nzCase1LhsHost); + aclrtFreeHost(nd2nzCase1RhsHost); + aclrtFreeHost(outNd2nzCase1Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/cube-matmul/fixpipe-cc-gm/compare.py b/test/vpto/cases/micro-op/cube-matmul/fixpipe-cc-gm/compare.py new file mode 100644 index 000000000..44cb435d1 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/fixpipe-cc-gm/compare.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + print(f"[ERROR] missing file: {golden_path} or {output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.float32) + output = np.fromfile(output_path, dtype=np.float32) + if golden.shape != output.shape: + print(f"[ERROR] shape mismatch: {golden.shape} vs {output.shape}") + return False + if np.allclose(golden, output, atol=1e-3, rtol=1e-3): + return True + diff = np.where(np.abs(golden - output) > (1e-3 + 1e-3 * np.abs(golden)))[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] first mismatch at idx={idx}: " + f"golden={float(golden[idx])}, out={float(output[idx])}" + ) + return False + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin") + ok = compare_bin("golden_v4.bin", "v4.bin") and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cube-matmul/fixpipe-cc-gm/golden.py b/test/vpto/cases/micro-op/cube-matmul/fixpipe-cc-gm/golden.py new file mode 100644 index 000000000..d42962ee7 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/fixpipe-cc-gm/golden.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + + +def generate(output_dir: Path) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + + lhs = (np.arange(40 * 50, dtype=np.float16).reshape(40, 50) * np.float16(0.5) + + np.float16(17)).astype(np.float16) + rhs = (np.arange(50 * 64, dtype=np.float16).reshape(50, 64) * np.float16(0.25) + + np.float16(3)).astype(np.float16) + out = np.zeros((40, 64), dtype=np.float32) + out_cbuf = np.zeros((40, 64), dtype=np.float32) + golden = lhs.astype(np.float32) @ rhs.astype(np.float32) + + lhs.reshape(-1).tofile(output_dir / "v1.bin") + rhs.reshape(-1).tofile(output_dir / "v2.bin") + out.reshape(-1).tofile(output_dir / "v3.bin") + out_cbuf.reshape(-1).tofile(output_dir / "v4.bin") + golden.reshape(-1).tofile(output_dir / "golden_v3.bin") + golden.reshape(-1).tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cube-matmul/fixpipe-cc-gm/kernel.pto b/test/vpto/cases/micro-op/cube-matmul/fixpipe-cc-gm/kernel.pto new file mode 100644 index 000000000..7226aa1d1 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/fixpipe-cc-gm/kernel.pto @@ -0,0 +1,118 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/cube-matmul/fixpipe-cc-gm +// family: cube-matmul +// target_ops: pto.mte_gm_l1_frac, pto.mte_l1_l0a, pto.mte_l1_l0b, pto.mad, +// pto.mte_l0c_gm, pto.mte_l0c_l1, pto.mte_l1_ub, pto.mte_ub_gm, +// pto.sync.set, pto.sync.wait +// scenarios: fixpipe, cc-to-gm, acc-to-cbuf-to-ub-to-gm, strict-matmul-golden +// ----------------------------------------------------------------------------- + +module attributes {pto.target_arch = "a5"} { + func.func @fixpipe_cc_gm_kernel(%src_gm: !pto.ptr, + %id_gm: !pto.ptr, + %out_gm: !pto.ptr, + %out_cbuf_gm: !pto.ptr) + attributes {pto.kernel} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c40_i64 = arith.constant 40 : i64 + %c48_i64 = arith.constant 48 : i64 + %c50_i64 = arith.constant 50 : i64 + %c64_i64 = arith.constant 64 : i64 + %c100_i64 = arith.constant 100 : i64 + %c128_i64 = arith.constant 128 : i64 + %c256_i64 = arith.constant 256 : i64 + %c2560_i64 = arith.constant 2560 : i64 + %c49152_i64 = arith.constant 49152 : i64 + %c65536_i64 = arith.constant 65536 : i64 + %c131072_i64 = arith.constant 131072 : i64 + %false = arith.constant false + + %mat_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %mat_id = pto.castptr %c65536_i64 : i64 -> !pto.ptr + %l1_out = pto.castptr %c131072_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0a = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0b = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0c = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.section.cube { + pto.mte_gm_l1_frac %id_gm, %mat_src, nd2nz, + shape(%c40_i64, %c50_i64), + src_layout(%c100_i64), + dst_group(%c1_i64, %c1_i64, %c48_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1_frac %src_gm, %mat_id, nd2nz, + shape(%c50_i64, %c64_i64), + src_layout(%c128_i64), + dst_group(%c1_i64, %c1_i64, %c64_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.mte_l1_l0a %mat_src, %l0a, %c40_i64, %c50_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.mte_l1_l0b %mat_id, %l0b, %c50_i64, %c64_i64 {transpose = true} + : !pto.ptr, !pto.ptr, i64, i64 + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.mad %l0a, %l0b, %l0c, %c40_i64, %c64_i64, %c50_i64 unit_flag(check_and_set) + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.mte_l0c_gm %l0c, %out_gm, %c40_i64, %c64_i64, %c48_i64, %c64_i64, %c0_i64, %c0_i64, + unit_flag(check_only), + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64), + sat + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i64, i64, i64 + + pto.mte_l0c_l1 %l0c, %l1_out, %c40_i64, %c64_i64, %c48_i64, %c64_i64, + unit_flag(check_only), + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64), + sat + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_FIX", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_FIX", "PIPE_MTE1", "EVENT_ID0"] + pto.mte_l1_ub %l1_out, %ub_out, %c64_i64 + nburst(%c40_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE1", "PIPE_FIX", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_FIX", "EVENT_ID0"] + pto.sync.set , 1 + pto.sync.set , 17 + } + + pto.section.vector { + %subblock = pto.get_subblock_idx + %is_subblock0 = arith.cmpi eq, %subblock, %c0_i64 : i64 + + scf.if %is_subblock0 { + pto.sync.wait , 1 + pto.mte_ub_gm %ub_out, %out_cbuf_gm, %c256_i64 + nburst(%c40_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + } + pto.barrier #pto.pipe + } + return + } +} diff --git a/test/vpto/cases/micro-op/cube-matmul/fixpipe-cc-gm/launch.cpp b/test/vpto/cases/micro-op/cube-matmul/fixpipe-cc-gm/launch.cpp new file mode 100644 index 000000000..4e9989a80 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/fixpipe-cc-gm/launch.cpp @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif + +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void fixpipe_cc_gm_kernel( + __gm__ __fp16 *src, __gm__ __fp16 *id, __gm__ float *out, + __gm__ float *out_cbuf); + +void LaunchFixpipe_cc_gm_kernel(__fp16 *src, __fp16 *id, float *out, + float *outCbuf, + void *stream) { + fixpipe_cc_gm_kernel<<<1, nullptr, stream>>>( + (__gm__ __fp16 *)src, (__gm__ __fp16 *)id, (__gm__ float *)out, + (__gm__ float *)outCbuf); +} diff --git a/test/vpto/cases/micro-op/cube-matmul/fixpipe-cc-gm/main.cpp b/test/vpto/cases/micro-op/cube-matmul/fixpipe-cc-gm/main.cpp new file mode 100644 index 000000000..ed76ebbf1 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/fixpipe-cc-gm/main.cpp @@ -0,0 +1,137 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +#define FILE_CHECK(expr, path) \ + do { \ + if (!(expr)) { \ + std::fprintf(stderr, "[ERROR] file operation failed: %s (%s:%d)\n", \ + path, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchFixpipe_cc_gm_kernel(__fp16 *src, __fp16 *id, float *out, + float *outCbuf, + void *stream); + +int main() { + constexpr size_t kSrcElem = 50 * 64; + constexpr size_t kIdElem = 40 * 50; + constexpr size_t kOutElem = 40 * 64; + constexpr size_t kSrcSize = kSrcElem * sizeof(__fp16); + constexpr size_t kIdSize = kIdElem * sizeof(__fp16); + constexpr size_t kOutSize = kOutElem * sizeof(float); + + __fp16 *srcHost = nullptr; + __fp16 *idHost = nullptr; + float *outHost = nullptr; + float *outCbufHost = nullptr; + __fp16 *srcDevice = nullptr; + __fp16 *idDevice = nullptr; + float *outDevice = nullptr; + float *outCbufDevice = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + size_t inputSize = 0; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)&srcHost, kSrcSize)); + ACL_CHECK(aclrtMallocHost((void **)&idHost, kIdSize)); + ACL_CHECK(aclrtMallocHost((void **)&outHost, kOutSize)); + ACL_CHECK(aclrtMallocHost((void **)&outCbufHost, kOutSize)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, kSrcSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&idDevice, kIdSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, kOutSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outCbufDevice, kOutSize, ACL_MEM_MALLOC_HUGE_FIRST)); + + inputSize = kIdSize; + FILE_CHECK(ReadFile("./v1.bin", inputSize, idHost, kIdSize) && inputSize == kIdSize, + "./v1.bin"); + inputSize = kSrcSize; + FILE_CHECK(ReadFile("./v2.bin", inputSize, srcHost, kSrcSize) && inputSize == kSrcSize, + "./v2.bin"); + inputSize = kOutSize; + FILE_CHECK(ReadFile("./v3.bin", inputSize, outHost, kOutSize) && inputSize == kOutSize, + "./v3.bin"); + inputSize = kOutSize; + FILE_CHECK(ReadFile("./v4.bin", inputSize, outCbufHost, kOutSize) && inputSize == kOutSize, + "./v4.bin"); + + ACL_CHECK(aclrtMemcpy(srcDevice, kSrcSize, srcHost, kSrcSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(idDevice, kIdSize, idHost, kIdSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, kOutSize, outHost, kOutSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outCbufDevice, kOutSize, outCbufHost, kOutSize, ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchFixpipe_cc_gm_kernel(srcDevice, idDevice, outDevice, outCbufDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + + ACL_CHECK(aclrtMemcpy(outHost, kOutSize, outDevice, kOutSize, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outCbufHost, kOutSize, outCbufDevice, kOutSize, ACL_MEMCPY_DEVICE_TO_HOST)); + FILE_CHECK(WriteFile("./v3.bin", outHost, kOutSize), "./v3.bin"); + FILE_CHECK(WriteFile("./v4.bin", outCbufHost, kOutSize), "./v4.bin"); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(idDevice); + aclrtFree(outDevice); + aclrtFree(outCbufDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(idHost); + aclrtFreeHost(outHost); + aclrtFreeHost(outCbufHost); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_acc/compare.py b/test/vpto/cases/micro-op/cube-matmul/mad_acc/compare.py new file mode 100644 index 000000000..c1391455e --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_acc/compare.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.float32) + output = np.fromfile(output_path, dtype=np.float32) + if golden.shape != output.shape: + print(f"[ERROR] shape mismatch: {golden.shape} vs {output.shape}") + return False + if np.allclose(golden, output, atol=1e-2, rtol=1e-2): + return True + diff = np.where(np.abs(golden - output) > (1e-2 + 1e-2 * np.abs(golden)))[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] first mismatch at idx={idx}: golden={float(golden[idx])}, out={float(output[idx])}") + return False + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_acc/golden.py b/test/vpto/cases/micro-op/cube-matmul/mad_acc/golden.py new file mode 100644 index 000000000..4e0fae3b8 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_acc/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import argparse +from pathlib import Path +import numpy as np + +M = 16 +N = 16 +K = 16 + + +def generate(output_dir: Path) -> None: + row = np.arange(M, dtype=np.float32).reshape(M, 1) + col = np.arange(K, dtype=np.float32).reshape(1, K) + a = (((row * 3 + col * 5) % 17) - 8).astype(np.float16) / np.float16(4.0) + k_idx = np.arange(K, dtype=np.float32).reshape(K, 1) + n_idx = np.arange(N, dtype=np.float32).reshape(1, N) + b = (((k_idx * 7 - n_idx * 2) % 19) - 9).astype(np.float16) / np.float16(5.0) + a_acc = (((row * 2 - col * 7) % 13) - 6).astype(np.float16) / np.float16(3.0) + b_acc = (((k_idx * 11 + n_idx * 3) % 17) - 8).astype(np.float16) / np.float16(4.0) + c = np.zeros((M, N), dtype=np.float32) + golden_c = a.astype(np.float32) @ b.astype(np.float32) + golden_c += a_acc.astype(np.float32) @ b_acc.astype(np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + a.reshape(-1).tofile(output_dir / "v1.bin") + b.reshape(-1).tofile(output_dir / "v2.bin") + c.reshape(-1).tofile(output_dir / "v3.bin") + a_acc.reshape(-1).tofile(output_dir / "v4.bin") + b_acc.reshape(-1).tofile(output_dir / "v5.bin") + golden_c.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_acc/kernel.pto b/test/vpto/cases/micro-op/cube-matmul/mad_acc/kernel.pto new file mode 100644 index 000000000..6a786a73d --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_acc/kernel.pto @@ -0,0 +1,98 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @mad_acc_kernel(%a_gm: !pto.ptr, + %b_gm: !pto.ptr, + %a_acc_gm: !pto.ptr, + %b_acc_gm: !pto.ptr, + %c_gm: !pto.ptr) attributes {pto.kernel} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c1536_i64 = arith.constant 1536 : i64 + %false = arith.constant false + + %l1_a = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l1_b = pto.castptr %c512_i64 : i64 -> !pto.ptr + %l1_a_acc = pto.castptr %c1024_i64 : i64 -> !pto.ptr + %l1_b_acc = pto.castptr %c1536_i64 : i64 -> !pto.ptr + %l0a = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0b = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0c = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.mte_gm_l1_frac %a_gm, %l1_a, nd2nz, + shape(%c16_i64, %c16_i64), + src_layout(%c32_i64), + dst_group(%c1_i64, %c1_i64, %c16_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1_frac %b_gm, %l1_b, nd2nz, + shape(%c16_i64, %c16_i64), + src_layout(%c32_i64), + dst_group(%c1_i64, %c1_i64, %c16_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1_frac %a_acc_gm, %l1_a_acc, nd2nz, + shape(%c16_i64, %c16_i64), + src_layout(%c32_i64), + dst_group(%c1_i64, %c1_i64, %c16_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1_frac %b_acc_gm, %l1_b_acc, nd2nz, + shape(%c16_i64, %c16_i64), + src_layout(%c32_i64), + dst_group(%c1_i64, %c1_i64, %c16_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + + pto.mte_l1_l0a %l1_a, %l0a, %c16_i64, %c16_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.mte_l1_l0b %l1_b, %l0b, %c16_i64, %c16_i64 {transpose = true} + : !pto.ptr, !pto.ptr, i64, i64 + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.mad %l0a, %l0b, %l0c, %c16_i64, %c16_i64, %c16_i64 unit_flag(check_only) disable_gemv n_dir + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + + pto.set_flag["PIPE_M", "PIPE_MTE1", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_MTE1", "EVENT_ID1"] + pto.mte_l1_l0a %l1_a_acc, %l0a, %c16_i64, %c16_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.mte_l1_l0b %l1_b_acc, %l0b, %c16_i64, %c16_i64 {transpose = true} + : !pto.ptr, !pto.ptr, i64, i64 + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID1"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID1"] + pto.mad_acc %l0a, %l0b, %l0c, %c16_i64, %c16_i64, %c16_i64 unit_flag(check_and_set) disable_gemv n_dir + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + + pto.mte_l0c_gm %l0c, %c_gm, %c16_i64, %c16_i64, %c16_i64, %c16_i64, + %c0_i64, %c0_i64, + nz2nd + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_acc/launch.cpp b/test/vpto/cases/micro-op/cube-matmul/mad_acc/launch.cpp new file mode 100644 index 000000000..9097e433f --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_acc/launch.cpp @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif + +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void mad_acc_kernel(__gm__ __fp16 *a, + __gm__ __fp16 *b, + __gm__ __fp16 *a_acc, + __gm__ __fp16 *b_acc, + __gm__ float *c); + +void LaunchMad_acc_kernel(__fp16 *a, __fp16 *b, __fp16 *a_acc, __fp16 *b_acc, + float *c, void *stream) { + mad_acc_kernel<<<1, nullptr, stream>>>((__gm__ __fp16 *)a, + (__gm__ __fp16 *)b, + (__gm__ __fp16 *)a_acc, + (__gm__ __fp16 *)b_acc, + (__gm__ float *)c); +} diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_acc/main.cpp b/test/vpto/cases/micro-op/cube-matmul/mad_acc/main.cpp new file mode 100644 index 000000000..4025b1cd8 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_acc/main.cpp @@ -0,0 +1,148 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +#define FILE_CHECK(expr, path) \ + do { \ + if (!(expr)) { \ + std::fprintf(stderr, "[ERROR] file operation failed: %s (%s:%d)\n", \ + path, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchMad_acc_kernel(__fp16 *a, __fp16 *b, __fp16 *a_acc, + __fp16 *b_acc, float *c, void *stream); + +int main() { + constexpr size_t kM = 16; + constexpr size_t kN = 16; + constexpr size_t kK = 16; + constexpr size_t aElem = kM * kK; + constexpr size_t bElem = kK * kN; + constexpr size_t cElem = kM * kN; + + constexpr size_t aSize = aElem * sizeof(__fp16); + constexpr size_t bSize = bElem * sizeof(__fp16); + constexpr size_t cSize = cElem * sizeof(float); + + __fp16 *aHost = nullptr; + __fp16 *bHost = nullptr; + __fp16 *aAccHost = nullptr; + __fp16 *bAccHost = nullptr; + float *cHost = nullptr; + __fp16 *aDevice = nullptr; + __fp16 *bDevice = nullptr; + __fp16 *aAccDevice = nullptr; + __fp16 *bAccDevice = nullptr; + float *cDevice = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + size_t inputSize = 0; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&aHost), aSize)); + ACL_CHECK(aclrtMallocHost((void **)(&bHost), bSize)); + ACL_CHECK(aclrtMallocHost((void **)(&aAccHost), aSize)); + ACL_CHECK(aclrtMallocHost((void **)(&bAccHost), bSize)); + ACL_CHECK(aclrtMallocHost((void **)(&cHost), cSize)); + ACL_CHECK(aclrtMalloc((void **)&aDevice, aSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&bDevice, bSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&aAccDevice, aSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&bAccDevice, bSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&cDevice, cSize, ACL_MEM_MALLOC_HUGE_FIRST)); + + inputSize = aSize; + FILE_CHECK(ReadFile("./v1.bin", inputSize, aHost, aSize) && inputSize == aSize, + "./v1.bin"); + inputSize = bSize; + FILE_CHECK(ReadFile("./v2.bin", inputSize, bHost, bSize) && inputSize == bSize, + "./v2.bin"); + inputSize = aSize; + FILE_CHECK(ReadFile("./v4.bin", inputSize, aAccHost, aSize) && inputSize == aSize, + "./v4.bin"); + inputSize = bSize; + FILE_CHECK(ReadFile("./v5.bin", inputSize, bAccHost, bSize) && inputSize == bSize, + "./v5.bin"); + inputSize = cSize; + FILE_CHECK(ReadFile("./v3.bin", inputSize, cHost, cSize) && inputSize == cSize, + "./v3.bin"); + + ACL_CHECK(aclrtMemcpy(aDevice, aSize, aHost, aSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(bDevice, bSize, bHost, bSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(aAccDevice, aSize, aAccHost, aSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(bAccDevice, bSize, bAccHost, bSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(cDevice, cSize, cHost, cSize, ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchMad_acc_kernel(aDevice, bDevice, aAccDevice, bAccDevice, cDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + + ACL_CHECK(aclrtMemcpy(cHost, cSize, cDevice, cSize, ACL_MEMCPY_DEVICE_TO_HOST)); + FILE_CHECK(WriteFile("./v3.bin", cHost, cSize), "./v3.bin"); + +cleanup: + aclrtFree(aDevice); + aclrtFree(bDevice); + aclrtFree(aAccDevice); + aclrtFree(bAccDevice); + aclrtFree(cDevice); + aclrtFreeHost(aHost); + aclrtFreeHost(bHost); + aclrtFreeHost(aAccHost); + aclrtFreeHost(bAccHost); + aclrtFreeHost(cHost); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_bf16bf16f32/compare.py b/test/vpto/cases/micro-op/cube-matmul/mad_bf16bf16f32/compare.py new file mode 100644 index 000000000..c1391455e --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_bf16bf16f32/compare.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.float32) + output = np.fromfile(output_path, dtype=np.float32) + if golden.shape != output.shape: + print(f"[ERROR] shape mismatch: {golden.shape} vs {output.shape}") + return False + if np.allclose(golden, output, atol=1e-2, rtol=1e-2): + return True + diff = np.where(np.abs(golden - output) > (1e-2 + 1e-2 * np.abs(golden)))[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] first mismatch at idx={idx}: golden={float(golden[idx])}, out={float(output[idx])}") + return False + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_bf16bf16f32/golden.py b/test/vpto/cases/micro-op/cube-matmul/mad_bf16bf16f32/golden.py new file mode 100644 index 000000000..a45e27d92 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_bf16bf16f32/golden.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import argparse +from pathlib import Path +import numpy as np + +M = 16 +N = 16 +K = 16 + + +def to_bf16_bits(values: np.ndarray) -> np.ndarray: + f32 = values.astype(np.float32, copy=False) + return (f32.view(np.uint32) >> 16).astype(np.uint16) + + +def bf16_bits_to_f32(bits: np.ndarray) -> np.ndarray: + return (bits.astype(np.uint32) << 16).view(np.float32) + + +def generate(output_dir: Path) -> None: + row = np.arange(M, dtype=np.float32).reshape(M, 1) + col = np.arange(K, dtype=np.float32).reshape(1, K) + a_f32 = (((row * 5 + col * 3) % 23) - 11) / 8.0 + k_idx = np.arange(K, dtype=np.float32).reshape(K, 1) + n_idx = np.arange(N, dtype=np.float32).reshape(1, N) + b_f32 = (((k_idx * 2 - n_idx * 7) % 29) - 14) / 9.0 + a = to_bf16_bits(a_f32) + b = to_bf16_bits(b_f32) + c = np.zeros((M, N), dtype=np.float32) + golden_c = bf16_bits_to_f32(a).astype(np.float32) @ bf16_bits_to_f32(b).astype(np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + a.reshape(-1).tofile(output_dir / "v1.bin") + b.reshape(-1).tofile(output_dir / "v2.bin") + c.reshape(-1).tofile(output_dir / "v3.bin") + golden_c.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_bf16bf16f32/kernel.pto b/test/vpto/cases/micro-op/cube-matmul/mad_bf16bf16f32/kernel.pto new file mode 100644 index 000000000..7a7edd638 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_bf16bf16f32/kernel.pto @@ -0,0 +1,64 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @mad_bf16bf16f32_kernel(%a_gm: !pto.ptr, + %b_gm: !pto.ptr, + %c_gm: !pto.ptr) attributes {pto.kernel} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + %c512_i64 = arith.constant 512 : i64 + %false = arith.constant false + + %l1_a = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l1_b = pto.castptr %c512_i64 : i64 -> !pto.ptr + %l0a = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0b = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0c = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.mte_gm_l1_frac %a_gm, %l1_a, nd2nz, + shape(%c16_i64, %c16_i64), + src_layout(%c32_i64), + dst_group(%c1_i64, %c1_i64, %c16_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1_frac %b_gm, %l1_b, nd2nz, + shape(%c16_i64, %c16_i64), + src_layout(%c32_i64), + dst_group(%c1_i64, %c1_i64, %c16_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + + pto.mte_l1_l0a %l1_a, %l0a, %c16_i64, %c16_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.mte_l1_l0b %l1_b, %l0b, %c16_i64, %c16_i64 {transpose = true} + : !pto.ptr, !pto.ptr, i64, i64 + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.mad %l0a, %l0b, %l0c, %c16_i64, %c16_i64, %c16_i64 + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + + pto.mte_l0c_gm %l0c, %c_gm, %c16_i64, %c16_i64, %c16_i64, %c16_i64, %c0_i64, %c0_i64, + nz2nd + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_bf16bf16f32/launch.cpp b/test/vpto/cases/micro-op/cube-matmul/mad_bf16bf16f32/launch.cpp new file mode 100644 index 000000000..1a2013f26 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_bf16bf16f32/launch.cpp @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif + +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +using bf16_storage_t = uint16_t; + +extern "C" __global__ [aicore] void mad_bf16bf16f32_kernel(__gm__ bf16_storage_t *a, + __gm__ bf16_storage_t *b, + __gm__ float *c); + +void LaunchMad_bf16bf16f32_kernel(bf16_storage_t *a, bf16_storage_t *b, float *c, + void *stream) { + mad_bf16bf16f32_kernel<<<1, nullptr, stream>>>((__gm__ bf16_storage_t *)a, + (__gm__ bf16_storage_t *)b, + (__gm__ float *)c); +} diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_bf16bf16f32/main.cpp b/test/vpto/cases/micro-op/cube-matmul/mad_bf16bf16f32/main.cpp new file mode 100644 index 000000000..a7c6844ad --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_bf16bf16f32/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +#define FILE_CHECK(expr, path) \ + do { \ + if (!(expr)) { \ + std::fprintf(stderr, "[ERROR] file operation failed: %s (%s:%d)\n", \ + path, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +using bf16_storage_t = uint16_t; + +void LaunchMad_bf16bf16f32_kernel(bf16_storage_t *a, bf16_storage_t *b, float *c, + void *stream); + +int main() { + constexpr size_t kM = 16; + constexpr size_t kN = 16; + constexpr size_t kK = 16; + constexpr size_t aElem = kM * kK; + constexpr size_t bElem = kK * kN; + constexpr size_t cElem = kM * kN; + + constexpr size_t aSize = aElem * sizeof(bf16_storage_t); + constexpr size_t bSize = bElem * sizeof(bf16_storage_t); + constexpr size_t cSize = cElem * sizeof(float); + + bf16_storage_t *aHost = nullptr; + bf16_storage_t *bHost = nullptr; + float *cHost = nullptr; + bf16_storage_t *aDevice = nullptr; + bf16_storage_t *bDevice = nullptr; + float *cDevice = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + size_t inputSize = 0; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&aHost), aSize)); + ACL_CHECK(aclrtMallocHost((void **)(&bHost), bSize)); + ACL_CHECK(aclrtMallocHost((void **)(&cHost), cSize)); + ACL_CHECK(aclrtMalloc((void **)&aDevice, aSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&bDevice, bSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&cDevice, cSize, ACL_MEM_MALLOC_HUGE_FIRST)); + + inputSize = aSize; + FILE_CHECK(ReadFile("./v1.bin", inputSize, aHost, aSize) && inputSize == aSize, + "./v1.bin"); + inputSize = bSize; + FILE_CHECK(ReadFile("./v2.bin", inputSize, bHost, bSize) && inputSize == bSize, + "./v2.bin"); + inputSize = cSize; + FILE_CHECK(ReadFile("./v3.bin", inputSize, cHost, cSize) && inputSize == cSize, + "./v3.bin"); + + ACL_CHECK(aclrtMemcpy(aDevice, aSize, aHost, aSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(bDevice, bSize, bHost, bSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(cDevice, cSize, cHost, cSize, ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchMad_bf16bf16f32_kernel(aDevice, bDevice, cDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + + ACL_CHECK(aclrtMemcpy(cHost, cSize, cDevice, cSize, ACL_MEMCPY_DEVICE_TO_HOST)); + FILE_CHECK(WriteFile("./v3.bin", cHost, cSize), "./v3.bin"); + +cleanup: + aclrtFree(aDevice); + aclrtFree(bDevice); + aclrtFree(cDevice); + aclrtFreeHost(aHost); + aclrtFreeHost(bHost); + aclrtFreeHost(cHost); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_bias/compare.py b/test/vpto/cases/micro-op/cube-matmul/mad_bias/compare.py new file mode 100644 index 000000000..7643407a2 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_bias/compare.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.float32) + output = np.fromfile(output_path, dtype=np.float32) + if golden.shape != output.shape: + print(f"[ERROR] shape mismatch: {golden.shape} vs {output.shape}") + return False + if np.allclose(golden, output, atol=1e-2, rtol=1e-2): + return True + diff = np.where(np.abs(golden - output) > (1e-2 + 1e-2 * np.abs(golden)))[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] first mismatch at idx={idx}: golden={float(golden[idx])}, out={float(output[idx])}") + return False + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_bias/golden.py b/test/vpto/cases/micro-op/cube-matmul/mad_bias/golden.py new file mode 100644 index 000000000..b68457cdb --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_bias/golden.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +M = 16 +N = 16 +K = 16 + + +def generate(output_dir: Path) -> None: + row = np.arange(M, dtype=np.float32).reshape(M, 1) + col = np.arange(K, dtype=np.float32).reshape(1, K) + a = (((row * 3 - col * 2) % 17) - 8).astype(np.float16) / np.float16(4.0) + k_idx = np.arange(K, dtype=np.float32).reshape(K, 1) + n_idx = np.arange(N, dtype=np.float32).reshape(1, N) + b = (((k_idx * 5 + n_idx * 7) % 23) - 11).astype(np.float16) / np.float16(6.0) + c = np.zeros((M, N), dtype=np.float32) + bias = (((np.arange(N, dtype=np.float32) * 3) % 19) - 9).astype(np.float16) / np.float16(3.0) + golden_c = a.astype(np.float32) @ b.astype(np.float32) + golden_c += bias.astype(np.float32)[None, :] + + output_dir.mkdir(parents=True, exist_ok=True) + a.reshape(-1).tofile(output_dir / "v1.bin") + b.reshape(-1).tofile(output_dir / "v2.bin") + c.reshape(-1).tofile(output_dir / "v3.bin") + bias.reshape(-1).tofile(output_dir / "v4.bin") + golden_c.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_bias/kernel.pto b/test/vpto/cases/micro-op/cube-matmul/mad_bias/kernel.pto new file mode 100644 index 000000000..f65993369 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_bias/kernel.pto @@ -0,0 +1,74 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @mad_bias_kernel(%a_gm: !pto.ptr, + %b_gm: !pto.ptr, + %c_gm: !pto.ptr, + %bias_gm: !pto.ptr) attributes {pto.kernel} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %false = arith.constant false + + %l1_a = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l1_b = pto.castptr %c512_i64 : i64 -> !pto.ptr + %l1_bias = pto.castptr %c1024_i64 : i64 -> !pto.ptr + %l0a = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0b = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0c = pto.castptr %c0_i64 : i64 -> !pto.ptr + %bt = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.mte_gm_l1_frac %a_gm, %l1_a, nd2nz, + shape(%c16_i64, %c16_i64), + src_layout(%c32_i64), + dst_group(%c1_i64, %c1_i64, %c16_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1_frac %b_gm, %l1_b, nd2nz, + shape(%c16_i64, %c16_i64), + src_layout(%c32_i64), + dst_group(%c1_i64, %c1_i64, %c16_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1 %bias_gm, %l1_bias, %c32_i64 + nburst(%c1_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + + pto.mte_l1_l0a %l1_a, %l0a, %c16_i64, %c16_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.mte_l1_l0b %l1_b, %l0b, %c16_i64, %c16_i64 {transpose = true} + : !pto.ptr, !pto.ptr, i64, i64 + pto.mte_l1_bt %l1_bias, %bt, %c16_i64 nburst(%c1_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + + pto.mad_bias %l0a, %l0b, %l0c, %bt, %c16_i64, %c16_i64, %c16_i64 unit_flag(check_only) disable_gemv n_dir + : !pto.ptr, !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + + pto.mte_l0c_gm %l0c, %c_gm, %c16_i64, %c16_i64, %c16_i64, %c16_i64, %c0_i64, %c0_i64, + nz2nd + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_bias/launch.cpp b/test/vpto/cases/micro-op/cube-matmul/mad_bias/launch.cpp new file mode 100644 index 000000000..4e7100a74 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_bias/launch.cpp @@ -0,0 +1,43 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void mad_bias_kernel(__gm__ __fp16 *a, + __gm__ __fp16 *b, + __gm__ float *c, + __gm__ __fp16 *bias); + +void LaunchMad_bias_kernel(__fp16 *a, __fp16 *b, float *c, __fp16 *bias, + void *stream) { + mad_bias_kernel<<<1, nullptr, stream>>>((__gm__ __fp16 *)a, + (__gm__ __fp16 *)b, + (__gm__ float *)c, + (__gm__ __fp16 *)bias); +} diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_bias/main.cpp b/test/vpto/cases/micro-op/cube-matmul/mad_bias/main.cpp new file mode 100644 index 000000000..e9556d915 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_bias/main.cpp @@ -0,0 +1,142 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +#define FILE_CHECK(expr, path) \ + do { \ + if (!(expr)) { \ + std::fprintf(stderr, "[ERROR] file operation failed: %s (%s:%d)\n", \ + path, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchMad_bias_kernel(__fp16 *a, __fp16 *b, float *c, __fp16 *bias, + void *stream); + +int main() { + constexpr size_t kM = 16; + constexpr size_t kN = 16; + constexpr size_t kK = 16; + constexpr size_t aElem = kM * kK; + constexpr size_t bElem = kK * kN; + constexpr size_t cElem = kM * kN; + constexpr size_t biasElem = kN; + + constexpr size_t aSize = aElem * sizeof(__fp16); + constexpr size_t bSize = bElem * sizeof(__fp16); + constexpr size_t cSize = cElem * sizeof(float); + constexpr size_t biasSize = biasElem * sizeof(__fp16); + + __fp16 *aHost = nullptr; + __fp16 *bHost = nullptr; + float *cHost = nullptr; + __fp16 *biasHost = nullptr; + __fp16 *aDevice = nullptr; + __fp16 *bDevice = nullptr; + float *cDevice = nullptr; + __fp16 *biasDevice = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + size_t inputSize = 0; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&aHost), aSize)); + ACL_CHECK(aclrtMallocHost((void **)(&bHost), bSize)); + ACL_CHECK(aclrtMallocHost((void **)(&cHost), cSize)); + ACL_CHECK(aclrtMallocHost((void **)(&biasHost), biasSize)); + ACL_CHECK(aclrtMalloc((void **)&aDevice, aSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&bDevice, bSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&cDevice, cSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&biasDevice, biasSize, ACL_MEM_MALLOC_HUGE_FIRST)); + + inputSize = aSize; + FILE_CHECK(ReadFile("./v1.bin", inputSize, aHost, aSize) && inputSize == aSize, + "./v1.bin"); + inputSize = bSize; + FILE_CHECK(ReadFile("./v2.bin", inputSize, bHost, bSize) && inputSize == bSize, + "./v2.bin"); + inputSize = cSize; + FILE_CHECK(ReadFile("./v3.bin", inputSize, cHost, cSize) && inputSize == cSize, + "./v3.bin"); + inputSize = biasSize; + FILE_CHECK(ReadFile("./v4.bin", inputSize, biasHost, biasSize) && + inputSize == biasSize, + "./v4.bin"); + + ACL_CHECK(aclrtMemcpy(aDevice, aSize, aHost, aSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(bDevice, bSize, bHost, bSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(cDevice, cSize, cHost, cSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(biasDevice, biasSize, biasHost, biasSize, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchMad_bias_kernel(aDevice, bDevice, cDevice, biasDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + + ACL_CHECK(aclrtMemcpy(cHost, cSize, cDevice, cSize, ACL_MEMCPY_DEVICE_TO_HOST)); + FILE_CHECK(WriteFile("./v3.bin", cHost, cSize), "./v3.bin"); + +cleanup: + aclrtFree(aDevice); + aclrtFree(bDevice); + aclrtFree(cDevice); + aclrtFree(biasDevice); + aclrtFreeHost(aHost); + aclrtFreeHost(bHost); + aclrtFreeHost(cHost); + aclrtFreeHost(biasHost); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_f16f16f32/compare.py b/test/vpto/cases/micro-op/cube-matmul/mad_f16f16f32/compare.py new file mode 100644 index 000000000..21669b86e --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_f16f16f32/compare.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.float32) + output = np.fromfile(output_path, dtype=np.float32) + if golden.shape != output.shape: + print(f"[ERROR] shape mismatch: {golden.shape} vs {output.shape}") + return False + if np.allclose(golden, output, atol=1e-2, rtol=1e-2): + return True + diff = np.where(np.abs(golden - output) > (1e-2 + 1e-2 * np.abs(golden)))[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] first mismatch at idx={idx}: golden={float(golden[idx])}, out={float(output[idx])}") + return False + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin") + ok = compare_bin("golden_v4.bin", "v4.bin") and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_f16f16f32/golden.py b/test/vpto/cases/micro-op/cube-matmul/mad_f16f16f32/golden.py new file mode 100644 index 000000000..79d8ff308 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_f16f16f32/golden.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import argparse +from pathlib import Path +import numpy as np + +M = 1 +N = 16 +K = 32 + + +def generate(output_dir: Path) -> None: + row = np.arange(M, dtype=np.float32).reshape(M, 1) + col = np.arange(K, dtype=np.float32).reshape(1, K) + a = (((row * 3 + col * 5) % 17) - 8).astype(np.float16) / np.float16(4.0) + k_idx = np.arange(K, dtype=np.float32).reshape(K, 1) + n_idx = np.arange(N, dtype=np.float32).reshape(1, N) + b = (((k_idx * 7 - n_idx * 2) % 19) - 9).astype(np.float16) / np.float16(5.0) + c_default = np.zeros((M, N), dtype=np.float32) + c_disable = np.zeros((M, N), dtype=np.float32) + golden = a.astype(np.float32) @ b.astype(np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + a.reshape(-1).tofile(output_dir / "v1.bin") + b.reshape(-1).tofile(output_dir / "v2.bin") + c_default.reshape(-1).tofile(output_dir / "v3.bin") + c_disable.reshape(-1).tofile(output_dir / "v4.bin") + golden.reshape(-1).tofile(output_dir / "golden_v3.bin") + golden.reshape(-1).tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_f16f16f32/kernel.pto b/test/vpto/cases/micro-op/cube-matmul/mad_f16f16f32/kernel.pto new file mode 100644 index 000000000..6753db54c --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_f16f16f32/kernel.pto @@ -0,0 +1,92 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @mad_f16f16f32_kernel(%a_gm: !pto.ptr, + %b_gm: !pto.ptr, + %c_default_gm: !pto.ptr, + %c_disable_gm: !pto.ptr) attributes {pto.kernel} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %false = arith.constant false + + %l1_a_gemv = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l1_a_normal = pto.castptr %c512_i64 : i64 -> !pto.ptr + %l1_b = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %l0a_gemv = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0a_normal = pto.castptr %c512_i64 : i64 -> !pto.ptr + %l0b = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0c = pto.castptr %c0_i64 : i64 -> !pto.ptr + + // GEMV consumes A as a contiguous 1xK vector. + pto.mte_gm_l1 %a_gm, %l1_a_gemv, %c64_i64 + nburst(%c1_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_gm_l1_frac %a_gm, %l1_a_normal, nd2nz, + shape(%c1_i64, %c32_i64), + src_layout(%c32_i64), + dst_group(%c1_i64, %c1_i64, %c16_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1_frac %b_gm, %l1_b, nd2nz, + shape(%c32_i64, %c16_i64), + src_layout(%c32_i64), + dst_group(%c1_i64, %c1_i64, %c16_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + + pto.load_cbuf_to_ca %l1_a_gemv, %l0a_gemv, %c0_i64, %c0_i64, %c1_i64, %c1_i64, %c1_i64, %c1_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.mte_l1_l0a %l1_a_normal, %l0a_normal, %c1_i64, %c32_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.mte_l1_l0b %l1_b, %l0b, %c32_i64, %c16_i64 {transpose = true} + : !pto.ptr, !pto.ptr, i64, i64 + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + + pto.mad %l0a_gemv, %l0b, %l0c, %c1_i64, %c16_i64, %c32_i64 + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + + pto.mte_l0c_gm %l0c, %c_default_gm, %c1_i64, %c16_i64, %c16_i64, %c16_i64, + %c0_i64, %c0_i64, + nz2nd + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_FIX", "PIPE_M", "EVENT_ID1"] + pto.wait_flag["PIPE_FIX", "PIPE_M", "EVENT_ID1"] + + pto.mad %l0a_normal, %l0b, %l0c, %c1_i64, %c16_i64, %c32_i64 disable_gemv + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + + pto.mte_l0c_gm %l0c, %c_disable_gm, %c1_i64, %c16_i64, %c16_i64, %c16_i64, + %c0_i64, %c0_i64, + nz2nd + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_f16f16f32/launch.cpp b/test/vpto/cases/micro-op/cube-matmul/mad_f16f16f32/launch.cpp new file mode 100644 index 000000000..5b23b2604 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_f16f16f32/launch.cpp @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif + +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void mad_f16f16f32_kernel(__gm__ __fp16 *a, + __gm__ __fp16 *b, + __gm__ float *c_default, + __gm__ float *c_disable); + +void LaunchMad_f16f16f32_kernel(__fp16 *a, __fp16 *b, float *cDefault, + float *cDisable, void *stream) { + mad_f16f16f32_kernel<<<1, nullptr, stream>>>((__gm__ __fp16 *)a, + (__gm__ __fp16 *)b, + (__gm__ float *)cDefault, + (__gm__ float *)cDisable); +} diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_f16f16f32/main.cpp b/test/vpto/cases/micro-op/cube-matmul/mad_f16f16f32/main.cpp new file mode 100644 index 000000000..eb701a600 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_f16f16f32/main.cpp @@ -0,0 +1,140 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +#define FILE_CHECK(expr, path) \ + do { \ + if (!(expr)) { \ + std::fprintf(stderr, "[ERROR] file operation failed: %s (%s:%d)\n", \ + path, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchMad_f16f16f32_kernel(__fp16 *a, __fp16 *b, float *cDefault, + float *cDisable, void *stream); + +int main() { + constexpr size_t kM = 1; + constexpr size_t kN = 16; + constexpr size_t kK = 32; + constexpr size_t aElem = kM * kK; + constexpr size_t bElem = kK * kN; + constexpr size_t cElem = kM * kN; + + constexpr size_t aSize = aElem * sizeof(__fp16); + constexpr size_t bSize = bElem * sizeof(__fp16); + constexpr size_t cSize = cElem * sizeof(float); + + __fp16 *aHost = nullptr; + __fp16 *bHost = nullptr; + float *cDefaultHost = nullptr; + float *cDisableHost = nullptr; + __fp16 *aDevice = nullptr; + __fp16 *bDevice = nullptr; + float *cDefaultDevice = nullptr; + float *cDisableDevice = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + size_t inputSize = 0; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&aHost), aSize)); + ACL_CHECK(aclrtMallocHost((void **)(&bHost), bSize)); + ACL_CHECK(aclrtMallocHost((void **)(&cDefaultHost), cSize)); + ACL_CHECK(aclrtMallocHost((void **)(&cDisableHost), cSize)); + ACL_CHECK(aclrtMalloc((void **)&aDevice, aSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&bDevice, bSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&cDefaultDevice, cSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&cDisableDevice, cSize, ACL_MEM_MALLOC_HUGE_FIRST)); + + inputSize = aSize; + FILE_CHECK(ReadFile("./v1.bin", inputSize, aHost, aSize) && inputSize == aSize, + "./v1.bin"); + inputSize = bSize; + FILE_CHECK(ReadFile("./v2.bin", inputSize, bHost, bSize) && inputSize == bSize, + "./v2.bin"); + inputSize = cSize; + FILE_CHECK(ReadFile("./v3.bin", inputSize, cDefaultHost, cSize) && inputSize == cSize, + "./v3.bin"); + inputSize = cSize; + FILE_CHECK(ReadFile("./v4.bin", inputSize, cDisableHost, cSize) && inputSize == cSize, + "./v4.bin"); + + ACL_CHECK(aclrtMemcpy(aDevice, aSize, aHost, aSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(bDevice, bSize, bHost, bSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(cDefaultDevice, cSize, cDefaultHost, cSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(cDisableDevice, cSize, cDisableHost, cSize, ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchMad_f16f16f32_kernel(aDevice, bDevice, cDefaultDevice, cDisableDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + + ACL_CHECK(aclrtMemcpy(cDefaultHost, cSize, cDefaultDevice, cSize, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(cDisableHost, cSize, cDisableDevice, cSize, ACL_MEMCPY_DEVICE_TO_HOST)); + FILE_CHECK(WriteFile("./v3.bin", cDefaultHost, cSize), "./v3.bin"); + FILE_CHECK(WriteFile("./v4.bin", cDisableHost, cSize), "./v4.bin"); + +cleanup: + aclrtFree(aDevice); + aclrtFree(bDevice); + aclrtFree(cDefaultDevice); + aclrtFree(cDisableDevice); + aclrtFreeHost(aHost); + aclrtFreeHost(bHost); + aclrtFreeHost(cDefaultHost); + aclrtFreeHost(cDisableHost); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_f32f32f32/compare.py b/test/vpto/cases/micro-op/cube-matmul/mad_f32f32f32/compare.py new file mode 100644 index 000000000..02a0ecdf7 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_f32f32f32/compare.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys + +import numpy as np + + +def is_close(golden: np.ndarray, output: np.ndarray) -> np.ndarray: + return np.isclose(golden, output, atol=1e-2, rtol=1e-2, equal_nan=True) + + +def compare_bin(golden_path: str, output_path: str) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + print(f"[ERROR] missing file: {golden_path} or {output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.float32) + output = np.fromfile(output_path, dtype=np.float32) + if golden.shape != output.shape: + print(f"[ERROR] shape mismatch: {golden.shape} vs {output.shape}") + return False + if np.all(is_close(golden, output)): + return True + diff = np.where(~is_close(golden, output))[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] first mismatch at idx={idx}: golden={float(golden[idx])}, out={float(output[idx])}") + return False + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin") + ok = compare_bin("golden_v4.bin", "v4.bin") and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_f32f32f32/golden.py b/test/vpto/cases/micro-op/cube-matmul/mad_f32f32f32/golden.py new file mode 100644 index 000000000..e7483cd45 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_f32f32f32/golden.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import argparse +from pathlib import Path +import numpy as np + +M = 16 +N = 16 +K = 16 + + +def generate(output_dir: Path) -> None: + row = np.arange(M, dtype=np.float32).reshape(M, 1) + col = np.arange(K, dtype=np.float32).reshape(1, K) + a = (((row * 11 + col * 3) % 31) - 15).astype(np.float32) / 16.0 + k_idx = np.arange(K, dtype=np.float32).reshape(K, 1) + n_idx = np.arange(N, dtype=np.float32).reshape(1, N) + b = (((k_idx * 5 - n_idx * 13) % 37) - 18).astype(np.float32) / 17.0 + + a[0, 0] = np.float32(np.inf) + b[0, 0] = np.float32(1.0) + a[1, 1] = np.float32(np.nan) + b[1, 1] = np.float32(1.0) + a[2, :] = np.float32(0.0) + b[:, 2] = np.float32(0.0) + a[2, 2] = np.float32(2.0e38) + b[2, 2] = np.float32(2.0) + + c_sat = np.zeros((M, N), dtype=np.float32) + c_nosat = np.zeros((M, N), dtype=np.float32) + saturated_a = np.nan_to_num( + a, + nan=np.float32(0.0), + posinf=np.finfo(np.float32).max, + neginf=-np.finfo(np.float32).max, + ).astype(np.float32) + saturated_b = np.nan_to_num( + b, + nan=np.float32(0.0), + posinf=np.finfo(np.float32).max, + neginf=-np.finfo(np.float32).max, + ).astype(np.float32) + with np.errstate(invalid="ignore", over="ignore"): + f32_max = np.finfo(np.float32).max + golden_sat = (saturated_a.astype(np.float64) @ saturated_b.astype(np.float64)) + golden_sat = np.nan_to_num( + np.clip(golden_sat, -f32_max, f32_max), + nan=0.0, + posinf=f32_max, + neginf=-f32_max, + ).astype(np.float32) + golden_nosat = (a @ b).astype(np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + a.reshape(-1).tofile(output_dir / "v1.bin") + b.reshape(-1).tofile(output_dir / "v2.bin") + c_sat.reshape(-1).tofile(output_dir / "v3.bin") + c_nosat.reshape(-1).tofile(output_dir / "v4.bin") + golden_sat.reshape(-1).tofile(output_dir / "golden_v3.bin") + golden_nosat.reshape(-1).tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_f32f32f32/kernel.pto b/test/vpto/cases/micro-op/cube-matmul/mad_f32f32f32/kernel.pto new file mode 100644 index 000000000..cbd0190a8 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_f32f32f32/kernel.pto @@ -0,0 +1,80 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @mad_f32f32f32_kernel(%a_gm: !pto.ptr, + %b_gm: !pto.ptr, + %c_sat_gm: !pto.ptr, + %c_nosat_gm: !pto.ptr) attributes {pto.kernel} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c16_i64 = arith.constant 16 : i64 + %c64_i64 = arith.constant 64 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %false = arith.constant false + + %l1_a = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l1_b = pto.castptr %c1024_i64 : i64 -> !pto.ptr + %l0a = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0b = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0c = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.mte_gm_l1_frac %a_gm, %l1_a, nd2nz, + shape(%c16_i64, %c16_i64), + src_layout(%c64_i64), + dst_group(%c1_i64, %c1_i64, %c16_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1_frac %b_gm, %l1_b, nd2nz, + shape(%c16_i64, %c16_i64), + src_layout(%c64_i64), + dst_group(%c1_i64, %c1_i64, %c16_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + + pto.mte_l1_l0a %l1_a, %l0a, %c16_i64, %c16_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.mte_l1_l0b %l1_b, %l0b, %c16_i64, %c16_i64 {transpose = true} + : !pto.ptr, !pto.ptr, i64, i64 + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.mad %l0a, %l0b, %l0c, %c16_i64, %c16_i64, %c16_i64 sat + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + + pto.mte_l0c_gm %l0c, %c_sat_gm, %c16_i64, %c16_i64, %c16_i64, %c16_i64, + %c0_i64, %c0_i64, + nz2nd + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_FIX", "PIPE_M", "EVENT_ID1"] + pto.wait_flag["PIPE_FIX", "PIPE_M", "EVENT_ID1"] + + pto.mad %l0a, %l0b, %l0c, %c16_i64, %c16_i64, %c16_i64 nosat + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + + pto.mte_l0c_gm %l0c, %c_nosat_gm, %c16_i64, %c16_i64, %c16_i64, %c16_i64, + %c0_i64, %c0_i64, + nz2nd + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_f32f32f32/launch.cpp b/test/vpto/cases/micro-op/cube-matmul/mad_f32f32f32/launch.cpp new file mode 100644 index 000000000..ad5f9fa44 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_f32f32f32/launch.cpp @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif + +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void mad_f32f32f32_kernel(__gm__ float *a, + __gm__ float *b, + __gm__ float *c_sat, + __gm__ float *c_nosat); + +void LaunchMad_f32f32f32_kernel(float *a, float *b, float *cSat, + float *cNoSat, void *stream) { + mad_f32f32f32_kernel<<<1, nullptr, stream>>>((__gm__ float *)a, + (__gm__ float *)b, + (__gm__ float *)cSat, + (__gm__ float *)cNoSat); +} diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_f32f32f32/main.cpp b/test/vpto/cases/micro-op/cube-matmul/mad_f32f32f32/main.cpp new file mode 100644 index 000000000..5393c4254 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_f32f32f32/main.cpp @@ -0,0 +1,140 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +#define FILE_CHECK(expr, path) \ + do { \ + if (!(expr)) { \ + std::fprintf(stderr, "[ERROR] file operation failed: %s (%s:%d)\n", \ + path, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchMad_f32f32f32_kernel(float *a, float *b, float *cSat, + float *cNoSat, void *stream); + +int main() { + constexpr size_t kM = 16; + constexpr size_t kN = 16; + constexpr size_t kK = 16; + constexpr size_t aElem = kM * kK; + constexpr size_t bElem = kK * kN; + constexpr size_t cElem = kM * kN; + + constexpr size_t aSize = aElem * sizeof(float); + constexpr size_t bSize = bElem * sizeof(float); + constexpr size_t cSize = cElem * sizeof(float); + + float *aHost = nullptr; + float *bHost = nullptr; + float *cSatHost = nullptr; + float *cNoSatHost = nullptr; + float *aDevice = nullptr; + float *bDevice = nullptr; + float *cSatDevice = nullptr; + float *cNoSatDevice = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + size_t inputSize = 0; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&aHost), aSize)); + ACL_CHECK(aclrtMallocHost((void **)(&bHost), bSize)); + ACL_CHECK(aclrtMallocHost((void **)(&cSatHost), cSize)); + ACL_CHECK(aclrtMallocHost((void **)(&cNoSatHost), cSize)); + ACL_CHECK(aclrtMalloc((void **)&aDevice, aSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&bDevice, bSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&cSatDevice, cSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&cNoSatDevice, cSize, ACL_MEM_MALLOC_HUGE_FIRST)); + + inputSize = aSize; + FILE_CHECK(ReadFile("./v1.bin", inputSize, aHost, aSize) && inputSize == aSize, + "./v1.bin"); + inputSize = bSize; + FILE_CHECK(ReadFile("./v2.bin", inputSize, bHost, bSize) && inputSize == bSize, + "./v2.bin"); + inputSize = cSize; + FILE_CHECK(ReadFile("./v3.bin", inputSize, cSatHost, cSize) && inputSize == cSize, + "./v3.bin"); + inputSize = cSize; + FILE_CHECK(ReadFile("./v4.bin", inputSize, cNoSatHost, cSize) && inputSize == cSize, + "./v4.bin"); + + ACL_CHECK(aclrtMemcpy(aDevice, aSize, aHost, aSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(bDevice, bSize, bHost, bSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(cSatDevice, cSize, cSatHost, cSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(cNoSatDevice, cSize, cNoSatHost, cSize, ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchMad_f32f32f32_kernel(aDevice, bDevice, cSatDevice, cNoSatDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + + ACL_CHECK(aclrtMemcpy(cSatHost, cSize, cSatDevice, cSize, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(cNoSatHost, cSize, cNoSatDevice, cSize, ACL_MEMCPY_DEVICE_TO_HOST)); + FILE_CHECK(WriteFile("./v3.bin", cSatHost, cSize), "./v3.bin"); + FILE_CHECK(WriteFile("./v4.bin", cNoSatHost, cSize), "./v4.bin"); + +cleanup: + aclrtFree(aDevice); + aclrtFree(bDevice); + aclrtFree(cSatDevice); + aclrtFree(cNoSatDevice); + aclrtFreeHost(aHost); + aclrtFreeHost(bHost); + aclrtFreeHost(cSatHost); + aclrtFreeHost(cNoSatHost); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_hif8/compare.py b/test/vpto/cases/micro-op/cube-matmul/mad_hif8/compare.py new file mode 100644 index 000000000..83a0ed35d --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_hif8/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.float32) + output = np.fromfile(output_path, dtype=np.float32) + if golden.shape != output.shape: + print(f"[ERROR] shape mismatch: {golden.shape} vs {output.shape}") + return False + if np.allclose(golden, output, atol=1e-2, rtol=1e-2): + return True + diff = np.where(np.abs(golden - output) > (1e-2 + 1e-2 * np.abs(golden)))[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] first mismatch at idx={idx}: golden={float(golden[idx])}, " + f"out={float(output[idx])}" + ) + return False + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin") + ok = compare_bin("golden_v4.bin", "v4.bin") and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_hif8/golden.py b/test/vpto/cases/micro-op/cube-matmul/mad_hif8/golden.py new file mode 100644 index 000000000..eb6b8e3e9 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_hif8/golden.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +M = 16 +N = 16 +K = 64 +M0 = 16 +K0 = 32 +N0 = 16 + + +def pack_lhs_cube_fractal(matrix: np.ndarray) -> np.ndarray: + m, k = matrix.shape + assert m % M0 == 0 and k % K0 == 0 + return matrix.reshape(m // M0, M0, k // K0, K0).transpose( + 2, 0, 1, 3 + ) + + +def pack_rhs_cube_fractal(matrix: np.ndarray) -> np.ndarray: + k, n = matrix.shape + assert k % K0 == 0 and n % N0 == 0 + return matrix.reshape(k // K0, K0, n // N0, N0).transpose( + 0, 2, 1, 3 + ) + + +def fp8_e4m3_to_f32(bits: np.ndarray) -> np.ndarray: + raw = bits.astype(np.uint8) + sign = np.where((raw & 0x80) != 0, -1.0, 1.0).astype(np.float32) + exponent = ((raw >> 3) & 0x0F).astype(np.int32) + mantissa = (raw & 0x07).astype(np.float32) + normal = exponent != 0 + value = np.where( + normal, + (1.0 + mantissa / 8.0) * np.exp2(exponent - 7), + (mantissa / 8.0) * np.exp2(-6), + ).astype(np.float32) + return sign * value + + +def generate(output_dir: Path) -> None: + # This kernel stages GM->L1 with raw byte copies, so v1/v2.bin must already + # be in the cube-fractal layout expected by mte_l1_l0a/l0b. + codes = np.array([0x40, 0xC0, 0x00, 0x40, 0xC0], dtype=np.uint8) + signed_units = np.array([1.0, -1.0, 0.0, 1.0, -1.0], dtype=np.float32) + + m_idx = np.arange(M).reshape(M, 1) + k_idx = np.arange(K).reshape(1, K) + a_index = (m_idx * 3 + k_idx * 0) % codes.size + a_logical = codes[a_index].astype(np.uint8) + a_unit = signed_units[a_index].astype(np.float32) + + k_idx = np.arange(K).reshape(K, 1) + n_idx = np.arange(N).reshape(1, N) + b_index = (k_idx * 0 + n_idx * 2 + 1) % codes.size + b_logical = codes[b_index].astype(np.uint8) + b_unit = signed_units[b_index].astype(np.float32) + + a = pack_lhs_cube_fractal(a_logical).reshape(-1).astype(np.uint8) + b = pack_rhs_cube_fractal(b_logical).reshape(-1).astype(np.uint8) + c_hif8 = np.zeros((M, N), dtype=np.float32) + c_fp8 = np.zeros((M, N), dtype=np.float32) + + golden_hif8 = (a_unit @ b_unit).astype(np.float32) * np.float32(128.0) + a_fp8 = fp8_e4m3_to_f32(a_logical) + b_fp8 = fp8_e4m3_to_f32(b_logical) + golden_fp8 = (a_fp8 @ b_fp8).astype(np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + a.reshape(-1).tofile(output_dir / "v1.bin") + b.reshape(-1).tofile(output_dir / "v2.bin") + c_hif8.reshape(-1).tofile(output_dir / "v3.bin") + c_fp8.reshape(-1).tofile(output_dir / "v4.bin") + golden_hif8.reshape(-1).tofile(output_dir / "golden_v3.bin") + golden_fp8.reshape(-1).tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_hif8/kernel.pto b/test/vpto/cases/micro-op/cube-matmul/mad_hif8/kernel.pto new file mode 100644 index 000000000..2881461a8 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_hif8/kernel.pto @@ -0,0 +1,93 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @mad_hif8_kernel(%a_gm: !pto.ptr, + %b_gm: !pto.ptr, + %c_hif8_gm: !pto.ptr, + %a_fp8_gm: !pto.ptr, + %b_fp8_gm: !pto.ptr, + %c_fp8_gm: !pto.ptr) attributes {pto.kernel} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c16_i64 = arith.constant 16 : i64 + %c64_i64 = arith.constant 64 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %false = arith.constant false + + %l1_a = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l1_b = pto.castptr %c1024_i64 : i64 -> !pto.ptr + %l0a = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0b = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0c = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0c_fp8 = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %l1_a_fp8 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l1_b_fp8 = pto.castptr %c1024_i64 : i64 -> !pto.ptr + %l0a_fp8 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0b_fp8 = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.mte_gm_l1 %a_gm, %l1_a, %c1024_i64 + nburst(%c1_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_gm_l1 %b_gm, %l1_b, %c1024_i64 + nburst(%c1_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + + pto.mte_l1_l0a %l1_a, %l0a, %c16_i64, %c64_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.mte_l1_l0b %l1_b, %l0b, %c64_i64, %c16_i64 {transpose = true} + : !pto.ptr, !pto.ptr, i64, i64 + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.mad %l0a, %l0b, %l0c, %c16_i64, %c16_i64, %c64_i64 + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + + pto.mte_l0c_gm %l0c, %c_hif8_gm, %c16_i64, %c16_i64, %c16_i64, %c16_i64, + %c0_i64, %c0_i64, + nz2nd + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + + pto.barrier #pto.pipe + + pto.mte_gm_l1 %a_fp8_gm, %l1_a_fp8, %c1024_i64 + nburst(%c1_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_gm_l1 %b_fp8_gm, %l1_b_fp8, %c1024_i64 + nburst(%c1_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID1"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID1"] + + pto.mte_l1_l0a %l1_a_fp8, %l0a_fp8, %c16_i64, %c64_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.mte_l1_l0b %l1_b_fp8, %l0b_fp8, %c64_i64, %c16_i64 {transpose = true} + : !pto.ptr, !pto.ptr, i64, i64 + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID1"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID1"] + pto.mad %l0a_fp8, %l0b_fp8, %l0c_fp8, %c16_i64, %c16_i64, %c64_i64 + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID0"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID0"] + + pto.mte_l0c_gm %l0c_fp8, %c_fp8_gm, %c16_i64, %c16_i64, %c16_i64, %c16_i64, + %c0_i64, %c0_i64, + nz2nd + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_hif8/launch.cpp b/test/vpto/cases/micro-op/cube-matmul/mad_hif8/launch.cpp new file mode 100644 index 000000000..6d447b046 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_hif8/launch.cpp @@ -0,0 +1,57 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif + +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void mad_hif8_kernel(__gm__ uint8_t *a, + __gm__ uint8_t *b, + __gm__ float *cHif8, + __gm__ uint8_t *aFp8, + __gm__ uint8_t *bFp8, + __gm__ float *cFp8); + +void LaunchMad_hif8_kernel(uint8_t *a, uint8_t *b, float *cHif8, + uint8_t *aFp8, uint8_t *bFp8, float *cFp8, + void *stream) { + mad_hif8_kernel<<<1, nullptr, stream>>>((__gm__ uint8_t *)a, + (__gm__ uint8_t *)b, + (__gm__ float *)cHif8, + (__gm__ uint8_t *)aFp8, + (__gm__ uint8_t *)bFp8, + (__gm__ float *)cFp8); +} diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_hif8/main.cpp b/test/vpto/cases/micro-op/cube-matmul/mad_hif8/main.cpp new file mode 100644 index 000000000..b3987cf5a --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_hif8/main.cpp @@ -0,0 +1,159 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +#define FILE_CHECK(expr, path) \ + do { \ + if (!(expr)) { \ + std::fprintf(stderr, "[ERROR] file operation failed: %s (%s:%d)\n", \ + path, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchMad_hif8_kernel(uint8_t *a, uint8_t *b, float *cHif8, + uint8_t *aFp8, uint8_t *bFp8, float *cFp8, + void *stream); + +int main() { + constexpr size_t kM = 16; + constexpr size_t kN = 16; + constexpr size_t kK = 64; + constexpr size_t aElem = kM * kK; + constexpr size_t bElem = kK * kN; + constexpr size_t cElem = kM * kN; + + constexpr size_t aSize = aElem * sizeof(uint8_t); + constexpr size_t bSize = bElem * sizeof(uint8_t); + constexpr size_t cSize = cElem * sizeof(float); + + uint8_t *aHost = nullptr; + uint8_t *bHost = nullptr; + uint8_t *aFp8Host = nullptr; + uint8_t *bFp8Host = nullptr; + float *cHif8Host = nullptr; + float *cFp8Host = nullptr; + uint8_t *aDevice = nullptr; + uint8_t *bDevice = nullptr; + uint8_t *aFp8Device = nullptr; + uint8_t *bFp8Device = nullptr; + float *cHif8Device = nullptr; + float *cFp8Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + size_t inputSize = 0; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&aHost), aSize)); + ACL_CHECK(aclrtMallocHost((void **)(&bHost), bSize)); + ACL_CHECK(aclrtMallocHost((void **)(&aFp8Host), aSize)); + ACL_CHECK(aclrtMallocHost((void **)(&bFp8Host), bSize)); + ACL_CHECK(aclrtMallocHost((void **)(&cHif8Host), cSize)); + ACL_CHECK(aclrtMallocHost((void **)(&cFp8Host), cSize)); + ACL_CHECK(aclrtMalloc((void **)&aDevice, aSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&bDevice, bSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&aFp8Device, aSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&bFp8Device, bSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&cHif8Device, cSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&cFp8Device, cSize, ACL_MEM_MALLOC_HUGE_FIRST)); + + inputSize = aSize; + FILE_CHECK(ReadFile("./v1.bin", inputSize, aHost, aSize) && inputSize == aSize, + "./v1.bin"); + inputSize = bSize; + FILE_CHECK(ReadFile("./v2.bin", inputSize, bHost, bSize) && inputSize == bSize, + "./v2.bin"); + std::memcpy(aFp8Host, aHost, aSize); + std::memcpy(bFp8Host, bHost, bSize); + inputSize = cSize; + FILE_CHECK(ReadFile("./v3.bin", inputSize, cHif8Host, cSize) && inputSize == cSize, + "./v3.bin"); + inputSize = cSize; + FILE_CHECK(ReadFile("./v4.bin", inputSize, cFp8Host, cSize) && inputSize == cSize, + "./v4.bin"); + + ACL_CHECK(aclrtMemcpy(aDevice, aSize, aHost, aSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(bDevice, bSize, bHost, bSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(aFp8Device, aSize, aFp8Host, aSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(bFp8Device, bSize, bFp8Host, bSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(cHif8Device, cSize, cHif8Host, cSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(cFp8Device, cSize, cFp8Host, cSize, ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchMad_hif8_kernel(aDevice, bDevice, cHif8Device, aFp8Device, bFp8Device, + cFp8Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + + ACL_CHECK(aclrtMemcpy(cHif8Host, cSize, cHif8Device, cSize, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(cFp8Host, cSize, cFp8Device, cSize, ACL_MEMCPY_DEVICE_TO_HOST)); + FILE_CHECK(WriteFile("./v3.bin", cHif8Host, cSize), "./v3.bin"); + FILE_CHECK(WriteFile("./v4.bin", cFp8Host, cSize), "./v4.bin"); + +cleanup: + aclrtFree(aDevice); + aclrtFree(bDevice); + aclrtFree(aFp8Device); + aclrtFree(bFp8Device); + aclrtFree(cHif8Device); + aclrtFree(cFp8Device); + aclrtFreeHost(aHost); + aclrtFreeHost(bHost); + aclrtFreeHost(aFp8Host); + aclrtFreeHost(bFp8Host); + aclrtFreeHost(cHif8Host); + aclrtFreeHost(cFp8Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_mx/compare.py b/test/vpto/cases/micro-op/cube-matmul/mad_mx/compare.py new file mode 100644 index 000000000..c1391455e --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_mx/compare.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.float32) + output = np.fromfile(output_path, dtype=np.float32) + if golden.shape != output.shape: + print(f"[ERROR] shape mismatch: {golden.shape} vs {output.shape}") + return False + if np.allclose(golden, output, atol=1e-2, rtol=1e-2): + return True + diff = np.where(np.abs(golden - output) > (1e-2 + 1e-2 * np.abs(golden)))[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] first mismatch at idx={idx}: golden={float(golden[idx])}, out={float(output[idx])}") + return False + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_mx/golden.py b/test/vpto/cases/micro-op/cube-matmul/mad_mx/golden.py new file mode 100644 index 000000000..ed32a73c7 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_mx/golden.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import argparse +from pathlib import Path + +import numpy as np + +M = 16 +N = 16 +K = 64 +SCALE_BYTES = 64 + + +def fp8_e4m3_to_f32(bits: np.ndarray) -> np.ndarray: + raw = bits.astype(np.uint8) + sign = np.where((raw & 0x80) != 0, -1.0, 1.0).astype(np.float32) + exponent = ((raw >> 3) & 0x0F).astype(np.int32) + mantissa = (raw & 0x07).astype(np.float32) + normal = exponent != 0 + value = np.where( + normal, + (1.0 + mantissa / 8.0) * np.exp2(exponent - 7), + (mantissa / 8.0) * np.exp2(-6), + ).astype(np.float32) + return sign * value + + +def e8m0_to_f32(bits: np.ndarray) -> np.ndarray: + return np.exp2(bits.astype(np.int32) - 127).astype(np.float32) + + +def pack_a_scale(a_scale: np.ndarray) -> np.ndarray: + packed = np.zeros(SCALE_BYTES, dtype=np.uint8) + packed[0:32] = a_scale.reshape(-1) + return packed + + +def pack_b_scale(b_scale: np.ndarray) -> np.ndarray: + packed = np.zeros(SCALE_BYTES, dtype=np.uint8) + packed[0:32] = b_scale.T.reshape(-1) + return packed + + +def generate(output_dir: Path) -> None: + # Values are exactly representable in FP8 E4M3: 0.5, 1.0, 2.0 and -1.0. + a_codes = np.array([0x30, 0x38, 0x40, 0xB8], dtype=np.uint8) + m_idx = np.arange(M).reshape(M, 1) + k_idx = np.arange(K).reshape(1, K) + a_matrix = a_codes[(m_idx * 3 + k_idx * 5) % a_codes.size] + b_matrix = np.full((K, N), 0x38, dtype=np.uint8) + + # E8M0 scale is 2^(byte - 127). The two K/32 groups use different scales, + # and A scales vary by M so the test catches incorrect scale grouping. + a_scale_matrix = np.where( + (np.arange(M).reshape(M, 1) + np.arange(2)) % 2 == 0, 127, 128 + ).astype(np.uint8) + b_scale_matrix = np.array([[126], [127]], dtype=np.uint8).repeat(N, axis=1) + a = a_matrix.reshape(-1).astype(np.uint8) + b = b_matrix.reshape(-1).astype(np.uint8) + a_scale = pack_a_scale(a_scale_matrix) + b_scale = pack_b_scale(b_scale_matrix) + c = np.zeros((M, N), dtype=np.float32) + + a_f32 = fp8_e4m3_to_f32(a_matrix) + b_f32 = fp8_e4m3_to_f32(b_matrix) + golden_c = np.zeros((M, N), dtype=np.float32) + a_scale_f32 = e8m0_to_f32(a_scale_matrix) + b_scale_f32 = e8m0_to_f32(b_scale_matrix) + for group in range(K // 32): + k_slice = slice(group * 32, (group + 1) * 32) + scaled_a = a_f32[:, k_slice] * a_scale_f32[:, group : group + 1] + scaled_b = b_f32[k_slice, :] * b_scale_f32[group : group + 1, :] + golden_c += scaled_a @ scaled_b + + output_dir.mkdir(parents=True, exist_ok=True) + a.tofile(output_dir / "v1.bin") + b.tofile(output_dir / "v2.bin") + a_scale.tofile(output_dir / "v4.bin") + b_scale.tofile(output_dir / "v5.bin") + c.reshape(-1).tofile(output_dir / "v3.bin") + golden_c.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_mx/kernel.pto b/test/vpto/cases/micro-op/cube-matmul/mad_mx/kernel.pto new file mode 100644 index 000000000..ea8a5f081 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_mx/kernel.pto @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @mad_mx_kernel(%a_gm: !pto.ptr, + %b_gm: !pto.ptr, + %a_scale_gm: !pto.ptr, + %b_scale_gm: !pto.ptr, + %c_gm: !pto.ptr) attributes {pto.kernel} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c16_i64 = arith.constant 16 : i64 + %c64_i64 = arith.constant 64 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c64_burst_i64 = arith.constant 64 : i64 + %c1088_i64 = arith.constant 1088 : i64 + %c2112_i64 = arith.constant 2112 : i64 + + %l1_a_data = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l1_a_scale = pto.castptr %c1024_i64 : i64 -> !pto.ptr + %l1_b_data = pto.castptr %c1088_i64 : i64 -> !pto.ptr + %l1_b_scale = pto.castptr %c2112_i64 : i64 -> !pto.ptr + %l0a = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0b = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0c = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.mte_gm_l1 %a_gm, %l1_a_data, %c1024_i64 + nburst(%c1_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_gm_l1 %a_scale_gm, %l1_a_scale, %c64_burst_i64 + nburst(%c1_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_gm_l1 %b_gm, %l1_b_data, %c1024_i64 + nburst(%c1_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_gm_l1 %b_scale_gm, %l1_b_scale, %c64_burst_i64 + nburst(%c1_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + + pto.mte_l1_l0a %l1_a_data, %l0a, %c16_i64, %c64_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.mte_l1_l0b %l1_b_data, %l0b, %c64_i64, %c16_i64 {transpose = true} + : !pto.ptr, !pto.ptr, i64, i64 + pto.mte_l1_l0a_mx %l1_a_scale, %l0a, %c16_i64, %c64_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.mte_l1_l0b_mx %l1_b_scale, %l0b, %c64_i64, %c16_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.mad_mx %l0a, %l0b, %l0c, %c16_i64, %c16_i64, %c64_i64 unit_flag(check_only) disable_gemv sat + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + + pto.mte_l0c_gm %l0c, %c_gm, %c16_i64, %c16_i64, %c16_i64, %c16_i64, %c0_i64, %c0_i64, + nz2nd + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_mx/launch.cpp b/test/vpto/cases/micro-op/cube-matmul/mad_mx/launch.cpp new file mode 100644 index 000000000..342078709 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_mx/launch.cpp @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif + +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void mad_mx_kernel(__gm__ uint8_t *a, + __gm__ uint8_t *b, + __gm__ uint8_t *a_scale, + __gm__ uint8_t *b_scale, + __gm__ float *c); + +void LaunchMad_mx_kernel(uint8_t *a, uint8_t *b, uint8_t *a_scale, + uint8_t *b_scale, float *c, void *stream) { + mad_mx_kernel<<<1, nullptr, stream>>>((__gm__ uint8_t *)a, + (__gm__ uint8_t *)b, + (__gm__ uint8_t *)a_scale, + (__gm__ uint8_t *)b_scale, + (__gm__ float *)c); +} diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_mx/main.cpp b/test/vpto/cases/micro-op/cube-matmul/mad_mx/main.cpp new file mode 100644 index 000000000..7180a2303 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_mx/main.cpp @@ -0,0 +1,157 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +#define FILE_CHECK(expr, path) \ + do { \ + if (!(expr)) { \ + std::fprintf(stderr, "[ERROR] file operation failed: %s (%s:%d)\n", \ + path, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchMad_mx_kernel(uint8_t *a, uint8_t *b, uint8_t *a_scale, + uint8_t *b_scale, float *c, void *stream); + +int main() { + constexpr size_t kM = 16; + constexpr size_t kN = 16; + constexpr size_t kK = 64; + constexpr size_t aElem = kM * kK; + constexpr size_t bElem = kK * kN; + constexpr size_t scaleElem = 64; + constexpr size_t cElem = kM * kN; + + constexpr size_t aSize = aElem * sizeof(uint8_t); + constexpr size_t bSize = bElem * sizeof(uint8_t); + constexpr size_t scaleSize = scaleElem * sizeof(uint8_t); + constexpr size_t cSize = cElem * sizeof(float); + + uint8_t *aHost = nullptr; + uint8_t *bHost = nullptr; + uint8_t *aScaleHost = nullptr; + uint8_t *bScaleHost = nullptr; + float *cHost = nullptr; + uint8_t *aDevice = nullptr; + uint8_t *bDevice = nullptr; + uint8_t *aScaleDevice = nullptr; + uint8_t *bScaleDevice = nullptr; + float *cDevice = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + size_t inputSize = 0; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&aHost), aSize)); + ACL_CHECK(aclrtMallocHost((void **)(&bHost), bSize)); + ACL_CHECK(aclrtMallocHost((void **)(&aScaleHost), scaleSize)); + ACL_CHECK(aclrtMallocHost((void **)(&bScaleHost), scaleSize)); + ACL_CHECK(aclrtMallocHost((void **)(&cHost), cSize)); + ACL_CHECK(aclrtMalloc((void **)&aDevice, aSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&bDevice, bSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK( + aclrtMalloc((void **)&aScaleDevice, scaleSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK( + aclrtMalloc((void **)&bScaleDevice, scaleSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&cDevice, cSize, ACL_MEM_MALLOC_HUGE_FIRST)); + + inputSize = aSize; + FILE_CHECK(ReadFile("./v1.bin", inputSize, aHost, aSize) && inputSize == aSize, + "./v1.bin"); + inputSize = bSize; + FILE_CHECK(ReadFile("./v2.bin", inputSize, bHost, bSize) && inputSize == bSize, + "./v2.bin"); + inputSize = scaleSize; + FILE_CHECK(ReadFile("./v4.bin", inputSize, aScaleHost, scaleSize) && + inputSize == scaleSize, + "./v4.bin"); + inputSize = scaleSize; + FILE_CHECK(ReadFile("./v5.bin", inputSize, bScaleHost, scaleSize) && + inputSize == scaleSize, + "./v5.bin"); + inputSize = cSize; + FILE_CHECK(ReadFile("./v3.bin", inputSize, cHost, cSize) && inputSize == cSize, + "./v3.bin"); + + ACL_CHECK(aclrtMemcpy(aDevice, aSize, aHost, aSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(bDevice, bSize, bHost, bSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(aScaleDevice, scaleSize, aScaleHost, scaleSize, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(bScaleDevice, scaleSize, bScaleHost, scaleSize, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(cDevice, cSize, cHost, cSize, ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchMad_mx_kernel(aDevice, bDevice, aScaleDevice, bScaleDevice, cDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + + ACL_CHECK(aclrtMemcpy(cHost, cSize, cDevice, cSize, ACL_MEMCPY_DEVICE_TO_HOST)); + FILE_CHECK(WriteFile("./v3.bin", cHost, cSize), "./v3.bin"); + +cleanup: + aclrtFree(aDevice); + aclrtFree(bDevice); + aclrtFree(aScaleDevice); + aclrtFree(bScaleDevice); + aclrtFree(cDevice); + aclrtFreeHost(aHost); + aclrtFreeHost(bHost); + aclrtFreeHost(aScaleHost); + aclrtFreeHost(bScaleHost); + aclrtFreeHost(cHost); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_mx_acc/compare.py b/test/vpto/cases/micro-op/cube-matmul/mad_mx_acc/compare.py new file mode 100644 index 000000000..c1391455e --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_mx_acc/compare.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.float32) + output = np.fromfile(output_path, dtype=np.float32) + if golden.shape != output.shape: + print(f"[ERROR] shape mismatch: {golden.shape} vs {output.shape}") + return False + if np.allclose(golden, output, atol=1e-2, rtol=1e-2): + return True + diff = np.where(np.abs(golden - output) > (1e-2 + 1e-2 * np.abs(golden)))[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] first mismatch at idx={idx}: golden={float(golden[idx])}, out={float(output[idx])}") + return False + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_mx_acc/golden.py b/test/vpto/cases/micro-op/cube-matmul/mad_mx_acc/golden.py new file mode 100644 index 000000000..f64876252 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_mx_acc/golden.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import argparse +from pathlib import Path + +import numpy as np + +M = 16 +N = 16 +K = 64 +SCALE_BYTES = 64 + + +def fp8_e4m3_to_f32(bits: np.ndarray) -> np.ndarray: + raw = bits.astype(np.uint8) + sign = np.where((raw & 0x80) != 0, -1.0, 1.0).astype(np.float32) + exponent = ((raw >> 3) & 0x0F).astype(np.int32) + mantissa = (raw & 0x07).astype(np.float32) + normal = exponent != 0 + value = np.where( + normal, + (1.0 + mantissa / 8.0) * np.exp2(exponent - 7), + (mantissa / 8.0) * np.exp2(-6), + ).astype(np.float32) + return sign * value + + +def e8m0_to_f32(bits: np.ndarray) -> np.ndarray: + return np.exp2(bits.astype(np.int32) - 127).astype(np.float32) + + +def pack_a_scale(a_scale: np.ndarray) -> np.ndarray: + packed = np.zeros(SCALE_BYTES, dtype=np.uint8) + packed[0:32] = a_scale.reshape(-1) + return packed + + +def pack_b_scale(b_scale: np.ndarray) -> np.ndarray: + packed = np.zeros(SCALE_BYTES, dtype=np.uint8) + packed[0:32] = b_scale.T.reshape(-1) + return packed + + +def generate(output_dir: Path) -> None: + # Values are exactly representable in FP8 E4M3: 0.5, 1.0, 2.0 and -1.0. + a_codes = np.array([0x30, 0x38, 0x40, 0xB8], dtype=np.uint8) + m_idx = np.arange(M).reshape(M, 1) + k_idx = np.arange(K).reshape(1, K) + a_matrix = a_codes[(m_idx * 3 + k_idx * 5) % a_codes.size] + b_matrix = np.full((K, N), 0x38, dtype=np.uint8) + + # E8M0 scale is 2^(byte - 127). The two K/32 groups use different scales, + # and A scales vary by M so the test catches incorrect scale grouping. + a_scale_matrix = np.where( + (np.arange(M).reshape(M, 1) + np.arange(2)) % 2 == 0, 127, 128 + ).astype(np.uint8) + b_scale_matrix = np.array([[126], [127]], dtype=np.uint8).repeat(N, axis=1) + a = a_matrix.reshape(-1).astype(np.uint8) + b = b_matrix.reshape(-1).astype(np.uint8) + a_scale = pack_a_scale(a_scale_matrix) + b_scale = pack_b_scale(b_scale_matrix) + c = np.zeros((M, N), dtype=np.float32) + + a_f32 = fp8_e4m3_to_f32(a_matrix) + b_f32 = fp8_e4m3_to_f32(b_matrix) + golden_c = np.zeros((M, N), dtype=np.float32) + a_scale_f32 = e8m0_to_f32(a_scale_matrix) + b_scale_f32 = e8m0_to_f32(b_scale_matrix) + for group in range(K // 32): + k_slice = slice(group * 32, (group + 1) * 32) + scaled_a = a_f32[:, k_slice] * a_scale_f32[:, group : group + 1] + scaled_b = b_f32[k_slice, :] * b_scale_f32[group : group + 1, :] + golden_c += scaled_a @ scaled_b + golden_c *= np.float32(2.0) + + output_dir.mkdir(parents=True, exist_ok=True) + a.tofile(output_dir / "v1.bin") + b.tofile(output_dir / "v2.bin") + a_scale.tofile(output_dir / "v4.bin") + b_scale.tofile(output_dir / "v5.bin") + c.reshape(-1).tofile(output_dir / "v3.bin") + golden_c.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_mx_acc/kernel.pto b/test/vpto/cases/micro-op/cube-matmul/mad_mx_acc/kernel.pto new file mode 100644 index 000000000..ddfef0485 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_mx_acc/kernel.pto @@ -0,0 +1,74 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @mad_mx_acc_kernel(%a_gm: !pto.ptr, + %b_gm: !pto.ptr, + %a_scale_gm: !pto.ptr, + %b_scale_gm: !pto.ptr, + %c_gm: !pto.ptr) attributes {pto.kernel} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c16_i64 = arith.constant 16 : i64 + %c64_i64 = arith.constant 64 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c64_burst_i64 = arith.constant 64 : i64 + %c1088_i64 = arith.constant 1088 : i64 + %c2112_i64 = arith.constant 2112 : i64 + + %l1_a_data = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l1_a_scale = pto.castptr %c1024_i64 : i64 -> !pto.ptr + %l1_b_data = pto.castptr %c1088_i64 : i64 -> !pto.ptr + %l1_b_scale = pto.castptr %c2112_i64 : i64 -> !pto.ptr + %l0a = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0b = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0c = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.mte_gm_l1 %a_gm, %l1_a_data, %c1024_i64 + nburst(%c1_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_gm_l1 %a_scale_gm, %l1_a_scale, %c64_burst_i64 + nburst(%c1_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_gm_l1 %b_gm, %l1_b_data, %c1024_i64 + nburst(%c1_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_gm_l1 %b_scale_gm, %l1_b_scale, %c64_burst_i64 + nburst(%c1_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + + pto.mte_l1_l0a %l1_a_data, %l0a, %c16_i64, %c64_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.mte_l1_l0b %l1_b_data, %l0b, %c64_i64, %c16_i64 {transpose = true} + : !pto.ptr, !pto.ptr, i64, i64 + pto.mte_l1_l0a_mx %l1_a_scale, %l0a, %c16_i64, %c64_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.mte_l1_l0b_mx %l1_b_scale, %l0b, %c64_i64, %c16_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.mad_mx %l0a, %l0b, %l0c, %c16_i64, %c16_i64, %c64_i64 unit_flag(check_only) disable_gemv sat + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + pto.barrier #pto.pipe + pto.mad_mx_acc %l0a, %l0b, %l0c, %c16_i64, %c16_i64, %c64_i64 unit_flag(check_and_set) disable_gemv nosat + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + + pto.mte_l0c_gm %l0c, %c_gm, %c16_i64, %c16_i64, %c16_i64, %c16_i64, + %c0_i64, %c0_i64, + nz2nd + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_mx_acc/launch.cpp b/test/vpto/cases/micro-op/cube-matmul/mad_mx_acc/launch.cpp new file mode 100644 index 000000000..d179071ed --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_mx_acc/launch.cpp @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif + +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void mad_mx_acc_kernel(__gm__ uint8_t *a, + __gm__ uint8_t *b, + __gm__ uint8_t *a_scale, + __gm__ uint8_t *b_scale, + __gm__ float *c); + +void LaunchMad_mx_acc_kernel(uint8_t *a, uint8_t *b, uint8_t *a_scale, + uint8_t *b_scale, float *c, void *stream) { + mad_mx_acc_kernel<<<1, nullptr, stream>>>((__gm__ uint8_t *)a, + (__gm__ uint8_t *)b, + (__gm__ uint8_t *)a_scale, + (__gm__ uint8_t *)b_scale, + (__gm__ float *)c); +} diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_mx_acc/main.cpp b/test/vpto/cases/micro-op/cube-matmul/mad_mx_acc/main.cpp new file mode 100644 index 000000000..c4772d9dc --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_mx_acc/main.cpp @@ -0,0 +1,157 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +#define FILE_CHECK(expr, path) \ + do { \ + if (!(expr)) { \ + std::fprintf(stderr, "[ERROR] file operation failed: %s (%s:%d)\n", \ + path, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchMad_mx_acc_kernel(uint8_t *a, uint8_t *b, uint8_t *a_scale, + uint8_t *b_scale, float *c, void *stream); + +int main() { + constexpr size_t kM = 16; + constexpr size_t kN = 16; + constexpr size_t kK = 64; + constexpr size_t aElem = kM * kK; + constexpr size_t bElem = kK * kN; + constexpr size_t scaleElem = 64; + constexpr size_t cElem = kM * kN; + + constexpr size_t aSize = aElem * sizeof(uint8_t); + constexpr size_t bSize = bElem * sizeof(uint8_t); + constexpr size_t scaleSize = scaleElem * sizeof(uint8_t); + constexpr size_t cSize = cElem * sizeof(float); + + uint8_t *aHost = nullptr; + uint8_t *bHost = nullptr; + uint8_t *aScaleHost = nullptr; + uint8_t *bScaleHost = nullptr; + float *cHost = nullptr; + uint8_t *aDevice = nullptr; + uint8_t *bDevice = nullptr; + uint8_t *aScaleDevice = nullptr; + uint8_t *bScaleDevice = nullptr; + float *cDevice = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + size_t inputSize = 0; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&aHost), aSize)); + ACL_CHECK(aclrtMallocHost((void **)(&bHost), bSize)); + ACL_CHECK(aclrtMallocHost((void **)(&aScaleHost), scaleSize)); + ACL_CHECK(aclrtMallocHost((void **)(&bScaleHost), scaleSize)); + ACL_CHECK(aclrtMallocHost((void **)(&cHost), cSize)); + ACL_CHECK(aclrtMalloc((void **)&aDevice, aSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&bDevice, bSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK( + aclrtMalloc((void **)&aScaleDevice, scaleSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK( + aclrtMalloc((void **)&bScaleDevice, scaleSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&cDevice, cSize, ACL_MEM_MALLOC_HUGE_FIRST)); + + inputSize = aSize; + FILE_CHECK(ReadFile("./v1.bin", inputSize, aHost, aSize) && inputSize == aSize, + "./v1.bin"); + inputSize = bSize; + FILE_CHECK(ReadFile("./v2.bin", inputSize, bHost, bSize) && inputSize == bSize, + "./v2.bin"); + inputSize = scaleSize; + FILE_CHECK(ReadFile("./v4.bin", inputSize, aScaleHost, scaleSize) && + inputSize == scaleSize, + "./v4.bin"); + inputSize = scaleSize; + FILE_CHECK(ReadFile("./v5.bin", inputSize, bScaleHost, scaleSize) && + inputSize == scaleSize, + "./v5.bin"); + inputSize = cSize; + FILE_CHECK(ReadFile("./v3.bin", inputSize, cHost, cSize) && inputSize == cSize, + "./v3.bin"); + + ACL_CHECK(aclrtMemcpy(aDevice, aSize, aHost, aSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(bDevice, bSize, bHost, bSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(aScaleDevice, scaleSize, aScaleHost, scaleSize, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(bScaleDevice, scaleSize, bScaleHost, scaleSize, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(cDevice, cSize, cHost, cSize, ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchMad_mx_acc_kernel(aDevice, bDevice, aScaleDevice, bScaleDevice, cDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + + ACL_CHECK(aclrtMemcpy(cHost, cSize, cDevice, cSize, ACL_MEMCPY_DEVICE_TO_HOST)); + FILE_CHECK(WriteFile("./v3.bin", cHost, cSize), "./v3.bin"); + +cleanup: + aclrtFree(aDevice); + aclrtFree(bDevice); + aclrtFree(aScaleDevice); + aclrtFree(bScaleDevice); + aclrtFree(cDevice); + aclrtFreeHost(aHost); + aclrtFreeHost(bHost); + aclrtFreeHost(aScaleHost); + aclrtFreeHost(bScaleHost); + aclrtFreeHost(cHost); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_mx_bias/compare.py b/test/vpto/cases/micro-op/cube-matmul/mad_mx_bias/compare.py new file mode 100644 index 000000000..c1391455e --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_mx_bias/compare.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.float32) + output = np.fromfile(output_path, dtype=np.float32) + if golden.shape != output.shape: + print(f"[ERROR] shape mismatch: {golden.shape} vs {output.shape}") + return False + if np.allclose(golden, output, atol=1e-2, rtol=1e-2): + return True + diff = np.where(np.abs(golden - output) > (1e-2 + 1e-2 * np.abs(golden)))[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] first mismatch at idx={idx}: golden={float(golden[idx])}, out={float(output[idx])}") + return False + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_mx_bias/golden.py b/test/vpto/cases/micro-op/cube-matmul/mad_mx_bias/golden.py new file mode 100644 index 000000000..e27f58fd0 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_mx_bias/golden.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import argparse +from pathlib import Path + +import numpy as np + +M = 16 +N = 16 +K = 64 +SCALE_BYTES = 64 + + +def fp8_e4m3_to_f32(bits: np.ndarray) -> np.ndarray: + raw = bits.astype(np.uint8) + sign = np.where((raw & 0x80) != 0, -1.0, 1.0).astype(np.float32) + exponent = ((raw >> 3) & 0x0F).astype(np.int32) + mantissa = (raw & 0x07).astype(np.float32) + normal = exponent != 0 + value = np.where( + normal, + (1.0 + mantissa / 8.0) * np.exp2(exponent - 7), + (mantissa / 8.0) * np.exp2(-6), + ).astype(np.float32) + return sign * value + + +def e8m0_to_f32(bits: np.ndarray) -> np.ndarray: + return np.exp2(bits.astype(np.int32) - 127).astype(np.float32) + + +def pack_a_scale(a_scale: np.ndarray) -> np.ndarray: + packed = np.zeros(SCALE_BYTES, dtype=np.uint8) + packed[0:32] = a_scale.reshape(-1) + return packed + + +def pack_b_scale(b_scale: np.ndarray) -> np.ndarray: + packed = np.zeros(SCALE_BYTES, dtype=np.uint8) + packed[0:32] = b_scale.T.reshape(-1) + return packed + + +def generate(output_dir: Path) -> None: + # Values are exactly representable in FP8 E4M3: 0.5, 1.0, 2.0 and -1.0. + a_codes = np.array([0x30, 0x38, 0x40, 0xB8], dtype=np.uint8) + m_idx = np.arange(M).reshape(M, 1) + k_idx = np.arange(K).reshape(1, K) + a_matrix = a_codes[(m_idx * 3 + k_idx * 5) % a_codes.size] + b_matrix = np.full((K, N), 0x38, dtype=np.uint8) + + # E8M0 scale is 2^(byte - 127). The two K/32 groups use different scales, + # and A scales vary by M so the test catches incorrect scale grouping. + a_scale_matrix = np.where( + (np.arange(M).reshape(M, 1) + np.arange(2)) % 2 == 0, 127, 128 + ).astype(np.uint8) + b_scale_matrix = np.array([[126], [127]], dtype=np.uint8).repeat(N, axis=1) + a = a_matrix.reshape(-1).astype(np.uint8) + b = b_matrix.reshape(-1).astype(np.uint8) + a_scale = pack_a_scale(a_scale_matrix) + b_scale = pack_b_scale(b_scale_matrix) + c = np.zeros((M, N), dtype=np.float32) + bias = ( + (((np.arange(N, dtype=np.float32) * 5) % 23) - 11).astype(np.float16) + / np.float16(4.0) + ) + + a_f32 = fp8_e4m3_to_f32(a_matrix) + b_f32 = fp8_e4m3_to_f32(b_matrix) + golden_c = np.zeros((M, N), dtype=np.float32) + a_scale_f32 = e8m0_to_f32(a_scale_matrix) + b_scale_f32 = e8m0_to_f32(b_scale_matrix) + for group in range(K // 32): + k_slice = slice(group * 32, (group + 1) * 32) + scaled_a = a_f32[:, k_slice] * a_scale_f32[:, group : group + 1] + scaled_b = b_f32[k_slice, :] * b_scale_f32[group : group + 1, :] + golden_c += scaled_a @ scaled_b + golden_c += bias.astype(np.float32)[None, :] + + output_dir.mkdir(parents=True, exist_ok=True) + a.tofile(output_dir / "v1.bin") + b.tofile(output_dir / "v2.bin") + a_scale.tofile(output_dir / "v4.bin") + b_scale.tofile(output_dir / "v5.bin") + bias.reshape(-1).tofile(output_dir / "v6.bin") + c.reshape(-1).tofile(output_dir / "v3.bin") + golden_c.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_mx_bias/kernel.pto b/test/vpto/cases/micro-op/cube-matmul/mad_mx_bias/kernel.pto new file mode 100644 index 000000000..946ad2bb0 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_mx_bias/kernel.pto @@ -0,0 +1,81 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @mad_mx_bias_kernel(%a_gm: !pto.ptr, + %b_gm: !pto.ptr, + %a_scale_gm: !pto.ptr, + %b_scale_gm: !pto.ptr, + %bias_gm: !pto.ptr, + %c_gm: !pto.ptr) attributes {pto.kernel} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c16_i64 = arith.constant 16 : i64 + %c64_i64 = arith.constant 64 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c64_burst_i64 = arith.constant 64 : i64 + %c1088_i64 = arith.constant 1088 : i64 + %c2112_i64 = arith.constant 2112 : i64 + %c2176_i64 = arith.constant 2176 : i64 + %c32_i64 = arith.constant 32 : i64 + + %l1_a_data = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l1_a_scale = pto.castptr %c1024_i64 : i64 -> !pto.ptr + %l1_b_data = pto.castptr %c1088_i64 : i64 -> !pto.ptr + %l1_b_scale = pto.castptr %c2112_i64 : i64 -> !pto.ptr + %l1_bias = pto.castptr %c2176_i64 : i64 -> !pto.ptr + %l0a = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0b = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0c = pto.castptr %c0_i64 : i64 -> !pto.ptr + %bt = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.mte_gm_l1 %a_gm, %l1_a_data, %c1024_i64 + nburst(%c1_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_gm_l1 %a_scale_gm, %l1_a_scale, %c64_burst_i64 + nburst(%c1_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_gm_l1 %b_gm, %l1_b_data, %c1024_i64 + nburst(%c1_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_gm_l1 %b_scale_gm, %l1_b_scale, %c64_burst_i64 + nburst(%c1_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_gm_l1 %bias_gm, %l1_bias, %c32_i64 + nburst(%c1_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + + pto.mte_l1_l0a %l1_a_data, %l0a, %c16_i64, %c64_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.mte_l1_l0b %l1_b_data, %l0b, %c64_i64, %c16_i64 {transpose = true} + : !pto.ptr, !pto.ptr, i64, i64 + pto.mte_l1_l0a_mx %l1_a_scale, %l0a, %c16_i64, %c64_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.mte_l1_l0b_mx %l1_b_scale, %l0b, %c64_i64, %c16_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.mte_l1_bt %l1_bias, %bt, %c16_i64 nburst(%c1_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.mad_mx_bias %l0a, %l0b, %l0c, %bt, %c16_i64, %c16_i64, %c64_i64 unit_flag(check_and_set) disable_gemv nosat + : !pto.ptr, !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + + pto.mte_l0c_gm %l0c, %c_gm, %c16_i64, %c16_i64, %c16_i64, %c16_i64, + %c0_i64, %c0_i64, + nz2nd + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_mx_bias/launch.cpp b/test/vpto/cases/micro-op/cube-matmul/mad_mx_bias/launch.cpp new file mode 100644 index 000000000..d37e689b2 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_mx_bias/launch.cpp @@ -0,0 +1,57 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif + +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void mad_mx_bias_kernel(__gm__ uint8_t *a, + __gm__ uint8_t *b, + __gm__ uint8_t *a_scale, + __gm__ uint8_t *b_scale, + __gm__ __fp16 *bias, + __gm__ float *c); + +void LaunchMad_mx_bias_kernel(uint8_t *a, uint8_t *b, uint8_t *a_scale, + uint8_t *b_scale, __fp16 *bias, float *c, + void *stream) { + mad_mx_bias_kernel<<<1, nullptr, stream>>>((__gm__ uint8_t *)a, + (__gm__ uint8_t *)b, + (__gm__ uint8_t *)a_scale, + (__gm__ uint8_t *)b_scale, + (__gm__ __fp16 *)bias, + (__gm__ float *)c); +} diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_mx_bias/main.cpp b/test/vpto/cases/micro-op/cube-matmul/mad_mx_bias/main.cpp new file mode 100644 index 000000000..8183ab58d --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_mx_bias/main.cpp @@ -0,0 +1,172 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +#define FILE_CHECK(expr, path) \ + do { \ + if (!(expr)) { \ + std::fprintf(stderr, "[ERROR] file operation failed: %s (%s:%d)\n", \ + path, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchMad_mx_bias_kernel(uint8_t *a, uint8_t *b, uint8_t *a_scale, + uint8_t *b_scale, __fp16 *bias, float *c, + void *stream); + +int main() { + constexpr size_t kM = 16; + constexpr size_t kN = 16; + constexpr size_t kK = 64; + constexpr size_t aElem = kM * kK; + constexpr size_t bElem = kK * kN; + constexpr size_t scaleElem = 64; + constexpr size_t biasElem = kN; + constexpr size_t cElem = kM * kN; + + constexpr size_t aSize = aElem * sizeof(uint8_t); + constexpr size_t bSize = bElem * sizeof(uint8_t); + constexpr size_t scaleSize = scaleElem * sizeof(uint8_t); + constexpr size_t biasSize = biasElem * sizeof(__fp16); + constexpr size_t cSize = cElem * sizeof(float); + + uint8_t *aHost = nullptr; + uint8_t *bHost = nullptr; + uint8_t *aScaleHost = nullptr; + uint8_t *bScaleHost = nullptr; + __fp16 *biasHost = nullptr; + float *cHost = nullptr; + uint8_t *aDevice = nullptr; + uint8_t *bDevice = nullptr; + uint8_t *aScaleDevice = nullptr; + uint8_t *bScaleDevice = nullptr; + __fp16 *biasDevice = nullptr; + float *cDevice = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + size_t inputSize = 0; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&aHost), aSize)); + ACL_CHECK(aclrtMallocHost((void **)(&bHost), bSize)); + ACL_CHECK(aclrtMallocHost((void **)(&aScaleHost), scaleSize)); + ACL_CHECK(aclrtMallocHost((void **)(&bScaleHost), scaleSize)); + ACL_CHECK(aclrtMallocHost((void **)(&biasHost), biasSize)); + ACL_CHECK(aclrtMallocHost((void **)(&cHost), cSize)); + ACL_CHECK(aclrtMalloc((void **)&aDevice, aSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&bDevice, bSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK( + aclrtMalloc((void **)&aScaleDevice, scaleSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK( + aclrtMalloc((void **)&bScaleDevice, scaleSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&biasDevice, biasSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&cDevice, cSize, ACL_MEM_MALLOC_HUGE_FIRST)); + + inputSize = aSize; + FILE_CHECK(ReadFile("./v1.bin", inputSize, aHost, aSize) && inputSize == aSize, + "./v1.bin"); + inputSize = bSize; + FILE_CHECK(ReadFile("./v2.bin", inputSize, bHost, bSize) && inputSize == bSize, + "./v2.bin"); + inputSize = scaleSize; + FILE_CHECK(ReadFile("./v4.bin", inputSize, aScaleHost, scaleSize) && + inputSize == scaleSize, + "./v4.bin"); + inputSize = scaleSize; + FILE_CHECK(ReadFile("./v5.bin", inputSize, bScaleHost, scaleSize) && + inputSize == scaleSize, + "./v5.bin"); + inputSize = biasSize; + FILE_CHECK(ReadFile("./v6.bin", inputSize, biasHost, biasSize) && + inputSize == biasSize, + "./v6.bin"); + inputSize = cSize; + FILE_CHECK(ReadFile("./v3.bin", inputSize, cHost, cSize) && inputSize == cSize, + "./v3.bin"); + + ACL_CHECK(aclrtMemcpy(aDevice, aSize, aHost, aSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(bDevice, bSize, bHost, bSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(aScaleDevice, scaleSize, aScaleHost, scaleSize, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(bScaleDevice, scaleSize, bScaleHost, scaleSize, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(biasDevice, biasSize, biasHost, biasSize, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(cDevice, cSize, cHost, cSize, ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchMad_mx_bias_kernel(aDevice, bDevice, aScaleDevice, bScaleDevice, + biasDevice, cDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + + ACL_CHECK(aclrtMemcpy(cHost, cSize, cDevice, cSize, ACL_MEMCPY_DEVICE_TO_HOST)); + FILE_CHECK(WriteFile("./v3.bin", cHost, cSize), "./v3.bin"); + +cleanup: + aclrtFree(aDevice); + aclrtFree(bDevice); + aclrtFree(aScaleDevice); + aclrtFree(bScaleDevice); + aclrtFree(biasDevice); + aclrtFree(cDevice); + aclrtFreeHost(aHost); + aclrtFreeHost(bHost); + aclrtFreeHost(aScaleHost); + aclrtFreeHost(bScaleHost); + aclrtFreeHost(biasHost); + aclrtFreeHost(cHost); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_tf32/compare.py b/test/vpto/cases/micro-op/cube-matmul/mad_tf32/compare.py new file mode 100644 index 000000000..d650ba261 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_tf32/compare.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.float32) + output = np.fromfile(output_path, dtype=np.float32) + if golden.shape != output.shape: + print(f"[ERROR] shape mismatch: {golden.shape} vs {output.shape}") + return False + if np.allclose(golden, output, atol=2.5e-3, rtol=0.0): + return True + diff = np.where(np.abs(golden - output) > 2.5e-3)[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] first mismatch at idx={idx}: golden={float(golden[idx])}, out={float(output[idx])}") + return False + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin") + ok = compare_bin("golden_v4.bin", "v4.bin") and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_tf32/golden.py b/test/vpto/cases/micro-op/cube-matmul/mad_tf32/golden.py new file mode 100644 index 000000000..896b2c3b8 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_tf32/golden.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import argparse +from pathlib import Path +import numpy as np + +M = 16 +N = 16 +K = 16 + + +def tf32_round_even(x: np.ndarray) -> np.ndarray: + bits = x.astype(np.float32).view(np.uint32) + lsb = (bits >> np.uint32(13)) & np.uint32(1) + rounded = bits + np.uint32(0x00000FFF) + lsb + return (rounded & np.uint32(0xFFFFE000)).view(np.float32) + + +def generate(output_dir: Path) -> None: + row = np.arange(M, dtype=np.float32).reshape(M, 1) + col = np.arange(K, dtype=np.float32).reshape(1, K) + a_base = (((row * 11 + col * 3) % 31) - 15).astype(np.float32) / 7.0 + a_perturb = (((row * 5 + col * 9) % 17) + 1).astype(np.float32) + a = a_base + a_perturb * np.float32(2.0 ** -13) + k_idx = np.arange(K, dtype=np.float32).reshape(K, 1) + n_idx = np.arange(N, dtype=np.float32).reshape(1, N) + b_base = (((k_idx * 5 - n_idx * 13) % 37) - 18).astype(np.float32) / 9.0 + b_perturb = (((k_idx * 7 + n_idx * 3) % 19) + 1).astype(np.float32) + b = b_base - b_perturb * np.float32(2.0 ** -13) + c_tf32 = np.zeros((M, N), dtype=np.float32) + c_plain = np.zeros((M, N), dtype=np.float32) + golden_tf32 = tf32_round_even(a) @ tf32_round_even(b) + plain_fp32_c = a @ b + max_tf32_delta = float(np.max(np.abs(plain_fp32_c - golden_tf32))) + print(f"[INFO] max plain-fp32-vs-tf32 delta: {max_tf32_delta}") + + output_dir.mkdir(parents=True, exist_ok=True) + a.reshape(-1).tofile(output_dir / "v1.bin") + b.reshape(-1).tofile(output_dir / "v2.bin") + c_tf32.reshape(-1).tofile(output_dir / "v3.bin") + c_plain.reshape(-1).tofile(output_dir / "v4.bin") + golden_tf32.reshape(-1).tofile(output_dir / "golden_v3.bin") + plain_fp32_c.reshape(-1).tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_tf32/kernel.pto b/test/vpto/cases/micro-op/cube-matmul/mad_tf32/kernel.pto new file mode 100644 index 000000000..0df65db39 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_tf32/kernel.pto @@ -0,0 +1,80 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @mad_tf32_kernel(%a_gm: !pto.ptr, + %b_gm: !pto.ptr, + %c_tf32_gm: !pto.ptr, + %c_plain_gm: !pto.ptr) attributes {pto.kernel} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c16_i64 = arith.constant 16 : i64 + %c64_i64 = arith.constant 64 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %false = arith.constant false + + %l1_a = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l1_b = pto.castptr %c1024_i64 : i64 -> !pto.ptr + %l0a = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0b = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0c = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.mte_gm_l1_frac %a_gm, %l1_a, nd2nz, + shape(%c16_i64, %c16_i64), + src_layout(%c64_i64), + dst_group(%c1_i64, %c1_i64, %c16_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1_frac %b_gm, %l1_b, nd2nz, + shape(%c16_i64, %c16_i64), + src_layout(%c64_i64), + dst_group(%c1_i64, %c1_i64, %c16_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + + pto.mte_l1_l0a %l1_a, %l0a, %c16_i64, %c16_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.mte_l1_l0b %l1_b, %l0b, %c16_i64, %c16_i64 {transpose = true} + : !pto.ptr, !pto.ptr, i64, i64 + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.mad %l0a, %l0b, %l0c, %c16_i64, %c16_i64, %c16_i64 tf32_mode(round_even) + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + + pto.mte_l0c_gm %l0c, %c_tf32_gm, %c16_i64, %c16_i64, %c16_i64, %c16_i64, + %c0_i64, %c0_i64, + nz2nd + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_FIX", "PIPE_M", "EVENT_ID1"] + pto.wait_flag["PIPE_FIX", "PIPE_M", "EVENT_ID1"] + + pto.mad %l0a, %l0b, %l0c, %c16_i64, %c16_i64, %c16_i64 + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + + pto.mte_l0c_gm %l0c, %c_plain_gm, %c16_i64, %c16_i64, %c16_i64, %c16_i64, + %c0_i64, %c0_i64, + nz2nd + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_tf32/launch.cpp b/test/vpto/cases/micro-op/cube-matmul/mad_tf32/launch.cpp new file mode 100644 index 000000000..5b66d6cd2 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_tf32/launch.cpp @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif + +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void mad_tf32_kernel(__gm__ float *a, + __gm__ float *b, + __gm__ float *c_tf32, + __gm__ float *c_plain); + +void LaunchMad_tf32_kernel(float *a, float *b, float *cTf32, float *cPlain, + void *stream) { + mad_tf32_kernel<<<1, nullptr, stream>>>((__gm__ float *)a, + (__gm__ float *)b, + (__gm__ float *)cTf32, + (__gm__ float *)cPlain); +} diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_tf32/main.cpp b/test/vpto/cases/micro-op/cube-matmul/mad_tf32/main.cpp new file mode 100644 index 000000000..c8e57ef4e --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_tf32/main.cpp @@ -0,0 +1,140 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +#define FILE_CHECK(expr, path) \ + do { \ + if (!(expr)) { \ + std::fprintf(stderr, "[ERROR] file operation failed: %s (%s:%d)\n", \ + path, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchMad_tf32_kernel(float *a, float *b, float *cTf32, float *cPlain, + void *stream); + +int main() { + constexpr size_t kM = 16; + constexpr size_t kN = 16; + constexpr size_t kK = 16; + constexpr size_t aElem = kM * kK; + constexpr size_t bElem = kK * kN; + constexpr size_t cElem = kM * kN; + + constexpr size_t aSize = aElem * sizeof(float); + constexpr size_t bSize = bElem * sizeof(float); + constexpr size_t cSize = cElem * sizeof(float); + + float *aHost = nullptr; + float *bHost = nullptr; + float *cTf32Host = nullptr; + float *cPlainHost = nullptr; + float *aDevice = nullptr; + float *bDevice = nullptr; + float *cTf32Device = nullptr; + float *cPlainDevice = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + size_t inputSize = 0; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&aHost), aSize)); + ACL_CHECK(aclrtMallocHost((void **)(&bHost), bSize)); + ACL_CHECK(aclrtMallocHost((void **)(&cTf32Host), cSize)); + ACL_CHECK(aclrtMallocHost((void **)(&cPlainHost), cSize)); + ACL_CHECK(aclrtMalloc((void **)&aDevice, aSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&bDevice, bSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&cTf32Device, cSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&cPlainDevice, cSize, ACL_MEM_MALLOC_HUGE_FIRST)); + + inputSize = aSize; + FILE_CHECK(ReadFile("./v1.bin", inputSize, aHost, aSize) && inputSize == aSize, + "./v1.bin"); + inputSize = bSize; + FILE_CHECK(ReadFile("./v2.bin", inputSize, bHost, bSize) && inputSize == bSize, + "./v2.bin"); + inputSize = cSize; + FILE_CHECK(ReadFile("./v3.bin", inputSize, cTf32Host, cSize) && inputSize == cSize, + "./v3.bin"); + inputSize = cSize; + FILE_CHECK(ReadFile("./v4.bin", inputSize, cPlainHost, cSize) && inputSize == cSize, + "./v4.bin"); + + ACL_CHECK(aclrtMemcpy(aDevice, aSize, aHost, aSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(bDevice, bSize, bHost, bSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(cTf32Device, cSize, cTf32Host, cSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(cPlainDevice, cSize, cPlainHost, cSize, ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchMad_tf32_kernel(aDevice, bDevice, cTf32Device, cPlainDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + + ACL_CHECK(aclrtMemcpy(cTf32Host, cSize, cTf32Device, cSize, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(cPlainHost, cSize, cPlainDevice, cSize, ACL_MEMCPY_DEVICE_TO_HOST)); + FILE_CHECK(WriteFile("./v3.bin", cTf32Host, cSize), "./v3.bin"); + FILE_CHECK(WriteFile("./v4.bin", cPlainHost, cSize), "./v4.bin"); + +cleanup: + aclrtFree(aDevice); + aclrtFree(bDevice); + aclrtFree(cTf32Device); + aclrtFree(cPlainDevice); + aclrtFreeHost(aHost); + aclrtFreeHost(bHost); + aclrtFreeHost(cTf32Host); + aclrtFreeHost(cPlainHost); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_tf32_round_away/compare.py b/test/vpto/cases/micro-op/cube-matmul/mad_tf32_round_away/compare.py new file mode 100644 index 000000000..964a73c9b --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_tf32_round_away/compare.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.float32) + output = np.fromfile(output_path, dtype=np.float32) + if golden.shape != output.shape: + print(f"[ERROR] shape mismatch: {golden.shape} vs {output.shape}") + return False + if np.allclose(golden, output, atol=1e-4, rtol=0.0): + return True + diff = np.where(np.abs(golden - output) > 1e-4)[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] first mismatch at idx={idx}: golden={float(golden[idx])}, out={float(output[idx])}") + return False + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_tf32_round_away/golden.py b/test/vpto/cases/micro-op/cube-matmul/mad_tf32_round_away/golden.py new file mode 100644 index 000000000..7f7a1d9c7 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_tf32_round_away/golden.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import argparse +from pathlib import Path +import numpy as np + +M = 16 +N = 16 +K = 16 + + +def tf32_round_away(x: np.ndarray) -> np.ndarray: + bits = x.astype(np.float32).view(np.uint32) + rounded = bits + np.uint32(0x00001000) + return (rounded & np.uint32(0xFFFFE000)).view(np.float32) + + +def tf32_round_even(x: np.ndarray) -> np.ndarray: + bits = x.astype(np.float32).view(np.uint32) + lsb = (bits >> np.uint32(13)) & np.uint32(1) + rounded = bits + np.uint32(0x00000FFF) + lsb + return (rounded & np.uint32(0xFFFFE000)).view(np.float32) + + +def generate(output_dir: Path) -> None: + row = np.arange(M, dtype=np.uint32).reshape(M, 1) + col = np.arange(K, dtype=np.uint32).reshape(1, K) + a_sign = ((row + col) & np.uint32(1)) << np.uint32(31) + a_mant = ((row * np.uint32(29) + col * np.uint32(37)) % np.uint32(512)) + a_bits = a_sign | np.uint32(0x3F800000) | (a_mant << np.uint32(13)) | np.uint32(0x1000) + a = a_bits.astype(np.uint32).view(np.float32) + + k_idx = np.arange(K, dtype=np.uint32).reshape(K, 1) + n_idx = np.arange(N, dtype=np.uint32).reshape(1, N) + b_sign = ((k_idx * np.uint32(3) + n_idx) & np.uint32(1)) << np.uint32(31) + b_mant = ((k_idx * np.uint32(41) + n_idx * np.uint32(11)) % np.uint32(512)) + b_bits = b_sign | np.uint32(0x3F800000) | (b_mant << np.uint32(13)) | np.uint32(0x1000) + b = b_bits.astype(np.uint32).view(np.float32) + c = np.zeros((M, N), dtype=np.float32) + golden_c = tf32_round_away(a) @ tf32_round_away(b) + round_even_c = tf32_round_even(a) @ tf32_round_even(b) + plain_fp32_c = a @ b + max_tf32_delta = float(np.max(np.abs(plain_fp32_c - golden_c))) + max_round_mode_delta = float(np.max(np.abs(round_even_c - golden_c))) + print(f"[INFO] max plain-fp32-vs-tf32 delta: {max_tf32_delta}") + print(f"[INFO] max round-even-vs-round-away delta: {max_round_mode_delta}") + + output_dir.mkdir(parents=True, exist_ok=True) + a.reshape(-1).tofile(output_dir / "v1.bin") + b.reshape(-1).tofile(output_dir / "v2.bin") + c.reshape(-1).tofile(output_dir / "v3.bin") + golden_c.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_tf32_round_away/kernel.pto b/test/vpto/cases/micro-op/cube-matmul/mad_tf32_round_away/kernel.pto new file mode 100644 index 000000000..23f1d8d70 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_tf32_round_away/kernel.pto @@ -0,0 +1,65 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @mad_tf32_round_away_kernel(%a_gm: !pto.ptr, + %b_gm: !pto.ptr, + %c_gm: !pto.ptr) attributes {pto.kernel} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c16_i64 = arith.constant 16 : i64 + %c64_i64 = arith.constant 64 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %false = arith.constant false + + %l1_a = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l1_b = pto.castptr %c1024_i64 : i64 -> !pto.ptr + %l0a = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0b = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0c = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.mte_gm_l1_frac %a_gm, %l1_a, nd2nz, + shape(%c16_i64, %c16_i64), + src_layout(%c64_i64), + dst_group(%c1_i64, %c1_i64, %c16_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1_frac %b_gm, %l1_b, nd2nz, + shape(%c16_i64, %c16_i64), + src_layout(%c64_i64), + dst_group(%c1_i64, %c1_i64, %c16_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + + pto.mte_l1_l0a %l1_a, %l0a, %c16_i64, %c16_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.mte_l1_l0b %l1_b, %l0b, %c16_i64, %c16_i64 {transpose = true} + : !pto.ptr, !pto.ptr, i64, i64 + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.mad %l0a, %l0b, %l0c, %c16_i64, %c16_i64, %c16_i64 tf32_mode(round_away) + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + + pto.mte_l0c_gm %l0c, %c_gm, %c16_i64, %c16_i64, %c16_i64, %c16_i64, + %c0_i64, %c0_i64, + nz2nd + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_tf32_round_away/launch.cpp b/test/vpto/cases/micro-op/cube-matmul/mad_tf32_round_away/launch.cpp new file mode 100644 index 000000000..6216b11a4 --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_tf32_round_away/launch.cpp @@ -0,0 +1,49 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif + +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void mad_tf32_round_away_kernel(__gm__ float *a, + __gm__ float *b, + __gm__ float *c); + +void LaunchMad_tf32_round_away_kernel(float *a, float *b, float *c, void *stream) { + mad_tf32_round_away_kernel<<<1, nullptr, stream>>>((__gm__ float *)a, + (__gm__ float *)b, + (__gm__ float *)c); +} diff --git a/test/vpto/cases/micro-op/cube-matmul/mad_tf32_round_away/main.cpp b/test/vpto/cases/micro-op/cube-matmul/mad_tf32_round_away/main.cpp new file mode 100644 index 000000000..3f0122f5b --- /dev/null +++ b/test/vpto/cases/micro-op/cube-matmul/mad_tf32_round_away/main.cpp @@ -0,0 +1,127 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +#define FILE_CHECK(expr, path) \ + do { \ + if (!(expr)) { \ + std::fprintf(stderr, "[ERROR] file operation failed: %s (%s:%d)\n", \ + path, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchMad_tf32_round_away_kernel(float *a, float *b, float *c, void *stream); + +int main() { + constexpr size_t kM = 16; + constexpr size_t kN = 16; + constexpr size_t kK = 16; + constexpr size_t aElem = kM * kK; + constexpr size_t bElem = kK * kN; + constexpr size_t cElem = kM * kN; + + constexpr size_t aSize = aElem * sizeof(float); + constexpr size_t bSize = bElem * sizeof(float); + constexpr size_t cSize = cElem * sizeof(float); + + float *aHost = nullptr; + float *bHost = nullptr; + float *cHost = nullptr; + float *aDevice = nullptr; + float *bDevice = nullptr; + float *cDevice = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + size_t inputSize = 0; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&aHost), aSize)); + ACL_CHECK(aclrtMallocHost((void **)(&bHost), bSize)); + ACL_CHECK(aclrtMallocHost((void **)(&cHost), cSize)); + ACL_CHECK(aclrtMalloc((void **)&aDevice, aSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&bDevice, bSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&cDevice, cSize, ACL_MEM_MALLOC_HUGE_FIRST)); + + inputSize = aSize; + FILE_CHECK(ReadFile("./v1.bin", inputSize, aHost, aSize) && inputSize == aSize, + "./v1.bin"); + inputSize = bSize; + FILE_CHECK(ReadFile("./v2.bin", inputSize, bHost, bSize) && inputSize == bSize, + "./v2.bin"); + inputSize = cSize; + FILE_CHECK(ReadFile("./v3.bin", inputSize, cHost, cSize) && inputSize == cSize, + "./v3.bin"); + + ACL_CHECK(aclrtMemcpy(aDevice, aSize, aHost, aSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(bDevice, bSize, bHost, bSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(cDevice, cSize, cHost, cSize, ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchMad_tf32_round_away_kernel(aDevice, bDevice, cDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + + ACL_CHECK(aclrtMemcpy(cHost, cSize, cDevice, cSize, ACL_MEMCPY_DEVICE_TO_HOST)); + FILE_CHECK(WriteFile("./v3.bin", cHost, cSize), "./v3.bin"); + +cleanup: + aclrtFree(aDevice); + aclrtFree(bDevice); + aclrtFree(cDevice); + aclrtFreeHost(aHost); + aclrtFreeHost(bHost); + aclrtFreeHost(cHost); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-store-atomic-f32-cv/compare.py b/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-store-atomic-f32-cv/compare.py new file mode 100644 index 000000000..d846b834d --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-store-atomic-f32-cv/compare.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + print(f"[ERROR] missing file: {golden_path} or {output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.float32) + output = np.fromfile(output_path, dtype=np.float32) + if golden.shape != output.shape: + print(f"[ERROR] shape mismatch: {golden.shape} vs {output.shape}") + return False + if np.allclose(golden, output, atol=1e-4, rtol=1e-4): + return True + diff = np.where(np.abs(golden - output) > (1e-4 + 1e-4 * np.abs(golden)))[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] first mismatch at idx={idx}: " + f"golden={float(golden[idx])}, out={float(output[idx])}" + ) + return False + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + for index in range(3, 7): + ok = compare_bin(f"golden_v{index}.bin", f"v{index}.bin") and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-store-atomic-f32-cv/golden.py b/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-store-atomic-f32-cv/golden.py new file mode 100644 index 000000000..e8a8345e7 --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-store-atomic-f32-cv/golden.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + + +M = 40 +N = 64 +K = 50 +ATOMIC_ADD_INIT = np.float32(1.25) +ATOMIC_DELTA = np.float32(0.5) +SEED = 419 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + lhs = rng.uniform(-2.0, 2.0, size=(M, K)).astype(np.float16) + rhs = rng.uniform(-1.5, 1.5, size=(K, N)).astype(np.float16) + lhs32 = lhs.astype(np.float32) + rhs32 = rhs.astype(np.float32) + matmul = np.zeros((M, N), dtype=np.float32) + for k_idx in range(K): + matmul += lhs32[:, k_idx:k_idx + 1] * rhs32[k_idx:k_idx + 1, :] + + plain_init = np.zeros((M, N), dtype=np.float32) + atomic_add_init = np.full((M, N), ATOMIC_ADD_INIT, dtype=np.float32) + parity = ((np.arange(M * N, dtype=np.int32).reshape(M, N) & 1) * 2 - 1).astype(np.float32) + atomic_max_init = matmul + parity * ATOMIC_DELTA + atomic_min_init = matmul - parity * ATOMIC_DELTA + atomic_add_golden = atomic_add_init + matmul + atomic_max_golden = np.maximum(atomic_max_init, matmul) + atomic_min_golden = np.minimum(atomic_min_init, matmul) + + output_dir.mkdir(parents=True, exist_ok=True) + lhs.reshape(-1).tofile(output_dir / "v1.bin") + rhs.reshape(-1).tofile(output_dir / "v2.bin") + plain_init.reshape(-1).tofile(output_dir / "v3.bin") + atomic_add_init.reshape(-1).tofile(output_dir / "v4.bin") + atomic_max_init.reshape(-1).tofile(output_dir / "v5.bin") + atomic_min_init.reshape(-1).tofile(output_dir / "v6.bin") + matmul.reshape(-1).tofile(output_dir / "golden_v3.bin") + atomic_add_golden.reshape(-1).tofile(output_dir / "golden_v4.bin") + atomic_max_golden.reshape(-1).tofile(output_dir / "golden_v5.bin") + atomic_min_golden.reshape(-1).tofile(output_dir / "golden_v6.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-store-atomic-f32-cv/kernel.pto b/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-store-atomic-f32-cv/kernel.pto new file mode 100644 index 000000000..bcb1d9b43 --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-store-atomic-f32-cv/kernel.pto @@ -0,0 +1,105 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/cv-mixed/fixpipe-acc-store-atomic-f32-cv +// family: cv-mixed +// target_ops: pto.mte_gm_l1_frac, pto.mte_l1_l0a, pto.mte_l1_l0b, pto.mad, +// pto.mte_l0c_gm, pto.sync.set, pto.sync.wait +// scenarios: acc-store-gm-atomic-add-max-min, f32, nz2nd, strict-matmul-golden +// ----------------------------------------------------------------------------- + +module attributes {pto.target_arch = "a5"} { + func.func @fixpipe_acc_store_atomic_f32_cv_kernel(%src_gm: !pto.ptr, + %id_gm: !pto.ptr, + %out_plain_gm: !pto.ptr, + %out_atomic_add_gm: !pto.ptr, + %out_atomic_max_gm: !pto.ptr, + %out_atomic_min_gm: !pto.ptr) + attributes {pto.kernel} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c40_i64 = arith.constant 40 : i64 + %c48_i64 = arith.constant 48 : i64 + %c50_i64 = arith.constant 50 : i64 + %c64_i64 = arith.constant 64 : i64 + %c100_i64 = arith.constant 100 : i64 + %c128_i64 = arith.constant 128 : i64 + %c2560_i64 = arith.constant 2560 : i64 + %c49152_i64 = arith.constant 49152 : i64 + %c65536_i64 = arith.constant 65536 : i64 + %false = arith.constant false + + %mat_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %mat_id = pto.castptr %c65536_i64 : i64 -> !pto.ptr + %l0a = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0b = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0c = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.section.cube { + pto.mte_gm_l1_frac %id_gm, %mat_src, nd2nz, + shape(%c40_i64, %c50_i64), + src_layout(%c100_i64), + dst_group(%c1_i64, %c1_i64, %c48_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1_frac %src_gm, %mat_id, nd2nz, + shape(%c50_i64, %c64_i64), + src_layout(%c128_i64), + dst_group(%c1_i64, %c1_i64, %c64_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.mte_l1_l0a %mat_src, %l0a, %c40_i64, %c50_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.mte_l1_l0b %mat_id, %l0b, %c50_i64, %c64_i64 {transpose = true} + : !pto.ptr, !pto.ptr, i64, i64 + + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.mad %l0a, %l0b, %l0c, %c40_i64, %c64_i64, %c50_i64 + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.mte_l0c_gm %l0c, %out_plain_gm, %c40_i64, %c64_i64, %c48_i64, %c64_i64, %c0_i64, %c0_i64, + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, + i64, i64, i64 + + pto.mte_l0c_gm %l0c, %out_atomic_add_gm, %c40_i64, %c64_i64, %c48_i64, %c64_i64, %c0_i64, %c0_i64, + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64), + atomic(type = f32, op = add) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, + i64, i64, i64 + + pto.mte_l0c_gm %l0c, %out_atomic_max_gm, %c40_i64, %c64_i64, %c48_i64, %c64_i64, %c0_i64, %c0_i64, + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64), + atomic(type = f32, op = max) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, + i64, i64, i64 + + pto.mte_l0c_gm %l0c, %out_atomic_min_gm, %c40_i64, %c64_i64, %c48_i64, %c64_i64, %c0_i64, %c0_i64, + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64), + atomic(type = f32, op = min) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, + i64, i64, i64 + } + return + } +} diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-store-atomic-f32-cv/launch.cpp b/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-store-atomic-f32-cv/launch.cpp new file mode 100644 index 000000000..d2f527f64 --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-store-atomic-f32-cv/launch.cpp @@ -0,0 +1,56 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif + +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void fixpipe_acc_store_atomic_f32_cv_kernel( + __gm__ __fp16 *src, __gm__ __fp16 *id, __gm__ float *out_plain, + __gm__ float *out_atomic_add, __gm__ float *out_atomic_max, + __gm__ float *out_atomic_min); + +void LaunchFixpipe_acc_store_atomic_f32_cv_kernel(__fp16 *src, __fp16 *id, + float *outPlain, + float *outAtomicAdd, + float *outAtomicMax, + float *outAtomicMin, + void *stream) { + fixpipe_acc_store_atomic_f32_cv_kernel<<<1, nullptr, stream>>>( + (__gm__ __fp16 *)src, (__gm__ __fp16 *)id, (__gm__ float *)outPlain, + (__gm__ float *)outAtomicAdd, (__gm__ float *)outAtomicMax, + (__gm__ float *)outAtomicMin); +} diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-store-atomic-f32-cv/main.cpp b/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-store-atomic-f32-cv/main.cpp new file mode 100644 index 000000000..cfe160bad --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-store-atomic-f32-cv/main.cpp @@ -0,0 +1,175 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" + +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +#define FILE_CHECK(expr, path) \ + do { \ + if (!(expr)) { \ + std::fprintf(stderr, "[ERROR] file operation failed: %s (%s:%d)\n", \ + path, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchFixpipe_acc_store_atomic_f32_cv_kernel(__fp16 *src, __fp16 *id, + float *outPlain, + float *outAtomicAdd, + float *outAtomicMax, + float *outAtomicMin, + void *stream); + +int main() { + constexpr size_t kSrcElems = 50 * 64; + constexpr size_t kIdElems = 40 * 50; + constexpr size_t kOutElems = 40 * 64; + constexpr size_t kSrcSize = kSrcElems * sizeof(__fp16); + constexpr size_t kIdSize = kIdElems * sizeof(__fp16); + constexpr size_t kOutSize = kOutElems * sizeof(float); + + __fp16 *srcHost = nullptr; + __fp16 *idHost = nullptr; + float *outPlainHost = nullptr; + float *outAtomicAddHost = nullptr; + float *outAtomicMaxHost = nullptr; + float *outAtomicMinHost = nullptr; + __fp16 *srcDevice = nullptr; + __fp16 *idDevice = nullptr; + float *outPlainDevice = nullptr; + float *outAtomicAddDevice = nullptr; + float *outAtomicMaxDevice = nullptr; + float *outAtomicMinDevice = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + size_t inputSize = 0; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)&srcHost, kSrcSize)); + ACL_CHECK(aclrtMallocHost((void **)&idHost, kIdSize)); + ACL_CHECK(aclrtMallocHost((void **)&outPlainHost, kOutSize)); + ACL_CHECK(aclrtMallocHost((void **)&outAtomicAddHost, kOutSize)); + ACL_CHECK(aclrtMallocHost((void **)&outAtomicMaxHost, kOutSize)); + ACL_CHECK(aclrtMallocHost((void **)&outAtomicMinHost, kOutSize)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, kSrcSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&idDevice, kIdSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outPlainDevice, kOutSize, + ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outAtomicAddDevice, kOutSize, + ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outAtomicMaxDevice, kOutSize, + ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outAtomicMinDevice, kOutSize, + ACL_MEM_MALLOC_HUGE_FIRST)); + + inputSize = kIdSize; + FILE_CHECK(ReadFile("./v1.bin", inputSize, idHost, kIdSize) && inputSize == kIdSize, + "./v1.bin"); + inputSize = kSrcSize; + FILE_CHECK(ReadFile("./v2.bin", inputSize, srcHost, kSrcSize) && inputSize == kSrcSize, + "./v2.bin"); + inputSize = kOutSize; + FILE_CHECK(ReadFile("./v3.bin", inputSize, outPlainHost, kOutSize) && + inputSize == kOutSize, + "./v3.bin"); + inputSize = kOutSize; + FILE_CHECK(ReadFile("./v4.bin", inputSize, outAtomicAddHost, kOutSize) && + inputSize == kOutSize, + "./v4.bin"); + inputSize = kOutSize; + FILE_CHECK(ReadFile("./v5.bin", inputSize, outAtomicMaxHost, kOutSize) && + inputSize == kOutSize, + "./v5.bin"); + inputSize = kOutSize; + FILE_CHECK(ReadFile("./v6.bin", inputSize, outAtomicMinHost, kOutSize) && + inputSize == kOutSize, + "./v6.bin"); + + ACL_CHECK(aclrtMemcpy(srcDevice, kSrcSize, srcHost, kSrcSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(idDevice, kIdSize, idHost, kIdSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outPlainDevice, kOutSize, outPlainHost, kOutSize, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outAtomicAddDevice, kOutSize, outAtomicAddHost, kOutSize, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outAtomicMaxDevice, kOutSize, outAtomicMaxHost, kOutSize, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outAtomicMinDevice, kOutSize, outAtomicMinHost, kOutSize, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchFixpipe_acc_store_atomic_f32_cv_kernel(srcDevice, idDevice, outPlainDevice, + outAtomicAddDevice, outAtomicMaxDevice, + outAtomicMinDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + + ACL_CHECK(aclrtMemcpy(outPlainHost, kOutSize, outPlainDevice, kOutSize, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outAtomicAddHost, kOutSize, outAtomicAddDevice, kOutSize, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outAtomicMaxHost, kOutSize, outAtomicMaxDevice, kOutSize, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outAtomicMinHost, kOutSize, outAtomicMinDevice, kOutSize, + ACL_MEMCPY_DEVICE_TO_HOST)); + + FILE_CHECK(WriteFile("./v3.bin", outPlainHost, kOutSize), "./v3.bin"); + FILE_CHECK(WriteFile("./v4.bin", outAtomicAddHost, kOutSize), "./v4.bin"); + FILE_CHECK(WriteFile("./v5.bin", outAtomicMaxHost, kOutSize), "./v5.bin"); + FILE_CHECK(WriteFile("./v6.bin", outAtomicMinHost, kOutSize), "./v6.bin"); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(idDevice); + aclrtFree(outPlainDevice); + aclrtFree(outAtomicAddDevice); + aclrtFree(outAtomicMaxDevice); + aclrtFree(outAtomicMinDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(idHost); + aclrtFreeHost(outPlainHost); + aclrtFreeHost(outAtomicAddHost); + aclrtFreeHost(outAtomicMaxHost); + aclrtFreeHost(outAtomicMinHost); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-store-dual-ub-cv/compare.py b/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-store-dual-ub-cv/compare.py new file mode 100644 index 000000000..0db95c8e4 --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-store-dual-ub-cv/compare.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.float32) + output = np.fromfile(output_path, dtype=np.float32) + if golden.shape != output.shape: + print(f"[ERROR] shape mismatch: {golden.shape} vs {output.shape}") + return False + if np.allclose(golden, output, atol=1e-3, rtol=1e-3): + return True + diff = np.where(np.abs(golden - output) > (1e-3 + 1e-3 * np.abs(golden)))[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] first mismatch at idx={idx}: golden={float(golden[idx])}, out={float(output[idx])}") + return False + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + for index in range(3, 7): + ok = compare_bin(f"golden_v{index}.bin", f"v{index}.bin") and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-store-dual-ub-cv/golden.py b/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-store-dual-ub-cv/golden.py new file mode 100644 index 000000000..4bed2e070 --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-store-dual-ub-cv/golden.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + + +def generate(output_dir: Path) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + + lhs = (np.arange(40 * 50, dtype=np.float16).reshape(40, 50) * np.float16(0.5) + + np.float16(17)).astype(np.float16) + rhs = (np.arange(50 * 64, dtype=np.float16).reshape(50, 64) * np.float16(0.25) + + np.float16(3)).astype(np.float16) + golden = lhs.astype(np.float32) @ rhs.astype(np.float32) + + lhs.reshape(-1).tofile(output_dir / "v1.bin") + rhs.reshape(-1).tofile(output_dir / "v2.bin") + outputs = { + 3: golden[:20, :], + 4: golden[20:, :], + 5: golden[:, :32], + 6: golden[:, 32:], + } + for index, value in outputs.items(): + np.zeros_like(value, dtype=np.float32).reshape(-1).tofile(output_dir / f"v{index}.bin") + value.astype(np.float32).reshape(-1).tofile(output_dir / f"golden_v{index}.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-store-dual-ub-cv/kernel.pto b/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-store-dual-ub-cv/kernel.pto new file mode 100644 index 000000000..71bdcd5bd --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-store-dual-ub-cv/kernel.pto @@ -0,0 +1,126 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/cv-mixed/fixpipe-acc-store-dual-ub-cv +// family: cv-mixed +// target_ops: pto.mte_gm_l1_frac, pto.mte_l1_l0a, pto.mte_l1_l0b, pto.mad, +// pto.mte_l0c_ub, pto.mte_ub_gm, pto.sync.set, pto.sync.wait +// scenarios: nd2nz-functional-load, mte_l0c_ub split-M, mte_l0c_ub split-N, +// per-subblock UB writeback +// ----------------------------------------------------------------------------- + +module attributes {pto.target_arch = "a5"} { + func.func @fixpipe_acc_store_dual_ub_cv_kernel(%src_gm: !pto.ptr, + %id_gm: !pto.ptr, + %out_split_m0_gm: !pto.ptr, + %out_split_m1_gm: !pto.ptr, + %out_split_n0_gm: !pto.ptr, + %out_split_n1_gm: !pto.ptr) + attributes {pto.kernel} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c2_i64 = arith.constant 2 : i64 + %c20_i64 = arith.constant 20 : i64 + %c32_i64 = arith.constant 32 : i64 + %c40_i64 = arith.constant 40 : i64 + %c48_i64 = arith.constant 48 : i64 + %c50_i64 = arith.constant 50 : i64 + %c64b_i64 = arith.constant 64 : i64 + %c64_i64 = arith.constant 64 : i64 + %c100_i64 = arith.constant 100 : i64 + %c128_i64 = arith.constant 128 : i64 + %c256_i64 = arith.constant 256 : i64 + %c32768_i64 = arith.constant 32768 : i64 + %c65536_i64 = arith.constant 65536 : i64 + + %false = arith.constant false + + %mat_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %mat_id = pto.castptr %c65536_i64 : i64 -> !pto.ptr + %ub_split_m = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_split_n = pto.castptr %c32768_i64 : i64 -> !pto.ptr + %l0a = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0b = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0c = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.section.cube { + pto.mte_gm_l1_frac %id_gm, %mat_src, nd2nz, + shape(%c40_i64, %c50_i64), + src_layout(%c100_i64), + dst_group(%c1_i64, %c1_i64, %c48_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1_frac %src_gm, %mat_id, nd2nz, + shape(%c50_i64, %c64b_i64), + src_layout(%c128_i64), + dst_group(%c1_i64, %c1_i64, %c64_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.mte_l1_l0a %mat_src, %l0a, %c40_i64, %c50_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.mte_l1_l0b %mat_id, %l0b, %c50_i64, %c64b_i64 {transpose = true} + : !pto.ptr, !pto.ptr, i64, i64 + + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.mad %l0a, %l0b, %l0c, %c40_i64, %c64b_i64, %c50_i64 unit_flag(check_and_set) + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.set_flag["PIPE_MTE2", "PIPE_FIX", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_FIX", "EVENT_ID0"] + pto.mte_l0c_ub %l0c, %ub_split_m, %c40_i64, %c64b_i64, %c48_i64, %c64b_i64, dst_mode(split_m), + nz2nd, + sat + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_l0c_ub %l0c, %ub_split_n, %c40_i64, %c64b_i64, %c48_i64, %c64b_i64, dst_mode(split_n), + nz2nd, + sat + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.sync.set , 1 + pto.sync.set , 17 + } + + pto.section.vector { + %subblock = pto.get_subblock_idx + %is_subblock0 = arith.cmpi eq, %subblock, %c0_i64 : i64 + %is_subblock1 = arith.cmpi eq, %subblock, %c1_i64 : i64 + + scf.if %is_subblock0 { + pto.sync.wait , 1 + pto.mte_ub_gm %ub_split_m, %out_split_m0_gm, %c256_i64 + nburst(%c20_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_split_n, %out_split_n0_gm, %c128_i64 + nburst(%c40_i64, %c256_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + } + scf.if %is_subblock1 { + pto.sync.wait , 1 + pto.mte_ub_gm %ub_split_m, %out_split_m1_gm, %c256_i64 + nburst(%c20_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_split_n, %out_split_n1_gm, %c128_i64 + nburst(%c40_i64, %c256_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + } + pto.barrier #pto.pipe + } + return + } +} diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-store-dual-ub-cv/launch.cpp b/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-store-dual-ub-cv/launch.cpp new file mode 100644 index 000000000..f81e11aff --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-store-dual-ub-cv/launch.cpp @@ -0,0 +1,53 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif + +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void fixpipe_acc_store_dual_ub_cv_kernel( + __gm__ __fp16 *src, __gm__ __fp16 *id, __gm__ float *outSplitM0, + __gm__ float *outSplitM1, __gm__ float *outSplitN0, + __gm__ float *outSplitN1); + +void LaunchFixpipe_acc_store_dual_ub_cv_kernel( + __fp16 *src, __fp16 *id, float *outSplitM0, float *outSplitM1, + float *outSplitN0, float *outSplitN1, void *stream) { + fixpipe_acc_store_dual_ub_cv_kernel<<<1, nullptr, stream>>>( + (__gm__ __fp16 *)src, (__gm__ __fp16 *)id, + (__gm__ float *)outSplitM0, (__gm__ float *)outSplitM1, + (__gm__ float *)outSplitN0, (__gm__ float *)outSplitN1); +} diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-store-dual-ub-cv/main.cpp b/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-store-dual-ub-cv/main.cpp new file mode 100644 index 000000000..0948dc0fd --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-store-dual-ub-cv/main.cpp @@ -0,0 +1,165 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +#define FILE_CHECK(expr, path) \ + do { \ + if (!(expr)) { \ + std::fprintf(stderr, "[ERROR] file operation failed: %s (%s:%d)\n", \ + path, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchFixpipe_acc_store_dual_ub_cv_kernel( + __fp16 *src, __fp16 *id, float *outSplitM0, float *outSplitM1, + float *outSplitN0, float *outSplitN1, void *stream); + +int main() { + constexpr size_t kSrcElem = 50 * 64; + constexpr size_t kIdElem = 40 * 50; + constexpr size_t kOutSplitMElem = 20 * 64; + constexpr size_t kOutSplitNElem = 40 * 32; + constexpr size_t kSrcSize = kSrcElem * sizeof(__fp16); + constexpr size_t kIdSize = kIdElem * sizeof(__fp16); + constexpr size_t kOutSplitMSize = kOutSplitMElem * sizeof(float); + constexpr size_t kOutSplitNSize = kOutSplitNElem * sizeof(float); + + __fp16 *srcHost = nullptr; + __fp16 *idHost = nullptr; + float *outSplitM0Host = nullptr; + float *outSplitM1Host = nullptr; + float *outSplitN0Host = nullptr; + float *outSplitN1Host = nullptr; + __fp16 *srcDevice = nullptr; + __fp16 *idDevice = nullptr; + float *outSplitM0Device = nullptr; + float *outSplitM1Device = nullptr; + float *outSplitN0Device = nullptr; + float *outSplitN1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + size_t inputSize = 0; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)&srcHost, kSrcSize)); + ACL_CHECK(aclrtMallocHost((void **)&idHost, kIdSize)); + ACL_CHECK(aclrtMallocHost((void **)&outSplitM0Host, kOutSplitMSize)); + ACL_CHECK(aclrtMallocHost((void **)&outSplitM1Host, kOutSplitMSize)); + ACL_CHECK(aclrtMallocHost((void **)&outSplitN0Host, kOutSplitNSize)); + ACL_CHECK(aclrtMallocHost((void **)&outSplitN1Host, kOutSplitNSize)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, kSrcSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&idDevice, kIdSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outSplitM0Device, kOutSplitMSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outSplitM1Device, kOutSplitMSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outSplitN0Device, kOutSplitNSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outSplitN1Device, kOutSplitNSize, ACL_MEM_MALLOC_HUGE_FIRST)); + + inputSize = kIdSize; + FILE_CHECK(ReadFile("./v1.bin", inputSize, idHost, kIdSize) && inputSize == kIdSize, + "./v1.bin"); + inputSize = kSrcSize; + FILE_CHECK(ReadFile("./v2.bin", inputSize, srcHost, kSrcSize) && inputSize == kSrcSize, + "./v2.bin"); + inputSize = kOutSplitMSize; + FILE_CHECK(ReadFile("./v3.bin", inputSize, outSplitM0Host, kOutSplitMSize) && inputSize == kOutSplitMSize, + "./v3.bin"); + inputSize = kOutSplitMSize; + FILE_CHECK(ReadFile("./v4.bin", inputSize, outSplitM1Host, kOutSplitMSize) && inputSize == kOutSplitMSize, + "./v4.bin"); + inputSize = kOutSplitNSize; + FILE_CHECK(ReadFile("./v5.bin", inputSize, outSplitN0Host, kOutSplitNSize) && inputSize == kOutSplitNSize, + "./v5.bin"); + inputSize = kOutSplitNSize; + FILE_CHECK(ReadFile("./v6.bin", inputSize, outSplitN1Host, kOutSplitNSize) && inputSize == kOutSplitNSize, + "./v6.bin"); + + ACL_CHECK(aclrtMemcpy(srcDevice, kSrcSize, srcHost, kSrcSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(idDevice, kIdSize, idHost, kIdSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outSplitM0Device, kOutSplitMSize, outSplitM0Host, kOutSplitMSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outSplitM1Device, kOutSplitMSize, outSplitM1Host, kOutSplitMSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outSplitN0Device, kOutSplitNSize, outSplitN0Host, kOutSplitNSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outSplitN1Device, kOutSplitNSize, outSplitN1Host, kOutSplitNSize, ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchFixpipe_acc_store_dual_ub_cv_kernel( + srcDevice, idDevice, outSplitM0Device, outSplitM1Device, + outSplitN0Device, outSplitN1Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + + ACL_CHECK(aclrtMemcpy(outSplitM0Host, kOutSplitMSize, outSplitM0Device, kOutSplitMSize, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outSplitM1Host, kOutSplitMSize, outSplitM1Device, kOutSplitMSize, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outSplitN0Host, kOutSplitNSize, outSplitN0Device, kOutSplitNSize, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outSplitN1Host, kOutSplitNSize, outSplitN1Device, kOutSplitNSize, ACL_MEMCPY_DEVICE_TO_HOST)); + FILE_CHECK(WriteFile("./v3.bin", outSplitM0Host, kOutSplitMSize), "./v3.bin"); + FILE_CHECK(WriteFile("./v4.bin", outSplitM1Host, kOutSplitMSize), "./v4.bin"); + FILE_CHECK(WriteFile("./v5.bin", outSplitN0Host, kOutSplitNSize), "./v5.bin"); + FILE_CHECK(WriteFile("./v6.bin", outSplitN1Host, kOutSplitNSize), "./v6.bin"); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(idDevice); + aclrtFree(outSplitM0Device); + aclrtFree(outSplitM1Device); + aclrtFree(outSplitN0Device); + aclrtFree(outSplitN1Device); + aclrtFreeHost(srcHost); + aclrtFreeHost(idHost); + aclrtFreeHost(outSplitM0Host); + aclrtFreeHost(outSplitM1Host); + aclrtFreeHost(outSplitN0Host); + aclrtFreeHost(outSplitN1Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-store-sat-f16-cv/compare.py b/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-store-sat-f16-cv/compare.py new file mode 100644 index 000000000..e0897bd1e --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-store-sat-f16-cv/compare.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + print(f"[ERROR] missing file: {golden_path} or {output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.float16) + output = np.fromfile(output_path, dtype=np.float16) + if golden.shape != output.shape: + print(f"[ERROR] shape mismatch: {golden.shape} vs {output.shape}") + return False + equal = golden.view(np.uint16) == output.view(np.uint16) + equal |= np.isnan(golden) & np.isnan(output) + if bool(np.all(equal)): + return True + diff = np.where(~equal)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] first mismatch at idx={idx}: " + f"golden={float(golden[idx])}, out={float(output[idx])}" + ) + return False + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + for index in range(4, 10): + ok = compare_bin(f"golden_v{index}.bin", f"v{index}.bin") and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-store-sat-f16-cv/golden.py b/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-store-sat-f16-cv/golden.py new file mode 100644 index 000000000..da13cc461 --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-store-sat-f16-cv/golden.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +import struct +from pathlib import Path + +import numpy as np + + +M = 40 +N = 64 +K = 50 +FP_QUANT_ELEMS = 64 +FP_TRANSPORT_ELEMS = FP_QUANT_ELEMS * 2 +K_ACTIVE = 6 +CASE_VALUES = np.array( + [ + np.float16(12000.0), + np.float16(-12000.0), + np.float16(np.inf), + np.float16(-np.inf), + np.float16(np.nan), + np.float16(20.0), + ], + dtype=np.float16, +) + + +def encode_scale(scale: float) -> np.uint64: + return np.uint64(struct.unpack("!I", struct.pack("!f", scale))[0]) + + +def generate(output_dir: Path) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + + lhs = np.zeros((M, K), dtype=np.float16) + rhs = np.zeros((K, N), dtype=np.float16) + lhs[:, :K_ACTIVE] = np.float16(1.0) + for col in range(N): + rhs[:K_ACTIVE, col] = CASE_VALUES[col % len(CASE_VALUES)] + fp = np.full(FP_QUANT_ELEMS, encode_scale(1.0), dtype=np.uint64) + + matmul = lhs.astype(np.float32) @ rhs.astype(np.float32) + sat_golden = np.nan_to_num( + matmul, + nan=np.float32(0.0), + posinf=np.finfo(np.float16).max, + neginf=np.finfo(np.float16).min, + ) + sat_golden = np.clip(sat_golden, np.finfo(np.float16).min, + np.finfo(np.float16).max).astype(np.float16) + with np.errstate(over="ignore", invalid="ignore"): + nosat_golden = matmul.astype(np.float16) + if np.array_equal(sat_golden.view(np.uint16), nosat_golden.view(np.uint16)): + raise AssertionError("sat and nosat golden outputs must differ") + + zero = np.zeros((M, N), dtype=np.float16) + lhs.reshape(-1).tofile(output_dir / "v1.bin") + rhs.reshape(-1).tofile(output_dir / "v2.bin") + fp.view(np.uint32).reshape(FP_TRANSPORT_ELEMS).tofile(output_dir / "v3.bin") + zero.reshape(-1).tofile(output_dir / "v4.bin") + zero.reshape(-1).tofile(output_dir / "v5.bin") + zero.reshape(-1).tofile(output_dir / "v6.bin") + zero.reshape(-1).tofile(output_dir / "v7.bin") + zero.reshape(-1).tofile(output_dir / "v8.bin") + zero.reshape(-1).tofile(output_dir / "v9.bin") + + sat_golden.reshape(-1).tofile(output_dir / "golden_v4.bin") + nosat_golden.reshape(-1).tofile(output_dir / "golden_v5.bin") + sat_golden.reshape(-1).tofile(output_dir / "golden_v6.bin") + nosat_golden.reshape(-1).tofile(output_dir / "golden_v7.bin") + sat_golden.reshape(-1).tofile(output_dir / "golden_v8.bin") + nosat_golden.reshape(-1).tofile(output_dir / "golden_v9.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-store-sat-f16-cv/kernel.pto b/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-store-sat-f16-cv/kernel.pto new file mode 100644 index 000000000..68fa2cf2a --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-store-sat-f16-cv/kernel.pto @@ -0,0 +1,193 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/cv-mixed/fixpipe-acc-store-sat-f16-cv +// family: cv-mixed +// target_ops: pto.mte_gm_l1_frac, pto.mte_l1_l0a, pto.mte_l1_l0b, pto.mad, +// pto.mte_gm_l1, pto.mte_l1_fb, pto.mte_l0c_ub, pto.mte_l0c_gm, pto.mte_l0c_l1, +// pto.mte_l1_ub, pto.mte_ub_gm, pto.get_ctrl, pto.sbitset1, pto.set_ctrl, +// pto.sync.set, pto.sync.wait +// scenarios: acc-store-ub-gm-l1, sat-ab, qf322f16, ctrl48-restore, ub-to-gm, l1-to-ub +// ----------------------------------------------------------------------------- + +module attributes {pto.target_arch = "a5"} { + func.func @fixpipe_acc_store_sat_f16_cv_kernel(%src_gm: !pto.ptr, + %id_gm: !pto.ptr, + %fp_gm: !pto.ptr, + %out_ub_sat_gm: !pto.ptr, + %out_ub_nosat_gm: !pto.ptr, + %out_gm_sat_gm: !pto.ptr, + %out_gm_nosat_gm: !pto.ptr, + %out_l1_sat_gm: !pto.ptr, + %out_l1_nosat_gm: !pto.ptr) + attributes {pto.kernel} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c8_i64 = arith.constant 8 : i64 + %c40_i64 = arith.constant 40 : i64 + %c48_i64 = arith.constant 48 : i64 + %c50_i64 = arith.constant 50 : i64 + %c64_i64 = arith.constant 64 : i64 + %c100_i64 = arith.constant 100 : i64 + %c128_i64 = arith.constant 128 : i64 + %c512_i64 = arith.constant 512 : i64 + %c2560_i64 = arith.constant 2560 : i64 + %c49152_i64 = arith.constant 49152 : i64 + %c65536_i64 = arith.constant 65536 : i64 + %c131072_i64 = arith.constant 131072 : i64 + %c196608_i64 = arith.constant 196608 : i64 + %c262144_i64 = arith.constant 262144 : i64 + %c32768_i64 = arith.constant 32768 : i64 + %c98304_i64 = arith.constant 98304 : i64 + %false = arith.constant false + + %mat_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %mat_id = pto.castptr %c65536_i64 : i64 -> !pto.ptr + %l1_fp_raw = pto.castptr %c131072_i64 : i64 -> !pto.ptr + %l1_fp = pto.castptr %c131072_i64 : i64 -> !pto.ptr + %l1_sat = pto.castptr %c196608_i64 : i64 -> !pto.ptr + %l1_nosat = pto.castptr %c262144_i64 : i64 -> !pto.ptr + %fb_fp = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_sat = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_nosat = pto.castptr %c32768_i64 : i64 -> !pto.ptr + %ub_l1_sat = pto.castptr %c65536_i64 : i64 -> !pto.ptr + %ub_l1_nosat = pto.castptr %c98304_i64 : i64 -> !pto.ptr + %l0a = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0b = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0c = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.section.cube { + pto.mte_gm_l1_frac %id_gm, %mat_src, nd2nz, + shape(%c40_i64, %c50_i64), + src_layout(%c100_i64), + dst_group(%c1_i64, %c1_i64, %c48_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1_frac %src_gm, %mat_id, nd2nz, + shape(%c50_i64, %c64_i64), + src_layout(%c128_i64), + dst_group(%c1_i64, %c1_i64, %c64_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1 %fp_gm, %l1_fp_raw, %c512_i64 + nburst(%c1_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.mte_l1_l0a %mat_src, %l0a, %c40_i64, %c50_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.mte_l1_l0b %mat_id, %l0b, %c50_i64, %c64_i64 {transpose = true} + : !pto.ptr, !pto.ptr, i64, i64 + + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.mad %l0a, %l0b, %l0c, %c40_i64, %c64_i64, %c50_i64 + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.set_flag["PIPE_MTE2", "PIPE_FIX", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_FIX", "EVENT_ID0"] + pto.mte_l1_fb %l1_fp, %fb_fp, %c8_i64 + nburst(%c1_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + %ctrl = pto.get_ctrl : i64 + %ctrl_preserve_inf = pto.sbitset1 %ctrl, %c48_i64 : i64, i64 -> i64 + pto.set_ctrl %ctrl_preserve_inf : i64 + + pto.mte_l0c_ub %l0c, %ub_sat, %c40_i64, %c64_i64, %c48_i64, %c64_i64, dst_mode(%c0_i64), + pre_quant(%fb_fp, mode = qf322f16_pre_vec), + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64), + sat + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, + !pto.ptr, i64, i64, i64 + pto.mte_l0c_ub %l0c, %ub_nosat, %c40_i64, %c64_i64, %c48_i64, %c64_i64, dst_mode(%c0_i64), + pre_quant(%fb_fp, mode = qf322f16_pre_vec), + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64), + nosat + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, + !pto.ptr, i64, i64, i64 + + pto.mte_l0c_gm %l0c, %out_gm_sat_gm, %c40_i64, %c64_i64, %c48_i64, %c64_i64, %c0_i64, %c0_i64, + pre_quant(%fb_fp, mode = qf322f16_pre_vec), + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64), + sat + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, + !pto.ptr, i64, i64, i64 + pto.mte_l0c_gm %l0c, %out_gm_nosat_gm, %c40_i64, %c64_i64, %c48_i64, %c64_i64, %c0_i64, %c0_i64, + pre_quant(%fb_fp, mode = qf322f16_pre_vec), + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64), + nosat + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, + !pto.ptr, i64, i64, i64 + + pto.mte_l0c_l1 %l0c, %l1_sat, %c40_i64, %c64_i64, %c48_i64, %c64_i64, + pre_quant(%fb_fp, mode = qf322f16_pre_vec), + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64), + sat + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, + !pto.ptr, i64, i64, i64 + pto.mte_l0c_l1 %l0c, %l1_nosat, %c40_i64, %c64_i64, %c48_i64, %c64_i64, + pre_quant(%fb_fp, mode = qf322f16_pre_vec), + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64), + nosat + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, + !pto.ptr, i64, i64, i64 + + pto.set_flag["PIPE_FIX", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_FIX", "PIPE_MTE1", "EVENT_ID0"] + pto.mte_l1_ub %l1_sat, %ub_l1_sat, %c128_i64 + nburst(%c40_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_l1_ub %l1_nosat, %ub_l1_nosat, %c128_i64 + nburst(%c40_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE1", "PIPE_FIX", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_FIX", "EVENT_ID0"] + pto.sync.set , 1 + pto.sync.set , 17 + } + + pto.section.vector { + %subblock = pto.get_subblock_idx + %is_subblock0 = arith.cmpi eq, %subblock, %c0_i64 : i64 + + scf.if %is_subblock0 { + pto.sync.wait , 1 + pto.mte_ub_gm %ub_sat, %out_ub_sat_gm, %c128_i64 + nburst(%c40_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_nosat, %out_ub_nosat_gm, %c128_i64 + nburst(%c40_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_l1_sat, %out_l1_sat_gm, %c128_i64 + nburst(%c40_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_l1_nosat, %out_l1_nosat_gm, %c128_i64 + nburst(%c40_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + } + pto.barrier #pto.pipe + } + return + } +} diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-store-sat-f16-cv/launch.cpp b/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-store-sat-f16-cv/launch.cpp new file mode 100644 index 000000000..76194af90 --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-store-sat-f16-cv/launch.cpp @@ -0,0 +1,56 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif + +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void fixpipe_acc_store_sat_f16_cv_kernel( + __gm__ __fp16 *src, __gm__ __fp16 *id, __gm__ uint32_t *fp, + __gm__ __fp16 *out_ub_sat, __gm__ __fp16 *out_ub_nosat, + __gm__ __fp16 *out_gm_sat, __gm__ __fp16 *out_gm_nosat, + __gm__ __fp16 *out_l1_sat, __gm__ __fp16 *out_l1_nosat); + +void LaunchFixpipe_acc_store_sat_f16_cv_kernel( + __fp16 *src, __fp16 *id, uint32_t *fp, __fp16 *outUbSat, + __fp16 *outUbNosat, __fp16 *outGmSat, __fp16 *outGmNosat, + __fp16 *outL1Sat, __fp16 *outL1Nosat, void *stream) { + fixpipe_acc_store_sat_f16_cv_kernel<<<1, nullptr, stream>>>( + (__gm__ __fp16 *)src, (__gm__ __fp16 *)id, (__gm__ uint32_t *)fp, + (__gm__ __fp16 *)outUbSat, (__gm__ __fp16 *)outUbNosat, + (__gm__ __fp16 *)outGmSat, (__gm__ __fp16 *)outGmNosat, + (__gm__ __fp16 *)outL1Sat, (__gm__ __fp16 *)outL1Nosat); +} diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-store-sat-f16-cv/main.cpp b/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-store-sat-f16-cv/main.cpp new file mode 100644 index 000000000..074707b7b --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-store-sat-f16-cv/main.cpp @@ -0,0 +1,209 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" + +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +#define FILE_CHECK(expr, path) \ + do { \ + if (!(expr)) { \ + std::fprintf(stderr, "[ERROR] file operation failed: %s (%s:%d)\n", \ + path, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchFixpipe_acc_store_sat_f16_cv_kernel( + __fp16 *src, __fp16 *id, uint32_t *fp, __fp16 *outUbSat, + __fp16 *outUbNosat, __fp16 *outGmSat, __fp16 *outGmNosat, + __fp16 *outL1Sat, __fp16 *outL1Nosat, void *stream); + +int main() { + constexpr size_t kSrcElems = 50 * 64; + constexpr size_t kIdElems = 40 * 50; + constexpr size_t kFpElems = 128; + constexpr size_t kOutElems = 40 * 64; + constexpr size_t kSrcSize = kSrcElems * sizeof(__fp16); + constexpr size_t kIdSize = kIdElems * sizeof(__fp16); + constexpr size_t kFpSize = kFpElems * sizeof(uint32_t); + constexpr size_t kOutSize = kOutElems * sizeof(__fp16); + + __fp16 *srcHost = nullptr; + __fp16 *idHost = nullptr; + uint32_t *fpHost = nullptr; + __fp16 *outUbSatHost = nullptr; + __fp16 *outUbNosatHost = nullptr; + __fp16 *outGmSatHost = nullptr; + __fp16 *outGmNosatHost = nullptr; + __fp16 *outL1SatHost = nullptr; + __fp16 *outL1NosatHost = nullptr; + __fp16 *srcDevice = nullptr; + __fp16 *idDevice = nullptr; + uint32_t *fpDevice = nullptr; + __fp16 *outUbSatDevice = nullptr; + __fp16 *outUbNosatDevice = nullptr; + __fp16 *outGmSatDevice = nullptr; + __fp16 *outGmNosatDevice = nullptr; + __fp16 *outL1SatDevice = nullptr; + __fp16 *outL1NosatDevice = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + size_t inputSize = 0; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)&srcHost, kSrcSize)); + ACL_CHECK(aclrtMallocHost((void **)&idHost, kIdSize)); + ACL_CHECK(aclrtMallocHost((void **)&fpHost, kFpSize)); + ACL_CHECK(aclrtMallocHost((void **)&outUbSatHost, kOutSize)); + ACL_CHECK(aclrtMallocHost((void **)&outUbNosatHost, kOutSize)); + ACL_CHECK(aclrtMallocHost((void **)&outGmSatHost, kOutSize)); + ACL_CHECK(aclrtMallocHost((void **)&outGmNosatHost, kOutSize)); + ACL_CHECK(aclrtMallocHost((void **)&outL1SatHost, kOutSize)); + ACL_CHECK(aclrtMallocHost((void **)&outL1NosatHost, kOutSize)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, kSrcSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&idDevice, kIdSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&fpDevice, kFpSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outUbSatDevice, kOutSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outUbNosatDevice, kOutSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outGmSatDevice, kOutSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outGmNosatDevice, kOutSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outL1SatDevice, kOutSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outL1NosatDevice, kOutSize, ACL_MEM_MALLOC_HUGE_FIRST)); + + inputSize = kIdSize; + FILE_CHECK(ReadFile("./v1.bin", inputSize, idHost, kIdSize) && inputSize == kIdSize, + "./v1.bin"); + inputSize = kSrcSize; + FILE_CHECK(ReadFile("./v2.bin", inputSize, srcHost, kSrcSize) && inputSize == kSrcSize, + "./v2.bin"); + inputSize = kFpSize; + FILE_CHECK(ReadFile("./v3.bin", inputSize, fpHost, kFpSize) && inputSize == kFpSize, + "./v3.bin"); + inputSize = kOutSize; + FILE_CHECK(ReadFile("./v4.bin", inputSize, outUbSatHost, kOutSize) && inputSize == kOutSize, + "./v4.bin"); + inputSize = kOutSize; + FILE_CHECK(ReadFile("./v5.bin", inputSize, outUbNosatHost, kOutSize) && + inputSize == kOutSize, + "./v5.bin"); + inputSize = kOutSize; + FILE_CHECK(ReadFile("./v6.bin", inputSize, outGmSatHost, kOutSize) && inputSize == kOutSize, + "./v6.bin"); + inputSize = kOutSize; + FILE_CHECK(ReadFile("./v7.bin", inputSize, outGmNosatHost, kOutSize) && + inputSize == kOutSize, + "./v7.bin"); + inputSize = kOutSize; + FILE_CHECK(ReadFile("./v8.bin", inputSize, outL1SatHost, kOutSize) && inputSize == kOutSize, + "./v8.bin"); + inputSize = kOutSize; + FILE_CHECK(ReadFile("./v9.bin", inputSize, outL1NosatHost, kOutSize) && + inputSize == kOutSize, + "./v9.bin"); + + ACL_CHECK(aclrtMemcpy(srcDevice, kSrcSize, srcHost, kSrcSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(idDevice, kIdSize, idHost, kIdSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(fpDevice, kFpSize, fpHost, kFpSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outUbSatDevice, kOutSize, outUbSatHost, kOutSize, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outUbNosatDevice, kOutSize, outUbNosatHost, kOutSize, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outGmSatDevice, kOutSize, outGmSatHost, kOutSize, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outGmNosatDevice, kOutSize, outGmNosatHost, kOutSize, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outL1SatDevice, kOutSize, outL1SatHost, kOutSize, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outL1NosatDevice, kOutSize, outL1NosatHost, kOutSize, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchFixpipe_acc_store_sat_f16_cv_kernel( + srcDevice, idDevice, fpDevice, outUbSatDevice, outUbNosatDevice, + outGmSatDevice, outGmNosatDevice, outL1SatDevice, outL1NosatDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + + ACL_CHECK(aclrtMemcpy(outUbSatHost, kOutSize, outUbSatDevice, kOutSize, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outUbNosatHost, kOutSize, outUbNosatDevice, kOutSize, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outGmSatHost, kOutSize, outGmSatDevice, kOutSize, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outGmNosatHost, kOutSize, outGmNosatDevice, kOutSize, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outL1SatHost, kOutSize, outL1SatDevice, kOutSize, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outL1NosatHost, kOutSize, outL1NosatDevice, kOutSize, + ACL_MEMCPY_DEVICE_TO_HOST)); + + FILE_CHECK(WriteFile("./v4.bin", outUbSatHost, kOutSize), "./v4.bin"); + FILE_CHECK(WriteFile("./v5.bin", outUbNosatHost, kOutSize), "./v5.bin"); + FILE_CHECK(WriteFile("./v6.bin", outGmSatHost, kOutSize), "./v6.bin"); + FILE_CHECK(WriteFile("./v7.bin", outGmNosatHost, kOutSize), "./v7.bin"); + FILE_CHECK(WriteFile("./v8.bin", outL1SatHost, kOutSize), "./v8.bin"); + FILE_CHECK(WriteFile("./v9.bin", outL1NosatHost, kOutSize), "./v9.bin"); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(idDevice); + aclrtFree(fpDevice); + aclrtFree(outUbSatDevice); + aclrtFree(outUbNosatDevice); + aclrtFree(outGmSatDevice); + aclrtFree(outGmNosatDevice); + aclrtFree(outL1SatDevice); + aclrtFree(outL1NosatDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(idHost); + aclrtFreeHost(fpHost); + aclrtFreeHost(outUbSatHost); + aclrtFreeHost(outUbNosatHost); + aclrtFreeHost(outGmSatHost); + aclrtFreeHost(outGmNosatHost); + aclrtFreeHost(outL1SatHost); + aclrtFreeHost(outL1NosatHost); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-ub-cv/compare.py b/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-ub-cv/compare.py new file mode 100644 index 000000000..ac4e08379 --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-ub-cv/compare.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.float32) + output = np.fromfile(output_path, dtype=np.float32) + if golden.shape != output.shape: + print(f"[ERROR] shape mismatch: {golden.shape} vs {output.shape}") + return False + if np.allclose(golden, output, atol=1e-3, rtol=1e-3): + return True + diff = np.where(np.abs(golden - output) > (1e-3 + 1e-3 * np.abs(golden)))[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] first mismatch at idx={idx}: golden={float(golden[idx])}, out={float(output[idx])}") + return False + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-ub-cv/golden.py b/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-ub-cv/golden.py new file mode 100644 index 000000000..57aa38205 --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-ub-cv/golden.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + + +def generate(output_dir: Path) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + + lhs = (np.arange(40 * 50, dtype=np.float16).reshape(40, 50) * np.float16(0.5) + + np.float16(17)).astype(np.float16) + rhs = (np.arange(50 * 64, dtype=np.float16).reshape(50, 64) * np.float16(0.25) + + np.float16(3)).astype(np.float16) + out = np.zeros((40, 64), dtype=np.float32) + golden = lhs.astype(np.float32) @ rhs.astype(np.float32) + + lhs.reshape(-1).tofile(output_dir / "v1.bin") + rhs.reshape(-1).tofile(output_dir / "v2.bin") + out.reshape(-1).tofile(output_dir / "v3.bin") + golden.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-ub-cv/kernel.pto b/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-ub-cv/kernel.pto new file mode 100644 index 000000000..45393b6fe --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-ub-cv/kernel.pto @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/cv-mixed/fixpipe-acc-ub-cv +// family: cv-mixed +// target_ops: pto.mte_gm_l1_frac, pto.mte_l1_l0a, pto.mte_l1_l0b, pto.mad, +// pto.mte_l0c_ub, pto.mte_ub_gm, pto.sync.set, pto.sync.wait +// scenarios: nd2nz-functional-load, cc-to-ub, ub-to-gm +// ----------------------------------------------------------------------------- + +module attributes {pto.target_arch = "a5"} { + func.func @fixpipe_acc_ub_cv_kernel(%src_gm: !pto.ptr, + %id_gm: !pto.ptr, + %out_gm: !pto.ptr) + attributes {pto.kernel} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c40_i64 = arith.constant 40 : i64 + %c48_i64 = arith.constant 48 : i64 + %c50_i64 = arith.constant 50 : i64 + %c64b_i64 = arith.constant 64 : i64 + %c64_i64 = arith.constant 64 : i64 + %c100_i64 = arith.constant 100 : i64 + %c128_i64 = arith.constant 128 : i64 + %c256_i64 = arith.constant 256 : i64 + %c2560_i64 = arith.constant 2560 : i64 + %c49152_i64 = arith.constant 49152 : i64 + %c65536_i64 = arith.constant 65536 : i64 + + %false = arith.constant false + + %mat_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %mat_id = pto.castptr %c65536_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0a = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0b = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0c = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.section.cube { + pto.mte_gm_l1_frac %id_gm, %mat_src, nd2nz, + shape(%c40_i64, %c50_i64), + src_layout(%c100_i64), + dst_group(%c1_i64, %c1_i64, %c48_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1_frac %src_gm, %mat_id, nd2nz, + shape(%c50_i64, %c64b_i64), + src_layout(%c128_i64), + dst_group(%c1_i64, %c1_i64, %c64_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.mte_l1_l0a %mat_src, %l0a, %c40_i64, %c50_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.mte_l1_l0b %mat_id, %l0b, %c50_i64, %c64b_i64 {transpose = true} + : !pto.ptr, !pto.ptr, i64, i64 + + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.mad %l0a, %l0b, %l0c, %c40_i64, %c64b_i64, %c50_i64 unit_flag(check_and_set) + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.set_flag["PIPE_MTE2", "PIPE_FIX", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_FIX", "EVENT_ID0"] + pto.mte_l0c_ub %l0c, %ub_out, %c40_i64, %c64b_i64, %c48_i64, %c64b_i64, dst_mode(%c0_i64), + unit_flag(check_only), + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64), + sat + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i64, i64 + + pto.sync.set , 1 + pto.sync.set , 17 + } + + pto.section.vector { + %subblock = pto.get_subblock_idx + %is_subblock0 = arith.cmpi eq, %subblock, %c0_i64 : i64 + + scf.if %is_subblock0 { + pto.sync.wait , 1 + pto.mte_ub_gm %ub_out, %out_gm, %c256_i64 + nburst(%c40_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + } + pto.barrier #pto.pipe + } + return + } +} diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-ub-cv/launch.cpp b/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-ub-cv/launch.cpp new file mode 100644 index 000000000..c7712bf11 --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-ub-cv/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif + +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void fixpipe_acc_ub_cv_kernel( + __gm__ __fp16 *src, __gm__ __fp16 *id, __gm__ float *out); + +void LaunchFixpipe_acc_ub_cv_kernel(__fp16 *src, __fp16 *id, float *out, + void *stream) { + fixpipe_acc_ub_cv_kernel<<<1, nullptr, stream>>>( + (__gm__ __fp16 *)src, (__gm__ __fp16 *)id, (__gm__ float *)out); +} diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-ub-cv/main.cpp b/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-ub-cv/main.cpp new file mode 100644 index 000000000..ebf9feb0c --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-acc-ub-cv/main.cpp @@ -0,0 +1,124 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +#define FILE_CHECK(expr, path) \ + do { \ + if (!(expr)) { \ + std::fprintf(stderr, "[ERROR] file operation failed: %s (%s:%d)\n", \ + path, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchFixpipe_acc_ub_cv_kernel(__fp16 *src, __fp16 *id, float *out, + void *stream); + +int main() { + constexpr size_t kSrcElem = 50 * 64; + constexpr size_t kIdElem = 40 * 50; + constexpr size_t kOutElem = 40 * 64; + constexpr size_t kSrcSize = kSrcElem * sizeof(__fp16); + constexpr size_t kIdSize = kIdElem * sizeof(__fp16); + constexpr size_t kOutSize = kOutElem * sizeof(float); + + __fp16 *srcHost = nullptr; + __fp16 *idHost = nullptr; + float *outHost = nullptr; + __fp16 *srcDevice = nullptr; + __fp16 *idDevice = nullptr; + float *outDevice = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + size_t inputSize = 0; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)&srcHost, kSrcSize)); + ACL_CHECK(aclrtMallocHost((void **)&idHost, kIdSize)); + ACL_CHECK(aclrtMallocHost((void **)&outHost, kOutSize)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, kSrcSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&idDevice, kIdSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, kOutSize, ACL_MEM_MALLOC_HUGE_FIRST)); + + inputSize = kIdSize; + FILE_CHECK(ReadFile("./v1.bin", inputSize, idHost, kIdSize) && inputSize == kIdSize, + "./v1.bin"); + inputSize = kSrcSize; + FILE_CHECK(ReadFile("./v2.bin", inputSize, srcHost, kSrcSize) && inputSize == kSrcSize, + "./v2.bin"); + inputSize = kOutSize; + FILE_CHECK(ReadFile("./v3.bin", inputSize, outHost, kOutSize) && inputSize == kOutSize, + "./v3.bin"); + + ACL_CHECK(aclrtMemcpy(srcDevice, kSrcSize, srcHost, kSrcSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(idDevice, kIdSize, idHost, kIdSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, kOutSize, outHost, kOutSize, ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchFixpipe_acc_ub_cv_kernel(srcDevice, idDevice, outDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + + ACL_CHECK(aclrtMemcpy(outHost, kOutSize, outDevice, kOutSize, ACL_MEMCPY_DEVICE_TO_HOST)); + FILE_CHECK(WriteFile("./v3.bin", outHost, kOutSize), "./v3.bin"); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(idDevice); + aclrtFree(outDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(idHost); + aclrtFreeHost(outHost); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-clip-ub-cv/compare.py b/test/vpto/cases/micro-op/cv-mixed/fixpipe-clip-ub-cv/compare.py new file mode 100644 index 000000000..fd7b17aaf --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-clip-ub-cv/compare.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.float16) + output = np.fromfile(output_path, dtype=np.float16) + if golden.shape != output.shape: + print(f"[ERROR] shape mismatch: {golden.shape} vs {output.shape}") + return False + if np.allclose(golden, output, atol=1e-3, rtol=1e-3): + return True + diff = np.where(np.abs(golden.astype(np.float32) - output.astype(np.float32)) > + (1e-3 + 1e-3 * np.abs(golden.astype(np.float32))))[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] first mismatch at idx={idx}: " + f"golden={float(golden[idx])}, out={float(output[idx])}" + ) + return False + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + for index in range(3, 6): + ok = compare_bin(f"golden_v{index}.bin", f"v{index}.bin") and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-clip-ub-cv/golden.py b/test/vpto/cases/micro-op/cv-mixed/fixpipe-clip-ub-cv/golden.py new file mode 100644 index 000000000..c3e56133e --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-clip-ub-cv/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + + +M = 40 +N = 64 +K = 50 +CLIP_MAX = np.float16(8.0) +SEED = 211 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + lhs = rng.uniform(-2.0, 2.0, size=(M, K)).astype(np.float16) + rhs = rng.uniform(-1.5, 1.5, size=(K, N)).astype(np.float16) + lhs32 = lhs.astype(np.float32) + rhs32 = rhs.astype(np.float32) + matmul = np.zeros((M, N), dtype=np.float32) + for k_idx in range(K): + matmul += lhs32[:, k_idx:k_idx + 1] * rhs32[k_idx:k_idx + 1, :] + clip_only = np.minimum(matmul.astype(np.float16), CLIP_MAX).astype(np.float16) + + output_dir.mkdir(parents=True, exist_ok=True) + lhs.reshape(-1).tofile(output_dir / "v1.bin") + rhs.reshape(-1).tofile(output_dir / "v2.bin") + for index in range(3, 6): + np.zeros((M, N), dtype=np.float16).reshape(-1).tofile(output_dir / f"v{index}.bin") + clip_only.reshape(-1).tofile(output_dir / f"golden_v{index}.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-clip-ub-cv/kernel.pto b/test/vpto/cases/micro-op/cv-mixed/fixpipe-clip-ub-cv/kernel.pto new file mode 100644 index 000000000..a96f80e6e --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-clip-ub-cv/kernel.pto @@ -0,0 +1,137 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/cv-mixed/fixpipe-clip-ub-cv +// family: cv-mixed +// target_ops: pto.mte_gm_l1_frac, pto.mte_l1_l0a, pto.mte_l1_l0b, pto.mad, +// pto.mte_l0c_ub, pto.mte_l0c_gm, pto.mte_l0c_l1, pto.mte_l1_ub, +// pto.mte_ub_gm, pto.sync.set, pto.sync.wait +// scenarios: nd2nz-functional-load, cc-store-ub-gm-l1, standalone-clip, ub-to-gm +// ----------------------------------------------------------------------------- + +module attributes {pto.target_arch = "a5"} { + func.func @fixpipe_clip_ub_cv_kernel(%src_gm: !pto.ptr, + %id_gm: !pto.ptr, + %out_ub_clip_gm: !pto.ptr, + %out_gm_clip_gm: !pto.ptr, + %out_l1_clip_gm: !pto.ptr) + attributes {pto.kernel} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c8_i64 = arith.constant 8 : i64 + %c40_i64 = arith.constant 40 : i64 + %c48_i64 = arith.constant 48 : i64 + %c50_i64 = arith.constant 50 : i64 + %c64_i64 = arith.constant 64 : i64 + %c100_i64 = arith.constant 100 : i64 + %c128_i64 = arith.constant 128 : i64 + %c2560_i64 = arith.constant 2560 : i64 + %c32768_i64 = arith.constant 32768 : i64 + %c49152_i64 = arith.constant 49152 : i64 + %c65536_i64 = arith.constant 65536 : i64 + %c98304_i64 = arith.constant 98304 : i64 + %c131072_i64 = arith.constant 131072 : i64 + %c8_f16 = arith.constant 8.000000e+00 : f16 + %c1_f32 = arith.constant 1.000000e+00 : f32 + %false = arith.constant false + + %mat_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %mat_id = pto.castptr %c65536_i64 : i64 -> !pto.ptr + %l1_clip = pto.castptr %c131072_i64 : i64 -> !pto.ptr + %ub_clip = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_l1_clip = pto.castptr %c32768_i64 : i64 -> !pto.ptr + %l0a = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0b = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0c = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.section.cube { + pto.mte_gm_l1_frac %id_gm, %mat_src, nd2nz, + shape(%c40_i64, %c50_i64), + src_layout(%c100_i64), + dst_group(%c1_i64, %c1_i64, %c48_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1_frac %src_gm, %mat_id, nd2nz, + shape(%c50_i64, %c64_i64), + src_layout(%c128_i64), + dst_group(%c1_i64, %c1_i64, %c64_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.mte_l1_l0a %mat_src, %l0a, %c40_i64, %c50_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.mte_l1_l0b %mat_id, %l0b, %c50_i64, %c64_i64 {transpose = true} + : !pto.ptr, !pto.ptr, i64, i64 + + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.mad %l0a, %l0b, %l0c, %c40_i64, %c64_i64, %c50_i64 + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.set_flag["PIPE_MTE2", "PIPE_FIX", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_FIX", "EVENT_ID0"] + pto.mte_l0c_ub %l0c, %ub_clip, %c40_i64, %c64_i64, %c48_i64, %c64_i64, dst_mode(%c0_i64), + pre_quant(%c1_f32, mode = qf322f16_pre_scalar), + pre_relu(mode = no_relu, clip = %c8_f16), + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, + f32, f16, i64, i64, i64 + pto.mte_l0c_gm %l0c, %out_gm_clip_gm, %c40_i64, %c64_i64, %c48_i64, %c64_i64, %c0_i64, %c0_i64, + pre_quant(%c1_f32, mode = qf322f16_pre_scalar), + pre_relu(mode = no_relu, clip = %c8_f16), + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, + f32, f16, i64, i64, i64 + pto.mte_l0c_l1 %l0c, %l1_clip, %c40_i64, %c64_i64, %c48_i64, %c64_i64, + pre_quant(%c1_f32, mode = qf322f16_pre_scalar), + pre_relu(mode = no_relu, clip = %c8_f16), + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, + f32, f16, i64, i64, i64 + + pto.set_flag["PIPE_FIX", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_FIX", "PIPE_MTE1", "EVENT_ID0"] + pto.mte_l1_ub %l1_clip, %ub_l1_clip, %c128_i64 + nburst(%c40_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE1", "PIPE_FIX", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_FIX", "EVENT_ID0"] + pto.sync.set , 1 + } + + pto.section.vector { + %subblock = pto.get_subblock_idx + %is_subblock0 = arith.cmpi eq, %subblock, %c0_i64 : i64 + + scf.if %is_subblock0 { + pto.sync.wait , 1 + pto.mte_ub_gm %ub_clip, %out_ub_clip_gm, %c128_i64 + nburst(%c40_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_l1_clip, %out_l1_clip_gm, %c128_i64 + nburst(%c40_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + } + pto.barrier #pto.pipe + } + return + } +} diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-clip-ub-cv/launch.cpp b/test/vpto/cases/micro-op/cv-mixed/fixpipe-clip-ub-cv/launch.cpp new file mode 100644 index 000000000..b5e4ddb2e --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-clip-ub-cv/launch.cpp @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif + +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void fixpipe_clip_ub_cv_kernel( + __gm__ __fp16 *src, __gm__ __fp16 *id, __gm__ __fp16 *out_ub_clip, + __gm__ __fp16 *out_gm_clip, __gm__ __fp16 *out_l1_clip); + +void LaunchFixpipe_clip_ub_cv_kernel(__fp16 *src, __fp16 *id, __fp16 *outUbClip, + __fp16 *outGmClip, __fp16 *outL1Clip, + void *stream) { + fixpipe_clip_ub_cv_kernel<<<1, nullptr, stream>>>( + (__gm__ __fp16 *)src, (__gm__ __fp16 *)id, (__gm__ __fp16 *)outUbClip, + (__gm__ __fp16 *)outGmClip, (__gm__ __fp16 *)outL1Clip); +} diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-clip-ub-cv/main.cpp b/test/vpto/cases/micro-op/cv-mixed/fixpipe-clip-ub-cv/main.cpp new file mode 100644 index 000000000..ad5d60710 --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-clip-ub-cv/main.cpp @@ -0,0 +1,154 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" + +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +#define FILE_CHECK(expr, path) \ + do { \ + if (!(expr)) { \ + std::fprintf(stderr, "[ERROR] file operation failed: %s (%s:%d)\n", \ + path, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchFixpipe_clip_ub_cv_kernel(__fp16 *src, __fp16 *id, __fp16 *outUbClip, + __fp16 *outGmClip, __fp16 *outL1Clip, + void *stream); + +int main() { + constexpr size_t kSrcElems = 50 * 64; + constexpr size_t kIdElems = 40 * 50; + constexpr size_t kOutElems = 40 * 64; + constexpr size_t kSrcSize = kSrcElems * sizeof(__fp16); + constexpr size_t kIdSize = kIdElems * sizeof(__fp16); + constexpr size_t kOutSize = kOutElems * sizeof(__fp16); + + __fp16 *srcHost = nullptr; + __fp16 *idHost = nullptr; + __fp16 *outUbClipHost = nullptr; + __fp16 *outGmClipHost = nullptr; + __fp16 *outL1ClipHost = nullptr; + __fp16 *srcDevice = nullptr; + __fp16 *idDevice = nullptr; + __fp16 *outUbClipDevice = nullptr; + __fp16 *outGmClipDevice = nullptr; + __fp16 *outL1ClipDevice = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + size_t inputSize = 0; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)&srcHost, kSrcSize)); + ACL_CHECK(aclrtMallocHost((void **)&idHost, kIdSize)); + ACL_CHECK(aclrtMallocHost((void **)&outUbClipHost, kOutSize)); + ACL_CHECK(aclrtMallocHost((void **)&outGmClipHost, kOutSize)); + ACL_CHECK(aclrtMallocHost((void **)&outL1ClipHost, kOutSize)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, kSrcSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&idDevice, kIdSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outUbClipDevice, kOutSize, + ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outGmClipDevice, kOutSize, + ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outL1ClipDevice, kOutSize, + ACL_MEM_MALLOC_HUGE_FIRST)); + + inputSize = kIdSize; + FILE_CHECK(ReadFile("./v1.bin", inputSize, idHost, kIdSize) && inputSize == kIdSize, + "./v1.bin"); + inputSize = kSrcSize; + FILE_CHECK(ReadFile("./v2.bin", inputSize, srcHost, kSrcSize) && inputSize == kSrcSize, + "./v2.bin"); + inputSize = kOutSize; + FILE_CHECK(ReadFile("./v3.bin", inputSize, outUbClipHost, kOutSize) && + inputSize == kOutSize, + "./v3.bin"); + inputSize = kOutSize; + FILE_CHECK(ReadFile("./v4.bin", inputSize, outGmClipHost, kOutSize) && + inputSize == kOutSize, + "./v4.bin"); + inputSize = kOutSize; + FILE_CHECK(ReadFile("./v5.bin", inputSize, outL1ClipHost, kOutSize) && + inputSize == kOutSize, + "./v5.bin"); + + ACL_CHECK(aclrtMemcpy(srcDevice, kSrcSize, srcHost, kSrcSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(idDevice, kIdSize, idHost, kIdSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outUbClipDevice, kOutSize, outUbClipHost, kOutSize, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outGmClipDevice, kOutSize, outGmClipHost, kOutSize, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outL1ClipDevice, kOutSize, outL1ClipHost, kOutSize, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchFixpipe_clip_ub_cv_kernel(srcDevice, idDevice, outUbClipDevice, + outGmClipDevice, outL1ClipDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + + ACL_CHECK(aclrtMemcpy(outUbClipHost, kOutSize, outUbClipDevice, kOutSize, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outGmClipHost, kOutSize, outGmClipDevice, kOutSize, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outL1ClipHost, kOutSize, outL1ClipDevice, kOutSize, + ACL_MEMCPY_DEVICE_TO_HOST)); + FILE_CHECK(WriteFile("./v3.bin", outUbClipHost, kOutSize), "./v3.bin"); + FILE_CHECK(WriteFile("./v4.bin", outGmClipHost, kOutSize), "./v4.bin"); + FILE_CHECK(WriteFile("./v5.bin", outL1ClipHost, kOutSize), "./v5.bin"); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(idDevice); + aclrtFree(outUbClipDevice); + aclrtFree(outGmClipDevice); + aclrtFree(outL1ClipDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(idHost); + aclrtFreeHost(outUbClipHost); + aclrtFreeHost(outGmClipHost); + aclrtFreeHost(outL1ClipHost); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-clip-ub-cv/stub.cpp b/test/vpto/cases/micro-op/cv-mixed/fixpipe-clip-ub-cv/stub.cpp new file mode 100644 index 000000000..0dfd378f0 --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-clip-ub-cv/stub.cpp @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void fixpipe_clip_ub_cv_kernel( + __gm__ __fp16 *src, __gm__ __fp16 *id, __gm__ float *out_clip) { + (void)src; + (void)id; + (void)out_clip; +} diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-clip-f16-ub-cv/compare.py b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-clip-f16-ub-cv/compare.py new file mode 100644 index 000000000..23bfa3cd2 --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-clip-f16-ub-cv/compare.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.float16) + output = np.fromfile(output_path, dtype=np.float16) + if golden.shape != output.shape: + print(f"[ERROR] shape mismatch: {golden.shape} vs {output.shape}") + return False + if np.allclose(golden, output, atol=1e-3, rtol=1e-3): + return True + diff = np.where(np.abs(golden.astype(np.float32) - output.astype(np.float32)) > + (1e-3 + 1e-3 * np.abs(golden.astype(np.float32))))[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] first mismatch at idx={idx}: " + f"golden={float(golden[idx])}, out={float(output[idx])}" + ) + return False + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + for index in range(4, 10): + ok = compare_bin(f"golden_v{index}.bin", f"v{index}.bin") and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-clip-f16-ub-cv/golden.py b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-clip-f16-ub-cv/golden.py new file mode 100644 index 000000000..18aea50ed --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-clip-f16-ub-cv/golden.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +import struct +from pathlib import Path + +import numpy as np + + +M = 40 +N = 64 +K = 50 +FP_QUANT_ELEMS = 64 +FP_TRANSPORT_ELEMS = FP_QUANT_ELEMS * 2 +CLIP_MAX = np.float16(8.0) +SEED = 97 + + +def extract_quant_params(quant: np.uint64) -> tuple[float, int, int]: + value = int(quant) + m1_bits = (value >> 13) & 0x7FFFF + offset = (value >> 37) & 0x1FF + sign = (value >> 46) & 0x1 + + sign_bit = (m1_bits >> 18) & 0x1 + exponent = (m1_bits >> 10) & 0xFF + mantissa = m1_bits & 0x3FF + m1 = ((-1) ** sign_bit) * (1 + mantissa / 1024.0) * (2 ** (exponent - 127)) + return m1, offset, sign + + +def qf322f16_pre(data: np.ndarray, quant: np.ndarray) -> np.ndarray: + result = np.zeros(data.shape, dtype=np.float16) + for row in range(data.shape[0]): + for col in range(data.shape[1]): + m1, _, _ = extract_quant_params(quant[col]) + scaled = data[row, col].astype(np.float32) * np.float32(m1) + result[row, col] = np.clip( + scaled, + np.finfo(np.float16).min, + np.finfo(np.float16).max, + ).astype(np.float16) + return result + + +def make_vector_quant_params(n: int) -> np.ndarray: + scales = (np.arange(n, dtype=np.float32) % np.float32(4.0)) + np.float32(1.0) + encoded = scales.astype(np.uint64) + for idx, scale in enumerate(scales): + encoded[idx] = struct.unpack("!I", struct.pack("!f", float(scale)))[0] + return np.frombuffer(encoded, np.uint64) + + +def generate(output_dir: Path, seed: int) -> None: + a = (np.arange(M * K, dtype=np.float32).reshape(M, K) * np.float32(0.01) + + np.float32(0.5)).astype(np.float16) + b = (np.arange(K * N, dtype=np.float32).reshape(K, N) * np.float32(0.005) + + np.float32(0.25)).astype(np.float16) + fp = make_vector_quant_params(FP_QUANT_ELEMS) + matmul = np.zeros((M, N), dtype=np.float32) + a32 = a.astype(np.float32) + b32 = b.astype(np.float32) + for k_idx in range(K): + matmul += a32[:, k_idx:k_idx + 1] * b32[k_idx:k_idx + 1, :] + golden_quant = qf322f16_pre(matmul, fp) + golden_clip = np.minimum(golden_quant, CLIP_MAX).astype(np.float16) + + output_dir.mkdir(parents=True, exist_ok=True) + a.reshape(-1).tofile(output_dir / "v1.bin") + b.reshape(-1).tofile(output_dir / "v2.bin") + fp.view(np.uint32).reshape(FP_TRANSPORT_ELEMS).tofile(output_dir / "v3.bin") + mapping = { + 4: golden_quant, + 5: golden_clip, + 6: golden_quant, + 7: golden_clip, + 8: golden_quant, + 9: golden_clip, + } + for index, golden in mapping.items(): + np.zeros((M, N), dtype=np.float16).reshape(-1).tofile(output_dir / f"v{index}.bin") + golden.reshape(-1).tofile(output_dir / f"golden_v{index}.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-clip-f16-ub-cv/kernel.pto b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-clip-f16-ub-cv/kernel.pto new file mode 100644 index 000000000..2927108b9 --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-clip-f16-ub-cv/kernel.pto @@ -0,0 +1,187 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/cv-mixed/fixpipe-quant-clip-f16-ub-cv +// family: cv-mixed +// target_ops: pto.mte_gm_l1_frac, pto.mte_l1_l0a, pto.mte_l1_l0b, pto.mad, +// pto.mte_l1_fb, pto.mte_l0c_ub, pto.mte_l0c_gm, pto.mte_l0c_l1, +// pto.mte_l1_ub, pto.mte_ub_gm, pto.sync.set, pto.sync.wait +// scenarios: split-compile, cube-producer, vector-consumer, fixpipe-vector-qf322f16, +// standalone-clip-on-f16, fp-load, strict-matmul-golden, cc-store-ub-gm-l1 +// ----------------------------------------------------------------------------- + +module attributes {pto.target_arch = "a5"} { + func.func @fixpipe_quant_clip_f16_ub_cv_kernel(%src_gm: !pto.ptr, + %id_gm: !pto.ptr, + %fp_gm: !pto.ptr, + %out_ub_quant_gm: !pto.ptr, + %out_ub_clip_gm: !pto.ptr, + %out_gm_quant_gm: !pto.ptr, + %out_gm_clip_gm: !pto.ptr, + %out_l1_quant_gm: !pto.ptr, + %out_l1_clip_gm: !pto.ptr) + attributes {pto.kernel} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c8_i64 = arith.constant 8 : i64 + %c40_i64 = arith.constant 40 : i64 + %c48_i64 = arith.constant 48 : i64 + %c50_i64 = arith.constant 50 : i64 + %c64_i64 = arith.constant 64 : i64 + %c100_i64 = arith.constant 100 : i64 + %c128_i64 = arith.constant 128 : i64 + %c512_i64 = arith.constant 512 : i64 + %c2560_i64 = arith.constant 2560 : i64 + %c32768_i64 = arith.constant 32768 : i64 + %c49152_i64 = arith.constant 49152 : i64 + %c65536_i64 = arith.constant 65536 : i64 + %c98304_i64 = arith.constant 98304 : i64 + %c131072_i64 = arith.constant 131072 : i64 + %c196608_i64 = arith.constant 196608 : i64 + %c262144_i64 = arith.constant 262144 : i64 + %c8_f16 = arith.constant 8.000000e+00 : f16 + %false = arith.constant false + + %mat_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %mat_id = pto.castptr %c65536_i64 : i64 -> !pto.ptr + %l1_fp_raw = pto.castptr %c131072_i64 : i64 -> !pto.ptr + %l1_fp = pto.castptr %c131072_i64 : i64 -> !pto.ptr + %l1_quant = pto.castptr %c196608_i64 : i64 -> !pto.ptr + %l1_clip = pto.castptr %c262144_i64 : i64 -> !pto.ptr + %fb_fp = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_quant = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_clip = pto.castptr %c32768_i64 : i64 -> !pto.ptr + %ub_l1_quant = pto.castptr %c65536_i64 : i64 -> !pto.ptr + %ub_l1_clip = pto.castptr %c98304_i64 : i64 -> !pto.ptr + %l0a = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0b = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0c = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.section.cube { + pto.mte_gm_l1_frac %id_gm, %mat_src, nd2nz, + shape(%c40_i64, %c50_i64), + src_layout(%c100_i64), + dst_group(%c1_i64, %c1_i64, %c48_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1_frac %src_gm, %mat_id, nd2nz, + shape(%c50_i64, %c64_i64), + src_layout(%c128_i64), + dst_group(%c1_i64, %c1_i64, %c64_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1 %fp_gm, %l1_fp_raw, %c512_i64 + nburst(%c1_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.mte_l1_l0a %mat_src, %l0a, %c40_i64, %c50_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.mte_l1_l0b %mat_id, %l0b, %c50_i64, %c64_i64 {transpose = true} + : !pto.ptr, !pto.ptr, i64, i64 + + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.mad %l0a, %l0b, %l0c, %c40_i64, %c64_i64, %c50_i64 + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.set_flag["PIPE_MTE2", "PIPE_FIX", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_FIX", "EVENT_ID0"] + pto.mte_l1_fb %l1_fp, %fb_fp, %c8_i64 + nburst(%c1_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.mte_l0c_ub %l0c, %ub_quant, %c40_i64, %c64_i64, %c48_i64, %c64_i64, dst_mode(%c0_i64), + pre_quant(%fb_fp, mode = qf322f16_pre_vec), + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, + !pto.ptr, i64, i64, i64 + pto.mte_l0c_ub %l0c, %ub_clip, %c40_i64, %c64_i64, %c48_i64, %c64_i64, dst_mode(%c0_i64), + pre_quant(%fb_fp, mode = qf322f16_pre_vec), + pre_relu(mode = no_relu, clip = %c8_f16), + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, + !pto.ptr, f16, i64, i64, i64 + + pto.mte_l0c_gm %l0c, %out_gm_quant_gm, %c40_i64, %c64_i64, %c48_i64, %c64_i64, %c0_i64, %c0_i64, + pre_quant(%fb_fp, mode = qf322f16_pre_vec), + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, + !pto.ptr, i64, i64, i64 + pto.mte_l0c_gm %l0c, %out_gm_clip_gm, %c40_i64, %c64_i64, %c48_i64, %c64_i64, %c0_i64, %c0_i64, + pre_quant(%fb_fp, mode = qf322f16_pre_vec), + pre_relu(mode = no_relu, clip = %c8_f16), + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, + !pto.ptr, f16, i64, i64, i64 + + pto.mte_l0c_l1 %l0c, %l1_quant, %c40_i64, %c64_i64, %c48_i64, %c64_i64, + pre_quant(%fb_fp, mode = qf322f16_pre_vec), + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, + !pto.ptr, i64, i64, i64 + pto.mte_l0c_l1 %l0c, %l1_clip, %c40_i64, %c64_i64, %c48_i64, %c64_i64, + pre_quant(%fb_fp, mode = qf322f16_pre_vec), + pre_relu(mode = no_relu, clip = %c8_f16), + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, + !pto.ptr, f16, i64, i64, i64 + + pto.set_flag["PIPE_FIX", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_FIX", "PIPE_MTE1", "EVENT_ID0"] + pto.mte_l1_ub %l1_quant, %ub_l1_quant, %c128_i64 + nburst(%c40_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_l1_ub %l1_clip, %ub_l1_clip, %c128_i64 + nburst(%c40_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE1", "PIPE_FIX", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_FIX", "EVENT_ID0"] + pto.sync.set , 1 + pto.sync.set , 17 + } + + pto.section.vector { + %subblock = pto.get_subblock_idx + %is_subblock0 = arith.cmpi eq, %subblock, %c0_i64 : i64 + + scf.if %is_subblock0 { + pto.sync.wait , 1 + pto.mte_ub_gm %ub_quant, %out_ub_quant_gm, %c128_i64 + nburst(%c40_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_clip, %out_ub_clip_gm, %c128_i64 + nburst(%c40_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_l1_quant, %out_l1_quant_gm, %c128_i64 + nburst(%c40_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_l1_clip, %out_l1_clip_gm, %c128_i64 + nburst(%c40_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + } + pto.barrier #pto.pipe + } + return + } +} diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-clip-f16-ub-cv/launch.cpp b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-clip-f16-ub-cv/launch.cpp new file mode 100644 index 000000000..de59f8c98 --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-clip-f16-ub-cv/launch.cpp @@ -0,0 +1,56 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif + +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void fixpipe_quant_clip_f16_ub_cv_kernel( + __gm__ __fp16 *src, __gm__ __fp16 *id, __gm__ uint32_t *fp, + __gm__ __fp16 *out_ub_quant, __gm__ __fp16 *out_ub_clip, + __gm__ __fp16 *out_gm_quant, __gm__ __fp16 *out_gm_clip, + __gm__ __fp16 *out_l1_quant, __gm__ __fp16 *out_l1_clip); + +void LaunchFixpipe_quant_clip_f16_ub_cv_kernel( + __fp16 *src, __fp16 *id, uint32_t *fp, __fp16 *outUbQuant, + __fp16 *outUbClip, __fp16 *outGmQuant, __fp16 *outGmClip, + __fp16 *outL1Quant, __fp16 *outL1Clip, void *stream) { + fixpipe_quant_clip_f16_ub_cv_kernel<<<1, nullptr, stream>>>( + (__gm__ __fp16 *)src, (__gm__ __fp16 *)id, (__gm__ uint32_t *)fp, + (__gm__ __fp16 *)outUbQuant, (__gm__ __fp16 *)outUbClip, + (__gm__ __fp16 *)outGmQuant, (__gm__ __fp16 *)outGmClip, + (__gm__ __fp16 *)outL1Quant, (__gm__ __fp16 *)outL1Clip); +} diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-clip-f16-ub-cv/main.cpp b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-clip-f16-ub-cv/main.cpp new file mode 100644 index 000000000..24c8bfe3f --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-clip-f16-ub-cv/main.cpp @@ -0,0 +1,212 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" + +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +#define FILE_CHECK(expr, path) \ + do { \ + if (!(expr)) { \ + std::fprintf(stderr, "[ERROR] file operation failed: %s (%s:%d)\n", \ + path, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchFixpipe_quant_clip_f16_ub_cv_kernel( + __fp16 *src, __fp16 *id, uint32_t *fp, __fp16 *outUbQuant, + __fp16 *outUbClip, __fp16 *outGmQuant, __fp16 *outGmClip, + __fp16 *outL1Quant, __fp16 *outL1Clip, void *stream); + +int main() { + constexpr size_t kSizeSrcElems = 50 * 64; + constexpr size_t kSizeIdElems = 40 * 50; + constexpr size_t kFpElems = 128; + constexpr size_t kOutElems = 40 * 64; + constexpr size_t kSizeSrc = kSizeSrcElems * sizeof(__fp16); + constexpr size_t kSizeId = kSizeIdElems * sizeof(__fp16); + constexpr size_t kSizeFp = kFpElems * sizeof(uint32_t); + constexpr size_t kSizeOut = kOutElems * sizeof(__fp16); + + __fp16 *srcHost = nullptr; + __fp16 *idHost = nullptr; + uint32_t *fpHost = nullptr; + __fp16 *outUbQuantHost = nullptr; + __fp16 *outUbClipHost = nullptr; + __fp16 *outGmQuantHost = nullptr; + __fp16 *outGmClipHost = nullptr; + __fp16 *outL1QuantHost = nullptr; + __fp16 *outL1ClipHost = nullptr; + __fp16 *srcDevice = nullptr; + __fp16 *idDevice = nullptr; + uint32_t *fpDevice = nullptr; + __fp16 *outUbQuantDevice = nullptr; + __fp16 *outUbClipDevice = nullptr; + __fp16 *outGmQuantDevice = nullptr; + __fp16 *outGmClipDevice = nullptr; + __fp16 *outL1QuantDevice = nullptr; + __fp16 *outL1ClipDevice = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + size_t inputSize = 0; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)&srcHost, kSizeSrc)); + ACL_CHECK(aclrtMallocHost((void **)&idHost, kSizeId)); + ACL_CHECK(aclrtMallocHost((void **)&fpHost, kSizeFp)); + ACL_CHECK(aclrtMallocHost((void **)&outUbQuantHost, kSizeOut)); + ACL_CHECK(aclrtMallocHost((void **)&outUbClipHost, kSizeOut)); + ACL_CHECK(aclrtMallocHost((void **)&outGmQuantHost, kSizeOut)); + ACL_CHECK(aclrtMallocHost((void **)&outGmClipHost, kSizeOut)); + ACL_CHECK(aclrtMallocHost((void **)&outL1QuantHost, kSizeOut)); + ACL_CHECK(aclrtMallocHost((void **)&outL1ClipHost, kSizeOut)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, kSizeSrc, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&idDevice, kSizeId, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&fpDevice, kSizeFp, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outUbQuantDevice, kSizeOut, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outUbClipDevice, kSizeOut, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outGmQuantDevice, kSizeOut, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outGmClipDevice, kSizeOut, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outL1QuantDevice, kSizeOut, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outL1ClipDevice, kSizeOut, ACL_MEM_MALLOC_HUGE_FIRST)); + + inputSize = kSizeId; + FILE_CHECK(ReadFile("./v1.bin", inputSize, idHost, kSizeId) && inputSize == kSizeId, + "./v1.bin"); + inputSize = kSizeSrc; + FILE_CHECK(ReadFile("./v2.bin", inputSize, srcHost, kSizeSrc) && inputSize == kSizeSrc, + "./v2.bin"); + inputSize = kSizeFp; + FILE_CHECK(ReadFile("./v3.bin", inputSize, fpHost, kSizeFp) && inputSize == kSizeFp, + "./v3.bin"); + for (int index = 4; index <= 9; ++index) { + __fp16 *hostBuf = nullptr; + switch (index) { + case 4: hostBuf = outUbQuantHost; break; + case 5: hostBuf = outUbClipHost; break; + case 6: hostBuf = outGmQuantHost; break; + case 7: hostBuf = outGmClipHost; break; + case 8: hostBuf = outL1QuantHost; break; + case 9: hostBuf = outL1ClipHost; break; + } + inputSize = kSizeOut; + char path[16]; + std::snprintf(path, sizeof(path), "./v%d.bin", index); + FILE_CHECK(ReadFile(path, inputSize, hostBuf, kSizeOut) && inputSize == kSizeOut, + path); + } + + ACL_CHECK(aclrtMemcpy(srcDevice, kSizeSrc, srcHost, kSizeSrc, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(idDevice, kSizeId, idHost, kSizeId, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(fpDevice, kSizeFp, fpHost, kSizeFp, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outUbQuantDevice, kSizeOut, outUbQuantHost, kSizeOut, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outUbClipDevice, kSizeOut, outUbClipHost, kSizeOut, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outGmQuantDevice, kSizeOut, outGmQuantHost, kSizeOut, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outGmClipDevice, kSizeOut, outGmClipHost, kSizeOut, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outL1QuantDevice, kSizeOut, outL1QuantHost, kSizeOut, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outL1ClipDevice, kSizeOut, outL1ClipHost, kSizeOut, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchFixpipe_quant_clip_f16_ub_cv_kernel( + srcDevice, idDevice, fpDevice, outUbQuantDevice, outUbClipDevice, + outGmQuantDevice, outGmClipDevice, outL1QuantDevice, outL1ClipDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + + ACL_CHECK(aclrtMemcpy(outUbQuantHost, kSizeOut, outUbQuantDevice, kSizeOut, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outUbClipHost, kSizeOut, outUbClipDevice, kSizeOut, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outGmQuantHost, kSizeOut, outGmQuantDevice, kSizeOut, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outGmClipHost, kSizeOut, outGmClipDevice, kSizeOut, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outL1QuantHost, kSizeOut, outL1QuantDevice, kSizeOut, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outL1ClipHost, kSizeOut, outL1ClipDevice, kSizeOut, + ACL_MEMCPY_DEVICE_TO_HOST)); + + for (int index = 4; index <= 9; ++index) { + __fp16 *hostBuf = nullptr; + switch (index) { + case 4: hostBuf = outUbQuantHost; break; + case 5: hostBuf = outUbClipHost; break; + case 6: hostBuf = outGmQuantHost; break; + case 7: hostBuf = outGmClipHost; break; + case 8: hostBuf = outL1QuantHost; break; + case 9: hostBuf = outL1ClipHost; break; + } + char path[16]; + std::snprintf(path, sizeof(path), "./v%d.bin", index); + FILE_CHECK(WriteFile(path, hostBuf, kSizeOut), path); + } + +cleanup: + aclrtFree(srcDevice); + aclrtFree(idDevice); + aclrtFree(fpDevice); + aclrtFree(outUbQuantDevice); + aclrtFree(outUbClipDevice); + aclrtFree(outGmQuantDevice); + aclrtFree(outGmClipDevice); + aclrtFree(outL1QuantDevice); + aclrtFree(outL1ClipDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(idHost); + aclrtFreeHost(fpHost); + aclrtFreeHost(outUbQuantHost); + aclrtFreeHost(outUbClipHost); + aclrtFreeHost(outGmQuantHost); + aclrtFreeHost(outGmClipHost); + aclrtFreeHost(outL1QuantHost); + aclrtFreeHost(outL1ClipHost); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-clip-f16-ub-cv/stub.cpp b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-clip-f16-ub-cv/stub.cpp new file mode 100644 index 000000000..00e51135f --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-clip-f16-ub-cv/stub.cpp @@ -0,0 +1,27 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void fixpipe_quant_clip_f16_ub_cv_kernel( + __gm__ __fp16 *src, __gm__ __fp16 *id, __gm__ uint32_t *fp, + __gm__ __fp16 *out_quant, __gm__ __fp16 *out_clip) { + (void)src; + (void)id; + (void)fp; + (void)out_quant; + (void)out_clip; +} diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-normal-relu-clip-f16-scalar-ub-cv/compare.py b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-normal-relu-clip-f16-scalar-ub-cv/compare.py new file mode 100644 index 000000000..2d1c4b61c --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-normal-relu-clip-f16-scalar-ub-cv/compare.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.float16) + output = np.fromfile(output_path, dtype=np.float16) + if golden.shape != output.shape: + print(f"[ERROR] shape mismatch: {golden.shape} vs {output.shape}") + return False + if np.allclose(golden, output, atol=1e-3, rtol=1e-3): + return True + diff = np.where(np.abs(golden.astype(np.float32) - output.astype(np.float32)) > + (1e-3 + 1e-3 * np.abs(golden.astype(np.float32))))[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] first mismatch at idx={idx}: " + f"golden={float(golden[idx])}, out={float(output[idx])}" + ) + return False + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + for index in range(3, 9): + ok = compare_bin(f"golden_v{index}.bin", f"v{index}.bin") and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-normal-relu-clip-f16-scalar-ub-cv/golden.py b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-normal-relu-clip-f16-scalar-ub-cv/golden.py new file mode 100644 index 000000000..dcf29bfea --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-normal-relu-clip-f16-scalar-ub-cv/golden.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + + +M = 40 +N = 64 +K = 50 +CLIP_MAX = np.float16(8.0) +SEED = 317 + + +def qf322f16_normal_relu(data: np.ndarray) -> np.ndarray: + relu_pre = np.maximum(data, np.float32(0.0)) + return relu_pre.astype(np.float16) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + lhs = rng.uniform(-2.0, 2.0, size=(M, K)).astype(np.float16) + rhs = rng.uniform(-1.5, 1.5, size=(K, N)).astype(np.float16) + matmul = np.zeros((M, N), dtype=np.float32) + lhs32 = lhs.astype(np.float32) + rhs32 = rhs.astype(np.float32) + for k_idx in range(K): + matmul += lhs32[:, k_idx:k_idx + 1] * rhs32[k_idx:k_idx + 1, :] + relu = qf322f16_normal_relu(matmul) + clip = np.minimum(relu, CLIP_MAX).astype(np.float16) + + output_dir.mkdir(parents=True, exist_ok=True) + lhs.reshape(-1).tofile(output_dir / "v1.bin") + rhs.reshape(-1).tofile(output_dir / "v2.bin") + mapping = { + 3: relu, + 4: clip, + 5: relu, + 6: clip, + 7: relu, + 8: clip, + } + for index, golden in mapping.items(): + np.zeros((M, N), dtype=np.float16).reshape(-1).tofile(output_dir / f"v{index}.bin") + golden.reshape(-1).tofile(output_dir / f"golden_v{index}.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-normal-relu-clip-f16-scalar-ub-cv/kernel.pto b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-normal-relu-clip-f16-scalar-ub-cv/kernel.pto new file mode 100644 index 000000000..982e718cd --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-normal-relu-clip-f16-scalar-ub-cv/kernel.pto @@ -0,0 +1,178 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/cv-mixed/fixpipe-quant-normal-relu-clip-f16-scalar-ub-cv +// family: cv-mixed +// target_ops: pto.mte_gm_l1_frac, pto.mte_l1_l0a, pto.mte_l1_l0b, pto.mad, +// pto.mte_l0c_ub, pto.mte_l0c_gm, pto.mte_l0c_l1, pto.mte_l1_ub, +// pto.mte_ub_gm, pto.sync.set, pto.sync.wait +// scenarios: split-compile, cube-producer, vector-consumer, fixpipe-qf322f16-scalar, +// normal-relu, clip-relu-pre, strict-matmul-golden, cc-store-ub-gm-l1 +// ----------------------------------------------------------------------------- + +module attributes {pto.target_arch = "a5"} { + func.func @fixpipe_quant_normal_relu_clip_f16_scalar_ub_cv_kernel( + %lhs_gm: !pto.ptr, + %rhs_gm: !pto.ptr, + %out_ub_relu_gm: !pto.ptr, + %out_ub_clip_gm: !pto.ptr, + %out_gm_relu_gm: !pto.ptr, + %out_gm_clip_gm: !pto.ptr, + %out_l1_relu_gm: !pto.ptr, + %out_l1_clip_gm: !pto.ptr) + attributes {pto.kernel} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c40_i64 = arith.constant 40 : i64 + %c48_i64 = arith.constant 48 : i64 + %c50_i64 = arith.constant 50 : i64 + %c64_i64 = arith.constant 64 : i64 + %c100_i64 = arith.constant 100 : i64 + %c128_i64 = arith.constant 128 : i64 + %c2560_i64 = arith.constant 2560 : i64 + %c32768_i64 = arith.constant 32768 : i64 + %c49152_i64 = arith.constant 49152 : i64 + %c65536_i64 = arith.constant 65536 : i64 + %c98304_i64 = arith.constant 98304 : i64 + %c131072_i64 = arith.constant 131072 : i64 + %c196608_i64 = arith.constant 196608 : i64 + %c8_f16 = arith.constant 8.000000e+00 : f16 + %c1_f32 = arith.constant 1.000000e+00 : f32 + %false = arith.constant false + + %mat_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %mat_rhs = pto.castptr %c65536_i64 : i64 -> !pto.ptr + %l1_relu = pto.castptr %c131072_i64 : i64 -> !pto.ptr + %l1_clip = pto.castptr %c196608_i64 : i64 -> !pto.ptr + %ub_relu = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_clip = pto.castptr %c32768_i64 : i64 -> !pto.ptr + %ub_l1_relu = pto.castptr %c65536_i64 : i64 -> !pto.ptr + %ub_l1_clip = pto.castptr %c98304_i64 : i64 -> !pto.ptr + %l0a = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0b = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0c = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.section.cube { + pto.mte_gm_l1_frac %lhs_gm, %mat_lhs, nd2nz, + shape(%c40_i64, %c50_i64), + src_layout(%c100_i64), + dst_group(%c1_i64, %c1_i64, %c48_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1_frac %rhs_gm, %mat_rhs, nd2nz, + shape(%c50_i64, %c64_i64), + src_layout(%c128_i64), + dst_group(%c1_i64, %c1_i64, %c64_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.mte_l1_l0a %mat_lhs, %l0a, %c40_i64, %c50_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.mte_l1_l0b %mat_rhs, %l0b, %c50_i64, %c64_i64 {transpose = true} + : !pto.ptr, !pto.ptr, i64, i64 + + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.mad %l0a, %l0b, %l0c, %c40_i64, %c64_i64, %c50_i64 + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.set_flag["PIPE_MTE2", "PIPE_FIX", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_FIX", "EVENT_ID0"] + pto.mte_l0c_ub %l0c, %ub_relu, %c40_i64, %c64_i64, %c48_i64, %c64_i64, dst_mode(%c0_i64), + pre_quant(%c1_f32, mode = qf322f16_pre_scalar), + pre_relu(mode = normal_relu), + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, + f32, i64, i64, i64 + pto.mte_l0c_ub %l0c, %ub_clip, %c40_i64, %c64_i64, %c48_i64, %c64_i64, dst_mode(%c0_i64), + pre_quant(%c1_f32, mode = qf322f16_pre_scalar), + pre_relu(mode = normal_relu, clip = %c8_f16), + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, + f32, f16, i64, i64, i64 + + pto.mte_l0c_gm %l0c, %out_gm_relu_gm, %c40_i64, %c64_i64, %c48_i64, %c64_i64, %c0_i64, %c0_i64, + pre_quant(%c1_f32, mode = qf322f16_pre_scalar), + pre_relu(mode = normal_relu), + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, + f32, i64, i64, i64 + pto.mte_l0c_gm %l0c, %out_gm_clip_gm, %c40_i64, %c64_i64, %c48_i64, %c64_i64, %c0_i64, %c0_i64, + pre_quant(%c1_f32, mode = qf322f16_pre_scalar), + pre_relu(mode = normal_relu, clip = %c8_f16), + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, + f32, f16, i64, i64, i64 + + pto.mte_l0c_l1 %l0c, %l1_relu, %c40_i64, %c64_i64, %c48_i64, %c64_i64, + pre_quant(%c1_f32, mode = qf322f16_pre_scalar), + pre_relu(mode = normal_relu), + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, + f32, i64, i64, i64 + pto.mte_l0c_l1 %l0c, %l1_clip, %c40_i64, %c64_i64, %c48_i64, %c64_i64, + pre_quant(%c1_f32, mode = qf322f16_pre_scalar), + pre_relu(mode = normal_relu, clip = %c8_f16), + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, + f32, f16, i64, i64, i64 + + pto.set_flag["PIPE_FIX", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_FIX", "PIPE_MTE1", "EVENT_ID0"] + pto.mte_l1_ub %l1_relu, %ub_l1_relu, %c128_i64 + nburst(%c40_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_l1_ub %l1_clip, %ub_l1_clip, %c128_i64 + nburst(%c40_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE1", "PIPE_FIX", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_FIX", "EVENT_ID0"] + pto.sync.set , 1 + pto.sync.set , 17 + } + + pto.section.vector { + %subblock = pto.get_subblock_idx + %is_subblock0 = arith.cmpi eq, %subblock, %c0_i64 : i64 + + scf.if %is_subblock0 { + pto.sync.wait , 1 + pto.mte_ub_gm %ub_relu, %out_ub_relu_gm, %c128_i64 + nburst(%c40_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_clip, %out_ub_clip_gm, %c128_i64 + nburst(%c40_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_l1_relu, %out_l1_relu_gm, %c128_i64 + nburst(%c40_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_l1_clip, %out_l1_clip_gm, %c128_i64 + nburst(%c40_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + } + pto.barrier #pto.pipe + } + return + } +} diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-normal-relu-clip-f16-scalar-ub-cv/launch.cpp b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-normal-relu-clip-f16-scalar-ub-cv/launch.cpp new file mode 100644 index 000000000..b91ab7ea8 --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-normal-relu-clip-f16-scalar-ub-cv/launch.cpp @@ -0,0 +1,57 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif + +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +fixpipe_quant_normal_relu_clip_f16_scalar_ub_cv_kernel( + __gm__ __fp16 *lhs, __gm__ __fp16 *rhs, __gm__ __fp16 *out_ub_relu, + __gm__ __fp16 *out_ub_clip, __gm__ __fp16 *out_gm_relu, + __gm__ __fp16 *out_gm_clip, __gm__ __fp16 *out_l1_relu, + __gm__ __fp16 *out_l1_clip); + +void LaunchFixpipe_quant_normal_relu_clip_f16_scalar_ub_cv_kernel( + __fp16 *lhs, __fp16 *rhs, __fp16 *outUbRelu, __fp16 *outUbClip, + __fp16 *outGmRelu, __fp16 *outGmClip, __fp16 *outL1Relu, + __fp16 *outL1Clip, void *stream) { + fixpipe_quant_normal_relu_clip_f16_scalar_ub_cv_kernel<<<1, nullptr, stream>>>( + (__gm__ __fp16 *)lhs, (__gm__ __fp16 *)rhs, (__gm__ __fp16 *)outUbRelu, + (__gm__ __fp16 *)outUbClip, (__gm__ __fp16 *)outGmRelu, + (__gm__ __fp16 *)outGmClip, (__gm__ __fp16 *)outL1Relu, + (__gm__ __fp16 *)outL1Clip); +} diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-normal-relu-clip-f16-scalar-ub-cv/main.cpp b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-normal-relu-clip-f16-scalar-ub-cv/main.cpp new file mode 100644 index 000000000..0fac8cb8a --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-normal-relu-clip-f16-scalar-ub-cv/main.cpp @@ -0,0 +1,199 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" + +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +#define FILE_CHECK(expr, path) \ + do { \ + if (!(expr)) { \ + std::fprintf(stderr, "[ERROR] file operation failed: %s (%s:%d)\n", \ + path, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchFixpipe_quant_normal_relu_clip_f16_scalar_ub_cv_kernel( + __fp16 *lhs, __fp16 *rhs, __fp16 *outUbRelu, __fp16 *outUbClip, + __fp16 *outGmRelu, __fp16 *outGmClip, __fp16 *outL1Relu, + __fp16 *outL1Clip, void *stream); + +int main() { + constexpr size_t kSizeLhsElems = 40 * 50; + constexpr size_t kSizeRhsElems = 50 * 64; + constexpr size_t kOutElems = 40 * 64; + constexpr size_t kSizeLhs = kSizeLhsElems * sizeof(__fp16); + constexpr size_t kSizeRhs = kSizeRhsElems * sizeof(__fp16); + constexpr size_t kSizeOut = kOutElems * sizeof(__fp16); + + __fp16 *lhsHost = nullptr; + __fp16 *rhsHost = nullptr; + __fp16 *outUbReluHost = nullptr; + __fp16 *outUbClipHost = nullptr; + __fp16 *outGmReluHost = nullptr; + __fp16 *outGmClipHost = nullptr; + __fp16 *outL1ReluHost = nullptr; + __fp16 *outL1ClipHost = nullptr; + __fp16 *lhsDevice = nullptr; + __fp16 *rhsDevice = nullptr; + __fp16 *outUbReluDevice = nullptr; + __fp16 *outUbClipDevice = nullptr; + __fp16 *outGmReluDevice = nullptr; + __fp16 *outGmClipDevice = nullptr; + __fp16 *outL1ReluDevice = nullptr; + __fp16 *outL1ClipDevice = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + size_t inputSize = 0; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)&lhsHost, kSizeLhs)); + ACL_CHECK(aclrtMallocHost((void **)&rhsHost, kSizeRhs)); + ACL_CHECK(aclrtMallocHost((void **)&outUbReluHost, kSizeOut)); + ACL_CHECK(aclrtMallocHost((void **)&outUbClipHost, kSizeOut)); + ACL_CHECK(aclrtMallocHost((void **)&outGmReluHost, kSizeOut)); + ACL_CHECK(aclrtMallocHost((void **)&outGmClipHost, kSizeOut)); + ACL_CHECK(aclrtMallocHost((void **)&outL1ReluHost, kSizeOut)); + ACL_CHECK(aclrtMallocHost((void **)&outL1ClipHost, kSizeOut)); + ACL_CHECK(aclrtMalloc((void **)&lhsDevice, kSizeLhs, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&rhsDevice, kSizeRhs, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outUbReluDevice, kSizeOut, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outUbClipDevice, kSizeOut, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outGmReluDevice, kSizeOut, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outGmClipDevice, kSizeOut, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outL1ReluDevice, kSizeOut, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outL1ClipDevice, kSizeOut, ACL_MEM_MALLOC_HUGE_FIRST)); + + inputSize = kSizeLhs; + FILE_CHECK(ReadFile("./v1.bin", inputSize, lhsHost, kSizeLhs) && inputSize == kSizeLhs, + "./v1.bin"); + inputSize = kSizeRhs; + FILE_CHECK(ReadFile("./v2.bin", inputSize, rhsHost, kSizeRhs) && inputSize == kSizeRhs, + "./v2.bin"); + for (int index = 3; index <= 8; ++index) { + __fp16 *hostBuf = nullptr; + switch (index) { + case 3: hostBuf = outUbReluHost; break; + case 4: hostBuf = outUbClipHost; break; + case 5: hostBuf = outGmReluHost; break; + case 6: hostBuf = outGmClipHost; break; + case 7: hostBuf = outL1ReluHost; break; + case 8: hostBuf = outL1ClipHost; break; + } + inputSize = kSizeOut; + char path[16]; + std::snprintf(path, sizeof(path), "./v%d.bin", index); + FILE_CHECK(ReadFile(path, inputSize, hostBuf, kSizeOut) && inputSize == kSizeOut, + path); + } + + ACL_CHECK(aclrtMemcpy(lhsDevice, kSizeLhs, lhsHost, kSizeLhs, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(rhsDevice, kSizeRhs, rhsHost, kSizeRhs, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outUbReluDevice, kSizeOut, outUbReluHost, kSizeOut, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outUbClipDevice, kSizeOut, outUbClipHost, kSizeOut, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outGmReluDevice, kSizeOut, outGmReluHost, kSizeOut, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outGmClipDevice, kSizeOut, outGmClipHost, kSizeOut, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outL1ReluDevice, kSizeOut, outL1ReluHost, kSizeOut, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outL1ClipDevice, kSizeOut, outL1ClipHost, kSizeOut, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchFixpipe_quant_normal_relu_clip_f16_scalar_ub_cv_kernel( + lhsDevice, rhsDevice, outUbReluDevice, outUbClipDevice, outGmReluDevice, + outGmClipDevice, outL1ReluDevice, outL1ClipDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + + ACL_CHECK(aclrtMemcpy(outUbReluHost, kSizeOut, outUbReluDevice, kSizeOut, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outUbClipHost, kSizeOut, outUbClipDevice, kSizeOut, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outGmReluHost, kSizeOut, outGmReluDevice, kSizeOut, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outGmClipHost, kSizeOut, outGmClipDevice, kSizeOut, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outL1ReluHost, kSizeOut, outL1ReluDevice, kSizeOut, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outL1ClipHost, kSizeOut, outL1ClipDevice, kSizeOut, + ACL_MEMCPY_DEVICE_TO_HOST)); + + for (int index = 3; index <= 8; ++index) { + __fp16 *hostBuf = nullptr; + switch (index) { + case 3: hostBuf = outUbReluHost; break; + case 4: hostBuf = outUbClipHost; break; + case 5: hostBuf = outGmReluHost; break; + case 6: hostBuf = outGmClipHost; break; + case 7: hostBuf = outL1ReluHost; break; + case 8: hostBuf = outL1ClipHost; break; + } + char path[16]; + std::snprintf(path, sizeof(path), "./v%d.bin", index); + FILE_CHECK(WriteFile(path, hostBuf, kSizeOut), path); + } + +cleanup: + aclrtFree(lhsDevice); + aclrtFree(rhsDevice); + aclrtFree(outUbReluDevice); + aclrtFree(outUbClipDevice); + aclrtFree(outGmReluDevice); + aclrtFree(outGmClipDevice); + aclrtFree(outL1ReluDevice); + aclrtFree(outL1ClipDevice); + aclrtFreeHost(lhsHost); + aclrtFreeHost(rhsHost); + aclrtFreeHost(outUbReluHost); + aclrtFreeHost(outUbClipHost); + aclrtFreeHost(outGmReluHost); + aclrtFreeHost(outGmClipHost); + aclrtFreeHost(outL1ReluHost); + aclrtFreeHost(outL1ClipHost); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-normal-relu-clip-f16-scalar-ub-cv/stub.cpp b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-normal-relu-clip-f16-scalar-ub-cv/stub.cpp new file mode 100644 index 000000000..11ae7c09e --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-normal-relu-clip-f16-scalar-ub-cv/stub.cpp @@ -0,0 +1,27 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void +fixpipe_quant_normal_relu_clip_f16_scalar_ub_cv_kernel( + __gm__ __fp16 *lhs, __gm__ __fp16 *rhs, __gm__ __fp16 *out_relu, + __gm__ __fp16 *out_clip) { + (void)lhs; + (void)rhs; + (void)out_relu; + (void)out_clip; +} diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-relu-clip-f16-scalar-ub-cv/compare.py b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-relu-clip-f16-scalar-ub-cv/compare.py new file mode 100644 index 000000000..2d1c4b61c --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-relu-clip-f16-scalar-ub-cv/compare.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.float16) + output = np.fromfile(output_path, dtype=np.float16) + if golden.shape != output.shape: + print(f"[ERROR] shape mismatch: {golden.shape} vs {output.shape}") + return False + if np.allclose(golden, output, atol=1e-3, rtol=1e-3): + return True + diff = np.where(np.abs(golden.astype(np.float32) - output.astype(np.float32)) > + (1e-3 + 1e-3 * np.abs(golden.astype(np.float32))))[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] first mismatch at idx={idx}: " + f"golden={float(golden[idx])}, out={float(output[idx])}" + ) + return False + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + for index in range(3, 9): + ok = compare_bin(f"golden_v{index}.bin", f"v{index}.bin") and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-relu-clip-f16-scalar-ub-cv/golden.py b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-relu-clip-f16-scalar-ub-cv/golden.py new file mode 100644 index 000000000..7b49cc7a4 --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-relu-clip-f16-scalar-ub-cv/golden.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + + +M = 40 +N = 64 +K = 50 +ALPHA = np.float32(0.25) +CLIP_MAX = np.float16(8.0) +SEED = 313 + + +def qf322f16_scalar_relu(data: np.ndarray) -> np.ndarray: + relu_pre = np.where(data >= np.float32(0.0), data, data * ALPHA) + return relu_pre.astype(np.float16) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + lhs = rng.uniform(-2.0, 2.0, size=(M, K)).astype(np.float16) + rhs = rng.uniform(-1.5, 1.5, size=(K, N)).astype(np.float16) + matmul = np.zeros((M, N), dtype=np.float32) + lhs32 = lhs.astype(np.float32) + rhs32 = rhs.astype(np.float32) + for k_idx in range(K): + matmul += lhs32[:, k_idx:k_idx + 1] * rhs32[k_idx:k_idx + 1, :] + relu = qf322f16_scalar_relu(matmul) + clip = np.minimum(relu, CLIP_MAX).astype(np.float16) + + output_dir.mkdir(parents=True, exist_ok=True) + lhs.reshape(-1).tofile(output_dir / "v1.bin") + rhs.reshape(-1).tofile(output_dir / "v2.bin") + mapping = { + 3: relu, + 4: clip, + 5: relu, + 6: clip, + 7: relu, + 8: clip, + } + for index, golden in mapping.items(): + np.zeros((M, N), dtype=np.float16).reshape(-1).tofile(output_dir / f"v{index}.bin") + golden.reshape(-1).tofile(output_dir / f"golden_v{index}.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-relu-clip-f16-scalar-ub-cv/kernel.pto b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-relu-clip-f16-scalar-ub-cv/kernel.pto new file mode 100644 index 000000000..0f5d378c1 --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-relu-clip-f16-scalar-ub-cv/kernel.pto @@ -0,0 +1,179 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/cv-mixed/fixpipe-quant-relu-clip-f16-scalar-ub-cv +// family: cv-mixed +// target_ops: pto.mte_gm_l1_frac, pto.mte_l1_l0a, pto.mte_l1_l0b, pto.mad, +// pto.mte_l0c_ub, pto.mte_l0c_gm, pto.mte_l0c_l1, pto.mte_l1_ub, +// pto.mte_ub_gm, pto.sync.set, pto.sync.wait +// scenarios: split-compile, cube-producer, vector-consumer, fixpipe-qf322f16-scalar, +// scalar-relu, clip-relu-pre, strict-matmul-golden, cc-store-ub-gm-l1 +// ----------------------------------------------------------------------------- + +module attributes {pto.target_arch = "a5"} { + func.func @fixpipe_quant_relu_clip_f16_scalar_ub_cv_kernel( + %lhs_gm: !pto.ptr, + %rhs_gm: !pto.ptr, + %out_ub_relu_gm: !pto.ptr, + %out_ub_clip_gm: !pto.ptr, + %out_gm_relu_gm: !pto.ptr, + %out_gm_clip_gm: !pto.ptr, + %out_l1_relu_gm: !pto.ptr, + %out_l1_clip_gm: !pto.ptr) + attributes {pto.kernel} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c40_i64 = arith.constant 40 : i64 + %c48_i64 = arith.constant 48 : i64 + %c50_i64 = arith.constant 50 : i64 + %c64_i64 = arith.constant 64 : i64 + %c100_i64 = arith.constant 100 : i64 + %c128_i64 = arith.constant 128 : i64 + %c2560_i64 = arith.constant 2560 : i64 + %c32768_i64 = arith.constant 32768 : i64 + %c49152_i64 = arith.constant 49152 : i64 + %c65536_i64 = arith.constant 65536 : i64 + %c98304_i64 = arith.constant 98304 : i64 + %c131072_i64 = arith.constant 131072 : i64 + %c196608_i64 = arith.constant 196608 : i64 + %c8_f16 = arith.constant 8.000000e+00 : f16 + %c025_f32 = arith.constant 2.500000e-01 : f32 + %c1_f32 = arith.constant 1.000000e+00 : f32 + %false = arith.constant false + + %mat_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %mat_rhs = pto.castptr %c65536_i64 : i64 -> !pto.ptr + %l1_relu = pto.castptr %c131072_i64 : i64 -> !pto.ptr + %l1_clip = pto.castptr %c196608_i64 : i64 -> !pto.ptr + %ub_relu = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_clip = pto.castptr %c32768_i64 : i64 -> !pto.ptr + %ub_l1_relu = pto.castptr %c65536_i64 : i64 -> !pto.ptr + %ub_l1_clip = pto.castptr %c98304_i64 : i64 -> !pto.ptr + %l0a = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0b = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0c = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.section.cube { + pto.mte_gm_l1_frac %lhs_gm, %mat_lhs, nd2nz, + shape(%c40_i64, %c50_i64), + src_layout(%c100_i64), + dst_group(%c1_i64, %c1_i64, %c48_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1_frac %rhs_gm, %mat_rhs, nd2nz, + shape(%c50_i64, %c64_i64), + src_layout(%c128_i64), + dst_group(%c1_i64, %c1_i64, %c64_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.mte_l1_l0a %mat_lhs, %l0a, %c40_i64, %c50_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.mte_l1_l0b %mat_rhs, %l0b, %c50_i64, %c64_i64 {transpose = true} + : !pto.ptr, !pto.ptr, i64, i64 + + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.mad %l0a, %l0b, %l0c, %c40_i64, %c64_i64, %c50_i64 + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.set_flag["PIPE_MTE2", "PIPE_FIX", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_FIX", "EVENT_ID0"] + pto.mte_l0c_ub %l0c, %ub_relu, %c40_i64, %c64_i64, %c48_i64, %c64_i64, dst_mode(%c0_i64), + pre_quant(%c1_f32, mode = qf322f16_pre_scalar), + pre_relu(%c025_f32, mode = scalar_relu), + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, + f32, f32, i64, i64, i64 + pto.mte_l0c_ub %l0c, %ub_clip, %c40_i64, %c64_i64, %c48_i64, %c64_i64, dst_mode(%c0_i64), + pre_quant(%c1_f32, mode = qf322f16_pre_scalar), + pre_relu(%c025_f32, mode = scalar_relu, clip = %c8_f16), + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, + f32, f32, f16, i64, i64, i64 + + pto.mte_l0c_gm %l0c, %out_gm_relu_gm, %c40_i64, %c64_i64, %c48_i64, %c64_i64, %c0_i64, %c0_i64, + pre_quant(%c1_f32, mode = qf322f16_pre_scalar), + pre_relu(%c025_f32, mode = scalar_relu), + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, + f32, f32, i64, i64, i64 + pto.mte_l0c_gm %l0c, %out_gm_clip_gm, %c40_i64, %c64_i64, %c48_i64, %c64_i64, %c0_i64, %c0_i64, + pre_quant(%c1_f32, mode = qf322f16_pre_scalar), + pre_relu(%c025_f32, mode = scalar_relu, clip = %c8_f16), + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, + f32, f32, f16, i64, i64, i64 + + pto.mte_l0c_l1 %l0c, %l1_relu, %c40_i64, %c64_i64, %c48_i64, %c64_i64, + pre_quant(%c1_f32, mode = qf322f16_pre_scalar), + pre_relu(%c025_f32, mode = scalar_relu), + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, + f32, f32, i64, i64, i64 + pto.mte_l0c_l1 %l0c, %l1_clip, %c40_i64, %c64_i64, %c48_i64, %c64_i64, + pre_quant(%c1_f32, mode = qf322f16_pre_scalar), + pre_relu(%c025_f32, mode = scalar_relu, clip = %c8_f16), + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, + f32, f32, f16, i64, i64, i64 + + pto.set_flag["PIPE_FIX", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_FIX", "PIPE_MTE1", "EVENT_ID0"] + pto.mte_l1_ub %l1_relu, %ub_l1_relu, %c128_i64 + nburst(%c40_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_l1_ub %l1_clip, %ub_l1_clip, %c128_i64 + nburst(%c40_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE1", "PIPE_FIX", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_FIX", "EVENT_ID0"] + pto.sync.set , 1 + pto.sync.set , 17 + } + + pto.section.vector { + %subblock = pto.get_subblock_idx + %is_subblock0 = arith.cmpi eq, %subblock, %c0_i64 : i64 + + scf.if %is_subblock0 { + pto.sync.wait , 1 + pto.mte_ub_gm %ub_relu, %out_ub_relu_gm, %c128_i64 + nburst(%c40_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_clip, %out_ub_clip_gm, %c128_i64 + nburst(%c40_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_l1_relu, %out_l1_relu_gm, %c128_i64 + nburst(%c40_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_l1_clip, %out_l1_clip_gm, %c128_i64 + nburst(%c40_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + } + pto.barrier #pto.pipe + } + return + } +} diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-relu-clip-f16-scalar-ub-cv/launch.cpp b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-relu-clip-f16-scalar-ub-cv/launch.cpp new file mode 100644 index 000000000..ece02dfaf --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-relu-clip-f16-scalar-ub-cv/launch.cpp @@ -0,0 +1,57 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif + +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +fixpipe_quant_relu_clip_f16_scalar_ub_cv_kernel( + __gm__ __fp16 *lhs, __gm__ __fp16 *rhs, __gm__ __fp16 *out_ub_relu, + __gm__ __fp16 *out_ub_clip, __gm__ __fp16 *out_gm_relu, + __gm__ __fp16 *out_gm_clip, __gm__ __fp16 *out_l1_relu, + __gm__ __fp16 *out_l1_clip); + +void LaunchFixpipe_quant_relu_clip_f16_scalar_ub_cv_kernel( + __fp16 *lhs, __fp16 *rhs, __fp16 *outUbRelu, __fp16 *outUbClip, + __fp16 *outGmRelu, __fp16 *outGmClip, __fp16 *outL1Relu, + __fp16 *outL1Clip, void *stream) { + fixpipe_quant_relu_clip_f16_scalar_ub_cv_kernel<<<1, nullptr, stream>>>( + (__gm__ __fp16 *)lhs, (__gm__ __fp16 *)rhs, (__gm__ __fp16 *)outUbRelu, + (__gm__ __fp16 *)outUbClip, (__gm__ __fp16 *)outGmRelu, + (__gm__ __fp16 *)outGmClip, (__gm__ __fp16 *)outL1Relu, + (__gm__ __fp16 *)outL1Clip); +} diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-relu-clip-f16-scalar-ub-cv/main.cpp b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-relu-clip-f16-scalar-ub-cv/main.cpp new file mode 100644 index 000000000..0bdf0044b --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-relu-clip-f16-scalar-ub-cv/main.cpp @@ -0,0 +1,199 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" + +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +#define FILE_CHECK(expr, path) \ + do { \ + if (!(expr)) { \ + std::fprintf(stderr, "[ERROR] file operation failed: %s (%s:%d)\n", \ + path, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchFixpipe_quant_relu_clip_f16_scalar_ub_cv_kernel( + __fp16 *lhs, __fp16 *rhs, __fp16 *outUbRelu, __fp16 *outUbClip, + __fp16 *outGmRelu, __fp16 *outGmClip, __fp16 *outL1Relu, + __fp16 *outL1Clip, void *stream); + +int main() { + constexpr size_t kSizeLhsElems = 40 * 50; + constexpr size_t kSizeRhsElems = 50 * 64; + constexpr size_t kOutElems = 40 * 64; + constexpr size_t kSizeLhs = kSizeLhsElems * sizeof(__fp16); + constexpr size_t kSizeRhs = kSizeRhsElems * sizeof(__fp16); + constexpr size_t kSizeOut = kOutElems * sizeof(__fp16); + + __fp16 *lhsHost = nullptr; + __fp16 *rhsHost = nullptr; + __fp16 *outUbReluHost = nullptr; + __fp16 *outUbClipHost = nullptr; + __fp16 *outGmReluHost = nullptr; + __fp16 *outGmClipHost = nullptr; + __fp16 *outL1ReluHost = nullptr; + __fp16 *outL1ClipHost = nullptr; + __fp16 *lhsDevice = nullptr; + __fp16 *rhsDevice = nullptr; + __fp16 *outUbReluDevice = nullptr; + __fp16 *outUbClipDevice = nullptr; + __fp16 *outGmReluDevice = nullptr; + __fp16 *outGmClipDevice = nullptr; + __fp16 *outL1ReluDevice = nullptr; + __fp16 *outL1ClipDevice = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + size_t inputSize = 0; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)&lhsHost, kSizeLhs)); + ACL_CHECK(aclrtMallocHost((void **)&rhsHost, kSizeRhs)); + ACL_CHECK(aclrtMallocHost((void **)&outUbReluHost, kSizeOut)); + ACL_CHECK(aclrtMallocHost((void **)&outUbClipHost, kSizeOut)); + ACL_CHECK(aclrtMallocHost((void **)&outGmReluHost, kSizeOut)); + ACL_CHECK(aclrtMallocHost((void **)&outGmClipHost, kSizeOut)); + ACL_CHECK(aclrtMallocHost((void **)&outL1ReluHost, kSizeOut)); + ACL_CHECK(aclrtMallocHost((void **)&outL1ClipHost, kSizeOut)); + ACL_CHECK(aclrtMalloc((void **)&lhsDevice, kSizeLhs, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&rhsDevice, kSizeRhs, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outUbReluDevice, kSizeOut, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outUbClipDevice, kSizeOut, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outGmReluDevice, kSizeOut, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outGmClipDevice, kSizeOut, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outL1ReluDevice, kSizeOut, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outL1ClipDevice, kSizeOut, ACL_MEM_MALLOC_HUGE_FIRST)); + + inputSize = kSizeLhs; + FILE_CHECK(ReadFile("./v1.bin", inputSize, lhsHost, kSizeLhs) && inputSize == kSizeLhs, + "./v1.bin"); + inputSize = kSizeRhs; + FILE_CHECK(ReadFile("./v2.bin", inputSize, rhsHost, kSizeRhs) && inputSize == kSizeRhs, + "./v2.bin"); + for (int index = 3; index <= 8; ++index) { + __fp16 *hostBuf = nullptr; + switch (index) { + case 3: hostBuf = outUbReluHost; break; + case 4: hostBuf = outUbClipHost; break; + case 5: hostBuf = outGmReluHost; break; + case 6: hostBuf = outGmClipHost; break; + case 7: hostBuf = outL1ReluHost; break; + case 8: hostBuf = outL1ClipHost; break; + } + inputSize = kSizeOut; + char path[16]; + std::snprintf(path, sizeof(path), "./v%d.bin", index); + FILE_CHECK(ReadFile(path, inputSize, hostBuf, kSizeOut) && inputSize == kSizeOut, + path); + } + + ACL_CHECK(aclrtMemcpy(lhsDevice, kSizeLhs, lhsHost, kSizeLhs, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(rhsDevice, kSizeRhs, rhsHost, kSizeRhs, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outUbReluDevice, kSizeOut, outUbReluHost, kSizeOut, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outUbClipDevice, kSizeOut, outUbClipHost, kSizeOut, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outGmReluDevice, kSizeOut, outGmReluHost, kSizeOut, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outGmClipDevice, kSizeOut, outGmClipHost, kSizeOut, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outL1ReluDevice, kSizeOut, outL1ReluHost, kSizeOut, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outL1ClipDevice, kSizeOut, outL1ClipHost, kSizeOut, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchFixpipe_quant_relu_clip_f16_scalar_ub_cv_kernel( + lhsDevice, rhsDevice, outUbReluDevice, outUbClipDevice, outGmReluDevice, + outGmClipDevice, outL1ReluDevice, outL1ClipDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + + ACL_CHECK(aclrtMemcpy(outUbReluHost, kSizeOut, outUbReluDevice, kSizeOut, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outUbClipHost, kSizeOut, outUbClipDevice, kSizeOut, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outGmReluHost, kSizeOut, outGmReluDevice, kSizeOut, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outGmClipHost, kSizeOut, outGmClipDevice, kSizeOut, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outL1ReluHost, kSizeOut, outL1ReluDevice, kSizeOut, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outL1ClipHost, kSizeOut, outL1ClipDevice, kSizeOut, + ACL_MEMCPY_DEVICE_TO_HOST)); + + for (int index = 3; index <= 8; ++index) { + __fp16 *hostBuf = nullptr; + switch (index) { + case 3: hostBuf = outUbReluHost; break; + case 4: hostBuf = outUbClipHost; break; + case 5: hostBuf = outGmReluHost; break; + case 6: hostBuf = outGmClipHost; break; + case 7: hostBuf = outL1ReluHost; break; + case 8: hostBuf = outL1ClipHost; break; + } + char path[16]; + std::snprintf(path, sizeof(path), "./v%d.bin", index); + FILE_CHECK(WriteFile(path, hostBuf, kSizeOut), path); + } + +cleanup: + aclrtFree(lhsDevice); + aclrtFree(rhsDevice); + aclrtFree(outUbReluDevice); + aclrtFree(outUbClipDevice); + aclrtFree(outGmReluDevice); + aclrtFree(outGmClipDevice); + aclrtFree(outL1ReluDevice); + aclrtFree(outL1ClipDevice); + aclrtFreeHost(lhsHost); + aclrtFreeHost(rhsHost); + aclrtFreeHost(outUbReluHost); + aclrtFreeHost(outUbClipHost); + aclrtFreeHost(outGmReluHost); + aclrtFreeHost(outGmClipHost); + aclrtFreeHost(outL1ReluHost); + aclrtFreeHost(outL1ClipHost); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-relu-clip-f16-scalar-ub-cv/stub.cpp b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-relu-clip-f16-scalar-ub-cv/stub.cpp new file mode 100644 index 000000000..7373a9d05 --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-relu-clip-f16-scalar-ub-cv/stub.cpp @@ -0,0 +1,28 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void +fixpipe_quant_relu_clip_f16_scalar_ub_cv_kernel(__gm__ __fp16 *lhs, + __gm__ __fp16 *rhs, + __gm__ __fp16 *out_relu, + __gm__ __fp16 *out_clip) { + (void)lhs; + (void)rhs; + (void)out_relu; + (void)out_clip; +} diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-relu-float-payload-f16-ub-cv/compare.py b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-relu-float-payload-f16-ub-cv/compare.py new file mode 100644 index 000000000..4964050dc --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-relu-float-payload-f16-ub-cv/compare.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + print(f"[ERROR] missing file: {golden_path} or {output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.float16) + output = np.fromfile(output_path, dtype=np.float16) + if golden.shape != output.shape: + print(f"[ERROR] shape mismatch: {golden.shape} vs {output.shape}") + return False + if np.allclose(golden, output, atol=1e-3, rtol=1e-3): + return True + diff = np.where(np.abs(golden.astype(np.float32) - output.astype(np.float32)) > + (1e-3 + 1e-3 * np.abs(golden.astype(np.float32))))[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] first mismatch at idx={idx}: " + f"golden={float(golden[idx])}, out={float(output[idx])}" + ) + return False + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + for index in range(3, 6): + ok = compare_bin(f"golden_v{index}.bin", f"v{index}.bin") and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-relu-float-payload-f16-ub-cv/golden.py b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-relu-float-payload-f16-ub-cv/golden.py new file mode 100644 index 000000000..d2dde2b83 --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-relu-float-payload-f16-ub-cv/golden.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + + +M = 40 +N = 64 +K = 50 +ALPHA = np.float32(0.25) +SEED = 607 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + lhs = rng.uniform(-2.0, 2.0, size=(M, K)).astype(np.float16) + rhs = rng.uniform(-1.5, 1.5, size=(K, N)).astype(np.float16) + lhs32 = lhs.astype(np.float32) + rhs32 = rhs.astype(np.float32) + matmul = np.zeros((M, N), dtype=np.float32) + for k_idx in range(K): + matmul += lhs32[:, k_idx:k_idx + 1] * rhs32[k_idx:k_idx + 1, :] + relu_pre = np.where(matmul >= np.float32(0.0), matmul, matmul * ALPHA) + golden = relu_pre.astype(np.float16) + + output_dir.mkdir(parents=True, exist_ok=True) + lhs.reshape(-1).tofile(output_dir / "v1.bin") + rhs.reshape(-1).tofile(output_dir / "v2.bin") + for index in range(3, 6): + np.zeros((M, N), dtype=np.float16).reshape(-1).tofile(output_dir / f"v{index}.bin") + golden.reshape(-1).tofile(output_dir / f"golden_v{index}.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-relu-float-payload-f16-ub-cv/kernel.pto b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-relu-float-payload-f16-ub-cv/kernel.pto new file mode 100644 index 000000000..bb7754b8d --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-relu-float-payload-f16-ub-cv/kernel.pto @@ -0,0 +1,145 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/cv-mixed/fixpipe-quant-relu-float-payload-f16-ub-cv +// family: cv-mixed +// target_ops: pto.mte_gm_l1_frac, pto.mte_l1_l0a, pto.mte_l1_l0b, pto.mad, +// pto.mte_l0c_ub, pto.mte_l0c_gm, pto.mte_l0c_l1, pto.mte_l1_ub, +// pto.mte_ub_gm, pto.sync.set, pto.sync.wait +// scenarios: split-compile, cube-producer, vector-consumer, +// fixpipe-qf322f16-scalar-float-payload, scalar-relu-float-payload, +// f32-f16-bf16-payload-ab, cc-store-ub-gm-l1 +// ----------------------------------------------------------------------------- + +module attributes {pto.target_arch = "a5"} { + func.func @fixpipe_quant_relu_float_payload_f16_ub_cv_kernel( + %lhs_gm: !pto.ptr, + %rhs_gm: !pto.ptr, + %out_ub_gm: !pto.ptr, + %out_gm_gm: !pto.ptr, + %out_l1_gm: !pto.ptr) + attributes {pto.kernel} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c40_i64 = arith.constant 40 : i64 + %c48_i64 = arith.constant 48 : i64 + %c50_i64 = arith.constant 50 : i64 + %c64_i64 = arith.constant 64 : i64 + %c100_i64 = arith.constant 100 : i64 + %c128_i64 = arith.constant 128 : i64 + %c2560_i64 = arith.constant 2560 : i64 + %c32768_i64 = arith.constant 32768 : i64 + %c49152_i64 = arith.constant 49152 : i64 + %c65536_i64 = arith.constant 65536 : i64 + %c131072_i64 = arith.constant 131072 : i64 + %c196608_i64 = arith.constant 196608 : i64 + %c1_f32 = arith.constant 1.000000e+00 : f32 + %c025_f32 = arith.constant 2.500000e-01 : f32 + %c1_f16 = arith.constant 1.000000e+00 : f16 + %c025_f16 = arith.constant 2.500000e-01 : f16 + %c1_bf16 = arith.constant 1.000000e+00 : bf16 + %c025_bf16 = arith.constant 2.500000e-01 : bf16 + %false = arith.constant false + + %mat_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %mat_rhs = pto.castptr %c65536_i64 : i64 -> !pto.ptr + %l1_out = pto.castptr %c196608_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_l1_out = pto.castptr %c32768_i64 : i64 -> !pto.ptr + %l0a = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0b = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0c = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.section.cube { + pto.mte_gm_l1_frac %lhs_gm, %mat_lhs, nd2nz, + shape(%c40_i64, %c50_i64), + src_layout(%c100_i64), + dst_group(%c1_i64, %c1_i64, %c48_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1_frac %rhs_gm, %mat_rhs, nd2nz, + shape(%c50_i64, %c64_i64), + src_layout(%c128_i64), + dst_group(%c1_i64, %c1_i64, %c64_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.mte_l1_l0a %mat_lhs, %l0a, %c40_i64, %c50_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.mte_l1_l0b %mat_rhs, %l0b, %c50_i64, %c64_i64 {transpose = true} + : !pto.ptr, !pto.ptr, i64, i64 + + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.mad %l0a, %l0b, %l0c, %c40_i64, %c64_i64, %c50_i64 + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.set_flag["PIPE_MTE2", "PIPE_FIX", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_FIX", "EVENT_ID0"] + + pto.mte_l0c_ub %l0c, %ub_out, %c40_i64, %c64_i64, %c48_i64, %c64_i64, dst_mode(%c0_i64), + pre_quant(%c1_f32, mode = qf322f16_pre_scalar), + pre_relu(%c025_f32, mode = scalar_relu), + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, + f32, f32, i64, i64, i64 + pto.mte_l0c_gm %l0c, %out_gm_gm, %c40_i64, %c64_i64, %c48_i64, %c64_i64, %c0_i64, %c0_i64, + pre_quant(%c1_f16, mode = qf322f16_pre_scalar), + pre_relu(%c025_f16, mode = scalar_relu), + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, + f16, f16, i64, i64, i64 + pto.mte_l0c_l1 %l0c, %l1_out, %c40_i64, %c64_i64, %c48_i64, %c64_i64, + pre_quant(%c1_bf16, mode = qf322f16_pre_scalar), + pre_relu(%c025_bf16, mode = scalar_relu), + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, + bf16, bf16, i64, i64, i64 + + pto.set_flag["PIPE_FIX", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_FIX", "PIPE_MTE1", "EVENT_ID0"] + pto.mte_l1_ub %l1_out, %ub_l1_out, %c128_i64 + nburst(%c40_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE1", "PIPE_FIX", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_FIX", "EVENT_ID0"] + pto.sync.set , 1 + pto.sync.set , 17 + } + + pto.section.vector { + %subblock = pto.get_subblock_idx + %is_subblock0 = arith.cmpi eq, %subblock, %c0_i64 : i64 + + scf.if %is_subblock0 { + pto.sync.wait , 1 + pto.mte_ub_gm %ub_out, %out_ub_gm, %c128_i64 + nburst(%c40_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_l1_out, %out_l1_gm, %c128_i64 + nburst(%c40_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + } + pto.barrier #pto.pipe + } + return + } +} diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-relu-float-payload-f16-ub-cv/launch.cpp b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-relu-float-payload-f16-ub-cv/launch.cpp new file mode 100644 index 000000000..75cffe932 --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-relu-float-payload-f16-ub-cv/launch.cpp @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif + +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +fixpipe_quant_relu_float_payload_f16_ub_cv_kernel( + __gm__ __fp16 *lhs, __gm__ __fp16 *rhs, __gm__ __fp16 *out_ub, + __gm__ __fp16 *out_gm, __gm__ __fp16 *out_l1); + +void LaunchFixpipe_quant_relu_float_payload_f16_ub_cv_kernel( + __fp16 *lhs, __fp16 *rhs, __fp16 *outUb, __fp16 *outGm, __fp16 *outL1, + void *stream) { + fixpipe_quant_relu_float_payload_f16_ub_cv_kernel<<<1, nullptr, stream>>>( + (__gm__ __fp16 *)lhs, (__gm__ __fp16 *)rhs, (__gm__ __fp16 *)outUb, + (__gm__ __fp16 *)outGm, (__gm__ __fp16 *)outL1); +} diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-relu-float-payload-f16-ub-cv/main.cpp b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-relu-float-payload-f16-ub-cv/main.cpp new file mode 100644 index 000000000..ade4ae9cd --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-relu-float-payload-f16-ub-cv/main.cpp @@ -0,0 +1,145 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" + +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +#define FILE_CHECK(expr, path) \ + do { \ + if (!(expr)) { \ + std::fprintf(stderr, "[ERROR] file operation failed: %s (%s:%d)\n", \ + path, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchFixpipe_quant_relu_float_payload_f16_ub_cv_kernel( + __fp16 *lhs, __fp16 *rhs, __fp16 *outUb, __fp16 *outGm, __fp16 *outL1, + void *stream); + +int main() { + constexpr size_t kLhsElems = 40 * 50; + constexpr size_t kRhsElems = 50 * 64; + constexpr size_t kOutElems = 40 * 64; + constexpr size_t kLhsSize = kLhsElems * sizeof(__fp16); + constexpr size_t kRhsSize = kRhsElems * sizeof(__fp16); + constexpr size_t kOutSize = kOutElems * sizeof(__fp16); + + __fp16 *lhsHost = nullptr; + __fp16 *rhsHost = nullptr; + __fp16 *outUbHost = nullptr; + __fp16 *outGmHost = nullptr; + __fp16 *outL1Host = nullptr; + __fp16 *lhsDevice = nullptr; + __fp16 *rhsDevice = nullptr; + __fp16 *outUbDevice = nullptr; + __fp16 *outGmDevice = nullptr; + __fp16 *outL1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + size_t inputSize = 0; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)&lhsHost, kLhsSize)); + ACL_CHECK(aclrtMallocHost((void **)&rhsHost, kRhsSize)); + ACL_CHECK(aclrtMallocHost((void **)&outUbHost, kOutSize)); + ACL_CHECK(aclrtMallocHost((void **)&outGmHost, kOutSize)); + ACL_CHECK(aclrtMallocHost((void **)&outL1Host, kOutSize)); + ACL_CHECK(aclrtMalloc((void **)&lhsDevice, kLhsSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&rhsDevice, kRhsSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outUbDevice, kOutSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outGmDevice, kOutSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outL1Device, kOutSize, ACL_MEM_MALLOC_HUGE_FIRST)); + + inputSize = kLhsSize; + FILE_CHECK(ReadFile("./v1.bin", inputSize, lhsHost, kLhsSize) && inputSize == kLhsSize, + "./v1.bin"); + inputSize = kRhsSize; + FILE_CHECK(ReadFile("./v2.bin", inputSize, rhsHost, kRhsSize) && inputSize == kRhsSize, + "./v2.bin"); + inputSize = kOutSize; + FILE_CHECK(ReadFile("./v3.bin", inputSize, outUbHost, kOutSize) && inputSize == kOutSize, + "./v3.bin"); + inputSize = kOutSize; + FILE_CHECK(ReadFile("./v4.bin", inputSize, outGmHost, kOutSize) && inputSize == kOutSize, + "./v4.bin"); + inputSize = kOutSize; + FILE_CHECK(ReadFile("./v5.bin", inputSize, outL1Host, kOutSize) && inputSize == kOutSize, + "./v5.bin"); + + ACL_CHECK(aclrtMemcpy(lhsDevice, kLhsSize, lhsHost, kLhsSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(rhsDevice, kRhsSize, rhsHost, kRhsSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outUbDevice, kOutSize, outUbHost, kOutSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outGmDevice, kOutSize, outGmHost, kOutSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outL1Device, kOutSize, outL1Host, kOutSize, ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchFixpipe_quant_relu_float_payload_f16_ub_cv_kernel( + lhsDevice, rhsDevice, outUbDevice, outGmDevice, outL1Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + + ACL_CHECK(aclrtMemcpy(outUbHost, kOutSize, outUbDevice, kOutSize, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outGmHost, kOutSize, outGmDevice, kOutSize, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outL1Host, kOutSize, outL1Device, kOutSize, + ACL_MEMCPY_DEVICE_TO_HOST)); + FILE_CHECK(WriteFile("./v3.bin", outUbHost, kOutSize), "./v3.bin"); + FILE_CHECK(WriteFile("./v4.bin", outGmHost, kOutSize), "./v4.bin"); + FILE_CHECK(WriteFile("./v5.bin", outL1Host, kOutSize), "./v5.bin"); + +cleanup: + aclrtFree(lhsDevice); + aclrtFree(rhsDevice); + aclrtFree(outUbDevice); + aclrtFree(outGmDevice); + aclrtFree(outL1Device); + aclrtFreeHost(lhsHost); + aclrtFreeHost(rhsHost); + aclrtFreeHost(outUbHost); + aclrtFreeHost(outGmHost); + aclrtFreeHost(outL1Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-relu-float-payload-f16-ub-cv/stub.cpp b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-relu-float-payload-f16-ub-cv/stub.cpp new file mode 100644 index 000000000..5b01f5fe2 --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-relu-float-payload-f16-ub-cv/stub.cpp @@ -0,0 +1,30 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void +fixpipe_quant_relu_float_payload_f16_ub_cv_kernel(__gm__ __fp16 *lhs, + __gm__ __fp16 *rhs, + __gm__ __fp16 *out_ub, + __gm__ __fp16 *out_gm, + __gm__ __fp16 *out_l1) { + (void)lhs; + (void)rhs; + (void)out_ub; + (void)out_gm; + (void)out_l1; +} diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-ub-cv/compare.py b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-ub-cv/compare.py new file mode 100644 index 000000000..ab35382d7 --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-ub-cv/compare.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.float16) + output = np.fromfile(output_path, dtype=np.float16) + if golden.shape != output.shape: + print(f"[ERROR] shape mismatch: {golden.shape} vs {output.shape}") + return False + if np.allclose(golden, output, atol=1e-3, rtol=1e-3): + return True + diff = np.where(np.abs(golden.astype(np.float32) - output.astype(np.float32)) > 1e-3)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] first mismatch at idx={idx}: " + f"golden={float(golden[idx])}, out={float(output[idx])}" + ) + return False + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + for index in range(4, 7): + ok = compare_bin(f"golden_v{index}.bin", f"v{index}.bin") and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-ub-cv/golden.py b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-ub-cv/golden.py new file mode 100644 index 000000000..2afd4cce1 --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-ub-cv/golden.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +import struct +from pathlib import Path + +import numpy as np + + +M = 40 +N = 64 +K = 50 +FP_QUANT_ELEMS = 64 +FP_TRANSPORT_ELEMS = FP_QUANT_ELEMS * 2 +SEED = 97 + + +def extract_quant_params(quant: np.uint64) -> tuple[float, int, int]: + value = int(quant) + m1_bits = (value >> 13) & 0x7FFFF + offset = (value >> 37) & 0x1FF + sign = (value >> 46) & 0x1 + + sign_bit = (m1_bits >> 18) & 0x1 + exponent = (m1_bits >> 10) & 0xFF + mantissa = m1_bits & 0x3FF + m1 = ((-1) ** sign_bit) * (1 + mantissa / 1024.0) * (2 ** (exponent - 127)) + return m1, offset, sign + + +def qf322f16_pre(data: np.ndarray, quant: np.ndarray) -> np.ndarray: + result = np.zeros(data.shape, dtype=np.float16) + for row in range(data.shape[0]): + for col in range(data.shape[1]): + m1, _, _ = extract_quant_params(quant[col]) + scaled = data[row, col].astype(np.float32) * np.float32(m1) + result[row, col] = np.clip( + scaled, + np.finfo(np.float16).min, + np.finfo(np.float16).max, + ).astype(np.float16) + return result + + +def make_vector_quant_params(n: int) -> np.ndarray: + scales = (np.arange(n, dtype=np.float32) % np.float32(4.0)) + np.float32(1.0) + encoded = scales.astype(np.uint64) + for idx, scale in enumerate(scales): + encoded[idx] = struct.unpack("!I", struct.pack("!f", float(scale)))[0] + return np.frombuffer(encoded, np.uint64) + + +def generate(output_dir: Path, seed: int) -> None: + a = (np.arange(M * K, dtype=np.float32).reshape(M, K) * np.float32(0.01) + + np.float32(0.5)).astype(np.float16) + b = (np.arange(K * N, dtype=np.float32).reshape(K, N) * np.float32(0.005) + + np.float32(0.25)).astype(np.float16) + fp = make_vector_quant_params(FP_QUANT_ELEMS) + matmul = np.zeros((M, N), dtype=np.float32) + a32 = a.astype(np.float32) + b32 = b.astype(np.float32) + for k_idx in range(K): + matmul += a32[:, k_idx:k_idx + 1] * b32[k_idx:k_idx + 1, :] + golden = qf322f16_pre(matmul, fp) + + output_dir.mkdir(parents=True, exist_ok=True) + a.reshape(-1).tofile(output_dir / "v1.bin") + b.reshape(-1).tofile(output_dir / "v2.bin") + fp.view(np.uint32).reshape(FP_TRANSPORT_ELEMS).tofile(output_dir / "v3.bin") + for index in range(4, 7): + np.zeros((M, N), dtype=np.float16).reshape(-1).tofile(output_dir / f"v{index}.bin") + golden.reshape(-1).tofile(output_dir / f"golden_v{index}.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-ub-cv/kernel.pto b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-ub-cv/kernel.pto new file mode 100644 index 000000000..9e84a497a --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-ub-cv/kernel.pto @@ -0,0 +1,146 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/cv-mixed/fixpipe-quant-ub-cv +// family: cv-mixed +// target_ops: pto.mte_gm_l1_frac, pto.mte_l1_l0a, pto.mte_l1_l0b, pto.mad, +// pto.mte_l1_fb, pto.mte_l0c_ub, pto.mte_l0c_gm, pto.mte_l0c_l1, +// pto.mte_l1_ub, pto.mte_ub_gm, pto.sync.set, pto.sync.wait +// scenarios: split-compile, cube-producer, vector-consumer, fixpipe-vector-qf322f16, +// fp-load, strict-matmul-golden, cc-store-ub-gm-l1 +// ----------------------------------------------------------------------------- + +module attributes {pto.target_arch = "a5"} { + func.func @fixpipe_quant_ub_cv_kernel(%src_gm: !pto.ptr, + %id_gm: !pto.ptr, + %fp_gm: !pto.ptr, + %out_ub_gm: !pto.ptr, + %out_gm_gm: !pto.ptr, + %out_l1_gm: !pto.ptr) + attributes {pto.kernel} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c8_i64 = arith.constant 8 : i64 + %c40_i64 = arith.constant 40 : i64 + %c48_i64 = arith.constant 48 : i64 + %c50_i64 = arith.constant 50 : i64 + %c64_i64 = arith.constant 64 : i64 + %c100_i64 = arith.constant 100 : i64 + %c128_i64 = arith.constant 128 : i64 + %c512_i64 = arith.constant 512 : i64 + %c2560_i64 = arith.constant 2560 : i64 + %c32768_i64 = arith.constant 32768 : i64 + %c49152_i64 = arith.constant 49152 : i64 + %c65536_i64 = arith.constant 65536 : i64 + %c131072_i64 = arith.constant 131072 : i64 + %c196608_i64 = arith.constant 196608 : i64 + %false = arith.constant false + + %mat_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %mat_id = pto.castptr %c65536_i64 : i64 -> !pto.ptr + %l1_fp_raw = pto.castptr %c131072_i64 : i64 -> !pto.ptr + %l1_fp = pto.castptr %c131072_i64 : i64 -> !pto.ptr + %l1_out = pto.castptr %c196608_i64 : i64 -> !pto.ptr + %fb_fp = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_l1_out = pto.castptr %c32768_i64 : i64 -> !pto.ptr + %l0a = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0b = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0c = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.section.cube { + pto.mte_gm_l1_frac %id_gm, %mat_src, nd2nz, + shape(%c40_i64, %c50_i64), + src_layout(%c100_i64), + dst_group(%c1_i64, %c1_i64, %c48_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1_frac %src_gm, %mat_id, nd2nz, + shape(%c50_i64, %c64_i64), + src_layout(%c128_i64), + dst_group(%c1_i64, %c1_i64, %c64_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1 %fp_gm, %l1_fp_raw, %c512_i64 + nburst(%c1_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.mte_l1_l0a %mat_src, %l0a, %c40_i64, %c50_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.mte_l1_l0b %mat_id, %l0b, %c50_i64, %c64_i64 {transpose = true} + : !pto.ptr, !pto.ptr, i64, i64 + + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.mad %l0a, %l0b, %l0c, %c40_i64, %c64_i64, %c50_i64 + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.set_flag["PIPE_MTE2", "PIPE_FIX", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_FIX", "EVENT_ID0"] + pto.mte_l1_fb %l1_fp, %fb_fp, %c8_i64 + nburst(%c1_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.mte_l0c_ub %l0c, %ub_out, %c40_i64, %c64_i64, %c48_i64, %c64_i64, dst_mode(%c0_i64), + pre_quant(%fb_fp, mode = qf322f16_pre_vec), + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, + !pto.ptr, i64, i64, i64 + pto.mte_l0c_gm %l0c, %out_gm_gm, %c40_i64, %c64_i64, %c48_i64, %c64_i64, %c0_i64, %c0_i64, + pre_quant(%fb_fp, mode = qf322f16_pre_vec), + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, + !pto.ptr, i64, i64, i64 + pto.mte_l0c_l1 %l0c, %l1_out, %c40_i64, %c64_i64, %c48_i64, %c64_i64, + pre_quant(%fb_fp, mode = qf322f16_pre_vec), + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, + !pto.ptr, i64, i64, i64 + + pto.set_flag["PIPE_FIX", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_FIX", "PIPE_MTE1", "EVENT_ID0"] + pto.mte_l1_ub %l1_out, %ub_l1_out, %c128_i64 + nburst(%c40_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE1", "PIPE_FIX", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_FIX", "EVENT_ID0"] + pto.sync.set , 1 + pto.sync.set , 17 + } + + pto.section.vector { + %subblock = pto.get_subblock_idx + %is_subblock0 = arith.cmpi eq, %subblock, %c0_i64 : i64 + + scf.if %is_subblock0 { + pto.sync.wait , 1 + pto.mte_ub_gm %ub_out, %out_ub_gm, %c128_i64 + nburst(%c40_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_l1_out, %out_l1_gm, %c128_i64 + nburst(%c40_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + } + pto.barrier #pto.pipe + } + return + } +} diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-ub-cv/launch.cpp b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-ub-cv/launch.cpp new file mode 100644 index 000000000..1a87a2bc8 --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-ub-cv/launch.cpp @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif + +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void fixpipe_quant_ub_cv_kernel( + __gm__ __fp16 *a, __gm__ __fp16 *b, __gm__ uint32_t *fp, + __gm__ __fp16 *out_ub, __gm__ __fp16 *out_gm, __gm__ __fp16 *out_l1); + +void LaunchFixpipe_quant_ub_cv_kernel(__fp16 *a, __fp16 *b, uint32_t *fp, + __fp16 *outUb, __fp16 *outGm, + __fp16 *outL1, void *stream) { + fixpipe_quant_ub_cv_kernel<<<1, nullptr, stream>>>( + (__gm__ __fp16 *)a, (__gm__ __fp16 *)b, (__gm__ uint32_t *)fp, + (__gm__ __fp16 *)outUb, (__gm__ __fp16 *)outGm, (__gm__ __fp16 *)outL1); +} diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-ub-cv/main.cpp b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-ub-cv/main.cpp new file mode 100644 index 000000000..16d1f5881 --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-quant-ub-cv/main.cpp @@ -0,0 +1,163 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" + +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +#define FILE_CHECK(expr, path) \ + do { \ + if (!(expr)) { \ + std::fprintf(stderr, "[ERROR] file operation failed: %s (%s:%d)\n", \ + path, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchFixpipe_quant_ub_cv_kernel(__fp16 *src, __fp16 *id, uint32_t *fp, + __fp16 *outUb, __fp16 *outGm, + __fp16 *outL1, void *stream); + +int main() { + constexpr size_t kSizeSrcElems = 50 * 64; + constexpr size_t kSizeIdElems = 40 * 50; + constexpr size_t kFpElems = 128; + constexpr size_t kOutElems = 40 * 64; + constexpr size_t kSizeSrc = kSizeSrcElems * sizeof(__fp16); + constexpr size_t kSizeId = kSizeIdElems * sizeof(__fp16); + constexpr size_t kSizeFp = kFpElems * sizeof(uint32_t); + constexpr size_t kSizeOut = kOutElems * sizeof(__fp16); + + __fp16 *srcHost = nullptr; + __fp16 *idHost = nullptr; + uint32_t *fpHost = nullptr; + __fp16 *outUbHost = nullptr; + __fp16 *outGmHost = nullptr; + __fp16 *outL1Host = nullptr; + __fp16 *srcDevice = nullptr; + __fp16 *idDevice = nullptr; + uint32_t *fpDevice = nullptr; + __fp16 *outUbDevice = nullptr; + __fp16 *outGmDevice = nullptr; + __fp16 *outL1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + size_t inputSize = 0; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)&srcHost, kSizeSrc)); + ACL_CHECK(aclrtMallocHost((void **)&idHost, kSizeId)); + ACL_CHECK(aclrtMallocHost((void **)&fpHost, kSizeFp)); + ACL_CHECK(aclrtMallocHost((void **)&outUbHost, kSizeOut)); + ACL_CHECK(aclrtMallocHost((void **)&outGmHost, kSizeOut)); + ACL_CHECK(aclrtMallocHost((void **)&outL1Host, kSizeOut)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, kSizeSrc, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&idDevice, kSizeId, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&fpDevice, kSizeFp, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outUbDevice, kSizeOut, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outGmDevice, kSizeOut, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outL1Device, kSizeOut, ACL_MEM_MALLOC_HUGE_FIRST)); + + inputSize = kSizeId; + FILE_CHECK(ReadFile("./v1.bin", inputSize, idHost, kSizeId) && inputSize == kSizeId, + "./v1.bin"); + inputSize = kSizeSrc; + FILE_CHECK(ReadFile("./v2.bin", inputSize, srcHost, kSizeSrc) && inputSize == kSizeSrc, + "./v2.bin"); + inputSize = kSizeFp; + FILE_CHECK(ReadFile("./v3.bin", inputSize, fpHost, kSizeFp) && inputSize == kSizeFp, + "./v3.bin"); + inputSize = kSizeOut; + FILE_CHECK(ReadFile("./v4.bin", inputSize, outUbHost, kSizeOut) && + inputSize == kSizeOut, + "./v4.bin"); + inputSize = kSizeOut; + FILE_CHECK(ReadFile("./v5.bin", inputSize, outGmHost, kSizeOut) && + inputSize == kSizeOut, + "./v5.bin"); + inputSize = kSizeOut; + FILE_CHECK(ReadFile("./v6.bin", inputSize, outL1Host, kSizeOut) && + inputSize == kSizeOut, + "./v6.bin"); + + ACL_CHECK(aclrtMemcpy(srcDevice, kSizeSrc, srcHost, kSizeSrc, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(idDevice, kSizeId, idHost, kSizeId, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(fpDevice, kSizeFp, fpHost, kSizeFp, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outUbDevice, kSizeOut, outUbHost, kSizeOut, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outGmDevice, kSizeOut, outGmHost, kSizeOut, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outL1Device, kSizeOut, outL1Host, kSizeOut, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchFixpipe_quant_ub_cv_kernel(srcDevice, idDevice, fpDevice, outUbDevice, + outGmDevice, outL1Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + + ACL_CHECK(aclrtMemcpy(outUbHost, kSizeOut, outUbDevice, kSizeOut, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outGmHost, kSizeOut, outGmDevice, kSizeOut, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outL1Host, kSizeOut, outL1Device, kSizeOut, + ACL_MEMCPY_DEVICE_TO_HOST)); + FILE_CHECK(WriteFile("./v4.bin", outUbHost, kSizeOut), "./v4.bin"); + FILE_CHECK(WriteFile("./v5.bin", outGmHost, kSizeOut), "./v5.bin"); + FILE_CHECK(WriteFile("./v6.bin", outL1Host, kSizeOut), "./v6.bin"); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(idDevice); + aclrtFree(fpDevice); + aclrtFree(outUbDevice); + aclrtFree(outGmDevice); + aclrtFree(outL1Device); + aclrtFreeHost(srcHost); + aclrtFreeHost(idHost); + aclrtFreeHost(fpHost); + aclrtFreeHost(outUbHost); + aclrtFreeHost(outGmHost); + aclrtFreeHost(outL1Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-relu-scalar-vector-f16-cv/compare.py b/test/vpto/cases/micro-op/cv-mixed/fixpipe-relu-scalar-vector-f16-cv/compare.py new file mode 100644 index 000000000..fc921610e --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-relu-scalar-vector-f16-cv/compare.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + print(f"[ERROR] missing file: {golden_path} or {output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.float16) + output = np.fromfile(output_path, dtype=np.float16) + if golden.shape != output.shape: + print(f"[ERROR] shape mismatch: {golden.shape} vs {output.shape}") + return False + if np.allclose(golden, output, atol=1e-3, rtol=1e-3): + return True + diff = np.where(np.abs(golden.astype(np.float32) - output.astype(np.float32)) > + (1e-3 + 1e-3 * np.abs(golden.astype(np.float32))))[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] first mismatch at idx={idx}: " + f"golden={float(golden[idx])}, out={float(output[idx])}" + ) + return False + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + for index in range(4, 10): + ok = compare_bin(f"golden_v{index}.bin", f"v{index}.bin") and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-relu-scalar-vector-f16-cv/golden.py b/test/vpto/cases/micro-op/cv-mixed/fixpipe-relu-scalar-vector-f16-cv/golden.py new file mode 100644 index 000000000..a35614428 --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-relu-scalar-vector-f16-cv/golden.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +import struct +from pathlib import Path + +import numpy as np + + +M = 40 +N = 64 +K = 50 +FP_RELU_ELEMS = 128 +ALPHA = np.float32(0.25) +VECTOR_ALPHA_PATTERN = np.array([0.125, 0.25, 0.5, 0.75], dtype=np.float32) +SEED = 521 + + +def encode_scale(scale: float) -> np.uint32: + return np.uint32(struct.unpack("!I", struct.pack("!f", scale))[0]) + + +def scalar_relu(data: np.ndarray) -> np.ndarray: + return np.where(data >= np.float32(0.0), data, data * ALPHA) + + +def make_vector_alphas() -> np.ndarray: + return np.resize(VECTOR_ALPHA_PATTERN, N).astype(np.float32) + + +def make_vector_relu_params(vector_alphas: np.ndarray) -> np.ndarray: + payload = np.resize(vector_alphas, FP_RELU_ELEMS).astype(np.float32) + return np.array([encode_scale(float(alpha)) for alpha in payload], dtype=np.uint32) + + +def vector_relu(data: np.ndarray, vector_alphas: np.ndarray) -> np.ndarray: + return np.where( + data >= np.float32(0.0), + data, + data * vector_alphas.reshape(1, N), + ) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + lhs = rng.uniform(-2.0, 2.0, size=(M, K)).astype(np.float16) + rhs = rng.uniform(-1.5, 1.5, size=(K, N)).astype(np.float16) + vector_alphas = make_vector_alphas() + relu_fp = make_vector_relu_params(vector_alphas) + + lhs32 = lhs.astype(np.float32) + rhs32 = rhs.astype(np.float32) + matmul = np.zeros((M, N), dtype=np.float32) + for k_idx in range(K): + matmul += lhs32[:, k_idx:k_idx + 1] * rhs32[k_idx:k_idx + 1, :] + + scalar_golden = scalar_relu(matmul).astype(np.float16) + vector_golden = vector_relu(matmul, vector_alphas).astype(np.float16) + if np.array_equal(scalar_golden, vector_golden): + raise AssertionError("vector relu golden must differ from scalar relu golden") + + output_dir.mkdir(parents=True, exist_ok=True) + lhs.reshape(-1).tofile(output_dir / "v1.bin") + rhs.reshape(-1).tofile(output_dir / "v2.bin") + relu_fp.reshape(-1).tofile(output_dir / "v3.bin") + for index in range(4, 10): + np.zeros((M, N), dtype=np.float16).reshape(-1).tofile(output_dir / f"v{index}.bin") + golden = scalar_golden if index in (4, 6, 8) else vector_golden + golden.reshape(-1).tofile(output_dir / f"golden_v{index}.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-relu-scalar-vector-f16-cv/kernel.pto b/test/vpto/cases/micro-op/cv-mixed/fixpipe-relu-scalar-vector-f16-cv/kernel.pto new file mode 100644 index 000000000..f2780c422 --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-relu-scalar-vector-f16-cv/kernel.pto @@ -0,0 +1,192 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/cv-mixed/fixpipe-relu-scalar-vector-f16-cv +// family: cv-mixed +// target_ops: pto.mte_gm_l1_frac, pto.mte_gm_l1, pto.mte_l1_l0a, pto.mte_l1_l0b, +// pto.mad, pto.mte_l1_fb, pto.mte_l0c_ub, pto.mte_l0c_gm, pto.mte_l0c_l1, +// pto.mte_l1_ub, pto.mte_ub_gm, pto.sync.set, pto.sync.wait +// scenarios: split-compile, cube-producer, vector-consumer, fixpipe-qf322f16-scalar, +// scalar-relu, vector-relu, strict-matmul-golden, cc-store-ub-gm-l1 +// ----------------------------------------------------------------------------- + +module attributes {pto.target_arch = "a5"} { + func.func @fixpipe_relu_scalar_vector_f16_cv_kernel( + %lhs_gm: !pto.ptr, + %rhs_gm: !pto.ptr, + %relu_fp_gm: !pto.ptr, + %out_ub_scalar_gm: !pto.ptr, + %out_ub_vector_gm: !pto.ptr, + %out_gm_scalar_gm: !pto.ptr, + %out_gm_vector_gm: !pto.ptr, + %out_l1_scalar_gm: !pto.ptr, + %out_l1_vector_gm: !pto.ptr) + attributes {pto.kernel} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c8_i64 = arith.constant 8 : i64 + %c40_i64 = arith.constant 40 : i64 + %c48_i64 = arith.constant 48 : i64 + %c50_i64 = arith.constant 50 : i64 + %c64_i64 = arith.constant 64 : i64 + %c100_i64 = arith.constant 100 : i64 + %c128_i64 = arith.constant 128 : i64 + %c512_i64 = arith.constant 512 : i64 + %c2560_i64 = arith.constant 2560 : i64 + %c32768_i64 = arith.constant 32768 : i64 + %c49152_i64 = arith.constant 49152 : i64 + %c65536_i64 = arith.constant 65536 : i64 + %c65600_i64 = arith.constant 65600 : i64 + %c98304_i64 = arith.constant 98304 : i64 + %c131072_i64 = arith.constant 131072 : i64 + %c196608_i64 = arith.constant 196608 : i64 + %c262144_i64 = arith.constant 262144 : i64 + %c025_f32 = arith.constant 2.500000e-01 : f32 + %c1_f32 = arith.constant 1.000000e+00 : f32 + %false = arith.constant false + + %mat_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %mat_rhs = pto.castptr %c65536_i64 : i64 -> !pto.ptr + %l1_fp_raw = pto.castptr %c131072_i64 : i64 -> !pto.ptr + %l1_scalar = pto.castptr %c196608_i64 : i64 -> !pto.ptr + %l1_vector = pto.castptr %c262144_i64 : i64 -> !pto.ptr + %fb_relu = pto.castptr %c65600_i64 : i64 -> !pto.ptr + %ub_scalar = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_vector = pto.castptr %c32768_i64 : i64 -> !pto.ptr + %ub_l1_scalar = pto.castptr %c65536_i64 : i64 -> !pto.ptr + %ub_l1_vector = pto.castptr %c98304_i64 : i64 -> !pto.ptr + %l0a = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0b = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0c = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.section.cube { + pto.mte_gm_l1_frac %lhs_gm, %mat_lhs, nd2nz, + shape(%c40_i64, %c50_i64), + src_layout(%c100_i64), + dst_group(%c1_i64, %c1_i64, %c48_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1_frac %rhs_gm, %mat_rhs, nd2nz, + shape(%c50_i64, %c64_i64), + src_layout(%c128_i64), + dst_group(%c1_i64, %c1_i64, %c64_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1 %relu_fp_gm, %l1_fp_raw, %c512_i64 + nburst(%c1_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.mte_l1_l0a %mat_lhs, %l0a, %c40_i64, %c50_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.mte_l1_l0b %mat_rhs, %l0b, %c50_i64, %c64_i64 {transpose = true} + : !pto.ptr, !pto.ptr, i64, i64 + + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.mad %l0a, %l0b, %l0c, %c40_i64, %c64_i64, %c50_i64 + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.set_flag["PIPE_MTE2", "PIPE_FIX", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_FIX", "EVENT_ID0"] + pto.mte_l1_fb %l1_fp_raw, %fb_relu, %c8_i64 + nburst(%c1_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.mte_l0c_ub %l0c, %ub_scalar, %c40_i64, %c64_i64, %c48_i64, %c64_i64, dst_mode(%c0_i64), + pre_quant(%c1_f32, mode = qf322f16_pre_scalar), + pre_relu(%c025_f32, mode = scalar_relu), + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, + f32, f32, i64, i64, i64 + pto.mte_l0c_ub %l0c, %ub_vector, %c40_i64, %c64_i64, %c48_i64, %c64_i64, dst_mode(%c0_i64), + pre_quant(%c1_f32, mode = qf322f16_pre_scalar), + pre_relu(%fb_relu, mode = vector_relu), + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, + f32, !pto.ptr, i64, i64, i64 + + pto.mte_l0c_gm %l0c, %out_gm_scalar_gm, %c40_i64, %c64_i64, %c48_i64, %c64_i64, %c0_i64, %c0_i64, + pre_quant(%c1_f32, mode = qf322f16_pre_scalar), + pre_relu(%c025_f32, mode = scalar_relu), + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, + f32, f32, i64, i64, i64 + pto.mte_l0c_gm %l0c, %out_gm_vector_gm, %c40_i64, %c64_i64, %c48_i64, %c64_i64, %c0_i64, %c0_i64, + pre_quant(%c1_f32, mode = qf322f16_pre_scalar), + pre_relu(%fb_relu, mode = vector_relu), + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, + f32, !pto.ptr, i64, i64, i64 + + pto.mte_l0c_l1 %l0c, %l1_scalar, %c40_i64, %c64_i64, %c48_i64, %c64_i64, + pre_quant(%c1_f32, mode = qf322f16_pre_scalar), + pre_relu(%c025_f32, mode = scalar_relu), + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, + f32, f32, i64, i64, i64 + pto.mte_l0c_l1 %l0c, %l1_vector, %c40_i64, %c64_i64, %c48_i64, %c64_i64, + pre_quant(%c1_f32, mode = qf322f16_pre_scalar), + pre_relu(%fb_relu, mode = vector_relu), + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, + f32, !pto.ptr, i64, i64, i64 + + pto.set_flag["PIPE_FIX", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_FIX", "PIPE_MTE1", "EVENT_ID0"] + pto.mte_l1_ub %l1_scalar, %ub_l1_scalar, %c128_i64 + nburst(%c40_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_l1_ub %l1_vector, %ub_l1_vector, %c128_i64 + nburst(%c40_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE1", "PIPE_FIX", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_FIX", "EVENT_ID0"] + pto.sync.set , 1 + pto.sync.set , 17 + } + + pto.section.vector { + %subblock = pto.get_subblock_idx + %is_subblock0 = arith.cmpi eq, %subblock, %c0_i64 : i64 + + scf.if %is_subblock0 { + pto.sync.wait , 1 + pto.mte_ub_gm %ub_scalar, %out_ub_scalar_gm, %c128_i64 + nburst(%c40_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_vector, %out_ub_vector_gm, %c128_i64 + nburst(%c40_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_l1_scalar, %out_l1_scalar_gm, %c128_i64 + nburst(%c40_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_l1_vector, %out_l1_vector_gm, %c128_i64 + nburst(%c40_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + } + pto.barrier #pto.pipe + } + return + } +} diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-relu-scalar-vector-f16-cv/launch.cpp b/test/vpto/cases/micro-op/cv-mixed/fixpipe-relu-scalar-vector-f16-cv/launch.cpp new file mode 100644 index 000000000..f72fda638 --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-relu-scalar-vector-f16-cv/launch.cpp @@ -0,0 +1,56 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif + +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void fixpipe_relu_scalar_vector_f16_cv_kernel( + __gm__ __fp16 *lhs, __gm__ __fp16 *rhs, __gm__ uint32_t *relu_fp, + __gm__ __fp16 *out_ub_scalar, __gm__ __fp16 *out_ub_vector, + __gm__ __fp16 *out_gm_scalar, __gm__ __fp16 *out_gm_vector, + __gm__ __fp16 *out_l1_scalar, __gm__ __fp16 *out_l1_vector); + +void LaunchFixpipe_relu_scalar_vector_f16_cv_kernel( + __fp16 *lhs, __fp16 *rhs, uint32_t *reluFp, __fp16 *outUbScalar, + __fp16 *outUbVector, __fp16 *outGmScalar, __fp16 *outGmVector, + __fp16 *outL1Scalar, __fp16 *outL1Vector, void *stream) { + fixpipe_relu_scalar_vector_f16_cv_kernel<<<1, nullptr, stream>>>( + (__gm__ __fp16 *)lhs, (__gm__ __fp16 *)rhs, (__gm__ uint32_t *)reluFp, + (__gm__ __fp16 *)outUbScalar, (__gm__ __fp16 *)outUbVector, + (__gm__ __fp16 *)outGmScalar, (__gm__ __fp16 *)outGmVector, + (__gm__ __fp16 *)outL1Scalar, (__gm__ __fp16 *)outL1Vector); +} diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-relu-scalar-vector-f16-cv/main.cpp b/test/vpto/cases/micro-op/cv-mixed/fixpipe-relu-scalar-vector-f16-cv/main.cpp new file mode 100644 index 000000000..cd53565a4 --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-relu-scalar-vector-f16-cv/main.cpp @@ -0,0 +1,215 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" + +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +#define FILE_CHECK(expr, path) \ + do { \ + if (!(expr)) { \ + std::fprintf(stderr, "[ERROR] file operation failed: %s (%s:%d)\n", \ + path, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchFixpipe_relu_scalar_vector_f16_cv_kernel( + __fp16 *lhs, __fp16 *rhs, uint32_t *reluFp, __fp16 *outUbScalar, + __fp16 *outUbVector, __fp16 *outGmScalar, __fp16 *outGmVector, + __fp16 *outL1Scalar, __fp16 *outL1Vector, void *stream); + +int main() { + constexpr size_t kLhsElems = 40 * 50; + constexpr size_t kRhsElems = 50 * 64; + constexpr size_t kReluFpElems = 64 * 2; + constexpr size_t kOutElems = 40 * 64; + constexpr size_t kLhsSize = kLhsElems * sizeof(__fp16); + constexpr size_t kRhsSize = kRhsElems * sizeof(__fp16); + constexpr size_t kReluFpSize = kReluFpElems * sizeof(uint32_t); + constexpr size_t kOutSize = kOutElems * sizeof(__fp16); + + __fp16 *lhsHost = nullptr; + __fp16 *rhsHost = nullptr; + uint32_t *reluFpHost = nullptr; + __fp16 *outUbScalarHost = nullptr; + __fp16 *outUbVectorHost = nullptr; + __fp16 *outGmScalarHost = nullptr; + __fp16 *outGmVectorHost = nullptr; + __fp16 *outL1ScalarHost = nullptr; + __fp16 *outL1VectorHost = nullptr; + __fp16 *lhsDevice = nullptr; + __fp16 *rhsDevice = nullptr; + uint32_t *reluFpDevice = nullptr; + __fp16 *outUbScalarDevice = nullptr; + __fp16 *outUbVectorDevice = nullptr; + __fp16 *outGmScalarDevice = nullptr; + __fp16 *outGmVectorDevice = nullptr; + __fp16 *outL1ScalarDevice = nullptr; + __fp16 *outL1VectorDevice = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + size_t inputSize = 0; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)&lhsHost, kLhsSize)); + ACL_CHECK(aclrtMallocHost((void **)&rhsHost, kRhsSize)); + ACL_CHECK(aclrtMallocHost((void **)&reluFpHost, kReluFpSize)); + ACL_CHECK(aclrtMallocHost((void **)&outUbScalarHost, kOutSize)); + ACL_CHECK(aclrtMallocHost((void **)&outUbVectorHost, kOutSize)); + ACL_CHECK(aclrtMallocHost((void **)&outGmScalarHost, kOutSize)); + ACL_CHECK(aclrtMallocHost((void **)&outGmVectorHost, kOutSize)); + ACL_CHECK(aclrtMallocHost((void **)&outL1ScalarHost, kOutSize)); + ACL_CHECK(aclrtMallocHost((void **)&outL1VectorHost, kOutSize)); + ACL_CHECK(aclrtMalloc((void **)&lhsDevice, kLhsSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&rhsDevice, kRhsSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&reluFpDevice, kReluFpSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outUbScalarDevice, kOutSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outUbVectorDevice, kOutSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outGmScalarDevice, kOutSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outGmVectorDevice, kOutSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outL1ScalarDevice, kOutSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outL1VectorDevice, kOutSize, ACL_MEM_MALLOC_HUGE_FIRST)); + + inputSize = kLhsSize; + FILE_CHECK(ReadFile("./v1.bin", inputSize, lhsHost, kLhsSize) && inputSize == kLhsSize, + "./v1.bin"); + inputSize = kRhsSize; + FILE_CHECK(ReadFile("./v2.bin", inputSize, rhsHost, kRhsSize) && inputSize == kRhsSize, + "./v2.bin"); + inputSize = kReluFpSize; + FILE_CHECK(ReadFile("./v3.bin", inputSize, reluFpHost, kReluFpSize) && + inputSize == kReluFpSize, + "./v3.bin"); + + for (int index = 4; index <= 9; ++index) { + __fp16 *hostBuf = nullptr; + switch (index) { + case 4: hostBuf = outUbScalarHost; break; + case 5: hostBuf = outUbVectorHost; break; + case 6: hostBuf = outGmScalarHost; break; + case 7: hostBuf = outGmVectorHost; break; + case 8: hostBuf = outL1ScalarHost; break; + case 9: hostBuf = outL1VectorHost; break; + } + inputSize = kOutSize; + char path[16]; + std::snprintf(path, sizeof(path), "./v%d.bin", index); + FILE_CHECK(ReadFile(path, inputSize, hostBuf, kOutSize) && inputSize == kOutSize, + path); + } + + ACL_CHECK(aclrtMemcpy(lhsDevice, kLhsSize, lhsHost, kLhsSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(rhsDevice, kRhsSize, rhsHost, kRhsSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(reluFpDevice, kReluFpSize, reluFpHost, kReluFpSize, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outUbScalarDevice, kOutSize, outUbScalarHost, kOutSize, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outUbVectorDevice, kOutSize, outUbVectorHost, kOutSize, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outGmScalarDevice, kOutSize, outGmScalarHost, kOutSize, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outGmVectorDevice, kOutSize, outGmVectorHost, kOutSize, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outL1ScalarDevice, kOutSize, outL1ScalarHost, kOutSize, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outL1VectorDevice, kOutSize, outL1VectorHost, kOutSize, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchFixpipe_relu_scalar_vector_f16_cv_kernel( + lhsDevice, rhsDevice, reluFpDevice, outUbScalarDevice, outUbVectorDevice, + outGmScalarDevice, outGmVectorDevice, outL1ScalarDevice, outL1VectorDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + + ACL_CHECK(aclrtMemcpy(outUbScalarHost, kOutSize, outUbScalarDevice, kOutSize, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outUbVectorHost, kOutSize, outUbVectorDevice, kOutSize, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outGmScalarHost, kOutSize, outGmScalarDevice, kOutSize, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outGmVectorHost, kOutSize, outGmVectorDevice, kOutSize, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outL1ScalarHost, kOutSize, outL1ScalarDevice, kOutSize, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outL1VectorHost, kOutSize, outL1VectorDevice, kOutSize, + ACL_MEMCPY_DEVICE_TO_HOST)); + + for (int index = 4; index <= 9; ++index) { + __fp16 *hostBuf = nullptr; + switch (index) { + case 4: hostBuf = outUbScalarHost; break; + case 5: hostBuf = outUbVectorHost; break; + case 6: hostBuf = outGmScalarHost; break; + case 7: hostBuf = outGmVectorHost; break; + case 8: hostBuf = outL1ScalarHost; break; + case 9: hostBuf = outL1VectorHost; break; + } + char path[16]; + std::snprintf(path, sizeof(path), "./v%d.bin", index); + FILE_CHECK(WriteFile(path, hostBuf, kOutSize), path); + } + +cleanup: + aclrtFree(lhsDevice); + aclrtFree(rhsDevice); + aclrtFree(reluFpDevice); + aclrtFree(outUbScalarDevice); + aclrtFree(outUbVectorDevice); + aclrtFree(outGmScalarDevice); + aclrtFree(outGmVectorDevice); + aclrtFree(outL1ScalarDevice); + aclrtFree(outL1VectorDevice); + aclrtFreeHost(lhsHost); + aclrtFreeHost(rhsHost); + aclrtFreeHost(reluFpHost); + aclrtFreeHost(outUbScalarHost); + aclrtFreeHost(outUbVectorHost); + aclrtFreeHost(outGmScalarHost); + aclrtFreeHost(outGmVectorHost); + aclrtFreeHost(outL1ScalarHost); + aclrtFreeHost(outL1VectorHost); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-relu-scalar-vector-f16-cv/stub.cpp b/test/vpto/cases/micro-op/cv-mixed/fixpipe-relu-scalar-vector-f16-cv/stub.cpp new file mode 100644 index 000000000..5eb333377 --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-relu-scalar-vector-f16-cv/stub.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void fixpipe_relu_scalar_vector_f16_cv_kernel( + __gm__ __fp16 *lhs, __gm__ __fp16 *rhs, __gm__ uint32_t *relu_fp, + __gm__ __fp16 *out_ub_scalar, __gm__ __fp16 *out_ub_vector, + __gm__ __fp16 *out_gm_scalar, __gm__ __fp16 *out_gm_vector, + __gm__ __fp16 *out_l1_scalar, __gm__ __fp16 *out_l1_vector) { + (void)lhs; + (void)rhs; + (void)relu_fp; + (void)out_ub_scalar; + (void)out_ub_vector; + (void)out_gm_scalar; + (void)out_gm_vector; + (void)out_l1_scalar; + (void)out_l1_vector; +} diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-relu-ub-cv/compare.py b/test/vpto/cases/micro-op/cv-mixed/fixpipe-relu-ub-cv/compare.py new file mode 100644 index 000000000..2d1c4b61c --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-relu-ub-cv/compare.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.float16) + output = np.fromfile(output_path, dtype=np.float16) + if golden.shape != output.shape: + print(f"[ERROR] shape mismatch: {golden.shape} vs {output.shape}") + return False + if np.allclose(golden, output, atol=1e-3, rtol=1e-3): + return True + diff = np.where(np.abs(golden.astype(np.float32) - output.astype(np.float32)) > + (1e-3 + 1e-3 * np.abs(golden.astype(np.float32))))[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] first mismatch at idx={idx}: " + f"golden={float(golden[idx])}, out={float(output[idx])}" + ) + return False + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + for index in range(3, 9): + ok = compare_bin(f"golden_v{index}.bin", f"v{index}.bin") and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-relu-ub-cv/golden.py b/test/vpto/cases/micro-op/cv-mixed/fixpipe-relu-ub-cv/golden.py new file mode 100644 index 000000000..71ec8e961 --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-relu-ub-cv/golden.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + + +M = 40 +N = 64 +K = 50 +CLIP_MAX = np.float16(8.0) +SEED = 211 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + lhs = rng.uniform(-2.0, 2.0, size=(M, K)).astype(np.float16) + rhs = rng.uniform(-1.5, 1.5, size=(K, N)).astype(np.float16) + lhs32 = lhs.astype(np.float32) + rhs32 = rhs.astype(np.float32) + matmul = np.zeros((M, N), dtype=np.float32) + for k_idx in range(K): + matmul += lhs32[:, k_idx:k_idx + 1] * rhs32[k_idx:k_idx + 1, :] + relu = np.maximum(matmul, np.float32(0.0)).astype(np.float16) + clipped = np.minimum(matmul.astype(np.float16), CLIP_MAX).astype(np.float16) + + output_dir.mkdir(parents=True, exist_ok=True) + lhs.reshape(-1).tofile(output_dir / "v1.bin") + rhs.reshape(-1).tofile(output_dir / "v2.bin") + mapping = { + 3: relu, + 4: clipped, + 5: relu, + 6: clipped, + 7: relu, + 8: clipped, + } + for index, golden in mapping.items(): + np.zeros((M, N), dtype=np.float16).reshape(-1).tofile(output_dir / f"v{index}.bin") + golden.reshape(-1).tofile(output_dir / f"golden_v{index}.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-relu-ub-cv/kernel.pto b/test/vpto/cases/micro-op/cv-mixed/fixpipe-relu-ub-cv/kernel.pto new file mode 100644 index 000000000..7a8f3102c --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-relu-ub-cv/kernel.pto @@ -0,0 +1,176 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/cv-mixed/fixpipe-relu-ub-cv +// family: cv-mixed +// target_ops: pto.mte_gm_l1_frac, pto.mte_l1_l0a, pto.mte_l1_l0b, pto.mad, +// pto.mte_l0c_ub, pto.mte_l0c_gm, pto.mte_l0c_l1, pto.mte_l1_ub, +// pto.mte_ub_gm, pto.sync.set, pto.sync.wait +// scenarios: nd2nz-functional-load, cc-store-ub-gm-l1, normal-relu, standalone-clip, ub-to-gm +// ----------------------------------------------------------------------------- + +module attributes {pto.target_arch = "a5"} { + func.func @fixpipe_relu_ub_cv_kernel(%src_gm: !pto.ptr, + %id_gm: !pto.ptr, + %out_ub_relu_gm: !pto.ptr, + %out_ub_clip_gm: !pto.ptr, + %out_gm_relu_gm: !pto.ptr, + %out_gm_clip_gm: !pto.ptr, + %out_l1_relu_gm: !pto.ptr, + %out_l1_clip_gm: !pto.ptr) + attributes {pto.kernel} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c40_i64 = arith.constant 40 : i64 + %c48_i64 = arith.constant 48 : i64 + %c50_i64 = arith.constant 50 : i64 + %c64_i64 = arith.constant 64 : i64 + %c100_i64 = arith.constant 100 : i64 + %c128_i64 = arith.constant 128 : i64 + %c2560_i64 = arith.constant 2560 : i64 + %c32768_i64 = arith.constant 32768 : i64 + %c49152_i64 = arith.constant 49152 : i64 + %c65536_i64 = arith.constant 65536 : i64 + %c98304_i64 = arith.constant 98304 : i64 + %c131072_i64 = arith.constant 131072 : i64 + %c196608_i64 = arith.constant 196608 : i64 + %c8_f16 = arith.constant 8.000000e+00 : f16 + %c1_f32 = arith.constant 1.000000e+00 : f32 + %false = arith.constant false + + %mat_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %mat_id = pto.castptr %c65536_i64 : i64 -> !pto.ptr + %l1_relu = pto.castptr %c131072_i64 : i64 -> !pto.ptr + %l1_clip = pto.castptr %c196608_i64 : i64 -> !pto.ptr + %ub_relu = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_clip = pto.castptr %c32768_i64 : i64 -> !pto.ptr + %ub_l1_relu = pto.castptr %c65536_i64 : i64 -> !pto.ptr + %ub_l1_clip = pto.castptr %c98304_i64 : i64 -> !pto.ptr + %l0a = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0b = pto.castptr %c0_i64 : i64 -> !pto.ptr + %l0c = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.section.cube { + pto.mte_gm_l1_frac %id_gm, %mat_src, nd2nz, + shape(%c40_i64, %c50_i64), + src_layout(%c100_i64), + dst_group(%c1_i64, %c1_i64, %c48_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + pto.mte_gm_l1_frac %src_gm, %mat_id, nd2nz, + shape(%c50_i64, %c64_i64), + src_layout(%c128_i64), + dst_group(%c1_i64, %c1_i64, %c64_i64, %c0_i64), + ctrl(%c0_i64, %false) + : !pto.ptr, !pto.ptr, nd2nz, + shape i64, i64, src_layout(i64), + dst_group i64, i64, i64, i64, ctrl i64, i1 + + pto.set_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE1", "EVENT_ID0"] + pto.mte_l1_l0a %mat_src, %l0a, %c40_i64, %c50_i64 + : !pto.ptr, !pto.ptr, i64, i64 + pto.mte_l1_l0b %mat_id, %l0b, %c50_i64, %c64_i64 {transpose = true} + : !pto.ptr, !pto.ptr, i64, i64 + + pto.set_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_M", "EVENT_ID0"] + pto.mad %l0a, %l0b, %l0c, %c40_i64, %c64_i64, %c50_i64 + : !pto.ptr, !pto.ptr, !pto.ptr, i64, i64, i64 + + pto.set_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.wait_flag["PIPE_M", "PIPE_FIX", "EVENT_ID1"] + pto.set_flag["PIPE_MTE2", "PIPE_FIX", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_FIX", "EVENT_ID0"] + pto.mte_l0c_ub %l0c, %ub_relu, %c40_i64, %c64_i64, %c48_i64, %c64_i64, dst_mode(%c0_i64), + pre_quant(%c1_f32, mode = qf322f16_pre_scalar), + pre_relu(mode = normal_relu), + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, + f32, i64, i64, i64 + pto.mte_l0c_ub %l0c, %ub_clip, %c40_i64, %c64_i64, %c48_i64, %c64_i64, dst_mode(%c0_i64), + pre_quant(%c1_f32, mode = qf322f16_pre_scalar), + pre_relu(mode = no_relu, clip = %c8_f16), + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, + f32, f16, i64, i64, i64 + + pto.mte_l0c_gm %l0c, %out_gm_relu_gm, %c40_i64, %c64_i64, %c48_i64, %c64_i64, %c0_i64, %c0_i64, + pre_quant(%c1_f32, mode = qf322f16_pre_scalar), + pre_relu(mode = normal_relu), + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, + f32, i64, i64, i64 + pto.mte_l0c_gm %l0c, %out_gm_clip_gm, %c40_i64, %c64_i64, %c48_i64, %c64_i64, %c0_i64, %c0_i64, + pre_quant(%c1_f32, mode = qf322f16_pre_scalar), + pre_relu(mode = no_relu, clip = %c8_f16), + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, + f32, f16, i64, i64, i64 + + pto.mte_l0c_l1 %l0c, %l1_relu, %c40_i64, %c64_i64, %c48_i64, %c64_i64, + pre_quant(%c1_f32, mode = qf322f16_pre_scalar), + pre_relu(mode = normal_relu), + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, + f32, i64, i64, i64 + pto.mte_l0c_l1 %l0c, %l1_clip, %c40_i64, %c64_i64, %c48_i64, %c64_i64, + pre_quant(%c1_f32, mode = qf322f16_pre_scalar), + pre_relu(mode = no_relu, clip = %c8_f16), + nz2nd, + loop3(%c1_i64, %c49152_i64, %c2560_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, + f32, f16, i64, i64, i64 + + pto.set_flag["PIPE_FIX", "PIPE_MTE1", "EVENT_ID0"] + pto.wait_flag["PIPE_FIX", "PIPE_MTE1", "EVENT_ID0"] + pto.mte_l1_ub %l1_relu, %ub_l1_relu, %c128_i64 + nburst(%c40_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_l1_ub %l1_clip, %ub_l1_clip, %c128_i64 + nburst(%c40_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE1", "PIPE_FIX", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE1", "PIPE_FIX", "EVENT_ID0"] + pto.sync.set , 1 + pto.sync.set , 17 + } + + pto.section.vector { + %subblock = pto.get_subblock_idx + %is_subblock0 = arith.cmpi eq, %subblock, %c0_i64 : i64 + + scf.if %is_subblock0 { + pto.sync.wait , 1 + pto.mte_ub_gm %ub_relu, %out_ub_relu_gm, %c128_i64 + nburst(%c40_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_clip, %out_ub_clip_gm, %c128_i64 + nburst(%c40_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_l1_relu, %out_l1_relu_gm, %c128_i64 + nburst(%c40_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_l1_clip, %out_l1_clip_gm, %c128_i64 + nburst(%c40_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + } + pto.barrier #pto.pipe + } + return + } +} diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-relu-ub-cv/launch.cpp b/test/vpto/cases/micro-op/cv-mixed/fixpipe-relu-ub-cv/launch.cpp new file mode 100644 index 000000000..fbac3b731 --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-relu-ub-cv/launch.cpp @@ -0,0 +1,56 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif + +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void fixpipe_relu_ub_cv_kernel( + __gm__ __fp16 *src, __gm__ __fp16 *id, __gm__ __fp16 *out_ub_relu, + __gm__ __fp16 *out_ub_clip, __gm__ __fp16 *out_gm_relu, + __gm__ __fp16 *out_gm_clip, __gm__ __fp16 *out_l1_relu, + __gm__ __fp16 *out_l1_clip); + +void LaunchFixpipe_relu_ub_cv_kernel(__fp16 *src, __fp16 *id, __fp16 *outUbRelu, + __fp16 *outUbClip, __fp16 *outGmRelu, + __fp16 *outGmClip, __fp16 *outL1Relu, + __fp16 *outL1Clip, void *stream) { + fixpipe_relu_ub_cv_kernel<<<1, nullptr, stream>>>( + (__gm__ __fp16 *)src, (__gm__ __fp16 *)id, (__gm__ __fp16 *)outUbRelu, + (__gm__ __fp16 *)outUbClip, (__gm__ __fp16 *)outGmRelu, + (__gm__ __fp16 *)outGmClip, (__gm__ __fp16 *)outL1Relu, + (__gm__ __fp16 *)outL1Clip); +} diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-relu-ub-cv/main.cpp b/test/vpto/cases/micro-op/cv-mixed/fixpipe-relu-ub-cv/main.cpp new file mode 100644 index 000000000..2028dbd9b --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-relu-ub-cv/main.cpp @@ -0,0 +1,200 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" + +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +#define FILE_CHECK(expr, path) \ + do { \ + if (!(expr)) { \ + std::fprintf(stderr, "[ERROR] file operation failed: %s (%s:%d)\n", \ + path, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchFixpipe_relu_ub_cv_kernel(__fp16 *src, __fp16 *id, __fp16 *outUbRelu, + __fp16 *outUbClip, __fp16 *outGmRelu, + __fp16 *outGmClip, __fp16 *outL1Relu, + __fp16 *outL1Clip, void *stream); + +int main() { + constexpr size_t kSrcElems = 50 * 64; + constexpr size_t kIdElems = 40 * 50; + constexpr size_t kOutElems = 40 * 64; + constexpr size_t kSrcSize = kSrcElems * sizeof(__fp16); + constexpr size_t kIdSize = kIdElems * sizeof(__fp16); + constexpr size_t kOutSize = kOutElems * sizeof(__fp16); + + __fp16 *srcHost = nullptr; + __fp16 *idHost = nullptr; + __fp16 *outUbReluHost = nullptr; + __fp16 *outUbClipHost = nullptr; + __fp16 *outGmReluHost = nullptr; + __fp16 *outGmClipHost = nullptr; + __fp16 *outL1ReluHost = nullptr; + __fp16 *outL1ClipHost = nullptr; + __fp16 *srcDevice = nullptr; + __fp16 *idDevice = nullptr; + __fp16 *outUbReluDevice = nullptr; + __fp16 *outUbClipDevice = nullptr; + __fp16 *outGmReluDevice = nullptr; + __fp16 *outGmClipDevice = nullptr; + __fp16 *outL1ReluDevice = nullptr; + __fp16 *outL1ClipDevice = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + size_t inputSize = 0; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)&srcHost, kSrcSize)); + ACL_CHECK(aclrtMallocHost((void **)&idHost, kIdSize)); + ACL_CHECK(aclrtMallocHost((void **)&outUbReluHost, kOutSize)); + ACL_CHECK(aclrtMallocHost((void **)&outUbClipHost, kOutSize)); + ACL_CHECK(aclrtMallocHost((void **)&outGmReluHost, kOutSize)); + ACL_CHECK(aclrtMallocHost((void **)&outGmClipHost, kOutSize)); + ACL_CHECK(aclrtMallocHost((void **)&outL1ReluHost, kOutSize)); + ACL_CHECK(aclrtMallocHost((void **)&outL1ClipHost, kOutSize)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, kSrcSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&idDevice, kIdSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outUbReluDevice, kOutSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outUbClipDevice, kOutSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outGmReluDevice, kOutSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outGmClipDevice, kOutSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outL1ReluDevice, kOutSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outL1ClipDevice, kOutSize, ACL_MEM_MALLOC_HUGE_FIRST)); + + inputSize = kIdSize; + FILE_CHECK(ReadFile("./v1.bin", inputSize, idHost, kIdSize) && inputSize == kIdSize, + "./v1.bin"); + inputSize = kSrcSize; + FILE_CHECK(ReadFile("./v2.bin", inputSize, srcHost, kSrcSize) && inputSize == kSrcSize, + "./v2.bin"); + for (int index = 3; index <= 8; ++index) { + __fp16 *hostBuf = nullptr; + switch (index) { + case 3: hostBuf = outUbReluHost; break; + case 4: hostBuf = outUbClipHost; break; + case 5: hostBuf = outGmReluHost; break; + case 6: hostBuf = outGmClipHost; break; + case 7: hostBuf = outL1ReluHost; break; + case 8: hostBuf = outL1ClipHost; break; + } + inputSize = kOutSize; + char path[16]; + std::snprintf(path, sizeof(path), "./v%d.bin", index); + FILE_CHECK(ReadFile(path, inputSize, hostBuf, kOutSize) && inputSize == kOutSize, + path); + } + + ACL_CHECK(aclrtMemcpy(srcDevice, kSrcSize, srcHost, kSrcSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(idDevice, kIdSize, idHost, kIdSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outUbReluDevice, kOutSize, outUbReluHost, kOutSize, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outUbClipDevice, kOutSize, outUbClipHost, kOutSize, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outGmReluDevice, kOutSize, outGmReluHost, kOutSize, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outGmClipDevice, kOutSize, outGmClipHost, kOutSize, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outL1ReluDevice, kOutSize, outL1ReluHost, kOutSize, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outL1ClipDevice, kOutSize, outL1ClipHost, kOutSize, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchFixpipe_relu_ub_cv_kernel(srcDevice, idDevice, outUbReluDevice, + outUbClipDevice, outGmReluDevice, + outGmClipDevice, outL1ReluDevice, + outL1ClipDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + + ACL_CHECK(aclrtMemcpy(outUbReluHost, kOutSize, outUbReluDevice, kOutSize, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outUbClipHost, kOutSize, outUbClipDevice, kOutSize, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outGmReluHost, kOutSize, outGmReluDevice, kOutSize, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outGmClipHost, kOutSize, outGmClipDevice, kOutSize, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outL1ReluHost, kOutSize, outL1ReluDevice, kOutSize, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outL1ClipHost, kOutSize, outL1ClipDevice, kOutSize, + ACL_MEMCPY_DEVICE_TO_HOST)); + + for (int index = 3; index <= 8; ++index) { + __fp16 *hostBuf = nullptr; + switch (index) { + case 3: hostBuf = outUbReluHost; break; + case 4: hostBuf = outUbClipHost; break; + case 5: hostBuf = outGmReluHost; break; + case 6: hostBuf = outGmClipHost; break; + case 7: hostBuf = outL1ReluHost; break; + case 8: hostBuf = outL1ClipHost; break; + } + char path[16]; + std::snprintf(path, sizeof(path), "./v%d.bin", index); + FILE_CHECK(WriteFile(path, hostBuf, kOutSize), path); + } + +cleanup: + aclrtFree(srcDevice); + aclrtFree(idDevice); + aclrtFree(outUbReluDevice); + aclrtFree(outUbClipDevice); + aclrtFree(outGmReluDevice); + aclrtFree(outGmClipDevice); + aclrtFree(outL1ReluDevice); + aclrtFree(outL1ClipDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(idHost); + aclrtFreeHost(outUbReluHost); + aclrtFreeHost(outUbClipHost); + aclrtFreeHost(outGmReluHost); + aclrtFreeHost(outGmClipHost); + aclrtFreeHost(outL1ReluHost); + aclrtFreeHost(outL1ClipHost); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/cv-mixed/fixpipe-relu-ub-cv/stub.cpp b/test/vpto/cases/micro-op/cv-mixed/fixpipe-relu-ub-cv/stub.cpp new file mode 100644 index 000000000..fc690e569 --- /dev/null +++ b/test/vpto/cases/micro-op/cv-mixed/fixpipe-relu-ub-cv/stub.cpp @@ -0,0 +1,26 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include + +#ifndef __global__ +#define __global__ +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +extern "C" __global__ [aicore] void fixpipe_relu_ub_cv_kernel( + __gm__ __fp16 *src, __gm__ __fp16 *id, __gm__ float *out_relu, + __gm__ float *out_clip) { + (void)src; + (void)id; + (void)out_relu; + (void)out_clip; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/compare.py b/test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/compare.py new file mode 100755 index 000000000..ae42e6822 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vaxpy-f32 +# family: dsa-sfu +# target_ops: pto.vaxpy +# scenarios: core-f32, scalar-operand, fused-op +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/golden.py b/test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/golden.py new file mode 100755 index 000000000..e99a6f22f --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/golden.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vaxpy-f32 +# family: dsa-sfu +# target_ops: pto.vaxpy +# scenarios: core-f32, scalar-operand, fused-op +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +ALPHA = np.float32(0.125) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v3 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v3 = (ALPHA * v1 + v2).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/kernel.pto new file mode 100644 index 000000000..53cbb2f11 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/kernel.pto @@ -0,0 +1,57 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vaxpy-f32 +// family: dsa-sfu +// target_ops: pto.vaxpy +// scenarios: core-f32, scalar-operand, fused-op +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vaxpy_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %alpha = arith.constant 1.250000e-01 : f32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_addend = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_addend, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c64 { + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %addend = pto.vlds %ub_addend[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %sum = pto.vaxpy %vec, %addend, %alpha, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/launch.cpp b/test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/launch.cpp new file mode 100644 index 000000000..00c358ce2 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/launch.cpp @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vaxpy-f32 +// family: dsa-sfu +// target_ops: pto.vaxpy +// scenarios: core-f32, scalar-operand, fused-op +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vaxpy_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3); + +void LaunchVaxpy_kernel_2d(float *v1, float *v2, float *v3, void *stream) { + vaxpy_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/main.cpp b/test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/main.cpp new file mode 100644 index 000000000..62c80a5fb --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vaxpy-f32/main.cpp @@ -0,0 +1,102 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vaxpy-f32 +// family: dsa-sfu +// target_ops: pto.vaxpy +// scenarios: core-f32, scalar-operand, fused-op +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVaxpy_kernel_2d(float *v1, float *v2, float *v3, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVaxpy_kernel_2d(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vbitsort/compare.py b/test/vpto/cases/micro-op/dsa-sfu/vbitsort/compare.py new file mode 100644 index 000000000..efadc7cd0 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vbitsort/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vbitsort +# family: dsa-sfu +# target_ops: pto.vbitsort +# scenarios: index-generation, layout-transform + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vbitsort/golden.py b/test/vpto/cases/micro-op/dsa-sfu/vbitsort/golden.py new file mode 100644 index 000000000..7a15acfe0 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vbitsort/golden.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vbitsort +# family: dsa-sfu +# target_ops: pto.vbitsort +# scenarios: index-generation, layout-transform + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +PROPOSALS = 32 + + +def generate(output_dir: Path, seed: int) -> None: + _ = seed + scores = np.array([ + 3.5, -2.0, 7.0, 7.0, 1.5, 4.25, 0.0, 9.5, + -8.0, 9.5, 2.0, 2.0, 6.0, 6.0, -1.0, 5.75, + 5.75, 4.25, 8.0, 8.0, 3.0, -4.5, 1.25, 1.25, + 10.0, 10.0, -3.0, 0.5, 12.0, 12.0, -7.0, 6.5, + ], dtype=np.float32) + indices = np.array([ + 100, 203, 77, 88, 12, 45, 501, 9, + 333, 7, 900, 901, 31, 32, 400, 62, + 63, 46, 73, 74, 15, 16, 120, 121, + 5, 6, 700, 701, 1, 2, 808, 90, + ], dtype=np.uint32) + + order = np.argsort(-scores, kind="stable") + sorted_scores = scores[order] + sorted_indices = indices[order] + + packed = np.empty(PROPOSALS * 2, dtype=np.uint32) + packed[0::2] = sorted_scores.view(np.uint32) + packed[1::2] = sorted_indices + + output_dir.mkdir(parents=True, exist_ok=True) + scores.tofile(output_dir / "v1.bin") + indices.tofile(output_dir / "v2.bin") + np.zeros(PROPOSALS * 2, dtype=np.uint32).tofile(output_dir / "v3.bin") + packed.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vbitsort/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vbitsort/kernel.pto new file mode 100644 index 000000000..4aa050245 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vbitsort/kernel.pto @@ -0,0 +1,41 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vbitsort +// family: dsa-sfu +// target_ops: pto.vbitsort +// scenarios: index-generation, layout-transform +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vbitsort_kernel_f32(%arg0: !pto.ptr, + %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c128_i64 = arith.constant 128 : i64 + %c256_i64 = arith.constant 256 : i64 + %c1 = arith.constant 1 : index + %false = arith.constant false + + %ub_scores = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_indices = pto.castptr %c128_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c256_i64 : i64 -> !pto.ptr + pto.mte_gm_ub %arg0, %ub_scores, %c0_i64, %c128_i64 + nburst(%c1_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_indices, %c0_i64, %c128_i64 + nburst(%c1_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vbitsort %ub_out, %ub_scores, %ub_indices, %c1 : !pto.ptr, !pto.ptr, !pto.ptr, index + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vbitsort/launch.cpp b/test/vpto/cases/micro-op/dsa-sfu/vbitsort/launch.cpp new file mode 100644 index 000000000..767eefd22 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vbitsort/launch.cpp @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vbitsort_kernel_f32(__gm__ float *scores, + __gm__ uint32_t *indices, + __gm__ uint32_t *output); + +void LaunchVbitsort_kernel_f32(float *scores, uint32_t *indices, uint32_t *output, + void *stream) { + vbitsort_kernel_f32<<<1, nullptr, stream>>>((__gm__ float *)scores, + (__gm__ uint32_t *)indices, + (__gm__ uint32_t *)output); +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vbitsort/main.cpp b/test/vpto/cases/micro-op/dsa-sfu/vbitsort/main.cpp new file mode 100644 index 000000000..5d9f5a1b2 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vbitsort/main.cpp @@ -0,0 +1,119 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vbitsort +// family: dsa-sfu +// target_ops: pto.vbitsort +// scenarios: index-generation, layout-transform +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVbitsort_kernel_f32(float *scores, uint32_t *indices, uint32_t *output, + void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 32; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint32_t); + size_t elemCount_v3 = 64; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint32_t); + float *v1Host = nullptr; + float *v1Device = nullptr; + uint32_t *v2Host = nullptr; + uint32_t *v2Device = nullptr; + uint32_t *v3Host = nullptr; + uint32_t *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVbitsort_kernel_f32(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vci-f16/compare.py b/test/vpto/cases/micro-op/dsa-sfu/vci-f16/compare.py new file mode 100755 index 000000000..8c2628b88 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vci-f16/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vci +# family: dsa-sfu / conversion +# target_ops: pto.vci +# scenarios: index-generation +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float16, 0.001) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vci-f16/golden.py b/test/vpto/cases/micro-op/dsa-sfu/vci-f16/golden.py new file mode 100755 index 000000000..c19fcdb99 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vci-f16/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vci +# family: dsa-sfu / conversion +# target_ops: pto.vci +# scenarios: index-generation +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 1 +COLS = 128 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + _ = seed + v1 = np.zeros((ROWS, COLS), dtype=np.float16) + v2 = np.zeros((ROWS, COLS), dtype=np.float16) + golden_v2 = np.arange(ROWS * COLS, dtype=np.float16).reshape(ROWS, COLS) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vci-f16/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vci-f16/kernel.pto new file mode 100644 index 000000000..52410408f --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vci-f16/kernel.pto @@ -0,0 +1,26 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vci_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c2_i64 = arith.constant 2 : i64 + %c128_i64 = arith.constant 128 : i64 + %cst = arith.constant 0.0 : f16 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + %indices = pto.vci %cst {order = "ASC"} : f16 -> !pto.vreg<128xf16> + pto.vsts %indices, %ub_out[%c0], %mask : !pto.vreg<128xf16>, !pto.ptr, !pto.mask + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c2_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vci-f16/launch.cpp b/test/vpto/cases/micro-op/dsa-sfu/vci-f16/launch.cpp new file mode 100644 index 000000000..8647dab79 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vci-f16/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vci_kernel_2d(__gm__ half *v1, + __gm__ half *v2); + +void LaunchVci_kernel_2d(aclFloat16 *v1, aclFloat16 *v2, void *stream) { + vci_kernel_2d<<<1, nullptr, stream>>>((__gm__ half *)v1, + (__gm__ half *)v2); +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vci-f16/main.cpp b/test/vpto/cases/micro-op/dsa-sfu/vci-f16/main.cpp new file mode 100644 index 000000000..b628b0747 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vci-f16/main.cpp @@ -0,0 +1,79 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVci_kernel_2d(aclFloat16 *v1, aclFloat16 *v2, void *stream); + +int main() { + size_t elemCount_v1 = 128; + size_t fileSize_v1 = elemCount_v1 * sizeof(aclFloat16); + size_t elemCount_v2 = 128; + size_t fileSize_v2 = elemCount_v2 * sizeof(aclFloat16); + aclFloat16 *v1Host = nullptr; + aclFloat16 *v1Device = nullptr; + aclFloat16 *v2Host = nullptr; + aclFloat16 *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVci_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vci-si8/compare.py b/test/vpto/cases/micro-op/dsa-sfu/vci-si8/compare.py new file mode 100755 index 000000000..326fa7450 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vci-si8/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vci +# family: dsa-sfu / conversion +# target_ops: pto.vci +# scenarios: index-generation +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.int8, 0) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vci-si8/golden.py b/test/vpto/cases/micro-op/dsa-sfu/vci-si8/golden.py new file mode 100755 index 000000000..b3482d94d --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vci-si8/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vci +# family: dsa-sfu / conversion +# target_ops: pto.vci +# scenarios: index-generation +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + _ = seed + v1 = np.zeros((ROWS, COLS), dtype=np.int8) + v2 = np.zeros((ROWS, COLS), dtype=np.int8) + golden_v2 = np.arange(ROWS * COLS, dtype=np.int32).astype(np.int8).reshape(ROWS, COLS) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vci-si8/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vci-si8/kernel.pto new file mode 100644 index 000000000..5782e7d51 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vci-si8/kernel.pto @@ -0,0 +1,32 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vci_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c8_i64 = arith.constant 8 : i64 + %c128_i64 = arith.constant 128 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c256 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b8 %remaining : i32 -> !pto.mask, i32 + %base = arith.index_castui %offset : index to i8 + %indices = pto.vci %base {order = "ASC"} : i8 -> !pto.vreg<256xsi8> + pto.vsts %indices, %ub_out[%offset], %mask : !pto.vreg<256xsi8>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c8_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vci-si8/launch.cpp b/test/vpto/cases/micro-op/dsa-sfu/vci-si8/launch.cpp new file mode 100644 index 000000000..0b3f33084 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vci-si8/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vci_kernel_2d(__gm__ int8_t *v1, + __gm__ int8_t *v2); + +void LaunchVci_kernel_2d(int8_t *v1, int8_t *v2, void *stream) { + vci_kernel_2d<<<1, nullptr, stream>>>((__gm__ int8_t *)v1, + (__gm__ int8_t *)v2); +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vci-si8/main.cpp b/test/vpto/cases/micro-op/dsa-sfu/vci-si8/main.cpp new file mode 100644 index 000000000..204d6efa9 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vci-si8/main.cpp @@ -0,0 +1,80 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVci_kernel_2d(int8_t *v1, int8_t *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(int8_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(int8_t); + int8_t *v1Host = nullptr; + int8_t *v1Device = nullptr; + int8_t *v2Host = nullptr; + int8_t *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVci_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vci/compare.py b/test/vpto/cases/micro-op/dsa-sfu/vci/compare.py new file mode 100755 index 000000000..4a2c212c6 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vci/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vci +# family: dsa-sfu / conversion +# target_ops: pto.vci +# scenarios: index-generation +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.int32, 0) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vci/golden.py b/test/vpto/cases/micro-op/dsa-sfu/vci/golden.py new file mode 100755 index 000000000..f044e1819 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vci/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vci +# family: dsa-sfu / conversion +# target_ops: pto.vci +# scenarios: index-generation +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + _ = seed + v1 = np.zeros((ROWS, COLS), dtype=np.int32) + v2 = np.zeros((ROWS, COLS), dtype=np.int32) + golden_v2 = np.arange(ROWS * COLS, dtype=np.int32).reshape(ROWS, COLS) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vci/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vci/kernel.pto new file mode 100644 index 000000000..3c4c15889 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vci/kernel.pto @@ -0,0 +1,32 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vci_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %base = arith.index_castui %offset : index to i32 + %indices = pto.vci %base {order = "ASC"} : i32 -> !pto.vreg<64xsi32> + pto.vsts %indices, %ub_out[%offset], %mask : !pto.vreg<64xsi32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vci/launch.cpp b/test/vpto/cases/micro-op/dsa-sfu/vci/launch.cpp new file mode 100644 index 000000000..0ce203973 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vci/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vci_kernel_2d(__gm__ int32_t *v1, + __gm__ int32_t *v2); + +void LaunchVci_kernel_2d(int32_t *v1, int32_t *v2, void *stream) { + vci_kernel_2d<<<1, nullptr, stream>>>((__gm__ int32_t *)v1, + (__gm__ int32_t *)v2); +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vci/main.cpp b/test/vpto/cases/micro-op/dsa-sfu/vci/main.cpp new file mode 100644 index 000000000..0baf928bd --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vci/main.cpp @@ -0,0 +1,80 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVci_kernel_2d(int32_t *v1, int32_t *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(int32_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(int32_t); + int32_t *v1Host = nullptr; + int32_t *v1Device = nullptr; + int32_t *v2Host = nullptr; + int32_t *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVci_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/compare.py b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/compare.py new file mode 100755 index 000000000..5353e5df9 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vexpdif-boundary +# family: dsa-sfu +# target_ops: pto.vexpdif +# scenarios: core-f32, fused-expdiff, exceptional-values, floating-overflow-underflow +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/golden.py b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/golden.py new file mode 100755 index 000000000..b4b417320 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/golden.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vexpdif-boundary +# family: dsa-sfu +# target_ops: pto.vexpdif +# scenarios: core-f32, fused-expdiff, exceptional-values, floating-overflow-underflow +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +def generate(output_dir: Path, seed: int) -> None: + del seed + src_pattern = np.array( + [ + 0.0, 88.0, -120.0, np.nan, np.inf, -np.inf, 1.0, -1.0, + 90.0, -90.0, 50.0, -50.0, 3.0, -3.0, 10.0, -10.0, + ], + dtype=np.float32, + ) + max_pattern = np.array( + [ + 0.0, 0.0, 0.0, 1.0, np.inf, -np.inf, -1.0, 1.0, + 0.0, 0.0, 100.0, -100.0, 3.0, -3.0, 20.0, -20.0, + ], + dtype=np.float32, + ) + flat_src = np.resize(src_pattern, ROWS * COLS).astype(np.float32, copy=False) + flat_max = np.resize(max_pattern, ROWS * COLS).astype(np.float32, copy=False) + v1 = flat_src.reshape(ROWS, COLS) + v2 = flat_max.reshape(ROWS, COLS) + v3 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v3 = np.exp(flat_src - flat_max).astype(np.float32, copy=False).reshape(ROWS, COLS) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/kernel.pto new file mode 100644 index 000000000..ab3ed968a --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/kernel.pto @@ -0,0 +1,56 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vexpdif-boundary +// family: dsa-sfu +// target_ops: pto.vexpdif +// scenarios: core-f32, fused-expdiff, exceptional-values, floating-overflow-underflow +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vexpdif_boundary_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_max = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_max, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c64 { + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %max = pto.vlds %ub_max[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %sum = pto.vexpdif %vec, %max, %mask, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/launch.cpp b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/launch.cpp new file mode 100644 index 000000000..e2f5057e6 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/launch.cpp @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vexpdif-boundary +// family: dsa-sfu +// target_ops: pto.vexpdif +// scenarios: core-f32, fused-expdiff, exceptional-values, floating-overflow-underflow +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vexpdif_boundary_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3); + +void LaunchVexpdiff_boundary_kernel_2d(float *v1, float *v2, float *v3, void *stream) { + vexpdif_boundary_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/main.cpp b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/main.cpp new file mode 100644 index 000000000..3f29604cb --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-boundary/main.cpp @@ -0,0 +1,102 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vexpdif-boundary +// family: dsa-sfu +// target_ops: pto.vexpdif +// scenarios: core-f32, fused-expdiff, exceptional-values, floating-overflow-underflow +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVexpdiff_boundary_kernel_2d(float *v1, float *v2, float *v3, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVexpdiff_boundary_kernel_2d(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/compare.py b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/compare.py new file mode 100644 index 000000000..8ca6af6cf --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/compare.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vexpdif-f16-part +# family: dsa-sfu +# target_ops: pto.vexpdif +# scenarios: core-f16, fused-expdiff, part-even-odd + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/golden.py b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/golden.py new file mode 100644 index 000000000..1d493c5bb --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/golden.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vexpdif-f16-part +# family: dsa-sfu +# target_ops: pto.vexpdif +# scenarios: core-f16, fused-expdiff, part-even-odd + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 31 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-4.0, 4.0, size=(ROWS, COLS)).astype(np.float16) + v2 = rng.uniform(-2.0, 2.0, size=(ROWS, COLS)).astype(np.float16) + v3 = np.zeros((ROWS, COLS), dtype=np.float32) + + flat1 = v1.reshape(-1) + flat2 = v2.reshape(-1) + golden = np.empty((ROWS * COLS,), dtype=np.float32) + for base in range(0, ROWS * COLS, 128): + chunk1 = flat1[base : base + 128].astype(np.float32) + chunk2 = flat2[base : base + 128].astype(np.float32) + golden[base : base + 64] = np.exp(chunk1[0::2] - chunk2[0::2]).astype( + np.float32 + ) + golden[base + 64 : base + 128] = np.exp( + chunk1[1::2] - chunk2[1::2] + ).astype(np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + flat1.tofile(output_dir / "v1.bin") + flat2.tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/kernel.pto new file mode 100644 index 000000000..363956318 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/kernel.pto @@ -0,0 +1,64 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vexpdif-f16-part +// family: dsa-sfu +// target_ops: pto.vexpdif +// scenarios: core-f16, fused-expdiff, part-even-odd +// NOTE: validates that ODD/EVEN selects odd/even lanes from f16 inputs. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vexpdif_f16_part_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_max = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_max, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %full_mask = pto.pset_b16 "PAT_ALL" : !pto.mask + %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1024_i32) -> (i32) { + %input = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<128xf16> + %max = pto.vlds %ub_max[%offset] : !pto.ptr -> !pto.vreg<128xf16> + %even_mask, %remaining_after_even = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %odd_mask, %next_remaining = pto.plt_b32 %remaining_after_even : i32 -> !pto.mask, i32 + %even = pto.vexpdif %input, %max, %full_mask, "EVEN" : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> + %odd = pto.vexpdif %input, %max, %full_mask, "ODD" : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> + %odd_offset = arith.addi %offset, %c64 : index + pto.vsts %even, %ub_out[%offset], %even_mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + pto.vsts %odd, %ub_out[%odd_offset], %odd_mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/launch.cpp b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/launch.cpp new file mode 100644 index 000000000..78f8bef63 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/launch.cpp @@ -0,0 +1,53 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vexpdif-f16-part +// family: dsa-sfu +// target_ops: pto.vexpdif +// scenarios: core-f16, fused-expdiff, part-even-odd +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vexpdif_f16_part_kernel_2d(__gm__ half *v1, + __gm__ half *v2, + __gm__ float *v3); + +void LaunchVexpdiff_f16_part_kernel_2d(uint16_t *v1, uint16_t *v2, float *v3, + void *stream) { + vexpdif_f16_part_kernel_2d<<<1, nullptr, stream>>>((__gm__ half *)v1, + (__gm__ half *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/main.cpp b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/main.cpp new file mode 100644 index 000000000..58b1f6c5d --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f16-part/main.cpp @@ -0,0 +1,101 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vexpdif-f16-part +// family: dsa-sfu +// target_ops: pto.vexpdif +// scenarios: core-f16, fused-expdiff, part-even-odd +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVexpdiff_f16_part_kernel_2d(uint16_t *v1, uint16_t *v2, float *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVexpdiff_f16_part_kernel_2d(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/compare.py b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/compare.py new file mode 100755 index 000000000..8575e7aa5 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vexpdif-f32 +# family: dsa-sfu +# target_ops: pto.vexpdif +# scenarios: core-f32, fused-expdiff +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/golden.py b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/golden.py new file mode 100755 index 000000000..874d5b6e5 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/golden.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vexpdif-f32 +# family: dsa-sfu +# target_ops: pto.vexpdif +# scenarios: core-f32, fused-expdiff +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.ones((ROWS, COLS), dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/kernel.pto new file mode 100644 index 000000000..d3913956c --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/kernel.pto @@ -0,0 +1,51 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vexpdif-f32 +// family: dsa-sfu +// target_ops: pto.vexpdif +// scenarios: core-f32, fused-expdiff +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vexpdif_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %sum = pto.vexpdif %vec, %vec, %mask, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/launch.cpp b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/launch.cpp new file mode 100644 index 000000000..00ada867d --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/launch.cpp @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vexpdif-f32 +// family: dsa-sfu +// target_ops: pto.vexpdif +// scenarios: core-f32, fused-expdiff +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vexpdif_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVexpdiff_kernel_2d(float *v1, float *v2, void *stream) { + vexpdif_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/main.cpp b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/main.cpp new file mode 100644 index 000000000..4afacde3a --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vexpdiff-f32/main.cpp @@ -0,0 +1,91 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vexpdif-f32 +// family: dsa-sfu +// target_ops: pto.vexpdif +// scenarios: core-f32, fused-expdiff +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVexpdiff_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVexpdiff_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f16/compare.py b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f16/compare.py new file mode 100755 index 000000000..4717fd3e8 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f16/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vlrelu-f16 +# family: dsa-sfu +# target_ops: pto.vlrelu +# scenarios: core-f16, full-mask, scalar-operand +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f16/golden.py b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f16/golden.py new file mode 100755 index 000000000..bc7c328b9 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f16/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vlrelu-f16 +# family: dsa-sfu +# target_ops: pto.vlrelu +# scenarios: core-f16, full-mask, scalar-operand +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +ALPHA = np.float16(0.125) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float16) + v2 = np.zeros((ROWS, COLS), dtype=np.float16) + golden_v2 = np.where(v1 >= 0.0, v1, v1 * ALPHA).astype(np.float16, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f16/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f16/kernel.pto new file mode 100644 index 000000000..249f947b0 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f16/kernel.pto @@ -0,0 +1,51 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vlrelu-f16 +// family: dsa-sfu +// target_ops: pto.vlrelu +// scenarios: core-f16, full-mask, scalar-operand +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vec_add_scalar_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %cst = arith.constant 1.250000e-01 : f16 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<128xf16> + %sum = pto.vlrelu %vec, %cst, %mask : !pto.vreg<128xf16>, f16, !pto.mask -> !pto.vreg<128xf16> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<128xf16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f16/launch.cpp b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f16/launch.cpp new file mode 100644 index 000000000..da89bb6f0 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f16/launch.cpp @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vlrelu-f16 +// family: dsa-sfu +// target_ops: pto.vlrelu +// scenarios: core-f16, full-mask, scalar-operand +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vec_add_scalar_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVec_add_scalar_kernel_2d(float *v1, float *v2, void *stream) { + vec_add_scalar_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f16/main.cpp b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f16/main.cpp new file mode 100644 index 000000000..73e868d99 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f16/main.cpp @@ -0,0 +1,91 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vlrelu-f16 +// family: dsa-sfu +// target_ops: pto.vlrelu +// scenarios: core-f16, full-mask, scalar-operand +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVec_add_scalar_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVec_add_scalar_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32-exceptional/compare.py b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32-exceptional/compare.py new file mode 100644 index 000000000..15b793fac --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32-exceptional/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32-exceptional/golden.py b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32-exceptional/golden.py new file mode 100644 index 000000000..938b69b9d --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32-exceptional/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +ALPHA = np.float32(0.125) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + specials = np.array( + [-np.inf, -8.0, -1.0, -0.0, 0.0, 1.0, np.inf, np.nan], + dtype=np.float32, + ) + v1 = np.resize(specials, ROWS * COLS).reshape(ROWS, COLS).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.where(v1 >= 0.0, v1, v1 * ALPHA).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32-exceptional/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32-exceptional/kernel.pto new file mode 100644 index 000000000..95005cf5a --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32-exceptional/kernel.pto @@ -0,0 +1,45 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vec_add_scalar_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + %cst = arith.constant 1.250000e-01 : f32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %sum = pto.vlrelu %vec, %cst, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32-exceptional/launch.cpp b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32-exceptional/launch.cpp new file mode 100644 index 000000000..44c07c249 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32-exceptional/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vec_add_scalar_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVec_add_scalar_kernel_2d(float *v1, float *v2, void *stream) { + vec_add_scalar_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32-exceptional/main.cpp b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32-exceptional/main.cpp new file mode 100644 index 000000000..fcb42331f --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32-exceptional/main.cpp @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVec_add_scalar_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVec_add_scalar_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32/compare.py b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32/compare.py new file mode 100644 index 000000000..15b793fac --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32/golden.py b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32/golden.py new file mode 100644 index 000000000..dd0899be8 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32/golden.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +ALPHA = np.float32(0.125) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.where(v1 >= 0.0, v1, v1 * ALPHA).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32/kernel.pto new file mode 100644 index 000000000..95005cf5a --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32/kernel.pto @@ -0,0 +1,45 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vec_add_scalar_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + %cst = arith.constant 1.250000e-01 : f32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %sum = pto.vlrelu %vec, %cst, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32/launch.cpp b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32/launch.cpp new file mode 100644 index 000000000..44c07c249 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vec_add_scalar_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVec_add_scalar_kernel_2d(float *v1, float *v2, void *stream) { + vec_add_scalar_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32/main.cpp b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32/main.cpp new file mode 100644 index 000000000..fcb42331f --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-f32/main.cpp @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVec_add_scalar_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVec_add_scalar_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-tail/compare.py b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-tail/compare.py new file mode 100644 index 000000000..c13d79273 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-tail/compare.py @@ -0,0 +1,39 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v2.bin", "v2.bin", np.float32, 1e-4, 1000) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-tail/golden.py b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-tail/golden.py new file mode 100644 index 000000000..2544a92ff --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-tail/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +ALPHA = np.float32(0.125) +LOGICAL_ELEMS = 1000 +OUT_SENTINEL = np.float32(-123.25) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + golden_v2 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + flat = v1.reshape(-1) + golden_v2.reshape(-1)[:LOGICAL_ELEMS] = np.where( + flat[:LOGICAL_ELEMS] >= 0.0, flat[:LOGICAL_ELEMS], flat[:LOGICAL_ELEMS] * ALPHA + ).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-tail/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-tail/kernel.pto new file mode 100644 index 000000000..abecf2979 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-tail/kernel.pto @@ -0,0 +1,44 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vadds_tail_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1000_i32 = arith.constant 1000 : i32 + %cst = arith.constant 1.250000e-01 : f32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1000_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %sum = pto.vlrelu %vec, %cst, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-tail/launch.cpp b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-tail/launch.cpp new file mode 100644 index 000000000..b4cd46470 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-tail/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vadds_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVadds_tail_kernel_2d(float *v1, float *v2, void *stream) { + vadds_tail_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vlrelu-tail/main.cpp b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-tail/main.cpp new file mode 100644 index 000000000..ab77e6b1a --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vlrelu-tail/main.cpp @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVadds_tail_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVadds_tail_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vmrgsort4/compare.py b/test/vpto/cases/micro-op/dsa-sfu/vmrgsort4/compare.py new file mode 100644 index 000000000..c0f54ec75 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vmrgsort4/compare.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import os +import struct +import sys + +import numpy as np + + +PAIR_FMT = "fI" +PAIR_SIZE = struct.calcsize(PAIR_FMT) +PAIR_COUNT = 4 + + +def read_pairs(path: str): + values = [] + indices = [] + with open(path, "rb") as f: + for _ in range(PAIR_COUNT): + data = f.read(PAIR_SIZE) + if len(data) != PAIR_SIZE: + break + value, index = struct.unpack(PAIR_FMT, data) + values.append(value) + indices.append(index) + return np.array(values, dtype=np.float32), np.array(indices, dtype=np.uint32) + + +def read_counts(path: str): + with open(path, "rb") as f: + data = f.read(8) + return np.array(struct.unpack("4h", data), dtype=np.int16) + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + golden_values, golden_indices = read_pairs("golden_v2.bin") + output_values, output_indices = read_pairs("v2.bin") + golden_counts = read_counts("golden_v3.bin") + output_counts = read_counts("v3.bin") + ok = ( + golden_values.shape == output_values.shape + and golden_indices.shape == output_indices.shape + and np.allclose(golden_values, output_values) + and np.array_equal(golden_indices, output_indices) + and np.array_equal(golden_counts, output_counts) + ) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vmrgsort4/golden.py b/test/vpto/cases/micro-op/dsa-sfu/vmrgsort4/golden.py new file mode 100644 index 000000000..cd3da3ea0 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vmrgsort4/golden.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +import struct +from pathlib import Path + + +PAIR_FMT = "fI" + + +def write_pairs(path: Path, pairs) -> None: + with path.open("wb") as f: + for score, index in pairs: + f.write(struct.pack(PAIR_FMT, score, index)) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + + out = args.output_dir + out.mkdir(parents=True, exist_ok=True) + + src = [(9.0, 90), (7.0, 70), (8.0, 80), (6.0, 60)] + golden = [(9.0, 90), (0.0, 0), (0.0, 0), (0.0, 0)] + + write_pairs(out / "v1.bin", src) + write_pairs(out / "v2.bin", [(0.0, 0)] * 4) + write_pairs(out / "golden_v2.bin", golden) + (out / "v3.bin").write_bytes(struct.pack("4h", 0, 0, 0, 0)) + (out / "golden_v3.bin").write_bytes(struct.pack("4h", 1, 0, 0, 0)) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vmrgsort4/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vmrgsort4/kernel.pto new file mode 100644 index 000000000..35a58675c --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vmrgsort4/kernel.pto @@ -0,0 +1,51 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmrgsort4_kernel_f32(%arg0: !pto.ptr, + %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %c6 = arith.constant 6 : index + %c32_i64 = arith.constant 32 : i64 + %c_count = arith.constant 281479271743489 : i64 + %c_config = arith.constant 7937 : i64 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c32_i64 : i64 -> !pto.ptr + %src1 = pto.addptr %ub_in, %c2 : !pto.ptr -> !pto.ptr + %src2 = pto.addptr %ub_in, %c4 : !pto.ptr -> !pto.ptr + %src3 = pto.addptr %ub_in, %c6 : !pto.ptr -> !pto.ptr + + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vmrgsort4 %ub_out, %ub_in, %src1, %src2, %src3, %c_count, %c_config + : !pto.ptr, !pto.ptr, !pto.ptr, !pto.ptr, + !pto.ptr, i64, i64 + + pto.set_flag["PIPE_V", "PIPE_S", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_S", "EVENT_ID0"] + %list0, %list1, %list2, %list3 = pto.get_vms4_sr : i16, i16, i16, i16 + pto.store_scalar %list0, %arg2[%c0] : !pto.ptr, i16 + pto.store_scalar %list1, %arg2[%c1] : !pto.ptr, i16 + pto.store_scalar %list2, %arg2[%c2] : !pto.ptr, i16 + pto.store_scalar %list3, %arg2[%c3] : !pto.ptr, i16 + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.mte_ub_gm %ub_out, %arg1, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vmrgsort4/launch.cpp b/test/vpto/cases/micro-op/dsa-sfu/vmrgsort4/launch.cpp new file mode 100644 index 000000000..7b3d6f877 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vmrgsort4/launch.cpp @@ -0,0 +1,28 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#include + +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vmrgsort4_kernel_f32(__gm__ float *src, + __gm__ float *dst, + __gm__ int16_t *counts); + +void LaunchVmrgsort4_kernel_f32(float *src, float *dst, int16_t *counts, + void *stream) { + vmrgsort4_kernel_f32<<<1, nullptr, stream>>>((__gm__ float *)src, + (__gm__ float *)dst, + (__gm__ int16_t *)counts); +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vmrgsort4/main.cpp b/test/vpto/cases/micro-op/dsa-sfu/vmrgsort4/main.cpp new file mode 100644 index 000000000..7f5a94a97 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vmrgsort4/main.cpp @@ -0,0 +1,100 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmrgsort4_kernel_f32(float *src, float *dst, int16_t *counts, + void *stream); + +int main() { + size_t inputBytes = 32; + size_t outputBytes = 32; + size_t countsBytes = 8; + float *srcHost = nullptr; + float *srcDevice = nullptr; + float *dstHost = nullptr; + float *dstDevice = nullptr; + int16_t *countsHost = nullptr; + int16_t *countsDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), inputBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), outputBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&countsHost), countsBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, inputBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, outputBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&countsDevice, countsBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + if (!ReadFile("./v1.bin", inputBytes, srcHost, inputBytes)) { + std::fprintf(stderr, "[ERROR] failed to read v1.bin\n"); + rc = 1; + goto cleanup; + } + if (!ReadFile("./v2.bin", outputBytes, dstHost, outputBytes)) { + std::fprintf(stderr, "[ERROR] failed to read v2.bin\n"); + rc = 1; + goto cleanup; + } + if (!ReadFile("./v3.bin", countsBytes, countsHost, countsBytes)) { + std::fprintf(stderr, "[ERROR] failed to read v3.bin\n"); + rc = 1; + goto cleanup; + } + + ACL_CHECK(aclrtMemcpy(srcDevice, inputBytes, srcHost, inputBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, outputBytes, dstHost, outputBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(countsDevice, countsBytes, countsHost, countsBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmrgsort4_kernel_f32(srcDevice, dstDevice, countsDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, outputBytes, dstDevice, outputBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(countsHost, countsBytes, countsDevice, countsBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, outputBytes); + WriteFile("./v3.bin", countsHost, countsBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFree(countsDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + aclrtFreeHost(countsHost); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vmula-accumulator-boundary/compare.py b/test/vpto/cases/micro-op/dsa-sfu/vmula-accumulator-boundary/compare.py new file mode 100755 index 000000000..e7e8af91d --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vmula-accumulator-boundary/compare.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vmula-accumulator-boundary +# family: dsa-sfu +# target_ops: pto.vmula +# scenarios: core-f32, fused-op, accumulator +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + +ACTIVE_ELEMS = 65 + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.size == count and output.size == count and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v2.bin", "v2.bin", np.float32, 1e-4, ACTIVE_ELEMS) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vmula-accumulator-boundary/golden.py b/test/vpto/cases/micro-op/dsa-sfu/vmula-accumulator-boundary/golden.py new file mode 100755 index 000000000..6c0d8c252 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vmula-accumulator-boundary/golden.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vmula-accumulator-boundary +# family: dsa-sfu +# target_ops: pto.vmula +# scenarios: core-f32, fused-op, accumulator, boundary +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = (v1 + np.abs(v1) * np.abs(v1)).astype(np.float32, copy=False) + golden_v2.reshape(-1)[65:] = 0.0 + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vmula-accumulator-boundary/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vmula-accumulator-boundary/kernel.pto new file mode 100644 index 000000000..9b1ba1b38 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vmula-accumulator-boundary/kernel.pto @@ -0,0 +1,52 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vmula-accumulator-boundary +// family: dsa-sfu +// target_ops: pto.vmula +// scenarios: core-f32, fused-op, accumulator, boundary +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vec_add_scalar_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c65_i32 = arith.constant 65 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c65_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %acc = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %lhs = pto.vabs %acc, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %rhs = pto.vabs %lhs, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %sum = pto.vmula %acc, %lhs, %rhs, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vmula-accumulator-boundary/launch.cpp b/test/vpto/cases/micro-op/dsa-sfu/vmula-accumulator-boundary/launch.cpp new file mode 100644 index 000000000..8dcf35197 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vmula-accumulator-boundary/launch.cpp @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vmula-accumulator-boundary +// family: dsa-sfu +// target_ops: pto.vmula +// scenarios: core-f32, fused-op, accumulator +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vec_add_scalar_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVec_add_scalar_kernel_2d(float *v1, float *v2, void *stream) { + vec_add_scalar_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vmula-accumulator-boundary/main.cpp b/test/vpto/cases/micro-op/dsa-sfu/vmula-accumulator-boundary/main.cpp new file mode 100644 index 000000000..c9b1f36c3 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vmula-accumulator-boundary/main.cpp @@ -0,0 +1,91 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vmula-accumulator-boundary +// family: dsa-sfu +// target_ops: pto.vmula +// scenarios: core-f32, fused-op, accumulator +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVec_add_scalar_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVec_add_scalar_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vmula/compare.py b/test/vpto/cases/micro-op/dsa-sfu/vmula/compare.py new file mode 100755 index 000000000..cfc4e190a --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vmula/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vmula +# family: dsa-sfu +# target_ops: pto.vmula +# scenarios: core-f32, fused-op, accumulator +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vmula/golden.py b/test/vpto/cases/micro-op/dsa-sfu/vmula/golden.py new file mode 100755 index 000000000..2110c7144 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vmula/golden.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vmula +# family: dsa-sfu +# target_ops: pto.vmula +# scenarios: core-f32, fused-op, accumulator +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = (v1 + np.abs(v1) * np.abs(v1)).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vmula/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vmula/kernel.pto new file mode 100644 index 000000000..41708a815 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vmula/kernel.pto @@ -0,0 +1,52 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vmula +// family: dsa-sfu +// target_ops: pto.vmula +// scenarios: core-f32, fused-op, accumulator +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vec_add_scalar_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %acc = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %lhs = pto.vabs %acc, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %rhs = pto.vabs %lhs, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %sum = pto.vmula %acc, %lhs, %rhs, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vmula/launch.cpp b/test/vpto/cases/micro-op/dsa-sfu/vmula/launch.cpp new file mode 100644 index 000000000..fa6ae4bb5 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vmula/launch.cpp @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vmula +// family: dsa-sfu +// target_ops: pto.vmula +// scenarios: core-f32, fused-op, accumulator +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vec_add_scalar_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVec_add_scalar_kernel_2d(float *v1, float *v2, void *stream) { + vec_add_scalar_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vmula/main.cpp b/test/vpto/cases/micro-op/dsa-sfu/vmula/main.cpp new file mode 100644 index 000000000..54508a12d --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vmula/main.cpp @@ -0,0 +1,91 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vmula +// family: dsa-sfu +// target_ops: pto.vmula +// scenarios: core-f32, fused-op, accumulator +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVec_add_scalar_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVec_add_scalar_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vmull/compare.py b/test/vpto/cases/micro-op/dsa-sfu/vmull/compare.py new file mode 100755 index 000000000..e2e0b0eef --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vmull/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vmull +# family: dsa-sfu +# target_ops: pto.vmull +# scenarios: widening-op, hi-lo-split +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.int32, 0) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vmull/golden.py b/test/vpto/cases/micro-op/dsa-sfu/vmull/golden.py new file mode 100755 index 000000000..96b892823 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vmull/golden.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vmull +# family: dsa-sfu +# target_ops: pto.vmull +# scenarios: widening-op, hi-lo-split +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + lhs = rng.integers(-10000, 10000, size=(ROWS // 2, COLS), dtype=np.int32) + rhs = rng.integers(-10000, 10000, size=(ROWS // 2, COLS), dtype=np.int32) + v1 = np.concatenate([lhs, rhs], axis=0).astype(np.int32, copy=False) + prod = lhs.astype(np.int64) * rhs.astype(np.int64) + low = (prod & np.int64(0xFFFFFFFF)).astype(np.uint32).view(np.int32) + high = ((prod >> np.int64(32)) & np.int64(0xFFFFFFFF)).astype(np.uint32).view(np.int32) + golden_v2 = np.concatenate([low, high], axis=0).astype(np.int32, copy=False) + v2 = np.zeros((ROWS, COLS), dtype=np.int32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vmull/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vmull/kernel.pto new file mode 100644 index 000000000..3565dc57a --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vmull/kernel.pto @@ -0,0 +1,54 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vmull +// family: dsa-sfu +// target_ops: pto.vmull +// scenarios: widening-op, hi-lo-split +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmull_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c512 = arith.constant 512 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c512_i32 = arith.constant 512 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %gm_in = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + %gm_out = pto.castptr %arg1 : !pto.ptr -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %gm_in, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c512 step %c64 iter_args(%remaining = %c512_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %rhs_offset = arith.addi %offset, %c512 : index + %lhs = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xi32> + %rhs = pto.vlds %ub_in[%rhs_offset] : !pto.ptr -> !pto.vreg<64xi32> + %low, %high = pto.vmull %lhs, %rhs, %mask : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32>, !pto.vreg<64xi32> + pto.vsts %low, %ub_out[%offset], %mask : !pto.vreg<64xi32>, !pto.ptr, !pto.mask + pto.vsts %high, %ub_out[%rhs_offset], %mask : !pto.vreg<64xi32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %gm_out, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vmull/launch.cpp b/test/vpto/cases/micro-op/dsa-sfu/vmull/launch.cpp new file mode 100644 index 000000000..1b7ce84c6 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vmull/launch.cpp @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vmull +// family: dsa-sfu +// target_ops: pto.vmull +// scenarios: widening-op, hi-lo-split +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vmull_kernel_2d(__gm__ int *v1, + __gm__ int *v2); + +void LaunchVmull_kernel_2d(int *v1, int *v2, void *stream) { + vmull_kernel_2d<<<1, nullptr, stream>>>((__gm__ int *)v1, + (__gm__ int *)v2); +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vmull/main.cpp b/test/vpto/cases/micro-op/dsa-sfu/vmull/main.cpp new file mode 100644 index 000000000..d853eda16 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vmull/main.cpp @@ -0,0 +1,91 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vmull +// family: dsa-sfu +// target_ops: pto.vmull +// scenarios: widening-op, hi-lo-split +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmull_kernel_2d(int *v1, int *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(int); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(int); + int *v1Host = nullptr; + int *v1Device = nullptr; + int *v2Host = nullptr; + int *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmull_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/compare.py b/test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/compare.py new file mode 100755 index 000000000..35caa0aa6 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vprelu-f32 +# family: dsa-sfu +# target_ops: pto.vprelu +# scenarios: core-f32, vector-alpha +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/golden.py b/test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/golden.py new file mode 100755 index 000000000..e1fd7c683 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vprelu-f32 +# family: dsa-sfu +# target_ops: pto.vprelu +# scenarios: core-f32, vector-alpha +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = rng.uniform(0.05, 0.5, size=(ROWS, COLS)).astype(np.float32) + v3 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v3 = np.where(v1 >= 0.0, v1, v1 * v2).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/kernel.pto new file mode 100644 index 000000000..cb1b92910 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/kernel.pto @@ -0,0 +1,56 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vprelu-f32 +// family: dsa-sfu +// target_ops: pto.vprelu +// scenarios: core-f32, vector-alpha +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vprelu_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_alpha = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_alpha, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c64 { + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %alpha = pto.vlds %ub_alpha[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %sum = pto.vprelu %vec, %alpha, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/launch.cpp b/test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/launch.cpp new file mode 100644 index 000000000..d6002ce63 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/launch.cpp @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vprelu-f32 +// family: dsa-sfu +// target_ops: pto.vprelu +// scenarios: core-f32, vector-alpha +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vprelu_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3); + +void LaunchVprelu_kernel_2d(float *v1, float *v2, float *v3, void *stream) { + vprelu_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/main.cpp b/test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/main.cpp new file mode 100644 index 000000000..6a2738912 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vprelu-f32/main.cpp @@ -0,0 +1,102 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vprelu-f32 +// family: dsa-sfu +// target_ops: pto.vprelu +// scenarios: core-f32, vector-alpha +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVprelu_kernel_2d(float *v1, float *v2, float *v3, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVprelu_kernel_2d(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/compare.py b/test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/compare.py new file mode 100755 index 000000000..bbc6ab65a --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/compare.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vprelu-tail +# family: dsa-sfu +# target_ops: pto.vprelu +# scenarios: core-f32, vector-alpha, tail-mask +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + +ACTIVE_ELEMS = 1000 + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.size == count and output.size == count and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v3.bin", "v3.bin", np.float32, 1e-4, ACTIVE_ELEMS) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/golden.py b/test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/golden.py new file mode 100755 index 000000000..a9e101569 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/dsa-sfu/vprelu-tail +# family: dsa-sfu +# target_ops: pto.vprelu +# scenarios: core-f32, vector-alpha, tail-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = rng.uniform(0.05, 0.5, size=(ROWS, COLS)).astype(np.float32) + v3 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v3 = np.where(v1 >= 0.0, v1, v1 * v2).astype(np.float32, copy=False) + golden_v3.reshape(-1)[1000:] = 0.0 + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + golden_v3.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/kernel.pto b/test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/kernel.pto new file mode 100644 index 000000000..8bb0c56ac --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/kernel.pto @@ -0,0 +1,58 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vprelu-tail +// family: dsa-sfu +// target_ops: pto.vprelu +// scenarios: core-f32, vector-alpha, tail-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vprelu_tail_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c1000_i32 = arith.constant 1000 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_alpha = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_alpha, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1000_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %alpha = pto.vlds %ub_alpha[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %sum = pto.vprelu %vec, %alpha, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/launch.cpp b/test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/launch.cpp new file mode 100644 index 000000000..b4a675c6a --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/launch.cpp @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vprelu-tail +// family: dsa-sfu +// target_ops: pto.vprelu +// scenarios: core-f32, vector-alpha, tail-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vprelu_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ float *v3); + +void LaunchVprelu_tail_kernel_2d(float *v1, float *v2, float *v3, void *stream) { + vprelu_tail_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/main.cpp b/test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/main.cpp new file mode 100644 index 000000000..27b55f701 --- /dev/null +++ b/test/vpto/cases/micro-op/dsa-sfu/vprelu-tail/main.cpp @@ -0,0 +1,102 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/dsa-sfu/vprelu-tail +// family: dsa-sfu +// target_ops: pto.vprelu +// scenarios: core-f32, vector-alpha, tail-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVprelu_tail_kernel_2d(float *v1, float *v2, float *v3, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVprelu_tail_kernel_2d(v1Device, v2Device, v3Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/compare.py b/test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/compare.py new file mode 100755 index 000000000..c932750d2 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/gather-scatter/vgather2-duplicate-index +# family: gather-scatter +# target_ops: pto.vgather2 +# scenarios: core-f32, non-contiguous, explicit-index-pattern, load-effect-validation, no-alias +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/golden.py b/test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/golden.py new file mode 100755 index 000000000..4a5c343b6 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/golden.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/gather-scatter/vgather2-duplicate-index +# family: gather-scatter +# target_ops: pto.vgather2 +# scenarios: core-f32, non-contiguous, explicit-index-pattern, load-effect-validation, no-alias +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + flat = rng.uniform(-8.0, 8.0, size=(ROWS * COLS,)).astype(np.float32) + pair_ids = ((np.arange((ROWS * COLS) // 2, dtype=np.int32) * 29) + 5) % (ROWS * COLS) + offsets = np.repeat(pair_ids, 2) + gathered = np.zeros((ROWS * COLS,), dtype=np.float32) + for base in range(0, ROWS * COLS, 64): + lanes = np.arange(base + 8, base + 64, dtype=np.int32) + gathered[lanes] = flat[offsets[lanes]] + gathered = gathered.reshape(ROWS, COLS) + v1 = flat.reshape(ROWS, COLS) + v2 = offsets.reshape(ROWS, COLS) + v3 = np.zeros((ROWS, COLS), dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + gathered.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vgather2 duplicate-index validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/kernel.pto b/test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/kernel.pto new file mode 100644 index 000000000..41bb9d841 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/kernel.pto @@ -0,0 +1,59 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vgather2-duplicate-index +// family: gather-scatter +// target_ops: pto.vgather2 +// scenarios: core-f32, non-contiguous, explicit-index-pattern, load-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vgather2_duplicate_index_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c8_i32 = arith.constant 8 : i32 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_offsets = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_offsets, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %prefix_mask, %next_remaining = pto.plt_b32 %c8_i32 : i32 -> !pto.mask, i32 + %full_mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %suffix_mask = pto.pnot %prefix_mask, %full_mask : !pto.mask, !pto.mask -> !pto.mask + %offsets = pto.vlds %ub_offsets[%offset] : !pto.ptr -> !pto.vreg<64xi32> + %out = pto.vgather2 %ub_in, %offsets, %suffix_mask : !pto.ptr, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %full_mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/launch.cpp b/test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/launch.cpp new file mode 100644 index 000000000..1a2d0359e --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/launch.cpp @@ -0,0 +1,72 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vgather2-duplicate-index +// family: gather-scatter +// target_ops: pto.vgather2 +// scenarios: core-f32, non-contiguous, explicit-index-pattern, load-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vgather2_duplicate_index_kernel_2d( + __gm__ float *v1, __gm__ int *v2, __gm__ float *v3); + +void LaunchVgather2_duplicate_index_kernel_2d(float *v1, int *v2, float *v3, + void *stream) { + vgather2_duplicate_index_kernel_2d<<<1, nullptr, stream>>>( + (__gm__ float *)v1, (__gm__ int *)v2, (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/main.cpp b/test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/main.cpp new file mode 100644 index 000000000..df5907af3 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2-duplicate-index/main.cpp @@ -0,0 +1,141 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vgather2-duplicate-index +// family: gather-scatter +// target_ops: pto.vgather2 +// scenarios: core-f32, non-contiguous, explicit-index-pattern, load-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVgather2_duplicate_index_kernel_2d(float *v1, int *v2, float *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(int); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + int *v2Host = nullptr; + int *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVgather2_duplicate_index_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2/compare.py b/test/vpto/cases/micro-op/gather-scatter/vgather2/compare.py new file mode 100755 index 000000000..41f5c3f65 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/gather-scatter/vgather2 +# family: gather-scatter +# target_ops: pto.vgather2 +# scenarios: core-f32, full-mask, non-contiguous, explicit-index-pattern, load-effect-validation, no-alias +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2/golden.py b/test/vpto/cases/micro-op/gather-scatter/vgather2/golden.py new file mode 100755 index 000000000..714c54c63 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2/golden.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/gather-scatter/vgather2 +# family: gather-scatter +# target_ops: pto.vgather2 +# scenarios: core-f32, full-mask, non-contiguous, explicit-index-pattern, load-effect-validation, no-alias +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + flat = rng.uniform(-8.0, 8.0, size=(ROWS * COLS,)).astype(np.float32) + offsets = ((np.arange(ROWS * COLS, dtype=np.int32) * 17) + 3) % (ROWS * COLS) + gathered = flat[offsets].reshape(ROWS, COLS) + v1 = flat.reshape(ROWS, COLS) + v2 = offsets.reshape(ROWS, COLS) + v3 = np.zeros((ROWS, COLS), dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + gathered.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vgather2 validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2/kernel.pto b/test/vpto/cases/micro-op/gather-scatter/vgather2/kernel.pto new file mode 100644 index 000000000..969b7a666 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2/kernel.pto @@ -0,0 +1,75 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vgather2 +// family: gather-scatter +// target_ops: pto.vgather2 +// scenarios: core-f32, full-mask, non-contiguous, explicit-index-pattern, load-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vgather2_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_offsets = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_offsets, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %offsets = pto.vlds %ub_offsets[%offset] : !pto.ptr -> !pto.vreg<64xi32> + %out = pto.vgather2 %ub_in, %offsets, %mask : !pto.ptr, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2/launch.cpp b/test/vpto/cases/micro-op/gather-scatter/vgather2/launch.cpp new file mode 100644 index 000000000..e99c6741a --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2/launch.cpp @@ -0,0 +1,73 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vgather2 +// family: gather-scatter +// target_ops: pto.vgather2 +// scenarios: core-f32, full-mask, non-contiguous, explicit-index-pattern, load-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vgather2_kernel_2d(__gm__ float *v1, + __gm__ int *v2, + __gm__ float *v3); + +void LaunchVgather2_kernel_2d(float *v1, int *v2, float *v3, void *stream) { + vgather2_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ int *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2/main.cpp b/test/vpto/cases/micro-op/gather-scatter/vgather2/main.cpp new file mode 100644 index 000000000..e2a9b4804 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2/main.cpp @@ -0,0 +1,140 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vgather2 +// family: gather-scatter +// target_ops: pto.vgather2 +// scenarios: core-f32, full-mask, non-contiguous, explicit-index-pattern, load-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVgather2_kernel_2d(float *v1, int *v2, float *v3, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(int); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + int *v2Host = nullptr; + int *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVgather2_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2_bc-sparse-mask/compare.py b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc-sparse-mask/compare.py new file mode 100755 index 000000000..83f25f4ab --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc-sparse-mask/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/gather-scatter/vgather2_bc-sparse-mask +# family: gather-scatter +# target_ops: pto.vgather2_bc +# scenarios: core-f32, masked-gather, load-effect-validation, no-alias +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2_bc-sparse-mask/golden.py b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc-sparse-mask/golden.py new file mode 100755 index 000000000..e0cea7841 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc-sparse-mask/golden.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/gather-scatter/vgather2_bc-sparse-mask +# family: gather-scatter +# target_ops: pto.vgather2_bc +# scenarios: core-f32, masked-gather, load-effect-validation, no-alias +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + flat = rng.uniform(-8.0, 8.0, size=(ROWS * COLS,)).astype(np.float32) + offsets = ((np.arange(ROWS * COLS, dtype=np.int32) * 17) + 3) % (ROWS * COLS) + gathered = np.zeros((ROWS * COLS,), dtype=np.float32) + active = offsets < 64 + gathered[active] = flat[offsets[active]] + v1 = flat.reshape(ROWS, COLS) + v2 = offsets.reshape(ROWS, COLS) + v3 = np.zeros((ROWS, COLS), dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + gathered.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vgather2_bc sparse-mask validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2_bc-sparse-mask/kernel.pto b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc-sparse-mask/kernel.pto new file mode 100644 index 000000000..b37d1de0e --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc-sparse-mask/kernel.pto @@ -0,0 +1,58 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vgather2_bc-sparse-mask +// family: gather-scatter +// target_ops: pto.vgather2_bc +// scenarios: core-f32, masked-gather, load-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vgather2_bc_sparse_mask_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c64_i32 = arith.constant 64 : i32 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_offsets = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_offsets, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %full_mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %offsets = pto.vlds %ub_offsets[%offset] : !pto.ptr -> !pto.vreg<64xi32> + %gather_mask = pto.vcmps %offsets, %c64_i32, %full_mask, "lt" : !pto.vreg<64xi32>, i32, !pto.mask -> !pto.mask + %out = pto.vgather2_bc %ub_in, %offsets, %gather_mask : !pto.ptr, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %full_mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2_bc-sparse-mask/launch.cpp b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc-sparse-mask/launch.cpp new file mode 100644 index 000000000..333288162 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc-sparse-mask/launch.cpp @@ -0,0 +1,72 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vgather2_bc-sparse-mask +// family: gather-scatter +// target_ops: pto.vgather2_bc +// scenarios: core-f32, masked-gather, load-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vgather2_bc_sparse_mask_kernel_2d( + __gm__ float *v1, __gm__ int *v2, __gm__ float *v3); + +void LaunchVgather2_bc_sparse_mask_kernel_2d(float *v1, int *v2, float *v3, + void *stream) { + vgather2_bc_sparse_mask_kernel_2d<<<1, nullptr, stream>>>( + (__gm__ float *)v1, (__gm__ int *)v2, (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2_bc-sparse-mask/main.cpp b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc-sparse-mask/main.cpp new file mode 100644 index 000000000..66ab70307 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc-sparse-mask/main.cpp @@ -0,0 +1,141 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vgather2_bc-sparse-mask +// family: gather-scatter +// target_ops: pto.vgather2_bc +// scenarios: core-f32, masked-gather, load-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVgather2_bc_sparse_mask_kernel_2d(float *v1, int *v2, float *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(int); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + int *v2Host = nullptr; + int *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVgather2_bc_sparse_mask_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2_bc/compare.py b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc/compare.py new file mode 100755 index 000000000..4ebeae5d2 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/gather-scatter/vgather2_bc +# family: gather-scatter +# target_ops: pto.vgather2_bc +# scenarios: core-f32, full-mask, non-contiguous, masked-gather, load-effect-validation, no-alias +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2_bc/golden.py b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc/golden.py new file mode 100755 index 000000000..da03fb5a7 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc/golden.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/gather-scatter/vgather2_bc +# family: gather-scatter +# target_ops: pto.vgather2_bc +# scenarios: core-f32, full-mask, non-contiguous, masked-gather, load-effect-validation, no-alias +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + flat = rng.uniform(-8.0, 8.0, size=(ROWS * COLS,)).astype(np.float32) + offsets = ((np.arange(ROWS * COLS, dtype=np.int32) * 17) + 3) % (ROWS * COLS) + gathered = np.zeros((ROWS * COLS,), dtype=np.float32) + active = offsets < 256 + gathered[active] = flat[offsets[active]] + v1 = flat.reshape(ROWS, COLS) + v2 = offsets.reshape(ROWS, COLS) + v3 = np.zeros((ROWS, COLS), dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + gathered.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vgather2_bc validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2_bc/kernel.pto b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc/kernel.pto new file mode 100644 index 000000000..826bba01e --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc/kernel.pto @@ -0,0 +1,58 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vgather2_bc +// family: gather-scatter +// target_ops: pto.vgather2_bc +// scenarios: core-f32, full-mask, non-contiguous, masked-gather, load-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vgather2_bc_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c256_i32 = arith.constant 256 : i32 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_offsets = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_offsets, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %full_mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %offsets = pto.vlds %ub_offsets[%offset] : !pto.ptr -> !pto.vreg<64xi32> + %gather_mask = pto.vcmps %offsets, %c256_i32, %full_mask, "lt" : !pto.vreg<64xi32>, i32, !pto.mask -> !pto.mask + %out = pto.vgather2_bc %ub_in, %offsets, %gather_mask : !pto.ptr, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %full_mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2_bc/launch.cpp b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc/launch.cpp new file mode 100644 index 000000000..2c60a591c --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc/launch.cpp @@ -0,0 +1,73 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vgather2_bc +// family: gather-scatter +// target_ops: pto.vgather2_bc +// scenarios: core-f32, full-mask, non-contiguous, masked-gather, load-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vgather2_bc_kernel_2d(__gm__ float *v1, + __gm__ int *v2, + __gm__ float *v3); + +void LaunchVgather2_bc_kernel_2d(float *v1, int *v2, float *v3, void *stream) { + vgather2_bc_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ int *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vgather2_bc/main.cpp b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc/main.cpp new file mode 100644 index 000000000..73ba0e412 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgather2_bc/main.cpp @@ -0,0 +1,140 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vgather2_bc +// family: gather-scatter +// target_ops: pto.vgather2_bc +// scenarios: core-f32, full-mask, non-contiguous, masked-gather, load-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVgather2_bc_kernel_2d(float *v1, int *v2, float *v3, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(int); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + int *v2Host = nullptr; + int *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVgather2_bc_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vgatherb-block-boundary/compare.py b/test/vpto/cases/micro-op/gather-scatter/vgatherb-block-boundary/compare.py new file mode 100755 index 000000000..6d777d58f --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgatherb-block-boundary/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/gather-scatter/vgatherb-block-boundary +# family: gather-scatter +# target_ops: pto.vgatherb +# scenarios: core-f32, block-gather, aligned-base, load-effect-validation, no-alias +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/gather-scatter/vgatherb-block-boundary/golden.py b/test/vpto/cases/micro-op/gather-scatter/vgatherb-block-boundary/golden.py new file mode 100755 index 000000000..2bfb0c0d3 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgatherb-block-boundary/golden.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/gather-scatter/vgatherb-block-boundary +# family: gather-scatter +# target_ops: pto.vgatherb +# scenarios: core-f32, block-gather, aligned-base, load-effect-validation, no-alias +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +BLOCK_FLOATS = 8 +BLOCKS_PER_ITER = 8 +ITER_ELEMS = 64 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + flat = rng.uniform(-8.0, 8.0, size=(ROWS * COLS,)).astype(np.float32) + blocks = flat.reshape(-1, BLOCK_FLOATS) + offsets = np.zeros((ROWS * COLS,), dtype=np.int32) + gathered = np.zeros((ROWS * COLS,), dtype=np.float32) + boundary_patterns = np.array([0, 1, 15, 16, 31, 32, 63, 127], dtype=np.int32) + + for chunk in range((ROWS * COLS) // ITER_ELEMS): + block_ids = (boundary_patterns + chunk * 3) % blocks.shape[0] + offsets[chunk * ITER_ELEMS:chunk * ITER_ELEMS + BLOCKS_PER_ITER] = block_ids * 32 + gathered[chunk * ITER_ELEMS:(chunk + 1) * ITER_ELEMS] = blocks[block_ids].reshape(-1) + + v1 = flat.reshape(ROWS, COLS) + v2 = offsets.reshape(ROWS, COLS) + v3 = np.zeros((ROWS, COLS), dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + gathered.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vgatherb block-boundary validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/gather-scatter/vgatherb-block-boundary/kernel.pto b/test/vpto/cases/micro-op/gather-scatter/vgatherb-block-boundary/kernel.pto new file mode 100644 index 000000000..e44003903 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgatherb-block-boundary/kernel.pto @@ -0,0 +1,58 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vgatherb-block-boundary +// family: gather-scatter +// target_ops: pto.vgatherb +// scenarios: core-f32, block-gather, aligned-base, load-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vgatherb_block_boundary_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c8_i32 = arith.constant 8 : i32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_offsets = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_offsets, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %full_mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %gather_mask, %_tail = pto.plt_b32 %c8_i32 : i32 -> !pto.mask, i32 + %offsets = pto.vlds %ub_offsets[%offset] : !pto.ptr -> !pto.vreg<64xi32> + %out = pto.vgatherb %ub_in, %offsets, %gather_mask : !pto.ptr, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %full_mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vgatherb-block-boundary/launch.cpp b/test/vpto/cases/micro-op/gather-scatter/vgatherb-block-boundary/launch.cpp new file mode 100644 index 000000000..fb6f40c39 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgatherb-block-boundary/launch.cpp @@ -0,0 +1,72 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vgatherb-block-boundary +// family: gather-scatter +// target_ops: pto.vgatherb +// scenarios: core-f32, block-gather, aligned-base, load-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vgatherb_block_boundary_kernel_2d( + __gm__ float *v1, __gm__ int *v2, __gm__ float *v3); + +void LaunchVgatherb_block_boundary_kernel_2d(float *v1, int *v2, float *v3, + void *stream) { + vgatherb_block_boundary_kernel_2d<<<1, nullptr, stream>>>( + (__gm__ float *)v1, (__gm__ int *)v2, (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vgatherb-block-boundary/main.cpp b/test/vpto/cases/micro-op/gather-scatter/vgatherb-block-boundary/main.cpp new file mode 100644 index 000000000..77c2cb46f --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgatherb-block-boundary/main.cpp @@ -0,0 +1,141 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vgatherb-block-boundary +// family: gather-scatter +// target_ops: pto.vgatherb +// scenarios: core-f32, block-gather, aligned-base, load-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVgatherb_block_boundary_kernel_2d(float *v1, int *v2, float *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(int); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + int *v2Host = nullptr; + int *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVgatherb_block_boundary_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vgatherb/compare.py b/test/vpto/cases/micro-op/gather-scatter/vgatherb/compare.py new file mode 100755 index 000000000..e9d439e1a --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgatherb/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/gather-scatter/vgatherb +# family: gather-scatter +# target_ops: pto.vgatherb +# scenarios: core-f32, full-mask, block-gather, aligned-base, load-effect-validation, no-alias +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/gather-scatter/vgatherb/golden.py b/test/vpto/cases/micro-op/gather-scatter/vgatherb/golden.py new file mode 100755 index 000000000..e102cecfe --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgatherb/golden.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/gather-scatter/vgatherb +# family: gather-scatter +# target_ops: pto.vgatherb +# scenarios: core-f32, full-mask, block-gather, aligned-base, load-effect-validation, no-alias +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +BLOCK_FLOATS = 8 +BLOCKS_PER_ITER = 8 +ITER_ELEMS = 64 +SEED = 19 +OUT_SENTINEL = np.float32(-123.25) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + flat = rng.uniform(-8.0, 8.0, size=(ROWS * COLS,)).astype(np.float32) + blocks = flat.reshape(-1, BLOCK_FLOATS) + offsets = np.zeros((ROWS * COLS,), dtype=np.int32) + gathered = np.full((ROWS * COLS,), OUT_SENTINEL, dtype=np.float32) + + for chunk in range((ROWS * COLS) // ITER_ELEMS): + block_ids = ((np.arange(BLOCKS_PER_ITER, dtype=np.int32) + chunk * 11) * 7 + 3) % blocks.shape[0] + offsets[chunk * ITER_ELEMS:chunk * ITER_ELEMS + BLOCKS_PER_ITER] = block_ids * 32 + gathered[chunk * ITER_ELEMS:(chunk + 1) * ITER_ELEMS] = blocks[block_ids].reshape(-1) + + v1 = flat.reshape(ROWS, COLS) + v2 = offsets.reshape(ROWS, COLS) + v3 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + gathered.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vgatherb validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/gather-scatter/vgatherb/kernel.pto b/test/vpto/cases/micro-op/gather-scatter/vgatherb/kernel.pto new file mode 100644 index 000000000..edfcfad54 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgatherb/kernel.pto @@ -0,0 +1,58 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vgatherb +// family: gather-scatter +// target_ops: pto.vgatherb +// scenarios: core-f32, full-mask, block-gather, aligned-base, load-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vgatherb_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c8_i32 = arith.constant 8 : i32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_offsets = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_offsets, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %full_mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %gather_mask, %_tail = pto.plt_b32 %c8_i32 : i32 -> !pto.mask, i32 + %offsets = pto.vlds %ub_offsets[%offset] : !pto.ptr -> !pto.vreg<64xi32> + %out = pto.vgatherb %ub_in, %offsets, %gather_mask : !pto.ptr, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %full_mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vgatherb/launch.cpp b/test/vpto/cases/micro-op/gather-scatter/vgatherb/launch.cpp new file mode 100644 index 000000000..589f236be --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgatherb/launch.cpp @@ -0,0 +1,73 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vgatherb +// family: gather-scatter +// target_ops: pto.vgatherb +// scenarios: core-f32, full-mask, block-gather, aligned-base, load-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vgatherb_kernel_2d(__gm__ float *v1, + __gm__ int *v2, + __gm__ float *v3); + +void LaunchVgatherb_kernel_2d(float *v1, int *v2, float *v3, void *stream) { + vgatherb_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ int *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vgatherb/main.cpp b/test/vpto/cases/micro-op/gather-scatter/vgatherb/main.cpp new file mode 100644 index 000000000..e16952c96 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vgatherb/main.cpp @@ -0,0 +1,140 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vgatherb +// family: gather-scatter +// target_ops: pto.vgatherb +// scenarios: core-f32, full-mask, block-gather, aligned-base, load-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVgatherb_kernel_2d(float *v1, int *v2, float *v3, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(int); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + int *v2Host = nullptr; + int *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVgatherb_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/compare.py b/test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/compare.py new file mode 100755 index 000000000..016bfa5b7 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/gather-scatter/vscatter-out-of-order-index +# family: gather-scatter +# target_ops: pto.vscatter +# scenarios: core-f32, explicit-index-pattern, scatter-store, store-effect-validation, no-alias +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/golden.py b/test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/golden.py new file mode 100755 index 000000000..99761f514 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/golden.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/gather-scatter/vscatter-out-of-order-index +# family: gather-scatter +# target_ops: pto.vscatter +# scenarios: core-f32, explicit-index-pattern, scatter-store, store-effect-validation, no-alias +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + flat = rng.uniform(-8.0, 8.0, size=(ROWS * COLS,)).astype(np.float32) + offsets = ((np.arange(ROWS * COLS, dtype=np.int32) * 43) + 11) % (ROWS * COLS) + scattered = np.zeros((ROWS * COLS,), dtype=np.float32) + for base in range(0, ROWS * COLS, 64): + lanes = np.arange(base + 8, base + 64, dtype=np.int32) + scattered[offsets[lanes]] = flat[lanes] + v1 = flat.reshape(ROWS, COLS) + v2 = offsets.reshape(ROWS, COLS) + v3 = np.zeros((ROWS, COLS), dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + scattered.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vscatter validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/kernel.pto b/test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/kernel.pto new file mode 100644 index 000000000..d38f20f7f --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/kernel.pto @@ -0,0 +1,62 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vscatter-out-of-order-index +// family: gather-scatter +// target_ops: pto.vscatter +// scenarios: core-f32, explicit-index-pattern, scatter-store, store-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vscatter_out_of_order_index_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c8_i32 = arith.constant 8 : i32 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_offsets = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_offsets, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg2, %ub_out, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %prefix_mask, %next_remaining = pto.plt_b32 %c8_i32 : i32 -> !pto.mask, i32 + %full_mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %suffix_mask = pto.pnot %prefix_mask, %full_mask : !pto.mask, !pto.mask -> !pto.mask + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %offsets = pto.vlds %ub_offsets[%offset] : !pto.ptr -> !pto.vreg<64xi32> + pto.vscatter %vec, %ub_out, %offsets, %suffix_mask : !pto.vreg<64xf32>, !pto.ptr, !pto.vreg<64xi32>, !pto.mask + scf.yield %remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/launch.cpp b/test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/launch.cpp new file mode 100644 index 000000000..87b02ee5d --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/launch.cpp @@ -0,0 +1,72 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vscatter-out-of-order-index +// family: gather-scatter +// target_ops: pto.vscatter +// scenarios: core-f32, explicit-index-pattern, scatter-store, store-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vscatter_out_of_order_index_kernel_2d( + __gm__ float *v1, __gm__ int *v2, __gm__ float *v3); + +void LaunchVscatter_out_of_order_index_kernel_2d(float *v1, int *v2, + float *v3, void *stream) { + vscatter_out_of_order_index_kernel_2d<<<1, nullptr, stream>>>( + (__gm__ float *)v1, (__gm__ int *)v2, (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/main.cpp b/test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/main.cpp new file mode 100644 index 000000000..f762aa293 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vscatter-out-of-order-index/main.cpp @@ -0,0 +1,141 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vscatter-out-of-order-index +// family: gather-scatter +// target_ops: pto.vscatter +// scenarios: core-f32, explicit-index-pattern, scatter-store, store-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVscatter_out_of_order_index_kernel_2d(float *v1, int *v2, float *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(int); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + int *v2Host = nullptr; + int *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVscatter_out_of_order_index_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vscatter/compare.py b/test/vpto/cases/micro-op/gather-scatter/vscatter/compare.py new file mode 100755 index 000000000..ada19a30e --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vscatter/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/gather-scatter/vscatter +# family: gather-scatter +# target_ops: pto.vscatter +# scenarios: core-f32, full-mask, non-contiguous, explicit-index-pattern, scatter-store, store-effect-validation, no-alias +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v3.bin", "v3.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/gather-scatter/vscatter/golden.py b/test/vpto/cases/micro-op/gather-scatter/vscatter/golden.py new file mode 100755 index 000000000..252356095 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vscatter/golden.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/gather-scatter/vscatter +# family: gather-scatter +# target_ops: pto.vscatter +# scenarios: core-f32, full-mask, non-contiguous, explicit-index-pattern, scatter-store, store-effect-validation, no-alias +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + flat = rng.uniform(-8.0, 8.0, size=(ROWS * COLS,)).astype(np.float32) + offsets = ((np.arange(ROWS * COLS, dtype=np.int32) * 29) + 7) % (ROWS * COLS) + scattered = np.zeros((ROWS * COLS,), dtype=np.float32) + scattered[offsets] = flat + v1 = flat.reshape(ROWS, COLS) + v2 = offsets.reshape(ROWS, COLS) + v3 = np.zeros((ROWS, COLS), dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + v3.reshape(-1).tofile(output_dir / "v3.bin") + scattered.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vscatter validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/gather-scatter/vscatter/kernel.pto b/test/vpto/cases/micro-op/gather-scatter/vscatter/kernel.pto new file mode 100644 index 000000000..7849e4824 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vscatter/kernel.pto @@ -0,0 +1,78 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vscatter +// family: gather-scatter +// target_ops: pto.vscatter +// scenarios: core-f32, full-mask, non-contiguous, explicit-index-pattern, scatter-store, store-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vscatter_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_offsets = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_offsets, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg2, %ub_out, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %offsets = pto.vlds %ub_offsets[%offset] : !pto.ptr -> !pto.vreg<64xi32> + pto.vscatter %vec, %ub_out, %offsets, %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.vreg<64xi32>, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vscatter/launch.cpp b/test/vpto/cases/micro-op/gather-scatter/vscatter/launch.cpp new file mode 100644 index 000000000..79296467b --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vscatter/launch.cpp @@ -0,0 +1,73 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vscatter +// family: gather-scatter +// target_ops: pto.vscatter +// scenarios: core-f32, full-mask, non-contiguous, explicit-index-pattern, scatter-store, store-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vscatter_kernel_2d(__gm__ float *v1, + __gm__ int *v2, + __gm__ float *v3); + +void LaunchVscatter_kernel_2d(float *v1, int *v2, float *v3, void *stream) { + vscatter_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ int *)v2, + (__gm__ float *)v3); +} diff --git a/test/vpto/cases/micro-op/gather-scatter/vscatter/main.cpp b/test/vpto/cases/micro-op/gather-scatter/vscatter/main.cpp new file mode 100644 index 000000000..613ee0282 --- /dev/null +++ b/test/vpto/cases/micro-op/gather-scatter/vscatter/main.cpp @@ -0,0 +1,140 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/gather-scatter/vscatter +// family: gather-scatter +// target_ops: pto.vscatter +// scenarios: core-f32, full-mask, non-contiguous, explicit-index-pattern, scatter-store, store-effect-validation, no-alias +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVscatter_kernel_2d(float *v1, int *v2, float *v3, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(int); + size_t elemCount_v3 = 1024; + size_t fileSize_v3 = elemCount_v3 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + int *v2Host = nullptr; + int *v2Device = nullptr; + float *v3Host = nullptr; + float *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVscatter_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pand/compare.py b/test/vpto/cases/micro-op/materialization-predicate/pand/compare.py new file mode 100755 index 000000000..546f4445e --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pand/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pand +# family: materialization-predicate +# target_ops: pto.pand +# scenarios: predicate-transform, logical-and +# coding=utf-8 + +import os +import sys + +import numpy as np + + +EXPECTED_WORDS = 32 +PREFIX_WORDS = 8 + + +def compare_words(golden_path, output_path): + if not os.path.exists(output_path) or not os.path.exists(golden_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print( + f"[ERROR] Unexpected word count: golden={golden.size} " + f"out={output.size} expected={EXPECTED_WORDS}" + ) + return False + golden = golden[:PREFIX_WORDS] + output = output[:PREFIX_WORDS] + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed predicate words): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pand/golden.py b/test/vpto/cases/micro-op/materialization-predicate/pand/golden.py new file mode 100755 index 000000000..6d0506864 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pand/golden.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pand +# family: materialization-predicate +# target_ops: pto.pand +# scenarios: predicate-transform, logical-and +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 32 +PREFIX_BITS = 13 +SUFFIX_BITS = 7 +PREDICATE_BITS = 256 +NIBBLE_COUNT = PREDICATE_BITS // 2 + + +def pack_nibbles(nibbles: np.ndarray) -> np.ndarray: + words = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + for idx, nibble in enumerate(nibbles): + words[idx // 8] |= np.uint32(int(nibble) & 0xF) << np.uint32((idx % 8) * 4) + return words + + +def generate(output_dir: Path, seed: int) -> None: + del seed + output_init = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + lhs = np.zeros((NIBBLE_COUNT,), dtype=np.uint8) + rhs = np.zeros((NIBBLE_COUNT,), dtype=np.uint8) + lhs[:PREFIX_BITS] = 1 + rhs[:SUFFIX_BITS] = 1 + golden = pack_nibbles(lhs & rhs) + + output_dir.mkdir(parents=True, exist_ok=True) + output_init.tofile(output_dir / "v1.bin") + golden.tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate packed predicate golden for VPTO micro-op validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pand/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pand/kernel.pto new file mode 100644 index 000000000..2b65c8fdd --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pand/kernel.pto @@ -0,0 +1,56 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pand +// family: materialization-predicate +// target_ops: pto.pand +// scenarios: predicate-transform +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @pand_kernel_2d(%arg0: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c7 = arith.constant 7 : i32 + %c13 = arith.constant 13 : i32 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + %all = pto.pset_b32 "PAT_ALL" : !pto.mask + %lhs, %lhs_next = pto.plt_b32 %c13 : i32 -> !pto.mask, i32 + %rhs, %rhs_next = pto.plt_b32 %c7 : i32 -> !pto.mask, i32 + %out = pto.pand %lhs, %rhs, %all : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + pto.psts %out, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %gm_out, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pand/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/pand/launch.cpp new file mode 100644 index 000000000..bf665aec7 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pand/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pand +// family: materialization-predicate +// target_ops: pto.pand +// scenarios: predicate-transform +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void pand_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPand(uint32_t *v1, void *stream) { + pand_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pand/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/pand/main.cpp new file mode 100644 index 000000000..751eed2d9 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pand/main.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pand +// family: materialization-predicate +// target_ops: pto.pand +// scenarios: predicate-transform +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); const char *_recent = aclGetRecentErrMsg(); if (_recent != nullptr && _recent[0] != '\0') { std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); } rc = 1; goto cleanup; } } while (0) + +void LaunchPand(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPand(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16-nontrivial/compare.py b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16-nontrivial/compare.py new file mode 100755 index 000000000..25e69c8a2 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16-nontrivial/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pdintlv_b16-nontrivial +# family: materialization-predicate +# target_ops: pto.pdintlv_b16 +# scenarios: predicate-transform, lane-order, representative-logical-elements +# coding=utf-8 + +import os +import sys + +import numpy as np + + +EXPECTED_WORDS = 32 +PREFIX_WORDS = 16 + + +def compare_words(golden_path, output_path): + if not os.path.exists(output_path) or not os.path.exists(golden_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print( + f"[ERROR] Unexpected word count: golden={golden.size} " + f"out={output.size} expected={EXPECTED_WORDS}" + ) + return False + golden = golden[:PREFIX_WORDS] + output = output[:PREFIX_WORDS] + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed predicate words): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16-nontrivial/golden.py b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16-nontrivial/golden.py new file mode 100755 index 000000000..814eb34b5 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16-nontrivial/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pdintlv_b16-nontrivial +# family: materialization-predicate +# target_ops: pto.pdintlv_b16 +# scenarios: predicate-transform, lane-order, representative-logical-elements +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 32 +GOLDEN_PREFIX_WORDS = np.array([85, 0, 0, 0, 286331153, 286331153, 286331153, 286331153, 85, 0, 0, 0, 0, 0, 0, 0], dtype=np.uint32) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + output_init = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden[: GOLDEN_PREFIX_WORDS.size] = GOLDEN_PREFIX_WORDS + + output_dir.mkdir(parents=True, exist_ok=True) + output_init.tofile(output_dir / "v1.bin") + golden.tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate packed predicate golden for VPTO micro-op validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16-nontrivial/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16-nontrivial/kernel.pto new file mode 100644 index 000000000..16cc2a84d --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16-nontrivial/kernel.pto @@ -0,0 +1,38 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pdintlv_b16-nontrivial +// family: materialization-predicate +// target_ops: pto.pdintlv_b16 +// scenarios: predicate-transform, lane-order, nontrivial-pattern +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @pdintlv_b16_nontrivial_kernel_2d(%arg0: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c16 = arith.constant 16 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %lhs = pto.pset_b16 "PAT_VL8" : !pto.mask + %rhs = pto.pset_b16 "PAT_M4" : !pto.mask + %low, %high = pto.pdintlv_b16 %lhs, %rhs : !pto.mask, !pto.mask -> !pto.mask, !pto.mask + pto.psts %low, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %high, %ub_out[%c32], "NORM" : !pto.mask, !pto.ptr, index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %gm_out, %c64_i64 + nburst(%c1_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16-nontrivial/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16-nontrivial/launch.cpp new file mode 100644 index 000000000..182f92536 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16-nontrivial/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pdintlv_b16-nontrivial +// family: materialization-predicate +// target_ops: pto.pdintlv_b16 +// scenarios: predicate-transform, lane-order, nontrivial-pattern +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void pdintlv_b16_nontrivial_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPdintlvB16Nontrivial(uint32_t *v1, void *stream) { + pdintlv_b16_nontrivial_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16-nontrivial/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16-nontrivial/main.cpp new file mode 100644 index 000000000..02d8a3875 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16-nontrivial/main.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pdintlv_b16-nontrivial +// family: materialization-predicate +// target_ops: pto.pdintlv_b16 +// scenarios: predicate-transform, lane-order, nontrivial-pattern +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); const char *_recent = aclGetRecentErrMsg(); if (_recent != nullptr && _recent[0] != '\0') { std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); } rc = 1; goto cleanup; } } while (0) + +void LaunchPdintlvB16Nontrivial(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPdintlvB16Nontrivial(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16/compare.py b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16/compare.py new file mode 100755 index 000000000..2a01bf650 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pdintlv_b16 +# family: materialization-predicate +# target_ops: pto.pdintlv_b16 +# scenarios: predicate-transform, lane-order +# coding=utf-8 + +import os +import sys + +import numpy as np + + +EXPECTED_WORDS = 32 +PREFIX_WORDS = 16 + + +def compare_words(golden_path, output_path): + if not os.path.exists(output_path) or not os.path.exists(golden_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print( + f"[ERROR] Unexpected word count: golden={golden.size} " + f"out={output.size} expected={EXPECTED_WORDS}" + ) + return False + golden = golden[:PREFIX_WORDS] + output = output[:PREFIX_WORDS] + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed predicate words): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16/golden.py b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16/golden.py new file mode 100755 index 000000000..e9227db10 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pdintlv_b16 +# family: materialization-predicate +# target_ops: pto.pdintlv_b16 +# scenarios: predicate-transform, lane-order +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 32 +GOLDEN_PREFIX_WORDS = np.array([1431655765, 1431655765, 1431655765, 1431655765, 0, 0, 0, 0, 1431655765, 1431655765, 1431655765, 1431655765, 0, 0, 0, 0], dtype=np.uint32) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + output_init = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden[: GOLDEN_PREFIX_WORDS.size] = GOLDEN_PREFIX_WORDS + + output_dir.mkdir(parents=True, exist_ok=True) + output_init.tofile(output_dir / "v1.bin") + golden.tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate packed predicate golden for VPTO micro-op validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16/kernel.pto new file mode 100644 index 000000000..af38280cf --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16/kernel.pto @@ -0,0 +1,38 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pdintlv_b16 +// family: materialization-predicate +// target_ops: pto.pdintlv_b16 +// scenarios: predicate-transform, lane-order +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @pdintlv_b16_kernel_2d(%arg0: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c16 = arith.constant 16 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %lhs = pto.pset_b16 "PAT_ALL" : !pto.mask + %rhs = pto.pset_b16 "PAT_ALLF" : !pto.mask + %low, %high = pto.pdintlv_b16 %lhs, %rhs : !pto.mask, !pto.mask -> !pto.mask, !pto.mask + pto.psts %low, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %high, %ub_out[%c32], "NORM" : !pto.mask, !pto.ptr, index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %gm_out, %c64_i64 + nburst(%c1_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16/launch.cpp new file mode 100644 index 000000000..519e90a51 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pdintlv_b16 +// family: materialization-predicate +// target_ops: pto.pdintlv_b16 +// scenarios: predicate-transform, lane-order +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void pdintlv_b16_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPdintlvB16(uint32_t *v1, void *stream) { + pdintlv_b16_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16/main.cpp new file mode 100644 index 000000000..e2491af41 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b16/main.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pdintlv_b16 +// family: materialization-predicate +// target_ops: pto.pdintlv_b16 +// scenarios: predicate-transform, lane-order +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); const char *_recent = aclGetRecentErrMsg(); if (_recent != nullptr && _recent[0] != '\0') { std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); } rc = 1; goto cleanup; } } while (0) + +void LaunchPdintlvB16(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPdintlvB16(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32-nontrivial/compare.py b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32-nontrivial/compare.py new file mode 100755 index 000000000..13e93d501 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32-nontrivial/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pdintlv_b32-nontrivial +# family: materialization-predicate +# target_ops: pto.pdintlv_b32 +# scenarios: predicate-transform, lane-order, representative-logical-elements +# coding=utf-8 + +import os +import sys + +import numpy as np + + +EXPECTED_WORDS = 32 +PREFIX_WORDS = 16 + + +def compare_words(golden_path, output_path): + if not os.path.exists(output_path) or not os.path.exists(golden_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print( + f"[ERROR] Unexpected word count: golden={golden.size} " + f"out={output.size} expected={EXPECTED_WORDS}" + ) + return False + golden = golden[:PREFIX_WORDS] + output = output[:PREFIX_WORDS] + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed predicate words): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32-nontrivial/golden.py b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32-nontrivial/golden.py new file mode 100755 index 000000000..013f7751e --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32-nontrivial/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pdintlv_b32-nontrivial +# family: materialization-predicate +# target_ops: pto.pdintlv_b32 +# scenarios: predicate-transform, lane-order, representative-logical-elements +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 32 +GOLDEN_PREFIX_WORDS = np.array([4369, 0, 0, 0, 16843009, 16843009, 16843009, 16843009, 4369, 0, 0, 0, 0, 0, 0, 0], dtype=np.uint32) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + output_init = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden[: GOLDEN_PREFIX_WORDS.size] = GOLDEN_PREFIX_WORDS + + output_dir.mkdir(parents=True, exist_ok=True) + output_init.tofile(output_dir / "v1.bin") + golden.tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate packed predicate golden for VPTO micro-op validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32-nontrivial/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32-nontrivial/kernel.pto new file mode 100644 index 000000000..815b3b373 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32-nontrivial/kernel.pto @@ -0,0 +1,38 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pdintlv_b32-nontrivial +// family: materialization-predicate +// target_ops: pto.pdintlv_b32 +// scenarios: predicate-transform, lane-order, nontrivial-pattern +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @pdintlv_b32_nontrivial_kernel_2d(%arg0: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c16 = arith.constant 16 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %lhs = pto.pset_b32 "PAT_VL8" : !pto.mask + %rhs = pto.pset_b32 "PAT_M4" : !pto.mask + %low, %high = pto.pdintlv_b32 %lhs, %rhs : !pto.mask, !pto.mask -> !pto.mask, !pto.mask + pto.psts %low, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %high, %ub_out[%c32], "NORM" : !pto.mask, !pto.ptr, index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %gm_out, %c64_i64 + nburst(%c1_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32-nontrivial/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32-nontrivial/launch.cpp new file mode 100644 index 000000000..9a6cd6a5e --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32-nontrivial/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pdintlv_b32-nontrivial +// family: materialization-predicate +// target_ops: pto.pdintlv_b32 +// scenarios: predicate-transform, lane-order, nontrivial-pattern +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void pdintlv_b32_nontrivial_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPdintlvB32Nontrivial(uint32_t *v1, void *stream) { + pdintlv_b32_nontrivial_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32-nontrivial/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32-nontrivial/main.cpp new file mode 100644 index 000000000..97d56c906 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32-nontrivial/main.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pdintlv_b32-nontrivial +// family: materialization-predicate +// target_ops: pto.pdintlv_b32 +// scenarios: predicate-transform, lane-order, nontrivial-pattern +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); const char *_recent = aclGetRecentErrMsg(); if (_recent != nullptr && _recent[0] != '\0') { std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); } rc = 1; goto cleanup; } } while (0) + +void LaunchPdintlvB32Nontrivial(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPdintlvB32Nontrivial(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32/compare.py b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32/compare.py new file mode 100755 index 000000000..fab797df6 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pdintlv_b32 +# family: materialization-predicate +# target_ops: pto.pdintlv_b32 +# scenarios: predicate-transform, lane-order +# coding=utf-8 + +import os +import sys + +import numpy as np + + +EXPECTED_WORDS = 32 +PREFIX_WORDS = 16 + + +def compare_words(golden_path, output_path): + if not os.path.exists(output_path) or not os.path.exists(golden_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print( + f"[ERROR] Unexpected word count: golden={golden.size} " + f"out={output.size} expected={EXPECTED_WORDS}" + ) + return False + golden = golden[:PREFIX_WORDS] + output = output[:PREFIX_WORDS] + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed predicate words): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32/golden.py b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32/golden.py new file mode 100755 index 000000000..cd1487eef --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pdintlv_b32 +# family: materialization-predicate +# target_ops: pto.pdintlv_b32 +# scenarios: predicate-transform, lane-order +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 32 +GOLDEN_PREFIX_WORDS = np.array([286331153, 286331153, 286331153, 286331153, 0, 0, 0, 0, 286331153, 286331153, 286331153, 286331153, 0, 0, 0, 0], dtype=np.uint32) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + output_init = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden[: GOLDEN_PREFIX_WORDS.size] = GOLDEN_PREFIX_WORDS + + output_dir.mkdir(parents=True, exist_ok=True) + output_init.tofile(output_dir / "v1.bin") + golden.tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate packed predicate golden for VPTO micro-op validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32/kernel.pto new file mode 100644 index 000000000..eb6ee504e --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32/kernel.pto @@ -0,0 +1,38 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pdintlv_b32 +// family: materialization-predicate +// target_ops: pto.pdintlv_b32 +// scenarios: predicate-transform, lane-order +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @pdintlv_b32_kernel_2d(%arg0: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c16 = arith.constant 16 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %lhs = pto.pset_b32 "PAT_ALL" : !pto.mask + %rhs = pto.pset_b32 "PAT_ALLF" : !pto.mask + %low, %high = pto.pdintlv_b32 %lhs, %rhs : !pto.mask, !pto.mask -> !pto.mask, !pto.mask + pto.psts %low, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %high, %ub_out[%c32], "NORM" : !pto.mask, !pto.ptr, index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %gm_out, %c64_i64 + nburst(%c1_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32/launch.cpp new file mode 100644 index 000000000..316cfc086 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pdintlv_b32 +// family: materialization-predicate +// target_ops: pto.pdintlv_b32 +// scenarios: predicate-transform, lane-order +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void pdintlv_b32_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPdintlvB32(uint32_t *v1, void *stream) { + pdintlv_b32_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32/main.cpp new file mode 100644 index 000000000..7af8a309d --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b32/main.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pdintlv_b32 +// family: materialization-predicate +// target_ops: pto.pdintlv_b32 +// scenarios: predicate-transform, lane-order +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); const char *_recent = aclGetRecentErrMsg(); if (_recent != nullptr && _recent[0] != '\0') { std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); } rc = 1; goto cleanup; } } while (0) + +void LaunchPdintlvB32(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPdintlvB32(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8-nontrivial/compare.py b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8-nontrivial/compare.py new file mode 100755 index 000000000..12db124bd --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8-nontrivial/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pdintlv_b8-nontrivial +# family: materialization-predicate +# target_ops: pto.pdintlv_b8 +# scenarios: predicate-transform, lane-order, representative-logical-elements +# coding=utf-8 + +import os +import sys + +import numpy as np + + +EXPECTED_WORDS = 32 +PREFIX_WORDS = 16 + + +def compare_words(golden_path, output_path): + if not os.path.exists(output_path) or not os.path.exists(golden_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print( + f"[ERROR] Unexpected word count: golden={golden.size} " + f"out={output.size} expected={EXPECTED_WORDS}" + ) + return False + golden = golden[:PREFIX_WORDS] + output = output[:PREFIX_WORDS] + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed predicate words): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8-nontrivial/golden.py b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8-nontrivial/golden.py new file mode 100755 index 000000000..1c58d3b87 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8-nontrivial/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pdintlv_b8-nontrivial +# family: materialization-predicate +# target_ops: pto.pdintlv_b8 +# scenarios: predicate-transform, lane-order, representative-logical-elements +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 32 +GOLDEN_PREFIX_WORDS = np.array([15, 0, 0, 0, 1431655765, 1431655765, 1431655765, 1431655765, 15, 0, 0, 0, 0, 0, 0, 0], dtype=np.uint32) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + output_init = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden[: GOLDEN_PREFIX_WORDS.size] = GOLDEN_PREFIX_WORDS + + output_dir.mkdir(parents=True, exist_ok=True) + output_init.tofile(output_dir / "v1.bin") + golden.tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate packed predicate golden for VPTO micro-op validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8-nontrivial/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8-nontrivial/kernel.pto new file mode 100644 index 000000000..a65d6bec5 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8-nontrivial/kernel.pto @@ -0,0 +1,38 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pdintlv_b8-nontrivial +// family: materialization-predicate +// target_ops: pto.pdintlv_b8 +// scenarios: predicate-transform, lane-order, nontrivial-pattern +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @pdintlv_b8_nontrivial_kernel_2d(%arg0: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c16 = arith.constant 16 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %lhs = pto.pset_b8 "PAT_VL8" : !pto.mask + %rhs = pto.pset_b8 "PAT_M4" : !pto.mask + %low, %high = pto.pdintlv_b8 %lhs, %rhs : !pto.mask, !pto.mask -> !pto.mask, !pto.mask + pto.psts %low, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %high, %ub_out[%c32], "NORM" : !pto.mask, !pto.ptr, index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %gm_out, %c64_i64 + nburst(%c1_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8-nontrivial/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8-nontrivial/launch.cpp new file mode 100644 index 000000000..e6e2949f5 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8-nontrivial/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pdintlv_b8-nontrivial +// family: materialization-predicate +// target_ops: pto.pdintlv_b8 +// scenarios: predicate-transform, lane-order, nontrivial-pattern +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void pdintlv_b8_nontrivial_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPdintlvB8Nontrivial(uint32_t *v1, void *stream) { + pdintlv_b8_nontrivial_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8-nontrivial/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8-nontrivial/main.cpp new file mode 100644 index 000000000..71e67e085 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8-nontrivial/main.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pdintlv_b8-nontrivial +// family: materialization-predicate +// target_ops: pto.pdintlv_b8 +// scenarios: predicate-transform, lane-order, nontrivial-pattern +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); const char *_recent = aclGetRecentErrMsg(); if (_recent != nullptr && _recent[0] != '\0') { std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); } rc = 1; goto cleanup; } } while (0) + +void LaunchPdintlvB8Nontrivial(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPdintlvB8Nontrivial(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8/compare.py b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8/compare.py new file mode 100755 index 000000000..305f17e97 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pdintlv_b8 +# family: materialization-predicate +# target_ops: pto.pdintlv_b8 +# scenarios: predicate-transform, lane-order +# coding=utf-8 + +import os +import sys + +import numpy as np + + +EXPECTED_WORDS = 32 +PREFIX_WORDS = 16 + + +def compare_words(golden_path, output_path): + if not os.path.exists(output_path) or not os.path.exists(golden_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print( + f"[ERROR] Unexpected word count: golden={golden.size} " + f"out={output.size} expected={EXPECTED_WORDS}" + ) + return False + golden = golden[:PREFIX_WORDS] + output = output[:PREFIX_WORDS] + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed predicate words): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8/golden.py b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8/golden.py new file mode 100755 index 000000000..e0ed75f95 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pdintlv_b8 +# family: materialization-predicate +# target_ops: pto.pdintlv_b8 +# scenarios: predicate-transform, lane-order +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 32 +GOLDEN_PREFIX_WORDS = np.array([4294967295, 4294967295, 4294967295, 4294967295, 0, 0, 0, 0, 4294967295, 4294967295, 4294967295, 4294967295, 0, 0, 0, 0], dtype=np.uint32) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + output_init = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden[: GOLDEN_PREFIX_WORDS.size] = GOLDEN_PREFIX_WORDS + + output_dir.mkdir(parents=True, exist_ok=True) + output_init.tofile(output_dir / "v1.bin") + golden.tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate packed predicate golden for VPTO micro-op validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8/kernel.pto new file mode 100644 index 000000000..29170023f --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8/kernel.pto @@ -0,0 +1,38 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pdintlv_b8 +// family: materialization-predicate +// target_ops: pto.pdintlv_b8 +// scenarios: predicate-transform, lane-order +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @pdintlv_b8_kernel_2d(%arg0: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c16 = arith.constant 16 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %lhs = pto.pset_b8 "PAT_ALL" : !pto.mask + %rhs = pto.pset_b8 "PAT_ALLF" : !pto.mask + %low, %high = pto.pdintlv_b8 %lhs, %rhs : !pto.mask, !pto.mask -> !pto.mask, !pto.mask + pto.psts %low, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %high, %ub_out[%c32], "NORM" : !pto.mask, !pto.ptr, index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %gm_out, %c64_i64 + nburst(%c1_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8/launch.cpp new file mode 100644 index 000000000..a8a45dada --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pdintlv_b8 +// family: materialization-predicate +// target_ops: pto.pdintlv_b8 +// scenarios: predicate-transform, lane-order +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void pdintlv_b8_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPdintlvB8(uint32_t *v1, void *stream) { + pdintlv_b8_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8/main.cpp new file mode 100644 index 000000000..d0edfba37 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pdintlv_b8/main.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pdintlv_b8 +// family: materialization-predicate +// target_ops: pto.pdintlv_b8 +// scenarios: predicate-transform, lane-order +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); const char *_recent = aclGetRecentErrMsg(); if (_recent != nullptr && _recent[0] != '\0') { std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); } rc = 1; goto cleanup; } } while (0) + +void LaunchPdintlvB8(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPdintlvB8(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask-boundary/compare.py b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask-boundary/compare.py new file mode 100755 index 000000000..8700b0d56 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask-boundary/compare.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pge-tail-mask-boundary +# family: materialization-predicate +# target_ops: pto.pge_b16, pto.pge_b32, pto.pge_b8 +# scenarios: tail-mask, boundary +# coding=utf-8 + +import os +import sys +import numpy as np + +EXPECTED_WORDS = 32 + + +def compare_words(golden_path, output_path): + if not os.path.exists(output_path) or not os.path.exists(golden_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print(f"[ERROR] Unexpected word count: golden={golden.size} out={output.size}") + return False + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] Mismatch (packed predicate words): idx={idx} golden={int(golden[idx])} out={int(output[idx])}") + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask-boundary/golden.py b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask-boundary/golden.py new file mode 100755 index 000000000..67f3b65df --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask-boundary/golden.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pge-tail-mask-boundary +# family: materialization-predicate +# target_ops: pto.pge_b16, pto.pge_b32, pto.pge_b8 +# scenarios: tail-mask, boundary +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 32 + + +def _pack_prefix(active_lanes: int, bit_stride: int, store_bytes: int) -> np.ndarray: + out = np.zeros((store_bytes,), dtype=np.uint8) + for lane in range(active_lanes): + bit_index = lane * bit_stride + out[bit_index // 8] |= np.uint8(1 << (bit_index % 8)) + return out + + +def generate(output_dir: Path, seed: int) -> None: + del seed + v1 = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden = np.zeros((OUTPUT_WORDS * 4,), dtype=np.uint8) + golden[0:32] = _pack_prefix(active_lanes=1, bit_stride=1, store_bytes=32) + golden[32:64] = _pack_prefix(active_lanes=1, bit_stride=2, store_bytes=32) + golden[64:96] = _pack_prefix(active_lanes=1, bit_stride=4, store_bytes=32) + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + golden.view(np.uint32).tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask-boundary/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask-boundary/kernel.pto new file mode 100644 index 000000000..f906d01ce --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask-boundary/kernel.pto @@ -0,0 +1,39 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pge-tail-mask-boundary +// family: materialization-predicate +// target_ops: pto.pge_b16, pto.pge_b32, pto.pge_b8 +// scenarios: tail-mask, boundary +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @pge_tail_mask_boundary_kernel_2d(%arg0: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c96_i64 = arith.constant 96 : i64 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %m0 = pto.pge_b8 "PAT_VL1" : !pto.mask + %m1 = pto.pge_b16 "PAT_VL1" : !pto.mask + %m2 = pto.pge_b32 "PAT_VL1" : !pto.mask + pto.psts %m0, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %m1, %ub_out[%c32], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %m2, %ub_out[%c64], "NORM" : !pto.mask, !pto.ptr, index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %gm_out, %c96_i64 + nburst(%c1_i64, %c96_i64, %c96_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask-boundary/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask-boundary/launch.cpp new file mode 100644 index 000000000..6a52d74b3 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask-boundary/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pge-tail-mask-boundary +// family: materialization-predicate +// target_ops: pto.pge_b16, pto.pge_b32, pto.pge_b8 +// scenarios: tail-mask, boundary +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void pge_tail_mask_boundary_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPgeTailMaskBoundary(uint32_t *v1, void *stream) { + pge_tail_mask_boundary_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask-boundary/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask-boundary/main.cpp new file mode 100644 index 000000000..c0ca90a8c --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask-boundary/main.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pge-tail-mask-boundary +// family: materialization-predicate +// target_ops: pto.pge_b16, pto.pge_b32, pto.pge_b8 +// scenarios: tail-mask, boundary +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); const char *_recent = aclGetRecentErrMsg(); if (_recent != nullptr && _recent[0] != '\0') { std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); } rc = 1; goto cleanup; } } while (0) + +void LaunchPgeTailMaskBoundary(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPgeTailMaskBoundary(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask/compare.py b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask/compare.py new file mode 100755 index 000000000..2e598e7ae --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask/compare.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pge-tail-mask +# family: materialization-predicate +# target_ops: pto.pge_b16, pto.pge_b32, pto.pge_b8 +# scenarios: tail-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + +EXPECTED_WORDS = 32 + + +def compare_words(golden_path, output_path): + if not os.path.exists(output_path) or not os.path.exists(golden_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print(f"[ERROR] Unexpected word count: golden={golden.size} out={output.size}") + return False + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] Mismatch (packed predicate words): idx={idx} golden={int(golden[idx])} out={int(output[idx])}") + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask/golden.py b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask/golden.py new file mode 100755 index 000000000..823fc8889 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask/golden.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pge-tail-mask +# family: materialization-predicate +# target_ops: pto.pge_b16, pto.pge_b32, pto.pge_b8 +# scenarios: tail-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 32 + + +def _pack_prefix(active_lanes: int, bit_stride: int, store_bytes: int) -> np.ndarray: + out = np.zeros((store_bytes,), dtype=np.uint8) + for lane in range(active_lanes): + bit_index = lane * bit_stride + out[bit_index // 8] |= np.uint8(1 << (bit_index % 8)) + return out + + +def generate(output_dir: Path, seed: int) -> None: + del seed + v1 = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden = np.zeros((OUTPUT_WORDS * 4,), dtype=np.uint8) + golden[0:32] = _pack_prefix(active_lanes=8, bit_stride=1, store_bytes=32) + golden[32:64] = _pack_prefix(active_lanes=8, bit_stride=2, store_bytes=32) + golden[64:96] = _pack_prefix(active_lanes=8, bit_stride=4, store_bytes=32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + golden.view(np.uint32).tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate packed predicate golden for VPTO micro-op pge-tail-mask validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask/kernel.pto new file mode 100644 index 000000000..4aa80ce57 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask/kernel.pto @@ -0,0 +1,39 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pge-tail-mask +// family: materialization-predicate +// target_ops: pto.pge_b16, pto.pge_b32, pto.pge_b8 +// scenarios: tail-mask +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @pge_tail_mask_kernel_2d(%arg0: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c96_i64 = arith.constant 96 : i64 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %m0 = pto.pge_b8 "PAT_VL8" : !pto.mask + %m1 = pto.pge_b16 "PAT_VL8" : !pto.mask + %m2 = pto.pge_b32 "PAT_VL8" : !pto.mask + pto.psts %m0, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %m1, %ub_out[%c32], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %m2, %ub_out[%c64], "NORM" : !pto.mask, !pto.ptr, index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %gm_out, %c96_i64 + nburst(%c1_i64, %c96_i64, %c96_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask/launch.cpp new file mode 100644 index 000000000..c38434d88 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pge-tail-mask +// family: materialization-predicate +// target_ops: pto.pge_b16, pto.pge_b32, pto.pge_b8 +// scenarios: tail-mask +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void pge_tail_mask_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPgeTailMask(uint32_t *v1, void *stream) { + pge_tail_mask_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask/main.cpp new file mode 100644 index 000000000..dea4fa6c5 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pge-tail-mask/main.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pge-tail-mask +// family: materialization-predicate +// target_ops: pto.pge_b16, pto.pge_b32, pto.pge_b8 +// scenarios: tail-mask +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); const char *_recent = aclGetRecentErrMsg(); if (_recent != nullptr && _recent[0] != '\0') { std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); } rc = 1; goto cleanup; } } while (0) + +void LaunchPgeTailMask(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPgeTailMask(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16-nontrivial/compare.py b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16-nontrivial/compare.py new file mode 100755 index 000000000..7704bbfb7 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16-nontrivial/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pintlv_b16-nontrivial +# family: materialization-predicate +# target_ops: pto.pintlv_b16 +# scenarios: predicate-transform, lane-order, representative-logical-elements +# coding=utf-8 + +import os +import sys + +import numpy as np + + +EXPECTED_WORDS = 32 +PREFIX_WORDS = 16 + + +def compare_words(golden_path, output_path): + if not os.path.exists(output_path) or not os.path.exists(golden_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print( + f"[ERROR] Unexpected word count: golden={golden.size} " + f"out={output.size} expected={EXPECTED_WORDS}" + ) + return False + golden = golden[:PREFIX_WORDS] + output = output[:PREFIX_WORDS] + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed predicate words): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16-nontrivial/golden.py b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16-nontrivial/golden.py new file mode 100755 index 000000000..f1729a845 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16-nontrivial/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pintlv_b16-nontrivial +# family: materialization-predicate +# target_ops: pto.pintlv_b16 +# scenarios: predicate-transform, lane-order, representative-logical-elements +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 32 +GOLDEN_PREFIX_WORDS = np.array([286593301, 262148, 262148, 262148, 262148, 262148, 262148, 262148, 262148, 262148, 262148, 262148, 262148, 262148, 262148, 262148], dtype=np.uint32) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + output_init = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden[: GOLDEN_PREFIX_WORDS.size] = GOLDEN_PREFIX_WORDS + + output_dir.mkdir(parents=True, exist_ok=True) + output_init.tofile(output_dir / "v1.bin") + golden.tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate packed predicate golden for VPTO micro-op validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16-nontrivial/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16-nontrivial/kernel.pto new file mode 100644 index 000000000..f72ab5de7 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16-nontrivial/kernel.pto @@ -0,0 +1,38 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pintlv_b16-nontrivial +// family: materialization-predicate +// target_ops: pto.pintlv_b16 +// scenarios: predicate-transform, lane-order, nontrivial-pattern +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @pintlv_b16_nontrivial_kernel_2d(%arg0: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c16 = arith.constant 16 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %lhs = pto.pset_b16 "PAT_VL8" : !pto.mask + %rhs = pto.pset_b16 "PAT_M4" : !pto.mask + %low, %high = pto.pintlv_b16 %lhs, %rhs : !pto.mask, !pto.mask -> !pto.mask, !pto.mask + pto.psts %low, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %high, %ub_out[%c32], "NORM" : !pto.mask, !pto.ptr, index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %gm_out, %c64_i64 + nburst(%c1_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16-nontrivial/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16-nontrivial/launch.cpp new file mode 100644 index 000000000..57939cac6 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16-nontrivial/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pintlv_b16-nontrivial +// family: materialization-predicate +// target_ops: pto.pintlv_b16 +// scenarios: predicate-transform, lane-order, nontrivial-pattern +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void pintlv_b16_nontrivial_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPintlvB16Nontrivial(uint32_t *v1, void *stream) { + pintlv_b16_nontrivial_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16-nontrivial/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16-nontrivial/main.cpp new file mode 100644 index 000000000..aca9caf7a --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16-nontrivial/main.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pintlv_b16-nontrivial +// family: materialization-predicate +// target_ops: pto.pintlv_b16 +// scenarios: predicate-transform, lane-order, nontrivial-pattern +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); const char *_recent = aclGetRecentErrMsg(); if (_recent != nullptr && _recent[0] != '\0') { std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); } rc = 1; goto cleanup; } } while (0) + +void LaunchPintlvB16Nontrivial(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPintlvB16Nontrivial(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16/compare.py b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16/compare.py new file mode 100755 index 000000000..9c1deb9c8 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pintlv_b16 +# family: materialization-predicate +# target_ops: pto.pintlv_b16 +# scenarios: predicate-transform, lane-order +# coding=utf-8 + +import os +import sys + +import numpy as np + + +EXPECTED_WORDS = 32 +PREFIX_WORDS = 16 + + +def compare_words(golden_path, output_path): + if not os.path.exists(output_path) or not os.path.exists(golden_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print( + f"[ERROR] Unexpected word count: golden={golden.size} " + f"out={output.size} expected={EXPECTED_WORDS}" + ) + return False + golden = golden[:PREFIX_WORDS] + output = output[:PREFIX_WORDS] + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed predicate words): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16/golden.py b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16/golden.py new file mode 100755 index 000000000..a52cba6a2 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pintlv_b16 +# family: materialization-predicate +# target_ops: pto.pintlv_b16 +# scenarios: predicate-transform, lane-order +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 32 +GOLDEN_PREFIX_WORDS = np.array([286331153, 286331153, 286331153, 286331153, 286331153, 286331153, 286331153, 286331153, 286331153, 286331153, 286331153, 286331153, 286331153, 286331153, 286331153, 286331153], dtype=np.uint32) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + output_init = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden[: GOLDEN_PREFIX_WORDS.size] = GOLDEN_PREFIX_WORDS + + output_dir.mkdir(parents=True, exist_ok=True) + output_init.tofile(output_dir / "v1.bin") + golden.tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate packed predicate golden for VPTO micro-op validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16/kernel.pto new file mode 100644 index 000000000..1bbce0777 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16/kernel.pto @@ -0,0 +1,38 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pintlv_b16 +// family: materialization-predicate +// target_ops: pto.pintlv_b16 +// scenarios: predicate-transform, lane-order +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @pintlv_b16_kernel_2d(%arg0: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c16 = arith.constant 16 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %lhs = pto.pset_b16 "PAT_ALL" : !pto.mask + %rhs = pto.pset_b16 "PAT_ALLF" : !pto.mask + %low, %high = pto.pintlv_b16 %lhs, %rhs : !pto.mask, !pto.mask -> !pto.mask, !pto.mask + pto.psts %low, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %high, %ub_out[%c32], "NORM" : !pto.mask, !pto.ptr, index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %gm_out, %c64_i64 + nburst(%c1_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16/launch.cpp new file mode 100644 index 000000000..262d87427 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pintlv_b16 +// family: materialization-predicate +// target_ops: pto.pintlv_b16 +// scenarios: predicate-transform, lane-order +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void pintlv_b16_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPintlvB16(uint32_t *v1, void *stream) { + pintlv_b16_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16/main.cpp new file mode 100644 index 000000000..29156f86a --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b16/main.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pintlv_b16 +// family: materialization-predicate +// target_ops: pto.pintlv_b16 +// scenarios: predicate-transform, lane-order +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); const char *_recent = aclGetRecentErrMsg(); if (_recent != nullptr && _recent[0] != '\0') { std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); } rc = 1; goto cleanup; } } while (0) + +void LaunchPintlvB16(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPintlvB16(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32-nontrivial/compare.py b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32-nontrivial/compare.py new file mode 100755 index 000000000..daad8dae1 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32-nontrivial/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pintlv_b32-nontrivial +# family: materialization-predicate +# target_ops: pto.pintlv_b32 +# scenarios: predicate-transform, lane-order, representative-logical-elements +# coding=utf-8 + +import os +import sys + +import numpy as np + + +EXPECTED_WORDS = 32 +PREFIX_WORDS = 16 + + +def compare_words(golden_path, output_path): + if not os.path.exists(output_path) or not os.path.exists(golden_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print( + f"[ERROR] Unexpected word count: golden={golden.size} " + f"out={output.size} expected={EXPECTED_WORDS}" + ) + return False + golden = golden[:PREFIX_WORDS] + output = output[:PREFIX_WORDS] + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed predicate words): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32-nontrivial/golden.py b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32-nontrivial/golden.py new file mode 100755 index 000000000..c28bc3f71 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32-nontrivial/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pintlv_b32-nontrivial +# family: materialization-predicate +# target_ops: pto.pintlv_b32 +# scenarios: predicate-transform, lane-order, representative-logical-elements +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 32 +GOLDEN_PREFIX_WORDS = np.array([16843025, 16843025, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16], dtype=np.uint32) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + output_init = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden[: GOLDEN_PREFIX_WORDS.size] = GOLDEN_PREFIX_WORDS + + output_dir.mkdir(parents=True, exist_ok=True) + output_init.tofile(output_dir / "v1.bin") + golden.tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate packed predicate golden for VPTO micro-op validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32-nontrivial/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32-nontrivial/kernel.pto new file mode 100644 index 000000000..8540c0783 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32-nontrivial/kernel.pto @@ -0,0 +1,38 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pintlv_b32-nontrivial +// family: materialization-predicate +// target_ops: pto.pintlv_b32 +// scenarios: predicate-transform, lane-order, nontrivial-pattern +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @pintlv_b32_nontrivial_kernel_2d(%arg0: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c16 = arith.constant 16 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %lhs = pto.pset_b32 "PAT_VL8" : !pto.mask + %rhs = pto.pset_b32 "PAT_M4" : !pto.mask + %low, %high = pto.pintlv_b32 %lhs, %rhs : !pto.mask, !pto.mask -> !pto.mask, !pto.mask + pto.psts %low, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %high, %ub_out[%c32], "NORM" : !pto.mask, !pto.ptr, index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %gm_out, %c64_i64 + nburst(%c1_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32-nontrivial/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32-nontrivial/launch.cpp new file mode 100644 index 000000000..06dcd4072 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32-nontrivial/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pintlv_b32-nontrivial +// family: materialization-predicate +// target_ops: pto.pintlv_b32 +// scenarios: predicate-transform, lane-order, nontrivial-pattern +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void pintlv_b32_nontrivial_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPintlvB32Nontrivial(uint32_t *v1, void *stream) { + pintlv_b32_nontrivial_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32-nontrivial/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32-nontrivial/main.cpp new file mode 100644 index 000000000..befee95c9 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32-nontrivial/main.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pintlv_b32-nontrivial +// family: materialization-predicate +// target_ops: pto.pintlv_b32 +// scenarios: predicate-transform, lane-order, nontrivial-pattern +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); const char *_recent = aclGetRecentErrMsg(); if (_recent != nullptr && _recent[0] != '\0') { std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); } rc = 1; goto cleanup; } } while (0) + +void LaunchPintlvB32Nontrivial(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPintlvB32Nontrivial(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32/compare.py b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32/compare.py new file mode 100755 index 000000000..b3050eb69 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pintlv_b32 +# family: materialization-predicate +# target_ops: pto.pintlv_b32 +# scenarios: predicate-transform, lane-order +# coding=utf-8 + +import os +import sys + +import numpy as np + + +EXPECTED_WORDS = 32 +PREFIX_WORDS = 16 + + +def compare_words(golden_path, output_path): + if not os.path.exists(output_path) or not os.path.exists(golden_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print( + f"[ERROR] Unexpected word count: golden={golden.size} " + f"out={output.size} expected={EXPECTED_WORDS}" + ) + return False + golden = golden[:PREFIX_WORDS] + output = output[:PREFIX_WORDS] + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed predicate words): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32/golden.py b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32/golden.py new file mode 100755 index 000000000..67cb39fc8 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pintlv_b32 +# family: materialization-predicate +# target_ops: pto.pintlv_b32 +# scenarios: predicate-transform, lane-order +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 32 +GOLDEN_PREFIX_WORDS = np.array([16843009, 16843009, 16843009, 16843009, 16843009, 16843009, 16843009, 16843009, 16843009, 16843009, 16843009, 16843009, 16843009, 16843009, 16843009, 16843009], dtype=np.uint32) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + output_init = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden[: GOLDEN_PREFIX_WORDS.size] = GOLDEN_PREFIX_WORDS + + output_dir.mkdir(parents=True, exist_ok=True) + output_init.tofile(output_dir / "v1.bin") + golden.tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate packed predicate golden for VPTO micro-op validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32/kernel.pto new file mode 100644 index 000000000..d2c3d75ee --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32/kernel.pto @@ -0,0 +1,38 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pintlv_b32 +// family: materialization-predicate +// target_ops: pto.pintlv_b32 +// scenarios: predicate-transform, lane-order +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @pintlv_b32_kernel_2d(%arg0: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c16 = arith.constant 16 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %lhs = pto.pset_b32 "PAT_ALL" : !pto.mask + %rhs = pto.pset_b32 "PAT_ALLF" : !pto.mask + %low, %high = pto.pintlv_b32 %lhs, %rhs : !pto.mask, !pto.mask -> !pto.mask, !pto.mask + pto.psts %low, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %high, %ub_out[%c32], "NORM" : !pto.mask, !pto.ptr, index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %gm_out, %c64_i64 + nburst(%c1_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32/launch.cpp new file mode 100644 index 000000000..bb990592f --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pintlv_b32 +// family: materialization-predicate +// target_ops: pto.pintlv_b32 +// scenarios: predicate-transform, lane-order +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void pintlv_b32_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPintlvB32(uint32_t *v1, void *stream) { + pintlv_b32_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32/main.cpp new file mode 100644 index 000000000..d0ef0696d --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b32/main.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pintlv_b32 +// family: materialization-predicate +// target_ops: pto.pintlv_b32 +// scenarios: predicate-transform, lane-order +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); const char *_recent = aclGetRecentErrMsg(); if (_recent != nullptr && _recent[0] != '\0') { std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); } rc = 1; goto cleanup; } } while (0) + +void LaunchPintlvB32(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPintlvB32(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8-nontrivial/compare.py b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8-nontrivial/compare.py new file mode 100755 index 000000000..16b0c224d --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8-nontrivial/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pintlv_b8-nontrivial +# family: materialization-predicate +# target_ops: pto.pintlv_b8 +# scenarios: predicate-transform, lane-order, representative-logical-elements +# coding=utf-8 + +import os +import sys + +import numpy as np + + +EXPECTED_WORDS = 32 +PREFIX_WORDS = 16 + + +def compare_words(golden_path, output_path): + if not os.path.exists(output_path) or not os.path.exists(golden_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print( + f"[ERROR] Unexpected word count: golden={golden.size} " + f"out={output.size} expected={EXPECTED_WORDS}" + ) + return False + golden = golden[:PREFIX_WORDS] + output = output[:PREFIX_WORDS] + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed predicate words): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8-nontrivial/golden.py b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8-nontrivial/golden.py new file mode 100755 index 000000000..de8ae6216 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8-nontrivial/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pintlv_b8-nontrivial +# family: materialization-predicate +# target_ops: pto.pintlv_b8 +# scenarios: predicate-transform, lane-order, representative-logical-elements +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 32 +GOLDEN_PREFIX_WORDS = np.array([33707863, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018], dtype=np.uint32) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + output_init = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden[: GOLDEN_PREFIX_WORDS.size] = GOLDEN_PREFIX_WORDS + + output_dir.mkdir(parents=True, exist_ok=True) + output_init.tofile(output_dir / "v1.bin") + golden.tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate packed predicate golden for VPTO micro-op validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8-nontrivial/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8-nontrivial/kernel.pto new file mode 100644 index 000000000..1786577ad --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8-nontrivial/kernel.pto @@ -0,0 +1,38 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pintlv_b8-nontrivial +// family: materialization-predicate +// target_ops: pto.pintlv_b8 +// scenarios: predicate-transform, lane-order, nontrivial-pattern +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @pintlv_b8_nontrivial_kernel_2d(%arg0: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c16 = arith.constant 16 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %lhs = pto.pset_b8 "PAT_VL8" : !pto.mask + %rhs = pto.pset_b8 "PAT_M4" : !pto.mask + %low, %high = pto.pintlv_b8 %lhs, %rhs : !pto.mask, !pto.mask -> !pto.mask, !pto.mask + pto.psts %low, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %high, %ub_out[%c32], "NORM" : !pto.mask, !pto.ptr, index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %gm_out, %c64_i64 + nburst(%c1_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8-nontrivial/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8-nontrivial/launch.cpp new file mode 100644 index 000000000..c466ef9c8 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8-nontrivial/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pintlv_b8-nontrivial +// family: materialization-predicate +// target_ops: pto.pintlv_b8 +// scenarios: predicate-transform, lane-order, nontrivial-pattern +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void pintlv_b8_nontrivial_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPintlvB8Nontrivial(uint32_t *v1, void *stream) { + pintlv_b8_nontrivial_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8-nontrivial/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8-nontrivial/main.cpp new file mode 100644 index 000000000..d27a5f08c --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8-nontrivial/main.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pintlv_b8-nontrivial +// family: materialization-predicate +// target_ops: pto.pintlv_b8 +// scenarios: predicate-transform, lane-order, nontrivial-pattern +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); const char *_recent = aclGetRecentErrMsg(); if (_recent != nullptr && _recent[0] != '\0') { std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); } rc = 1; goto cleanup; } } while (0) + +void LaunchPintlvB8Nontrivial(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPintlvB8Nontrivial(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8/compare.py b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8/compare.py new file mode 100755 index 000000000..d6cae1168 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pintlv_b8 +# family: materialization-predicate +# target_ops: pto.pintlv_b8 +# scenarios: predicate-transform, lane-order +# coding=utf-8 + +import os +import sys + +import numpy as np + + +EXPECTED_WORDS = 32 +PREFIX_WORDS = 16 + + +def compare_words(golden_path, output_path): + if not os.path.exists(output_path) or not os.path.exists(golden_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print( + f"[ERROR] Unexpected word count: golden={golden.size} " + f"out={output.size} expected={EXPECTED_WORDS}" + ) + return False + golden = golden[:PREFIX_WORDS] + output = output[:PREFIX_WORDS] + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed predicate words): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8/golden.py b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8/golden.py new file mode 100755 index 000000000..bae1a196c --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pintlv_b8 +# family: materialization-predicate +# target_ops: pto.pintlv_b8 +# scenarios: predicate-transform, lane-order +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 32 +GOLDEN_PREFIX_WORDS = np.array([1431655765, 1431655765, 1431655765, 1431655765, 1431655765, 1431655765, 1431655765, 1431655765, 1431655765, 1431655765, 1431655765, 1431655765, 1431655765, 1431655765, 1431655765, 1431655765], dtype=np.uint32) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + output_init = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden[: GOLDEN_PREFIX_WORDS.size] = GOLDEN_PREFIX_WORDS + + output_dir.mkdir(parents=True, exist_ok=True) + output_init.tofile(output_dir / "v1.bin") + golden.tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate packed predicate golden for VPTO micro-op validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8/kernel.pto new file mode 100644 index 000000000..863cffee4 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8/kernel.pto @@ -0,0 +1,38 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pintlv_b8 +// family: materialization-predicate +// target_ops: pto.pintlv_b8 +// scenarios: predicate-transform, lane-order +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @pintlv_b8_kernel_2d(%arg0: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c16 = arith.constant 16 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %lhs = pto.pset_b8 "PAT_ALL" : !pto.mask + %rhs = pto.pset_b8 "PAT_ALLF" : !pto.mask + %low, %high = pto.pintlv_b8 %lhs, %rhs : !pto.mask, !pto.mask -> !pto.mask, !pto.mask + pto.psts %low, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %high, %ub_out[%c32], "NORM" : !pto.mask, !pto.ptr, index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %gm_out, %c64_i64 + nburst(%c1_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8/launch.cpp new file mode 100644 index 000000000..d0299e575 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pintlv_b8 +// family: materialization-predicate +// target_ops: pto.pintlv_b8 +// scenarios: predicate-transform, lane-order +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void pintlv_b8_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPintlvB8(uint32_t *v1, void *stream) { + pintlv_b8_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8/main.cpp new file mode 100644 index 000000000..b4b856773 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pintlv_b8/main.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pintlv_b8 +// family: materialization-predicate +// target_ops: pto.pintlv_b8 +// scenarios: predicate-transform, lane-order +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); const char *_recent = aclGetRecentErrMsg(); if (_recent != nullptr && _recent[0] != '\0') { std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); } rc = 1; goto cleanup; } } while (0) + +void LaunchPintlvB8(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPintlvB8(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask-boundary/compare.py b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask-boundary/compare.py new file mode 100755 index 000000000..8eac93173 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask-boundary/compare.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/plt-tail-mask-boundary +# family: materialization-predicate +# target_ops: pto.plt_b16, pto.plt_b32, pto.plt_b8 +# scenarios: tail-mask, scalar-carry-out, boundary +# coding=utf-8 + +import os +import sys +import numpy as np + +EXPECTED_WORDS = 32 + + +def compare_words(golden_path, output_path): + if not os.path.exists(output_path) or not os.path.exists(golden_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print(f"[ERROR] Unexpected word count: golden={golden.size} out={output.size}") + return False + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] Mismatch (packed predicate words): idx={idx} golden={int(golden[idx])} out={int(output[idx])}") + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask-boundary/golden.py b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask-boundary/golden.py new file mode 100755 index 000000000..c812554d8 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask-boundary/golden.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/plt-tail-mask-boundary +# family: materialization-predicate +# target_ops: pto.plt_b16, pto.plt_b32, pto.plt_b8 +# scenarios: tail-mask, scalar-carry-out, boundary +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 32 + + +def _pack_prefix(active_lanes: int, bit_stride: int, store_bytes: int) -> np.ndarray: + out = np.zeros((store_bytes,), dtype=np.uint8) + for lane in range(active_lanes): + bit_index = lane * bit_stride + out[bit_index // 8] |= np.uint8(1 << (bit_index % 8)) + return out + + +def generate(output_dir: Path, seed: int) -> None: + del seed + v1 = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden = np.zeros((OUTPUT_WORDS * 4,), dtype=np.uint8) + golden[0:32] = _pack_prefix(active_lanes=1, bit_stride=1, store_bytes=32) + golden[32:64] = _pack_prefix(active_lanes=1, bit_stride=2, store_bytes=32) + golden[64:96] = _pack_prefix(active_lanes=1, bit_stride=4, store_bytes=32) + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + golden.view(np.uint32).tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask-boundary/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask-boundary/kernel.pto new file mode 100644 index 000000000..6aff2409b --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask-boundary/kernel.pto @@ -0,0 +1,40 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/plt-tail-mask-boundary +// family: materialization-predicate +// target_ops: pto.plt_b16, pto.plt_b32, pto.plt_b8 +// scenarios: tail-mask, scalar-carry-out, boundary +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @plt_tail_mask_boundary_kernel_2d(%arg0: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c96_i64 = arith.constant 96 : i64 + %c1_i32 = arith.constant 1 : i32 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %m0, %s0 = pto.plt_b8 %c1_i32 : i32 -> !pto.mask, i32 + %m1, %s1 = pto.plt_b16 %c1_i32 : i32 -> !pto.mask, i32 + %m2, %s2 = pto.plt_b32 %c1_i32 : i32 -> !pto.mask, i32 + pto.psts %m0, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %m1, %ub_out[%c32], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %m2, %ub_out[%c64], "NORM" : !pto.mask, !pto.ptr, index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %gm_out, %c96_i64 + nburst(%c1_i64, %c96_i64, %c96_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask-boundary/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask-boundary/launch.cpp new file mode 100644 index 000000000..fad09577d --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask-boundary/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/plt-tail-mask-boundary +// family: materialization-predicate +// target_ops: pto.plt_b16, pto.plt_b32, pto.plt_b8 +// scenarios: tail-mask, scalar-carry-out, boundary +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void plt_tail_mask_boundary_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPltTailMaskBoundary(uint32_t *v1, void *stream) { + plt_tail_mask_boundary_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask-boundary/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask-boundary/main.cpp new file mode 100644 index 000000000..0dfe9d502 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask-boundary/main.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/plt-tail-mask-boundary +// family: materialization-predicate +// target_ops: pto.plt_b16, pto.plt_b32, pto.plt_b8 +// scenarios: tail-mask, scalar-carry-out, boundary +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); const char *_recent = aclGetRecentErrMsg(); if (_recent != nullptr && _recent[0] != '\0') { std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); } rc = 1; goto cleanup; } } while (0) + +void LaunchPltTailMaskBoundary(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPltTailMaskBoundary(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask/compare.py b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask/compare.py new file mode 100755 index 000000000..b2466b6ab --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask/compare.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/plt-tail-mask +# family: materialization-predicate +# target_ops: pto.plt_b16, pto.plt_b32, pto.plt_b8 +# scenarios: tail-mask, scalar-carry-out +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + +EXPECTED_WORDS = 32 + + +def compare_words(golden_path, output_path): + if not os.path.exists(output_path) or not os.path.exists(golden_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print(f"[ERROR] Unexpected word count: golden={golden.size} out={output.size}") + return False + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] Mismatch (packed predicate words): idx={idx} golden={int(golden[idx])} out={int(output[idx])}") + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask/golden.py b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask/golden.py new file mode 100755 index 000000000..1ef2a1a79 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask/golden.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/plt-tail-mask +# family: materialization-predicate +# target_ops: pto.plt_b16, pto.plt_b32, pto.plt_b8 +# scenarios: tail-mask, scalar-carry-out +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 32 + + +def _pack_prefix(active_lanes: int, bit_stride: int, store_bytes: int) -> np.ndarray: + out = np.zeros((store_bytes,), dtype=np.uint8) + for lane in range(active_lanes): + bit_index = lane * bit_stride + out[bit_index // 8] |= np.uint8(1 << (bit_index % 8)) + return out + + +def generate(output_dir: Path, seed: int) -> None: + del seed + v1 = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden = np.zeros((OUTPUT_WORDS * 4,), dtype=np.uint8) + golden[0:32] = _pack_prefix(active_lanes=13, bit_stride=1, store_bytes=32) + golden[32:64] = _pack_prefix(active_lanes=7, bit_stride=2, store_bytes=32) + golden[64:96] = _pack_prefix(active_lanes=3, bit_stride=4, store_bytes=32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + golden.view(np.uint32).tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate packed predicate golden for VPTO micro-op plt-tail-mask validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask/kernel.pto new file mode 100644 index 000000000..869d06b17 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask/kernel.pto @@ -0,0 +1,42 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/plt-tail-mask +// family: materialization-predicate +// target_ops: pto.plt_b16, pto.plt_b32, pto.plt_b8 +// scenarios: tail-mask, scalar-carry-out +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @plt_tail_mask_kernel_2d(%arg0: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c96_i64 = arith.constant 96 : i64 + %c13 = arith.constant 13 : i32 + %c7 = arith.constant 7 : i32 + %c3 = arith.constant 3 : i32 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %m0, %s0 = pto.plt_b8 %c13 : i32 -> !pto.mask, i32 + %m1, %s1 = pto.plt_b16 %c7 : i32 -> !pto.mask, i32 + %m2, %s2 = pto.plt_b32 %c3 : i32 -> !pto.mask, i32 + pto.psts %m0, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %m1, %ub_out[%c32], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %m2, %ub_out[%c64], "NORM" : !pto.mask, !pto.ptr, index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %gm_out, %c96_i64 + nburst(%c1_i64, %c96_i64, %c96_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask/launch.cpp new file mode 100644 index 000000000..1c9b21d24 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/plt-tail-mask +// family: materialization-predicate +// target_ops: pto.plt_b16, pto.plt_b32, pto.plt_b8 +// scenarios: tail-mask, scalar-carry-out +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void plt_tail_mask_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPltTailMask(uint32_t *v1, void *stream) { + plt_tail_mask_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask/main.cpp new file mode 100644 index 000000000..7fdb5fe93 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/plt-tail-mask/main.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/plt-tail-mask +// family: materialization-predicate +// target_ops: pto.plt_b16, pto.plt_b32, pto.plt_b8 +// scenarios: tail-mask, scalar-carry-out +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); const char *_recent = aclGetRecentErrMsg(); if (_recent != nullptr && _recent[0] != '\0') { std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); } rc = 1; goto cleanup; } } while (0) + +void LaunchPltTailMask(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPltTailMask(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pnot/compare.py b/test/vpto/cases/micro-op/materialization-predicate/pnot/compare.py new file mode 100755 index 000000000..a75c98c84 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pnot/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pnot +# family: materialization-predicate +# target_ops: pto.pnot +# scenarios: predicate-transform, logical-not +# coding=utf-8 + +import os +import sys + +import numpy as np + + +EXPECTED_WORDS = 32 +PREFIX_WORDS = 8 + + +def compare_words(golden_path, output_path): + if not os.path.exists(output_path) or not os.path.exists(golden_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print( + f"[ERROR] Unexpected word count: golden={golden.size} " + f"out={output.size} expected={EXPECTED_WORDS}" + ) + return False + golden = golden[:PREFIX_WORDS] + output = output[:PREFIX_WORDS] + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed predicate words): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pnot/golden.py b/test/vpto/cases/micro-op/materialization-predicate/pnot/golden.py new file mode 100755 index 000000000..cafe4d7d0 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pnot/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pnot +# family: materialization-predicate +# target_ops: pto.pnot +# scenarios: predicate-transform, logical-not +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 32 +GOLDEN_PREFIX_WORDS = np.array([0, 286261248, 286331153, 286331153, 286331153, 286331153, 286331153, 286331153], dtype=np.uint32) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + output_init = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden[: GOLDEN_PREFIX_WORDS.size] = GOLDEN_PREFIX_WORDS + + output_dir.mkdir(parents=True, exist_ok=True) + output_init.tofile(output_dir / "v1.bin") + golden.tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate packed predicate golden for VPTO micro-op validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pnot/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pnot/kernel.pto new file mode 100644 index 000000000..02e6b942e --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pnot/kernel.pto @@ -0,0 +1,37 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pnot +// family: materialization-predicate +// target_ops: pto.pnot +// scenarios: predicate-transform +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @pnot_kernel_2d(%arg0: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c13 = arith.constant 13 : i32 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %all = pto.pset_b32 "PAT_ALL" : !pto.mask + %half, %next = pto.plt_b32 %c13 : i32 -> !pto.mask, i32 + %out = pto.pnot %half, %all : !pto.mask, !pto.mask -> !pto.mask + pto.psts %out, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %gm_out, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pnot/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/pnot/launch.cpp new file mode 100644 index 000000000..50cf29220 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pnot/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pnot +// family: materialization-predicate +// target_ops: pto.pnot +// scenarios: predicate-transform +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void pnot_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPnot(uint32_t *v1, void *stream) { + pnot_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pnot/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/pnot/main.cpp new file mode 100644 index 000000000..64b153376 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pnot/main.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pnot +// family: materialization-predicate +// target_ops: pto.pnot +// scenarios: predicate-transform +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); const char *_recent = aclGetRecentErrMsg(); if (_recent != nullptr && _recent[0] != '\0') { std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); } rc = 1; goto cleanup; } } while (0) + +void LaunchPnot(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPnot(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/por/compare.py b/test/vpto/cases/micro-op/materialization-predicate/por/compare.py new file mode 100755 index 000000000..2d6c341a8 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/por/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/por +# family: materialization-predicate +# target_ops: pto.por +# scenarios: predicate-transform, logical-or +# coding=utf-8 + +import os +import sys + +import numpy as np + + +EXPECTED_WORDS = 32 +PREFIX_WORDS = 8 + + +def compare_words(golden_path, output_path): + if not os.path.exists(output_path) or not os.path.exists(golden_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print( + f"[ERROR] Unexpected word count: golden={golden.size} " + f"out={output.size} expected={EXPECTED_WORDS}" + ) + return False + golden = golden[:PREFIX_WORDS] + output = output[:PREFIX_WORDS] + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed predicate words): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/por/golden.py b/test/vpto/cases/micro-op/materialization-predicate/por/golden.py new file mode 100755 index 000000000..c9c5dfe1e --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/por/golden.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/por +# family: materialization-predicate +# target_ops: pto.por +# scenarios: predicate-transform, logical-or +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 32 +PREFIX_BITS = 13 +SUFFIX_BITS = 7 +PREDICATE_BITS = 256 +NIBBLE_COUNT = PREDICATE_BITS // 2 + + +def pack_nibbles(nibbles: np.ndarray) -> np.ndarray: + words = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + for idx, nibble in enumerate(nibbles): + words[idx // 8] |= np.uint32(int(nibble) & 0xF) << np.uint32((idx % 8) * 4) + return words + + +def generate(output_dir: Path, seed: int) -> None: + del seed + output_init = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + lhs = np.zeros((NIBBLE_COUNT,), dtype=np.uint8) + rhs = np.zeros((NIBBLE_COUNT,), dtype=np.uint8) + lhs[:PREFIX_BITS] = 1 + rhs[:SUFFIX_BITS] = 1 + golden = pack_nibbles(lhs | rhs) + + output_dir.mkdir(parents=True, exist_ok=True) + output_init.tofile(output_dir / "v1.bin") + golden.tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate packed predicate golden for VPTO micro-op validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/por/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/por/kernel.pto new file mode 100644 index 000000000..3eea2cb87 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/por/kernel.pto @@ -0,0 +1,56 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/por +// family: materialization-predicate +// target_ops: pto.por +// scenarios: predicate-transform +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @por_kernel_2d(%arg0: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c7 = arith.constant 7 : i32 + %c13 = arith.constant 13 : i32 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + %all = pto.pset_b32 "PAT_ALL" : !pto.mask + %lhs, %lhs_next = pto.plt_b32 %c13 : i32 -> !pto.mask, i32 + %rhs, %rhs_next = pto.plt_b32 %c7 : i32 -> !pto.mask, i32 + %out = pto.por %lhs, %rhs, %all : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + pto.psts %out, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %gm_out, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/por/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/por/launch.cpp new file mode 100644 index 000000000..caa8684a4 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/por/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/por +// family: materialization-predicate +// target_ops: pto.por +// scenarios: predicate-transform +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void por_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPor(uint32_t *v1, void *stream) { + por_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/por/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/por/main.cpp new file mode 100644 index 000000000..527116eff --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/por/main.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/por +// family: materialization-predicate +// target_ops: pto.por +// scenarios: predicate-transform +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); const char *_recent = aclGetRecentErrMsg(); if (_recent != nullptr && _recent[0] != '\0') { std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); } rc = 1; goto cleanup; } } while (0) + +void LaunchPor(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPor(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack-nontrivial/compare.py b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack-nontrivial/compare.py new file mode 100755 index 000000000..2585ff1e4 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack-nontrivial/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/ppack-punpack-nontrivial +# family: materialization-predicate +# target_ops: pto.ppack, pto.punpack +# scenarios: pack-unpack-roundtrip, representative-logical-elements +# coding=utf-8 + +import os +import sys + +import numpy as np + + +EXPECTED_WORDS = 32 +PREFIX_WORDS = 16 + + +def compare_words(golden_path, output_path): + if not os.path.exists(output_path) or not os.path.exists(golden_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print( + f"[ERROR] Unexpected word count: golden={golden.size} " + f"out={output.size} expected={EXPECTED_WORDS}" + ) + return False + golden = golden[:PREFIX_WORDS] + output = output[:PREFIX_WORDS] + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed predicate words): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack-nontrivial/golden.py b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack-nontrivial/golden.py new file mode 100755 index 000000000..7923992f9 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack-nontrivial/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/ppack-punpack-nontrivial +# family: materialization-predicate +# target_ops: pto.ppack, pto.punpack +# scenarios: pack-unpack-roundtrip, representative-logical-elements +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 32 +GOLDEN_PREFIX_WORDS = np.array([16843009, 16843009, 16843009, 16843009, 0, 0, 0, 0, 65537, 65537, 65537, 65537, 65537, 65537, 65537, 65537], dtype=np.uint32) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + output_init = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden[: GOLDEN_PREFIX_WORDS.size] = GOLDEN_PREFIX_WORDS + + output_dir.mkdir(parents=True, exist_ok=True) + output_init.tofile(output_dir / "v1.bin") + golden.tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate packed predicate golden for VPTO micro-op validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack-nontrivial/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack-nontrivial/kernel.pto new file mode 100644 index 000000000..b387f0b78 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack-nontrivial/kernel.pto @@ -0,0 +1,37 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/ppack-punpack-nontrivial +// family: materialization-predicate +// target_ops: pto.ppack, pto.punpack +// scenarios: pack-unpack-roundtrip, nontrivial-pattern +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @ppack_punpack_nontrivial_kernel_2d(%arg0: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %src = pto.pset_b32 "PAT_M4" : !pto.mask + %packed = pto.ppack %src, "LOWER" : !pto.mask -> !pto.mask + %roundtrip = pto.punpack %packed, "LOWER" : !pto.mask -> !pto.mask + pto.psts %packed, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %roundtrip, %ub_out[%c32], "NORM" : !pto.mask, !pto.ptr, index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %gm_out, %c64_i64 + nburst(%c1_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack-nontrivial/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack-nontrivial/launch.cpp new file mode 100644 index 000000000..aac69efd1 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack-nontrivial/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/ppack-punpack-nontrivial +// family: materialization-predicate +// target_ops: pto.ppack, pto.punpack +// scenarios: pack-unpack-roundtrip, nontrivial-pattern +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void ppack_punpack_nontrivial_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPpackPunpackNontrivial(uint32_t *v1, void *stream) { + ppack_punpack_nontrivial_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack-nontrivial/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack-nontrivial/main.cpp new file mode 100644 index 000000000..30feafd01 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack-nontrivial/main.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/ppack-punpack-nontrivial +// family: materialization-predicate +// target_ops: pto.ppack, pto.punpack +// scenarios: pack-unpack-roundtrip, nontrivial-pattern +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); const char *_recent = aclGetRecentErrMsg(); if (_recent != nullptr && _recent[0] != '\0') { std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); } rc = 1; goto cleanup; } } while (0) + +void LaunchPpackPunpackNontrivial(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPpackPunpackNontrivial(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack/compare.py b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack/compare.py new file mode 100755 index 000000000..08e54575d --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/ppack-punpack +# family: materialization-predicate +# target_ops: pto.ppack, pto.punpack +# scenarios: pack-unpack-roundtrip +# coding=utf-8 + +import os +import sys + +import numpy as np + + +EXPECTED_WORDS = 32 +PREFIX_WORDS = 16 + + +def compare_words(golden_path, output_path): + if not os.path.exists(output_path) or not os.path.exists(golden_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print( + f"[ERROR] Unexpected word count: golden={golden.size} " + f"out={output.size} expected={EXPECTED_WORDS}" + ) + return False + golden = golden[:PREFIX_WORDS] + output = output[:PREFIX_WORDS] + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed predicate words): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack/golden.py b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack/golden.py new file mode 100755 index 000000000..05d8b350b --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/ppack-punpack +# family: materialization-predicate +# target_ops: pto.ppack, pto.punpack +# scenarios: pack-unpack-roundtrip +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 32 +GOLDEN_PREFIX_WORDS = np.array([1431655765, 1431655765, 1431655765, 1431655765, 0, 0, 0, 0, 286331153, 286331153, 286331153, 286331153, 286331153, 286331153, 286331153, 286331153], dtype=np.uint32) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + output_init = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden[: GOLDEN_PREFIX_WORDS.size] = GOLDEN_PREFIX_WORDS + + output_dir.mkdir(parents=True, exist_ok=True) + output_init.tofile(output_dir / "v1.bin") + golden.tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate packed predicate golden for VPTO micro-op validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack/kernel.pto new file mode 100644 index 000000000..3e701025c --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack/kernel.pto @@ -0,0 +1,37 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/ppack-punpack +// family: materialization-predicate +// target_ops: pto.ppack, pto.punpack +// scenarios: pack-unpack-roundtrip +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @ppack_punpack_kernel_2d(%arg0: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %src = pto.pset_b32 "PAT_ALL" : !pto.mask + %packed = pto.ppack %src, "LOWER" : !pto.mask -> !pto.mask + %roundtrip = pto.punpack %packed, "LOWER" : !pto.mask -> !pto.mask + pto.psts %packed, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %roundtrip, %ub_out[%c32], "NORM" : !pto.mask, !pto.ptr, index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %gm_out, %c64_i64 + nburst(%c1_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack/launch.cpp new file mode 100644 index 000000000..2dc4b848d --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/ppack-punpack +// family: materialization-predicate +// target_ops: pto.ppack, pto.punpack +// scenarios: pack-unpack-roundtrip +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void ppack_punpack_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPpackPunpack(uint32_t *v1, void *stream) { + ppack_punpack_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack/main.cpp new file mode 100644 index 000000000..0ad47bb5d --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/ppack-punpack/main.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/ppack-punpack +// family: materialization-predicate +// target_ops: pto.ppack, pto.punpack +// scenarios: pack-unpack-roundtrip +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); const char *_recent = aclGetRecentErrMsg(); if (_recent != nullptr && _recent[0] != '\0') { std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); } rc = 1; goto cleanup; } } while (0) + +void LaunchPpackPunpack(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPpackPunpack(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/psel-tail-predicate/compare.py b/test/vpto/cases/micro-op/materialization-predicate/psel-tail-predicate/compare.py new file mode 100755 index 000000000..d334ff512 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/psel-tail-predicate/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/psel-tail-predicate +# family: materialization-predicate +# target_ops: pto.psel +# scenarios: predicate-transform, predicate-select, tail-mask +# coding=utf-8 + +import os +import sys + +import numpy as np + + +EXPECTED_WORDS = 32 +PREFIX_WORDS = 16 + + +def compare_words(golden_path, output_path): + if not os.path.exists(output_path) or not os.path.exists(golden_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print( + f"[ERROR] Unexpected word count: golden={golden.size} " + f"out={output.size} expected={EXPECTED_WORDS}" + ) + return False + golden = golden[:PREFIX_WORDS] + output = output[:PREFIX_WORDS] + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed predicate words): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/psel-tail-predicate/golden.py b/test/vpto/cases/micro-op/materialization-predicate/psel-tail-predicate/golden.py new file mode 100755 index 000000000..0144a0f36 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/psel-tail-predicate/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/psel-tail-predicate +# family: materialization-predicate +# target_ops: pto.psel +# scenarios: predicate-transform, predicate-select, tail-mask +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 32 +GOLDEN_PREFIX_WORDS = np.array([286331153, 69905, 0, 0, 0, 0, 0, 0, 286331153, 286331153, 286331153, 286331153, 286331153, 286331153, 286331153, 286331153], dtype=np.uint32) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + output_init = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden[: GOLDEN_PREFIX_WORDS.size] = GOLDEN_PREFIX_WORDS + + output_dir.mkdir(parents=True, exist_ok=True) + output_init.tofile(output_dir / "v1.bin") + golden.tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate packed predicate golden for VPTO micro-op validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/psel-tail-predicate/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/psel-tail-predicate/kernel.pto new file mode 100644 index 000000000..4dab858c2 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/psel-tail-predicate/kernel.pto @@ -0,0 +1,41 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/psel-tail-predicate +// family: materialization-predicate +// target_ops: pto.psel +// scenarios: predicate-transform, predicate-select, tail-mask +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @psel_tail_kernel_2d(%arg0: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c13 = arith.constant 13 : i32 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %src0 = pto.pset_b32 "PAT_ALL" : !pto.mask + %sel, %next = pto.plt_b32 %c13 : i32 -> !pto.mask, i32 + %out = pto.psel %src0, %sel, %sel : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + pto.psts %out, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + %out_next = pto.psel %sel, %src0, %sel : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + pto.psts %out_next, %ub_out[%c32], "NORM" : !pto.mask, !pto.ptr, index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %gm_out, %c64_i64 + nburst(%c1_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/psel-tail-predicate/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/psel-tail-predicate/launch.cpp new file mode 100644 index 000000000..e4c8692cf --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/psel-tail-predicate/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/psel-tail-predicate +// family: materialization-predicate +// target_ops: pto.psel +// scenarios: predicate-transform, predicate-select, tail-mask +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void psel_tail_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPselTailPredicate(uint32_t *v1, void *stream) { + psel_tail_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/psel-tail-predicate/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/psel-tail-predicate/main.cpp new file mode 100644 index 000000000..5a5996be6 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/psel-tail-predicate/main.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/psel-tail-predicate +// family: materialization-predicate +// target_ops: pto.psel +// scenarios: predicate-transform, predicate-select, tail-mask +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); const char *_recent = aclGetRecentErrMsg(); if (_recent != nullptr && _recent[0] != '\0') { std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); } rc = 1; goto cleanup; } } while (0) + +void LaunchPselTailPredicate(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPselTailPredicate(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/psel/compare.py b/test/vpto/cases/micro-op/materialization-predicate/psel/compare.py new file mode 100755 index 000000000..fb258a13e --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/psel/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/psel +# family: materialization-predicate +# target_ops: pto.psel +# scenarios: predicate-transform, predicate-select +# coding=utf-8 + +import os +import sys + +import numpy as np + + +EXPECTED_WORDS = 32 +PREFIX_WORDS = 8 + + +def compare_words(golden_path, output_path): + if not os.path.exists(output_path) or not os.path.exists(golden_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print( + f"[ERROR] Unexpected word count: golden={golden.size} " + f"out={output.size} expected={EXPECTED_WORDS}" + ) + return False + golden = golden[:PREFIX_WORDS] + output = output[:PREFIX_WORDS] + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed predicate words): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/psel/golden.py b/test/vpto/cases/micro-op/materialization-predicate/psel/golden.py new file mode 100755 index 000000000..101269c58 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/psel/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/psel +# family: materialization-predicate +# target_ops: pto.psel +# scenarios: predicate-transform, predicate-select +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 32 +GOLDEN_PREFIX_WORDS = np.array([286331153, 69905, 0, 0, 0, 0, 0, 0], dtype=np.uint32) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + output_init = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden[: GOLDEN_PREFIX_WORDS.size] = GOLDEN_PREFIX_WORDS + + output_dir.mkdir(parents=True, exist_ok=True) + output_init.tofile(output_dir / "v1.bin") + golden.tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate packed predicate golden for VPTO micro-op validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/psel/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/psel/kernel.pto new file mode 100644 index 000000000..ebe2e5782 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/psel/kernel.pto @@ -0,0 +1,37 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/psel +// family: materialization-predicate +// target_ops: pto.psel +// scenarios: predicate-transform, predicate-select +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @psel_kernel_2d(%arg0: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c13 = arith.constant 13 : i32 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %src0 = pto.pset_b32 "PAT_ALL" : !pto.mask + %sel, %next = pto.plt_b32 %c13 : i32 -> !pto.mask, i32 + %out = pto.psel %src0, %sel, %sel : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + pto.psts %out, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %gm_out, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/psel/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/psel/launch.cpp new file mode 100644 index 000000000..34e1641cb --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/psel/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/psel +// family: materialization-predicate +// target_ops: pto.psel +// scenarios: predicate-transform, predicate-select +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void psel_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPsel(uint32_t *v1, void *stream) { + psel_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/psel/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/psel/main.cpp new file mode 100644 index 000000000..bfb6d3558 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/psel/main.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/psel +// family: materialization-predicate +// target_ops: pto.psel +// scenarios: predicate-transform, predicate-select +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); const char *_recent = aclGetRecentErrMsg(); if (_recent != nullptr && _recent[0] != '\0') { std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); } rc = 1; goto cleanup; } } while (0) + +void LaunchPsel(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPsel(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pset-pattern-fragment/compare.py b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern-fragment/compare.py new file mode 100755 index 000000000..2de1b2000 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern-fragment/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pset-pattern-fragment +# family: materialization-predicate +# target_ops: pto.pset_b16, pto.pset_b32, pto.pset_b8 +# scenarios: pattern-mask, pat-vl, representative-logical-elements +# coding=utf-8 + +import os +import sys + +import numpy as np + + +EXPECTED_WORDS = 32 +PREFIX_WORDS = 24 + + +def compare_words(golden_path, output_path): + if not os.path.exists(output_path) or not os.path.exists(golden_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print( + f"[ERROR] Unexpected word count: golden={golden.size} " + f"out={output.size} expected={EXPECTED_WORDS}" + ) + return False + golden = golden[:PREFIX_WORDS] + output = output[:PREFIX_WORDS] + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed predicate words): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pset-pattern-fragment/golden.py b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern-fragment/golden.py new file mode 100755 index 000000000..fcf402e5a --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern-fragment/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pset-pattern-fragment +# family: materialization-predicate +# target_ops: pto.pset_b16, pto.pset_b32, pto.pset_b8 +# scenarios: pattern-mask, pat-vl, representative-logical-elements +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 32 +GOLDEN_PREFIX_WORDS = np.array([1227133513, 2454267026, 613566756, 1227133513, 2454267026, 613566756, 1227133513, 2454267026, 1431655765, 1431655765, 1431655765, 1431655765, 0, 0, 0, 0, 286331153, 286331153, 0, 0, 0, 0, 0, 0], dtype=np.uint32) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + output_init = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + golden[: GOLDEN_PREFIX_WORDS.size] = GOLDEN_PREFIX_WORDS + + output_dir.mkdir(parents=True, exist_ok=True) + output_init.tofile(output_dir / "v1.bin") + golden.tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate packed predicate golden for VPTO micro-op validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pset-pattern-fragment/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern-fragment/kernel.pto new file mode 100644 index 000000000..e98a0b62c --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern-fragment/kernel.pto @@ -0,0 +1,39 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pset-pattern-fragment +// family: materialization-predicate +// target_ops: pto.pset_b16, pto.pset_b32, pto.pset_b8 +// scenarios: pattern-mask, fragment-pattern +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @pset_pattern_fragment_kernel_2d(%arg0: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c96_i64 = arith.constant 96 : i64 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %m0 = pto.pset_b8 "PAT_M3" : !pto.mask + %m1 = pto.pset_b16 "PAT_H" : !pto.mask + %m2 = pto.pset_b32 "PAT_Q" : !pto.mask + pto.psts %m0, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %m1, %ub_out[%c32], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %m2, %ub_out[%c64], "NORM" : !pto.mask, !pto.ptr, index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %gm_out, %c96_i64 + nburst(%c1_i64, %c96_i64, %c96_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pset-pattern-fragment/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern-fragment/launch.cpp new file mode 100644 index 000000000..c74390013 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern-fragment/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pset-pattern-fragment +// family: materialization-predicate +// target_ops: pto.pset_b16, pto.pset_b32, pto.pset_b8 +// scenarios: pattern-mask, fragment-pattern +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void pset_pattern_fragment_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPsetPatternFragment(uint32_t *v1, void *stream) { + pset_pattern_fragment_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pset-pattern-fragment/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern-fragment/main.cpp new file mode 100644 index 000000000..116e3e7ae --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern-fragment/main.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pset-pattern-fragment +// family: materialization-predicate +// target_ops: pto.pset_b16, pto.pset_b32, pto.pset_b8 +// scenarios: pattern-mask, fragment-pattern +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); const char *_recent = aclGetRecentErrMsg(); if (_recent != nullptr && _recent[0] != '\0') { std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); } rc = 1; goto cleanup; } } while (0) + +void LaunchPsetPatternFragment(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPsetPatternFragment(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pset-pattern/compare.py b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern/compare.py new file mode 100755 index 000000000..10290abc6 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern/compare.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pset-pattern +# family: materialization-predicate +# target_ops: pto.pset_b16, pto.pset_b32, pto.pset_b8 +# scenarios: pattern-mask, pat-all, pat-vl +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + +EXPECTED_WORDS = 24 + + +def compare_packed_words(golden_path, output_path): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print( + f"[ERROR] Unexpected word count: golden={golden.size} " + f"out={output.size} expected={EXPECTED_WORDS}" + ) + return False + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed predicate words): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_packed_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pset-pattern/golden.py b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern/golden.py new file mode 100755 index 000000000..dcc083810 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern/golden.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pset-pattern +# family: materialization-predicate +# target_ops: pto.pset_b16, pto.pset_b32, pto.pset_b8 +# scenarios: pattern-mask, pat-all, pat-vl +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 24 + + +def _pack_pset_prefix(active_lanes: int, bit_stride: int, store_bytes: int) -> np.ndarray: + out = np.zeros((store_bytes,), dtype=np.uint8) + for lane in range(active_lanes): + bit_index = lane * bit_stride + byte_index = bit_index // 8 + bit_in_byte = bit_index % 8 + out[byte_index] |= np.uint8(1 << bit_in_byte) + return out + + +def generate(output_dir: Path, seed: int) -> None: + del seed + output_init = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + + out = np.zeros((OUTPUT_WORDS * 4,), dtype=np.uint8) + out[0:32] = _pack_pset_prefix(active_lanes=256, bit_stride=1, store_bytes=32) + out[32:48] = _pack_pset_prefix(active_lanes=8, bit_stride=2, store_bytes=16) + out[64:80] = _pack_pset_prefix(active_lanes=16, bit_stride=4, store_bytes=16) + golden = out.view(np.uint32) + + output_dir.mkdir(parents=True, exist_ok=True) + output_init.tofile(output_dir / "v1.bin") + golden.tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate packed predicate golden for VPTO micro-op pset-pattern validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pset-pattern/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern/kernel.pto new file mode 100644 index 000000000..42c94f782 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern/kernel.pto @@ -0,0 +1,39 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pset-pattern +// family: materialization-predicate +// target_ops: pto.pset_b16, pto.pset_b32, pto.pset_b8 +// scenarios: pattern-mask, pat-all, pat-vl +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @pset_pattern_kernel_2d(%arg0: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c96_i64 = arith.constant 96 : i64 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %m0 = pto.pset_b8 "PAT_ALL" : !pto.mask + %m1 = pto.pset_b16 "PAT_VL8" : !pto.mask + %m2 = pto.pset_b32 "PAT_VL16" : !pto.mask + pto.psts %m0, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %m1, %ub_out[%c32], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %m2, %ub_out[%c64], "NORM" : !pto.mask, !pto.ptr, index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %gm_out, %c96_i64 + nburst(%c1_i64, %c96_i64, %c96_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pset-pattern/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern/launch.cpp new file mode 100644 index 000000000..01ec8b624 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern/launch.cpp @@ -0,0 +1,69 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pset-pattern +// family: materialization-predicate +// target_ops: pto.pset_b16, pto.pset_b32, pto.pset_b8 +// scenarios: pattern-mask, pat-all, pat-vl +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void pset_pattern_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPset_pattern_kernel_2d(uint32_t *v1, void *stream) { + pset_pattern_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pset-pattern/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern/main.cpp new file mode 100644 index 000000000..15a7b4181 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pset-pattern/main.cpp @@ -0,0 +1,119 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pset-pattern +// family: materialization-predicate +// target_ops: pto.pset_b16, pto.pset_b32, pto.pset_b8 +// scenarios: pattern-mask, pat-all, pat-vl +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchPset_pattern_kernel_2d(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 24; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPset_pattern_kernel_2d(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pxor/compare.py b/test/vpto/cases/micro-op/materialization-predicate/pxor/compare.py new file mode 100755 index 000000000..0652ae4a5 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pxor/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pxor +# family: materialization-predicate +# target_ops: pto.pxor +# scenarios: predicate-transform, logical-xor +# coding=utf-8 + +import os +import sys + +import numpy as np + + +EXPECTED_WORDS = 32 +PREFIX_WORDS = 8 + + +def compare_words(golden_path, output_path): + if not os.path.exists(output_path) or not os.path.exists(golden_path): + print(f"[ERROR] Missing file: golden={golden_path} out={output_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + print( + f"[ERROR] Unexpected word count: golden={golden.size} " + f"out={output.size} expected={EXPECTED_WORDS}" + ) + return False + golden = golden[:PREFIX_WORDS] + output = output[:PREFIX_WORDS] + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed predicate words): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_words("golden_v1.bin", "v1.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pxor/golden.py b/test/vpto/cases/micro-op/materialization-predicate/pxor/golden.py new file mode 100755 index 000000000..16f212335 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pxor/golden.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/pxor +# family: materialization-predicate +# target_ops: pto.pxor +# scenarios: predicate-transform, logical-xor +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 32 +PREFIX_BITS = 13 +SUFFIX_BITS = 7 +PREDICATE_BITS = 256 +NIBBLE_COUNT = PREDICATE_BITS // 2 + + +def pack_nibbles(nibbles: np.ndarray) -> np.ndarray: + words = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + for idx, nibble in enumerate(nibbles): + words[idx // 8] |= np.uint32(int(nibble) & 0xF) << np.uint32((idx % 8) * 4) + return words + + +def generate(output_dir: Path, seed: int) -> None: + del seed + output_init = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + lhs = np.zeros((NIBBLE_COUNT,), dtype=np.uint8) + rhs = np.zeros((NIBBLE_COUNT,), dtype=np.uint8) + lhs[:PREFIX_BITS] = 1 + rhs[:SUFFIX_BITS] = 1 + golden = pack_nibbles(np.bitwise_xor(lhs, rhs)) + + output_dir.mkdir(parents=True, exist_ok=True) + output_init.tofile(output_dir / "v1.bin") + golden.tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate packed predicate golden for VPTO micro-op validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/pxor/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/pxor/kernel.pto new file mode 100644 index 000000000..e986c3e74 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pxor/kernel.pto @@ -0,0 +1,56 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pxor +// family: materialization-predicate +// target_ops: pto.pxor +// scenarios: predicate-transform +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @pxor_kernel_2d(%arg0: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c7 = arith.constant 7 : i32 + %c13 = arith.constant 13 : i32 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + %all = pto.pset_b32 "PAT_ALL" : !pto.mask + %lhs, %lhs_next = pto.plt_b32 %c13 : i32 -> !pto.mask, i32 + %rhs, %rhs_next = pto.plt_b32 %c7 : i32 -> !pto.mask, i32 + %out = pto.pxor %lhs, %rhs, %all : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + pto.psts %out, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %gm_out, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pxor/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/pxor/launch.cpp new file mode 100644 index 000000000..55f3770dc --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pxor/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pxor +// family: materialization-predicate +// target_ops: pto.pxor +// scenarios: predicate-transform +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void pxor_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPxor(uint32_t *v1, void *stream) { + pxor_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/pxor/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/pxor/main.cpp new file mode 100644 index 000000000..6bf82fb80 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/pxor/main.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/pxor +// family: materialization-predicate +// target_ops: pto.pxor +// scenarios: predicate-transform +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); const char *_recent = aclGetRecentErrMsg(); if (_recent != nullptr && _recent[0] != '\0') { std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); } rc = 1; goto cleanup; } } while (0) + +void LaunchPxor(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 32; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchPxor(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vbr-f32/compare.py b/test/vpto/cases/micro-op/materialization-predicate/vbr-f32/compare.py new file mode 100644 index 000000000..8a9b82365 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vbr-f32/compare.py @@ -0,0 +1,204 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v1.bin", "v1.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/vbr-f32/golden.py b/test/vpto/cases/micro-op/materialization-predicate/vbr-f32/golden.py new file mode 100644 index 000000000..86a637182 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vbr-f32/golden.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +VALUE = np.float32(1.25) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + v1 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v1 = np.full((ROWS, COLS), VALUE, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + golden_v1.reshape(-1).tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vbr-f32 validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/vbr-f32/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/vbr-f32/kernel.pto new file mode 100644 index 000000000..f8a8f9481 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vbr-f32/kernel.pto @@ -0,0 +1,49 @@ +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vbr_f32_kernel_2d(%arg0: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %cst = arith.constant 1.250000e+00 : f32 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.vecscope { + %active = pto.pset_b32 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c64 { + %vec = pto.vbr %cst : f32 -> !pto.vreg<64xf32> + pto.vsts %vec, %ub_out[%offset], %active : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg0, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vbr-f32/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/vbr-f32/launch.cpp new file mode 100644 index 000000000..cf1d57866 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vbr-f32/launch.cpp @@ -0,0 +1,61 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vbr_f32_kernel_2d(__gm__ float *v1); + +void LaunchVbr_f32_kernel_2d(float *v1, void *stream) { + vbr_f32_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vbr-f32/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/vbr-f32/main.cpp new file mode 100644 index 000000000..0fce80155 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vbr-f32/main.cpp @@ -0,0 +1,111 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVbr_f32_kernel_2d(float *v1, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVbr_f32_kernel_2d(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vbr-i32/compare.py b/test/vpto/cases/micro-op/materialization-predicate/vbr-i32/compare.py new file mode 100644 index 000000000..78ba4226d --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vbr-i32/compare.py @@ -0,0 +1,204 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v1.bin", "v1.bin", np.int32, 0) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/vbr-i32/golden.py b/test/vpto/cases/micro-op/materialization-predicate/vbr-i32/golden.py new file mode 100644 index 000000000..68e367caf --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vbr-i32/golden.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +VALUE = np.int32(7) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + v1 = np.zeros((ROWS, COLS), dtype=np.int32) + golden_v1 = np.full((ROWS, COLS), VALUE, dtype=np.int32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + golden_v1.reshape(-1).tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vbr-i32 validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/vbr-i32/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/vbr-i32/kernel.pto new file mode 100644 index 000000000..08710a41f --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vbr-i32/kernel.pto @@ -0,0 +1,49 @@ +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vbr_i32_kernel_2d(%arg0: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %cst = arith.constant 7 : i32 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.vecscope { + %active = pto.pset_b32 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c64 { + %vec = pto.vbr %cst : i32 -> !pto.vreg<64xi32> + pto.vsts %vec, %ub_out[%offset], %active : !pto.vreg<64xi32>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg0, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vbr-i32/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/vbr-i32/launch.cpp new file mode 100644 index 000000000..8db4dd6a4 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vbr-i32/launch.cpp @@ -0,0 +1,61 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vbr_i32_kernel_2d(__gm__ int *v1); + +void LaunchVbr_i32_kernel_2d(int *v1, void *stream) { + vbr_i32_kernel_2d<<<1, nullptr, stream>>>((__gm__ int *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vbr-i32/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/vbr-i32/main.cpp new file mode 100644 index 000000000..cae6ce5ec --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vbr-i32/main.cpp @@ -0,0 +1,111 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVbr_i32_kernel_2d(int *v1, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(int); + int *v1Host = nullptr; + int *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVbr_i32_kernel_2d(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vbr-i8/compare.py b/test/vpto/cases/micro-op/materialization-predicate/vbr-i8/compare.py new file mode 100644 index 000000000..ceb7196aa --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vbr-i8/compare.py @@ -0,0 +1,204 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v1.bin", "v1.bin", np.int8, 0) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/vbr-i8/golden.py b/test/vpto/cases/micro-op/materialization-predicate/vbr-i8/golden.py new file mode 100644 index 000000000..ea6c7ad81 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vbr-i8/golden.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +VALUE = np.int8(-7) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + v1 = np.zeros((ROWS, COLS), dtype=np.int8) + golden_v1 = np.full((ROWS, COLS), VALUE, dtype=np.int8) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + golden_v1.reshape(-1).tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vbr-i8 validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/vbr-i8/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/vbr-i8/kernel.pto new file mode 100644 index 000000000..a5be8b151 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vbr-i8/kernel.pto @@ -0,0 +1,49 @@ +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vbr_i8_kernel_2d(%arg0: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %cst = arith.constant -7 : i8 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.vecscope { + %active = pto.pset_b8 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c256 { + %vec = pto.vbr %cst : i8 -> !pto.vreg<256xsi8> + pto.vsts %vec, %ub_out[%offset], %active : !pto.vreg<256xsi8>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg0, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vbr-i8/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/vbr-i8/launch.cpp new file mode 100644 index 000000000..d3b16ce39 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vbr-i8/launch.cpp @@ -0,0 +1,61 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vbr_i8_kernel_2d(__gm__ int8_t *v1); + +void LaunchVbr_i8_kernel_2d(int8_t *v1, void *stream) { + vbr_i8_kernel_2d<<<1, nullptr, stream>>>((__gm__ int8_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vbr-i8/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/vbr-i8/main.cpp new file mode 100644 index 000000000..e9b756b4d --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vbr-i8/main.cpp @@ -0,0 +1,111 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVbr_i8_kernel_2d(int8_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(int8_t); + int8_t *v1Host = nullptr; + int8_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVbr_i8_kernel_2d(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vbr-u8/compare.py b/test/vpto/cases/micro-op/materialization-predicate/vbr-u8/compare.py new file mode 100644 index 000000000..207bac38b --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vbr-u8/compare.py @@ -0,0 +1,204 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v1.bin", "v1.bin", np.uint8, 0) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/vbr-u8/golden.py b/test/vpto/cases/micro-op/materialization-predicate/vbr-u8/golden.py new file mode 100644 index 000000000..41a4e75f9 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vbr-u8/golden.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +VALUE = np.uint8(201) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + v1 = np.zeros((ROWS, COLS), dtype=np.uint8) + golden_v1 = np.full((ROWS, COLS), VALUE, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + golden_v1.reshape(-1).tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vbr-u8 validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/vbr-u8/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/vbr-u8/kernel.pto new file mode 100644 index 000000000..70b58f31f --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vbr-u8/kernel.pto @@ -0,0 +1,49 @@ +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vbr_u8_kernel_2d(%arg0: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %cst = arith.constant -55 : i8 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.vecscope { + %active = pto.pset_b8 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c256 { + %vec = pto.vbr %cst : i8 -> !pto.vreg<256xui8> + pto.vsts %vec, %ub_out[%offset], %active : !pto.vreg<256xui8>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg0, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vbr-u8/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/vbr-u8/launch.cpp new file mode 100644 index 000000000..20b8a37c7 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vbr-u8/launch.cpp @@ -0,0 +1,61 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vbr_u8_kernel_2d(__gm__ uint8_t *v1); + +void LaunchVbr_u8_kernel_2d(uint8_t *v1, void *stream) { + vbr_u8_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint8_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vbr-u8/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/vbr-u8/main.cpp new file mode 100644 index 000000000..53637c15b --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vbr-u8/main.cpp @@ -0,0 +1,111 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVbr_u8_kernel_2d(uint8_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint8_t); + uint8_t *v1Host = nullptr; + uint8_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVbr_u8_kernel_2d(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-lane/compare.py b/test/vpto/cases/micro-op/materialization-predicate/vdup-lane/compare.py new file mode 100755 index 000000000..f22dfc0f5 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-lane/compare.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/vdup-lane +# family: materialization-predicate +# target_ops: pto.vdup +# scenarios: core-f32, vector-input, lowest-highest +# coding=utf-8 + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + diff = np.abs(golden.astype(np.float64) - output.astype(np.float64)) + idx = int(np.argmax(diff)) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={float(diff[idx])} " + f"at idx={idx} (golden={golden[idx]}, out={output[idx]})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_low.bin", "out_low.bin", np.float32, 1e-4) and ok + ok = compare_bin("golden_high.bin", "out_high.bin", np.float32, 1e-4) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-lane/golden.py b/test/vpto/cases/micro-op/materialization-predicate/vdup-lane/golden.py new file mode 100755 index 000000000..712a2adbd --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-lane/golden.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/materialization-predicate/vdup-lane +# family: materialization-predicate +# target_ops: pto.vdup +# scenarios: core-f32, vector-input, lowest-highest +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + src = rng.normal(loc=1.0, scale=3.0, size=(ROWS, COLS)).astype(np.float32) + + src_flat = src.reshape(-1) + low_flat = np.empty_like(src_flat) + high_flat = np.empty_like(src_flat) + block = 64 + for begin in range(0, src_flat.size, block): + chunk = src_flat[begin : begin + block] + low_flat[begin : begin + block] = chunk[0] + high_flat[begin : begin + block] = chunk[-1] + + low = low_flat.reshape(src.shape) + high = high_flat.reshape(src.shape) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "src.bin") + low.reshape(-1).tofile(output_dir / "golden_low.bin") + high.reshape(-1).tofile(output_dir / "golden_high.bin") + np.zeros_like(src.reshape(-1)).tofile(output_dir / "out_low.bin") + np.zeros_like(src.reshape(-1)).tofile(output_dir / "out_high.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO vector-input vdup validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-lane/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/vdup-lane/kernel.pto new file mode 100644 index 000000000..536a68700 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-lane/kernel.pto @@ -0,0 +1,54 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/vdup-lane +// family: materialization-predicate +// target_ops: pto.vdup +// scenarios: core-f32, vector-input, lowest-highest +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vdup_lane_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: !pto.ptr) attributes {pto.kernel} { + %c8192_i64 = arith.constant 8192 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %false = arith.constant false + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_low = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_high = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + pto.set_loop2_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 + pto.set_loop1_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %active = pto.pset_b32 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c64 { + %src = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %low = pto.vdup %src, %active {position = "LOWEST"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %high = pto.vdup %src, %active {position = "HIGHEST"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %low, %ub_low[%offset], %active : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + pto.vsts %high, %ub_high[%offset], %active : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_low, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_high, %arg2, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-lane/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/vdup-lane/launch.cpp new file mode 100644 index 000000000..522a86731 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-lane/launch.cpp @@ -0,0 +1,46 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/vdup-lane +// family: materialization-predicate +// target_ops: pto.vdup +// scenarios: core-f32, vector-input, lowest-highest +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vdup_lane_kernel_2d(__gm__ float *src, + __gm__ float *outLow, + __gm__ float *outHigh); + +void LaunchVdup_lane_kernel_2d(float *src, float *outLow, float *outHigh, + void *stream) { + vdup_lane_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)src, + (__gm__ float *)outLow, + (__gm__ float *)outHigh); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-lane/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/vdup-lane/main.cpp new file mode 100644 index 000000000..685317502 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-lane/main.cpp @@ -0,0 +1,119 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/vdup-lane +// family: materialization-predicate +// target_ops: pto.vdup +// scenarios: core-f32, vector-input, lowest-highest +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVdup_lane_kernel_2d(float *src, float *outLow, float *outHigh, void *stream); + +int main() { + size_t elemCount = 1024; + size_t fileSize = elemCount * sizeof(float); + float *srcHost = nullptr; + float *outLowHost = nullptr; + float *outHighHost = nullptr; + float *srcDevice = nullptr; + float *outLowDevice = nullptr; + float *outHighDevice = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), fileSize)); + ACL_CHECK(aclrtMallocHost((void **)(&outLowHost), fileSize)); + ACL_CHECK(aclrtMallocHost((void **)(&outHighHost), fileSize)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outLowDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outHighDevice, fileSize, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./src.bin", fileSize, srcHost, fileSize); + ACL_CHECK(aclrtMemcpy(srcDevice, fileSize, srcHost, fileSize, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemset(outLowDevice, fileSize, 0, fileSize)); + ACL_CHECK(aclrtMemset(outHighDevice, fileSize, 0, fileSize)); + + LaunchVdup_lane_kernel_2d(srcDevice, outLowDevice, outHighDevice, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(outLowHost, fileSize, outLowDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outHighHost, fileSize, outHighDevice, fileSize, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./out_low.bin", outLowHost, fileSize); + WriteFile("./out_high.bin", outHighHost, fileSize); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(outLowDevice); + aclrtFree(outHighDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(outLowHost); + aclrtFreeHost(outHighHost); + if (stream != nullptr) { + const aclError ret = aclrtDestroyStream(stream); + if (ret != ACL_SUCCESS) + std::fprintf(stderr, "[ERROR] aclrtDestroyStream(stream) failed: %d (%s:%d)\n", + (int)ret, __FILE__, __LINE__); + } + if (deviceSet) { + const aclError ret = aclrtResetDevice(deviceId); + if (ret != ACL_SUCCESS) + std::fprintf(stderr, "[ERROR] aclrtResetDevice(deviceId) failed: %d (%s:%d)\n", + (int)ret, __FILE__, __LINE__); + } + if (aclInited) { + const aclError ret = aclFinalize(); + if (ret != ACL_SUCCESS) + std::fprintf(stderr, "[ERROR] aclFinalize() failed: %d (%s:%d)\n", + (int)ret, __FILE__, __LINE__); + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-f16/compare.py b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-f16/compare.py new file mode 100644 index 000000000..e423b7707 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-f16/compare.py @@ -0,0 +1,60 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v1.bin", "v1.bin", np.float16, 0.001) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-f16/golden.py b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-f16/golden.py new file mode 100644 index 000000000..3f3ad08ba --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-f16/golden.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +VALUE = np.float16(1.25) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + v1 = np.zeros((ROWS, COLS), dtype=np.float16) + golden_v1 = np.full((ROWS, COLS), VALUE, dtype=np.float16) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + golden_v1.reshape(-1).tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vdup-scalar-f16 validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-f16/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-f16/kernel.pto new file mode 100644 index 000000000..4ca7c6cd6 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-f16/kernel.pto @@ -0,0 +1,36 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/vdup-scalar-f16 +// family: materialization-predicate +// target_ops: pto.vdup +// scenarios: core-f16, scalar-broadcast +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vdup_scalar_f16_kernel_2d(%arg0: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %cst = arith.constant 1.250000e+00 : f16 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.vecscope { + %active = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %vec = pto.vdup %cst, %active : f16, !pto.mask -> !pto.vreg<128xf16> + pto.vsts %vec, %ub_out[%offset], %active : !pto.vreg<128xf16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg0, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-f16/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-f16/launch.cpp new file mode 100644 index 000000000..664c961dd --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-f16/launch.cpp @@ -0,0 +1,43 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vdup_scalar_f16_kernel_2d(__gm__ half *v1); + +void LaunchVdup_scalar_f16_kernel_2d(aclFloat16 *v1, void *stream) { + vdup_scalar_f16_kernel_2d<<<1, nullptr, stream>>>((__gm__ half *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-f16/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-f16/main.cpp new file mode 100644 index 000000000..b8f469441 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-f16/main.cpp @@ -0,0 +1,111 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVdup_scalar_f16_kernel_2d(aclFloat16 *v1, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(aclFloat16); + aclFloat16 *v1Host = nullptr; + aclFloat16 *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVdup_scalar_f16_kernel_2d(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/compare.py b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/compare.py new file mode 100644 index 000000000..bc7b42e1a --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/compare.py @@ -0,0 +1,54 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden[idx])}, out={int(output[idx])}, dtype={dtype_np})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v1.bin", "v1.bin", np.int8) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/golden.py b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/golden.py new file mode 100644 index 000000000..90e2a44bb --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/golden.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +VALUE = np.int8(-83) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + v1 = np.zeros((ROWS, COLS), dtype=np.int8) + golden_v1 = np.full((ROWS, COLS), VALUE, dtype=np.int8) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + golden_v1.reshape(-1).tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vdup-scalar-i8 validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/kernel.pto new file mode 100644 index 000000000..f6cc925e4 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/kernel.pto @@ -0,0 +1,36 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/vdup-scalar-i8 +// family: materialization-predicate +// target_ops: pto.vdup +// scenarios: core-i8, scalar-broadcast +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vdup_scalar_i8_kernel_2d(%arg0: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %cst = arith.constant -83 : i8 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.vecscope { + %active = pto.pset_b8 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %vec = pto.vdup %cst, %active : i8, !pto.mask -> !pto.vreg<256xsi8> + pto.vsts %vec, %ub_out[%offset], %active : !pto.vreg<256xsi8>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg0, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/launch.cpp new file mode 100644 index 000000000..5054b1777 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/launch.cpp @@ -0,0 +1,43 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vdup_scalar_i8_kernel_2d(__gm__ int8_t *v1); + +void LaunchVdup_scalar_i8_kernel_2d(int8_t *v1, void *stream) { + vdup_scalar_i8_kernel_2d<<<1, nullptr, stream>>>((__gm__ int8_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/main.cpp new file mode 100644 index 000000000..fde52caa7 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-i8/main.cpp @@ -0,0 +1,111 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVdup_scalar_i8_kernel_2d(int8_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(int8_t); + int8_t *v1Host = nullptr; + int8_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVdup_scalar_i8_kernel_2d(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-u8/compare.py b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-u8/compare.py new file mode 100644 index 000000000..54831ce84 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-u8/compare.py @@ -0,0 +1,54 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden[idx])}, out={int(output[idx])}, dtype={dtype_np})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v1.bin", "v1.bin", np.uint8) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-u8/golden.py b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-u8/golden.py new file mode 100644 index 000000000..20eed0a35 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-u8/golden.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +VALUE = np.uint8(173) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + v1 = np.zeros((ROWS, COLS), dtype=np.uint8) + golden_v1 = np.full((ROWS, COLS), VALUE, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + golden_v1.reshape(-1).tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vdup-scalar-u8 validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-u8/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-u8/kernel.pto new file mode 100644 index 000000000..14e74bb52 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-u8/kernel.pto @@ -0,0 +1,36 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/materialization-predicate/vdup-scalar-u8 +// family: materialization-predicate +// target_ops: pto.vdup +// scenarios: core-u8, scalar-broadcast-signless +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vdup_scalar_u8_kernel_2d(%arg0: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %cst = arith.constant -83 : i8 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.vecscope { + %active = pto.pset_b8 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %vec = pto.vdup %cst, %active : i8, !pto.mask -> !pto.vreg<256xui8> + pto.vsts %vec, %ub_out[%offset], %active : !pto.vreg<256xui8>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg0, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-u8/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-u8/launch.cpp new file mode 100644 index 000000000..d25c635c9 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-u8/launch.cpp @@ -0,0 +1,43 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vdup_scalar_u8_kernel_2d(__gm__ uint8_t *v1); + +void LaunchVdup_scalar_u8_kernel_2d(uint8_t *v1, void *stream) { + vdup_scalar_u8_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint8_t *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-u8/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-u8/main.cpp new file mode 100644 index 000000000..32dfff59a --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar-u8/main.cpp @@ -0,0 +1,101 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVdup_scalar_u8_kernel_2d(uint8_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint8_t); + uint8_t *v1Host = nullptr; + uint8_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVdup_scalar_u8_kernel_2d(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar/compare.py b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar/compare.py new file mode 100644 index 000000000..8a9b82365 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar/compare.py @@ -0,0 +1,204 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v1.bin", "v1.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar/golden.py b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar/golden.py new file mode 100644 index 000000000..3153c6005 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar/golden.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +VALUE = np.float32(-2.5) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + v1 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v1 = np.full((ROWS, COLS), VALUE, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + golden_v1.reshape(-1).tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vdup-scalar validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar/kernel.pto b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar/kernel.pto new file mode 100644 index 000000000..de419713b --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar/kernel.pto @@ -0,0 +1,49 @@ +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vdup_scalar_kernel_2d(%arg0: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %cst = arith.constant -2.500000e+00 : f32 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.vecscope { + %active = pto.pset_b32 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c64 { + %vec = pto.vdup %cst, %active : f32, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %vec, %ub_out[%offset], %active : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg0, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar/launch.cpp b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar/launch.cpp new file mode 100644 index 000000000..02754e6d2 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar/launch.cpp @@ -0,0 +1,61 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vdup_scalar_kernel_2d(__gm__ float *v1); + +void LaunchVdup_scalar_kernel_2d(float *v1, void *stream) { + vdup_scalar_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1); +} diff --git a/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar/main.cpp b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar/main.cpp new file mode 100644 index 000000000..6aff66657 --- /dev/null +++ b/test/vpto/cases/micro-op/materialization-predicate/vdup-scalar/main.cpp @@ -0,0 +1,111 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVdup_scalar_kernel_2d(float *v1, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVdup_scalar_kernel_2d(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/_predicate_load_store_case.py b/test/vpto/cases/micro-op/predicate-load-store/_predicate_load_store_case.py new file mode 100644 index 000000000..494247b52 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/_predicate_load_store_case.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +OUTPUT_BYTES = ROWS * COLS +PREDICATE_BITS = 256 +# For the current A5 predicate load/store surface used by these composition +# cases, the user-visible packed NORM footprint is 16 bytes. Bytes beyond that +# range are not part of the checked result footprint. +NORM_STORAGE_BYTES = 16 + + +def prefix_bits(active_bits: int) -> np.ndarray: + bits = np.zeros((PREDICATE_BITS,), dtype=np.uint8) + bits[:active_bits] = 1 + return bits + + +def pk_us_compose(bits: np.ndarray) -> np.ndarray: + packed = bits[::2] + return np.repeat(packed, 2).astype(np.uint8, copy=False) + + +def norm_ds_compose(bits: np.ndarray) -> np.ndarray: + source = np.concatenate( + [bits.astype(np.uint8, copy=False), np.zeros_like(bits, dtype=np.uint8)] + ) + return source[::2][:PREDICATE_BITS].astype(np.uint8, copy=False) + + +def norm_store_bytes(bits: np.ndarray) -> np.ndarray: + packed = np.packbits(bits.astype(np.uint8, copy=False), bitorder="little") + out = np.zeros((OUTPUT_BYTES,), dtype=np.uint8) + out[:NORM_STORAGE_BYTES] = packed[:NORM_STORAGE_BYTES] + return out + + +def write_default_inputs(output_dir: Path) -> None: + np.zeros((ROWS * COLS,), dtype=np.float32).tofile(output_dir / "v1.bin") + np.zeros((ROWS * COLS,), dtype=np.float32).tofile(output_dir / "v2.bin") + np.zeros((OUTPUT_BYTES,), dtype=np.uint8).tofile(output_dir / "v3.bin") + + +def write_case(output_dir: Path, bits: np.ndarray) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + write_default_inputs(output_dir) + norm_store_bytes(bits).tofile(output_dir / "golden_v3.bin") + + +def compare_norm_store(golden_path: str, output_path: str) -> bool: + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + if golden.size < NORM_STORAGE_BYTES or output.size < NORM_STORAGE_BYTES: + return False + if not np.array_equal(golden[:NORM_STORAGE_BYTES], output[:NORM_STORAGE_BYTES]): + diff = np.nonzero(golden[:NORM_STORAGE_BYTES] != output[:NORM_STORAGE_BYTES])[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (predicate load/store composition): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + return False + return True diff --git a/test/vpto/cases/micro-op/predicate-load-store/pldi-norm/compare.py b/test/vpto/cases/micro-op/predicate-load-store/pldi-norm/compare.py new file mode 100644 index 000000000..1f197eec2 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/pldi-norm/compare.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/predicate-load-store/pldi-norm +# family: predicate-load-store +# target_ops: pto.pldi +# scenarios: packed-load, immediate-offset, representative-logical-elements + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.uint8) + output = np.fromfile("v2.bin", dtype=np.uint8) + if golden.size < 256 or output.size < 256: + print( + f"[ERROR] Packed buffer too small: golden={golden.size} out={output.size}" + ) + raise SystemExit(2) + if not np.array_equal(golden[:256], output[:256]): + diff = np.nonzero(golden[:256] != output[:256])[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (pldi NORM -> vsel): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + raise SystemExit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/predicate-load-store/pldi-norm/golden.py b/test/vpto/cases/micro-op/predicate-load-store/pldi-norm/golden.py new file mode 100644 index 000000000..d4811cdb0 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/pldi-norm/golden.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/predicate-load-store/pldi-norm +# family: predicate-load-store +# target_ops: pto.pldi +# scenarios: packed-load, immediate-offset, representative-logical-elements + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +ACTIVE_BITS = 145 +OUTPUT_BYTES = 1024 +VECTOR_BYTES = 256 +PACKED_BYTES = 32 + + +def prefix_bits(active_bits: int) -> np.ndarray: + bits = np.zeros((256,), dtype=np.uint8) + bits[:active_bits] = 1 + return bits + + +def make_input_buffer(bits: np.ndarray) -> np.ndarray: + packed = np.packbits(bits.astype(np.uint8, copy=False), bitorder="little") + ones = np.ones((VECTOR_BYTES,), dtype=np.uint8) + zeros = np.zeros((VECTOR_BYTES,), dtype=np.uint8) + out = np.zeros((OUTPUT_BYTES,), dtype=np.uint8) + out[:PACKED_BYTES] = packed[:PACKED_BYTES] + out[PACKED_BYTES : PACKED_BYTES + VECTOR_BYTES] = ones + out[PACKED_BYTES + VECTOR_BYTES : PACKED_BYTES + 2 * VECTOR_BYTES] = zeros + return out + + +def expected_selected_bytes(bits: np.ndarray) -> np.ndarray: + out = np.zeros((OUTPUT_BYTES,), dtype=np.uint8) + out[:VECTOR_BYTES] = bits.astype(np.uint8, copy=False) + return out + + +def generate(output_dir: Path, seed: int) -> None: + del seed + bits = prefix_bits(ACTIVE_BITS) + input_buffer = make_input_buffer(bits) + golden = expected_selected_bytes(bits) + + output_dir.mkdir(parents=True, exist_ok=True) + input_buffer.tofile(output_dir / "v1.bin") + np.zeros((OUTPUT_BYTES,), dtype=np.uint8).tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate raw packed predicate input/golden for VPTO micro-op pldi-norm validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/predicate-load-store/pldi-norm/kernel.pto b/test/vpto/cases/micro-op/predicate-load-store/pldi-norm/kernel.pto new file mode 100644 index 000000000..9b16dd30f --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/pldi-norm/kernel.pto @@ -0,0 +1,59 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/pldi-norm +// family: predicate-load-store +// target_ops: pto.pldi +// scenarios: packed-load, immediate-offset, representative-logical-elements +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @pldi_norm_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c32 = arith.constant 32 : index + %c288 = arith.constant 288 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c4_i64 = arith.constant 4 : i64 + %c256_i64 = arith.constant 256 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c10240_i64 = arith.constant 10240 : i64 + %c256_i32 = arith.constant 256 : i32 + %c256_loop_i32 = arith.constant 256 : i32 + %false = arith.constant false + + %ub_in = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c10240_i64 : i64 -> !pto.ptr + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c256_i64 + nburst(%c4_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c256 step %c256 iter_args(%remaining = %c256_loop_i32) -> (i32) { + %loaded = pto.pldi %ub_in[%c0], "NORM" : !pto.ptr, index -> !pto.mask + %full_mask, %next_remaining = pto.plt_b8 %remaining : i32 -> !pto.mask, i32 + %ones_offset = arith.addi %offset, %c32 : index + %zeros_offset = arith.addi %offset, %c288 : index + %ones = pto.vlds %ub_in[%ones_offset] : !pto.ptr -> !pto.vreg<256xui8> + %zeros = pto.vlds %ub_in[%zeros_offset] : !pto.ptr -> !pto.vreg<256xui8> + %out = pto.vsel %ones, %zeros, %loaded : !pto.vreg<256xui8>, !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<256xui8> + pto.vsts %out, %ub_out[%offset], %full_mask : !pto.vreg<256xui8>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop1_stride_ubtoout %c256_i64, %c256_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c256_i64, %c256_i64 : i64, i64 + pto.mte_ub_gm %ub_out, %arg1, %c256_i64 + nburst(%c4_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/pldi-norm/launch.cpp b/test/vpto/cases/micro-op/predicate-load-store/pldi-norm/launch.cpp new file mode 100644 index 000000000..8044d8893 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/pldi-norm/launch.cpp @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void pldi_norm_kernel_2d(__gm__ unsigned char *v1, + __gm__ unsigned char *v2); + +void LaunchPldi_norm_kernel_2d(unsigned char *v1, unsigned char *v2, void *stream) { + pldi_norm_kernel_2d<<<1, nullptr, stream>>>((__gm__ unsigned char *)v1, + (__gm__ unsigned char *)v2); +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/pldi-norm/main.cpp b/test/vpto/cases/micro-op/predicate-load-store/pldi-norm/main.cpp new file mode 100644 index 000000000..a1a8204c2 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/pldi-norm/main.cpp @@ -0,0 +1,85 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/pldi-norm +// ----------------------------------------------------------------------------- +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchPldi_norm_kernel_2d(unsigned char *v1, unsigned char *v2, void *stream); + +int main() { + size_t fileSize_v1 = 1024 * sizeof(unsigned char); + size_t fileSize_v2 = 1024 * sizeof(unsigned char); + unsigned char *v1Host = nullptr; + unsigned char *v1Device = nullptr; + unsigned char *v2Host = nullptr; + unsigned char *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchPldi_norm_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/plds-norm/compare.py b/test/vpto/cases/micro-op/predicate-load-store/plds-norm/compare.py new file mode 100644 index 000000000..bd3820b2e --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/plds-norm/compare.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/predicate-load-store/plds-norm +# family: predicate-load-store +# target_ops: pto.plds +# scenarios: packed-load, dynamic-offset, representative-logical-elements + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.uint8) + output = np.fromfile("v2.bin", dtype=np.uint8) + if golden.size < 256 or output.size < 256: + print( + f"[ERROR] Packed buffer too small: golden={golden.size} out={output.size}" + ) + raise SystemExit(2) + if not np.array_equal(golden[:256], output[:256]): + diff = np.nonzero(golden[:256] != output[:256])[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (plds NORM -> vsel): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + raise SystemExit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/predicate-load-store/plds-norm/golden.py b/test/vpto/cases/micro-op/predicate-load-store/plds-norm/golden.py new file mode 100644 index 000000000..e6cf2fb1b --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/plds-norm/golden.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/predicate-load-store/plds-norm +# family: predicate-load-store +# target_ops: pto.plds +# scenarios: packed-load, dynamic-offset, representative-logical-elements + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +ACTIVE_BITS = 145 +OUTPUT_BYTES = 1024 +VECTOR_BYTES = 256 +PACKED_BYTES = 32 + + +def prefix_bits(active_bits: int) -> np.ndarray: + bits = np.zeros((256,), dtype=np.uint8) + bits[:active_bits] = 1 + return bits + + +def make_input_buffer(bits: np.ndarray) -> np.ndarray: + packed = np.packbits(bits.astype(np.uint8, copy=False), bitorder="little") + ones = np.ones((VECTOR_BYTES,), dtype=np.uint8) + zeros = np.zeros((VECTOR_BYTES,), dtype=np.uint8) + out = np.zeros((OUTPUT_BYTES,), dtype=np.uint8) + out[:PACKED_BYTES] = packed[:PACKED_BYTES] + out[PACKED_BYTES : PACKED_BYTES + VECTOR_BYTES] = ones + out[PACKED_BYTES + VECTOR_BYTES : PACKED_BYTES + 2 * VECTOR_BYTES] = zeros + return out + + +def expected_selected_bytes(bits: np.ndarray) -> np.ndarray: + out = np.zeros((OUTPUT_BYTES,), dtype=np.uint8) + out[:VECTOR_BYTES] = bits.astype(np.uint8, copy=False) + return out + + +def generate(output_dir: Path, seed: int) -> None: + del seed + bits = prefix_bits(ACTIVE_BITS) + input_buffer = make_input_buffer(bits) + golden = expected_selected_bytes(bits) + + output_dir.mkdir(parents=True, exist_ok=True) + input_buffer.tofile(output_dir / "v1.bin") + np.zeros((OUTPUT_BYTES,), dtype=np.uint8).tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate raw packed predicate input/golden for VPTO micro-op plds-norm validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/predicate-load-store/plds-norm/kernel.pto b/test/vpto/cases/micro-op/predicate-load-store/plds-norm/kernel.pto new file mode 100644 index 000000000..d64654b22 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/plds-norm/kernel.pto @@ -0,0 +1,58 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/plds-norm +// family: predicate-load-store +// target_ops: pto.plds +// scenarios: packed-load, dynamic-offset, representative-logical-elements +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @plds_norm_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c32 = arith.constant 32 : index + %c288 = arith.constant 288 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c4_i64 = arith.constant 4 : i64 + %c256_i64 = arith.constant 256 : i64 + %c256_i32 = arith.constant 256 : i32 + %c256_loop_i32 = arith.constant 256 : i32 + %c8192_i64 = arith.constant 8192 : i64 + %c10240_i64 = arith.constant 10240 : i64 + %false = arith.constant false + + %ub_in = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c10240_i64 : i64 -> !pto.ptr + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c256_i64 + nburst(%c4_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c256 step %c256 iter_args(%remaining = %c256_loop_i32) -> (i32) { + %byte_offset = arith.addi %offset, %c0 : index + %loaded = pto.plds %ub_in[%byte_offset], "NORM" : !pto.ptr, index -> !pto.mask + %full_mask, %next_remaining = pto.plt_b8 %remaining : i32 -> !pto.mask, i32 + %ones_offset = arith.addi %offset, %c32 : index + %zeros_offset = arith.addi %offset, %c288 : index + %ones = pto.vlds %ub_in[%ones_offset] : !pto.ptr -> !pto.vreg<256xui8> + %zeros = pto.vlds %ub_in[%zeros_offset] : !pto.ptr -> !pto.vreg<256xui8> + %out = pto.vsel %ones, %zeros, %loaded : !pto.vreg<256xui8>, !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<256xui8> + pto.vsts %out, %ub_out[%offset], %full_mask : !pto.vreg<256xui8>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop1_stride_ubtoout %c256_i64, %c256_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c256_i64, %c256_i64 : i64, i64 + pto.mte_ub_gm %ub_out, %arg1, %c256_i64 + nburst(%c4_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/plds-norm/launch.cpp b/test/vpto/cases/micro-op/predicate-load-store/plds-norm/launch.cpp new file mode 100644 index 000000000..9a1e9de07 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/plds-norm/launch.cpp @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void plds_norm_kernel_2d(__gm__ unsigned char *v1, + __gm__ unsigned char *v2); + +void LaunchPlds_norm_kernel_2d(unsigned char *v1, unsigned char *v2, void *stream) { + plds_norm_kernel_2d<<<1, nullptr, stream>>>((__gm__ unsigned char *)v1, + (__gm__ unsigned char *)v2); +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/plds-norm/main.cpp b/test/vpto/cases/micro-op/predicate-load-store/plds-norm/main.cpp new file mode 100644 index 000000000..30f136c67 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/plds-norm/main.cpp @@ -0,0 +1,85 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/plds-norm +// ----------------------------------------------------------------------------- +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchPlds_norm_kernel_2d(unsigned char *v1, unsigned char *v2, void *stream); + +int main() { + size_t fileSize_v1 = 1024 * sizeof(unsigned char); + size_t fileSize_v2 = 1024 * sizeof(unsigned char); + unsigned char *v1Host = nullptr; + unsigned char *v1Device = nullptr; + unsigned char *v2Host = nullptr; + unsigned char *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchPlds_norm_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/psti-norm-pldi-ds/compare.py b/test/vpto/cases/micro-op/predicate-load-store/psti-norm-pldi-ds/compare.py new file mode 100644 index 000000000..9a9e71168 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psti-norm-pldi-ds/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/predicate-load-store/psti-norm-pldi-ds +# family: predicate-load-store +# target_ops: pto.pldi, pto.psti +# scenarios: predicate-load-store-composition, immediate-offset, load-store-pair-preservation, representative-logical-elements + +import os +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).resolve().parent.parent)) + +from _predicate_load_store_case import compare_norm_store + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_norm_store("golden_v3.bin", "v3.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/predicate-load-store/psti-norm-pldi-ds/golden.py b/test/vpto/cases/micro-op/predicate-load-store/psti-norm-pldi-ds/golden.py new file mode 100644 index 000000000..fa4f55ccd --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psti-norm-pldi-ds/golden.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/predicate-load-store/psti-norm-pldi-ds +# family: predicate-load-store +# target_ops: pto.pldi, pto.psti +# scenarios: predicate-load-store-composition, immediate-offset, load-store-pair-preservation, representative-logical-elements + +import argparse +from pathlib import Path +import sys + +sys.path.append(str(Path(__file__).resolve().parent.parent)) + +from _predicate_load_store_case import norm_ds_compose, prefix_bits, write_case + + +SEED = 19 +ACTIVE_BITS = 143 + + +def generate(output_dir: Path, seed: int, src_elem_bytes: int) -> None: + del seed + del src_elem_bytes + write_case(output_dir, norm_ds_compose(prefix_bits(ACTIVE_BITS))) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for psti-norm-pldi-ds." + ) + parser.add_argument( + "--output-dir", + type=Path, + default=Path("."), + help="Directory where v1.bin/v2.bin/v3.bin/golden_v3.bin are written.", + ) + parser.add_argument("--seed", type=int, default=SEED, help="Numpy random seed.") + parser.add_argument( + "--src-elem-bytes", + type=int, + default=4, + help="Unused compatibility option kept for the shared runner surface.", + ) + args = parser.parse_args() + generate(args.output_dir, args.seed, args.src_elem_bytes) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/predicate-load-store/psti-norm-pldi-ds/kernel.pto b/test/vpto/cases/micro-op/predicate-load-store/psti-norm-pldi-ds/kernel.pto new file mode 100644 index 000000000..e09d6999b --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psti-norm-pldi-ds/kernel.pto @@ -0,0 +1,51 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/psti-norm-pldi-ds +// family: predicate-load-store +// target_ops: pto.pldi, pto.psti +// scenarios: predicate-load-store-composition, immediate-offset, load-store-pair-preservation, representative-logical-elements +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @psti_norm_pldi_ds_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c143 = arith.constant 143 : i32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c10240_i64 = arith.constant 10240 : i64 + %false = arith.constant false + + %ub_mid = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c10240_i64 : i64 -> !pto.ptr + pto.mte_gm_ub %arg1, %ub_mid, %c0_i64, %c32_i64 + nburst(%c32_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %iv = %c0 to %c1 step %c1 iter_args(%remaining = %c143) -> (i32) { + %src, %next = pto.plt_b8 %remaining : i32 -> !pto.mask, i32 + pto.psti %src, %ub_mid[%c32], "NORM" : !pto.mask, !pto.ptr, index + pto.mem_bar "VST_VLD" + %loaded = pto.pldi %ub_mid[%c32], "DS" : !pto.ptr, index -> !pto.mask + pto.psts %loaded, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + scf.yield %next : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.mte_ub_gm %ub_out, %arg2, %c32_i64 + nburst(%c32_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/psti-norm-pldi-ds/launch.cpp b/test/vpto/cases/micro-op/predicate-load-store/psti-norm-pldi-ds/launch.cpp new file mode 100644 index 000000000..b1d57a8e8 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psti-norm-pldi-ds/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void psti_norm_pldi_ds_kernel_2d(__gm__ float *v1, + __gm__ unsigned char *v2, + __gm__ unsigned char *v3); + +void LaunchPsti_norm_pldi_ds_kernel_2d(float *v1, unsigned char *v2, + unsigned char *v3, void *stream) { + psti_norm_pldi_ds_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ unsigned char *)v2, + (__gm__ unsigned char *)v3); +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/psti-norm-pldi-ds/main.cpp b/test/vpto/cases/micro-op/predicate-load-store/psti-norm-pldi-ds/main.cpp new file mode 100644 index 000000000..db6a66d3f --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psti-norm-pldi-ds/main.cpp @@ -0,0 +1,95 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/psti-norm-pldi-ds +// ----------------------------------------------------------------------------- +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchPsti_norm_pldi_ds_kernel_2d(float *v1, unsigned char *v2, + unsigned char *v3, void *stream); + +int main() { + size_t fileSize_v1 = 1024 * sizeof(float); + size_t fileSize_v2 = 1024 * sizeof(unsigned char); + size_t fileSize_v3 = 1024 * sizeof(unsigned char); + float *v1Host = nullptr; + float *v1Device = nullptr; + unsigned char *v2Host = nullptr; + unsigned char *v2Device = nullptr; + unsigned char *v3Host = nullptr; + unsigned char *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchPsti_norm_pldi_ds_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/psti-pk-pldi-us/compare.py b/test/vpto/cases/micro-op/predicate-load-store/psti-pk-pldi-us/compare.py new file mode 100644 index 000000000..5adbcb96c --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psti-pk-pldi-us/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/predicate-load-store/psti-pk-pldi-us +# family: predicate-load-store +# target_ops: pto.pldi, pto.psti +# scenarios: predicate-load-store-composition, immediate-offset, load-store-pair-preservation, representative-logical-elements + +import os +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).resolve().parent.parent)) + +from _predicate_load_store_case import compare_norm_store + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_norm_store("golden_v3.bin", "v3.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/predicate-load-store/psti-pk-pldi-us/golden.py b/test/vpto/cases/micro-op/predicate-load-store/psti-pk-pldi-us/golden.py new file mode 100644 index 000000000..eb6a105ed --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psti-pk-pldi-us/golden.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/predicate-load-store/psti-pk-pldi-us +# family: predicate-load-store +# target_ops: pto.pldi, pto.psti +# scenarios: predicate-load-store-composition, immediate-offset, load-store-pair-preservation, representative-logical-elements + +import argparse +from pathlib import Path +import sys + +sys.path.append(str(Path(__file__).resolve().parent.parent)) + +from _predicate_load_store_case import pk_us_compose, prefix_bits, write_case + + +SEED = 19 +ACTIVE_BITS = 145 + + +def generate(output_dir: Path, seed: int, src_elem_bytes: int) -> None: + del seed + del src_elem_bytes + write_case(output_dir, pk_us_compose(prefix_bits(ACTIVE_BITS))) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate inputs/golden for psti-pk-pldi-us.") + parser.add_argument( + "--output-dir", + type=Path, + default=Path("."), + help="Directory where v1.bin/v2.bin/v3.bin/golden_v3.bin are written.", + ) + parser.add_argument("--seed", type=int, default=SEED, help="Numpy random seed.") + parser.add_argument( + "--src-elem-bytes", + type=int, + default=4, + help="Unused compatibility option kept for the shared runner surface.", + ) + args = parser.parse_args() + generate(args.output_dir, args.seed, args.src_elem_bytes) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/predicate-load-store/psti-pk-pldi-us/kernel.pto b/test/vpto/cases/micro-op/predicate-load-store/psti-pk-pldi-us/kernel.pto new file mode 100644 index 000000000..db2f693f0 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psti-pk-pldi-us/kernel.pto @@ -0,0 +1,51 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/psti-pk-pldi-us +// family: predicate-load-store +// target_ops: pto.pldi, pto.psti +// scenarios: predicate-load-store-composition, immediate-offset, load-store-pair-preservation, representative-logical-elements +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @psti_pk_pldi_us_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c145 = arith.constant 145 : i32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c10240_i64 = arith.constant 10240 : i64 + %false = arith.constant false + + %ub_mid = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c10240_i64 : i64 -> !pto.ptr + pto.mte_gm_ub %arg1, %ub_mid, %c0_i64, %c32_i64 + nburst(%c32_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %iv = %c0 to %c1 step %c1 iter_args(%remaining = %c145) -> (i32) { + %src, %next = pto.plt_b8 %remaining : i32 -> !pto.mask, i32 + pto.psti %src, %ub_mid[%c8], "PK" : !pto.mask, !pto.ptr, index + pto.mem_bar "VST_VLD" + %loaded = pto.pldi %ub_mid[%c8], "US" : !pto.ptr, index -> !pto.mask + pto.psts %loaded, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + scf.yield %next : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.mte_ub_gm %ub_out, %arg2, %c32_i64 + nburst(%c32_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/psti-pk-pldi-us/launch.cpp b/test/vpto/cases/micro-op/predicate-load-store/psti-pk-pldi-us/launch.cpp new file mode 100644 index 000000000..40a6df38d --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psti-pk-pldi-us/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void psti_pk_pldi_us_kernel_2d(__gm__ float *v1, + __gm__ unsigned char *v2, + __gm__ unsigned char *v3); + +void LaunchPsti_pk_pldi_us_kernel_2d(float *v1, unsigned char *v2, + unsigned char *v3, void *stream) { + psti_pk_pldi_us_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ unsigned char *)v2, + (__gm__ unsigned char *)v3); +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/psti-pk-pldi-us/main.cpp b/test/vpto/cases/micro-op/predicate-load-store/psti-pk-pldi-us/main.cpp new file mode 100644 index 000000000..0f2567d2c --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psti-pk-pldi-us/main.cpp @@ -0,0 +1,95 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/psti-pk-pldi-us +// ----------------------------------------------------------------------------- +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchPsti_pk_pldi_us_kernel_2d(float *v1, unsigned char *v2, + unsigned char *v3, void *stream); + +int main() { + size_t fileSize_v1 = 1024 * sizeof(float); + size_t fileSize_v2 = 1024 * sizeof(unsigned char); + size_t fileSize_v3 = 1024 * sizeof(unsigned char); + float *v1Host = nullptr; + float *v1Device = nullptr; + unsigned char *v2Host = nullptr; + unsigned char *v2Device = nullptr; + unsigned char *v3Host = nullptr; + unsigned char *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchPsti_pk_pldi_us_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/psti-pk/compare.py b/test/vpto/cases/micro-op/predicate-load-store/psti-pk/compare.py new file mode 100644 index 000000000..e0ca16a26 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psti-pk/compare.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/predicate-load-store/psti-pk +# family: predicate-load-store +# target_ops: pto.psti +# scenarios: packed-store, immediate-offset, representative-logical-elements + +import numpy as np + + +EXPECTED_WORDS = 8 +PK_STORAGE_BYTES = 16 + + +def main() -> None: + golden = np.fromfile("golden_v1.bin", dtype=np.uint8) + output = np.fromfile("v1.bin", dtype=np.uint8) + expected_bytes = EXPECTED_WORDS * 4 + if golden.size != expected_bytes or output.size != expected_bytes: + print( + f"[ERROR] Unexpected byte count: golden={golden.size} " + f"out={output.size} expected={expected_bytes}" + ) + raise SystemExit(2) + if not np.array_equal(golden[:PK_STORAGE_BYTES], output[:PK_STORAGE_BYTES]): + diff = np.nonzero(golden[:PK_STORAGE_BYTES] != output[:PK_STORAGE_BYTES])[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (psti PK raw packed store): idx={idx} " + f"golden={int(golden[idx])} out={int(output[idx])}" + ) + raise SystemExit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/predicate-load-store/psti-pk/golden.py b/test/vpto/cases/micro-op/predicate-load-store/psti-pk/golden.py new file mode 100644 index 000000000..42fd1b842 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psti-pk/golden.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/predicate-load-store/psti-pk +# family: predicate-load-store +# target_ops: pto.psti +# scenarios: packed-store, immediate-offset, representative-logical-elements + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +OUTPUT_WORDS = 8 +ACTIVE_BITS = 145 +PK_STORAGE_BYTES = 16 + + +def prefix_bits(active_bits: int) -> np.ndarray: + bits = np.zeros((256,), dtype=np.uint8) + bits[:active_bits] = 1 + return bits + + +def generate(output_dir: Path, seed: int) -> None: + del seed + bits = prefix_bits(ACTIVE_BITS) + packed_pk = np.packbits(bits[::2], bitorder="little") + out = np.zeros((OUTPUT_WORDS * 4,), dtype=np.uint8) + out[:PK_STORAGE_BYTES] = packed_pk[:PK_STORAGE_BYTES] + + output_dir.mkdir(parents=True, exist_ok=True) + np.zeros((OUTPUT_WORDS,), dtype=np.uint32).tofile(output_dir / "v1.bin") + out.view(np.uint32).tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate packed predicate golden for VPTO micro-op psti-pk validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/predicate-load-store/psti-pk/kernel.pto b/test/vpto/cases/micro-op/predicate-load-store/psti-pk/kernel.pto new file mode 100644 index 000000000..3987d3e5e --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psti-pk/kernel.pto @@ -0,0 +1,34 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/psti-pk +// family: predicate-load-store +// target_ops: pto.psti +// scenarios: packed-store, immediate-offset, representative-logical-elements +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @psti_pk_kernel_2d(%arg0: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c145 = arith.constant 145 : i32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + %gm_out = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %src, %next = pto.plt_b8 %c145 : i32 -> !pto.mask, i32 + pto.psti %src, %ub_out[%c0], "PK" : !pto.mask, !pto.ptr, index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %gm_out, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/psti-pk/launch.cpp b/test/vpto/cases/micro-op/predicate-load-store/psti-pk/launch.cpp new file mode 100644 index 000000000..5be1e518d --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psti-pk/launch.cpp @@ -0,0 +1,43 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void psti_pk_kernel_2d(__gm__ uint32_t *v1); + +void LaunchPsti_pk_kernel_2d(uint32_t *v1, void *stream) { + psti_pk_kernel_2d<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1); +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/psti-pk/main.cpp b/test/vpto/cases/micro-op/predicate-load-store/psti-pk/main.cpp new file mode 100644 index 000000000..8be1b45c4 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psti-pk/main.cpp @@ -0,0 +1,78 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/psti-pk +// ----------------------------------------------------------------------------- +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchPsti_pk_kernel_2d(uint32_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 8; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchPsti_pk_kernel_2d(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/psts-norm-plds-ds/compare.py b/test/vpto/cases/micro-op/predicate-load-store/psts-norm-plds-ds/compare.py new file mode 100644 index 000000000..39299f639 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psts-norm-plds-ds/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/predicate-load-store/psts-norm-plds-ds +# family: predicate-load-store +# target_ops: pto.plds, pto.psts +# scenarios: predicate-load-store-composition, dynamic-offset, load-store-pair-preservation, representative-logical-elements + +import os +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).resolve().parent.parent)) + +from _predicate_load_store_case import compare_norm_store + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_norm_store("golden_v3.bin", "v3.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/predicate-load-store/psts-norm-plds-ds/golden.py b/test/vpto/cases/micro-op/predicate-load-store/psts-norm-plds-ds/golden.py new file mode 100644 index 000000000..19a16409e --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psts-norm-plds-ds/golden.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/predicate-load-store/psts-norm-plds-ds +# family: predicate-load-store +# target_ops: pto.plds, pto.psts +# scenarios: predicate-load-store-composition, dynamic-offset, load-store-pair-preservation, representative-logical-elements + +import argparse +from pathlib import Path +import sys + +sys.path.append(str(Path(__file__).resolve().parent.parent)) + +from _predicate_load_store_case import norm_ds_compose, prefix_bits, write_case + + +SEED = 19 +ACTIVE_BITS = 175 + + +def generate(output_dir: Path, seed: int) -> None: + del seed + write_case(output_dir, norm_ds_compose(prefix_bits(ACTIVE_BITS))) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate inputs/golden for psts-norm-plds-ds.") + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/predicate-load-store/psts-norm-plds-ds/kernel.pto b/test/vpto/cases/micro-op/predicate-load-store/psts-norm-plds-ds/kernel.pto new file mode 100644 index 000000000..dac3c026f --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psts-norm-plds-ds/kernel.pto @@ -0,0 +1,56 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/psts-norm-plds-ds +// family: predicate-load-store +// target_ops: pto.plds, pto.psts +// scenarios: predicate-load-store-composition, dynamic-offset, load-store-pair-preservation, representative-logical-elements +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @psts_norm_plds_ds_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c175 = arith.constant 175 : i32 + %c0_i32 = arith.constant 0 : i32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c10240_i64 = arith.constant 10240 : i64 + %false = arith.constant false + + %ub_mid = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c10240_i64 : i64 -> !pto.ptr + pto.mte_gm_ub %arg1, %ub_mid, %c0_i64, %c32_i64 + nburst(%c32_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %iv = %c0 to %c1 step %c1 iter_args(%remaining = %c175) -> (i32) { + %byte_offset = arith.addi %iv, %c32 : index + %src, %next = pto.plt_b8 %remaining : i32 -> !pto.mask, i32 + %zero, %_unused = pto.plt_b8 %c0_i32 : i32 -> !pto.mask, i32 + pto.psts %src, %ub_mid[%byte_offset], "NORM" : !pto.mask, !pto.ptr, index + pto.psts %zero, %ub_mid[%c64], "NORM" : !pto.mask, !pto.ptr, index + pto.mem_bar "VST_VLD" + %loaded = pto.plds %ub_mid[%byte_offset], "DS" : !pto.ptr, index -> !pto.mask + pto.psts %loaded, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + scf.yield %next : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.mte_ub_gm %ub_out, %arg2, %c32_i64 + nburst(%c32_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/psts-norm-plds-ds/launch.cpp b/test/vpto/cases/micro-op/predicate-load-store/psts-norm-plds-ds/launch.cpp new file mode 100644 index 000000000..7c9a920cc --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psts-norm-plds-ds/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void psts_norm_plds_ds_kernel_2d(__gm__ float *v1, + __gm__ unsigned char *v2, + __gm__ unsigned char *v3); + +void LaunchPsts_norm_plds_ds_kernel_2d(float *v1, unsigned char *v2, + unsigned char *v3, void *stream) { + psts_norm_plds_ds_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ unsigned char *)v2, + (__gm__ unsigned char *)v3); +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/psts-norm-plds-ds/main.cpp b/test/vpto/cases/micro-op/predicate-load-store/psts-norm-plds-ds/main.cpp new file mode 100644 index 000000000..2f0be35a7 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psts-norm-plds-ds/main.cpp @@ -0,0 +1,95 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/psts-norm-plds-ds +// ----------------------------------------------------------------------------- +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchPsts_norm_plds_ds_kernel_2d(float *v1, unsigned char *v2, + unsigned char *v3, void *stream); + +int main() { + size_t fileSize_v1 = 1024 * sizeof(float); + size_t fileSize_v2 = 1024 * sizeof(unsigned char); + size_t fileSize_v3 = 1024 * sizeof(unsigned char); + float *v1Host = nullptr; + float *v1Device = nullptr; + unsigned char *v2Host = nullptr; + unsigned char *v2Device = nullptr; + unsigned char *v3Host = nullptr; + unsigned char *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchPsts_norm_plds_ds_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary/compare.py b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary/compare.py new file mode 100644 index 000000000..fe48e2bdb --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary +# family: predicate-load-store +# target_ops: pto.plds, pto.psts +# scenarios: predicate-load-store-composition, dynamic-offset, load-store-pair-preservation, representative-logical-elements + +import os +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).resolve().parent.parent)) + +from _predicate_load_store_case import compare_norm_store + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_norm_store("golden_v3.bin", "v3.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary/golden.py b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary/golden.py new file mode 100644 index 000000000..f2ab5e6e3 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary/golden.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary +# family: predicate-load-store +# target_ops: pto.plds, pto.psts +# scenarios: predicate-load-store-composition, dynamic-offset, load-store-pair-preservation, representative-logical-elements + +import argparse +from pathlib import Path +import sys + +sys.path.append(str(Path(__file__).resolve().parent.parent)) + +from _predicate_load_store_case import pk_us_compose, prefix_bits, write_case + + +SEED = 19 +ACTIVE_BITS = 173 + + +def generate(output_dir: Path, seed: int) -> None: + del seed + write_case(output_dir, pk_us_compose(prefix_bits(ACTIVE_BITS))) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate inputs/golden for psts-pk-plds-us-prefix-boundary." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary/kernel.pto b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary/kernel.pto new file mode 100644 index 000000000..c6f014273 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary/kernel.pto @@ -0,0 +1,53 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary +// family: predicate-load-store +// target_ops: pto.plds, pto.psts +// scenarios: predicate-load-store-composition, dynamic-offset, load-store-pair-preservation, representative-logical-elements +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @psts_pk_plds_us_prefix_boundary_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c173 = arith.constant 173 : i32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c10240_i64 = arith.constant 10240 : i64 + %false = arith.constant false + + %ub_mid = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c10240_i64 : i64 -> !pto.ptr + pto.mte_gm_ub %arg1, %ub_mid, %c0_i64, %c32_i64 + nburst(%c32_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %iv = %c0 to %c1 step %c1 iter_args(%remaining = %c173) -> (i32) { + %byte_offset = arith.addi %iv, %c16 : index + %src, %next = pto.plt_b8 %remaining : i32 -> !pto.mask, i32 + pto.psts %src, %ub_mid[%byte_offset], "PK" : !pto.mask, !pto.ptr, index + pto.mem_bar "VST_VLD" + %loaded = pto.plds %ub_mid[%byte_offset], "US" : !pto.ptr, index -> !pto.mask + pto.psts %loaded, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + scf.yield %next : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.mte_ub_gm %ub_out, %arg2, %c32_i64 + nburst(%c32_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary/launch.cpp b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary/launch.cpp new file mode 100644 index 000000000..a2d8377c7 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary/launch.cpp @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void psts_pk_plds_us_prefix_boundary_kernel_2d(__gm__ float *v1, + __gm__ unsigned char *v2, + __gm__ unsigned char *v3); + +void LaunchPsts_pk_plds_us_prefix_boundary_kernel_2d(float *v1, unsigned char *v2, + unsigned char *v3, void *stream) { + psts_pk_plds_us_prefix_boundary_kernel_2d<<<1, nullptr, stream>>>( + (__gm__ float *)v1, (__gm__ unsigned char *)v2, (__gm__ unsigned char *)v3); +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary/main.cpp b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary/main.cpp new file mode 100644 index 000000000..d76ca6ac1 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary/main.cpp @@ -0,0 +1,95 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/psts-pk-plds-us-prefix-boundary +// ----------------------------------------------------------------------------- +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchPsts_pk_plds_us_prefix_boundary_kernel_2d(float *v1, unsigned char *v2, + unsigned char *v3, void *stream); + +int main() { + size_t fileSize_v1 = 1024 * sizeof(float); + size_t fileSize_v2 = 1024 * sizeof(unsigned char); + size_t fileSize_v3 = 1024 * sizeof(unsigned char); + float *v1Host = nullptr; + float *v1Device = nullptr; + unsigned char *v2Host = nullptr; + unsigned char *v2Device = nullptr; + unsigned char *v3Host = nullptr; + unsigned char *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchPsts_pk_plds_us_prefix_boundary_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us/compare.py b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us/compare.py new file mode 100644 index 000000000..8a88450d6 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/predicate-load-store/psts-pk-plds-us +# family: predicate-load-store +# target_ops: pto.plds, pto.psts +# scenarios: predicate-load-store-composition, dynamic-offset, load-store-pair-preservation, representative-logical-elements + +import os +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).resolve().parent.parent)) + +from _predicate_load_store_case import compare_norm_store + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_norm_store("golden_v3.bin", "v3.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us/golden.py b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us/golden.py new file mode 100644 index 000000000..cbbb855be --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us/golden.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/predicate-load-store/psts-pk-plds-us +# family: predicate-load-store +# target_ops: pto.plds, pto.psts +# scenarios: predicate-load-store-composition, dynamic-offset, load-store-pair-preservation, representative-logical-elements + +import argparse +from pathlib import Path +import sys + +sys.path.append(str(Path(__file__).resolve().parent.parent)) + +from _predicate_load_store_case import pk_us_compose, prefix_bits, write_case + + +SEED = 19 +ACTIVE_BITS = 171 + + +def generate(output_dir: Path, seed: int) -> None: + del seed + write_case(output_dir, pk_us_compose(prefix_bits(ACTIVE_BITS))) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate inputs/golden for psts-pk-plds-us.") + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us/kernel.pto b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us/kernel.pto new file mode 100644 index 000000000..6af9f336a --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us/kernel.pto @@ -0,0 +1,52 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/psts-pk-plds-us +// family: predicate-load-store +// target_ops: pto.plds, pto.psts +// scenarios: predicate-load-store-composition, dynamic-offset, load-store-pair-preservation, representative-logical-elements +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @psts_pk_plds_us_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c171 = arith.constant 171 : i32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c10240_i64 = arith.constant 10240 : i64 + %false = arith.constant false + + %ub_mid = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c10240_i64 : i64 -> !pto.ptr + pto.mte_gm_ub %arg1, %ub_mid, %c0_i64, %c32_i64 + nburst(%c32_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %iv = %c0 to %c1 step %c1 iter_args(%remaining = %c171) -> (i32) { + %byte_offset = arith.addi %iv, %c16 : index + %src, %next = pto.plt_b8 %remaining : i32 -> !pto.mask, i32 + pto.psts %src, %ub_mid[%byte_offset], "PK" : !pto.mask, !pto.ptr, index + pto.mem_bar "VST_VLD" + %loaded = pto.plds %ub_mid[%byte_offset], "US" : !pto.ptr, index -> !pto.mask + pto.psts %loaded, %ub_out[%c0], "NORM" : !pto.mask, !pto.ptr, index + scf.yield %next : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop1_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.set_loop2_stride_ubtoout %c32_i64, %c32_i64 : i64, i64 + pto.mte_ub_gm %ub_out, %arg2, %c32_i64 + nburst(%c32_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us/launch.cpp b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us/launch.cpp new file mode 100644 index 000000000..acc8dfb7b --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void psts_pk_plds_us_kernel_2d(__gm__ float *v1, + __gm__ unsigned char *v2, + __gm__ unsigned char *v3); + +void LaunchPsts_pk_plds_us_kernel_2d(float *v1, unsigned char *v2, + unsigned char *v3, void *stream) { + psts_pk_plds_us_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ unsigned char *)v2, + (__gm__ unsigned char *)v3); +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us/main.cpp b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us/main.cpp new file mode 100644 index 000000000..1462814f4 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/psts-pk-plds-us/main.cpp @@ -0,0 +1,95 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/psts-pk-plds-us +// ----------------------------------------------------------------------------- +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchPsts_pk_plds_us_kernel_2d(float *v1, unsigned char *v2, + unsigned char *v3, void *stream); + +int main() { + size_t fileSize_v1 = 1024 * sizeof(float); + size_t fileSize_v2 = 1024 * sizeof(unsigned char); + size_t fileSize_v3 = 1024 * sizeof(unsigned char); + float *v1Host = nullptr; + float *v1Device = nullptr; + unsigned char *v2Host = nullptr; + unsigned char *v2Device = nullptr; + unsigned char *v3Host = nullptr; + unsigned char *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchPsts_pk_plds_us_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/pstu-init-align-outside-loop/compare.py b/test/vpto/cases/micro-op/predicate-load-store/pstu-init-align-outside-loop/compare.py new file mode 100644 index 000000000..845f5233e --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/pstu-init-align-outside-loop/compare.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/predicate-load-store/pstu-init-align-outside-loop +# family: predicate-load-store +# target_ops: pto.pstu +# scenarios: unaligned-predicate-store, state-update, representative-logical-elements, init-align-outside-loop + +import os +import sys +import numpy as np + +EXPECTED_WORDS = 8 + + +def compare_packed_pred_mask(golden_path, output_path): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.uint32) + output = np.fromfile(output_path, dtype=np.uint32) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + return False + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] Mismatch (packed mask words): idx={idx} golden={int(golden[idx])} out={int(output[idx])}") + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_packed_pred_mask("golden_v3.bin", "v3.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/predicate-load-store/pstu-init-align-outside-loop/golden.py b/test/vpto/cases/micro-op/predicate-load-store/pstu-init-align-outside-loop/golden.py new file mode 100644 index 000000000..bf0e5a2ab --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/pstu-init-align-outside-loop/golden.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/predicate-load-store/pstu-init-align-outside-loop +# family: predicate-load-store +# target_ops: pto.pstu +# scenarios: unaligned-predicate-store, state-update, representative-logical-elements, init-align-outside-loop +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +PACKED_BYTES_PER_STORE = 8 +OUTPUT_WORDS = 8 + + +def _pack_mask_b32(active_lanes: int) -> np.ndarray: + if active_lanes < 0 or active_lanes > 64: + raise ValueError(f"active_lanes must be in [0, 64], got {active_lanes}") + logical = np.zeros((64,), dtype=np.uint8) + logical[:active_lanes] = 1 + packed = np.packbits(logical, bitorder="little") + out = np.zeros((PACKED_BYTES_PER_STORE,), dtype=np.uint8) + out[: packed.size] = packed + return out + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-3.0, 3.0, size=(ROWS, COLS)).astype(np.float32) + v2 = rng.uniform(-1.0, 1.0, size=(ROWS, COLS)).astype(np.float32) + + first = _pack_mask_b32(13) + second = _pack_mask_b32(7) + packed = np.zeros((OUTPUT_WORDS * np.dtype(np.uint32).itemsize,), dtype=np.uint8) + packed[:PACKED_BYTES_PER_STORE] = first + packed[PACKED_BYTES_PER_STORE : 2 * PACKED_BYTES_PER_STORE] = second + packed[2 * PACKED_BYTES_PER_STORE : 3 * PACKED_BYTES_PER_STORE] = first + packed[3 * PACKED_BYTES_PER_STORE : 4 * PACKED_BYTES_PER_STORE] = second + output_init = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + output_init.tofile(output_dir / "v3.bin") + packed.view(np.uint32).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op pstu-init-align-outside-loop validation." + ) + parser.add_argument( + "--output-dir", + type=Path, + default=Path("."), + help="Directory where v1.bin/v2.bin/v3.bin/golden_v3.bin are written.", + ) + parser.add_argument( + "--seed", + type=int, + default=SEED, + help="Numpy random seed.", + ) + args = parser.parse_args() + + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/predicate-load-store/pstu-init-align-outside-loop/kernel.pto b/test/vpto/cases/micro-op/predicate-load-store/pstu-init-align-outside-loop/kernel.pto new file mode 100644 index 000000000..474a9011f --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/pstu-init-align-outside-loop/kernel.pto @@ -0,0 +1,46 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/pstu-init-align-outside-loop +// family: predicate-load-store +// target_ops: pto.pstu +// scenarios: unaligned-predicate-store, state-update, representative-logical-elements, init-align-outside-loop +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @pstu_init_align_outside_loop_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c8_i64 = arith.constant 8 : i64 + %c32_i64 = arith.constant 32 : i64 + %c0_i32 = arith.constant 0 : i32 + %c13 = arith.constant 13 : i32 + %c7 = arith.constant 7 : i32 + + %ub_mask = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.vecscope { + %align_init = pto.init_align : !pto.align + %align_final, %base_final = scf.for %iter = %c0 to %c2 step %c1 + iter_args(%align_iter = %align_init, %base_iter = %ub_mask) + -> (!pto.align, !pto.ptr) { + %value, %next = pto.plt_b32 %c13 : i32 -> !pto.mask, i32 + %align_out, %base_out = pto.pstu %align_iter, %value, %base_iter : !pto.align, !pto.mask, !pto.ptr -> !pto.align, !pto.ptr + %value_tail, %next_tail = pto.plt_b32 %c7 : i32 -> !pto.mask, i32 + %align_tail, %base_tail = pto.pstu %align_out, %value_tail, %base_out : !pto.align, !pto.mask, !pto.ptr -> !pto.align, !pto.ptr + scf.yield %align_tail, %base_tail : !pto.align, !pto.ptr + } + pto.vstas %align_final, %base_final, %c0_i32 : !pto.align, !pto.ptr, i32 + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_mask, %arg2, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/pstu-init-align-outside-loop/launch.cpp b/test/vpto/cases/micro-op/predicate-load-store/pstu-init-align-outside-loop/launch.cpp new file mode 100644 index 000000000..1d97093b5 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/pstu-init-align-outside-loop/launch.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/pstu-init-align-outside-loop +// family: predicate-load-store +// target_ops: pto.pstu +// scenarios: unaligned-predicate-store, state-update, representative-logical-elements, init-align-outside-loop +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +pstu_init_align_outside_loop_kernel_2d(__gm__ float *v1, __gm__ float *v2, + __gm__ uint32_t *v3); + +void LaunchPstu_init_align_outside_loop_kernel_2d(float *v1, float *v2, + uint32_t *v3, void *stream) { + pstu_init_align_outside_loop_kernel_2d<<<1, nullptr, stream>>>( + (__gm__ float *)v1, (__gm__ float *)v2, (__gm__ uint32_t *)v3); +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/pstu-init-align-outside-loop/main.cpp b/test/vpto/cases/micro-op/predicate-load-store/pstu-init-align-outside-loop/main.cpp new file mode 100644 index 000000000..ad4157e2f --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/pstu-init-align-outside-loop/main.cpp @@ -0,0 +1,112 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/pstu-init-align-outside-loop +// family: predicate-load-store +// target_ops: pto.pstu +// scenarios: unaligned-predicate-store, state-update, representative-logical-elements, init-align-outside-loop +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchPstu_init_align_outside_loop_kernel_2d(float *v1, float *v2, + uint32_t *v3, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 8; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint32_t); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + uint32_t *v3Host = nullptr; + uint32_t *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchPstu_init_align_outside_loop_kernel_2d(v1Device, v2Device, v3Device, + stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/pstu-state-advance-boundary/compare.py b/test/vpto/cases/micro-op/predicate-load-store/pstu-state-advance-boundary/compare.py new file mode 100755 index 000000000..bf213031a --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/pstu-state-advance-boundary/compare.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/predicate-load-store/pstu-state-advance-boundary +# family: predicate-load-store +# target_ops: pto.pstu +# scenarios: unaligned-predicate-store, state-update, boundary, b16-mask, typed-ptr-b16 + +import os +import sys +import numpy as np + +EXPECTED_WORDS = 16 + + +def compare_packed_pred_mask(golden_path, output_path): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.uint16) + output = np.fromfile(output_path, dtype=np.uint16) + if golden.size != EXPECTED_WORDS or output.size != EXPECTED_WORDS: + return False + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] Mismatch (packed mask words): idx={idx} golden={int(golden[idx])} out={int(output[idx])}") + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_packed_pred_mask("golden_v3.bin", "v3.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/predicate-load-store/pstu-state-advance-boundary/golden.py b/test/vpto/cases/micro-op/predicate-load-store/pstu-state-advance-boundary/golden.py new file mode 100755 index 000000000..f9db13980 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/pstu-state-advance-boundary/golden.py @@ -0,0 +1,80 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/predicate-load-store/pstu-state-advance-boundary +# family: predicate-load-store +# target_ops: pto.pstu +# scenarios: unaligned-predicate-store, state-update, boundary, b16-mask, typed-ptr-b16 +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +PACKED_BYTES_PER_STORE = 16 +OUTPUT_WORDS = 16 + + +def _pack_mask_b16(active_lanes: int) -> np.ndarray: + if active_lanes < 0 or active_lanes > 128: + raise ValueError(f"active_lanes must be in [0, 128], got {active_lanes}") + logical = np.zeros((128,), dtype=np.uint8) + logical[:active_lanes] = 1 + packed = np.packbits(logical, bitorder="little") + out = np.zeros((PACKED_BYTES_PER_STORE,), dtype=np.uint8) + out[: packed.size] = packed + return out + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + + v1 = rng.uniform(-3.0, 3.0, size=(ROWS, COLS)).astype(np.float32) + v2 = rng.uniform(-1.0, 1.0, size=(ROWS, COLS)).astype(np.float32) + + first = _pack_mask_b16(1) + second = _pack_mask_b16(127) + packed = np.concatenate([first, second]).astype(np.uint8, copy=False) + output_init = np.zeros((OUTPUT_WORDS,), dtype=np.uint16) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + output_init.tofile(output_dir / "v3.bin") + packed.view(np.uint16).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op pstu-state-advance-boundary validation." + ) + parser.add_argument( + "--output-dir", + type=Path, + default=Path("."), + help="Directory where v1.bin/v2.bin/v3.bin/golden_v3.bin are written.", + ) + parser.add_argument( + "--seed", + type=int, + default=SEED, + help="Numpy random seed.", + ) + args = parser.parse_args() + + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/predicate-load-store/pstu-state-advance-boundary/kernel.pto b/test/vpto/cases/micro-op/predicate-load-store/pstu-state-advance-boundary/kernel.pto new file mode 100644 index 000000000..69ab6c300 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/pstu-state-advance-boundary/kernel.pto @@ -0,0 +1,41 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/pstu-state-advance-boundary +// family: predicate-load-store +// target_ops: pto.pstu +// scenarios: unaligned-predicate-store, state-update, boundary, b16-mask, typed-ptr-b16 +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @pstu_state_advance_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c127 = arith.constant 127 : i32 + + %ub_mask = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %align0 = pto.init_align : !pto.align + %value0, %next0 = pto.plt_b16 %c1_i32 : i32 -> !pto.mask, i32 + %align1, %base1 = pto.pstu %align0, %value0, %ub_mask : !pto.align, !pto.mask, !pto.ptr -> !pto.align, !pto.ptr + %value1, %next1 = pto.plt_b16 %c127 : i32 -> !pto.mask, i32 + %align2, %base2 = pto.pstu %align1, %value1, %base1 : !pto.align, !pto.mask, !pto.ptr -> !pto.align, !pto.ptr + pto.vstas %align2, %base2, %c0_i32 : !pto.align, !pto.ptr, i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_mask, %arg2, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/pstu-state-advance-boundary/launch.cpp b/test/vpto/cases/micro-op/predicate-load-store/pstu-state-advance-boundary/launch.cpp new file mode 100644 index 000000000..2c01b6ceb --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/pstu-state-advance-boundary/launch.cpp @@ -0,0 +1,58 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/pstu-state-advance-boundary +// family: predicate-load-store +// target_ops: pto.pstu +// scenarios: unaligned-predicate-store, state-update, boundary, b16-mask, typed-ptr-b16 +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void pstu_state_advance_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ uint16_t *v3); + +void LaunchPstu_state_advance_kernel_2d(float *v1, float *v2, uint16_t *v3, + void *stream) { + pstu_state_advance_kernel_2d<<<1, nullptr, stream>>>( + (__gm__ float *)v1, (__gm__ float *)v2, (__gm__ uint16_t *)v3); +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/pstu-state-advance-boundary/main.cpp b/test/vpto/cases/micro-op/predicate-load-store/pstu-state-advance-boundary/main.cpp new file mode 100644 index 000000000..1f97fcc70 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/pstu-state-advance-boundary/main.cpp @@ -0,0 +1,111 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/pstu-state-advance-boundary +// family: predicate-load-store +// target_ops: pto.pstu +// scenarios: unaligned-predicate-store, state-update, boundary, b16-mask, typed-ptr-b16 +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchPstu_state_advance_kernel_2d(float *v1, float *v2, uint16_t *v3, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 16; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint16_t); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + uint16_t *v3Host = nullptr; + uint16_t *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchPstu_state_advance_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/pstu/compare.py b/test/vpto/cases/micro-op/predicate-load-store/pstu/compare.py new file mode 100755 index 000000000..e452bd612 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/pstu/compare.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/predicate-load-store/pstu +# family: predicate-load-store +# target_ops: pto.pstu +# scenarios: unaligned-predicate-store, state-update, representative-logical-elements +# NOTE: bulk-generated coverage skeleton. + +import os +import sys +import numpy as np + +EXPECTED_WORDS = 8 +VALID_BYTES = 16 + + +def compare_packed_pred_mask(golden_path, output_path): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + expected_bytes = EXPECTED_WORDS * np.dtype(np.uint32).itemsize + if golden.size != expected_bytes or output.size != expected_bytes: + return False + if not np.array_equal(golden[:VALID_BYTES], output[:VALID_BYTES]): + diff = np.nonzero(golden[:VALID_BYTES] != output[:VALID_BYTES])[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] Mismatch (packed mask words): idx={idx} golden={int(golden[idx])} out={int(output[idx])}") + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_packed_pred_mask("golden_v3.bin", "v3.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/predicate-load-store/pstu/golden.py b/test/vpto/cases/micro-op/predicate-load-store/pstu/golden.py new file mode 100755 index 000000000..678185011 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/pstu/golden.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/predicate-load-store/pstu +# family: predicate-load-store +# target_ops: pto.pstu +# scenarios: unaligned-predicate-store, state-update, representative-logical-elements +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +PACKED_BYTES_PER_STORE = 8 +OUTPUT_WORDS = 8 + + +def _pack_mask_b32(active_lanes: int) -> np.ndarray: + if active_lanes < 0 or active_lanes > 64: + raise ValueError(f"active_lanes must be in [0, 64], got {active_lanes}") + logical = np.zeros((64,), dtype=np.uint8) + logical[:active_lanes] = 1 + packed = np.packbits(logical, bitorder="little") + out = np.zeros((PACKED_BYTES_PER_STORE,), dtype=np.uint8) + out[: packed.size] = packed + return out + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-3.0, 3.0, size=(ROWS, COLS)).astype(np.float32) + v2 = rng.uniform(-1.0, 1.0, size=(ROWS, COLS)).astype(np.float32) + + first = _pack_mask_b32(13) + second = _pack_mask_b32(7) + packed = np.zeros((OUTPUT_WORDS * np.dtype(np.uint32).itemsize,), dtype=np.uint8) + packed[:PACKED_BYTES_PER_STORE] = first + packed[PACKED_BYTES_PER_STORE : 2 * PACKED_BYTES_PER_STORE] = second + output_init = np.zeros((OUTPUT_WORDS,), dtype=np.uint32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + output_init.tofile(output_dir / "v3.bin") + packed.view(np.uint32).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op pstu validation." + ) + parser.add_argument( + "--output-dir", + type=Path, + default=Path("."), + help="Directory where v1.bin/v2.bin/v3.bin/golden_v3.bin are written.", + ) + parser.add_argument( + "--seed", + type=int, + default=SEED, + help="Numpy random seed.", + ) + args = parser.parse_args() + + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/predicate-load-store/pstu/kernel.pto b/test/vpto/cases/micro-op/predicate-load-store/pstu/kernel.pto new file mode 100644 index 000000000..f879fec94 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/pstu/kernel.pto @@ -0,0 +1,42 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/pstu +// family: predicate-load-store +// target_ops: pto.pstu +// scenarios: unaligned-predicate-store, state-update, representative-logical-elements +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @pstu_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c8_i64 = arith.constant 8 : i64 + %c32_i64 = arith.constant 32 : i64 + %c0_i32 = arith.constant 0 : i32 + %c13 = arith.constant 13 : i32 + %c7 = arith.constant 7 : i32 + + %ub_mask = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.vecscope { + scf.for %iter = %c0 to %c1 step %c1 { + %align = pto.init_align : !pto.align + %value, %next = pto.plt_b32 %c13 : i32 -> !pto.mask, i32 + %align_out, %base_out = pto.pstu %align, %value, %ub_mask : !pto.align, !pto.mask, !pto.ptr -> !pto.align, !pto.ptr + %value_tail, %next_tail = pto.plt_b32 %c7 : i32 -> !pto.mask, i32 + %align_tail, %base_tail = pto.pstu %align_out, %value_tail, %base_out : !pto.align, !pto.mask, !pto.ptr -> !pto.align, !pto.ptr + pto.vstas %align_tail, %base_tail, %c0_i32 : !pto.align, !pto.ptr, i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_mask, %arg2, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/pstu/launch.cpp b/test/vpto/cases/micro-op/predicate-load-store/pstu/launch.cpp new file mode 100644 index 000000000..977155180 --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/pstu/launch.cpp @@ -0,0 +1,57 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/pstu +// family: predicate-load-store +// target_ops: pto.pstu +// scenarios: unaligned-predicate-store, state-update, representative-logical-elements +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void pstu_kernel_2d(__gm__ float *v1, + __gm__ float *v2, + __gm__ uint32_t *v3); + +void LaunchPstu_kernel_2d(float *v1, float *v2, uint32_t *v3, void *stream) { + pstu_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2, + (__gm__ uint32_t *)v3); +} diff --git a/test/vpto/cases/micro-op/predicate-load-store/pstu/main.cpp b/test/vpto/cases/micro-op/predicate-load-store/pstu/main.cpp new file mode 100644 index 000000000..5c31f323d --- /dev/null +++ b/test/vpto/cases/micro-op/predicate-load-store/pstu/main.cpp @@ -0,0 +1,110 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/predicate-load-store/pstu +// family: predicate-load-store +// target_ops: pto.pstu +// scenarios: unaligned-predicate-store, state-update, representative-logical-elements +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchPstu_kernel_2d(float *v1, float *v2, uint32_t *v3, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + size_t elemCount_v3 = 8; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint32_t); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + uint32_t *v3Host = nullptr; + uint32_t *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchPstu_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv-lane-boundary/compare.py b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv-lane-boundary/compare.py new file mode 100755 index 000000000..f2a3e0459 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv-lane-boundary/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/rearrangement/vintlv-vdintlv-lane-boundary +# family: rearrangement +# target_ops: pto.vdintlv, pto.vintlv +# scenarios: paired-roundtrip, lane-order +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv-lane-boundary/golden.py b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv-lane-boundary/golden.py new file mode 100755 index 000000000..bc2746fab --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv-lane-boundary/golden.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/rearrangement/vintlv-vdintlv-lane-boundary +# family: rearrangement +# target_ops: pto.vdintlv, pto.vintlv +# scenarios: paired-roundtrip, lane-order +# NOTE: paired vintlv+vdintlv roundtrip should recover the original input, including lane-boundary patterns. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + flat = rng.uniform(-8.0, 8.0, size=ROWS * COLS).astype(np.float32) + for base in range(0, flat.size, 128): + flat[base + 62 : base + 66] = np.array([-62.0, -1.0, 1.0, 62.0], dtype=np.float32) + v1 = flat.reshape(ROWS, COLS) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = v1.astype(np.float32, copy=True) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vintlv+vdintlv lane-boundary validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv-lane-boundary/kernel.pto b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv-lane-boundary/kernel.pto new file mode 100644 index 000000000..370672fc3 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv-lane-boundary/kernel.pto @@ -0,0 +1,54 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vintlv-vdintlv-lane-boundary +// family: rearrangement +// target_ops: pto.vdintlv, pto.vintlv +// scenarios: paired-roundtrip, lane-order, boundary +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vintlv_vdintlv_boundary_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %rhs_offset = arith.addi %offset, %c64 : index + %lhs = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %rhs = pto.vlds %ub_in[%rhs_offset] : !pto.ptr -> !pto.vreg<64xf32> + %ilow, %ihigh = pto.vintlv %lhs, %rhs : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %dlow, %dhigh = pto.vdintlv %ilow, %ihigh : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + pto.vsts %dlow, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + pto.vsts %dhigh, %ub_out[%rhs_offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv-lane-boundary/launch.cpp b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv-lane-boundary/launch.cpp new file mode 100644 index 000000000..f7bb8bf5a --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv-lane-boundary/launch.cpp @@ -0,0 +1,71 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vintlv-vdintlv-lane-boundary +// family: rearrangement +// target_ops: pto.vdintlv, pto.vintlv +// scenarios: paired-roundtrip, lane-order +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vintlv_vdintlv_boundary_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVintlv_vdintlv_boundary_kernel_2d(float *v1, float *v2, void *stream) { + vintlv_vdintlv_boundary_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv-lane-boundary/main.cpp b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv-lane-boundary/main.cpp new file mode 100644 index 000000000..f6fb0606b --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv-lane-boundary/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vintlv-vdintlv-lane-boundary +// family: rearrangement +// target_ops: pto.vdintlv, pto.vintlv +// scenarios: paired-roundtrip, lane-order +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVintlv_vdintlv_boundary_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVintlv_vdintlv_boundary_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv/compare.py b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv/compare.py new file mode 100755 index 000000000..afa3d870f --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/rearrangement/vintlv-vdintlv +# family: rearrangement +# target_ops: pto.vdintlv, pto.vintlv +# scenarios: paired-roundtrip, lane-order +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv/golden.py b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv/golden.py new file mode 100755 index 000000000..8878cae52 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/rearrangement/vintlv-vdintlv +# family: rearrangement +# target_ops: pto.vdintlv, pto.vintlv +# scenarios: paired-roundtrip, lane-order +# NOTE: paired vintlv+vdintlv roundtrip should recover the original input. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = v1.astype(np.float32, copy=True) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vintlv+vdintlv validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv/kernel.pto b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv/kernel.pto new file mode 100644 index 000000000..4dc0cc115 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv/kernel.pto @@ -0,0 +1,54 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vintlv-vdintlv +// family: rearrangement +// target_ops: pto.vdintlv, pto.vintlv +// scenarios: paired-roundtrip, lane-order +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vintlv_vdintlv_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %rhs_offset = arith.addi %offset, %c64 : index + %lhs = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %rhs = pto.vlds %ub_in[%rhs_offset] : !pto.ptr -> !pto.vreg<64xf32> + %ilow, %ihigh = pto.vintlv %lhs, %rhs : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %dlow, %dhigh = pto.vdintlv %ilow, %ihigh : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + pto.vsts %dlow, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + pto.vsts %dhigh, %ub_out[%rhs_offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv/launch.cpp b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv/launch.cpp new file mode 100644 index 000000000..27baf3164 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv/launch.cpp @@ -0,0 +1,71 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vintlv-vdintlv +// family: rearrangement +// target_ops: pto.vdintlv, pto.vintlv +// scenarios: paired-roundtrip, lane-order +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vintlv_vdintlv_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVintlv_vdintlv_kernel_2d(float *v1, float *v2, void *stream) { + vintlv_vdintlv_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv/main.cpp b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv/main.cpp new file mode 100644 index 000000000..0a66ddb10 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vintlv-vdintlv/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vintlv-vdintlv +// family: rearrangement +// target_ops: pto.vdintlv, pto.vintlv +// scenarios: paired-roundtrip, lane-order +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVintlv_vdintlv_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVintlv_vdintlv_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/rearrangement/vpack-higher/compare.py b/test/vpto/cases/micro-op/rearrangement/vpack-higher/compare.py new file mode 100644 index 000000000..c318abde7 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vpack-higher/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/rearrangement/vpack-higher +# family: rearrangement +# target_ops: pto.vpack +# scenarios: narrowing, higher-half-placement, zero-fill-lower-half +# coding=utf-8 +import os +import sys + +import numpy as np + + +def compare_bin(golden_path, output_path, dtype): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.uint16) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/rearrangement/vpack-higher/golden.py b/test/vpto/cases/micro-op/rearrangement/vpack-higher/golden.py new file mode 100644 index 000000000..b97089067 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vpack-higher/golden.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/rearrangement/vpack-higher +# family: rearrangement +# target_ops: pto.vpack +# scenarios: narrowing, higher-half-placement, zero-fill-lower-half, post-pack-consumer +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +ELEMS = ROWS * COLS +CHUNK = 64 +OUTPUT_ELEMS = ELEMS * 2 +SEED = 19 +BIAS = np.uint16(1) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(-(1 << 20), 1 << 20, size=ELEMS, dtype=np.int32) + v2 = np.zeros(OUTPUT_ELEMS, dtype=np.uint16) + golden_v2 = np.zeros(OUTPUT_ELEMS, dtype=np.uint16) + + narrowed = v1.astype(np.uint16, copy=False) + for chunk_base in range(0, ELEMS, CHUNK): + chunk = narrowed[chunk_base : chunk_base + CHUNK] + out_base = (chunk_base // CHUNK) * (CHUNK * 2) + golden_v2[out_base : out_base + CHUNK] = BIAS + golden_v2[out_base + CHUNK : out_base + 2 * CHUNK] = ( + chunk.astype(np.uint32) + int(BIAS) + ).astype(np.uint16) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vpack-higher validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/rearrangement/vpack-higher/kernel.pto b/test/vpto/cases/micro-op/rearrangement/vpack-higher/kernel.pto new file mode 100644 index 000000000..cb86876bf --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vpack-higher/kernel.pto @@ -0,0 +1,51 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vpack-higher +// family: rearrangement +// target_ops: pto.vpack +// scenarios: narrowing, higher-half-placement, zero-fill-lower-half, post-pack-consumer +// ----------------------------------------------------------------------------- + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vpack_higher_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c1_i16 = arith.constant 1 : i16 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %store_mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %src_offset = %c0 to %c1024 step %c64 { + %dst_offset = arith.muli %src_offset, %c2 : index + %vec = pto.vlds %ub_in[%src_offset] : !pto.ptr -> !pto.vreg<64xi32> + %packed = pto.vpack %vec, "HIGHER" : !pto.vreg<64xi32> -> !pto.vreg<128xui16> + %observed = pto.vadds %packed, %c1_i16, %store_mask : !pto.vreg<128xui16>, i16, !pto.mask -> !pto.vreg<128xui16> + pto.vsts %observed, %ub_out[%dst_offset], %store_mask : !pto.vreg<128xui16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/rearrangement/vpack-higher/launch.cpp b/test/vpto/cases/micro-op/rearrangement/vpack-higher/launch.cpp new file mode 100644 index 000000000..3d5224385 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vpack-higher/launch.cpp @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vpack-higher +// family: rearrangement +// target_ops: pto.vpack +// scenarios: narrowing, higher-half-placement, zero-fill-lower-half +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vpack_higher_kernel_2d(__gm__ int *v1, + __gm__ uint16_t *v2); + +void LaunchVpack_higher_kernel_2d(int32_t *v1, uint16_t *v2, void *stream) { + vpack_higher_kernel_2d<<<1, nullptr, stream>>>((__gm__ int *)v1, + (__gm__ uint16_t *)v2); +} diff --git a/test/vpto/cases/micro-op/rearrangement/vpack-higher/main.cpp b/test/vpto/cases/micro-op/rearrangement/vpack-higher/main.cpp new file mode 100644 index 000000000..b28ee6e85 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vpack-higher/main.cpp @@ -0,0 +1,128 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vpack-higher +// family: rearrangement +// target_ops: pto.vpack +// scenarios: narrowing, higher-half-placement, zero-fill-lower-half +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVpack_higher_kernel_2d(int32_t *v1, uint16_t *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(int32_t); + size_t elemCount_v2 = 2048; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + int32_t *v1Host = nullptr; + int32_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVpack_higher_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/rearrangement/vpack-lower/compare.py b/test/vpto/cases/micro-op/rearrangement/vpack-lower/compare.py new file mode 100644 index 000000000..0caf7195c --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vpack-lower/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/rearrangement/vpack-lower +# family: rearrangement +# target_ops: pto.vpack +# scenarios: narrowing, lower-half-placement, zero-fill-upper-half +# coding=utf-8 +import os +import sys + +import numpy as np + + +def compare_bin(golden_path, output_path, dtype): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.uint16) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/rearrangement/vpack-lower/golden.py b/test/vpto/cases/micro-op/rearrangement/vpack-lower/golden.py new file mode 100644 index 000000000..37ca69a25 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vpack-lower/golden.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/rearrangement/vpack-lower +# family: rearrangement +# target_ops: pto.vpack +# scenarios: narrowing, lower-half-placement, zero-fill-upper-half, post-pack-consumer +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +ELEMS = ROWS * COLS +CHUNK = 64 +OUTPUT_ELEMS = ELEMS * 2 +SEED = 19 +BIAS = np.uint16(1) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(-(1 << 20), 1 << 20, size=ELEMS, dtype=np.int32) + v2 = np.zeros(OUTPUT_ELEMS, dtype=np.uint16) + golden_v2 = np.zeros(OUTPUT_ELEMS, dtype=np.uint16) + + narrowed = v1.astype(np.uint16, copy=False) + for chunk_base in range(0, ELEMS, CHUNK): + chunk = narrowed[chunk_base : chunk_base + CHUNK] + out_base = (chunk_base // CHUNK) * (CHUNK * 2) + golden_v2[out_base : out_base + CHUNK] = ( + chunk.astype(np.uint32) + int(BIAS) + ).astype(np.uint16) + golden_v2[out_base + CHUNK : out_base + 2 * CHUNK] = BIAS + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vpack-lower validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/rearrangement/vpack-lower/kernel.pto b/test/vpto/cases/micro-op/rearrangement/vpack-lower/kernel.pto new file mode 100644 index 000000000..10294f0f7 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vpack-lower/kernel.pto @@ -0,0 +1,51 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vpack-lower +// family: rearrangement +// target_ops: pto.vpack +// scenarios: narrowing, lower-half-placement, zero-fill-upper-half, post-pack-consumer +// ----------------------------------------------------------------------------- + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vpack_lower_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c1_i16 = arith.constant 1 : i16 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %store_mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %src_offset = %c0 to %c1024 step %c64 { + %dst_offset = arith.muli %src_offset, %c2 : index + %vec = pto.vlds %ub_in[%src_offset] : !pto.ptr -> !pto.vreg<64xi32> + %packed = pto.vpack %vec, "LOWER" : !pto.vreg<64xi32> -> !pto.vreg<128xui16> + %observed = pto.vadds %packed, %c1_i16, %store_mask : !pto.vreg<128xui16>, i16, !pto.mask -> !pto.vreg<128xui16> + pto.vsts %observed, %ub_out[%dst_offset], %store_mask : !pto.vreg<128xui16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/rearrangement/vpack-lower/launch.cpp b/test/vpto/cases/micro-op/rearrangement/vpack-lower/launch.cpp new file mode 100644 index 000000000..3bf8b0da1 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vpack-lower/launch.cpp @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vpack-lower +// family: rearrangement +// target_ops: pto.vpack +// scenarios: narrowing, lower-half-placement, zero-fill-upper-half +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vpack_lower_kernel_2d(__gm__ int *v1, + __gm__ uint16_t *v2); + +void LaunchVpack_lower_kernel_2d(int32_t *v1, uint16_t *v2, void *stream) { + vpack_lower_kernel_2d<<<1, nullptr, stream>>>((__gm__ int *)v1, + (__gm__ uint16_t *)v2); +} diff --git a/test/vpto/cases/micro-op/rearrangement/vpack-lower/main.cpp b/test/vpto/cases/micro-op/rearrangement/vpack-lower/main.cpp new file mode 100644 index 000000000..5cc58448c --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vpack-lower/main.cpp @@ -0,0 +1,128 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vpack-lower +// family: rearrangement +// target_ops: pto.vpack +// scenarios: narrowing, lower-half-placement, zero-fill-upper-half +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVpack_lower_kernel_2d(int32_t *v1, uint16_t *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(int32_t); + size_t elemCount_v2 = 2048; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + int32_t *v1Host = nullptr; + int32_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVpack_lower_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/rearrangement/vsqz-nontrivial-mask/compare.py b/test/vpto/cases/micro-op/rearrangement/vsqz-nontrivial-mask/compare.py new file mode 100755 index 000000000..4cca2c574 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vsqz-nontrivial-mask/compare.py @@ -0,0 +1,208 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/rearrangement/vsqz-nontrivial-mask +# family: rearrangement +# target_ops: pto.vsqz +# scenarios: predicate-driven-rearrangement, stable-order, nontrivial-mask +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/rearrangement/vsqz-nontrivial-mask/golden.py b/test/vpto/cases/micro-op/rearrangement/vsqz-nontrivial-mask/golden.py new file mode 100755 index 000000000..e7d456a40 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vsqz-nontrivial-mask/golden.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/rearrangement/vsqz-nontrivial-mask +# family: rearrangement +# target_ops: pto.vsqz +# scenarios: predicate-driven-rearrangement, stable-order, nontrivial-mask +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +LANES = 64 +BLOCKS = ROWS * COLS // LANES +ACTIVE_POSITIONS = [1, 4, 5, 9, 12, 16, 21, 24, 29, 33, 36, 40, 45, 49, 54, 60] +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + values = rng.uniform(-8.0, 8.0, size=(BLOCKS, LANES)).astype(np.float32) + mask_seed = np.full((BLOCKS, LANES), -1.0, dtype=np.float32) + golden = np.zeros((BLOCKS, LANES), dtype=np.float32) + + for block in range(BLOCKS): + for pos in ACTIVE_POSITIONS: + mask_seed[block, pos] = 1.0 + kept = values[block, ACTIVE_POSITIONS] + golden[block, :kept.size] = kept + + output_dir.mkdir(parents=True, exist_ok=True) + values.reshape(-1).tofile(output_dir / "v1.bin") + mask_seed.reshape(-1).tofile(output_dir / "v2.bin") + golden.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate nontrivial-mask inputs/golden for VPTO micro-op vsqz validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/rearrangement/vsqz-nontrivial-mask/kernel.pto b/test/vpto/cases/micro-op/rearrangement/vsqz-nontrivial-mask/kernel.pto new file mode 100644 index 000000000..b2e00d2df --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vsqz-nontrivial-mask/kernel.pto @@ -0,0 +1,65 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vsqz-nontrivial-mask +// family: rearrangement +// target_ops: pto.vsqz +// scenarios: predicate-driven-rearrangement, stable-order, nontrivial-mask +// ----------------------------------------------------------------------------- +// Validate nontrivial predicate-driven compaction: +// - arg0 provides input values. +// - arg1 provides a mask seed (positive => keep lane; non-positive => drop lane) +// and receives the compacted output. +// For each 64-lane chunk: +// 1. Build placement mask via vcmps(mask_seed > 0). +// 2. Run vsqz using that placement mask. +// 3. Store full compacted vector (kept lanes first, tail zeroed) back to UB. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vsqz_nontrivial_mask_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i32 = arith.constant 64 : i32 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %zero_f32 = arith.constant 0.0 : f32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_mask_seed = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_mask_seed, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %offset = %c0 to %c1024 step %c64 { + %store_mask, %unused = pto.plt_b32 %c64_i32 : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %mask_seed = pto.vlds %ub_mask_seed[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %place = pto.vcmps %mask_seed, %zero_f32, %store_mask, "gt" : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask + %out = pto.vsqz %vec, %place : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %store_mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/rearrangement/vsqz-nontrivial-mask/launch.cpp b/test/vpto/cases/micro-op/rearrangement/vsqz-nontrivial-mask/launch.cpp new file mode 100644 index 000000000..be43cc98f --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vsqz-nontrivial-mask/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vsqz-nontrivial-mask +// family: rearrangement +// target_ops: pto.vsqz +// scenarios: predicate-driven-rearrangement, stable-order +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vsqz_nontrivial_mask_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVsqzNontrivialMask_kernel_2d(float *v1, float *v2, void *stream) { + vsqz_nontrivial_mask_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/rearrangement/vsqz-nontrivial-mask/main.cpp b/test/vpto/cases/micro-op/rearrangement/vsqz-nontrivial-mask/main.cpp new file mode 100644 index 000000000..8d467ea02 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vsqz-nontrivial-mask/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vsqz-nontrivial-mask +// family: rearrangement +// target_ops: pto.vsqz +// scenarios: predicate-driven-rearrangement, stable-order +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVsqzNontrivialMask_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVsqzNontrivialMask_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/rearrangement/vsqz/compare.py b/test/vpto/cases/micro-op/rearrangement/vsqz/compare.py new file mode 100755 index 000000000..f10e14e5f --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vsqz/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/rearrangement/vsqz +# family: rearrangement +# target_ops: pto.vsqz +# scenarios: predicate-driven-rearrangement, stable-order +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/rearrangement/vsqz/golden.py b/test/vpto/cases/micro-op/rearrangement/vsqz/golden.py new file mode 100755 index 000000000..5722d4362 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vsqz/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/rearrangement/vsqz +# family: rearrangement +# target_ops: pto.vsqz +# scenarios: predicate-driven-rearrangement, stable-order +# NOTE: full-mask compaction should preserve original lane order. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = v1.copy() + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vsqz validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/rearrangement/vsqz/kernel.pto b/test/vpto/cases/micro-op/rearrangement/vsqz/kernel.pto new file mode 100644 index 000000000..614d4ec87 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vsqz/kernel.pto @@ -0,0 +1,70 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vsqz +// family: rearrangement +// target_ops: pto.vsqz +// scenarios: predicate-driven-rearrangement, stable-order +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vsqz_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vsqz %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/rearrangement/vsqz/launch.cpp b/test/vpto/cases/micro-op/rearrangement/vsqz/launch.cpp new file mode 100644 index 000000000..511bb9f16 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vsqz/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vsqz +// family: rearrangement +// target_ops: pto.vsqz +// scenarios: predicate-driven-rearrangement, stable-order +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vsqz_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVsqz_kernel_2d(float *v1, float *v2, void *stream) { + vsqz_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/rearrangement/vsqz/main.cpp b/test/vpto/cases/micro-op/rearrangement/vsqz/main.cpp new file mode 100644 index 000000000..ddb3c318b --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vsqz/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vsqz +// family: rearrangement +// target_ops: pto.vsqz +// scenarios: predicate-driven-rearrangement, stable-order +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVsqz_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVsqz_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/rearrangement/vsunpack/compare.py b/test/vpto/cases/micro-op/rearrangement/vsunpack/compare.py new file mode 100755 index 000000000..85f2fa8b2 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vsunpack/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/rearrangement/vsunpack +# family: rearrangement +# target_ops: pto.vsunpack +# scenarios: pack-unpack, sign-extend +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.int32, 0.0) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/rearrangement/vsunpack/golden.py b/test/vpto/cases/micro-op/rearrangement/vsunpack/golden.py new file mode 100755 index 000000000..8ceca3b5e --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vsunpack/golden.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/rearrangement/vsunpack +# family: rearrangement +# target_ops: pto.vsunpack +# scenarios: pack-unpack, sign-extend +# NOTE: sign-extending unpack of the lower half of each 128-lane i16 chunk. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +INPUT_ELEMS = 2048 +OUTPUT_ELEMS = 1024 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(np.iinfo(np.int16).min, np.iinfo(np.int16).max + 1, size=INPUT_ELEMS, dtype=np.int16) + v2 = np.zeros(OUTPUT_ELEMS, dtype=np.int32) + golden_v2 = np.zeros(OUTPUT_ELEMS, dtype=np.int32) + for src_base in range(0, INPUT_ELEMS, 128): + dst_base = (src_base // 128) * 64 + golden_v2[dst_base : dst_base + 64] = v1[src_base : src_base + 64].astype(np.int32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vsunpack validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/rearrangement/vsunpack/kernel.pto b/test/vpto/cases/micro-op/rearrangement/vsunpack/kernel.pto new file mode 100644 index 000000000..bd98ba521 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vsunpack/kernel.pto @@ -0,0 +1,72 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vsunpack +// family: rearrangement +// target_ops: pto.vsunpack +// scenarios: pack-unpack, sign-extend +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vsunpack_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c2 = arith.constant 2 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %part = arith.constant 0 : index + + %gm_in = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + %gm_out = pto.castptr %arg1 : !pto.ptr -> !pto.ptr + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %gm_in, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %store_mask = pto.pset_b32 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c64 { + %src_offset = arith.muli %offset, %c2 : index + %vec = pto.vlds %ub_in[%src_offset] : !pto.ptr -> !pto.vreg<128xi16> + %out = pto.vsunpack %vec, %part : !pto.vreg<128xi16> -> !pto.vreg<64xi32> + pto.vsts %out, %ub_out[%offset], %store_mask : !pto.vreg<64xi32>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %gm_out, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/rearrangement/vsunpack/launch.cpp b/test/vpto/cases/micro-op/rearrangement/vsunpack/launch.cpp new file mode 100644 index 000000000..938b0ee13 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vsunpack/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vsunpack +// family: rearrangement +// target_ops: pto.vsunpack +// scenarios: pack-unpack, sign-extend +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vsunpack_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVsunpack_kernel_2d(float *v1, float *v2, void *stream) { + vsunpack_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/rearrangement/vsunpack/main.cpp b/test/vpto/cases/micro-op/rearrangement/vsunpack/main.cpp new file mode 100644 index 000000000..8b0e546ac --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vsunpack/main.cpp @@ -0,0 +1,132 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vsunpack +// family: rearrangement +// target_ops: pto.vsunpack +// scenarios: pack-unpack, sign-extend +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVsunpack_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 2048; + size_t fileSize_v1 = elemCount_v1 * sizeof(int16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(int32_t); + int16_t *v1Host = nullptr; + int16_t *v1Device = nullptr; + int32_t *v2Host = nullptr; + int32_t *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVsunpack_kernel_2d(reinterpret_cast(v1Device), + reinterpret_cast(v2Device), + stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/rearrangement/vusqz-nontrivial-mask/compare.py b/test/vpto/cases/micro-op/rearrangement/vusqz-nontrivial-mask/compare.py new file mode 100644 index 000000000..c5d68b8e4 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vusqz-nontrivial-mask/compare.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/rearrangement/vusqz-nontrivial-mask +# family: rearrangement +# target_ops: pto.vusqz +# scenarios: predicate-driven-rearrangement, prefix-count + +import sys +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v3.bin", dtype=np.int32) + output = np.fromfile("v3.bin", dtype=np.int32) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden.shape} vs {output.shape}") + sys.exit(2) + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch at idx={idx}: golden={int(golden[idx])} out={int(output[idx])}" + ) + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/rearrangement/vusqz-nontrivial-mask/golden.py b/test/vpto/cases/micro-op/rearrangement/vusqz-nontrivial-mask/golden.py new file mode 100644 index 000000000..81fc36cbc --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vusqz-nontrivial-mask/golden.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/rearrangement/vusqz-nontrivial-mask +# family: rearrangement +# target_ops: pto.vusqz +# scenarios: predicate-driven-rearrangement, prefix-count + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +LANES = 64 +BLOCKS = ROWS * COLS // LANES +ACTIVE_POSITIONS = [1, 4, 5, 9, 12, 16, 21, 24, 29, 33, 36, 40, 45, 49, 54, 60] +SEED = 19 + + +def build_case() -> tuple[np.ndarray, np.ndarray, np.ndarray]: + src = np.zeros((BLOCKS, LANES), dtype=np.int32) + mask_seed = np.full((BLOCKS, LANES), -1.0, dtype=np.float32) + out = np.zeros((BLOCKS, LANES), dtype=np.int32) + + for block in range(BLOCKS): + src[block] = np.arange(block * 1000 + 7, block * 1000 + 7 + LANES, dtype=np.int32) + for pos in ACTIVE_POSITIONS: + mask_seed[block, pos] = 1.0 + active_count = 0 + out[block, 0] = 0 + for lane in range(1, LANES): + if mask_seed[block, lane - 1] > 0.0: + active_count += 1 + out[block, lane] = active_count + + return src.reshape(ROWS, COLS), mask_seed.reshape(ROWS, COLS), out.reshape(ROWS, COLS) + + +def generate(output_dir: Path) -> None: + src, mask_seed, out = build_case() + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + mask_seed.reshape(-1).tofile(output_dir / "v2.bin") + out.reshape(-1).tofile(output_dir / "golden_v3.bin") + np.zeros_like(out.reshape(-1)).tofile(output_dir / "v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate vusqz nontrivial prefix-count inputs/golden." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + del args.seed + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/rearrangement/vusqz-nontrivial-mask/kernel.pto b/test/vpto/cases/micro-op/rearrangement/vusqz-nontrivial-mask/kernel.pto new file mode 100644 index 000000000..440685449 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vusqz-nontrivial-mask/kernel.pto @@ -0,0 +1,59 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vusqz-nontrivial-mask +// family: rearrangement +// target_ops: pto.vusqz +// scenarios: predicate-driven-rearrangement, prefix-count +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vusqz_nontrivial_mask_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i32 = arith.constant 64 : i32 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %zero_f32 = arith.constant 0.0 : f32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_mask_seed = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_src, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_mask_seed, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %offset = %c0 to %c1024 step %c64 { + %store_mask, %unused = pto.plt_b32 %c64_i32 : i32 -> !pto.mask, i32 + %src = pto.vlds %ub_src[%offset] : !pto.ptr -> !pto.vreg<64xi32> + %mask_seed = pto.vlds %ub_mask_seed[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %place = pto.vcmps %mask_seed, %zero_f32, %store_mask, "gt" : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask + %out = pto.vusqz %src, %place : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> + pto.vsts %out, %ub_out[%offset], %store_mask : !pto.vreg<64xi32>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/rearrangement/vusqz-nontrivial-mask/launch.cpp b/test/vpto/cases/micro-op/rearrangement/vusqz-nontrivial-mask/launch.cpp new file mode 100644 index 000000000..e9edd85e3 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vusqz-nontrivial-mask/launch.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vusqz-nontrivial-mask +// family: rearrangement +// target_ops: pto.vusqz +// scenarios: predicate-driven-rearrangement, prefix-count +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vusqz_nontrivial_mask_kernel_2d(__gm__ int32_t *v1, + __gm__ float *v2, + __gm__ int32_t *v3); + +void LaunchVusqz_nontrivial_mask_kernel_2d(int32_t *v1, + float *v2, + int32_t *v3, + void *stream) { + vusqz_nontrivial_mask_kernel_2d<<<1, nullptr, stream>>>( + (__gm__ int32_t *)v1, (__gm__ float *)v2, (__gm__ int32_t *)v3); +} diff --git a/test/vpto/cases/micro-op/rearrangement/vusqz-nontrivial-mask/main.cpp b/test/vpto/cases/micro-op/rearrangement/vusqz-nontrivial-mask/main.cpp new file mode 100644 index 000000000..50190e7f3 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vusqz-nontrivial-mask/main.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vusqz-nontrivial-mask +// family: rearrangement +// target_ops: pto.vusqz +// scenarios: predicate-driven-rearrangement, prefix-count +// ----------------------------------------------------------------------------- +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, \ + __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVusqz_nontrivial_mask_kernel_2d(int32_t *v1, + float *v2, + int32_t *v3, + void *stream); + +int main() { + constexpr size_t elemCount = 1024; + size_t fileSizeV1 = elemCount * sizeof(int32_t); + size_t fileSizeV2 = elemCount * sizeof(float); + size_t fileSizeV3 = elemCount * sizeof(int32_t); + int32_t *v1Host = nullptr; + int32_t *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int32_t *v3Host = nullptr; + int32_t *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSizeV1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSizeV2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSizeV3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSizeV1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSizeV2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSizeV3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSizeV1, v1Host, fileSizeV1); + ReadFile("./v2.bin", fileSizeV2, v2Host, fileSizeV2); + std::fill_n(v3Host, elemCount, 0); + ACL_CHECK(aclrtMemcpy(v1Device, fileSizeV1, v1Host, fileSizeV1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSizeV2, v2Host, fileSizeV2, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSizeV3, v3Host, fileSizeV3, ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVusqz_nontrivial_mask_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSizeV3, v3Device, fileSizeV3, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSizeV3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + (void)aclrtDestroyStream(stream); + if (deviceSet) + (void)aclrtResetDevice(deviceId); + if (aclInited) + (void)aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/rearrangement/vusqz/compare.py b/test/vpto/cases/micro-op/rearrangement/vusqz/compare.py new file mode 100644 index 000000000..6f3603aab --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vusqz/compare.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/rearrangement/vusqz +# family: rearrangement +# target_ops: pto.vusqz +# scenarios: predicate-driven-rearrangement, prefix-count + +import sys +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v3.bin", dtype=np.int32) + output = np.fromfile("v3.bin", dtype=np.int32) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden.shape} vs {output.shape}") + sys.exit(2) + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch at idx={idx}: golden={int(golden[idx])} out={int(output[idx])}" + ) + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/rearrangement/vusqz/golden.py b/test/vpto/cases/micro-op/rearrangement/vusqz/golden.py new file mode 100644 index 000000000..94c38565a --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vusqz/golden.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/rearrangement/vusqz +# family: rearrangement +# target_ops: pto.vusqz +# scenarios: predicate-driven-rearrangement, prefix-count + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +LANES = 64 +BLOCKS = ROWS * COLS // LANES +ACTIVE_PER_BLOCK = 16 +SEED = 19 + + +def build_case() -> tuple[np.ndarray, np.ndarray, np.ndarray]: + src = np.zeros((BLOCKS, LANES), dtype=np.int32) + mask_seed = np.full((BLOCKS, LANES), -1.0, dtype=np.float32) + out = np.zeros((BLOCKS, LANES), dtype=np.int32) + + for block in range(BLOCKS): + src[block] = np.arange(block * 100 - 31, block * 100 - 31 + LANES, dtype=np.int32) + mask_seed[block, :ACTIVE_PER_BLOCK] = 1.0 + active_count = 0 + out[block, 0] = 0 + for lane in range(1, LANES): + if mask_seed[block, lane - 1] > 0.0: + active_count += 1 + out[block, lane] = active_count + + return src.reshape(ROWS, COLS), mask_seed.reshape(ROWS, COLS), out.reshape(ROWS, COLS) + + +def generate(output_dir: Path) -> None: + src, mask_seed, out = build_case() + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + mask_seed.reshape(-1).tofile(output_dir / "v2.bin") + out.reshape(-1).tofile(output_dir / "golden_v3.bin") + np.zeros_like(out.reshape(-1)).tofile(output_dir / "v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate vusqz prefix-count inputs/golden.") + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + del args.seed + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/rearrangement/vusqz/kernel.pto b/test/vpto/cases/micro-op/rearrangement/vusqz/kernel.pto new file mode 100644 index 000000000..7fa0e91db --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vusqz/kernel.pto @@ -0,0 +1,59 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vusqz +// family: rearrangement +// target_ops: pto.vusqz +// scenarios: predicate-driven-rearrangement, prefix-count +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vusqz_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr, + %arg2: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i32 = arith.constant 64 : i32 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %zero_f32 = arith.constant 0.0 : f32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_mask_seed = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_src, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_mask_seed, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %offset = %c0 to %c1024 step %c64 { + %store_mask, %unused = pto.plt_b32 %c64_i32 : i32 -> !pto.mask, i32 + %src = pto.vlds %ub_src[%offset] : !pto.ptr -> !pto.vreg<64xi32> + %mask_seed = pto.vlds %ub_mask_seed[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %place = pto.vcmps %mask_seed, %zero_f32, %store_mask, "gt" : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask + %out = pto.vusqz %src, %place : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> + pto.vsts %out, %ub_out[%offset], %store_mask : !pto.vreg<64xi32>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/rearrangement/vusqz/launch.cpp b/test/vpto/cases/micro-op/rearrangement/vusqz/launch.cpp new file mode 100644 index 000000000..3684fe9b2 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vusqz/launch.cpp @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vusqz +// family: rearrangement +// target_ops: pto.vusqz +// scenarios: predicate-driven-rearrangement, prefix-count +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vusqz_kernel_2d(__gm__ int32_t *v1, + __gm__ float *v2, + __gm__ int32_t *v3); + +void LaunchVusqz_kernel_2d(int32_t *v1, float *v2, int32_t *v3, void *stream) { + vusqz_kernel_2d<<<1, nullptr, stream>>>( + (__gm__ int32_t *)v1, (__gm__ float *)v2, (__gm__ int32_t *)v3); +} diff --git a/test/vpto/cases/micro-op/rearrangement/vusqz/main.cpp b/test/vpto/cases/micro-op/rearrangement/vusqz/main.cpp new file mode 100644 index 000000000..9da958163 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vusqz/main.cpp @@ -0,0 +1,100 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vusqz +// family: rearrangement +// target_ops: pto.vusqz +// scenarios: predicate-driven-rearrangement, prefix-count +// ----------------------------------------------------------------------------- +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, \ + __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVusqz_kernel_2d(int32_t *v1, float *v2, int32_t *v3, void *stream); + +int main() { + constexpr size_t elemCount = 1024; + size_t fileSizeV1 = elemCount * sizeof(int32_t); + size_t fileSizeV2 = elemCount * sizeof(float); + size_t fileSizeV3 = elemCount * sizeof(int32_t); + int32_t *v1Host = nullptr; + int32_t *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int32_t *v3Host = nullptr; + int32_t *v3Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSizeV1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSizeV2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSizeV3)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSizeV1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSizeV2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSizeV3, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSizeV1, v1Host, fileSizeV1); + ReadFile("./v2.bin", fileSizeV2, v2Host, fileSizeV2); + std::fill_n(v3Host, elemCount, 0); + ACL_CHECK(aclrtMemcpy(v1Device, fileSizeV1, v1Host, fileSizeV1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSizeV2, v2Host, fileSizeV2, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSizeV3, v3Host, fileSizeV3, ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVusqz_kernel_2d(v1Device, v2Device, v3Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSizeV3, v3Device, fileSizeV3, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSizeV3); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + if (stream != nullptr) + (void)aclrtDestroyStream(stream); + if (deviceSet) + (void)aclrtResetDevice(deviceId); + if (aclInited) + (void)aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/rearrangement/vzunpack/compare.py b/test/vpto/cases/micro-op/rearrangement/vzunpack/compare.py new file mode 100755 index 000000000..0dc97cc35 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vzunpack/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/rearrangement/vzunpack +# family: rearrangement +# target_ops: pto.vzunpack +# scenarios: pack-unpack, zero-extend +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.uint32, 0.0) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/rearrangement/vzunpack/golden.py b/test/vpto/cases/micro-op/rearrangement/vzunpack/golden.py new file mode 100755 index 000000000..e6014e397 --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vzunpack/golden.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/rearrangement/vzunpack +# family: rearrangement +# target_ops: pto.vzunpack +# scenarios: pack-unpack, zero-extend +# NOTE: zero-extending unpack of the lower half of each 128-lane ui16 chunk. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +INPUT_ELEMS = 2048 +OUTPUT_ELEMS = 1024 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(0, np.iinfo(np.uint16).max + 1, size=INPUT_ELEMS, dtype=np.uint16) + v2 = np.zeros(OUTPUT_ELEMS, dtype=np.uint32) + golden_v2 = np.zeros(OUTPUT_ELEMS, dtype=np.uint32) + for src_base in range(0, INPUT_ELEMS, 128): + dst_base = (src_base // 128) * 64 + golden_v2[dst_base : dst_base + 64] = v1[src_base : src_base + 64].astype(np.uint32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vzunpack validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/rearrangement/vzunpack/kernel.pto b/test/vpto/cases/micro-op/rearrangement/vzunpack/kernel.pto new file mode 100644 index 000000000..47870c2ec --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vzunpack/kernel.pto @@ -0,0 +1,72 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vzunpack +// family: rearrangement +// target_ops: pto.vzunpack +// scenarios: pack-unpack, zero-extend +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vzunpack_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c2 = arith.constant 2 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %part = arith.constant 0 : index + + %gm_in = pto.castptr %arg0 : !pto.ptr -> !pto.ptr + %gm_out = pto.castptr %arg1 : !pto.ptr -> !pto.ptr + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %gm_in, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %store_mask = pto.pset_b32 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c64 { + %src_offset = arith.muli %offset, %c2 : index + %vec = pto.vlds %ub_in[%src_offset] : !pto.ptr -> !pto.vreg<128xui16> + %out = pto.vzunpack %vec, %part : !pto.vreg<128xui16> -> !pto.vreg<64xui32> + pto.vsts %out, %ub_out[%offset], %store_mask : !pto.vreg<64xui32>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %gm_out, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/rearrangement/vzunpack/launch.cpp b/test/vpto/cases/micro-op/rearrangement/vzunpack/launch.cpp new file mode 100644 index 000000000..7fa2a6c4b --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vzunpack/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vzunpack +// family: rearrangement +// target_ops: pto.vzunpack +// scenarios: pack-unpack, zero-extend +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vzunpack_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVzunpack_kernel_2d(float *v1, float *v2, void *stream) { + vzunpack_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/rearrangement/vzunpack/main.cpp b/test/vpto/cases/micro-op/rearrangement/vzunpack/main.cpp new file mode 100644 index 000000000..e3693855f --- /dev/null +++ b/test/vpto/cases/micro-op/rearrangement/vzunpack/main.cpp @@ -0,0 +1,132 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/rearrangement/vzunpack +// family: rearrangement +// target_ops: pto.vzunpack +// scenarios: pack-unpack, zero-extend +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVzunpack_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 2048; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint32_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint32_t *v2Host = nullptr; + uint32_t *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVzunpack_kernel_2d(reinterpret_cast(v1Device), + reinterpret_cast(v2Device), + stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/reduction/vcadd-tail/compare.py b/test/vpto/cases/micro-op/reduction/vcadd-tail/compare.py new file mode 100644 index 000000000..c13d79273 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcadd-tail/compare.py @@ -0,0 +1,39 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v2.bin", "v2.bin", np.float32, 1e-4, 1000) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/reduction/vcadd-tail/golden.py b/test/vpto/cases/micro-op/reduction/vcadd-tail/golden.py new file mode 100644 index 000000000..9ea041d65 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcadd-tail/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +LANES = 64 +LOGICAL_ELEMS = 1000 +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.zeros((ROWS, COLS), dtype=np.float32) + flat_in = v1.reshape(-1) + flat_out = golden_v2.reshape(-1) + for offset in range(0, LOGICAL_ELEMS, LANES): + chunk = flat_in[offset:min(offset + LANES, LOGICAL_ELEMS)] + flat_out[offset] = np.sum(chunk, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/reduction/vcadd-tail/kernel.pto b/test/vpto/cases/micro-op/reduction/vcadd-tail/kernel.pto new file mode 100644 index 000000000..ee7ad51fe --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcadd-tail/kernel.pto @@ -0,0 +1,42 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vabs_tail_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1000_i32 = arith.constant 1000 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1000_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vcadd %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/reduction/vcadd-tail/launch.cpp b/test/vpto/cases/micro-op/reduction/vcadd-tail/launch.cpp new file mode 100644 index 000000000..494bc5bf3 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcadd-tail/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vabs_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_tail_kernel_2d(float *v1, float *v2, void *stream) { + vabs_tail_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/reduction/vcadd-tail/main.cpp b/test/vpto/cases/micro-op/reduction/vcadd-tail/main.cpp new file mode 100644 index 000000000..cf25e5dff --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcadd-tail/main.cpp @@ -0,0 +1,87 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_tail_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_tail_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/reduction/vcadd/compare.py b/test/vpto/cases/micro-op/reduction/vcadd/compare.py new file mode 100644 index 000000000..962985a24 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcadd/compare.py @@ -0,0 +1,204 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/reduction/vcadd/golden.py b/test/vpto/cases/micro-op/reduction/vcadd/golden.py new file mode 100644 index 000000000..906f71e66 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcadd/golden.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +LANES = 64 +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.zeros((ROWS, COLS), dtype=np.float32) + flat_in = v1.reshape(-1) + flat_out = golden_v2.reshape(-1) + for offset in range(0, flat_in.size, LANES): + chunk = flat_in[offset:offset + LANES] + flat_out[offset] = np.sum(chunk, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/reduction/vcadd/kernel.pto b/test/vpto/cases/micro-op/reduction/vcadd/kernel.pto new file mode 100644 index 000000000..a660a8be0 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcadd/kernel.pto @@ -0,0 +1,43 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vabs_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vcadd %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/reduction/vcadd/launch.cpp b/test/vpto/cases/micro-op/reduction/vcadd/launch.cpp new file mode 100644 index 000000000..9002bcd67 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcadd/launch.cpp @@ -0,0 +1,62 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream) { + vabs_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/reduction/vcadd/main.cpp b/test/vpto/cases/micro-op/reduction/vcadd/main.cpp new file mode 100644 index 000000000..29454461f --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcadd/main.cpp @@ -0,0 +1,122 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/reduction/vcgadd-tail/compare.py b/test/vpto/cases/micro-op/reduction/vcgadd-tail/compare.py new file mode 100755 index 000000000..fb57e856a --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgadd-tail/compare.py @@ -0,0 +1,211 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/reduction/vcgadd-tail +# family: reduction +# target_ops: pto.vcgadd +# scenarios: group-reduction, tail-mask, result-placement +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 +LOGICAL_ELEMS = 1000 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin_prefix("golden_v2.bin", "v2.bin", np.float32, 0.0001, + LOGICAL_ELEMS) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/reduction/vcgadd-tail/golden.py b/test/vpto/cases/micro-op/reduction/vcgadd-tail/golden.py new file mode 100755 index 000000000..282927eff --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgadd-tail/golden.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +LANES = 64 +LOGICAL_ELEMS = 1000 +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.zeros((ROWS, COLS), dtype=np.float32) + flat_in = v1.reshape(-1) + flat_out = golden_v2.reshape(-1) + group_elems = 8 + for offset in range(0, LOGICAL_ELEMS, LANES): + chunk = flat_in[offset:min(offset + LANES, LOGICAL_ELEMS)] + for gi, group in enumerate(range(0, chunk.size, group_elems)): + # VCGADD writes one reduced value per 32B block continuously to low lanes. + flat_out[offset + gi] = np.sum(chunk[group:group + group_elems], dtype=np.float32) + + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/reduction/vcgadd-tail/kernel.pto b/test/vpto/cases/micro-op/reduction/vcgadd-tail/kernel.pto new file mode 100644 index 000000000..34122605e --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgadd-tail/kernel.pto @@ -0,0 +1,51 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcgadd-tail +// family: reduction +// target_ops: pto.vcgadd +// scenarios: group-reduction, tail-mask, result-placement +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vcgadd_tail_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1000_i32 = arith.constant 1000 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1000_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vcgadd %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/reduction/vcgadd-tail/launch.cpp b/test/vpto/cases/micro-op/reduction/vcgadd-tail/launch.cpp new file mode 100644 index 000000000..e35c2b363 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgadd-tail/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcgadd-tail +// family: reduction +// target_ops: pto.vcgadd +// scenarios: group-reduction, tail-mask, result-placement +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcgadd_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVcgaddTail_kernel_2d(float *v1, float *v2, void *stream) { + vcgadd_tail_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/reduction/vcgadd-tail/main.cpp b/test/vpto/cases/micro-op/reduction/vcgadd-tail/main.cpp new file mode 100644 index 000000000..29bdc23bf --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgadd-tail/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcgadd-tail +// family: reduction +// target_ops: pto.vcgadd +// scenarios: group-reduction, tail-mask, result-placement +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcgaddTail_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVcgaddTail_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/reduction/vcgadd/compare.py b/test/vpto/cases/micro-op/reduction/vcgadd/compare.py new file mode 100755 index 000000000..2c8e5f087 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgadd/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/reduction/vcgadd +# family: reduction +# target_ops: pto.vcgadd +# scenarios: group-reduction, result-placement +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/reduction/vcgadd/golden.py b/test/vpto/cases/micro-op/reduction/vcgadd/golden.py new file mode 100755 index 000000000..efa021477 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgadd/golden.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +LANES = 64 +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.zeros((ROWS, COLS), dtype=np.float32) + flat_in = v1.reshape(-1) + flat_out = golden_v2.reshape(-1) + group_elems = 8 + for offset in range(0, flat_in.size, LANES): + chunk = flat_in[offset:offset + LANES] + for gi, group in enumerate(range(0, chunk.size, group_elems)): + # VCGADD writes one reduced value per 32B block continuously to low lanes. + flat_out[offset + gi] = np.sum(chunk[group:group + group_elems], dtype=np.float32) + + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/reduction/vcgadd/kernel.pto b/test/vpto/cases/micro-op/reduction/vcgadd/kernel.pto new file mode 100644 index 000000000..72c89af37 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgadd/kernel.pto @@ -0,0 +1,51 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcgadd +// family: reduction +// target_ops: pto.vcgadd +// scenarios: group-reduction, result-placement +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vcgadd_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vcgadd %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/reduction/vcgadd/launch.cpp b/test/vpto/cases/micro-op/reduction/vcgadd/launch.cpp new file mode 100644 index 000000000..16a1993e8 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgadd/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcgadd +// family: reduction +// target_ops: pto.vcgadd +// scenarios: group-reduction, result-placement +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcgadd_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVcgadd_kernel_2d(float *v1, float *v2, void *stream) { + vcgadd_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/reduction/vcgadd/main.cpp b/test/vpto/cases/micro-op/reduction/vcgadd/main.cpp new file mode 100644 index 000000000..712f0755a --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgadd/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcgadd +// family: reduction +// target_ops: pto.vcgadd +// scenarios: group-reduction, result-placement +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcgadd_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVcgadd_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/reduction/vcgmax-tie/compare.py b/test/vpto/cases/micro-op/reduction/vcgmax-tie/compare.py new file mode 100755 index 000000000..a4a5c50c3 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgmax-tie/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/reduction/vcgmax-tie +# family: reduction +# target_ops: pto.vcgmax +# scenarios: group-reduction, result-placement +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/reduction/vcgmax-tie/golden.py b/test/vpto/cases/micro-op/reduction/vcgmax-tie/golden.py new file mode 100755 index 000000000..a4d414312 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgmax-tie/golden.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +LANES = 64 +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + flat_seed = v1.reshape(-1) + for offset in range(0, flat_seed.size, LANES): + for group in range(0, LANES, 8): + base = offset + group + flat_seed[base:base + 8] = np.array([7.0, 7.0, -3.0, 1.0, 0.5, -2.0, 4.0, 6.0], dtype=np.float32) + + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.zeros((ROWS, COLS), dtype=np.float32) + flat_in = v1.reshape(-1) + flat_out = golden_v2.reshape(-1) + group_elems = 8 + for offset in range(0, flat_in.size, LANES): + chunk = flat_in[offset:offset + LANES] + for gi, group in enumerate(range(0, chunk.size, group_elems)): + # VCGMAX writes one reduced value per 32B block continuously to low lanes. + flat_out[offset + gi] = np.max(chunk[group:group + group_elems]) + + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/reduction/vcgmax-tie/kernel.pto b/test/vpto/cases/micro-op/reduction/vcgmax-tie/kernel.pto new file mode 100644 index 000000000..cffd43031 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgmax-tie/kernel.pto @@ -0,0 +1,51 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcgmax-tie +// family: reduction +// target_ops: pto.vcgmax +// scenarios: group-reduction, result-placement +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vcgmax_tie_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vcgmax %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/reduction/vcgmax-tie/launch.cpp b/test/vpto/cases/micro-op/reduction/vcgmax-tie/launch.cpp new file mode 100644 index 000000000..35e5a63b3 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgmax-tie/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcgmax-tie +// family: reduction +// target_ops: pto.vcgmax +// scenarios: group-reduction, result-placement +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcgmax_tie_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVcgmaxTie_kernel_2d(float *v1, float *v2, void *stream) { + vcgmax_tie_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/reduction/vcgmax-tie/main.cpp b/test/vpto/cases/micro-op/reduction/vcgmax-tie/main.cpp new file mode 100644 index 000000000..79ff13b2e --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgmax-tie/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcgmax-tie +// family: reduction +// target_ops: pto.vcgmax +// scenarios: group-reduction, result-placement +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcgmaxTie_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVcgmaxTie_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/reduction/vcgmax/compare.py b/test/vpto/cases/micro-op/reduction/vcgmax/compare.py new file mode 100755 index 000000000..f1f037986 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgmax/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/reduction/vcgmax +# family: reduction +# target_ops: pto.vcgmax +# scenarios: group-reduction, result-placement +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/reduction/vcgmax/golden.py b/test/vpto/cases/micro-op/reduction/vcgmax/golden.py new file mode 100755 index 000000000..d807ff1e0 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgmax/golden.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +LANES = 64 +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.zeros((ROWS, COLS), dtype=np.float32) + flat_in = v1.reshape(-1) + flat_out = golden_v2.reshape(-1) + group_elems = 8 + for offset in range(0, flat_in.size, LANES): + chunk = flat_in[offset:offset + LANES] + for gi, group in enumerate(range(0, chunk.size, group_elems)): + # VCGMAX writes one reduced value per 32B block continuously to low lanes. + flat_out[offset + gi] = np.max(chunk[group:group + group_elems]) + + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/reduction/vcgmax/kernel.pto b/test/vpto/cases/micro-op/reduction/vcgmax/kernel.pto new file mode 100644 index 000000000..12f289720 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgmax/kernel.pto @@ -0,0 +1,51 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcgmax +// family: reduction +// target_ops: pto.vcgmax +// scenarios: group-reduction, result-placement +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vcgmax_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vcgmax %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/reduction/vcgmax/launch.cpp b/test/vpto/cases/micro-op/reduction/vcgmax/launch.cpp new file mode 100644 index 000000000..33855f496 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgmax/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcgmax +// family: reduction +// target_ops: pto.vcgmax +// scenarios: group-reduction, result-placement +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcgmax_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVcgmax_kernel_2d(float *v1, float *v2, void *stream) { + vcgmax_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/reduction/vcgmax/main.cpp b/test/vpto/cases/micro-op/reduction/vcgmax/main.cpp new file mode 100644 index 000000000..f51aa0ebe --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgmax/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcgmax +// family: reduction +// target_ops: pto.vcgmax +// scenarios: group-reduction, result-placement +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcgmax_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVcgmax_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/reduction/vcgmin-tie/compare.py b/test/vpto/cases/micro-op/reduction/vcgmin-tie/compare.py new file mode 100755 index 000000000..05b8ee45c --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgmin-tie/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/reduction/vcgmin-tie +# family: reduction +# target_ops: pto.vcgmin +# scenarios: group-reduction, result-placement +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/reduction/vcgmin-tie/golden.py b/test/vpto/cases/micro-op/reduction/vcgmin-tie/golden.py new file mode 100755 index 000000000..62a18cd0d --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgmin-tie/golden.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +LANES = 64 +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + flat_seed = v1.reshape(-1) + for offset in range(0, flat_seed.size, LANES): + for group in range(0, LANES, 8): + base = offset + group + flat_seed[base:base + 8] = np.array([-7.0, -7.0, 3.0, -1.0, 0.5, 2.0, -4.0, -6.0], dtype=np.float32) + + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.zeros((ROWS, COLS), dtype=np.float32) + flat_in = v1.reshape(-1) + flat_out = golden_v2.reshape(-1) + group_elems = 8 + for offset in range(0, flat_in.size, LANES): + chunk = flat_in[offset:offset + LANES] + for gi, group in enumerate(range(0, chunk.size, group_elems)): + # VCGMIN writes one reduced value per 32B block continuously to low lanes. + flat_out[offset + gi] = np.min(chunk[group:group + group_elems]) + + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/reduction/vcgmin-tie/kernel.pto b/test/vpto/cases/micro-op/reduction/vcgmin-tie/kernel.pto new file mode 100644 index 000000000..af0430119 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgmin-tie/kernel.pto @@ -0,0 +1,51 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcgmin-tie +// family: reduction +// target_ops: pto.vcgmin +// scenarios: group-reduction, result-placement +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vcgmin_tie_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vcgmin %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/reduction/vcgmin-tie/launch.cpp b/test/vpto/cases/micro-op/reduction/vcgmin-tie/launch.cpp new file mode 100644 index 000000000..35f95d660 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgmin-tie/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcgmin-tie +// family: reduction +// target_ops: pto.vcgmin +// scenarios: group-reduction, result-placement +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcgmin_tie_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVcgminTie_kernel_2d(float *v1, float *v2, void *stream) { + vcgmin_tie_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/reduction/vcgmin-tie/main.cpp b/test/vpto/cases/micro-op/reduction/vcgmin-tie/main.cpp new file mode 100644 index 000000000..3a940457b --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgmin-tie/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcgmin-tie +// family: reduction +// target_ops: pto.vcgmin +// scenarios: group-reduction, result-placement +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcgminTie_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVcgminTie_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/reduction/vcgmin/compare.py b/test/vpto/cases/micro-op/reduction/vcgmin/compare.py new file mode 100755 index 000000000..57ac3a528 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgmin/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/reduction/vcgmin +# family: reduction +# target_ops: pto.vcgmin +# scenarios: group-reduction, result-placement +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/reduction/vcgmin/golden.py b/test/vpto/cases/micro-op/reduction/vcgmin/golden.py new file mode 100755 index 000000000..5f2413af5 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgmin/golden.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +LANES = 64 +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.zeros((ROWS, COLS), dtype=np.float32) + flat_in = v1.reshape(-1) + flat_out = golden_v2.reshape(-1) + group_elems = 8 + for offset in range(0, flat_in.size, LANES): + chunk = flat_in[offset:offset + LANES] + for gi, group in enumerate(range(0, chunk.size, group_elems)): + # VCGMIN writes one reduced value per 32B block continuously to low lanes. + flat_out[offset + gi] = np.min(chunk[group:group + group_elems]) + + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/reduction/vcgmin/kernel.pto b/test/vpto/cases/micro-op/reduction/vcgmin/kernel.pto new file mode 100644 index 000000000..6d7c9d14c --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgmin/kernel.pto @@ -0,0 +1,51 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcgmin +// family: reduction +// target_ops: pto.vcgmin +// scenarios: group-reduction, result-placement +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vcgmin_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vcgmin %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/reduction/vcgmin/launch.cpp b/test/vpto/cases/micro-op/reduction/vcgmin/launch.cpp new file mode 100644 index 000000000..b6787415c --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgmin/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcgmin +// family: reduction +// target_ops: pto.vcgmin +// scenarios: group-reduction, result-placement +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcgmin_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVcgmin_kernel_2d(float *v1, float *v2, void *stream) { + vcgmin_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/reduction/vcgmin/main.cpp b/test/vpto/cases/micro-op/reduction/vcgmin/main.cpp new file mode 100644 index 000000000..1c4fc7676 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcgmin/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcgmin +// family: reduction +// target_ops: pto.vcgmin +// scenarios: group-reduction, result-placement +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcgmin_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVcgmin_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/reduction/vcmax/compare.py b/test/vpto/cases/micro-op/reduction/vcmax/compare.py new file mode 100644 index 000000000..962985a24 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcmax/compare.py @@ -0,0 +1,204 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/reduction/vcmax/golden.py b/test/vpto/cases/micro-op/reduction/vcmax/golden.py new file mode 100644 index 000000000..739d372e6 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcmax/golden.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +LANES = 64 +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.zeros((ROWS, COLS), dtype=np.float32) + flat_in = v1.reshape(-1) + flat_out = golden_v2.reshape(-1) + flat_out_u32 = flat_out.view(np.uint32) + for offset in range(0, flat_in.size, LANES): + chunk = flat_in[offset:offset + LANES] + idx = int(np.argmax(chunk)) + flat_out[offset] = chunk[idx] + flat_out_u32[offset + 1] = np.uint32(idx) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/reduction/vcmax/kernel.pto b/test/vpto/cases/micro-op/reduction/vcmax/kernel.pto new file mode 100644 index 000000000..93aced72d --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcmax/kernel.pto @@ -0,0 +1,43 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vabs_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vcmax %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/reduction/vcmax/launch.cpp b/test/vpto/cases/micro-op/reduction/vcmax/launch.cpp new file mode 100644 index 000000000..9002bcd67 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcmax/launch.cpp @@ -0,0 +1,62 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream) { + vabs_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/reduction/vcmax/main.cpp b/test/vpto/cases/micro-op/reduction/vcmax/main.cpp new file mode 100644 index 000000000..29454461f --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcmax/main.cpp @@ -0,0 +1,122 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/reduction/vcmin/compare.py b/test/vpto/cases/micro-op/reduction/vcmin/compare.py new file mode 100644 index 000000000..962985a24 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcmin/compare.py @@ -0,0 +1,204 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/reduction/vcmin/golden.py b/test/vpto/cases/micro-op/reduction/vcmin/golden.py new file mode 100644 index 000000000..bbbfe8d57 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcmin/golden.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +LANES = 64 +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.zeros((ROWS, COLS), dtype=np.float32) + flat_in = v1.reshape(-1) + flat_out = golden_v2.reshape(-1) + flat_out_u32 = flat_out.view(np.uint32) + for offset in range(0, flat_in.size, LANES): + chunk = flat_in[offset:offset + LANES] + idx = int(np.argmin(chunk)) + flat_out[offset] = chunk[idx] + flat_out_u32[offset + 1] = np.uint32(idx) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/reduction/vcmin/kernel.pto b/test/vpto/cases/micro-op/reduction/vcmin/kernel.pto new file mode 100644 index 000000000..9d0e34332 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcmin/kernel.pto @@ -0,0 +1,43 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vabs_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vcmin %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/reduction/vcmin/launch.cpp b/test/vpto/cases/micro-op/reduction/vcmin/launch.cpp new file mode 100644 index 000000000..9002bcd67 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcmin/launch.cpp @@ -0,0 +1,62 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream) { + vabs_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/reduction/vcmin/main.cpp b/test/vpto/cases/micro-op/reduction/vcmin/main.cpp new file mode 100644 index 000000000..29454461f --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcmin/main.cpp @@ -0,0 +1,122 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/reduction/vcpadd-tail/compare.py b/test/vpto/cases/micro-op/reduction/vcpadd-tail/compare.py new file mode 100755 index 000000000..59d97b87c --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcpadd-tail/compare.py @@ -0,0 +1,211 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/reduction/vcpadd-tail +# family: reduction +# target_ops: pto.vcpadd +# scenarios: prefix-op, tail-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 +LOGICAL_ELEMS = 1000 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin_prefix("golden_v2.bin", "v2.bin", np.float32, 0.0001, + LOGICAL_ELEMS) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/reduction/vcpadd-tail/golden.py b/test/vpto/cases/micro-op/reduction/vcpadd-tail/golden.py new file mode 100755 index 000000000..08dc83922 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcpadd-tail/golden.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +LANES = 64 +LOGICAL_ELEMS = 1000 +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.zeros((ROWS, COLS), dtype=np.float32) + flat_in = v1.reshape(-1) + flat_out = golden_v2.reshape(-1) + for offset in range(0, LOGICAL_ELEMS, LANES): + chunk = flat_in[offset:min(offset + LANES, LOGICAL_ELEMS)] + pair_count = (chunk.size + 1) // 2 + for i in range(pair_count): + a = chunk[2 * i] + b = chunk[2 * i + 1] if (2 * i + 1) < chunk.size else np.float32(0.0) + # VCPADD writes pair-reduction results to low half lanes. + flat_out[offset + i] = np.float32(a + b) + + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/reduction/vcpadd-tail/kernel.pto b/test/vpto/cases/micro-op/reduction/vcpadd-tail/kernel.pto new file mode 100644 index 000000000..84f25e634 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcpadd-tail/kernel.pto @@ -0,0 +1,51 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcpadd-tail +// family: reduction +// target_ops: pto.vcpadd +// scenarios: prefix-op, tail-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vcpadd_tail_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1000_i32 = arith.constant 1000 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1000_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vcpadd %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/reduction/vcpadd-tail/launch.cpp b/test/vpto/cases/micro-op/reduction/vcpadd-tail/launch.cpp new file mode 100644 index 000000000..08c0b9ad5 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcpadd-tail/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcpadd-tail +// family: reduction +// target_ops: pto.vcpadd +// scenarios: prefix-op, tail-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcpadd_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVcpaddTail_kernel_2d(float *v1, float *v2, void *stream) { + vcpadd_tail_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/reduction/vcpadd-tail/main.cpp b/test/vpto/cases/micro-op/reduction/vcpadd-tail/main.cpp new file mode 100644 index 000000000..d571471dc --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcpadd-tail/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcpadd-tail +// family: reduction +// target_ops: pto.vcpadd +// scenarios: prefix-op, tail-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcpaddTail_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVcpaddTail_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/reduction/vcpadd/compare.py b/test/vpto/cases/micro-op/reduction/vcpadd/compare.py new file mode 100755 index 000000000..8094ed94e --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcpadd/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/reduction/vcpadd +# family: reduction +# target_ops: pto.vcpadd +# scenarios: prefix-op, full-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/reduction/vcpadd/golden.py b/test/vpto/cases/micro-op/reduction/vcpadd/golden.py new file mode 100755 index 000000000..eb41c69f0 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcpadd/golden.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +LANES = 64 +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.zeros((ROWS, COLS), dtype=np.float32) + flat_in = v1.reshape(-1) + flat_out = golden_v2.reshape(-1) + for offset in range(0, flat_in.size, LANES): + chunk = flat_in[offset:offset + LANES] + pair_count = (chunk.size + 1) // 2 + for i in range(pair_count): + a = chunk[2 * i] + b = chunk[2 * i + 1] if (2 * i + 1) < chunk.size else np.float32(0.0) + # VCPADD writes pair-reduction results to low half lanes. + flat_out[offset + i] = np.float32(a + b) + + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/reduction/vcpadd/kernel.pto b/test/vpto/cases/micro-op/reduction/vcpadd/kernel.pto new file mode 100644 index 000000000..3b4bf2f9d --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcpadd/kernel.pto @@ -0,0 +1,51 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcpadd +// family: reduction +// target_ops: pto.vcpadd +// scenarios: prefix-op, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vcpadd_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vcpadd %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/reduction/vcpadd/launch.cpp b/test/vpto/cases/micro-op/reduction/vcpadd/launch.cpp new file mode 100644 index 000000000..ad26d59b2 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcpadd/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcpadd +// family: reduction +// target_ops: pto.vcpadd +// scenarios: prefix-op, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vcpadd_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVcpadd_kernel_2d(float *v1, float *v2, void *stream) { + vcpadd_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/reduction/vcpadd/main.cpp b/test/vpto/cases/micro-op/reduction/vcpadd/main.cpp new file mode 100644 index 000000000..7f62d2606 --- /dev/null +++ b/test/vpto/cases/micro-op/reduction/vcpadd/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/reduction/vcpadd +// family: reduction +// target_ops: pto.vcpadd +// scenarios: prefix-op, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVcpadd_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVcpadd_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/scalar-load-store/load-store-scalar-ub/compare.py b/test/vpto/cases/micro-op/scalar-load-store/load-store-scalar-ub/compare.py new file mode 100644 index 000000000..87edcafae --- /dev/null +++ b/test/vpto/cases/micro-op/scalar-load-store/load-store-scalar-ub/compare.py @@ -0,0 +1,54 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden[idx])}, out={int(output[idx])}, dtype={dtype_np})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.int16) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/scalar-load-store/load-store-scalar-ub/golden.py b/test/vpto/cases/micro-op/scalar-load-store/load-store-scalar-ub/golden.py new file mode 100644 index 000000000..a4fb8dd28 --- /dev/null +++ b/test/vpto/cases/micro-op/scalar-load-store/load-store-scalar-ub/golden.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(-20000, 20000, size=ELEMS, dtype=np.int16) + v2 = np.zeros(ELEMS, dtype=np.int16) + golden_v2 = (v1.astype(np.int32) + 4).astype(np.int16) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO UB scalar load/store validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/scalar-load-store/load-store-scalar-ub/kernel.pto b/test/vpto/cases/micro-op/scalar-load-store/load-store-scalar-ub/kernel.pto new file mode 100644 index 000000000..5957546cc --- /dev/null +++ b/test/vpto/cases/micro-op/scalar-load-store/load-store-scalar-ub/kernel.pto @@ -0,0 +1,50 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/scalar-load-store/load-store-scalar-ub +// family: scalar-load-store +// target_ops: pto.load_scalar, pto.store_scalar +// scenarios: core-i16, ub-roundtrip, scalar-rw +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @load_store_scalar_ub_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.aicore} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c7_i16 = arith.constant 7 : i16 + %c3_i16 = arith.constant 3 : i16 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %offset = %c0 to %c1024 step %c1 { + %loaded = pto.load_scalar %ub_in[%offset] : !pto.ptr -> i16 + %biased = arith.addi %loaded, %c7_i16 : i16 + pto.store_scalar %biased, %ub_out[%offset] : !pto.ptr, i16 + %echo = pto.load_scalar %ub_out[%offset] : !pto.ptr -> i16 + %result = arith.subi %echo, %c3_i16 : i16 + pto.store_scalar %result, %ub_out[%offset] : !pto.ptr, i16 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/scalar-load-store/load-store-scalar-ub/launch.cpp b/test/vpto/cases/micro-op/scalar-load-store/load-store-scalar-ub/launch.cpp new file mode 100644 index 000000000..dbfc19cfa --- /dev/null +++ b/test/vpto/cases/micro-op/scalar-load-store/load-store-scalar-ub/launch.cpp @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void load_store_scalar_ub_kernel(__gm__ int16_t *v1, + __gm__ int16_t *v2); + +void LaunchLoad_store_scalar_ub_kernel(int16_t *v1, int16_t *v2, void *stream) { + load_store_scalar_ub_kernel<<<1, nullptr, stream>>>((__gm__ int16_t *)v1, + (__gm__ int16_t *)v2); +} diff --git a/test/vpto/cases/micro-op/scalar-load-store/load-store-scalar-ub/main.cpp b/test/vpto/cases/micro-op/scalar-load-store/load-store-scalar-ub/main.cpp new file mode 100644 index 000000000..ff7a3e98b --- /dev/null +++ b/test/vpto/cases/micro-op/scalar-load-store/load-store-scalar-ub/main.cpp @@ -0,0 +1,119 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchLoad_store_scalar_ub_kernel(int16_t *v1, int16_t *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(int16_t); + int16_t *v1Host = nullptr; + int16_t *v2Host = nullptr; + int16_t *v1Device = nullptr; + int16_t *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v1, v2Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v1, v2Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchLoad_store_scalar_ub_kernel(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v1, v2Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v1); + +cleanup: + aclrtFree(v2Device); + aclrtFree(v1Device); + aclrtFreeHost(v2Host); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/simt/simt-store-tid/compare.py b/test/vpto/cases/micro-op/simt/simt-store-tid/compare.py new file mode 100644 index 000000000..99a157ea3 --- /dev/null +++ b/test/vpto/cases/micro-op/simt/simt-store-tid/compare.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import os +import sys + +import numpy as np + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + golden = np.fromfile("golden_v1.bin", dtype=np.int32) + out = np.fromfile("v1.bin", dtype=np.int32) + ok = golden.shape == out.shape and np.array_equal(golden, out) + if not ok: + idxs = np.nonzero(golden != out)[0] + idx = int(idxs[0]) if idxs.size else 0 + print( + f"[ERROR] mismatch at idx={idx}, golden={int(golden[idx])}, out={int(out[idx])}" + ) + if strict: + sys.exit(2) + print("[INFO] compare passed" if ok else "[WARN] compare failed (non-gating)") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/simt/simt-store-tid/golden.py b/test/vpto/cases/micro-op/simt/simt-store-tid/golden.py new file mode 100644 index 000000000..7726b14ba --- /dev/null +++ b/test/vpto/cases/micro-op/simt/simt-store-tid/golden.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ELEMS = 1024 + + +def generate(output_dir: Path) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + v1 = np.full(ELEMS, -1, dtype=np.int32) + xs = np.arange(32, dtype=np.int32) + ys = np.arange(32, dtype=np.int32)[:, None] + golden_v1 = (xs[None, :] | (ys << 8)).reshape(ELEMS) + v1.tofile(output_dir / "v1.bin") + golden_v1.tofile(output_dir / "golden_v1.bin") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/simt/simt-store-tid/kernel.pto b/test/vpto/cases/micro-op/simt/simt-store-tid/kernel.pto new file mode 100644 index 000000000..f2766cadf --- /dev/null +++ b/test/vpto/cases/micro-op/simt/simt-store-tid/kernel.pto @@ -0,0 +1,41 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @simt_store_tid_kernel(%arg0: !pto.ptr) attributes {pto.kernel} { + %c0_i64 = arith.constant 0 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %dim_z = arith.constant 1 : i32 + %dim_y = arith.constant 32 : i32 + %dim_x = arith.constant 32 : i32 + + %ub_out = pto.castptr %c0_i64 : i64 -> !pto.ptr + + pto.store_vfsimt_info %dim_z, %dim_y, %dim_x : i32, i32, i32 + func.call @simt_write(%ub_out) : (!pto.ptr) -> () + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg0, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } + + func.func @simt_write(%dst: !pto.ptr) attributes {pto.simt_entry} { + %tx = pto.get_tid_x : i32 + %ty = pto.get_tid_y : i32 + %tz = pto.get_tid_z : i32 + %c8_i32 = arith.constant 8 : i32 + %c16_i32 = arith.constant 16 : i32 + %c32_i32 = arith.constant 32 : i32 + %ty_shift = arith.shli %ty, %c8_i32 : i32 + %tz_shift = arith.shli %tz, %c16_i32 : i32 + %xy = arith.ori %tx, %ty_shift : i32 + %xyz = arith.ori %xy, %tz_shift : i32 + %lane_base = arith.muli %ty, %c32_i32 : i32 + %tid = arith.addi %lane_base, %tx : i32 + %tid_idx = arith.index_castui %tid : i32 to index + pto.store %xyz, %dst[%tid_idx] : !pto.ptr, i32 + return + } +} diff --git a/test/vpto/cases/micro-op/simt/simt-store-tid/launch.cpp b/test/vpto/cases/micro-op/simt/simt-store-tid/launch.cpp new file mode 100644 index 000000000..130e4a4fc --- /dev/null +++ b/test/vpto/cases/micro-op/simt/simt-store-tid/launch.cpp @@ -0,0 +1,18 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif +extern "C" __global__ [aicore] void simt_store_tid_kernel(__gm__ int *v1); +void LaunchSimt_store_tid_kernel(int *v1, void *stream) { + simt_store_tid_kernel<<<1, nullptr, stream>>>((__gm__ int *)v1); +} diff --git a/test/vpto/cases/micro-op/simt/simt-store-tid/main.cpp b/test/vpto/cases/micro-op/simt/simt-store-tid/main.cpp new file mode 100644 index 000000000..f1cc9b6c4 --- /dev/null +++ b/test/vpto/cases/micro-op/simt/simt-store-tid/main.cpp @@ -0,0 +1,56 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) do { const aclError _ret = (expr); if (_ret != ACL_SUCCESS) { std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); rc = 1; goto cleanup; } } while (0) + +void LaunchSimt_store_tid_kernel(int *v1, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(int); + int *v1Host = nullptr; + int *v1Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchSimt_store_tid_kernel(v1Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/system-runtime-query/get-block-subblock-id/compare.py b/test/vpto/cases/micro-op/system-runtime-query/get-block-subblock-id/compare.py new file mode 100644 index 000000000..561c706e1 --- /dev/null +++ b/test/vpto/cases/micro-op/system-runtime-query/get-block-subblock-id/compare.py @@ -0,0 +1,54 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden[idx])}, out={int(output[idx])}, dtype={dtype_np})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v1.bin", "v1.bin", np.int64) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/system-runtime-query/get-block-subblock-id/golden.py b/test/vpto/cases/micro-op/system-runtime-query/get-block-subblock-id/golden.py new file mode 100644 index 000000000..785d7bc11 --- /dev/null +++ b/test/vpto/cases/micro-op/system-runtime-query/get-block-subblock-id/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +SEED = 19 +VALUES = np.full(64, -1, dtype=np.int64) +VALUES[0] = 0 +VALUES[1] = 0 +VALUES[2] = 2 +VALUES[3] = 1 +VALUES[32] = 1 +VALUES[33] = 0 +VALUES[34] = 2 +VALUES[35] = 1 + + +def generate(output_dir: Path, seed: int) -> None: + del seed + v1 = np.full(VALUES.shape, -1, dtype=np.int64) + golden_v1 = VALUES.copy() + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + golden_v1.tofile(output_dir / "golden_v1.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO runtime query validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/system-runtime-query/get-block-subblock-id/kernel.pto b/test/vpto/cases/micro-op/system-runtime-query/get-block-subblock-id/kernel.pto new file mode 100644 index 000000000..2f83b8d49 --- /dev/null +++ b/test/vpto/cases/micro-op/system-runtime-query/get-block-subblock-id/kernel.pto @@ -0,0 +1,35 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/system-runtime-query/get-block-subblock-id +// family: system-runtime-query +// target_ops: pto.get_block_idx, pto.get_subblock_idx, pto.get_block_num, +// pto.get_subblock_num, pto.store_scalar +// scenarios: multi-block, runtime-query +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @get_block_subblock_id_kernel(%arg0: !pto.ptr) attributes {pto.kernel} { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c32 = arith.constant 32 : index + + pto.vecscope { + } + + %block = pto.get_block_idx + %subblock = pto.get_subblock_idx + %block_num = pto.get_block_num + %subblock_num = pto.get_subblock_num + + %block_idx = arith.index_cast %block : i64 to index + %slot_base = arith.muli %block_idx, %c32 : index + + pto.store_scalar %block, %arg0[%slot_base] : !pto.ptr, i64 + %slot_1 = arith.addi %slot_base, %c1 : index + pto.store_scalar %subblock, %arg0[%slot_1] : !pto.ptr, i64 + %slot_2 = arith.addi %slot_base, %c2 : index + pto.store_scalar %block_num, %arg0[%slot_2] : !pto.ptr, i64 + %slot_3 = arith.addi %slot_base, %c3 : index + pto.store_scalar %subblock_num, %arg0[%slot_3] : !pto.ptr, i64 + return + } +} diff --git a/test/vpto/cases/micro-op/system-runtime-query/get-block-subblock-id/launch.cpp b/test/vpto/cases/micro-op/system-runtime-query/get-block-subblock-id/launch.cpp new file mode 100644 index 000000000..c3d8bcdbe --- /dev/null +++ b/test/vpto/cases/micro-op/system-runtime-query/get-block-subblock-id/launch.cpp @@ -0,0 +1,43 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void get_block_subblock_id_kernel(__gm__ int64_t *v1); + +void LaunchGet_block_subblock_id_kernel(int64_t *v1, void *stream) { + get_block_subblock_id_kernel<<<2, nullptr, stream>>>((__gm__ int64_t *)v1); +} diff --git a/test/vpto/cases/micro-op/system-runtime-query/get-block-subblock-id/main.cpp b/test/vpto/cases/micro-op/system-runtime-query/get-block-subblock-id/main.cpp new file mode 100644 index 000000000..f9ed27342 --- /dev/null +++ b/test/vpto/cases/micro-op/system-runtime-query/get-block-subblock-id/main.cpp @@ -0,0 +1,111 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchGet_block_subblock_id_kernel(int64_t *v1, void *stream); + +int main() { + size_t elemCount_v1 = 64; + size_t fileSize_v1 = elemCount_v1 * sizeof(int64_t); + int64_t *v1Host = nullptr; + int64_t *v1Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchGet_block_subblock_id_kernel(v1Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v1Host, fileSize_v1, v1Device, fileSize_v1, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v1.bin", v1Host, fileSize_v1); + +cleanup: + aclrtFree(v1Device); + aclrtFreeHost(v1Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-f16/compare.py b/test/vpto/cases/micro-op/unary-vector/vabs-f16/compare.py new file mode 100755 index 000000000..77d269686 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-f16/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/unary-vector/vabs-f16 +# family: unary-vector +# target_ops: pto.vabs +# scenarios: core-f16, full-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-f16/golden.py b/test/vpto/cases/micro-op/unary-vector/vabs-f16/golden.py new file mode 100755 index 000000000..b90b097ce --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-f16/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/unary-vector/vabs-f16 +# family: unary-vector +# target_ops: pto.vabs +# scenarios: core-f16, full-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.abs(v1).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vabs validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-f16/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vabs-f16/kernel.pto new file mode 100644 index 000000000..ae2af0a2e --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-f16/kernel.pto @@ -0,0 +1,70 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vabs-f16 +// family: unary-vector +// target_ops: pto.vabs +// scenarios: core-f16, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vabs_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vabs %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-f16/launch.cpp b/test/vpto/cases/micro-op/unary-vector/vabs-f16/launch.cpp new file mode 100644 index 000000000..58cdd948a --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-f16/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vabs-f16 +// family: unary-vector +// target_ops: pto.vabs +// scenarios: core-f16, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream) { + vabs_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-f16/main.cpp b/test/vpto/cases/micro-op/unary-vector/vabs-f16/main.cpp new file mode 100644 index 000000000..76001407d --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-f16/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vabs-f16 +// family: unary-vector +// target_ops: pto.vabs +// scenarios: core-f16, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-f32-exceptional/compare.py b/test/vpto/cases/micro-op/unary-vector/vabs-f32-exceptional/compare.py new file mode 100644 index 000000000..962985a24 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-f32-exceptional/compare.py @@ -0,0 +1,204 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-f32-exceptional/golden.py b/test/vpto/cases/micro-op/unary-vector/vabs-f32-exceptional/golden.py new file mode 100644 index 000000000..95f77e83a --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-f32-exceptional/golden.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + del seed + specials = np.array( + [-np.inf, -7.5, -0.0, 0.0, 1.0, np.inf, np.nan, 3.5], + dtype=np.float32, + ) + v1 = np.resize(specials, ROWS * COLS).reshape(ROWS, COLS).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.abs(v1).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-f32-exceptional/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vabs-f32-exceptional/kernel.pto new file mode 100644 index 000000000..e6eb1c661 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-f32-exceptional/kernel.pto @@ -0,0 +1,42 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vabs_f32_exceptional_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vabs %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-f32-exceptional/launch.cpp b/test/vpto/cases/micro-op/unary-vector/vabs-f32-exceptional/launch.cpp new file mode 100644 index 000000000..806579491 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-f32-exceptional/launch.cpp @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vabs_f32_exceptional_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_f32_exceptional_kernel_2d(float *v1, float *v2, void *stream) { + vabs_f32_exceptional_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-f32-exceptional/main.cpp b/test/vpto/cases/micro-op/unary-vector/vabs-f32-exceptional/main.cpp new file mode 100644 index 000000000..b3312f7e2 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-f32-exceptional/main.cpp @@ -0,0 +1,87 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_f32_exceptional_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_f32_exceptional_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed-overflow-edge/compare.py b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed-overflow-edge/compare.py new file mode 100644 index 000000000..672b2df43 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed-overflow-edge/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/unary-vector/vabs-i16-signed-overflow-edge +# family: unary-vector +# target_ops: pto.vabs +# scenarios: core-i16-signed, full-mask, integer-overflow + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.int16) + output = np.fromfile(output_path, dtype=np.int16) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed-overflow-edge/golden.py b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed-overflow-edge/golden.py new file mode 100644 index 000000000..e8562fdac --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed-overflow-edge/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/unary-vector/vabs-i16-signed-overflow-edge +# family: unary-vector +# target_ops: pto.vabs +# scenarios: core-i16-signed, full-mask, integer-overflow + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + data = rng.integers(-30000, 30000, size=ELEMS, dtype=np.int16) + edge = np.array( + [-32768, -32767, -12345, -1, 0, 1, 12345, 32767, + -32768, -2, 2, -32766, 32766, -1024, 1024, -17], + dtype=np.int16, + ) + data[:edge.size] = edge + golden = np.abs(data).astype(np.int16, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + data.tofile(output_dir / "v1.bin") + np.zeros(ELEMS, dtype=np.int16).tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed-overflow-edge/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed-overflow-edge/kernel.pto new file mode 100644 index 000000000..199d3ebc6 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed-overflow-edge/kernel.pto @@ -0,0 +1,49 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vabs-i16-signed-overflow-edge +// family: unary-vector +// target_ops: pto.vabs +// scenarios: core-i16-signed, full-mask, integer-overflow +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vabs_i16_signed_overflow_edge_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<128xi16> + %out = pto.vabs %vec, %mask : !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<128xi16>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed-overflow-edge/launch.cpp b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed-overflow-edge/launch.cpp new file mode 100644 index 000000000..be3498f7e --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed-overflow-edge/launch.cpp @@ -0,0 +1,49 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vabs_i16_signed_overflow_edge_kernel( + __gm__ int16_t *v1, __gm__ int16_t *v2); + +void LaunchVabs_i16_signed_overflow_edge_kernel(int16_t *v1, int16_t *v2, + void *stream) { + vabs_i16_signed_overflow_edge_kernel<<<1, nullptr, stream>>>( + (__gm__ int16_t *)v1, (__gm__ int16_t *)v2); +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed-overflow-edge/main.cpp b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed-overflow-edge/main.cpp new file mode 100644 index 000000000..55de29f79 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed-overflow-edge/main.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vabs-i16-signed-overflow-edge +// family: unary-vector +// target_ops: pto.vabs +// scenarios: core-i16-signed, full-mask, integer-overflow +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_i16_signed_overflow_edge_kernel(int16_t *v1, int16_t *v2, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(int16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(int16_t); + int16_t *v1Host = nullptr; + int16_t *v1Device = nullptr; + int16_t *v2Host = nullptr; + int16_t *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_i16_signed_overflow_edge_kernel(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed/compare.py b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed/compare.py new file mode 100755 index 000000000..eca2ddc70 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/unary-vector/vabs-i16-signed +# family: unary-vector +# target_ops: pto.vabs +# scenarios: core-i16-signed, full-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed/golden.py b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed/golden.py new file mode 100755 index 000000000..ae05da408 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/unary-vector/vabs-i16-signed +# family: unary-vector +# target_ops: pto.vabs +# scenarios: core-i16-signed, full-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.abs(v1).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vabs validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed/kernel.pto new file mode 100644 index 000000000..79d01c359 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed/kernel.pto @@ -0,0 +1,70 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vabs-i16-signed +// family: unary-vector +// target_ops: pto.vabs +// scenarios: core-i16-signed, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vabs_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vabs %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed/launch.cpp b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed/launch.cpp new file mode 100644 index 000000000..1d4dc5556 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vabs-i16-signed +// family: unary-vector +// target_ops: pto.vabs +// scenarios: core-i16-signed, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream) { + vabs_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed/main.cpp b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed/main.cpp new file mode 100644 index 000000000..565d8c357 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-i16-signed/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vabs-i16-signed +// family: unary-vector +// target_ops: pto.vabs +// scenarios: core-i16-signed, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-i16-unsigned/compare.py b/test/vpto/cases/micro-op/unary-vector/vabs-i16-unsigned/compare.py new file mode 100755 index 000000000..a6d5c46f7 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-i16-unsigned/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/unary-vector/vabs-i16-unsigned +# family: unary-vector +# target_ops: pto.vabs +# scenarios: core-i16-unsigned, full-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-i16-unsigned/golden.py b/test/vpto/cases/micro-op/unary-vector/vabs-i16-unsigned/golden.py new file mode 100755 index 000000000..6f75723de --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-i16-unsigned/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/unary-vector/vabs-i16-unsigned +# family: unary-vector +# target_ops: pto.vabs +# scenarios: core-i16-unsigned, full-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.abs(v1).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vabs validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-i16-unsigned/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vabs-i16-unsigned/kernel.pto new file mode 100644 index 000000000..d68983556 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-i16-unsigned/kernel.pto @@ -0,0 +1,70 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vabs-i16-unsigned +// family: unary-vector +// target_ops: pto.vabs +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vabs_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vabs %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-i16-unsigned/launch.cpp b/test/vpto/cases/micro-op/unary-vector/vabs-i16-unsigned/launch.cpp new file mode 100644 index 000000000..8ae1a4350 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-i16-unsigned/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vabs-i16-unsigned +// family: unary-vector +// target_ops: pto.vabs +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream) { + vabs_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-i16-unsigned/main.cpp b/test/vpto/cases/micro-op/unary-vector/vabs-i16-unsigned/main.cpp new file mode 100644 index 000000000..d54791913 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-i16-unsigned/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vabs-i16-unsigned +// family: unary-vector +// target_ops: pto.vabs +// scenarios: core-i16-unsigned, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-loop-carried-vreg/compare.py b/test/vpto/cases/micro-op/unary-vector/vabs-loop-carried-vreg/compare.py new file mode 100644 index 000000000..6098dd82c --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-loop-carried-vreg/compare.py @@ -0,0 +1,198 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 +ACTIVE_ELEMS = 1000 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin_prefix("golden_v2.bin", "v2.bin", np.float32, 0.0001, ACTIVE_ELEMS) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-loop-carried-vreg/golden.py b/test/vpto/cases/micro-op/unary-vector/vabs-loop-carried-vreg/golden.py new file mode 100644 index 000000000..7448a6a1c --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-loop-carried-vreg/golden.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +ACTIVE_ELEMS = 1000 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + edge_values = np.array( + [ + -0.0, + 0.0, + -1.0, + 1.0, + -8.0, + 8.0, + -1.0e-30, + 1.0e-30, + -1.0e10, + 1.0e10, + -3.5, + 3.5, + -7.25, + 7.25, + -2.0, + 2.0, + ], + dtype=np.float32, + ) + v1.reshape(-1)[: edge_values.size] = edge_values + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.zeros((ROWS, COLS), dtype=np.float32) + flat_v1 = v1.reshape(-1) + flat_golden_v2 = golden_v2.reshape(-1) + flat_golden_v2[:ACTIVE_ELEMS] = np.abs(flat_v1[:ACTIVE_ELEMS]).astype( + np.float32, copy=False + ) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vabs loop-carried vreg validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-loop-carried-vreg/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vabs-loop-carried-vreg/kernel.pto new file mode 100644 index 000000000..476939d7a --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-loop-carried-vreg/kernel.pto @@ -0,0 +1,50 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vabs_loop_carried_vreg_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %c1000 = arith.constant 1000 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1000_i32 = arith.constant 1000 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1000 step %c64 iter_args(%remaining = %c1000_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = scf.for %iter = %c0 to %c2 step %c1 + iter_args(%carry = %vec) -> (!pto.vreg<64xf32>) { + %abs = pto.vabs %carry, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + scf.yield %abs : !pto.vreg<64xf32> + } + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-loop-carried-vreg/launch.cpp b/test/vpto/cases/micro-op/unary-vector/vabs-loop-carried-vreg/launch.cpp new file mode 100644 index 000000000..2663b7625 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-loop-carried-vreg/launch.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vabs_loop_carried_vreg_kernel_2d(__gm__ float *v1, __gm__ float *v2); + +void LaunchVabs_loop_carried_vreg_kernel_2d(float *v1, float *v2, void *stream) { + vabs_loop_carried_vreg_kernel_2d<<<1, nullptr, stream>>>( + (__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-loop-carried-vreg/main.cpp b/test/vpto/cases/micro-op/unary-vector/vabs-loop-carried-vreg/main.cpp new file mode 100644 index 000000000..4d4bd221b --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-loop-carried-vreg/main.cpp @@ -0,0 +1,122 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_loop_carried_vreg_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_loop_carried_vreg_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-tail/compare.py b/test/vpto/cases/micro-op/unary-vector/vabs-tail/compare.py new file mode 100644 index 000000000..c13d79273 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-tail/compare.py @@ -0,0 +1,39 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v2.bin", "v2.bin", np.float32, 1e-4, 1000) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-tail/golden.py b/test/vpto/cases/micro-op/unary-vector/vabs-tail/golden.py new file mode 100644 index 000000000..03fd9a768 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-tail/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +LOGICAL_ELEMS = 1000 +OUT_SENTINEL = np.float32(-123.25) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + golden_v2 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + golden_v2.reshape(-1)[:LOGICAL_ELEMS] = np.abs( + v1.reshape(-1)[:LOGICAL_ELEMS] + ).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-tail/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vabs-tail/kernel.pto new file mode 100644 index 000000000..00f93abc2 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-tail/kernel.pto @@ -0,0 +1,42 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vabs_tail_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1000_i32 = arith.constant 1000 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1000_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vabs %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-tail/launch.cpp b/test/vpto/cases/micro-op/unary-vector/vabs-tail/launch.cpp new file mode 100644 index 000000000..494bc5bf3 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-tail/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vabs_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_tail_kernel_2d(float *v1, float *v2, void *stream) { + vabs_tail_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs-tail/main.cpp b/test/vpto/cases/micro-op/unary-vector/vabs-tail/main.cpp new file mode 100644 index 000000000..cf25e5dff --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs-tail/main.cpp @@ -0,0 +1,87 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_tail_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_tail_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs/compare.py b/test/vpto/cases/micro-op/unary-vector/vabs/compare.py new file mode 100644 index 000000000..962985a24 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs/compare.py @@ -0,0 +1,204 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vabs/golden.py b/test/vpto/cases/micro-op/unary-vector/vabs/golden.py new file mode 100644 index 000000000..5b04d20bb --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs/golden.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.abs(v1).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vabs validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vabs/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vabs/kernel.pto new file mode 100644 index 000000000..75486b4fe --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs/kernel.pto @@ -0,0 +1,62 @@ +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vabs_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vabs %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs/launch.cpp b/test/vpto/cases/micro-op/unary-vector/vabs/launch.cpp new file mode 100644 index 000000000..9002bcd67 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs/launch.cpp @@ -0,0 +1,62 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream) { + vabs_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/unary-vector/vabs/main.cpp b/test/vpto/cases/micro-op/unary-vector/vabs/main.cpp new file mode 100644 index 000000000..29454461f --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vabs/main.cpp @@ -0,0 +1,122 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-f16/compare.py b/test/vpto/cases/micro-op/unary-vector/vexp-f16/compare.py new file mode 100755 index 000000000..1971de729 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp-f16/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/unary-vector/vexp-f16 +# family: unary-vector +# target_ops: pto.vexp +# scenarios: core-f16, full-mask +# NOTE: f16 vector exp baseline. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float16, 0.01) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-f16/golden.py b/test/vpto/cases/micro-op/unary-vector/vexp-f16/golden.py new file mode 100755 index 000000000..aa2de48ca --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp-f16/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/unary-vector/vexp-f16 +# family: unary-vector +# target_ops: pto.vexp +# scenarios: core-f16, full-mask +# NOTE: f16 vector exp baseline. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-4.0, 4.0, size=(ROWS, COLS)).astype(np.float16) + v2 = np.zeros((ROWS, COLS), dtype=np.float16) + golden_v2 = np.exp(v1.astype(np.float32)).astype(np.float16) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vexp f16 validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-f16/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vexp-f16/kernel.pto new file mode 100644 index 000000000..7bd07d6e9 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp-f16/kernel.pto @@ -0,0 +1,50 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vexp-f16 +// family: unary-vector +// target_ops: pto.vexp +// scenarios: core-f16, full-mask +// NOTE: f16 vector exp baseline. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vexp_f16_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<128xf16> + %out = pto.vexp %vec, %mask : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<128xf16>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-f16/launch.cpp b/test/vpto/cases/micro-op/unary-vector/vexp-f16/launch.cpp new file mode 100644 index 000000000..4530d8cea --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp-f16/launch.cpp @@ -0,0 +1,71 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vexp-f16 +// family: unary-vector +// target_ops: pto.vexp +// scenarios: core-f16, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vexp_f16_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVexp_f16_kernel_2d(float *v1, float *v2, void *stream) { + vexp_f16_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-f16/main.cpp b/test/vpto/cases/micro-op/unary-vector/vexp-f16/main.cpp new file mode 100644 index 000000000..c41afb75e --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp-f16/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vexp-f16 +// family: unary-vector +// target_ops: pto.vexp +// scenarios: core-f16, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVexp_f16_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVexp_f16_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-f32-exceptional/compare.py b/test/vpto/cases/micro-op/unary-vector/vexp-f32-exceptional/compare.py new file mode 100644 index 000000000..962985a24 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp-f32-exceptional/compare.py @@ -0,0 +1,204 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-f32-exceptional/golden.py b/test/vpto/cases/micro-op/unary-vector/vexp-f32-exceptional/golden.py new file mode 100644 index 000000000..fd76b39a9 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp-f32-exceptional/golden.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + del seed + specials = np.array( + [-np.inf, -1.0, -0.0, 0.0, 1.0, np.inf, np.nan, 3.5], + dtype=np.float32, + ) + v1 = np.resize(specials, ROWS * COLS).reshape(ROWS, COLS).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.exp(v1).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-f32-exceptional/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vexp-f32-exceptional/kernel.pto new file mode 100644 index 000000000..d98e71ba8 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp-f32-exceptional/kernel.pto @@ -0,0 +1,42 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vexp_f32_exceptional_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vexp %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-f32-exceptional/launch.cpp b/test/vpto/cases/micro-op/unary-vector/vexp-f32-exceptional/launch.cpp new file mode 100644 index 000000000..f96f1fc2e --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp-f32-exceptional/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vexp_f32_exceptional_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVexp_f32_exceptional_kernel_2d(float *v1, float *v2, void *stream) { + vexp_f32_exceptional_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-f32-exceptional/main.cpp b/test/vpto/cases/micro-op/unary-vector/vexp-f32-exceptional/main.cpp new file mode 100644 index 000000000..2a6824d9f --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp-f32-exceptional/main.cpp @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVexp_f32_exceptional_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVexp_f32_exceptional_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-f32-over-underflow/compare.py b/test/vpto/cases/micro-op/unary-vector/vexp-f32-over-underflow/compare.py new file mode 100644 index 000000000..962985a24 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp-f32-over-underflow/compare.py @@ -0,0 +1,204 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-f32-over-underflow/golden.py b/test/vpto/cases/micro-op/unary-vector/vexp-f32-over-underflow/golden.py new file mode 100644 index 000000000..11cde41fe --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp-f32-over-underflow/golden.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + del seed + specials = np.array( + [-120.0, -104.0, -88.0, 0.0, 40.0, 88.0, 90.0, 104.0], + dtype=np.float32, + ) + v1 = np.resize(specials, ROWS * COLS).reshape(ROWS, COLS).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.exp(v1).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-f32-over-underflow/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vexp-f32-over-underflow/kernel.pto new file mode 100644 index 000000000..8d3e0f31f --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp-f32-over-underflow/kernel.pto @@ -0,0 +1,42 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vexp_f32_over_underflow_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vexp %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-f32-over-underflow/launch.cpp b/test/vpto/cases/micro-op/unary-vector/vexp-f32-over-underflow/launch.cpp new file mode 100644 index 000000000..219a407d0 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp-f32-over-underflow/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vexp_f32_over_underflow_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVexp_f32_exceptional_kernel_2d(float *v1, float *v2, void *stream) { + vexp_f32_over_underflow_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-f32-over-underflow/main.cpp b/test/vpto/cases/micro-op/unary-vector/vexp-f32-over-underflow/main.cpp new file mode 100644 index 000000000..2a6824d9f --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp-f32-over-underflow/main.cpp @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVexp_f32_exceptional_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVexp_f32_exceptional_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-tail/compare.py b/test/vpto/cases/micro-op/unary-vector/vexp-tail/compare.py new file mode 100644 index 000000000..c13d79273 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp-tail/compare.py @@ -0,0 +1,39 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v2.bin", "v2.bin", np.float32, 1e-4, 1000) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-tail/golden.py b/test/vpto/cases/micro-op/unary-vector/vexp-tail/golden.py new file mode 100644 index 000000000..b77b49528 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp-tail/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +LOGICAL_ELEMS = 1000 +OUT_SENTINEL = np.float32(-123.25) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-4.0, 4.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + golden_v2 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + golden_v2.reshape(-1)[:LOGICAL_ELEMS] = np.exp( + v1.reshape(-1)[:LOGICAL_ELEMS] + ).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-tail/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vexp-tail/kernel.pto new file mode 100644 index 000000000..d8719f305 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp-tail/kernel.pto @@ -0,0 +1,42 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vexp_tail_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1000_i32 = arith.constant 1000 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1000_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vexp %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-tail/launch.cpp b/test/vpto/cases/micro-op/unary-vector/vexp-tail/launch.cpp new file mode 100644 index 000000000..723fee5d5 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp-tail/launch.cpp @@ -0,0 +1,43 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vexp_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVexp_tail_kernel_2d(float *v1, float *v2, void *stream) { + vexp_tail_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/unary-vector/vexp-tail/main.cpp b/test/vpto/cases/micro-op/unary-vector/vexp-tail/main.cpp new file mode 100644 index 000000000..19f1b06f2 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp-tail/main.cpp @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVexp_tail_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVexp_tail_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vexp/compare.py b/test/vpto/cases/micro-op/unary-vector/vexp/compare.py new file mode 100644 index 000000000..962985a24 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp/compare.py @@ -0,0 +1,204 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vexp/golden.py b/test/vpto/cases/micro-op/unary-vector/vexp/golden.py new file mode 100644 index 000000000..1df0d6853 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp/golden.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-4.0, 4.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.exp(v1).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vexp validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vexp/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vexp/kernel.pto new file mode 100644 index 000000000..b724b268c --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp/kernel.pto @@ -0,0 +1,62 @@ +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vexp_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vexp %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/unary-vector/vexp/launch.cpp b/test/vpto/cases/micro-op/unary-vector/vexp/launch.cpp new file mode 100644 index 000000000..b6d8cdbf0 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp/launch.cpp @@ -0,0 +1,62 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vexp_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVexp_kernel_2d(float *v1, float *v2, void *stream) { + vexp_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/unary-vector/vexp/main.cpp b/test/vpto/cases/micro-op/unary-vector/vexp/main.cpp new file mode 100644 index 000000000..f864622ca --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vexp/main.cpp @@ -0,0 +1,122 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVexp_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVexp_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vln-domain-boundary/compare.py b/test/vpto/cases/micro-op/unary-vector/vln-domain-boundary/compare.py new file mode 100755 index 000000000..afee62a98 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vln-domain-boundary/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/unary-vector/vln-domain-boundary +# family: unary-vector +# target_ops: pto.vln +# scenarios: core-f32, domain-positive, exceptional-values +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vln-domain-boundary/golden.py b/test/vpto/cases/micro-op/unary-vector/vln-domain-boundary/golden.py new file mode 100755 index 000000000..64f82ec2d --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vln-domain-boundary/golden.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/unary-vector/vln-domain-boundary +# family: unary-vector +# target_ops: pto.vln +# scenarios: core-f32, domain-positive, exceptional-values +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(0.125, 8.0, size=(ROWS, COLS)).astype(np.float32) + flat = v1.reshape(-1) + flat[:8] = np.array( + [ + np.float32(np.finfo(np.float32).tiny), + np.float32(np.finfo(np.float32).tiny * 2.0), + np.float32(1.0), + np.float32(2.0), + np.float32(16.0), + np.float32(1024.0), + np.float32(np.finfo(np.float32).max), + np.float32(0.5), + ], + dtype=np.float32, + ) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.log(v1).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vln domain-boundary validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vln-domain-boundary/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vln-domain-boundary/kernel.pto new file mode 100644 index 000000000..b841980ba --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vln-domain-boundary/kernel.pto @@ -0,0 +1,70 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vln-domain-boundary +// family: unary-vector +// target_ops: pto.vln +// scenarios: core-f32, domain-positive, exceptional-values +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vabs_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vln %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/unary-vector/vln-domain-boundary/launch.cpp b/test/vpto/cases/micro-op/unary-vector/vln-domain-boundary/launch.cpp new file mode 100644 index 000000000..6aeeded6c --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vln-domain-boundary/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vln-domain-boundary +// family: unary-vector +// target_ops: pto.vln +// scenarios: core-f32, domain-positive, exceptional-values +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream) { + vabs_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/unary-vector/vln-domain-boundary/main.cpp b/test/vpto/cases/micro-op/unary-vector/vln-domain-boundary/main.cpp new file mode 100644 index 000000000..ab31f79d8 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vln-domain-boundary/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vln-domain-boundary +// family: unary-vector +// target_ops: pto.vln +// scenarios: core-f32, domain-positive, exceptional-values +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vln/compare.py b/test/vpto/cases/micro-op/unary-vector/vln/compare.py new file mode 100644 index 000000000..962985a24 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vln/compare.py @@ -0,0 +1,204 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vln/golden.py b/test/vpto/cases/micro-op/unary-vector/vln/golden.py new file mode 100644 index 000000000..b7a53856f --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vln/golden.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = np.exp(rng.uniform(-4.0, 2.0, size=(ROWS, COLS)).astype(np.float32)) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.log(v1).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vln validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vln/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vln/kernel.pto new file mode 100644 index 000000000..b3105ea1f --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vln/kernel.pto @@ -0,0 +1,43 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vexp_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vln %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/unary-vector/vln/launch.cpp b/test/vpto/cases/micro-op/unary-vector/vln/launch.cpp new file mode 100644 index 000000000..b6d8cdbf0 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vln/launch.cpp @@ -0,0 +1,62 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vexp_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVexp_kernel_2d(float *v1, float *v2, void *stream) { + vexp_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/unary-vector/vln/main.cpp b/test/vpto/cases/micro-op/unary-vector/vln/main.cpp new file mode 100644 index 000000000..f864622ca --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vln/main.cpp @@ -0,0 +1,122 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVexp_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVexp_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vneg-f32-exceptional/compare.py b/test/vpto/cases/micro-op/unary-vector/vneg-f32-exceptional/compare.py new file mode 100755 index 000000000..1030c959f --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vneg-f32-exceptional/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/unary-vector/vneg-f32-exceptional +# family: unary-vector +# target_ops: pto.vneg +# scenarios: core-f32, exceptional-values +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vneg-f32-exceptional/golden.py b/test/vpto/cases/micro-op/unary-vector/vneg-f32-exceptional/golden.py new file mode 100755 index 000000000..0f394e5ed --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vneg-f32-exceptional/golden.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/unary-vector/vneg-f32-exceptional +# family: unary-vector +# target_ops: pto.vneg +# scenarios: core-f32, exceptional-values +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + flat = v1.reshape(-1) + flat[:8] = np.array( + [ + np.float32(0.0), + np.float32(-0.0), + np.float32(np.inf), + np.float32(-np.inf), + np.float32(np.nan), + np.float32(1.0), + np.float32(-1.0), + np.float32(3.5), + ], + dtype=np.float32, + ) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.negative(v1).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vneg exceptional validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vneg-f32-exceptional/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vneg-f32-exceptional/kernel.pto new file mode 100644 index 000000000..9efd8ff9b --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vneg-f32-exceptional/kernel.pto @@ -0,0 +1,70 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vneg-f32-exceptional +// family: unary-vector +// target_ops: pto.vneg +// scenarios: core-f32, exceptional-values +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vabs_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vneg %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/unary-vector/vneg-f32-exceptional/launch.cpp b/test/vpto/cases/micro-op/unary-vector/vneg-f32-exceptional/launch.cpp new file mode 100644 index 000000000..2614d8040 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vneg-f32-exceptional/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vneg-f32-exceptional +// family: unary-vector +// target_ops: pto.vneg +// scenarios: core-f32, exceptional-values +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream) { + vabs_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/unary-vector/vneg-f32-exceptional/main.cpp b/test/vpto/cases/micro-op/unary-vector/vneg-f32-exceptional/main.cpp new file mode 100644 index 000000000..de8fba973 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vneg-f32-exceptional/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vneg-f32-exceptional +// family: unary-vector +// target_ops: pto.vneg +// scenarios: core-f32, exceptional-values +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vneg/compare.py b/test/vpto/cases/micro-op/unary-vector/vneg/compare.py new file mode 100755 index 000000000..0ce4e18b6 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vneg/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/unary-vector/vneg +# family: unary-vector +# target_ops: pto.vneg +# scenarios: core-f32, full-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vneg/golden.py b/test/vpto/cases/micro-op/unary-vector/vneg/golden.py new file mode 100755 index 000000000..a7e86608a --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vneg/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/unary-vector/vneg +# family: unary-vector +# target_ops: pto.vneg +# scenarios: core-f32, full-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.negative(v1).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vneg validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vneg/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vneg/kernel.pto new file mode 100644 index 000000000..5926ca4ba --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vneg/kernel.pto @@ -0,0 +1,70 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vneg +// family: unary-vector +// target_ops: pto.vneg +// scenarios: core-f32, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vabs_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vneg %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/unary-vector/vneg/launch.cpp b/test/vpto/cases/micro-op/unary-vector/vneg/launch.cpp new file mode 100644 index 000000000..65504cb9e --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vneg/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vneg +// family: unary-vector +// target_ops: pto.vneg +// scenarios: core-f32, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream) { + vabs_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/unary-vector/vneg/main.cpp b/test/vpto/cases/micro-op/unary-vector/vneg/main.cpp new file mode 100644 index 000000000..134aa5b2c --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vneg/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vneg +// family: unary-vector +// target_ops: pto.vneg +// scenarios: core-f32, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vnot/compare.py b/test/vpto/cases/micro-op/unary-vector/vnot/compare.py new file mode 100755 index 000000000..cdecd8075 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vnot/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/unary-vector/vnot +# family: unary-vector +# target_ops: pto.vnot +# scenarios: core-i16-signed, full-mask +# NOTE: lane-wise bitwise inversion on signed i16 source lanes. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.int16, 0.0) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vnot/golden.py b/test/vpto/cases/micro-op/unary-vector/vnot/golden.py new file mode 100755 index 000000000..c0d048c27 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vnot/golden.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/unary-vector/vnot +# family: unary-vector +# target_ops: pto.vnot +# scenarios: core-i16-signed, full-mask +# NOTE: lane-wise bitwise inversion on signed i16 source lanes. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers( + low=np.iinfo(np.int16).min, + high=np.iinfo(np.int16).max + 1, + size=(ROWS, COLS), + dtype=np.int16, + ) + v2 = np.zeros((ROWS, COLS), dtype=np.int16) + golden_v2 = np.bitwise_not(v1).astype(np.int16, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vnot validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vnot/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vnot/kernel.pto new file mode 100644 index 000000000..eaa0826ff --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vnot/kernel.pto @@ -0,0 +1,50 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vnot +// family: unary-vector +// target_ops: pto.vnot +// scenarios: core-i16-signed, full-mask +// NOTE: lane-wise bitwise inversion on signed i16 source lanes. +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vnot_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<128xi16> + %out = pto.vnot %vec, %mask : !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<128xi16>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/unary-vector/vnot/launch.cpp b/test/vpto/cases/micro-op/unary-vector/vnot/launch.cpp new file mode 100644 index 000000000..c2b22293e --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vnot/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vnot +// family: unary-vector +// target_ops: pto.vnot +// scenarios: core-i16-signed, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vnot_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVnot_kernel_2d(float *v1, float *v2, void *stream) { + vnot_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/unary-vector/vnot/main.cpp b/test/vpto/cases/micro-op/unary-vector/vnot/main.cpp new file mode 100644 index 000000000..3b97a8523 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vnot/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vnot +// family: unary-vector +// target_ops: pto.vnot +// scenarios: core-i16-signed, full-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVnot_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVnot_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vrelu/compare.py b/test/vpto/cases/micro-op/unary-vector/vrelu/compare.py new file mode 100644 index 000000000..962985a24 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vrelu/compare.py @@ -0,0 +1,204 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vrelu/golden.py b/test/vpto/cases/micro-op/unary-vector/vrelu/golden.py new file mode 100644 index 000000000..d481d99d4 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vrelu/golden.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.maximum(v1, np.float32(0.0)).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vrelu validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vrelu/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vrelu/kernel.pto new file mode 100644 index 000000000..5407bf8cb --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vrelu/kernel.pto @@ -0,0 +1,43 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vexp_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vrelu %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/unary-vector/vrelu/launch.cpp b/test/vpto/cases/micro-op/unary-vector/vrelu/launch.cpp new file mode 100644 index 000000000..b6d8cdbf0 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vrelu/launch.cpp @@ -0,0 +1,62 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vexp_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVexp_kernel_2d(float *v1, float *v2, void *stream) { + vexp_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/unary-vector/vrelu/main.cpp b/test/vpto/cases/micro-op/unary-vector/vrelu/main.cpp new file mode 100644 index 000000000..f864622ca --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vrelu/main.cpp @@ -0,0 +1,122 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVexp_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVexp_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vsqrt-domain-boundary/compare.py b/test/vpto/cases/micro-op/unary-vector/vsqrt-domain-boundary/compare.py new file mode 100755 index 000000000..71d8b50c2 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vsqrt-domain-boundary/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/unary-vector/vsqrt-domain-boundary +# family: unary-vector +# target_ops: pto.vsqrt +# scenarios: core-f32, domain-nonnegative, exceptional-values +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vsqrt-domain-boundary/golden.py b/test/vpto/cases/micro-op/unary-vector/vsqrt-domain-boundary/golden.py new file mode 100755 index 000000000..9607fbcc5 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vsqrt-domain-boundary/golden.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/unary-vector/vsqrt-domain-boundary +# family: unary-vector +# target_ops: pto.vsqrt +# scenarios: core-f32, domain-nonnegative, exceptional-values +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(0.0, 16.0, size=(ROWS, COLS)).astype(np.float32) + flat = v1.reshape(-1) + flat[:8] = np.array( + [ + np.float32(0.0), + np.nextafter(np.float32(0.0), np.float32(1.0), dtype=np.float32), + np.float32(1.0), + np.float32(4.0), + np.float32(9.0), + np.float32(16.0), + np.float32(1024.0), + np.float32(np.finfo(np.float32).max), + ], + dtype=np.float32, + ) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.sqrt(v1).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vsqrt domain-boundary validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vsqrt-domain-boundary/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vsqrt-domain-boundary/kernel.pto new file mode 100644 index 000000000..f7c39b33a --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vsqrt-domain-boundary/kernel.pto @@ -0,0 +1,70 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vsqrt-domain-boundary +// family: unary-vector +// target_ops: pto.vsqrt +// scenarios: core-f32, domain-nonnegative, exceptional-values +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vabs_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vsqrt %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/unary-vector/vsqrt-domain-boundary/launch.cpp b/test/vpto/cases/micro-op/unary-vector/vsqrt-domain-boundary/launch.cpp new file mode 100644 index 000000000..1070e35b1 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vsqrt-domain-boundary/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vsqrt-domain-boundary +// family: unary-vector +// target_ops: pto.vsqrt +// scenarios: core-f32, domain-nonnegative, exceptional-values +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream) { + vabs_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/unary-vector/vsqrt-domain-boundary/main.cpp b/test/vpto/cases/micro-op/unary-vector/vsqrt-domain-boundary/main.cpp new file mode 100644 index 000000000..95bed1286 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vsqrt-domain-boundary/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/unary-vector/vsqrt-domain-boundary +// family: unary-vector +// target_ops: pto.vsqrt +// scenarios: core-f32, domain-nonnegative, exceptional-values +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/unary-vector/vsqrt/compare.py b/test/vpto/cases/micro-op/unary-vector/vsqrt/compare.py new file mode 100644 index 000000000..962985a24 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vsqrt/compare.py @@ -0,0 +1,204 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vsqrt/golden.py b/test/vpto/cases/micro-op/unary-vector/vsqrt/golden.py new file mode 100644 index 000000000..a5739a6b3 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vsqrt/golden.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + base = rng.uniform(0.0, 4.0, size=(ROWS, COLS)).astype(np.float32) + v1 = np.square(base).astype(np.float32, copy=False) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.sqrt(v1).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vsqrt validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/unary-vector/vsqrt/kernel.pto b/test/vpto/cases/micro-op/unary-vector/vsqrt/kernel.pto new file mode 100644 index 000000000..af8858995 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vsqrt/kernel.pto @@ -0,0 +1,43 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vexp_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %out = pto.vsqrt %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/unary-vector/vsqrt/launch.cpp b/test/vpto/cases/micro-op/unary-vector/vsqrt/launch.cpp new file mode 100644 index 000000000..b6d8cdbf0 --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vsqrt/launch.cpp @@ -0,0 +1,62 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vexp_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVexp_kernel_2d(float *v1, float *v2, void *stream) { + vexp_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/unary-vector/vsqrt/main.cpp b/test/vpto/cases/micro-op/unary-vector/vsqrt/main.cpp new file mode 100644 index 000000000..f864622ca --- /dev/null +++ b/test/vpto/cases/micro-op/unary-vector/vsqrt/main.cpp @@ -0,0 +1,122 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVexp_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVexp_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vaddcs-carry-boundary/compare.py b/test/vpto/cases/micro-op/vec-scalar/vaddcs-carry-boundary/compare.py new file mode 100644 index 000000000..b5dd9902e --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vaddcs-carry-boundary/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vec-scalar/vaddcs-carry-boundary +# family: vec-scalar +# target_ops: pto.vaddcs +# scenarios: core-u32-unsigned, full-mask, carry-chain, integer-overflow + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 +LOGICAL_ELEMS = 64 +SRC_ELEM_BYTES = 4 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + repeat_elems = REPEAT_BYTES // src_elem_bytes + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + +def compare_result(): + golden = np.fromfile("golden_v3.bin", dtype=np.uint32, count=64) + output = np.fromfile("v3.bin", dtype=np.uint32, count=64) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def compare_carry(): + prefix_bytes = _packed_pred_storage_bytes(LOGICAL_ELEMS, SRC_ELEM_BYTES) + golden = np.fromfile("golden_v4.bin", dtype=np.uint8) + output = np.fromfile("v4.bin", dtype=np.uint8) + if golden.size < prefix_bytes or output.size < prefix_bytes: + return False + return np.array_equal(golden[:prefix_bytes], output[:prefix_bytes]) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_result() and compare_carry() + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vaddcs-carry-boundary/golden.py b/test/vpto/cases/micro-op/vec-scalar/vaddcs-carry-boundary/golden.py new file mode 100644 index 000000000..ddf74542c --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vaddcs-carry-boundary/golden.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vec-scalar/vaddcs-carry-boundary +# family: vec-scalar +# target_ops: pto.vaddcs +# scenarios: core-u32-unsigned, full-mask, carry-chain, integer-overflow + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 64 +LHS_PATTERN = np.array( + [0x00000000, 0x00000001, 0xFFFFFFFE, 0xFFFFFFFF, 0x7FFFFFFF, 0x80000000, 0xAAAAAAAA, 0x55555555], + dtype=np.uint32, +) +RHS_PATTERN = np.array( + [0x00000000, 0xFFFFFFFF, 0x00000001, 0x00000000, 0x80000000, 0x7FFFFFFF, 0x55555555, 0xAAAAAAAA], + dtype=np.uint32, +) + + +def pack_mask_nibbles(bits): + out = np.zeros(256, dtype=np.uint8) + for idx, bit in enumerate(bits): + if not bit: + continue + byte = idx // 2 + if idx % 2 == 0: + out[byte] |= np.uint8(0x1) + else: + out[byte] |= np.uint8(0x10) + return out + + +def generate(output_dir: Path, seed: int) -> None: + del seed + repeats = LANES // LHS_PATTERN.size + lhs = np.tile(LHS_PATTERN, repeats) + rhs = np.tile(RHS_PATTERN, repeats) + total = lhs.astype(np.uint64) + rhs.astype(np.uint64) + np.uint64(1) + result = (total & np.uint64(0xFFFFFFFF)).astype(np.uint32) + carry = (total >> np.uint64(32)) != 0 + + output_dir.mkdir(parents=True, exist_ok=True) + lhs.tofile(output_dir / "v1.bin") + rhs.tofile(output_dir / "v2.bin") + np.zeros(LANES, dtype=np.uint32).tofile(output_dir / "v3.bin") + np.zeros(256, dtype=np.uint8).tofile(output_dir / "v4.bin") + result.tofile(output_dir / "golden_v3.bin") + pack_mask_nibbles(carry).tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=19) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vaddcs-carry-boundary/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vaddcs-carry-boundary/kernel.pto new file mode 100644 index 000000000..cd5a5748d --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vaddcs-carry-boundary/kernel.pto @@ -0,0 +1,57 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vaddcs-carry-boundary +// family: vec-scalar +// target_ops: pto.vaddcs +// scenarios: core-u32-unsigned, full-mask, carry-chain, integer-overflow +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vaddcs_carry_boundary_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr, + %arg2: !pto.ptr, + %arg3: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c128_i64 = arith.constant 128 : i64 + %c256_i64 = arith.constant 256 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c12288_i64 = arith.constant 12288 : i64 + %false = arith.constant false + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_carry = pto.castptr %c12288_i64 : i64 -> !pto.ptr + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %carry_in = pto.pset_b32 "PAT_ALL" : !pto.mask + %lhs = pto.vlds %ub_lhs[%c0] : !pto.ptr -> !pto.vreg<64xui32> + %rhs = pto.vlds %ub_rhs[%c0] : !pto.ptr -> !pto.vreg<64xui32> + %sum, %carry = pto.vaddcs %lhs, %rhs, %carry_in, %mask : !pto.vreg<64xui32>, !pto.vreg<64xui32>, !pto.mask, !pto.mask -> !pto.vreg<64xui32>, !pto.mask + pto.vsts %sum, %ub_out[%c0], %mask : !pto.vreg<64xui32>, !pto.ptr, !pto.mask + pto.psti %carry, %ub_carry[%c0], "NORM" : !pto.mask, !pto.ptr, index + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_carry, %arg3, %c128_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vaddcs-carry-boundary/launch.cpp b/test/vpto/cases/micro-op/vec-scalar/vaddcs-carry-boundary/launch.cpp new file mode 100644 index 000000000..209c04c6a --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vaddcs-carry-boundary/launch.cpp @@ -0,0 +1,53 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vaddcs-carry-boundary +// family: vec-scalar +// target_ops: pto.vaddcs +// scenarios: core-u32-unsigned, full-mask, carry-chain, integer-overflow +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vaddcs_carry_boundary_kernel( + __gm__ uint32_t *v1, __gm__ uint32_t *v2, __gm__ uint32_t *v3, + __gm__ uint8_t *v4); + +void LaunchVaddcsCarryBoundaryKernel(uint32_t *v1, uint32_t *v2, uint32_t *v3, + uint8_t *v4, void *stream) { + vaddcs_carry_boundary_kernel<<<1, nullptr, stream>>>( + (__gm__ uint32_t *)v1, (__gm__ uint32_t *)v2, (__gm__ uint32_t *)v3, + (__gm__ uint8_t *)v4); +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vaddcs-carry-boundary/main.cpp b/test/vpto/cases/micro-op/vec-scalar/vaddcs-carry-boundary/main.cpp new file mode 100644 index 000000000..6addb079d --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vaddcs-carry-boundary/main.cpp @@ -0,0 +1,115 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vaddcs-carry-boundary +// family: vec-scalar +// target_ops: pto.vaddcs +// scenarios: core-u32-unsigned, full-mask, carry-chain, integer-overflow +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVaddcsCarryBoundaryKernel(uint32_t *v1, uint32_t *v2, uint32_t *v3, + uint8_t *v4, void *stream); + +int main() { + size_t elemCount_v1 = 64; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + size_t elemCount_v2 = 64; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint32_t); + size_t elemCount_v3 = 64; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint32_t); + size_t elemCount_v4 = 256; + size_t fileSize_v4 = elemCount_v4 * sizeof(uint8_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + uint32_t *v2Host = nullptr; + uint32_t *v2Device = nullptr; + uint32_t *v3Host = nullptr; + uint32_t *v3Device = nullptr; + uint8_t *v4Host = nullptr; + uint8_t *v4Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMallocHost((void **)(&v4Host), fileSize_v4)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v4Device, fileSize_v4, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ReadFile("./v4.bin", fileSize_v4, v4Host, fileSize_v4); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v4Device, fileSize_v4, v4Host, fileSize_v4, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVaddcsCarryBoundaryKernel(v1Device, v2Device, v3Device, v4Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(v4Host, fileSize_v4, v4Device, fileSize_v4, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + WriteFile("./v4.bin", v4Host, fileSize_v4); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFree(v4Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + aclrtFreeHost(v4Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vaddcs/compare.py b/test/vpto/cases/micro-op/vec-scalar/vaddcs/compare.py new file mode 100644 index 000000000..130b06567 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vaddcs/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vec-scalar/vaddcs +# family: vec-scalar +# target_ops: pto.vaddcs +# scenarios: core-u32-unsigned, full-mask, carry-chain + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 +LOGICAL_ELEMS = 64 +SRC_ELEM_BYTES = 4 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + repeat_elems = REPEAT_BYTES // src_elem_bytes + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + +def compare_result(): + golden = np.fromfile("golden_v3.bin", dtype=np.uint32, count=64) + output = np.fromfile("v3.bin", dtype=np.uint32, count=64) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def compare_carry(): + prefix_bytes = _packed_pred_storage_bytes(LOGICAL_ELEMS, SRC_ELEM_BYTES) + golden = np.fromfile("golden_v4.bin", dtype=np.uint8) + output = np.fromfile("v4.bin", dtype=np.uint8) + if golden.size < prefix_bytes or output.size < prefix_bytes: + return False + return np.array_equal(golden[:prefix_bytes], output[:prefix_bytes]) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_result() and compare_carry() + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vaddcs/golden.py b/test/vpto/cases/micro-op/vec-scalar/vaddcs/golden.py new file mode 100644 index 000000000..35f57f535 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vaddcs/golden.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vec-scalar/vaddcs +# family: vec-scalar +# target_ops: pto.vaddcs +# scenarios: core-u32-unsigned, full-mask, carry-chain + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 64 +SEED = 19 + + +def pack_mask_nibbles(bits): + out = np.zeros(256, dtype=np.uint8) + for idx, bit in enumerate(bits): + if not bit: + continue + byte = idx // 2 + if idx % 2 == 0: + out[byte] |= np.uint8(0x1) + else: + out[byte] |= np.uint8(0x10) + return out + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + lhs = rng.integers(0, 0xFFFFFFFF, size=LANES, dtype=np.uint32) + rhs = rng.integers(0, 0xFFFFFFFF, size=LANES, dtype=np.uint32) + total = lhs.astype(np.uint64) + rhs.astype(np.uint64) + np.uint64(1) + result = (total & np.uint64(0xFFFFFFFF)).astype(np.uint32) + carry = (total >> np.uint64(32)) != 0 + + output_dir.mkdir(parents=True, exist_ok=True) + lhs.tofile(output_dir / "v1.bin") + rhs.tofile(output_dir / "v2.bin") + np.zeros(LANES, dtype=np.uint32).tofile(output_dir / "v3.bin") + np.zeros(256, dtype=np.uint8).tofile(output_dir / "v4.bin") + result.tofile(output_dir / "golden_v3.bin") + pack_mask_nibbles(carry).tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vaddcs/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vaddcs/kernel.pto new file mode 100644 index 000000000..bf71e2b04 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vaddcs/kernel.pto @@ -0,0 +1,55 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vaddcs +// family: vec-scalar +// target_ops: pto.vaddcs +// scenarios: core-u32-unsigned, full-mask, carry-chain +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vaddcs_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr, %arg3: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c128_i64 = arith.constant 128 : i64 + %c256_i64 = arith.constant 256 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c12288_i64 = arith.constant 12288 : i64 + %false = arith.constant false + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_carry = pto.castptr %c12288_i64 : i64 -> !pto.ptr + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %carry_in = pto.pset_b32 "PAT_ALL" : !pto.mask + %lhs = pto.vlds %ub_lhs[%c0] : !pto.ptr -> !pto.vreg<64xui32> + %rhs = pto.vlds %ub_rhs[%c0] : !pto.ptr -> !pto.vreg<64xui32> + %sum, %carry = pto.vaddcs %lhs, %rhs, %carry_in, %mask : !pto.vreg<64xui32>, !pto.vreg<64xui32>, !pto.mask, !pto.mask -> !pto.vreg<64xui32>, !pto.mask + pto.vsts %sum, %ub_out[%c0], %mask : !pto.vreg<64xui32>, !pto.ptr, !pto.mask + pto.psti %carry, %ub_carry[%c0], "NORM" : !pto.mask, !pto.ptr, index + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_carry, %arg3, %c128_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vaddcs/launch.cpp b/test/vpto/cases/micro-op/vec-scalar/vaddcs/launch.cpp new file mode 100644 index 000000000..d25fa38e7 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vaddcs/launch.cpp @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vaddcs +// family: vec-scalar +// target_ops: pto.vaddcs +// scenarios: core-u32-unsigned, full-mask, carry-chain +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vaddcs_kernel(__gm__ uint32_t *v1, __gm__ uint32_t *v2, __gm__ uint32_t *v3, + __gm__ uint8_t *v4); + +void LaunchVaddcs_kernel(uint32_t *v1, uint32_t *v2, uint32_t *v3, uint8_t *v4, + void *stream) { + vaddcs_kernel<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1, + (__gm__ uint32_t *)v2, + (__gm__ uint32_t *)v3, + (__gm__ uint8_t *)v4); +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vaddcs/main.cpp b/test/vpto/cases/micro-op/vec-scalar/vaddcs/main.cpp new file mode 100644 index 000000000..ea0899688 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vaddcs/main.cpp @@ -0,0 +1,115 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vaddcs +// family: vec-scalar +// target_ops: pto.vaddcs +// scenarios: core-u32-unsigned, full-mask, carry-chain +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVaddcs_kernel(uint32_t *v1, uint32_t *v2, uint32_t *v3, uint8_t *v4, + void *stream); + +int main() { + size_t elemCount_v1 = 64; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + size_t elemCount_v2 = 64; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint32_t); + size_t elemCount_v3 = 64; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint32_t); + size_t elemCount_v4 = 256; + size_t fileSize_v4 = elemCount_v4 * sizeof(uint8_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + uint32_t *v2Host = nullptr; + uint32_t *v2Device = nullptr; + uint32_t *v3Host = nullptr; + uint32_t *v3Device = nullptr; + uint8_t *v4Host = nullptr; + uint8_t *v4Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMallocHost((void **)(&v4Host), fileSize_v4)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v4Device, fileSize_v4, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ReadFile("./v4.bin", fileSize_v4, v4Host, fileSize_v4); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v4Device, fileSize_v4, v4Host, fileSize_v4, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVaddcs_kernel(v1Device, v2Device, v3Device, v4Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(v4Host, fileSize_v4, v4Device, fileSize_v4, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + WriteFile("./v4.bin", v4Host, fileSize_v4); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFree(v4Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + aclrtFreeHost(v4Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-bf16/compare.py b/test/vpto/cases/micro-op/vec-scalar/vadds-bf16/compare.py new file mode 100644 index 000000000..896c992b2 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-bf16/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.uint16, 0) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-bf16/golden.py b/test/vpto/cases/micro-op/vec-scalar/vadds-bf16/golden.py new file mode 100644 index 000000000..0efc4ec38 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-bf16/golden.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 +SCALE = np.float32(1.5) + + +def f32_to_bf16_bits(values: np.ndarray) -> np.ndarray: + wide = values.astype(np.float32, copy=False).view(np.uint32) + rounding = np.uint32(0x7FFF) + ((wide >> 16) & np.uint32(1)) + return ((wide + rounding) >> 16).astype(np.uint16) + + +def bf16_bits_to_f32(bits: np.ndarray) -> np.ndarray: + return (bits.astype(np.uint32) << 16).view(np.float32) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1_f32 = rng.uniform(-4.0, 4.0, size=ELEMS).astype(np.float32) + v1 = f32_to_bf16_bits(v1_f32) + v2 = np.zeros(ELEMS, dtype=np.uint16) + scalar_bits = f32_to_bf16_bits(np.array([SCALE], dtype=np.float32))[0] + scalar = bf16_bits_to_f32(np.array([scalar_bits], dtype=np.uint16))[0] + golden_v2 = f32_to_bf16_bits(bf16_bits_to_f32(v1) + scalar) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-bf16/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vadds-bf16/kernel.pto new file mode 100644 index 000000000..3ff454ee1 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-bf16/kernel.pto @@ -0,0 +1,42 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vadds_bf16_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %cst = arith.constant 1.500000e+00 : bf16 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<128xbf16> + %sum = pto.vadds %vec, %cst, %mask : !pto.vreg<128xbf16>, bf16, !pto.mask -> !pto.vreg<128xbf16> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<128xbf16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-bf16/launch.cpp b/test/vpto/cases/micro-op/vec-scalar/vadds-bf16/launch.cpp new file mode 100644 index 000000000..734c5e95a --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-bf16/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vadds_bf16_kernel(__gm__ bfloat16_t *v1, + __gm__ bfloat16_t *v2); + +void LaunchVadds_bf16_kernel(uint16_t *v1, uint16_t *v2, void *stream) { + vadds_bf16_kernel<<<1, nullptr, stream>>>((__gm__ bfloat16_t *)v1, + (__gm__ bfloat16_t *)v2); +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-bf16/main.cpp b/test/vpto/cases/micro-op/vec-scalar/vadds-bf16/main.cpp new file mode 100644 index 000000000..ce2c7d7bf --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-bf16/main.cpp @@ -0,0 +1,84 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVadds_bf16_kernel(uint16_t *v1, uint16_t *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVadds_bf16_kernel(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-f16/compare.py b/test/vpto/cases/micro-op/vec-scalar/vadds-f16/compare.py new file mode 100644 index 000000000..1b47ca433 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-f16/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float16, 5e-3) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-f16/golden.py b/test/vpto/cases/micro-op/vec-scalar/vadds-f16/golden.py new file mode 100644 index 000000000..019cf6980 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-f16/golden.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 +SCALE = np.float16(1.5) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-4.0, 4.0, size=ELEMS).astype(np.float16) + v2 = np.zeros(ELEMS, dtype=np.float16) + golden_v2 = (v1.astype(np.float32) + np.float32(SCALE)).astype(np.float16) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-f16/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vadds-f16/kernel.pto new file mode 100644 index 000000000..798b9e92a --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-f16/kernel.pto @@ -0,0 +1,42 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vadds_f16_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %cst = arith.constant 1.500000e+00 : f16 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<128xf16> + %sum = pto.vadds %vec, %cst, %mask : !pto.vreg<128xf16>, f16, !pto.mask -> !pto.vreg<128xf16> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<128xf16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-f16/launch.cpp b/test/vpto/cases/micro-op/vec-scalar/vadds-f16/launch.cpp new file mode 100644 index 000000000..e964f1539 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-f16/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vadds_f16_kernel(__gm__ half *v1, + __gm__ half *v2); + +void LaunchVadds_f16_kernel(uint16_t *v1, uint16_t *v2, void *stream) { + vadds_f16_kernel<<<1, nullptr, stream>>>((__gm__ half *)v1, + (__gm__ half *)v2); +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-f16/main.cpp b/test/vpto/cases/micro-op/vec-scalar/vadds-f16/main.cpp new file mode 100644 index 000000000..0e9c0076c --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-f16/main.cpp @@ -0,0 +1,84 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVadds_f16_kernel(uint16_t *v1, uint16_t *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVadds_f16_kernel(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-f32-exceptional/compare.py b/test/vpto/cases/micro-op/vec-scalar/vadds-f32-exceptional/compare.py new file mode 100644 index 000000000..15b793fac --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-f32-exceptional/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-f32-exceptional/golden.py b/test/vpto/cases/micro-op/vec-scalar/vadds-f32-exceptional/golden.py new file mode 100644 index 000000000..c101038fb --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-f32-exceptional/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +SCALE = np.float32(0.5) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + specials = np.array( + [-np.inf, -1.0, -0.0, 0.0, 1.0, np.inf, np.nan, 3.5], + dtype=np.float32, + ) + v1 = np.resize(specials, ROWS * COLS).reshape(ROWS, COLS).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = (v1 + SCALE).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-f32-exceptional/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vadds-f32-exceptional/kernel.pto new file mode 100644 index 000000000..7c074f34e --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-f32-exceptional/kernel.pto @@ -0,0 +1,44 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vadds_f32_exceptional_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + %cst = arith.constant 5.000000e-01 : f32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %sum = pto.vadds %vec, %cst, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-f32-exceptional/launch.cpp b/test/vpto/cases/micro-op/vec-scalar/vadds-f32-exceptional/launch.cpp new file mode 100644 index 000000000..915da93e1 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-f32-exceptional/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vadds_f32_exceptional_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVadds_f32_exceptional_kernel_2d(float *v1, float *v2, void *stream) { + vadds_f32_exceptional_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-f32-exceptional/main.cpp b/test/vpto/cases/micro-op/vec-scalar/vadds-f32-exceptional/main.cpp new file mode 100644 index 000000000..9ba910e1c --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-f32-exceptional/main.cpp @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVadds_f32_exceptional_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVadds_f32_exceptional_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed-overflow/compare.py b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed-overflow/compare.py new file mode 100644 index 000000000..3402d0a6d --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed-overflow/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vec-scalar/vadds-i16-signed-overflow +# family: vec-scalar +# target_ops: pto.vadds +# scenarios: core-i16-signed, full-mask, scalar-operand, integer-overflow + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str, dtype) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.int16) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed-overflow/golden.py b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed-overflow/golden.py new file mode 100644 index 000000000..fcf4c5afb --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed-overflow/golden.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vec-scalar/vadds-i16-signed-overflow +# family: vec-scalar +# target_ops: pto.vadds +# scenarios: core-i16-signed, full-mask, scalar-operand, integer-overflow + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 +SCALAR = np.int16(1024) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(-16000, 16000, size=ELEMS, dtype=np.int16) + v1[:12] = np.array( + [ + 32767, + 32766, + 32760, + 32000, + 0, + 1, + -1, + -32768, + -32767, + -32000, + 12345, + -12345, + ], + dtype=np.int16, + ) + v2 = np.zeros(ELEMS, dtype=np.int16) + golden_v2 = (v1.astype(np.int32) + int(SCALAR)).astype(np.int16) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed-overflow/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed-overflow/kernel.pto new file mode 100644 index 000000000..1fa3b62d1 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed-overflow/kernel.pto @@ -0,0 +1,48 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vadds-i16-signed-overflow +// family: vec-scalar +// target_ops: pto.vadds +// scenarios: core-i16-signed, full-mask, scalar-operand, integer-overflow +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vadds_i16_signed_overflow_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %scalar = arith.constant 1024 : i16 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<128xi16> + %sum = pto.vadds %vec, %scalar, %mask : !pto.vreg<128xi16>, i16, !pto.mask -> !pto.vreg<128xi16> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<128xi16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed-overflow/launch.cpp b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed-overflow/launch.cpp new file mode 100644 index 000000000..4ac3e8a59 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed-overflow/launch.cpp @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vadds-i16-signed-overflow +// family: vec-scalar +// target_ops: pto.vadds +// scenarios: core-i16-signed, full-mask, scalar-operand, integer-overflow +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vadds_i16_signed_overflow_kernel(__gm__ int16_t *v1, __gm__ int16_t *v2); + +void LaunchVadds_i16_signed_overflow_kernel(int16_t *v1, int16_t *v2, + void *stream) { + vadds_i16_signed_overflow_kernel<<<1, nullptr, stream>>>((__gm__ int16_t *)v1, + (__gm__ int16_t *)v2); +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed-overflow/main.cpp b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed-overflow/main.cpp new file mode 100644 index 000000000..b8c63d7a8 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed-overflow/main.cpp @@ -0,0 +1,81 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVadds_i16_signed_overflow_kernel(int16_t *v1, int16_t *v2, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(int16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(int16_t); + int16_t *v1Host = nullptr; + int16_t *v1Device = nullptr; + int16_t *v2Host = nullptr; + int16_t *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVadds_i16_signed_overflow_kernel(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed/compare.py b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed/compare.py new file mode 100644 index 000000000..421ac84f5 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vec-scalar/vadds-i16-signed +# family: vec-scalar +# target_ops: pto.vadds +# scenarios: core-i16-signed, full-mask, scalar-operand + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str, dtype) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.int16) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed/golden.py b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed/golden.py new file mode 100644 index 000000000..23d6855f0 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed/golden.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vec-scalar/vadds-i16-signed +# family: vec-scalar +# target_ops: pto.vadds +# scenarios: core-i16-signed, full-mask, scalar-operand + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 +SCALAR = np.int16(37) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(-12000, 12000, size=ELEMS, dtype=np.int16) + v2 = np.zeros(ELEMS, dtype=np.int16) + golden_v2 = (v1.astype(np.int32) + int(SCALAR)).astype(np.int16) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed/kernel.pto new file mode 100644 index 000000000..9ebf5fa9f --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed/kernel.pto @@ -0,0 +1,48 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vadds-i16-signed +// family: vec-scalar +// target_ops: pto.vadds +// scenarios: core-i16-signed, full-mask, scalar-operand +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vadds_i16_signed_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c1024_i32 = arith.constant 1024 : i32 + %scalar = arith.constant 37 : i16 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<128xi16> + %sum = pto.vadds %vec, %scalar, %mask : !pto.vreg<128xi16>, i16, !pto.mask -> !pto.vreg<128xi16> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<128xi16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed/launch.cpp b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed/launch.cpp new file mode 100644 index 000000000..fd275ede0 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vadds_i16_signed_kernel(__gm__ int16_t *v1, + __gm__ int16_t *v2); + +void LaunchVadds_i16_signed_kernel(int16_t *v1, int16_t *v2, void *stream) { + vadds_i16_signed_kernel<<<1, nullptr, stream>>>((__gm__ int16_t *)v1, + (__gm__ int16_t *)v2); +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed/main.cpp b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed/main.cpp new file mode 100644 index 000000000..3cbb6afba --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-signed/main.cpp @@ -0,0 +1,80 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVadds_i16_signed_kernel(int16_t *v1, int16_t *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(int16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(int16_t); + int16_t *v1Host = nullptr; + int16_t *v1Device = nullptr; + int16_t *v2Host = nullptr; + int16_t *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVadds_i16_signed_kernel(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned-overflow/compare.py b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned-overflow/compare.py new file mode 100755 index 000000000..a1b852540 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned-overflow/compare.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vec-scalar/vadds-i16-unsigned-overflow +# family: vec-scalar +# target_ops: pto.vadds +# scenarios: core-i16-unsigned, full-mask, scalar-operand, integer-overflow + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.uint16) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned-overflow/golden.py b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned-overflow/golden.py new file mode 100755 index 000000000..813ec0287 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned-overflow/golden.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vec-scalar/vadds-i16-unsigned-overflow +# family: vec-scalar +# target_ops: pto.vadds +# scenarios: core-i16-unsigned, full-mask, scalar-operand, integer-overflow + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 +SCALAR = np.uint16(4096) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(0, 65535, size=ELEMS, dtype=np.uint16) + v1[:12] = np.array( + [ + 65535, + 65534, + 65500, + 65000, + 4096, + 2048, + 1024, + 1, + 0, + 32768, + 12345, + 54321, + ], + dtype=np.uint16, + ) + v2 = np.zeros(ELEMS, dtype=np.uint16) + golden_v2 = (v1.astype(np.uint32) + int(SCALAR)).astype(np.uint16) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned-overflow/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned-overflow/kernel.pto new file mode 100644 index 000000000..5918a0212 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned-overflow/kernel.pto @@ -0,0 +1,47 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vadds-i16-unsigned-overflow +// family: vec-scalar +// target_ops: pto.vadds +// scenarios: core-i16-unsigned, full-mask, scalar-operand, integer-overflow +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vadds_i16_unsigned_overflow_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %scalar = arith.constant 4096 : i16 + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %sum = pto.vadds %vec, %scalar, %mask : !pto.vreg<128xui16>, i16, !pto.mask -> !pto.vreg<128xui16> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<128xui16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned-overflow/launch.cpp b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned-overflow/launch.cpp new file mode 100644 index 000000000..003c7556b --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned-overflow/launch.cpp @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vadds-i16-unsigned-overflow +// family: vec-scalar +// target_ops: pto.vadds +// scenarios: core-i16-unsigned, full-mask, scalar-operand, integer-overflow +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vadds_i16_unsigned_overflow_kernel(__gm__ uint16_t *v1, __gm__ uint16_t *v2); + +void LaunchVadds_i16_unsigned_overflow_kernel(uint16_t *v1, uint16_t *v2, + void *stream) { + vadds_i16_unsigned_overflow_kernel<<<1, nullptr, stream>>>( + (__gm__ uint16_t *)v1, (__gm__ uint16_t *)v2); +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned-overflow/main.cpp b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned-overflow/main.cpp new file mode 100644 index 000000000..8f8b2ebe6 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned-overflow/main.cpp @@ -0,0 +1,81 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVadds_i16_unsigned_overflow_kernel(uint16_t *v1, uint16_t *v2, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVadds_i16_unsigned_overflow_kernel(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned/compare.py b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned/compare.py new file mode 100755 index 000000000..437f48ad7 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned/compare.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vec-scalar/vadds-i16-unsigned +# family: vec-scalar +# target_ops: pto.vadds +# scenarios: core-i16-unsigned, full-mask, scalar-operand + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.uint16) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned/golden.py b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned/golden.py new file mode 100755 index 000000000..df317a729 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned/golden.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vec-scalar/vadds-i16-unsigned +# family: vec-scalar +# target_ops: pto.vadds +# scenarios: core-i16-unsigned, full-mask, scalar-operand + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 +SCALAR = np.uint16(37) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(0, 60000, size=ELEMS, dtype=np.uint16) + v2 = np.zeros(ELEMS, dtype=np.uint16) + golden_v2 = (v1.astype(np.uint32) + int(SCALAR)).astype(np.uint16) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned/kernel.pto new file mode 100644 index 000000000..232b318d6 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned/kernel.pto @@ -0,0 +1,47 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vadds-i16-unsigned +// family: vec-scalar +// target_ops: pto.vadds +// scenarios: core-i16-unsigned, full-mask, scalar-operand +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vadds_i16_unsigned_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %scalar = arith.constant 37 : i16 + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %sum = pto.vadds %vec, %scalar, %mask : !pto.vreg<128xui16>, i16, !pto.mask -> !pto.vreg<128xui16> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<128xui16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned/launch.cpp b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned/launch.cpp new file mode 100644 index 000000000..61ff83045 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned/launch.cpp @@ -0,0 +1,50 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vadds-i16-unsigned +// family: vec-scalar +// target_ops: pto.vadds +// scenarios: core-i16-unsigned, full-mask, scalar-operand +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vadds_i16_unsigned_kernel(__gm__ uint16_t *v1, + __gm__ uint16_t *v2); + +void LaunchVadds_i16_unsigned_kernel(uint16_t *v1, uint16_t *v2, void *stream) { + vadds_i16_unsigned_kernel<<<1, nullptr, stream>>>((__gm__ uint16_t *)v1, + (__gm__ uint16_t *)v2); +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned/main.cpp b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned/main.cpp new file mode 100644 index 000000000..de7c50e82 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-i16-unsigned/main.cpp @@ -0,0 +1,90 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vadds-i16-unsigned +// family: vec-scalar +// target_ops: pto.vadds +// scenarios: core-i16-unsigned, full-mask, scalar-operand +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVadds_i16_unsigned_kernel(uint16_t *v1, uint16_t *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVadds_i16_unsigned_kernel(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-tail/compare.py b/test/vpto/cases/micro-op/vec-scalar/vadds-tail/compare.py new file mode 100644 index 000000000..c13d79273 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-tail/compare.py @@ -0,0 +1,39 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v2.bin", "v2.bin", np.float32, 1e-4, 1000) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-tail/golden.py b/test/vpto/cases/micro-op/vec-scalar/vadds-tail/golden.py new file mode 100644 index 000000000..2f06c22fa --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-tail/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +SCALE = np.float32(3.14) +LOGICAL_ELEMS = 1000 +OUT_SENTINEL = np.float32(-123.25) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.random((ROWS, COLS), dtype=np.float32) + v2 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + golden_v2 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + golden_v2.reshape(-1)[:LOGICAL_ELEMS] = ( + v1.reshape(-1)[:LOGICAL_ELEMS] + SCALE + ).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-tail/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vadds-tail/kernel.pto new file mode 100644 index 000000000..c6784e700 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-tail/kernel.pto @@ -0,0 +1,44 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vadds_tail_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1000_i32 = arith.constant 1000 : i32 + %cst = arith.constant 3.140000e+00 : f32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1000_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %sum = pto.vadds %vec, %cst, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-tail/launch.cpp b/test/vpto/cases/micro-op/vec-scalar/vadds-tail/launch.cpp new file mode 100644 index 000000000..b4cd46470 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-tail/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vadds_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVadds_tail_kernel_2d(float *v1, float *v2, void *stream) { + vadds_tail_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds-tail/main.cpp b/test/vpto/cases/micro-op/vec-scalar/vadds-tail/main.cpp new file mode 100644 index 000000000..ab77e6b1a --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds-tail/main.cpp @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVadds_tail_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVadds_tail_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds/compare.py b/test/vpto/cases/micro-op/vec-scalar/vadds/compare.py new file mode 100644 index 000000000..15b793fac --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds/golden.py b/test/vpto/cases/micro-op/vec-scalar/vadds/golden.py new file mode 100644 index 000000000..273a8d29f --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds/golden.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +SCALE = np.float32(3.14) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = (v1 + SCALE).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vadds/kernel.pto new file mode 100644 index 000000000..672c20e84 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds/kernel.pto @@ -0,0 +1,45 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vec_add_scalar_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + %cst = arith.constant 3.140000e+00 : f32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %sum = pto.vadds %vec, %cst, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %sum, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds/launch.cpp b/test/vpto/cases/micro-op/vec-scalar/vadds/launch.cpp new file mode 100644 index 000000000..44c07c249 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vec_add_scalar_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVec_add_scalar_kernel_2d(float *v1, float *v2, void *stream) { + vec_add_scalar_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vadds/main.cpp b/test/vpto/cases/micro-op/vec-scalar/vadds/main.cpp new file mode 100644 index 000000000..fcb42331f --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vadds/main.cpp @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVec_add_scalar_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVec_add_scalar_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vmaxs-tail/compare.py b/test/vpto/cases/micro-op/vec-scalar/vmaxs-tail/compare.py new file mode 100644 index 000000000..c13d79273 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmaxs-tail/compare.py @@ -0,0 +1,39 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v2.bin", "v2.bin", np.float32, 1e-4, 1000) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vmaxs-tail/golden.py b/test/vpto/cases/micro-op/vec-scalar/vmaxs-tail/golden.py new file mode 100644 index 000000000..0b08cbcab --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmaxs-tail/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +SCALE = np.float32(3.14) +LOGICAL_ELEMS = 1000 +OUT_SENTINEL = np.float32(-123.25) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.random((ROWS, COLS), dtype=np.float32) + v2 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + golden_v2 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + golden_v2.reshape(-1)[:LOGICAL_ELEMS] = np.maximum( + v1.reshape(-1)[:LOGICAL_ELEMS], SCALE + ).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vmaxs-tail/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vmaxs-tail/kernel.pto new file mode 100644 index 000000000..de40999f8 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmaxs-tail/kernel.pto @@ -0,0 +1,44 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmaxs_tail_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1000_i32 = arith.constant 1000 : i32 + %cst = arith.constant 3.140000e+00 : f32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1000_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %maxv = pto.vmaxs %vec, %cst, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %maxv, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vmaxs-tail/launch.cpp b/test/vpto/cases/micro-op/vec-scalar/vmaxs-tail/launch.cpp new file mode 100644 index 000000000..d5ae524ce --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmaxs-tail/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vmaxs_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVadds_tail_kernel_2d(float *v1, float *v2, void *stream) { + vmaxs_tail_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vmaxs-tail/main.cpp b/test/vpto/cases/micro-op/vec-scalar/vmaxs-tail/main.cpp new file mode 100644 index 000000000..ab77e6b1a --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmaxs-tail/main.cpp @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVadds_tail_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVadds_tail_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vmaxs/compare.py b/test/vpto/cases/micro-op/vec-scalar/vmaxs/compare.py new file mode 100644 index 000000000..15b793fac --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmaxs/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vmaxs/golden.py b/test/vpto/cases/micro-op/vec-scalar/vmaxs/golden.py new file mode 100644 index 000000000..d4b379f33 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmaxs/golden.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +SCALE = np.float32(3.14) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.maximum(v1, SCALE).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vmaxs/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vmaxs/kernel.pto new file mode 100644 index 000000000..f57b56001 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmaxs/kernel.pto @@ -0,0 +1,45 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vec_max_scalar_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + %cst = arith.constant 3.140000e+00 : f32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %maxv = pto.vmaxs %vec, %cst, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %maxv, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vmaxs/launch.cpp b/test/vpto/cases/micro-op/vec-scalar/vmaxs/launch.cpp new file mode 100644 index 000000000..a08848672 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmaxs/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vec_max_scalar_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVec_max_scalar_kernel_2d(float *v1, float *v2, void *stream) { + vec_max_scalar_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vmaxs/main.cpp b/test/vpto/cases/micro-op/vec-scalar/vmaxs/main.cpp new file mode 100644 index 000000000..47ce3f58b --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmaxs/main.cpp @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVec_max_scalar_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVec_max_scalar_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vmins-tail/compare.py b/test/vpto/cases/micro-op/vec-scalar/vmins-tail/compare.py new file mode 100644 index 000000000..c13d79273 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmins-tail/compare.py @@ -0,0 +1,39 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v2.bin", "v2.bin", np.float32, 1e-4, 1000) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vmins-tail/golden.py b/test/vpto/cases/micro-op/vec-scalar/vmins-tail/golden.py new file mode 100644 index 000000000..e4e63235a --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmins-tail/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +SCALE = np.float32(3.14) +LOGICAL_ELEMS = 1000 +OUT_SENTINEL = np.float32(-123.25) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.random((ROWS, COLS), dtype=np.float32) + v2 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + golden_v2 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + golden_v2.reshape(-1)[:LOGICAL_ELEMS] = np.minimum( + v1.reshape(-1)[:LOGICAL_ELEMS], SCALE + ).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vmins-tail/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vmins-tail/kernel.pto new file mode 100644 index 000000000..9a6f30010 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmins-tail/kernel.pto @@ -0,0 +1,44 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmins_tail_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1000_i32 = arith.constant 1000 : i32 + %cst = arith.constant 3.140000e+00 : f32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1000_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %minv = pto.vmins %vec, %cst, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %minv, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vmins-tail/launch.cpp b/test/vpto/cases/micro-op/vec-scalar/vmins-tail/launch.cpp new file mode 100644 index 000000000..2774c3f46 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmins-tail/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vmins_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVadds_tail_kernel_2d(float *v1, float *v2, void *stream) { + vmins_tail_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vmins-tail/main.cpp b/test/vpto/cases/micro-op/vec-scalar/vmins-tail/main.cpp new file mode 100644 index 000000000..ab77e6b1a --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmins-tail/main.cpp @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVadds_tail_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVadds_tail_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vmins/compare.py b/test/vpto/cases/micro-op/vec-scalar/vmins/compare.py new file mode 100644 index 000000000..15b793fac --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmins/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vmins/golden.py b/test/vpto/cases/micro-op/vec-scalar/vmins/golden.py new file mode 100644 index 000000000..7caa057f7 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmins/golden.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +SCALE = np.float32(3.14) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.minimum(v1, SCALE).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vmins/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vmins/kernel.pto new file mode 100644 index 000000000..5fa5b3613 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmins/kernel.pto @@ -0,0 +1,45 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vec_min_scalar_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + %cst = arith.constant 3.140000e+00 : f32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %minv = pto.vmins %vec, %cst, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %minv, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vmins/launch.cpp b/test/vpto/cases/micro-op/vec-scalar/vmins/launch.cpp new file mode 100644 index 000000000..23603d652 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmins/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vec_min_scalar_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVec_min_scalar_kernel_2d(float *v1, float *v2, void *stream) { + vec_min_scalar_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vmins/main.cpp b/test/vpto/cases/micro-op/vec-scalar/vmins/main.cpp new file mode 100644 index 000000000..888a58876 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmins/main.cpp @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVec_min_scalar_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVec_min_scalar_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vmuls-tail/compare.py b/test/vpto/cases/micro-op/vec-scalar/vmuls-tail/compare.py new file mode 100644 index 000000000..c13d79273 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmuls-tail/compare.py @@ -0,0 +1,39 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype, count=count) + output = np.fromfile(output_path, dtype=dtype, count=count) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_prefix("golden_v2.bin", "v2.bin", np.float32, 1e-4, 1000) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vmuls-tail/golden.py b/test/vpto/cases/micro-op/vec-scalar/vmuls-tail/golden.py new file mode 100644 index 000000000..fdfd56fd5 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmuls-tail/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +SCALE = np.float32(3.14) +LOGICAL_ELEMS = 1000 +OUT_SENTINEL = np.float32(-123.25) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.random((ROWS, COLS), dtype=np.float32) + v2 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + golden_v2 = np.full((ROWS, COLS), OUT_SENTINEL, dtype=np.float32) + golden_v2.reshape(-1)[:LOGICAL_ELEMS] = ( + v1.reshape(-1)[:LOGICAL_ELEMS] * SCALE + ).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vmuls-tail/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vmuls-tail/kernel.pto new file mode 100644 index 000000000..92efaba1b --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmuls-tail/kernel.pto @@ -0,0 +1,44 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmuls_tail_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1000_i32 = arith.constant 1000 : i32 + %cst = arith.constant 3.140000e+00 : f32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1000_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %prod = pto.vmuls %vec, %cst, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %prod, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vmuls-tail/launch.cpp b/test/vpto/cases/micro-op/vec-scalar/vmuls-tail/launch.cpp new file mode 100644 index 000000000..65f00d71a --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmuls-tail/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vmuls_tail_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVadds_tail_kernel_2d(float *v1, float *v2, void *stream) { + vmuls_tail_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vmuls-tail/main.cpp b/test/vpto/cases/micro-op/vec-scalar/vmuls-tail/main.cpp new file mode 100644 index 000000000..ab77e6b1a --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmuls-tail/main.cpp @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVadds_tail_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVadds_tail_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vmuls/compare.py b/test/vpto/cases/micro-op/vec-scalar/vmuls/compare.py new file mode 100644 index 000000000..15b793fac --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmuls/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vmuls/golden.py b/test/vpto/cases/micro-op/vec-scalar/vmuls/golden.py new file mode 100644 index 000000000..5233be0ed --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmuls/golden.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +SCALE = np.float32(3.14) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = (v1 * SCALE).astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vmuls/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vmuls/kernel.pto new file mode 100644 index 000000000..4e6461f2f --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmuls/kernel.pto @@ -0,0 +1,45 @@ +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vec_mul_scalar_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + %cst = arith.constant 3.140000e+00 : f32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %prod = pto.vmuls %vec, %cst, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %prod, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vmuls/launch.cpp b/test/vpto/cases/micro-op/vec-scalar/vmuls/launch.cpp new file mode 100644 index 000000000..0146a24b5 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmuls/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vec_mul_scalar_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVec_mul_scalar_kernel_2d(float *v1, float *v2, void *stream) { + vec_mul_scalar_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vmuls/main.cpp b/test/vpto/cases/micro-op/vec-scalar/vmuls/main.cpp new file mode 100644 index 000000000..e99b6c097 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vmuls/main.cpp @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVec_mul_scalar_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVec_mul_scalar_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vshls-shift-boundary/compare.py b/test/vpto/cases/micro-op/vec-scalar/vshls-shift-boundary/compare.py new file mode 100644 index 000000000..f07dd1f4c --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vshls-shift-boundary/compare.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vec-scalar/vshls-shift-boundary +# family: vec-scalar +# target_ops: pto.vshls +# scenarios: core-i16-unsigned, full-mask, scalar-operand, shift-boundary + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.uint16) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vshls-shift-boundary/golden.py b/test/vpto/cases/micro-op/vec-scalar/vshls-shift-boundary/golden.py new file mode 100644 index 000000000..e0952743a --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vshls-shift-boundary/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vec-scalar/vshls-shift-boundary +# family: vec-scalar +# target_ops: pto.vshls +# scenarios: core-i16-unsigned, full-mask, scalar-operand, shift-boundary + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SHIFT = 15 +PATTERN = np.array( + [0x0000, 0x0001, 0x0002, 0x0003, 0x7FFF, 0x8000, 0x8001, 0xFFFF], + dtype=np.uint16, +) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + repeats = ELEMS // PATTERN.size + v1 = np.tile(PATTERN, repeats) + v2 = np.zeros(ELEMS, dtype=np.uint16) + golden_v2 = np.left_shift(v1.astype(np.uint32), SHIFT).astype(np.uint16) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=19) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vshls-shift-boundary/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vshls-shift-boundary/kernel.pto new file mode 100644 index 000000000..e25159f5e --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vshls-shift-boundary/kernel.pto @@ -0,0 +1,47 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vshls-shift-boundary +// family: vec-scalar +// target_ops: pto.vshls +// scenarios: core-i16-unsigned, full-mask, scalar-operand, shift-boundary +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vshls_shift_boundary_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %scalar = arith.constant 15 : i16 + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %shifted = pto.vshls %vec, %scalar, %mask : !pto.vreg<128xui16>, i16, !pto.mask -> !pto.vreg<128xui16> + pto.vsts %shifted, %ub_out[%offset], %mask : !pto.vreg<128xui16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vshls-shift-boundary/launch.cpp b/test/vpto/cases/micro-op/vec-scalar/vshls-shift-boundary/launch.cpp new file mode 100644 index 000000000..ee7141d19 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vshls-shift-boundary/launch.cpp @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vshls-shift-boundary +// family: vec-scalar +// target_ops: pto.vshls +// scenarios: core-i16-unsigned, full-mask, scalar-operand +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vshls_shift_boundary_kernel(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVshls_shift_boundary_kernel(float *v1, float *v2, void *stream) { + vshls_shift_boundary_kernel<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vshls-shift-boundary/main.cpp b/test/vpto/cases/micro-op/vec-scalar/vshls-shift-boundary/main.cpp new file mode 100644 index 000000000..3b51b0c33 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vshls-shift-boundary/main.cpp @@ -0,0 +1,91 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vshls-shift-boundary +// family: vec-scalar +// target_ops: pto.vshls +// scenarios: core-i16-unsigned, full-mask, scalar-operand +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVshls_shift_boundary_kernel(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVshls_shift_boundary_kernel(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vshls/compare.py b/test/vpto/cases/micro-op/vec-scalar/vshls/compare.py new file mode 100644 index 000000000..d5e9ad835 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vshls/compare.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vec-scalar/vshls +# family: vec-scalar +# target_ops: pto.vshls +# scenarios: core-i16-unsigned, full-mask, scalar-operand + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.uint16) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vshls/golden.py b/test/vpto/cases/micro-op/vec-scalar/vshls/golden.py new file mode 100644 index 000000000..5d4fe1763 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vshls/golden.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vec-scalar/vshls +# family: vec-scalar +# target_ops: pto.vshls +# scenarios: core-i16-unsigned, full-mask, scalar-operand + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 +SHIFT = 3 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(0, 1 << 12, size=ELEMS, dtype=np.uint16) + v2 = np.zeros(ELEMS, dtype=np.uint16) + golden_v2 = np.left_shift(v1.astype(np.uint32), SHIFT).astype(np.uint16) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vshls/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vshls/kernel.pto new file mode 100644 index 000000000..dd3ff545d --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vshls/kernel.pto @@ -0,0 +1,47 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vshls +// family: vec-scalar +// target_ops: pto.vshls +// scenarios: core-i16-unsigned, full-mask, scalar-operand +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vshls_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %scalar = arith.constant 3 : i16 + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %shifted = pto.vshls %vec, %scalar, %mask : !pto.vreg<128xui16>, i16, !pto.mask -> !pto.vreg<128xui16> + pto.vsts %shifted, %ub_out[%offset], %mask : !pto.vreg<128xui16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vshls/launch.cpp b/test/vpto/cases/micro-op/vec-scalar/vshls/launch.cpp new file mode 100644 index 000000000..ed048246e --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vshls/launch.cpp @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vshls +// family: vec-scalar +// target_ops: pto.vshls +// scenarios: core-i16-unsigned, full-mask, scalar-operand +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vshls_kernel(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVshls_kernel(float *v1, float *v2, void *stream) { + vshls_kernel<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vshls/main.cpp b/test/vpto/cases/micro-op/vec-scalar/vshls/main.cpp new file mode 100644 index 000000000..f5cec4212 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vshls/main.cpp @@ -0,0 +1,91 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vshls +// family: vec-scalar +// target_ops: pto.vshls +// scenarios: core-i16-unsigned, full-mask, scalar-operand +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVshls_kernel(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVshls_kernel(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vshrs-shift-boundary/compare.py b/test/vpto/cases/micro-op/vec-scalar/vshrs-shift-boundary/compare.py new file mode 100644 index 000000000..65d6e8920 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vshrs-shift-boundary/compare.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vec-scalar/vshrs-shift-boundary +# family: vec-scalar +# target_ops: pto.vshrs +# scenarios: core-i16-unsigned, full-mask, scalar-operand, shift-boundary + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.uint16) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vshrs-shift-boundary/golden.py b/test/vpto/cases/micro-op/vec-scalar/vshrs-shift-boundary/golden.py new file mode 100644 index 000000000..c1f36dae0 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vshrs-shift-boundary/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vec-scalar/vshrs-shift-boundary +# family: vec-scalar +# target_ops: pto.vshrs +# scenarios: core-i16-unsigned, full-mask, scalar-operand, shift-boundary + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SHIFT = 15 +PATTERN = np.array( + [0x0000, 0x0001, 0x0002, 0x0003, 0x7FFF, 0x8000, 0x8001, 0xFFFF], + dtype=np.uint16, +) + + +def generate(output_dir: Path, seed: int) -> None: + del seed + repeats = ELEMS // PATTERN.size + v1 = np.tile(PATTERN, repeats) + v2 = np.zeros(ELEMS, dtype=np.uint16) + golden_v2 = np.right_shift(v1, SHIFT) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=19) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vshrs-shift-boundary/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vshrs-shift-boundary/kernel.pto new file mode 100644 index 000000000..c122287fc --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vshrs-shift-boundary/kernel.pto @@ -0,0 +1,47 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vshrs-shift-boundary +// family: vec-scalar +// target_ops: pto.vshrs +// scenarios: core-i16-unsigned, full-mask, scalar-operand, shift-boundary +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vshrs_shift_boundary_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %scalar = arith.constant 15 : i16 + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %shifted = pto.vshrs %vec, %scalar, %mask : !pto.vreg<128xui16>, i16, !pto.mask -> !pto.vreg<128xui16> + pto.vsts %shifted, %ub_out[%offset], %mask : !pto.vreg<128xui16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vshrs-shift-boundary/launch.cpp b/test/vpto/cases/micro-op/vec-scalar/vshrs-shift-boundary/launch.cpp new file mode 100644 index 000000000..b108e4ba5 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vshrs-shift-boundary/launch.cpp @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vshrs-shift-boundary +// family: vec-scalar +// target_ops: pto.vshrs +// scenarios: core-i16-unsigned, full-mask, scalar-operand +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vshrs_shift_boundary_kernel(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVshrs_shift_boundary_kernel(float *v1, float *v2, void *stream) { + vshrs_shift_boundary_kernel<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vshrs-shift-boundary/main.cpp b/test/vpto/cases/micro-op/vec-scalar/vshrs-shift-boundary/main.cpp new file mode 100644 index 000000000..4f5611378 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vshrs-shift-boundary/main.cpp @@ -0,0 +1,91 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vshrs-shift-boundary +// family: vec-scalar +// target_ops: pto.vshrs +// scenarios: core-i16-unsigned, full-mask, scalar-operand +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVshrs_shift_boundary_kernel(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVshrs_shift_boundary_kernel(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vshrs/compare.py b/test/vpto/cases/micro-op/vec-scalar/vshrs/compare.py new file mode 100644 index 000000000..3c2384aff --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vshrs/compare.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vec-scalar/vshrs +# family: vec-scalar +# target_ops: pto.vshrs +# scenarios: core-i16-unsigned, full-mask, scalar-operand + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.uint16) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vshrs/golden.py b/test/vpto/cases/micro-op/vec-scalar/vshrs/golden.py new file mode 100644 index 000000000..82b2a9e07 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vshrs/golden.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vec-scalar/vshrs +# family: vec-scalar +# target_ops: pto.vshrs +# scenarios: core-i16-unsigned, full-mask, scalar-operand + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 19 +SHIFT = 3 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(0, np.iinfo(np.uint16).max + 1, size=ELEMS, dtype=np.uint16) + v2 = np.zeros(ELEMS, dtype=np.uint16) + golden_v2 = np.right_shift(v1, SHIFT) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vshrs/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vshrs/kernel.pto new file mode 100644 index 000000000..0c01f3047 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vshrs/kernel.pto @@ -0,0 +1,47 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vshrs +// family: vec-scalar +// target_ops: pto.vshrs +// scenarios: core-i16-unsigned, full-mask, scalar-operand +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vshrs_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %scalar = arith.constant 3 : i16 + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<128xui16> + %shifted = pto.vshrs %vec, %scalar, %mask : !pto.vreg<128xui16>, i16, !pto.mask -> !pto.vreg<128xui16> + pto.vsts %shifted, %ub_out[%offset], %mask : !pto.vreg<128xui16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vshrs/launch.cpp b/test/vpto/cases/micro-op/vec-scalar/vshrs/launch.cpp new file mode 100644 index 000000000..ebf9902d1 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vshrs/launch.cpp @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vshrs +// family: vec-scalar +// target_ops: pto.vshrs +// scenarios: core-i16-unsigned, full-mask, scalar-operand +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vshrs_kernel(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVshrs_kernel(float *v1, float *v2, void *stream) { + vshrs_kernel<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vshrs/main.cpp b/test/vpto/cases/micro-op/vec-scalar/vshrs/main.cpp new file mode 100644 index 000000000..81790da59 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vshrs/main.cpp @@ -0,0 +1,91 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vshrs +// family: vec-scalar +// target_ops: pto.vshrs +// scenarios: core-i16-unsigned, full-mask, scalar-operand +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVshrs_kernel(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVshrs_kernel(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vsubcs-borrow-boundary/compare.py b/test/vpto/cases/micro-op/vec-scalar/vsubcs-borrow-boundary/compare.py new file mode 100644 index 000000000..87847b721 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vsubcs-borrow-boundary/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vec-scalar/vsubcs-borrow-boundary +# family: vec-scalar +# target_ops: pto.vsubcs +# scenarios: core-u32-unsigned, full-mask, carry-chain, integer-overflow + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 +LOGICAL_ELEMS = 64 +SRC_ELEM_BYTES = 4 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + repeat_elems = REPEAT_BYTES // src_elem_bytes + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + +def compare_result(): + golden = np.fromfile("golden_v3.bin", dtype=np.uint32, count=64) + output = np.fromfile("v3.bin", dtype=np.uint32, count=64) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def compare_borrow(): + prefix_bytes = _packed_pred_storage_bytes(LOGICAL_ELEMS, SRC_ELEM_BYTES) + golden = np.fromfile("golden_v4.bin", dtype=np.uint8) + output = np.fromfile("v4.bin", dtype=np.uint8) + if golden.size < prefix_bytes or output.size < prefix_bytes: + return False + return np.array_equal(golden[:prefix_bytes], output[:prefix_bytes]) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_result() and compare_borrow() + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vsubcs-borrow-boundary/golden.py b/test/vpto/cases/micro-op/vec-scalar/vsubcs-borrow-boundary/golden.py new file mode 100644 index 000000000..d20ebafc3 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vsubcs-borrow-boundary/golden.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vec-scalar/vsubcs-borrow-boundary +# family: vec-scalar +# target_ops: pto.vsubcs +# scenarios: core-u32-unsigned, full-mask, carry-chain, integer-overflow + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 64 +LHS_PATTERN = np.array( + [0x00000000, 0x00000001, 0x00000000, 0xFFFFFFFF, 0x80000000, 0x7FFFFFFF, 0xAAAAAAAA, 0x55555555], + dtype=np.uint32, +) +RHS_PATTERN = np.array( + [0x00000000, 0x00000000, 0x00000001, 0xFFFFFFFF, 0x7FFFFFFF, 0x80000000, 0x55555555, 0xAAAAAAAA], + dtype=np.uint32, +) + + +def pack_mask_nibbles(bits): + out = np.zeros(256, dtype=np.uint8) + for idx, bit in enumerate(bits): + if not bit: + continue + byte = idx // 2 + if idx % 2 == 0: + out[byte] |= np.uint8(0x1) + else: + out[byte] |= np.uint8(0x10) + return out + + +def generate(output_dir: Path, seed: int) -> None: + del seed + repeats = LANES // LHS_PATTERN.size + lhs = np.tile(LHS_PATTERN, repeats) + rhs = np.tile(RHS_PATTERN, repeats) + lhs64 = lhs.astype(np.uint64) + rhs64 = rhs.astype(np.uint64) + no_borrow = lhs64 >= rhs64 + result = ((lhs64 - rhs64) & np.uint64(0xFFFFFFFF)).astype(np.uint32) + + output_dir.mkdir(parents=True, exist_ok=True) + lhs.tofile(output_dir / "v1.bin") + rhs.tofile(output_dir / "v2.bin") + np.zeros(LANES, dtype=np.uint32).tofile(output_dir / "v3.bin") + np.zeros(256, dtype=np.uint8).tofile(output_dir / "v4.bin") + result.tofile(output_dir / "golden_v3.bin") + pack_mask_nibbles(no_borrow).tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=19) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vsubcs-borrow-boundary/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vsubcs-borrow-boundary/kernel.pto new file mode 100644 index 000000000..5ef0438fe --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vsubcs-borrow-boundary/kernel.pto @@ -0,0 +1,57 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vsubcs-borrow-boundary +// family: vec-scalar +// target_ops: pto.vsubcs +// scenarios: core-u32-unsigned, full-mask, carry-chain, integer-overflow +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vsubcs_borrow_boundary_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr, + %arg2: !pto.ptr, + %arg3: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c128_i64 = arith.constant 128 : i64 + %c256_i64 = arith.constant 256 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c12288_i64 = arith.constant 12288 : i64 + %false = arith.constant false + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_borrow = pto.castptr %c12288_i64 : i64 -> !pto.ptr + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %borrow_in = pto.pset_b32 "PAT_ALL" : !pto.mask + %lhs = pto.vlds %ub_lhs[%c0] : !pto.ptr -> !pto.vreg<64xui32> + %rhs = pto.vlds %ub_rhs[%c0] : !pto.ptr -> !pto.vreg<64xui32> + %diff, %borrow = pto.vsubcs %lhs, %rhs, %borrow_in, %mask : !pto.vreg<64xui32>, !pto.vreg<64xui32>, !pto.mask, !pto.mask -> !pto.vreg<64xui32>, !pto.mask + pto.vsts %diff, %ub_out[%c0], %mask : !pto.vreg<64xui32>, !pto.ptr, !pto.mask + pto.psti %borrow, %ub_borrow[%c0], "NORM" : !pto.mask, !pto.ptr, index + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_borrow, %arg3, %c128_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vsubcs-borrow-boundary/launch.cpp b/test/vpto/cases/micro-op/vec-scalar/vsubcs-borrow-boundary/launch.cpp new file mode 100644 index 000000000..a1cb56e2e --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vsubcs-borrow-boundary/launch.cpp @@ -0,0 +1,53 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vsubcs-borrow-boundary +// family: vec-scalar +// target_ops: pto.vsubcs +// scenarios: core-u32-unsigned, full-mask, carry-chain, integer-overflow +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vsubcs_borrow_boundary_kernel( + __gm__ uint32_t *v1, __gm__ uint32_t *v2, __gm__ uint32_t *v3, + __gm__ uint8_t *v4); + +void LaunchVsubcsBorrowBoundaryKernel(uint32_t *v1, uint32_t *v2, uint32_t *v3, + uint8_t *v4, void *stream) { + vsubcs_borrow_boundary_kernel<<<1, nullptr, stream>>>( + (__gm__ uint32_t *)v1, (__gm__ uint32_t *)v2, (__gm__ uint32_t *)v3, + (__gm__ uint8_t *)v4); +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vsubcs-borrow-boundary/main.cpp b/test/vpto/cases/micro-op/vec-scalar/vsubcs-borrow-boundary/main.cpp new file mode 100644 index 000000000..169bc4512 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vsubcs-borrow-boundary/main.cpp @@ -0,0 +1,116 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vsubcs-borrow-boundary +// family: vec-scalar +// target_ops: pto.vsubcs +// scenarios: core-u32-unsigned, full-mask, carry-chain, integer-overflow +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVsubcsBorrowBoundaryKernel(uint32_t *v1, uint32_t *v2, uint32_t *v3, + uint8_t *v4, void *stream); + +int main() { + size_t elemCount_v1 = 64; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + size_t elemCount_v2 = 64; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint32_t); + size_t elemCount_v3 = 64; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint32_t); + size_t elemCount_v4 = 256; + size_t fileSize_v4 = elemCount_v4 * sizeof(uint8_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + uint32_t *v2Host = nullptr; + uint32_t *v2Device = nullptr; + uint32_t *v3Host = nullptr; + uint32_t *v3Device = nullptr; + uint8_t *v4Host = nullptr; + uint8_t *v4Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMallocHost((void **)(&v4Host), fileSize_v4)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v4Device, fileSize_v4, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ReadFile("./v4.bin", fileSize_v4, v4Host, fileSize_v4); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v4Device, fileSize_v4, v4Host, fileSize_v4, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVsubcsBorrowBoundaryKernel(v1Device, v2Device, v3Device, v4Device, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(v4Host, fileSize_v4, v4Device, fileSize_v4, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + WriteFile("./v4.bin", v4Host, fileSize_v4); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFree(v4Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + aclrtFreeHost(v4Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vsubcs/compare.py b/test/vpto/cases/micro-op/vec-scalar/vsubcs/compare.py new file mode 100644 index 000000000..047f6c245 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vsubcs/compare.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vec-scalar/vsubcs +# family: vec-scalar +# target_ops: pto.vsubcs +# scenarios: core-u32-unsigned, full-mask, carry-chain + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 +LOGICAL_ELEMS = 64 +SRC_ELEM_BYTES = 4 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + repeat_elems = REPEAT_BYTES // src_elem_bytes + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + +def compare_result(): + golden = np.fromfile("golden_v3.bin", dtype=np.uint32, count=64) + output = np.fromfile("v3.bin", dtype=np.uint32, count=64) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def compare_borrow(): + prefix_bytes = _packed_pred_storage_bytes(LOGICAL_ELEMS, SRC_ELEM_BYTES) + golden = np.fromfile("golden_v4.bin", dtype=np.uint8) + output = np.fromfile("v4.bin", dtype=np.uint8) + if golden.size < prefix_bytes or output.size < prefix_bytes: + return False + return np.array_equal(golden[:prefix_bytes], output[:prefix_bytes]) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_result() and compare_borrow() + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vsubcs/golden.py b/test/vpto/cases/micro-op/vec-scalar/vsubcs/golden.py new file mode 100644 index 000000000..d9c1f2e8b --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vsubcs/golden.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vec-scalar/vsubcs +# family: vec-scalar +# target_ops: pto.vsubcs +# scenarios: core-u32-unsigned, full-mask, carry-chain + +import argparse +from pathlib import Path + +import numpy as np + + +LANES = 64 +SEED = 19 + + +def pack_mask_nibbles(bits): + out = np.zeros(256, dtype=np.uint8) + for idx, bit in enumerate(bits): + if not bit: + continue + byte = idx // 2 + if idx % 2 == 0: + out[byte] |= np.uint8(0x1) + else: + out[byte] |= np.uint8(0x10) + return out + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + lhs = rng.integers(0, 0xFFFFFFFF, size=LANES, dtype=np.uint32) + rhs = rng.integers(0, 0xFFFFFFFF, size=LANES, dtype=np.uint32) + lhs64 = lhs.astype(np.uint64) + rhs64 = rhs.astype(np.uint64) + no_borrow = lhs64 >= rhs64 + result = ((lhs64 - rhs64) & np.uint64(0xFFFFFFFF)).astype(np.uint32) + + output_dir.mkdir(parents=True, exist_ok=True) + lhs.tofile(output_dir / "v1.bin") + rhs.tofile(output_dir / "v2.bin") + np.zeros(LANES, dtype=np.uint32).tofile(output_dir / "v3.bin") + np.zeros(256, dtype=np.uint8).tofile(output_dir / "v4.bin") + result.tofile(output_dir / "golden_v3.bin") + pack_mask_nibbles(no_borrow).tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vec-scalar/vsubcs/kernel.pto b/test/vpto/cases/micro-op/vec-scalar/vsubcs/kernel.pto new file mode 100644 index 000000000..a7d9e2962 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vsubcs/kernel.pto @@ -0,0 +1,55 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vsubcs +// family: vec-scalar +// target_ops: pto.vsubcs +// scenarios: core-u32-unsigned, full-mask, carry-chain +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vsubcs_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, + %arg2: !pto.ptr, %arg3: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c128_i64 = arith.constant 128 : i64 + %c256_i64 = arith.constant 256 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c12288_i64 = arith.constant 12288 : i64 + %false = arith.constant false + + %ub_lhs = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_borrow = pto.castptr %c12288_i64 : i64 -> !pto.ptr + pto.mte_gm_ub %arg0, %ub_lhs, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_rhs, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %borrow_in = pto.pset_b32 "PAT_ALL" : !pto.mask + %lhs = pto.vlds %ub_lhs[%c0] : !pto.ptr -> !pto.vreg<64xui32> + %rhs = pto.vlds %ub_rhs[%c0] : !pto.ptr -> !pto.vreg<64xui32> + %diff, %borrow = pto.vsubcs %lhs, %rhs, %borrow_in, %mask : !pto.vreg<64xui32>, !pto.vreg<64xui32>, !pto.mask, !pto.mask -> !pto.vreg<64xui32>, !pto.mask + pto.vsts %diff, %ub_out[%c0], %mask : !pto.vreg<64xui32>, !pto.ptr, !pto.mask + pto.psti %borrow, %ub_borrow[%c0], "NORM" : !pto.mask, !pto.ptr, index + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg2, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_borrow, %arg3, %c128_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vsubcs/launch.cpp b/test/vpto/cases/micro-op/vec-scalar/vsubcs/launch.cpp new file mode 100644 index 000000000..534b84ab1 --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vsubcs/launch.cpp @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vsubcs +// family: vec-scalar +// target_ops: pto.vsubcs +// scenarios: core-u32-unsigned, full-mask, carry-chain +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vsubcs_kernel(__gm__ uint32_t *v1, __gm__ uint32_t *v2, __gm__ uint32_t *v3, + __gm__ uint8_t *v4); + +void LaunchVsubcs_kernel(uint32_t *v1, uint32_t *v2, uint32_t *v3, uint8_t *v4, + void *stream) { + vsubcs_kernel<<<1, nullptr, stream>>>((__gm__ uint32_t *)v1, + (__gm__ uint32_t *)v2, + (__gm__ uint32_t *)v3, + (__gm__ uint8_t *)v4); +} diff --git a/test/vpto/cases/micro-op/vec-scalar/vsubcs/main.cpp b/test/vpto/cases/micro-op/vec-scalar/vsubcs/main.cpp new file mode 100644 index 000000000..5bcad0fcc --- /dev/null +++ b/test/vpto/cases/micro-op/vec-scalar/vsubcs/main.cpp @@ -0,0 +1,115 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vec-scalar/vsubcs +// family: vec-scalar +// target_ops: pto.vsubcs +// scenarios: core-u32-unsigned, full-mask, carry-chain +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVsubcs_kernel(uint32_t *v1, uint32_t *v2, uint32_t *v3, uint8_t *v4, + void *stream); + +int main() { + size_t elemCount_v1 = 64; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint32_t); + size_t elemCount_v2 = 64; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint32_t); + size_t elemCount_v3 = 64; + size_t fileSize_v3 = elemCount_v3 * sizeof(uint32_t); + size_t elemCount_v4 = 256; + size_t fileSize_v4 = elemCount_v4 * sizeof(uint8_t); + uint32_t *v1Host = nullptr; + uint32_t *v1Device = nullptr; + uint32_t *v2Host = nullptr; + uint32_t *v2Device = nullptr; + uint32_t *v3Host = nullptr; + uint32_t *v3Device = nullptr; + uint8_t *v4Host = nullptr; + uint8_t *v4Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize_v3)); + ACL_CHECK(aclrtMallocHost((void **)(&v4Host), fileSize_v4)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize_v3, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v4Device, fileSize_v4, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ReadFile("./v3.bin", fileSize_v3, v3Host, fileSize_v3); + ReadFile("./v4.bin", fileSize_v4, v4Host, fileSize_v4); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize_v3, v3Host, fileSize_v3, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v4Device, fileSize_v4, v4Host, fileSize_v4, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVsubcs_kernel(v1Device, v2Device, v3Device, v4Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v3Host, fileSize_v3, v3Device, fileSize_v3, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(v4Host, fileSize_v4, v4Device, fileSize_v4, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", v3Host, fileSize_v3); + WriteFile("./v4.bin", v4Host, fileSize_v4); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFree(v4Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + aclrtFreeHost(v4Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/cbuf-ubuf-roundtrip-mixed/compare.py b/test/vpto/cases/micro-op/vector-load-store/cbuf-ubuf-roundtrip-mixed/compare.py new file mode 100644 index 000000000..968ea8ed8 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/cbuf-ubuf-roundtrip-mixed/compare.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.int16) + output = np.fromfile(output_path, dtype=np.int16) + if golden.shape != output.shape: + print(f"[ERROR] shape mismatch: {golden.shape} vs {output.shape}") + return False + if np.array_equal(golden, output): + return True + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] first mismatch at idx={idx}: golden={int(golden[idx])}, out={int(output[idx])}") + return False + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/cbuf-ubuf-roundtrip-mixed/golden.py b/test/vpto/cases/micro-op/vector-load-store/cbuf-ubuf-roundtrip-mixed/golden.py new file mode 100644 index 000000000..d86fd060e --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/cbuf-ubuf-roundtrip-mixed/golden.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 16 +COLS = 16 + + +def generate(output_dir: Path) -> None: + src = ((np.arange(ROWS * COLS, dtype=np.int32).reshape(ROWS, COLS) * 3) % 97 - 48).astype(np.int16) + dst = np.zeros((ROWS, COLS), dtype=np.int16) + golden_dst = src.copy() + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + dst.reshape(-1).tofile(output_dir / "v2.bin") + golden_dst.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/cbuf-ubuf-roundtrip-mixed/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/cbuf-ubuf-roundtrip-mixed/kernel.pto new file mode 100644 index 000000000..8eebb3db9 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/cbuf-ubuf-roundtrip-mixed/kernel.pto @@ -0,0 +1,64 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5"} { + func.func @cbuf_ubuf_roundtrip_mixed_kernel( + %src_gm: !pto.ptr, %dst_gm: !pto.ptr) + attributes {pto.kernel} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + %c512_i64 = arith.constant 512 : i64 + + %l1_rt = pto.castptr %c512_i64 : i64 -> !pto.ptr + %ub0 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub1 = pto.castptr %c512_i64 : i64 -> !pto.ptr + + pto.section.vector { + pto.mte_gm_ub %src_gm, %ub0, %c0_i64, %c32_i64 + nburst(%c16_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + // Match PTO-ISA local ordering: wait for gm->ub before consuming UB in ub->cbuf. + pto.set_flag["PIPE_MTE2", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_MTE3", "EVENT_ID0"] + + pto.copy_ubuf_to_cbuf %ub0, %l1_rt, %c0_i64, %c1_i64, %c16_i64, %c0_i64, %c0_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.copy_ubuf_to_cbuf %ub0, %l1_rt, %c1_i64, %c1_i64, %c16_i64, %c0_i64, %c0_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.sync.set , 0 + + pto.sync.wait , 1 + + pto.mte_ub_gm %ub1, %dst_gm, %c32_i64 + nburst(%c16_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + } + + pto.section.cube { + // Follow the working PTO-ISA tmov_ub2l1 handoff: + // vec produces L1, cube waits on MTE1 before reading cbuf. + pto.sync.wait , 0 + pto.sync.wait , 16 + + pto.copy_cbuf_to_ubuf %l1_rt, %ub1, %c0_i64, %c1_i64, %c16_i64, %c0_i64, %c0_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.copy_cbuf_to_ubuf %l1_rt, %ub1, %c1_i64, %c1_i64, %c16_i64, %c0_i64, %c0_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + // cube produces UB, vec waits on MTE3 for the return handoff. + pto.sync.set , 1 + pto.sync.set , 17 + } + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/cbuf-ubuf-roundtrip-mixed/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/cbuf-ubuf-roundtrip-mixed/launch.cpp new file mode 100644 index 000000000..6c24c978e --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/cbuf-ubuf-roundtrip-mixed/launch.cpp @@ -0,0 +1,46 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void cbuf_ubuf_roundtrip_mixed_kernel( + __gm__ int16_t *src, __gm__ int16_t *dst); + +void LaunchCbuf_ubuf_roundtrip_mixed_kernel(int16_t *src, int16_t *dst, + void *stream) { + cbuf_ubuf_roundtrip_mixed_kernel<<<1, nullptr, stream>>>( + (__gm__ int16_t *)src, (__gm__ int16_t *)dst); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/cbuf-ubuf-roundtrip-mixed/main.cpp b/test/vpto/cases/micro-op/vector-load-store/cbuf-ubuf-roundtrip-mixed/main.cpp new file mode 100644 index 000000000..d9cb29f6a --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/cbuf-ubuf-roundtrip-mixed/main.cpp @@ -0,0 +1,114 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" + +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +#define FILE_CHECK(expr, path) \ + do { \ + if (!(expr)) { \ + std::fprintf(stderr, "[ERROR] file operation failed: %s (%s:%d)\n", \ + path, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchCbuf_ubuf_roundtrip_mixed_kernel(int16_t *src, int16_t *dst, + void *stream); + +int main() { + constexpr size_t elemCount = 16 * 16; + constexpr size_t bufSize = elemCount * sizeof(int16_t); + + int16_t *srcHost = nullptr; + int16_t *dstHost = nullptr; + int16_t *srcDevice = nullptr; + int16_t *dstDevice = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + size_t inputSize = 0; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)&srcHost, bufSize)); + ACL_CHECK(aclrtMallocHost((void **)&dstHost, bufSize)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, bufSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, bufSize, ACL_MEM_MALLOC_HUGE_FIRST)); + + inputSize = bufSize; + FILE_CHECK(ReadFile("./v1.bin", inputSize, srcHost, bufSize) && inputSize == bufSize, + "./v1.bin"); + inputSize = bufSize; + FILE_CHECK(ReadFile("./v2.bin", inputSize, dstHost, bufSize) && inputSize == bufSize, + "./v2.bin"); + + ACL_CHECK(aclrtMemcpy(srcDevice, bufSize, srcHost, bufSize, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, bufSize, dstHost, bufSize, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchCbuf_ubuf_roundtrip_mixed_kernel(srcDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + + ACL_CHECK(aclrtMemcpy(dstHost, bufSize, dstDevice, bufSize, + ACL_MEMCPY_DEVICE_TO_HOST)); + FILE_CHECK(WriteFile("./v2.bin", dstHost, bufSize), "./v2.bin"); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/dma-copy-rearrange/compare.py b/test/vpto/cases/micro-op/vector-load-store/dma-copy-rearrange/compare.py new file mode 100644 index 000000000..3b08bd2b0 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/dma-copy-rearrange/compare.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/dma-copy-rearrange +# family: micro-op/vector-load-store +# target_ops: pto.copy_gm_to_ubuf, pto.mte_ub_ub, pto.copy_ubuf_to_gm +# scenarios: i16, ub-rearrange, permute-4x16-rows + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.int16) + output = np.fromfile(output_path, dtype=np.int16) + if golden.shape != output.shape: + print(f"[ERROR] shape mismatch: {golden.shape} vs {output.shape}") + return False + if np.array_equal(golden, output): + return True + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print(f"[ERROR] first mismatch at idx={idx}: golden={int(golden[idx])}, out={int(output[idx])}") + return False + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/dma-copy-rearrange/golden.py b/test/vpto/cases/micro-op/vector-load-store/dma-copy-rearrange/golden.py new file mode 100644 index 000000000..1056d10e1 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/dma-copy-rearrange/golden.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/dma-copy-rearrange +# family: micro-op/vector-load-store +# target_ops: pto.copy_gm_to_ubuf, pto.mte_ub_ub, pto.copy_ubuf_to_gm +# scenarios: i16, ub-rearrange, permute-4x16-rows + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 4 +COLS = 16 + + +def generate(output_dir: Path) -> None: + v1 = np.arange(ROWS * COLS, dtype=np.int16).reshape(ROWS, COLS) + v2 = np.zeros((ROWS, COLS), dtype=np.int16) + golden_v2 = v1[[2, 0, 3, 1], :].copy() + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/dma-copy-rearrange/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/dma-copy-rearrange/kernel.pto new file mode 100644 index 000000000..ae7db0d7a --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/dma-copy-rearrange/kernel.pto @@ -0,0 +1,67 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/dma-copy-rearrange +// family: micro-op/vector-load-store +// target_ops: pto.copy_gm_to_ubuf, pto.mte_ub_ub, pto.copy_ubuf_to_gm +// scenarios: i16, ub-rearrange, permute-4x16-rows +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @dma_copy_rearrange_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr, + %arg2: i64, + %arg3: i64, + %arg4: i64, + %arg5: i64) attributes {pto.kernel} { + %false = arith.constant false + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c96_i64 = arith.constant 96 : i64 + %c128_i64 = arith.constant 128 : i64 + %c160_i64 = arith.constant 160 : i64 + %c192_i64 = arith.constant 192 : i64 + %c224_i64 = arith.constant 224 : i64 + %c256_i64 = arith.constant 256 : i64 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c128_i64 : i64 -> !pto.ptr + + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c1_i64, %c128_i64, + %c0_i64, %c0_i64, %false, %c0_i64, %c128_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.barrier #pto.pipe + + %src_row0 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %src_row1 = pto.castptr %c32_i64 : i64 -> !pto.ptr + %src_row2 = pto.castptr %c64_i64 : i64 -> !pto.ptr + %src_row3 = pto.castptr %c96_i64 : i64 -> !pto.ptr + + %dst_row0 = pto.castptr %c128_i64 : i64 -> !pto.ptr + %dst_row1 = pto.castptr %c160_i64 : i64 -> !pto.ptr + %dst_row2 = pto.castptr %c192_i64 : i64 -> !pto.ptr + %dst_row3 = pto.castptr %c224_i64 : i64 -> !pto.ptr + + pto.mte_ub_ub %src_row2, %dst_row0, %arg3 + nburst(%arg2, %arg4, %arg5) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_ub %src_row0, %dst_row1, %arg3 + nburst(%arg2, %arg4, %arg5) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_ub %src_row3, %dst_row2, %arg3 + nburst(%arg2, %arg4, %arg5) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_ub %src_row1, %dst_row3, %arg3 + nburst(%arg2, %arg4, %arg5) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + + pto.barrier #pto.pipe + + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c1_i64, %c128_i64, %c0_i64, %c128_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/dma-copy-rearrange/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/dma-copy-rearrange/launch.cpp new file mode 100644 index 000000000..e418c4f9a --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/dma-copy-rearrange/launch.cpp @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void dma_copy_rearrange_kernel(__gm__ int16_t *v1, + __gm__ int16_t *v2, + int64_t n_burst, + int64_t len_burst, + int64_t src_gap, + int64_t dst_gap); + +void LaunchDma_copy_rearrange_kernel(int16_t *v1, int16_t *v2, + int64_t n_burst, int64_t len_burst, + int64_t src_gap, int64_t dst_gap, + void *stream) { + dma_copy_rearrange_kernel<<<1, nullptr, stream>>>((__gm__ int16_t *)v1, + (__gm__ int16_t *)v2, + n_burst, len_burst, + src_gap, dst_gap); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/dma-copy-rearrange/main.cpp b/test/vpto/cases/micro-op/vector-load-store/dma-copy-rearrange/main.cpp new file mode 100644 index 000000000..6ab7b0c13 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/dma-copy-rearrange/main.cpp @@ -0,0 +1,128 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/dma-copy-rearrange +// family: micro-op/vector-load-store +// target_ops: pto.copy_gm_to_ubuf, pto.mte_ub_ub, pto.copy_ubuf_to_gm +// scenarios: i16, ub-rearrange, permute-4x16-rows +// ----------------------------------------------------------------------------- + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +#define FILE_CHECK(expr, path) \ + do { \ + if (!(expr)) { \ + std::fprintf(stderr, "[ERROR] file operation failed: %s (%s:%d)\n", \ + path, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchDma_copy_rearrange_kernel(int16_t *v1, int16_t *v2, + int64_t n_burst, int64_t len_burst, + int64_t src_gap, int64_t dst_gap, + void *stream); + +int main() { + constexpr size_t elemCount = 64; + constexpr size_t fileSize = elemCount * sizeof(int16_t); + size_t inputFileSize = fileSize; + + int16_t *v1Host = nullptr; + int16_t *v2Host = nullptr; + int16_t *v1Device = nullptr; + int16_t *v2Device = nullptr; + const int64_t nBurst = 1; + const int64_t lenBurst = 1; + const int64_t srcGap = 0; + const int64_t dstGap = 0; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST)); + + FILE_CHECK(ReadFile("./v1.bin", inputFileSize, v1Host, fileSize) && + inputFileSize == fileSize, + "./v1.bin"); + inputFileSize = fileSize; + FILE_CHECK(ReadFile("./v2.bin", inputFileSize, v2Host, fileSize) && + inputFileSize == fileSize, + "./v2.bin"); + + ACL_CHECK(aclrtMemcpy(v1Device, fileSize, v1Host, fileSize, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize, v2Host, fileSize, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchDma_copy_rearrange_kernel(v1Device, v2Device, nBurst, lenBurst, + srcGap, dstGap, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + + ACL_CHECK(aclrtMemcpy(v2Host, fileSize, v2Device, fileSize, + ACL_MEMCPY_DEVICE_TO_HOST)); + + FILE_CHECK(WriteFile("./v2.bin", v2Host, fileSize), "./v2.bin"); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/issue-173-vsts-signed-signless/compare.py b/test/vpto/cases/micro-op/vector-load-store/issue-173-vsts-signed-signless/compare.py new file mode 100644 index 000000000..0b84e2102 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/issue-173-vsts-signed-signless/compare.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/issue-173-vsts-signed-signless +# family: micro-op/vector-load-store +# target_ops: pto.vlds, pto.vsts +# scenarios: signed-i16, signless-i16, same-module, issue-173-regression + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path: str, output_path: str) -> bool: + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=np.int16) + output = np.fromfile(output_path, dtype=np.int16) + return golden.shape == output.shape and np.array_equal(golden, output) + + +def main() -> None: + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin") + ok = compare_bin("golden_v4.bin", "v4.bin") and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/issue-173-vsts-signed-signless/golden.py b/test/vpto/cases/micro-op/vector-load-store/issue-173-vsts-signed-signless/golden.py new file mode 100644 index 000000000..7a92aee52 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/issue-173-vsts-signed-signless/golden.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/issue-173-vsts-signed-signless +# family: micro-op/vector-load-store +# target_ops: pto.vlds, pto.vsts +# scenarios: signed-i16, signless-i16, same-module, issue-173-regression + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMS = 1024 +SEED = 173 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + signed = rng.integers(-32768, 32768, size=ELEMS, dtype=np.int16) + signless = rng.integers(-32768, 32768, size=ELEMS, dtype=np.int16) + + signed[:16] = np.array( + [-32768, -30000, -12345, -1, 0, 1, 2, 3, 7, 15, 127, 255, 1024, 12345, 30000, 32767], + dtype=np.int16, + ) + signless[:16] = np.array( + [32767, 30000, 12345, 1024, 255, 127, 15, 7, 3, 2, 1, 0, -1, -12345, -30000, -32768], + dtype=np.int16, + ) + + output_dir.mkdir(parents=True, exist_ok=True) + signed.tofile(output_dir / "v1.bin") + np.zeros(ELEMS, dtype=np.int16).tofile(output_dir / "v2.bin") + signless.tofile(output_dir / "v3.bin") + np.zeros(ELEMS, dtype=np.int16).tofile(output_dir / "v4.bin") + signed.tofile(output_dir / "golden_v2.bin") + signless.tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/issue-173-vsts-signed-signless/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/issue-173-vsts-signed-signless/kernel.pto new file mode 100644 index 000000000..ae350dcc5 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/issue-173-vsts-signed-signless/kernel.pto @@ -0,0 +1,81 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/issue-173-vsts-signed-signless +// family: micro-op/vector-load-store +// target_ops: pto.vlds, pto.vsts +// scenarios: signed-i16, signless-i16, same-module, issue-173-regression +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @copy_signed_i16_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<128xsi16> + pto.vsts %vec, %ub_out[%offset], %mask : !pto.vreg<128xsi16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } + + func.func @copy_signless_i16_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + pto.copy_gm_to_ubuf %arg0, %ub_in, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %offset = %c0 to %c1024 step %c128 { + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<128xi16> + pto.vsts %vec, %ub_out[%offset], %mask : !pto.vreg<128xi16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + pto.copy_ubuf_to_gm %ub_out, %arg1, %c0_i64, %c32_i64, %c64_i64, %c0_i64, %c64_i64, %c64_i64 : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/issue-173-vsts-signed-signless/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/issue-173-vsts-signed-signless/launch.cpp new file mode 100644 index 000000000..3153aaff6 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/issue-173-vsts-signed-signless/launch.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void copy_signed_i16_kernel(__gm__ int16_t *v1, + __gm__ int16_t *v2); +extern "C" __global__ [aicore] void copy_signless_i16_kernel( + __gm__ int16_t *v3, __gm__ int16_t *v4); + +void LaunchCopySignedI16Kernel(int16_t *v1, int16_t *v2, void *stream) { + copy_signed_i16_kernel<<<1, nullptr, stream>>>((__gm__ int16_t *)v1, + (__gm__ int16_t *)v2); +} + +void LaunchCopySignlessI16Kernel(int16_t *v3, int16_t *v4, void *stream) { + copy_signless_i16_kernel<<<1, nullptr, stream>>>((__gm__ int16_t *)v3, + (__gm__ int16_t *)v4); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/issue-173-vsts-signed-signless/main.cpp b/test/vpto/cases/micro-op/vector-load-store/issue-173-vsts-signed-signless/main.cpp new file mode 100644 index 000000000..4500f5748 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/issue-173-vsts-signed-signless/main.cpp @@ -0,0 +1,150 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/issue-173-vsts-signed-signless +// family: micro-op/vector-load-store +// target_ops: pto.vlds, pto.vsts +// scenarios: signed-i16, signless-i16, same-module, issue-173-regression +// ----------------------------------------------------------------------------- + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +#define FILE_CHECK(expr, path) \ + do { \ + if (!(expr)) { \ + std::fprintf(stderr, "[ERROR] file operation failed: %s (%s:%d)\n", \ + path, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchCopySignedI16Kernel(int16_t *v1, int16_t *v2, void *stream); +void LaunchCopySignlessI16Kernel(int16_t *v3, int16_t *v4, void *stream); + +int main() { + constexpr size_t elemCount = 1024; + constexpr size_t fileSize = elemCount * sizeof(int16_t); + size_t inputFileSize = fileSize; + + int16_t *v1Host = nullptr; + int16_t *v2Host = nullptr; + int16_t *v3Host = nullptr; + int16_t *v4Host = nullptr; + int16_t *v1Device = nullptr; + int16_t *v2Device = nullptr; + int16_t *v3Device = nullptr; + int16_t *v4Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize)); + ACL_CHECK(aclrtMallocHost((void **)(&v3Host), fileSize)); + ACL_CHECK(aclrtMallocHost((void **)(&v4Host), fileSize)); + + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v3Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v4Device, fileSize, ACL_MEM_MALLOC_HUGE_FIRST)); + + FILE_CHECK(ReadFile("./v1.bin", inputFileSize, v1Host, fileSize) && + inputFileSize == fileSize, + "./v1.bin"); + inputFileSize = fileSize; + FILE_CHECK(ReadFile("./v2.bin", inputFileSize, v2Host, fileSize) && + inputFileSize == fileSize, + "./v2.bin"); + inputFileSize = fileSize; + FILE_CHECK(ReadFile("./v3.bin", inputFileSize, v3Host, fileSize) && + inputFileSize == fileSize, + "./v3.bin"); + inputFileSize = fileSize; + FILE_CHECK(ReadFile("./v4.bin", inputFileSize, v4Host, fileSize) && + inputFileSize == fileSize, + "./v4.bin"); + + ACL_CHECK(aclrtMemcpy(v1Device, fileSize, v1Host, fileSize, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize, v2Host, fileSize, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v3Device, fileSize, v3Host, fileSize, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v4Device, fileSize, v4Host, fileSize, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchCopySignedI16Kernel(v1Device, v2Device, stream); + LaunchCopySignlessI16Kernel(v3Device, v4Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + + ACL_CHECK(aclrtMemcpy(v2Host, fileSize, v2Device, fileSize, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(v4Host, fileSize, v4Device, fileSize, + ACL_MEMCPY_DEVICE_TO_HOST)); + + FILE_CHECK(WriteFile("./v2.bin", v2Host, fileSize), "./v2.bin"); + FILE_CHECK(WriteFile("./v4.bin", v4Host, fileSize), "./v4.bin"); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFree(v3Device); + aclrtFree(v4Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + aclrtFreeHost(v3Host); + aclrtFreeHost(v4Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vldas-vldus-state-chain/compare.py b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus-state-chain/compare.py new file mode 100755 index 000000000..bc0d8fc41 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus-state-chain/compare.py @@ -0,0 +1,210 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vldas-vldus-state-chain +# family: vector-load-store +# target_ops: pto.vldas, pto.vldus +# scenarios: core-f32, full-mask, unaligned, stream-state, state-update +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 +PREFIX_ELEMS = 128 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin_prefix("golden_v2.bin", "v2.bin", np.float32, 0.0001, PREFIX_ELEMS) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vldas-vldus-state-chain/golden.py b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus-state-chain/golden.py new file mode 100755 index 000000000..926db9342 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus-state-chain/golden.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vldas-vldus-state-chain +# family: vector-load-store +# target_ops: pto.vldas, pto.vldus +# scenarios: core-f32, full-mask, unaligned, repeated-no-post +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +LANES = 64 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.zeros((ROWS, COLS), dtype=np.float32) + flat_in = v1.reshape(-1) + flat_out = golden_v2.reshape(-1) + flat_out[:LANES] = flat_in[1 : 1 + LANES] + flat_out[LANES : 2 * LANES] = flat_in[65 : 65 + LANES] + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vldas-vldus state-chain validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vldas-vldus-state-chain/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus-state-chain/kernel.pto new file mode 100644 index 000000000..e460e1308 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus-state-chain/kernel.pto @@ -0,0 +1,61 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vldas-vldus-state-chain +// family: vector-load-store +// target_ops: pto.vldas, pto.vldus +// scenarios: core-f32, full-mask, unaligned, repeated-no-post +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// Validate repeated no-post unaligned loads. Each `pto.vldus` is paired with +// its own `pto.vldas` and uses an explicit unaligned source pointer; the second +// load does not depend on state returned from the first one. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vldas_vldus_state_chain_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %offset = %c0 to %c64 step %c64 { + %src0 = pto.addptr %ub_in, %c1 : !pto.ptr -> !pto.ptr + %src1 = pto.addptr %src0, %c64 : !pto.ptr -> !pto.ptr + %align0 = pto.vldas %src0 : !pto.ptr -> !pto.align + %mask, %next_remaining = pto.plt_b32 %c1024_i32 : i32 -> !pto.mask, i32 + %out0, %align1 = pto.vldus %src0, %align0 + : !pto.ptr, !pto.align -> !pto.vreg<64xf32>, !pto.align + %align2 = pto.vldas %src1 : !pto.ptr -> !pto.align + %out1, %align3 = pto.vldus %src1, %align2 + : !pto.ptr, !pto.align -> !pto.vreg<64xf32>, !pto.align + pto.vsts %out0, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + pto.vsts %out1, %ub_out[%c64], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vldas-vldus-state-chain/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus-state-chain/launch.cpp new file mode 100644 index 000000000..6a077aa65 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus-state-chain/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vldas-vldus-state-chain +// family: vector-load-store +// target_ops: pto.vldas, pto.vldus +// scenarios: core-f32, full-mask, unaligned, stream-state, state-update +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vldas_vldus_state_chain_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVldasVldusStateChain_kernel_2d(float *v1, float *v2, void *stream) { + vldas_vldus_state_chain_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vldas-vldus-state-chain/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus-state-chain/main.cpp new file mode 100644 index 000000000..95044a646 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus-state-chain/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vldas-vldus-state-chain +// family: vector-load-store +// target_ops: pto.vldas, pto.vldus +// scenarios: core-f32, full-mask, unaligned, stream-state, state-update +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVldasVldusStateChain_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVldasVldusStateChain_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vldas-vldus/compare.py b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus/compare.py new file mode 100755 index 000000000..f4bfe43cc --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus/compare.py @@ -0,0 +1,210 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vldas-vldus +# family: vector-load-store +# target_ops: pto.vldas, pto.vldus +# scenarios: core-f32, full-mask, unaligned, stream-state +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 +PREFIX_ELEMS = 64 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin_prefix("golden_v2.bin", "v2.bin", np.float32, 0.0001, PREFIX_ELEMS) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vldas-vldus/golden.py b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus/golden.py new file mode 100755 index 000000000..77d537906 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus/golden.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vldas-vldus +# family: vector-load-store +# target_ops: pto.vldas, pto.vldus +# scenarios: core-f32, full-mask, unaligned, stream-state +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +LANES = 64 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.zeros((ROWS, COLS), dtype=np.float32) + flat_in = v1.reshape(-1) + flat_out = golden_v2.reshape(-1) + flat_out[:LANES] = flat_in[1 : 1 + LANES] + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vldas-vldus validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vldas-vldus/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus/kernel.pto new file mode 100644 index 000000000..c5e6da2b5 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus/kernel.pto @@ -0,0 +1,71 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vldas-vldus +// family: vector-load-store +// target_ops: pto.vldas, pto.vldus +// scenarios: core-f32, full-mask, unaligned, stream-state +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vabs_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %offset = %c0 to %c64 step %c64 { + %src0 = pto.addptr %ub_in, %c1 : !pto.ptr -> !pto.ptr + %align0 = pto.vldas %src0 : !pto.ptr -> !pto.align + %mask, %next_remaining = pto.plt_b32 %c1024_i32 : i32 -> !pto.mask, i32 + %out, %next_align = pto.vldus %src0, %align0 + : !pto.ptr, !pto.align -> !pto.vreg<64xf32>, !pto.align + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vldas-vldus/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus/launch.cpp new file mode 100644 index 000000000..b25715db0 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vldas-vldus +// family: vector-load-store +// target_ops: pto.vldas, pto.vldus +// scenarios: core-f32, full-mask, unaligned, stream-state +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream) { + vabs_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vldas-vldus/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus/main.cpp new file mode 100644 index 000000000..7bf75309b --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldas-vldus/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vldas-vldus +// family: vector-load-store +// target_ops: pto.vldas, pto.vldus +// scenarios: core-f32, full-mask, unaligned, stream-state +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16-f32/compare.py b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16-f32/compare.py new file mode 100644 index 000000000..4c19eb038 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16-f32/compare.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16-f32/golden.py b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16-f32/golden.py new file mode 100644 index 000000000..8cc8dbe42 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16-f32/golden.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vlds-brc-b16-f32 +# family: vector-load-store +# target_ops: pto.vlds +# scenarios: core-f32, full-mask, aligned, dist-brc-b16, width-agnostic-dist + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMENTS = 1024 +LANES = 64 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ELEMENTS,)).astype(np.float32) + v2 = np.zeros((ELEMENTS,), dtype=np.float32) + + src_bytes = v1.view(np.uint8) + golden_bytes = np.zeros_like(src_bytes) + chunk_bytes = LANES * 4 + for offset in range(0, src_bytes.size, chunk_bytes): + pattern = src_bytes[offset : offset + 2] + tiled = np.tile(pattern, chunk_bytes // 2) + golden_bytes[offset : offset + chunk_bytes] = tiled + golden_v2 = golden_bytes.view(np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vlds BRC_B16 on f32 validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16-f32/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16-f32/kernel.pto new file mode 100644 index 000000000..e94fe5b01 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16-f32/kernel.pto @@ -0,0 +1,47 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-brc-b16-f32 +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-f32, full-mask, aligned, dist-brc-b16, width-agnostic-dist +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vlds_brc_b16_f32_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %out = pto.vlds %ub_in[%offset] {dist = "BRC_B16"} : !pto.ptr -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16-f32/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16-f32/launch.cpp new file mode 100644 index 000000000..530496dba --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16-f32/launch.cpp @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vlds_brc_b16_f32_kernel(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVlds_brc_b16_f32_kernel(float *v1, float *v2, void *stream) { + vlds_brc_b16_f32_kernel<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} + diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16-f32/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16-f32/main.cpp new file mode 100644 index 000000000..661e47152 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16-f32/main.cpp @@ -0,0 +1,122 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVlds_brc_b16_f32_kernel(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVlds_brc_b16_f32_kernel(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16/compare.py b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16/compare.py new file mode 100755 index 000000000..51c671b74 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16/compare.py @@ -0,0 +1,210 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vlds-brc-b16 +# family: vector-load-store +# target_ops: pto.vlds +# scenarios: core-f16, full-mask, aligned, dist-brc-b16 +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 +PREFIX_ELEMS = 1024 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin_prefix("golden_v2.bin", "v2.bin", np.float16, 0.0001, PREFIX_ELEMS) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16/golden.py b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16/golden.py new file mode 100755 index 000000000..1f593f79d --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16/golden.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vlds-brc-b16 +# family: vector-load-store +# target_ops: pto.vlds +# scenarios: core-f16, full-mask, aligned, dist-brc-b16 +# NOTE: BRC on b16 broadcasts the first f16 element of each 128-lane chunk. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMENTS = 2048 +ACTIVE_ELEMS = 1024 +LANES = 128 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ELEMENTS,)).astype(np.float16) + v2 = np.zeros((ELEMENTS,), dtype=np.float16) + golden_v2 = np.zeros((ELEMENTS,), dtype=np.float16) + for offset in range(0, ACTIVE_ELEMS, LANES): + golden_v2[offset : offset + LANES] = v1[offset] + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vlds b16 broadcast validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16/kernel.pto new file mode 100644 index 000000000..95fe2fb62 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16/kernel.pto @@ -0,0 +1,54 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-brc-b16 +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-f16, full-mask, aligned, dist-brc-b16 +// ----------------------------------------------------------------------------- +// Validate one representative `BRC_B16` load on `b16`. +// The case keeps the structure minimal: +// 1. DMA one input tile into UB +// 2. issue `pto.vlds` with `dist = "BRC_B16"` inside `pto.vecscope` +// 3. store the resulting vector back through `pto.vsts` + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vabs_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %out = pto.vlds %ub_in[%offset] {dist = "BRC_B16"} : !pto.ptr -> !pto.vreg<128xf16> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<128xf16>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16/launch.cpp new file mode 100644 index 000000000..1d8ac9f5d --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-brc-b16 +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-f16, full-mask, aligned, dist-brc-b16 +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream) { + vabs_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16/main.cpp new file mode 100644 index 000000000..cbec16893 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b16/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-brc-b16 +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-f16, full-mask, aligned, dist-brc-b16 +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b32/compare.py b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b32/compare.py new file mode 100755 index 000000000..16f027ac0 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b32/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vlds-brc-b32 +# family: vector-load-store +# target_ops: pto.vlds +# scenarios: core-f32, full-mask, aligned, dist-brc-b32 +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b32/golden.py b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b32/golden.py new file mode 100755 index 000000000..541e3d770 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b32/golden.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vlds-brc-b32 +# family: vector-load-store +# target_ops: pto.vlds +# scenarios: core-f32, full-mask, aligned, dist-brc-b32 +# NOTE: BRC on b32 broadcasts the first f32 element of each 64-lane chunk. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMENTS = 1024 +LANES = 64 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ELEMENTS,)).astype(np.float32) + v2 = np.zeros((ELEMENTS,), dtype=np.float32) + golden_v2 = np.empty((ELEMENTS,), dtype=np.float32) + for offset in range(0, ELEMENTS, LANES): + golden_v2[offset : offset + LANES] = v1[offset] + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vlds b32 broadcast validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b32/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b32/kernel.pto new file mode 100644 index 000000000..fac81676f --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b32/kernel.pto @@ -0,0 +1,69 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-brc-b32 +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-f32, full-mask, aligned, dist-brc-b32 +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vabs_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %out = pto.vlds %ub_in[%offset] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b32/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b32/launch.cpp new file mode 100644 index 000000000..bc599a4ac --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b32/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-brc-b32 +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-f32, full-mask, aligned, dist-brc-b32 +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream) { + vabs_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b32/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b32/main.cpp new file mode 100644 index 000000000..567ff23cf --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b32/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-brc-b32 +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-f32, full-mask, aligned, dist-brc-b32 +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b8-f32/compare.py b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b8-f32/compare.py new file mode 100644 index 000000000..4c19eb038 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b8-f32/compare.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b8-f32/golden.py b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b8-f32/golden.py new file mode 100644 index 000000000..be66bd820 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b8-f32/golden.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vlds-brc-b8-f32 +# family: vector-load-store +# target_ops: pto.vlds +# scenarios: core-f32, full-mask, aligned, dist-brc-b8, width-agnostic-dist + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMENTS = 1024 +LANES = 64 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ELEMENTS,)).astype(np.float32) + v2 = np.zeros((ELEMENTS,), dtype=np.float32) + + src_bytes = v1.view(np.uint8) + golden_bytes = np.zeros_like(src_bytes) + chunk_bytes = LANES * 4 + for offset in range(0, src_bytes.size, chunk_bytes): + pattern = src_bytes[offset] + golden_bytes[offset : offset + chunk_bytes] = pattern + golden_v2 = golden_bytes.view(np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vlds BRC_B8 on f32 validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b8-f32/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b8-f32/kernel.pto new file mode 100644 index 000000000..fac3510de --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b8-f32/kernel.pto @@ -0,0 +1,47 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-brc-b8-f32 +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-f32, full-mask, aligned, dist-brc-b8, width-agnostic-dist +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vlds_brc_b8_f32_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %out = pto.vlds %ub_in[%offset] {dist = "BRC_B8"} : !pto.ptr -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b8-f32/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b8-f32/launch.cpp new file mode 100644 index 000000000..4628e9dc1 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b8-f32/launch.cpp @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vlds_brc_b8_f32_kernel(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVlds_brc_b8_f32_kernel(float *v1, float *v2, void *stream) { + vlds_brc_b8_f32_kernel<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} + diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b8-f32/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b8-f32/main.cpp new file mode 100644 index 000000000..bf2d99510 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-b8-f32/main.cpp @@ -0,0 +1,122 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVlds_brc_b8_f32_kernel(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVlds_brc_b8_f32_kernel(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-blk/compare.py b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-blk/compare.py new file mode 100755 index 000000000..db7576dcf --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-blk/compare.py @@ -0,0 +1,210 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vlds-brc-blk +# family: vector-load-store +# target_ops: pto.vlds +# scenarios: core-u8, full-mask, aligned, dist-brc-blk +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 +PREFIX_ELEMS = 1024 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin_prefix("golden_v2.bin", "v2.bin", np.uint8, 0.0, PREFIX_ELEMS) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-blk/golden.py b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-blk/golden.py new file mode 100755 index 000000000..4b9610b1e --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-blk/golden.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vlds-brc-blk +# family: vector-load-store +# target_ops: pto.vlds +# scenarios: core-u8, full-mask, aligned, dist-brc-blk +# NOTE: BRC_BLK repeats the first 32-byte block across each 256-byte vector chunk. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMENTS = 4096 +ACTIVE_ELEMS = 1024 +LANES = 256 +BLOCK_BYTES = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(0, 256, size=(ELEMENTS,), dtype=np.uint8) + v2 = np.zeros((ELEMENTS,), dtype=np.uint8) + golden_v2 = np.zeros((ELEMENTS,), dtype=np.uint8) + repeats = LANES // BLOCK_BYTES + for offset in range(0, ACTIVE_ELEMS, LANES): + golden_v2[offset : offset + LANES] = np.tile(v1[offset : offset + BLOCK_BYTES], repeats) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vlds block-broadcast validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-blk/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-blk/kernel.pto new file mode 100644 index 000000000..d3c29c19c --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-blk/kernel.pto @@ -0,0 +1,50 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-brc-blk +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-u8, full-mask, aligned, dist-brc-blk +// ----------------------------------------------------------------------------- +// Validate one representative `BRC_BLK` load. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vabs_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c256 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b8 %remaining : i32 -> !pto.mask, i32 + %out = pto.vlds %ub_in[%offset] {dist = "BRC_BLK"} : !pto.ptr -> !pto.vreg<256xui8> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<256xui8>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-blk/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-blk/launch.cpp new file mode 100644 index 000000000..0299a6136 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-blk/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-brc-blk +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-u8, full-mask, aligned, dist-brc-blk +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream) { + vabs_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-brc-blk/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-blk/main.cpp new file mode 100644 index 000000000..7e62df365 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-brc-blk/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-brc-blk +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-u8, full-mask, aligned, dist-brc-blk +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-dma-loop/compare.py b/test/vpto/cases/micro-op/vector-load-store/vlds-dma-loop/compare.py new file mode 100644 index 000000000..81cfc5edb --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-dma-loop/compare.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden.shape} vs {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + abs_diff = np.abs(golden.astype(np.float64) - output.astype(np.float64)) + idx = int(np.argmax(abs_diff)) + print( + f"[ERROR] Mismatch at idx={idx}: golden={golden[idx]}, out={output[idx]}, " + f"diff={abs_diff[idx]}" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 1e-4) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-dma-loop/golden.py b/test/vpto/cases/micro-op/vector-load-store/vlds-dma-loop/golden.py new file mode 100644 index 000000000..ee6929e7f --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-dma-loop/golden.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 8 +INPUT_COLS = 56 +OUTPUT_COLS = 64 +PAD_VALUE = 1.0 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, INPUT_COLS)).astype(np.float32) + v2 = np.zeros((ROWS, OUTPUT_COLS), dtype=np.float32) + golden_v2 = np.full((ROWS, OUTPUT_COLS), PAD_VALUE, dtype=np.float32) + golden_v2[:, :INPUT_COLS] = v1 + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate inputs/golden for VPTO micro-op vlds dma loop validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-dma-loop/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vlds-dma-loop/kernel.pto new file mode 100644 index 000000000..e452ea596 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-dma-loop/kernel.pto @@ -0,0 +1,67 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-dma-loop +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-f32, dma-loop-load-store, sw-loop-plus-hw-loop, full-mask, aligned, dist-norm +// ----------------------------------------------------------------------------- + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vlds_dma_loop_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c512 = arith.constant 512 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c2_i64 = arith.constant 2 : i64 + %c8_i64 = arith.constant 8 : i64 + %c224_i64 = arith.constant 224 : i64 + %c256_i64 = arith.constant 256 : i64 + %c448_i64 = arith.constant 448 : i64 + %c512_i64 = arith.constant 512 : i64 + %c896_i64 = arith.constant 896 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c512_i32 = arith.constant 512 : i32 + %pad = arith.constant 1.000000e+00 : f32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c224_i64 + nburst(%c1_i64, %c224_i64, %c256_i64) + loop(%c2_i64, %c224_i64, %c256_i64) + loop(%c2_i64, %c448_i64, %c512_i64) + loop(%c2_i64, %c896_i64, %c1024_i64) + pad(%pad, %c0_i64, %c8_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, + loop i64, i64, i64, loop i64, i64, i64, loop i64, i64, i64, pad f32, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c512 step %c64 iter_args(%remaining = %c512_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %value = pto.vlds %ub_in[%offset] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> + pto.vsts %value, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + + pto.mte_ub_gm %ub_out, %arg1, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + loop(%c2_i64, %c256_i64, %c256_i64) + loop(%c2_i64, %c512_i64, %c512_i64) + loop(%c2_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, + loop i64, i64, i64, loop i64, i64, i64, loop i64, i64, i64 + + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-dma-loop/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds-dma-loop/launch.cpp new file mode 100644 index 000000000..3f59702eb --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-dma-loop/launch.cpp @@ -0,0 +1,41 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vlds_dma_loop_kernel(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVlds_dma_loop_kernel(float *v1, float *v2, void *stream) { + vlds_dma_loop_kernel<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-dma-loop/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds-dma-loop/main.cpp new file mode 100644 index 000000000..42d510325 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-dma-loop/main.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "test_common.h" +#include "acl/acl.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVlds_dma_loop_kernel(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 448; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 512; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVlds_dma_loop_kernel(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-ds-b16/compare.py b/test/vpto/cases/micro-op/vector-load-store/vlds-ds-b16/compare.py new file mode 100755 index 000000000..e558d22f2 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-ds-b16/compare.py @@ -0,0 +1,210 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vlds-ds-b16 +# family: vector-load-store +# target_ops: pto.vlds +# scenarios: core-i16, full-mask, aligned, dist-ds-b16 +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 +PREFIX_ELEMS = 1024 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin_prefix("golden_v2.bin", "v2.bin", np.int16, 0.0, PREFIX_ELEMS) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-ds-b16/golden.py b/test/vpto/cases/micro-op/vector-load-store/vlds-ds-b16/golden.py new file mode 100755 index 000000000..63da9f605 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-ds-b16/golden.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vlds-ds-b16 +# family: vector-load-store +# target_ops: pto.vlds +# scenarios: core-i16, full-mask, aligned, dist-ds-b16 +# NOTE: DS on b16 keeps every other i16 element from a 256-element source window. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMENTS = 2048 +ACTIVE_ELEMS = 1024 +LANES = 128 +SOURCE_WINDOW = 256 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(-(2**15), 2**15, size=(ELEMENTS,), dtype=np.int16) + v2 = np.zeros((ELEMENTS,), dtype=np.int16) + golden_v2 = np.zeros((ELEMENTS,), dtype=np.int16) + for offset in range(0, ACTIVE_ELEMS, LANES): + golden_v2[offset : offset + LANES] = v1[offset : offset + SOURCE_WINDOW : 2][:LANES] + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vlds b16 downsample validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-ds-b16/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vlds-ds-b16/kernel.pto new file mode 100644 index 000000000..c9dedb3c7 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-ds-b16/kernel.pto @@ -0,0 +1,50 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-ds-b16 +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-i16, full-mask, aligned, dist-ds-b16 +// ----------------------------------------------------------------------------- +// Validate one representative `DS_B16` load on `b16`. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vabs_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %out = pto.vlds %ub_in[%offset] {dist = "DS_B16"} : !pto.ptr -> !pto.vreg<128xi16> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<128xi16>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-ds-b16/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds-ds-b16/launch.cpp new file mode 100644 index 000000000..07ccb8b8d --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-ds-b16/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-ds-b16 +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-i16, full-mask, aligned, dist-ds-b16 +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream) { + vabs_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-ds-b16/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds-ds-b16/main.cpp new file mode 100644 index 000000000..951256acb --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-ds-b16/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-ds-b16 +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-i16, full-mask, aligned, dist-ds-b16 +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-tail/compare.py b/test/vpto/cases/micro-op/vector-load-store/vlds-tail/compare.py new file mode 100755 index 000000000..1f3503d9b --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-tail/compare.py @@ -0,0 +1,210 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vlds-tail +# family: vector-load-store +# target_ops: pto.vlds +# scenarios: core-f32, contiguous, tail-mask, aligned, dist-norm +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 +PREFIX_ELEMS = 1000 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin_prefix("golden_v2.bin", "v2.bin", np.float32, 0.0001, PREFIX_ELEMS) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-tail/golden.py b/test/vpto/cases/micro-op/vector-load-store/vlds-tail/golden.py new file mode 100755 index 000000000..7eca6a3ad --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-tail/golden.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vlds-tail +# family: vector-load-store +# target_ops: pto.vlds +# scenarios: core-f32, contiguous, tail-mask, aligned, dist-norm +# NOTE: tail-mask case writes the first 1000 f32 lanes and leaves the +# remaining lanes zero. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMENTS = 1024 +LOGICAL_ELEMS = 1000 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ELEMENTS,)).astype(np.float32) + v2 = np.zeros((ELEMENTS,), dtype=np.float32) + golden_v2 = np.zeros((ELEMENTS,), dtype=np.float32) + golden_v2[:LOGICAL_ELEMS] = v1[:LOGICAL_ELEMS] + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vlds tail validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-tail/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vlds-tail/kernel.pto new file mode 100644 index 000000000..8627110bb --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-tail/kernel.pto @@ -0,0 +1,70 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-tail +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-f32, contiguous, tail-mask, aligned, dist-norm +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vabs_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1000_i32 = arith.constant 1000 : i32 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1000_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %out = pto.vlds %ub_in[%offset] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-tail/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds-tail/launch.cpp new file mode 100644 index 000000000..dfbca2f61 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-tail/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-tail +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-f32, contiguous, tail-mask, aligned, dist-norm +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream) { + vabs_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-tail/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds-tail/main.cpp new file mode 100644 index 000000000..a9f049135 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-tail/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-tail +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-f32, contiguous, tail-mask, aligned, dist-norm +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-unpk-b16/compare.py b/test/vpto/cases/micro-op/vector-load-store/vlds-unpk-b16/compare.py new file mode 100644 index 000000000..3c394fd2d --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-unpk-b16/compare.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vlds-unpk-b16 +# family: vector-load-store +# target_ops: pto.vlds +# scenarios: core-f16, full-mask, aligned, dist-unpk-b16 + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + golden = np.fromfile(golden_path, dtype=np.uint16) + output = np.fromfile(output_path, dtype=np.uint16) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden.shape} vs {output.shape}") + return False + if not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden=0x{int(golden[idx]):04x}, out=0x{int(output[idx]):04x})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-unpk-b16/golden.py b/test/vpto/cases/micro-op/vector-load-store/vlds-unpk-b16/golden.py new file mode 100644 index 000000000..5d7850f1b --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-unpk-b16/golden.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vlds-unpk-b16 +# family: vector-load-store +# target_ops: pto.vlds +# scenarios: core-f16, full-mask, aligned, dist-unpk-b16 + +import argparse +from pathlib import Path + +import numpy as np + + +INPUT_ELEMS = 1024 +OUTPUT_ELEMS = 2048 +SRC_CHUNK = 64 +DST_CHUNK = 128 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + src = rng.uniform(-8.0, 8.0, size=INPUT_ELEMS).astype(np.float16) + dst = np.zeros((OUTPUT_ELEMS,), dtype=np.float16) + golden = np.zeros((OUTPUT_ELEMS,), dtype=np.float16) + + for src_base in range(0, INPUT_ELEMS, SRC_CHUNK): + dst_base = src_base * 2 + golden[dst_base : dst_base + DST_CHUNK : 2] = src[src_base : src_base + SRC_CHUNK] + + output_dir.mkdir(parents=True, exist_ok=True) + src.view(np.uint16).tofile(output_dir / "v1.bin") + dst.view(np.uint16).tofile(output_dir / "v2.bin") + golden.view(np.uint16).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vlds UNPK_B16 validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-unpk-b16/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vlds-unpk-b16/kernel.pto new file mode 100644 index 000000000..7a37d4a1e --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-unpk-b16/kernel.pto @@ -0,0 +1,56 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-unpk-b16 +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-f16, full-mask, aligned, dist-unpk-b16 +// ----------------------------------------------------------------------------- +// Validate one representative `UNPK_B16` load on `b16`. +// Installed A5 `TCvt.hpp::cast16to32` uses `vlds(..., UNPK_B16)` before +// `vcvt(..., PART_EVEN)`, so this case probes the resulting 128-lane `f16` +// layout directly through `vsts`. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vlds_unpk_b16_kernel_2d(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c2048_i64 = arith.constant 2048 : i64 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %full_mask = pto.pset_b16 "PAT_ALL" : !pto.mask + scf.for %src_offset = %c0 to %c1024 step %c64 { + %dst_offset = arith.muli %src_offset, %c2 : index + %out = pto.vlds %ub_in[%src_offset] {dist = "UNPK_B16"} : !pto.ptr -> !pto.vreg<128xf16> + pto.vsts %out, %ub_out[%dst_offset], %full_mask : !pto.vreg<128xf16>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-unpk-b16/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds-unpk-b16/launch.cpp new file mode 100644 index 000000000..b8544e34a --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-unpk-b16/launch.cpp @@ -0,0 +1,50 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-unpk-b16 +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-f16, full-mask, aligned, dist-unpk-b16 +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vlds_unpk_b16_kernel_2d(__gm__ half *v1, + __gm__ half *v2); + +void LaunchVlds_unpk_b16_kernel_2d(uint16_t *v1, uint16_t *v2, void *stream) { + vlds_unpk_b16_kernel_2d<<<1, nullptr, stream>>>((__gm__ half *)v1, + (__gm__ half *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-unpk-b16/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds-unpk-b16/main.cpp new file mode 100644 index 000000000..83260e90f --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-unpk-b16/main.cpp @@ -0,0 +1,102 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-unpk-b16 +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-f16, full-mask, aligned, dist-unpk-b16 +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVlds_unpk_b16_kernel_2d(uint16_t *v1, uint16_t *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(uint16_t); + size_t elemCount_v2 = 2048; + size_t fileSize_v2 = elemCount_v2 * sizeof(uint16_t); + uint16_t *v1Host = nullptr; + uint16_t *v1Device = nullptr; + uint16_t *v2Host = nullptr; + uint16_t *v2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVlds_unpk_b16_kernel_2d(v1Device, v2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-us-b16/compare.py b/test/vpto/cases/micro-op/vector-load-store/vlds-us-b16/compare.py new file mode 100755 index 000000000..5ccc50a39 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-us-b16/compare.py @@ -0,0 +1,210 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vlds-us-b16 +# family: vector-load-store +# target_ops: pto.vlds +# scenarios: core-i16, full-mask, aligned, dist-us-b16 +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 +PREFIX_ELEMS = 1024 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin_prefix("golden_v2.bin", "v2.bin", np.int16, 0.0, PREFIX_ELEMS) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-us-b16/golden.py b/test/vpto/cases/micro-op/vector-load-store/vlds-us-b16/golden.py new file mode 100755 index 000000000..214b0269c --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-us-b16/golden.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vlds-us-b16 +# family: vector-load-store +# target_ops: pto.vlds +# scenarios: core-i16, full-mask, aligned, dist-us-b16 +# NOTE: US on b16 duplicates each source i16 element into two consecutive lanes. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMENTS = 2048 +ACTIVE_ELEMS = 1024 +LANES = 128 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(-(2**15), 2**15, size=(ELEMENTS,), dtype=np.int16) + v2 = np.zeros((ELEMENTS,), dtype=np.int16) + golden_v2 = np.zeros((ELEMENTS,), dtype=np.int16) + half_lanes = LANES // 2 + for offset in range(0, ACTIVE_ELEMS, LANES): + golden_v2[offset : offset + LANES] = np.repeat(v1[offset : offset + half_lanes], 2) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vlds b16 upsample validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-us-b16/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vlds-us-b16/kernel.pto new file mode 100644 index 000000000..df43097fb --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-us-b16/kernel.pto @@ -0,0 +1,50 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-us-b16 +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-i16, full-mask, aligned, dist-us-b16 +// ----------------------------------------------------------------------------- +// Validate one representative `US_B16` load on `b16`. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vabs_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %out = pto.vlds %ub_in[%offset] {dist = "US_B16"} : !pto.ptr -> !pto.vreg<128xi16> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<128xi16>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-us-b16/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds-us-b16/launch.cpp new file mode 100644 index 000000000..4ecd1586e --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-us-b16/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-us-b16 +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-i16, full-mask, aligned, dist-us-b16 +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream) { + vabs_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds-us-b16/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds-us-b16/main.cpp new file mode 100644 index 000000000..2c4c9b679 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds-us-b16/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds-us-b16 +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-i16, full-mask, aligned, dist-us-b16 +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds/compare.py b/test/vpto/cases/micro-op/vector-load-store/vlds/compare.py new file mode 100755 index 000000000..1c07e2d7c --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vlds +# family: vector-load-store +# target_ops: pto.vlds +# scenarios: core-f32, contiguous, full-mask, aligned, dist-norm +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds/golden.py b/test/vpto/cases/micro-op/vector-load-store/vlds/golden.py new file mode 100755 index 000000000..21c58baab --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vlds +# family: vector-load-store +# target_ops: pto.vlds +# scenarios: core-f32, contiguous, full-mask, aligned, dist-norm +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = v1.astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vlds validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vlds/kernel.pto new file mode 100644 index 000000000..3ba1b9aa5 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds/kernel.pto @@ -0,0 +1,69 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-f32, contiguous, full-mask, aligned, dist-norm +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vabs_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %out = pto.vlds %ub_in[%offset] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> + pto.vsts %out, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds/launch.cpp new file mode 100644 index 000000000..2e2fa02fb --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-f32, contiguous, full-mask, aligned, dist-norm +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vabs_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream) { + vabs_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vlds/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vlds/main.cpp new file mode 100644 index 000000000..ab816737d --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vlds/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vlds +// family: vector-load-store +// target_ops: pto.vlds +// scenarios: core-f32, contiguous, full-mask, aligned, dist-norm +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vldsx2-layout-check/compare.py b/test/vpto/cases/micro-op/vector-load-store/vldsx2-layout-check/compare.py new file mode 100755 index 000000000..a4c5fae81 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldsx2-layout-check/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vldsx2-layout-check +# family: vector-load-store +# target_ops: pto.vldsx2 +# scenarios: core-f32, full-mask, paired-roundtrip, dintlv, lane-order +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vldsx2-layout-check/golden.py b/test/vpto/cases/micro-op/vector-load-store/vldsx2-layout-check/golden.py new file mode 100755 index 000000000..8c481d96b --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldsx2-layout-check/golden.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vldsx2-layout-check +# family: vector-load-store +# target_ops: pto.vldsx2 +# scenarios: core-f32, full-mask, dintlv, lane-order, split-observation +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +ACTIVE = 128 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.zeros((ROWS, COLS), dtype=np.float32).reshape(-1) + flat = v1.reshape(-1) + + # DINTLV_B32 exposes the two deinterleaved 64-lane results independently. + # Observe them through two plain NORM_B32 stores: + # low -> output[offset : offset + 64] + # high -> output[offset + 64 : offset + 128] + for base in range(0, ROWS * COLS, ACTIVE): + chunk = flat[base : base + ACTIVE] + golden_v2[base : base + 64] = chunk[0::2] + golden_v2[base + 64 : base + 128] = chunk[1::2] + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vldsx2 layout validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vldsx2-layout-check/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vldsx2-layout-check/kernel.pto new file mode 100644 index 000000000..3bb2b29c6 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldsx2-layout-check/kernel.pto @@ -0,0 +1,57 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vldsx2-layout-check +// family: vector-load-store +// target_ops: pto.vldsx2 +// scenarios: core-f32, full-mask, dintlv, lane-order, split-observation +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vldx2_layout_check_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c8 = arith.constant 8 : index + %c128 = arith.constant 128 : index + %c1_i64 = arith.constant 1 : i64 + %c0_i64 = arith.constant 0 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c64_i32 = arith.constant 64 : i32 + %false = arith.constant false + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %group = %c0 to %c8 step %c1 { + %group_base = arith.muli %group, %c128 : index + scf.for %chunk = %c0 to %c128 step %c128 { + %offset = arith.addi %group_base, %chunk : index + %high_offset = arith.addi %offset, %c64 : index + %mask, %remaining = pto.plt_b32 %c64_i32 : i32 -> !pto.mask, i32 + %x, %y = pto.vldsx2 %ub_in[%offset], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + pto.vsts %x, %ub_out[%offset], %mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + pto.vsts %y, %ub_out[%high_offset], %mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vldsx2-layout-check/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vldsx2-layout-check/launch.cpp new file mode 100644 index 000000000..d06dda18c --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldsx2-layout-check/launch.cpp @@ -0,0 +1,71 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vldsx2-layout-check +// family: vector-load-store +// target_ops: pto.vldsx2 +// scenarios: core-f32, full-mask, paired-roundtrip, dintlv, lane-order +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vldx2_layout_check_kernel(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVldx2_layout_check_kernel(float *v1, float *v2, void *stream) { + vldx2_layout_check_kernel<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vldsx2-layout-check/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vldsx2-layout-check/main.cpp new file mode 100644 index 000000000..45e578b57 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldsx2-layout-check/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vldsx2-layout-check +// family: vector-load-store +// target_ops: pto.vldsx2 +// scenarios: core-f32, full-mask, paired-roundtrip, dintlv, lane-order +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVldx2_layout_check_kernel(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVldx2_layout_check_kernel(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2-b8-f32/compare.py b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2-b8-f32/compare.py new file mode 100644 index 000000000..af950320b --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2-b8-f32/compare.py @@ -0,0 +1,201 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vldsx2-vstsx2-b8-f32 +# family: vector-load-store +# target_ops: pto.vldsx2, pto.vstsx2 +# scenarios: core-f32, full-mask, paired-roundtrip, dintlv-b8-intlv-b8, width-agnostic-dist +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2-b8-f32/golden.py b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2-b8-f32/golden.py new file mode 100644 index 000000000..6732c8799 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2-b8-f32/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vldsx2-vstsx2-b8-f32 +# family: vector-load-store +# target_ops: pto.vldsx2, pto.vstsx2 +# scenarios: core-f32, full-mask, paired-roundtrip, dintlv-b8-intlv-b8, width-agnostic-dist +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.array(v1, copy=True) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vldsx2-vstsx2-b8-f32 validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2-b8-f32/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2-b8-f32/kernel.pto new file mode 100644 index 000000000..c4346777f --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2-b8-f32/kernel.pto @@ -0,0 +1,54 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vldsx2-vstsx2-b8-f32 +// family: vector-load-store +// target_ops: pto.vldsx2, pto.vstsx2 +// scenarios: core-f32, full-mask, paired-roundtrip, dintlv-b8-intlv-b8, width-agnostic-dist +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vldx2_vstsx2_b8_f32_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c128 = arith.constant 128 : index + %c1_i64 = arith.constant 1 : i64 + %c0_i64 = arith.constant 0 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c64_i32 = arith.constant 64 : i32 + %false = arith.constant false + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %group = %c0 to %c8 step %c1 { + %group_base = arith.muli %group, %c128 : index + scf.for %chunk = %c0 to %c128 step %c128 { + %offset = arith.addi %group_base, %chunk : index + %mask, %remaining = pto.plt_b32 %c64_i32 : i32 -> !pto.mask, i32 + %low, %high = pto.vldsx2 %ub_in[%offset], "DINTLV_B8" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + pto.vstsx2 %low, %high, %ub_out[%offset], "INTLV_B8", %mask + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.ptr, index, + !pto.mask + } + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2-b8-f32/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2-b8-f32/launch.cpp new file mode 100644 index 000000000..beadc1f7e --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2-b8-f32/launch.cpp @@ -0,0 +1,50 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vldsx2-vstsx2-b8-f32 +// family: vector-load-store +// target_ops: pto.vldsx2, pto.vstsx2 +// scenarios: core-f32, full-mask, paired-roundtrip, dintlv-b8-intlv-b8, width-agnostic-dist +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vldx2_vstsx2_b8_f32_kernel(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVldx2_vstsx2_b8_f32_kernel(float *v1, float *v2, void *stream) { + vldx2_vstsx2_b8_f32_kernel<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2-b8-f32/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2-b8-f32/main.cpp new file mode 100644 index 000000000..61686d35b --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2-b8-f32/main.cpp @@ -0,0 +1,128 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vldsx2-vstsx2-b8-f32 +// family: vector-load-store +// target_ops: pto.vldsx2, pto.vstsx2 +// scenarios: core-f32, full-mask, paired-roundtrip, dintlv-b8-intlv-b8, width-agnostic-dist +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVldx2_vstsx2_b8_f32_kernel(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVldx2_vstsx2_b8_f32_kernel(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2/compare.py b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2/compare.py new file mode 100755 index 000000000..b28c98567 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vldsx2-vstsx2 +# family: vector-load-store +# target_ops: pto.vldsx2, pto.vstsx2 +# scenarios: core-f32, full-mask, paired-roundtrip, dintlv +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2/golden.py b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2/golden.py new file mode 100755 index 000000000..14d41a9a3 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2/golden.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vldsx2-vstsx2 +# family: vector-load-store +# target_ops: pto.vldsx2, pto.vstsx2 +# scenarios: core-f32, full-mask, paired-roundtrip, dintlv +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +ACTIVE = 128 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.array(v1, copy=True) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vldsx2-vstsx2 validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2/kernel.pto new file mode 100644 index 000000000..a51a84105 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2/kernel.pto @@ -0,0 +1,54 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vldsx2-vstsx2 +// family: vector-load-store +// target_ops: pto.vldsx2, pto.vstsx2 +// scenarios: core-f32, full-mask, paired-roundtrip, dintlv +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vldx2_vstsx2_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c128 = arith.constant 128 : index + %c1_i64 = arith.constant 1 : i64 + %c0_i64 = arith.constant 0 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c64_i32 = arith.constant 64 : i32 + %false = arith.constant false + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %group = %c0 to %c8 step %c1 { + %group_base = arith.muli %group, %c128 : index + scf.for %chunk = %c0 to %c128 step %c128 { + %offset = arith.addi %group_base, %chunk : index + %mask, %remaining = pto.plt_b32 %c64_i32 : i32 -> !pto.mask, i32 + %low, %high = pto.vldsx2 %ub_in[%offset], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + pto.vstsx2 %low, %high, %ub_out[%offset], "INTLV_B32", %mask + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.ptr, index, + !pto.mask + } + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2/launch.cpp new file mode 100644 index 000000000..d7e4c3fed --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2/launch.cpp @@ -0,0 +1,71 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vldsx2-vstsx2 +// family: vector-load-store +// target_ops: pto.vldsx2, pto.vstsx2 +// scenarios: core-f32, full-mask, paired-roundtrip, dintlv +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vldx2_vstsx2_kernel(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVldx2_vstsx2_kernel(float *v1, float *v2, void *stream) { + vldx2_vstsx2_kernel<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2/main.cpp new file mode 100644 index 000000000..ec0f59491 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vldsx2-vstsx2/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vldsx2-vstsx2 +// family: vector-load-store +// target_ops: pto.vldsx2, pto.vstsx2 +// scenarios: core-f32, full-mask, paired-roundtrip, dintlv +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVldx2_vstsx2_kernel(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVldx2_vstsx2_kernel(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vsldb/compare.py b/test/vpto/cases/micro-op/vector-load-store/vsldb/compare.py new file mode 100755 index 000000000..c47755b60 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsldb/compare.py @@ -0,0 +1,210 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vsldb +# family: vector-load-store +# target_ops: pto.vsldb +# scenarios: core-f32, full-mask, block-strided-load, block-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 +PREFIX_ELEMS = 64 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin_prefix("golden_v2.bin", "v2.bin", np.float32, 0.0001, PREFIX_ELEMS) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vsldb/golden.py b/test/vpto/cases/micro-op/vector-load-store/vsldb/golden.py new file mode 100755 index 000000000..3d8dad137 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsldb/golden.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vsldb +# family: vector-load-store +# target_ops: pto.vsldb +# scenarios: core-f32, full-mask, block-strided-load, block-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +BLOCK_STRIDE = 2 +REPEAT_STRIDE = 4 +BLOCK_ELEMS = 8 +BLOCK_COUNT = 8 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.zeros((ROWS, COLS), dtype=np.float32) + flat_in = v1.reshape(-1) + flat_golden = golden_v2.reshape(-1) + for blk in range(BLOCK_COUNT): + src_blk = REPEAT_STRIDE + blk * BLOCK_STRIDE + flat_golden[blk * BLOCK_ELEMS:(blk + 1) * BLOCK_ELEMS] = flat_in[ + src_blk * BLOCK_ELEMS:(src_blk + 1) * BLOCK_ELEMS + ] + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vsldb validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vsldb/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vsldb/kernel.pto new file mode 100644 index 000000000..839f56139 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsldb/kernel.pto @@ -0,0 +1,47 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vsldb +// family: vector-load-store +// target_ops: pto.vsldb +// scenarios: core-f32, full-mask, block-strided-load, block-mask +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vsldb_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1_i64 = arith.constant 1 : i64 + %c0_i64 = arith.constant 0 : i64 + %c2_i16 = arith.constant 2 : i16 + %c4_i16 = arith.constant 4 : i16 + %c32_i64 = arith.constant 32 : i64 + %c64_i32 = arith.constant 64 : i32 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + %c1 = arith.constant 1 : index + pto.vecscope { + scf.for %iv = %c0 to %c1 step %c1 { + %mask, %next_remaining = pto.plt_b32 %c64_i32 : i32 -> !pto.mask, i32 + %loaded = pto.vsldb %ub_in, %c2_i16, %c4_i16, %mask : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %loaded, %ub_out[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vsldb/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vsldb/launch.cpp new file mode 100644 index 000000000..fe71cc2b6 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsldb/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vsldb +// family: vector-load-store +// target_ops: pto.vsldb +// scenarios: core-f32, full-mask, block-strided-load, block-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vsldb_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream) { + vsldb_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vsldb/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vsldb/main.cpp new file mode 100644 index 000000000..f0b21ff83 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsldb/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vsldb +// family: vector-load-store +// target_ops: pto.vsldb +// scenarios: core-f32, full-mask, block-strided-load, block-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vsstb/compare.py b/test/vpto/cases/micro-op/vector-load-store/vsstb/compare.py new file mode 100755 index 000000000..cffa8ea8b --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsstb/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vsstb +# family: vector-load-store +# target_ops: pto.vsstb +# scenarios: core-f32, full-mask, block-strided-store, block-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vsstb/golden.py b/test/vpto/cases/micro-op/vector-load-store/vsstb/golden.py new file mode 100755 index 000000000..033a8b030 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsstb/golden.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vsstb +# family: vector-load-store +# target_ops: pto.vsstb +# scenarios: core-f32, full-mask, block-strided-store, block-mask +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 +BLOCK_STRIDE = 2 +REPEAT_STRIDE = 4 +BLOCK_ELEMS = 8 +BLOCK_COUNT = 8 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.array(v1, copy=True) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vsstb validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vsstb/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vsstb/kernel.pto new file mode 100644 index 000000000..df23d2b44 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsstb/kernel.pto @@ -0,0 +1,50 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vsstb +// family: vector-load-store +// target_ops: pto.vsstb +// scenarios: core-f32, full-mask, block-strided-store, block-mask +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vsstb_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1_i64 = arith.constant 1 : i64 + %c0_i64 = arith.constant 0 : i64 + %c2_i16 = arith.constant 2 : i16 + %c4_i16 = arith.constant 4 : i16 + %c32_i64 = arith.constant 32 : i64 + %c64_i32 = arith.constant 64 : i32 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + %c1 = arith.constant 1 : index + pto.vecscope { + scf.for %iv = %c0 to %c1 step %c1 { + %mask, %next_remaining = pto.plt_b32 %c64_i32 : i32 -> !pto.mask, i32 + %value = pto.vlds %ub_in[%c0] : !pto.ptr -> !pto.vreg<64xf32> + pto.vsstb %value, %ub_out, %c2_i16, %c4_i16, %mask : !pto.vreg<64xf32>, !pto.ptr, i16, i16, !pto.mask + pto.mem_bar "VST_VLD" + %roundtrip = pto.vsldb %ub_out, %c2_i16, %c4_i16, %mask : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %roundtrip, %ub_in[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_in, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vsstb/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vsstb/launch.cpp new file mode 100644 index 000000000..95d2a57bd --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsstb/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vsstb +// family: vector-load-store +// target_ops: pto.vsstb +// scenarios: core-f32, full-mask, block-strided-store, block-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vsstb_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream) { + vsstb_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vsstb/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vsstb/main.cpp new file mode 100644 index 000000000..72c683928 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsstb/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vsstb +// family: vector-load-store +// target_ops: pto.vsstb +// scenarios: core-f32, full-mask, block-strided-store, block-mask +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vstar/compare.py b/test/vpto/cases/micro-op/vector-load-store/vstar/compare.py new file mode 100755 index 000000000..3f233f6e6 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstar/compare.py @@ -0,0 +1,236 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vstar +# family: vector-load-store +# target_ops: pto.vstar +# scenarios: core-f32, full-mask, aligned, state-update +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 +CHECK_OFFSET = 1 +CHECK_COUNT = 8 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + golden = np.fromfile("golden_v2.bin", dtype=np.float32) if os.path.exists("golden_v2.bin") else None + output = np.fromfile("v2.bin", dtype=np.float32) if os.path.exists("v2.bin") else None + lo = CHECK_OFFSET + hi = CHECK_OFFSET + CHECK_COUNT + if output is None: + ok = False + print("[ERROR] Output missing: v2.bin") + elif golden is None: + ok = False + print("[ERROR] Golden missing: golden_v2.bin") + elif golden.size < hi or output.size < hi: + ok = False + print( + f"[ERROR] Flush slice too small: need={hi} elems, " + f"golden={golden.size}, out={output.size}" + ) + elif not np.allclose(golden[lo:hi], output[lo:hi], atol=0.0001, rtol=0.0001, equal_nan=True): + g = golden[lo:hi].astype(np.float64, copy=False) + o = output[lo:hi].astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + ok = False + print( + f"[ERROR] Mismatch (flush slice): golden_v2.bin vs v2.bin, max diff={float(abs_diff[idx])} " + f"at idx={lo + idx} (golden={g[idx]}, out={o[idx]}, dtype=float32)" + ) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vstar/golden.py b/test/vpto/cases/micro-op/vector-load-store/vstar/golden.py new file mode 100755 index 000000000..d1a1054ba --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstar/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vstar +# family: vector-load-store +# target_ops: pto.vstar +# scenarios: core-f32, predicate-squeezed, unaligned, state-update +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2.reshape(-1)[1:9] = v1.reshape(-1)[:8] + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vstar validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vstar/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vstar/kernel.pto new file mode 100644 index 000000000..8fa41490c --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstar/kernel.pto @@ -0,0 +1,63 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vstar +// family: vector-load-store +// target_ops: pto.vstar +// scenarios: core-f32, predicate-squeezed, unaligned, state-update +// ----------------------------------------------------------------------------- +// Validate the final flush step of a stateful store chain. +// The case keeps `pto.vstar` as the target op and uses the minimal required +// setup: +// 1. load one aligned vector from `%ub_in` +// 2. squeeze a small active prefix to prime `SPR SQZN` +// 3. prime one store-state carrier from unaligned `%ub_out` +// 4. issue one `pto.vstur ... "POST_UPDATE"` to create residual state +// 5. flush that residual state with `pto.vstar` +// This makes the observable payload come from `vstar` while keeping the chain +// contract valid per docs. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vstar_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c8_i32 = arith.constant 8 : i32 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1_elem = arith.constant 1 : index + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out1 = pto.addptr %ub_out, %c1_elem : !pto.ptr -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + pto.sprclr "AR" + scf.for %iter = %c0 to %c1 step %c1 { + %vec = pto.vlds %ub_in[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %mask, %unused = pto.plt_b32 %c8_i32 : i32 -> !pto.mask, i32 + %sqz = pto.vsqz %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %align0 = pto.init_align : !pto.align + %align1 = pto.vstur %align0, %sqz, %ub_out1, "POST_UPDATE" + : !pto.align, !pto.vreg<64xf32>, !pto.ptr -> !pto.align + pto.vstar %align1, %ub_out1 : !pto.align, !pto.ptr + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vstar/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vstar/launch.cpp new file mode 100644 index 000000000..6d4789d7e --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstar/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vstar +// family: vector-load-store +// target_ops: pto.vstar +// scenarios: core-f32, full-mask, aligned, state-update +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vstar_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVstar_kernel_2d(float *v1, float *v2, void *stream) { + vstar_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vstar/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vstar/main.cpp new file mode 100644 index 000000000..0f316a695 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstar/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vstar +// family: vector-load-store +// target_ops: pto.vstar +// scenarios: core-f32, full-mask, aligned, state-update +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVstar_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVstar_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vstas-vstus-offset-update/compare.py b/test/vpto/cases/micro-op/vector-load-store/vstas-vstus-offset-update/compare.py new file mode 100755 index 000000000..0916b067e --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstas-vstus-offset-update/compare.py @@ -0,0 +1,208 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vstas-vstus-offset-update +# family: vector-load-store +# target_ops: pto.vstas, pto.vstus +# scenarios: core-f32, full-mask, immediate-offset, state-update +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin_prefix("golden_v2.bin", "v2.bin", np.float32, 0.0001, 69) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vstas-vstus-offset-update/golden.py b/test/vpto/cases/micro-op/vector-load-store/vstas-vstus-offset-update/golden.py new file mode 100755 index 000000000..b1b68f800 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstas-vstus-offset-update/golden.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vstas-vstus-offset-update +# family: vector-load-store +# target_ops: pto.vstas, pto.vstus +# scenarios: core-f32, full-mask, immediate-offset, state-update +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMENTS = 1024 +VECTOR_LANES = 64 +POST_UPDATE_OFFSET_ELEMENTS = 3 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ELEMENTS,)).astype(np.float32) + v2 = np.zeros((ELEMENTS,), dtype=np.float32) + golden_v2 = np.zeros((ELEMENTS,), dtype=np.float32) + golden_v2[:POST_UPDATE_OFFSET_ELEMENTS] = v1[:POST_UPDATE_OFFSET_ELEMENTS] + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vstas/vstus chain validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vstas-vstus-offset-update/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vstas-vstus-offset-update/kernel.pto new file mode 100644 index 000000000..ac04a7d26 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstas-vstus-offset-update/kernel.pto @@ -0,0 +1,61 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vstas-vstus-offset-update +// family: vector-load-store +// target_ops: pto.vstas, pto.vstus +// scenarios: core-f32, full-mask, immediate-offset, state-update +// ----------------------------------------------------------------------------- +// Validate the state chain required by the plan: +// 1. prime a store-state carrier +// 2. issue one no-post `vstus` with a non-zero explicit offset +// 3. flush the residual state with `vstas` using the same explicit flush point +// The observable effect should match an unaligned store stream where `vstus` +// advances the stream by 3 f32 elements and leaves the buffered tail in +// `!pto.align`, then `vstas` commits that pending tail at the matching flush +// point identified by the original base plus the same scalar offset. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vstas_vstus_offset_update_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c3_i32 = arith.constant 3 : i32 + %c0_i32 = arith.constant 0 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_out, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %offset = %c0 to %c64 step %c64 { + %align0 = pto.init_align : !pto.align + %vec = pto.vlds %ub_in[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %align1 = pto.vstus %align0, %c3_i32, %vec, %ub_out + : !pto.align, i32, !pto.vreg<64xf32>, !pto.ptr -> !pto.align + pto.vstas %align1, %ub_out, %c3_i32 : !pto.align, !pto.ptr, i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vstas-vstus-offset-update/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vstas-vstus-offset-update/launch.cpp new file mode 100644 index 000000000..b395937e5 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstas-vstus-offset-update/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vstas-vstus-offset-update +// family: vector-load-store +// target_ops: pto.vstas, pto.vstus +// scenarios: core-f32, full-mask, immediate-offset, state-update +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vstas_vstus_offset_update_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVstasVstusOffsetUpdate_kernel_2d(float *v1, float *v2, void *stream) { + vstas_vstus_offset_update_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vstas-vstus-offset-update/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vstas-vstus-offset-update/main.cpp new file mode 100644 index 000000000..2d2a9469b --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstas-vstus-offset-update/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vstas-vstus-offset-update +// family: vector-load-store +// target_ops: pto.vstas, pto.vstus +// scenarios: core-f32, full-mask, immediate-offset, state-update +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVstasVstusOffsetUpdate_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVstasVstusOffsetUpdate_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-1pt-b16/compare.py b/test/vpto/cases/micro-op/vector-load-store/vsts-1pt-b16/compare.py new file mode 100755 index 000000000..d6d773550 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-1pt-b16/compare.py @@ -0,0 +1,258 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vsts-1pt-b16 +# family: vector-load-store +# target_ops: pto.vsts +# scenarios: core-i16, full-mask, aligned, dist-1pt-b16 +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 +ACTIVE_ELEMS = 1024 +LANES = 128 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def compare_1pt_positions(golden_path, output_path, dtype, active_elems, lanes): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + active_elems = int(active_elems) + lanes = int(lanes) + except Exception: + print(f"[ERROR] Invalid 1PT compare arguments: active_elems={active_elems} lanes={lanes}") + return False + if active_elems <= 0 or lanes <= 0: + print(f"[ERROR] Invalid 1PT compare arguments: active_elems={active_elems} lanes={lanes}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden.shape} vs {output.shape}") + return False + + positions = np.arange(0, active_elems, lanes, dtype=np.int64) + if positions.size == 0: + print("[ERROR] No 1PT positions selected") + return False + if positions[-1] >= golden.size: + print( + f"[ERROR] 1PT positions out of range: last={int(positions[-1])} size={golden.size}" + ) + return False + + golden_sel = golden[positions] + output_sel = output[positions] + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + pos = int(positions[idx]) + print( + f"[ERROR] Mismatch (1PT positions): idx={pos} " + f"golden={int(golden_sel[idx])} out={int(output_sel[idx])}" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_1pt_positions("golden_v2.bin", "v2.bin", np.int16, ACTIVE_ELEMS, LANES) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-1pt-b16/golden.py b/test/vpto/cases/micro-op/vector-load-store/vsts-1pt-b16/golden.py new file mode 100755 index 000000000..1ed2947c7 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-1pt-b16/golden.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vsts-1pt-b16 +# family: vector-load-store +# target_ops: pto.vsts +# scenarios: core-i16, full-mask, aligned, dist-1pt-b16 +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMENTS = 2048 +ACTIVE_ELEMS = 1024 +LANES = 128 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(-(2**15), 2**15, size=(ELEMENTS,), dtype=np.int16) + v2 = np.zeros((ELEMENTS,), dtype=np.int16) + golden_v2 = np.zeros((ELEMENTS,), dtype=np.int16) + for offset in range(0, ACTIVE_ELEMS, LANES): + golden_v2[offset] = v1[offset] + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vsts 1PT validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-1pt-b16/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vsts-1pt-b16/kernel.pto new file mode 100644 index 000000000..8bf7ce3f9 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-1pt-b16/kernel.pto @@ -0,0 +1,50 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vsts-1pt-b16 +// family: vector-load-store +// target_ops: pto.vsts +// scenarios: core-i16, full-mask, aligned, dist-1pt-b16 +// ----------------------------------------------------------------------------- +// Validate one representative `1PT_B16` store distribution on `b16`. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vsts_1pt_b16_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<128xi16> + pto.vsts %vec, %ub_out[%offset], %mask {dist = "1PT_B16"} : !pto.vreg<128xi16>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-1pt-b16/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vsts-1pt-b16/launch.cpp new file mode 100644 index 000000000..3514bffb8 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-1pt-b16/launch.cpp @@ -0,0 +1,71 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vsts +// family: vector-load-store +// target_ops: pto.vsts +// scenarios: core-f32, contiguous, full-mask, aligned, dist-norm +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vsts_1pt_b16_kernel(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream) { + vsts_1pt_b16_kernel<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-1pt-b16/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vsts-1pt-b16/main.cpp new file mode 100644 index 000000000..6bc7026e2 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-1pt-b16/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vsts +// family: vector-load-store +// target_ops: pto.vsts +// scenarios: core-f32, contiguous, full-mask, aligned, dist-norm +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b16/compare.py b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b16/compare.py new file mode 100755 index 000000000..058c478a5 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b16/compare.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vsts-pk-b16 +# family: vector-load-store +# target_ops: pto.vsts +# scenarios: core-i16, full-mask, aligned, dist-pk-b16 +# coding=utf-8 + +import os +import sys +import numpy as np + +OUTPUT_BUFFER_BYTES = 4096 +# Keep this aligned with kernel.pto loop bound (offset: 0..1024 step 128 on i16). +ACTIVE_ELEMS = 1024 +LANES = 128 +BYTES_PER_ELEM = 2 + + +def build_checked_mask(total_bytes): + # For this case kernel: + # - loop offset: 0..1024 step 128 (i16 elements) + # - dist=PK_B16 stores 1 byte per active i16 element + # So each iteration writes 128 bytes at dst_byte_base = offset * 2. + mask = np.zeros((total_bytes,), dtype=bool) + for offset in range(0, ACTIVE_ELEMS, LANES): + dst_byte_base = offset * BYTES_PER_ELEM + mask[dst_byte_base : dst_byte_base + LANES] = True + return mask + + +def compare_bin(golden_path, output_path): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden.shape} vs {output.shape}") + return False + + if golden.size != OUTPUT_BUFFER_BYTES: + print( + f"[ERROR] Unexpected byte size for this case: got {golden.size}, expected {OUTPUT_BUFFER_BYTES}" + ) + return False + + checked = build_checked_mask(golden.size) + checked_golden = golden[checked] + checked_output = output[checked] + if not np.array_equal(checked_golden, checked_output): + diff = np.nonzero(checked_golden != checked_output)[0] + idx = int(diff[0]) if diff.size else 0 + global_idx = int(np.nonzero(checked)[0][idx]) if diff.size else 0 + print( + f"[ERROR] Mismatch (checked footprint): {golden_path} vs {output_path}, " + f"first diff at checked_idx={idx}, global_idx={global_idx} " + f"(golden=0x{int(checked_golden[idx]):02x}, out=0x{int(checked_output[idx]):02x})" + ) + return False + print( + f"[INFO] compared writable footprint only: {int(np.count_nonzero(checked))}/{golden.size} bytes" + ) + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin") + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b16/golden.py b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b16/golden.py new file mode 100755 index 000000000..b0ebf667c --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b16/golden.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vsts-pk-b16 +# family: vector-load-store +# target_ops: pto.vsts +# scenarios: core-i16, full-mask, aligned, dist-pk-b16 +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +OUTPUT_BUFFER_BYTES = 4096 +TOTAL_ELEMS_I16 = OUTPUT_BUFFER_BYTES // 2 +# This case kernel only iterates 0..1024 on i16 lanes, so only 1024 packed bytes +# are semantically writable by vsts(pk_b16) in this testcase. +ACTIVE_ELEMS = 1024 +LANES = 128 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.integers(-(2**15), 2**15, size=(TOTAL_ELEMS_I16,), dtype=np.int16) + v2 = rng.integers(0, 256, size=(OUTPUT_BUFFER_BYTES,), dtype=np.uint8) + golden_v2 = v2.copy() + + # PK_B16: write low 8 bits of each active b16 element as a compact byte stream. + # Destination address is unchanged for non-post-update form; within each 256B + # lane chunk only the first 128B are overwritten. + v1_u16 = v1.view(np.uint16) + packed_bytes_per_chunk = LANES + for offset in range(0, ACTIVE_ELEMS, LANES): + src = v1_u16[offset : offset + LANES] + packed = (src & 0x00FF).astype(np.uint8) + dst_byte_base = offset * 2 + golden_v2[dst_byte_base : dst_byte_base + packed_bytes_per_chunk] = packed + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vsts PK_B16 validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b16/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b16/kernel.pto new file mode 100644 index 000000000..abe05b2d3 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b16/kernel.pto @@ -0,0 +1,50 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vsts-pk-b16 +// family: vector-load-store +// target_ops: pto.vsts +// scenarios: core-i16, full-mask, aligned, dist-pk-b16 +// ----------------------------------------------------------------------------- +// Validate one representative `PK_B16` store distribution on `b16`. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vsts_pk_b16_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b16 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<128xi16> + pto.vsts %vec, %ub_out[%offset], %mask {dist = "PK_B16"} : !pto.vreg<128xi16>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b16/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b16/launch.cpp new file mode 100644 index 000000000..9a902908c --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b16/launch.cpp @@ -0,0 +1,71 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vsts +// family: vector-load-store +// target_ops: pto.vsts +// scenarios: core-f32, contiguous, full-mask, aligned, dist-norm +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vsts_pk_b16_kernel(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream) { + vsts_pk_b16_kernel<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b16/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b16/main.cpp new file mode 100644 index 000000000..6bc7026e2 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b16/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vsts +// family: vector-load-store +// target_ops: pto.vsts +// scenarios: core-f32, contiguous, full-mask, aligned, dist-norm +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b64-f32/compare.py b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b64-f32/compare.py new file mode 100644 index 000000000..4c19eb038 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b64-f32/compare.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + + +import os +import sys + +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(golden_path) or not os.path.exists(output_path): + return False + golden = np.fromfile(golden_path, dtype=dtype) + output = np.fromfile(output_path, dtype=dtype) + return golden.shape == output.shape and np.allclose( + golden, output, atol=eps, rtol=eps, equal_nan=True + ) + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b64-f32/golden.py b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b64-f32/golden.py new file mode 100644 index 000000000..c5635db7c --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b64-f32/golden.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vsts-pk-b64-f32 +# family: vector-load-store +# target_ops: pto.vsts +# scenarios: core-f32, full-mask, aligned, dist-pk-b64, width-agnostic-dist + +import argparse +from pathlib import Path + +import numpy as np + + +ELEMENTS = 1024 +LANES = 64 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ELEMENTS,)).astype(np.float32) + v2 = rng.uniform(-8.0, 8.0, size=(ELEMENTS,)).astype(np.float32) + golden_v2 = np.array(v2, copy=True) + + for offset in range(0, ELEMENTS, LANES): + chunk = v1[offset : offset + LANES] + packed = chunk[0::2] + golden_v2[offset : offset + packed.size] = packed + + output_dir.mkdir(parents=True, exist_ok=True) + v1.tofile(output_dir / "v1.bin") + v2.tofile(output_dir / "v2.bin") + golden_v2.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vsts PK_B64 on f32 validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b64-f32/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b64-f32/kernel.pto new file mode 100644 index 000000000..315724f96 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b64-f32/kernel.pto @@ -0,0 +1,50 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vsts-pk-b64-f32 +// family: vector-load-store +// target_ops: pto.vsts +// scenarios: core-f32, full-mask, aligned, dist-pk-b64, width-agnostic-dist +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vsts_pk_b64_f32_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %arg1, %ub_out, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + pto.vsts %vec, %ub_out[%offset], %mask {dist = "PK_B64"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b64-f32/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b64-f32/launch.cpp new file mode 100644 index 000000000..039f11e40 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b64-f32/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vsts_pk_b64_f32_kernel(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVsts_pk_b64_f32_kernel(float *v1, float *v2, void *stream) { + vsts_pk_b64_f32_kernel<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b64-f32/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b64-f32/main.cpp new file mode 100644 index 000000000..707a88e05 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-pk-b64-f32/main.cpp @@ -0,0 +1,122 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVsts_pk_b64_f32_kernel(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVsts_pk_b64_f32_kernel(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-tail/compare.py b/test/vpto/cases/micro-op/vector-load-store/vsts-tail/compare.py new file mode 100755 index 000000000..1821ec6aa --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-tail/compare.py @@ -0,0 +1,257 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vsts-tail +# family: vector-load-store +# target_ops: pto.vsts +# scenarios: core-f32, contiguous, tail-mask, aligned, dist-norm +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_window(golden_path, output_path, dtype, eps, offset, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + offset = int(offset) + count = int(count) + except Exception: + print(f"[ERROR] Invalid compare window: offset={offset} count={count}") + return False + if offset < 0 or count <= 0: + print(f"[ERROR] Invalid compare window: offset={offset} count={count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + end = offset + count + if golden.size < end or output.size < end: + print( + f"[ERROR] Compare window out of range: offset={offset} count={count}, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[offset:end] + output_sel = output[offset:end] + if not np.allclose(golden_sel, output_sel, atol=eps, rtol=eps, equal_nan=True): + if golden_sel.size: + g = golden_sel.astype(np.float64, copy=False) + o = output_sel.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (window): {golden_path} vs {output_path}, max diff={diff} " + f"at idx={offset + idx} (golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, " + f"offset={offset}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (window): {golden_path} vs {output_path}, empty window, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin_window("golden_v2.bin", "v2.bin", np.float32, 0.0001, 0, 13) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-tail/golden.py b/test/vpto/cases/micro-op/vector-load-store/vsts-tail/golden.py new file mode 100755 index 000000000..73bd90f99 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-tail/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vsts-tail +# family: vector-load-store +# target_ops: pto.vsts +# scenarios: core-f32, contiguous, tail-mask, aligned, dist-norm +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +ACTIVE = 13 +SEED = 19 +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2.reshape(-1)[:ACTIVE] = v1.reshape(-1)[:ACTIVE] + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vsts-tail validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-tail/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vsts-tail/kernel.pto new file mode 100644 index 000000000..901e975f4 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-tail/kernel.pto @@ -0,0 +1,45 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vsts-tail +// family: vector-load-store +// target_ops: pto.vsts +// scenarios: core-f32, contiguous, tail-mask, aligned, dist-norm +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vsts_tail_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1_i64 = arith.constant 1 : i64 + %c0_i64 = arith.constant 0 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c13_i32 = arith.constant 13 : i32 + %false = arith.constant false + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + %c1 = arith.constant 1 : index + pto.vecscope { + scf.for %iv = %c0 to %c1 step %c1 { + %mask, %remaining = pto.plt_b32 %c13_i32 : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%c0] : !pto.ptr -> !pto.vreg<64xf32> + pto.vsts %vec, %ub_out[%c0], %mask + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-tail/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vsts-tail/launch.cpp new file mode 100644 index 000000000..2a94e832d --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-tail/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vsts-tail +// family: vector-load-store +// target_ops: pto.vsts +// scenarios: core-f32, contiguous, tail-mask, aligned, dist-norm +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vsts_tail_kernel(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream) { + vsts_tail_kernel<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts-tail/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vsts-tail/main.cpp new file mode 100644 index 000000000..f8da4c77a --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts-tail/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vsts-tail +// family: vector-load-store +// target_ops: pto.vsts +// scenarios: core-f32, contiguous, tail-mask, aligned, dist-norm +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts/compare.py b/test/vpto/cases/micro-op/vector-load-store/vsts/compare.py new file mode 100755 index 000000000..dc064cb22 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts/compare.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vsts +# family: vector-load-store +# target_ops: pto.vsts +# scenarios: core-f32, contiguous, full-mask, aligned, dist-norm +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin("golden_v2.bin", "v2.bin", np.float32, 0.0001) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts/golden.py b/test/vpto/cases/micro-op/vector-load-store/vsts/golden.py new file mode 100755 index 000000000..9eb6e0453 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vsts +# family: vector-load-store +# target_ops: pto.vsts +# scenarios: core-f32, contiguous, full-mask, aligned, dist-norm +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = v1.astype(np.float32, copy=False) + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vsts validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vsts/kernel.pto new file mode 100644 index 000000000..8133b4e5f --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts/kernel.pto @@ -0,0 +1,69 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vsts +// family: vector-load-store +// target_ops: pto.vsts +// scenarios: core-f32, contiguous, full-mask, aligned, dist-norm +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// ============================================================================= +// abs_kernel_2d: Element-wise absolute value on a 32x32 f32 tile +// ============================================================================= +// This kernel computes abs(input) for a 32x32 float32 matrix (1024 elements). +// +// Memory Layout: +// - Input: arg0 -> GM (Global Memory) +// - Output: arg1 -> GM (Global Memory) +// - UB (Unified Buffer) at offset 0: input tile (4096 bytes = 32*32*4) +// - UB at offset 4096: output tile (4096 bytes = 32*32*4) +// +// Pipeline: +// 1. DMA: GM -> UB (MTE2 pipe) - copy input tile to UB +// 2. Sync: wait for MTE2 -> V pipe handoff +// 3. Compute: vabs on 64-element vectors (V pipe) - 16 iterations for 1024 elements +// 4. Sync: wait for V -> MTE3 pipe handoff +// 5. DMA: UB -> GM (MTE3 pipe) - copy result tile back to GM +// ============================================================================= + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vsts_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c1024_i32 = arith.constant 1024 : i32 + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c64 iter_args(%remaining = %c1024_i32) -> (i32) { + %mask, %next_remaining = pto.plt_b32 %remaining : i32 -> !pto.mask, i32 + %vec = pto.vlds %ub_in[%offset] : !pto.ptr -> !pto.vreg<64xf32> + pto.vsts %vec, %ub_out[%offset], %mask {dist = "NORM_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %next_remaining : i32 + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vsts/launch.cpp new file mode 100644 index 000000000..851e10299 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vsts +// family: vector-load-store +// target_ops: pto.vsts +// scenarios: core-f32, contiguous, full-mask, aligned, dist-norm +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vsts_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream) { + vsts_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vsts/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vsts/main.cpp new file mode 100644 index 000000000..6bc7026e2 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vsts/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vsts +// family: vector-load-store +// target_ops: pto.vsts +// scenarios: core-f32, contiguous, full-mask, aligned, dist-norm +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVabs_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVabs_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vstsx2-layout-check/compare.py b/test/vpto/cases/micro-op/vector-load-store/vstsx2-layout-check/compare.py new file mode 100755 index 000000000..b2a31f90e --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstsx2-layout-check/compare.py @@ -0,0 +1,210 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vstsx2-layout-check +# family: vector-load-store +# target_ops: pto.vstsx2 +# scenarios: core-f32, full-mask, paired-roundtrip, dintlv, lane-order +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 +PREFIX_ELEMS = 128 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin_prefix("golden_v2.bin", "v2.bin", np.float32, 0.0001, PREFIX_ELEMS) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vstsx2-layout-check/golden.py b/test/vpto/cases/micro-op/vector-load-store/vstsx2-layout-check/golden.py new file mode 100755 index 000000000..24665fcf7 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstsx2-layout-check/golden.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vstsx2-layout-check +# family: vector-load-store +# target_ops: pto.vstsx2 +# scenarios: core-f32, full-mask, paired-roundtrip, dintlv, lane-order +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +ACTIVE = 128 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.zeros((ROWS, COLS), dtype=np.float32) + flat_in = v1.reshape(-1) + flat_golden = golden_v2.reshape(-1) + flat_golden[:ACTIVE:2] = flat_in[:64] + flat_golden[1:ACTIVE:2] = flat_in[64:128] + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vstsx2 layout validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vstsx2-layout-check/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vstsx2-layout-check/kernel.pto new file mode 100644 index 000000000..728a52839 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstsx2-layout-check/kernel.pto @@ -0,0 +1,49 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vstsx2-layout-check +// family: vector-load-store +// target_ops: pto.vstsx2 +// scenarios: core-f32, full-mask, paired-roundtrip, dintlv, lane-order +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vstsx2_layout_check_kernel(%arg0: !pto.ptr, + %arg1: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1_i64 = arith.constant 1 : i64 + %c0_i64 = arith.constant 0 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c64_i32 = arith.constant 64 : i32 + %false = arith.constant false + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + %c1 = arith.constant 1 : index + pto.vecscope { + scf.for %iv = %c0 to %c1 step %c1 { + %mask, %remaining = pto.plt_b32 %c64_i32 : i32 -> !pto.mask, i32 + %x = pto.vlds %ub_in[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %y = pto.vlds %ub_in[%c64] : !pto.ptr -> !pto.vreg<64xf32> + pto.vstsx2 %x, %y, %ub_out[%c0], "INTLV_B32", %mask + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.ptr, index, + !pto.mask + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vstsx2-layout-check/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vstsx2-layout-check/launch.cpp new file mode 100644 index 000000000..bf28f8403 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstsx2-layout-check/launch.cpp @@ -0,0 +1,71 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vstsx2-layout-check +// family: vector-load-store +// target_ops: pto.vstsx2 +// scenarios: core-f32, full-mask, paired-roundtrip, dintlv, lane-order +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vstsx2_layout_check_kernel(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVstsx2_layout_check_kernel(float *v1, float *v2, void *stream) { + vstsx2_layout_check_kernel<<<1, nullptr, stream>>>((__gm__ float *)v1, + (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vstsx2-layout-check/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vstsx2-layout-check/main.cpp new file mode 100644 index 000000000..1e380b6b2 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstsx2-layout-check/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vstsx2-layout-check +// family: vector-load-store +// target_ops: pto.vstsx2 +// scenarios: core-f32, full-mask, paired-roundtrip, dintlv, lane-order +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVstsx2_layout_check_kernel(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVstsx2_layout_check_kernel(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vstur-init-align-outside-loop/compare.py b/test/vpto/cases/micro-op/vector-load-store/vstur-init-align-outside-loop/compare.py new file mode 100644 index 000000000..fde3a5229 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstur-init-align-outside-loop/compare.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vstur-init-align-outside-loop +# family: vector-load-store +# target_ops: pto.vstur +# scenarios: core-f32, full-mask, unaligned, state-update, init-align-outside-loop +# coding=utf-8 + +import os +import sys +import numpy as np + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_window(golden_path, output_path, dtype, eps, offset, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + offset = int(offset) + count = int(count) + except Exception: + print(f"[ERROR] Invalid compare window: offset={offset} count={count}") + return False + if offset < 0 or count <= 0: + print(f"[ERROR] Invalid compare window: offset={offset} count={count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + end = offset + count + if golden.size < end or output.size < end: + print( + f"[ERROR] Compare window out of range: offset={offset} count={count}, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[offset:end] + output_sel = output[offset:end] + if not np.allclose(golden_sel, output_sel, atol=eps, rtol=eps, equal_nan=True): + if golden_sel.size: + g = golden_sel.astype(np.float64, copy=False) + o = output_sel.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (window): {golden_path} vs {output_path}, max diff={diff} " + f"at idx={offset + idx} (golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, " + f"offset={offset}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (window): {golden_path} vs {output_path}, empty window, dtype={dtype_np}") + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = compare_bin_window("golden_v2.bin", "v2.bin", np.float32, 0.0001, 1, 8) + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vstur-init-align-outside-loop/golden.py b/test/vpto/cases/micro-op/vector-load-store/vstur-init-align-outside-loop/golden.py new file mode 100644 index 000000000..d13ca8097 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstur-init-align-outside-loop/golden.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vstur-init-align-outside-loop +# family: vector-load-store +# target_ops: pto.vstur +# scenarios: core-f32, predicate-squeezed, unaligned, state-update, init-align-outside-loop +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +ACTIVE_LANES = 8 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2.reshape(-1)[1 : 1 + ACTIVE_LANES] = v1.reshape(-1)[:ACTIVE_LANES] + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vstur-init-align-outside-loop validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vstur-init-align-outside-loop/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vstur-init-align-outside-loop/kernel.pto new file mode 100644 index 000000000..0b07a956c --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstur-init-align-outside-loop/kernel.pto @@ -0,0 +1,53 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vstur-init-align-outside-loop +// family: vector-load-store +// target_ops: pto.vstur +// scenarios: core-f32, predicate-squeezed, unaligned, state-update, init-align-outside-loop +// ----------------------------------------------------------------------------- +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vstur_init_align_outside_loop_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c1 = arith.constant 1 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c8_i32 = arith.constant 8 : i32 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c0 = arith.constant 0 : index + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out1 = pto.addptr %ub_out, %c1 : !pto.ptr -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + pto.sprclr "AR" + %align0 = pto.init_align : !pto.align + %align_final = scf.for %offset = %c0 to %c1 step %c1 + iter_args(%align_iter = %align0) -> (!pto.align) { + %vec = pto.vlds %ub_in[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %mask, %unused = pto.plt_b32 %c8_i32 : i32 -> !pto.mask, i32 + %sqz = pto.vsqz %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %align1 = pto.vstur %align_iter, %sqz, %ub_out1, "POST_UPDATE" + : !pto.align, !pto.vreg<64xf32>, !pto.ptr -> !pto.align + scf.yield %align1 : !pto.align + } + pto.vstar %align_final, %ub_out1 : !pto.align, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vstur-init-align-outside-loop/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vstur-init-align-outside-loop/launch.cpp new file mode 100644 index 000000000..a56f83c6e --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstur-init-align-outside-loop/launch.cpp @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vstur-init-align-outside-loop +// family: vector-load-store +// target_ops: pto.vstur +// scenarios: core-f32, full-mask, unaligned, state-update, init-align-outside-loop +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vstur_init_align_outside_loop_kernel_2d(__gm__ float *v1, __gm__ float *v2); + +void LaunchVstur_init_align_outside_loop_kernel_2d(float *v1, float *v2, + void *stream) { + vstur_init_align_outside_loop_kernel_2d<<<1, nullptr, stream>>>( + (__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vstur-init-align-outside-loop/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vstur-init-align-outside-loop/main.cpp new file mode 100644 index 000000000..486ca9862 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstur-init-align-outside-loop/main.cpp @@ -0,0 +1,129 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vstur-init-align-outside-loop +// family: vector-load-store +// target_ops: pto.vstur +// scenarios: core-f32, full-mask, unaligned, state-update, init-align-outside-loop +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVstur_init_align_outside_loop_kernel_2d(float *v1, float *v2, + void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVstur_init_align_outside_loop_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + return rc; +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vstur/compare.py b/test/vpto/cases/micro-op/vector-load-store/vstur/compare.py new file mode 100755 index 000000000..80b4dab8a --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstur/compare.py @@ -0,0 +1,257 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vstur +# family: vector-load-store +# target_ops: pto.vstur +# scenarios: core-f32, full-mask, unaligned, state-update +# NOTE: bulk-generated coverage skeleton. +# coding=utf-8 + +import os +import sys +import numpy as np + + +REPEAT_BYTES = 256 + + +def _ceil_div(x, y): + return (x + y - 1) // y + + +def _packed_pred_storage_bytes(logical_elems, src_elem_bytes): + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + if logical_elems <= 0: + raise ValueError(f"logical_elems must be > 0, got {logical_elems}") + if src_elem_bytes not in (1, 2, 4): + raise ValueError(f"unsupported packed predicate source size: {src_elem_bytes}") + + repeat_elems = REPEAT_BYTES // src_elem_bytes + if src_elem_bytes == 4: + repeat_times = _ceil_div(logical_elems, repeat_elems) + 1 + loop_count = repeat_times // 2 + return loop_count * 16 + + repeat_times = _ceil_div(logical_elems, repeat_elems) + return repeat_times * (repeat_elems // 8) + + +def compare_bin(golden_path, output_path, dtype, eps): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + if golden.shape != output.shape: + print(f"[ERROR] Shape mismatch: {golden_path} {golden.shape} vs {output_path} {output.shape}") + return False + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch: {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np})" + ) + else: + print(f"[ERROR] Mismatch: {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_prefix(golden_path, output_path, dtype, eps, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + count = int(count) + except Exception: + print(f"[ERROR] Invalid prefix count: {count}") + return False + if count <= 0: + print(f"[ERROR] Invalid prefix count: {count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np, count=count) + output = np.fromfile(output_path, dtype=dtype_np, count=count) + + if golden.size != count or output.size != count: + print( + f"[ERROR] Prefix read too small: need={count} elems, " + f"golden={golden.size}, out={output.size}" + ) + return False + + if not np.allclose(golden, output, atol=eps, rtol=eps, equal_nan=True): + if golden.size: + if np.issubdtype(dtype_np, np.floating): + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + elif np.issubdtype(dtype_np, np.integer) or np.issubdtype(dtype_np, np.unsignedinteger): + g = golden.astype(np.int64, copy=False) + o = output.astype(np.int64, copy=False) + else: + g = golden.astype(np.float64, copy=False) + o = output.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, max diff={diff} at idx={idx} " + f"(golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (prefix): {golden_path} vs {output_path}, empty buffers, dtype={dtype_np}") + return False + return True + + +def compare_bin_window(golden_path, output_path, dtype, eps, offset, count): + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + offset = int(offset) + count = int(count) + except Exception: + print(f"[ERROR] Invalid compare window: offset={offset} count={count}") + return False + if offset < 0 or count <= 0: + print(f"[ERROR] Invalid compare window: offset={offset} count={count}") + return False + + dtype_np = np.dtype(dtype) + golden = np.fromfile(golden_path, dtype=dtype_np) + output = np.fromfile(output_path, dtype=dtype_np) + end = offset + count + if golden.size < end or output.size < end: + print( + f"[ERROR] Compare window out of range: offset={offset} count={count}, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[offset:end] + output_sel = output[offset:end] + if not np.allclose(golden_sel, output_sel, atol=eps, rtol=eps, equal_nan=True): + if golden_sel.size: + g = golden_sel.astype(np.float64, copy=False) + o = output_sel.astype(np.float64, copy=False) + abs_diff = np.abs(g - o) + idx = int(np.argmax(abs_diff)) + diff = float(abs_diff[idx]) + print( + f"[ERROR] Mismatch (window): {golden_path} vs {output_path}, max diff={diff} " + f"at idx={offset + idx} (golden={g[idx]}, out={o[idx]}, dtype={dtype_np}, " + f"offset={offset}, count={count})" + ) + else: + print(f"[ERROR] Mismatch (window): {golden_path} vs {output_path}, empty window, dtype={dtype_np}") + return False + return True + + +def compare_packed_pred_mask(golden_path, output_path, logical_elems, src_elem_bytes): + """ + Compare outputs of pto.tcmp / pto.tcmps. + + PTO-ISA stores packed predicate results as a linear PK byte stream via + `psts`, with the exact written prefix length determined by the typed + TCMP/TCMPS repeat schedule. Compare only that semantic prefix. + """ + if not os.path.exists(output_path): + print(f"[ERROR] Output missing: {output_path}") + return False + if not os.path.exists(golden_path): + print(f"[ERROR] Golden missing: {golden_path}") + return False + try: + logical_elems = int(logical_elems) + src_elem_bytes = int(src_elem_bytes) + except Exception: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + if logical_elems <= 0 or src_elem_bytes <= 0: + print( + "[ERROR] Invalid packed mask compare arguments: " + f"logical_elems={logical_elems} src_elem_bytes={src_elem_bytes}" + ) + return False + + golden = np.fromfile(golden_path, dtype=np.uint8) + output = np.fromfile(output_path, dtype=np.uint8) + try: + prefix_bytes = _packed_pred_storage_bytes(logical_elems, src_elem_bytes) + except ValueError as exc: + print(f"[ERROR] {exc}") + return False + + if golden.size < prefix_bytes or output.size < prefix_bytes: + print( + f"[ERROR] Packed mask buffer too small: need={prefix_bytes} bytes, " + f"golden={golden.size}, out={output.size}" + ) + return False + + golden_sel = golden[:prefix_bytes] + output_sel = output[:prefix_bytes] + + if not np.array_equal(golden_sel, output_sel): + diff = np.nonzero(golden_sel != output_sel)[0] + idx = int(diff[0]) if diff.size else 0 + print( + f"[ERROR] Mismatch (packed mask): {golden_path} vs {output_path}, first diff at idx={idx} " + f"(golden={int(golden_sel[idx])}, out={int(output_sel[idx])}, " + f"logical_elems={logical_elems}, src_elem_bytes={src_elem_bytes}, prefix_bytes={prefix_bytes})" + ) + return False + return True + + +def main(): + strict = os.getenv("COMPARE_STRICT", "1") != "0" + ok = True + ok = compare_bin_window("golden_v2.bin", "v2.bin", np.float32, 0.0001, 1, 8) and ok + if not ok: + if strict: + print("[ERROR] compare failed") + sys.exit(2) + print("[WARN] compare failed (non-gating)") + return + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vstur/golden.py b/test/vpto/cases/micro-op/vector-load-store/vstur/golden.py new file mode 100755 index 000000000..96b3c4030 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstur/golden.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# case: micro-op/vector-load-store/vstur +# family: vector-load-store +# target_ops: pto.vstur +# scenarios: core-f32, predicate-squeezed, unaligned, state-update +# coding=utf-8 + +import argparse +from pathlib import Path + +import numpy as np + + +ROWS = 32 +COLS = 32 +ACTIVE_LANES = 8 +SEED = 19 + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + v1 = rng.uniform(-8.0, 8.0, size=(ROWS, COLS)).astype(np.float32) + v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2 = np.zeros((ROWS, COLS), dtype=np.float32) + golden_v2.reshape(-1)[1 : 1 + ACTIVE_LANES] = v1.reshape(-1)[:ACTIVE_LANES] + + output_dir.mkdir(parents=True, exist_ok=True) + v1.reshape(-1).tofile(output_dir / "v1.bin") + v2.reshape(-1).tofile(output_dir / "v2.bin") + golden_v2.reshape(-1).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate numpy-based inputs/golden for VPTO micro-op vstur validation." + ) + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/micro-op/vector-load-store/vstur/kernel.pto b/test/vpto/cases/micro-op/vector-load-store/vstur/kernel.pto new file mode 100644 index 000000000..52ffc13c3 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstur/kernel.pto @@ -0,0 +1,61 @@ +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vstur +// family: vector-load-store +// target_ops: pto.vstur +// scenarios: core-f32, predicate-squeezed, unaligned, state-update +// ----------------------------------------------------------------------------- +// Validate the standalone `vstur` surface with its required SQZN producer. +// The case keeps the sequence minimal: +// 1. load one vector from `%ub_in` +// 2. generate a small predicate and squeeze the vector to prime `SPR SQZN` +// 3. prime one store-state carrier from `%ub_out` +// 4. issue one `pto.vstur ... "POST_UPDATE"` +// 5. flush the residual state with `pto.vstar` +// This preserves the testcase goal around unaligned store state update without +// fabricating extra semantics beyond the installed A5 wrapper contract. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vstur_kernel_2d(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel} { + %c1 = arith.constant 1 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c8_i32 = arith.constant 8 : i32 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c0 = arith.constant 0 : index + + %ub_in = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out1 = pto.addptr %ub_out, %c1 : !pto.ptr -> !pto.ptr + + %false = arith.constant false + pto.mte_gm_ub %arg0, %ub_in, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + pto.sprclr "AR" + scf.for %offset = %c0 to %c1 step %c1 { + %vec = pto.vlds %ub_in[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %mask, %unused = pto.plt_b32 %c8_i32 : i32 -> !pto.mask, i32 + %sqz = pto.vsqz %vec, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %align0 = pto.init_align : !pto.align + %align1 = pto.vstur %align0, %sqz, %ub_out1, "POST_UPDATE" + : !pto.align, !pto.vreg<64xf32>, !pto.ptr -> !pto.align + pto.vstar %align1, %ub_out1 : !pto.align, !pto.ptr + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out, %arg1, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vstur/launch.cpp b/test/vpto/cases/micro-op/vector-load-store/vstur/launch.cpp new file mode 100644 index 000000000..b0c69d79c --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstur/launch.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vstur +// family: vector-load-store +// target_ops: pto.vstur +// scenarios: core-f32, full-mask, unaligned, state-update +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +// --------------------------------------------------------------------------- +// PTOAS compatibility layer +// +// The upstream pto-isa headers reference some FP8/FP4 types and the +// __VEC_SCOPE__ marker that are not available on every AICore arch/toolchain +// combination (e.g. __NPU_ARCH__==2201). +// +// For our PTOAS-generated kernels we don't rely on these types today, but the +// headers still mention them in templates/static_asserts. Provide minimal +// fallbacks to keep compilation working on dav-c220. +// --------------------------------------------------------------------------- +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif + +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include + +// AICore printf support is gated behind `--cce-enable-print` on some +// toolchains. When enabled, include the CCE print header so `cce::printf` +// resolves in device compilation. +#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) +#include +#endif + +// Some PTO-ISA types are only available in the __CCE_AICORE__ compilation +// path, but `bisheng -xcce` still performs a host-side parse pass. +// Provide minimal fallbacks only when the corresponding header wasn't +// pulled in by the selected arch implementation. +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void vstur_kernel_2d(__gm__ float *v1, + __gm__ float *v2); + +void LaunchVstur_kernel_2d(float *v1, float *v2, void *stream) { + vstur_kernel_2d<<<1, nullptr, stream>>>((__gm__ float *)v1, (__gm__ float *)v2); +} diff --git a/test/vpto/cases/micro-op/vector-load-store/vstur/main.cpp b/test/vpto/cases/micro-op/vector-load-store/vstur/main.cpp new file mode 100644 index 000000000..273fb30c4 --- /dev/null +++ b/test/vpto/cases/micro-op/vector-load-store/vstur/main.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// ----------------------------------------------------------------------------- +// case: micro-op/vector-load-store/vstur +// family: vector-load-store +// target_ops: pto.vstur +// scenarios: core-f32, full-mask, unaligned, state-update +// NOTE: bulk-generated coverage skeleton. Parser/verifier/lowering failure is +// still a valid test conclusion in the current coverage-first phase. +// ----------------------------------------------------------------------------- +/** +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include "test_common.h" +#include "acl/acl.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#ifndef TMRGSORT_HPP +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, (int)_ret, __FILE__, __LINE__); \ + const char *_recent = aclGetRecentErrMsg(); \ + if (_recent != nullptr && _recent[0] != '\0') { \ + std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ + } \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVstur_kernel_2d(float *v1, float *v2, void *stream); + +int main() { + size_t elemCount_v1 = 1024; + size_t fileSize_v1 = elemCount_v1 * sizeof(float); + size_t elemCount_v2 = 1024; + size_t fileSize_v2 = elemCount_v2 * sizeof(float); + float *v1Host = nullptr; + float *v1Device = nullptr; + float *v2Host = nullptr; + float *v2Device = nullptr; + + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) { + deviceId = std::atoi(envDevice); + } + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + + ACL_CHECK(aclrtMallocHost((void **)(&v1Host), fileSize_v1)); + ACL_CHECK(aclrtMallocHost((void **)(&v2Host), fileSize_v2)); + ACL_CHECK(aclrtMalloc((void **)&v1Device, fileSize_v1, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&v2Device, fileSize_v2, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", fileSize_v1, v1Host, fileSize_v1); + ReadFile("./v2.bin", fileSize_v2, v2Host, fileSize_v2); + ACL_CHECK(aclrtMemcpy(v1Device, fileSize_v1, v1Host, fileSize_v1, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(v2Device, fileSize_v2, v2Host, fileSize_v2, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVstur_kernel_2d(v1Device, v2Device, stream); + + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(v2Host, fileSize_v2, v2Device, fileSize_v2, ACL_MEMCPY_DEVICE_TO_HOST)); + + WriteFile("./v2.bin", v2Host, fileSize_v2); + +cleanup: + aclrtFree(v1Device); + aclrtFree(v2Device); + aclrtFreeHost(v1Host); + aclrtFreeHost(v2Host); + if (stream != nullptr) { + const aclError _ret = aclrtDestroyStream(stream); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtDestroyStream(stream)", (int)_ret, __FILE__, __LINE__); + } + stream = nullptr; + } + if (deviceSet) { + const aclError _ret = aclrtResetDevice(deviceId); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclrtResetDevice(deviceId)", (int)_ret, __FILE__, __LINE__); + } + } + if (aclInited) { + const aclError _ret = aclFinalize(); + if (_ret != ACL_SUCCESS) { + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", + "aclFinalize()", (int)_ret, __FILE__, __LINE__); + } + } + + return rc; +} diff --git a/test/vpto/npu_validation/common/test_common.h b/test/vpto/npu_validation/common/test_common.h new file mode 100644 index 000000000..3cbb7a3e3 --- /dev/null +++ b/test/vpto/npu_validation/common/test_common.h @@ -0,0 +1,61 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace PtoTestCommon { + +inline bool ReadFile(const std::string &filePath, size_t &fileSize, void *buffer, size_t bufferSize) { + struct stat sBuf; + if (stat(filePath.c_str(), &sBuf) == -1) { + return false; + } + if (!S_ISREG(sBuf.st_mode)) { + return false; + } + + std::ifstream file(filePath, std::ios::binary); + if (!file.is_open()) { + return false; + } + + std::filebuf *buf = file.rdbuf(); + size_t size = buf->pubseekoff(0, std::ios::end, std::ios::in); + if (size == 0 || size > bufferSize) { + return false; + } + buf->pubseekpos(0, std::ios::in); + buf->sgetn(static_cast(buffer), size); + fileSize = size; + return true; +} + +inline bool WriteFile(const std::string &filePath, const void *buffer, size_t size) { + if (buffer == nullptr) { + return false; + } + + int fd = open(filePath.c_str(), O_RDWR | O_CREAT | O_TRUNC, S_IRUSR | S_IWRITE); + if (fd < 0) { + return false; + } + + ssize_t writeSize = write(fd, buffer, size); + (void)close(fd); + return writeSize == static_cast(size); +} + +} // namespace PtoTestCommon diff --git a/test/vpto/scripts/run_host_vpto_validation.sh b/test/vpto/scripts/run_host_vpto_validation.sh new file mode 100755 index 000000000..0edac3882 --- /dev/null +++ b/test/vpto/scripts/run_host_vpto_validation.sh @@ -0,0 +1,335 @@ +#!/usr/bin/env bash +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT_DIR="$(cd "${SCRIPT_DIR}/../../.." && pwd)" +VPTO_ROOT="${VPTO_ROOT:-${ROOT_DIR}/test/vpto/cases}" +CASES_ROOT="${CASES_ROOT:-${VPTO_ROOT}}" +NPU_VALIDATION_COMMON_DIR="${NPU_VALIDATION_COMMON_DIR:-${ROOT_DIR}/test/vpto/npu_validation/common}" + +WORK_SPACE="${WORK_SPACE:-}" +ASCEND_HOME_PATH="${ASCEND_HOME_PATH:-}" +PTOAS_BIN="${PTOAS_BIN:-${ROOT_DIR}/build/tools/ptoas/ptoas}" +PTOAS_FLAGS="${PTOAS_FLAGS:---pto-arch a5 --pto-backend=vpto}" +# set he HOST_RUNNER to "ssh root@localhost" if must change user to root to access the device +HOST_RUNNER="${HOST_RUNNER:-}" +CASE_NAME="${CASE_NAME:-}" +DEVICE="${DEVICE:-SIM}" +SIM_LIB_DIR="${SIM_LIB_DIR:-}" +COMPILE_ONLY="${COMPILE_ONLY:-0}" + +log() { + echo "[$(date +'%F %T')] $*" +} + +die() { + echo "ERROR: $*" >&2 + exit 1 +} + +run_remote() { + local cmd="$1" + if [[ "${HOST_RUNNER}" == "ssh root@localhost" ]]; then + ssh -o StrictHostKeyChecking=no root@localhost "${cmd}" + else + bash -lc "${cmd}" + fi +} + +require_env() { + local name="$1" + local value="$2" + if [[ -z "${value}" ]]; then + die "${name} is required" + fi +} + +require_env "WORK_SPACE" "${WORK_SPACE}" +require_env "ASCEND_HOME_PATH" "${ASCEND_HOME_PATH}" +[[ -x "${PTOAS_BIN}" ]] || die "PTOAS_BIN is not executable: ${PTOAS_BIN}" +[[ -d "${CASES_ROOT}" ]] || die "missing cases root: ${CASES_ROOT}" + +if [[ -f "${ASCEND_HOME_PATH}/set_env.sh" ]]; then + set +u + source "${ASCEND_HOME_PATH}/set_env.sh" >/dev/null 2>&1 + set -u +fi + +resolve_sim_lib_dir() { + if [[ "${DEVICE}" != "SIM" ]]; then + return 0 + fi + + if [[ -n "${SIM_LIB_DIR}" ]]; then + [[ -d "${SIM_LIB_DIR}" ]] || + die "SIM_LIB_DIR is set but invalid: ${SIM_LIB_DIR}" + return 0 + fi + + local -a candidates=() + readarray -t candidates < <( + find "${ASCEND_HOME_PATH}" -type d -path '*/simulator/dav_3510/lib' | sort + ) + + if [[ "${#candidates[@]}" -eq 1 ]]; then + SIM_LIB_DIR="${candidates[0]}" + log "SIM_LIB_DIR is unset; auto-selected: ${SIM_LIB_DIR}" + return 0 + fi + + if [[ "${#candidates[@]}" -gt 1 ]]; then + SIM_LIB_DIR="${candidates[0]}" + log "SIM_LIB_DIR is unset; multiple dav_3510 simulator dirs found, using: ${SIM_LIB_DIR}" + return 0 + fi + + die "SIM_LIB_DIR is required for DEVICE=SIM and no dav_3510 simulator lib dir was found under: ${ASCEND_HOME_PATH}" +} + +resolve_sim_lib_dir + +BISHENG_BIN="${BISHENG_BIN:-${ASCEND_HOME_PATH}/bin/bisheng}" + +command -v "${BISHENG_BIN}" >/dev/null 2>&1 || die "bisheng not found: ${BISHENG_BIN}" +command -v python3 >/dev/null 2>&1 || die "python3 not found" + +mkdir -p "${WORK_SPACE}" +WORK_SPACE="$(cd "${WORK_SPACE}" && pwd)" + +discover_cases() { + local required_files=( + launch.cpp + main.cpp + golden.py + compare.py + ) + + if [[ -n "${CASE_NAME}" ]]; then + [[ "${CASE_NAME}" != /* ]] || die "CASE_NAME must be relative to CASES_ROOT: ${CASE_NAME}" + local requested_dir="${CASES_ROOT}/${CASE_NAME}" + [[ -d "${requested_dir}" ]] || die "unknown case: ${CASE_NAME}" + for f in "${required_files[@]}"; do + [[ -f "${requested_dir}/${f}" ]] || die "case ${CASE_NAME} is missing ${f}" + done + [[ -f "${requested_dir}/kernel.pto" ]] || + die "case ${CASE_NAME} must provide kernel.pto" + printf "%s\n" "${CASE_NAME}" + return 0 + fi + + find "${CASES_ROOT}" -mindepth 1 -type d | sort | while read -r dir; do + local ok=1 + for f in "${required_files[@]}"; do + if [[ ! -f "${dir}/${f}" ]]; then + ok=0 + break + fi + done + [[ "${ok}" -eq 1 ]] || continue + [[ -f "${dir}/kernel.pto" ]] || continue + local rel="${dir#${CASES_ROOT}/}" + printf "%s\n" "${rel}" + done +} + +readarray -t CASES < <(discover_cases) +[[ "${#CASES[@]}" -gt 0 ]] || die "no cases found under ${CASES_ROOT}" + +case_output_token() { + printf '%s' "$1" | sed 's#[/[:space:]]#_#g' +} + +build_launch_object() { + local case_dir="$1" + local out_obj="$2" + + "${BISHENG_BIN}" \ + -c -fPIC -xcce -fenable-matrix --cce-aicore-enable-tl \ + -fPIC -Xhost-start -Xhost-end \ + -mllvm -cce-aicore-stack-size=0x8000 \ + -mllvm -cce-aicore-function-stack-size=0x8000 \ + -mllvm -cce-aicore-record-overflow=true \ + -mllvm -cce-aicore-addr-transform \ + -mllvm -cce-aicore-dcci-insert-for-scalar=false \ + --cce-aicore-arch=dav-c310 \ + -DREGISTER_BASE \ + -std=c++17 \ + -Wno-macro-redefined -Wno-ignored-attributes \ + -I "${ASCEND_HOME_PATH}/include" \ + -I "${ASCEND_HOME_PATH}/pkg_inc" \ + -I "${ASCEND_HOME_PATH}/pkg_inc/profiling" \ + -I "${ASCEND_HOME_PATH}/pkg_inc/runtime/runtime" \ + "${case_dir}/launch.cpp" \ + -o "${out_obj}" +} + +link_kernel_so() { + local case_name="$1" + local kernel_fatobj="$2" + local launch_obj="$3" + local kernel_so="$4" + local extra_lib_dirs=() + local extra_link_libs=() + + if [[ "${DEVICE}" == "SIM" ]]; then + [[ -n "${SIM_LIB_DIR}" && -d "${SIM_LIB_DIR}" ]] || + die "SIM_LIB_DIR is not set or invalid for DEVICE=SIM: ${SIM_LIB_DIR}" + extra_lib_dirs+=(-L "${SIM_LIB_DIR}" -Wl,-rpath,"${SIM_LIB_DIR}") + extra_link_libs+=(-Wl,--no-as-needed -lruntime_camodel) + else + extra_link_libs+=(-Wl,--no-as-needed -lruntime) + fi + + "${BISHENG_BIN}" \ + -fPIC -s -Wl,-z,relro -Wl,-z,now --cce-fatobj-link \ + -shared -Wl,-soname,"lib${case_name}_kernel.so" \ + -L "${ASCEND_HOME_PATH}/lib64" \ + "${extra_lib_dirs[@]}" \ + -Wl,-rpath,"${ASCEND_HOME_PATH}/lib64" \ + -o "${kernel_so}" \ + "${kernel_fatobj}" \ + "${launch_obj}" \ + "${extra_link_libs[@]}" +} + +build_host_executable() { + local case_token="$1" + local case_dir="$2" + local out_dir="$3" + local extra_ldflags=() + local extra_lib_dirs=() + if [[ "${DEVICE}" == "SIM" ]]; then + [[ -n "${SIM_LIB_DIR}" && -d "${SIM_LIB_DIR}" ]] || + die "SIM_LIB_DIR is not set or invalid for DEVICE=SIM: ${SIM_LIB_DIR}" + extra_lib_dirs+=(-L "${SIM_LIB_DIR}" -Wl,-rpath,"${SIM_LIB_DIR}") + extra_ldflags+=(-Wl,--allow-shlib-undefined -lruntime_camodel) + else + extra_ldflags+=(-Wl,--allow-shlib-undefined -lruntime) + fi + + "${BISHENG_BIN}" \ + -xc++ -include stdint.h -include stddef.h -std=c++17 \ + "${case_dir}/main.cpp" \ + -I "${case_dir}" \ + -I "${NPU_VALIDATION_COMMON_DIR}" \ + -I "${ASCEND_HOME_PATH}/include" \ + -L "${out_dir}" \ + -L "${ASCEND_HOME_PATH}/lib64" \ + "${extra_lib_dirs[@]}" \ + -Wl,-rpath,"${out_dir}" \ + -Wl,-rpath,"${ASCEND_HOME_PATH}/lib64" \ + -o "${out_dir}/${case_token}" \ + -l"${case_token}_kernel" \ + "${extra_ldflags[@]}" \ + -lstdc++ -lascendcl -lm -ltiling_api -lplatform -lc_sec -ldl -lnnopbase +} + +build_one_impl() { + local case_name="$1" + local case_dir="${CASES_ROOT}/${case_name}" + local case_token + case_token="$(case_output_token "${case_name}")" + local out_dir="${WORK_SPACE}/${case_token}" + local launch_obj="${out_dir}/launch.o" + local kernel_fatobj="${out_dir}/kernel.fatobj.o" + local kernel_so="${out_dir}/lib${case_token}_kernel.so" + + [[ -f "${case_dir}/main.cpp" ]] || die "missing main.cpp for ${case_name}" + [[ -f "${case_dir}/launch.cpp" ]] || die "missing launch.cpp for ${case_name}" + [[ -f "${case_dir}/golden.py" ]] || die "missing golden.py for ${case_name}" + [[ -f "${case_dir}/compare.py" ]] || die "missing compare.py for ${case_name}" + [[ -f "${case_dir}/kernel.pto" ]] || + die "missing kernel.pto for ${case_name}" + + log "[$case_name] step 1/4: emit kernel fatobj" + "${PTOAS_BIN}" ${PTOAS_FLAGS} \ + "${case_dir}/kernel.pto" -o "${kernel_fatobj}" + + log "[$case_name] step 2/4: build launch object" + build_launch_object "${case_dir}" "${launch_obj}" + + log "[$case_name] step 3/4: link kernel shared library" + link_kernel_so "${case_token}" "${kernel_fatobj}" "${launch_obj}" "${kernel_so}" + + if [[ "${COMPILE_ONLY}" == "1" ]]; then + log "[$case_name] compile-only mode: stop after kernel shared library" + log "[$case_name] output dir: ${out_dir}" + return 0 + fi + + log "[$case_name] step 4/4: build host executable and golden" + build_host_executable "${case_token}" "${case_dir}" "${out_dir}" + ( + cd "${out_dir}" + python3 "${case_dir}/golden.py" + ) + + log "[$case_name] run NPU validation" + local remote_run_cmd + remote_run_cmd=$(cat </dev/null 2>&1; fi && \ +LD_LIBRARY_PATH="${out_dir}:${SIM_LIB_DIR}:\$ASCEND_HOME_PATH/lib64:\${LD_LIBRARY_PATH:-}" "./${case_token}" +EOF +) + run_remote "${remote_run_cmd}" + + local remote_ldd_cmd + remote_ldd_cmd=$(cat </dev/null 2>&1; fi && \ +LD_LIBRARY_PATH="${out_dir}:${SIM_LIB_DIR}:\$ASCEND_HOME_PATH/lib64:\${LD_LIBRARY_PATH:-}" ldd "./${case_token}" | grep "lib${case_token}_kernel.so" +EOF +) + local ldd_output + ldd_output="$(run_remote "${remote_ldd_cmd}")" + [[ "${ldd_output}" == *"${kernel_so}"* || "${ldd_output}" == *"lib${case_token}_kernel.so"* ]] || \ + die "${case_name} did not load expected kernel so: ${ldd_output}" + + ( + cd "${out_dir}" + COMPARE_STRICT=1 python3 "${case_dir}/compare.py" + ) + + log "[$case_name] compare passed" + log "[$case_name] output dir: ${out_dir}" +} + +build_one() { + local case_name="$1" + local case_token + case_token="$(case_output_token "${case_name}")" + local out_dir="${WORK_SPACE}/${case_token}" + local case_log="${out_dir}/validation.log" + + rm -rf "${out_dir}" + mkdir -p "${out_dir}" + + ( + build_one_impl "${case_name}" + ) 2>&1 | tee "${case_log}" +} + +log "=== VPTO Host Validation ===" +log "WORK_SPACE=${WORK_SPACE}" +log "ASCEND_HOME_PATH=${ASCEND_HOME_PATH}" +log "PTOAS_BIN=${PTOAS_BIN}" +log "PTOAS_FLAGS=${PTOAS_FLAGS}" +log "COMPILE_ONLY=${COMPILE_ONLY}" +log "CASE_NAME=${CASE_NAME:-}" + +for case_name in "${CASES[@]}"; do + build_one "${case_name}" +done + +log "All ${#CASES[@]} VPTO case(s) passed" diff --git a/test/vpto/scripts/run_host_vpto_validation_parallel.sh b/test/vpto/scripts/run_host_vpto_validation_parallel.sh new file mode 100755 index 000000000..f706a5f43 --- /dev/null +++ b/test/vpto/scripts/run_host_vpto_validation_parallel.sh @@ -0,0 +1,190 @@ +#!/usr/bin/env bash +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT_DIR="$(cd "${SCRIPT_DIR}/../../.." && pwd)" +VPTO_ROOT="${VPTO_ROOT:-${ROOT_DIR}/test/vpto/cases}" +CASES_ROOT="${CASES_ROOT:-${VPTO_ROOT}}" +SERIAL_SCRIPT="${SCRIPT_DIR}/run_host_vpto_validation.sh" + +WORK_SPACE="${WORK_SPACE:-}" +CASE_NAME="${CASE_NAME:-}" +CASE_PREFIX="${CASE_PREFIX:-}" +JOBS="${JOBS:-}" + +log() { + echo "[$(date +'%F %T')] $*" +} + +die() { + echo "ERROR: $*" >&2 + exit 1 +} + +clean_tmp_inode_hotspots() { + local -a targets=( + /tmp/pto-microop-full + /tmp/pto-microop-full-redownload + ) + + log "tmp inode usage before cleanup" + df -ih /tmp + + for dir in "${targets[@]}"; do + if [[ -e "${dir}" ]]; then + log "remove ${dir}" + rm -rf "${dir}" + fi + done + + log "tmp inode usage after cleanup" + df -ih /tmp +} + +clean_tmp_inode_hotspots + +[[ -x "${SERIAL_SCRIPT}" ]] || die "missing serial validation script: ${SERIAL_SCRIPT}" +[[ -d "${CASES_ROOT}" ]] || die "missing cases root: ${CASES_ROOT}" +[[ -n "${WORK_SPACE}" ]] || die "WORK_SPACE is required" + +if [[ -z "${JOBS}" ]]; then + if command -v nproc >/dev/null 2>&1; then + JOBS="$(nproc)" + else + JOBS=1 + fi + if [[ "${JOBS}" -gt 1 ]]; then + JOBS="$((JOBS / 2))" + fi +fi + +[[ "${JOBS}" =~ ^[0-9]+$ ]] || die "JOBS must be a positive integer, got: ${JOBS}" +[[ "${JOBS}" -ge 1 ]] || die "JOBS must be >= 1" + +mkdir -p "${WORK_SPACE}" +WORK_SPACE="$(cd "${WORK_SPACE}" && pwd)" +SUMMARY_FILE="${WORK_SPACE}/parallel-summary.tsv" +RUNNER_LOG="${WORK_SPACE}/parallel-runner.log" + +discover_cases() { + local required_files=( + launch.cpp + main.cpp + golden.py + compare.py + ) + + if [[ -n "${CASE_NAME}" ]]; then + local requested_dir="${CASES_ROOT}/${CASE_NAME}" + [[ -d "${requested_dir}" ]] || die "unknown case: ${CASE_NAME}" + for f in "${required_files[@]}"; do + [[ -f "${requested_dir}/${f}" ]] || die "case ${CASE_NAME} is missing ${f}" + done + [[ -f "${requested_dir}/kernel.pto" ]] || + die "case ${CASE_NAME} must provide kernel.pto" + printf "%s\n" "${CASE_NAME}" + return 0 + fi + + find "${CASES_ROOT}" -mindepth 1 -type d | sort | while read -r dir; do + local ok=1 + for f in "${required_files[@]}"; do + if [[ ! -f "${dir}/${f}" ]]; then + ok=0 + break + fi + done + [[ "${ok}" -eq 1 ]] || continue + [[ -f "${dir}/kernel.pto" ]] || continue + local rel="${dir#${CASES_ROOT}/}" + if [[ -n "${CASE_PREFIX}" && "${rel}" != "${CASE_PREFIX}"* ]]; then + continue + fi + printf "%s\n" "${rel}" + done +} + +readarray -t CASES < <(discover_cases) +[[ "${#CASES[@]}" -gt 0 ]] || die "no cases found under ${CASES_ROOT}" + +: > "${SUMMARY_FILE}" +: > "${RUNNER_LOG}" + +declare -A PID_TO_CASE=() + +launch_case() { + local case_name="$1" + + log "[${case_name}] launch" | tee -a "${RUNNER_LOG}" + ( + CASE_NAME="${case_name}" "${SERIAL_SCRIPT}" + ) & + + local pid=$! + PID_TO_CASE["${pid}"]="${case_name}" +} + +reap_one() { + local pid="$1" + local case_name="${PID_TO_CASE[${pid}]}" + local result="FAIL" + local detail="1" + + if wait "${pid}"; then + result="PASS" + detail="0" + fi + + printf '%s\t%s\t%s\n' "${case_name}" "${result}" "${detail}" >> "${SUMMARY_FILE}" + log "[${case_name}] ${result} (${detail})" | tee -a "${RUNNER_LOG}" + unset 'PID_TO_CASE['"${pid}"']' +} + +log "=== VPTO Host Validation Parallel ===" | tee -a "${RUNNER_LOG}" +log "WORK_SPACE=${WORK_SPACE}" | tee -a "${RUNNER_LOG}" +log "CASE_NAME=${CASE_NAME:-}" | tee -a "${RUNNER_LOG}" +log "CASE_PREFIX=${CASE_PREFIX:-}" | tee -a "${RUNNER_LOG}" +log "JOBS=${JOBS}" | tee -a "${RUNNER_LOG}" +log "TOTAL_CASES=${#CASES[@]}" | tee -a "${RUNNER_LOG}" + +next_index=0 +while [[ "${next_index}" -lt "${#CASES[@]}" || "${#PID_TO_CASE[@]}" -gt 0 ]]; do + while [[ "${next_index}" -lt "${#CASES[@]}" && "${#PID_TO_CASE[@]}" -lt "${JOBS}" ]]; do + launch_case "${CASES[${next_index}]}" + next_index="$((next_index + 1))" + done + + if [[ "${#PID_TO_CASE[@]}" -eq 0 ]]; then + continue + fi + + while true; do + for pid in "${!PID_TO_CASE[@]}"; do + if ! kill -0 "${pid}" 2>/dev/null; then + reap_one "${pid}" + break 2 + fi + done + sleep 1 + done +done + +pass_count="$(awk -F '\t' '$2 == "PASS" {count++} END {print count + 0}' "${SUMMARY_FILE}")" +fail_count="$(awk -F '\t' '$2 != "PASS" {count++} END {print count + 0}' "${SUMMARY_FILE}")" + +log "PASS=${pass_count} FAIL=${fail_count}" | tee -a "${RUNNER_LOG}" +log "summary: ${SUMMARY_FILE}" | tee -a "${RUNNER_LOG}" + +if [[ "${fail_count}" -ne 0 ]]; then + die "parallel validation finished with ${fail_count} failing case(s)" +fi + +log "All ${pass_count} case(s) passed" | tee -a "${RUNNER_LOG}" diff --git a/test/vpto_tilelang_inline_soft_divmod_fastpath.pto b/test/vpto_tilelang_inline_soft_divmod_fastpath.pto new file mode 100644 index 000000000..a6ea11748 --- /dev/null +++ b/test/vpto_tilelang_inline_soft_divmod_fastpath.pto @@ -0,0 +1,158 @@ +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto %s --emit-vpto -o - | FileCheck %s + +// CHECK-LABEL: func.func @kernel( +// CHECK: pto.vecscope { +// CHECK: pto.vlds +// CHECK: pto.vcmps +// CHECK: pto.vxor +// CHECK: pto.vdiv +// CHECK: pto.vmul +// CHECK: pto.vsub +// CHECK: pto.vsts +// CHECK-NOT: func.call @__tl_inline__tl_soft +// CHECK-NOT: func.func private @__tl_inline__tl_soft +// CHECK-NOT: __tl_inline__tl_soft_vdiv_ +// CHECK-NOT: __tl_inline__tl_soft_vmod_ + +// tilelang.target = a5 +// tilelang.op = dump_i16_divmod_lit_tmp +// tilelang.dtypes = (i16, i16) +// tilelang.verify = True +// tilelang.advanced = False +// tilelang.specialize dst shape=(8, 16) memory_space=ub config=None +// tilelang.specialize src shape=(8, 16) memory_space=ub config=None +module attributes {pto.target_arch = "a5"} { + func.func @kernel(%arg0: !pto.tile_buf, %arg1: !pto.tile_buf) attributes { pto.tilelang.instance } { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %tmp_0 = pto.tile_buf_addr %arg0 : !pto.tile_buf -> memref<8x16xi16, #pto.address_space> + %tmp_1 = pto.tile_buf_addr %arg1 : !pto.tile_buf -> memref<8x16xi16, #pto.address_space> + pto.vecscope { + %mask_0 = pto.pset_b16 "PAT_ALL" : !pto.mask + %vec_1 = pto.vlds %tmp_1[%c0] : memref<8x16xi16, #pto.address_space> -> !pto.vreg<128xi16> + %q_59 = func.call @__tl_inline__tl_soft_vdiv_2(%vec_1, %vec_1, %mask_0) : (!pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask) -> !pto.vreg<128xi16> + %r_81 = func.call @__tl_inline__tl_soft_vmod_4(%vec_1, %vec_1, %mask_0) : (!pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask) -> !pto.vreg<128xi16> + pto.vsts %q_59, %tmp_0[%c0], %mask_0 : !pto.vreg<128xi16>, memref<8x16xi16, #pto.address_space>, !pto.mask + pto.vsts %r_81, %tmp_0[%c1], %mask_0 : !pto.vreg<128xi16>, memref<8x16xi16, #pto.address_space>, !pto.mask + } + return + } + func.func private @__tl_inline__tl_soft_vdiv_u16_0(%arg0: !pto.vreg<128xui16>, %arg1: !pto.vreg<128xui16>, %arg2: !pto.mask) -> !pto.vreg<128xui16> attributes { pto.tilelang.inline_proc } { + %c0_i32 = arith.constant 0 : i32 + %c0_ui32 = builtin.unrealized_conversion_cast %c0_i32 : i32 to ui32 + %c65536_0_f32 = arith.constant 65536.0 : f32 + %c65535_i16 = arith.constant 65535 : i16 + %c65535_ui16 = builtin.unrealized_conversion_cast %c65535_i16 : i16 to ui16 + %tmp_0 = arith.constant 0 : i16 + %zero_10 = builtin.unrealized_conversion_cast %tmp_0 : i16 to ui16 + %tmp_1 = arith.constant 1 : i16 + %one_11 = builtin.unrealized_conversion_cast %tmp_1 : i16 to ui16 + %fp32_one_12 = arith.constant 1.0 : f32 + %full_mask_b16_13 = pto.pset_b16 "PAT_ALL" : !pto.mask + %full_mask_b32_14 = pto.pset_b32 "PAT_ALL" : !pto.mask + %zero_mask_15 = pto.vcmps %arg1, %zero_10, %arg2, "eq" : !pto.vreg<128xui16>, ui16, !pto.mask -> !pto.mask + %active_mask_16 = pto.pnot %zero_mask_15, %arg2 : !pto.mask, !pto.mask -> !pto.mask + %zero_u16_17 = pto.vbr %zero_10 : ui16 -> !pto.vreg<128xui16> + %vy_lower_u16_18, %vy_higher_u16_19 = pto.vintlv %arg1, %zero_u16_17 : !pto.vreg<128xui16>, !pto.vreg<128xui16> -> !pto.vreg<128xui16>, !pto.vreg<128xui16> + %vy_lower_u32_20 = pto.vcvt %vy_lower_u16_18, %active_mask_16 {part = "EVEN"} : !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<64xui32> + %vy_higher_u32_21 = pto.vcvt %vy_higher_u16_19, %active_mask_16 {part = "EVEN"} : !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<64xui32> + %active_low_22 = pto.vcmps %vy_lower_u32_20, %c0_ui32, %full_mask_b32_14, "ne" : !pto.vreg<64xui32>, ui32, !pto.mask -> !pto.mask + %active_high_23 = pto.vcmps %vy_higher_u32_21, %c0_ui32, %full_mask_b32_14, "ne" : !pto.vreg<64xui32>, ui32, !pto.mask -> !pto.mask + %tmp_2 = pto.vbitcast %vy_lower_u32_20 : !pto.vreg<64xui32> -> !pto.vreg<64xi32> + %vy_lower_f32_24 = pto.vcvt %tmp_2, %active_low_22 {rnd = "F"} : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xf32> + %tmp_3 = pto.vbitcast %vy_higher_u32_21 : !pto.vreg<64xui32> -> !pto.vreg<64xi32> + %vy_higher_f32_25 = pto.vcvt %tmp_3, %active_high_23 {rnd = "F"} : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xf32> + %tmp_4 = pto.vbr %fp32_one_12 : f32 -> !pto.vreg<64xf32> + %vy_rec_lower_26 = pto.vdiv %tmp_4, %vy_lower_f32_24, %active_low_22 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %tmp_5 = pto.vbr %fp32_one_12 : f32 -> !pto.vreg<64xf32> + %vy_rec_higher_27 = pto.vdiv %tmp_5, %vy_higher_f32_25, %active_high_23 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %tmp_6 = pto.vbr %c65536_0_f32 : f32 -> !pto.vreg<64xf32> + %vy_scale_lower_28 = pto.vmul %vy_rec_lower_26, %tmp_6, %active_low_22 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %tmp_7 = pto.vbr %c65536_0_f32 : f32 -> !pto.vreg<64xf32> + %vy_scale_higher_29 = pto.vmul %vy_rec_higher_27, %tmp_7, %active_high_23 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %v_lower_i32_30 = pto.vcvt %vy_scale_lower_28, %active_low_22 {rnd = "F", sat = "NOSAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xi32> + %v_higher_i32_31 = pto.vcvt %vy_scale_higher_29, %active_high_23 {rnd = "F", sat = "NOSAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xi32> + %v_lower_u32_32 = pto.vbitcast %v_lower_i32_30 : !pto.vreg<64xi32> -> !pto.vreg<64xui32> + %v_higher_u32_33 = pto.vbitcast %v_higher_i32_31 : !pto.vreg<64xi32> -> !pto.vreg<64xui32> + %vx_lower_u16_34, %vx_higher_u16_35 = pto.vintlv %arg0, %zero_u16_17 : !pto.vreg<128xui16>, !pto.vreg<128xui16> -> !pto.vreg<128xui16>, !pto.vreg<128xui16> + %vx_lower_u32_36 = pto.vcvt %vx_lower_u16_34, %active_mask_16 {part = "EVEN"} : !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<64xui32> + %vx_higher_u32_37 = pto.vcvt %vx_higher_u16_35, %active_mask_16 {part = "EVEN"} : !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<64xui32> + %q_tmp_lower_38 = pto.vmul %v_lower_u32_32, %vx_lower_u32_36, %active_low_22 : !pto.vreg<64xui32>, !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<64xui32> + %q_tmp_higher_39 = pto.vmul %v_higher_u32_33, %vx_higher_u32_37, %active_high_23 : !pto.vreg<64xui32>, !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<64xui32> + %tmp_8 = pto.vbitcast %q_tmp_lower_38 : !pto.vreg<64xui32> -> !pto.vreg<128xui16> + %tmp_9 = pto.vbitcast %q_tmp_higher_39 : !pto.vreg<64xui32> -> !pto.vreg<128xui16> + %_q_lower_40, %q_tmp_41 = pto.vdintlv %tmp_8, %tmp_9 : !pto.vreg<128xui16>, !pto.vreg<128xui16> -> !pto.vreg<128xui16>, !pto.vreg<128xui16> + %yq_tmp_42 = pto.vmul %q_tmp_41, %arg1, %active_mask_16 : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + %r_tmp_43 = pto.vsub %arg0, %yq_tmp_42, %active_mask_16 : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + %ge_mask_44 = pto.vcmp %r_tmp_43, %arg1, %active_mask_16, "ge" : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.mask + %refined_r_45 = pto.vsub %r_tmp_43, %arg1, %active_mask_16 : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + %r_tmp_46 = pto.vsel %refined_r_45, %r_tmp_43, %ge_mask_44 : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + %q_inc_47 = pto.vadds %q_tmp_41, %one_11, %active_mask_16 : !pto.vreg<128xui16>, ui16, !pto.mask -> !pto.vreg<128xui16> + %q_tmp_48 = pto.vsel %q_inc_47, %q_tmp_41, %ge_mask_44 : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + %ge_mask_49 = pto.vcmp %r_tmp_46, %arg1, %active_mask_16, "ge" : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.mask + %refined_r_50 = pto.vsub %r_tmp_46, %arg1, %active_mask_16 : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + %r_tmp_51 = pto.vsel %refined_r_50, %r_tmp_46, %ge_mask_49 : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + %q_inc_52 = pto.vadds %q_tmp_48, %one_11, %active_mask_16 : !pto.vreg<128xui16>, ui16, !pto.mask -> !pto.vreg<128xui16> + %q_tmp_53 = pto.vsel %q_inc_52, %q_tmp_48, %ge_mask_49 : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + %zero_q_54 = pto.vbr %c65535_ui16 : ui16 -> !pto.vreg<128xui16> + %tmp_10 = pto.vsel %zero_q_54, %q_tmp_53, %zero_mask_15 : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> + return %tmp_10 : !pto.vreg<128xui16> + } + func.func private @__tl_inline__tl_soft_vdiv_i16_1(%arg0: !pto.vreg<128xi16>, %arg1: !pto.vreg<128xi16>, %arg2: !pto.mask) -> !pto.vreg<128xi16> attributes { pto.tilelang.inline_proc } { + %zero_2 = arith.constant 0 : i16 + %neg_one_3 = arith.constant -1 : i16 + %zero_mask_4 = pto.vcmps %arg1, %zero_2, %arg2, "eq" : !pto.vreg<128xi16>, i16, !pto.mask -> !pto.mask + %active_mask_5 = pto.pnot %zero_mask_4, %arg2 : !pto.mask, !pto.mask -> !pto.mask + %tmp_0 = pto.vabs %arg0, %active_mask_5 : !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + %abs_x_6 = pto.vbitcast %tmp_0 : !pto.vreg<128xi16> -> !pto.vreg<128xui16> + %tmp_1 = pto.vabs %arg1, %active_mask_5 : !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + %abs_y_7 = pto.vbitcast %tmp_1 : !pto.vreg<128xi16> -> !pto.vreg<128xui16> + %x_xor_y_8 = pto.vxor %arg0, %arg1, %active_mask_5 : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + %p_pos_9 = pto.vcmps %x_xor_y_8, %zero_2, %active_mask_5, "ge" : !pto.vreg<128xi16>, i16, !pto.mask -> !pto.mask + %q_abs_55 = func.call @__tl_inline__tl_soft_vdiv_u16_0(%abs_x_6, %abs_y_7, %active_mask_5) : (!pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask) -> !pto.vreg<128xui16> + %tmp_2 = pto.vbitcast %q_abs_55 : !pto.vreg<128xui16> -> !pto.vreg<128xi16> + %neg_q_56 = pto.vneg %tmp_2, %active_mask_5 : !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + %tmp_3 = pto.vbitcast %q_abs_55 : !pto.vreg<128xui16> -> !pto.vreg<128xi16> + %q_57 = pto.vsel %tmp_3, %neg_q_56, %p_pos_9 : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + %tmp_5 = pto.vbr %neg_one_3 : i16 -> !pto.vreg<128xi16> + %tmp_4 = pto.vsel %tmp_5, %q_57, %zero_mask_4 : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + return %tmp_4 : !pto.vreg<128xi16> + } + func.func private @__tl_inline__tl_soft_vdiv_2(%arg0: !pto.vreg<128xi16>, %arg1: !pto.vreg<128xi16>, %arg2: !pto.mask) -> !pto.vreg<128xi16> attributes { pto.tilelang.inline_proc } { + %result_58 = func.call @__tl_inline__tl_soft_vdiv_i16_1(%arg0, %arg1, %arg2) : (!pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask) -> !pto.vreg<128xi16> + return %result_58 : !pto.vreg<128xi16> + } + func.func private @__tl_inline__tl_soft_vmod_i16_3(%arg0: !pto.vreg<128xi16>, %arg1: !pto.vreg<128xi16>, %arg2: !pto.mask) -> !pto.vreg<128xi16> attributes { pto.tilelang.inline_proc } { + %zero_60 = arith.constant 0 : i16 + %neg_one_61 = arith.constant -1 : i16 + %zero_mask_62 = pto.vcmps %arg1, %zero_60, %arg2, "eq" : !pto.vreg<128xi16>, i16, !pto.mask -> !pto.mask + %active_mask_63 = pto.pnot %zero_mask_62, %arg2 : !pto.mask, !pto.mask -> !pto.mask + %tmp_0 = pto.vabs %arg0, %active_mask_63 : !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + %abs_x_64 = pto.vbitcast %tmp_0 : !pto.vreg<128xi16> -> !pto.vreg<128xui16> + %tmp_1 = pto.vabs %arg1, %active_mask_63 : !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + %abs_y_65 = pto.vbitcast %tmp_1 : !pto.vreg<128xi16> -> !pto.vreg<128xui16> + %x_xor_y_66 = pto.vxor %arg0, %arg1, %active_mask_63 : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + %p_pos_67 = pto.vcmps %x_xor_y_66, %zero_60, %active_mask_63, "ge" : !pto.vreg<128xi16>, i16, !pto.mask -> !pto.mask + %q_abs_68 = func.call @__tl_inline__tl_soft_vdiv_u16_0(%abs_x_64, %abs_y_65, %active_mask_63) : (!pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask) -> !pto.vreg<128xui16> + %tmp_2 = pto.vbitcast %q_abs_68 : !pto.vreg<128xui16> -> !pto.vreg<128xi16> + %neg_q_69 = pto.vneg %tmp_2, %active_mask_63 : !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + %tmp_3 = pto.vbitcast %q_abs_68 : !pto.vreg<128xui16> -> !pto.vreg<128xi16> + %q_70 = pto.vsel %tmp_3, %neg_q_69, %p_pos_67 : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + %qy_71 = pto.vmul %q_70, %arg1, %active_mask_63 : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + %remainder_72 = pto.vsub %arg0, %qy_71, %active_mask_63 : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + %nonzero_remainder_73 = pto.vcmps %remainder_72, %zero_60, %active_mask_63, "ne" : !pto.vreg<128xi16>, i16, !pto.mask -> !pto.mask + %sign_x_74 = pto.vcmps %arg0, %zero_60, %active_mask_63, "ge" : !pto.vreg<128xi16>, i16, !pto.mask -> !pto.mask + %sign_y_75 = pto.vcmps %arg1, %zero_60, %active_mask_63, "ge" : !pto.vreg<128xi16>, i16, !pto.mask -> !pto.mask + %sign_diff_76 = pto.pxor %sign_x_74, %sign_y_75, %active_mask_63 : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + %need_floor_fix_77 = pto.pand %sign_diff_76, %nonzero_remainder_73, %active_mask_63 : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + %amended_remainder_78 = pto.vadd %arg1, %remainder_72, %active_mask_63 : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + %remainder_79 = pto.vsel %amended_remainder_78, %remainder_72, %need_floor_fix_77 : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + %tmp_5 = pto.vbr %neg_one_61 : i16 -> !pto.vreg<128xi16> + %tmp_4 = pto.vsel %tmp_5, %remainder_79, %zero_mask_62 : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> + return %tmp_4 : !pto.vreg<128xi16> + } + func.func private @__tl_inline__tl_soft_vmod_4(%arg0: !pto.vreg<128xi16>, %arg1: !pto.vreg<128xi16>, %arg2: !pto.mask) -> !pto.vreg<128xi16> attributes { pto.tilelang.inline_proc } { + %result_80 = func.call @__tl_inline__tl_soft_vmod_i16_3(%arg0, %arg1, %arg2) : (!pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask) -> !pto.vreg<128xi16> + return %result_80 : !pto.vreg<128xi16> + } +} diff --git a/tilelang-dsl/CMakeLists.txt b/tilelang-dsl/CMakeLists.txt new file mode 100644 index 000000000..2bbca7c3d --- /dev/null +++ b/tilelang-dsl/CMakeLists.txt @@ -0,0 +1,43 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# ========================================================= +# TileLang DSL package wiring +# ========================================================= + +set(TILELANG_DSL_PACKAGE_SRC_DIR + "${CMAKE_CURRENT_SOURCE_DIR}/python/tilelang_dsl") +set(TILELANG_DSL_BUILD_ROOT "${CMAKE_BINARY_DIR}/python") +set(TILELANG_DSL_BUILD_PACKAGE_DIR + "${TILELANG_DSL_BUILD_ROOT}/tilelang_dsl") + +add_custom_target(TileLangDSLPackage ALL + COMMAND ${CMAKE_COMMAND} -E make_directory "${TILELANG_DSL_BUILD_ROOT}" + COMMAND ${CMAKE_COMMAND} -E remove_directory "${TILELANG_DSL_BUILD_PACKAGE_DIR}" + COMMAND ${CMAKE_COMMAND} -E copy_directory + "${TILELANG_DSL_PACKAGE_SRC_DIR}" + "${TILELANG_DSL_BUILD_PACKAGE_DIR}" + COMMENT "Staging tilelang_dsl package into build/python" + VERBATIM +) + +install( + DIRECTORY "${TILELANG_DSL_PACKAGE_SRC_DIR}" + DESTINATION "." + COMPONENT PTOAS_Runtime + PATTERN "__pycache__" EXCLUDE + PATTERN "*.pyc" EXCLUDE +) + +install( + DIRECTORY "${CMAKE_SOURCE_DIR}/lib/TileOps" + DESTINATION "share/ptoas" + COMPONENT PTOAS_Runtime + PATTERN "__pycache__" EXCLUDE + PATTERN "*.pyc" EXCLUDE +) diff --git a/tilelang-dsl/README.md b/tilelang-dsl/README.md new file mode 100644 index 000000000..37d13f015 --- /dev/null +++ b/tilelang-dsl/README.md @@ -0,0 +1,124 @@ +TileLang DSL v1 lives under this directory. + +This subtree is the source of truth for the new frontend introduced by +`add-tilelang-dsl-core-foundation`. + +Boundary with the existing `python/pto/dialects/pto.py` module: +- `tilelang-dsl/` owns new TileLang DSL v1 core implementation work +- `python/pto/dialects/pto.py` keeps PTO dialect bindings and the legacy + experimental VPTO Python DSL surface +- Root-level wiring into build/install/test is allowed, but TileLang DSL core + logic must not move back into `python/pto/dialects/pto.py` + +Layout: +- `python/tilelang_dsl/`: package sources +- `tests/`: TileLang DSL focused tests +- `examples/`: self-contained examples +- `docs/`: local documentation for this frontend + +## How To Generate MLIR From A `.py` + +Run the examples from the repository root. + +If you are developing against the in-tree Python sources, point `PYTHONPATH` +at `tilelang-dsl/python`: + +```bash +cd /home/zhangzhendong/ptoas-workspace/PTOAS +PYTHONPATH=$PWD/tilelang-dsl/python python3 tilelang-dsl/examples/v1_emit_mlir_demo.py +PYTHONPATH=$PWD/tilelang-dsl/python python3 tilelang-dsl/examples/v1_emit_mlir_demo.py /tmp/tilelang_demo.mlir +``` + +If you already built and installed the Python package into the repo build tree, +you can also point `PYTHONPATH` at `build/python`: + +```bash +cd /home/zhangzhendong/ptoas-workspace/PTOAS +PYTHONPATH=$PWD/build/python python3 tilelang-dsl/examples/v1_emit_mlir_demo.py +``` + +Behavior: +- without an output path, the script prints MLIR to stdout +- with an output path, the script writes MLIR to that file through `emit(path)` + +Useful examples: + +```bash +PYTHONPATH=$PWD/tilelang-dsl/python python3 tilelang-dsl/examples/v1_elementwise_tail_demo.py +PYTHONPATH=$PWD/tilelang-dsl/python python3 tilelang-dsl/examples/v1_elementwise_tail_demo.py /tmp/tilelang_v1_elementwise.mlir +PYTHONPATH=$PWD/tilelang-dsl/python python3 tilelang-dsl/examples/v1_tadd_implicit_vecscope_demo.py +PYTHONPATH=$PWD/tilelang-dsl/python python3 tilelang-dsl/examples/v1_verify_smoke.py /tmp/tilelang_v1_verify.mlir +``` + +## Advanced Mode + +The default v1 surface still requires explicit `pto.strict_vecscope`. + +If you want the follow-up advanced surface for: +- implicit `pto.vecscope` inference +- `pto.vlds(tile[row, col:])` +- `pto.vsts(vec, tile[row, col:], mask)` + +set `advanced=True` on `@pto.vkernel` and follow +[`tilelang-dsl/examples/v1_tadd_implicit_vecscope_demo.py`](/home/zhangzhendong/ptoas-workspace/PTOAS/tilelang-dsl/examples/v1_tadd_implicit_vecscope_demo.py). + +## Minimal Script Pattern + +Your own `.py` only needs to: +- import `tilelang_dsl` +- define a `@pto.vkernel` +- call `specialize(...)` +- call `mlir_text()` or `emit(path)` + +Minimal example: + +```python +from pathlib import Path +import tilelang_dsl as pto + + +@pto.vkernel( + op="eltwise_with_tile", + dtypes=[(pto.f32, pto.f16, pto.i32)], + name="my_kernel", +) +def kernel(inp: pto.TensorView, tile: pto.Tile, scale: pto.i32): + return None + + +specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(16, 32), + memory_space=pto.MemorySpace.UB, + ) +) + +print(specialized.mlir_text()) +specialized.emit(Path("/tmp/my_kernel.mlir")) +``` + +If `python3 your_script.py` reports `ModuleNotFoundError: tilelang_dsl`, it +means the package import path is missing. Re-run with one of: + +```bash +PYTHONPATH=$PWD/tilelang-dsl/python python3 your_script.py +PYTHONPATH=$PWD/build/python python3 your_script.py +``` + +## Optional Verifier Check + +To check that the generated MLIR passes the current repo VPTO authoring-stage +legality path: + +```bash +source scripts/ptoas_env.sh +PYTHONPATH=$PWD/tilelang-dsl/python python3 tilelang-dsl/examples/v1_verify_smoke.py /tmp/tilelang_v1_verify.mlir +build/tools/ptoas/ptoas --pto-arch a5 --pto-backend=vpto --emit-vpto \ + /tmp/tilelang_v1_verify.mlir -o /tmp/tilelang_v1_verify.checked.mlir +``` + +For the implemented authoring-form VPTO lowering contract, support matrix, +examples, and minimal validation commands, see +`tilelang-dsl/docs/v1-lowering.md`. + +Root-level wiring belongs to follow-up tasks and must stay minimal. diff --git a/tilelang-dsl/docs/README.md b/tilelang-dsl/docs/README.md new file mode 100644 index 000000000..1a8f35a90 --- /dev/null +++ b/tilelang-dsl/docs/README.md @@ -0,0 +1,48 @@ +# TileLang DSL 文档 + +TileLang Python DSL 为面向 Ascend NPU 硬件的向量计算和矩阵乘法(Cube)内核提供高级的 Pythonic 接口。本指南适用于需要编写高效、硬件感知内核的库开发人员和性能工程师。 + +## 文档结构 + +### 入门指南 +- [简介](user_guide/01-introduction.md) - 语言概述、层级、基本vs高级模式 +- [快速开始](user_guide/02-quick-start.md) - 快速入门示例 + +### 核心概念 +- [内核声明](user_guide/03-kernel-declaration.md) - 内核声明、装饰器参数、约束系统 +- [模板内核](user_guide/04-template-kernels.md) - 模板内核、多操作内核、编译时代换 + +### 类型系统 +- [类型系统](user_guide/05-type-system.md) - 标量、向量、指针、TensorView、Tile 类型 + +### 控制流 +- [控制流](user_guide/06-control-flow.md) - 向量作用域、循环、条件语句 + +### 操作参考 +- [前端操作](user_guide/07-frontend-operations.md) - 前端操作、类型查询、指针构造 +- [同步和DMA操作](user_guide/08-sync-dma-operations.md) - 同步和DMA操作 +- [向量内存操作](user_guide/09-vector-memory-operations.md) - 向量加载和存储操作 +- [谓词操作](user_guide/10-predicate-operations.md) - 谓词操作 +- [向量算术操作](user_guide/11-vector-arithmetic-operations.md) - 向量算术操作 +- [Cube 矩阵乘法操作](user_guide/12-cube-operations.md) - Cube 数据搬运与矩阵乘法操作 + +### 示例和错误处理 +- [示例](user_guide/13-examples.md) - 各种 Vector 和 Cube 内核示例 +- [常见错误](user_guide/14-common-errors.md) - 常见错误和解决方案 + +### 附录 +- [兼容性说明](user_guide/15-compatibility-notes.md) - 与实验实现的差异 +- [后续步骤](user_guide/16-next-steps.md) - 相关资源链接 + +## 相关文档 +- [v1-surface.md](v1-surface.md) - TileLang DSL v1 合约 +- [v1-lowering.md](v1-lowering.md) - TileLang DSL v1 降低合约 +- [matcher-and-advanced-surface-migration.md](matcher-and-advanced-surface-migration.md) - 迁移说明 +- [unsupported-features.md](unsupported-features.md) - 不支持的功能 + +--- + +**原始文档边界说明**: +- `tilelang-dsl/docs/` 是新的 `tilelang_dsl` 前端本地文档的真实来源 +- 仓库级文档可以链接到这里,但不应重新定义此包实现的 v1 边界 +- `python/pto/dialects/pto.py` 不是 TileLang DSL v1 的真实来源 diff --git a/tilelang-dsl/docs/matcher-and-advanced-surface-migration.md b/tilelang-dsl/docs/matcher-and-advanced-surface-migration.md new file mode 100644 index 000000000..56a249c97 --- /dev/null +++ b/tilelang-dsl/docs/matcher-and-advanced-surface-migration.md @@ -0,0 +1,381 @@ +# TileLang DSL Matcher And Advanced-Surface Migration + +## Scope + +This document explains how to move from the original v1 core contract +(`add-tilelang-dsl-core-foundation` + +`add-tilelang-dsl-authoring-vpto-lowering`) to the matcher and +advanced-surface capability implemented by +`extend-tilelang-dsl-matcher-and-advanced-surface`, and how to adopt the +template-slot authoring model added by +`extend-tilelang-dsl-template-op-slots`. + +It focuses on: +- matcher-driven kernel selection +- migration from explicit real `pto.*` calls to template-slot authoring +- implicit vecscope inference +- raw pointer / low-level DMA authoring +- advanced vector-family coverage that is implemented today +- the remaining deferred boundary + +## Current Tier Snapshot + +This migration note lives at the boundary between the basic starter path and +the broader expert surface. The public-surface groups discussed across the +guide, this migration note, and the support matrix currently map to tiers as +follows: + +| Surface Family | Tier | Migration Meaning | +|----------------|------|-------------------| +| `TensorView` | `basic` | Keep as the default GM-facing operand model. | +| `Tile` | `basic` | Keep as the default UB-facing compute tile model. | +| `dma_load` / `dma_store` | `basic` | Keep as the preferred high-level GM <-> UB path. | +| Base vector ops such as `make_mask`, `vlds`, `vsts`, `vadd`, `vmuls` | `basic` | Keep as the default compute skeleton before dropping to expert surfaces. | +| Raw pointer family such as `ptr(...)`, `castptr`, `addptr` | `advanced` | Use when moving from the starter path to expert pointer-form authoring. | +| Low-level DMA family such as `copy_*` and `set_loop*_stride_*` / `set_loop_size_*` | `advanced` | Use only when the high-level DMA surface is not sufficient. | +| Tile helper family such as `tile.slice(...)`, `tile.reshape(...)`, `tile.as_ptr()`, `tile_from_ptr(...)`, `tile_with_strides(...)`, `tile_config(...)` | `advanced` | Treat as partial or evolving surface rather than part of the basic starter path. | + +For the exact tier source of truth, see +`tilelang-dsl/python/tilelang_dsl/support_matrix.py`. + +## What Changed + +The original v1 core profile assumed: +- one monomorphic `dtypes` signature +- no matcher registry or selection API +- explicit `pto.strict_vecscope` for vector code +- no raw-pointer or low-level DMA authoring surface +- no advanced vector-family lowering beyond the fixed elementwise set + +The current package now adds: +- `KernelRegistry` +- `pto.select_kernel(...)` +- multi-signature `dtypes` +- multi-op descriptors via `op=` / `ops=[...]` +- `AnyFloat`, `AnyInt`, `AnyType`, `AnyMask` +- `TypeVar(...)` +- `constraints=[...]` +- `priority=` +- descriptor-bound `selected_op` for multi-op matches +- `templates={...}` +- `pto.tpl("slot", ...)` +- implicit vecscope inference in `advanced=True` kernels +- `ptr(...)` / `PointerType` +- `castptr`, `addptr` +- low-level DMA config/copy surface +- compare/select, predicate movement, carry, and rearrangement families + +## Matcher Migration + +### Before + +The original v1 contract only supported one concrete signature: + +```python +@pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32)]) +def kernel(inp: pto.TensorView, out: pto.Tile): + return None +``` + +### After + +You can now register multiple polymorphic descriptors and let the matcher pick +the concrete specialization: + +```python +@pto.vkernel( + op="eltwise", + dtypes=[ + (pto.AnyFloat, pto.AnyFloat), + (pto.AnyInt, pto.AnyInt), + ], + constraints=[lambda enabled=True: enabled], + priority=10, +) +def kernel(inp: pto.TensorView, out: pto.Tile): + return None + +selected = pto.select_kernel( + "a5", + "eltwise", + (pto.f32, pto.f32), + context_attrs={"enabled": True}, +) +``` + +Matcher rules in the implemented package: +- matching is deterministic +- selection order is `target -> op -> dtypes -> constraints -> priority` +- highest-priority ties raise an explicit error +- `TypeVar` only binds within one signature +- `op=` and `ops=[...]` are mutually exclusive +- `ops=[...]` only widens the descriptor's matcher set; callers still query + `pto.select_kernel(...)` with one concrete op +- when a multi-op descriptor matches, the returned descriptor is already bound + to one concrete `selected_op` + +Matcher diagnostics are also available through the opt-in report path: + +```python +report = pto.select_kernel( + "a5", + "eltwise", + (pto.f32, pto.f32), + context_attrs={"enabled": False}, + return_metadata=True, + include_mlir=False, +) +``` + +In report mode: + +- `report.final_status` summarizes the overall outcome +- `report.candidates` keeps one record per `target/op`-matched descriptor +- constraint failures expose `failed_constraint_index`, + `failed_constraint_name`, and `failed_constraint_location` +- `include_mlir=True` additionally collects `mlir_text` or `mlir_error` for + candidates that pass constraint evaluation + +For clearer diagnostics, prefer writing multiple small constraint entries over a +single compound Python predicate. Report mode can identify which constraint +callable failed, but it does not decompose `cond0 and cond1` inside one +callable. + +For explicit single-op kernels that already map 1:1 to one real PTO op, you +do not need to migrate anything. Keep `op="..."` and keep authoring explicit +real `pto.*` calls in the kernel body. + +For shared-family kernels, the matcher migration usually comes first: +- change one descriptor from `op="..."` to `ops=[...]` +- continue selecting with concrete query ops +- rely on `selected_op` only as internal compile-time context for later + template-slot expansion + +Materialization boundary for multi-op descriptors: +- a descriptor registered with `ops=[...]` cannot directly `mlir_text()`, + `mlir_module()`, `verify()`, or `emit(path)` before selection +- call `pto.select_kernel(...)` first so the returned descriptor carries one + concrete `selected_op` + +## Vecscope Migration + +### Before + +Vector code needed an explicit `pto.strict_vecscope` boundary: + +```python +with pto.strict_vecscope(tile, tile, 0, 256, 64) as (src, dst, lb, ub, step): + for lane in range(lb, ub, step): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + vec = pto.vlds(src, lane) + pto.vsts(vec, dst, lane, mask) +``` + +### After + +In `advanced=True` kernels, the frontend now infers `pto.vecscope` for +contiguous vector-active regions: + +```python +@pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32)], advanced=True) +def kernel(src: pto.Tile, dst: pto.Tile): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + vec = pto.vlds(src[0, 0:]) + pto.vsts(vec, dst[0, 0:], mask) +``` + +Inference boundaries in the implemented package: +- scalar statements cut inference +- `if` / `for` structure is respected +- sync and DMA statements cut inference +- explicit `pto.strict_vecscope` remains a hard boundary + +Use `pto.strict_vecscope` when you need a deterministic region ABI or do not +want inference to merge adjacent vector chains. + +## Template-Slot Migration + +Template slots are the migration path for kernels whose control-flow, +load/store pattern, masks, and surrounding vector scaffolding stay the same +while one or a few real `pto.*` ops differ by concrete matcher op. + +### When To Keep Explicit Real `pto.*` Calls + +Keep the original style when: +- the kernel only serves one concrete op +- different ops need structurally different loops, masks, DMA scheduling, or + control flow +- the body is clearer when the real op is written directly +- there is no duplication pressure worth introducing `ops=[...]` and + `templates={...}` + +Example: + +```python +@pto.vkernel(op="tadd", dtypes=[(pto.f32, pto.f32, pto.f32)], advanced=True) +def add_kernel(lhs: pto.TensorView, rhs: pto.TensorView, out: pto.Tile): + with pto.strict_vecscope(out, lhs, 0, 256, 64) as (_, _, lb, ub, step): + for lane in range(lb, ub, step): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + lhs_v = pto.vlds(lhs, lane) + rhs_v = pto.vlds(rhs, lane) + out_v = pto.vadd(lhs_v, rhs_v, mask) + pto.vsts(out_v, out, lane, mask) +``` + +### When To Migrate To Template Slots + +Migrate when: +- several concrete ops share the same loop skeleton +- only the core vector op or a small number of real `pto.*` calls differ +- you want one descriptor and one kernel body to cover a whole op family +- you still want deterministic compile-time expansion, not runtime dispatch + +Recommended pattern: + +```python +@pto.vkernel( + ops=["tadd", "tsub", "tmul", "tdiv"], + dtypes=[(pto.f32, pto.f32, pto.f32)], + advanced=True, + templates={ + "core": { + "tadd": "vadd", + "tsub": "vsub", + "tmul": "vmul", + "tdiv": "vdiv", + } + }, +) +def arithmetic_kernel(lhs: pto.TensorView, rhs: pto.TensorView, out: pto.Tile): + with pto.strict_vecscope(out, lhs, 0, 256, 64) as (_, _, lb, ub, step): + for lane in range(lb, ub, step): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + lhs_v = pto.vlds(lhs, lane) + rhs_v = pto.vlds(rhs, lane) + out_v = pto.tpl("core", lhs_v, rhs_v, mask) + pto.vsts(out_v, out, lane, mask) + +selected = pto.select_kernel( + "a5", + "tmul", + (pto.f32, pto.f32, pto.f32), +) +``` + +In this model: +- `ops=[...]` defines which concrete ops the descriptor may match +- `pto.select_kernel(...)` still receives one concrete op such as `"tmul"` +- the selected descriptor carries `selected_op="tmul"` +- frontend expansion rewrites `pto.tpl("core", ...)` to the real call for + that selected concrete op, such as `pto.vmul(...)` + +The example in +`tilelang-dsl/examples/v1_template_slot_multiop_demo.py` shows this shared +kernel-body migration pattern end to end. + +### Migration Checklist + +When converting an existing family of explicit kernels to template slots: +1. Confirm the kernels only differ in a few real `pto.*` calls. +2. Keep one shared body and move the op differences into + `templates={...}` slot mappings. +3. Replace the differing real calls with `pto.tpl("slot", ...)`. +4. Switch the descriptor from `op="..."` to `ops=[...]`. +5. Ensure all materialization goes through `pto.select_kernel(...)` so the + descriptor is bound to one concrete `selected_op`. + +### Boundaries And Non-Goals + +Template-slot migration is intentionally narrow: +- `pto.tpl("slot", ...)` is a compile-time placeholder, not a runtime helper +- the first argument must be a string literal slot name +- template mappings live in descriptor metadata, not in kernel-body Python + dictionaries +- callable-based dispatch such as `table["core"](...)` or `resolver(...)` + remains outside the DSL contract +- unresolved multi-op descriptors must not materialize before + `pto.select_kernel(...)` binds one concrete `selected_op` + +Template slots are not the right abstraction when: +- the kernels differ in control-flow structure, not just in a few ops +- one op variant needs extra DMA, sync, or pointer logic that the others do + not share +- you need arbitrary Python-level dispatch or dynamic selection inside the + kernel body + +## Pointer And DMA Migration + +### New Pointer Surface + +The package now exposes: +- `pto.ptr(dtype, memory_space)` +- pointer-typed parameters such as `pto.ptr(pto.f32, pto.MemorySpace.UB)` +- `pto.castptr(...)` +- `pto.addptr(...)` + +Example: + +```python +@pto.vkernel(op="copy", dtypes=[(pto.f32, pto.i64)], advanced=True) +def kernel(dst: pto.ptr(pto.f32, pto.MemorySpace.UB), addr: pto.i64): + src = pto.castptr(addr, pto.ptr(pto.f32, pto.MemorySpace.UB)) + next_src = pto.addptr(src, 64) + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + vec = pto.vlds(src, 0) + pto.vsts(vec, next_src, 0, mask) +``` + +### New Low-Level DMA Surface + +The package now lowers: +- `set_loop2_stride_outtoub` +- `set_loop1_stride_outtoub` +- `set_loop_size_outtoub` +- `set_loop2_stride_ubtoout` +- `set_loop1_stride_ubtoout` +- `set_loop_size_ubtoout` +- `copy_gm_to_ubuf` +- `copy_ubuf_to_gm` +- `copy_ubuf_to_ubuf` + +High-level `dma_load` / `dma_store` remain the preferred default. Use the +low-level surface only when you need manual DMA programming. + +## Advanced Vector Families + +The currently implemented advanced-family groups are: +- compare/select: + `vcmp`, `vcmps`, `vsel`, `vselr`, `vselrv2` +- predicate movement: + `pnot`, `psel`, `ppack`, `punpack` +- carry family: + `vaddc`, `vsubc`, `vaddcs`, `vsubcs` +- rearrangement: + `vintlv`, `vdintlv`, `vintlvv2`, `vdintlvv2` + +These lower directly to authoring-form VPTO and are covered by +`tilelang-dsl/tests/test_tilelang_dsl_v1.py`. + +## Still Deferred + +The following boundary remains intentionally deferred: +- reduction family authoring + +Reason: +- the current repo does not expose a public authoring-form VPTO reduction op + that TileLang DSL can target directly +- existing reduction logic lives in other lowering paths such as OpLib / EmitC + and cannot be treated as the public TileLang DSL authoring contract + +Current package behavior: +- reduction-family surface remains an explicit frontend reject +- no extra helper IR is introduced to fake reduction support + +## Recommended Reading Order + +For the current package contract, read in this order: +1. `tilelang-dsl/docs/v1-surface.md` +2. `tilelang-dsl/docs/v1-lowering.md` +3. `tilelang-dsl/docs/matcher-and-advanced-surface-migration.md` +4. `docs/tilelang-dsl-guide.md` diff --git a/tilelang-dsl/docs/unsupported-features.md b/tilelang-dsl/docs/unsupported-features.md new file mode 100644 index 000000000..d47cba7d8 --- /dev/null +++ b/tilelang-dsl/docs/unsupported-features.md @@ -0,0 +1,228 @@ +# TileLang DSL Unsupported And Partial Features + +## Scope + +This document records the gap between the broad language surface described in +`tilelang-dsl-guide.md` and what the current standalone `tilelang_dsl` package +actually implements under: + +- `tilelang-dsl/python/tilelang_dsl/` +- `tilelang-dsl/tests/` + +Use this file as a quick "what is still missing" index. For the implemented +contract, treat these as the source-of-truth companion documents: + +- `v1-surface.md` +- `v1-lowering.md` +- `matcher-and-advanced-surface-migration.md` + +## Status Labels + +- `Unsupported`: the public surface is documented but not exported or not + accepted by the frontend at all. +- `Partial`: the concept exists, but only a narrower subset works in the + current implementation. + +## Unsupported Features + +### Missing Public Type Constructors And Aliases + +The guide documents a richer type-construction surface that is not exported by +the current package: + +- `pto.tile(...)` +- `SyncOpType` + +Today, the public package exports annotation markers (`TensorView`, `Tile`), +scalar dtypes, `ptr(...)`, `PadMode`, `BLayout`, `SLayout`, `PadValue`, +`TileConfig`, matcher APIs, and a small set of enums. The list above covers the +remaining missing public constructors and aliases from the guide. + +### Missing Tile/Tensor Utility Methods + +The following guide surfaces are not implemented as public APIs: + +- `tile.slice(...)` +- `tile.reshape(...)` +- `pto.tile_from_ptr(...)` +- `pto.tile_with_strides(...)` +- `pto.tile_config(...)` + +### Missing Vector Load/Store Families + +The current package supports the core v0.3 load/store subset: + +- `pto.vlds(...)` +- `pto.vsts(...)` +- `pto.vldsx2(...)` +- `pto.vstsx2(...)` +- `pto.load_scalar(...)` +- `pto.store_scalar(...)` + +The following documented load/store families are still unsupported: + +- `pto.vsld(...)` +- `pto.vstu(...)` + +### Missing Direct Predicate Constructor/Compare APIs + +The implementation expects users to go through `pto.make_mask(...)` rather than +call the underlying mask ops directly. These guide-documented APIs are not part +of the supported authoring surface: + +- `pto.pset_b8(...)`, `pto.pset_b16(...)`, `pto.pset_b32(...)` +- `pto.pge_b8(...)`, `pto.pge_b16(...)`, `pto.pge_b32(...)` +- `pto.plt_b8(...)`, `pto.plt_b16(...)`, `pto.plt_b32(...)` + +### Missing Extended Vector Arithmetic Families + +The previously missing `11-vector-arithmetic-operations.md` gap list is now +implemented in the current package surface (including fused ops, broadcast/index +generation, reduction-flavored ops, and rearrangement/sort groups). + +### Deferred Surface + +`pto.vreduce(...)` is still explicitly deferred and remains rejected even in +`advanced=True` kernels. + +## Partial Features + +### Scalar Constants And Literal Typing + +The guide describes automatic `float -> pto.f32` literal typing. + +Literal support currently includes: + +- `bool` +- `int` +- `str` +- `None` + +### TensorView Attribute Model + +`TensorView` currently supports only a narrow attribute subset: + +- `shape` +- `strides` +- `element_type` +- `valid_shape` + +The following documented attributes are not implemented: + +- `offset` + +In practice, `TensorView` is now modeled as a fixed 5D GM view in the current +profile, but the DMA-oriented slicing/lowering path remains narrower than the +full guide: + +- `shape` / `valid_shape` exposure follows the 5D descriptor +- `strides` lower through hidden stride parameters carried alongside TensorView shape +- fewer written slice axes are right-aligned onto the trailing physical axes +- DMA-oriented slicing/lowering still only accepts rank-2 TensorView slices + +### Tile Attribute Model + +`Tile` currently exposes the documented metadata/query surface used by the user guide: + +- `shape` +- `element_type` +- `memory_space` +- `valid_shape` +- `config` +- `rank` + +Current constraints still apply: + +- only statically specialized rank-1/rank-2 UB tiles are supported +- `TileConfig` is queryable metadata, but lowering still renders the fixed baseline + layout contract unless later backend work teaches richer layout semantics + +### Tile Config Semantics + +`TileConfig` can be attached during specialization, but lowering does not yet +honor the rich layout/padding semantics described in the guide. The rendered +tile type is effectively fixed to a hard-coded baseline: + +- `blayout=row_major` +- `slayout=none_box` +- `fractal=512` +- `pad=0` + +So this is currently metadata storage rather than full behavioral support. + +### TensorView Slicing + +The guide presents general Python slicing with dynamic starts and strides. The +current stable DMA-oriented implementation is still a narrower 2D profile: + +- slice `stop` must be explicit on all dimensions +- slice `start` may be a compile-time constant or runtime index expression +- slice `step` must be a static positive integer +- dimension 0 may use `step > 1` +- dimension 1 must keep `step == 1` (current DMA restriction) + +Dynamic bounds are supported within those constraints. + + +### Tile Indexing Sugar + +Tile indexing sugar is partially implemented on the stable authoring path. + +Currently supported: + +- rank-1: `tile[start:]` +- rank-2: `tile[row, col:]` +- rank-2 column-major: `tile[row_start:, col_index]` +- for `pto.vlds(...)`, `pto.vsts(...)`, `pto.vldsx2(...)`, and `pto.vstsx2(...)` + +Not currently supported from the guide's broader indexing model: + +- single-element syntax such as `tile[row, col]` and `tile[pos]` +- explicit slice `stop` +- stepped tile vector slices +- the remaining wider indexed op family gap (`vsld`) + +### Control-Flow Result Merging + +The frontend does analyze loop-carried values and merged `if` results, but +lowering still has a hard limit: + +- at most one loop-carried binding per loop +- at most one merged `if`/`else` binding per conditional + +So the language feature exists conceptually, but multi-value merge cases are +not fully lowered yet. + +### Tile Profile Breadth + +The guide discusses Tile memory spaces in more general terms, but bare Tile +specialization still only accepts: + +- rank-1 or rank-2 Tiles +- static physical shape +- `MemorySpace.UB` + +So GM Tiles and more general profiles are not supported yet. + +## Currently Implemented Core Surface + +For quick orientation, the current package head is strongest in these areas: + +- matcher-driven kernel selection +- `templates={...}` and `pto.tpl(...)` +- `ptr(...)`, `pto.castptr(...)`, `pto.addptr(...)` +- low-level DMA config/copy ops +- runtime block queries (`pto.get_block_idx`, `pto.get_block_num`, ...) +- `pto.make_mask(...)` +- `pto.vlds(...)`, `pto.vsts(...)`, `pto.vldsx2(...)`, `pto.vstsx2(...)` +- `pto.load_scalar(...)` and `pto.store_scalar(...)` +- base unary/binary/vector-scalar vector ops +- advanced compare/select/carry/rearrangement families + +If you need the exact supported boundary for implementation work, prefer the +source files and tests over the broader guide: + +- `tilelang-dsl/python/tilelang_dsl/support_matrix.py` +- `tilelang-dsl/python/tilelang_dsl/semantic.py` +- `tilelang-dsl/python/tilelang_dsl/lowering.py` +- `tilelang-dsl/tests/test_tilelang_dsl_v1.py` diff --git a/tilelang-dsl/docs/user_guide/01-introduction.md b/tilelang-dsl/docs/user_guide/01-introduction.md new file mode 100644 index 000000000..26012f781 --- /dev/null +++ b/tilelang-dsl/docs/user_guide/01-introduction.md @@ -0,0 +1,47 @@ +# TileLang Python DSL Guide + +The TileLang Python DSL provides a high-level, Pythonic interface for authoring vector compute kernels targeting the Ascend NPU hardware. This guide is intended for library developers and performance engineers who need to write efficient, hardware-aware kernels using the PTO micro instruction set. + +The DSL is designed to generate MLIR function libraries rather than direct binary executables. These MLIR libraries are intended to be consumed by other compilation frameworks that transform high-level tile semantics into low-level vector operations. This enables library developers to focus on hardware-aware kernel authoring while relying on upstream compilers for tile-level optimizations and code generation. + +## Language Tier + +The DSL surface is organized into multiple maturity tiers, reflecting the stability and intended use of different language features. As the design evolves, the basic authoring path is being explicitly separated from more advanced surfaces. Refer to the following table when reading this guide: + +| Surface Family | Tier | Usage Guidance | +|----------------|------|----------------| +| `TensorView` | `basic` | Default GM-facing data model for starter kernels. | +| `Tile` | `basic` | Default UB-facing compute tile for starter kernels. | +| Base vector ops (`make_mask`, `vlds`, `vsts`, `vadd`, `vmuls`, etc.) | `basic` | Default compute skeleton for starter kernels. | +| `strict_vecscope` | `advanced` | Explicit vector-scope management for expert authoring. | +| Raw pointer family (`ptr(...)`, `castptr`, `addptr`) | `advanced` | For expert authoring and migration; not required for Quick Start. | +| DMA family (`copy_*`, `set_loop*_stride_*`, `set_loop_size_*`, pad-fill control) | `advanced` | Direct DMA engine control for expert authoring, including GM→UB padding behavior. | +| Tile pointer helper (`tile.as_ptr()`) | `advanced` | Expert-only helper when advanced authoring needs explicit typed pointers. | + +For the authoritative tier classification, consult `tilelang-dsl/python/tilelang_dsl/support_matrix.py`. For known implementation gaps, refer to `tilelang-dsl/docs/unsupported-features.md`. + +### Basic vs Advanced Authoring Modes + +The TileLang DSL provides two distinct authoring modes: + +**Basic Mode (default)** +- Uses **Tile element/slice semantics** for buffer access +- Direct tile indexing syntax: `tile[start:]`, `tile[row, col:]`, `tile[row:, col]` (Tile indexing sugar only supports open-ended vector slices; explicit `stop` and `step` forms are not accepted for `Tile` indexing) +- Vector operations use element-indexing syntax: `pto.vlds(tile[row, col:])`, `pto.vsts(vec, tile[start:], mask)` +- No pointer arithmetic or explicit offset calculations +- Suitable for most kernel authoring with high-level abstractions + +**Advanced Mode (`advanced=True` in `@pto.vkernel`)** +- Uses **raw pointer semantics** for explicit memory management +- Direct pointer operations correspond to `pto.ptr` types in MLIR +- Explicit pointer arithmetic: `ptr(...)`, `castptr`, `addptr` +- Manual DMA engine control with low-level copy operations and explicit GM→UB padding behavior +- Requires explicit buffer management and pointer arithmetic +- Intended for expert users and performance-critical optimizations + +**Key Differences** +- **Basic mode**: Uses tile element-indexing syntax (`tile[row, col:]`, `tile[start:]`) for vector operations +- **Advanced mode**: Uses pointer byte-offset syntax (`pto.vlds(buf: ptr, offset)`) for vector operations +- Tile slices in basic mode correspond to MLIR `memref` types +- Raw pointers in advanced mode correspond to MLIR `pto.ptr` types +- No automatic conversion between tile and pointer semantics - choose the appropriate syntax for your authoring mode diff --git a/tilelang-dsl/docs/user_guide/02-quick-start.md b/tilelang-dsl/docs/user_guide/02-quick-start.md new file mode 100644 index 000000000..26b0ba58b --- /dev/null +++ b/tilelang-dsl/docs/user_guide/02-quick-start.md @@ -0,0 +1,78 @@ +## Quick Start + +**Note on mask pattern enums**: For brevity, examples in this guide use `PAT` as an alias for `pto.MaskPattern` (e.g., `PAT.ALL` instead of `pto.MaskPattern.PAT_ALL`). You can create this alias with `from pto import MaskPattern as PAT` or `PAT = pto.MaskPattern`. + +TileLang DSL provides the following core constructs for kernel authoring: + +- `TensorView` – Access global memory (GM) tensors +- `Tile` – Local computation buffers in unified buffer (UB) +- Base vector operations (`make_mask`, `vlds`, `vmuls`, `vadd`, `vsts`) – Perform vector computations + +A typical kernel follows the GM → UB → vector compute → GM pattern: + +```python +import tilelang_dsl as pto + +@pto.vkernel(target="a5", op="scale", dtypes=[(pto.f32, pto.f32, pto.f32, pto.f32)]) +def tile_scale( + input_tensor: pto.TensorView, + output_tensor: pto.TensorView, + work_tile: pto.Tile, + scale_factor: pto.f32, +): + dim0 = 4 + dim1 = 16 + + # Stage one GM tile into UB. + # GM -> UB data movement (implementation detail) + + # Run vector compute over the UB tile using tile indexing sugar. + for i in range(0, dim0): + mask = pto.make_mask(pto.f32, PAT.ALL) + vec = pto.vlds(work_tile[i, 0:]) + scaled = pto.vmuls(vec, scale_factor, mask) + pto.vsts(scaled, work_tile[i, 0:], mask) + + # Write the UB result back to GM. + # UB -> GM data movement (implementation detail) +``` + +The example illustrates the key components of a TileLang kernel: + +1. **`TensorView` parameters** – Access global memory tensors +2. **`Tile` parameters** – Local computation buffers in unified buffer (UB) +3. **Base vector operations** (`make_mask`, `vlds`, `vmuls`, `vadd`, `vsts`) – Perform vector computations + +Here is a second example with two inputs and one output: + +```python +@pto.vkernel( + target="a5", + op="elementwise_add", + dtypes=[(pto.f32, pto.f32, pto.f32, pto.f32, pto.f32, pto.f32)], +) +def elementwise_add( + lhs_gm: pto.TensorView, + rhs_gm: pto.TensorView, + out_gm: pto.TensorView, + lhs_tile: pto.Tile, + rhs_tile: pto.Tile, + dst_tile: pto.Tile, +): + dim0 = 4 + dim1 = 16 + + # GM -> UB data movement (implementation detail) + + for lane in range(0, 256, 64): + mask = pto.make_mask(pto.f32, PAT.ALL) + lhs_vec = pto.vlds(lhs_tile, lane) + rhs_vec = pto.vlds(rhs_tile, lane) + summed = pto.vadd(lhs_vec, rhs_vec, mask) + pto.vsts(summed, dst_tile, lane, mask) + + # UB -> GM data movement (implementation detail) +``` + +Both examples follow the same fundamental pattern: load data from global memory into local tiles, perform vector operations, and store results back. The compiler automatically infers vector-scope boundaries for the base vector operations. The `Tile` parameters are specialized to concrete shapes during compilation. Later sections cover advanced features such as matchers, template slots, raw pointer operations, and explicit scope management with `strict_vecscope`. + diff --git a/tilelang-dsl/docs/user_guide/03-kernel-declaration.md b/tilelang-dsl/docs/user_guide/03-kernel-declaration.md new file mode 100644 index 000000000..73c9e1800 --- /dev/null +++ b/tilelang-dsl/docs/user_guide/03-kernel-declaration.md @@ -0,0 +1,528 @@ +## Core Concepts + +### Kernel Declaration + +TileLang DSL exposes two kernel decorators: + +- `@pto.vkernel` for the Vector (AIV) execution model +- `@pto.ckernel` for the Cube (AIC) execution model + +#### Basic Syntax + +```python +@pto.vkernel( + target="a5", # Target architecture + op="pto.matmul ins(a, b) -> outs(c)", # PTO op + operand schema + dtypes=[(pto.f16, pto.f16, pto.f32)], # Type signatures + constraints=[ # Additional constraints + lambda a, b: a.shape[1] == b.shape[0], + lambda batch=1: batch >= 1, + ], + priority=100 # Priority for selection +) +def matmul_fallback(a: pto.Tile, b: pto.Tile, c: pto.Tile) -> None: + # kernel implementation +``` + +#### Decorator Parameters + +| Parameter | Type | Required | Description | +|-----------|------|----------|-------------| +| `target` | `str` | Yes | Target hardware architecture (e.g., `"a5"` for Ascend 950). | +| `op` | `str` | No* | PTO operation matcher. Preferred form is schema mode: `"pto.op_name ins(in0, in1, ...) -> outs(out0, out1, ...)"`. Legacy bare-op form (`"pto.op_name"`) is still accepted for compatibility. **Mutually exclusive with `ops`**. | +| `ops` | `List[str]` | No* | List of PTO operation names to match. **Mutually exclusive with `op`**. Use this when one descriptor should match multiple concrete ops (schema mode is currently only supported in `op`). | +| `dtypes` | `List[Tuple[Type, ...]]` | Yes | List of type signatures. Each tuple specifies the expected data types for the operation's operands (inputs and outputs) in order. | +| `templates` | `Dict[str, Dict[str, str]]` | No | Static template-slot mappings. Each slot maps concrete matcher ops to real `pto.*` op names. Required when the kernel body uses `pto.tpl(...)`. | +| `constraints` | `List[Callable[..., bool]]` | No | Additional selection-time predicates. Constraint arguments bind by name to kernel parameter proxy objects or `context_attrs` keys. Default: empty list. | +| `priority` | `int` | No | Selection priority when multiple kernels match. Higher values have higher priority. Default: `0`. | +| `name` | `str` | No | Kernel name (used for debugging and profiling). Defaults to the decorated function's name. | +| `advanced` | `bool` | No | Enable advanced-tier DSL surfaces (for example `strict_vecscope`, raw pointer family, and low-level DMA family). Implicit vecscope inference is available in both modes and runs only when no explicit `with pto.vecscope():` is present. Default: `False`. | + +#### Operation Schema in `op` (ins/outs) + +`op` supports a schema string that declares how kernel parameter names map to PTO op operands: + +```python +op="pto.tadds ins(src, scalar) -> outs(dst)" +``` + +Schema form: + +```text + ins(, , ...) -> outs(, , ...) +``` + +Rules: + +1. `ins(...)` and `outs(...)` are both required in schema mode. +2. Names in `ins` and `outs` must be valid, unique Python identifiers. +3. The decorated function parameter list must exactly match `ins + outs` by both count and name. +4. MLIR function argument ordering is defined by schema order (`ins` first, then `outs`). +5. Constraint binding keeps using parameter names; schema mode makes these names explicit and stable. +6. Schema mode applies to `op=...` (single matcher op). `ops=[...]` remains bare-op matching. + +Example: + +```python +@pto.vkernel( + target="a5", + op="pto.tadds ins(src, scalar) -> outs(dst)", + dtypes=[(pto.f32, pto.f32, pto.f32)], +) +def template_tadds(src: pto.Tile, scalar: pto.f32, dst: pto.Tile): + return None +``` + +If names or order do not match, descriptor construction fails early with a schema mismatch error. + + +#### Type Matching Rules + +The `dtypes` parameter supports flexible type matching: + +1. **Concrete Types**: Exact type matches using DSL scalar types: + - `pto.f16`, `pto.f32`, `pto.bf16` + - `pto.i8`, `pto.si8`, `pto.ui8` + - `pto.i16`, `pto.si16`, `pto.ui16` + - `pto.i32`, `pto.si32`, `pto.ui32` + - `pto.i64`, `pto.si64`, `pto.ui64` + - `pto.mask_b8`, `pto.mask_b16`, `pto.mask_b32` + + Builtin vector operands still use their element dtype in `dtypes=[...]`. + For example, a parameter annotated as `ex_vec: pto.vector(pto.i16, (4,))` + contributes `pto.i16` to the signature tuple, while the vector shape + contract stays in the parameter annotation. + +2. **Type Wildcards**: Generic type patterns: + - `pto.AnyFloat`: Matches any floating-point type (`f16`, `bf16`, `f32`) + - `pto.AnyInt`: Matches any integer type (`i*`, `si*`, `ui*`) + - `pto.AnyType`: Matches any scalar type + - `pto.AnyMask`: Matches any mask type (`mask_b8`, `mask_b16`, `mask_b32`) + +3. **Type Variables**: Named type variables that enforce consistency within a signature: + ```python + T = pto.TypeVar('T') # Define a type variable + + @pto.vkernel( + target="a5", + op="elementwise", + dtypes=[(T, T, T)], # All three operands must have the same type + constraints=[] + ) + def elementwise_same_type(x: pto.Tile, y: pto.Tile, out: pto.Tile) -> None: + # x, y, and out must have identical element types + pass + ``` + +4. **Mixed Signatures**: Multiple type signatures for the same operation: + ```python + @pto.vkernel( + target="a5", + op="add", + dtypes=[ + (pto.AnyFloat, pto.AnyFloat, pto.AnyFloat), # Float addition + (pto.AnyInt, pto.AnyInt, pto.AnyInt) # Integer addition + ] + ) + def generic_add(a: pto.Tile, b: pto.Tile, c: pto.Tile) -> None: + # Supports both float and integer types + pass + ``` + +#### Constraint System + +Constraints are compile-time predicates that refine kernel selection. In the current implementation, each entry in `constraints=[...]` is a Python callable returning `True` or `False`. + +##### Predefined Constraints + +| Constraint | Description | +|------------|-------------| +| `k_dim_aligned_64` | K dimension is aligned to 64 elements (for matmul kernels). | +| `continuous_memory` | Operands reside in contiguous memory regions. | +| `requires_ub_memory` | Operation requires Unified Buffer memory (vs. Global Memory). | +| `tensor_rank(rank)` | Operand tensor has specified rank (e.g., `tensor_rank(2)` for 2D tensors). | +| `broadcastable` | Operands are broadcastable according to NumPy-style broadcasting rules. | +| `static_shape` | All tensor dimensions are known at compile time (no dynamic shapes). | + +##### Logical Constraint Combinators + +| Combinator | Description | Example | +|------------|-------------|---------| +| `AnyOf(c1, c2, ...)` | At least one of the constraints must be satisfied. | `AnyOf(k_dim_aligned_64, continuous_memory)` | +| `AllOf(c1, c2, ...)` | All constraints must be satisfied. | `AllOf(tensor_rank(2), static_shape)` | +| `Not(c)` | The constraint must not be satisfied. | `Not(requires_ub_memory)` | + +##### Custom Constraints + +Users can define custom constraints using predicate functions: + +```python +# Define a custom constraint that consumes one context attr by name. +def large_batch(min_batch: int): + return lambda batch=0: batch >= min_batch + +@pto.vkernel( + target="a5", + op="pto.matmul ins(a, b) -> outs(c)", + dtypes=[(pto.f16, pto.f16, pto.f32)], + constraints=[large_batch(1024)] +) +def large_batch_matmul(a: pto.Tile, b: pto.Tile, c: pto.Tile) -> None: + # Optimized for large batch sizes + pass +``` + +Constraint callables bind by parameter name. + +- Kernel parameter names such as `src`, `dst`, `a`, `b` receive lightweight proxy objects, so constraints can use direct expressions like `src.shape[0] <= dst.shape[0]`. +- Extra `context_attrs` passed to `pto.select_kernel(...)` bind by key name, for example `batch`, `enabled`, or `expected_rows`. + +##### Parameter Proxy Objects + +When a constraint argument name matches a kernel parameter name, the callable receives a lightweight proxy object rather than raw Python data. + +- For `TensorView` parameters, the proxy exposes `rank`, `shape`, `strides`, `dtype`, and `memory_space`. +- For `Tile` parameters, the proxy exposes `rank`, `shape`, `valid_shape`, `dtype`, `memory_space`, and `config`. +- `shape`, `strides`, and `valid_shape` support index access such as `src.shape[0]` or `dst.valid_shape[1]`. +- Missing or not-yet-known metadata evaluates as "unknown", so comparisons conservatively pass rather than failing early. + +Example: + +```python +def tload_preconditions(src, dst): + logical_rows = src.shape[0] * src.shape[1] * src.shape[2] * src.shape[3] + logical_cols = src.shape[4] + return ( + src.rank == 5 + and src.strides[4] == 1 + and dst.valid_shape[0] <= logical_rows + and dst.valid_shape[1] <= logical_cols + and logical_rows <= dst.shape[0] + and logical_cols <= dst.shape[1] + ) + +@pto.vkernel( + target="a5", + op="pto.tload", + dtypes=[(pto.f32, pto.f32)], + constraints=[tload_preconditions], +) +def template_tload(src: pto.TensorView, dst: pto.Tile): + return None +``` + +This is the recommended constraint style for current TileLang DSL head. + +##### Builtin Vector Parameters + +When a kernel needs to match a builtin MLIR vector operand, annotate that +parameter with `pto.vector(element_dtype, shape)`. + +```python +@pto.vkernel( + target="a5", + op="pto.tmrgsort ins(src0, src1, tmp) -> outs(dst, ex_vec)", + dtypes=[(pto.f32, pto.f32, pto.f32, pto.f32, pto.i16)], +) +def template( + src0: pto.Tile, + src1: pto.Tile, + tmp: pto.Tile, + dst: pto.Tile, + ex_vec: pto.vector(pto.i16, (4,)), +): + return None +``` + +Rules: + +- Use `pto.vector(...)` for builtin vector operands, not Python `list`. +- `shape` is a Python tuple. A 1-D vector of length 4 is written `(4,)`. +- `dtypes=[...]` still records only the element dtype for that operand (`pto.i16` + in the example above). +- `pto.vector(...)` is distinct from `pto.vreg(...)`: the former models builtin + `vector<...>`, the latter models fixed-width VPTO vector registers. + +#### Kernel Selection Mechanism + +When a PTO operation needs implementation, the system performs the following matching process: + +1. **Target Filtering**: Select kernels with matching `target` architecture. +2. **Operation Filtering**: Select kernels whose matcher metadata covers the concrete query op: + - `op="foo"` requires exact match + - `op="foo ins(...) -> outs(...)"` still matches by op name `foo`; `ins/outs` additionally defines parameter naming/order contract for descriptor validation and materialization + - `ops=[...]` requires the concrete query op to appear in that list +3. **Type Matching**: For each kernel's `dtypes` list, check if any signature matches the operation's operand types: + - Concrete types must match exactly. + - Wildcard types match according to their category. + - Type variables must be consistent within the signature. +4. **Constraint Validation**: For each matching kernel, evaluate all `constraints`. If any constraint fails, the kernel is rejected. +5. **Priority Selection**: From the remaining kernels, select the one with the highest `priority` value. +6. **Fallback**: If no kernel matches, compilation fails with an error. + +For multi-op descriptors selected through `ops=[...]`, `pto.select_kernel(...)` +also binds the concrete query op before materialization. This bound +`selected_op` is what template-slot expansion uses later. + +The package also exposes explicit selection utilities: + +```python +registry = pto.KernelRegistry() +registry.register(my_kernel) + +selected = pto.select_kernel( + "a5", + "matmul", + (pto.f16, pto.f16, pto.f32), + context_attrs={"k_aligned": True}, + registry=registry, +) +``` + +`pto.select_kernel(...)` also supports an opt-in diagnostics path for matcher debugging: + +```python +report = pto.select_kernel( + "a5", + "matmul", + (pto.f16, pto.f16, pto.f32), + context_attrs={"k_aligned": False}, + return_metadata=True, + include_mlir=False, +) +``` + +When `return_metadata=True`, the result is a `KernelSelectionReport` instead of one +selected descriptor. + +- `report.selected` carries the winner when one candidate is selected. +- `report.final_status` is one of `selected`, `no_candidate`, or `priority_tie`. +- `report.final_error` summarizes the final selection outcome. +- `report.candidates` contains one `KernelSelectionCandidateMetadata` per + `target/op`-matched descriptor, including `dtype_mismatch`, + `constraint_failed`, `constraint_error`, `priority_shadowed`, `selected`, and + `priority_tie` states. + +Constraint diagnostics in report mode include: + +- `failed_constraint_index` +- `failed_constraint_name` +- `failed_constraint_location` as `file:line` + +For best diagnostics, prefer splitting compound predicates into multiple +constraint entries instead of writing one large `cond0 and cond1 and cond2` +callable. Report mode can precisely identify which constraint entry failed, but +it does not introspect which sub-expression inside one Python boolean +expression returned `False`. + +When `include_mlir=True`, report mode also attempts `mlir_text()` for candidates +that pass constraint evaluation. + +- On success, the candidate carries `mlir_text`. +- On materialization failure such as missing `specialize()` bindings, the + candidate carries `mlir_error`. +- Use `include_mlir=False` to skip this extra materialization attempt. + +#### Examples + +##### Matmul with Multiple Implementations + +```python +# High-performance kernel for aligned K dimension +def k_aligned_64(k=0): + return k % 64 == 0 + +@pto.vkernel( + target="a5", + op="pto.matmul ins(a, b) -> outs(c)", + dtypes=[(pto.f16, pto.f16, pto.f32)], + constraints=[k_aligned_64], + priority=200 +) +def matmul_aligned_k(a: pto.Tile, b: pto.Tile, c: pto.Tile) -> None: + # Optimized implementation for aligned K + pass + +# General-purpose fallback +@pto.vkernel( + target="a5", + op="pto.matmul ins(a, b) -> outs(c)", + dtypes=[(pto.AnyFloat, pto.AnyFloat, pto.AnyFloat)], + constraints=[], + priority=100 +) +def matmul_general(a: pto.Tile, b: pto.Tile, c: pto.Tile) -> None: + # Generic implementation + pass +``` + +##### Elementwise Operation with Type Polymorphism + +```python +def same_shape(a, b, out): + return a.shape[0] == out.shape[0] and b.shape[0] == out.shape[0] + +@pto.vkernel( + target="a5", + op="pto.add ins(a, b) -> outs(out)", + dtypes=[ + (pto.AnyFloat, pto.AnyFloat, pto.AnyFloat), + (pto.AnyInt, pto.AnyInt, pto.AnyInt) + ], + constraints=[same_shape] +) +def polymorphic_add(a: pto.Tile, b: pto.Tile, out: pto.Tile) -> None: + # Single implementation handles both float and integer types + dtype = a.element_type + all_mask = pto.make_mask(dtype, PAT.ALL) + # ... implementation using generic vector operations + pass +``` + +##### Constrained Convolution Kernel + +```python +def prefer_static_nhwc(src, weight): + return src.rank == 4 and weight.rank == 4 + +@pto.vkernel( + target="a5", + op="pto.conv2d ins(input, filter) -> outs(output)", + dtypes=[(pto.f16, pto.f16, pto.f32)], + constraints=[prefer_static_nhwc], + priority=150 +) +def conv2d_nhwc_f16_f32(input: pto.Tile, filter: pto.Tile, output: pto.Tile) -> None: + # Optimized for NHWC layout with static shapes + pass +``` + +--- + +### Cube Kernel Declaration + +Cube kernels target the AIC (Cube) hardware unit for matrix multiplication operations. Unlike Vector kernels, Cube kernels operate on raw `pto.ptr` pointers and do not use `vecscope` execution scopes. + +#### Basic Syntax + +```python +@pto.ckernel( + target="a5", + op="pto.mad", # concrete matcher op + dtypes=[(pto.f16, pto.f16, pto.f32)], # selection dtype signature + name="my_gemm", # optional registry/debug name +) +def gemm(inp: pto.TensorView): + # Cube kernel body — linear cube authoring IR + ... +``` + +#### Parameter Type Conventions + +Cube kernel parameters represent different roles in the data flow: + +| Parameter Type | Role | Description | +|---------------|------|-------------| +| `PartitionTensorView` | GM input/output | Tiled view of a logical tensor in GM, partitioned by the caller | +| `TensorView` | GM input/output | Full logical tensor view in GM (for non-partitioned use) | +| `Tile` (specific addr space) | Pre-allocated hardware buffer | Tile already allocated in LEFT/RIGHT/ACC/MAT/BIAS address space | +| `int` | Dimension | Scalar dimension parameter (M, K, N, etc.) | +| `pto.f16` / `pto.f32` etc. | Scalar | Scalar parameters (threshold, alpha, etc.) | + +GM payload is modeled through `TensorView` and `PartitionTensorView`. `Tile` +values represent staged hardware buffers allocated in concrete hardware address +spaces such as `MAT`, `LEFT`, `RIGHT`, `ACC`, and `BIAS` via `pto.Tile`. + +#### Decorator Parameters + +| Parameter | Type | Required | Description | +|-----------|------|----------|-------------| +| `target` | `str` | No | Target hardware architecture. Cube DSL v1 supports `"a5"`. Default: `"a5"`. | +| `op` | `str` | 与 `ops` 二选一 | Single concrete matcher op. Bare-op strings such as `"pto.mad"` are supported. **Mutually exclusive with `ops`**. | +| `ops` | `List[str]` | 与 `op` 二选一 | List of concrete matcher ops for shared-body selection and template-slot dispatch. **Mutually exclusive with `op`**. | +| `dtypes` | `List[Tuple[Type, ...]]` | Recommended | List of selection dtype signatures. For cube kernels, these signatures describe the concrete query op rather than necessarily mirroring the Python parameter list. | +| `templates` | `Dict[str, Dict[str, str]]` | No | Static template-slot mappings. Each slot maps concrete op names to real `pto.*` calls. Required when the kernel body uses `pto.tpl(...)`. | +| `name` | `str` | No | Descriptor name used for registration, debugging, and emitted symbol naming. Defaults to the decorated function name. | +| `priority` | `int` | No | Selection priority when multiple kernels match. Default: `0`. | + +#### Key Differences from `@pto.vkernel` + +| Feature | `@pto.vkernel` (Vector) | `@pto.ckernel` (Cube) | +|---------|--------------------------|------------------------| +| Hardware unit | AIV (Vector) | AIC (Cube) | +| Execution scope | `pto.vecscope` / `pto.strict_vecscope` | **No scope** — function body is linear IR | +| GM data input | `TensorView` / `Tile` | `TensorView` / `PartitionTensorView` | +| Operand abstraction | Tile + vector registers + masks | `pto.ptr` raw pointers | +| Core operations | Vector ALU, load/store | Data movement (cube_load/store) + matmul (mad) | +| Address spaces | GM, UB (VEC) | GM, MAT, LEFT, RIGHT, ACC, BIAS, UB | +| Generated IR attr | `#pto.kernel_kind` | `#pto.kernel_kind` | + +#### Programming Model + +Cube kernels follow a GM → L1 → L0 → compute → L0 → GM data flow: + +```python +@pto.ckernel( + target="a5", + op="pto.mad", + dtypes=[(pto.f16, pto.f16, pto.f32)], + name="gemm", +) +def gemm(a_tv: pto.PartitionTensorView, # [M, K] in GM + b_tv: pto.PartitionTensorView, # [K, N] in GM + c_tv: pto.PartitionTensorView): # [M, N] in GM, output + # 1. Get GM pointers from PartitionTensorViews + a_ptr = a_tv.as_ptr() # -> pto.ptr + b_ptr = b_tv.as_ptr() # -> pto.ptr + c_ptr = c_tv.as_ptr() # -> pto.ptr + + # 2. Allocate L1 (MAT) tile buffers (returns Tile, then get ptr) + l1_a = pto.Tile([16, 32], pto.f16, pto.MemorySpace.MAT) + l1_b = pto.Tile([32, 16], pto.f16, pto.MemorySpace.MAT) + + # 3. Allocate L0 tile buffers + l0a = pto.Tile([16, 32], pto.f16, pto.MemorySpace.LEFT) + l0b = pto.Tile([32, 16], pto.f16, pto.MemorySpace.RIGHT) + l0c = pto.Tile([16, 16], pto.f32, pto.MemorySpace.ACC) + + # 4. GM → L1 data movement + pto.cube_load(a_ptr, l1_a.as_ptr(), 16, nburst=(1, 0, 0)) + pto.cube_load(b_ptr, l1_b.as_ptr(), 16, nburst=(1, 0, 0)) + + # 5. L1 → L0 data movement + pto.left_load(l1_a.as_ptr(), l0a.as_ptr(), 16, 32) + pto.right_load(l1_b.as_ptr(), l0b.as_ptr(), 32, 16) + + # 6. Matrix multiplication + pto.mad(l0a.as_ptr(), l0b.as_ptr(), l0c.as_ptr(), 16, 16, 32) + + # 7. L0C → GM writeback + pto.acc_store_gm( + l0c.as_ptr(), c_ptr, 16, 16, 16, 16, mode=pto.FractalMode.NZ2ND + ) +``` + +This example shows a **full-pipeline** kernel that handles data movement and compute. Alternatively, a **pure-compute** kernel can take pre-allocated tiles directly: + +```python +@pto.ckernel( + target="a5", + op="pto.mad", + dtypes=[(pto.f16, pto.f16, pto.f32)], + name="matmul_compute", +) +def matmul_compute(a_left: pto.Tile, # Pre-allocated LEFT tile (L0A) + b_right: pto.Tile, # Pre-allocated RIGHT tile (L0B) + c_acc: pto.Tile): # Pre-allocated ACC tile (L0C) + pto.mad_acc(a_left.as_ptr(), b_right.as_ptr(), c_acc.as_ptr(), 16, 16, 32) +``` + +#### Hardware Isolation + +- `@pto.ckernel` functions generate `#pto.kernel_kind` IR attribute. +- `@pto.vkernel` functions generate `#pto.kernel_kind` IR attribute. +- The IR verifier prevents Cube and Vector operations from appearing in the same function. +- The DSL semantic analyzer additionally checks that Cube kernel bodies do not contain Vector-specific operations (`vlds`, `vadd`, etc.) or `vecscope` scopes. +- Both kernel types can coexist in the same `.py` file; each compiles independently with conditional compilation macros (`__DAV_CUBE__` / `__DAV_VEC__`). + +For the complete Cube operation reference and `pto.Tile` constructor details, see [Cube Matrix Multiply Operations](12-cube-operations.md). diff --git a/tilelang-dsl/docs/user_guide/04-template-kernels.md b/tilelang-dsl/docs/user_guide/04-template-kernels.md new file mode 100644 index 000000000..9fcda0fd0 --- /dev/null +++ b/tilelang-dsl/docs/user_guide/04-template-kernels.md @@ -0,0 +1,333 @@ +### Template-based Kernel Authoring + +For operations that share similar computation patterns but differ in their core vector operations, the DSL supports template-based kernel authoring. This allows a single kernel implementation to serve multiple related operations through parameterized templates. + +#### Multi-operation Kernels with `ops` Parameter + +Instead of specifying a single `op` parameter, you can provide an `ops` list to match multiple operations: + +```python +@pto.vkernel( + target="a5", + ops=["tadd", "tsub", "tmul", "tdiv"], # List of operations + dtypes=[(T, T, T)], # Type signature using type variable + advanced=True, + templates={ + "core": { + "tadd": "vadd", + "tsub": "vsub", + "tmul": "vmul", + "tdiv": "vdiv", + } + } +) +def elementwise_arithmetic(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): + dtype = dst.element_type + rows, cols = dst.valid_shape + elems_per_vreg = pto.elements_per_vreg(dtype) # Number of elements per vector register + for row in range(0, rows, 1): + remained = cols + for col in range(0, cols, elems_per_vreg): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + out = pto.tpl("core", lhs, rhs, mask) # Template dispatch + pto.vsts(out, dst[row, col:], mask) +``` + +`op` and `ops` are mutually exclusive, and exactly one of them must be +provided. `ops=[...]` only widens the matcher set; callers still use +`pto.select_kernel(target, concrete_op, operand_types, ...)` with a concrete +PTO op such as `"tadd"` or `"tmul"`. + +#### Template System + +The template system consists of three components: + +1. **`templates` parameter**: A dictionary mapping template names to operation-specific implementations +2. **`pto.tpl()` function**: A compile-time placeholder that resolves to the appropriate implementation for the currently selected concrete op +3. **`ops` parameter**: Replaces the singular `op` parameter for multi-operation kernels + +##### Template Definition + +Templates are defined in the `templates` parameter of `@pto.vkernel`. Each template is a dictionary mapping operation names to implementation strings: + +```python +templates={ + "template_name": { + "op1": "implementation_for_op1", + "op2": "implementation_for_op2", + # ... + }, + "another_template": { + "op1": "different_implementation_for_op1", + # ... + } +} +``` + +Template-slot metadata is static and validated when the descriptor is +registered: + +- slot names must be non-empty strings +- mapping keys must be concrete ops covered by the descriptor matcher set +- mapping values must be supported real `pto.*` op names + +The implementation strings are typically vector operation names such as +`"vadd"`, `"vsub"`, `"vmul"`, and `"vdiv"`, which are resolved during kernel +expansion. + +##### Template Usage with `pto.tpl()` + +The `pto.tpl()` operation enables template dispatch for multi-operation kernels, allowing code reuse across related operations through compile-time substitution. + +#### `pto.tpl(template_name: str, *args) -> Any` + +**Description**: Template dispatch operation for multi-operation kernels. Resolves to different implementations based on the current operation being expanded. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `template_name` | `str` | Name of the template to dispatch | +| `*args` | `Any` | Positional arguments passed unchanged to the resolved real implementation | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `Any` | Result of the template implementation | + +**Behavior**: +- Only valid inside kernels decorated with `@pto.vkernel` that have a `templates` parameter +- The first argument must be a string literal template-slot name +- During kernel expansion for a specific operation `op_name`, `pto.tpl("template_name", ...)` is replaced with the implementation specified in `templates["template_name"]["op_name"]` +- The replacement is a direct compile-time substitution; positional arguments are passed unchanged +- Template implementations are typically string names of vector operations (e.g., `"vadd"`, `"vsub"`) +- `pto.select_kernel(...)` must bind a concrete op before template expansion can happen +- Python dict lookup, callable values, lambdas, and other runtime dispatch patterns are not part of the supported kernel-body surface + +**Example**: +```python +@pto.vkernel( + ops=["tadd", "tsub"], + dtypes=[(T, T, T)], + templates={ + "core": { + "tadd": "vadd", + "tsub": "vsub", + } + } +) +def elementwise_kernel(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): + # ... load vectors + result = pto.tpl("core", lhs, rhs, mask) # Expands to vadd for tadd, vsub for tsub + # ... store result +``` + +**Constraints**: +- Template names must be defined in the `templates` parameter of the `@pto.vkernel` decorator +- When a kernel body uses `pto.tpl("slot", ...)`, that slot must define an implementation for the currently selected concrete op +- Template implementations must be valid operation names in the DSL + +#### Decorator Parameters Update + +| Parameter | Type | Required | Description | +|-----------|------|----------|-------------| +| `target` | `str` | Yes | Target hardware architecture (e.g., `"a5"` for Ascend 950). | +| `op` | `str` | No* | Name of the PTO operation to match. **Mutually exclusive with `ops`**. | +| `ops` | `List[str]` | No* | List of PTO operation names to match. **Mutually exclusive with `op`**. | +| `dtypes` | `List[Tuple[Type, ...]]` | Yes | List of type signatures. Each tuple specifies the expected data types for the operation's operands. | +| `templates` | `Dict[str, Dict[str, str]]` | No | Static slot mappings from concrete matcher ops to real `pto.*` op names. Required when the kernel body uses `pto.tpl(...)`. | +| `constraints` | `List[Constraint]` | No | Additional constraints that must be satisfied for kernel selection. | +| `priority` | `int` | No | Selection priority when multiple kernels match. Default: `0`. | +| `name` | `str` | No | Kernel name (used for debugging and profiling). Defaults to the decorated function's name. | +| `advanced` | `bool` | No | Enable advanced-tier DSL surfaces (for example `strict_vecscope`, raw pointer family, and low-level DMA family). Implicit vecscope inference is mode-independent and runs only when no explicit `with pto.vecscope():` is present. Default: `False`. | + +**Note**: +- Either `op` or `ops` must be provided, but not both. +- `templates` is only needed when the kernel body uses `pto.tpl(...)`. +- `pto.select_kernel(...)` still queries with a concrete op even for `ops=[...]` descriptors. + +#### Advanced Template Patterns + +##### Multiple Templates per Kernel + +A kernel can define multiple templates for different aspects of the computation: + +```python +@pto.vkernel( + target="a5", + ops=["tadd_relu", "tsub_relu", "tadd_abs", "tsub_abs"], + dtypes=[(T, T, T)], + templates={ + "arithmetic": { + "tadd_relu": "vadd", + "tsub_relu": "vsub", + "tadd_abs": "vadd", + "tsub_abs": "vsub", + }, + "postprocess": { + "tadd_relu": "vrelu", + "tsub_relu": "vrelu", # Same activation for both + "tadd_abs": "vabs", + "tsub_abs": "vabs", + } + } +) +def elementwise_with_postprocess(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): + # ... load vectors + arith_result = pto.tpl("arithmetic", lhs, rhs, mask) + postprocessed = pto.tpl("postprocess", arith_result, mask) + # ... store result +``` + +##### Compile-time Substitution Model + +Template-slot expansion happens before semantic checking and lowering: + +- `pto.select_kernel(...)` first binds a concrete op such as `"tadd"` +- the frontend then resolves `pto.tpl("core", ...)` using `templates["core"]["tadd"]` +- the placeholder is rewritten to a real `pto.*` call before semantic analysis +- diagnostics for unknown slots, missing mappings, or unsupported resolved surfaces are raised before any VPTO IR is generated + +#### Type Variables in Template Kernels + +Template kernels often use type variables to enforce type consistency: + +```python +T = pto.TypeVar('T') + +@pto.vkernel( + target="a5", + ops=["tadd", "tsub"], + dtypes=[(T, T, T)], # All three operands share type T + templates={ + "core": { + "tadd": "vadd", + "tsub": "vsub", + } + } +) +def typed_elementwise(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): + # Type variable T ensures all tiles have same element type + dtype = dst.element_type # This is type T + # ... implementation +``` + +#### Selection Mechanism for Template Kernels + +When a PTO operation matches a template kernel: +1. The system selects the descriptor based on `op` exact match or `ops` list inclusion. +2. `pto.select_kernel(...)` binds the concrete query op as the descriptor's `selected_op`. +3. During frontend expansion, `pto.tpl()` calls are resolved using that bound concrete op. +4. For operation `"op_name"`, template `"template_name"` resolves to `templates["template_name"]["op_name"]`. +5. The resolved string (e.g., `"vadd"`) is replaced with the corresponding real DSL operation before semantic analysis and lowering. + +#### Example: Unified Arithmetic Kernel + +```python +T = pto.TypeVar('T') + +@pto.vkernel( + ops=["tadd", "tsub", "tmul", "tdiv", "tmax", "tmin"], + dtypes=[(T, T, T)], + advanced=True, + templates={ + "arithmetic": { + "tadd": "vadd", + "tsub": "vsub", + "tmul": "vmul", + "tdiv": "vdiv", + "tmax": "vmax", + "tmin": "vmin", + } + } +) +def unified_arithmetic(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): + """Single implementation for six arithmetic operations.""" + dtype = dst.element_type + rows, cols = dst.valid_shape + elems_per_vreg = pto.elements_per_vreg(dtype) # Number of elements per vector register + + for row in range(0, rows, 1): + remained = cols + for col in range(0, cols, elems_per_vreg): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + out = pto.tpl("arithmetic", lhs, rhs, mask) + pto.vsts(out, dst[row, col:], mask) +``` + +#### Compile-time Specialization with `pto.constexpr` + +The `pto.constexpr` construct enables compile-time branching for kernel specialization, allowing different code paths to be selected based on static compile-time information. Unlike runtime conditionals that generate control flow, `pto.constexpr` branches are resolved during kernel descriptor materialization, with only the selected branch retained for lowering. + +**Syntax and Usage**: +```python +if pto.constexpr(condition): + # Branch taken if condition evaluates to True at compile time + ... +else: + # Branch taken if condition evaluates to False at compile time + ... +``` + +**Semantics**: +- The `condition` must be evaluable at compile time during kernel descriptor materialization. +- Only the selected branch is analyzed, semantically checked, and lowered to VPTO IR. +- The non-selected branch is discarded entirely and does not contribute to runtime control flow or value merging. +- If the condition cannot be proven static, descriptor materialization fails with a frontend diagnostic. + +**Comparison with Runtime Conditionals**: + +| Aspect | Runtime `if` | `pto.constexpr` | +|--------|--------------|-----------------| +| **Evaluation time** | Runtime | Compile-time (descriptor materialization) | +| **Control flow** | Generates `scf.if` with merge logic | No runtime control flow; branch eliminated | +| **Value merging** | Both branches must produce compatible values for merge | No value merging; only one branch exists after elimination | +| **Use case** | Dynamic decision making based on runtime values | Code generation specialization based on static parameters | + +**Typical Static Inputs**: +- Literal integers, booleans, and strings +- Data type symbols (`src.element_type`, `dst.element_type`) and comparisons derived from them +- Statically specialized `Tile.shape` and `Tile.valid_shape` values +- Frontend query helpers such as `pto.bytewidth(dtype)` and `pto.elements_per_vreg(dtype)` (which computes elements per vector register) + +**Constraints and Notes**: +- `TensorView.shape` and `TensorView.strides` may be represented by hidden kernel parameters rather than descriptor-time constants. They should not be assumed constexpr unless separately bound through specialization or other compile-time context. +- `pto.constexpr` is a frontend-only authoring construct; it does not correspond to any runtime VPTO instruction. + +**Guidelines**: +- Use `constraints=[...]` and `pto.select_kernel(...)` when specialization requires selecting an entirely different kernel descriptor. +- Use `pto.constexpr` when the kernel remains the same but internal regions require specialization based on compile-time parameters. + +**Example**: +```python +@pto.vkernel(target="a5", op="pto.trowsum") +def template_trowsum(dst: pto.Tile, src: pto.Tile, tmp: pto.Tile): + acc_dtype = tmp.element_type + dst_dtype = dst.element_type + acc_mask_1, _ = pto.make_mask(acc_dtype, 1) + dst_mask_1, _ = pto.make_mask(dst_dtype, 1) + + if pto.constexpr(acc_dtype != dst_dtype): + # Type conversion required + v_acc_casted = pto.vcvt(v_acc, dst_dtype, acc_mask_1) + pto.vsts(v_acc_casted, dst[row, 0:], dst_mask_1) + else: + # No conversion needed + pto.vsts(v_acc, dst[row, 0:], dst_mask_1) +``` + +### Value Model + +The DSL operates on symbolic values, not Python runtime values: +- **Constants**: Python literals that are typed to machine types +- **Operation results**: Values produced by DSL operations +- **Block arguments**: Values introduced by control flow structures + +### Memory Spaces + +The DSL supports different memory spaces: +- `MemorySpace.GM`: Global Memory +- `MemorySpace.UB`: Unified Buffer (local storage for vector computation) diff --git a/tilelang-dsl/docs/user_guide/05-type-system.md b/tilelang-dsl/docs/user_guide/05-type-system.md new file mode 100644 index 000000000..c40f12475 --- /dev/null +++ b/tilelang-dsl/docs/user_guide/05-type-system.md @@ -0,0 +1,686 @@ + + +## Type System + +### Scalar Types + +| DSL Type | Description | Bit Width | +|----------|-------------|-----------| +| `pto.i1` | Boolean | 1 | +| `pto.i8` | 8-bit signless integer | 8 | +| `pto.si8` | 8-bit signed integer | 8 | +| `pto.ui8` | 8-bit unsigned integer | 8 | +| `pto.i16` | 16-bit signless integer | 16 | +| `pto.si16` | 16-bit signed integer | 16 | +| `pto.ui16` | 16-bit unsigned integer | 16 | +| `pto.i32` | 32-bit signless integer | 32 | +| `pto.si32` | 32-bit signed integer | 32 | +| `pto.ui32` | 32-bit unsigned integer | 32 | +| `pto.i64` | 64-bit signless integer | 64 | +| `pto.si64` | 64-bit signed integer | 64 | +| `pto.ui64` | 64-bit unsigned integer | 64 | +| `pto.f16` | Half precision float | 16 | +| `pto.bf16` | Brain float 16 | 16 | +| `pto.f32` | Single precision float | 32 | + +Python literals are automatically typed: +- `bool` → `pto.i1` +- `int` → Context-dependent (typically `pto.i32` or `pto.i64`) +- `float` → `pto.f32` + +For explicit typing, use type constructors: +```python +x = pto.i32(1024) # Explicit i32 constant +y: pto.i32 = 1024 # Type annotation +z = pto.ui16(7) # Explicit unsigned 16-bit constant +``` + +Static dtype bindings can also be called like constructors. This is useful when +the dtype comes from compile-time metadata such as `element_type`: + +```python +idx_dtype = tile.element_type +zero_idx = idx_dtype(0) +v_col = idx_dtype(col) +``` + +Integer sign semantics are part of the DSL type surface. `pto.si16`, +`pto.ui16`, and `pto.i16` are distinct scalar dtypes and lower to `si16`, +`ui16`, and `i16` respectively in VPTO IR. + +### Integer Literal Guidance + +For ordinary integer constants, prefer plain integer literals instead of +string forms. + +```python +count = pto.i32(1024) +delta = pto.i16(-12) +min_i32 = pto.i32(-2147483648) +unsigned_hi = pto.ui16(32768) +``` + +Integer string literals are reserved for explicit bit-pattern authoring. They +must use hex form. + +```python +# Use hex strings only when you intentionally want fixed-width bit-pattern +# interpretation at the target dtype width. +hi_bit = pto.i32("0x80000000") # -2147483648 +all_ones = pto.i16("0xFFFF") # -1 +unsigned_hi = pto.ui16("0x8000") # 32768 +``` + +Rules: +- Prefer plain integer literals such as `pto.i32(1024)` or `pto.i16(-12)` for normal integer authoring. +- Integer string literals must use hex bit-pattern form such as `"0xFFFF"`. +- Ordinary integer strings such as `"1024"` or `"-12"` are rejected; write them as integer literals instead. +- For signed and signless integer dtypes (`pto.i*`, `pto.si*`), hex strings use two's-complement interpretation at the target dtype width. +- For unsigned integer dtypes (`pto.ui*`), hex strings keep their unsigned value. +- Hex strings must fit within the target bit width. For example, `pto.i16("0x10000")` is rejected because the literal exceeds 16 bits. + +### Floating-Point Literal Forms + +`pto.f16(...)`, `pto.bf16(...)`, and `pto.f32(...)` accept multiple literal forms. + +```python +# Signed numeric literals +a = pto.f16(-1.5) +b = pto.bf16(+2.5) +c = pto.f32(-3.5) + +# Special floating-point values +pos_inf = pto.f32("inf") +neg_inf = pto.f32("-inf") +qnan = pto.f32("nan") + +# Bit-pattern form (hex string, interpreted by target dtype) +f16_neg_inf = pto.f16("0xFC00") +bf16_neg_inf = pto.bf16("0xFF80") +f32_neg_inf = pto.f32("0xFF800000") +``` + +Notes: +- Prefer dtype constructors for reduction seeds and boundary values (for example rowmax initialization). +- For float bit-pattern constants, pass a **string** hex literal to the matching dtype constructor. +- Avoid passing raw integer bit-patterns directly into vector broadcast/dup APIs when a floating vector is expected. +- `float(...)` function calls are not part of the TileLang DSL public call surface; use constructor forms above. + +### Vector Register Type + +Vector registers have fixed 256-byte width: + +```python +v_f32 = pto.vreg(pto.f32) # !pto.vreg<64xf32> +v_f16 = pto.vreg(pto.f16) # !pto.vreg<128xf16> +v_i8 = pto.vreg(pto.i8) # !pto.vreg<256xi8> +``` + +`pto.vreg(dtype)` only takes the element type. The frontend infers the element count automatically from the fixed 256-byte register width: + +- `pto.f32` → `!pto.vreg<64xf32>` +- `pto.f16` → `!pto.vreg<128xf16>` +- `pto.bf16` → `!pto.vreg<128xbf16>` +- `pto.i32` → `!pto.vreg<64xi32>` +- `pto.si32` → `!pto.vreg<64xsi32>` +- `pto.ui32` → `!pto.vreg<64xui32>` +- `pto.i16` → `!pto.vreg<128xi16>` +- `pto.si16` → `!pto.vreg<128xsi16>` +- `pto.ui16` → `!pto.vreg<128xui16>` +- `pto.i8` → `!pto.vreg<256xi8>` +- `pto.si8` → `!pto.vreg<256xsi8>` +- `pto.ui8` → `!pto.vreg<256xui8>` + +Constraint: `element_count × bitwidth(element_type) = 2048` + +Use `pto.elements_per_vreg(dtype)` when you need the inferred element count explicitly: + +```python +v_dtype = pto.vreg(pto.f32) +lanes0 = v_dtype.elements_per_vreg # 64 +lanes1 = pto.elements_per_vreg(pto.f32) # 64 +``` + +Current TileLang DSL v1 vector lowering supports the 8/16/32-bit integer +families (`i*`, `si*`, `ui*`) plus `f16`, `bf16`, and `f32` element types. + +### Builtin Vector Type + +TileLang DSL v1 also exposes builtin MLIR vector types through +`pto.vector(element_dtype, shape)`. + +```python +executed_ty = pto.vector(pto.i16, (4,)) # vector<4xi16> +``` + +This type is different from `pto.vreg(...)`: + +- `pto.vreg(dtype)` models a VPTO vector register with fixed 256-byte width. +- `pto.vector(dtype, shape)` models a builtin MLIR `vector<...>` type with an + explicit static shape. + +Use `pto.vector(...)` when a kernel parameter or intermediate value must match +an existing builtin vector operand in PTO IR, for example an auxiliary +`vector<4xi16>` operand carried by a tile op template. + +```python +@pto.vkernel( + target="a5", + op="pto.tmrgsort ins(src0, src1, tmp) -> outs(dst, ex_vec)", + dtypes=[(pto.f32, pto.f32, pto.f32, pto.f32, pto.i16)], +) +def template( + src0: pto.Tile, + src1: pto.Tile, + tmp: pto.Tile, + dst: pto.Tile, + ex_vec: pto.vector(pto.i16, (4,)), +): + return None +``` + +Notes: + +- `shape` must be a Python tuple of integers. For a 1-D vector, write `(4,)`, + not `(4)`. The trailing comma is Python's single-element tuple syntax. +- The current public surface is intended for static builtin vector types. +- In descriptor `dtypes=[...]`, builtin vector operands are matched by their + element dtype (`pto.i16` in the example above). The vector shape contract is + carried by the parameter annotation `pto.vector(...)`. + +### Vector Type Reinterpretation (vbitcast) + +Vector registers support bitwise type reinterpretation via `pto.vbitcast`: + +```python +result = pto.vbitcast(vector, to_type) +``` + +Interface summary: +- `vector`: a vector register value of type `!pto.vreg` +- `to_type`: target element dtype such as `pto.i32`, `pto.ui32`, `pto.f16`, `pto.bf16`, `pto.f32` +- return: a new vector register `!pto.vreg` whose element count is inferred from the fixed 256-byte vreg width + +Constraints: +- `vector` must be a vreg value; scalar values, pointers, `Tile`, and `TensorView` are rejected +- `to_type` must be a DSL-supported vreg element dtype +- `vbitcast` preserves the total register storage size, so only reinterpretations with the same total bit count are allowed +- the operation has no mask, rounding, saturation, or lane-placement parameters + +Lane count is recomputed from `to_type`: +- `!pto.vreg<64xf32> + pto.i32 -> !pto.vreg<64xi32>` +- `!pto.vreg<64xf32> + pto.f16 -> !pto.vreg<128xf16>` +- `!pto.vreg<128xbf16> + pto.ui16 -> !pto.vreg<128xui16>` + +```python +# Float to integer bitwise reinterpretation +fvec = pto.vlds(ub_ptr, lane) # !pto.vreg<64xf32> +ivec = pto.vbitcast(fvec, pto.i32) # !pto.vreg<64xi32> + +# Signed to unsigned integer reinterpretation +signed_vec = pto.vlds(ptr, lane) # !pto.vreg<64xsi32> +unsigned_vec = pto.vbitcast(signed_vec, pto.ui32) # !pto.vreg<64xui32> + +# Element size change (32-bit to 16-bit) +f32_vec = pto.vlds(ptr, lane) # !pto.vreg<64xf32> +f16_vec = pto.vbitcast(f32_vec, pto.f16) # !pto.vreg<128xf16> +``` + +Pythonic syntax sugar via `astype()` method: + +```python +ivec = fvec.astype(pto.i32) # Float to integer +unsigned_vec = signed_vec.astype(pto.ui32) # Signed to unsigned +f16_vec = f32_vec.astype(pto.f16) # 32-bit to 16-bit +``` + +`astype()` on a vector register is syntax sugar for `pto.vbitcast(...)`. In other words, it is a bit reinterpretation API, not a numeric conversion API. + +**Note**: `vbitcast` preserves the exact bit pattern (type punning), unlike `vcvt` which performs value conversion with rounding/saturation. Use `vcvt` when you want numeric conversion semantics; use `vbitcast` when you want the bits to stay unchanged. + +### Typed Masks + +Masks are typed by their bit granularity: + +| DSL Type | VPTO Type | Description | +|----------|-----------|-------------| +| `pto.mask_b8` | `!pto.mask` | 8-bit granularity mask | +| `pto.mask_b16` | `!pto.mask` | 16-bit granularity mask | +| `pto.mask_b32` | `!pto.mask` | 32-bit granularity mask | + +```python +mask_ty = pto.mask_b32 +mask: pto.mask_b32 = pto.make_mask(pto.f32, PAT.ALL) +``` + +Typed masks also support explicit type reinterpretation via `pto.pbitcast`: + +```python +mask_b8 = pto.plds(mask_ptr, offset, pto.PredicateDist.US) +mask_b16 = pto.pbitcast(mask_b8, pto.mask_b16) +mask_b32 = pto.pbitcast(mask_b16, pto.mask_b32) +``` + +`pto.pbitcast(...)` is the predicate analogue of `pto.vbitcast(...)`: +- it changes the static mask granularity seen by later DSL/VPTO consumers +- it preserves the underlying predicate bit image +- it does not perform pack/unpack or interleave/deinterleave by itself + +Mask operations must match the vector element family: +- `f32`, `i32`, `si32`, and `ui32` vectors use `mask_b32` +- `f16`, `bf16`, `i16`, `si16`, and `ui16` vectors use `mask_b16` +- `i8`, `si8`, and `ui8` vectors use `mask_b8` + +```python +# Correct: f32 vector with b32 mask +mask32 = pto.make_mask(pto.f32, PAT.ALL) +vec_f32 = pto.vlds(ptr, offset) +out = pto.vabs(vec_f32, mask32) + +# Error: mismatched mask granularity +mask16 = pto.make_mask(pto.f16, PAT.ALL) +out = pto.vabs(vec_f32, mask16) # Type error! +``` + +### Pointer Types [Advanced Tier] + +Pointers combine element type and memory space: + +```python +from pto import MemorySpace + +ptr_gm = pto.ptr(pto.f32, MemorySpace.GM) # GM pointer to f32 +ptr_ub = pto.ptr(pto.f16, MemorySpace.UB) # UB pointer to f16 +``` + +The `MemorySpace` enum provides type-safe memory space specification: + +| Enum Value | Description | +|------------|-------------| +| `MemorySpace.GM` | Global Memory (off-chip HBM/DDR) | +| `MemorySpace.MAT` | Cube L1 / cbuf staging buffer | +| `MemorySpace.LEFT` | Cube L0A left-operand buffer | +| `MemorySpace.RIGHT` | Cube L0B right-operand buffer | +| `MemorySpace.ACC` | Cube L0C accumulator buffer | +| `MemorySpace.BIAS` | Cube bias table buffer | +| `MemorySpace.UB` | Unified Buffer (on-chip SRAM, 256KB) | + +This replaces ad-hoc string literals with compile-time checked enums and is +shared by both the Vector and Cube DSL surfaces. + +### Public Buffer Types + +TileLang uses three public buffer-facing type names in kernel signatures: + +| Public Type | Description | +|-------------|-------------| +| `pto.TensorView` | GM-facing tensor view descriptor used for DMA-oriented data access | +| `pto.PartitionTensorView` | Logical GM partition (slice) descriptor, corresponding to `!pto.partition_tensor_view<...>` | +| `pto.Tile` | Tile buffer value for hardware-resident staged compute/storage buffers | + +### TensorView Types + +TensorView types represent multi-dimensional (up to 5D) views into tensors residing in Global Memory (GM). They are used as kernel parameters for describing GM data and support slicing operations to create logical partitions for DMA load/store operations. + +#### TensorView Type Definition + +TensorView types are parameterized by shape (a tuple of up to 5 dimensions) and element type: + +```python +# Kernel parameter using TensorView +@pto.vkernel(target="a5", op="custom", dtypes=[(pto.AnyFloat, pto.AnyFloat, pto.AnyFloat)], priority=10) +def tiled_kernel( + input_tensor: pto.TensorView, # GM tensor view + output_tensor: pto.TensorView, # GM tensor view + tile_buf: pto.Tile # UB tile +): + # Access tensor view properties + shape = input_tensor.shape # tuple of dimensions (dynamic or static, up to 5D) + dtype = input_tensor.element_type # e.g., pto.f32 + strides = input_tensor.strides # stride in elements +``` + +Important notes: +- TensorView is a read-only descriptor for GM data, though DMA store operations can write through it. +- Shape can be static (compile-time constants) or dynamic (determined at runtime). +- Strides are expressed in elements, not bytes. +- Memory space is always GM (Global Memory). +- Maximum rank is 5. PTO ISA right-aligns lower-rank shapes to 5D. +- When higher dimensions are 1, a 5D TensorView can be abbreviated to lower-rank forms. For example, shape `(1, 1, 64, 32, 16)` can be written as `(64, 32, 16)`, and shape `(1, 1, 1, 32, 16)` can be written as `(32, 16)`. + +#### TensorView Attributes + +| Attribute | Type | Description | +|-----------|------|-------------| +| `shape` | `tuple[int, ...]` | Tensor dimensions (supports up to 5 dimensions, right-aligned to 5D in PTO ISA) | +| `element_type` | `Type` | Element data type (for example `pto.f32`, `pto.f16`) | +| `strides` | `tuple[int, ...]` | Stride in elements for each dimension | +| `offset` | `pto.i64` | Byte offset from base pointer (internal) | + +#### Padding Mode Enum + +Padding mode controls how out-of-bounds accesses are handled during DMA load/store operations: + +| Enum Value | Description | +|------------|-------------| +| `PadMode.PadNull` | No padding. Out-of-bounds access is invalid | +| `PadMode.PadFirstElem` | Pad using the first element of the source | +| `PadMode.PadValue` | Pad using a specified value and requires `pad_value` | + +#### Slicing Syntax + +TensorView supports Python slicing syntax to create logical partitions: + +```python +# Create a partition from a tensor view +partition = tensor_view[dim0_start:dim0_end, dim1_start:dim1_end] + +# Example: extract a 16x16 tile from a larger tensor +tile_view = large_tensor[0:16, 0:16] + +# Dynamic offsets and sizes +dim0_start = tensor_view.shape[0] // 2 +dynamic_partition = tensor_view[dim0_start:tensor_view.shape[0], 4:20] + +# Static positive step on dimension 0 +stepped_partition = tensor_view[0:32:2, 0:16] + +# Right-aligned shorthand on a 5D descriptor +partition_3d = tensor_view[d2_start:d2_end, d3_start:d3_end, d4_start:d4_end] + +# Full 5D spelling remains available when needed +partition_5d = tensor_view[ + d0_start:d0_end, + d1_start:d1_end, + d2_start:d2_end, + d3_start:d3_end, + d4_start:d4_end, +] +``` + +Constraints: +- Slicing returns a new `pto.PartitionTensorView` representing the logical partition. +- The partition must be within the original tensor bounds. +- When fewer than 5 slice axes are written, they are right-aligned to the trailing physical axes of the 5D descriptor. +- `stop` must be explicit on all dimensions. +- `start` may be static or dynamic. +- `step` must be a static positive integer. +- Dimension 0 may use `step > 1`. +- Dimension 1 must keep `step == 1` in the current DMA-oriented implementation. + +### PartitionTensorView Types + +`pto.PartitionTensorView` models a logical partition of GM tensor data and maps to +`!pto.partition_tensor_view` in PTO IR. +Like `TensorView`, it is a descriptor type and does not own storage. + +#### PartitionTensorView Type Definition + +```python +@pto.vkernel(target="a5", op="custom_partition", dtypes=[(pto.f32, pto.f32)]) +def kernel(inp: pto.TensorView, out: pto.TensorView): + part: pto.PartitionTensorView = inp[0:16, 0:16] + p_rows, p_cols = part.shape + s_row, s_col = part.strides + return None +``` + +Important notes: +- A `PartitionTensorView` carries partition `shape` and `strides` metadata in element units. +- Element dtype is inherited from the source tensor view. +- Memory space remains GM. +- Rank handling follows the same right-aligned 5D contract as `TensorView`. +- `PartitionTensorView` can be used where DMA-oriented TensorView-like descriptors are accepted. +- Prefer direct indexing or tuple unpacking for `shape`/`strides` metadata values in current DSL v1 lowering. + +#### PartitionTensorView Attributes + +| Attribute | Type | Description | +|-----------|------|-------------| +| `shape` | `tuple[int, ...]` | Partition dimensions | +| `element_type` | `Type` | Element data type inherited from source tensor view | +| `strides` | `tuple[int, ...]` | Stride in elements for each dimension | +| `offset` | `pto.i64` | Byte offset from the base tensor pointer (internal) | + +### Tile Types + +Tile types represent data blocks in memory with layout and configuration information, corresponding to `!pto.tile_buf` in the VPTO IR. Tiles are commonly used as kernel parameters for tiled computations. + +#### Tile Type Definition + +`pto.Tile` is the public tile type used for hardware buffer allocation in specific +address spaces. Tiles are constructed directly via the `pto.Tile` constructor: + +```python +pto.Tile( + shape: tuple[int, ...], # Buffer shape (required) + dtype: Type, # Element type (required) + memory_space: MemorySpace, # Address space (required) + valid_shape: tuple[int, ...] | None = None, # Valid region, defaults to shape + blayout: BLayout | None = None, # B layout, auto-detected from address space + slayout: SLayout | None = None, # S layout, auto-detected from address space + fractal_size: int | None = None, # Fractal size, auto-detected from address space + pad_value: PadValue = PadValue.Null, # Pad policy + compact_mode: CompactMode = CompactMode.Null, # Compact mode + addr: int | None = None, # Pre-assigned address (level3 only) +) -> Tile +``` + +Layout defaults are selected automatically based on the address space: + +| Address Space | blayout default | slayout default | fractal_size default | +|--------------|----------------|----------------|---------------------| +| `MAT` | `ColMajor` | `RowMajor` | `TileConfig.fractalABSize` (512) | +| `LEFT` | `ColMajor` | `RowMajor` | `TileConfig.fractalABSize` (512) | +| `RIGHT` | `RowMajor` | `ColMajor` | `TileConfig.fractalABSize` (512) | +| `ACC` | `ColMajor` | `RowMajor` | `TileConfig.fractalCSize` (1024) | +| `BIAS` | `RowMajor` | `NoneBox` | `TileConfig.fractalABSize` (512) | +| `UB` / `VEC` | `RowMajor` | `NoneBox` | `TileConfig.fractalABSize` (512) | + +Related enum types: + +| Enum | Values | +|------|--------| +| `BLayout` | `ColMajor` (0), `RowMajor` (1) | +| `SLayout` | `NoneBox` (0), `RowMajor` (1), `ColMajor` (2) | +| `PadValue` | `Null` (0), `Zero` (1), `Max` (2), `Min` (3) | +| `CompactMode` | `Null` (0), `Normal` (1), `RowPlusOne` (2) | + +Usage: + +```python +# Allocate tiles in @vkernel or @ckernel +tile_ub = pto.Tile([256, 128], pto.f32, MemorySpace.UB) +tile_left = pto.Tile([16, 64], pto.f16, MemorySpace.LEFT) +tile_acc = pto.Tile([16, 16], pto.f32, MemorySpace.ACC, valid_shape=(12, 12)) +``` + +Important notes on shape and valid shape: +- `shape` must be a compile-time constant. Tile dimensions are fixed at compilation time and cannot change at runtime. +- `valid_shape` can be either static or dynamic and must be less than or equal to `shape` in each dimension. +- When `valid_shape` is not specified, it defaults to the full `shape`. + +#### Tile Attributes + +| Attribute | Type | Description | +|-----------|------|-------------| +| `shape` | `tuple[int, ...]` | Full tile dimensions. These are compile-time constants | +| `element_type` | `Type` | Element data type (for example `pto.f32`) | +| `memory_space` | `MemorySpace` | Memory space such as UB, MAT, LEFT, RIGHT, ACC, or BIAS | +| `valid_shape` | `tuple[int, ...]` | Actual data dimensions within the tile. Must be less than or equal to `shape` in each dimension | +| `config` | `TileConfig` | Layout and padding configuration | + +#### Tile Pad Values + +`TileConfig.pad_value` is modeled after the C++ `PadValue : uint64_t` design. + +Standard pad values use small integer encodings: + +| DSL Value | Encoded Value | Meaning | +|-----------|---------------|---------| +| `pto.PadValue.NULL` | `0` | No concrete fill value | +| `pto.PadValue.ZERO` | `1` | Zero fill | +| `pto.PadValue.MAX` | `2` | Maximum finite / integer max for the tile element dtype | +| `pto.PadValue.MIN` | `3` | Minimum finite / integer min for the tile element dtype | + +Custom pad values use the `CustomBase = 0x100000000` convention and are authored with `pto.PadValue.custom_f32(...)`: + +```python +pad0 = pto.PadValue.ZERO +pad1 = pto.PadValue.custom_f32(-1.0) +pad2 = pto.PadValue.custom_f32("0xBF800000") # float32 bit pattern for -1.0f +``` + +Notes: +- `PadValue.encoded` exposes the host-side uint64 payload. `PadValue.value` is intentionally unavailable to avoid confusion with `.eval(...)` scalar materialization. +- `PadValue.text` exposes the standard textual spelling for built-ins such as `null` and `zero`. +- Custom pad values currently model an `f32` payload. In DSL v1, materializing a custom pad into a scalar is only supported for floating tile element dtypes. +- `PadValue.NULL` does not denote a usable scalar fill constant. Calling `tile.pad_value.eval()` or `tile.config.pad_value.eval()` when the enum is `NULL` is a frontend error. +- **DMA padding**: When performing GM→UB DMA transfers with padding enabled (via `enable_ub_pad=True` in `pto.copy_gm_to_ubuf`), the pad value must be configured explicitly using `pto.set_mov_pad_val`. Tile `PadValue` descriptors are not automatically translated to hardware register configurations in TileLang DSL v1. See [Pad Fill Semantics](08-sync-dma-operations.md#pad-fill-semantics) for usage details. + +Host-side code can materialize a scalar with an explicit dtype: + +```python +pad_max_f32 = pto.PadValue.MAX.eval(pto.f32) +pad_min_i16 = pto.PadValue.MIN.eval(pto.i16) +``` + +#### Tile Shape Concepts + +- `shape` is the static physical allocation size of the tile buffer. +- `valid_shape` is the logical data region and may be static or dynamic. +- `valid_shape[i] <= shape[i]` must hold for each dimension. +- Fixed-size tiles with smaller valid regions are useful for padding and partial-tile cases. + +#### Basic Access Operations + +```python +# Get tile properties +shape = tile.shape # (256, 128) +elem_type = tile.element_type # pto.f32 +mem_space = tile.memory_space # MemorySpace.UB +valid_shape = tile.valid_shape # (240, 120) or same as shape + +# Get configuration properties +config = tile.config +b_layout = config.b_layout # pto.BLayout.ROW_MAJOR +s_layout = config.s_layout # pto.SLayout.NONE_BOX +s_fractal = config.s_fractal_size # pto.i32(512) +pad_desc = tile.config.pad_value # PadValue enum bound to the tile element dtype +pad_desc2 = tile.pad_value # direct sugar for the same PadValue enum + +# Dynamic properties +rank = tile.rank # 2 +``` + +`tile.config.pad_value` and `tile.pad_value` are enum-typed inside kernel code. Use `.eval()` to materialize the configured pad descriptor against the tile element dtype: + +- `tile.pad_value.eval()` with `PadValue.ZERO` becomes `0` / `0.0` +- `tile.pad_value.eval()` with `PadValue.MAX` becomes dtype-aware max +- `tile.pad_value.eval()` with `PadValue.MIN` becomes dtype-aware min +- `tile.pad_value.eval()` with `PadValue.custom_f32(...)` becomes the authored floating scalar +- `tile.pad_value.eval()` with `PadValue.NULL` raises a frontend error + +For dtype-dependent fill seeds, prefer `tile.pad_value.eval()` over handwritten +`if dtype == ...` ladders. + +For standalone `PadValue` symbols that are not bound to a tile, pass the target dtype explicitly: + +```python +pad_scalar = pto.PadValue.MAX.eval(pto.f32) +``` + +```python +@pto.vkernel(op="fill_pad_value", dtypes=[(pto.AnyType,)]) +def fill_pad_value(dst: pto.Tile): + pad_scalar = dst.pad_value.eval() + pad_vec = pto.vbr(pad_scalar) + # ... +``` + +Typical materialized values: + +- `PadValue.ZERO` -> `0` / `0.0` +- `PadValue.MAX` -> dtype-aware max, for example `4294967295` for `pto.ui32` +- `PadValue.MIN` -> dtype-aware min, for example `-2147483648` for `pto.i32` and `0` for `pto.ui32` + +This is usually simpler than spelling every dtype case manually with +`pto.constexpr(dst.element_type == ...)`. + +Example: reading pad value from a `Tile` + +```python +@pto.vkernel(op="fill_pad_demo", dtypes=[(pto.f16,)]) +def kernel(dst: pto.Tile): + mask, _ = pto.make_mask(pto.f16, 8) + + # Read the Tile-bound PadValue enum. + pad0 = dst.pad_value + + # Equivalent form through TileConfig metadata. + pad1 = dst.config.pad_value + + if pto.constexpr(pad0 != pto.PadValue.NULL): + scalar0 = pad0.eval() + scalar1 = pad1.eval() + vec0 = pto.vdup(scalar0, mask) + vec1 = pto.vdup(scalar1, mask) + pto.vsts(vec0, dst[0, 0:], mask) + pto.vsts(vec1, dst[1, 0:], mask) +``` + +If `dst` is specialized with `config=pto.TileConfig.from_mapping({"pad_value": pto.PadValue.ZERO})`, +both `pad0` and `pad1` are `PadValue.ZERO`, and `pad0.eval()` / `pad1.eval()` materialize to the scalar `0.0` for an `f16` tile. + +#### Conversion Operations + +Basic mode syntax uses tile element-indexing directly in vector operations: + +```python +# 2D tile indexing +vec = pto.vlds(tile[row, col:]) +pto.vsts(vec, tile[row, col:], mask) + +# 1D tile indexing +vec = pto.vlds(tile[start:]) +pto.vsts(vec, tile[start:], mask) +``` + +Advanced mode syntax converts tiles to typed pointers for byte-offset operations: + +```python +# Convert tile to pointer +ptr = tile.as_ptr() # Returns pto.ptr(pto.f32, MemorySpace.UB) + +# Use pointer with byte offset +vec = pto.vlds(ptr, offset) +pto.vsts(vec, ptr, offset, mask) +``` + +#### Kernel Parameter Usage + +```python +@pto.vkernel(target="a5", op="scale", dtypes=[(pto.AnyFloat, pto.AnyFloat)], priority=10) +def tiled_kernel( + input_tile: pto.Tile, + output_tile: pto.Tile, + scale: pto.f32 +): + all_mask = pto.make_mask(pto.f32, PAT.ALL) + for i in range(0, 256, 64): + vec = pto.vlds(input_tile[i, 0:]) + scaled = pto.vmuls(vec, scale, all_mask) + pto.vsts(scaled, output_tile[i, 0:], all_mask) +``` + +### Alignment Type + +The `pto.align` type is used for alignment carrier operations and maps to `!pto.align`. diff --git a/tilelang-dsl/docs/user_guide/06-control-flow.md b/tilelang-dsl/docs/user_guide/06-control-flow.md new file mode 100644 index 000000000..41b623d1d --- /dev/null +++ b/tilelang-dsl/docs/user_guide/06-control-flow.md @@ -0,0 +1,181 @@ +## Control Flow + +### Vector Scopes + +The TileLang DSL supports implicit vector scope inference, allowing developers to write vector operations directly without explicit `pto.vecscope()` blocks. The compiler automatically groups consecutive, data-dependent vector operations into implicit vector scopes during lowering. + +#### Implicit Scope Inference + +**Note:** `pto.vecscope()` is supported. Automatic scope inference runs only when the kernel does **not** contain explicit `with pto.vecscope():` blocks. + +When you write vector operations like `pto.vlds`, `pto.vadd`, `pto.vsts` directly in your code, the compiler's **Scope Inference Pass** analyzes the control flow graph and automatically creates vector scopes: + +```python +# No explicit vecscope needed - compiler infers scope boundaries +vec = pto.vlds(outer_ptr, offset) +result = pto.vadd(vec, vec, all_mask) +pto.vsts(result, dst_ptr, offset, all_mask) +``` + +The compiler automatically groups these three operations into a single implicit vector scope because they form a data-dependent chain (when no explicit `pto.vecscope()` appears in the kernel). + +**Scope boundary rules:** +1. **Control flow boundaries**: Branches (`if`/`else`), loops (`for`/`while`), and function calls create implicit scope boundaries +2. **Scalar operations**: Non-vector operations (e.g., scalar arithmetic, pointer arithmetic) create boundaries +3. **Explicit scope blocks**: User-defined `vecscope` and `strict_vecscope` blocks create hard boundaries + +#### Explicit Scope Boundaries with `strict_vecscope` [Advanced Tier] + +##### `pto.strict_vecscope(*captures: AnyType) -> ContextManager[Tuple[AnyType, ...]]` + +**Description**: Creates an explicit vector scope boundary with explicit value captures. Values used inside the scope must be passed as arguments; implicit capture from outer scope is rejected. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `*captures` | `AnyType` | Variable number of values to be captured and passed into the scope | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `context_manager` | `ContextManager[Tuple[AnyType, ...]]` | Context manager that yields a tuple of captured values when entered | + +**Constraints**: +- The scope body cannot implicitly capture values from the surrounding scope; all used values must be passed as `captures`. +- Creates a hard boundary that prevents the compiler from merging vector operations across the scope boundary. +- Useful for performance optimization, debugging, resource management, and hardware compatibility. + +For precise control over scope boundaries, use explicit `strict_vecscope` blocks. These create hard boundaries that prevent the compiler from merging operations across the block boundary: + +```python +with pto.strict_vecscope(src_ptr, dst_ptr, start, end) as (s, d, lb, ub): + # Operations inside this block are isolated from outside + # Compiler will not merge operations across this boundary + for i in range(lb, ub, 64): + vec = pto.vlds(s, i) + pto.vsts(vec, d, i, all_mask) +``` + +**Use cases for strict_vecscope:** +- Performance optimization: Isolate critical vector computation regions +- Debugging: Create explicit boundaries to isolate vector operations +- Resource management: Control vector register allocation boundaries +- Compatibility: Ensure deterministic scope placement for hardware constraints + +#### Explicit Scope Blocks with `vecscope` + +`pto.vecscope` provides an explicit vector-scope boundary without strict capture ABI constraints: + +```python +with pto.vecscope(): + vec = pto.vlds(src, 0) + vec = pto.vadd(vec, vec, mask) + pto.vsts(vec, dst, 0, mask) +``` + +**Rules**: +- `pto.vecscope()` takes no positional/keyword arguments. +- `pto.vecscope()` does not support `as (...)` bindings. +- When any explicit `pto.vecscope()` is present in a kernel body, automatic vecscope inference is disabled for that kernel. + +### Inline Procedures (`@pto.inline_proc`) + +TileLang DSL supports reusable top-level procedures decorated with `@pto.inline_proc`. +`inline_proc` follows function-call semantics in frontend IR and is force-inlined +later by the VPTO backend mainline in `ptoas`. + +```python +@pto.inline_proc +def store_row(dst: pto.Tile, src: pto.Tile, row: pto.i32): + vec = pto.vlds(src[row, 0:]) + mask = pto.make_mask(dst.element_type, pto.PAT.ALL) + pto.vsts(vec, dst[row, 0:], mask) + return None + +@pto.vkernel(op="pto.row_copy", dtypes=[(pto.f32, pto.f32, pto.i32)]) +def row_copy(dst: pto.Tile, src: pto.Tile, row: pto.i32): + store_row(dst, src, row) + return None +``` + +Important semantics: + +- `pto.(...)` and bare helper calls are different mechanisms. +- Calls written as `pto.vadd(...)`, `pto.vdiv(...)`, `pto.vlds(...)`, etc. target + built-in TileLang/VPTO surfaces directly. +- Calls written as bare Python names such as `store_row(...)` target a + user-defined `@pto.inline_proc` helper when the callee name resolves to a + registered top-level inline procedure in the current module. +- `inline_proc` helpers do not live in the `pto` namespace; using the same + basename as a `pto.` op is allowed because the frontend distinguishes + `pto.xxx(...)` from bare `xxx(...)` calls. +- Frontend preserves helper `func.func` and `func.call` in `mlir_text()` output. +- VPTO backend mainline force-inlines helper calls before downstream lowering. +- Helper definitions support default parameter values. +- Helper calls support positional arguments and keyword arguments. +- Helper calls can appear in statement and expression positions. +- Helper definitions can use trailing `return ` to return values. +- Implicit capture is rejected except module-level globals whose current bound value is `bool`/`int`/`float`/`str`; pass other required values as explicit arguments. +- Recursive/mutually-recursive helper call graphs are rejected. +- `*args`, `**kwargs`, and keyword-only parameters are unsupported in current version. + +Shared helpers can live in a separate Python file in the template directory and +be imported directly by templates: + +```python +# shared_rows.py +import tilelang_dsl as pto + +@pto.inline_proc +def touch_row(dst: pto.Tile, row: pto.i32): + mask = pto.make_mask(dst.element_type, pto.PAT.ALL) + vec = pto.vlds(dst[row, 0:]) + pto.vsts(vec, dst[row, 0:], mask) + return None + +# trow_template.py +import tilelang_dsl as pto +from shared_rows import touch_row + +@pto.vkernel(op="pto.row_touch", dtypes=[(pto.f32, pto.i32)]) +def row_touch(dst: pto.Tile, row: pto.i32): + touch_row(dst, row) + return None +``` + +Only directly imported `@pto.inline_proc` helpers are part of this shared-helper +surface. Ordinary Python functions remain unsupported in DSL bodies, and +qualified calls such as `shared_rows.touch_row(...)` are not part of this +version. If multiple imported helpers expose the same bare name, the frontend +rejects the template instead of choosing one by import order. + +### Loops + +Counted loops use Python's `range` syntax: + +```python +for i in range(lb, ub, step): + # Loop body + mask, rem = pto.make_mask(pto.f32, remaining) + # ... +``` + +Loop-carried state is automatically handled through variable updates within the loop. + +### Conditionals + +`if` statements support value merging: + +```python +flag: pto.i1 = some_condition +step: pto.i32 = 0 + +if flag: + step = pto.i32(64) +else: + step = pto.i32(128) + +# 'step' here is the merged result from both branches +``` + +Variables defined in only one branch are local to that branch. diff --git a/tilelang-dsl/docs/user_guide/07-frontend-operations.md b/tilelang-dsl/docs/user_guide/07-frontend-operations.md new file mode 100644 index 000000000..621a8c78f --- /dev/null +++ b/tilelang-dsl/docs/user_guide/07-frontend-operations.md @@ -0,0 +1,352 @@ + +### Frontend-only Authoring Operations + +Operations in this family affect descriptor construction and code generation +shape. They are consumed by the frontend and do not correspond to runtime VPTO +instructions by themselves. + +#### `pto.constexpr(value: bool) -> bool` + +**Description**: Compile-time conditional construct for kernel specialization. Marks a boolean expression for evaluation during descriptor materialization, enabling branch elimination based on static compile-time information. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `value` | `bool` | Boolean expression that must be evaluable at compile time. | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `bool` | A frontend-only compile-time boolean used to guard `if` statements. | + +**Behavior**: +- Evaluated during kernel descriptor materialization before semantic analysis and lowering. +- When used in `if pto.constexpr(...):` statements, only the selected branch is retained; the other branch is discarded entirely. +- If the condition cannot be proven static, descriptor materialization fails with a frontend diagnostic. +- Does not generate runtime control flow or value merging logic. + +**Examples**: +```python +# Specialize based on element size +dtype = dst.element_type +elem_bytes = pto.bytewidth(dtype) + +if pto.constexpr(elem_bytes == 2): + # Specialized path for 16-bit types (f16/bf16) + ... +else: + # Fallback path for other types + ... +``` + +```python +# Specialize based on tile shape +rows, cols = dst.shape + +if pto.constexpr(rows == 1 and cols == 16): + # Fast path for specific tile configuration + ... +``` + +**Constraints**: +- `pto.constexpr` is a frontend-only authoring construct with no runtime representation. +- The condition must be statically evaluable from descriptor-time information (data types, tile shapes, literals, etc.). +- For kernel-level specialization, prefer `constraints=[...]` and `pto.select_kernel(...)`. +- See [Compile-time Specialization with `pto.constexpr`](04-template-kernels.md#compile-time-specialization-with-ptoconstexpr) for detailed usage guidelines. + +### Type Query Operations + +Operations for querying type properties. + +#### `pto.bytewidth(dtype: Type) -> pto.i32` + +**Description**: Returns the size in bytes of a single element of the given data type. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `dtype` | `Type` | Data type (e.g., `pto.f32`, `pto.f16`, `pto.i8`, `pto.si16`, `pto.ui32`) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `size` | `pto.i32` | Element size in bytes | + +**Example**: +```python +f32_size = pto.bytewidth(pto.f32) # Returns 4 +f16_size = pto.bytewidth(pto.f16) # Returns 2 +i8_size = pto.bytewidth(pto.i8) # Returns 1 +ui64_size = pto.bytewidth(pto.ui64) # Returns 8 +``` + +**Common Use Case**: Calculate byte offsets for memory access: +```python +element_type = pto.f32 +byte_offset = index * pto.bytewidth(element_type) +``` + +#### `pto.elements_per_vreg(dtype: Type) -> pto.i32` + +**Description**: Returns the number of elements per vector register for a given element type, based on the hardware vector register size (256 bytes). This function computes `256 // bytewidth(dtype)`, which represents the maximum number of elements of the given type that can fit in a single vector register. Useful for determining vector width and loop stride calculations. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `dtype` | `Type` | Data type (e.g., `pto.f32`, `pto.f16`, `pto.i8`, `pto.si16`, `pto.ui32`) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `elems` | `pto.i32` | Number of elements per vector register for the given element type | + +**Example**: +```python +f32_elems_per_vreg = pto.elements_per_vreg(pto.f32) # Returns 64 (256 / 4) +f16_elems_per_vreg = pto.elements_per_vreg(pto.f16) # Returns 128 (256 / 2) +i8_elems_per_vreg = pto.elements_per_vreg(pto.i8) # Returns 256 (256 / 1) +si16_elems_per_vreg = pto.elements_per_vreg(pto.si16) # Returns 128 (256 / 2) +``` + +**Common Use Case**: Loop stride calculation for vector operations: +```python +dtype = pto.f32 +elems_per_vreg = pto.elements_per_vreg(dtype) # Returns 64 for f32 +for col in range(0, cols, elems_per_vreg): + # Load/store vectors of 'elems_per_vreg' elements + pass +``` + +**Relationship with `pto.bytewidth`**: +```python +# The relationship between bytewidth and elements per vector register: +elems = 256 // pto.bytewidth(dtype) +# This is equivalent to: +elems = pto.elements_per_vreg(dtype) +``` + +### Runtime Block Query Operations + +These ops expose the current kernel instance's execution coordinates to scalar +code. They are pure scalar producers: + +- they do not move data +- they do not allocate buffers +- they do not by themselves create `vecscope` boundaries + +Their main purpose is workload partitioning. A common pattern is: + +1. query the current block or subblock id +2. compute a per-instance starting offset +3. use that offset to derive GM/UB pointers or TensorView slices +4. run the local tile or vector loop for that partition + +#### `pto.get_block_idx() -> pto.i64` + +**Description**: Returns the current block ID for the running kernel instance. + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `block` | `pto.i64` | Current block index in the range `[0, pto.get_block_num())` | + +**Behavior**: +- The returned value is launch-instance-local and may differ across concurrently running blocks. +- The value is stable for the lifetime of one kernel instance. +- The op is scalar-only and can be used before pointer arithmetic, TensorView partitioning, DMA setup, or loop construction. + +#### `pto.get_subblock_idx() -> pto.i64` + +**Description**: Returns the current subblock ID visible to the running kernel instance. + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `subblock` | `pto.i64` | Current subblock index in the range `[0, pto.get_subblock_num())` | + +**Behavior**: +- Used when one block is further subdivided by the launch/runtime model. +- Like `pto.get_block_idx()`, this is a pure scalar query with no side effects. + +#### `pto.get_block_num() -> pto.i64` + +**Description**: Returns the total number of blocks visible to the current kernel launch. + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `block_num` | `pto.i64` | Total block count for the current launch domain | + +**Behavior**: +- Typically paired with `pto.get_block_idx()` to compute per-block ranges. +- The result is a runtime value and should not be assumed to be a compile-time constant. + +#### `pto.get_subblock_num() -> pto.i64` + +**Description**: Returns the total number of subblocks visible to the current execution instance. + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `subblock_num` | `pto.i64` | Total subblock count in the current runtime execution domain | + +**Behavior**: +- Typically paired with `pto.get_subblock_idx()` for finer-grained partitioning inside one block. + +**Example**: +```python +block = pto.get_block_idx() +block_num = pto.get_block_num() +subblock = pto.get_subblock_idx() +subblock_num = pto.get_subblock_num() +``` + +**Typical Use Case**: Compute a per-block base pointer. +```python +block = pto.get_block_idx() +block_len = 2048 +base_elem = block * block_len +block_src = pto.addptr(src_gm, base_elem) +block_dst = pto.addptr(dst_gm, base_elem) +``` + +**Constraints**: +- These ops return runtime scalar values, not compile-time specialization constants. +- They are intended for scalar address/control computation, not as vector operands. +- When mixing them with pointer arithmetic, remember that `pto.addptr(...)` uses element offsets, not byte offsets. + +### Scalar Pointer Helpers [Advanced Tier] + +These ops perform scalar element access on typed PTO pointers. Unlike +`pto.vlds(...)` / `pto.vsts(...)`, they operate on exactly one element and do +not create or consume vector registers or masks. + +They are useful when a kernel needs a small amount of scalar state next to +vector code, for example: + +- reading a scalar coefficient or loop-carried value from UB +- writing a scalar flag or reduction result +- patching a small header/metadata area without vector load-store semantics + +#### `pto.load_scalar(ptr: PtrType, offset: Index) -> ScalarType` +#### `pto.load_scalar(dtype: Type, ptr: PtrType, offset: Index) -> ScalarType` + +**Description**: Loads one scalar element from a typed PTO pointer at the given element offset. + +**Parameters (`load_scalar(ptr, offset)`)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `ptr` | `PtrType` | Typed pointer created by `pto.ptr(...)`, `pto.castptr(...)`, `Tile.as_ptr()`, or `TensorView.as_ptr()` | +| `offset` | `Index` | Element displacement from `ptr` | + +**Parameters (`load_scalar(dtype, ptr, offset)`)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `dtype` | `Type` | Optional explicit result dtype; must match the pointer element type | +| `ptr` | `PtrType` | Typed pointer source | +| `offset` | `Index` | Element displacement from `ptr` | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `value` | `ScalarType` | One scalar element loaded from `ptr[offset]` | + +**Behavior**: +- Access is element-based, not byte-based. +- The loaded value has the same scalar dtype as the pointer element type. +- This is a scalar memory helper; it does not participate in vector distribution families such as `dist`. +- It may target any memory space represented by the pointer type; the memory-space legality follows the pointer producer. + +#### `pto.store_scalar(ptr: PtrType, offset: Index, value: ScalarType) -> None` +#### `pto.store_scalar(value: ScalarType, ptr: PtrType, offset: Index) -> None` + +**Description**: Stores one scalar element to a typed PTO pointer at the given element offset. + +**Parameters (`store_scalar(ptr, offset, value)`)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `ptr` | `PtrType` | Typed destination pointer | +| `offset` | `Index` | Element displacement from `ptr` | +| `value` | `ScalarType` | Scalar value to write | + +**Parameters (`store_scalar(value, ptr, offset)`)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `value` | `ScalarType` | Scalar value to write | +| `ptr` | `PtrType` | Typed destination pointer | +| `offset` | `Index` | Element displacement from `ptr` | + +**Returns**: None (side-effect operation) + +**Behavior**: +- Stores exactly one scalar element to `ptr[offset]`. +- Does not consume a predicate mask. +- Does not imply vector-store ordering semantics such as `dist` or unaligned store state. + +**Example**: +```python +value = pto.load_scalar(src_ptr, 0) +pto.store_scalar(dst_ptr, 0, value) +``` + +**Typical Use Case**: Read-modify-write scalar metadata next to vector code. +```python +flag = pto.load_scalar(status_ptr, 0) +# scalar compute on `flag` +pto.store_scalar(status_ptr, 0, flag) +``` + +**Constraints**: +- `ptr` must be a typed `pto.ptr(...)` value. +- `offset` is element-based and must be index-typed after frontend normalization. + Plain integer literals such as `0` are accepted and lowered as index constants. +- The scalar dtype must match the pointer element dtype. +- These ops are advanced pointer-surface operations; prefer Tile/TensorView authoring surfaces when scalar pointer manipulation is not required. + +### Pointer Construction [Advanced Tier] + +Operations for creating and manipulating typed pointers. + +#### `pto.castptr(offset: pto.i64, ptr_type: Type) -> PtrType` + +**Description**: Creates a typed pointer from an integer address, a memref-backed address value, or another typed pointer in the same memory space. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `offset` | `pto.i64` / address-like value | Integer address, memref-backed address value, or existing pointer | +| `ptr_type` | `Type` | Target pointer type (e.g., `pto.ptr(pto.f32, MemorySpace.GM)`) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `ptr` | `PtrType` | Typed pointer value | + +**Example**: +```python +ub_ptr = pto.castptr(0, pto.ptr(pto.f32, MemorySpace.UB)) +``` + +`TensorView.as_ptr()` and `Tile.as_ptr()` remain the preferred high-level APIs. They lower directly to address-extraction intrinsics (`pto.tensor_view_addr` / `pto.tile_buf_addr`) with pointer result types, while tile slice / buffer-view authoring paths continue to materialize memref results from the same intrinsics. + +#### `pto.addptr(ptr: PtrType, offset: pto.i64) -> PtrType` + +**Description**: Adds an element offset to an existing pointer. The offset is counted in elements, not bytes. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `ptr` | `PtrType` | Source pointer | +| `offset` | `pto.i64` | Element offset to add (counted in elements, not bytes) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `new_ptr` | `PtrType` | Pointer with element offset applied | + +**Example**: +```python +# Advance pointer by 1024 f32 elements (not bytes) +next_ptr = pto.addptr(ub_ptr, 1024) +``` + diff --git a/tilelang-dsl/docs/user_guide/08-sync-dma-operations.md b/tilelang-dsl/docs/user_guide/08-sync-dma-operations.md new file mode 100644 index 000000000..883e5104a --- /dev/null +++ b/tilelang-dsl/docs/user_guide/08-sync-dma-operations.md @@ -0,0 +1,622 @@ +### Synchronization & Buffer Control + +Operations for pipeline synchronization and buffer management. + +#### Enum Types for Synchronization + +The following enum types provide type-safe parameter specification for synchronization operations: + +- **`BarrierType`**: Memory barrier types for `pto.mem_bar` + - `VV_ALL`, `VST_VLD`, `VLD_VST`, `VST_VST`: vector→vector barriers + - `VS_ALL`, `VST_LD`, `VLD_ST`, `VST_ST`: vector→scalar barriers + - `SV_ALL`, `ST_VLD`, `LD_VST`, `ST_VST`: scalar→vector barriers + +- **`Pipe`**: Hardware pipeline identifiers + - `MTE2`: Memory Transfer Engine 2 pipeline + - `V`: Vector pipeline + - `MTE3`: Memory Transfer Engine 3 pipeline + - `ALL`: All pipelines (for barrier operations) + +- **`Event`**: Event identifiers for synchronization + - `ID0`, `ID1`, `ID2`, `ID3`, ..., `ID31`: Event IDs 0-31 (A5 supports 32 event IDs, 0-15 for subblock 0, 16-31 for subblock 1) + +#### `pto.set_flag(pipe_from: PIPE, pipe_to: PIPE, event: EVENT) -> None` + +**Description**: Sets a synchronization flag between hardware pipelines. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `pipe_from` | `PIPE` | Source pipeline (e.g., `PIPE.MTE2`) | +| `pipe_to` | `PIPE` | Destination pipeline (e.g., `PIPE.V`) | +| `event` | `EVENT` | Event identifier (e.g., `EVENT.ID0`) | + +**Returns**: None (side-effect operation) + +**Example**: +```python +from pto import PIPE, EVENT + +pto.set_flag(PIPE.MTE2, PIPE.V, EVENT.ID0) +``` + +#### `pto.wait_flag(pipe_from: PIPE, pipe_to: PIPE, event: EVENT) -> None` + +**Description**: Waits for a synchronization flag between hardware pipelines. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `pipe_from` | `PIPE` | Source pipeline (e.g., `PIPE.MTE2`) | +| `pipe_to` | `PIPE` | Destination pipeline (e.g., `PIPE.V`) | +| `event` | `EVENT` | Event identifier (e.g., `EVENT.ID0`) | + +**Returns**: None (side-effect operation) + +**Example**: +```python +from pto import PIPE, EVENT + +pto.wait_flag(PIPE.MTE2, PIPE.V, EVENT.ID0) +``` + +#### `pto.pipe_barrier(pipes: PIPE) -> None` + +**Description**: Executes a barrier across specified pipelines. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `pipes` | `PIPE` | Pipeline specification (e.g., `PIPE.ALL`) | + +**Returns**: None (side-effect operation) + +**Example**: +```python +from pto import PIPE + +pto.pipe_barrier(PIPE.ALL) +``` + +#### `pto.get_buf(pipe: Pipe, buf_id: pto.i64, mode: pto.i64) -> None` + +**Description**: Acquire buffer slot for inter-pipeline double-buffering coordination. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `pipe` | `Pipe` | Pipeline identifier (e.g., `Pipe.MTE2`, `Pipe.V`, `Pipe.MTE3`) | +| `buf_id` | `pto.i64` | Buffer identifier | +| `mode` | `pto.i64` | Acquisition mode | + +**Returns**: None (side-effect operation) + +**Example**: +```python +from pto import Pipe + +# Acquire buffer for MTE2 pipeline +pto.get_buf(Pipe.MTE2, 0, 0) +``` + +#### `pto.rls_buf(pipe: Pipe, buf_id: pto.i64, mode: pto.i64) -> None` + +**Description**: Release buffer slot to allow other pipeline to proceed. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `pipe` | `Pipe` | Pipeline identifier (e.g., `Pipe.MTE2`, `Pipe.V`, `Pipe.MTE3`) | +| `buf_id` | `pto.i64` | Buffer identifier | +| `mode` | `pto.i64` | Release mode | + +**Returns**: None (side-effect operation) + +**Example**: +```python +from pto import Pipe + +# Release buffer for MTE2 pipeline +pto.rls_buf(Pipe.MTE2, 0, 0) +``` + +#### `pto.mem_bar(barrier_type: BarrierType) -> None` + +**Description**: Memory barrier for pipeline synchronization within vector scope. Required when UB addresses alias between vector load/store operations. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `barrier_type` | `BarrierType` | Barrier type controlling prior/subsequent instruction ordering. Supported values are `BarrierType.VV_ALL`, `BarrierType.VST_VLD`, `BarrierType.VLD_VST`, `BarrierType.VST_VST`, `BarrierType.VS_ALL`, `BarrierType.VST_LD`, `BarrierType.VLD_ST`, `BarrierType.VST_ST`, `BarrierType.SV_ALL`, `BarrierType.ST_VLD`, `BarrierType.LD_VST`, and `BarrierType.ST_VST`. | + +**Returns**: None (side-effect operation) + +**Example**: +```python +from pto import BarrierType + +# Ensure stores are visible before loads to same UB region +pto.mem_bar(BarrierType.VST_VLD) +``` + +#### `pto.set_cross_core(core_id: pto.i64, event_id: Event) -> None` + +**Description**: Signal event to another core (cross-core synchronization). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `core_id` | `pto.i64` | Target/source core identifier (platform-specific mapping) | +| `event_id` | `Event` | Cross-core event identifier (e.g., `Event.ID0`) | + +**Returns**: None (side-effect operation) + +**Example**: +```python +from pto import Event + +# Signal event ID0 to core 0 +pto.set_cross_core(0, Event.ID0) +``` + +#### `pto.set_intra_block(block_id: pto.i64, event_id: Event) -> None` + +**Description**: Signal event within a block (A5). Specifies trigger pipe. 1:1 per subblock. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `block_id` | `pto.i64` | Block/pipeline identifier specifying trigger pipe | +| `event_id` | `Event` | Event identifier (e.g., `Event.ID0`) | + +**Returns**: None (side-effect operation) + +**Example**: +```python +from pto import Event + +# Signal event ID0 on block/pipeline 0 +pto.set_intra_block(0, Event.ID0) +``` + +#### `pto.set_intra_core(config: pto.i32) -> None` + +**Description**: Configures intra-core synchronization settings. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `config` | `pto.i32` | Configuration value for intra-core synchronization | + +**Returns**: None (side-effect operation) + +**Example**: +```python +pto.set_intra_core(3) +``` + +#### `pto.wait_flag_dev(core_id: pto.i64, event_id: Event) -> None` + +**Description**: Wait for event from another core. SU-level blocking — entire core stalls. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `core_id` | `pto.i64` | Core identifier | +| `event_id` | `Event` | Event identifier (e.g., `Event.ID0`) | + +**Returns**: None (side-effect operation) + +**Example**: +```python +from pto import Event + +# Wait for event ID0 from core 0 +pto.wait_flag_dev(0, Event.ID0) +``` + +#### `pto.wait_intra_core(block_id: pto.i64, event_id: Event) -> None` + +**Description**: Wait for event within block (A5). Specifies which pipeline should wait — only that pipe stalls, SU and other pipes continue. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `block_id` | `pto.i64` | Block/pipeline identifier specifying which pipeline should wait | +| `event_id` | `Event` | Event identifier (e.g., `Event.ID0`) | + +**Returns**: None (side-effect operation) + +**Example**: +```python +from pto import Event + +# Wait for event ID0 on block/pipeline 0 +pto.wait_intra_core(0, Event.ID0) +``` + +### DMA Programming [Advanced Tier] + +This section covers Direct Memory Access (DMA) operations for transferring data between Global Memory (GM) and Unified Buffer (UB). DMA operations are performance-critical and require careful configuration of stride parameters and transfer sizes. + +**Key Concepts:** +- **DMA Configuration**: Set stride parameters and loop sizes using `set_loop*_stride_*` and `set_loop_size_*` operations. +- **DMA Execution**: Perform transfers using `copy_gm_to_ubuf`, `copy_ubuf_to_gm`, and `copy_ubuf_to_ubuf` operations. +- **GM→UB Padding**: Optionally fill out-of-bounds regions with a specified value when copying from GM to UB. See [Pad Fill Semantics](#pad-fill-semantics) for details. + +**Usage Flow:** +1. Configure DMA parameters (strides, loop sizes) +2. Execute the DMA transfer operation +3. Optionally enable padding for GM→UB transfers + +**Note**: All DMA operations in this section are part of the **Advanced Tier** and require explicit buffer management and pointer arithmetic. For basic tile-based authoring, refer to the [Basic Authoring Mode](01-introduction.md#basic-vs-advanced-authoring-modes) documentation. + +#### Manual Configuration Example + +```python +# DMA configuration example (requires careful parameter tuning) +pto.set_loop2_stride_outtoub(src_stride=32, dst_stride=128) # Outer loop strides +pto.set_loop1_stride_outtoub(src_stride=1, dst_stride=32) # Inner loop strides +pto.set_loop_size_outtoub(loop1=16, loop2=16) # Transfer size +pto.copy_gm_to_ubuf(src=gm_ptr, dst=ub_ptr, n_burst=16, len_burst=128, gm_stride=128, ub_stride=128) + +``` + +#### Pad Fill Semantics + +When copying data from Global Memory (GM) to Unified Buffer (UB), you can enable padding to fill out-of-bounds regions with a specified value. This is useful when the source data dimensions don't perfectly match the destination tile allocation, or when you need to handle boundary conditions in tiled computations. + +##### How Padding Works + +1. **Configure the hardware pad register**: Call `pto.set_mov_pad_val` to set the pad value in the hardware register. This must be done before any `pto.copy_gm_to_ubuf` operation with padding enabled. + +2. **Enable padding in the DMA operation**: Set `enable_ub_pad=True` in the `pto.copy_gm_to_ubuf` call to activate the padded transfer path. The pad value from the hardware register will be used for filling out-of-bounds regions. + +3. **Hardware mapping**: The `pto.set_mov_pad_val` operation corresponds directly to the low-level VPTO instruction that configures the hardware pad register. There is no automatic translation from tile `PadValue` descriptors—you must explicitly set the pad register before padded DMA transfers. + +##### Example Workflow + +Configure the hardware pad register using `pto.set_mov_pad_val`, then perform the DMA transfer with padding enabled: + +```python +# First, configure the hardware pad register with a scalar value +# For zero fill, use an appropriate scalar type based on your data +pto.set_mov_pad_val(pto.f32(0.0)) # Zero fill for float32 data + +# Then perform the DMA transfer with padding enabled +pto.copy_gm_to_ubuf( + src=gm_ptr, + dst=ub_ptr, + n_burst=32, + len_burst=200, + gm_stride=200, + ub_stride=256, + enable_ub_pad=True, # Enable padded transfer +) +``` + +##### Accessing Pad Values in Kernel Code + +Tile `PadValue` descriptors can be used within kernel code for computation purposes (e.g., initializing vectors with a specific fill value). However, note that **these descriptors are not automatically used for DMA padding**—you must still call `pto.set_mov_pad_val` explicitly to configure the hardware pad register for GM→UB transfers. + +To access a pad value from a tile descriptor in kernel code: + +```python +# Get the pad descriptor from the destination tile +pad_desc = dst.pad_value + +# Check if a valid pad value is configured +if pto.constexpr(pad_desc != pto.PadValue.NULL): + # Materialize the scalar value + pad_scalar = pad_desc.eval() + + # Use the scalar value (e.g., for vector duplication) + mask = pto.make_mask(pto.f32, PAT.ALL) + pad_vector = pto.vdup(pad_scalar, mask) +``` + +##### Important Notes + +- The `PadValue.NULL` descriptor indicates no pad value is configured. Attempting to call `.eval()` on `PadValue.NULL` will raise a frontend error. +- Custom pad values currently support only 32-bit float payloads (`PadValue.custom_f32(...)`). +- Padding only affects GM→UB transfers (`pto.copy_gm_to_ubuf`). UB→GM and UB→UB transfers do not support padding. +- The padded region is determined by the difference between the tile's `valid_shape` and its full `shape`. Ensure your tile is configured with appropriate dimensions. +- Tile `PadValue` descriptors are not automatically used for DMA padding. You must call `pto.set_mov_pad_val` explicitly to configure the hardware pad register for padded GM→UB transfers. + +##### `pto.set_mov_pad_val` Operation [Advanced Tier] + +The `pto.set_mov_pad_val` operation configures the hardware pad register used for GM→UB transfers when padding is enabled. This operation must be called explicitly before any `pto.copy_gm_to_ubuf` operation with `enable_ub_pad=True`, as the TileLang DSL v1 does not automatically translate tile `PadValue` descriptors to hardware register configurations. + +**Operation Signature**: +```python +pto.set_mov_pad_val(pad_value: ScalarType) -> None +``` + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `pad_value` | `ScalarType` | Scalar value used for padding. Supported types: any 8/16/32-bit integer scalar (`pto.i8`, `pto.si8`, `pto.ui8`, `pto.i16`, `pto.si16`, `pto.ui16`, `pto.i32`, `pto.si32`, `pto.ui32`) plus `pto.f16`, `pto.bf16`, and `pto.f32`. The value's bit pattern is encoded into the hardware pad register. Integer inputs are automatically normalized to the corresponding signless hardware operand width during lowering, so no manual cast is required before calling `pto.set_mov_pad_val`. For standard pad values, use `PadValue.eval(...)` to obtain the appropriate scalar: `0` or `0.0` for `PadValue.ZERO`, dtype-aware maximum for `PadValue.MAX`, dtype-aware minimum for `PadValue.MIN`. | + +**Returns**: None (side-effect operation) + +**Example**: + +Using a scalar value directly: +```python +# Configure the hardware pad register for zero fill using an integer scalar +pto.set_mov_pad_val(pto.i32(0)) # Zero fill for integer types + +# Or using a float scalar for floating-point padding +pto.set_mov_pad_val(pto.f32(0.0)) # Zero fill for float types + +# Perform DMA transfer with padding enabled +pto.copy_gm_to_ubuf( + src=gm_ptr, + dst=ub_ptr, + n_burst=32, + len_burst=200, + gm_stride=200, + ub_stride=256, + enable_ub_pad=True, +) +``` + +Using a tile's pad value descriptor: +```python +# Get the pad value from a tile configuration +pad_desc = tile.pad_value # PadValue enum +if pto.constexpr(pad_desc != pto.PadValue.NULL): + pad_scalar = pad_desc.eval() # Materializes to a scalar value + pto.set_mov_pad_val(pad_scalar) + + # Perform padded DMA transfer + pto.copy_gm_to_ubuf( + src=gm_ptr, + dst=ub_ptr, + n_burst=32, + len_burst=200, + gm_stride=200, + ub_stride=256, + enable_ub_pad=True, + ) +``` + +Using a standalone `PadValue` with an explicit dtype: +```python +pad_scalar = pto.PadValue.MAX.eval(pto.f32) +pto.set_mov_pad_val(pad_scalar) +``` + +For integer tile dtypes such as `pto.ui16` or `pto.si32`, `pad_desc.eval()` can be passed directly to `pto.set_mov_pad_val`. TileLang DSL v1 will automatically insert the required same-width bitcast to the signless hardware operand type during lowering. + +**Important**: You are responsible for ensuring the pad register is properly configured before any `pto.copy_gm_to_ubuf` operation with `enable_ub_pad=True`. The pad register configuration persists until changed by another `pto.set_mov_pad_val` call. + +**Future Improvement**: Future versions of TileLang DSL may provide an implicit approach that automatically translates `PadValue` descriptors from tile configurations to hardware register configurations, similar to DMA syntax sugar features. + +#### `pto.set_loop2_stride_outtoub(src_stride: pto.i64, dst_stride: pto.i64) -> None` [Advanced Tier] + +**Description**: Configures DMA stride parameters for GM → UB transfers (loop2). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src_stride` | `pto.i64` | Source-side stride | +| `dst_stride` | `pto.i64` | Destination-side stride | + +**Returns**: None (side-effect operation) + +#### `pto.set_loop1_stride_outtoub(src_stride: pto.i64, dst_stride: pto.i64) -> None` [Advanced Tier] + +**Description**: Configures DMA stride parameters for GM → UB transfers (loop1). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src_stride` | `pto.i64` | Source-side stride | +| `dst_stride` | `pto.i64` | Destination-side stride | + +**Returns**: None (side-effect operation) + +#### `pto.set_loop_size_outtoub(loop1: pto.i64, loop2: pto.i64) -> None` [Advanced Tier] + +**Description**: Configures DMA transfer size for GM → UB transfers. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `loop1` | `pto.i64` | Inner loop trip count | +| `loop2` | `pto.i64` | Outer loop trip count | + +**Returns**: None (side-effect operation) + +**Example**: +```python +pto.set_loop_size_outtoub(loop1=1, loop2=1) +``` + +#### `pto.set_loop2_stride_ubtoout(src_stride: pto.i64, dst_stride: pto.i64) -> None` [Advanced Tier] + +**Description**: Configures DMA stride parameters for UB → GM transfers (loop2). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src_stride` | `pto.i64` | Source-side stride | +| `dst_stride` | `pto.i64` | Destination-side stride | + +**Returns**: None (side-effect operation) + +#### `pto.set_loop1_stride_ubtoout(src_stride: pto.i64, dst_stride: pto.i64) -> None` [Advanced Tier] + +**Description**: Configures DMA stride parameters for UB → GM transfers (loop1). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src_stride` | `pto.i64` | Source-side stride | +| `dst_stride` | `pto.i64` | Destination-side stride | + +**Returns**: None (side-effect operation) + +#### `pto.set_loop_size_ubtoout(loop1: pto.i64, loop2: pto.i64) -> None` [Advanced Tier] + +**Description**: Configures DMA transfer size for UB → GM transfers. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `loop1` | `pto.i64` | Inner loop trip count | +| `loop2` | `pto.i64` | Outer loop trip count | + +**Returns**: None (side-effect operation) + +#### `pto.set_loop(loop_id: pto.i32, src_stride: pto.i64, dst_stride: pto.i64) -> None` [Advanced Tier] + +**Description**: Configures DMA stride parameters for a generic loop. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `loop_id` | `pto.i32` | Loop identifier (e.g., 1 for inner loop, 2 for outer loop) | +| `src_stride` | `pto.i64` | Source-side stride | +| `dst_stride` | `pto.i64` | Destination-side stride | + +**Returns**: None (side-effect operation) + +**Example**: +```python +pto.set_loop(1, src_stride=32, dst_stride=64) +``` + +#### `pto.set_loop_size(loop_id: pto.i32, size: pto.i64) -> None` [Advanced Tier] + +**Description**: Configures DMA transfer size for a generic loop. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `loop_id` | `pto.i32` | Loop identifier (e.g., 1 for inner loop, 2 for outer loop) | +| `size` | `pto.i64` | Loop trip count | + +**Returns**: None (side-effect operation) + +**Example**: +```python +pto.set_loop_size(1, 16) +``` + +#### DMA Execution Operations + +**Note**: These operations execute DMA transfers but require manual configuration of DMA parameters (loop strides, loop sizes) using the `set_loop*_stride_*` and `set_loop_size_*` operations described above. + +The following operations provide direct control over DMA transfers but require manual stride and size configuration. + +#### `pto.copy_gm_to_ubuf(src: GMPtr, dst: UBPtr, sid: pto.i64 = 0, n_burst: pto.i64, len_burst: pto.i64, left_padding_count: pto.i64 = 0, right_padding_count: pto.i64 = 0, enable_ub_pad: pto.i1 = False, l2_cache_ctl: pto.i64 = 0, gm_stride: pto.i64, ub_stride: pto.i64) -> None` [Advanced Tier] + +**Description**: Copies data from Global Memory (GM) to Unified Buffer (UB). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `GMPtr` | Source GM pointer | +| `dst` | `UBPtr` | Destination UB pointer | +| `sid` | `pto.i64` | DMA stream/control operand, defaults to `0` | +| `n_burst` | `pto.i64` | Number of bursts | +| `len_burst` | `pto.i64` | Bytes copied by each burst | +| `left_padding_count` | `pto.i64` | Left padding count, defaults to `0` | +| `right_padding_count` | `pto.i64` | Right padding count, defaults to `0` | +| `enable_ub_pad` | `pto.i1` | Convenience alias for `data_select_bit`, defaults to `False` | +| `l2_cache_ctl` | `pto.i64` | L2 cache control operand, defaults to `0` | +| `gm_stride` | `pto.i64` | GM-side stride in bytes | +| `ub_stride` | `pto.i64` | UB-side stride in bytes | + +**Returns**: None (side-effect operation) + +**Notes**: +- **Keyword arguments**: The keyword form shown above is the recommended public API surface. Use named arguments for clarity. +- **Padding control**: Set `enable_ub_pad=True` to enable padded GM→UB transfers. The pad value must be configured separately using `pto.set_mov_pad_val` before the DMA operation (see [Pad Fill Semantics](#pad-fill-semantics) for details). +- **Pad value source**: When padding is enabled, the fill scalar comes from the hardware pad register configured by `pto.set_mov_pad_val`. You must call this operation explicitly before the DMA transfer. +- **ABI compatibility**: The lowering preserves the underlying PTO operand order while providing a more ergonomic keyword interface. + +**Example**: +```python +pto.copy_gm_to_ubuf( + src=gm_ptr, + dst=ub_ptr, + n_burst=32, + len_burst=128, + gm_stride=128, + ub_stride=128, + enable_ub_pad=False, +) +``` + +**Padding Example**: +```python +# First configure the hardware pad register with a scalar value +pto.set_mov_pad_val(pto.f32(0.0)) # Zero fill for float32 data + +# Then perform padded DMA transfer +pto.copy_gm_to_ubuf( + src=gm_ptr, + dst=ub_ptr, + n_burst=32, + len_burst=200, + gm_stride=200, + ub_stride=256, + enable_ub_pad=True, +) +``` + +#### `pto.copy_ubuf_to_ubuf(src: UBPtr, dst: UBPtr, src_offset: pto.i64, src_stride0: pto.i64, src_stride1: pto.i64, dst_offset: pto.i64, dst_stride0: pto.i64, dst_stride1: pto.i64) -> None` [Advanced Tier] + +**Description**: Copies data within Unified Buffer (UB → UB). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `UBPtr` | Source UB pointer | +| `dst` | `UBPtr` | Destination UB pointer | +| `src_offset` | `pto.i64` | Source offset | +| `src_stride0` | `pto.i64` | Source stride dimension 0 | +| `src_stride1` | `pto.i64` | Source stride dimension 1 | +| `dst_offset` | `pto.i64` | Destination offset | +| `dst_stride0` | `pto.i64` | Destination stride dimension 0 | +| `dst_stride1` | `pto.i64` | Destination stride dimension 1 | + +**Returns**: None (side-effect operation) + +#### `pto.copy_ubuf_to_gm(src: UBPtr, dst: GMPtr, sid: pto.i64 = 0, n_burst: pto.i64, len_burst: pto.i64, reserved: pto.i64 = 0, gm_stride: pto.i64, ub_stride: pto.i64) -> None` [Advanced Tier] + +**Description**: Copies data from Unified Buffer (UB) to Global Memory (GM). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `UBPtr` | Source UB pointer | +| `dst` | `GMPtr` | Destination GM pointer | +| `sid` | `pto.i64` | DMA stream/control operand, defaults to `0` | +| `n_burst` | `pto.i64` | Number of bursts | +| `len_burst` | `pto.i64` | Bytes copied by each burst | +| `reserved` | `pto.i64` | Reserved operand, defaults to `0` | +| `gm_stride` | `pto.i64` | GM-side stride in bytes | +| `ub_stride` | `pto.i64` | UB-side stride in bytes | + +**Returns**: None (side-effect operation) + +**Notes**: +- In TileLang DSL, the keyword form above is the recommended public surface. +- `gm_stride`/`ub_stride` are ergonomic aliases for the low-level `burst_dst_stride`/`burst_src_stride` operands. +- The lowering still maps to the underlying low-level PTO operand ABI in positional order. + +**Example**: +```python +pto.copy_ubuf_to_gm( + src=ub_ptr, + dst=gm_ptr, + n_burst=32, + len_burst=128, + gm_stride=128, + ub_stride=128, +) +``` diff --git a/tilelang-dsl/docs/user_guide/09-vector-memory-operations.md b/tilelang-dsl/docs/user_guide/09-vector-memory-operations.md new file mode 100644 index 000000000..f7a20fd76 --- /dev/null +++ b/tilelang-dsl/docs/user_guide/09-vector-memory-operations.md @@ -0,0 +1,1058 @@ +### Enum Types for Vector Memory Operations + +The current DSL exposes type-safe Enum operands for the dual load/store +distribution families: + +- **`VLoadDist`** for `pto.vlds` + - `VLoadDist.NORM`: ordinary load + - `VLoadDist.UNPK_B8`, `VLoadDist.UNPK_B16`, `VLoadDist.UNPK_B32`: unpacking loads + - `VLoadDist.BRC_B8`, `VLoadDist.BRC_B16`, `VLoadDist.BRC_B32`: broadcast loads + - `VLoadDist.US_B8`, `VLoadDist.US_B16`, `VLoadDist.DS_B8`, `VLoadDist.DS_B16`: strided/narrow load families + +- **`VStoreDist`** for `pto.vsts` + - `VStoreDist.NORM_B8`, `VStoreDist.NORM_B16`, `VStoreDist.NORM_B32`: ordinary stores + - `VStoreDist.ONE_POINT_B8`, `VStoreDist.ONE_POINT_B16`, `VStoreDist.ONE_POINT_B32`: one-point stores + - `VStoreDist.PK_B16`, `VStoreDist.PK_B32`, `VStoreDist.PK_B64`: packed stores + - `VStoreDist.PK4_B32`, `VStoreDist.MRG4CHN_B8`, `VStoreDist.MRG2CHN_B8`, `VStoreDist.MRG2CHN_B16`: merged packed stores + +- **`DeinterleaveDist`** for `pto.vldsx2` + - `DeinterleaveDist.DINTLV`: alternating-element deinterleave + - `DeinterleaveDist.BDINTLV`: block deinterleave + - compatibility aliases: `DeinterleaveDist.B8`, `DeinterleaveDist.B16`, + `DeinterleaveDist.B32`, `DeinterleaveDist.BD` + +- **`InterleaveDist`** for `pto.vstsx2` + - `InterleaveDist.INTLV`: interleave two vectors into one destination stream + - compatibility aliases: `InterleaveDist.B8`, `InterleaveDist.B16`, + `InterleaveDist.B32` + +- **`PostUpdateMode`** for `pto.vstur` + - `PostUpdateMode.NO_POST_UPDATE`: preserve the current hardware AR state + - `PostUpdateMode.POST_UPDATE`: advance the hardware AR state after the store + +The canonical VPTO v0.3 spellings are the enum values: + +- `VLoadDist.UNPK_B16.value == "UNPK_B16"` +- `VStoreDist.PK_B32.value == "PK_B32"` +- `DeinterleaveDist.DINTLV.value == "DINTLV"` +- `DeinterleaveDist.BDINTLV.value == "BDINTLV"` +- `InterleaveDist.INTLV.value == "INTLV"` +- `PostUpdateMode.NO_POST_UPDATE.value == "NO_POST_UPDATE"` +- `PostUpdateMode.POST_UPDATE.value == "POST_UPDATE"` + +`pto.vstur` mode is intentionally Enum-only in the DSL. Unlike the legacy +distribution-token compatibility retained for some older load/store families, +raw strings such as `"POST_UPDATE"` are not accepted for `PostUpdateMode`. + +For migration convenience, the implementation still accepts legacy raw strings +such as `"DINTLV_B32"` and `"INTLV_B32"`, but new DSL code should prefer the +Enum operands. + +- **`StrideMode`**: Stride modes for `pto.vsld` + - `S3_B16`: Stride 3, block size 16 + - `S4_B64`: Stride 4, block size 64 + - `S8_B32`: Stride 8, block size 32 + - `S2_B64`: Stride 2, block size 64 + +### Address Generation Syntax Sugar + +To simplify address calculation and reduce manual byte offset computation errors, TileLang DSL provides syntactic sugar for vector load/store operations using element-based indexing. This syntax automatically computes the byte offset based on tile shape, element type, and layout. + +#### Indexing Syntax + +The syntax supports two indexing modes for different operations: + +1. **Vector-range indexing** (for vector load/store operations): + - **Row-major layout (default)**: `tile[row_index, col_start:]` + - `row_index`: Row index (0-based) + - `col_start:`: Starting column index followed by colon, indicating a vector-width contiguous region starting from this column + - The colon (`:`) indicates an implicit vector-width range determined by hardware vector size (256 bytes) and element type + + - **Column-major layout**: `tile[row_start:, col_index]` + - `row_start:`: Starting row index followed by colon, indicating a vector-width contiguous region starting from this row + - `col_index`: Column index (0-based) + - Used for column-major tiles (`BLayout.COL_MAJOR`) where elements are stored column-wise + + - **1D tile indexing**: `tile[start:]` (or equivalently `tile[0, start:]` for row-major or `tile[start:, 0]` for column-major) + - `start:`: Starting element index followed by colon + + Tile indexing sugar only accepts an open-ended vector slice. Python slice + forms with an explicit `stop` or `step` are not supported for `Tile` + indexing. For example, `tile[row, col:col_end]`, `tile[row, col::2]`, + `tile[row_start:row_end, col]`, and `tile[start:stop:step]` are invalid. + +2. **Single-element indexing** (for scalar load operations like `pto.vsld`): + - **Row-major layout (default)**: `tile[row_index, col_index]` + - `row_index`: Row index (0-based) + - `col_index`: Column index (0-based) + - Loads a single element at the specified position and broadcasts it to all vector lanes + + - **Column-major layout**: `tile[row_index, col_index]` (same syntax) + - `row_index`: Row index (0-based) + - `col_index`: Column index (0-based) + - Same syntax as row-major; the layout determines how the offset is computed + + - **1D tile indexing**: `tile[pos]` + - `pos`: Element index (0-based) + - Loads a single element at the specified position and broadcasts it to all vector lanes + +#### Vector Width Calculation + +The number of elements loaded/stored in a single vector operation is determined by: + +``` +vector_lanes = 256 // element_size_bytes(element_type) +``` + +**Convenience API**: Use `pto.elements_per_vreg(dtype)` to compute the number of elements per vector register for a given element type (e.g., `pto.elements_per_vreg(pto.f32)` returns 64, `pto.elements_per_vreg(pto.f16)` returns 128). See [Type Query Operations](07-frontend-operations.md#type-query-operations) for full documentation. + +Where `element_size_bytes` is: +- 1 byte for `i8`, `si8`, `ui8` +- 2 bytes for `i16`, `si16`, `ui16`, `f16`, `bf16` +- 4 bytes for `i32`, `si32`, `ui32`, `f32` +- 8 bytes for `i64`, `si64`, `ui64` + +#### Offset Computation + +The byte offset is automatically computed based on tile layout: + +- **Row-major layout** (`BLayout.ROW_MAJOR`): + ``` + offset = (row_index * stride_row + col_start) * element_size_bytes + ``` + where `stride_row` is the row stride in elements (typically `tile.shape[1]` for contiguous tiles). + +- **Column-major layout** (`BLayout.COL_MAJOR`): + - For syntax `tile[row_start:, col_index]`: + ``` + offset = (col_index * stride_col + row_start) * element_size_bytes + ``` + - For backward compatibility with traditional offset calculation: + ``` + offset = (col_start * stride_col + row_index) * element_size_bytes + ``` + where `stride_col` is the column stride in elements (typically `tile.shape[0]` for contiguous tiles), `row_start` is the starting row index, and `col_index` is the column index. + +**Note**: +- For single-element indexing (`tile[row, col]` or `tile[pos]`), the same offset formulas apply with `col_start` replaced by `col_index` (or `start` replaced by `pos` for 1D tiles). +- For column-major vector-range indexing (`tile[row_start:, col_index]`), the offset formula uses `row_start` as the starting position along the contiguous dimension. +- The compiler automatically handles the appropriate substitution based on the indexing syntax and tile layout. + +#### Constraints + +1. **Boundary checks**: The requested region must be within tile bounds: + - **For vector-range indexing** (`:` syntax): + - **Row-major layout** (`tile[row_index, col_start:]`): + - `row_index < tile.shape[0]` and `col_start + vector_lanes <= tile.shape[1]` + - **Column-major layout** (`tile[row_start:, col_index]`): + - `row_start + vector_lanes <= tile.shape[0]` and `col_index < tile.shape[1]` + - **1D tile indexing**: `tile[start:]` + - `start + vector_lanes <= tile.shape[0]` (or `tile.shape[1]` for 1D tiles) + - **For single-element indexing** (no `:` syntax): + - 2D: `row_index < tile.shape[0]` and `col_index < tile.shape[1]` (same for both layouts) + - 1D: `pos < tile.shape[0]` (or `tile.shape[1]` for 1D tiles) + +2. **Alignment**: The computed offset must satisfy hardware alignment requirements for the operation. + +3. **Full vectors only**: The `:` syntax always loads/stores a full vector width. For partial vectors, use the traditional byte offset approach with explicit mask handling. + +4. **Single-element operations**: The single-element indexing syntax (`tile[row, col]` or `tile[pos]`) is only supported for scalar load operations like `pto.vsld`. For other operations, use vector-range indexing with `:` syntax. + +5. **No explicit slice bounds/stride for `Tile` indexing**: `Tile` vector-range + indexing only supports the open-ended forms `tile[start:]`, + `tile[row, col:]`, and `tile[row_start:, col_index]` (for column-major + layout). `stop` and `step` syntax are not accepted in user-guide Tile + indexing. + +#### Supported Operations + +The indexing syntax is supported for all vector load and store operations with the following syntax mapping: + +- **Vector-range indexing** (`tile[row, col:]` or `tile[start:]`): + - Load operations: `vlds`, `vldas`, `vldus`, `vldsx2` + - Store operations: `vsts`, `vsta`, `psts`, `vsst`, `vstsx2` + +- **Single-element indexing** (`tile[row, col]` or `tile[pos]`): + - Load operations: `vsld` (scalar load with broadcast) + +#### Examples + +The following examples use row-major layout syntax. For column-major tiles, use `tile[row_start:, col_index]` syntax instead of `tile[row_index, col_start:]`. + +```python +# 2D tile indexing (row-major layout) +vec = pto.vlds(tile[i, j:]) # Load vector from row i, columns j to j+vector_lanes-1 +pto.vsts(vec, tile[i, j:], mask) # Store vector with mask + +# 1D tile indexing +vec = pto.vlds(tile[k:]) # Load vector from elements k to k+vector_lanes-1 +pto.vsts(vec, tile[k:], mask) # Store vector with mask + +# Dual load with deinterleave +low, high = pto.vldsx2(tile[i, j:], "DINTLV") + +# Aligned load with indexing +vec = pto.vldas(tile[i, j:], align) + +# Scalar load (broadcast) +vec = pto.vsld(tile[i, j]) # Load scalar at tile[i,j] and broadcast to vector +``` + +#### Comparison with Manual Offset Calculation + +**Traditional approach (error-prone):** +```python +# Manual byte offset calculation for f32 tile +rows, cols = tile.shape +row_offset = i * cols * 4 # Hard-coded 4 bytes for f32 +col_offset = j * 4 +offset = row_offset + col_offset +vec = pto.vlds(tile, offset) +``` + +**New syntax (type-safe):** +```python +# Automatic offset calculation +vec = pto.vlds(tile[i, j:]) # Compiler computes correct offset for any element type +``` + +The syntax sugar eliminates manual byte calculations, reduces errors, and makes code generic across different element types (e.g., the same kernel works for both `f16` and `f32` without modification). + +### Vector Load Operations + +Operations for loading data from memory into vector registers. + +#### `pto.vlds(buf: ptr, offset: Index, dist: pto.VLoadDist | None = None) -> VRegType` [Advanced Tier] +#### `pto.vlds(tile[row, col:], dist: pto.VLoadDist | None = None) -> VRegType` [Basic Tier] +#### `pto.vlds(tile[start:], dist: pto.VLoadDist | None = None) -> VRegType` [Basic Tier] + +**Description**: Stateless vector load from buffer. Supports both traditional byte-offset syntax and new element-indexing syntax. + +**Parameters (pointer syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `ptr` | Pointer to buffer in UB memory space (Advanced mode only - requires explicit pointer) | +| `offset` | `Index` | Byte offset | +| `dist` | `pto.VLoadDist \| None` | Optional load distribution enum such as `pto.VLoadDist.NORM` or `pto.VLoadDist.UNPK_B16` | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column (vector-width range) | +| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index (vector-width range) | +| `dist` | `pto.VLoadDist \| None` | Optional load distribution enum such as `pto.VLoadDist.NORM` or `pto.VLoadDist.UNPK_B16` | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec` | `VRegType` | Loaded vector register | + +**Constraints**: +- Buffer must be in UB memory space +- For byte-offset syntax: offset must be properly aligned based on element type +- For element-indexing syntax: the requested vector region must be within tile bounds and satisfy alignment requirements +- `dist` is optional. When omitted, the load uses the backend default layout for the vector family. +- `dist` must be a `pto.VLoadDist` enum value. + +**Examples**: +```python +# Traditional byte-offset syntax +vec = pto.vlds(ub_ptr, lane * 256) +vec_unpacked = pto.vlds(ub_ptr, lane * 128, dist=pto.VLoadDist.UNPK_B16) + +# New element-indexing syntax +vec = pto.vlds(tile[i, j:]) # Load from row i, columns j to j+vector_lanes-1 +vec = pto.vlds(tile[k:]) # Load from 1D tile, elements k to k+vector_lanes-1 + +# Generic kernel that works for both f16 and f32 +@pto.vkernel(target="a5", op="scale", dtypes=[(pto.AnyFloat, pto.AnyFloat)], priority=10) +def generic_scale(src: pto.Tile, dst: pto.Tile, scale: pto.f32): + rows, cols = src.shape + all_mask = pto.make_mask(src.element_type, PAT.ALL) + for i in range(0, rows): + for j in range(0, cols, vector_lanes): # vector_lanes computed from element type + # No manual byte calculation needed! + vec = pto.vlds(src[i, j:]) + scaled = pto.vmuls(vec, scale, all_mask) + pto.vsts(scaled, dst[i, j:], all_mask) +``` + +#### `pto.vldas(buf: ptr) -> pto.align` [Advanced Tier] +#### `pto.vldas(tile[row, col:]) -> pto.align` [Basic Tier] +#### `pto.vldas(tile[start:]) -> pto.align` [Basic Tier] + +**Description**: Prime alignment buffer for subsequent unaligned load. Returns alignment state for use with `pto.vldus`. Supports both pointer syntax and element-indexing syntax. + +**Parameters (pointer syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `ptr` | Pointer to buffer in UB memory space (Advanced mode only - requires explicit pointer) | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column | +| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `align` | `pto.align` | Alignment state for use with `pto.vldus` | + +**Examples**: +```python +# Pointer syntax +align = pto.vldas(ub_ptr) + +# Element-indexing syntax +align = pto.vldas(tile[i, j:]) +align = pto.vldas(tile[k:]) +``` + +#### `pto.vldus(buf: ptr, align: pto.align) -> (VRegType, pto.align, ptr)` [Advanced Tier] +#### `pto.vldus(tile[row, col:], align: pto.align) -> (VRegType, pto.align, ptr)` [Basic Tier] +#### `pto.vldus(tile[start:], align: pto.align) -> (VRegType, pto.align, ptr)` [Basic Tier] + +**Description**: Unaligned load using primed align state. Requires alignment state from `pto.vldas` or previous `pto.vldus`. Updates alignment state and base pointer for subsequent loads. Supports both pointer syntax and element-indexing syntax. + +**Parameters (pointer syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `ptr` | Pointer to buffer in UB memory space (Advanced mode only - requires explicit pointer) | +| `align` | `pto.align` | Alignment state from `pto.vldas` or previous `pto.vldus` | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column | +| `align` | `pto.align` | Alignment state from `pto.vldas` or previous `pto.vldus` | +| _or_ | | | +| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index | +| `align` | `pto.align` | Alignment state from `pto.vldas` or previous `pto.vldus` | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec` | `VRegType` | Assembled vector value | +| `align_out` | `pto.align` | Updated alignment state for next load | +| `base_out` | `ptr` | Post-update base pointer state | + +**Constraints**: +- A matching `pto.vldas` must appear before the first dependent `pto.vldus` stream in the same vector loop +- Both alignment state and base address advance across the stream +- If DSL authoring uses explicit byte/element offsets, the frontend first rewrites them into pointer/index expressions before lowering to this VPTO form. + +**Examples**: +```python +# Pointer syntax - requires alignment state priming +align = pto.vldas(ub_ptr) +vec, align_out, base_out = pto.vldus(ub_ptr, align) + +# Element-indexing syntax +align = pto.vldas(tile[i, j:]) +vec, align_out, base_out = pto.vldus(tile[i, j:], align) + +# Multiple unaligned loads in a stream +align = pto.vldas(tile[k:]) +for n in range(4): + vec, align, base = pto.vldus(tile[k:], align) # alignment state updates +``` + + +#### `pto.vldsx2(buf: ptr, offset: Index, dist: DeinterleaveDist) -> (VRegType, VRegType)` [Advanced Tier] +#### `pto.vldsx2(tile[row, col:], dist: DeinterleaveDist) -> (VRegType, VRegType)` [Basic Tier] +#### `pto.vldsx2(tile[start:], dist: DeinterleaveDist) -> (VRegType, VRegType)` [Basic Tier] + +**Description**: Dual vector load with deinterleave (AoS → SoA conversion). Loads interleaved data from a single buffer and deinterleaves into two vectors. Supports both byte-offset and element-indexing syntax. + +**Parameters (pointer syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `ptr` | Pointer to source buffer in UB memory space (Advanced mode only - requires explicit pointer) | +| `offset` | `Index` | Byte offset | +| `dist` | `DeinterleaveDist` | Deinterleave distribution enum. Prefer `DeinterleaveDist.DINTLV` or `DeinterleaveDist.BDINTLV`. | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column (vector-width range) | +| `dist` | `DeinterleaveDist` | Deinterleave distribution enum. Prefer `DeinterleaveDist.DINTLV` or `DeinterleaveDist.BDINTLV`. | +| _or_ | | | +| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index (vector-width range) | +| `dist` | `DeinterleaveDist` | Deinterleave distribution enum. Prefer `DeinterleaveDist.DINTLV` or `DeinterleaveDist.BDINTLV`. | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `low` | `VRegType` | First vector (even elements in interleaved stream) | +| `high` | `VRegType` | Second vector (odd elements in interleaved stream) | + +**Constraints**: +- Source buffer must be in UB memory space +- Offset must satisfy alignment requirements for the selected distribution mode +- The requested vector region must be within tile bounds (for element-indexing syntax) +- Distribution mode must match element type (e.g., `"DINTLV"` for 32-bit elements) + +**Examples**: +```python +# Byte-offset syntax +low, high = pto.vldsx2(ub_ptr, offset, pto.DeinterleaveDist.DINTLV) + +# Element-indexing syntax +low, high = pto.vldsx2(tile[i, j:], pto.DeinterleaveDist.DINTLV) +low, high = pto.vldsx2(tile[k:], pto.DeinterleaveDist.DINTLV) + +# Example: Load interleaved XY pairs into separate X/Y vectors +x_vec, y_vec = pto.vldsx2(xy_tile[i, j:], pto.DeinterleaveDist.DINTLV) +``` + +#### `pto.vsld(buf: ptr, offset: Index, stride: StrideMode) -> VRegType` [Advanced Tier] +#### `pto.vsld(tile[row, col], stride: StrideMode) -> VRegType` [Basic Tier] +#### `pto.vsld(tile[pos], stride: StrideMode) -> VRegType` [Basic Tier] + +**Description**: Strided load with fixed stride pattern. Loads elements from memory with regular stride pattern. The offset parameter encodes displacement with selected stride mode. This is a deprecated compatibility family. + +**Parameters (pointer syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `ptr` | Pointer to buffer in UB memory space (Advanced mode only - requires explicit pointer) | +| `offset` | `Index` | Byte displacement encoded with selected stride mode | +| `stride` | `StrideMode` | Stride mode token: `StrideMode.S3_B16`, `StrideMode.S4_B64`, `StrideMode.S8_B32`, `StrideMode.S2_B64`. Determines which sub-elements are read from each source block. | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `tile[row, col]` | `Tile` with indexing | 2D tile with row and column indices (single element) | +| `stride` | `StrideMode` | Stride mode token: `StrideMode.S3_B16`, `StrideMode.S4_B64`, `StrideMode.S8_B32`, `StrideMode.S2_B64`. | +| _or_ | | | +| `tile[pos]` | `Tile` with indexing | 1D tile with element index (single element) | +| `stride` | `StrideMode` | Stride mode token: `StrideMode.S3_B16`, `StrideMode.S4_B64`, `StrideMode.S8_B32`, `StrideMode.S2_B64`. | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec` | `VRegType` | Loaded vector with strided pattern | + +**Constraints**: +- The selected stride token determines which sub-elements are read from each source block +- This operation family is deprecated; prefer other load patterns when possible + +**Examples**: +```python +from pto import StrideMode + +# Byte-offset syntax +vec = pto.vsld(ub_ptr, offset, StrideMode.S4_B64) + +# Element-indexing syntax +vec = pto.vsld(tile[i, j], StrideMode.S3_B16) +vec = pto.vsld(tile[k], StrideMode.S8_B32) +``` + +#### `pto.vgather2(buf: ptr, offsets: Index, active_lanes: Index) -> VRegType` [Advanced Tier] + +**Description**: Indexed gather from UB. Gathers elements from a single buffer using per-lane offsets, with participation bounded by active lanes count. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `ptr` | Pointer to source buffer in UB memory space | +| `offsets` | `Index` | Per-lane element offsets (vector register) | +| `active_lanes` | `Index` | Number of lanes that participate (bounds gather operation) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec` | `VRegType` | Gathered vector | + +**Constraints**: +- Only the first `active_lanes` offsets participate in the gather +- Index element width and interpretation must match selected gather form +- Each effective address must satisfy the gather form's alignment rules + +**Example**: +```python +vec = pto.vgather2(buf, offsets, active_lanes) +``` + +#### `pto.vgather2_bc(buf: ptr, offsets: Index, mask: MaskType) -> VRegType` [Advanced Tier] + +**Description**: Gather with broadcast, conditioned by mask. Gathers elements from a single buffer using per-lane offsets, with mask gating lane participation. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `ptr` | Pointer to source buffer in UB memory space | +| `offsets` | `Index` | Per-lane element offsets (vector register) | +| `mask` | `MaskType` | Mask gating which lanes participate | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec` | `VRegType` | Gathered vector | + +**Constraints**: +- Masked-off lanes do not participate in address coalescing and do not trigger address overflow exceptions +- Destination lanes for masked-off lanes are zero-filled +- This is a backward-compatible operation family + +**Example**: +```python +vec = pto.vgather2_bc(buf, offsets, mask) +``` + +#### `pto.vgatherb(buf: ptr, offsets: Index) -> VRegType` [Advanced Tier] + +**Description**: Byte‑granularity gather load. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `ptr` | Pointer to buffer | +| `offsets` | `Index` | Byte offsets | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec` | `VRegType` | Gathered vector | + +**Example**: +```python +vec = pto.vgatherb(buf, offsets) +``` + +#### `pto.vsldb(buf: ptr, offset: Index, mask: MaskType) -> VRegType` [Advanced Tier] +#### `pto.vsldb(tile[row, col], offset: Index, mask: MaskType) -> VRegType` [Basic Tier] +#### `pto.vsldb(tile[pos], offset: Index, mask: MaskType) -> VRegType` [Basic Tier] + +**Description**: Block-strided load for 2D tile access. Loads elements with block stride pattern controlled by packed offset word and mask. + +**Parameters (byte-offset syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `ptr` | Pointer to buffer in UB memory space | +| `offset` | `Index` | Packed stride/control word (not plain byte displacement) | +| `mask` | `MaskType` | Mask controlling which blocks participate | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `tile[row, col]` | `Tile` with indexing | 2D tile with row and column indices (single element) | +| `offset` | `Index` | Packed stride/control word (not plain byte displacement) | +| `mask` | `MaskType` | Mask controlling which blocks participate | +| _or_ | | | +| `tile[pos]` | `Tile` with indexing | 1D tile with element index (single element) | +| `offset` | `Index` | Packed stride/control word (not plain byte displacement) | +| `mask` | `MaskType` | Mask controlling which blocks participate | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec` | `VRegType` | Loaded vector with block-strided pattern | + +**Constraints**: +- The offset encodes block stride and repeat pattern, not a plain byte displacement +- If a block is masked off, the corresponding destination block is zeroed +- Masked-off blocks must not raise address overflow exceptions + +**Example**: +```python +# Byte-offset syntax +vec = pto.vsldb(ub_ptr, control_word, mask) + +# Element-indexing syntax +vec = pto.vsldb(tile[i, j], control_word, mask) +vec = pto.vsldb(tile[k], control_word, mask) +``` + +### Vector Store Operations + +Operations for storing data from vector registers to memory. + +#### `pto.vsts(vec: VRegType, buf: ptr, offset: Index, mask: MaskType, dist: pto.VStoreDist | None = None) -> None` [Advanced Tier] +#### `pto.vsts(vec: VRegType, tile[row, col:], mask: MaskType, dist: pto.VStoreDist | None = None) -> None` [Basic Tier] +#### `pto.vsts(vec: VRegType, tile[start:], mask: MaskType, dist: pto.VStoreDist | None = None) -> None` [Basic Tier] + +**Description**: Stateless vector store to buffer. Supports both byte-offset and element-indexing syntax. + +**Parameters (byte-offset syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Vector to store | +| `buf` | `ptr` | Pointer to destination buffer in UB memory space (Advanced mode only - requires explicit pointer) | +| `offset` | `Index` | Byte offset | +| `mask` | `MaskType` | Predicate mask | +| `dist` | `pto.VStoreDist \| None` | Optional store distribution enum such as `pto.VStoreDist.NORM_B32` or `pto.VStoreDist.PK_B32` | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Vector to store | +| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column | +| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index | +| `mask` | `MaskType` | Predicate mask | +| `dist` | `pto.VStoreDist \| None` | Optional store distribution enum such as `pto.VStoreDist.NORM_B32` or `pto.VStoreDist.PK_B32` | + +**Returns**: None (side-effect operation) + +**Constraints**: +- Buffer must be in UB memory space +- For byte-offset syntax: offset must be properly aligned based on element type +- For element-indexing syntax: the destination vector region must be within tile bounds and satisfy alignment requirements +- `dist` is optional. When omitted, the store uses the backend default layout for the vector family. +- Current TileLang DSL v1 accepts exactly one keyword attr on `pto.vsts`: `dist=...`. +- `dist` must be a `pto.VStoreDist` enum value. +- `mask` must match the effective store payload granularity, which may differ from the vector element family when `dist` repacks lanes. +- Common width-changing cases: + default / `NORM_B32` stores expect `mask_b32` for `f32`/`i32`-family vectors; + `PK_B32` also expects `mask_b32` and is used by narrow stores such as `f32 -> f16` `tcvt`; + `PK_B16` expects `mask_b16`. + +**Examples**: +```python +# Byte-offset syntax +pto.vsts(vec_f32, ub_ptr, lane * 256, mask32) + +# Element-indexing syntax +pto.vsts(vec, tile[i, j:], mask) # Store to row i, columns j to j+vector_lanes-1 +pto.vsts(vec, tile[k:], mask) # Store to 1D tile, elements k to k+vector_lanes-1 + +# VPTO-aligned packed store +vec_f16 = pto.vcvt( + vec_f32, + pto.f16, + mask32, + rnd=pto.VcvtRoundMode.R, + sat=pto.VcvtSatMode.SAT, + part=pto.VcvtPartMode.EVEN, +) +pto.vsts(vec_f16, tile[i, j:], mask32, dist=pto.VStoreDist.PK_B32) + +# In a generic kernel +@pto.vkernel(target="a5", op="copy", dtypes=[(pto.AnyFloat, pto.AnyFloat)], priority=10) +def generic_store(src: pto.Tile, dst: pto.Tile): + rows, cols = src.shape + all_mask = pto.make_mask(src.element_type, PAT.ALL) + for i in range(0, rows): + for j in range(0, cols, vector_lanes): + vec = pto.vlds(src[i, j:]) + pto.vsts(vec, dst[i, j:], all_mask) # No manual offset calculation +``` + +#### `pto.psts(mask: MaskType, buf: ptr, offset: Index, dist: PredicateDist = PredicateDist.NORM) -> None` [Advanced Tier] + +**Description**: Predicate store (`pto.psts`) writes the packed payload represented by +`MaskType` to UB memory. This is the dynamic-offset form of the VPTO predicate-store +family (`psts` vs `psti`): the payload semantics are identical, and only the offset +delivery form differs. + +**Parameters (advanced byte-offset syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `MaskType` | Predicate payload to store | +| `buf` | `ptr` | Pointer to destination UB buffer (Advanced mode only - requires explicit pointer) | +| `offset` | `Index` | Runtime offset (`index`) | +| `dist` | `PredicateDist` | Predicate distribution mode. Use `PredicateDist.NORM` or `PredicateDist.PK` (default: `PredicateDist.NORM`). | + +**Returns**: None (side-effect operation) + +**DIST semantics (VPTO-aligned)**: +- `PredicateDist.NORM`: store packed predicate payload into a normal destination space of size `VL/8`. +- `PredicateDist.PK`: store packed predicate payload into a destination space of size `VL/16`, keeping one bit out of every two bits. + +**Notes**: +- `pto.psts` is intentionally documented as explicit `buf + offset` surface in DSL v1. +- Packed predicate payload layout is bit-level (`VL/8` or `VL/16`), so tile element-indexing is not part of the stable Basic Tier contract. +- The pointer + offset form maps directly to explicit `base[offset]`. +- Authoritative predicate-memory-family semantics are documented in `10-predicate-operations.md`. + +#### `pto.vsst(scalar: ScalarType, buf: ptr, offset: Index, mask: MaskType) -> None` [Advanced Tier] +#### `pto.vsst(scalar: ScalarType, tile[row, col:], mask: MaskType) -> None` +#### `pto.vsst(scalar: ScalarType, tile[start:], mask: MaskType) -> None` + +**Description**: Scalar to vector store (broadcast scalar to all lanes). Supports both traditional byte-offset syntax and new element-indexing syntax. + +**Parameters (byte-offset syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `scalar` | `ScalarType` | Scalar value | +| `buf` | `ptr` | Pointer to destination buffer (Advanced mode only - requires explicit pointer) | +| `offset` | `Index` | Byte offset | +| `mask` | `MaskType` | Predicate mask | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `scalar` | `ScalarType` | Scalar value | +| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column (vector-width range) | +| `mask` | `MaskType` | Predicate mask | + +**Parameters (1D element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `scalar` | `ScalarType` | Scalar value | +| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index (vector-width range) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: None (side-effect operation) + +#### `pto.vstsx2(low: VRegType, high: VRegType, buf: ptr, offset: Index, dist: InterleaveDist, mask: MaskType) -> None` [Advanced Tier] +#### `pto.vstsx2(low: VRegType, high: VRegType, tile[row, col:], dist: InterleaveDist, mask: MaskType) -> None` +#### `pto.vstsx2(low: VRegType, high: VRegType, tile[start:], dist: InterleaveDist, mask: MaskType) -> None` + +**Description**: Dual interleaved store (SoA → AoS conversion). Stores two vectors interleaved into a single buffer. Supports both byte-offset and element-indexing syntax. + +**Parameters (byte-offset syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `low` | `VRegType` | First vector (even elements in interleaved stream) | +| `high` | `VRegType` | Second vector (odd elements in interleaved stream) | +| `buf` | `ptr` | Pointer to destination buffer in UB memory space (Advanced mode only - requires explicit pointer) | +| `offset` | `Index` | Byte offset | +| `dist` | `InterleaveDist` | Interleave distribution enum. Prefer `InterleaveDist.INTLV`. | +| `mask` | `MaskType` | Predicate mask | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `low` | `VRegType` | First vector (even elements in interleaved stream) | +| `high` | `VRegType` | Second vector (odd elements in interleaved stream) | +| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column (vector-width range) | +| `dist` | `InterleaveDist` | Interleave distribution enum. Prefer `InterleaveDist.INTLV`. | +| `mask` | `MaskType` | Predicate mask | + +**Parameters (1D element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `low` | `VRegType` | First vector (even elements in interleaved stream) | +| `high` | `VRegType` | Second vector (odd elements in interleaved stream) | +| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index (vector-width range) | +| `dist` | `InterleaveDist` | Interleave distribution enum. Prefer `InterleaveDist.INTLV`. | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: None (side-effect operation) + +**Constraints**: +- Destination buffer must be in UB memory space +- Offset must satisfy alignment requirements for the selected distribution mode +- The destination vector region must be within tile bounds (for element-indexing syntax) +- Distribution mode must match element type (e.g., `"INTLV"` for 32-bit elements) +- The two source vectors form an ordered pair; interleave semantics must be preserved + +**Examples**: +```python +# Byte-offset syntax +pto.vstsx2(x_vec, y_vec, ub_ptr, offset, pto.InterleaveDist.INTLV, mask) + +# Element-indexing syntax +pto.vstsx2(x_vec, y_vec, tile[i, j:], pto.InterleaveDist.INTLV, mask) +pto.vstsx2(x_vec, y_vec, tile[k:], pto.InterleaveDist.INTLV, mask) + +# Example: Store separate X/Y vectors as interleaved XY pairs +pto.vstsx2(x_vec, y_vec, xy_tile[i, j:], pto.InterleaveDist.INTLV, all_mask) +``` + +#### `pto.vsta(align: pto.align, buf: ptr, offset: Index) -> None` [Advanced Tier] +#### `pto.vsta(align: pto.align, tile[row, col:]) -> None` [Basic Tier] +#### `pto.vsta(align: pto.align, tile[start:]) -> None` [Basic Tier] + +**Description**: Flush alignment state to memory. Writes buffered tail bytes from alignment state to UB memory. Consumes the alignment state after flush. + +**Parameters (byte-offset syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `align` | `pto.align` | Pending store-alignment state | +| `buf` | `ptr` | Pointer to destination buffer in UB memory space (Advanced mode only - requires explicit pointer) | +| `offset` | `Index` | Flush displacement | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `align` | `pto.align` | Pending store-alignment state | +| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column (vector-width range) | +| _or_ | | | +| `align` | `pto.align` | Pending store-alignment state | +| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index (vector-width range) | + +**Returns**: None (side-effect operation) + +**Constraints**: +- The flush address must match the post-updated address expected by the preceding unaligned-store stream +- After the flush, the corresponding store alignment state is consumed +- A final flush operation is required to commit buffered bytes after unaligned-store sequences +- The `align` input should come from the latest `vstu`/`vstus`/`vstur` in the same stream + +**Example**: +```python +# Byte-offset syntax +pto.vsta(align, ub_ptr, offset) + +# Element-indexing syntax +pto.vsta(align, tile[i, j:]) +pto.vsta(align, tile[k:]) +``` + +#### `pto.vscatter(vec: VRegType, buf: ptr, offsets: Index, active_lanes: Index) -> None` [Advanced Tier] + +**Description**: Indexed scatter to UB. Stores vector elements to irregular locations using per-lane offsets, with participation bounded by active lanes count. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Source vector to scatter | +| `buf` | `ptr` | Pointer to destination buffer in UB memory space | +| `offsets` | `Index` | Per-lane element offsets (vector register) | +| `active_lanes` | `Index` | Number of lanes that participate (bounds scatter operation) | + +**Returns**: None (side-effect operation) + +**Constraints**: +- Only `b8`, `b16`, and `b32` element sizes are supported +- Current TileLang DSL / VPTO path requires `i32` index vectors +- Each computed address must be element-aligned +- If indices alias, only one write is guaranteed (winning lane is implementation-defined) +- Only the first `active_lanes` offsets participate in the scatter + +**Example**: +```python +pto.vscatter(vec, buf, offsets, active_lanes) +``` + +#### `pto.vsstb(scalar: ScalarType, buf: ptr, offset: Index, mask: MaskType) -> None` [Advanced Tier] +#### `pto.vsstb(scalar: ScalarType, tile[row, col:], mask: MaskType) -> None` [Basic Tier] +#### `pto.vsstb(scalar: ScalarType, tile[start:], mask: MaskType) -> None` [Basic Tier] + +**Description**: Scalar to vector store with broadcast (enhanced version of `vsst`). Supports both byte‑offset and element‑indexing syntax. + +**Parameters (pointer syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `scalar` | `ScalarType` | Scalar value | +| `buf` | `ptr` | Pointer to destination buffer | +| `offset` | `Index` | Byte offset | +| `mask` | `MaskType` | Predicate mask | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `scalar` | `ScalarType` | Scalar value | +| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column (vector-width range) | +| `mask` | `MaskType` | Predicate mask | + +**Parameters (1D element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `scalar` | `ScalarType` | Scalar value | +| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index (vector-width range) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: None (side-effect operation) + +**Example**: +```python +# Byte-offset syntax +pto.vsstb(pto.f32(0.0), ub_ptr, offset, mask) + +# Element-indexing syntax +pto.vsstb(pto.f32(1.0), tile[i, j:], mask) +``` + +#### `pto.vstar(align: pto.align, buf: ptr) -> None` [Advanced Tier] +#### `pto.vstar(align: pto.align, tile[row, col:]) -> None` [Basic Tier] +#### `pto.vstar(align: pto.align, tile[start:]) -> None` [Basic Tier] + +**Description**: Flush alignment state using the register-update form. Writes buffered tail bytes from alignment state to UB memory. The implicit update state must correspond to the same store stream that produced the alignment state. + +**Parameters (byte-offset syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `align` | `pto.align` | Pending store-alignment state | +| `buf` | `ptr` | Pointer to destination buffer in UB memory space | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `align` | `pto.align` | Pending store-alignment state | +| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column (vector-width range) | +| _or_ | | | +| `align` | `pto.align` | Pending store-alignment state | +| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index (vector-width range) | + +**Returns**: None (side-effect operation) + +**Constraints**: +- The implicit update state consumed by this flush must correspond to the same store stream that produced the alignment state +- A final flush operation is required to commit buffered bytes after unaligned-store sequences +- The `align` input should come from the latest `vstu`/`vstus`/`vstur` in the same stream + +**Example**: +```python +# Byte-offset syntax +pto.vstar(align, ub_ptr) + +# Element-indexing syntax +pto.vstar(align, tile[i, j:]) +pto.vstar(align, tile[k:]) +``` + +#### `pto.vstas(align: pto.align, buf: ptr, offset: Index) -> None` [Advanced Tier] +#### `pto.vstas(align: pto.align, tile[row, col:], offset: Index) -> None` [Basic Tier] +#### `pto.vstas(align: pto.align, tile[start:], offset: Index) -> None` [Basic Tier] + +**Description**: Scalar-register-offset form of alignment-state flush. Writes buffered tail bytes from alignment state to UB memory with explicit scalar offset. Uses same buffered-tail semantics as `pto.vsta`. + +**Parameters (pointer syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `align` | `pto.align` | Pending store-alignment state | +| `buf` | `ptr` | Pointer to destination buffer in UB memory space | +| `offset` | `Index` | Scalar-register style displacement | + +**Parameters (element-indexing syntax)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `align` | `pto.align` | Pending store-alignment state | +| `tile[row, col:]` | `Tile` with indexing | 2D tile with row index and starting column (vector-width range) | +| `offset` | `Index` | Scalar-register style displacement | +| _or_ | | | +| `align` | `pto.align` | Pending store-alignment state | +| `tile[start:]` | `Tile` with indexing | 1D tile with starting element index (vector-width range) | +| `offset` | `Index` | Scalar-register style displacement | + +**Returns**: None (side-effect operation) + +**Example**: +```python +# Byte-offset syntax +pto.vstas(align, ub_ptr, offset) + +# Element-indexing syntax +pto.vstas(align, tile[i, j:], offset) +pto.vstas(align, tile[k:], offset) +``` + +### Stateful Store Operations + +Operations for storing data with stateful semantics. + +#### `pto.pstu(align_in: pto.align, mask: MaskType, buf: ptr) -> (pto.align, ptr)` [Advanced Tier] + +**Description**: Predicate unaligned store with align state update. Stores predicate mask with alignment state threading. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `align_in` | `pto.align` | Incoming store-alignment state | +| `mask` | `MaskType` | Predicate mask to store | +| `buf` | `ptr` | Pointer to destination buffer in UB memory space | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `align_out` | `pto.align` | Updated alignment state | +| `base_out` | `ptr` | Post-update base pointer state | + +**Constraints**: +- Part of stateful unaligned-store sequence with alignment state threading + +#### `pto.vstu(align_in: pto.align, base_in: ptr, vec: VRegType, buf: ptr, mode: Index) -> (pto.align, ptr)` [Advanced Tier] + +**Description**: Unaligned store with explicit threaded alignment/base state. Models a stateful unaligned-store sequence in SSA form. Requires a final `pto.vsta`/`pto.vstas`/`pto.vstar` to flush trailing buffered bytes. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `align_in` | `pto.align` | Incoming store-alignment state | +| `base_in` | `ptr` | Current stream base pointer | +| `vec` | `VRegType` | Vector to store | +| `buf` | `ptr` | Destination buffer in UB memory space | +| `mode` | `Index` | Mode selecting post-update behavior | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `align_out` | `pto.align` | Updated buffered-tail state | +| `base_out` | `ptr` | Post-update base pointer state | + +**Constraints**: +- Models stateful unaligned-store sequence in SSA form +- Final flush operation required to commit buffered bytes + +**Example**: +```python +# Stateful unaligned store + final flush (vsta form) +align1, base1 = pto.vstu(align0, base0, vec0, ub_ptr, mode) +align2, base2 = pto.vstu(align1, base1, vec1, ub_ptr, mode) +pto.vsta(align2, ub_ptr, tail_offset) +``` + +#### `pto.vstus(align_in: pto.align, base_in: ptr, vec: VRegType, buf: ptr, offset: Index) -> (pto.align, ptr)` [Advanced Tier] + +**Description**: Scalar-offset unaligned store with threaded state. Same roles as `pto.vstu` but with explicit scalar displacement. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `align_in` | `pto.align` | Incoming store-alignment state | +| `base_in` | `ptr` | Current stream base pointer | +| `vec` | `VRegType` | Vector to store | +| `buf` | `ptr` | Destination buffer in UB memory space | +| `offset` | `Index` | Scalar displacement | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `align_out` | `pto.align` | Updated buffered-tail state | +| `base_out` | `ptr` | Post-update base pointer state | + +**Constraints**: +- Same final flush requirement and state-threading constraints as `pto.vstu` + +**Example**: +```python +# Scalar-offset threaded form + final flush (vstas form) +align1, base1 = pto.vstus(align0, base0, vec0, ub_ptr, offset0) +align2, base2 = pto.vstus(align1, base1, vec1, ub_ptr, offset1) +pto.vstas(align2, ub_ptr, flush_offset) +``` + +#### `pto.vstur(align_in: pto.align, vec: VRegType, buf: ptr, mode: PostUpdateMode = pto.PostUpdateMode.NO_POST_UPDATE) -> pto.align` [Advanced Tier] + +**Description**: Register-update unaligned store form. Updates only the residual alignment state without base pointer update. Requires matching flush operation to emit trailing bytes. The optional `mode` operand is a typed Enum and controls whether the hardware performs post-update on the implicit AR state. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `align_in` | `pto.align` | Incoming store-alignment state | +| `vec` | `VRegType` | Vector to store | +| `buf` | `ptr` | Destination buffer in UB memory space | +| `mode` | `PostUpdateMode` | Optional post-update mode. Defaults to `pto.PostUpdateMode.NO_POST_UPDATE`. | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `align_out` | `pto.align` | Updated buffered-tail state | + +**Constraints**: +- Updates only residual alignment state (no base pointer update) +- Matching flush operation still required to emit trailing bytes + +**Example**: +```python +# Residual-state form + final flush (vstar form) +align1 = pto.vstur(align0, vec0, ub_ptr) +align2 = pto.vstur(align1, vec1, ub_ptr) +pto.vstar(align2, ub_ptr) + +# Explicit post-update mode with typed Enum +align3 = pto.vstur(align2, vec2, ub_ptr, pto.PostUpdateMode.POST_UPDATE) +``` + +#### Align-State Store Closed Loop + +For unaligned store families, the state must form a closed loop: + +1. Start from an incoming `align` state. +2. Thread state through one or more `vstu` / `vstus` / `vstur` operations. +3. Terminate the stream with exactly one flush op: `vsta` or `vstas` or `vstar`. +4. Do not reuse a flushed `align` state in another stream. diff --git a/tilelang-dsl/docs/user_guide/10-predicate-operations.md b/tilelang-dsl/docs/user_guide/10-predicate-operations.md new file mode 100644 index 000000000..8cc92da2c --- /dev/null +++ b/tilelang-dsl/docs/user_guide/10-predicate-operations.md @@ -0,0 +1,637 @@ +### Predicate Operations + +Operations for creating and manipulating typed masks. + +**Recommended API**: For most use cases, prefer the unified `pto.make_mask()` function which automatically selects the appropriate mask granularity based on element type and supports both tail processing (remaining element count) and pattern-based mask generation. This eliminates the need to manually choose between `plt_b8`/`plt_b16`/`plt_b32` (tail processing) and `pset_b8`/`pset_b16`/`pset_b32` (pattern generation) operations. + +**Pattern alias**: For brevity in examples, the documentation uses `PAT` as an alias for `pto.MaskPattern` (e.g., `PAT.ALL` instead of `pto.MaskPattern.ALL`). In practice, you can create this alias with `from pto import MaskPattern as PAT` or `PAT = pto.MaskPattern`. + +**Predicate Part Enum**: `pto.ppack` and `pto.punpack` require the `PredicatePart` enum. Use `PredicatePart.LOWER` or `PredicatePart.HIGHER`; these lower to the VPTO canonical `PART` tokens `"LOWER"` and `"HIGHER"`. + +**Predicate Dist Enum**: The `PredicateDist` enum provides type-safe distribution mode selection for predicate memory families. Load families (`plds`, `pld`, `pldi`) use `NORM`, `US`, and `DS`. Store families (`psts`, `pst`, `psti`) use `NORM` and `PK`. + +**Pattern coverage**: The VPTO canonical predicate-generation families use `PAT_*` tokens such as `PAT_ALL`, `PAT_ALLF`, `PAT_H`, `PAT_Q`, `PAT_VL*`, `PAT_M3`, and `PAT_M4`. The Python DSL surface may expose only a subset through `pto.MaskPattern`; check the enum for currently available values. + +#### `pto.pset_b8(pattern: pto.MaskPattern) -> pto.mask_b8` + +**Description**: Creates an 8-bit granularity mask from a pattern. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `pattern` | `pto.MaskPattern` | Mask pattern enum (for example `pto.MaskPattern.ALL`, `pto.MaskPattern.ALLF`, or `pto.MaskPattern.VL32`) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `pto.mask_b8` | 8-bit granularity mask | + +**Constraints**: +- Used with `i8` vector operations + +**Example**: +```python +mask8 = pto.pset_b8(PAT.ALL) +``` + +#### `pto.pset_b16(pattern: pto.MaskPattern) -> pto.mask_b16` + +**Description**: Creates a 16-bit granularity mask from a pattern. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `pattern` | `pto.MaskPattern` | Mask pattern enum (for example `pto.MaskPattern.ALL`, `pto.MaskPattern.ALLF`, or `pto.MaskPattern.VL32`) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `pto.mask_b16` | 16-bit granularity mask | + +**Constraints**: +- Used with `f16`/`bf16`/`i16` vector operations + +**Example**: +```python +mask16 = pto.pset_b16(PAT.ALL) +``` + +#### `pto.pset_b32(pattern: pto.MaskPattern) -> pto.mask_b32` + +**Description**: Creates a 32-bit granularity mask from a pattern. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `pattern` | `pto.MaskPattern` | Mask pattern enum (for example `pto.MaskPattern.ALL`, `pto.MaskPattern.ALLF`, or `pto.MaskPattern.VL32`) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `pto.mask_b32` | 32-bit granularity mask | + +**Constraints**: +- Used with `f32`/`i32` vector operations + +**Example**: +```python +mask32 = pto.pset_b32(PAT.ALL) +``` + +#### `pto.pge_b8(pattern: pto.MaskPattern) -> pto.mask_b8` + +**Description**: Generate tail mask — first N lanes active based on pattern. Creates an 8-bit granularity mask where the first N lanes are active according to the specified pattern. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `pattern` | `pto.MaskPattern` | Tail mask pattern enum lowered to a VPTO `PAT_*` token (for example `pto.MaskPattern.VL16` or `pto.MaskPattern.VL32`) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `pto.mask_b8` | 8-bit granularity tail mask | + +**Constraints**: +- Used with `i8` vector operations +- Pattern must be a valid tail mask pattern (typically `PAT_VL*` variants) + +**Example**: +```python +# Tail mask pattern lowered as `PAT_VL16` +tail_mask = pto.pge_b8(PAT.VL16) +``` + +#### `pto.pge_b16(pattern: pto.MaskPattern) -> pto.mask_b16` + +**Description**: Generate tail mask — first N lanes active based on pattern. Creates a 16-bit granularity mask where the first N lanes are active according to the specified pattern. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `pattern` | `pto.MaskPattern` | Tail mask pattern enum lowered to a VPTO `PAT_*` token (for example `pto.MaskPattern.VL16` or `pto.MaskPattern.VL32`) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `pto.mask_b16` | 16-bit granularity tail mask | + +**Constraints**: +- Used with `f16`/`bf16`/`i16` vector operations +- Pattern must be a valid tail mask pattern (typically `PAT_VL*` variants) + +**Example**: +```python +# Tail mask for first 16 lanes +tail_mask = pto.pge_b16(PAT.VL16) +``` + +#### `pto.pge_b32(pattern: pto.MaskPattern) -> pto.mask_b32` + +**Description**: Generate tail mask — first N lanes active based on pattern. Creates a 32-bit granularity mask where the first N lanes are active according to the specified pattern. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `pattern` | `pto.MaskPattern` | Tail mask pattern enum lowered to a VPTO `PAT_*` token (for example `pto.MaskPattern.VL16` or `pto.MaskPattern.VL32`) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `pto.mask_b32` | 32-bit granularity tail mask | + +**Constraints**: +- Used with `f32`/`i32` vector operations +- Pattern must be a valid tail mask pattern (typically `PAT_VL*` variants) + +**Example**: +```python +# Tail mask for first 32 lanes +tail_mask = pto.pge_b32(PAT.VL32) +``` + +#### `pto.plt_b8(scalar: pto.i32) -> (pto.mask_b8, pto.i32)` + +**Description**: Generate predicate state together with updated scalar state (tail processing). Creates an 8-bit granularity mask and returns updated scalar value for state progression. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `scalar` | `pto.i32` | Input scalar value (typically remaining element count) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `pto.mask_b8` | 8-bit granularity mask | +| `scalar_out` | `pto.i32` | Updated scalar state | + +**Constraints**: +- Used with `i8` vector operations for tail processing +- The scalar input is typically a remaining element count that decrements across successive calls + +**Example**: +```python +remaining: pto.i32 = 64 +mask, remaining = pto.plt_b8(remaining) # generates mask for next chunk, updates remaining count +``` + +#### `pto.plt_b16(scalar: pto.i32) -> (pto.mask_b16, pto.i32)` + +**Description**: Generate predicate state together with updated scalar state (tail processing). Creates a 16-bit granularity mask and returns updated scalar value for state progression. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `scalar` | `pto.i32` | Input scalar value (typically remaining element count) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `pto.mask_b16` | 16-bit granularity mask | +| `scalar_out` | `pto.i32` | Updated scalar state | + +**Constraints**: +- Used with `f16`/`bf16`/`i16` vector operations for tail processing +- The scalar input is typically a remaining element count that decrements across successive calls + +**Example**: +```python +remaining: pto.i32 = 64 +mask, remaining = pto.plt_b16(remaining) # generates mask for next chunk, updates remaining count +``` + +#### `pto.plt_b32(scalar: pto.i32) -> (pto.mask_b32, pto.i32)` + +**Description**: Generate predicate state together with updated scalar state (tail processing). Creates a 32-bit granularity mask and returns updated scalar value for state progression. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `scalar` | `pto.i32` | Input scalar value (typically remaining element count) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `pto.mask_b32` | 32-bit granularity mask | +| `scalar_out` | `pto.i32` | Updated scalar state | + +**Constraints**: +- Used with `f32`/`i32` vector operations for tail processing +- The scalar input is typically a remaining element count that decrements across successive calls + +**Example**: +```python +remaining: pto.i32 = 64 +mask, remaining = pto.plt_b32(remaining) # generates mask for next chunk, updates remaining count +``` + +#### `pto.make_mask(element_type: Type, value: pto.i32 | pto.MaskPattern) -> MaskType | (MaskType, pto.i32)` + +**Description**: Creates a mask with appropriate bitwidth (8, 16, or 32) based on element type, automatically inferring whether to perform tail processing or pattern-based mask generation based on the `value` parameter type. This convenience function eliminates the need to manually choose between `plt_b8`/`plt_b16`/`plt_b32` and `pset_b8`/`pset_b16`/`pset_b32` operations. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `element_type` | `Type` | Element type (e.g., `pto.f32`, `pto.f16`, `pto.i8`) | +| `value` | `pto.i32` \| `pto.MaskPattern` | Either:
- Remaining element count (as `pto.i32`) for tail processing
- Mask pattern enum value for fixed mask generation (for example `pto.MaskPattern.ALL` or `pto.MaskPattern.VL32`) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `MaskType` | Generated mask with appropriate granularity | +| `remaining` | `pto.i32` | Updated remaining element count (only returned when `value` is a `pto.i32` for tail processing) | + +**Constraints**: +- The `element_type` must be one of: `f32`, `f16`, `bf16`, or an 8/16/32-bit integer family member (`i*`, `si*`, `ui*`) +- The returned mask granularity matches the element type: 32-bit for `f32`/`i32`/`si32`/`ui32`, 16-bit for `f16`/`bf16`/`i16`/`si16`/`ui16`, and 8-bit for `i8`/`si8`/`ui8` +- The function infers the operation mode from the `value` parameter type at compile time: + - `pto.i32` value → tail processing mode (returns `(mask, updated_remaining)`) + - `pto.MaskPattern` enum value → pattern mode (returns `mask` only) + +**Implementation Note**: This function is a DSL macro that performs type-based dispatch at compile time: +- When `value` is a `pto.i32` expression: expands to corresponding `plt_b` instruction (`plt_b32`, `plt_b16`, or `plt_b8`) +- When `value` is a `pto.MaskPattern` enum value: expands to corresponding `pset_b` instruction (`pset_b32`, `pset_b16`, or `pset_b8`) + +**Example**: +```python +# Tail processing with f32 vectors: value is pto.i32 → expands to plt_b32 +mask_f32, remaining_f32 = pto.make_mask(pto.f32, remaining_elements) + +# Tail processing with f16 vectors: value is pto.i32 → expands to plt_b16 +mask_f16, remaining_f16 = pto.make_mask(pto.f16, remaining_elements) + +# Tail processing with i8 vectors: value is pto.i32 → expands to plt_b8 +mask_i8, remaining_i8 = pto.make_mask(pto.i8, remaining_elements) + +# Pattern-based mask with f32 vectors: value is MaskPattern enum → expands to pset_b32 +mask_all_f32 = pto.make_mask(pto.f32, PAT.ALL) + +# Pattern-based mask with f16 vectors: value is MaskPattern enum → expands to pset_b16 +mask_even_f16 = pto.make_mask(pto.f16, PAT.EVEN) + +# Pattern-based mask with i8 vectors: value is MaskPattern enum → expands to pset_b8 +mask_all_i8 = pto.make_mask(pto.i8, PAT.ALL) + +# Type annotations help clarify expected parameter types +remaining: pto.i32 = 1024 +mask1, updated = pto.make_mask(pto.f32, remaining) # tail processing +mask2 = pto.make_mask(pto.f32, PAT.ALL) # pattern mode +``` + +#### `pto.ppack(mask: MaskType, part: PredicatePart) -> MaskType` + +**Description**: Narrowing pack of a predicate register. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `MaskType` | Input mask (`mask_b8`, `mask_b16`, or `mask_b32`) | +| `part` | `PredicatePart` | Part selector enum. Use `PredicatePart.LOWER` or `PredicatePart.HIGHER`. | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `packed` | `MaskType` | Packed mask | + +**Example**: +```python +packed = pto.ppack(mask, pto.PredicatePart.LOWER) +``` + +#### `pto.punpack(mask: MaskType, part: PredicatePart) -> MaskType` + +**Description**: Widening unpack of a predicate register. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `MaskType` | Input mask | +| `part` | `PredicatePart` | Part selector enum. Use `PredicatePart.LOWER` or `PredicatePart.HIGHER`. | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `MaskType` | Unpacked mask | + +**Example**: +```python +unpacked = pto.punpack(mask, pto.PredicatePart.HIGHER) +``` + +#### `pto.pbitcast(mask: MaskType, to_type: MaskType) -> MaskType` + +**Description**: Reinterprets a typed predicate mask as another typed mask granularity without changing the underlying predicate bit image. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `MaskType` | Input mask (`mask_b8`, `mask_b16`, or `mask_b32`) | +| `to_type` | `MaskType` | Target mask type marker such as `pto.mask_b16` or `pto.mask_b32` | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `MaskType` | Reinterpreted mask with the requested target granularity | + +**Constraints**: +- `mask` must already be a typed predicate value +- `to_type` must be one of the DSL mask type markers: `pto.mask_b8`, `pto.mask_b16`, `pto.mask_b32` +- this is a bit reinterpretation helper, not a logical predicate transform; it does not insert packing, unpacking, interleaving, or deinterleaving by itself +- use `pto.ppack`, `pto.punpack`, `pto.pdintlv_b8`, or `pto.pintlv_b16` when the predicate image itself must be rearranged + +**Example**: +```python +mask_b8 = pto.plds(mask_ptr, offset, pto.PredicateDist.US) +mask_b16 = pto.pbitcast(mask_b8, pto.mask_b16) + +mask0_b16, mask1_b16 = pto.pintlv_b16(mask_b16, pto.pset_b16(PAT.ALL)) +mask0_b32 = pto.pbitcast(mask0_b16, pto.mask_b32) +``` + +#### `pto.pnot(mask: MaskType, gate: MaskType) -> MaskType` + +**Description**: Predicate negation under a same-granularity mask gate. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `MaskType` | Input mask | +| `gate` | `MaskType` | Gating mask with the same granularity | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `negated` | `MaskType` | Negated mask | + +#### `pto.psel(src0: MaskType, src1: MaskType, mask: MaskType) -> MaskType` + +**Description**: Selects between two masks using a third mask as selector. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src0` | `MaskType` | First input mask | +| `src1` | `MaskType` | Second input mask | +| `mask` | `MaskType` | Selection mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `MaskType` | Selected mask | + +#### `pto.plds(buf: ptr, offset: Index, dist: PredicateDist = PredicateDist.NORM) -> MaskType` [Advanced Tier] + +**Description**: Predicate load with scalar-index style offset form. This is the default DSL surface for loading predicate masks from UB memory. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `ptr` | Source pointer in UB memory space | +| `offset` | `Index` | Scalar/index-style offset | +| `dist` | `PredicateDist` | Distribution mode (default: `PredicateDist.NORM`) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `MaskType` | Loaded predicate mask | + +**Example**: +```python +mask = pto.plds(buf, offset, pto.PredicateDist.NORM) +``` + +#### `pto.pld(buf: ptr, offset: Index, dist: PredicateDist) -> MaskType` [Advanced Tier] + +**Description**: Predicate load with areg/index register style offset encoding. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `ptr` | Source pointer in UB memory space | +| `offset` | `Index` | Areg/index-style offset | +| `dist` | `PredicateDist` | Distribution mode | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `MaskType` | Loaded predicate mask | + +**Example**: +```python +mask = pto.pld(buf, offset, pto.PredicateDist.NORM) +``` + +#### `pto.pldi(buf: ptr, imm_offset: pto.i32, dist: PredicateDist) -> MaskType` [Advanced Tier] + +**Description**: Predicate load with immediate-offset encoding form. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `ptr` | Source pointer in UB memory space | +| `imm_offset` | `pto.i32` | Immediate-offset operand | +| `dist` | `PredicateDist` | Distribution mode | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `MaskType` | Loaded predicate mask | + +**Example**: +```python +mask = pto.pldi(buf, 0, pto.PredicateDist.NORM) +``` + +#### `pto.psts(mask: MaskType, buf: ptr, offset: Index, dist: PredicateDist = PredicateDist.NORM) -> None` [Advanced Tier] + +**Description**: Stores a predicate mask to UB memory using the VPTO dynamic-offset +`psts` form. This is the dynamic counterpart of `psti`: both encode the same +predicate payload semantics, while offset delivery differs (runtime `index` vs +constant immediate). + +**Parameters (Advanced Tier: explicit pointer surface)**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `MaskType` | Predicate mask to store | +| `buf` | `ptr` | Pointer to destination UB buffer | +| `offset` | `Index` | Runtime offset (`index`) | +| `dist` | `PredicateDist` | Distribution mode. Use `PredicateDist.NORM` or `PredicateDist.PK` (default: `PredicateDist.NORM`). | + +**DIST semantics (VPTO-aligned)**: +- `NORM`: stores packed predicate payload into destination space of size `VL/8`. +- `PK`: stores packed predicate payload into destination space of size `VL/16`, + keeping one bit out of every two bits. + +**Returns**: None (side-effect operation) + +**Example**: +```python +pto.psts(mask, buf, offset, pto.PredicateDist.NORM) +``` + +#### `pto.pst(mask: MaskType, buf: ptr, offset: Index, dist: PredicateDist = PredicateDist.NORM) -> None` [Advanced Tier] + +**Description**: Stores a predicate mask to UB memory using areg/index offset encoding. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `MaskType` | Predicate mask to store | +| `buf` | `ptr` | Pointer to destination UB buffer | +| `offset` | `Index` | Areg/index-style offset | +| `dist` | `PredicateDist` | Distribution mode for predicate store. Use `PredicateDist.NORM` or `PredicateDist.PK`. Default is `PredicateDist.NORM`. | + +**Returns**: None (side-effect operation) + +**Example**: +```python +pto.pst(mask, buf, offset, pto.PredicateDist.NORM) +``` + +#### `pto.psti(mask: MaskType, buf: ptr, imm_offset: pto.i32, dist: PredicateDist = PredicateDist.NORM) -> None` [Advanced Tier] + +**Description**: Stores a predicate mask to UB memory using immediate-offset encoding. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `MaskType` | Predicate mask to store | +| `buf` | `ptr` | Pointer to destination UB buffer | +| `imm_offset` | `pto.i32` | Immediate-offset operand | +| `dist` | `PredicateDist` | Distribution mode for predicate store. Use `PredicateDist.NORM` or `PredicateDist.PK`. Default is `PredicateDist.NORM`. | + +**Returns**: None (side-effect operation) + +**Example**: +```python +pto.psti(mask, buf, pto.i32(8), pto.PredicateDist.PK) +``` + +#### `pto.pstu(align_in: pto.align, mask: MaskType, buf: ptr) -> (pto.align, ptr)` [Advanced Tier] + +**Description**: Unaligned predicate store with align-state update. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `align_in` | `pto.align` | Input alignment state | +| `mask` | `MaskType` | Predicate mask to store | +| `buf` | `ptr` | Pointer to destination UB buffer | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `align_out` | `pto.align` | Updated alignment state | +| `base_out` | `ptr` | Updated destination pointer | + +**Example**: +```python +align_out, base_out = pto.pstu(align_in, mask, buf) +``` + +#### `pto.pand(src0: MaskType, src1: MaskType, mask: MaskType) -> MaskType` + +**Description**: Bitwise AND of two predicate masks under a gating mask. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src0` | `MaskType` | First input mask | +| `src1` | `MaskType` | Second input mask | +| `mask` | `MaskType` | Gating mask with the same granularity | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `MaskType` | Bitwise AND result | + +**Example**: +```python +result = pto.pand(mask1, mask2, gate) +``` + +#### `pto.por(src0: MaskType, src1: MaskType, mask: MaskType) -> MaskType` + +**Description**: Bitwise OR of two predicate masks under a gating mask. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src0` | `MaskType` | First input mask | +| `src1` | `MaskType` | Second input mask | +| `mask` | `MaskType` | Gating mask with the same granularity | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `MaskType` | Bitwise OR result | + +**Example**: +```python +result = pto.por(mask1, mask2, gate) +``` + +#### `pto.pxor(src0: MaskType, src1: MaskType, mask: MaskType) -> MaskType` + +**Description**: Bitwise XOR of two predicate masks under a gating mask. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src0` | `MaskType` | First input mask | +| `src1` | `MaskType` | Second input mask | +| `mask` | `MaskType` | Gating mask with the same granularity | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `MaskType` | Bitwise XOR result | + +**Example**: +```python +result = pto.pxor(mask1, mask2, gate) +``` + +#### `pto.pdintlv_b8(src0: pto.mask_b8, src1: pto.mask_b8) -> (pto.mask_b8, pto.mask_b8)` + +**Description**: Predicate deinterleave for 8-bit masks. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src0` | `pto.mask_b8` | First input mask | +| `src1` | `pto.mask_b8` | Second input mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `low` | `pto.mask_b8` | First result mask | +| `high` | `pto.mask_b8` | Second result mask | + +**Example**: +```python +low8, high8 = pto.pdintlv_b8(mask_a, mask_b) +``` + +#### `pto.pintlv_b16(src0: pto.mask_b16, src1: pto.mask_b16) -> (pto.mask_b16, pto.mask_b16)` + +**Description**: Predicate interleave for 16-bit masks. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src0` | `pto.mask_b16` | First input mask | +| `src1` | `pto.mask_b16` | Second input mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `low` | `pto.mask_b16` | First result mask | +| `high` | `pto.mask_b16` | Second result mask | + +**Example**: +```python +low16, high16 = pto.pintlv_b16(mask_a, mask_b) +``` + +**Note**: Prefer `pto.make_mask()` for automatic bitwidth selection and unified tail/pattern mask generation. diff --git a/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md b/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md new file mode 100644 index 000000000..ede8388df --- /dev/null +++ b/tilelang-dsl/docs/user_guide/11-vector-arithmetic-operations.md @@ -0,0 +1,1611 @@ +### Unary Vector Operations + +Element-wise unary operations on vector registers. + +#### `pto.vabs(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Absolute value of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask (granularity must match vector element type) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Absolute values | + +**Constraints**: +- Mask granularity must match vector element type (e.g., `f32` requires `mask_b32`) + +**Example**: +```python +abs_vec = pto.vabs(vec_f32, mask32) +``` + +#### `pto.vexp(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Exponential of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Exponential values | + +#### `pto.vln(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Natural logarithm of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Natural logarithm values | + +#### `pto.vsqrt(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Square root of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Square root values | + +#### `pto.vrec(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Reciprocal of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Reciprocal values | + +#### `pto.vrelu(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: ReLU activation (max(0, x)) of vector elements. + +**Supported dtypes**: `si32`, `i32`, `f16`, `f32` + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | ReLU-activated values | + +#### `pto.vnot(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Bitwise NOT of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Bitwise NOT values | + +#### `pto.vcadd(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Reduction add of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Reduction result vector | + +**Type Rules**: +- For floating-point inputs and `i32/ui32`, the result vector type matches the input vector type. +- For `i8/ui8` inputs, `pto.vcadd` returns a widened `i16/ui16` vector. +- For `i16/ui16` inputs, `pto.vcadd` returns a widened `i32/ui32` vector. +- The result mask granularity follows the result vector element type. + +#### `pto.vcmax(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Complex maximum of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector (interpreted as complex pairs) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Complex maximum result | + +#### `pto.vbcnt(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Bit count (population count) of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Bit count values | + +#### `pto.vneg(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Negation of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask (granularity must match vector element type) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Negated values | + +**Constraints**: +- Mask granularity must match vector element type + +**Example**: +```python +neg_vec = pto.vneg(vec_f32, mask32) +``` + +#### `pto.vcls(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Count leading sign bits of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Count of leading sign bits | + +**Constraints**: +- Operates on integer vector types only + +#### `pto.vcmin(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Complex minimum of vector elements (treating pairs as complex numbers). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector (interpreted as complex pairs) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Complex minimum result | + +#### `pto.vrsqrt(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Reciprocal square root of vector elements (1/√x). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Reciprocal square root values | + +**Constraints**: +- For floating-point vector types only + +#### `pto.vprelu(vec: VRegType, alpha: VRegType, mask: MaskType) -> VRegType` + +**Description**: Parametric ReLU activation of vector elements: `x if x >= 0 else alpha * x`. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `alpha` | `VRegType` | Slope parameter for negative values | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Parametric ReLU activated values | + +#### `pto.vmov(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Vector move (data movement). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Copied vector | + +#### `pto.vsunpack(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Signed unpack of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Unpacked signed values | + +**Constraints**: +- Operates on integer vector types only + +#### `pto.vzunpack(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Zero-extended unpack of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Unpacked zero-extended values | + +**Constraints**: +- Operates on integer vector types only + +#### `pto.vusqz(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Unsigned squeeze (compression) of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Compressed unsigned values | + +**Constraints**: +- Operates on integer vector types only + +#### `pto.vsqz(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Signed squeeze (compression) of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Compressed signed values | + +**Constraints**: +- Operates on integer vector types only + +#### `pto.vexpdif(vec: VRegType, max_vec: VRegType, mask: MaskType, part: pto.VcvtPartMode) -> VRegType` + +**Description**: Fused exponential difference `exp(vec - max_vec)` for numerically stable softmax lowering. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `max_vec` | `VRegType` | Per-lane max vector subtracted before exponentiation | +| `mask` | `MaskType` | Predicate mask. Use `b16` for `f16` inputs and `b32` for `f32` inputs. | +| `part` | `pto.VcvtPartMode` | Output part selector enum. Use `pto.VcvtPartMode.EVEN` or `pto.VcvtPartMode.ODD`. | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Exponential difference values; result element type is `f32` | + +**Constraints**: +- Supports `f16` and `f32` input vectors only +- `vec` and `max_vec` must use the same vector type +- `mask` granularity must match the input vector element width +- `part` should use `pto.VcvtPartMode.EVEN` or `pto.VcvtPartMode.ODD` +- Canonical strings `"EVEN"` / `"ODD"` are still accepted for compatibility + +### Binary Vector Operations + +Element-wise binary operations on vector registers. + +#### `pto.vadd(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise addition of two vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Sum of vectors | + +**Example**: +```python +sum_vec = pto.vadd(vec_a, vec_b, mask32) +``` + +#### `pto.vsub(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise subtraction of two vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Difference of vectors | + +#### `pto.vmul(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise multiplication of two vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Product of vectors | + +#### `pto.vdiv(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise division of two vectors. + +- Supported element types are the 8/16/32-bit integer families (`i*`, `si*`, `ui*`) plus `f16` and `f32`. +- `f16`/`f32` authoring code stays on the public `pto.vdiv` VPTO path. +- Integer `pto.vdiv` also uses the same public surface, but lowers through an internal soft-helper path. +- For `i8`/`ui8`, the integer lowering widens to 16-bit lanes, computes the soft division, then narrows back to 8-bit lanes. +- Internal helper names such as `_tl_soft_vdiv_*` are implementation details and are not part of the supported DSL call surface. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Quotient of vectors | + +#### `pto.vmod(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise modulo of two vectors. + +- Supported element types are the 8/16/32-bit integer families (`i*`, `si*`, `ui*`). +- Floating-point `vmod` is not part of the current TileLang DSL v1 public surface. +- `pto.vmod` is the only public vector modulo entry point in TileLang DSL v1. +- The current implementation lowers through an internal soft-helper path; helper names such as `_tl_soft_vmod_*` are intentionally hidden implementation details. +- For `i8`/`ui8`, the modulo path uses an explicit widen-to-16-bit, soft-compute, narrow-back-to-8-bit profile. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | Dividend vector | +| `vec2` | `VRegType` | Divisor vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Remainder vector | + +#### `pto.vmax(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise maximum of two vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Element-wise maximum | + +#### `pto.vmin(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise minimum of two vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Element-wise minimum | + +#### `pto.vand(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise bitwise AND of two vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Bitwise AND result | + +#### `pto.vor(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise bitwise OR of two vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Bitwise OR result | + +#### `pto.vxor(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise bitwise XOR of two vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Bitwise XOR result | + +#### `pto.vshl(vec: VRegType, shift: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise shift left (vector shift amounts). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `shift` | `VRegType` | Shift amounts (per element) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Shifted values | + +#### `pto.vshr(vec: VRegType, shift: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise shift right (vector shift amounts). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `shift` | `VRegType` | Shift amounts (per element) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Shifted values | + +#### `pto.vaddrelu(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Addition with ReLU activation (max(0, vec1 + vec2)). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | ReLU-activated sum of vectors | + +#### `pto.vaddreluconv(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Convolution addition with ReLU activation (convolution-specific fused operation). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | ReLU-activated convolution sum | + +**Constraints**: +- Optimized for convolution-specific patterns + +#### `pto.vsubrelu(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Subtraction with ReLU activation (max(0, vec1 - vec2)). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | ReLU-activated difference of vectors | + +#### `pto.vaxpy(alpha: VRegType, x: VRegType, y: VRegType, mask: MaskType) -> VRegType` + +**Description**: BLAS AXPY operation (αx + y). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `alpha` | `VRegType` | Scaling factor | +| `x` | `VRegType` | Input vector x | +| `y` | `VRegType` | Input vector y | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Result of αx + y | + +#### `pto.vmulconv(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Convolution multiplication (convolution-specific multiplication). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Convolution product | + +**Constraints**: +- Optimized for convolution-specific patterns + +#### `pto.vmull(vec1: VRegType, vec2: VRegType, mask: MaskType) -> (VRegType, VRegType)` + +**Description**: Widening multiply with split low/high results (extended arithmetic). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `low` | `VRegType` | Low part of widened product (`r & 0xFFFFFFFF`) | +| `high` | `VRegType` | High part of widened product (`r >> 32`) | + +**Constraints**: +- Current A5 documented form is native `i32/u32` 32x32->64 widening multiply +- Result is split into two vector outputs instead of a single widened vector + +**Example**: +```python +low, high = pto.vmull(lhs_i32, rhs_i32, mask32) +``` + +#### `pto.vmula(vec1: VRegType, vec2: VRegType, vec3: VRegType, mask: MaskType) -> VRegType` + +**Description**: Fused multiply-add (vec1 * vec2 + vec3). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector (multiplier) | +| `vec2` | `VRegType` | Second input vector (multiplicand) | +| `vec3` | `VRegType` | Third input vector (addend) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Result of vec1 * vec2 + vec3 | + +### Vector-Scalar Operations + +Operations between vectors and scalars. + +#### `pto.vmuls(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Vector multiplied by scalar (broadcast). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `scalar` | `ScalarType` | Scalar multiplier | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Scaled vector | + +**Example**: +```python +scaled = pto.vmuls(vec_f32, pto.f32(2.0), mask32) +``` + +#### `pto.vadds(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Vector plus scalar (broadcast). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `scalar` | `ScalarType` | Scalar addend | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Result vector | + +#### `pto.vmaxs(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Element-wise maximum of vector and scalar. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `scalar` | `ScalarType` | Scalar value | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Maximum values | + +#### `pto.vmins(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Element-wise minimum of vector and scalar. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `scalar` | `ScalarType` | Scalar value | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Minimum values | + +#### `pto.vlrelu(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Leaky ReLU activation (max(αx, x)). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `scalar` | `ScalarType` | Alpha coefficient | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Leaky ReLU activated values | + +#### `pto.vshls(vec: VRegType, shift: i16, mask: MaskType) -> VRegType` + +**Description**: Vector shift left by scalar (uniform shift). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `shift` | `i16` | Shift amount (same for all elements) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Shifted values | + +#### `pto.vshrs(vec: VRegType, shift: i16, mask: MaskType) -> VRegType` + +**Description**: Vector shift right by scalar (uniform shift). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `shift` | `i16` | Shift amount (same for all elements) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Shifted values | + +#### `pto.vands(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Element-wise bitwise AND of vector and scalar. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `scalar` | `ScalarType` | Scalar operand | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Bitwise AND result | + +**Constraints**: +- Operates on integer vector types only + +#### `pto.vors(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Element-wise bitwise OR of vector and scalar. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `scalar` | `ScalarType` | Scalar operand | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Bitwise OR result | + +**Constraints**: +- Operates on integer vector types only + +#### `pto.vxors(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Element-wise bitwise XOR of vector and scalar. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `scalar` | `ScalarType` | Scalar operand | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Bitwise XOR result | + +**Constraints**: +- Operates on integer vector types only + +#### `pto.vsubs(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Vector minus scalar (broadcast). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `scalar` | `ScalarType` | Scalar subtrahend | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Difference vector | + +#### `pto.vbr(value: ScalarType) -> VRegType` + +**Description**: Broadcast scalar to all vector lanes. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `value` | `ScalarType` | Scalar source | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Vector whose active lanes all carry `value` | + +**Constraints**: +- Supported scalar types are the 8/16/32-bit integer families (`i*`, `si*`, `ui*`) plus `f16`, `bf16`, and `f32`. +- For integer types, only the low bits of the scalar source are consumed according to the bit width (8, 16, or 32 bits). + +**Example**: +```python +# Broadcast scalar constant to vector +zero_vec = pto.vbr(0.0) +one_vec = pto.vbr(1.0) + +# Reduction seed with explicit floating dtype +rowmax_seed_f32 = pto.vbr(pto.f32("-inf")) +rowmax_seed_f16 = pto.vbr(pto.f16("0xFC00")) +``` + +#### `pto.vdup(input: ScalarType, mask: MaskType) -> VRegType` +#### `pto.vdup(input: VRegType, mask: MaskType, position: PositionMode = PositionMode.LOWEST) -> VRegType` + +**Description**: Duplicate a scalar value or one selected vector element into +the active lanes of a destination vector. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `input` | `ScalarType` or `VRegType` | Input scalar or source vector | +| `mask` | `MaskType` | Predicate mask controlling which lanes are written | +| `position` | `PositionMode` | Optional enum for the vector-input overload, selecting the source vector element to duplicate (default: `PositionMode.LOWEST`) | + +**Position Mode Enum**: The `PositionMode` enum provides type-safe source-lane +selection for `pto.vdup`. `LOWEST` selects the lowest-index element of the +source vector and `HIGHEST` selects the highest-index element. The enum is only +used by the vector-input overload. + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Vector whose active lanes receive the duplicated value | + +**Constraints**: +- `mask` granularity must match the destination vector element type. For + example, `f32`/`i32`/`si32`/`ui32` vectors require `mask_b32`. +- When `input` is a scalar, the scalar value is duplicated to every active lane. +- When `input` is a vector, `position` selects a single source element and that + value is duplicated to every active lane. +- The scalar overload does not accept `position`. +- Inactive lanes follow VPTO predicate semantics and are not guaranteed to carry + meaningful values for subsequent masked-off use. +- Supported scalar types are the 8/16/32-bit integer families (`i*`, `si*`, `ui*`) plus `f16`, `bf16`, and `f32`. +- `position` is only meaningful for vector input. TileLang DSL currently exposes + `PositionMode.LOWEST` and `PositionMode.HIGHEST`, matching VPTO v0.3. + +**Example**: +```python +mask32 = pto.make_mask(pto.f32, pto.PAT.ALL) + +# Duplicate a scalar into all active lanes. +broadcast = pto.vdup(3.14, mask32) + +# Use dtype constructors for floating-point special values. +seed = pto.vdup(pto.f32("-inf"), mask32) +seed_f16 = pto.vdup(pto.f16("0xFC00"), pto.make_mask(pto.f16, pto.PAT.ALL)) + +# Assume `vec` is an existing `f32` vector register value. +vec = pto.vlds(src, 0) + +# Duplicate the lowest source lane to all active lanes. +dup_lowest = pto.vdup(vec, mask32) # position defaults to "LOWEST" + +# Duplicate the highest source lane to all active lanes. +dup_highest = pto.vdup(vec, mask32, pto.PositionMode.HIGHEST) +``` + +**Type Safety Note**: +- For floating-point seeds, prefer `pto.f16(...)` / `pto.bf16(...)` / `pto.f32(...)` constructors. +- Do not pass integer bit-pattern literals directly (for example `0xFF800000`) when a floating vector type is intended. + +### Carry & Select Operations + +Operations with carry propagation and selection. + +**Comparison Mode Enum**: The `CmpMode` enum provides type-safe comparison mode specification for `pto.vcmp` and `pto.vcmps` operations. It includes the following values: `EQ` (equal), `NE` (not equal), `LT` (less than), `LE` (less than or equal), `GT` (greater than), `GE` (greater than or equal). + +Implemented current-package carry/select surface also includes: +- `pto.vselr(vec0, vec1) -> VRegType` +- `pto.vselrv2(vec0, vec1) -> VRegType` +- `pto.vaddcs(vec0, vec1, carry_in, mask) -> (VRegType, MaskType)` +- `pto.vsubcs(vec0, vec1, carry_in, mask) -> (VRegType, MaskType)` + +#### `pto.vcmp(vec0: VRegType, vec1: VRegType, seed_mask: MaskType, cmp_mode: CmpMode) -> MaskType` + +**Description**: Element-wise vector comparison with seed mask. Compares two vectors element-wise and generates a predicate mask based on the specified comparison mode. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec0` | `VRegType` | First input vector | +| `vec1` | `VRegType` | Second input vector | +| `seed_mask` | `MaskType` | Seed mask that determines which lanes participate in the comparison | +| `cmp_mode` | `CmpMode` | Comparison mode enum: `CmpMode.EQ` (equal), `CmpMode.NE` (not equal), `CmpMode.LT` (less than), `CmpMode.LE` (less than or equal), `CmpMode.GT` (greater than), `CmpMode.GE` (greater than or equal) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `MaskType` | Generated predicate mask based on element-wise comparison | + +**Constraints**: +- Only lanes enabled by `seed_mask` participate in the comparison +- The two input vectors must have the same element type and vector length +- The output mask granularity matches the input vector element type + +**Example**: +```python +# Compare two vectors for less-than relation +all_mask = pto.make_mask(pto.f32, PAT.ALL) +lt_mask = pto.vcmp(vec_a, vec_b, all_mask, CmpMode.LT) +``` + +#### `pto.vcmps(vec: VRegType, scalar: ScalarType, seed_mask: MaskType, cmp_mode: CmpMode) -> MaskType` + +**Description**: Vector-scalar comparison with seed mask. Compares each element of a vector against a scalar value and generates a predicate mask based on the specified comparison mode. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `scalar` | `ScalarType` | Scalar value to compare against (must match vector element type) | +| `seed_mask` | `MaskType` | Seed mask that determines which lanes participate in the comparison | +| `cmp_mode` | `CmpMode` | Comparison mode enum: `CmpMode.EQ`, `CmpMode.NE`, `CmpMode.LT`, `CmpMode.LE`, `CmpMode.GT`, `CmpMode.GE` | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `MaskType` | Generated predicate mask based on vector-scalar comparison | + +**Constraints**: +- Only lanes enabled by `seed_mask` participate in the comparison +- The scalar type must match the vector element type +- The output mask granularity matches the input vector element type + +**Example**: +```python +# Check which elements are greater than zero +all_mask = pto.make_mask(pto.f32, PAT.ALL) +positive_mask = pto.vcmps(values, pto.f32(0.0), all_mask, CmpMode.GT) +``` + +#### `pto.vaddc(vec1: VRegType, vec2: VRegType, mask: MaskType) -> (VRegType, MaskType)` + +**Description**: Vector addition with carry output. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Sum vector | +| `carry_out` | `MaskType` | Output carry mask | + +#### `pto.vsubc(vec1: VRegType, vec2: VRegType, mask: MaskType) -> (VRegType, MaskType)` + +**Description**: Vector subtraction with borrow output. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Difference vector | +| `borrow_out` | `MaskType` | Output borrow mask | + +#### `pto.vsel(true_vec: VRegType, false_vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Vector select based on mask. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `true_vec` | `VRegType` | Vector selected when mask bit is 1 | +| `false_vec` | `VRegType` | Vector selected when mask bit is 0 | +| `mask` | `MaskType` | Selection mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Selected vector | + +**Example**: +```python +result = pto.vsel(scaled_vec, original_vec, mask32) +``` + +### Reduction Operations + +Reduction operations across vector lanes or channels. + +#### `pto.vcgadd(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Cross-group addition reduction (reduction across VLanes). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Reduced sum across groups | + +#### `pto.vcgmax(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Cross-group maximum reduction (reduction across VLanes). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Reduced maximum across groups | + +#### `pto.vcgmin(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Cross-group minimum reduction (reduction across VLanes). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Reduced minimum across groups | + +#### `pto.vcpadd(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Cross-channel addition reduction (reduction across channels). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Reduced sum across channels | + +### Data Rearrangement + +Operations for rearranging data within vectors. + +Predicate rearrangement ops `pto.pdintlv_b8` and `pto.pintlv_b16` are documented in `10-predicate-operations.md` because they operate on predicate masks rather than vector registers. + +Implemented current-package rearrangement surface also includes: +- `pto.vintlvv2(vec0, vec1, part) -> VRegType` +- `pto.vdintlvv2(vec0, vec1, part) -> VRegType` + +#### `pto.vintlv(vec1: VRegType, vec2: VRegType) -> (VRegType, VRegType)` + +**Description**: Interleave two vectors and return the low/high results. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `low` | `VRegType` | Low interleaved result | +| `high` | `VRegType` | High interleaved result | + +#### `pto.vdintlv(vec0: VRegType, vec1: VRegType) -> (VRegType, VRegType)` + +**Description**: Deinterleave a pair of vectors into low/high results. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec0` | `VRegType` | First input vector | +| `vec1` | `VRegType` | Second input vector | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec1` | `VRegType` | First deinterleaved vector | +| `vec2` | `VRegType` | Second deinterleaved vector | + +#### `pto.vpack(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Vector packing (combine elements from two vectors). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Packed vector | + +#### `pto.vperm(vec: VRegType, indices: VRegType, mask: MaskType) -> VRegType` + +**Description**: Vector permutation (reorder elements according to index vector). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `indices` | `VRegType` | Permutation indices | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Permuted vector | + +#### `pto.vshift(vec: VRegType, shift_amount: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Generic vector shift (shift all elements by same amount). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `shift_amount` | `ScalarType` | Shift amount (same for all elements) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Shifted vector | + +#### `pto.vslide(vec: VRegType, window_size: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Vector sliding window (create overlapping windows). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `window_size` | `ScalarType` | Size of sliding window | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Sliding window result | + +#### `pto.vsort32(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: 32-element sorting of vector elements. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector (32 elements) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Sorted vector | + +**Constraints**: +- Input vector must have exactly 32 elements + +#### `pto.vmrgsort(vec1: VRegType, vec2: VRegType, mask: MaskType) -> VRegType` + +**Description**: Merge sort of two vectors. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec1` | `VRegType` | First input vector | +| `vec2` | `VRegType` | Second input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Merged and sorted vector | + +#### `pto.vtranspose(dest: ptr, src: ptr, config: pto.i64) -> None` [Advanced Tier] + +**Description**: UB-to-UB transpose operation. This op works on UB memory directly (not `vreg -> vreg`). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `dest` | `ptr` | Destination pointer in UB memory space | +| `src` | `ptr` | Source pointer in UB memory space | +| `config` | `pto.i64` | ISA control/config operand that encodes transpose layout behavior | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `None` | `None` | Side-effect operation that writes transposed data to `dest` | + +**Constraints**: +- `dest` and `src` must be UB pointers +- Correctness depends on the `config` encoding and UB layout contract + +**Example**: +```python +pto.vtranspose(dst_ub_ptr, src_ub_ptr, config_word) +``` + +### Conversion & Special Operations + +Type conversion and specialized operations. + +#### `pto.vtrc(vec: VRegType, mask: MaskType, rnd: pto.VcvtRoundMode | None = None) -> VRegType` + +**Description**: Truncate/round float to integer-valued float (stays in float type). This is the TileLang DSL surface for the VPTO `pto.vtrc` operation. + +**Attribute Enums**: +- `pto.VcvtRoundMode`: `R`, `A`, `F`, `C`, `Z`, `O` (note: `vtrc` does not support `O`) + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | +| `rnd` | `pto.VcvtRoundMode` \| `None` | Optional rounding-mode attribute lowered to VPTO `round_mode`. Defaults to `R` if not specified. | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Truncated vector with integer-valued float elements | + +**Constraints**: +- Current TileLang DSL v1 accepts exactly two positional arguments: `pto.vtrc(vec, mask)`. Optional `rnd` attribute is exposed as keyword argument: `rnd=...`. +- The underlying VPTO op syntax is `pto.vtrc %input, %mask, "RND"`. +- Supported rounding modes are `R` (round to nearest), `A` (round away from zero), `F` (floor), `C` (ceil), `Z` (truncate toward zero). +- The enum form is preferred. For compatibility, canonical strings such as `"R"`, `"A"`, `"F"`, `"C"`, `"Z"` are also accepted. +- This op does not change the element type; input and output have the same vector type. +- Only floating-point element types are supported: `f16`, `bf16`, `f32`. + +#### `pto.vcvt(vec: VRegType, to_type: Type, mask: MaskType, rnd: pto.VcvtRoundMode | None = None, sat: pto.VcvtSatMode | None = None, part: pto.VcvtPartMode | None = None) -> VRegType` + +**Description**: Convert vector elements between supported float and integer +families. This is the TileLang DSL surface for the VPTO `pto.vcvt` conversion +family. + +**Attribute Enums**: +- `pto.VcvtRoundMode`: `R`, `A`, `F`, `C`, `Z`, `O` +- `pto.VcvtSatMode`: `SAT`, `NOSAT` +- `pto.VcvtPartMode`: `EVEN`, `ODD`, `P0`, `P1`, `P2`, `P3` + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `to_type` | `Type` | Target scalar dtype symbol for the result vector element type | +| `mask` | `MaskType` | Predicate mask selecting active source lanes. Its granularity must match the source vector family, not the destination family | +| `rnd` | `pto.VcvtRoundMode` \| `None` | Optional rounding-mode attribute lowered to VPTO `rnd` | +| `sat` | `pto.VcvtSatMode` \| `None` | Optional saturation attribute lowered to VPTO `sat` | +| `part` | `pto.VcvtPartMode` \| `None` | Optional width-changing lane-placement selector lowered to VPTO `part` | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Converted vector with the vreg shape implied by `to_type` | + +**Constraints**: +- Current TileLang DSL v1 accepts exactly three positional arguments: + `pto.vcvt(vec, to_type, mask)`. Optional attributes are exposed as keyword + arguments: `rnd=...`, `sat=...`, `part=...`. +- The underlying VPTO op family is the fuller + `pto.vcvt %input, %mask {rnd, sat, part}` surface, and the DSL keywords map + directly to those VPTO attributes. +- `mask` always follows the source vector family: + `f32`/`i32`/`si32`/`ui32` use `mask_b32`; + `f16`/`bf16`/`i16`/`si16`/`ui16` use `mask_b16`; + `i8`/`si8`/`ui8` use `mask_b8`. +- The enum form is preferred. For compatibility, canonical strings such as + `"R"`, `"SAT"`, and `"EVEN"` are also accepted. +- VPTO `part` supports two families: `Part` (`EVEN`/`ODD`) for ordinary + width-changing conversions (e.g. `32 -> 16`, `16 -> 32`), and `Part_T` + (`P0`–`P3`) for 4-way packed placement (e.g. `32 -> 8`, fp8/fp4 flows). + + | Mode | VPTO spelling | Family | Description | TileLang DSL v1 status | + |------|---------------|--------|-------------|------------------------| + | `EVEN` | `PART_EVEN` | `Part` | Output to even-indexed lanes | Exposed as `pto.VcvtPartMode.EVEN` | + | `ODD` | `PART_ODD` | `Part` | Output to odd-indexed lanes | Exposed as `pto.VcvtPartMode.ODD` | + | `P0` | `PART_P0` | `Part_T` | Output to sub-part 0 in 4-way packed placement | Exposed as `pto.VcvtPartMode.P0` | + | `P1` | `PART_P1` | `Part_T` | Output to sub-part 1 in 4-way packed placement | Exposed as `pto.VcvtPartMode.P1` | + | `P2` | `PART_P2` | `Part_T` | Output to sub-part 2 in 4-way packed placement | Exposed as `pto.VcvtPartMode.P2` | + | `P3` | `PART_P3` | `Part_T` | Output to sub-part 3 in 4-way packed placement | Exposed as `pto.VcvtPartMode.P3` | +- Only backend-supported source/destination type pairs are legal. For the full + A5 `vcvt` type matrix, width-changing packing rules, and attribute-sensitive + forms, refer to + [`../vpto_spec/vpto-spec-current.md`](../vpto_spec/vpto-spec-current.md). +- Attribute requirements are type-pair specific. The DSL enforces the same + per-form contract as VPTO, so some pairs require attributes while others + reject them. +- Examples: + `f32 -> si32` requires `rnd` and `sat`; + `f16 -> si32` requires `rnd` and `part`, and rejects `sat`; + `bf16 -> f16` requires `rnd` and `sat`; + `f16 -> f32` requires `part`; + `f32 -> f16` requires `rnd`, `sat`, and `part`; + `si32 -> f32` requires `rnd`. +- VPTO does not define a `mask_b64` form. Conversions that produce `si64` + results still use the typed mask granularity of the source vector family. +- Width-changing conversions continue to follow VPTO packing semantics even on + the simplified DSL surface. For example, `f16 -> f32` uses an `f16`-family + `mask_b16`, because the mask is attached to the source vector family. +- A common `tcvt`-style pair is: + `f16 -> f32`: `pto.vlds(..., dist=pto.VLoadDist.UNPK_B16)` + `pto.vcvt(..., part=pto.VcvtPartMode.EVEN)`; + `f32 -> f16`: `pto.vcvt(..., rnd=..., sat=..., part=pto.VcvtPartMode.EVEN)` + `pto.vsts(..., dist=pto.VStoreDist.PK_B32)`. +- In those `tcvt` flows, the `vcvt` mask still follows the source vector family: + `f16 -> f32` uses `mask_b16`, while `f32 -> f16` uses `mask_b32`. +- The follow-on `vsts` mask is checked against the store `dist`, not the narrowed element dtype alone. For example, `pto.vsts(vec_f16, ..., mask32, dist=pto.VStoreDist.PK_B32)` is valid and expected for `f32 -> f16` rowwise `tcvt`. + +**Example**: +```python +mask16 = pto.make_mask(pto.f16, PAT.ALL) +vec_f16 = pto.vlds(src, 0) +vec_f32 = pto.vcvt(vec_f16, pto.f32, mask16) + +mask32 = pto.make_mask(pto.f32, PAT.ALL) +vec_i32 = pto.vcvt(vec_f32, pto.si32, mask32) + +vec_i32_wide = pto.vcvt( + vec_f16, + pto.si32, + mask16, + rnd=pto.VcvtRoundMode.R, + part=pto.VcvtPartMode.EVEN, +) + +vec_f16_from_bf16 = pto.vcvt( + vec_bf16, + pto.f16, + mask16, + rnd=pto.VcvtRoundMode.R, + sat=pto.VcvtSatMode.SAT, +) + +vec_f16_narrow = pto.vcvt( + vec_f32, + pto.f16, + mask32, + rnd=pto.VcvtRoundMode.R, + sat=pto.VcvtSatMode.SAT, + part=pto.VcvtPartMode.ODD, +) + +# Rowwise tcvt-style widening from f16 to f32 +vec_f16_unpacked = pto.vlds(src, 0, dist=pto.VLoadDist.UNPK_B16) +vec_f32_from_f16 = pto.vcvt( + vec_f16_unpacked, + pto.f32, + mask16, + part=pto.VcvtPartMode.EVEN, +) + +# Rowwise tcvt-style narrowing from f32 to f16 +vec_f16_packed = pto.vcvt( + vec_f32, + pto.f16, + mask32, + rnd=pto.VcvtRoundMode.R, + sat=pto.VcvtSatMode.SAT, + part=pto.VcvtPartMode.EVEN, +) +pto.vsts(vec_f16_packed, dst, 0, mask32, dist=pto.VStoreDist.PK_B32) +``` + +#### `pto.vbitsort(dest: ptr, src: ptr, indices: ptr, repeat_times: index) -> None` [Advanced Tier] + +**Description**: Sort 32 region proposals by score and materialize sorted proposal +records into UB memory. This is a UB helper and not a `vreg -> vreg` operation. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `dest` | `ptr` | Destination pointer in UB memory space | +| `src` | `ptr` | Source score pointer in UB memory space | +| `indices` | `ptr` | Source index pointer in UB memory space | +| `repeat_times` | `index` | Repeat count; each repeat processes the next adjacent group of 32 scores and 32 indices | + +**Returns**: +None. The op writes UB memory directly. + +**Constraints**: +- `dest`, `src`, and `indices` must be UB-backed pointers +- Scores are sorted in descending order +- Equal-score ties preserve the earlier input proposal first +- Output records occupy 8 bytes each: upper 4 bytes for the index and lower 4 bytes for the score + +#### `pto.vmrgsort4(dest: ptr, src0: ptr, src1: ptr, src2: ptr, src3: ptr, count: pto.i64, config: pto.i64) -> None` [Advanced Tier] + +**Description**: Merge-sort 4 pre-sorted UB inputs. This op writes UB memory +directly and does not return a vector SSA value. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `dest` | `ptr` | Destination pointer in UB memory space | +| `src0` | `ptr` | First pre-sorted input pointer in UB memory space | +| `src1` | `ptr` | Second pre-sorted input pointer in UB memory space | +| `src2` | `ptr` | Third pre-sorted input pointer in UB memory space | +| `src3` | `ptr` | Fourth pre-sorted input pointer in UB memory space | +| `count` | `pto.i64` | Number of valid input elements participating in the merge | +| `config` | `pto.i64` | Operation control word encoding sort behavior | + +**Returns**: +None. The op writes UB memory directly. + +**Constraints**: +- `dest` and `src0` through `src3` must be UB-backed pointers +- Inputs must already be sorted according to the order encoded by `config` + +#### `pto.get_vms4_sr() -> (pto.i16, pto.i16, pto.i16, pto.i16)` [Advanced Tier] + +**Description**: Read `VMS4_SR` after exhausted `pto.vmrgsort4` and return the +finished element counts for source lists 0 through 3. + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `list0` | `pto.i16` | Finished count from `VMS4_SR[15:0]` | +| `list1` | `pto.i16` | Finished count from `VMS4_SR[31:16]` | +| `list2` | `pto.i16` | Finished count from `VMS4_SR[47:32]` | +| `list3` | `pto.i16` | Finished count from `VMS4_SR[63:48]` | + +**Example**: +```python +list0, list1, list2, list3 = pto.get_vms4_sr() +``` + +**Order Mode Enum**: The `OrderMode` enum provides type-safe order selection for `pto.vci` operations. `ASC` and `DESC` are supported. + +#### `pto.vci(index: ScalarType, order: OrderMode = OrderMode.ASC) -> VRegType` + +**Description**: Generate a lane-index vector from a scalar seed/index value (DSA/SFU operation). + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `index` | `ScalarType` | Scalar seed or base index value | +| `order` | `OrderMode` | Order mode enum (default: `OrderMode.ASC`; supported values: `ASC`, `DESC`) | + +**Returns**: +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Generated index vector | + +**Constraints**: +- This is an index-generation family, not a numeric conversion +- The `order` parameter and result element type together determine how indices are generated +- Supported order modes are ascending (`OrderMode.ASC`) and descending (`OrderMode.DESC`) + +**Example**: +```python +# Generate ascending indices starting from 0 +indices = pto.vci(pto.i32(0), OrderMode.ASC) + +# Generate descending indices starting from the seed value +indices_desc = pto.vci(pto.i32(63), OrderMode.DESC) + +# Keyword form for the optional order argument is also supported +indices_kw = pto.vci(pto.i32(0), order=OrderMode.ASC) +``` diff --git a/tilelang-dsl/docs/user_guide/12-cube-operations.md b/tilelang-dsl/docs/user_guide/12-cube-operations.md new file mode 100644 index 000000000..275039838 --- /dev/null +++ b/tilelang-dsl/docs/user_guide/12-cube-operations.md @@ -0,0 +1,454 @@ +# Cube Matrix Multiply Operations + +Cube operations target the AIC (Cube) hardware unit for matrix multiplication and +staged data movement. They are only available inside `@pto.ckernel` function +bodies. All Cube operands use `pto.ptr` raw pointers — no +`vecscope` execution scope is used. + +## Address Spaces + +Cube operations use the following address spaces via the `MemorySpace` enum. +The IR type column shows the canonical `!pto.ptr` spelling. Older +`mat`/`left`/`right`/`acc`/`bias`/`scaling` pointer spellings are accepted as +parser aliases and print back as `l1`/`l0a`/`l0b`/`l0c`/`bt`/`fb`. + +| Address Space | Enum Value | Canonical IR Type | Legacy ptr alias | Description | +|--------------|------------|-------------------|------------------|-------------| +| `GM` | `MemorySpace.GM` | `!pto.ptr` | - | Global memory | +| `MAT` | `MemorySpace.MAT` | `!pto.ptr` | `mat` | L1 buffer (cbuf) | +| `LEFT` | `MemorySpace.LEFT` | `!pto.ptr` | `left` | L0A left-operand buffer | +| `RIGHT` | `MemorySpace.RIGHT` | `!pto.ptr` | `right` | L0B right-operand buffer | +| `ACC` | `MemorySpace.ACC` | `!pto.ptr` | `acc` | L0C accumulator buffer | +| `BIAS` | `MemorySpace.BIAS` | `!pto.ptr` | `bias` | Bias table | +| `UB` | `MemorySpace.UB` | `!pto.ptr` | `vec` | Unified buffer (Vector side) | + +## Shared Infrastructure + +Cube operations reuse general tile and pointer facilities documented elsewhere: + +| Facility | Description | Reference | +|----------|-------------|-----------| +| `pto.Tile` | Allocate a tile buffer with address space | [Type System — Tile Type Definition](05-type-system.md#tile-type-definition) | +| `.as_ptr()` | Get raw pointer from Tile / TensorView | [Frontend Operations — Pointer Construction](07-frontend-operations.md#pointer-construction-advanced-tier) | +| `pto.addptr` | Element-offset a pointer | [Frontend Operations — Pointer Construction](07-frontend-operations.md#pointer-construction-advanced-tier) | + +--- + +## Matrix Compute Operations + +### `pto.mad` — zero-init matmul + +#### `pto.mad(lhs: PtrType, rhs: PtrType, dst: PtrType, m: int, n: int, k: int, *, unit_flag_ctrl: int = 0, disable_gemv: bool = False) -> None` + +**Description**: Zero-init cube matrix multiply. Clears the accumulator and computes +`dst = lhs * rhs`. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `lhs` | `pto.ptr` | L0A left operand | +| `rhs` | `pto.ptr` | L0B right operand | +| `dst` | `pto.ptr` | L0C accumulator destination | +| `m` | `int` | M dimension size | +| `n` | `int` | N dimension size | +| `k` | `int` | K dimension size | +| `unit_flag_ctrl` | `int` | Accumulator control flag (0 / 2 / 3) | +| `disable_gemv` | `bool` | GEMV disable control | + +**Constraints**: +- `lhs` must be in `l0a` address space. +- `rhs` must be in `l0b` address space. +- `dst` must be in `l0c` address space. + +**Example**: +```python +pto.mad(l0a, l0b, l0c, 16, 16, 64) +``` + +--- + +### `pto.mad_acc` — accumulating matmul + +#### `pto.mad_acc(lhs: PtrType, rhs: PtrType, dst: PtrType, m: int, n: int, k: int, *, unit_flag_ctrl: int = 0, disable_gemv: bool = False) -> None` + +**Description**: Accumulating cube matrix multiply. Computes `dst += lhs * rhs`. + +**Parameters**: Same as `pto.mad`. + +**Example**: +```python +pto.mad_acc(l0a, l0b, l0c, 16, 16, 64, unit_flag_ctrl=2) +``` + +--- + +### `pto.mad_bias` — bias-init matmul + +#### `pto.mad_bias(lhs: PtrType, rhs: PtrType, dst: PtrType, bias: PtrType, m: int, n: int, k: int, *, unit_flag_ctrl: int = 0, disable_gemv: bool = False) -> None` + +**Description**: Bias-init cube matrix multiply. Computes `dst = lhs * rhs + bias`. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `bias` | `pto.ptr` | Bias table pointer | + +Other parameters are the same as `pto.mad`. + +**Constraints**: +- `bias` must be in `bt` address space. + +**Example**: +```python +pto.mad_bias(l0a, l0b, l0c, bt, 16, 16, 64) +``` + +--- + +### `pto.mad_mx` — zero-init MX matmul + +#### `pto.mad_mx(lhs: PtrType, rhs: PtrType, dst: PtrType, m: int, n: int, k: int, *, unit_flag_ctrl: int = 0, disable_gemv: bool = False) -> None` + +**Description**: Zero-init MX (micro-scaling) cube matrix multiply. Same semantics +as `pto.mad`, for MX-capable dtypes such as `f8E4M3FN`. + +**Parameters**: Same as `pto.mad`. + +**Example**: +```python +pto.mad_mx(l0a, l0b, l0c, 16, 16, 64) +``` + +--- + +### `pto.mad_mx_acc` — accumulating MX matmul + +#### `pto.mad_mx_acc(lhs: PtrType, rhs: PtrType, dst: PtrType, m: int, n: int, k: int, *, unit_flag_ctrl: int = 0, disable_gemv: bool = False) -> None` + +**Description**: Accumulating MX cube matrix multiply. Computes `dst += lhs * rhs`. + +**Parameters**: Same as `pto.mad`. + +--- + +### `pto.mad_mx_bias` — MX bias-init matmul + +#### `pto.mad_mx_bias(lhs: PtrType, rhs: PtrType, dst: PtrType, bias: PtrType, m: int, n: int, k: int, *, unit_flag_ctrl: int = 0, disable_gemv: bool = False) -> None` + +**Description**: MX bias-init cube matrix multiply. Computes `dst = lhs * rhs + bias`. + +**Parameters**: Same as `pto.mad_bias`. + +--- + +## Data Movement Operations + +### `pto.cube_load` — GM → L1 (cbuf) + +#### `pto.cube_load(src: PtrType, dst: PtrType, len_burst: int, *, nburst: tuple[int, int, int] = (1, 0, 0), loops: list[tuple[int, int, int]] | None = None) -> None` + +**Description**: Structured GM-to-L1 (`cbuf` / `l1`) data movement wrapper. Lowers +to loop/stride setup plus `pto.copy_gm_to_cbuf`. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `pto.ptr` | Global memory source pointer | +| `dst` | `pto.ptr` | L1 (cbuf) destination pointer | +| `len_burst` | `int` | Burst length in bytes | +| `nburst` | `tuple[int, int, int]` | `(count, src_stride, dst_stride)` | +| `loops` | `list[tuple[int, int, int]]` or `None` | Optional nested loop params, each `(count_i, src_stride_i, dst_stride_i)` | + +**Constraints**: +- `src` must be in `gm` address space. +- `dst` must be in `l1` address space. + +**Example**: +```python +pto.cube_load(a_ptr, l1_a.as_ptr(), 16, nburst=(1, 0, 0)) +``` + +--- + +### `pto.cube_store` — L1 (cbuf) → UB + +#### `pto.cube_store(src: PtrType, dst: PtrType, len_burst: int, *, nburst: tuple[int, int, int] = (1, 0, 0), loops: list[tuple[int, int, int]] | None = None) -> None` + +**Description**: Structured L1 (`cbuf`) to UB data movement wrapper. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `pto.ptr` | L1 source pointer | +| `dst` | `pto.ptr` | UB destination pointer | +| `len_burst` | `int` | Burst length in bytes | +| `nburst` | `tuple[int, int, int]` | `(count, src_stride, dst_stride)` | +| `loops` | `list[tuple[int, int, int]]` or `None` | Optional nested loop params | + +**Example**: +```python +pto.cube_store(l1_src.as_ptr(), ub_dst.as_ptr(), 16, nburst=(1, 0, 0)) +``` + +--- + +### `pto.cube_load_frac` — fractal load + +#### `pto.cube_load_frac(src: PtrType, dst: PtrType, mode: pto.FractalMode, *, shape: tuple[int, int], src_layout: tuple[int, int], dst_group: tuple[int, int, int, int], ctrl: tuple[int, bool]) -> None` + +**Description**: Structured fractal-load wrapper for `nd2nz` and `dn2nz` modes. +Lowers to `set_mte2_nz_para` plus `copy_gm_to_cbuf_multi_nd2nz` or +`copy_gm_to_cbuf_multi_dn2nz`. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `pto.ptr` | Global memory source pointer | +| `dst` | `pto.ptr` | L1 destination pointer | +| `mode` | `pto.FractalMode` | `pto.FractalMode.ND2NZ` or `pto.FractalMode.DN2NZ` | +| `shape` | `tuple[int, int]` | `(n_value, d_value)` | +| `src_layout` | `tuple[int, int]` | `(inner_stride, outer_stride)` | +| `dst_group` | `tuple[int, int, int, int]` | `(group_count, loop2_stride, loop3_stride, loop4_stride)` | +| `ctrl` | `tuple[int, bool]` | `(l2_cache_ctrl, smallc0_en)` | + +**Constraints**: +- `src` must be in `gm` address space. +- `dst` must be in `l1` address space. + +**Example**: +```python +pto.cube_load_frac(a_ptr, l1_a.as_ptr(), pto.FractalMode.ND2NZ, + shape=(16, 16), src_layout=(4, 8), + dst_group=(1, 0, 0, 0), ctrl=(0, False)) +``` + +--- + +### `pto.bias_load` — L1 (cbuf) → bias table + +#### `pto.bias_load(src: PtrType, dst: PtrType, len_burst: int, *, nburst: tuple[int, int, int] = (1, 0, 0)) -> None` + +**Description**: Structured L1 (`cbuf`) to bias-table load wrapper. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `pto.ptr` | L1 source pointer | +| `dst` | `pto.ptr` | Bias table destination pointer | +| `len_burst` | `int` | Burst length in bytes | +| `nburst` | `tuple[int, int, int]` | `(count, src_gap, dst_gap)` | + +**Constraints**: +- Supported source/destination type pairs: `f32→f32`, `i32→i32`, `f16→f32`, `bf16→f32`. + +**Example**: +```python +pto.bias_load(l1_bias.as_ptr(), bt.as_ptr(), 16, nburst=(1, 0, 0)) +``` + +--- + +### `pto.left_load` — L1 (cbuf) → L0A + +#### `pto.left_load(src: PtrType, dst: PtrType, m: int, k: int) -> None` + +**Description**: Structured L1-to-L0A wrapper. Lowers to `pto.load_cbuf_to_ca`. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `pto.ptr` | L1 source pointer | +| `dst` | `pto.ptr` | L0A destination pointer | +| `m` | `int` | M dimension size | +| `k` | `int` | K dimension size | + +**Constraints**: +- `src` must be in `l1` address space. +- `dst` must be in `l0a` address space. + +**Example**: +```python +pto.left_load(l1_a.as_ptr(), l0a.as_ptr(), 16, 64) +``` + +--- + +### `pto.right_load` — L1 (cbuf) → L0B + +#### `pto.right_load(src: PtrType, dst: PtrType, k: int, n: int) -> None` + +**Description**: Structured L1-to-L0B wrapper. Lowers to `pto.load_cbuf_to_cb`. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `pto.ptr` | L1 source pointer | +| `dst` | `pto.ptr` | L0B destination pointer | +| `k` | `int` | K dimension size | +| `n` | `int` | N dimension size | + +**Constraints**: +- `src` must be in `l1` address space. +- `dst` must be in `l0b` address space. + +**Example**: +```python +pto.right_load(l1_b.as_ptr(), l0b.as_ptr(), 64, 16) +``` + +--- + +### `pto.left_load_mx` — MX L1 → L0A + +#### `pto.left_load_mx(src: PtrType, dst: PtrType, m: int, k: int) -> None` + +**Description**: MX-mode L1-to-L0A wrapper. Lowers to `pto.load_cbuf_to_ca_mx`. + +**Parameters**: Same as `pto.left_load`. + +--- + +### `pto.right_load_mx` — MX L1 → L0B + +#### `pto.right_load_mx(src: PtrType, dst: PtrType, k: int, n: int) -> None` + +**Description**: MX-mode L1-to-L0B wrapper. Lowers to `pto.load_cbuf_to_cb_mx`. + +**Parameters**: Same as `pto.right_load`. + +--- + +## Result Writeback Operations + +### `pto.acc_store` — L0C (acc) → L1 (cbuf) + +#### `pto.acc_store(src: PtrType, dst: PtrType, m: int, n: int, src_stride: int, dst_stride: int, *, mode: pto.FractalMode = pto.FractalMode.NZ2ND, loop0_src_stride: int | None = None, split: int | None = None, loop3: tuple[int, int, int] | None = None) -> None` + +**Description**: Structured L0C (`l0c`) to L1 (`cbuf`) writeback wrapper. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `pto.ptr` | L0C source pointer | +| `dst` | `pto.ptr` | L1 (cbuf) destination pointer | +| `m` | `int` | M dimension size | +| `n` | `int` | N dimension size | +| `src_stride` | `int` | Source stride | +| `dst_stride` | `int` | Destination stride | +| `mode` | `pto.FractalMode` | Layout mode: `NZ2ND` / `NZ2DN` / `NZ2NZ` | + +Mode-dependent parameters: + +| Mode | Required | Not Accepted | +|------|----------|--------------| +| `pto.FractalMode.NZ2ND` | (none) | — | +| `pto.FractalMode.NZ2DN` | `loop0_src_stride` | — | +| `pto.FractalMode.NZ2NZ` | `split` | `loop3` | + +Optional for `pto.FractalMode.NZ2ND` and `pto.FractalMode.NZ2DN`: +`loop3=(count, src_stride3, dst_stride3)`. + +**Example**: +```python +pto.acc_store(l0c.as_ptr(), l1_out.as_ptr(), + 16, 16, 16, 16, mode=pto.FractalMode.NZ2ND) +``` + +--- + +### `pto.acc_store_gm` — L0C (acc) → GM + +#### `pto.acc_store_gm(src: PtrType, dst: PtrType, m: int, n: int, src_stride: int, dst_stride: int, *, sid: int = 0, l2_cache_ctrl: int = 0, mode: pto.FractalMode = pto.FractalMode.NZ2ND, loop0_src_stride: int | None = None, split: int | None = None, loop3: tuple[int, int, int] | None = None) -> None` + +**Description**: Structured L0C (`l0c`) to GM writeback wrapper. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `pto.ptr` | L0C source pointer | +| `dst` | `pto.ptr` | GM destination pointer | +| `sid` | `int` | Stream ID | +| `l2_cache_ctrl` | `int` | L2 cache control | + +Other parameters are the same as `pto.acc_store`. + +**Example**: +```python +pto.acc_store_gm(l0c.as_ptr(), c_ptr, 16, 16, 16, 16, mode=pto.FractalMode.NZ2ND) +``` + +--- + +### `pto.acc_store_ub` — L0C (acc) → UB + +#### `pto.acc_store_ub(src: PtrType, dst: PtrType, m: int, n: int, src_stride: int, dst_stride: int, *, dual_dst_mode: int = 0, sub_blockid: int = 0, mode: pto.FractalMode = pto.FractalMode.NZ2ND, loop0_src_stride: int | None = None, channel_split_en: int | None = None, loop3: tuple[int, int, int] | None = None) -> None` + +**Description**: Structured L0C (`l0c`) to UB writeback wrapper. + +**Parameters**: +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `pto.ptr` | L0C source pointer | +| `dst` | `pto.ptr` | UB destination pointer | +| `dual_dst_mode` | `int` | Dual destination mode | +| `sub_blockid` | `int` | Sub-block ID | +| `channel_split_en` | `int` or `None` | Channel split enable (required for `mode=pto.FractalMode.NZ2NZ`) | + +Other parameters are the same as `pto.acc_store`. + +**Example**: +```python +pto.acc_store_ub(l0c.as_ptr(), ub_out.as_ptr(), + 16, 16, 16, 16, mode=pto.FractalMode.NZ2ND) +``` + +--- + +## Quick Reference + +### By Data Flow + +| Data Flow | Operation | Src Space | Dst Space | +|-----------|-----------|-----------|-----------| +| GM → L1 | `pto.cube_load` | gm | l1 | +| GM → L1 (fractal) | `pto.cube_load_frac` | gm | l1 | +| L1 → UB | `pto.cube_store` | l1 | ub | +| L1 → L0A | `pto.left_load` | l1 | l0a | +| L1 → L0B | `pto.right_load` | l1 | l0b | +| L1 → L0A (MX) | `pto.left_load_mx` | l1 | l0a | +| L1 → L0B (MX) | `pto.right_load_mx` | l1 | l0b | +| L1 → Bias | `pto.bias_load` | l1 | bt | +| L0A×L0B → L0C | `pto.mad` | l0a, l0b | l0c | +| L0A×L0B → L0C (acc) | `pto.mad_acc` | l0a, l0b | l0c | +| L0A×L0B+Bias → L0C | `pto.mad_bias` | l0a, l0b, bt | l0c | +| L0C → L1 | `pto.acc_store` | l0c | l1 | +| L0C → GM | `pto.acc_store_gm` | l0c | gm | +| L0C → UB | `pto.acc_store_ub` | l0c | ub | + +### MX Variants + +| Base Op | MX Variant | Description | +|---------|------------|-------------| +| `pto.mad` | `pto.mad_mx` | Zero-init MX matmul | +| `pto.mad_acc` | `pto.mad_mx_acc` | Accumulating MX matmul | +| `pto.mad_bias` | `pto.mad_mx_bias` | Bias-init MX matmul | + +--- + +## Template Slot Support + +Cube operations support `pto.tpl()` template-slot dispatch, consistent with the +Vector DSL mechanism. See [Template Kernels](04-template-kernels.md) for general +`pto.tpl()` usage. + +**Constraints**: Variants within the same slot must have identical parameter +signatures. For example, `mad` and `mad_acc` can share a slot, but `mad_bias` +(which adds a `bias` parameter) requires a separate slot. + +--- + +## See Also + +- [Kernel Declaration](03-kernel-declaration.md) — `@pto.ckernel` decorator specification +- [Examples](13-examples.md) — full Cube kernel code examples +- [Design doc](../../../docs/designs/tilelang-cube-dsl-design.md) — Cube DSL design details diff --git a/tilelang-dsl/docs/user_guide/13-examples.md b/tilelang-dsl/docs/user_guide/13-examples.md new file mode 100644 index 000000000..16105b853 --- /dev/null +++ b/tilelang-dsl/docs/user_guide/13-examples.md @@ -0,0 +1,417 @@ +## Examples + +### Template-based Kernel Examples + +#### Unified Arithmetic Operations + +A single kernel implementing multiple arithmetic operations using templates: + +```python +T = pto.TypeVar('T') + +@pto.vkernel( + target="a5", + ops=["tadd", "tsub", "tmul", "tdiv"], + dtypes=[(T, T, T)], + advanced=True, + templates={ + "core": { + "tadd": "vadd", + "tsub": "vsub", + "tmul": "vmul", + "tdiv": "vdiv", + } + } +) +def elementwise_arithmetic(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): + """Single implementation for four arithmetic operations.""" + dtype = dst.element_type + rows, cols = dst.valid_shape + + for row in range(0, rows, 1): + remained = cols + for col in range(0, cols, pto.elements_per_vreg(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + out = pto.tpl("core", lhs, rhs, mask) + pto.vsts(out, dst[row, col:], mask) +``` + +#### Multiple Templates with Postprocess + +Kernel using separate templates for arithmetic and postprocess operations: + +```python +@pto.vkernel( + target="a5", + ops=["add_relu", "sub_relu", "add_abs", "sub_abs"], + dtypes=[(T, T, T)], + templates={ + "arithmetic": { + "add_relu": "vadd", + "sub_relu": "vsub", + "add_abs": "vadd", + "sub_abs": "vsub", + }, + "postprocess": { + "add_relu": "vrelu", + "sub_relu": "vrelu", + "add_abs": "vabs", + "sub_abs": "vabs", + } + } +) +def elementwise_with_postprocess(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): + dtype = dst.element_type + rows, cols = dst.valid_shape + + for row in range(0, rows, 1): + remained = cols + for col in range(0, cols, pto.elements_per_vreg(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + + # Use arithmetic template + arith_result = pto.tpl("arithmetic", lhs, rhs, mask) + + # Apply postprocess template + activated = pto.tpl("postprocess", arith_result, mask) + + pto.vsts(activated, dst[row, col:], mask) +``` + +#### Compile-time Substitution + +Template substitution happens before semantic analysis and lowering: + +```python +selected = pto.select_kernel("a5", "tadd", (ptype, ptype, ptype)) +# frontend resolves: +# pto.tpl("core", lhs, rhs, mask) +# into: +# pto.vadd(lhs, rhs, mask) +``` + +#### Benefits of Template-based Authoring + +1. **Code Reuse**: Single implementation serves multiple operations +2. **Maintenance**: Bug fixes and optimizations apply to all related operations +3. **Consistency**: Ensures uniform behavior across operation families +4. **Reduced Boilerplate**: Eliminates duplicate control flow and data movement code +5. **Type Safety**: Type variables ensure consistent operand types + +### Simple Vector Copy + +```python +@pto.vkernel(...) +def vector_copy(src: pto.Tile, dst: pto.Tile): + all_mask: pto.mask_b32 = pto.make_mask(pto.f32, PAT.ALL) + for offset in range(0, 256, 64): + vec = pto.vlds(src, offset) + pto.vsts(vec, dst, offset, all_mask) +``` + +### Conditional Computation + +```python +@pto.vkernel(...) +def conditional_scale(src: pto.ptr(pto.f32, MemorySpace.GM), + dst: pto.ptr(pto.f32, MemorySpace.GM), + threshold: pto.f32): + # ... setup ... + + with pto.strict_vecscope(ub_in, ub_out, threshold) as (vin, vout, thresh): + for i in range(0, 1024, 64): + vec = pto.vlds(vin, i) + + # Compare with threshold + mask = pto.pge_b32(vec, thresh) + + # Scale values above threshold + scaled = pto.vmuls(vec, pto.f32(2.0), mask) + + # Keep original values below threshold + result = pto.vsel(scaled, vec, mask) + + pto.vsts(result, vout, i, all_mask) +``` + +### Loop with Carry + +```python +@pto.vkernel(...) +def prefix_sum(src: pto.ptr(pto.i32, MemorySpace.UB), + dst: pto.ptr(pto.i32, MemorySpace.UB)): + all_mask = pto.make_mask(pto.i32, PAT.ALL) + carry = all_mask + + for i in range(0, 256, 64): + vec = pto.vlds(src, i) + result, carry = pto.vaddcs(vec, vec, carry, all_mask) + pto.vsts(result, dst, i, all_mask) +``` + +--- + +## Cube Kernel Examples + +Cube kernels target the AIC (Cube) hardware unit for matrix multiplication. GM data is expressed through `PartitionTensorView`, while hardware buffers in specific address spaces are constructed via `pto.Tile`. + +### Basic GEMM + +A full-pipeline matrix multiplication C = A × B: + +```python +from tilelang_dsl import ckernel, Tile, MemorySpace + +@pto.ckernel( + target="a5", + op="pto.mad", + dtypes=[(pto.f16, pto.f16, pto.f32)], + name="gemm", +) +def gemm(a_tv: pto.PartitionTensorView, # [M, K] in GM + b_tv: pto.PartitionTensorView, # [K, N] in GM + c_tv: pto.PartitionTensorView, # [M, N] in GM, output + M: int, K: int, N: int): + # Get GM pointers from PartitionTensorViews + a_ptr = a_tv.as_ptr() + b_ptr = b_tv.as_ptr() + c_ptr = c_tv.as_ptr() + + # Allocate L1 (MAT) tile buffers + l1_a_tile = pto.Tile([M, K], pto.f16, MemorySpace.MAT) + l1_b_tile = pto.Tile([K, N], pto.f16, MemorySpace.MAT) + + # Allocate L0 tile buffers + l0a_tile = pto.Tile([M, K], pto.f16, MemorySpace.LEFT) + l0b_tile = pto.Tile([K, N], pto.f16, MemorySpace.RIGHT) + l0c_tile = pto.Tile([M, N], pto.f32, MemorySpace.ACC) + + # GM → L1 + pto.cube_load(a_ptr, l1_a_tile.as_ptr(), K, nburst=(1, 0, 0)) + pto.cube_load(b_ptr, l1_b_tile.as_ptr(), N, nburst=(1, 0, 0)) + + # L1 → L0 + pto.left_load(l1_a_tile.as_ptr(), l0a_tile.as_ptr(), M, K) + pto.right_load(l1_b_tile.as_ptr(), l0b_tile.as_ptr(), K, N) + + # Compute: C = A × B + pto.mad(l0a_tile.as_ptr(), l0b_tile.as_ptr(), l0c_tile.as_ptr(), M, N, K) + + # L0C → GM writeback + pto.acc_store_gm(l0c_tile.as_ptr(), c_ptr, M, N, + src_stride=N, dst_stride=N, mode="nz2nd") +``` + +### Split-K GEMM + +Matrix multiplication with K-dimension splitting for large K values: + +```python +@pto.ckernel( + target="a5", + op="pto.mad", + dtypes=[(pto.f16, pto.f16, pto.f32)], + name="gemm_splitk", +) +def gemm_splitk(a_tv: pto.PartitionTensorView, # [M, K] + b_tv: pto.PartitionTensorView, # [K, N] + c_tv: pto.PartitionTensorView, # [M, N] + M: int, K: int, N: int, BASEK: int): + iters = K // BASEK + + a_ptr = a_tv.as_ptr() + b_ptr = b_tv.as_ptr() + c_ptr = c_tv.as_ptr() + + # Allocate buffers sized for one split-K step + l1_a = pto.Tile([M, BASEK], pto.f16, MemorySpace.MAT) + l1_b = pto.Tile([BASEK, N], pto.f16, MemorySpace.MAT) + l0a = pto.Tile([M, BASEK], pto.f16, MemorySpace.LEFT) + l0b = pto.Tile([BASEK, N], pto.f16, MemorySpace.RIGHT) + l0c = pto.Tile([M, N], pto.f32, MemorySpace.ACC) + + for k_step in range(iters): + k_off = k_step * BASEK + + # Offset GM pointers for this K-slice + a_k = pto.addptr(a_ptr, k_off) + b_k = pto.addptr(b_ptr, k_off) + + # GM → L1 → L0 + pto.cube_load(a_k, l1_a.as_ptr(), BASEK, nburst=(1, 0, 0)) + pto.cube_load(b_k, l1_b.as_ptr(), N, nburst=(1, 0, 0)) + pto.left_load(l1_a.as_ptr(), l0a.as_ptr(), M, BASEK) + pto.right_load(l1_b.as_ptr(), l0b.as_ptr(), BASEK, N) + + # First step: zero-init; subsequent steps: accumulate + if k_step == 0: + pto.mad(l0a.as_ptr(), l0b.as_ptr(), l0c.as_ptr(), M, N, BASEK) + else: + pto.mad_acc(l0a.as_ptr(), l0b.as_ptr(), l0c.as_ptr(), M, N, BASEK) + + # L0C → GM + pto.acc_store_gm(l0c.as_ptr(), c_ptr, M, N, + src_stride=N, dst_stride=N, mode="nz2nd") +``` + +### GEMM with Bias + +Matrix multiplication with bias addition C = A × B + bias: + +```python +@pto.ckernel( + target="a5", + op="pto.mad_bias", + dtypes=[(pto.f16, pto.f16, pto.f32)], + name="gemm_bias", +) +def gemm_bias(a_tv: pto.PartitionTensorView, + b_tv: pto.PartitionTensorView, + c_tv: pto.PartitionTensorView, + bias_tv: pto.PartitionTensorView, + M: int, K: int, N: int): + a_ptr = a_tv.as_ptr() + b_ptr = b_tv.as_ptr() + c_ptr = c_tv.as_ptr() + bias_ptr = bias_tv.as_ptr() + + # L1 buffers + l1_a = pto.Tile([M, K], pto.f16, MemorySpace.MAT) + l1_b = pto.Tile([K, N], pto.f16, MemorySpace.MAT) + l1_bias = pto.Tile([1, N], pto.f32, MemorySpace.MAT) + + # L0 buffers + l0a = pto.Tile([M, K], pto.f16, MemorySpace.LEFT) + l0b = pto.Tile([K, N], pto.f16, MemorySpace.RIGHT) + l0c = pto.Tile([M, N], pto.f32, MemorySpace.ACC) + + # Bias table + bt = pto.Tile([1, N], pto.f32, MemorySpace.BIAS) + + # Data movement + pto.cube_load(a_ptr, l1_a.as_ptr(), K, nburst=(1, 0, 0)) + pto.cube_load(b_ptr, l1_b.as_ptr(), N, nburst=(1, 0, 0)) + pto.cube_load(bias_ptr, l1_bias.as_ptr(), N, nburst=(1, 0, 0)) + pto.bias_load(l1_bias.as_ptr(), bt.as_ptr(), N, nburst=(1, 0, 0)) + + # L1 → L0 + pto.left_load(l1_a.as_ptr(), l0a.as_ptr(), M, K) + pto.right_load(l1_b.as_ptr(), l0b.as_ptr(), K, N) + + # Compute: C = A × B + bias + pto.mad_bias(l0a.as_ptr(), l0b.as_ptr(), l0c.as_ptr(), bt.as_ptr(), M, N, K) + + # Writeback + pto.acc_store_gm(l0c.as_ptr(), c_ptr, M, N, + src_stride=N, dst_stride=N, mode="nz2nd") +``` + +### Fractal Load (nd2nz) Example + +Using fractal load for ND-layout to NZ-fractal data loading: + +```python +@pto.ckernel( + target="a5", + op="pto.mad", + dtypes=[(pto.f16, pto.f16, pto.f32)], + name="gemm_frac", +) +def gemm_frac(a_tv: pto.PartitionTensorView, + b_tv: pto.PartitionTensorView, + c_tv: pto.PartitionTensorView, + M: int, K: int, N: int): + a_ptr = a_tv.as_ptr() + b_ptr = b_tv.as_ptr() + c_ptr = c_tv.as_ptr() + + l1_a = pto.Tile([M, K], pto.f16, MemorySpace.MAT) + l1_b = pto.Tile([K, N], pto.f16, MemorySpace.MAT) + l0a = pto.Tile([M, K], pto.f16, MemorySpace.LEFT) + l0b = pto.Tile([K, N], pto.f16, MemorySpace.RIGHT) + l0c = pto.Tile([M, N], pto.f32, MemorySpace.ACC) + + # Fractal load: ND → NZ + pto.cube_load_frac(a_ptr, l1_a.as_ptr(), "nd2nz", + shape=(M, K), + src_layout=(K,), + dst_group=(1, 0, 0, 0), + ctrl=(0, False)) + pto.cube_load(b_ptr, l1_b.as_ptr(), N, nburst=(1, 0, 0)) + + pto.left_load(l1_a.as_ptr(), l0a.as_ptr(), M, K) + pto.right_load(l1_b.as_ptr(), l0b.as_ptr(), K, N) + pto.mad(l0a.as_ptr(), l0b.as_ptr(), l0c.as_ptr(), M, N, K) + + pto.acc_store_gm(l0c.as_ptr(), c_ptr, M, N, + src_stride=N, dst_stride=N, mode="nz2nd") +``` + +### Pure-Compute Kernel (Pre-Allocated Tiles) + +When tiles are pre-allocated externally, the kernel only performs computation: + +```python +@pto.ckernel( + target="a5", + op="pto.mad", + dtypes=[(pto.f16, pto.f16, pto.f32)], + name="matmul_compute", +) +def matmul_compute(a_left: pto.Tile, # Pre-allocated LEFT tile (L0A) + b_right: pto.Tile, # Pre-allocated RIGHT tile (L0B) + c_acc: pto.Tile, # Pre-allocated ACC tile (L0C) + M: int, K: int, N: int): + pto.mad(a_left.as_ptr(), b_right.as_ptr(), c_acc.as_ptr(), M, N, K) +``` + +### Template-based Multi-Op Cube Kernel + +Reusing a single template body for multiple Cube matmul variants: + +```python +@pto.ckernel( + target="a5", + ops=["mad", "mad_acc"], + dtypes=[(pto.f16, pto.f16, pto.f32)], + name="gemm_template", + templates={ + "compute": {"mad": "mad", "mad_acc": "mad_acc"}, + }, +) +def gemm_template(a_tv: pto.PartitionTensorView, + b_tv: pto.PartitionTensorView, + c_tv: pto.PartitionTensorView, + M: int, K: int, N: int): + a_ptr = a_tv.as_ptr() + b_ptr = b_tv.as_ptr() + c_ptr = c_tv.as_ptr() + + l1_a = pto.Tile([M, K], pto.f16, MemorySpace.MAT) + l1_b = pto.Tile([K, N], pto.f16, MemorySpace.MAT) + l0a = pto.Tile([M, K], pto.f16, MemorySpace.LEFT) + l0b = pto.Tile([K, N], pto.f16, MemorySpace.RIGHT) + l0c = pto.Tile([M, N], pto.f32, MemorySpace.ACC) + + pto.cube_load(a_ptr, l1_a.as_ptr(), K, nburst=(1, 0, 0)) + pto.cube_load(b_ptr, l1_b.as_ptr(), N, nburst=(1, 0, 0)) + pto.left_load(l1_a.as_ptr(), l0a.as_ptr(), M, K) + pto.right_load(l1_b.as_ptr(), l0b.as_ptr(), K, N) + + # Template slot: resolved at specialization time + pto.tpl("compute", l0a.as_ptr(), l0b.as_ptr(), l0c.as_ptr(), M, N, K) + + pto.acc_store_gm(l0c.as_ptr(), c_ptr, M, N, + src_stride=N, dst_stride=N, mode="nz2nd") +``` + +Usage: + +```python +k_mad = pto.select_kernel("a5", "gemm_template", selected_op="mad") +k_acc = pto.select_kernel("a5", "gemm_template", selected_op="mad_acc") +``` diff --git a/tilelang-dsl/docs/user_guide/14-common-errors.md b/tilelang-dsl/docs/user_guide/14-common-errors.md new file mode 100644 index 000000000..46abe09b9 --- /dev/null +++ b/tilelang-dsl/docs/user_guide/14-common-errors.md @@ -0,0 +1,51 @@ +## Common Errors + +### Typed Mask Mismatch + +``` +Error: f32 vector operation cannot consume mask_b16 +``` + +**Solution:** Ensure mask granularity matches vector element size: +- `f32` vectors use `mask_b32` +- `f16` vectors use `mask_b16` +- `i8` vectors use `mask_b8` + +### Strict Scope Implicit Capture + +``` +Error: strict_vecscope body cannot capture outer value 'ub_in' implicitly +``` + +**Solution:** Pass all required values in the capture list: + +```python +# Wrong: +with pto.strict_vecscope() as (): + vec = pto.vlds(ub_in, offset) # ub_in from outer scope + +# Correct: +with pto.strict_vecscope(ub_in) as (ub): + vec = pto.vlds(ub, offset) +``` + +### Untyped Loop Carried State + +``` +Error: loop-carried value must have explicit machine type +``` + +**Solution:** Add type annotations to loop-carried variables: + +```python +# Wrong: +remaining = 1024 # Plain Python int +for i in range(0, N, step): + mask, remaining = pto.make_mask(pto.f32, remaining) + +# Correct: +remaining: pto.i32 = 1024 +# or +remaining = pto.i32(1024) +``` + diff --git a/tilelang-dsl/docs/user_guide/15-compatibility-notes.md b/tilelang-dsl/docs/user_guide/15-compatibility-notes.md new file mode 100644 index 000000000..defcf704c --- /dev/null +++ b/tilelang-dsl/docs/user_guide/15-compatibility-notes.md @@ -0,0 +1,9 @@ +## Compatibility Notes + +The current experimental implementation in `python/pto/dialects/pto.py` differs from this specification in several ways: + +1. **Mask types**: The experimental version uses untyped `mask` instead of `mask_b8`/`mask_b16`/`mask_b32` +2. **Barrier operation**: Uses `pto.barrier()` instead of `pto.pipe_barrier()` +3. **Operation coverage**: Implements only a subset of operations + +When implementing new code, follow this specification. The experimental implementation will be updated to match over time. diff --git a/tilelang-dsl/docs/user_guide/16-next-steps.md b/tilelang-dsl/docs/user_guide/16-next-steps.md new file mode 100644 index 000000000..2fe63b9a4 --- /dev/null +++ b/tilelang-dsl/docs/user_guide/16-next-steps.md @@ -0,0 +1,7 @@ +## Next Steps + +- Explore the ISA documentation in `docs/isa/` for detailed operation semantics +- Check `test/samples/` for example kernels +- Refer to `docs/vpto-spec.md` for the underlying VPTO instruction specification + +For compiler developers, see `docs/PTO_IR_manual.md` for MLIR-level details. diff --git a/tilelang-dsl/docs/v1-lowering.md b/tilelang-dsl/docs/v1-lowering.md new file mode 100644 index 000000000..c8ae1d82d --- /dev/null +++ b/tilelang-dsl/docs/v1-lowering.md @@ -0,0 +1,146 @@ +# TileLang DSL v1 Authoring Lowering + +## Scope + +This document records the implemented TileLang DSL v1 lowering contract for +`add-tilelang-dsl-authoring-vpto-lowering`. + +It covers: +- the current v1 lowering support matrix +- dynamic-bound and shape-profile behavior +- examples that match the implemented surface +- minimal validation commands, including the repo `ptoas` legality path + +It does not define: +- matcher-driven dispatch +- raw pointer authoring surface +- advanced vector-family lowering beyond the fixed v1 matrix + +For migration from that original v1 lowering boundary to the current matcher +and advanced-surface implementation, see +`tilelang-dsl/docs/matcher-and-advanced-surface-migration.md`. + +## Source Of Truth + +The implemented lowering surface lives under: +- `tilelang-dsl/python/tilelang_dsl/` +- `tilelang-dsl/tests/` +- `tilelang-dsl/examples/` +- `tilelang-dsl/docs/` + +OpenSpec source of truth for this capability: +- `openspec/changes/add-tilelang-dsl-authoring-vpto-lowering/` + +## Implemented v1 Support Matrix + +The current v1 lowering contract supports: +- fixed-rank 5D `TensorView` descriptors +- 1D/2D `Tile` +- `dma_load` +- `dma_store` +- `make_mask(dtype, PAT.*)` +- `make_mask(dtype, remaining)` +- `vlds` +- `vsts` +- unary vector family: `vabs`, `vrelu`, `vexp`, `vnot` +- binary vector family: `vadd`, `vsub`, `vmul`, `vdiv`, `vmax`, `vmin`, `vand`, `vor`, `vxor` +- vector-scalar family: `vadds`, `vsubs`, `vmuls`, `vdivs`, `vmaxs`, `vmins` +- `for range(lb, ub, step)` +- `if/else` +- `set_flag`, `wait_flag`, `pipe_barrier` + +Current lowering shape: +- emits stable `func.func + arith/scf + pto.*` authoring-form VPTO modules +- defaults to memref-first function/tile authoring when the target VPTO family supports memref operands +- keeps `copy_*` family on typed `!pto.ptr` +- infers dedicated `pto.vecscope` for stable vector-active runs +- lowers `pto.strict_vecscope` buffer captures through ptr-form region ABI so the current emission-boundary ptr rewrite stays legal +- only accepts explicit `pto.strict_vecscope` in `advanced=True` kernels +- rejects support-matrix-external surface in the frontend + +## Dynamic-Bound Profile + +The implemented shape profile is: +- Tile physical shape must stay static +- TensorView parameters stay in authoring IR as `!pto.tensor_view<...>` +- TensorView shape access lowers through `pto.get_tensor_view_dim` +- TensorView stride access lowers through `pto.get_tensor_view_stride` +- TensorView slice bounds may be dynamic +- TensorView slice spelling may omit leading axes; written axes are right-aligned + onto the trailing physical axes of the 5D descriptor +- loop bounds may be dynamic +- tail `remaining` values may be dynamic + +The current DMA lowering still uses the static physical Tile shape when the +TensorView slice extent is dynamic. This keeps v1 inside the current +authoring-form contract without introducing fully dynamic Tile allocation or +tail-DMA semantics. + +Although the descriptor rank is 5D, the current DMA-oriented slicing/lowering +path still only supports rank-2 TensorView slices. + +## Examples + +Examples aligned with the implemented surface: +- `tilelang-dsl/examples/v1_elementwise_tail_demo.py` + - emits a guide-style elementwise authoring kernel + - covers DMA, advanced-only explicit `strict_vecscope`, dynamic loop bound, and typed tail mask +- `tilelang-dsl/examples/v1_verify_smoke.py` + - emits a minimal module that is expected to pass the current repo + `ptoas --pto-backend=vpto` legality path + +Typical usage from the repository root: + +```bash +PYTHONPATH=$PWD/tilelang-dsl/python \ + python3 tilelang-dsl/examples/v1_elementwise_tail_demo.py + +PYTHONPATH=$PWD/tilelang-dsl/python \ + python3 tilelang-dsl/examples/v1_elementwise_tail_demo.py /tmp/tilelang_v1_elementwise.mlir + +PYTHONPATH=$PWD/tilelang-dsl/python \ + python3 tilelang-dsl/examples/v1_verify_smoke.py +``` + +## Historical Deferred Features + +The following remained outside the original v1 lowering boundary and were +assigned to follow-up changes: +- implicit vecscope inference +- matcher registry and deterministic selection +- raw pointer / low-level DMA / `copy_ubuf_to_ubuf` authoring surface +- compare/select, predicate movement, carry, rearrangement, reduction families +- wildcard / type-variable dtypes +- multiple `dtypes` signatures + +Primary follow-up change: +- `extend-tilelang-dsl-matcher-and-advanced-surface` + +In the current package head, that follow-up has implemented matcher dispatch, +implicit vecscope inference, raw pointer / low-level DMA authoring, and +compare/select + predicate movement + carry + rearrangement families. +Reduction remains deferred because the repo still does not expose a public +authoring-form VPTO reduction op for TileLang DSL to target directly. + +## Minimal Validation + +The minimal validation set for the implemented v1 lowering is: + +```bash +python3 -m py_compile tilelang-dsl/python/tilelang_dsl/*.py + +PYTHONPATH=$PWD/tilelang-dsl/python \ + python3 -m unittest $PWD/tilelang-dsl/tests/test_tilelang_dsl_v1.py + +PYTHONPATH=$PWD/tilelang-dsl/python \ + python3 tilelang-dsl/examples/v1_verify_smoke.py /tmp/tilelang_v1_verify.mlir + +build/tools/ptoas/ptoas --pto-arch a5 --pto-backend=vpto --emit-vpto \ + /tmp/tilelang_v1_verify.mlir -o /tmp/tilelang_v1_verify.checked.mlir +``` + +What these commands confirm: +- the standalone source-tree package imports and compiles +- the focused unittest suite passes for lowering, diagnostics, and verify behavior +- a generated TileLang DSL v1 module can be emitted to MLIR +- the emitted verify-smoke module passes the repo VPTO authoring-stage legality path diff --git a/tilelang-dsl/docs/v1-surface.md b/tilelang-dsl/docs/v1-surface.md new file mode 100644 index 000000000..dcf169cd1 --- /dev/null +++ b/tilelang-dsl/docs/v1-surface.md @@ -0,0 +1,254 @@ +# TileLang DSL v1 Surface + +## Scope + +This document records the implemented v1 boundary for the standalone +`tilelang_dsl` package introduced by +`add-tilelang-dsl-core-foundation`. + +It covers: +- package entrypoints +- supported `@vkernel` decorator metadata +- parameter typing rules +- Tile specialization requirements +- current frontend diagnostics boundary +- deferred features that belong to follow-up changes + +It does not define: +- DSL to VPTO lowering details +- matcher and priority semantics +- advanced vector-family surface +- implicit vecscope inference + +For implemented lowering details, examples, and `verify()` behavior, see +`tilelang-dsl/docs/v1-lowering.md`. +For migration from the original v1 core boundary to the current matcher and +advanced-surface package capabilities, see +`tilelang-dsl/docs/matcher-and-advanced-surface-migration.md`. + +## Source Of Truth + +TileLang DSL v1 source of truth lives under: +- `tilelang-dsl/python/tilelang_dsl/` +- `tilelang-dsl/tests/` +- `tilelang-dsl/examples/` +- `tilelang-dsl/docs/` + +`python/pto/dialects/pto.py` is not the source of truth for TileLang DSL v1. +That file still exists for PTO dialect bindings and the legacy experimental VPTO +Python DSL surface. Root-level wiring into build, install, and test is allowed, +but new TileLang DSL core behavior must land under `tilelang-dsl/`. + +## Package Entry + +Examples and tests should import the standalone package: + +```python +import tilelang_dsl as pto +``` + +The package currently exports: +- `vkernel` +- `VKernelDescriptor` +- `BoundKernelParameter` +- `MaterializedMLIRModule` +- `TileLangFrontendError` +- `TensorView` +- `Tile` +- `VRegType` +- `MaskType` +- scalar dtypes such as `f16`, `bf16`, `f32`, `i8`, `i16`, `i32`, `i64` +- type helpers such as `vreg(...)`, `ptr(...)`, `mask_b8`, `mask_b16`, `mask_b32`, `MemorySpace`, `TileConfig`, `TileSpecialization` + +The package does not expose a DSL-level `pto.memref(...)` constructor. MemRef +only appears in generated/lowered IR, not in the public authoring type surface. + +## v1 Decorator Surface + +The supported v1 decorator surface is: + +```python +@pto.vkernel( + target="a5", + op="some_op_name", + dtypes=[(pto.f32, pto.f16, pto.i32)], + name="optional_name", + verify=True, +) +def kernel(...): + ... +``` + +Current rules: +- `target` only accepts `"a5"` +- `op` is required and must be a non-empty string +- `dtypes` must contain exactly one monomorphic signature tuple +- `name` is optional and defaults to the Python function name +- `verify` is optional and must be a bool + +The descriptor keeps these metadata fields: +- `target` +- `op` +- `dtypes` +- `name` +- `verify` + +## Parameter Typing + +v1 accepts these parameter categories: +- bare `TensorView` +- bare `Tile` +- scalar annotations such as `pto.i32`, `pto.f16`, `pto.f32`, `pto.AnyType`, or `pto.TypeVar("T")` + +Binding rules: +- the single `dtypes` signature binds parameter element types positionally +- `TensorView` parameters get their element dtype from the same position in + `dtypes` +- `Tile` parameters get their element dtype from the same position in `dtypes` +- scalar parameters must use a TileLang scalar-style annotation +- scalar annotations may be concrete scalar dtypes, wildcard dtypes, or + `TypeVar(...)` +- concrete scalar annotations must exactly match the dtype at the same position + in `dtypes` +- wildcard scalar annotations must accept the dtype at the same position in + `dtypes` +- `TypeVar(...)` scalar annotations bind to the selected dtype at the same + position in `dtypes` + +Example: + +```python +@pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.bf16, pto.i32)]) +def kernel(inp: pto.TensorView, tmp: pto.Tile, scale: pto.i32): + return None +``` + +In this example: +- `inp` binds to `f32` +- `tmp` binds to `bf16` +- `scale` binds to `i32` + +## Tile Specialization + +Bare `Tile` parameters are incomplete until descriptor-level specialization is +provided. + +The only supported completion path is: + +```python +specialized = descriptor.specialize( + tmp=pto.TileSpecialization( + shape=(16, 32), + memory_space=pto.MemorySpace.UB, + config=pto.TileConfig.from_mapping({"layout": "row_major"}), + ) +) +``` + +Current v1 Tile profile rules: +- Tile physical shape must be static +- Tile dimensions must be positive integers +- Tile rank must be 1D or 2D +- Tile memory space must be `MemorySpace.UB` +- `config` may be omitted, provided as `TileConfig`, or built from a dict + +Before all bare `Tile` parameters are specialized, the descriptor must reject: +- `mlir_text()` +- `mlir_module()` +- `verify()` +- `emit(path)` + +## Materialization API + +After all bare `Tile` parameters are specialized, the descriptor exposes: +- `mlir_text()` +- `mlir_module()` +- `verify()` +- `emit(path)` + +At this stage of the workflow, these APIs provide a stable descriptor/materialization +surface for the new package. They do not yet define the final TileLang DSL to +VPTO lowering behavior; that work belongs to +`add-tilelang-dsl-authoring-vpto-lowering`. + +## Frontend Diagnostics + +The v1 frontend fails fast for: +- unsupported decorator matcher features +- unsupported Python syntax +- arbitrary external calls +- unsupported `pto.*` op surface +- missing Tile specialization +- dynamic physical Tile shape +- illegal Tile profile + +Diagnostics are frontend errors, not deferred verifier failures. When source is +available, errors include file, line, and column information. + +## Minimal Validation + +The following commands are the minimal validation set for +`add-tilelang-dsl-core-foundation`: + +```bash +cmake --build build --target TileLangDSLPackage +python3 -c "import sys; sys.path.insert(0, 'build/python'); import tilelang_dsl; print(tilelang_dsl.__file__)" +ctest --test-dir build -R tilelang_dsl_import --output-on-failure +ctest --test-dir build -R tilelang_dsl_unittest --output-on-failure +``` + +What these commands confirm: +- the standalone `tilelang_dsl` package is staged into `build/python/` +- Python can import the staged package directly +- the dedicated import smoke test passes +- the focused unittest suite passes for descriptor API, specialization, and + diagnostics coverage + +For a direct source-location diagnostics smoke, run: + +```bash +tmp=$(mktemp /tmp/tilelang_dsl_diag_XXXX.py) +cat > "$tmp" <<'PY' +import tilelang_dsl as pto + +try: + @pto.vkernel(op="x", dtypes=[(pto.f32,)]) + def kernel(x: pto.TensorView): + while True: + return None +except pto.TileLangFrontendError as exc: + print(exc) +PY +PYTHONPATH=build/python python3 "$tmp" +rm -f "$tmp" +``` + +Expected output shape: + +```text +/tmp/tilelang_dsl_diag_XXXX.py:6:5: unsupported Python syntax `while` in TileLang DSL v1 +``` + +This confirms diagnostics are emitted against the authored DSL source file +rather than an internal lowering location. + +## Historical Deferred Features + +The following were intentionally out of scope for the original v1 core boundary +and were assigned to follow-up changes: +- multiple `dtypes` signatures +- `constraints` +- `priority` +- `AnyFloat`, `AnyInt`, `AnyType`, `AnyMask` +- `TypeVar` +- matcher registry and deterministic selection +- implicit vecscope inference +- raw pointer authoring surface +- advanced vector-family support + +Matcher-related extensions are deferred to +`extend-tilelang-dsl-matcher-and-advanced-surface`. +That follow-up capability is now implemented in the current package head; use +`tilelang-dsl/docs/matcher-and-advanced-surface-migration.md` for the updated +surface boundary instead of reading the list above as a statement about current +head behavior. diff --git a/tilelang-dsl/docs/vpto_spec/vpto-spec-current.md b/tilelang-dsl/docs/vpto_spec/vpto-spec-current.md new file mode 100644 index 000000000..39075df26 --- /dev/null +++ b/tilelang-dsl/docs/vpto_spec/vpto-spec-current.md @@ -0,0 +1,5383 @@ +# PTO micro Instruction Spec — Draft (A5) + +- v0.3: Add runtime block query and vector-interval legality notes; Normalize load/store distribution families; Update get_buf/rls_buf details +- v0.2: Update micro Instruction latency and throughput +- v0.1: Doc Init + +[toc] + +--- + +## Part I: Architecture Overview + +### Overview + +This document defines the PTO micro Instruction, a compiler-internal and externally facing specification designed to represent vector compute kernels within the PTO architecture. Much like NVVM provides a robust IR for GPU architectures, the PTO micro Instruction serves as the direct bridge between high-level programming models and the underlying hardware ISA, providing a precise, low-level representation of vector workloads explicitly designed for the Ascend 950 architecture. + +#### Position in the Stack and Layer Modeled + +The PTO micro Instruction operates as a very low-level intermediate representation within the PTO compiler stack. It is uniquely designed to accurately and comprehensively express all architectural information of the Ascend 950 hardware. It specifically models the bare-metal vector execution layer, making hardware-specific capabilities and constraints, such as exact vector lane configurations, memory space hierarchies, and hardware-specific fusion semantics, fully transparent and controllable. + +#### PTO Instruction Modes and Compilation Flows + +Within the end-to-end PTO software stack, PTO instructions may appear in three closely related authoring or lowering modes: + +- **PTO Tile Instruction**: tile-oriented PTO code that serves as a nano-kernel encapsulation of Tile operations, primarily expressing computation and data movement in terms of tile buffers, tile shapes, and tile-local layout. +- **PTO micro Instruction**: vector-execution-oriented PTO code that makes DMA setup, vector registers, masks, synchronization, and `__VEC_SCOPE__` boundaries explicit. This document is centered on this mode. +- **PTO Tile+micro Instruction**: a hybrid PTO form that keeps tile-level orchestration while embedding explicit micro-instruction regions where direct vector-pipeline control is required. + +From these PTO instruction forms, the stack can proceed along two main compilation flows: + +- **CCE generation flow**: PTO ISA is lowered into a CCE-oriented representation, which is then compiled by the BiSheng toolchain into Ascend device binaries. +- **Bytecode generation flow**: PTO ISA is emitted as bytecode, which is then compiled by the BiSheng toolchain into Ascend device binaries. + +```text +High-level frameworks / DSLs / library kernels + | + v + +----------------------------------+ + | PTO ISA layer | + | | + | (1) PTO Tile Instruction | + | (2) PTO micro Instruction | + | (3) PTO Tile+micro Instruction | + +----------------+-----------------+ + | + +------------+------------+ + | | + v v + +-------------------------+ +-------------------------+ + | Path A: generate CCE | | Path B: generate | + | (CCE-oriented form) | | bytecode | + +------------+------------+ +------------+------------+ + | | + v v + +-----------------------------------------------+ + | BiSheng compiler | + +---------------------------+-------------------+ + | + v + +-----------------------------+ + | Ascend device binaries | + +-----------------------------+ +``` + +#### Why External Developers Read or Author PTO micro Instruction + +While the majority of users will interact with the PTO architecture via higher-level frameworks, external developers may need to read or author PTO micro Instruction directly for several key reasons: + +- Custom Toolchain Development: build custom compiler frontends or domain-specific languages (DSLs) that target the Ascend 950 architecture with maximum hardware utilization. +- Performance Engineering: inspect the output of high-level compiler passes, verify fine-grained optimization behaviors, and pinpoint performance bottlenecks at the architectural level. +- Micro-Optimization: hand-author highly optimized, critical mathematical kernels using a stable, precise IR when higher-level abstractions cannot achieve the theoretical peak performance of the hardware. + +#### Relationship to CCE + +The PTO micro Instruction is designed to express the full semantic capabilities of the Compute Cube Engine (CCE), but with significant structural and pipeline advantages for compiler development. + +- Bypassing the C/Clang Pipeline: while CCE heavily relies on C/C++ extensions parsed by Clang, the PTO micro Instruction operates entirely independently of the C language frontend. By bypassing Clang AST generation and frontend processing, utilizing the PTO micro Instruction significantly reduces overall compilation time and memory overhead. +- Enhanced IR Verification: because the PTO micro Instruction is a strongly typed, SSA-based (Static Single Assignment) compiler IR rather than a C-wrapper API, it provides a much more rigorous and detailed IR verification process. Structural inconsistencies, invalid memory access patterns, and operand type mismatches are caught immediately with precise, explicit diagnostic feedback, providing developers with much higher visibility into kernel correctness than traditional CCE error reporting. + +#### Intended Audience + +This document is written for compiler engineers, library writers, and advanced performance architects. We expect the reader to have a working understanding of modern compiler infrastructure, specifically MLIR, the principles of Static Single Assignment (SSA) form, and a deep understanding of the vector-processing capabilities of the Ascend 950 architecture. + +### Getting Started + +The PTO micro Instruction is architected as a performance-critical layer within the compiler stack, specifically designed to exploit the **Decoupled Access-Execute** (DAE) nature of the Ascend 950 hardware. + +#### Hardware Pipeline Modeling + +The IR is structured to mirror the three primary hardware pipelines of the Ascend 950 architecture. Correct PTO micro Instruction authoring requires managing the interaction between these asynchronous units: + +**MTE2** (Memory Transfer Engine - Inbound): Responsible for moving data from Global Memory (GM) to the Unified Buffer (UB). + +**Vector Core** (Computation): The primary engine for executing SIMD operations on data stored in UB. + +**MTE3** (Memory Transfer Engine - Outbound): Responsible for moving processed data from UB back to GM. + +#### Architecture Detail: Vector Lane (VLane) + +The vector register is organized as **8 VLanes** of 32 bytes each. A VLane is the atomic unit for group reduction operations. + +``` +vreg (256 bytes total): +┌─────────┬─────────┬─────────┬─────┬─────────┬─────────┐ +│ VLane 0 │ VLane 1 │ VLane 2 │ ... │ VLane 6 │ VLane 7 │ +│ 32B │ 32B │ 32B │ │ 32B │ 32B │ +└─────────┴─────────┴─────────┴─────┴─────────┴─────────┘ +``` + +Elements per VLane by data type: + +| Data Type | Elements/VLane | Total Elements/vreg | +|-----------|---------------|-------------------| +| i8/si8/ui8 | 32 | 256 | +| i16/si16/ui16/f16/bf16 | 16 | 128 | +| i32/si32/ui32/f32 | 8 | 64 | +| i64/si64/ui64 | 4 | 32 | + +#### Memory and Synchronization Model + +The PTO micro Instruction enforces a strict memory hierarchy. The Unified Buffer (UB) is the only valid operand source for vector compute instructions. Consequently, the architecture of a PTO micro Instruction program is defined by the explicit management of data movement: + +**Address Space Isolation**: The IR uses `!pto.ptr` to distinguish between GM (`!pto.ptr`) and UB (`!pto.ptr`). The verifier ensures that vector compute operations do not access GM directly; data must first be moved into UB. + +**UB Capacity**: The Unified Buffer provides 256KB of on-chip SRAM (also referred to as "vecTile"). + +**Data Flow**: + +``` +┌─────────────────────────────────────────────┐ +│ Global Memory (GM) │ +│ (Off-chip HBM/DDR) │ +└─────────────────────┬───────────────────────┘ + │ DMA (MTE2 inbound / MTE3 outbound) +┌─────────────────────▼───────────────────────┐ +│ Unified Buffer (UB) │ +│ (On-chip SRAM, 256KB) │ +└─────────────────────┬───────────────────────┘ + │ Vector Load/Store (PIPE_V) +┌─────────────────────▼───────────────────────┐ +│ Vector Register File (VRF) │ +│ vreg (256B each) + mask (256-bit each) │ +└─────────────────────────────────────────────┘ +``` + +1. **GM → UB**: DMA transfer via MTE2 (`pto.copy_gm_to_ubuf`) +2. **UB → vreg**: Vector Load instructions (`pto.vlds`, `pto.vldsx2`, etc.) +3. **vreg → vreg**: Compute instructions (`pto.vadd`, `pto.vmul`, etc.) +4. **vreg → UB**: Vector Store instructions (`pto.vsts`, `pto.vstsx2`, etc.) +5. **UB → GM**: DMA transfer via MTE3 (`pto.copy_ubuf_to_gm`) + +**Load/Store Access Patterns**: + +For UB↔vreg data movement, besides contiguous load/store, the architecture provides rich access pattern support including strided access, pack/unpack, interleave/deinterleave, broadcast, upsample/downsample, channel split/merge, gather/scatter, and squeeze/expand operations. For detailed instruction syntax and distribution modes, refer to the [Vector Load/Store](#isa-03-vector-load-store) group in the ISA specification. + +#### Synchronization Model + +The Ascend 950 architecture employs a cluster-based design with a 1:2 ratio of Cube cores to Vector cores. The PTO micro Instruction provides multiple levels of synchronization to manage concurrent execution across pipelines and cores: + +**Inter-Core Synchronization (within a cluster):** + +Synchronization between cores within the same cluster is achieved via the core sync mechanism using `pto.set_intra_core` and `pto.wait_intra_core` operations. This enables coordination between Cube and Vector cores sharing the same cluster resources. + +**Vector Core Pipeline Synchronization:** + +Within a single core, multiple pipelines operate asynchronously: + +- **MTE2 (PIPE_MTE2)**: DMA copy-in from GM to UB +- **MTE3 (PIPE_MTE3)**: DMA copy-out from UB to GM +- **Vector Compute (PIPE_V)**: Vector ALU operations +- **Scalar (PIPE_S)**: Scalar unit running the kernel program + +Pipeline synchronization can be achieved through two mechanisms: + +1. **Flag/Event mechanism**: `pto.set_flag` and `pto.wait_flag` operations resolve Read-After-Write (RAW) and Write-After-Read (WAR) hazards between pipelines. + +2. **Buffer-ID mechanism**: `pto.get_buf` and `pto.rls_buf` provide finer-grained synchronization through buffer acquisition and release semantics for producer-consumer coordination. + +**Intra-Pipeline Memory Barriers (within `__VEC_SCOPE__`):** + +Within the vector execution scope, the hardware does not track UB address aliasing between reg↔UB accesses. When UB addresses overlap or alias between vector load/store operations, explicit memory barriers are required: + +```c +pto.mem_bar "VV_ALL" // All prior vector ops complete before subsequent +pto.mem_bar "VST_VLD" // All prior vector stores visible before subsequent loads +pto.mem_bar "VLD_VST" // All prior vector loads complete before subsequent stores +``` + +Without proper barriers, loads may see stale data or stores may be reordered incorrectly. + +#### Execution Scopes (__VEC_SCOPE__) + +`__VEC_SCOPE__` is the IR-level representation of a Vector Function (VF) launch. In the PTO architecture, it defines the hardware interface between the Scalar Unit and the Vector Thread. + +In PTO micro Instruction source IR, vector execution scopes are modeled as dedicated region ops. The default form is `pto.vecscope`; when the scope body must reject implicit capture and require explicit region arguments, use `pto.strict_vecscope`. + +**Scalar-Vector Interface:** + +The execution model follows non-blocking fork semantics: + +- Scalar invocation: the scalar processor invokes a vector thread by calling a VF. Once the launch command is issued, the scalar unit does not stall and continues executing subsequent instructions in the pipeline. +- Vector execution: after invocation, the vector thread independently fetches and executes the instructions defined within the VF scope. +- Parallelism: this decoupled execution allows the scalar and vector units to run in parallel, so the scalar unit can prepare addresses or manage control flow while the vector unit performs heavy SIMD computation. + +**Launch Mechanism And Constraints:** + +- Parameter buffering: all arguments required by the VF must be staged in hardware-specific buffers. +- Launch overhead: launching a VF incurs a latency of a few cycles. Very small VFs should account for this overhead because launch cost can rival useful computation time. + +**MLIR Representation:** + +```mlir +pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} +``` + +**Strict MLIR Representation:** + +```mlir +pto.strict_vecscope(%ub, %ub_out, %lane) { +^bb0(%in: !pto.ptr, %out: !pto.ptr, %iv: index): + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %in[%iv] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %out[%iv], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} : (!pto.ptr, !pto.ptr, index) -> () +``` + +`pto.strict_vecscope` is the strict form of `pto.vecscope`. + +- `pto.vecscope` allows the body to use surrounding SSA values directly. +- `pto.strict_vecscope` requires every external value used by the body to be passed through the op operand list and received as a body block argument. +- `pto.strict_vecscope` rejects implicit capture from the surrounding scope. +- both ops still represent one explicit VPTO vector interval. +- regardless of whether the source form uses `pto.vecscope`, + `pto.strict_vecscope`, or a lowered carrier loop with + `llvm.loop.aivector_scope`, every op that produces or consumes `!pto.vreg`, + `!pto.mask<...>`, or `!pto.align` must be enclosed by exactly one vector + interval +- nested vector intervals are not part of the legal VPTO surface; ordinary + nested `scf.for` structure is fine, but one vector interval may not contain + another vector interval + +### Example: VecScope + +```mlir +pto.set_loop2_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.copy_gm_to_ubuf %7, %2, %3, %3, %c0_i64, %c32_i64, %4, %c0_i64, %c0_i64, + %false, %c0_i64, %c128_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i64, i1, i64, i64, i64 + +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +pto.vecscope { + scf.for %lane = %c0 to %9 step %c64 { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %2[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %8[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } +} + +pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] +pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_ubtoout %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop2_stride_ubtoout %c4096_i64, %c4096_i64 : i64, i64 +pto.copy_ubuf_to_gm %8, %14, %3, %3, %c0_i64, %c32_i64, %4, %c0_i64, %c128_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i64, i64 +``` + +### Example: Strict VecScope + +```mlir +pto.strict_vecscope(%ub_in, %ub_out, %lane, %remaining) { +^bb0(%in: !pto.ptr, %out: !pto.ptr, %iv: index, %rem: i32): + %mask, %next_remaining = pto.plt_b32 %rem : i32 -> !pto.mask, i32 + %v = pto.vlds %in[%iv] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %out[%iv], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} : (!pto.ptr, !pto.ptr, index, i32) -> () +``` + +Use `pto.strict_vecscope` when the source form should make all vector-scope inputs explicit in the region signature instead of relying on surrounding SSA visibility. The scope op itself only defines the vector-interval boundary and region argument contract. + +### Scope + +This document is the interface specification centered on the `mlir::pto` dialect and the shared MLIR surface used alongside it in PTO micro Instruction programs. + +It only describes: + +- operation names +- operand and result lists +- operand and result types +- important attributes +- C-style semantics for each operation + +It does not describe lowering strategy. + +PTO micro Instruction source programs are not restricted to `pto` operations alone. In practice they also use shared MLIR dialect ops, most notably the full scalar operation surface of `arith` together with structured control-flow ops from `scf`, to express scalar constants, scalar arithmetic, type conversion, comparisons, and structured control flow around PTO vector or tile regions. These shared-dialect ops are part of the supported PTO micro Instruction source surface and should be regarded as part of PTO-ISA alongside `pto` dialect operations. + +### Shared MLIR Dialects + +- `arith`: the full scalar `arith` surface is supported in PTO micro Instruction programs, covering scalar integer, floating-point, boolean, and `index` operations. In current samples the most common uses are still constants, offset/bounds arithmetic, casts, compares, and selects. +- `scf`: structured control flow used to model counted loops, conditional regions, loop-carried state, and break-like control around PTO compute and data-movement ops. +- Shared dialect ops remain in standard MLIR form so that PTO analyses and backend passes can reason about control flow and scalar state without re-encoding them as PTO-specific instructions. + +### BlockDim Query Operations + +These ops expose the current kernel instance's execution coordinates to scalar code. They are the PTO-level equivalent of runtime queries such as `GetBlockIdx()` and `GetBlockNum()` in kernel programming models. + +Use them when the same kernel body is launched across multiple blocks or subblocks and each execution instance must figure out which slice of the global workload it owns. + +A common pattern is: + +- split the full input/output tensor into `block_num` disjoint block-sized regions +- let each block compute its own starting offset from `block_idx` +- within one block, further tile the local region and drive the tile loop with ordinary scalar `arith` / `scf` ops + +For example, if a tensor is split evenly across 8 blocks and each block handles `block_length = 2048` elements, then block `b` owns the global range `[b * block_length, (b + 1) * block_length)`. The per-block GM base pointer can be formed by adding `block_idx * block_length` elements to the original base pointer. + +At the PTO micro Instruction level, these runtime-query ops are pure scalar producers. They do not perform data movement, do not allocate memory, and do not by themselves create tiling or double buffering. Instead, they provide the scalar values used by surrounding address computation and structured control flow. + +#### Example: block-level data partitioning + +```mlir +%block = pto.get_block_idx +%block_num = pto.get_block_num +%block_len = arith.constant 2048 : index +%base = arith.index_cast %block : i64 to index +%offset = arith.muli %base, %block_len : index +%block_in = pto.addptr %gm_in, %offset : !pto.ptr -> !pto.ptr +%block_out = pto.addptr %gm_out, %offset : !pto.ptr -> !pto.ptr +``` + +In this pattern, all blocks execute the same kernel body, but each block sees a different `%block` value and therefore computes a different GM window. + +#### `pto.get_block_idx` + +- **syntax:** `%block = pto.get_block_idx` +- **result:** `i64` +- **semantics:** Return the current block ID in the range `[0, pto.get_block_num())`. + +```c +block = block_idx(); +``` + +#### `pto.get_subblock_idx` + +- **syntax:** `%subblock = pto.get_subblock_idx` +- **result:** `i64` +- **semantics:** Return the current subblock ID in the range `[0, pto.get_subblock_num())`. + +```c +subblock = subblock_idx(); +``` + +#### `pto.get_block_num` + +- **syntax:** `%block_num = pto.get_block_num` +- **result:** `i64` +- **semantics:** Return the total number of launched blocks visible to the current kernel instance. + +```c +block_num = block_num(); +``` + +#### `pto.get_subblock_num` + +- **syntax:** `%subblock_num = pto.get_subblock_num` +- **result:** `i64` +- **semantics:** Return the total number of visible subblocks for the current execution instance. + +```c +subblock_num = subblock_num(); +``` + +Typical usage: + +```mlir +%block = pto.get_block_idx +%subblock = pto.get_subblock_idx +%block_num = pto.get_block_num +%subblock_num = pto.get_subblock_num +``` + +### Core Types + +### Element Types +`vreg`: `!pto.vreg` Fixed-width PTO micro Instruction vector type with total width exactly 256 bytes (2048 bits). `N` is the lane count, `T` is the element type, and `N * bitwidth(T) = 2048`. + +| Type | Bits | Description | +|------|------|-------------| +| `i8` / `si8` / `ui8` | 8 | Signless/signed/unsigned 8-bit integer | +| `i16` / `si16` / `ui16` | 16 | Signless/signed/unsigned 16-bit integer | +| `i32` / `si32` / `ui32` | 32 | Signless/signed/unsigned 32-bit integer | +| `i64` / `si64` / `ui64` | 64 | Signless/signed/unsigned 64-bit integer | +| `f16` | 16 | IEEE 754 half precision | +| `bf16` | 16 | Brain floating point | +| `f32` | 32 | IEEE 754 single precision | + +### Mask Types + +`mask`: `!pto.mask` Typed predicate-register view. `G` is one of `b8`, `b16`, `b32` and records the byte-granularity interpretation used by VPTO ops and verifiers. + +Typed masks are also the primary legality contract for predicated VPTO code: + +- vector ops over `f32`, `i32`, `si32`, and `ui32` consume `!pto.mask` +- vector ops over `f16`, `bf16`, `i16`, `si16`, and `ui16` consume + `!pto.mask` +- vector ops over 8-bit element families consume `!pto.mask` +- compare families keep seed-mask and result-mask granularity aligned with the + compared vector family +- carry families keep carry-in, carry-out, and execution-mask granularity + aligned with the data-vector family +- mask-only ops that do not explicitly change granularity preserve the same `G` + +### Address Space Conventions + +PTO micro Instruction memory operands use `!pto.ptr`. This specification models the following memory-space attributes: + +| Space | Interpretation | +|-------|----------------| +| `gm` | Global Memory (GM), off-chip HBM/DDR storage | +| `ub` | Unified Buffer (UB), on-chip vector buffer | + +Typical pointer construction and pointer arithmetic follow the same `!pto.ptr<..., space>` form: + +```mlir +%0 = pto.castptr %c0 : i64 -> !pto.ptr +%1 = pto.addptr %0, %c1024 : !pto.ptr -> !pto.ptr +``` + +### `!pto.ptr` + +`!pto.ptr` is the typed pointer form used for explicit memory operands in PTO micro Instruction. + +- `T` is the element type associated with the pointed-to storage. +- `space` is the memory domain, typically `gm` or `ub` in this specification. +- A `pto.ptr` value carries an address plus its element-type / memory-space interpretation, but it does not carry tensor shape or stride metadata by itself. +- Tensor semantics are introduced separately through view-building operations such as `pto.make_tensor_view`. +- Pointer arithmetic is element-based rather than byte-based. + +Typical examples: + +- `!pto.ptr` +- `!pto.ptr` +- `!pto.ptr` + +### Pointer Operations + +#### `pto.castptr` + +- **syntax:** `%result = pto.castptr %addr : i64 -> !pto.ptr` +- **semantics:** Reinterpret a scalar address value as a typed PTO pointer in the target memory space. + +```c +result = (ptr)addr; +``` + +`pto.castptr` is a pointer-construction operation. It does not perform data movement and does not by itself imply any load/store side effect. + +#### `pto.addptr` + +- **syntax:** `%result = pto.addptr %ptr, %offset : !pto.ptr -> !pto.ptr` +- **semantics:** Compute a new pointer by advancing the base pointer by an element offset. + +```c +result = ptr + offset; // offset counted in elements, not bytes +``` + +`pto.addptr` preserves both the element type `T` and the memory-space tag `space`. + +#### `pto.load_scalar` + +- **syntax:** `%value = pto.load_scalar %ptr[%offset] : !pto.ptr -> T` +- **semantics:** Load one scalar element from a pointer-like operand. + +```c +value = ptr[offset]; +``` + +- **inputs:** + `%ptr` is a typed PTO pointer `!pto.ptr`, and `%offset` is an + `index` displacement counted in elements. +- **outputs:** + `%value` is the loaded scalar element. +- **constraints and limitations:** + The result type MUST match the element type of `%ptr`. This op is a scalar + memory helper; unlike `pto.vlds`, it does not produce a `vreg` result and + does not participate in vector load `dist` families. + +#### `pto.store_scalar` + +- **syntax:** `pto.store_scalar %value, %ptr[%offset] : !pto.ptr, T` +- **semantics:** Store one scalar element to a pointer-like operand. + +```c +ptr[offset] = value; +``` + +- **inputs:** + `%value` is the scalar value to store. `%ptr` is a typed PTO pointer + `!pto.ptr`, and `%offset` is an `index` displacement counted in + elements. +- **constraints and limitations:** + The stored value type MUST match the element type of `%ptr`. This op is a + scalar memory helper; unlike `pto.vsts`, it does not consume a mask and does + not target vector-store `dist` families. + +#### Pointer-Based Vector Access Example + +The following lowered-style fragment shows how typed PTO pointers flow through +pointer construction, pointer arithmetic, structured control flow, and PTO +memory ops. Scalar memory access is expressed on `!pto.ptr` in +general, but the common VPTO pattern here is UB-local scalar access alongside +UB vector loads/stores: + +```mlir +%0 = pto.castptr %c0 : i64 -> !pto.ptr +%1 = pto.addptr %0, %c1024 : !pto.ptr -> !pto.ptr +pto.vecscope { + %16 = scf.for %arg3 = %c0 to %11 step %c64 iter_args(%arg4 = %12) -> (i32) { + %mask, %scalar_out = pto.plt_b32 %arg4 : i32 -> !pto.mask, i32 + %s = pto.load_scalar %1[%c4] : !pto.ptr -> f32 + pto.store_scalar %s, %1[%c8] : !pto.ptr, f32 + %17 = pto.vlds %1[%arg3] : !pto.ptr -> !pto.vreg<64xf32> + %18 = pto.vabs %17, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %18, %10[%arg3], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %scalar_out : i32 + } +} +``` + +In this pattern, `pto.castptr` materializes a typed UB pointer, `pto.addptr` shifts the base by 1024 `f32` elements, and the subsequent `[%arg3]` indexing on `pto.vlds` / `pto.vsts` applies an additional element offset relative to that base. + +### Special Types + +#### `!pto.mask` + +`!pto.mask` models an A5 predicate register (256-bit) under a typed granularity view, not an integer vector. + +`G` is part of the type and MUST be one of: + +- `b32` +- `b16` +- `b8` + +All three forms describe the same physical 256-bit predicate-register class. The type parameter does not encode how many lanes are currently active. Instead, it records how VPTO interprets the register when matching mask-producing ops, mask-consuming ops, and verifier legality rules. + +In the ISA chapters below, this document uses `!pto.mask` as shorthand when a +family is generic over granularity. For op families whose names already encode +the granularity, such as `pset_b32`, `pge_b16`, `plt_b8`, +`pdintlv_b8`, and `pintlv_b16`, examples use the corresponding concrete typed +mask. + +**Mask Granularity:** + +The predicate register is 256 bits in length, where each bit controls 1 byte of data. `G` therefore describes how many bytes form one logical element slot: + +| Mask Type | Bytes / Element Slot | Typical Element Family | Derived Logical Lanes | +|-----------|----------------------|------------------------|-----------------------| +| `!pto.mask` | 4 | `f32` / `i32` | 64 | +| `!pto.mask` | 2 | `f16` / `bf16` / `i16` | 128 | +| `!pto.mask` | 1 | 8-bit element family | 256 | + +This is intentionally different from a lane-vector model such as `mask<64xi1>`: + +- `!pto.mask` still denotes a 256-bit predicate register; +- `64` is only the derived logical lane count for the `b32` view; +- value-level patterns such as `PAT_VL32` describe which lanes are active, not a different type. +- `pto.vaddc`, `pto.vsubc`, `pto.vaddcs`, and `pto.vsubcs` use `!pto.mask` + to carry their per-lane carry results, interpreted with this same + granularity. + +**Predication Behavior (Zero-Merge):** + +The native hardware predication mode is **ZEROING** — inactive lanes produce zero: + +```c +dst[i] = mask[i] ? op(src0[i], src1[i]) : 0 // ZEROING mode +``` + +```mlir +// Predicated add: inactive lanes produce zero +%mask = pto.pset_b32 "PAT_VL32" : !pto.mask // first 32 logical b32 lanes active +%result = pto.vcmp %a, %b, %mask, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +``` + +```mlir +// Compare and select: generate mask from comparison, use for conditional select +%mask = pto.vcmp %lhs, %rhs, %seed, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +%out = pto.vsel %x, %y, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +#### `!pto.align` + +`!pto.align` models the A5 vector-align carrier state. It is not payload data. + +```mlir +%align = pto.vldas %ub : !pto.ptr -> !pto.align +%vec, %align_out = pto.vldus %ub, %align : !pto.ptr, !pto.align -> !pto.vreg<64xf32>, !pto.align + +%store_align = pto.init_align : !pto.align +%next_align = pto.vstus %store_align, %offset, %vec, %ub + : !pto.align, i32, !pto.vreg<64xf32>, !pto.ptr -> !pto.align +``` + +--- + +## Part II: Notation Convention + +This section defines the MLIR syntax patterns and C-style semantic notation used throughout the ISA reference (Part III). + +### MLIR Op Syntax Patterns + +All PTO micro Instruction operations follow standard MLIR syntax. The common patterns are: + +**Unary (one vector in, one vector out):** + +```mlir +%result = pto. %input : !pto.vreg -> !pto.vreg +``` + +**Binary (two vectors in, one vector out):** + +```mlir +%result = pto. %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg +``` + +**Vec-Scalar (one vector + one scalar in, one vector out):** + +```mlir +%result = pto. %input, %scalar : !pto.vreg, T -> !pto.vreg +``` + +**Load (memory to register):** + +```mlir +%result = pto.vlds %source[%offset] {dist = "DIST"} : !pto.ptr -> !pto.vreg +``` + +**Store (register to memory):** + +```mlir +pto.vsts %value, %destination[%offset] {dist = "DIST"} : !pto.vreg, !pto.ptr +``` + +**Dual Load (one load, two results — deinterleave):** + +```mlir +%low, %high = pto.vldsx2 %source[%offset], "DIST" : !pto.ptr, index -> !pto.vreg, !pto.vreg +``` + +**Dual Store (two inputs, one interleaved store):** + +```mlir +pto.vstsx2 %low, %high, %dest[%offset], "DIST", %mask : !pto.vreg, !pto.vreg, !pto.ptr, index, !pto.mask +``` + +**Compare (two vectors + seed mask in, mask out):** + +```mlir +%mask = pto.vcmp %src0, %src1, %seed, "CMP_MODE" : !pto.vreg, !pto.vreg, !pto.mask -> !pto.mask +``` + +**Conversion (one vector in, different-typed vector out):** + +```mlir +%result = pto.vcvt %input, %mask {rnd = "R", sat = "SAT", part = "EVEN"} : !pto.vreg, !pto.mask -> !pto.vreg +``` + +**Predicate construction:** + +```mlir +%mask = pto.pset_b32 "PAT_ALL" : !pto.mask +%tail = pto.pge_b32 "PAT_VL16" : !pto.mask +``` + +**Sync operations:** + +```mlir +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +**Pointer construction and arithmetic:** + +```mlir +%ptr = pto.castptr %addr : i64 -> !pto.ptr +%ptr2 = pto.addptr %ptr, %offset : !pto.ptr -> !pto.ptr +``` + +### Shared Dialect Syntax Patterns + +PTO micro Instruction programs may interleave PTO ops with standard MLIR `arith` and `scf` ops. +The examples below emphasize common index-heavy patterns, but `arith` support is not limited to index arithmetic. + +**Scalar / index constant:** + +```mlir +%c0 = arith.constant 0 : index +%zero = arith.constant 0.0 : f32 +``` + +**Scalar arithmetic (integer / float / boolean-style bitwise):** + +```mlir +%sum_i = arith.addi %lhs_i, %rhs_i : i32 +%sum_f = arith.addf %lhs_f, %rhs_f : f32 +%bits = arith.andi %flags0, %flags1 : i32 +``` + +**Scalar compare and select:** + +```mlir +%cond = arith.cmpi eq, %lhs, %rhs : index +%bound = arith.select %cond, %a, %b : index +``` + +**Counted loop with loop-carried values:** + +```mlir +%result = scf.for %iv = %lb to %ub step %step + iter_args(%acc = %init) -> (index) { + %next = arith.addi %acc, %iv : index + scf.yield %next : index +} +``` + +**Structured conditional region:** + +```mlir +%selected = scf.if %cond -> (index) { + scf.yield %then_value : index +} else { + scf.yield %else_value : index +} +``` + +**Structured while loop:** + +```mlir +%state:2 = scf.while (%iv = %c0, %alive = %true) : (index, i1) -> (index, i1) { + %keep_going = arith.cmpi slt, %iv, %limit : index + scf.condition(%keep_going) %iv, %alive : index, i1 +} do { +^bb0(%iv_in: index, %alive_in: i1): + %iv_next = arith.addi %iv_in, %c1 : index + scf.yield %iv_next, %alive_in : index, i1 +} +``` + +### C-Style Semantics Convention + +For each ISA operation in Part III, semantics are expressed as C code. The convention: + +```c +// Vector register contents as arrays: +T dst[N]; // destination +T src0[N]; // first source +T src1[N]; // second source (binary ops) +T scalar; // scalar operand (vec-scalar ops) +int mask[N]; // per-lane predicate (0 or 1) + +// N = lane count determined by type: +// N = 256 for i8/si8/ui8 +// N = 128 for i16/si16/ui16/f16/bf16 +// N = 64 for i32/si32/ui32/f32 +// N = 32 for i64/si64/ui64 +``` + +**Example — pto.vadd semantics:** + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] + src1[i]; +``` + +**Example — pto.vcgadd (group reduction per VLane) semantics:** + +```c +int K = N / 8; // elements per VLane +for (int g = 0; g < 8; g++) { + T sum = 0; + for (int i = 0; i < K; i++) + sum += src[g*K + i]; + dst[g*K] = sum; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +### Template Placeholder Conventions + +| Placeholder | Meaning | +|-------------|---------| +| `"SRC_PIPE"`, `"DST_PIPE"` | Pipeline identifiers: `"PIPE_MTE2"`, `"PIPE_V"`, `"PIPE_MTE3"` | +| `"EVENT_ID"` | Event identifier: `"EVENT_ID0"` etc. | +| `"DIST"` | Distribution mode string (see the relevant load/store ISA group in Part III) | +| `"CMP_MODE"` | Compare predicate: `eq \| ne \| lt \| le \| gt \| ge` | +| `"RND"` | Rounding mode: `R \| A \| F \| C \| Z \| O` | +| `"SAT"` | Saturation: `SAT \| NOSAT` | +| `"PART"` | Half selector: `EVEN \| ODD` | +| `"PAT_*"` | Predicate pattern literal | +| `T` | Element type (f32, f16, bf16, i32, i16, i8, etc.) | +| `N` | Lane count (`N * bitwidth(T) = 2048`) | + +--- + +### C-Style Semantics Convention + +For each ISA operation in Part III, semantics are expressed as C code. The convention: + +```c +// Vector register contents as arrays: +T dst[N]; // destination +T src0[N]; // first source +T src1[N]; // second source (binary ops) +T scalar; // scalar operand (vec-scalar ops) +int mask[N]; // per-lane predicate (0 or 1) + +// N = lane count determined by type: +// N = 256 for i8/si8/ui8 +// N = 128 for i16/si16/ui16/f16/bf16 +// N = 64 for i32/si32/ui32/f32 +// N = 32 for i64/si64/ui64 +``` + +**Example — pto.vadd semantics:** + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] + src1[i]; +``` + +**Example — pto.vcgadd (group reduction per VLane) semantics:** + +```c +int K = N / 8; // elements per VLane +for (int g = 0; g < 8; g++) { + T sum = 0; + for (int i = 0; i < K; i++) + sum += src[g*K + i]; + dst[g*K] = sum; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +### Template Placeholder Conventions + +| Placeholder | Meaning | +|-------------|---------| +| `"SRC_PIPE"`, `"DST_PIPE"` | Pipeline identifiers: `"PIPE_MTE2"`, `"PIPE_V"`, `"PIPE_MTE3"` | +| `"EVENT_ID"` | Event identifier: `"EVENT_ID0"` etc. | +| `"DIST"` | Distribution mode string (see the relevant load/store ISA group in Part III) | +| `"CMP_MODE"` | Compare predicate: `eq \| ne \| lt \| le \| gt \| ge` | +| `"RND"` | Rounding mode: `R \| A \| F \| C \| Z \| O` | +| `"SAT"` | Saturation: `SAT \| NOSAT` | +| `"PART"` | Half selector: `EVEN \| ODD` | +| `"PAT_*"` | Predicate pattern literal | +| `T` | Element type (f32, f16, bf16, i32, i16, i8, etc.) | +| `N` | Lane count (`N * bitwidth(T) = 2048`) | + +--- + +## Part III: ISA Instruction Reference — Summary + +This section provides a categorized overview of all PTO micro Instruction operations plus the shared MLIR `arith` and `scf` ops that may appear in PTO micro Instruction programs. Detailed documentation for each group is included later in this merged document. + +--- + +## Instruction Groups + +| # | Group | Description | Count | Details | +|---|-------|-------------|-------|---------| +| 1 | [Pipeline Sync](#isa-01-pipeline-sync) | Intra-core pipeline synchronization | 5 | `pto.set_flag`, `pto.wait_flag`, `pto.pipe_barrier`, `pto.get_buf`, `pto.rls_buf` | +| 2 | [DMA Copy Programming](#isa-02-dma-copy) | DMA configuration and transfer between GM↔UB | 10 | `pto.set_loop*_stride_*`, `pto.set_loop_size_*`, `pto.set_mov_pad_val`, `pto.copy_gm_to_ubuf`, `pto.copy_ubuf_to_ubuf`, `pto.copy_ubuf_to_gm` | +| 3 | [Vector Load/Store](#isa-03-vector-load-store) | UB↔vreg data movement with various access patterns | ~20 | `pto.vlds`, `pto.vldsx2`, `pto.vgather2`, `pto.vsts`, `pto.vstsx2`, `pto.vscatter`, etc. | +| 4 | [Predicate Load/Store](#isa-04-predicate-load-store) | UB↔mask register movement | 5 | `pto.plds`, `pto.pldi`, `pto.psts`, `pto.psti`, `pto.pstu` | +| 5 | [Materialization & Predicate Ops](#isa-05-materialization-predicate) | Scalar broadcast, predicate generation and manipulation | ~17 | `pto.vbr`, `pto.vdup`, `pto.pset_b*`, `pto.pge_b*`, `pto.plt_b*`, `pto.ppack`, `pto.punpack`, `pto.pnot`, `pto.psel`, etc. | +| 6 | [Unary Vector Ops](#isa-06-unary-vector-ops) | Single-input element-wise operations | 6 | `pto.vabs`, `pto.vexp`, `pto.vln`, `pto.vsqrt`, `pto.vrelu`, `pto.vnot` | +| 7 | [Binary Vector Ops](#isa-07-binary-vector-ops) | Two-input element-wise operations | 13 | `pto.vadd`, `pto.vsub`, `pto.vmul`, `pto.vdiv`, `pto.vmax`, `pto.vmin`, `pto.vand`, `pto.vor`, `pto.vxor`, `pto.vshl`, `pto.vshr`, `pto.vaddc`, `pto.vsubc` | +| 8 | [Vec-Scalar Ops](#isa-08-vec-scalar-ops) | Vector-scalar operations | 9 | `pto.vadds`, `pto.vmuls`, `pto.vmaxs`, `pto.vmins`, `pto.vlrelu`, `pto.vshls`, `pto.vshrs`, `pto.vaddcs`, `pto.vsubcs` | +| 9 | [Conversion Ops](#isa-09-conversion-ops) | Type conversion with rounding/saturation control | 2 | `pto.vcvt`, `pto.vtrc` | +| 10 | [Reduction Ops](#isa-10-reduction-ops) | Vector reductions | 7 | `pto.vcadd`, `pto.vcmax`, `pto.vcmin`, `pto.vcgadd`, `pto.vcgmax`, `pto.vcgmin`, `pto.vcpadd` | +| 11 | [Compare & Select](#isa-11-compare-select) | Comparison and conditional selection | 4 (+1 not A5) | `pto.vcmp`, `pto.vcmps`, `pto.vsel`, `pto.vselr` (`pto.vselrv2` removed: not A5) | +| 12 | [Data Rearrangement](#isa-12-data-rearrangement) | In-register data movement and permutation | 2 (+2 not A5) | `pto.vintlv`, `pto.vdintlv` (`pto.vintlvv2`, `pto.vdintlvv2` removed: not A5) | +| 13 | [DSA/SFU Ops](#isa-13-dsa-sfu-ops) | Specialized ops, index generation, and sorting helpers | 9 | `pto.vlrelu`, `pto.vprelu`, `pto.vexpdif`, `pto.vaxpy`, `pto.vmull`, `pto.vmula`, `pto.vci`, `pto.vbitsort`, `pto.vmrgsort4` | +| 14 | [Arith (Shared MLIR Dialect)](#isa-14-shared-arith) | Full scalar `arith` surface used around PTO ops; the companion page lists categories and representative examples | all scalar ops | `arith.constant`, `arith.addi`, `arith.addf`, `arith.cmpi`, `arith.cmpf`, `arith.select`, `arith.index_cast`, `arith.extsi`, `arith.trunci`, `arith.andi`, `arith.shli`, etc. | +| 15 | [SCF (Shared MLIR Dialect)](#isa-15-shared-scf) | Structured loops, branches, and loop-carried state around PTO regions | 5 | `scf.for`, `scf.if`, `scf.while`, `scf.condition`, `scf.yield` | + +--- + +## Detailed ISA Group Reference + +This section inlines the 15 ISA group documents so the architectural overview, notation, summary table, and per-group semantics can be read in a single file. + + + +### 1. Pipeline Synchronization + +> **Category:** Synchronization primitives for coordinating pipeline execution +> **Pipelines:** MTE2 (GM→UB), PIPE_V (Vector), MTE3 (UB→GM) + +The PTO micro Instruction model operates on the Ascend 950's **Decoupled Access-Execute** architecture. The MTE and Vector pipelines run asynchronously, requiring explicit synchronization to prevent data hazards. + +--- + +#### Intra-Core Pipeline Sync + +These ops coordinate data flow between pipelines within a single vector core. + +##### `pto.set_flag` + +- **syntax:** `pto.set_flag["SRC_PIPE", "DST_PIPE", "EVENT_ID"]` +- **semantics:** Signal event from source pipe to destination pipe. + +```c +set_flag(src_pipe, dst_pipe, event_id); +``` + +**Example:** After MTE2 completes GM→UB transfer, signal Vector pipe: +```mlir +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +--- + +##### `pto.wait_flag` + +- **syntax:** `pto.wait_flag["SRC_PIPE", "DST_PIPE", "EVENT_ID"]` +- **semantics:** Block destination pipe until source pipe signals event. + +```c +wait_flag(src_pipe, dst_pipe, event_id); +``` + +**Example:** Vector pipe waits for MTE2 data to arrive: +```mlir +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +--- + +##### `pto.pipe_barrier` + +- **syntax:** `pto.pipe_barrier "PIPE_*"` +- **semantics:** Drain all pending ops in the specified pipe. All previously issued operations on that pipe complete before any subsequent operation begins. + +```c +pipe_barrier(pipe); +``` + +**Pipe identifiers:** `PIPE_MTE2`, `PIPE_V`, `PIPE_MTE3` + +**Example:** Two back-to-back `copy_ubuf_to_gm` calls writing to the same GM address. Without a barrier, MTE3 may reorder them and the final GM value is non-deterministic: + +```mlir +// Both stores target the same GM address — order matters! +pto.copy_ubuf_to_gm %ub_partial_0, %gm_result, ... +// Without pipe_barrier, MTE3 could execute the second copy before the first +// completes, producing a non-deterministic result at %gm_result. +pto.pipe_barrier "PIPE_MTE3" +// After barrier: first copy is guaranteed complete. Second copy overwrites deterministically. +pto.copy_ubuf_to_gm %ub_partial_1, %gm_result, ... +``` + +--- + +##### `pto.get_buf` + +- **syntax:** `pto.get_buf "PIPE_*", %buf_id, %mode : i64, i64` +- **semantics:** Acquire buffer slot for inter-pipeline double-buffering coordination. + +```c +get_buf(pipe, buf_id, mode); +``` + +--- + +##### `pto.rls_buf` + +- **syntax:** `pto.rls_buf "PIPE_*", %buf_id, %mode : i64, i64` +- **semantics:** Release buffer slot to allow other pipeline to proceed. + +```c +rls_buf(pipe, buf_id, mode); +``` + +--- + +##### Mode Parameter for `get_buf` / `rls_buf` + +The `mode` parameter controls how `get_buf` and `rls_buf` interact with pipeline execution and dependency tracking: + +| Mode | `get_buf` Behavior | `rls_buf` Behavior | Use Case | +|------|-------------------|-------------------|----------| +| **0** (default) | **Blocking acquire**: waits for all previous `rls_buf` with same `buf_id` from all pipelines (in program order) before the specified pipe can proceed | **Immediate release**: signals completion for only the instructions related to the specified pipe | **Automatic ping/pong dependency** — recommended for double/multi-buffering | +| **1** | **Non-blocking acquire**: does not wait; pipe execution proceeds immediately | **Deferred release**: waits for all instructions across all pipelines with same `buf_id` to retire before signaling | **Backward compatibility** with `set_flag`/`wait_flag` semantics | + +**Mode 0 (Default — Recommended):** +- `get_buf`: The specified pipeline blocks until all previous `rls_buf` operations for the same buffer ID (from any pipeline) have completed, respecting program order. +- `rls_buf`: Immediately signals that the specified pipeline has finished using the buffer — only waits for that pipe's related instructions. +- This mode provides **automatic RAW/WAR/WAW dependency resolution** based on buffer ID and program order, making it ideal for ping/pong and N-buffer patterns. + +**Mode 1 (Legacy Compatibility):** +- `get_buf`: Does not block — the pipeline proceeds immediately without waiting. +- `rls_buf`: Waits for **all** previous instructions across **all** pipelines with the same buffer ID to retire before signaling release. +- This mode emulates `set_flag`/`wait_flag` behavior and is provided for backward compatibility with existing code patterns. + +> **Note:** A5 supports both `set_flag`/`wait_flag` and `get_buf`/`rls_buf` mechanisms. Mode 1 is rarely needed since mode 0 provides a more programmer-friendly approach for buffer-based synchronization. + +--- + +##### `pto.mem_bar` + +- **syntax:** `pto.mem_bar "BARRIER_TYPE"` +- **semantics:** Intra-vector-pipe memory fence within `__VEC_SCOPE__`. Required when UB addresses alias between vector load/store operations. + +```c +mem_bar(barrier_type); +``` + +**Barrier types:** + +| Type | Semantics | +|------|-----------| +| `VV_ALL` | All prior vector ops complete before subsequent | +| `VST_VLD` | All prior vector stores visible before subsequent loads | +| `VLD_VST` | All prior vector loads complete before subsequent stores | + +**Example:** Ensure stores are visible before loads to same UB region: +```mlir +pto.vsts %v0, %ub[%c0] : !pto.vreg<64xf32>, !pto.ptr +pto.mem_bar "VST_VLD" +%v1 = pto.vlds %ub[%c0] : !pto.ptr -> !pto.vreg<64xf32> +``` + +--- + +#### Why `get_buf` / `rls_buf` is More Programmer-Friendly + +The buffer-based synchronization (`get_buf`/`rls_buf`) provides the **same functional capability** as `set_flag`/`wait_flag` for maintaining correct ordering of RAW/WAR/WAW dependencies across pipelines, but with significant usability advantages: + +##### 1. No Manual Priming or Draining + +With `set_flag`/`wait_flag`, ping/pong loops require: +- **Pre-loop priming**: 4× `set_flag` to initialize reverse-dependency signals (otherwise first iteration deadlocks) +- **Post-loop draining**: 4× `wait_flag` to consume leftover signals from final iterations + +With `get_buf`/`rls_buf`: +- **First iteration**: Buffer is initially free, so `get_buf` proceeds immediately — no priming needed +- **Final iteration**: Last `rls_buf` simply completes — no draining required + +##### 2. No Loop Peeling for Complex Dependencies + +For non-1:1 producer-consumer ratios (e.g., 1 MTE2 load : N Vector compute slices), `set_flag`/`wait_flag` requires **peeling the set_flag outside the loop**: + +```mlir +// set_flag/wait_flag: 1 MTE2 load, 8 Vector computes on slices +// MTE2 loads large tile once +pto.copy_gm_to_ubuf %gm_ptr, %ub_tile, ... +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVT_TILE_READY"] // ◀ MUST be outside loop + +// Vector consumes in 8 slices — but wait_flag can only fire ONCE +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVT_TILE_READY"] // ◀ MUST peel before loop +scf.for %slice = %c0 to %c8 step %c1 { + // compute on %ub_tile[%slice] + // Cannot put wait_flag here — would deadlock on iteration 1+ +} +``` + +With `get_buf`/`rls_buf`, acquire/release can be **inside the loop** — no peeling needed: + +```mlir +// get_buf/rls_buf: same 1:8 pattern, acquire/release inside loop works fine +// MTE2 loads large tile +pto.get_buf "PIPE_MTE2", %bufid_tile, %c0 : i64, i64 +pto.copy_gm_to_ubuf %gm_ptr, %ub_tile, ... +pto.rls_buf "PIPE_MTE2", %bufid_tile, %c0 : i64, i64 + +// Vector acquires/releases per slice — all 8 iterations work correctly +scf.for %slice = %c0 to %c8 step %c1 { + pto.get_buf "PIPE_V", %bufid_tile, %c0 : i64, i64 // iteration 0: blocks until MTE2 done + // iteration 1-7: proceeds immediately (already acquired) + // compute on %ub_tile[%slice] + pto.rls_buf "PIPE_V", %bufid_tile, %c0 : i64, i64 +} +// No peeling required — get_buf handles the MTE2→V dependency automatically +``` + +##### 3. Simpler Mental Model + +| Aspect | `set_flag`/`wait_flag` | `get_buf`/`rls_buf` | +|--------|------------------------|---------------------| +| **Dependency tracking** | Manual: track event IDs, signal directions, pair every set with wait | Automatic: buffer ID + program order | +| **Event ID management** | **8 IDs per pipe-pair direction** (HW limit); each buffer occupies 1 ID per direction | **1 buffer ID per shared resource** (HW limit: 32 global); same ID used across all pipelines | +| **Error-prone areas** | Forgetting prime/drain, mismatched IDs, wrong direction | Forgetting release (but compile-time checkable) | + +##### Quick Example Comparison + +**Problem:** MTE2 loads into `buf[i%2]`, Vector processes, MTE3 stores — standard ping/pong. + +**set_flag/wait_flag approach:** +```mlir +// BEFORE loop: prime 4 reverse-dep signals +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_0"] +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_1"] +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_0"] +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_1"] + +scf.for %i = ... { + // 4 set_flag + 4 wait_flag inside loop + // Must track 4 IDs: 2 pipe-pair directions × 2 ping/pong buffers +} + +// AFTER loop: drain 4 signals +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_0"] +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_1"] +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_0"] +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_1"] +``` + +**get_buf/rls_buf approach:** +```mlir +scf.for %i = ... { + pto.get_buf %bufid_in[%pp], "PIPE_MTE2" + // ... MTE2 work ... + pto.rls_buf %bufid_in[%pp], "PIPE_MTE2" + + pto.get_buf %bufid_in[%pp], "PIPE_V" + pto.get_buf %bufid_out[%pp], "PIPE_V" + // ... Vector work ... + pto.rls_buf %bufid_in[%pp], "PIPE_V" + pto.rls_buf %bufid_out[%pp], "PIPE_V" + + pto.get_buf %bufid_out[%pp], "PIPE_MTE3" + // ... MTE3 work ... + pto.rls_buf %bufid_out[%pp], "PIPE_MTE3" +} +// Done. No prime. No drain. Dependencies resolved by buffer ID + program order. +``` + +--- + +#### Intra-Core Sync Patterns & Examples + +##### Example 1: `set_flag` / `wait_flag` (Explicit Events) + +Each cross-pipeline data dependency requires an explicit signal/wait pair. The programmer must manually insert `set_flag` after the producer and `wait_flag` before the consumer. + +```mlir +// ─── Stage 1: MTE2 loads data from GM into UB ─── +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, ... + +// MTE2 signals: "UB data is ready for Vector pipe" +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +// ─── Stage 2: Vector pipe consumes UB data ─── +// Vector waits until MTE2's signal arrives +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +pto.vecscope { + %v = pto.vlds %ub_ptr[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} + +// Vector signals: "UB output is ready for MTE3" +pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + +// ─── Stage 3: MTE3 stores result from UB back to GM ─── +// MTE3 waits until Vector's signal arrives +pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + +pto.copy_ubuf_to_gm %ub_out, %gm_out, ... +``` + +**Key property:** Every cross-pipeline edge is an explicit `(set_flag, wait_flag)` pair. Simple for straight-line code, but gets verbose in loops (see Example 3). + +--- + +##### Example 2: `get_buf` / `rls_buf` (Resource-Based) + +Instead of naming events, each pipeline declares when it **acquires** (`get_buf`) and **releases** (`rls_buf`) a shared UB buffer. Cross-pipeline RAW/WAR dependencies are resolved implicitly by program order — if MTE2 releases `buf_A` and Vector later acquires `buf_A`, the hardware ensures the acquire cannot proceed until the release completes. + +```mlir +// ─── Stage 1: MTE2 loads data into UB ─── +// MTE2 acquires ub_ptr — blocks if Vector hasn't released it from a prior iteration +pto.get_buf "PIPE_MTE2", %bufid_ub_ptr, %c0 : i64, i64 // mode=0 (default) +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, ... +// MTE2 done writing ub_ptr — release it so Vector can consume +pto.rls_buf "PIPE_MTE2", %bufid_ub_ptr, %c0 : i64, i64 + +// ─── Stage 2: Vector computation ─── +// Vector acquires ub_ptr (input) — blocks until MTE2 releases it (RAW: MTE2 write → V read) +pto.get_buf "PIPE_V", %bufid_ub_ptr, %c0 : i64, i64 +// Vector acquires ub_out (output) — blocks until MTE3 releases it from a prior iteration (WAR: MTE3 read → V write) +pto.get_buf "PIPE_V", %bufid_ub_out, %c0 : i64, i64 + +pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub_ptr[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} + +// Vector done reading ub_ptr — release so MTE2 can reuse it in next iteration +pto.rls_buf "PIPE_V", %bufid_ub_ptr, %c0 : i64, i64 +// Vector done writing ub_out — release so MTE3 can consume +pto.rls_buf "PIPE_V", %bufid_ub_out, %c0 : i64, i64 + +// ─── Stage 3: MTE3 stores result to GM ─── +// MTE3 acquires ub_out — blocks until Vector releases it (RAW: V write → MTE3 read) +pto.get_buf "PIPE_MTE3", %bufid_ub_out, %c0 : i64, i64 +pto.copy_ubuf_to_gm %ub_out, %gm_out, ... +// MTE3 done reading ub_out — release so Vector can reuse it in next iteration +pto.rls_buf "PIPE_MTE3", %bufid_ub_out, %c0 : i64, i64 +``` + +**Key property:** No event IDs needed. Dependencies are implicit from program order of `get_buf`/`rls_buf` on the same buffer ID. This becomes much more convenient in multi-iteration loops (see Example 3). + +--- + +##### Example 3: Ping/Pong Double-Buffering Loop + +Double-buffering overlaps DMA and compute by using two UB buffers alternately. All three stages (MTE2, Vector, MTE3) appear in the **same iteration** — the hardware pipelines them across iterations because different iterations operate on different buffers (`buf[i%2]`). + +###### Event ID scheme (`set_flag` / `wait_flag`) + +With 2 ping/pong buffers and 2 pipeline pairs (MTE2↔V, V↔MTE3), `set_flag`/`wait_flag` needs **8 event IDs** = 2 pipe-pairs × 2 buffers × (forward + reverse): + +**MTE2 ↔ Vector (input buffers):** + +| Event ID | Direction | Purpose | +|----------|-----------|---------| +| `EVT_IN_FWD_0` | MTE2 → V | RAW: buf_in[0] data ready | +| `EVT_IN_FWD_1` | MTE2 → V | RAW: buf_in[1] data ready | +| `EVT_IN_REV_0` | V → MTE2 | WAR: Vector done reading buf_in[0] | +| `EVT_IN_REV_1` | V → MTE2 | WAR: Vector done reading buf_in[1] | + +**Vector ↔ MTE3 (output buffers):** + +| Event ID | Direction | Purpose | +|----------|-----------|---------| +| `EVT_OUT_FWD_0` | V → MTE3 | RAW: buf_out[0] result ready | +| `EVT_OUT_FWD_1` | V → MTE3 | RAW: buf_out[1] result ready | +| `EVT_OUT_REV_0` | MTE3 → V | WAR: MTE3 done reading buf_out[0] | +| `EVT_OUT_REV_1` | MTE3 → V | WAR: MTE3 done reading buf_out[1] | + +###### 3a. `set_flag` / `wait_flag` version + +```mlir +// ═══ Pre-loop: prime ALL reverse-dependency signals ═══ +// Both input and output buffers start unused. We must pre-send +// reverse-dep signals so the first iteration's wait_flags don't deadlock. +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_0"] // ◀ PRIME: buf_in[0] "free" +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_1"] // ◀ PRIME: buf_in[1] "free" +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_0"] // ◀ PRIME: buf_out[0] "free" +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_1"] // ◀ PRIME: buf_out[1] "free" + +scf.for %i = %c0 to %N step %c1 { + // ── All 3 stages in same iteration, indexed by i%2 ── + // %pp = i % 2 (ping/pong selector for buffer & event IDs) + + // ── MTE2: load tile[i] into buf_in[i%2] ── + // WAR: wait until Vector has released buf_in[i%2] from iteration i-2 + pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{pp}"] + pto.copy_gm_to_ubuf %gm_ptr[%i], %ub_in[%pp], ... + // RAW: signal Vector that buf_in[i%2] data is ready + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVT_IN_FWD_{pp}"] + + // ── Vector: compute buf_in[i%2] → buf_out[i%2] ── + // RAW: wait for MTE2 to finish loading buf_in[i%2] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVT_IN_FWD_{pp}"] + // WAR: wait for MTE3 to finish reading buf_out[i%2] from iteration i-2 + pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{pp}"] + scf.for %dummy = %c0 to %c1 step %c1 { + %v = pto.vlds %ub_in[%pp][%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%pp][%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } {llvm.loop.aivector_scope} + // WAR: tell MTE2 "done reading buf_in[i%2]" + pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{pp}"] + // RAW: tell MTE3 "buf_out[i%2] result ready" + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVT_OUT_FWD_{pp}"] + + // ── MTE3: store result from buf_out[i%2] to GM ── + // RAW: wait for Vector to finish writing buf_out[i%2] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVT_OUT_FWD_{pp}"] + pto.copy_ubuf_to_gm %ub_out[%pp], %gm_out[%i], ... + // WAR: tell Vector "done reading buf_out[i%2]" + pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{pp}"] +} + +// ═══ Post-loop: drain — match every pre-loop prime with a wait ═══ +// Each priming set_flag must be paired. The last loop iteration's +// set_flags are consumed by wait_flags that will never fire inside the +// loop (there is no iteration i+2). Drain them here. +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{(N-1)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{(N-2)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{(N-1)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{(N-2)%2}"] // ◀ DRAIN +``` + +**What `set_flag`/`wait_flag` requires outside the loop:** +- **Before the loop (4 × `set_flag`):** Prime every reverse-dependency event ID — one per buffer per pipe-pair. Without this, the first iteration's `wait_flag` for reverse deps would deadlock (no signal was ever sent). +- **After the loop (4 × `wait_flag`):** Drain the matching reverse-dep signals from the last iterations. Every `set_flag` must be paired with a `wait_flag` — the last loop iterations produce signals that no subsequent iteration consumes, so they must be drained explicitly. + +###### 3b. `get_buf` / `rls_buf` version + +Same ping/pong double-buffering, but **no pre-loop priming or post-loop draining needed.** Buffer acquire/release semantics handle everything. + +```mlir +scf.for %i = %c0 to %N step %c1 { + // %pp = i % 2 (ping/pong selector) + + // ── MTE2: load tile[i] into buf[i%2] ── + // Acquires buf[i%2] — on first iteration, buffer is free so proceeds immediately. + // On later iterations, blocks until Vector releases buf[i%2] (WAR: automatic). + pto.get_buf "PIPE_MTE2", %bufid_buf[%pp], %c0 : i64, i64 // mode=0 + pto.copy_gm_to_ubuf %gm_ptr[%i], %ub_buf[%pp], ... + pto.rls_buf "PIPE_MTE2", %bufid_buf[%pp], %c0 : i64, i64 + + // ── Vector: compute on buf[i%2] ── + // Acquires buf[i%2] — blocks until MTE2 releases it (RAW: automatic) + pto.get_buf "PIPE_V", %bufid_buf[%pp], %c0 : i64, i64 + pto.get_buf "PIPE_V", %bufid_out[%pp], %c0 : i64, i64 + scf.for %dummy = %c0 to %c1 step %c1 { + %v = pto.vlds %ub_buf[%pp][%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%pp][%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } {llvm.loop.aivector_scope} + // Release buf[i%2] — MTE2 can reuse in iteration i+2 (WAR resolved) + pto.rls_buf "PIPE_V", %bufid_buf[%pp], %c0 : i64, i64 + pto.rls_buf "PIPE_V", %bufid_out[%pp], %c0 : i64, i64 + + // ── MTE3: store result ── + // Acquires out[i%2] — blocks until Vector releases it (RAW: automatic) + pto.get_buf "PIPE_MTE3", %bufid_out[%pp], %c0 : i64, i64 + pto.copy_ubuf_to_gm %ub_out[%pp], %gm_out[%i], ... + pto.rls_buf "PIPE_MTE3", %bufid_out[%pp], %c0 : i64, i64 +} +// No post-loop drain needed — last rls_buf completes the pipeline. +``` + +**No priming, no draining, no event IDs.** The acquire/release protocol on buffer IDs indexed by `i%2` implicitly resolves all cross-pipeline dependencies: +- **RAW** (MTE2→V): Vector's `get_buf` blocks until MTE2's `rls_buf` on `buf[i%2]` +- **WAR** (V→MTE2): MTE2's `get_buf` in iteration `i+2` blocks until Vector's `rls_buf` in iteration `i` (same buffer) +- **First iteration:** Buffer is initially free, so `get_buf` proceeds without blocking — no priming needed + +--- + +#### Comparison Summary + +| Aspect | `set_flag` / `wait_flag` | `get_buf` / `rls_buf` | +|--------|--------------------------|------------------------| +| Dependency model | Explicit event signals | Implicit via buffer acquire/release | +| IDs per pipe-pair | 2 IDs per buffer: 1 for forward (e.g., MTE2→V) + 1 for reverse (V→MTE2) | **1 ID per buffer** (handles both directions automatically) | +| Total HW IDs | **8 per pipe-pair** (hardware limit) | **32 global** across all pipes | +| Reverse (WAR) deps | Extra `set_flag`/`wait_flag` pair per buffer | Handled automatically | +| Pre-loop setup | `set_flag` to prime each reverse dep | **None** | +| Post-loop teardown | `wait_flag` to drain all primed signals | **None** | +| Loop peeling for complex deps | Required for non-1:1 or nested loops | **Not required** | +| Straight-line code | Simple, clear | Slightly more verbose (bracket each stage) | +| Ping/pong loops | 8 event IDs + 4 prime + 4 drain | Same pattern, **no overhead** | +| Best used for | Simple pipelines, fine-grained control | Double/multi-buffering, complex loops | + +--- + +#### Inter-Core Sync + +> **Note:** Inter-core sync is only needed for **mixed Cube+Vector tasks** where Cube produces data that Vector consumes (or vice versa). **Vec-only tasks can ignore this section entirely.** + +These ops coordinate execution across the Cube block and Vector subblocks within a cluster. Each core cluster consists of **1 Cube block : 2 Vector subblocks**, each with its own **SU (Sequencer Unit)** running independent instruction streams. + +``` +Core Cluster (1:2 ratio) +┌─────────────────────────────────────────────┐ +│ ┌──────────────┐ ┌──────────────┐ │ +│ │ AIC (Cube) │ │ AIV0 (Vec) │ │ +│ │ ┌────────┐ │ │ ┌────────┐ │ │ +│ │ │ SU │──┼────┼──│ SU │ │ │ +│ │ └────────┘ │ │ └────────┘ │ │ +│ │ CUBE pipe │ │ MTE2/V/MTE3 │ │ +│ │ L0C buffer │ │ UB (256KB) │ │ +│ └──────────────┘ └──────────────┘ │ +│ ┌──────────────┐ │ +│ │ AIV1 (Vec) │ │ +│ │ ┌────────┐ │ │ +│ │ │ SU │ │ │ +│ │ └────────┘ │ │ +│ │ MTE2/V/MTE3 │ │ +│ │ UB (256KB) │ │ +│ └──────────────┘ │ +└─────────────────────────────────────────────┘ +``` + +##### Platform Comparison + +| Aspect | A2A3 (Ascend 910) | A5 (Ascend 950) | +|--------|-------------------|-----------------| +| **Signal op** | `set_cross_core` (mode2) | `set_intra_block` | +| **Wait op** | `wait_flag_dev` | `wait_intra_core` | +| **Wait behavior** | SU-level blocking (entire core stalls) | Per-pipeline (only named pipe stalls) | +| **Semaphore pool** | 16 IDs per cluster, 4-bit counter | 16 IDs, but 32-ID address space (see below) | +| **C→V** | **Broadcast**: one `set` reaches both AIV0+AIV1 | **1:1**: separate `set` per subblock required | +| **V→C** | **Reduce**: Cube waits for both subblocks in one `wait` | **1:1**: Cube needs separate `wait` per subblock | + +##### A2A3: `set_cross_core` / `wait_flag_dev` + +```c +// mode2 broadcast/reduce semantics for 1:2 cluster +set_cross_core(pipe, semaphore_id); // pipe: VEC/MTE2/CUBE/FIX +wait_flag_dev(semaphore_id); // SU-level blocking +``` + +``` +C→V Broadcast (one set reaches both): + AIC ──set_cross_core──┬──> AIV0 sema++ + └──> AIV1 sema++ + +V→C Reduce (one wait for both): + AIV0 ──set_cross_core──┐ + ├──> AIC wait_flag_dev (blocks until BOTH) + AIV1 ──set_cross_core──┘ +``` + +##### `pto.set_cross_core` + +- **syntax:** `pto.set_cross_core %core_id, %event_id : i64, i64` +- **semantics:** Signal event to another core. Uses **mode2** for 1:2 cluster on A2A3. + +##### `pto.wait_flag_dev` + +- **syntax:** `pto.wait_flag_dev %core_id, %event_id : i64, i64` +- **semantics:** Wait for event from another core. **SU-level blocking** — entire core stalls. + +##### A5: `set_intra_block` / `wait_intra_core` + +```c +set_intra_block(trigger_pipe, semaphore_id); +wait_intra_core(wait_pipe, semaphore_id); // only named pipe stalls +``` + +**A5 semaphore address space:** The hardware has **16 physical semaphore IDs** but exposes a **32-ID address space** to support 1:1 signaling to each subblock: + +| ID Range | Target | +|----------|--------| +| 0–15 | AIV0 (subblock 0) | +| 16–31 (+15 offset) | AIV1 (subblock 1) | + +This means C→V requires **separate `set_intra_block` calls** per subblock (no broadcast), and V→C requires **separate `wait_intra_core` calls** per subblock (no hardware reduce). + +``` +C→V on A5 (1:1, no broadcast — need two sets): + AIC ──set_intra_block(pipe, sema_id)────> AIV0 + AIC ──set_intra_block(pipe, sema_id+15)──> AIV1 + +V→C on A5 (1:1, no reduce — need two waits): + AIV0 ──set_intra_block──> AIC wait_intra_core(pipe, sema_id) + AIV1 ──set_intra_block──> AIC wait_intra_core(pipe, sema_id+15) // extra wait +``` + +##### `pto.set_intra_block` + +- **syntax:** `pto.set_intra_block %block_id, %event_id : i64, i64` +- **semantics:** Signal event within a block (A5). Specifies **trigger pipe**. 1:1 per subblock. + +##### `pto.wait_intra_core` + +- **syntax:** `pto.wait_intra_core %block_id, %event_id : i64, i64` +- **semantics:** Wait for event within block (A5). Specifies **which pipeline should wait** — only that pipe stalls, SU and other pipes continue. + +##### Wait Granularity Comparison + +``` +A2A3 wait_flag_dev (SU-level stall): + SU ──┬── PIPE_MTE2 ───╳ ALL STALLED + ├── PIPE_V ───╳ ALL STALLED + └── PIPE_MTE3 ───╳ ALL STALLED + +A5 wait_intra_core "PIPE_MTE2" (per-pipe stall): + SU ──┬── PIPE_MTE2 ───╳ STALLED (waiting for Cube) + ├── PIPE_V ─── ✓ RUNNING + └── PIPE_MTE3 ─── ✓ RUNNING +``` + + + +### 2. DMA Copy Programming + +> **Category:** DMA transfer configuration and execution +> **Pipelines:** MTE2 (GM→UB), MTE3 (UB→GM) + +DMA transfers move data between Global Memory (GM) and Unified Buffer (UB). The MTE engines operate asynchronously from the Vector core, requiring explicit sync (see [Pipeline Sync](#isa-01-pipeline-sync)). + +The MTE2/MTE3 DMA engine executes a **multi-level nested loop** transfer. Before issuing the copy instruction, stride and loop-size registers must be configured. + +--- + +#### Loop Stride Configuration (GM→UB) + +These ops configure the MTE2 DMA engine's hardware loops for GM→UB transfers. They must be set **before** calling `pto.copy_gm_to_ubuf`. + +##### `pto.set_loop_size_outtoub` + +- **syntax:** `pto.set_loop_size_outtoub %loop1_count, %loop2_count : i64, i64` +- **semantics:** Configure HW loop iteration counts for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%loop1_count` | 21 bits | Inner HW loop iteration count | +| `%loop2_count` | 21 bits | Outer HW loop iteration count | + +When not using multi-level looping, set both to 1. + +--- + +##### `pto.set_loop2_stride_outtoub` + +- **syntax:** `pto.set_loop2_stride_outtoub %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure outer loop (loop2) pointer advance for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 40 bits | GM source pointer advance per loop2 iteration (bytes) | +| `%dst_stride` | 21 bits | UB destination pointer advance per loop2 iteration (bytes) | + +After each loop2 iteration, the DMA engine advances the GM read pointer by `%src_stride` and UB write pointer by `%dst_stride`. + +--- + +##### `pto.set_loop1_stride_outtoub` + +- **syntax:** `pto.set_loop1_stride_outtoub %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure inner loop (loop1) pointer advance for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 40 bits | GM source pointer advance per loop1 iteration (bytes) | +| `%dst_stride` | 21 bits | UB destination pointer advance per loop1 iteration (bytes) | + +--- + +#### Loop Stride Configuration (UB→GM) + +These ops configure the MTE3 DMA engine's hardware loops for UB→GM transfers. They must be set **before** calling `pto.copy_ubuf_to_gm`. + +Note: UB stride fields are 21 bits (sufficient for 256KB UB address space), GM stride fields are 40 bits (full GM address range). + +##### `pto.set_loop_size_ubtoout` + +- **syntax:** `pto.set_loop_size_ubtoout %loop1_count, %loop2_count : i64, i64` +- **semantics:** Configure HW loop iteration counts for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%loop1_count` | 21 bits | Inner HW loop iteration count | +| `%loop2_count` | 21 bits | Outer HW loop iteration count | + +--- + +##### `pto.set_loop2_stride_ubtoout` + +- **syntax:** `pto.set_loop2_stride_ubtoout %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure outer loop (loop2) pointer advance for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 21 bits | UB source pointer advance per loop2 iteration (bytes) | +| `%dst_stride` | 40 bits | GM destination pointer advance per loop2 iteration (bytes) | + +--- + +##### `pto.set_loop1_stride_ubtoout` + +- **syntax:** `pto.set_loop1_stride_ubtoout %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure inner loop (loop1) pointer advance for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 21 bits | UB source pointer advance per loop1 iteration (bytes) | +| `%dst_stride` | 40 bits | GM destination pointer advance per loop1 iteration (bytes) | + +--- + +#### Pad Value Configuration + +##### `pto.set_mov_pad_val` + +- **syntax:** `pto.set_mov_pad_val %value : T` +- **supported `T`:** `i8`, `i16`, `i32`, `f16`, `bf16`, `f32` +- **semantics:** Configure the pad fill value used by GM→UB DMA when `data_select_bit = true`. + +This op programs the hardware pad register consumed by `pto.copy_gm_to_ubuf`. The operand is a typed scalar. Its raw bit pattern is encoded into the underlying hardware configuration payload: + +- integer inputs use their zero-extended bit pattern +- floating-point inputs use their bitcast-to-integer bit pattern, then zero-extend to `i64` + +This configuration affects only the GM→UB padding path. UB→GM DMA ignores the pad value. + +**Parameter Table:** + +| Parameter | Description | +|-----------|-------------| +| `%value` | Pad fill scalar. Must be one of `i8/i16/i32/f16/bf16/f32`. | + +**Example:** + +```mlir +%pad = arith.constant 0 : i16 +pto.set_mov_pad_val %pad : i16 +``` + +--- + +#### DMA Transfer Execution + +##### `pto.copy_gm_to_ubuf` + +- **syntax:** +```mlir +pto.copy_gm_to_ubuf %gm_src, %ub_dst, + %sid, %n_burst, %len_burst, %left_padding, %right_padding, + %data_select_bit, %l2_cache_ctl, %src_stride, %dst_stride + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` +- **semantics:** DMA transfer from Global Memory (`!pto.ptr`) to Unified Buffer (`!pto.ptr`). + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%gm_src` | GM source pointer (`!pto.ptr`) | +| `%ub_dst` | UB destination pointer (`!pto.ptr`, 32B-aligned) | +| `%sid` | Stream ID (usually 0) | +| `%n_burst` | Number of burst rows (innermost loop count) | +| `%len_burst` | Contiguous bytes transferred per burst row | +| `%left_padding` | Left padding count (bytes) | +| `%right_padding` | Right padding count (bytes) | +| `%data_select_bit` | Padding / data-select control bit (`i1`) | +| `%l2_cache_ctl` | L2 cache allocate control (TBD — controls whether DMA allocates in L2 cache) | +| `%src_stride` | GM source stride: start-to-start distance between consecutive burst rows (bytes) | +| `%dst_stride` | UB destination stride: start-to-start distance between consecutive burst rows (bytes, 32B-aligned) | + +--- + +##### `pto.copy_ubuf_to_gm` + +- **syntax:** +```mlir +pto.copy_ubuf_to_gm %ub_src, %gm_dst, + %sid, %n_burst, %len_burst, %reserved, %dst_stride, %src_stride + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` +- **semantics:** DMA transfer from Unified Buffer (`!pto.ptr`) to Global Memory (`!pto.ptr`). MTE3 reads only `len_burst` bytes from each UB row (de-padding). + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%ub_src` | UB source pointer (`!pto.ptr`, 32B-aligned) | +| `%gm_dst` | GM destination pointer (`!pto.ptr`) | +| `%sid` | Stream ID (usually 0) | +| `%n_burst` | Number of burst rows | +| `%len_burst` | Contiguous bytes transferred per burst row | +| `%reserved` | Reserved field (set to 0) | +| `%dst_stride` | GM destination stride: start-to-start distance between consecutive burst rows (bytes) | +| `%src_stride` | UB source stride: start-to-start distance between consecutive burst rows (bytes, 32B-aligned) | + +--- + +##### `pto.copy_ubuf_to_ubuf` + +- **syntax:** +```mlir +pto.copy_ubuf_to_ubuf %source, %dest, %sid, %n_burst, %len_burst, %src_stride, %dst_stride + : !pto.ptr, !pto.ptr, i64 x5 +``` +- **semantics:** Copy within Unified Buffer. + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%source` | UB source pointer | +| `%dest` | UB destination pointer | +| `%sid` | Stream ID | +| `%n_burst` | Number of bursts | +| `%len_burst` | Length per burst | +| `%src_stride` | Source stride | +| `%dst_stride` | Destination stride | + +--- + +#### Burst / Stride / Pad Model + +All A5 DMA addresses are **stride-based**: stride is the distance from the start of one row to the start of the next row (`stride >= lenBurst`). There is no separate "gap" parameter. + +##### Key Terms + +``` +burst = lenBurst contiguous bytes transferred per row +stride = distance (bytes) from start of row[r] to start of row[r+1] +pad = ub_stride - lenBurst, padded to the 32B alignment boundary +``` + +##### Alignment Constraints + +- **UB addresses** (both source and destination) must be **32-byte aligned**. +- **GM→UB padding**: When `data_select_bit = true`, each UB row is padded from `lenBurst` up to the **32B-aligned boundary** of `ub_stride` with `pad_val` (set via `set_mov_pad_val`). This ensures every UB row starts at a 32B-aligned offset. +- **UB→GM de-padding**: MTE3 reads `lenBurst` bytes from each 32B-aligned UB row (skipping any padding that was added during load), writing only valid data to GM. This effectively strips padding on store. + +##### 2D Diagram: GM→UB (pto.copy_gm_to_ubuf) + +``` +GM (source, `!pto.ptr`): + + |<--- src_stride (start-to-start) --->| + |<- len_burst ->| | +Row 0: [##DATA########]......................| +Row 1: [##DATA########]......................| +Row 2: [##DATA########]......................| + ... +Row N-1: [##DATA########] + +UB (destination, `!pto.ptr`, 32B-aligned): + + |<---------- dst_stride (32B-aligned) ---------->| + |<- len_burst ->|<- pad (to 32B boundary) ->| | +Row 0: [##DATA########][000000 PAD 000000000000000] +Row 1: [##DATA########][000000 PAD 000000000000000] +Row 2: [##DATA########][000000 PAD 000000000000000] + ... +Row N-1: [##DATA########][000000 PAD 000000000000000] + +N = n_burst +stride = start of row[r] to start of row[r+1] +pad = filled with pad_val to 32B boundary (data_select_bit=true) +[DATA] = valid data transferred by DMA +[PAD] = pad_val fill (set via set_mov_pad_val) +``` + +##### 2D Diagram: UB→GM (pto.copy_ubuf_to_gm) + +``` +UB (source, `!pto.ptr`, 32B-aligned start addr): + + |<---------- src_stride (32B-aligned) --------->| + |<- len_burst ->|<-- pad (ignored on read) -->| | +Row 0: [##DATA########][000 pad 000000000000000000] +Row 1: [##DATA########][000 pad 000000000000000000] +Row 2: [##DATA########][000 pad 000000000000000000] + ... +Row N-1: [##DATA########][000 pad 000000000000000000] + +GM (destination, `!pto.ptr`): + + |<--- dst_stride (start-to-start) --->| + |<- len_burst ->| | +Row 0: [##DATA########]......................| +Row 1: [##DATA########]......................| +Row 2: [##DATA########]......................| + ... +Row N-1: [##DATA########] + +N = n_burst +MTE3 reads only len_burst bytes from each UB row (de-padding). +Only len_burst bytes are written to each GM row. +``` + +--- + +#### Multi-Level Loop Semantics (C Code) + +The full DMA transfer is a nested loop. The HW loop registers (set before the copy) control the outer levels, and the copy instruction parameters control the innermost burst level. + +##### GM→UB Full Loop + +```c +// C equivalent of what the HW executes: +for (int j = 0; j < loop2_count; j++) { // HW outer loop + uint8_t *gm1 = gm_src + j * loop2_src_stride; + uint8_t *ub1 = ub_dst + j * loop2_dst_stride; + + for (int k = 0; k < loop1_count; k++) { // HW inner loop + uint8_t *gm2 = gm1 + k * loop1_src_stride; + uint8_t *ub2 = ub1 + k * loop1_dst_stride; + + for (int r = 0; r < n_burst; r++) { // burst engine + memcpy(ub2 + r * dst_stride, // UB dest row + gm2 + r * src_stride, // GM src row + len_burst); // contiguous bytes + if (data_select_bit) + memset(ub2 + r * dst_stride + len_burst, + pad_val, dst_stride - len_burst); + } + } +} +``` + +##### UB→GM Full Loop + +```c +// C equivalent: +for (int j = 0; j < loop2_count; j++) { + uint8_t *ub1 = ub_src + j * loop2_src_stride; + uint8_t *gm1 = gm_dst + j * loop2_dst_stride; + + for (int k = 0; k < loop1_count; k++) { + uint8_t *ub2 = ub1 + k * loop1_src_stride; + uint8_t *gm2 = gm1 + k * loop1_dst_stride; + + for (int r = 0; r < n_burst; r++) { + memcpy(gm2 + r * dst_stride, // GM dest row + ub2 + r * src_stride, // UB src row + len_burst); // contiguous bytes + } + } +} +``` + +--- + +#### Example 1: GM→UB — Load a 32×32 f32 Tile (Simple Case) + +Load a 32×32 f32 tile from GM into UB. This matches the `abs_kernel_2d` test case. + +``` +GM layout (32 × 32 f32, contiguous): + + |<- len_burst = 128B (32 × 4) ->| + |<- src_stride = 128B --------->| + +--[#######TILE#######]--+ row 0 + +--[#######TILE#######]--+ row 1 + ... + +--[#######TILE#######]--+ row 31 + +UB layout (32 × 32 f32, 32B-aligned, contiguous): + + |<- dst_stride = 128B (32B-aligned) ->| + +--[#######TILE#######]--+ row 0 + +--[#######TILE#######]--+ row 1 + ... + +--[#######TILE#######]--+ row 31 + + len_burst = 32 × 4 = 128 bytes + src_stride = 128 bytes (contiguous rows) + dst_stride = 128 bytes (already 32B-aligned, no padding) +``` + +```mlir +// Simple 2D load — no multi-level loops needed +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + +pto.copy_gm_to_ubuf %arg0, %ub_in, + %c0_i64, // sid = 0 + %c32_i64, // n_burst = 32 (32 rows) + %c128_i64, // len_burst = 128 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c128_i64, // src_stride = 128 bytes + %c128_i64 // dst_stride = 128 bytes + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +#### Example 2: GM→UB — Load a 2D Tile from a Larger Matrix + +Load a 64×128 tile (f16) from a 1024×512 matrix in GM into UB. + +``` +GM layout (1024 × 512 f16): + + col 0 col 128 col 512 + | | | + +--[###TILE###]+.....................+ row R + +--[###TILE###]+.....................+ row R+1 + ... + +--[###TILE###]+.....................+ row R+63 + + |<--------- src_stride = 1024B ----------->| + |<-len_burst=256B->| + + len_burst = 128 × 2 = 256 bytes (128 f16 elements) + src_stride = 512 × 2 = 1024 bytes (start-to-start, full GM row) + +UB layout (64 × 128 f16, 32B-aligned, contiguous): + + +--[###TILE###]--+ row 0 (256 bytes, 32B-aligned, no pad) + +--[###TILE###]--+ row 1 + ... + +--[###TILE###]--+ row 63 + + dst_stride = 256 bytes (= len_burst, already 32B-aligned, no padding) +``` + +```mlir +// Simple 2D load — no multi-level loops needed +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 (64 rows) + %c256_i64, // len_burst = 256 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c1024_i64, // src_stride = 1024 bytes (full matrix row) + %c256_i64 // dst_stride = 256 bytes (tile row) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +#### Example 3: GM→UB — Load with Padding + +Load 100 valid columns from GM into a 128-wide UB tile (f16). The remaining 28 columns are zero-padded. + +``` +GM (100 cols valid, contiguous): + + |<-len_burst=200B->| + |<- src_stride=200B (start-to-start) ->| + +--[####DATA####]-+ row 0 + +--[####DATA####]-+ row 1 + ... + +--[####DATA####]-+ row 63 + +UB (128 cols wide, 32B-aligned, padded): + + |<--------- dst_stride = 256B (32B-aligned) --------->| + |<-len_burst=200B->|<---- pad = 56B to 32B boundary ->| + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 0 + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 1 + ... + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 63 + + len_burst = 100 × 2 = 200 bytes + src_stride = 200 bytes (start-to-start, contiguous in GM) + dst_stride = 128 × 2 = 256 bytes (32B-aligned tile width in UB) + pad = 256 - 200 = 56 bytes (padded to 32B boundary with pad_val) +``` + +```mlir +%pad = arith.constant 0 : i16 +pto.set_mov_pad_val %pad : i16 +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 + %c200_i64, // len_burst = 200 bytes + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %true, // data_select_bit = true (enable padding) + %c0_i64, // l2_cache_ctl = 0 + %c200_i64, // src_stride = 200 bytes + %c256_i64 // dst_stride = 256 bytes (32B-aligned) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +#### Example 4: UB→GM — Store a 32×32 f32 Tile (Simple Case) + +Store a 32×32 f32 tile from UB back to GM. This matches the `abs_kernel_2d` test case. + +``` +UB (source, 32B-aligned, 32 × 32 f32): + + |<- src_stride = 128B (32B-aligned) ->| + |<- len_burst = 128B ->| + +--[#######TILE#######]---+ row 0 + +--[#######TILE#######]---+ row 1 + ... + +--[#######TILE#######]---+ row 31 + + (no padding here — len_burst == src_stride) + +GM (dest, 32 × 32 f32): + + |<- dst_stride = 128B ->| + |<- len_burst = 128B -->| + +--[#######TILE#######]---+ row 0 + +--[#######TILE#######]---+ row 1 + ... + +--[#######TILE#######]---+ row 31 +``` + +```mlir +// Configure MTE3 strides +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + +pto.copy_ubuf_to_gm %ub_out, %arg1, + %c0_i64, // sid = 0 + %c32_i64, // n_burst = 32 + %c128_i64, // len_burst = 128 bytes + %c0_i64, // reserved = 0 + %c128_i64, // dst_stride = 128 bytes + %c128_i64 // src_stride = 128 bytes + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` + +--- + +#### Example 5: UB→GM — Store a 2D Tile Back to a Larger Matrix + +Store a 64×128 tile (f16) from UB back to a 1024×512 GM matrix at an offset. + +``` +UB (source, 32B-aligned, 64 × 128 f16): + + |<- src_stride = 256B (32B-aligned) ->| + |<- len_burst = 256B ->| + +--[#####TILE#####]---+ row 0 + +--[#####TILE#####]---+ row 1 + ... + +--[#####TILE#####]---+ row 63 + + (no padding here — len_burst == src_stride) + +GM (dest, into 1024 × 512 matrix): + + |<----------- dst_stride = 1024B (start-to-start) --------->| + |<- len_burst = 256B ->| | + col 0 col 128 col 512 + +--[#####TILE#####]---+.............................+ row R + +--[#####TILE#####]---+.............................+ row R+1 + ... + +--[#####TILE#####]---+.............................+ row R+63 + + MTE3 reads len_burst bytes from each 32B-aligned UB row, + writes only len_burst bytes per GM row (stride controls row spacing). +``` + +```mlir +// Configure MTE3 strides +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_ubtoout %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_ubtoout %c0_i64, %c0_i64 : i64, i64 + +pto.copy_ubuf_to_gm %ub_ptr, %gm_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 + %c256_i64, // len_burst = 256 bytes + %c0_i64, // reserved = 0 + %c1024_i64, // dst_stride = 1024 bytes (GM row) + %c256_i64 // src_stride = 256 bytes (UB row) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` + +--- + +#### Example 6: GM→UB with Multi-Level Loop (Batch of Tiles) + +Load 4 batches of 8×128 tiles from a [4, 8, 128] f16 tensor using loop1. + +``` +GM [4, 8, 128] f16 (contiguous): UB (4 tiles laid out sequentially): + + batch 0: 8 rows × 256 bytes [batch 0: 8×128][batch 1: 8×128] + batch 1: 8 rows × 256 bytes [batch 2: 8×128][batch 3: 8×128] + batch 2: 8 rows × 256 bytes + batch 3: 8 rows × 256 bytes loop1 src_stride = 2048 bytes (8 × 256) + loop1 dst_stride = 2048 bytes (8 × 256) + Each batch = 8 × 256 = 2048 bytes loop1_count = 4 (iterate over batches) +``` + +```mlir +// loop1_count = 4 batches, loop2_count = 1 (not used) +pto.set_loop_size_outtoub %c4_i64, %c1_i64 : i64, i64 + +// loop1 stride: advance by one batch (2048 bytes) in both GM and UB +pto.set_loop1_stride_outtoub %c2048_i64, %c2048_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c8_i64, // n_burst = 8 rows per batch + %c256_i64, // len_burst = 256 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c256_i64, // src_stride = 256 (contiguous rows) + %c256_i64 // dst_stride = 256 (contiguous rows) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +Execution trace: + +``` +loop1 iter 0: gm_ptr + 0×2048 → ub_ptr + 0×2048, DMA 8 rows × 256B +loop1 iter 1: gm_ptr + 1×2048 → ub_ptr + 1×2048, DMA 8 rows × 256B +loop1 iter 2: gm_ptr + 2×2048 → ub_ptr + 2×2048, DMA 8 rows × 256B +loop1 iter 3: gm_ptr + 3×2048 → ub_ptr + 3×2048, DMA 8 rows × 256B +``` + + + +### 3. Vector Load/Store + +> **Category:** UB ↔ Vector Register data movement +> **Pipeline:** PIPE_V (Vector Core) + +Vector loads move data from Unified Buffer (UB) to vector registers (`vreg`). Vector stores move data from `vreg` back to UB. All vector compute operates only on `vreg` — UB is the staging area between DMA and compute. + +#### Common Operand Model + +- `%source` / `%dest` is the base address operand in SSA form. The base pointer + MUST address the Vector tile buffer / UB space. +- `%offset` is the displacement operand in SSA form. The exact encoding is + instruction-specific, but the effective address and any post-update behavior + MUST match the selected instruction form. +- `%mask` is the predicate operand for predicated memory families. For memory + families, + inactive lanes or inactive blocks MUST NOT issue memory requests unless the + instruction explicitly documents a different behavior. +- `%result` is the destination vector register value in SSA form. +- `!pto.align` is the SSA carrier for alignment-buffer state used by unaligned + load/store families. The PTO micro Instruction representation makes that state explicit rather than implicit. + +--- + +#### Latency and throughput (A5) + +**Cycle-accurate simulator (CA model)** issue→retire timings for vector-side instructions behind this chapter. Values are **simulator** results, **not** guaranteed for silicon. + +**SOC:** Tables below are from **Ascend910_9599** CA sim (the pto-isa ST default when **Ascend950PR_9599** is not selected). + +**Log `dist:` tokens:** PTO load/store modes lower to **`RV_VLD` / `RV_VLDI` / `RV_VST` / `RV_VSTI`** with a **`dist:`** field on the vector pipes (`RVECLD` / `RVECST`). Some simulator logs typo contiguous load as `dist:NORAML`; treat as **`NORMAL`**. + +##### Reference op latencies (A5 mnemonics) + +| A5 mnemonic | Mode / note | Typical issue→retire (cycles) | +|-------------|-------------|------------------------------| +| `RV_VLD` | `dist:NORMAL` / `NORAML` | **9** | +| `RV_VLDI` | `dist:DINTLV` (dual vreg) | **9** | +| `RV_VST` / `RV_VSTI` | `dist:NORM` | **9** | +| `RV_VGATHER2` | `Dtype: B32` | **27–28** | +| `RV_VGATHERB` | indexed byte gather | **~21** | +| `RV_VSCATTER` | `Dtype: B16` | **~17** | +| `RV_VADD` | F32 between UB-backed ops | **7** | + +##### `dist:` tokens (issue→retire) + +Most **`dist:`** tokens are **9** issue→retire cycles. **`INTLV`** on **`RV_VSTI`** is **12** cycles. + +| `dist:` (as in log) | RV op | issue→retire (cycles) | +|---------------------|-------|----------------------| +| `DINTLV` | `RV_VLDI` | **9** | +| `BRC` | `RV_VLD` | **9** | +| `BRC_BLK` | `RV_VLD` | **9** | +| `INTLV` | `RV_VSTI` | **12** | +| `UNPK` | `RV_VLD` | **9** | +| `NORM` | `RV_VSTI` | **9** | +| `PK` | `RV_VSTI` | **9** | +| `NORMAL` / `NORAML` | `RV_VLD` | **9** | + +**Note:** PTO intrinsic **`BRC_BLK`** matches the **`BRC_BLK`** `dist:` string on **`RV_VLD`** in simulator logs (block-replicate path; not a plain contiguous copy in the usual tiling use). + +**Issue (vector load/store):** `pto.vlds` (**`RV_VLD`**) is **dual-issue capable**: two independent `pto.vlds` can issue **in the same cycle**. **Alternatively**, the hardware can issue **one** `pto.vlds` **and** **one** `pto.vsts` together (**1+1**) in the same cycle. Each cycle is **either** dual **`vlds`** **or** **`vlds` + `vsts` (1+1)**—those two issue modes are mutually exclusive. Sustained throughput still depends on RAW hazards and loop structure. + +**Throughput (simulator, pattern-dependent):** + +- **`RV_VLD` / `pto.vlds`:** Dual-issue **or** half of a **1+1** with `vsts`, per the rule above. +- **`RV_VST` / `pto.vsts`:** In a **1+1** cycle, pairs with one `vlds`; otherwise typically **one** store per cycle in tight loops. +- **`RV_VGATHER2`:** Much lower than contiguous `RV_VLD` (on the order of **~0.1** ops/cycle in steady-state alongside 27–28-cycle latency). + +##### PTO `dist` summary (loads) + +| PTO `dist` (load) | Latency | +|-------------------|-------------------| +| `NORM` | **9** cycles | +| `UNPK` | **9** cycles | +| `DINTLV` | **9** cycles (`RV_VLDI`) | +| `BRC` | **9** cycles (`RV_VLD`) | +| `BRC_BLK` | **9** cycles as **`dist:BRC_BLK`** on `RV_VLD` | +| `BDINTLV` | **9** cycles | +| `US`, `DS`, `SPLT4CHN`, `SPLT2CHN` | **9** cycles | + +##### PTO `dist` summary (stores) + +| PTO `dist` (store) | Latency | +|--------------------|-------------------| +| `NORM` | **9** cycles (`RV_VSTI`) | +| `PK` | **9** cycles | +| `INTLV` (`pto.vstx2`) | **12** cycles | +| `MRG4CHN`, `MRG2CHN` | **9** cycles | + +##### Gather, scatter, and special addressing + +| PTO op | A5-level | Latency | +|--------|----------|-------------------| +| `pto.vgather2` | `RV_VGATHER2` | **27–28** cycles (pattern-dependent) | +| `pto.vgatherb` | `RV_VGATHERB` | **~21** cycles issue→retire | +| `pto.vgather2_bc` | (broadcast gather) | **27–28** cycles (same as **`pto.vgather2`**) | +| `pto.vscatter` | `RV_VSCATTER` | **~17** cycles for **`Dtype: B16`** | + +##### Strided loads/stores, unaligned ops, alignment state + +Ops such as **`pto.vldas`**, **`pto.vldus`**, **`pto.vsld`**, **`pto.vsldb`**, **`pto.vsst`**, **`pto.vsstb`**, **`pto.vsta`**, **`pto.vstas`**, **`pto.vstar`**, **`pto.vstu`**, **`pto.vstus`**, **`pto.vstur`**: **9** cycles (same vector load/store pipe family as contiguous `RV_VLD` / `RV_VST` unless listed otherwise above). + +##### Dual-issue vs DMA + +DMA **`TLOAD` / `TSTORE`** (global memory ↔ UB) use **MTE** pipes, not `RV_VLD`/`RV_VST`. **MTE2** `MOV_*` latency is not the same as vector `RV_VLD` latency (see `02-dma-copy.md` for GM↔UB movement). + +--- + +#### Contiguous Loads + +##### `pto.vlds` + +- **syntax:** `%result = pto.vlds %source[%offset] {dist = "DIST"} : !pto.ptr -> !pto.vreg` +- **semantics:** Vector load with distribution mode. +- **inputs:** + `%source` is the UB base address, `%offset` is the load displacement, and + `DIST` selects the distribution mode. +- **outputs:** + `%result` is the loaded vector register value. +- **constraints and limitations:** + The effective address MUST satisfy the alignment rule of the selected + distribution mode. `NORM` reads one full vector footprint. Broadcast, + upsample, downsample, unpack, split-channel, and deinterleave modes change + how memory bytes are mapped into destination lanes, but they do not change the + fact that the source is UB memory. PTO surface exposes load `dist` as family + tokens, and each family only supports the element widths listed below. + +**Distribution families:** + +| Family | Allowed element widths | C semantics | Latency | +|------|-------------|-------------|-------------| +| `NORM` | width-agnostic | `dst[i] = UB[base + i * sizeof(T)]` | **9** cycles | +| `BRC` | `b8`, `b16`, `b32` | `dst[i] = UB[base]` for all `i` | **9** cycles | +| `US` | `b8`, `b16` | `dst[2*i] = dst[2*i+1] = UB[base + i]` | **9** cycles | +| `DS` | `b8`, `b16` | `dst[i] = UB[base + 2*i]` | **9** cycles | +| `UNPK` | `b8`, `b16`, `b32` | Expand packed source data into wider lanes | **9** cycles | +| `BRC_BLK` | width-agnostic | Block-replicate load path; simulator logs may print `dist:BRC_BLK` | **9** cycles | +| `E2B` | `b16`, `b32` | Load element groups and expand them into byte-oriented lane layout | **9** cycles | +| `UNPK4` | `b8` | Unpack 4-way packed `b8` source groups into destination lanes | **9** cycles | +| `SPLT4CHN` | `b8` | Split 4-channel interleaved source into one channel plane | **9** cycles | +| `SPLT2CHN` | `b8`, `b16` | Split 2-channel interleaved source into one channel plane | **9** cycles | + +`pto.vlds` currently covers only single-result load families. Dual-result +deinterleave forms are modeled separately in PTO surface as +[`pto.vldsx2`](#ptovldsx2): `BDINTLV` is the block-deinterleave family, while +`DINTLV` is the element-width-sensitive deinterleave family. + +**Example — Contiguous load:** +```mlir +%v = pto.vlds %ub[%offset] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> +``` + +**Example — Broadcast scalar to all lanes:** +```mlir +%v = pto.vlds %ub[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> +``` + +--- + +##### `pto.vldas` + +- **syntax:** `%result = pto.vldas %source : !pto.ptr -> !pto.align` +- **semantics:** Prime alignment buffer for subsequent unaligned load. +- **inputs:** + `%source` is the UB address whose surrounding aligned block seeds the load + alignment state. +- **outputs:** + `%result` is the initialized load-alignment state. +- **constraints and limitations:** + This op is the required leading operation for a `pto.vldus` stream using the + same alignment state. The source address itself need not be 32-byte aligned; + hardware truncates it to the aligned block boundary for the priming load. +- **Latency:** **9** cycles. + +--- + +##### `pto.vldus` + +- **syntax:** `%result, %align_out = pto.vldus %source, %align : !pto.ptr, !pto.align -> !pto.vreg, !pto.align` +- **semantics:** Unaligned load using primed align state. +- **inputs:** + `%source` is the current UB address and `%align` is the incoming load + alignment state primed by `pto.vldas` or a prior `pto.vldus`. +- **outputs:** + `%result` is the assembled vector value and `%align_out` is the updated + alignment state. +- **constraints and limitations:** + A matching `pto.vldas` MUST appear before the first dependent `pto.vldus` + stream in the same vector loop. The installed no-post A5 interface keeps a + struct-shaped internal return for lowering convenience, but its no-post + `base` field is not meaningful user-visible state. VPTO therefore hides that + value and only exposes the updated align carrier. Reusing the original + `%source` starts a new explicit access point; if the caller wants another + no-post access, it should compute the next source pointer explicitly and pair + it with the required align setup. +- **Latency:** **9** cycles. + +**Unaligned load pattern:** +```mlir +%align = pto.vldas %ub : !pto.ptr -> !pto.align +%vec, %align2 = pto.vldus %ub, %align : !pto.ptr, !pto.align -> !pto.vreg<64xf32>, !pto.align +``` + +--- + +##### `pto.init_align` + +- **syntax:** `%result = pto.init_align : !pto.align` +- **semantics:** Initialize store-side align carrier state. +- **outputs:** + `%result` is a fresh zero-initialized align carrier for store-side unaligned + streams such as `pto.vstus`, `pto.vstur`, `pto.vstar`, `pto.vstas`, and + `pto.pstu`. +- **constraints and limitations:** + This op is for store-family initialization only. Unaligned load streams still + start from `pto.vldas`. + +--- + +#### Dual Loads (Deinterleave) + +##### `pto.vldsx2` + +- **syntax:** `%low, %high = pto.vldsx2 %source[%offset], "DIST" : !pto.ptr, index -> !pto.vreg, !pto.vreg` +- **semantics:** Dual load with deinterleave (AoS → SoA conversion). +- **inputs:** + `%source` is the UB base pointer, `%offset` is the displacement, and `DIST` + selects a dual-load/deinterleave layout. +- **outputs:** + `%low` and `%high` are the two destination vectors. +- **constraints and limitations:** + This family is only legal for interleave/deinterleave style distributions. + The two outputs form an ordered pair, and that pairing MUST be preserved. + PTO surface accepts deinterleave families. `BDINTLV` is element-width + agnostic, while `DINTLV` supports only the element widths listed in the + table. +- **latency:** `BDINTLV` / `DINTLV` are both **9** cycles. + +**Distribution families:** + +| Family | Allowed element widths | C semantics | Latency | +|------|-------------|-------------|-------------| +| `BDINTLV` | width-agnostic | Block deinterleave into two destination vectors | **9** cycles | +| `DINTLV` | `b8`, `b16`, `b32` | Deinterleave alternating elements into `%low` / `%high` | **9** cycles | + +```c +// DINTLV family on 32-bit elements: deinterleave 32-bit elements +for (int i = 0; i < 64; i++) { + low[i] = UB[base + 8*i]; // even elements + high[i] = UB[base + 8*i + 4]; // odd elements +} +``` + +**Example — Load interleaved XY pairs into separate X/Y vectors:** +```mlir +%x, %y = pto.vldsx2 %ub[%offset], "DINTLV" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +``` + +##### `pto.vsldb` + +- **syntax:** `%result = pto.vsldb %source, %block_stride, %repeat_stride, %mask : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg` +- **semantics:** Block-strided load for 2D tile access. +- **inputs:** + `%source` is the UB base pointer. `%block_stride` and `%repeat_stride` are + the two 16-bit fields of the hardware control word, and `%mask` controls + which blocks participate. +- **outputs:** + `%result` is the loaded vector. +- **constraints and limitations:** + PTO surface does not expose the packed control word directly. If a block is + masked off, the corresponding destination block is zeroed and MUST NOT raise + an address overflow exception for that block. +- **Latency:** **9** cycles. + +```c +// Block-strided load on 32-bit elements: one 32B block = 8 lanes. +for (int blk = 0; blk < 8; ++blk) { + if (pg_b32[blk]) + dst_block[blk] = UB_block[base + repeat_stride + blk * block_stride]; + else + dst_block[blk] = 0; +} +``` + +--- + +#### Gather (Indexed) Loads + +##### `pto.vgather2` + +- **syntax:** `%result = pto.vgather2 %source, %offsets, %active_lanes : !pto.ptr, !pto.vreg, index -> !pto.vreg` +- **semantics:** Indexed gather from UB. +- **inputs:** + `%source` is the UB base pointer, `%offsets` provides per-lane element + offsets, and `%active_lanes` bounds how many lanes participate. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + Only the first `%active_lanes` indices participate. The index element width + and interpretation MUST match the selected gather form, and each effective + address must satisfy that form's alignment rules. +- **Latency:** **27–28** cycles per `RV_VGATHER2`; throughput much lower than contiguous `RV_VLD` (see **Latency and throughput (A5)** at the start of this chapter). + +```c +for (int i = 0; i < active_lanes; i++) + dst[i] = UB[base + offsets[i] * sizeof(T)]; +``` + +--- + +##### `pto.vgatherb` + +- **syntax:** `%result = pto.vgatherb %source, %offsets, %mask : !pto.ptr, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Block gather load from UB. +- **inputs:** + `%source` is the UB base pointer, `%offsets` is a `ui32` offset vector, and + `%mask` is a `b32` predicate over the block-index lanes. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + This is a 32-byte block gather, not an element gather. `%source` MUST be + 32-byte aligned. Each participating `offsets[i]` is interpreted as a byte + offset and MUST itself be 32-byte aligned. Only the low `VL/8` bytes of the + offset vector are semantically valid; the effective block address is + `block_addr[i] = offsets_u32[i] + base`. If a `b32` predicate position is + false, the corresponding block does not participate in address coalescing, + does not raise overflow on that block address, and the destination block is + zero-filled. +- **Latency:** **~21** cycles issue→retire. + +```c +for (int blk = 0; blk < VL / 32; ++blk) { + if (pg_b32[blk]) + dst_block[blk] = UB_block[base + offsets_u32[blk]]; + else + dst_block[blk] = 0; +} +``` + +--- + +##### `pto.vgather2_bc` + +- **syntax:** `%result = pto.vgather2_bc %source, %offsets, %mask : !pto.ptr, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Gather with broadcast, conditioned by mask. +- **inputs:** + `%source` is the UB base pointer, `%offsets` contains gather indices, and + `%mask` gates which lanes participate. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + This is a backward-compatible family. Masked-off lanes do not participate in + address coalescing and do not trigger address overflow exceptions; their + destination lanes are zero-filled. On the current PTO surface, `%offsets` + uses 32-bit integer elements. +- **Latency:** **27–28** cycles (same as **`pto.vgather2`**). + +--- + +#### Contiguous Stores + +##### `pto.vsts` + +- **syntax:** `pto.vsts %value, %dest[%offset], %mask {dist = "DIST"} : !pto.vreg, !pto.ptr, !pto.mask` +- **semantics:** Vector store with distribution mode. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, `%offset` is + the displacement, `%mask` selects the active lanes or sub-elements, and + `DIST` selects the store distribution. +- **outputs:** + This op has no SSA result; it writes to UB memory. +- **constraints and limitations:** + The effective destination address MUST satisfy the alignment rule of the + selected store mode. The single-input `pto.vsts` family covers contiguous + store, first-element-only store, packed store, and channel-merge store. + Dual-input interleave store remains in `pto.vstsx2`. PTO surface exposes + store `dist` as family tokens, and each family only supports the element + widths listed below. + +**Distribution families:** + +| Family | Allowed element widths | C semantics | Latency | +|------|-------------|-------------|-------------| +| `NORM` | `b8`, `b16`, `b32` | `UB[base + i] = src[i]` | **9** cycles | +| `1PT` | `b8`, `b16`, `b32` | Only element 0 is written to the destination footprint | **9** cycles | +| `PK` | `b16`, `b32`, `b64` | Pack low half bits of each source element before store | **9** cycles | +| `PK4` | `b32` | Pack low 8 bits of each `b32` element before store | **9** cycles | +| `MRG4CHN` | `b8` | Merge 4 channel planes into an interleaved 4-channel layout | **9** cycles | +| `MRG2CHN` | `b8`, `b16` | Merge 2 channel planes into an interleaved 2-channel layout | **9** cycles | + +**Example — Contiguous store:** +```mlir +pto.vsts %v, %ub[%offset], %mask {dist = "NORM"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +--- + +#### Dual Stores (Interleave) + +##### `pto.vstsx2` + +- **syntax:** `pto.vstsx2 %low, %high, %dest[%offset], "DIST", %mask : !pto.vreg, !pto.vreg, !pto.ptr, index, !pto.mask` +- **semantics:** Dual interleaved store (SoA → AoS conversion). +- **inputs:** + `%low` and `%high` are the two source vectors, `%dest` is the UB base pointer, + `%offset` is the displacement, `DIST` selects the interleave layout, and + `%mask` gates the participating elements. +- **outputs:** + This op has no SSA result; it writes an interleaved stream to UB. +- **constraints and limitations:** + This family is only legal for interleave distributions. The two source + vectors form an ordered pair, and the interleave semantics of that pair MUST + be preserved. PTO surface accepts the `INTLV` family, which only supports the + element widths listed below. + be preserved. PTO surface accepts the `INTLV` family, which only supports the + element widths listed below. +- **latency:** `INTLV` is **12** cycles。 + +**Distribution families:** + +| Family | Allowed element widths | C semantics | Latency | +|------|-------------|-------------|-------------| +| `INTLV` | `b8`, `b16`, `b32` | Interleave `%low` / `%high` into one destination stream | **12** cycles | +| `INTLV` | `b8`, `b16`, `b32` | + +```c +// INTLV family on 32-bit elements: +for (int i = 0; i < 64; i++) { + UB[base + 8*i] = low[i]; + UB[base + 8*i + 4] = high[i]; +} +``` + +##### `pto.vsstb` + +- **syntax:** `pto.vsstb %value, %dest, %block_stride, %repeat_stride, %mask : !pto.vreg, !pto.ptr, i16, i16, !pto.mask` +- **semantics:** Block-strided store for 2D tile access. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, + `%block_stride` and `%repeat_stride` are the two 16-bit fields of the + hardware control word, and `%mask` controls block participation. +- **outputs:** + This op writes UB memory and returns no SSA value. +- **constraints and limitations:** + PTO surface does not expose the packed control word directly. Masked-off + blocks MUST NOT issue memory writes. +- **Latency:** **9** cycles. + +```c +// Block-strided store on 32-bit elements: one 32B block = 8 lanes. +for (int blk = 0; blk < 8; ++blk) { + if (pg_b32[blk]) + UB_block[base + repeat_stride + blk * block_stride] = src_block[blk]; +} +``` + +--- + +#### Scatter (Indexed) Stores + +##### `pto.vscatter` + +- **syntax:** `pto.vscatter %value, %dest, %offsets, %active_lanes : !pto.vreg, !pto.ptr, !pto.vreg, index` +- **semantics:** Indexed scatter to UB. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, `%offsets` + provides per-lane or per-block indices, and `%active_lanes` bounds the active + requests. +- **outputs:** + This op writes UB memory and returns no SSA value. +- **constraints and limitations:** + Only `b8`, `b16`, and `b32` element sizes are supported. The index vector + must use a supported integer element type and layout for this family. + Each computed address MUST be element-aligned. If two or more indices alias, + only one write is guaranteed and the winning lane is implementation-defined. +- **Latency:** **~17** cycles for **`Dtype: B16`**. + +```c +for (int i = 0; i < active_lanes; i++) + UB[base + offsets[i] * sizeof(T)] = src[i]; +``` + +--- + +#### Alignment State Stores + +##### `pto.vstas` +- **syntax:** `pto.vstas %value, %dest, %offset : !pto.align, !pto.ptr, i32` +- **semantics:** Scalar-register-offset form of alignment-state flush. +- **inputs:** + `%value` is the pending store-alignment state, `%dest` is the UB base + pointer, and `%offset` is the scalar-register style displacement. +- **outputs:** + This op writes buffered tail bytes to UB and returns no SSA value. +- **constraints and limitations:** + This family flushes pending store-alignment state using an explicit scalar + offset and keeps the scalar-offset form explicit. The incoming `%value` + should come from `pto.init_align` or from a prior state-producing unaligned + store op in the same stream. `%dest` and `%offset` together must identify the + same logical flush point produced by the immediately preceding stateful + unaligned-store step on that stream; using an unrelated base/offset pair is + invalid even if `%value` itself came from the same stream. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstar` +- **syntax:** `pto.vstar %value, %dest : !pto.align, !pto.ptr` +- **semantics:** Flush alignment state using the register-update form. +- **inputs:** + `%value` is the pending store-alignment state and `%dest` is the UB base + pointer. +- **outputs:** + This op writes buffered tail bytes to UB and returns no SSA value. +- **constraints and limitations:** + The implicit update state consumed by this flush MUST correspond to the same + store stream that produced `%value`. The first store-side state in a stream + should be created by `pto.init_align`. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstar` + +- **syntax:** `pto.vstar %value, %dest : !pto.align, !pto.ptr` +- **semantics:** Flush remaining alignment state. +- **inputs:** + `%value` is the pending alignment/buffer state that still needs to be emitted, + and `%dest` is the UB destination base pointer. +- **outputs:** + No SSA result. The effect is a memory-side flush that writes the remaining + buffered bytes to memory. +- **constraints and limitations:** + This op terminates an unaligned-store sequence. It MUST be paired with a + compatible prior state-producing store sequence so that the pending tail state + is well-defined. +- **Latency:** **9** cycles. + +--- + +#### Stateful Store Ops + +These ops make reference-updated state explicit as SSA results. + +##### `pto.vstus` + +- **syntax:** `%align_out = pto.vstus %align_in, %offset, %value, %base : !pto.align, i32, !pto.vreg, !pto.ptr -> !pto.align` +- **semantics:** No-post unaligned store with scalar offset. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%offset` is the scalar + displacement, `%value` is the vector being stored, and `%base` is the UB base + pointer. +- **outputs:** + `%align_out` is the updated buffered-tail state. +- **constraints and limitations:** + This is the scalar-offset stateful form of the unaligned store family. The + scalar offset width MUST match the selected form, and a later flush op is + still required. The first `%align_in` in the stream should come from + `pto.init_align`. This op does not mean "store a full vector starting at + `%base + %offset`". Instead, `%offset` describes how far the store stream + advances at this step, and `%align_out` carries any residual tail that could + not be committed yet. The no-post surface does not expose an updated base + pointer. A later flush op must therefore use an explicit destination/offset + pair that identifies the same logical flush point as this `pto.vstus`. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstur` + +- **syntax:** `%align_out = pto.vstur %align_in, %value, %base, "MODE" : !pto.align, !pto.vreg, !pto.ptr -> !pto.align` +- **semantics:** Unaligned store with residual flush and SPR-AR-driven state update. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%value` is the vector to + store, `%base` is the UB base pointer, and `MODE` selects whether the + hardware updates `SPR AR` after the store. +- **outputs:** + `%align_out` is the updated residual state after the current partial store. +- **constraints and limitations:** + The effective address is `base + AR`, where `AR` is the hardware SPR state + carried outside SSA. `POST_UPDATE` means hardware may advance `SPR AR` + according to the fixed `SPR SQZN` configuration; `NO_POST_UPDATE` preserves + the current `SPR AR` value. This form exposes only the evolving residual + align-state in SSA; it does not by itself guarantee that all buffered bytes + have reached memory. A compatible final flush is still required unless the + surrounding sequence is known to be complete. Independent sequences typically + begin from `AR = 0`; if the surrounding program does not already guarantee + that, the hardware sequence should clear `SPR AR` before the first dependent + `pto.vstur`. The first `%align_in` in the stream should come from + `pto.init_align`. `pto.vstur` also consumes the fixed `SPR SQZN` state, so a + preceding squeeze producer such as `pto.vsqz` / `pto.vusqz` MUST establish + the byte count before the store. `MODE` MUST be one of `POST_UPDATE` or + `NO_POST_UPDATE`. +- **Latency:** **9** cycles. + + + +### 4. Predicate Load/Store + +> **Category:** UB ↔ Predicate Register data movement +> **Pipeline:** PIPE_V (Vector Core) + +Predicate registers (`!pto.mask`) are 256-bit registers that enable per-lane conditional execution. These ops move predicate values between UB and predicate registers. + +In concrete examples, `G` should be chosen to match the consumer family. The +examples below use `b32` when the loaded/stored mask is used with `f32` +vector compares or selects. + +The predicate load/store ops documented on this page always use explicit +`base[offset]` addressing. The immediate forms (`pldi`, `psti`) and dynamic +forms (`plds`, `psts`) differ only in how `%offset` is supplied. + +--- + +#### Predicate Loads + +##### `pto.plds` + +- **syntax:** `%result = pto.plds %source[%offset], "DIST" : !pto.ptr, index -> !pto.mask` +- **semantics:** Load predicate register with runtime offset. This is the + dynamic-offset form of `pto.pldi`: the predicate payload interpretation is + the same, but `%offset` is supplied as an SSA `index` instead of a constant + `index` immediate. +- **DIST:** mandatory string token, one of `NORM`, `US`, `DS`. + - `NORM`: load a normal packed predicate payload of size `VL/8`. + - `US`: load a packed predicate payload of size `VL/16`, then duplicate each + loaded bit once. + - `DS`: load a packed predicate payload of size `2 * VL/8`, then keep one + bit out of every two bits. + +The loaded payload is a packed predicate image in UB. Consumer ops interpret +the resulting `!pto.mask` according to the mask granularity `G`. +`pto.plds` only +models the explicit `base[offset]` form. + +**Example:** +```mlir +%mask = pto.plds %ub[%c0], "NORM" : !pto.ptr, index -> !pto.mask +``` + +--- + +##### `pto.pldi` + +- **syntax:** `%result = pto.pldi %source[%offset], "DIST" : !pto.ptr, index -> !pto.mask` +- **offset:** must be a constant `index` immediate in PTO surface form. +- **semantics:** Load predicate register with immediate offset. +- **DIST:** mandatory string token, one of `NORM`, `US`, `DS`. + - `NORM`: load a normal packed predicate payload of size `VL/8`. + - `US`: load a packed predicate payload of size `VL/16`, then duplicate each + loaded bit once. + - `DS`: load a packed predicate payload of size `2 * VL/8`, then keep one + bit out of every two bits. + +Like `pto.plds`, this op reads a packed predicate payload from UB and +materializes it as `!pto.mask`. + +--- + +#### Predicate Stores + +##### `pto.psts` + +- **syntax:** `pto.psts %value, %dest[%offset], "DIST" : !pto.mask, !pto.ptr, index` +- **semantics:** Store predicate register with runtime offset. This is the + dynamic-offset form of `pto.psti`: the predicate payload interpretation is + the same, but `%offset` is supplied as an SSA `index` instead of a constant + `index` immediate. +- **DIST:** mandatory string token, one of `NORM`, `PK`. + - `NORM`: store the packed predicate payload into a normal destination space + of size `VL/8`. + - `PK`: store the packed predicate payload into a destination space of size + `VL/16`, keeping one bit out of every two bits. + +`pto.psts` stores the packed predicate payload represented by `!pto.mask`. +It only models the explicit `base[offset]` form. + +**Example:** +```mlir +pto.psts %mask, %ub[%c0], "NORM" : !pto.mask, !pto.ptr, index +``` + +--- + +##### `pto.psti` + +- **syntax:** `pto.psti %value, %dest[%offset], "DIST" : !pto.mask, !pto.ptr, index` +- **offset:** must be a constant `index` immediate in PTO surface form. +- **semantics:** Store predicate register with immediate offset. +- **DIST:** mandatory string token, one of `NORM`, `PK`. + - `NORM`: store the packed predicate payload into a normal destination space + of size `VL/8`. + - `PK`: store the packed predicate payload into a destination space of size + `VL/16`, keeping one bit out of every two bits. + +`pto.psti` and `pto.psts` store the packed predicate payload represented by +`!pto.mask`. The surface distinction is only immediate-offset versus +dynamic-offset. + +--- + +##### `pto.pstu` + +- **syntax:** `%align_out, %base_out = pto.pstu %align_in, %value, %base : !pto.align, !pto.mask, !pto.ptr -> !pto.align, !pto.ptr` +- **syntax:** `%align_out, %base_out = pto.pstu %align_in, %value, %base : !pto.align, !pto.mask, !pto.ptr -> !pto.align, !pto.ptr` +- **semantics:** Predicate unaligned store with align/base state update. The base type is fixed by mask granularity: `b16 <-> ui16`, `b32 <-> ui32`. +- **outputs:** + `%align_out` and `%base_out` are the updated unaligned-store state and are + intended to be used by a later `pto.pstu` call. +- **constraints and limitations:** + The first `%align_in` in a predicate unaligned-store stream should come from + `pto.init_align`. + +--- + +#### Typical Usage Pattern + +```mlir +// Generate comparison mask +%mask = pto.vcmp %v0, %v1, %seed, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + +// Store mask to UB for later use +pto.psts %mask, %ub_mask[%c0], "NORM" : !pto.mask, !pto.ptr, index + +// ... later in another kernel ... + +// Load mask from UB +%saved_mask = pto.plds %ub_mask[%c0], "NORM" : !pto.ptr, index -> !pto.mask + +// Use for predicated select +%result = pto.vsel %v_true, %v_false, %saved_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 5. Materialization & Predicate Ops + +> **Category:** Scalar broadcast, predicate generation and manipulation +> **Pipeline:** PIPE_V (Vector Core) + +These ops create vectors from scalar values and manipulate predicate registers. + +#### Common Operand Model + +- `%value` is the scalar source value in SSA form. +- `%input` is either a source scalar or a source vector depending on the op. +- `%result` is the destination vector register value. +- For 32-bit scalar inputs, the scalar source MUST satisfy the backend's legal + scalar-source constraints for this family. + +--- + +#### Scalar Materialization + +##### `pto.vbr` + +- **syntax:** `%result = pto.vbr %value : T -> !pto.vreg` +- **semantics:** Broadcast scalar to all vector lanes. +- **inputs:** + `%value` is the scalar source. +- **outputs:** + `%result` is a vector whose active lanes all carry `%value`. +- **constraints and limitations:** + Supported forms are `b8`, `b16`, and `b32`. For `b8`, only the low 8 bits of + the scalar source are consumed. + +```c +for (int i = 0; i < N; i++) + dst[i] = value; +``` + +**Example:** +```mlir +%one = pto.vbr %c1_f32 : f32 -> !pto.vreg<64xf32> +``` + +--- + +##### `pto.vdup` + +- **syntax:** `%result = pto.vdup %input, %mask {position = "LOWEST|HIGHEST"} : T|!pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Duplicate scalar or vector element to all lanes. +- **inputs:** + `%input` supplies the scalar or source-lane value selected by `position`, + and `%mask` controls the active lanes. +- **outputs:** + `%result` is the duplicated vector. +- **constraints and limitations:** + `position` selects which source vector element is duplicated and is only valid + for vector input. `position` defaults to `LOWEST`. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? input_scalar_or_element : 0; +``` + +--- + +#### Predicate Generation + +##### `pto.pset_b8` / `pto.pset_b16` / `pto.pset_b32` + +- **syntax:** `%result = pto.pset_b8 "PATTERN" : !pto.mask` +- **syntax:** `%result = pto.pset_b16 "PATTERN" : !pto.mask` +- **syntax:** `%result = pto.pset_b32 "PATTERN" : !pto.mask` +- **semantics:** Materialize a predicate register from a named pattern token. + +**Supported pattern tokens:** + +| Pattern | Description | +|---------|-------------| +| `PAT_ALL` | All lanes active | +| `PAT_ALLF` | All lanes inactive | +| `PAT_H` | High half active | +| `PAT_Q` | Upper quarter active | +| `PAT_VL1`...`PAT_VL128` | First N logical lanes active | +| `PAT_M3`, `PAT_M4` | Modular patterns | + +`PAT_ALL` is the PTO spelling of the VISA-style all-true predicate pattern. +The other tokens listed above are also concrete installed-toolchain pattern +objects, not PTO-only aliases. + +**Example — All 64 f32 lanes active:** +```mlir +%all_active = pto.pset_b32 "PAT_ALL" : !pto.mask +``` + +**Example — First 16 lanes active:** +```mlir +%first_16 = pto.pset_b32 "PAT_VL16" : !pto.mask +``` + +--- + +##### `pto.pge_b8` / `pto.pge_b16` / `pto.pge_b32` + +- **syntax:** `%result = pto.pge_b8 "PATTERN" : !pto.mask` +- **syntax:** `%result = pto.pge_b16 "PATTERN" : !pto.mask` +- **syntax:** `%result = pto.pge_b32 "PATTERN" : !pto.mask` +- **semantics:** Generate a predicate from a lane-count pattern token. In the + common tail-mask form, `PAT_VL` marks the first `N` logical lanes active. +- **supported pattern tokens:** `PAT_ALL`, `PAT_ALLF`, `PAT_H`, `PAT_Q`, + `PAT_VL1`, `PAT_VL2`, `PAT_VL3`, `PAT_VL4`, `PAT_VL8`, `PAT_VL16`, + `PAT_VL32`, `PAT_VL64`, `PAT_VL128`, `PAT_M3`, `PAT_M4` + +```c +for (int i = 0; i < TOTAL_LANES; i++) + mask[i] = (i < len); +``` + +**Example — Tail mask for remainder loop:** +```mlir +%tail_mask = pto.pge_b32 "PAT_VL8" : !pto.mask +``` + +--- + +##### `pto.plt_b8` / `pto.plt_b16` / `pto.plt_b32` + +- **syntax:** `%mask, %scalar_out = pto.plt_b8 %scalar : i32 -> !pto.mask, i32` +- **syntax:** `%mask, %scalar_out = pto.plt_b16 %scalar : i32 -> !pto.mask, i32` +- **syntax:** `%mask, %scalar_out = pto.plt_b32 %scalar : i32 -> !pto.mask, i32` +- **semantics:** Generate a tail-style predicate from an SSA lane-count value. + On A5/V300-style toolchains, this family is exposed as a post-update wrapper: + the predicate result becomes `%mask`, and the wrapper's carry-out scalar state + is surfaced as `%scalar_out`. +- **inputs:** + `%scalar` is the incoming lane-count / remaining-count state. +- **outputs:** + `%mask` is the generated predicate. + `%scalar_out` is the post-update scalar carry-out from the same `plt` call + and can be threaded into a subsequent `pto.plt_b*` call in the same chain. + +```c +for (int i = 0; i < VL_t; ++i) + mask[i] = (i < scalar_in); + +scalar_out = (scalar_in < VL_t) ? 0 : (scalar_in - VL_t); +``` + +Where `VL_t` is the logical lane count of the concrete op variant: + +- `pto.plt_b8`: `VL_t = 256` +- `pto.plt_b16`: `VL_t = 128` +- `pto.plt_b32`: `VL_t = 64` + +--- + +#### Predicate Pack/Unpack + +##### `pto.ppack` + +- **syntax:** `%result = pto.ppack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Narrowing pack of predicate register. +- **part tokens:** + - `LOWER`: pack into the lower half of `%result`; the upper half is zeroed. + - `HIGHER`: pack into the higher half of `%result`; the lower half is zeroed. + +Conceptually, `pto.ppack` keeps one bit out of each adjacent 2-bit group from +`%input`, packs those kept bits into the selected half of `%result`, and fills +the other half with zeros. + +```c +// Let VL be the logical lane count of the destination predicate. +// LOWER +for (int i = 0; i < VL / 2; ++i) + result[i] = input[2 * i]; +for (int i = VL / 2; i < VL; ++i) + result[i] = 0; + +// HIGHER +for (int i = 0; i < VL / 2; ++i) + result[VL / 2 + i] = input[2 * i]; +for (int i = 0; i < VL / 2; ++i) + result[i] = 0; +``` + +--- + +##### `pto.punpack` + +- **syntax:** `%result = pto.punpack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Widening unpack of predicate register. +- **part tokens:** + - `LOWER`: unpack from the lower half of `%input`. + - `HIGHER`: unpack from the higher half of `%input`. + +Conceptually, `pto.punpack` reads the selected half of `%input`, zero-extends +each 1-bit predicate element into a 2-bit group in `%result`, and leaves the +expanded image in the full destination predicate register. + +```c +// Let VL be the logical lane count of the destination predicate. +// LOWER +for (int i = 0; i < VL / 2; ++i) { + result[2 * i] = input[i]; + result[2 * i + 1] = 0; +} + +// HIGHER +for (int i = 0; i < VL / 2; ++i) { + result[2 * i] = input[VL / 2 + i]; + result[2 * i + 1] = 0; +} +``` + +--- + +#### Predicate Logical Ops + +##### `pto.pand` + +- **syntax:** `%result = pto.pand %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise AND gated by a governing predicate. + +Inactive lanes selected out by `%mask` are zeroed. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? (src0[i] & src1[i]) : 0; +``` + +--- + +##### `pto.por` + +- **syntax:** `%result = pto.por %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise OR gated by a governing predicate. + +Inactive lanes selected out by `%mask` are zeroed. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? (src0[i] | src1[i]) : 0; +``` + +--- + +##### `pto.pxor` + +- **syntax:** `%result = pto.pxor %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise XOR gated by a governing predicate. + +Inactive lanes selected by `%mask` are zeroed. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? (src0[i] ^ src1[i]) : 0; +``` + +--- + +##### `pto.pnot` + +- **syntax:** `%result = pto.pnot %input, %mask : !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise NOT gated by a governing predicate. + +Inactive lanes selected by `%mask` are zeroed. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? (~src[i]) : 0; +``` + +--- + +##### `pto.psel` + +- **syntax:** `%result = pto.psel %src0, %src1, %sel : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate select (mux). `%sel` is the governing predicate that + chooses lanes from `%src0` or `%src1`. + +```c +for (int i = 0; i < N; i++) + dst[i] = sel[i] ? src0[i] : src1[i]; +``` + +--- + +##### `pto.pdintlv_b8` / `pto.pdintlv_b16` / `pto.pdintlv_b32` + +- **syntax:** `%low, %high = pto.pdintlv_b8 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **syntax:** `%low, %high = pto.pdintlv_b16 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **syntax:** `%low, %high = pto.pdintlv_b32 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **semantics:** De-interleave two predicate sources and return the two + de-interleaved predicate images in the same predicate element family. + +--- + +##### `pto.pintlv_b8` / `pto.pintlv_b16` / `pto.pintlv_b32` + +- **syntax:** `%low, %high = pto.pintlv_b8 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **syntax:** `%low, %high = pto.pintlv_b16 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **syntax:** `%low, %high = pto.pintlv_b32 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **semantics:** Interleave two predicate sources and return the two + resulting predicate images in the same predicate element family. + +--- + +#### Typical Usage + +```mlir +// Generate all-active mask for f32 (64 lanes) +%all = pto.pset_b32 "PAT_ALL" : !pto.mask + +// Generate tail mask for remainder (last 12 elements) +%tail = pto.pge_b32 "PAT_VL12" : !pto.mask + +// Compare and generate mask +%cmp_mask = pto.vcmp %a, %b, %all, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + +// Combine masks: only process tail elements that passed comparison +%combined = pto.pand %cmp_mask, %tail, %all : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + +// Use for predicated operation +%result = pto.vsel %true_vals, %false_vals, %combined : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 6. Unary Vector Ops + +> **Category:** Single-input vector operations +> **Pipeline:** PIPE_V (Vector Core) + +Element-wise operations that take one vector input and produce one vector output. + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%mask` is the predicate operand. For this family, inactive lanes follow the + predication behavior of the selected instruction form: zeroing forms + zero-fill inactive lanes, while merging forms preserve the destination value. +- `%result` is the destination vector register value. Unless stated otherwise, + `%result` has the same lane count and element type as `%input`. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). **fp16** values use **aclFloat16** in traces where measured. **bf16:** no simple-tile ST coverage on this surface; treat as **—**. + +| PTO op | RV (CA) | fp32 | fp16 | bf16 | +|--------|---------|------|------|------| +| `pto.vabs` | `RV_VABS_FP` | **5** | **5** | — | +| `pto.vneg` | `RV_VMULS` | **8** | **8** | — | +| `pto.vexp` | `RV_VEXP` | **16** | **21** | — | +| `pto.vln` | `RV_VLN` | **18** | **23** | — | +| `pto.vsqrt` | `RV_VSQRT` | **17** | **22** | — | +| `pto.vrelu` | `RV_VRELU` | **5** | **5** | — | +| `pto.vnot` | `RV_VNOT` | — | int-only paths | — | +| `pto.vmov` | `RV_VLD` proxy | **9** | **9** | — | + +--- + +#### Arithmetic + +##### `pto.vabs` + +- **syntax:** `%result = pto.vabs %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] < 0) ? -src[i] : src[i]; +``` + +- **inputs:** `%input` supplies the source lanes and `%mask` selects which lanes + participate. +- **outputs:** `%result` receives the lane-wise absolute values. +- **constraints and limitations:** Source and result types MUST match. On A5, + integer overflow follows the ISA default truncation behavior for this family; + `pto.vabs` is not an explicit saturating op. + +--- + +##### `pto.vneg` + +- **syntax:** `%result = pto.vneg %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = -src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise arithmetic negation. +- **constraints and limitations:** Source and result types MUST match. + +--- + +#### Transcendental + +##### `pto.vexp` + +- **syntax:** `%result = pto.vexp %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = expf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds `exp(input[i])` per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + +--- + +##### `pto.vln` + +- **syntax:** `%result = pto.vln %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = logf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the natural logarithm per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + For real-number semantics, active inputs SHOULD be strictly positive; non- + positive inputs follow the target's exception/NaN rules. + +--- + +##### `pto.vsqrt` + +- **syntax:** `%result = pto.vsqrt %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = sqrtf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the square root per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + Negative active inputs follow the target's exception/NaN rules. + +--- + +#### Activation + +##### `pto.vrelu` + +- **syntax:** `%result = pto.vrelu %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** si32, i32, f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] > 0) ? src[i] : 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds `max(input[i], 0)` per active lane. +- **constraints and limitations:** Signed or signless 32-bit integer and + floating-point element types are legal on the current A5 surface described + here. + +--- + +#### Bitwise + +##### `pto.vnot` + +- **syntax:** `%result = pto.vnot %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = ~src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the lane-wise bitwise inversion. +- **constraints and limitations:** Integer element types only. + +--- + +#### Movement + +#### Typical Usage + +```mlir +// Softmax numerator: exp(x - max) +%sub = pto.vsub %x, %max_broadcast, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%exp = pto.vexp %sub, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// ReLU activation +%activated = pto.vrelu %linear_out, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 7. Binary Vector Ops + +> **Category:** Two-input vector operations +> **Pipeline:** PIPE_V (Vector Core) + +Element-wise operations that take two vector inputs and produce one vector output. + +#### Common Operand Model + +- `%lhs` and `%rhs` are the two source vector register values. +- `%mask` is the predicate operand `Pg` that gates which lanes participate. +- `%result` is the destination vector register value. Unless explicitly noted, + it has the same lane count and element type as the inputs. +- Unless explicitly documented otherwise, `%lhs`, `%rhs`, and `%result` MUST + have matching vector shapes and element types. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). **fp16** uses **aclFloat16** in measured traces. **bf16:** — (no dedicated vec tile ST on this surface). + +| PTO op | RV (CA) | fp32 | fp16 | bf16 | +|--------|---------|------|------|------| +| `pto.vadd` | `RV_VADD` | **7** | **7** | — | +| `pto.vsub` | `RV_VSUB` | **7** | **7** | — | +| `pto.vmul` | `RV_VMUL` | **8** | **8** | — | +| `pto.vdiv` | `RV_VDIV` | **17** | **22** | — | + +--- + +#### Arithmetic + +##### `pto.vadd` + +- **syntax:** `%result = pto.vadd %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i64, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] + src1[i]; +``` + +- **inputs:** `%lhs` and `%rhs` are added lane-wise; `%mask` selects active + lanes. +- **outputs:** `%result` is the lane-wise sum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vsub` + +- **syntax:** `%result = pto.vsub %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i64, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] - src1[i]; +``` + +- **inputs:** `%lhs` is the minuend, `%rhs` is the subtrahend, and `%mask` + selects active lanes. +- **outputs:** `%result` is the lane-wise difference. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vmul` + +- **syntax:** `%result = pto.vmul %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, bf16, f32 (**NOT** i8/ui8) + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] * src1[i]; +``` + +- **inputs:** `%lhs` and `%rhs` are multiplied lane-wise; `%mask` selects + active lanes. +- **outputs:** `%result` is the lane-wise product. +- **constraints and limitations:** The current A5 profile excludes `i8/ui8` + forms from this surface. + +--- + +##### `pto.vdiv` + +- **syntax:** `%result = pto.vdiv %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 only (no integer division) + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] / src1[i]; +``` + +- **inputs:** `%lhs` is the numerator, `%rhs` is the denominator, and `%mask` + selects active lanes. +- **outputs:** `%result` is the lane-wise quotient. +- **constraints and limitations:** Floating-point element types only. Active + denominators containing `+0` or `-0` follow the target's exceptional + behavior. + +--- + +##### `pto.vmax` + +- **syntax:** `%result = pto.vmax %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src0[i] > src1[i]) ? src0[i] : src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` holds the lane-wise maximum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vmin` + +- **syntax:** `%result = pto.vmin %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src0[i] < src1[i]) ? src0[i] : src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` holds the lane-wise minimum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +#### Bitwise + +##### `pto.vand` + +- **syntax:** `%result = pto.vand %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] & src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise AND. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vor` + +- **syntax:** `%result = pto.vor %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] | src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise OR. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vxor` + +- **syntax:** `%result = pto.vxor %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] ^ src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise XOR. +- **constraints and limitations:** Integer element types only. + +--- + +#### Shift + +##### `pto.vshl` + +- **syntax:** `%result = pto.vshl %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] << src1[i]; +``` + +- **inputs:** `%lhs` supplies the shifted value, `%rhs` supplies the per-lane + shift amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. Shift counts + SHOULD stay within `[0, bitwidth(T) - 1]`; out-of-range behavior is target- + defined unless the verifier narrows it further. + +--- + +##### `pto.vshr` + +- **syntax:** `%result = pto.vshr %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] >> src1[i]; // arithmetic for signed, logical for unsigned +``` + +- **inputs:** `%lhs` supplies the shifted value, `%rhs` supplies the per-lane + shift amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. Signedness of the + element type determines arithmetic vs logical behavior. + +--- + +#### Carry Operations + +##### `pto.vaddc` + +- **syntax:** `%result, %carry = pto.vaddc %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Add with carry output. + +```c +for (int i = 0; i < N; i++) { + uint64_t r = (uint64_t)src0[i] + src1[i]; + dst[i] = (T)r; + carry[i] = (r >> bitwidth); +} +``` + +- **inputs:** `%lhs` and `%rhs` are added lane-wise and `%mask` selects active + lanes. +- **outputs:** `%result` is the truncated arithmetic result and `%carry` is the + carry/overflow predicate per lane. +- **A5 types:** `i32`, `si32`, `ui32` +- **constraints and limitations:** This is a carry-chain integer add family. On + the current A5 surface, only 32-bit integer element types are supported. + `%mask` and `%carry` therefore use the same typed-mask granularity as the + data vector family, which on the current documented A5 surface means + `!pto.mask`. + +--- + +##### `pto.vsubc` + +- **syntax:** `%result, %carry = pto.vsubc %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Subtract with per-lane carry output. + +```c +for (int i = 0; i < N; i++) { + dst[i] = src0[i] - src1[i]; + carry[i] = (src0[i] >= src1[i]); +} +``` + +- **inputs:** `%lhs` and `%rhs` are subtracted lane-wise and `%mask` selects + active lanes. +- **outputs:** `%result` is the arithmetic difference and `%carry` is the + per-lane carry predicate. For this subtraction family, active lanes set + `%carry[i] = 1` when the subtraction completes without borrow, and + `%carry[i] = 0` when a borrow occurs. +- **A5 types:** `i32`, `si32`, `ui32` +- **constraints and limitations:** This operation is currently restricted to + the 32-bit integer carry/borrow-chain family. `%mask` and `%carry` + therefore use the same typed-mask granularity as the data vector family, + which on the current documented A5 surface means `!pto.mask`. + +--- + +#### Typical Usage + +```mlir +// Vector addition +%sum = pto.vadd %a, %b, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Element-wise multiply +%prod = pto.vmul %x, %y, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Clamp to range [min, max] +%clamped_low = pto.vmax %input, %min_vec, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%clamped = pto.vmin %clamped_low, %max_vec, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Bit manipulation +%masked = pto.vand %data, %bitmask, %mask : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +``` + + + +### 8. Vec-Scalar Ops + +> **Category:** Vector-scalar operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that combine a vector with a scalar value, applying the scalar to every lane. + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%scalar` is the scalar operand in SSA form. +- `%mask` is the predicate operand. +- `%result` is the destination vector register value. +- For 32-bit scalar forms, the scalar source MUST satisfy the backend's legal + scalar-source constraints for this family. +- For elementwise vec-scalar families whose scalar conceptually matches the + vector element type (`pto.vadds`, `pto.vmuls`, `pto.vmaxs`, + `pto.vmins`, `pto.vlrelu`): + - signed integer vectors accept signed integer scalars with the same width, + and also accept signless `i` + - unsigned integer vectors accept unsigned integer scalars with the same + width, and also accept signless `i` + - signless integer vectors accept signless `i` +- `pto.vshls` and `pto.vshrs` are not part of that rule; their scalar operand + is the shift amount and remains fixed to `i16`. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). **fp16** uses **aclFloat16** in measured traces. **bf16:** —. + +| PTO op | RV (CA) | fp32 | fp16 | bf16 | +|--------|---------|------|------|------| +| `pto.vadds` | `RV_VADDS` | **7** | **7** | — | +| `pto.vmuls` | `RV_VMULS` | **8** | **8** | — | + +--- + +#### Arithmetic + +##### `pto.vadds` + +- **syntax:** `%result = pto.vadds %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` +- **A5 types:** `si8`, `si16`, `si32`, `ui8`, `ui16`, `ui32`, `f16`, `bf16`, `f32` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] + scalar; +``` + +- **inputs:** `%input` is the source vector, `%scalar` is broadcast logically to + each lane, and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise sum. +- **constraints and limitations:** Input vector element type, scalar type, and + result vector element type MUST match. For integer vector forms, `%scalar` + may also use matching-signedness integer or signless `i` with the same + bit width as the vector element type, so it can be fed directly from `arith` + constants. + +--- + +##### `pto.vmuls` + +- **syntax:** `%result = pto.vmuls %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] * scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise product. +- **constraints and limitations:** Supported element types are hardware-family + specific; the current PTO micro Instruction documentation covers the common + numeric cases. For integer vector forms, `%scalar` may use matching-signedness + integer or signless `i` with the same bit width as the vector element + type. + +--- + +##### `pto.vmaxs` + +- **syntax:** `%result = pto.vmaxs %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] > scalar) ? src[i] : scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise maximum. +- **constraints and limitations:** Input and result types MUST match. For + integer vector forms, `%scalar` may use matching-signedness integer or + signless `i` with the same bit width as the vector element type. + +--- + +##### `pto.vmins` + +- **syntax:** `%result = pto.vmins %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] < scalar) ? src[i] : scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise minimum. +- **constraints and limitations:** Input and result types MUST match. For + integer vector forms, `%scalar` may use matching-signedness integer or + signless `i` with the same bit width as the vector element type. + +--- + +#### Shift + +##### `pto.vshls` + +- **syntax:** `%result = pto.vshls %input, %scalar, %mask : !pto.vreg, i16, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] << scalar; +``` + +- **inputs:** `%input` is the value vector, `%scalar` is the uniform `i16` shift + amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. The shift amount + SHOULD stay within the source element width. + +--- + +##### `pto.vshrs` + +- **syntax:** `%result = pto.vshrs %input, %scalar, %mask : !pto.vreg, i16, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] >> scalar; +``` + +- **inputs:** `%input` is the value vector, `%scalar` is the uniform `i16` shift + amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vlrelu` + +- **syntax:** `%result = pto.vlrelu %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : scalar * src[i]; +``` + +- **inputs:** `%input` is the activation vector, `%scalar` is the leaky slope, + and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise leaky-ReLU result. +- **constraints and limitations:** Only `f16` and `f32` forms are currently + documented for `pto.vlrelu`. + +--- + +#### Carry Operations + +##### `pto.vaddcs` + +- **syntax:** `%result, %carry = pto.vaddcs %lhs, %rhs, %carry_in, %mask : !pto.vreg, !pto.vreg, !pto.mask, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Add with carry-in and carry-out. + +```c +for (int i = 0; i < N; i++) { + uint64_t r = (uint64_t)src0[i] + src1[i] + carry_in[i]; + dst[i] = (T)r; + carry_out[i] = (r >> bitwidth); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the value vectors, `%carry_in` is the + incoming carry predicate, and `%mask` selects active lanes. +- **outputs:** `%result` is the arithmetic result and `%carry` is the carry-out + predicate. +- **A5 types:** `i32`, `si32`, `ui32` +- **constraints and limitations:** This is the scalar-extended carry-chain + family. On the current A5 surface, only 32-bit integer element types are + supported. `%carry_in`, `%mask`, and `%carry` therefore all use the same + typed-mask granularity as the data vector family, which on the current + documented A5 surface means `!pto.mask`. + +--- + +##### `pto.vsubcs` + +- **syntax:** `%result, %carry = pto.vsubcs %lhs, %rhs, %carry_in, %mask : !pto.vreg, !pto.vreg, !pto.mask, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Subtract with carry input and output. + +```c +for (int i = 0; i < N; i++) { + dst[i] = src0[i] - src1[i] - (1 - carry_in[i]); + carry_out[i] = (src0[i] >= src1[i] + (1 - carry_in[i])); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the value vectors, `%carry_in` is the + incoming carry predicate, and `%mask` selects active lanes. +- **outputs:** `%result` is the arithmetic result and `%carry` is the + carry predicate after the lane-wise subtraction. For this subtraction family, + active lanes set `%carry[i] = 1` when the subtraction completes without + borrow, and `%carry[i] = 0` when a borrow occurs. +- **A5 types:** `i32`, `si32`, `ui32` +- **constraints and limitations:** This is the scalar-extended borrow-chain + family and is currently restricted to 32-bit integer element types. + `%carry_in`, `%mask`, and `%carry` therefore all use the same typed-mask + granularity as the data vector family, which on the current documented A5 + surface means `!pto.mask`. + +--- + +#### Typical Usage + +```mlir +// Add bias to all elements +%biased = pto.vadds %activation, %bias_scalar, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Scale by constant +%scaled = pto.vmuls %input, %scale, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Clamp to [0, 255] for uint8 quantization +%clamped_low = pto.vmaxs %input, %c0, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> +%clamped = pto.vmins %clamped_low, %c255, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Shift right by fixed amount +%shifted = pto.vshrs %data, %c4, %mask : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> +``` + + + +### 9. Conversion Ops + +> **Category:** Type conversion operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that convert between data types (float/int, narrowing/widening). + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%mask` is the predicate mask that selects active conversion lanes. +- `%result` is the destination vector register value. +- `rnd`, `sat`, and `part` are optional attributes that refine + conversion behavior when the selected source/destination type pair needs + rounding, saturation, or lane placement control. +- The single `pto.vcvt` surface covers float-int, float-float, int-float, and + int-int conversion families. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). Only representative traces below; other `pto.vcvt` conversion pairs depend on the RV lowering in the trace. + +| PTO op | RV (CA) | Note | Latency | +|--------|---------|------|---------| +| `pto.vcvt` | `RV_VCVT_F2F` | f32→f16 | **7** | +| `pto.vci` | — | no vector `RV_*` in sampled `veccore0` trace | — | + +--- + +#### `pto.vci` + +- **syntax:** `%result = pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- **semantics:** Generate a lane-index vector from a scalar seed/index value. +- **inputs:** + `%index` is the scalar seed or base index. +- **outputs:** + `%result` is the generated index vector. +- **constraints and limitations:** + This is an index-generation family, not a numeric conversion. `ORDER` and the + result element type together determine how indices are generated. `%result` + uses an integer element type, and the scalar `%index` type matches that + result element type. + +--- + +#### `pto.vcvt` + +- **syntax:** `%result = pto.vcvt %input, %mask {rnd = "RND", sat = "SAT", part = "PART"} : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Type conversion between float/int types with rounding control. + +```c +for (int i = 0; i < min(N, M); i++) + if (mask[i]) + dst[i] = convert(src[i], T0, T1, rnd); +``` + +- **inputs:** + `%input` is the source vector, `%mask` selects active lanes, and attributes + select rounding, saturation, and output placement when the conversion changes + width or packs into sub-lane positions. +- **outputs:** + `%result` is the converted vector. +- **constraints and limitations:** + Only documented source/destination type pairs are legal. All three + attributes are optional at the surface level, but only the subset meaningful + to the selected conversion kind should be provided. The execution mask must + use the typed-mask granularity that matches the source vector family on the + current surface; there is no `!pto.mask` form in VPTO. + +--- + +##### Rounding Modes + +| Mode | Description | +|------|-------------| +| `R` | Round to nearest, ties to even (default) | +| `A` | Round away from zero | +| `F` | Round toward negative infinity (floor) | +| `C` | Round toward positive infinity (ceil) | +| `Z` | Round toward zero (truncate) | +| `O` | Round to odd | + +--- + +##### Saturation Modes + +| Mode | Description | +|------|-------------| +| `SAT` | Saturate on overflow | +| `NOSAT` | No saturation (wrap/undefined on overflow) | + +--- + +##### Part Modes + +Use `part` when a width-changing conversion writes only one half of each wider +destination lane group. This is typically used in even/odd placement forms such +as `32 -> 16` or `16 -> 32` style conversions. + +| Mode | Description | +|------|-------------| +| `EVEN` | Output to even-indexed lanes | +| `ODD` | Output to odd-indexed lanes | + +--- + +##### Attribute Guidance + +- `rnd` + - Use when the conversion needs an explicit rounding rule, especially for + float-to-int, float-to-float narrowing, or integer-to-float forms that do + not map exactly. +- `mask` + - Use to select which source lanes participate in the conversion. In + width-changing conversions, `mask` works together with `part` / `pp` to + determine which logical lane positions are produced. +- `sat` + - Use when the conversion may overflow the destination range and hardware + exposes a saturating form. +- `part` + - Use for width-changing conversions that select the even or odd half of the + destination packing layout. + +###### Float To Int + +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<32xsi64>` +- `%dst = pto.vcvt %src, %mask {rnd, sat} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xsi32>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {rnd, part} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xsi32>` +- `%dst = pto.vcvt %src, %mask {rnd, sat} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<256xsi8>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<64xsi32>` + +###### Float To Float + +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xbf16>` +- `%dst = pto.vcvt %src, %mask {rnd, sat} : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<128xf16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<64xf32>` + +###### Int To Float + +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<128xf16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xsi8>, !pto.mask -> !pto.vreg<128xf16>` +- `%dst = pto.vcvt %src, %mask {rnd} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<128xf16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<64xf32>` +- `%dst = pto.vcvt %src, %mask {rnd} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<64xf32>` + +###### Int To Int + +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<128xui16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<64xui32>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xsi8>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xsi8>, !pto.mask -> !pto.vreg<64xsi32>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<64xui32>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<64xui32>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<64xsi32>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<128xui16>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<128xui16>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<32xsi64>` + +##### A5 Supported Type Matrix + +The table below is only a summary. For exact attribute combinations, use the +per-form entries above as the source of truth. + +| `src \ dst` | `ui8` | `si8` | `ui16` | `si16` | `ui32` | `si32` | `si64` | `f16` | `f32` | `bf16` | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| `ui8` | | | Y | | Y | | | Y | | | +| `si8` | | | | Y | | Y | | Y | | | +| `ui16` | Y | | | | Y | | | | | | +| `si16` | Y | | | | Y | Y | | Y | Y | | +| `ui32` | Y | | Y | Y | | | | | | | +| `si32` | Y | | Y | Y | | | Y | | Y | | +| `si64` | | | | | | | | | | | +| `f16` | Y | Y | | Y | | Y | | | Y | | +| `f32` | | | | Y | | Y | Y | Y | | Y | +| `bf16` | | | | | | Y | | Y | Y | | + +--- + +##### Width-Changing Conversion Pattern + +For conversions that change width (e.g., f32→f16), use even/odd parts and combine: + +```mlir +// Convert two f32 vectors to one f16 vector +%even = pto.vcvt %in0, %mask {rnd = "R", sat = "SAT", part = "EVEN"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +%odd = pto.vcvt %in1, %mask {rnd = "R", sat = "SAT", part = "ODD"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +%result = pto.vor %even, %odd, %mask : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> +``` + +--- + +#### `pto.vtrc` + +- **syntax:** `%result = pto.vtrc %input, %mask, "RND" : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Truncate/round float to integer-valued float (stays in float type). + +```c +for (int i = 0; i < N; i++) + dst[i] = round_to_int_valued_float(src[i], rnd); +``` + +- **inputs:** + `%input` is the floating-point source vector, `%mask` selects active lanes, + and `RND` selects the truncation/rounding rule. +- **outputs:** + `%result` is still a floating-point vector, but each active lane now carries + an integer-valued floating-point result. +- **constraints and limitations:** + This op does not change the element type. `T` must be `f16`, `f32`, or + `bf16`. `RND` must be one of `R`, `A`, `F`, `C`, or `Z`. `BW` must match the + element width: `b16` for `f16`/`bf16`, `b32` for `f32`. + +**Example:** +```mlir +// Round to nearest integer, keep as float +%rounded = pto.vtrc %input, %mask, "R" : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// input: [1.4, 2.6, -1.5, 3.0] +// output: [1.0, 3.0, -2.0, 3.0] +``` + +--- + +#### Typical Usage + +```mlir +// Quantization: f32 → i8 with saturation +%scaled = pto.vmuls %input, %scale, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> +%quantized = pto.vcvt %scaled, %mask {rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xi32> +// Then narrow i32 → i8 via pack ops + +// Mixed precision: bf16 → f32 for accumulation +%f32_vec = pto.vcvt %bf16_input, %mask {part = "EVEN"} + : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<64xf32> + +// Floor for integer division +%floored = pto.vtrc %ratio, %mask, "F" : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%int_div = pto.vcvt %floored, %mask {rnd = "Z"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xi32> +``` + + + +### 10. Reduction Ops + +> **Category:** Vector reduction operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that reduce a vector to a scalar or per-group result. + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%mask` is the predicate operand `Pg`; inactive lanes do not participate. +- `%result` is the destination vector register value. +- Reduction results are written into the low-significance portion of the + destination vector and the remaining destination bits are zero-filled. + +--- + +#### Full Vector Reductions + +##### `pto.vcadd` + +- **syntax:** `%result = pto.vcadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i64, f16, f32 +- **semantics:** Sum all elements. Result in lane 0, others zeroed. + +```c +T sum = 0; +for (int i = 0; i < N; i++) + sum += src[i]; +dst[0] = sum; +for (int i = 1; i < N; i++) + dst[i] = 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains the reduction result in its low element(s). +- **constraints and limitations:** Some narrow integer forms may widen the + internal accumulation or result placement. If all predicate bits are zero, the + result is zero. + +--- + +##### `pto.vcmax` + +- **syntax:** `%result = pto.vcmax %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Find max element with argmax. The lowest destination element + stores the maximum value, the second-lowest destination element stores the + index of the first maximum, and all remaining elements are zero-filled. + +```c +T mx = -INF; int idx = 0; +for (int i = 0; i < N; i++) + if (src[i] > mx) { mx = src[i]; idx = i; } +dst[0] = mx; +dst[1] = idx; +for (int i = 2; i < N; i++) + dst[i] = 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result[0]` holds the extremum value and `%result[1]` holds the + index. Other destination elements are zero-filled. +- **constraints and limitations:** If there are multiple maxima, the minimum + index is written. For floating-point types, inactive lanes are treated as + `-INF`; if all lanes are inactive, `%result[0]` becomes `-INF`. For integer + types, inactive lanes are treated as the literal minimum value; if all lanes + are inactive, `%result[0]` becomes that literal minimum value. The index is + written into the second destination element slot of the same destination + vector register. + +--- + +##### `pto.vcmin` + +- **syntax:** `%result = pto.vcmin %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Find min element with argmin. The lowest destination element + stores the minimum value, the second-lowest destination element stores the + index of the first minimum, and all remaining elements are zero-filled. + +```c +T mn = INF; int idx = 0; +for (int i = 0; i < N; i++) + if (src[i] < mn) { mn = src[i]; idx = i; } +dst[0] = mn; +dst[1] = idx; +for (int i = 2; i < N; i++) + dst[i] = 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result[0]` holds the extremum value and `%result[1]` holds the + index. Other destination elements are zero-filled. +- **constraints and limitations:** If there are multiple minima, the minimum + index is written. For floating-point types, inactive lanes are treated as + `+INF`; if all lanes are inactive, `%result[0]` becomes `+INF`. For integer + types, inactive lanes are treated as the literal maximum value; if all lanes + are inactive, `%result[0]` becomes that literal maximum value. The index is + written into the second destination element slot of the same destination + vector register. + +--- + +#### Per-VLane (Group) Reductions + +The vector register is organized as **8 VLanes** of 32 bytes each. Group reductions operate within each VLane independently. + +``` +vreg layout (f32 example, 64 elements total): +VLane 0: [0..7] VLane 1: [8..15] VLane 2: [16..23] VLane 3: [24..31] +VLane 4: [32..39] VLane 5: [40..47] VLane 6: [48..55] VLane 7: [56..63] +``` + +##### `pto.vcgadd` + +- **syntax:** `%result = pto.vcgadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Sum within each VLane. 8 results at indices 0, 8, 16, 24, 32, 40, 48, 56 (for f32). + +```c +int K = N / 8; // elements per VLane +for (int g = 0; g < 8; g++) { + T sum = 0; + for (int i = 0; i < K; i++) + sum += src[g*K + i]; + dst[g*K] = sum; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +// For f32: results at dst[0], dst[8], dst[16], dst[24], dst[32], dst[40], dst[48], dst[56] +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one sum per 32-byte VLane group, written + contiguously into the low slot of each group. +- **constraints and limitations:** This is a per-32-byte VLane-group reduction. + Inactive lanes are treated as zero. + +--- + +##### `pto.vcgmax` + +- **syntax:** `%result = pto.vcgmax %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Max within each VLane. + +```c +int K = N / 8; +for (int g = 0; g < 8; g++) { + T mx = -INF; + for (int i = 0; i < K; i++) + if (src[g*K + i] > mx) mx = src[g*K + i]; + dst[g*K] = mx; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one maximum per 32-byte VLane group. +- **constraints and limitations:** Grouping is by hardware 32-byte VLane, not by + arbitrary software subvector. + +--- + +##### `pto.vcgmin` + +- **syntax:** `%result = pto.vcgmin %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Min within each VLane. + +```c +int K = N / 8; +for (int g = 0; g < 8; g++) { + T mn = INF; + for (int i = 0; i < K; i++) + if (src[g*K + i] < mn) mn = src[g*K + i]; + dst[g*K] = mn; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one minimum per 32-byte VLane group. +- **constraints and limitations:** Grouping is by hardware 32-byte VLane, not by + arbitrary software subvector. + +--- + +#### Prefix Operations + +##### `pto.vcpadd` + +- **syntax:** `%result = pto.vcpadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Inclusive prefix sum (scan). + +```c +dst[0] = src[0]; +for (int i = 1; i < N; i++) + dst[i] = dst[i-1] + src[i]; +``` + +**Example:** +```c +// input: [1, 2, 3, 4, 5, ...] +// output: [1, 3, 6, 10, 15, ...] +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` is the inclusive prefix-sum vector. +- **constraints and limitations:** Only floating-point element types are + documented on the current A5 surface here. + +--- + +#### Typical Usage + +```mlir +// Softmax: find max for numerical stability +%max_vec = pto.vcmax %logits, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// max is in lane 0, broadcast it +%max_broadcast = pto.vlds %ub_tmp[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> + +// Row-wise sum using vcgadd (for 8-row tile) +%row_sums = pto.vcgadd %tile, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// Results at indices 0, 8, 16, 24, 32, 40, 48, 56 + +// Full vector sum for normalization +%total = pto.vcadd %values, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// total[0] contains the sum + +// Prefix sum for cumulative distribution +%cdf = pto.vcpadd %pdf, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 11. Compare & Select + +> **Category:** Comparison and conditional selection operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that compare vectors and conditionally select elements. + +#### Common Operand Model + +- `%src0` and `%src1` are source vector operands. +- `%scalar` is the scalar operand for scalar-comparison families. +- `%seed` is the incoming predicate that limits which lanes participate in the + compare. +- `%result` is either a predicate mask (`vcmp`, `vcmps`) or a vector register + (`vsel`, `vselr`, `vselrv2`). + +--- + +#### Comparison Operations + +##### `pto.vcmp` + +- **syntax:** `%result = pto.vcmp %src0, %src1, %seed, "CMP_MODE" : !pto.vreg, !pto.vreg, !pto.mask -> !pto.mask` +- **semantics:** Element-wise comparison, output predicate mask. + +```c +for (int i = 0; i < N; i++) + if (seed[i]) + dst[i] = (src0[i] CMP src1[i]) ? 1 : 0; +``` + +**Compare modes:** + +| Mode | Operation | +|------|-----------| +| `eq` | Equal (==) | +| `ne` | Not equal (!=) | +| `lt` | Less than (<) | +| `le` | Less than or equal (<=) | +| `gt` | Greater than (>) | +| `ge` | Greater than or equal (>=) | + +**Example:** +```mlir +%all_active = pto.pset_b32 "PAT_ALL" : !pto.mask +%lt_mask = pto.vcmp %a, %b, %all_active, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +// lt_mask[i] = 1 if a[i] < b[i] +``` + +- **inputs:** `%src0`, `%src1`, and `%seed`; `CMP_MODE` selects the comparison + predicate. +- **outputs:** `%result` is the generated predicate mask. +- **constraints and limitations:** Only lanes enabled by `%seed` participate. + Integer and floating-point comparisons follow their own element-type-specific + comparison rules. `%seed` and `%result` keep the typed-mask granularity that + matches `%src0` / `%src1`. + +--- + +##### `pto.vcmps` + +- **syntax:** `%result = pto.vcmps %src, %scalar, %seed, "CMP_MODE" : !pto.vreg, T, !pto.mask -> !pto.mask` +- **semantics:** Compare vector against scalar. + +```c +for (int i = 0; i < N; i++) + if (seed[i]) + dst[i] = (src[i] CMP scalar) ? 1 : 0; +``` + +**Example:** +```mlir +%positive_mask = pto.vcmps %values, %c0_f32, %all_active, "gt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +// positive_mask[i] = 1 if values[i] > 0 +``` + +- **inputs:** `%src` is the vector source, `%scalar` is the scalar comparison + value, and `%seed` is the incoming predicate. +- **outputs:** `%result` is the generated predicate mask. +- **constraints and limitations:** For 32-bit scalar forms, the scalar source + MUST satisfy the backend's legal scalar-source constraints for this family. + `%seed` and `%result` keep the typed-mask granularity that matches `%src`. + +--- + +#### Selection Operations + +##### `pto.vsel` + +- **syntax:** `%result = pto.vsel %src0, %src1, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Per-lane select based on mask. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? src0[i] : src1[i]; +``` + +**Example — Conditional assignment:** +```mlir +// dst = mask ? true_vals : false_vals +%result = pto.vsel %true_vals, %false_vals, %condition + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +- **inputs:** `%src0` is the true-path vector, `%src1` is the false-path vector, + and `%mask` selects between them. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** Source vectors and result MUST have matching + vector shapes and element types. `%mask` keeps the typed-mask granularity + that matches the selected vector family. + +--- + +##### `pto.vselr` + +- **syntax:** `%result = pto.vselr %src, %idx : !pto.vreg, !pto.vreg> -> !pto.vreg` +- **semantics:** Lane-select by index vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = src[idx[i]]; +``` + +- **inputs:** `%src` is the source vector. `%idx` is the lane-index vector. +- **outputs:** `%result` is the reordered vector. +- **constraints and limitations:** `%idx` must use integer elements. `%idx` + must have the same lane count as `%src`, and its integer element width must + match the bit width of `%src` element type. + +--- + +##### `pto.vselrv2` + +- **syntax:** `%result = pto.vselrv2 %src0, %src1 : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Variant select form with the same current two-vector operand shape. +- **inputs:** `%src0` and `%src1` are the source vectors. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** This page records the surface shape only. + Lowering MUST preserve the exact A5 variant semantics selected for this form. + +--- + +#### Typical Usage + +```mlir +// Clamp negative values to zero (manual ReLU) +%all = pto.pset_b32 "PAT_ALL" : !pto.mask +%zero = pto.vbr %c0_f32 : f32 -> !pto.vreg<64xf32> +%neg_mask = pto.vcmps %input, %c0_f32, %all, "lt" : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%clamped = pto.vsel %zero, %input, %neg_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Element-wise max via compare+select +%gt_mask = pto.vcmp %a, %b, %all, "gt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +%max_ab = pto.vsel %a, %b, %gt_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Threshold filter +%above_thresh = pto.vcmps %scores, %threshold, %all, "ge" : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%filtered = pto.vsel %scores, %zero, %above_thresh : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +--- + +#### Compare + Select Pattern + +```mlir +// Softmax safe exp: exp(x - max) where x < max returns exp of negative +// but we want to clamp to avoid underflow + +%all = pto.pset_b32 "PAT_ALL" : !pto.mask + +// 1. Compare against threshold +%too_small = pto.vcmps %x_minus_max, %min_exp_arg, %all, "lt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask + +// 2. Clamp values below threshold +%clamped = pto.vsel %min_exp_arg_vec, %x_minus_max, %too_small + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// 3. Safe exp +%exp_result = pto.vexp %clamped, %all : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 12. Data Rearrangement + +> **Category:** In-register data movement and permutation +> **Pipeline:** PIPE_V (Vector Core) + +Operations that rearrange data within or between vector registers without memory access. + +#### Common Operand Model + +- `%lhs` / `%rhs` are source vector register values. +- `%src` is a single source vector register value. +- `%result` is the destination vector register value unless an op explicitly + returns multiple vectors. +- These families do not access UB directly; they only rearrange register + contents. + +--- + +#### Interleave / Deinterleave + +##### `pto.vintlv` + +- **syntax:** `%low, %high = pto.vintlv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg, !pto.vreg` +- **semantics:** Interleave elements from two sources. + +```c +// Interleave: merge even/odd elements from two sources +// low = {src0[0], src1[0], src0[1], src1[1], ...} +// high = {src0[N/2], src1[N/2], src0[N/2+1], src1[N/2+1], ...} +``` + +- **inputs:** `%lhs` and `%rhs` are the two source vectors. +- **outputs:** `%low` and `%high` are the two destination vectors. +- **constraints and limitations:** The two outputs form a paired interleave + result. The PTO micro Instruction representation exposes that pair as two SSA results, and the pair ordering MUST + be preserved. + +--- + +##### `pto.vdintlv` + +- **syntax:** `%low, %high = pto.vdintlv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg, !pto.vreg` +- **semantics:** Deinterleave elements into even/odd. + +```c +// Deinterleave: separate even/odd elements +// low = {src0[0], src0[2], src0[4], ...} // even +// high = {src0[1], src0[3], src0[5], ...} // odd +``` + +- **inputs:** `%lhs` and `%rhs` represent the interleaved source stream in the + current PTO micro Instruction representation. +- **outputs:** `%low` and `%high` are the separated destination vectors. +- **constraints and limitations:** The two outputs form the even/odd + deinterleave result pair, and their ordering MUST be preserved. + +--- + +#### Compress / Expand + +##### `pto.vsqz` + +- **syntax:** `%result = pto.vsqz %src, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Compress — pack active lanes to front. + +```c +int j = 0; +for (int i = 0; i < N; i++) + if (mask[i]) dst[j++] = src[i]; +while (j < N) dst[j++] = 0; +``` + +**Use case:** Sparse data compaction, filtering. + +- **inputs:** `%src` is the source vector and `%mask` selects which elements are + kept. +- **outputs:** `%result` is the compacted vector. +- **constraints and limitations:** This is a reduction-style compaction family. + Preserved element order MUST match source lane order. + +--- + +##### `pto.vusqz` + +- **syntax:** `%result = pto.vusqz %src, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Generate per-lane prefix counts from the governing predicate. + +```c +dst[0] = 0; +for (int i = 1; i < N; i++) + dst[i] = mask[i - 1] ? (dst[i - 1] + 1) : dst[i - 1]; +``` + +- **inputs:** `%mask` is the governing predicate. The current PTO surface keeps + `%src` in the operand list for interface compatibility, but the observable + result semantics are determined by `%mask`. +- **outputs:** `%result[i]` equals the number of active lanes in `%mask[0:i)`, + with `%result[0] = 0`. +- **constraints and limitations:** `T` is currently limited to `si8`, `si16`, + or `si32`. This operation is a predicate-derived counting/rearrangement + primitive rather than a value-placement primitive. The final predicate lane + does not contribute to a later output lane because there is no `dst[N]`. + +--- + +--- + +##### `pto.vselr` + +- **syntax:** `%result = pto.vselr %src, %idx : !pto.vreg, !pto.vreg> -> !pto.vreg` +- **semantics:** Register lane-select with an explicit index vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = src[idx[i]]; +``` + +- **inputs:** `%src` is the source vector. `%idx` is the lane-index vector. +- **outputs:** `%result` is the reordered vector. +- **constraints and limitations:** This page records the rearrangement use of + the family; the compare/select page documents the same name from the predicate + selection perspective. + +--- + +#### Pack / Unpack + +##### `pto.vpack` + +- **syntax:** `%result = pto.vpack %src, "PART" : !pto.vreg -> !pto.vreg<2NxT_narrow>` +- **semantics:** Narrow one wide vector and place the narrowed payload into the + selected half of the result. The other half is filled with zero. + +```c +// e.g., vreg<64xi32> → vreg<128xui16> +for (int i = 0; i < N; i++) + dst[i] = 0; + +if (part == LOWER) { + for (int i = 0; i < N; i++) + dst[i] = truncate(src[i]); +} else { // HIGHER + for (int i = 0; i < N; i++) + dst[N + i] = truncate(src[i]); +} +``` + +- **inputs:** `%src` is the wide source vector. `"LOWER"` and `"HIGHER"` + select whether the narrowed payload lands in the lower or upper half. +- **outputs:** `%result` is the packed narrow vector. +- **constraints and limitations:** Packing is a narrowing conversion with + truncation semantics. Current VPTO surface supports `i32/ui32 -> ui16` and + `i16/ui16 -> ui8`. + +--- + +##### `pto.vsunpack` + +- **syntax:** `%result = pto.vsunpack %src, %part : !pto.vreg, index -> !pto.vreg` +- **semantics:** Sign-extending unpack — narrow to wide (half). + +```c +// e.g., vreg<128xi16> → vreg<64xi32> (one half) +for (int i = 0; i < N/2; i++) + dst[i] = sign_extend(src[part_offset + i]); +``` + +- **inputs:** `%src` is the packed narrow vector and `%part` selects which half + is unpacked. +- **outputs:** `%result` is the widened vector. +- **constraints and limitations:** This is the sign-extending unpack family. + +--- + +##### `pto.vzunpack` + +- **syntax:** `%result = pto.vzunpack %src, %part : !pto.vreg, index -> !pto.vreg` +- **semantics:** Zero-extending unpack — narrow to wide (half). + +```c +for (int i = 0; i < N/2; i++) + dst[i] = zero_extend(src[part_offset + i]); +``` + +- **inputs:** `%src` is the packed narrow vector and `%part` selects which half + is unpacked. +- **outputs:** `%result` is the widened vector. +- **constraints and limitations:** This is the zero-extending unpack family. + +--- + +#### Typical Usage + +```mlir +// AoS → SoA conversion using deinterleave +%even, %odd = pto.vdintlv %interleaved0, %interleaved1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +// Filter: keep only elements passing condition +%pass_mask = pto.vcmps %values, %threshold, %all, "gt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%compacted = pto.vsqz %values, %pass_mask + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Type narrowing via pack +%packed_i16 = pto.vpack %wide_i32, "LOWER" + : !pto.vreg<64xi32> -> !pto.vreg<128xui16> +``` + +--- + +#### V2 Interleave Forms + +##### `pto.vintlvv2` + +- **syntax:** `%result = pto.vintlvv2 %lhs, %rhs, "PART" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **inputs:** `%lhs` and `%rhs` are source vectors and `PART` selects the + returned half of the V2 interleave result. +- **outputs:** `%result` is the selected interleave half. +- **constraints and limitations:** This op exposes only one half of the V2 + result in SSA form. + +##### `pto.vdintlvv2` + +- **syntax:** `%result = pto.vdintlvv2 %lhs, %rhs, "PART" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **inputs:** `%lhs` and `%rhs` are source vectors and `PART` selects the + returned half of the V2 deinterleave result. +- **outputs:** `%result` is the selected deinterleave half. +- **constraints and limitations:** This op exposes only one half of the V2 + result in SSA form. + + + +### 13. DSA/SFU Ops + +> **Category:** Domain-specific accelerator and special function unit operations +> **Pipeline:** PIPE_V (Vector Core) / SFU + +Fused operations, special functions, and UB-to-UB operations that leverage hardware acceleration. + +#### Common Operand Model + +- `%input`, `%lhs`, `%rhs`, `%acc`, and `%alpha` are source SSA values whose + roles are called out per instruction. +- `%mask` is the predicate operand `Pg` when present. +- `%result` is the destination SSA value. +- This page mixes three different backend shapes: pure `vreg -> vreg` ops, + conversion/fusion ops, and UB-to-UB helpers. Each instruction section calls + out which storage model it uses. + +--- + +#### Fused Activation Ops (vreg→vreg) + +##### `pto.vlrelu` + +- **syntax:** `%result = pto.vlrelu %input, %alpha, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Leaky ReLU with scalar alpha. + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : alpha * src[i]; +``` + +- **inputs:** `%input` is the activation vector, `%alpha` is the scalar slope, + and `%mask` selects active lanes. +- **outputs:** `%result` is the leaky-ReLU vector. +- **constraints and limitations:** Only `f16` and `f32` forms are currently + documented for `pto.vlrelu`. + +--- + +##### `pto.vprelu` + +- **syntax:** `%result = pto.vprelu %input, %alpha : !pto.vreg, !pto.vreg -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Parametric ReLU with per-element alpha vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : alpha[i] * src[i]; +``` + +- **inputs:** `%input` is the activation vector and `%alpha` is the per-element + slope vector. +- **outputs:** `%result` is the parametric-ReLU vector. +- **constraints and limitations:** Floating-point element types only on the + current A5 surface. + +--- + +##### `pto.vexpdif` + +- **syntax:** `%result = pto.vexpdif %input, %max, %mask, "EVEN|ODD" : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** input `f16` or `f32`, output `f32` +- **semantics:** Fused exp(x - max) for numerically stable softmax. + +```c +for (int i = 0; i < N; i++) + dst[i] = expf(src[i] - max[i]); +``` + +**Use case:** Softmax numerator computation with numerical stability. + +- **inputs:** `%input` is the source vector and `%max` is the broadcasted + subtraction term. `%part` selects `EVEN` or `ODD` for the + underlying hardware contract. +- **outputs:** `%result` is the fused `exp(input - max)` vector with `f32` + elements. +- **constraints and limitations:** Source vectors must be `f16` or `f32`, the + result vector must be `f32`, and source/result storage width must match. + +--- + +#### Fused Compute+Convert Ops + +##### `pto.vaxpy` + +- **syntax:** `%result = pto.vaxpy %src0, %src1, %alpha, %mask : !pto.vreg, !pto.vreg, T, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** AXPY — scalar-vector multiply-add. + +```c +for (int i = 0; i < N; i++) + dst[i] = alpha * src0[i] + src1[i]; +``` + +- **inputs:** `%src0` is the scaled vector, `%src1` is the addend vector, + `%alpha` is the scalar multiplier, and `%mask` selects active lanes. +- **outputs:** `%result` is the fused AXPY result. +- **constraints and limitations:** Floating-point element types only on the + current documented surface. + +--- + +#### Extended Arithmetic + +##### `pto.vmull` + +- **syntax:** `%low, %high = pto.vmull %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.vreg` +- **A5 types:** i32/ui32 (native 32×32→64 widening multiply) +- **semantics:** Widening multiply with high/low results. + +```c +for (int i = 0; i < 64; i++) { + int64_t r = (int64_t)src0_i32[i] * (int64_t)src1_i32[i]; + dst_lo[i] = (int32_t)(r & 0xFFFFFFFF); + dst_hi[i] = (int32_t)(r >> 32); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the source vectors and `%mask` selects + active lanes. +- **outputs:** `%low` and `%high` expose the widened-product low/high parts. +- **constraints and limitations:** The current documented A5 form is the native + widening 32x32->64 integer multiply family. + +--- + +##### `pto.vmula` + +- **syntax:** `%result = pto.vmula %acc, %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Multiply-accumulate. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) + dst[i] = acc[i] + lhs[i] * rhs[i]; +``` + +- **inputs:** `%acc` is the accumulator input, `%lhs` and `%rhs` are the + multiplicands, and `%mask` selects active lanes. +- **outputs:** `%result` is the multiply-accumulate result. +- **constraints and limitations:** `pto.vmula` is a fused multiply-accumulate + operation and is not always interchangeable with separate `vmul` plus `vadd`. + +--- + +#### Index Generation + +##### `pto.vci` + +- **syntax:** `%result = pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- **semantics:** Generate lane index vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = base_index + i; +``` + +**Use case:** Generate indices for gather/scatter, argsort, etc. + +- **inputs:** `%index` is the scalar seed/base index. +- **outputs:** `%result` is the generated index vector. +- **constraints and limitations:** This page documents the arithmetic/indexing + use of the family; the conversion page also records the same opcode for + completeness. + +--- + +#### Sorting Operations + +##### `pto.vbitsort` + +- **syntax:** `pto.vbitsort %dest, %src, %indices, %repeat_times : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, index` +- **semantics:** Sort 32 region proposals by score and materialize sorted + proposal records into `%dest`. +- **inputs:** `%dest` is the UB destination buffer. `%src` is the UB score + buffer. `%indices` is the UB index buffer. `%repeat_times` is the repeat + count; each repeat processes the next adjacent group of 32 scores and 32 + indices. +- **outputs:** This op writes UB memory and returns no SSA value. Each output + record occupies 8 bytes: the upper 4 bytes hold the index and the lower + 4 bytes hold the score. For `f16` score forms, the score uses the lower + 2 bytes of that 4-byte score field and the upper 2 bytes are reserved. +- **constraints and limitations:** `%dest`, `%src`, and `%indices` MUST be + UB-backed pointers and SHOULD satisfy the backend alignment contract expected + by the A5 `VBS32` instruction. Scores are sorted in descending order, so the + highest score is written to the lowest destination address. Equal-score ties + preserve the earlier input proposal first. This is a UB helper, not a pure + `vreg -> vreg` op. + +--- + +##### `pto.vmrgsort4` + +- **syntax:** `pto.vmrgsort4 %dest, %src0, %src1, %src2, %src3, %count, %config : !pto.ptr, !pto.ptr, !pto.ptr, !pto.ptr, !pto.ptr, i64, i64` +- **semantics:** Merge-sort 4 pre-sorted input vectors. +- **inputs:** `%dest` is the UB destination, `%src0..%src3` are the four + pre-sorted UB inputs, `%count` is the number of valid elements, and `%config` + is the operation control word. +- **outputs:** This op writes UB memory and returns no SSA value. +- **constraints and limitations:** Inputs MUST already be sorted according to + the sort order encoded by `%config`. + +--- + +#### Current Implementation Surface Summary + +- `pto.vmull %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.vreg` +- `pto.vmula %acc, %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- `pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- `pto.vbitsort %dest, %src, %indices, %repeat_times : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, index` +- `pto.vmrgsort4 %dest, %src0, %src1, %src2, %src3, %count, %config : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, i64, i64` + +--- + +#### Typical Usage + +```mlir +// Softmax with fused expdiff +%max_broadcast = pto.vlds %ub_max[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> +%exp_stable = pto.vexpdif %logits, %max_broadcast, %mask, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Leaky ReLU activation +%activated = pto.vlrelu %linear_out, %alpha_scalar, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Generate indices for argsort +%indices = pto.vci %c0 {order = "ASC"} : i32 -> !pto.vreg<64xi32> +``` + + + +### 14. Arith (Shared MLIR Dialect) + +> **Category:** Shared full scalar `arith` surface used around PTO ops +> **Dialect:** `arith` +> **Upstream Reference:** https://mlir.llvm.org/docs/Dialects/ArithOps/ + +The upstream MLIR `arith` dialect defines primitive arithmetic, comparison, select, and cast operations over signless integer, index, floating-point, and boolean-compatible scalar values. Within PTO micro Instruction code, the full scalar operation surface of `arith` is supported. These ops are used around PTO instructions to build constants, compute offsets and loop bounds, perform general scalar math, derive valid-shape metadata, and form predicates for `scf` control flow. + +These ops are part of the documented PTO micro Instruction surface, but they are not PTO ISA instructions. + +--- + +#### Role in PTO micro Instruction Code + +- materialize scalar constants used by PTO scalar operands and loop bounds +- compute scalar/index offsets for tensor views, partitioning, and dynamic shapes +- perform general scalar integer and floating-point math outside PTO vector/tile payload operations +- derive scalar predicates that guard `scf.if` or `scf.while` +- apply scalar casts, width changes, bitwise ops, and selects without introducing PTO-specific control ops + +Prefer PTO ops for vector or tile payload math. Use `arith` for scalar computation and bookkeeping that surrounds PTO regions. + +--- + +#### Supported Surface + +The documented PTO micro Instruction surface supports the full scalar operation surface of upstream `arith`. The upstream `arith` dialect reference remains authoritative for the exhaustive op-by-op syntax and semantics. The categories below summarize how that support is used in PTO micro Instruction code. + +| Category | Representative Ops | Typical Use in PTO micro Instruction Code | +|----------|--------------------|------------------| +| Constants | `arith.constant` | integer, floating-point, boolean, and `index` constants | +| Integer / Index Arithmetic | `arith.addi`, `arith.subi`, `arith.muli`, `arith.divsi`, `arith.divui`, `arith.ceildivsi`, `arith.ceildivui`, `arith.floordivsi`, `arith.remsi`, `arith.remui` | offsets, bounds, chunk sizes, scalar math | +| Floating-Point Arithmetic | `arith.addf`, `arith.subf`, `arith.mulf`, `arith.divf`, `arith.negf`, `arith.maximumf`, `arith.minimumf`, `arith.maxnumf`, `arith.minnumf` | scalar math around PTO regions | +| Bitwise / Shift Ops | `arith.andi`, `arith.ori`, `arith.xori`, `arith.shli`, `arith.shrsi`, `arith.shrui` | flags, masks, packed scalar fields | +| Comparisons / Select | `arith.cmpi`, `arith.cmpf`, `arith.select`, `arith.maxsi`, `arith.minui` | predicates, clamps, scalar muxes | +| Casts / Width Changes | `arith.index_cast`, `arith.index_castui`, `arith.extsi`, `arith.extui`, `arith.trunci`, `arith.sitofp`, `arith.uitofp`, `arith.fptosi`, `arith.fptoui`, `arith.extf`, `arith.truncf`, `arith.bitcast` | ABI glue, dynamic-shape plumbing, scalar type adaptation | + +--- + +#### Current PTOAS Coverage + +- the current repository examples are still dominated by constants, casts, integer/index arithmetic, compares, and selects because those are the most common surrounding-scalar patterns in existing kernels +- backend-specific tests such as the PTO shared-dialect fixture visibly exercise only a representative subset of `arith` ops in a single path +- the documented PTO micro Instruction source-level contract is nevertheless the full scalar `arith` surface, not just the index-heavy subset that appears most often in current samples + +This section therefore uses representative categories and examples instead of pretending that the supported `arith` surface is limited to the currently most common sample patterns. + +--- + +#### Typical Patterns + +##### Scalar Setup + +```mlir +%c0 = arith.constant 0 : index +%c1 = arith.constant 1 : index +%scale = arith.constant 2.0 : f32 +``` + +##### Dynamic Offset Computation + +```mlir +%vrow = arith.index_cast %valid_row : i32 to index +%chunk = arith.muli %row, %c32 : index +%tail = arith.subi %limit, %chunk : index +``` + +##### General Scalar Arithmetic + +```mlir +%sum_i = arith.addi %lhs_i, %rhs_i : i32 +%sum_f = arith.addf %lhs_f, %rhs_f : f32 +%prod_f = arith.mulf %sum_f, %scale : f32 +``` + +##### Scalar Predicate and Selection + +```mlir +%is_first = arith.cmpi eq, %i, %c0 : index +%active = arith.select %is_first, %first_count, %steady_count : index +``` + +##### Bitwise / Width Adaptation + +```mlir +%flags = arith.andi %flags0, %flags1 : i32 +%wide = arith.extui %flags : i32 to i64 +%shrunk = arith.trunci %wide : i64 to i16 +``` + +--- + +#### Authoring Guidance + +- treat upstream `arith` scalar semantics as the source of truth for supported scalar ops +- keep `arith` values scalar or `index` typed; do not use `arith` as a substitute for PTO vector/tile compute +- use `arith` for general scalar math, scalar comparisons, bitwise operations, and casts around PTO regions, not just for `index` arithmetic +- use `arith.cmpi` / `arith.cmpf` plus `scf.if` / `scf.while` for control flow, not ad hoc control intrinsics +- prefer `arith.index_cast` / `arith.index_castui` at ABI or shape boundaries where `index` is required, but do not read that as a restriction on the rest of scalar `arith` + + + +### 15. SCF (Shared MLIR Dialect) + +> **Category:** Shared structured control flow around PTO regions +> **Dialect:** `scf` +> **Upstream Reference:** https://mlir.llvm.org/docs/Dialects/SCFDialect/ + +The upstream MLIR `scf` dialect defines structured control flow operations with regions, including counted loops, conditional regions, and while-style loops. In PTO micro Instruction code, `scf` is the control shell around PTO ops: it sequences DMA, vector, and tile operations; carries scalar or tile state across iterations; and preserves analyzable control flow for PTO-specific analyses and lowerings. + +These ops are part of the documented PTO micro Instruction surface, but they are shared MLIR control-flow constructs rather than PTO ISA instructions. + +--- + +#### Supported Ops + +| Op | Role in PTO micro Instruction Code | Notes | +|----|------------------------|-------| +| `scf.for` | counted loops and loop-carried values | common structured counted loop form | +| `scf.if` | structured conditional execution | may yield values or act as side-effect-only branch | +| `scf.yield` | region terminator for `for` / `if` / `while` bodies | carries loop or branch results | +| `scf.while` | break-like or stateful loops | useful for source-level structured control | +| `scf.condition` | loop-continue / loop-exit decision for `scf.while` | placed in the "before" region | + +Ops such as `scf.execute_region`, `scf.forall`, or `scf.index_switch` are not part of the documented shared-dialect portion of the PTO micro Instruction surface here. + +--- + +#### Current PTOAS Coverage + +- `scf.for`, `scf.if`, and `scf.yield` are directly exercised in the shared-dialect PTO fixture and appear widely across PTO samples +- PTO synchronization and memory analyses explicitly reason about `scf.for`, `scf.if`, `scf.yield`, and `scf.while` +- `scf.while` and `scf.condition` appear in control-flow samples and are handled in PTO-to-EmitC control-flow lowering, but they are less broadly exercised than `for` / `if` on all backend paths + +--- + +#### Typical Patterns + +##### Counted Loop + +```mlir +scf.for %i = %c0 to %c4 step %c1 { + %offset = arith.muli %i, %c32 : index + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} +``` + +##### Counted Loop with Loop-Carried State + +```mlir +%final_alive = scf.for %i = %c0 to %c4 step %c1 + iter_args(%alive = %true) -> (i1) { + %break_now = arith.cmpi eq, %i, %c2 : index + %next_alive = scf.if %break_now -> (i1) { + scf.yield %false : i1 + } else { + scf.yield %alive : i1 + } + scf.yield %next_alive : i1 +} +``` + +##### Structured Conditional Region + +```mlir +%is_mode_a = arith.cmpi eq, %mode, %c0_i32 : i32 +scf.if %is_mode_a { + pto.tmuls ins(%data, %scale_a : !pto.tile_buf<...>, f32) outs(%data : !pto.tile_buf<...>) +} else { + pto.tadds ins(%data, %bias_b : !pto.tile_buf<...>, f32) outs(%data : !pto.tile_buf<...>) +} +``` + +##### While-Style Break Loop + +```mlir +%final:2 = scf.while (%i = %c0, %alive = %true) : (index, i1) -> (index, i1) { + %lt = arith.cmpi slt, %i, %c4 : index + %go = arith.andi %lt, %alive : i1 + scf.condition(%go) %i, %alive : index, i1 +} do { +^bb0(%i2: index, %alive2: i1): + %next_i = arith.addi %i2, %c1 : index + scf.yield %next_i, %alive2 : index, i1 +} +``` + +--- + +#### Authoring Guidance + +- use `scf.for` for regular counted loops and loop-carried scalar/tile state +- use `scf.if` for structured branching around PTO regions instead of inventing PTO-specific branch ops +- keep region results explicit with `scf.yield`; this is important for PTO analyses that track carried buffers and aliasing +- use `scf.while` only when a counted loop cannot express the control cleanly; `scf.for` remains the more common and better-exercised form in the current repository +- build branch predicates and loop conditions with `arith` ops, not PTO vector masks, unless the control decision truly comes from a scalarized value + +## Supported Data Types + +| Type | Bits | vreg Lanes | Description | +|------|------|-----------|-------------| +| `i8` / `si8` / `ui8` | 8 | 256 | Signless/signed/unsigned 8-bit integer | +| `i16` / `si16` / `ui16` | 16 | 128 | Signless/signed/unsigned 16-bit integer | +| `f16` | 16 | 128 | IEEE 754 half precision | +| `bf16` | 16 | 128 | Brain floating point | +| `i32` / `si32` / `ui32` | 32 | 64 | Signless/signed/unsigned 32-bit integer | +| `f32` | 32 | 64 | IEEE 754 single precision | +| `i64` / `si64` / `ui64` | 64 | 32 | Signless/signed/unsigned 64-bit integer | + +--- + +## Common Patterns + +### Softmax (Numerically Stable) + +```mlir +// 1. Find max +%max_vec = pto.vcmax %logits, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +pto.vsts %max_vec, %ub_tmp[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +%max_bc = pto.vlds %ub_tmp[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> + +// 2. exp(x - max) using fused op +%exp = pto.vexpdif %logits, %max_bc, %mask, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// 3. Sum +%sum = pto.vcadd %exp, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +pto.vsts %sum, %ub_tmp[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +%sum_bc = pto.vlds %ub_tmp[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> + +// 4. Divide +%softmax = pto.vdiv %exp, %sum_bc, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +### ReLU Variants + +```mlir +// Standard ReLU +%relu = pto.vrelu %input, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Leaky ReLU (scalar alpha) +%lrelu = pto.vlrelu %input, %alpha, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Parametric ReLU (per-element alpha) +%prelu = pto.vprelu %input, %alpha_vec : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +``` + +### Data Layout Conversion + +```mlir +// AoS → SoA (deinterleave) +%x, %y = pto.vldsx2 %ub_xy[%offset], "DINTLV" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +// SoA → AoS (interleave) +pto.vstsx2 %x, %y, %ub_xy[%offset], "INTLV", %all_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.ptr, index, !pto.mask +``` + +--- + +--- + +## Quick Reference by Category + +### Memory Operations + +| Operation | Group | Description | +|-----------|-------|-------------| +| GM→UB DMA | 2 | `pto.copy_gm_to_ubuf` | +| UB→GM DMA | 2 | `pto.copy_ubuf_to_gm` | +| UB→UB Copy | 2 | `pto.copy_ubuf_to_ubuf` | +| Contiguous Load | 3 | `pto.vlds` with `NORM` dist | +| Broadcast Load | 3 | `pto.vlds` with `BRC` family dist | +| Gather | 3 | `pto.vgather2`, `pto.vgatherb` | +| Contiguous Store | 3 | `pto.vsts` with `NORM` dist | +| Scatter | 3 | `pto.vscatter` | + +### Compute Operations + +| Operation | Group | Description | +|-----------|-------|-------------| +| Element-wise Arithmetic | 6, 7 | `pto.vadd`, `pto.vmul`, `pto.vabs`, etc. | +| Scalar Operations | 8 | `pto.vadds`, `pto.vmuls`, etc. | +| Transcendental | 6 | `pto.vexp`, `pto.vln`, `pto.vsqrt`, etc. | +| Reduction | 10 | `pto.vcadd`, `pto.vcmax`, `pto.vcmin` | +| Comparison | 11 | `pto.vcmp`, `pto.vcmps` | +| Selection | 11 | `pto.vsel`, `pto.vselr` | + +### Type & Data Manipulation + +| Operation | Group | Description | +|-----------|-------|-------------| +| Type Conversion | 9 | `pto.vcvt` | +| Interleave/Deinterleave | 12 | `pto.vintlv`, `pto.vdintlv` | +| Interleave/Deinterleave (not A5) | 12 | `pto.vintlvv2`, `pto.vdintlvv2` | + +### Synchronization + +| Operation | Group | Description | +|-----------|-------|-------------| +| Intra-core Sync | 1 | `pto.set_flag`, `pto.wait_flag` | +| Pipeline Buffer Sync | 1 | `pto.get_buf`, `pto.rls_buf` | + +### Scalar & Control Operations + +Group 14 covers the full scalar `arith` surface. The rows below list common PTO micro Instruction patterns rather than an exhaustive partition of `arith` ops. + +| Operation | Group | Description | +|-----------|-------|-------------| +| Scalar Constants | 14 | `arith.constant` | +| Scalar Integer / Index Arithmetic | 14 | `arith.addi`, `arith.subi`, `arith.muli`, `arith.divsi`, `arith.remui`, `arith.ceildivsi`, etc. | +| Scalar Floating-Point Arithmetic | 14 | `arith.addf`, `arith.subf`, `arith.mulf`, `arith.divf`, `arith.maximumf`, etc. | +| Scalar Compare & Select | 14 | `arith.cmpi`, `arith.cmpf`, `arith.select` | +| Scalar Casts / Width Changes | 14 | `arith.index_cast`, `arith.index_castui`, `arith.extsi`, `arith.extui`, `arith.trunci`, `arith.sitofp`, etc. | +| Scalar Bitwise / Shift Ops | 14 | `arith.andi`, `arith.ori`, `arith.xori`, `arith.shli`, `arith.shrsi`, `arith.shrui`, etc. | +| Counted Loops | 15 | `scf.for` | +| Conditional Regions | 15 | `scf.if`, `scf.yield` | +| Break-like Structured Loops | 15 | `scf.while`, `scf.condition`, `scf.yield` | diff --git a/tilelang-dsl/docs/vpto_spec/vpto-spec-v0.2.md b/tilelang-dsl/docs/vpto_spec/vpto-spec-v0.2.md new file mode 100644 index 000000000..6c06c4c07 --- /dev/null +++ b/tilelang-dsl/docs/vpto_spec/vpto-spec-v0.2.md @@ -0,0 +1,5072 @@ +# PTO micro Instruction Spec — Draft (A5) + +- v0.2: Update micro Instruction latency and throughput +- v0.1: Doc Init + +[toc] + +--- + +## Part I: Architecture Overview + +### Overview + +This document defines the PTO micro Instruction, a compiler-internal and externally facing specification designed to represent vector compute kernels within the PTO architecture. Much like NVVM provides a robust IR for GPU architectures, the PTO micro Instruction serves as the direct bridge between high-level programming models and the underlying hardware ISA, providing a precise, low-level representation of vector workloads explicitly designed for the Ascend 950 architecture. + +#### Position in the Stack and Layer Modeled + +The PTO micro Instruction operates as a very low-level intermediate representation within the PTO compiler stack. It is uniquely designed to accurately and comprehensively express all architectural information of the Ascend 950 hardware. It specifically models the bare-metal vector execution layer, making hardware-specific capabilities and constraints, such as exact vector lane configurations, memory space hierarchies, and hardware-specific fusion semantics, fully transparent and controllable. + +#### PTO Instruction Modes and Compilation Flows + +Within the end-to-end PTO software stack, PTO instructions may appear in three closely related authoring or lowering modes: + +- **PTO Tile Instruction**: tile-oriented PTO code that serves as a nano-kernel encapsulation of Tile operations, primarily expressing computation and data movement in terms of tile buffers, tile shapes, and tile-local layout. +- **PTO micro Instruction**: vector-execution-oriented PTO code that makes DMA setup, vector registers, masks, synchronization, and `__VEC_SCOPE__` boundaries explicit. This document is centered on this mode. +- **PTO Tile+micro Instruction**: a hybrid PTO form that keeps tile-level orchestration while embedding explicit micro-instruction regions where direct vector-pipeline control is required. + +From these PTO instruction forms, the stack can proceed along two main compilation flows: + +- **CCE generation flow**: PTO ISA is lowered into a CCE-oriented representation, which is then compiled by the BiSheng toolchain into Ascend device binaries. +- **Bytecode generation flow**: PTO ISA is emitted as bytecode, which is then compiled by the BiSheng toolchain into Ascend device binaries. + +```text +High-level frameworks / DSLs / library kernels + | + v + +----------------------------------+ + | PTO ISA layer | + | | + | (1) PTO Tile Instruction | + | (2) PTO micro Instruction | + | (3) PTO Tile+micro Instruction | + +----------------+-----------------+ + | + +------------+------------+ + | | + v v + +-------------------------+ +-------------------------+ + | Path A: generate CCE | | Path B: generate | + | (CCE-oriented form) | | bytecode | + +------------+------------+ +------------+------------+ + | | + v v + +-----------------------------------------------+ + | BiSheng compiler | + +---------------------------+-------------------+ + | + v + +-----------------------------+ + | Ascend device binaries | + +-----------------------------+ +``` + +#### Why External Developers Read or Author PTO micro Instruction + +While the majority of users will interact with the PTO architecture via higher-level frameworks, external developers may need to read or author PTO micro Instruction directly for several key reasons: + +- Custom Toolchain Development: build custom compiler frontends or domain-specific languages (DSLs) that target the Ascend 950 architecture with maximum hardware utilization. +- Performance Engineering: inspect the output of high-level compiler passes, verify fine-grained optimization behaviors, and pinpoint performance bottlenecks at the architectural level. +- Micro-Optimization: hand-author highly optimized, critical mathematical kernels using a stable, precise IR when higher-level abstractions cannot achieve the theoretical peak performance of the hardware. + +#### Relationship to CCE + +The PTO micro Instruction is designed to express the full semantic capabilities of the Compute Cube Engine (CCE), but with significant structural and pipeline advantages for compiler development. + +- Bypassing the C/Clang Pipeline: while CCE heavily relies on C/C++ extensions parsed by Clang, the PTO micro Instruction operates entirely independently of the C language frontend. By bypassing Clang AST generation and frontend processing, utilizing the PTO micro Instruction significantly reduces overall compilation time and memory overhead. +- Enhanced IR Verification: because the PTO micro Instruction is a strongly typed, SSA-based (Static Single Assignment) compiler IR rather than a C-wrapper API, it provides a much more rigorous and detailed IR verification process. Structural inconsistencies, invalid memory access patterns, and operand type mismatches are caught immediately with precise, explicit diagnostic feedback, providing developers with much higher visibility into kernel correctness than traditional CCE error reporting. + +#### Intended Audience + +This document is written for compiler engineers, library writers, and advanced performance architects. We expect the reader to have a working understanding of modern compiler infrastructure, specifically MLIR, the principles of Static Single Assignment (SSA) form, and a deep understanding of the vector-processing capabilities of the Ascend 950 architecture. + +### Getting Started + +The PTO micro Instruction is architected as a performance-critical layer within the compiler stack, specifically designed to exploit the **Decoupled Access-Execute** (DAE) nature of the Ascend 950 hardware. + +#### Hardware Pipeline Modeling + +The IR is structured to mirror the three primary hardware pipelines of the Ascend 950 architecture. Correct PTO micro Instruction authoring requires managing the interaction between these asynchronous units: + +**MTE2** (Memory Transfer Engine - Inbound): Responsible for moving data from Global Memory (GM) to the Unified Buffer (UB). + +**Vector Core** (Computation): The primary engine for executing SIMD operations on data stored in UB. + +**MTE3** (Memory Transfer Engine - Outbound): Responsible for moving processed data from UB back to GM. + +#### Architecture Detail: Vector Lane (VLane) + +The vector register is organized as **8 VLanes** of 32 bytes each. A VLane is the atomic unit for group reduction operations. + +``` +vreg (256 bytes total): +┌─────────┬─────────┬─────────┬─────┬─────────┬─────────┐ +│ VLane 0 │ VLane 1 │ VLane 2 │ ... │ VLane 6 │ VLane 7 │ +│ 32B │ 32B │ 32B │ │ 32B │ 32B │ +└─────────┴─────────┴─────────┴─────┴─────────┴─────────┘ +``` + +Elements per VLane by data type: + +| Data Type | Elements/VLane | Total Elements/vreg | +|-----------|---------------|-------------------| +| i8/u8 | 32 | 256 | +| i16/u16/f16/bf16 | 16 | 128 | +| i32/u32/f32 | 8 | 64 | +| i64/u64 | 4 | 32 | + +#### Memory and Synchronization Model + +The PTO micro Instruction enforces a strict memory hierarchy. The Unified Buffer (UB) is the only valid operand source for vector compute instructions. Consequently, the architecture of a PTO micro Instruction program is defined by the explicit management of data movement: + +**Address Space Isolation**: The IR uses `!pto.ptr` to distinguish between GM (`!pto.ptr`) and UB (`!pto.ptr`). The verifier ensures that vector compute operations do not access GM directly; data must first be moved into UB. + +**UB Capacity**: The Unified Buffer provides 256KB of on-chip SRAM (also referred to as "vecTile"). + +**Data Flow**: + +``` +┌─────────────────────────────────────────────┐ +│ Global Memory (GM) │ +│ (Off-chip HBM/DDR) │ +└─────────────────────┬───────────────────────┘ + │ DMA (MTE2 inbound / MTE3 outbound) +┌─────────────────────▼───────────────────────┐ +│ Unified Buffer (UB) │ +│ (On-chip SRAM, 256KB) │ +└─────────────────────┬───────────────────────┘ + │ Vector Load/Store (PIPE_V) +┌─────────────────────▼───────────────────────┐ +│ Vector Register File (VRF) │ +│ vreg (256B each) + mask (256-bit each) │ +└─────────────────────────────────────────────┘ +``` + +1. **GM → UB**: DMA transfer via MTE2 (`pto.copy_gm_to_ubuf`) +2. **UB → vreg**: Vector Load instructions (`pto.vlds`, `pto.vldx2`, etc.) +3. **vreg → vreg**: Compute instructions (`pto.vadd`, `pto.vmul`, etc.) +4. **vreg → UB**: Vector Store instructions (`pto.vsts`, `pto.vstx2`, etc.) +5. **UB → GM**: DMA transfer via MTE3 (`pto.copy_ubuf_to_gm`) + +**Load/Store Access Patterns**: + +For UB↔vreg data movement, besides contiguous load/store, the architecture provides rich access pattern support including strided access, pack/unpack, interleave/deinterleave, broadcast, upsample/downsample, channel split/merge, gather/scatter, and squeeze/expand operations. For detailed instruction syntax and distribution modes, refer to the [Vector Load/Store](#isa-03-vector-load-store) group in the ISA specification. + +#### Synchronization Model + +The Ascend 950 architecture employs a cluster-based design with a 1:2 ratio of Cube cores to Vector cores. The PTO micro Instruction provides multiple levels of synchronization to manage concurrent execution across pipelines and cores: + +**Inter-Core Synchronization (within a cluster):** + +Synchronization between cores within the same cluster is achieved via the core sync mechanism using `pto.set_intra_core` and `pto.wait_intra_core` operations. This enables coordination between Cube and Vector cores sharing the same cluster resources. + +**Vector Core Pipeline Synchronization:** + +Within a single core, multiple pipelines operate asynchronously: + +- **MTE2 (PIPE_MTE2)**: DMA copy-in from GM to UB +- **MTE3 (PIPE_MTE3)**: DMA copy-out from UB to GM +- **Vector Compute (PIPE_V)**: Vector ALU operations +- **Scalar (PIPE_S)**: Scalar unit running the kernel program + +Pipeline synchronization can be achieved through two mechanisms: + +1. **Flag/Event mechanism**: `pto.set_flag` and `pto.wait_flag` operations resolve Read-After-Write (RAW) and Write-After-Read (WAR) hazards between pipelines. + +2. **Buffer-ID mechanism**: `pto.get_buf` and `pto.rls_buf` provide finer-grained synchronization through buffer acquisition and release semantics for producer-consumer coordination. + +**Intra-Pipeline Memory Barriers (within `__VEC_SCOPE__`):** + +Within the vector execution scope, the hardware does not track UB address aliasing between reg↔UB accesses. When UB addresses overlap or alias between vector load/store operations, explicit memory barriers are required: + +```c +pto.mem_bar "VV_ALL" // All prior vector ops complete before subsequent +pto.mem_bar "VST_VLD" // All prior vector stores visible before subsequent loads +pto.mem_bar "VLD_VST" // All prior vector loads complete before subsequent stores +``` + +Without proper barriers, loads may see stale data or stores may be reordered incorrectly. + +#### Execution Scopes (__VEC_SCOPE__) + +`__VEC_SCOPE__` is the IR-level representation of a Vector Function (VF) launch. In the PTO architecture, it defines the hardware interface between the Scalar Unit and the Vector Thread. + +In PTO micro Instruction source IR, vector execution scopes are modeled as dedicated region ops. The default form is `pto.vecscope`; when the scope body must reject implicit capture and require explicit region arguments, use `pto.strict_vecscope`. + +**Scalar-Vector Interface:** + +The execution model follows non-blocking fork semantics: + +- Scalar invocation: the scalar processor invokes a vector thread by calling a VF. Once the launch command is issued, the scalar unit does not stall and continues executing subsequent instructions in the pipeline. +- Vector execution: after invocation, the vector thread independently fetches and executes the instructions defined within the VF scope. +- Parallelism: this decoupled execution allows the scalar and vector units to run in parallel, so the scalar unit can prepare addresses or manage control flow while the vector unit performs heavy SIMD computation. + +**Launch Mechanism And Constraints:** + +- Parameter buffering: all arguments required by the VF must be staged in hardware-specific buffers. +- Launch overhead: launching a VF incurs a latency of a few cycles. Very small VFs should account for this overhead because launch cost can rival useful computation time. + +**MLIR Representation:** + +```mlir +pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} +``` + +**Strict MLIR Representation:** + +```mlir +pto.strict_vecscope(%ub, %ub_out, %lane) { +^bb0(%in: !pto.ptr, %out: !pto.ptr, %iv: index): + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %in[%iv] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %out[%iv], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} : (!pto.ptr, !pto.ptr, index) -> () +``` + +`pto.strict_vecscope` is the strict form of `pto.vecscope`. + +- `pto.vecscope` allows the body to use surrounding SSA values directly. +- `pto.strict_vecscope` requires every external value used by the body to be passed through the op operand list and received as a body block argument. +- `pto.strict_vecscope` rejects implicit capture from the surrounding scope. +- both ops still represent one explicit VPTO vector interval. + +### Example: VecScope + +```mlir +pto.set_loop2_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.copy_gm_to_ubuf %7, %2, %3, %3, %c0_i64, %c32_i64, %4, %c0_i64, %c0_i64, + %false, %c0_i64, %c128_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i64, i1, i64, i64, i64 + +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +pto.vecscope { + scf.for %lane = %c0 to %9 step %c64 { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %2[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %8[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } +} + +pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] +pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_ubtoout %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop2_stride_ubtoout %c4096_i64, %c4096_i64 : i64, i64 +pto.copy_ubuf_to_gm %8, %14, %3, %3, %c0_i64, %c32_i64, %4, %c0_i64, %c128_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i64, i64 +``` + +### Example: Strict VecScope + +```mlir +pto.strict_vecscope(%ub_in, %ub_out, %lane, %remaining) { +^bb0(%in: !pto.ptr, %out: !pto.ptr, %iv: index, %rem: i32): + %mask, %next_remaining = pto.plt_b32 %rem : i32 -> !pto.mask, i32 + %v = pto.vlds %in[%iv] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %out[%iv], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} : (!pto.ptr, !pto.ptr, index, i32) -> () +``` + +Use `pto.strict_vecscope` when the source form should make all vector-scope inputs explicit in the region signature instead of relying on surrounding SSA visibility. The scope op itself only defines the vector-interval boundary and region argument contract. + +### Scope + +This document is the interface specification centered on the `mlir::pto` dialect and the shared MLIR surface used alongside it in PTO micro Instruction programs. + +It only describes: + +- operation names +- operand and result lists +- operand and result types +- important attributes +- C-style semantics for each operation + +It does not describe lowering strategy. + +PTO micro Instruction source programs are not restricted to `pto` operations alone. In practice they also use shared MLIR dialect ops, most notably the full scalar operation surface of `arith` together with structured control-flow ops from `scf`, to express scalar constants, scalar arithmetic, type conversion, comparisons, and structured control flow around PTO vector or tile regions. These shared-dialect ops are part of the supported PTO micro Instruction source surface and should be regarded as part of PTO-ISA alongside `pto` dialect operations. + +- `vreg`: `!pto.vreg` + Fixed-width VPTO vector type with total width exactly 256 bytes. +- `mask`: `!pto.mask` + Typed predicate-register view. `G` is one of `b8`, `b16`, `b32` and records the byte-granularity interpretation used by VPTO ops and verifiers. +- `align`: `!pto.align` +- `buf`: buffer-like LLVM pointer type accepted by the dialect +- `buf_like`: `memref<...>` or `!llvm.ptr` for stateless/predicate + `vld*/vst*` families +- `idx`: `index` +- `i32`: `i32` +- `i64`: `i64` + +### Shared MLIR Dialects + +- `arith`: the full scalar `arith` surface is supported in PTO micro Instruction programs, covering scalar integer, floating-point, boolean, and `index` operations. In current samples the most common uses are still constants, offset/bounds arithmetic, casts, compares, and selects. +- `scf`: structured control flow used to model counted loops, conditional regions, loop-carried state, and break-like control around PTO compute and data-movement ops. +- Shared dialect ops remain in standard MLIR form so that PTO analyses and backend passes can reason about control flow and scalar state without re-encoding them as PTO-specific instructions. + +### Core Types + +### Element Types +`vreg`: `!pto.vreg` Fixed-width PTO micro Instruction vector type with total width exactly 256 bytes (2048 bits). `N` is the lane count, `T` is the element type, and `N * bitwidth(T) = 2048`. + +| Type | Bits | Description | +|------|------|-------------| +| `i8` / `s8` / `u8` | 8 | Signless/signed/unsigned 8-bit integer | +| `i16` / `s16` / `u16` | 16 | Signless/signed/unsigned 16-bit integer | +| `i32` / `s32` / `u32` | 32 | Signless/signed/unsigned 32-bit integer | +| `i64` / `s64` / `u64` | 64 | Signless/signed/unsigned 64-bit integer | +| `f16` | 16 | IEEE 754 half precision | +| `bf16` | 16 | Brain floating point | +| `f32` | 32 | IEEE 754 single precision | +| `f8e4m3` | 8 | FP8 (4-bit exponent, 3-bit mantissa) | +| `f8e5m2` | 8 | FP8 (5-bit exponent, 2-bit mantissa) | + +### Address Space Conventions + +PTO micro Instruction memory operands use `!pto.ptr`. This specification models the following memory-space attributes: + +| Space | Interpretation | +|-------|----------------| +| `gm` | Global Memory (GM), off-chip HBM/DDR storage | +| `ub` | Unified Buffer (UB), on-chip vector buffer | + +Typical pointer construction and pointer arithmetic follow the same `!pto.ptr<..., space>` form: + +```mlir +%0 = pto.castptr %c0 : i64 -> !pto.ptr +%1 = pto.addptr %0, %c1024 : !pto.ptr -> !pto.ptr +``` + +### `!pto.ptr` + +`!pto.ptr` is the typed pointer form used for explicit memory operands in PTO micro Instruction. + +- `T` is the element type associated with the pointed-to storage. +- `space` is the memory domain, typically `gm` or `ub` in this specification. +- A `pto.ptr` value carries an address plus its element-type / memory-space interpretation, but it does not carry tensor shape or stride metadata by itself. +- Tensor semantics are introduced separately through view-building operations such as `pto.make_tensor_view`. +- Pointer arithmetic is element-based rather than byte-based. + +Typical examples: + +- `!pto.ptr` +- `!pto.ptr` +- `!pto.ptr` + +### Pointer Operations + +#### `pto.castptr` + +- **syntax:** `%result = pto.castptr %addr : i64 -> !pto.ptr` +- **semantics:** Reinterpret a scalar address value as a typed PTO pointer in the target memory space. + +```c +result = (ptr)addr; +``` + +`pto.castptr` is a pointer-construction operation. It does not perform data movement and does not by itself imply any load/store side effect. + +#### `pto.addptr` + +- **syntax:** `%result = pto.addptr %ptr, %offset : !pto.ptr -> !pto.ptr` +- **semantics:** Compute a new pointer by advancing the base pointer by an element offset. + +```c +result = ptr + offset; // offset counted in elements, not bytes +``` + +`pto.addptr` preserves both the element type `T` and the memory-space tag `space`. + +#### Pointer-Based Vector Access Example + +The following lowered-style fragment shows how typed PTO pointers flow through pointer construction, pointer arithmetic, structured control flow, and PTO memory ops: + +```mlir +%0 = pto.castptr %c0 : i64 -> !pto.ptr +%1 = pto.addptr %0, %c1024 : !pto.ptr -> !pto.ptr +pto.vecscope { + %16 = scf.for %arg3 = %c0 to %11 step %c64 iter_args(%arg4 = %12) -> (i32) { + %mask, %scalar_out = pto.plt_b32 %arg4 : i32 -> !pto.mask, i32 + %17 = pto.vlds %1[%arg3] : !pto.ptr -> !pto.vreg<64xf32> + %18 = pto.vabs %17, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %18, %10[%arg3], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %scalar_out : i32 + } +} +``` + +In this pattern, `pto.castptr` materializes a typed UB pointer, `pto.addptr` shifts the base by 1024 `f32` elements, and the subsequent `[%arg3]` indexing on `pto.vlds` / `pto.vsts` applies an additional element offset relative to that base. + +### Special Types + +#### `!pto.mask` + +`!pto.mask` models an A5 predicate register (256-bit) under a typed granularity view, not an integer vector. + +`G` is part of the type and MUST be one of: + +- `b32` +- `b16` +- `b8` + +All three forms describe the same physical 256-bit predicate-register class. The type parameter does not encode how many lanes are currently active. Instead, it records how VPTO interprets the register when matching mask-producing ops, mask-consuming ops, and verifier legality rules. + +In the ISA chapters below, this document uses `!pto.mask` as shorthand when a +family is generic over granularity. For op families whose names already encode +the granularity, such as `pset_b32`, `pge_b16`, `plt_b8`, +`pdintlv_b8`, and `pintlv_b16`, examples use the corresponding concrete typed +mask. + +**Mask Granularity:** + +The predicate register is 256 bits in length, where each bit controls 1 byte of data. `G` therefore describes how many bytes form one logical element slot: + +| Mask Type | Bytes / Element Slot | Typical Element Family | Derived Logical Lanes | +|-----------|----------------------|------------------------|-----------------------| +| `!pto.mask` | 4 | `f32` / `i32` | 64 | +| `!pto.mask` | 2 | `f16` / `bf16` / `i16` | 128 | +| `!pto.mask` | 1 | 8-bit element family | 256 | + +This is intentionally different from a lane-vector model such as `mask<64xi1>`: + +- `!pto.mask` still denotes a 256-bit predicate register; +- `64` is only the derived logical lane count for the `b32` view; +- value-level patterns such as `PAT_VL32` describe which lanes are active, not a different type. + +**Predication Behavior (Zero-Merge):** + +The native hardware predication mode is **ZEROING** — inactive lanes produce zero: + +```c +dst[i] = mask[i] ? op(src0[i], src1[i]) : 0 // ZEROING mode +``` + +```mlir +// Predicated add: inactive lanes produce zero +%mask = pto.pset_b32 "PAT_VL32" : !pto.mask // first 32 logical b32 lanes active +%result = pto.vcmp %a, %b, %mask, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +``` + +```mlir +// Compare and select: generate mask from comparison, use for conditional select +%mask = pto.vcmp %lhs, %rhs, %seed, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +%out = pto.vsel %x, %y, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +#### `!pto.align` + +`!pto.align` models the A5 vector-align carrier state. It is not payload data. + +```mlir +%align = pto.vldas %ub : !pto.ptr -> !pto.align +%vec, %align_out, %base_out = pto.vldus %ub, %align : !pto.ptr, !pto.align -> !pto.vreg<64xf32>, !pto.align, !pto.ptr +``` + +--- + +## Part II: Notation Convention + +This section defines the MLIR syntax patterns and C-style semantic notation used throughout the ISA reference (Part III). + +### MLIR Op Syntax Patterns + +All PTO micro Instruction operations follow standard MLIR syntax. The common patterns are: + +**Unary (one vector in, one vector out):** + +```mlir +%result = pto. %input : !pto.vreg -> !pto.vreg +``` + +**Binary (two vectors in, one vector out):** + +```mlir +%result = pto. %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg +``` + +**Vec-Scalar (one vector + one scalar in, one vector out):** + +```mlir +%result = pto. %input, %scalar : !pto.vreg, T -> !pto.vreg +``` + +**Load (memory to register):** + +```mlir +%result = pto.vlds %source[%offset] {dist = "DIST"} : !pto.ptr -> !pto.vreg +``` + +**Store (register to memory):** + +```mlir +pto.vsts %value, %destination[%offset] {dist = "DIST"} : !pto.vreg, !pto.ptr +``` + +**Dual Load (one load, two results — deinterleave):** + +```mlir +%low, %high = pto.vldx2 %source[%offset], "DIST" : !pto.ptr, index -> !pto.vreg, !pto.vreg +``` + +**Dual Store (two inputs, one interleaved store):** + +```mlir +pto.vstx2 %low, %high, %dest[%offset], "DIST", %mask : !pto.vreg, !pto.vreg, !pto.ptr, index, !pto.mask +``` + +**Compare (two vectors + seed mask in, mask out):** + +```mlir +%mask = pto.vcmp %src0, %src1, %seed, "CMP_MODE" : !pto.vreg, !pto.vreg, !pto.mask -> !pto.mask +``` + +**Conversion (one vector in, different-typed vector out):** + +```mlir +%result = pto.vcvt %input {round_mode = "ROUND_R", sat = "RS_ENABLE", part = "PART_EVEN"} : !pto.vreg -> !pto.vreg +``` + +**Predicate construction:** + +```mlir +%mask = pto.pset_b32 "PAT_ALL" : !pto.mask +%tail = pto.pge_b32 "PAT_VL16" : !pto.mask +``` + +**Sync operations:** + +```mlir +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +**Pointer construction and arithmetic:** + +```mlir +%ptr = pto.castptr %addr : i64 -> !pto.ptr +%ptr2 = pto.addptr %ptr, %offset : !pto.ptr -> !pto.ptr +``` + +### Shared Dialect Syntax Patterns + +PTO micro Instruction programs may interleave PTO ops with standard MLIR `arith` and `scf` ops. +The examples below emphasize common index-heavy patterns, but `arith` support is not limited to index arithmetic. + +**Scalar / index constant:** + +```mlir +%c0 = arith.constant 0 : index +%zero = arith.constant 0.0 : f32 +``` + +### C-Style Semantics Convention + +For each ISA operation in Part III, semantics are expressed as C code. The convention: + +```c +// Vector register contents as arrays: +T dst[N]; // destination +T src0[N]; // first source +T src1[N]; // second source (binary ops) +T scalar; // scalar operand (vec-scalar ops) +int mask[N]; // per-lane predicate (0 or 1) + +// N = lane count determined by type: +// N = 256 for i8/u8 +// N = 128 for i16/u16/f16/bf16 +// N = 64 for i32/u32/f32 +// N = 32 for i64/u64 +``` + +**Example — pto.vadd semantics:** + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] + src1[i]; +``` + +**Example — pto.vcgadd (group reduction per VLane) semantics:** + +```c +int K = N / 8; // elements per VLane +for (int g = 0; g < 8; g++) { + T sum = 0; + for (int i = 0; i < K; i++) + sum += src[g*K + i]; + dst[g*K] = sum; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +### Template Placeholder Conventions + +| Placeholder | Meaning | +|-------------|---------| +| `"SRC_PIPE"`, `"DST_PIPE"` | Pipeline identifiers: `"PIPE_MTE2"`, `"PIPE_V"`, `"PIPE_MTE3"` | +| `"EVENT_ID"` | Event identifier: `"EVENT_ID0"` etc. | +| `"DIST"` | Distribution mode string (see the relevant load/store ISA group in Part III) | +| `"CMP_MODE"` | Compare predicate: `eq \| ne \| lt \| le \| gt \| ge` | +| `"ROUND_MODE"` | Rounding mode: `ROUND_R \| ROUND_A \| ROUND_F \| ROUND_C \| ROUND_Z` | +| `"SAT_MODE"` | Saturation: `RS_ENABLE \| RS_DISABLE` | +| `"PART_MODE"` | Half selector: `PART_EVEN \| PART_ODD` | +| `"PAT_*"` | Predicate pattern literal | +| `T` | Element type (f32, f16, bf16, i32, i16, i8, etc.) | +| `N` | Lane count (`N * bitwidth(T) = 2048`) | + +--- + +## Part III: ISA Instruction Reference — Summary + +This section provides a categorized overview of all PTO micro Instruction operations plus the shared MLIR `arith` and `scf` ops that may appear in PTO micro Instruction programs. Detailed documentation for each group is included later in this merged document. + +--- + +## Instruction Groups + +| # | Group | Description | Count | Details | +|---|-------|-------------|-------|---------| +| 1 | [Pipeline Sync](#isa-01-pipeline-sync) | Intra-core pipeline synchronization | 5 | `pto.set_flag`, `pto.wait_flag`, `pto.pipe_barrier`, `pto.get_buf`, `pto.rls_buf` | +| 2 | [DMA Copy Programming](#isa-02-dma-copy) | DMA configuration and transfer between GM↔UB | 9 | `pto.set_loop*_stride_*`, `pto.set_loop_size_*`, `pto.copy_gm_to_ubuf`, `pto.copy_ubuf_to_ubuf`, `pto.copy_ubuf_to_gm` | +| 3 | [Vector Load/Store](#isa-03-vector-load-store) | UB↔vreg data movement with various access patterns | ~20 | `pto.vlds`, `pto.vldx2`, `pto.vgather2`, `pto.vsts`, `pto.vstx2`, `pto.vscatter`, etc. | +| 4 | [Predicate Load/Store](#isa-04-predicate-load-store) | UB↔mask register movement | 7 | `pto.plds`, `pto.pld`, `pto.pldi`, `pto.psts`, `pto.pst`, `pto.psti`, `pto.pstu` | +| 5 | [Materialization & Predicate Ops](#isa-05-materialization-predicate) | Scalar broadcast, predicate generation and manipulation | ~17 | `pto.vbr`, `pto.vdup`, `pto.pset_b*`, `pto.pge_b*`, `pto.plt_b*`, `pto.ppack`, `pto.punpack`, `pto.pnot`, `pto.psel`, etc. | +| 6 | [Unary Vector Ops](#isa-06-unary-vector-ops) | Single-input element-wise operations | 9 | `pto.vabs`, `pto.vexp`, `pto.vln`, `pto.vsqrt`, `pto.vrec`, `pto.vrelu`, `pto.vnot`, `pto.vbcnt`, `pto.vcls` | +| 7 | [Binary Vector Ops](#isa-07-binary-vector-ops) | Two-input element-wise operations | 13 | `pto.vadd`, `pto.vsub`, `pto.vmul`, `pto.vdiv`, `pto.vmax`, `pto.vmin`, `pto.vand`, `pto.vor`, `pto.vxor`, `pto.vshl`, `pto.vshr`, `pto.vaddc`, `pto.vsubc` | +| 8 | [Vec-Scalar Ops](#isa-08-vec-scalar-ops) | Vector-scalar operations | 8 | `pto.vadds`, `pto.vmuls`, `pto.vmaxs`, `pto.vmins`, `pto.vlrelu`, `pto.vshls`, `pto.vshrs`, `pto.vaddcs`, `pto.vsubcs` | +| 9 | [Conversion Ops](#isa-09-conversion-ops) | Type conversion with rounding/saturation control | 2 | `pto.vcvt`, `pto.vtrc` | +| 10 | [Reduction Ops](#isa-10-reduction-ops) | Vector reductions | 3 | `pto.vcadd`, `pto.vcmax`, `pto.vcmin` | +| 11 | [Compare & Select](#isa-11-compare-select) | Comparison and conditional selection | 5 | `pto.vcmp`, `pto.vcmps`, `pto.vsel`, `pto.vselr`, `pto.vselrv2` | +| 12 | [Data Rearrangement](#isa-12-data-rearrangement) | In-register data movement and permutation | 4 | `pto.vintlv`, `pto.vdintlv`, `pto.vintlvv2`, `pto.vdintlvv2` | +| 13 | [DSA/SFU Ops](#isa-13-dsa-sfu-ops) | Specialized ops, index generation, and sorting helpers | 5 | `pto.vmull`, `pto.vmula`, `pto.vci`, `pto.vbitsort`, `pto.vmrgsort4` | +| 14 | [Arith (Shared MLIR Dialect)](#isa-14-shared-arith) | Full scalar `arith` surface used around PTO ops; the companion page lists categories and representative examples | all scalar ops | `arith.constant`, `arith.addi`, `arith.addf`, `arith.cmpi`, `arith.cmpf`, `arith.select`, `arith.index_cast`, `arith.extsi`, `arith.trunci`, `arith.andi`, `arith.shli`, etc. | +| 15 | [SCF (Shared MLIR Dialect)](#isa-15-shared-scf) | Structured loops, branches, and loop-carried state around PTO regions | 5 | `scf.for`, `scf.if`, `scf.while`, `scf.condition`, `scf.yield` | + +--- + +## Detailed ISA Group Reference + +This section inlines the 15 ISA group documents so the architectural overview, notation, summary table, and per-group semantics can be read in a single file. + + + +### 1. Pipeline Synchronization + +> **Category:** Synchronization primitives for coordinating pipeline execution +> **Pipelines:** MTE2 (GM→UB), PIPE_V (Vector), MTE3 (UB→GM) + +The PTO micro Instruction model operates on the Ascend 950's **Decoupled Access-Execute** architecture. The MTE and Vector pipelines run asynchronously, requiring explicit synchronization to prevent data hazards. + +--- + +#### Intra-Core Pipeline Sync + +These ops coordinate data flow between pipelines within a single vector core. + +##### `pto.set_flag` + +- **syntax:** `pto.set_flag["SRC_PIPE", "DST_PIPE", "EVENT_ID"]` +- **semantics:** Signal event from source pipe to destination pipe. + +```c +set_flag(src_pipe, dst_pipe, event_id); +``` + +**Example:** After MTE2 completes GM→UB transfer, signal Vector pipe: +```mlir +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +--- + +##### `pto.wait_flag` + +- **syntax:** `pto.wait_flag["SRC_PIPE", "DST_PIPE", "EVENT_ID"]` +- **semantics:** Block destination pipe until source pipe signals event. + +```c +wait_flag(src_pipe, dst_pipe, event_id); +``` + +**Example:** Vector pipe waits for MTE2 data to arrive: +```mlir +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +--- + +##### `pto.pipe_barrier` + +- **syntax:** `pto.pipe_barrier "PIPE_*"` +- **semantics:** Drain all pending ops in the specified pipe. All previously issued operations on that pipe complete before any subsequent operation begins. + +```c +pipe_barrier(pipe); +``` + +**Pipe identifiers:** `PIPE_MTE2`, `PIPE_V`, `PIPE_MTE3` + +**Example:** Two back-to-back `copy_ubuf_to_gm` calls writing to the same GM address. Without a barrier, MTE3 may reorder them and the final GM value is non-deterministic: + +```mlir +// Both stores target the same GM address — order matters! +pto.copy_ubuf_to_gm %ub_partial_0, %gm_result, ... +// Without pipe_barrier, MTE3 could execute the second copy before the first +// completes, producing a non-deterministic result at %gm_result. +pto.pipe_barrier "PIPE_MTE3" +// After barrier: first copy is guaranteed complete. Second copy overwrites deterministically. +pto.copy_ubuf_to_gm %ub_partial_1, %gm_result, ... +``` + +--- + +##### `pto.get_buf` + +- **syntax:** `pto.get_buf "PIPE_*", %buf_id, %mode : i64, i64` +- **semantics:** Acquire buffer slot for inter-pipeline double-buffering coordination. + +```c +get_buf(pipe, buf_id, mode); +``` + +--- + +##### `pto.rls_buf` + +- **syntax:** `pto.rls_buf "PIPE_*", %buf_id, %mode : i64, i64` +- **semantics:** Release buffer slot to allow other pipeline to proceed. + +```c +rls_buf(pipe, buf_id, mode); +``` + +--- + +##### `pto.mem_bar` + +- **syntax:** `pto.mem_bar "BARRIER_TYPE"` +- **semantics:** Intra-vector-pipe memory fence within `__VEC_SCOPE__`. Required when UB addresses alias between vector load/store operations. + +```c +mem_bar(barrier_type); +``` + +**Barrier types:** + +| Type | Semantics | +|------|-----------| +| `VV_ALL` | All prior vector ops complete before subsequent | +| `VST_VLD` | All prior vector stores visible before subsequent loads | +| `VLD_VST` | All prior vector loads complete before subsequent stores | + +**Example:** Ensure stores are visible before loads to same UB region: +```mlir +pto.vsts %v0, %ub[%c0] : !pto.vreg<64xf32>, !pto.ptr +pto.mem_bar "VST_VLD" +%v1 = pto.vlds %ub[%c0] : !pto.ptr -> !pto.vreg<64xf32> +``` + +--- + +#### Intra-Core Sync Patterns & Examples + +##### Example 1: `set_flag` / `wait_flag` (Explicit Events) + +Each cross-pipeline data dependency requires an explicit signal/wait pair. The programmer must manually insert `set_flag` after the producer and `wait_flag` before the consumer. + +```mlir +// ─── Stage 1: MTE2 loads data from GM into UB ─── +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, ... + +// MTE2 signals: "UB data is ready for Vector pipe" +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +// ─── Stage 2: Vector pipe consumes UB data ─── +// Vector waits until MTE2's signal arrives +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +pto.vecscope { + %v = pto.vlds %ub_ptr[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} + +// Vector signals: "UB output is ready for MTE3" +pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + +// ─── Stage 3: MTE3 stores result from UB back to GM ─── +// MTE3 waits until Vector's signal arrives +pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + +pto.copy_ubuf_to_gm %ub_out, %gm_out, ... +``` + +**Key property:** Every cross-pipeline edge is an explicit `(set_flag, wait_flag)` pair. Simple for straight-line code, but gets verbose in loops (see Example 3). + +--- + +##### Example 2: `get_buf` / `rls_buf` (Resource-Based) + +Instead of naming events, each pipeline declares when it **acquires** (`get_buf`) and **releases** (`rls_buf`) a shared UB buffer. Cross-pipeline RAW/WAR dependencies are resolved implicitly by program order — if MTE2 releases `buf_A` and Vector later acquires `buf_A`, the hardware ensures the acquire cannot proceed until the release completes. + +```mlir +// ─── Stage 1: MTE2 loads data into UB ─── +// MTE2 acquires ub_ptr — blocks if Vector hasn't released it from a prior iteration +pto.get_buf "PIPE_MTE2", %bufid_ub_ptr, %mode : i64, i64 +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, ... +// MTE2 done writing ub_ptr — release it so Vector can consume +pto.rls_buf "PIPE_MTE2", %bufid_ub_ptr, %mode : i64, i64 + +// ─── Stage 2: Vector computation ─── +// Vector acquires ub_ptr (input) — blocks until MTE2 releases it (RAW: MTE2 write → V read) +pto.get_buf "PIPE_V", %bufid_ub_ptr, %mode : i64, i64 +// Vector acquires ub_out (output) — blocks until MTE3 releases it from a prior iteration (WAR: MTE3 read → V write) +pto.get_buf "PIPE_V", %bufid_ub_out, %mode : i64, i64 + +pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub_ptr[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} + +// Vector done reading ub_ptr — release so MTE2 can reuse it in next iteration +pto.rls_buf "PIPE_V", %bufid_ub_ptr, %mode : i64, i64 +// Vector done writing ub_out — release so MTE3 can consume +pto.rls_buf "PIPE_V", %bufid_ub_out, %mode : i64, i64 + +// ─── Stage 3: MTE3 stores result to GM ─── +// MTE3 acquires ub_out — blocks until Vector releases it (RAW: V write → MTE3 read) +pto.get_buf "PIPE_MTE3", %bufid_ub_out, %mode : i64, i64 +pto.copy_ubuf_to_gm %ub_out, %gm_out, ... +// MTE3 done reading ub_out — release so Vector can reuse it in next iteration +pto.rls_buf "PIPE_MTE3", %bufid_ub_out, %mode : i64, i64 +``` + +**Key property:** No event IDs needed. Dependencies are implicit from program order of `get_buf`/`rls_buf` on the same buffer ID. This becomes much more convenient in multi-iteration loops (see Example 3). + +--- + +##### Example 3: Ping/Pong Double-Buffering Loop + +Double-buffering overlaps DMA and compute by using two UB buffers alternately. All three stages (MTE2, Vector, MTE3) appear in the **same iteration** — the hardware pipelines them across iterations because different iterations operate on different buffers (`buf[i%2]`). + +###### Event ID scheme (`set_flag` / `wait_flag`) + +With 2 ping/pong buffers and 2 pipeline pairs (MTE2↔V, V↔MTE3), `set_flag`/`wait_flag` needs **8 event IDs** = 2 pipe-pairs × 2 buffers × (forward + reverse): + +**MTE2 ↔ Vector (input buffers):** + +| Event ID | Direction | Purpose | +|----------|-----------|---------| +| `EVT_IN_FWD_0` | MTE2 → V | RAW: buf_in[0] data ready | +| `EVT_IN_FWD_1` | MTE2 → V | RAW: buf_in[1] data ready | +| `EVT_IN_REV_0` | V → MTE2 | WAR: Vector done reading buf_in[0] | +| `EVT_IN_REV_1` | V → MTE2 | WAR: Vector done reading buf_in[1] | + +**Vector ↔ MTE3 (output buffers):** + +| Event ID | Direction | Purpose | +|----------|-----------|---------| +| `EVT_OUT_FWD_0` | V → MTE3 | RAW: buf_out[0] result ready | +| `EVT_OUT_FWD_1` | V → MTE3 | RAW: buf_out[1] result ready | +| `EVT_OUT_REV_0` | MTE3 → V | WAR: MTE3 done reading buf_out[0] | +| `EVT_OUT_REV_1` | MTE3 → V | WAR: MTE3 done reading buf_out[1] | + +###### 3a. `set_flag` / `wait_flag` version + +```mlir +// ═══ Pre-loop: prime ALL reverse-dependency signals ═══ +// Both input and output buffers start unused. We must pre-send +// reverse-dep signals so the first iteration's wait_flags don't deadlock. +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_0"] // ◀ PRIME: buf_in[0] "free" +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_1"] // ◀ PRIME: buf_in[1] "free" +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_0"] // ◀ PRIME: buf_out[0] "free" +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_1"] // ◀ PRIME: buf_out[1] "free" + +scf.for %i = %c0 to %N step %c1 { + // ── All 3 stages in same iteration, indexed by i%2 ── + // %pp = i % 2 (ping/pong selector for buffer & event IDs) + + // ── MTE2: load tile[i] into buf_in[i%2] ── + // WAR: wait until Vector has released buf_in[i%2] from iteration i-2 + pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{pp}"] + pto.copy_gm_to_ubuf %gm_ptr[%i], %ub_in[%pp], ... + // RAW: signal Vector that buf_in[i%2] data is ready + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVT_IN_FWD_{pp}"] + + // ── Vector: compute buf_in[i%2] → buf_out[i%2] ── + // RAW: wait for MTE2 to finish loading buf_in[i%2] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVT_IN_FWD_{pp}"] + // WAR: wait for MTE3 to finish reading buf_out[i%2] from iteration i-2 + pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{pp}"] + scf.for %dummy = %c0 to %c1 step %c1 { + %v = pto.vlds %ub_in[%pp][%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%pp][%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } {llvm.loop.aivector_scope} + // WAR: tell MTE2 "done reading buf_in[i%2]" + pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{pp}"] + // RAW: tell MTE3 "buf_out[i%2] result ready" + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVT_OUT_FWD_{pp}"] + + // ── MTE3: store result from buf_out[i%2] to GM ── + // RAW: wait for Vector to finish writing buf_out[i%2] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVT_OUT_FWD_{pp}"] + pto.copy_ubuf_to_gm %ub_out[%pp], %gm_out[%i], ... + // WAR: tell Vector "done reading buf_out[i%2]" + pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{pp}"] +} + +// ═══ Post-loop: drain — match every pre-loop prime with a wait ═══ +// Each priming set_flag must be paired. The last loop iteration's +// set_flags are consumed by wait_flags that will never fire inside the +// loop (there is no iteration i+2). Drain them here. +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{(N-1)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{(N-2)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{(N-1)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{(N-2)%2}"] // ◀ DRAIN +``` + +**What `set_flag`/`wait_flag` requires outside the loop:** +- **Before the loop (4 × `set_flag`):** Prime every reverse-dependency event ID — one per buffer per pipe-pair. Without this, the first iteration's `wait_flag` for reverse deps would deadlock (no signal was ever sent). +- **After the loop (4 × `wait_flag`):** Drain the matching reverse-dep signals from the last iterations. Every `set_flag` must be paired with a `wait_flag` — the last loop iterations produce signals that no subsequent iteration consumes, so they must be drained explicitly. + +###### 3b. `get_buf` / `rls_buf` version + +Same ping/pong double-buffering, but **no pre-loop priming or post-loop draining needed.** Buffer acquire/release semantics handle everything. + +```mlir +scf.for %i = %c0 to %N step %c1 { + // %pp = i % 2 (ping/pong selector) + + // ── MTE2: load tile[i] into buf[i%2] ── + // Acquires buf[i%2] — on first iteration, buffer is free so proceeds immediately. + // On later iterations, blocks until Vector releases buf[i%2] (WAR: automatic). + pto.get_buf %bufid_buf[%pp], "PIPE_MTE2" + pto.copy_gm_to_ubuf %gm_ptr[%i], %ub_buf[%pp], ... + pto.rls_buf %bufid_buf[%pp], "PIPE_MTE2" + + // ── Vector: compute on buf[i%2] ── + // Acquires buf[i%2] — blocks until MTE2 releases it (RAW: automatic) + pto.get_buf %bufid_buf[%pp], "PIPE_V" + pto.get_buf %bufid_out[%pp], "PIPE_V" + scf.for %dummy = %c0 to %c1 step %c1 { + %v = pto.vlds %ub_buf[%pp][%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%pp][%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } {llvm.loop.aivector_scope} + // Release buf[i%2] — MTE2 can reuse in iteration i+2 (WAR resolved) + pto.rls_buf %bufid_buf[%pp], "PIPE_V" + pto.rls_buf %bufid_out[%pp], "PIPE_V" + + // ── MTE3: store result ── + // Acquires out[i%2] — blocks until Vector releases it (RAW: automatic) + pto.get_buf %bufid_out[%pp], "PIPE_MTE3" + pto.copy_ubuf_to_gm %ub_out[%pp], %gm_out[%i], ... + pto.rls_buf %bufid_out[%pp], "PIPE_MTE3" +} +// No post-loop drain needed — last rls_buf completes the pipeline. +``` + +**No priming, no draining, no event IDs.** The acquire/release protocol on buffer IDs indexed by `i%2` implicitly resolves all cross-pipeline dependencies: +- **RAW** (MTE2→V): Vector's `get_buf` blocks until MTE2's `rls_buf` on `buf[i%2]` +- **WAR** (V→MTE2): MTE2's `get_buf` in iteration `i+2` blocks until Vector's `rls_buf` in iteration `i` (same buffer) +- **First iteration:** Buffer is initially free, so `get_buf` proceeds without blocking — no priming needed + +--- + +#### Comparison Summary + +| Aspect | `set_flag` / `wait_flag` | `get_buf` / `rls_buf` | +|--------|--------------------------|------------------------| +| Dependency model | Explicit event signals | Implicit via buffer acquire/release | +| IDs per pipe-pair | **8** = 2 buffers × 2 dirs × 2 (fwd+rev) | 1 fwd + 1 rev per buffer (shared global pool) | +| Total HW IDs | 8 per pipe-pair, grows with buffers | **32 global** across all pipes | +| Reverse (WAR) deps | Extra `set_flag`/`wait_flag` pair per buffer | Handled automatically | +| Pre-loop setup | `set_flag` to prime each reverse dep | None | +| Post-loop teardown | `wait_flag` to drain all primed signals | None | +| Straight-line code | Simple, clear | Slightly more verbose (bracket each stage) | +| Ping/pong loops | 8 event IDs + 4 prime + 4 drain | Same pattern, no overhead | +| Best used for | Simple pipelines, fine-grained control | Double/multi-buffering, complex loops | + +--- + +#### Inter-Core Sync + +> **Note:** Inter-core sync is only needed for **mixed Cube+Vector tasks** where Cube produces data that Vector consumes (or vice versa). **Vec-only tasks can ignore this section entirely.** + +These ops coordinate execution across the Cube block and Vector subblocks within a cluster. Each core cluster consists of **1 Cube block : 2 Vector subblocks**, each with its own **SU (Sequencer Unit)** running independent instruction streams. + +``` +Core Cluster (1:2 ratio) +┌─────────────────────────────────────────────┐ +│ ┌──────────────┐ ┌──────────────┐ │ +│ │ AIC (Cube) │ │ AIV0 (Vec) │ │ +│ │ ┌────────┐ │ │ ┌────────┐ │ │ +│ │ │ SU │──┼────┼──│ SU │ │ │ +│ │ └────────┘ │ │ └────────┘ │ │ +│ │ CUBE pipe │ │ MTE2/V/MTE3 │ │ +│ │ L0C buffer │ │ UB (256KB) │ │ +│ └──────────────┘ └──────────────┘ │ +│ ┌──────────────┐ │ +│ │ AIV1 (Vec) │ │ +│ │ ┌────────┐ │ │ +│ │ │ SU │ │ │ +│ │ └────────┘ │ │ +│ │ MTE2/V/MTE3 │ │ +│ │ UB (256KB) │ │ +│ └──────────────┘ │ +└─────────────────────────────────────────────┘ +``` + +##### Platform Comparison + +| Aspect | A2A3 (Ascend 910) | A5 (Ascend 950) | +|--------|-------------------|-----------------| +| **Signal op** | `set_cross_core` (mode2) | `set_intra_block` | +| **Wait op** | `wait_flag_dev` | `wait_intra_core` | +| **Wait behavior** | SU-level blocking (entire core stalls) | Per-pipeline (only named pipe stalls) | +| **Semaphore pool** | 16 IDs per cluster, 4-bit counter | 16 IDs, but 32-ID address space (see below) | +| **C→V** | **Broadcast**: one `set` reaches both AIV0+AIV1 | **1:1**: separate `set` per subblock required | +| **V→C** | **Reduce**: Cube waits for both subblocks in one `wait` | **1:1**: Cube needs separate `wait` per subblock | + +##### A2A3: `set_cross_core` / `wait_flag_dev` + +```c +// mode2 broadcast/reduce semantics for 1:2 cluster +set_cross_core(pipe, semaphore_id); // pipe: VEC/MTE2/CUBE/FIX +wait_flag_dev(semaphore_id); // SU-level blocking +``` + +``` +C→V Broadcast (one set reaches both): + AIC ──set_cross_core──┬──> AIV0 sema++ + └──> AIV1 sema++ + +V→C Reduce (one wait for both): + AIV0 ──set_cross_core──┐ + ├──> AIC wait_flag_dev (blocks until BOTH) + AIV1 ──set_cross_core──┘ +``` + +##### `pto.set_cross_core` + +- **syntax:** `pto.set_cross_core %core_id, %event_id : i64, i64` +- **semantics:** Signal event to another core. Uses **mode2** for 1:2 cluster on A2A3. + +##### `pto.wait_flag_dev` + +- **syntax:** `pto.wait_flag_dev %core_id, %event_id : i64, i64` +- **semantics:** Wait for event from another core. **SU-level blocking** — entire core stalls. + +##### A5: `set_intra_block` / `wait_intra_core` + +```c +set_intra_block(trigger_pipe, semaphore_id); +wait_intra_core(wait_pipe, semaphore_id); // only named pipe stalls +``` + +**A5 semaphore address space:** The hardware has **16 physical semaphore IDs** but exposes a **32-ID address space** to support 1:1 signaling to each subblock: + +| ID Range | Target | +|----------|--------| +| 0–15 | AIV0 (subblock 0) | +| 16–31 (+15 offset) | AIV1 (subblock 1) | + +This means C→V requires **separate `set_intra_block` calls** per subblock (no broadcast), and V→C requires **separate `wait_intra_core` calls** per subblock (no hardware reduce). + +``` +C→V on A5 (1:1, no broadcast — need two sets): + AIC ──set_intra_block(pipe, sema_id)────> AIV0 + AIC ──set_intra_block(pipe, sema_id+15)──> AIV1 + +V→C on A5 (1:1, no reduce — need two waits): + AIV0 ──set_intra_block──> AIC wait_intra_core(pipe, sema_id) + AIV1 ──set_intra_block──> AIC wait_intra_core(pipe, sema_id+15) // extra wait +``` + +##### `pto.set_intra_block` + +- **syntax:** `pto.set_intra_block %block_id, %event_id : i64, i64` +- **semantics:** Signal event within a block (A5). Specifies **trigger pipe**. 1:1 per subblock. + +##### `pto.wait_intra_core` + +- **syntax:** `pto.wait_intra_core %block_id, %event_id : i64, i64` +- **semantics:** Wait for event within block (A5). Specifies **which pipeline should wait** — only that pipe stalls, SU and other pipes continue. + +##### Wait Granularity Comparison + +``` +A2A3 wait_flag_dev (SU-level stall): + SU ──┬── PIPE_MTE2 ───╳ ALL STALLED + ├── PIPE_V ───╳ ALL STALLED + └── PIPE_MTE3 ───╳ ALL STALLED + +A5 wait_intra_core "PIPE_MTE2" (per-pipe stall): + SU ──┬── PIPE_MTE2 ───╳ STALLED (waiting for Cube) + ├── PIPE_V ─── ✓ RUNNING + └── PIPE_MTE3 ─── ✓ RUNNING +``` + + + +### 2. DMA Copy Programming + +> **Category:** DMA transfer configuration and execution +> **Pipelines:** MTE2 (GM→UB), MTE3 (UB→GM) + +DMA transfers move data between Global Memory (GM) and Unified Buffer (UB). The MTE engines operate asynchronously from the Vector core, requiring explicit sync (see [Pipeline Sync](#isa-01-pipeline-sync)). + +The MTE2/MTE3 DMA engine executes a **multi-level nested loop** transfer. Before issuing the copy instruction, stride and loop-size registers must be configured. + +--- + +#### Loop Stride Configuration (GM→UB) + +These ops configure the MTE2 DMA engine's hardware loops for GM→UB transfers. They must be set **before** calling `pto.copy_gm_to_ubuf`. + +##### `pto.set_loop_size_outtoub` + +- **syntax:** `pto.set_loop_size_outtoub %loop1_count, %loop2_count : i64, i64` +- **semantics:** Configure HW loop iteration counts for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%loop1_count` | 21 bits | Inner HW loop iteration count | +| `%loop2_count` | 21 bits | Outer HW loop iteration count | + +When not using multi-level looping, set both to 1. + +--- + +##### `pto.set_loop2_stride_outtoub` + +- **syntax:** `pto.set_loop2_stride_outtoub %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure outer loop (loop2) pointer advance for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 40 bits | GM source pointer advance per loop2 iteration (bytes) | +| `%dst_stride` | 21 bits | UB destination pointer advance per loop2 iteration (bytes) | + +After each loop2 iteration, the DMA engine advances the GM read pointer by `%src_stride` and UB write pointer by `%dst_stride`. + +--- + +##### `pto.set_loop1_stride_outtoub` + +- **syntax:** `pto.set_loop1_stride_outtoub %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure inner loop (loop1) pointer advance for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 40 bits | GM source pointer advance per loop1 iteration (bytes) | +| `%dst_stride` | 21 bits | UB destination pointer advance per loop1 iteration (bytes) | + +--- + +#### Loop Stride Configuration (UB→GM) + +These ops configure the MTE3 DMA engine's hardware loops for UB→GM transfers. They must be set **before** calling `pto.copy_ubuf_to_gm`. + +Note: UB stride fields are 21 bits (sufficient for 256KB UB address space), GM stride fields are 40 bits (full GM address range). + +##### `pto.set_loop_size_ubtoout` + +- **syntax:** `pto.set_loop_size_ubtoout %loop1_count, %loop2_count : i64, i64` +- **semantics:** Configure HW loop iteration counts for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%loop1_count` | 21 bits | Inner HW loop iteration count | +| `%loop2_count` | 21 bits | Outer HW loop iteration count | + +--- + +##### `pto.set_loop2_stride_ubtoout` + +- **syntax:** `pto.set_loop2_stride_ubtoout %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure outer loop (loop2) pointer advance for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 21 bits | UB source pointer advance per loop2 iteration (bytes) | +| `%dst_stride` | 40 bits | GM destination pointer advance per loop2 iteration (bytes) | + +--- + +##### `pto.set_loop1_stride_ubtoout` + +- **syntax:** `pto.set_loop1_stride_ubtoout %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure inner loop (loop1) pointer advance for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 21 bits | UB source pointer advance per loop1 iteration (bytes) | +| `%dst_stride` | 40 bits | GM destination pointer advance per loop1 iteration (bytes) | + +--- + +#### DMA Transfer Execution + +##### `pto.copy_gm_to_ubuf` + +- **syntax:** +```mlir +pto.copy_gm_to_ubuf %gm_src, %ub_dst, + %sid, %n_burst, %len_burst, %left_padding, %right_padding, + %data_select_bit, %l2_cache_ctl, %src_stride, %dst_stride + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` +- **semantics:** DMA transfer from Global Memory (`!pto.ptr`) to Unified Buffer (`!pto.ptr`). + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%gm_src` | GM source pointer (`!pto.ptr`) | +| `%ub_dst` | UB destination pointer (`!pto.ptr`, 32B-aligned) | +| `%sid` | Stream ID (usually 0) | +| `%n_burst` | Number of burst rows (innermost loop count) | +| `%len_burst` | Contiguous bytes transferred per burst row | +| `%left_padding` | Left padding count (bytes) | +| `%right_padding` | Right padding count (bytes) | +| `%data_select_bit` | Padding / data-select control bit (`i1`) | +| `%l2_cache_ctl` | L2 cache allocate control (TBD — controls whether DMA allocates in L2 cache) | +| `%src_stride` | GM source stride: start-to-start distance between consecutive burst rows (bytes) | +| `%dst_stride` | UB destination stride: start-to-start distance between consecutive burst rows (bytes, 32B-aligned) | + +--- + +##### `pto.copy_ubuf_to_gm` + +- **syntax:** +```mlir +pto.copy_ubuf_to_gm %ub_src, %gm_dst, + %sid, %n_burst, %len_burst, %reserved, %dst_stride, %src_stride + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` +- **semantics:** DMA transfer from Unified Buffer (`!pto.ptr`) to Global Memory (`!pto.ptr`). MTE3 reads only `len_burst` bytes from each UB row (de-padding). + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%ub_src` | UB source pointer (`!pto.ptr`, 32B-aligned) | +| `%gm_dst` | GM destination pointer (`!pto.ptr`) | +| `%sid` | Stream ID (usually 0) | +| `%n_burst` | Number of burst rows | +| `%len_burst` | Contiguous bytes transferred per burst row | +| `%reserved` | Reserved field (set to 0) | +| `%dst_stride` | GM destination stride: start-to-start distance between consecutive burst rows (bytes) | +| `%src_stride` | UB source stride: start-to-start distance between consecutive burst rows (bytes, 32B-aligned) | + +--- + +##### `pto.copy_ubuf_to_ubuf` + +- **syntax:** +```mlir +pto.copy_ubuf_to_ubuf %source, %dest, %sid, %n_burst, %len_burst, %src_stride, %dst_stride + : !pto.ptr, !pto.ptr, i64 x5 +``` +- **semantics:** Copy within Unified Buffer. + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%source` | UB source pointer | +| `%dest` | UB destination pointer | +| `%sid` | Stream ID | +| `%n_burst` | Number of bursts | +| `%len_burst` | Length per burst | +| `%src_stride` | Source stride | +| `%dst_stride` | Destination stride | + +--- + +#### Burst / Stride / Pad Model + +All A5 DMA addresses are **stride-based**: stride is the distance from the start of one row to the start of the next row (`stride >= lenBurst`). There is no separate "gap" parameter. + +##### Key Terms + +``` +burst = lenBurst contiguous bytes transferred per row +stride = distance (bytes) from start of row[r] to start of row[r+1] +pad = ub_stride - lenBurst, padded to the 32B alignment boundary +``` + +##### Alignment Constraints + +- **UB addresses** (both source and destination) must be **32-byte aligned**. +- **GM→UB padding**: When `data_select_bit = true`, each UB row is padded from `lenBurst` up to the **32B-aligned boundary** of `ub_stride` with `pad_val` (set via `set_mov_pad_val`). This ensures every UB row starts at a 32B-aligned offset. +- **UB→GM de-padding**: MTE3 reads `lenBurst` bytes from each 32B-aligned UB row (skipping any padding that was added during load), writing only valid data to GM. This effectively strips padding on store. + +##### 2D Diagram: GM→UB (pto.copy_gm_to_ubuf) + +``` +GM (source, `!pto.ptr`): + + |<--- src_stride (start-to-start) --->| + |<- len_burst ->| | +Row 0: [##DATA########]......................| +Row 1: [##DATA########]......................| +Row 2: [##DATA########]......................| + ... +Row N-1: [##DATA########] + +UB (destination, `!pto.ptr`, 32B-aligned): + + |<---------- dst_stride (32B-aligned) ---------->| + |<- len_burst ->|<- pad (to 32B boundary) ->| | +Row 0: [##DATA########][000000 PAD 000000000000000] +Row 1: [##DATA########][000000 PAD 000000000000000] +Row 2: [##DATA########][000000 PAD 000000000000000] + ... +Row N-1: [##DATA########][000000 PAD 000000000000000] + +N = n_burst +stride = start of row[r] to start of row[r+1] +pad = filled with pad_val to 32B boundary (data_select_bit=true) +[DATA] = valid data transferred by DMA +[PAD] = pad_val fill (set via set_mov_pad_val) +``` + +##### 2D Diagram: UB→GM (pto.copy_ubuf_to_gm) + +``` +UB (source, `!pto.ptr`, 32B-aligned start addr): + + |<---------- src_stride (32B-aligned) --------->| + |<- len_burst ->|<-- pad (ignored on read) -->| | +Row 0: [##DATA########][000 pad 000000000000000000] +Row 1: [##DATA########][000 pad 000000000000000000] +Row 2: [##DATA########][000 pad 000000000000000000] + ... +Row N-1: [##DATA########][000 pad 000000000000000000] + +GM (destination, `!pto.ptr`): + + |<--- dst_stride (start-to-start) --->| + |<- len_burst ->| | +Row 0: [##DATA########]......................| +Row 1: [##DATA########]......................| +Row 2: [##DATA########]......................| + ... +Row N-1: [##DATA########] + +N = n_burst +MTE3 reads only len_burst bytes from each UB row (de-padding). +Only len_burst bytes are written to each GM row. +``` + +--- + +#### Multi-Level Loop Semantics (C Code) + +The full DMA transfer is a nested loop. The HW loop registers (set before the copy) control the outer levels, and the copy instruction parameters control the innermost burst level. + +##### GM→UB Full Loop + +```c +// C equivalent of what the HW executes: +for (int j = 0; j < loop2_count; j++) { // HW outer loop + uint8_t *gm1 = gm_src + j * loop2_src_stride; + uint8_t *ub1 = ub_dst + j * loop2_dst_stride; + + for (int k = 0; k < loop1_count; k++) { // HW inner loop + uint8_t *gm2 = gm1 + k * loop1_src_stride; + uint8_t *ub2 = ub1 + k * loop1_dst_stride; + + for (int r = 0; r < n_burst; r++) { // burst engine + memcpy(ub2 + r * dst_stride, // UB dest row + gm2 + r * src_stride, // GM src row + len_burst); // contiguous bytes + if (data_select_bit) + memset(ub2 + r * dst_stride + len_burst, + pad_val, dst_stride - len_burst); + } + } +} +``` + +##### UB→GM Full Loop + +```c +// C equivalent: +for (int j = 0; j < loop2_count; j++) { + uint8_t *ub1 = ub_src + j * loop2_src_stride; + uint8_t *gm1 = gm_dst + j * loop2_dst_stride; + + for (int k = 0; k < loop1_count; k++) { + uint8_t *ub2 = ub1 + k * loop1_src_stride; + uint8_t *gm2 = gm1 + k * loop1_dst_stride; + + for (int r = 0; r < n_burst; r++) { + memcpy(gm2 + r * dst_stride, // GM dest row + ub2 + r * src_stride, // UB src row + len_burst); // contiguous bytes + } + } +} +``` + +--- + +#### Example 1: GM→UB — Load a 32×32 f32 Tile (Simple Case) + +Load a 32×32 f32 tile from GM into UB. This matches the `abs_kernel_2d` test case. + +``` +GM layout (32 × 32 f32, contiguous): + + |<- len_burst = 128B (32 × 4) ->| + |<- src_stride = 128B --------->| + +--[#######TILE#######]--+ row 0 + +--[#######TILE#######]--+ row 1 + ... + +--[#######TILE#######]--+ row 31 + +UB layout (32 × 32 f32, 32B-aligned, contiguous): + + |<- dst_stride = 128B (32B-aligned) ->| + +--[#######TILE#######]--+ row 0 + +--[#######TILE#######]--+ row 1 + ... + +--[#######TILE#######]--+ row 31 + + len_burst = 32 × 4 = 128 bytes + src_stride = 128 bytes (contiguous rows) + dst_stride = 128 bytes (already 32B-aligned, no padding) +``` + +```mlir +// Simple 2D load — no multi-level loops needed +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + +pto.copy_gm_to_ubuf %arg0, %ub_in, + %c0_i64, // sid = 0 + %c32_i64, // n_burst = 32 (32 rows) + %c128_i64, // len_burst = 128 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c128_i64, // src_stride = 128 bytes + %c128_i64 // dst_stride = 128 bytes + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +#### Example 2: GM→UB — Load a 2D Tile from a Larger Matrix + +Load a 64×128 tile (f16) from a 1024×512 matrix in GM into UB. + +``` +GM layout (1024 × 512 f16): + + col 0 col 128 col 512 + | | | + +--[###TILE###]+.....................+ row R + +--[###TILE###]+.....................+ row R+1 + ... + +--[###TILE###]+.....................+ row R+63 + + |<--------- src_stride = 1024B ----------->| + |<-len_burst=256B->| + + len_burst = 128 × 2 = 256 bytes (128 f16 elements) + src_stride = 512 × 2 = 1024 bytes (start-to-start, full GM row) + +UB layout (64 × 128 f16, 32B-aligned, contiguous): + + +--[###TILE###]--+ row 0 (256 bytes, 32B-aligned, no pad) + +--[###TILE###]--+ row 1 + ... + +--[###TILE###]--+ row 63 + + dst_stride = 256 bytes (= len_burst, already 32B-aligned, no padding) +``` + +```mlir +// Simple 2D load — no multi-level loops needed +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 (64 rows) + %c256_i64, // len_burst = 256 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c1024_i64, // src_stride = 1024 bytes (full matrix row) + %c256_i64 // dst_stride = 256 bytes (tile row) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +#### Example 3: GM→UB — Load with Padding + +Load 100 valid columns from GM into a 128-wide UB tile (f16). The remaining 28 columns are zero-padded. + +``` +GM (100 cols valid, contiguous): + + |<-len_burst=200B->| + |<- src_stride=200B (start-to-start) ->| + +--[####DATA####]-+ row 0 + +--[####DATA####]-+ row 1 + ... + +--[####DATA####]-+ row 63 + +UB (128 cols wide, 32B-aligned, padded): + + |<--------- dst_stride = 256B (32B-aligned) --------->| + |<-len_burst=200B->|<---- pad = 56B to 32B boundary ->| + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 0 + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 1 + ... + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 63 + + len_burst = 100 × 2 = 200 bytes + src_stride = 200 bytes (start-to-start, contiguous in GM) + dst_stride = 128 × 2 = 256 bytes (32B-aligned tile width in UB) + pad = 256 - 200 = 56 bytes (padded to 32B boundary with pad_val) +``` + +```mlir +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 + %c200_i64, // len_burst = 200 bytes + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %true, // data_select_bit = true (enable padding) + %c0_i64, // l2_cache_ctl = 0 + %c200_i64, // src_stride = 200 bytes + %c256_i64 // dst_stride = 256 bytes (32B-aligned) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +#### Example 4: UB→GM — Store a 32×32 f32 Tile (Simple Case) + +Store a 32×32 f32 tile from UB back to GM. This matches the `abs_kernel_2d` test case. + +``` +UB (source, 32B-aligned, 32 × 32 f32): + + |<- src_stride = 128B (32B-aligned) ->| + |<- len_burst = 128B ->| + +--[#######TILE#######]---+ row 0 + +--[#######TILE#######]---+ row 1 + ... + +--[#######TILE#######]---+ row 31 + + (no padding here — len_burst == src_stride) + +GM (dest, 32 × 32 f32): + + |<- dst_stride = 128B ->| + |<- len_burst = 128B -->| + +--[#######TILE#######]---+ row 0 + +--[#######TILE#######]---+ row 1 + ... + +--[#######TILE#######]---+ row 31 +``` + +```mlir +// Configure MTE3 strides +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + +pto.copy_ubuf_to_gm %ub_out, %arg1, + %c0_i64, // sid = 0 + %c32_i64, // n_burst = 32 + %c128_i64, // len_burst = 128 bytes + %c0_i64, // reserved = 0 + %c128_i64, // dst_stride = 128 bytes + %c128_i64 // src_stride = 128 bytes + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` + +--- + +#### Example 5: UB→GM — Store a 2D Tile Back to a Larger Matrix + +Store a 64×128 tile (f16) from UB back to a 1024×512 GM matrix at an offset. + +``` +UB (source, 32B-aligned, 64 × 128 f16): + + |<- src_stride = 256B (32B-aligned) ->| + |<- len_burst = 256B ->| + +--[#####TILE#####]---+ row 0 + +--[#####TILE#####]---+ row 1 + ... + +--[#####TILE#####]---+ row 63 + + (no padding here — len_burst == src_stride) + +GM (dest, into 1024 × 512 matrix): + + |<----------- dst_stride = 1024B (start-to-start) --------->| + |<- len_burst = 256B ->| | + col 0 col 128 col 512 + +--[#####TILE#####]---+.............................+ row R + +--[#####TILE#####]---+.............................+ row R+1 + ... + +--[#####TILE#####]---+.............................+ row R+63 + + MTE3 reads len_burst bytes from each 32B-aligned UB row, + writes only len_burst bytes per GM row (stride controls row spacing). +``` + +```mlir +// Configure MTE3 strides +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_ubtoout %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_ubtoout %c0_i64, %c0_i64 : i64, i64 + +pto.copy_ubuf_to_gm %ub_ptr, %gm_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 + %c256_i64, // len_burst = 256 bytes + %c0_i64, // reserved = 0 + %c1024_i64, // dst_stride = 1024 bytes (GM row) + %c256_i64 // src_stride = 256 bytes (UB row) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` + +--- + +#### Example 6: GM→UB with Multi-Level Loop (Batch of Tiles) + +Load 4 batches of 8×128 tiles from a [4, 8, 128] f16 tensor using loop1. + +``` +GM [4, 8, 128] f16 (contiguous): UB (4 tiles laid out sequentially): + + batch 0: 8 rows × 256 bytes [batch 0: 8×128][batch 1: 8×128] + batch 1: 8 rows × 256 bytes [batch 2: 8×128][batch 3: 8×128] + batch 2: 8 rows × 256 bytes + batch 3: 8 rows × 256 bytes loop1 src_stride = 2048 bytes (8 × 256) + loop1 dst_stride = 2048 bytes (8 × 256) + Each batch = 8 × 256 = 2048 bytes loop1_count = 4 (iterate over batches) +``` + +```mlir +// loop1_count = 4 batches, loop2_count = 1 (not used) +pto.set_loop_size_outtoub %c4_i64, %c1_i64 : i64, i64 + +// loop1 stride: advance by one batch (2048 bytes) in both GM and UB +pto.set_loop1_stride_outtoub %c2048_i64, %c2048_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c8_i64, // n_burst = 8 rows per batch + %c256_i64, // len_burst = 256 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c256_i64, // src_stride = 256 (contiguous rows) + %c256_i64 // dst_stride = 256 (contiguous rows) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +Execution trace: + +``` +loop1 iter 0: gm_ptr + 0×2048 → ub_ptr + 0×2048, DMA 8 rows × 256B +loop1 iter 1: gm_ptr + 1×2048 → ub_ptr + 1×2048, DMA 8 rows × 256B +loop1 iter 2: gm_ptr + 2×2048 → ub_ptr + 2×2048, DMA 8 rows × 256B +loop1 iter 3: gm_ptr + 3×2048 → ub_ptr + 3×2048, DMA 8 rows × 256B +``` + + + +### 3. Vector Load/Store + +> **Category:** UB ↔ Vector Register data movement +> **Pipeline:** PIPE_V (Vector Core) + +Vector loads move data from Unified Buffer (UB) to vector registers (`vreg`). Vector stores move data from `vreg` back to UB. All vector compute operates only on `vreg` — UB is the staging area between DMA and compute. + +#### Common Operand Model + +- `%source` / `%dest` is the base address operand in SSA form. The base pointer + MUST address the Vector tile buffer / UB space. +- `%offset` is the displacement operand in SSA form. The exact encoding is + instruction-specific, but the effective address and any post-update behavior + MUST match the selected instruction form. +- `%mask` is the predicate operand for predicated memory families. For memory + families, + inactive lanes or inactive blocks MUST NOT issue memory requests unless the + instruction explicitly documents a different behavior. +- `%result` is the destination vector register value in SSA form. +- `!pto.align` is the SSA carrier for alignment-buffer state used by unaligned + load/store families. The PTO micro Instruction representation makes that state explicit rather than implicit. + +--- + +#### Latency and throughput (A5) + +**Cycle-accurate simulator (CA model)** issue→retire timings for vector-side instructions behind this chapter. Values are **simulator** results, **not** guaranteed for silicon. + +**SOC:** Tables below are from **Ascend910_9599** CA sim (the pto-isa ST default when **Ascend950PR_9599** is not selected). + +**Log `dist:` tokens:** PTO load/store modes lower to **`RV_VLD` / `RV_VLDI` / `RV_VST` / `RV_VSTI`** with a **`dist:`** field on the vector pipes (`RVECLD` / `RVECST`). Some simulator logs typo contiguous load as `dist:NORAML`; treat as **`NORMAL`**. + +##### Reference op latencies (A5 mnemonics) + +| A5 mnemonic | Mode / note | Typical issue→retire (cycles) | +|-------------|-------------|------------------------------| +| `RV_VLD` | `dist:NORMAL` / `NORAML` | **9** | +| `RV_VLDI` | `dist:DINTLV_B32` (dual vreg) | **9** | +| `RV_VST` / `RV_VSTI` | `dist:NORM_B32` | **9** | +| `RV_VGATHER2` | `Dtype: B32` | **27–28** | +| `RV_VGATHERB` | indexed byte gather | **~21** | +| `RV_VSCATTER` | `Dtype: B16` | **~17** | +| `RV_VADD` | F32 between UB-backed ops | **7** | + +##### `dist:` tokens (issue→retire) + +Most **`dist:`** tokens are **9** issue→retire cycles. **`INTLV_*`** on **`RV_VSTI`** are **12** cycles. + +| `dist:` (as in log) | RV op | issue→retire (cycles) | +|---------------------|-------|----------------------| +| `DINTLV_B32` | `RV_VLDI` | **9** | +| `DINTLV_B16` | `RV_VLDI` | **9** | +| `DINTLV_B8` | `RV_VLDI` | **9** | +| `BRC_B32` | `RV_VLD` | **9** | +| `BRC_B8` | `RV_VLD` | **9** | +| `BRC_B16` | `RV_VLD` | **9** | +| `BRC_BLK` | `RV_VLD` | **9** | +| `INTLV_B32` | `RV_VSTI` | **12** | +| `INTLV_B16` | `RV_VSTI` | **12** | +| `INTLV_B8` | `RV_VSTI` | **12** | +| `UNPK_B8` | `RV_VLD` | **9** | +| `UNPK_B16` | `RV_VLD` | **9** | +| `UNPK_B32` | `RV_VLD` | **9** | +| `NORM_B32` | `RV_VSTI` | **9** | +| `NORM_B16` | `RV_VSTI` | **9** | +| `NORM_B8` | `RV_VSTI` | **9** | +| `PK_B32` | `RV_VSTI` | **9** | +| `PK_B16` | `RV_VSTI` | **9** | +| `NORMAL` / `NORAML` | `RV_VLD` | **9** | + +**Note:** PTO intrinsic **`BLK`** matches the **`BRC_BLK`** `dist:` string on **`RV_VLD`** in simulator logs (block-replicate path; not a plain contiguous copy in the usual tiling use). + +**Issue (vector load/store):** `pto.vlds` (**`RV_VLD`**) is **dual-issue capable**: two independent `pto.vlds` can issue **in the same cycle**. **Alternatively**, the hardware can issue **one** `pto.vlds` **and** **one** `pto.vsts` together (**1+1**) in the same cycle. Each cycle is **either** dual **`vlds`** **or** **`vlds` + `vsts` (1+1)**—those two issue modes are mutually exclusive. Sustained throughput still depends on RAW hazards and loop structure. + +**Throughput (simulator, pattern-dependent):** + +- **`RV_VLD` / `pto.vlds`:** Dual-issue **or** half of a **1+1** with `vsts`, per the rule above. +- **`RV_VST` / `pto.vsts`:** In a **1+1** cycle, pairs with one `vlds`; otherwise typically **one** store per cycle in tight loops. +- **`RV_VGATHER2`:** Much lower than contiguous `RV_VLD` (on the order of **~0.1** ops/cycle in steady-state alongside 27–28-cycle latency). + +##### PTO `dist` summary (loads) + +| PTO `dist` (load) | Latency | +|-------------------|-------------------| +| `NORM` | **9** cycles | +| `UNPK_B8`, `UNPK_B16`, `UNPK_B32` | **9** cycles | +| `DINTLV_B32` | **9** cycles (`RV_VLDI`) | +| `DINTLV_B16`, `DINTLV_B8` | **9** cycles (same `RV_VLDI` + `dist:DINTLV_*` path as `DINTLV_B32`) | +| `BRC_B32` | **9** cycles | +| `BRC_B8`, `BRC_B16` | **9** cycles (`RV_VLD`) | +| `BLK` | **9** cycles as **`dist:BRC_BLK`** on `RV_VLD` | +| `BDINTLV` | **9** cycles | +| `US_*`, `DS_*`, `SPLT*` | **9** cycles | + +##### PTO `dist` summary (stores) + +| PTO `dist` (store) | Latency | +|--------------------|-------------------| +| `NORM_B8`, `NORM_B16`, `NORM_B32` | **9** cycles (`RV_VSTI`) | +| `PK_B16`, `PK_B32` | **9** cycles | +| `INTLV_B32` (`pto.vstx2`) | **12** cycles | +| `INTLV_B16`, `INTLV_B8` | **12** cycles (same interleave store path as `INTLV_B32`) | +| `MRG4CHN_B8`, `MRG2CHN_*` | **9** cycles | + +##### Gather, scatter, and special addressing + +| PTO op | A5-level | Latency | +|--------|----------|-------------------| +| `pto.vgather2` | `RV_VGATHER2` | **27–28** cycles (pattern-dependent) | +| `pto.vgatherb` | `RV_VGATHERB` | **~21** cycles issue→retire | +| `pto.vgather2_bc` | (broadcast gather) | **27–28** cycles (same as **`pto.vgather2`**) | +| `pto.vscatter` | `RV_VSCATTER` | **~17** cycles for **`Dtype: B16`** | + +##### Strided loads/stores, unaligned ops, alignment state + +Ops such as **`pto.vldas`**, **`pto.vldus`**, **`pto.vsld`**, **`pto.vsldb`**, **`pto.vsst`**, **`pto.vsstb`**, **`pto.vsta`**, **`pto.vstas`**, **`pto.vstar`**, **`pto.vstu`**, **`pto.vstus`**, **`pto.vstur`**: **9** cycles (same vector load/store pipe family as contiguous `RV_VLD` / `RV_VST` unless listed otherwise above). + +##### Dual-issue vs DMA + +DMA **`TLOAD` / `TSTORE`** (global memory ↔ UB) use **MTE** pipes, not `RV_VLD`/`RV_VST`. **MTE2** `MOV_*` latency is not the same as vector `RV_VLD` latency (see `02-dma-copy.md` for GM↔UB movement). + +--- + +#### Contiguous Loads + +##### `pto.vlds` + +- **syntax:** `%result = pto.vlds %source[%offset] {dist = "DIST"} : !pto.ptr -> !pto.vreg` +- **semantics:** Vector load with distribution mode. +- **inputs:** + `%source` is the UB base address, `%offset` is the load displacement, and + `DIST` selects the distribution mode. +- **outputs:** + `%result` is the loaded vector register value. +- **constraints and limitations:** + The effective address MUST satisfy the alignment rule of the selected + distribution mode. `NORM` reads one full vector footprint. Broadcast, + upsample, downsample, unpack, split-channel, and deinterleave modes change + how memory bytes are mapped into destination lanes, but they do not change the + fact that the source is UB memory. + +**Distribution modes:** + +| Mode | Description | C Semantics | Latency | +|------|-------------|-------------|---------------------| +| `NORM` | Contiguous 256B load | `dst[i] = UB[base + i * sizeof(T)]` | **9** cycles | +| `BRC_B32` | Broadcast single element | `dst[i] = UB[base]` for all i | **9** cycles | +| `BRC_B8`, `BRC_B16` | Broadcast first lane element | Same idea at B8/B16 width | **9** cycles | +| `US_B8/B16` | Upsample (duplicate each element) | `dst[2*i] = dst[2*i+1] = UB[base + i]` | **9** cycles | +| `DS_B8/B16` | Downsample (every 2nd element) | `dst[i] = UB[base + 2*i]` | **9** cycles | +| `UNPK_B8/B16/B32` | Unpack (zero-extend to wider type) | `dst_i32[i] = (uint32_t)UB_i16[base + 2*i]` | **9** cycles | +| `SPLT4CHN_B8` | Split 4-channel (RGBA → R plane) | Extract every 4th byte | **9** cycles | +| `SPLT2CHN_B8/B16` | Split 2-channel | Extract every 2nd element | **9** cycles | +| `DINTLV_B32` | Deinterleave 32-bit | Even elements only | **9** cycles | +| `DINTLV_B16`, `DINTLV_B8` | Deinterleave 16-bit / 8-bit | Pair lanes from interleaved UB | **9** cycles | +| `BDINTLV` | Block deinterleave | (see PTO headers for exact tiling) | **9** cycles | +| `BLK` | Block load | Blocked / tiled access pattern (see PTO headers) | **9** cycles (`dist:BRC_BLK` on `RV_VLD`) | + +**Example — Contiguous load:** +```mlir +%v = pto.vlds %ub[%offset] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> +``` + +**Example — Broadcast scalar to all lanes:** +```mlir +%v = pto.vlds %ub[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> +``` + +--- + +##### `pto.vldas` + +- **syntax:** `%result = pto.vldas %source : !pto.ptr -> !pto.align` +- **semantics:** Prime alignment buffer for subsequent unaligned load. +- **inputs:** + `%source` is the UB address whose surrounding aligned block seeds the load + alignment state. +- **outputs:** + `%result` is the initialized load-alignment state. +- **constraints and limitations:** + This op is the required leading operation for a `pto.vldus` stream using the + same alignment state. The source address itself need not be 32-byte aligned; + hardware truncates it to the aligned block boundary for the priming load. +- **Latency:** **9** cycles. + +--- + +##### `pto.vldus` + +- **syntax:** `%result, %align_out, %base_out = pto.vldus %source, %align : !pto.ptr, !pto.align -> !pto.vreg, !pto.align, !pto.ptr` +- **semantics:** Unaligned load using primed align state. +- **inputs:** + `%source` is the current UB address and `%align` is the incoming load + alignment state primed by `pto.vldas` or a prior `pto.vldus`. +- **outputs:** + `%result` is the assembled vector value, `%align_out` is the updated alignment + state, and `%base_out` is the post-update base pointer state exposed in SSA + form. +- **constraints and limitations:** + A matching `pto.vldas` MUST appear before the first dependent `pto.vldus` + stream in the same vector loop. Both the alignment state and the base address + advance across the stream, and the PTO micro Instruction representation exposes those updates as SSA results. +- **Latency:** **9** cycles. + +**Unaligned load pattern:** +```mlir +%align = pto.vldas %ub : !pto.ptr -> !pto.align +%vec, %align2, %ub2 = pto.vldus %ub, %align : !pto.ptr, !pto.align -> !pto.vreg<64xf32>, !pto.align, !pto.ptr +``` + +--- + +#### Dual Loads (Deinterleave) + +##### `pto.vldx2` + +- **syntax:** `%low, %high = pto.vldx2 %source[%offset], "DIST" : !pto.ptr, index -> !pto.vreg, !pto.vreg` +- **semantics:** Dual load with deinterleave (AoS → SoA conversion). +- **inputs:** + `%source` is the UB base pointer, `%offset` is the displacement, and `DIST` + selects a dual-load/deinterleave layout. +- **outputs:** + `%low` and `%high` are the two destination vectors. +- **constraints and limitations:** + This family is only legal for interleave/deinterleave style distributions. + The two outputs form an ordered pair, and that pairing MUST be preserved. +- **Latency:** **`DINTLV_B32` → 9** cycles on `RV_VLDI`. **`DINTLV_B16` / `DINTLV_B8` → 9** cycles on `RV_VLDI`. **`BDINTLV` → 9** cycles on `RV_VLDI`. + +**Distribution modes:** `DINTLV_B8`, `DINTLV_B16`, `DINTLV_B32`, `BDINTLV` + +```c +// DINTLV_B32: deinterleave 32-bit elements +for (int i = 0; i < 64; i++) { + low[i] = UB[base + 8*i]; // even elements + high[i] = UB[base + 8*i + 4]; // odd elements +} +``` + +**Example — Load interleaved XY pairs into separate X/Y vectors:** +```mlir +%x, %y = pto.vldx2 %ub[%offset], "DINTLV_B32" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +``` + +--- + +#### Strided Loads + +##### `pto.vsld` + +- **syntax:** `%result = pto.vsld %source[%offset], "STRIDE" : !pto.ptr -> !pto.vreg` +- **semantics:** Strided load with fixed stride pattern. +- **inputs:** + `%source` is the UB base pointer and `%offset` is the displacement encoded + with the selected fixed stride mode. +- **outputs:** + `%result` is the loaded vector. +- **constraints and limitations:** + This is a deprecated compatibility family. The selected stride token + determines which sub-elements are read from each source block. +- **Latency:** **9** cycles. + +**Stride modes:** `STRIDE_S3_B16`, `STRIDE_S4_B64`, `STRIDE_S8_B32`, `STRIDE_S2_B64` + +--- + +##### `pto.vsldb` + +- **syntax:** `%result = pto.vsldb %source, %offset, %mask : !pto.ptr, i32, !pto.mask -> !pto.vreg` +- **semantics:** Block-strided load for 2D tile access. +- **inputs:** + `%source` is the UB base pointer, `%offset` is the packed stride/control word, + and `%mask` controls which blocks participate. +- **outputs:** + `%result` is the loaded vector. +- **constraints and limitations:** + `%offset` is not a plain byte displacement; it encodes the block stride and + repeat pattern. If a block is masked off, the corresponding destination block + is zeroed and MUST NOT raise an address overflow exception for that block. +- **Latency:** **9** cycles. + +--- + +#### Gather (Indexed) Loads + +##### `pto.vgather2` + +- **syntax:** `%result = pto.vgather2 %source, %offsets, %active_lanes : !pto.ptr, !pto.vreg, index -> !pto.vreg` +- **semantics:** Indexed gather from UB. +- **inputs:** + `%source` is the UB base pointer, `%offsets` provides per-lane element + offsets, and `%active_lanes` bounds how many lanes participate. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + Only the first `%active_lanes` indices participate. The index element width + and interpretation MUST match the selected gather form, and each effective + address must satisfy that form's alignment rules. +- **Latency:** **27–28** cycles per `RV_VGATHER2`; throughput much lower than contiguous `RV_VLD` (see **Latency and throughput (A5)** at the start of this chapter). + +```c +for (int i = 0; i < active_lanes; i++) + dst[i] = UB[base + offsets[i] * sizeof(T)]; +``` + +--- + +##### `pto.vgatherb` + +- **syntax:** `%result = pto.vgatherb %source, %offsets, %active_lanes : !pto.ptr, !pto.vreg, index -> !pto.vreg` +- **semantics:** Byte-granularity indexed gather from UB. +- **inputs:** + `%source` is the UB base pointer, `%offsets` contains per-block byte offsets, + and `%active_lanes` bounds the number of active gathered blocks. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + This is a block gather, not a byte-per-lane gather. `%source` MUST be 32-byte + aligned, each participating offset MUST describe a 32-byte-aligned block, and + inactive blocks are zero-filled. +- **Latency:** **~21** cycles issue→retire. + +```c +for (int i = 0; i < active_lanes; i++) + dst[i] = UB[base + offsets[i]]; // byte-addressed +``` + +--- + +##### `pto.vgather2_bc` + +- **syntax:** `%result = pto.vgather2_bc %source, %offsets, %mask : !pto.ptr, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Gather with broadcast, conditioned by mask. +- **inputs:** + `%source` is the UB base pointer, `%offsets` contains gather indices, and + `%mask` gates which lanes participate. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + This is a backward-compatible family. Masked-off lanes do not participate in + address coalescing and do not trigger address overflow exceptions; their + destination lanes are zero-filled. +- **Latency:** **27–28** cycles (same as **`pto.vgather2`**). + +--- + +#### Contiguous Stores + +##### `pto.vsts` + +- **syntax:** `pto.vsts %value, %dest[%offset], %mask {dist = "DIST"} : !pto.vreg, !pto.ptr, !pto.mask` +- **semantics:** Vector store with distribution mode. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, `%offset` is + the displacement, `%mask` selects the active lanes or sub-elements, and + `DIST` selects the store distribution. +- **outputs:** + This op has no SSA result; it writes to UB memory. +- **constraints and limitations:** + The effective destination address MUST satisfy the alignment rule of the + selected store mode. Narrowing/packing modes may only preserve a subset of the + source bits. Merge-channel modes reinterpret the source vector as channel + planes and interleave them on store. + +**Distribution modes:** + +| Mode | Description | C Semantics | Latency | +|------|-------------|-------------|---------------------| +| `NORM_B8/B16/B32` | Contiguous store | `UB[base + i] = src[i]` | **9** cycles | +| `PK_B16/B32` | Pack/narrowing store | `UB_i16[base + 2*i] = truncate_16(src_i32[i])` | **9** cycles | +| `MRG4CHN_B8` | Merge 4 channels (R,G,B,A → RGBA) | Interleave 4 planes | **9** cycles | +| `MRG2CHN_B8/B16` | Merge 2 channels | Interleave 2 planes | **9** cycles | + +**Example — Contiguous store:** +```mlir +pto.vsts %v, %ub[%offset], %mask {dist = "NORM_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +--- + +#### Dual Stores (Interleave) + +##### `pto.vstx2` + +- **syntax:** `pto.vstx2 %low, %high, %dest[%offset], "DIST", %mask : !pto.vreg, !pto.vreg, !pto.ptr, index, !pto.mask` +- **semantics:** Dual interleaved store (SoA → AoS conversion). +- **inputs:** + `%low` and `%high` are the two source vectors, `%dest` is the UB base pointer, + `%offset` is the displacement, `DIST` selects the interleave layout, and + `%mask` gates the participating elements. +- **outputs:** + This op has no SSA result; it writes an interleaved stream to UB. +- **constraints and limitations:** + This family is only legal for interleave distributions. The two source + vectors form an ordered pair, and the interleave semantics of that pair MUST + be preserved. +- **Latency:** **`INTLV_B32` / `INTLV_B16` / `INTLV_B8` → 12** cycles on `RV_VSTI`. + +**Distribution modes:** `INTLV_B8`, `INTLV_B16`, `INTLV_B32` + +```c +// INTLV_B32: +for (int i = 0; i < 64; i++) { + UB[base + 8*i] = low[i]; + UB[base + 8*i + 4] = high[i]; +} +``` + +--- + +#### Strided Stores + +##### `pto.vsst` + +- **syntax:** `pto.vsst %value, %dest[%offset], "STRIDE" : !pto.vreg, !pto.ptr` +- **semantics:** Strided store with fixed stride pattern. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, and `%offset` + / `STRIDE` select the fixed strided layout. +- **outputs:** + This op writes UB memory and returns no SSA value. +- **constraints and limitations:** + This is a deprecated compatibility family. The stride token, not the vector + lane number alone, determines which destination elements are written. +- **Latency:** **9** cycles. + +--- + +##### `pto.vsstb` + +- **syntax:** `pto.vsstb %value, %dest, %offset, %mask : !pto.vreg, !pto.ptr, i32, !pto.mask` +- **semantics:** Block-strided store for 2D tile access. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, `%offset` is + the packed stride/control word, and `%mask` controls block participation. +- **outputs:** + This op writes UB memory and returns no SSA value. +- **constraints and limitations:** + `%offset` is a control word, not a plain byte displacement. This is a + deprecated compatibility family kept for surface coverage. +- **Latency:** **9** cycles. + +--- + +#### Scatter (Indexed) Stores + +##### `pto.vscatter` + +- **syntax:** `pto.vscatter %value, %dest, %offsets, %active_lanes : !pto.vreg, !pto.ptr, !pto.vreg, index` +- **semantics:** Indexed scatter to UB. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, `%offsets` + provides per-lane or per-block indices, and `%active_lanes` bounds the active + requests. +- **outputs:** + This op writes UB memory and returns no SSA value. +- **constraints and limitations:** + Only `b8`, `b16`, and `b32` element sizes are supported. The index vector + must use a supported integer element type and layout for this family. + Each computed address MUST be element-aligned. If two or more indices alias, + only one write is guaranteed and the winning lane is implementation-defined. +- **Latency:** **~17** cycles for **`Dtype: B16`**. + +```c +for (int i = 0; i < active_lanes; i++) + UB[base + offsets[i] * sizeof(T)] = src[i]; +``` + +--- + +#### Alignment State Stores + +##### `pto.vsta` + +- **syntax:** `pto.vsta %value, %dest[%offset] : !pto.align, !pto.ptr, index` +- **semantics:** Flush alignment state to memory. +- **inputs:** + `%value` is the pending store-alignment state, `%dest` is the UB base pointer, + and `%offset` is the flush displacement. +- **outputs:** + This op writes buffered tail bytes to UB and returns no SSA value. +- **constraints and limitations:** + The flush address MUST match the post-updated address expected by the + preceding unaligned-store stream. After the flush, the corresponding store + alignment state is consumed. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstas` +- **syntax:** `pto.vstas %value, %dest, %offset : !pto.align, !pto.ptr, i32` +- **semantics:** Scalar-register-offset form of alignment-state flush. +- **inputs:** + `%value` is the pending store-alignment state, `%dest` is the UB base + pointer, and `%offset` is the scalar-register style displacement. +- **outputs:** + This op writes buffered tail bytes to UB and returns no SSA value. +- **constraints and limitations:** + This family uses the same buffered-tail semantics as `pto.vsta` but keeps the + scalar-offset form explicit. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstar` +- **syntax:** `pto.vstar %value, %dest : !pto.align, !pto.ptr` +- **semantics:** Flush alignment state using the register-update form. +- **inputs:** + `%value` is the pending store-alignment state and `%dest` is the UB base + pointer. +- **outputs:** + This op writes buffered tail bytes to UB and returns no SSA value. +- **constraints and limitations:** + The implicit update state consumed by this flush MUST correspond to the same + store stream that produced `%value`. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstu` +- **syntax:** `%align_out, %base_out = pto.vstu %align_in, %base_in, %value, %dest, %mode : !pto.align, !pto.ptr, !pto.vreg, !pto.ptr, index -> !pto.align, !pto.ptr` +- **semantics:** Unaligned store with explicit threaded alignment/base state. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%base_in` is the current + stream base, `%value` is the vector to store, `%dest` is the UB base pointer, + and `%mode` selects the post-update behavior. +- **outputs:** + `%align_out` is the updated buffered-tail state and `%base_out` is the + post-update base pointer state. +- **constraints and limitations:** + This op models a stateful unaligned-store sequence in SSA form. A final + `pto.vsta` / `pto.vstas` / `pto.vstar` is still required to flush the trailing + buffered bytes. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstus` +- **syntax:** `%align_out, %base_out = pto.vstus %align_in, %base_in, %value, %dest, %offset : !pto.align, !pto.ptr, !pto.vreg, !pto.ptr, i32 -> !pto.align, !pto.ptr` +- **semantics:** Scalar-offset unaligned store with threaded state. +- **inputs:** + Same roles as `pto.vstu`, but `%offset` is provided explicitly as the scalar + displacement. +- **outputs:** + Updated alignment state and base state. +- **constraints and limitations:** + The same final flush requirement and state-threading constraints as + `pto.vstu` apply. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstur` +- **syntax:** `%align_out = pto.vstur %align_in, %value, %dest : !pto.align, !pto.vreg, !pto.ptr -> !pto.align` +- **semantics:** Register-update unaligned store form. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%value` is the vector to + store, and `%dest` is the UB base pointer. +- **outputs:** + `%align_out` is the updated buffered-tail state. +- **constraints and limitations:** + This op updates only the residual alignment state. A matching flush op is + still required to emit the trailing bytes. +- **Latency:** **9** cycles. + +--- + +#### Stateful Store Ops + +These ops make reference-updated state explicit as SSA results. + +##### `pto.vstu` + +- **syntax:** `%align_out, %offset_out = pto.vstu %align_in, %offset_in, %value, %base, "MODE" : !pto.align, index, !pto.vreg, !pto.ptr -> !pto.align, index` +- **semantics:** Unaligned store with align + offset state update. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%offset_in` is the current + logical byte/element displacement, `%value` is the vector being stored, and + `%base` is the UB base pointer. +- **outputs:** + `%align_out` is the updated alignment/tail state and `%offset_out` is the + next offset after applying the selected post-update rule. +- **constraints and limitations:** + The alignment state MUST be threaded in program order. A terminating flush + form such as `pto.vstar`/`pto.vstas` is still required to commit the buffered + tail bytes. +- **Latency:** **9** cycles. + +**Mode tokens:** `POST_UPDATE`, `NO_POST_UPDATE` + +--- + +##### `pto.vstus` + +- **syntax:** `%align_out, %base_out = pto.vstus %align_in, %offset, %value, %base, "MODE" : !pto.align, i32, !pto.vreg, !pto.ptr -> !pto.align, !pto.ptr` +- **semantics:** Unaligned store with scalar offset and state update. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%offset` is the scalar + displacement, `%value` is the vector being stored, and `%base` is the UB base + pointer. +- **outputs:** + `%align_out` is the updated buffered-tail state and `%base_out` is the next + base pointer when the lowering chooses a post-update form. +- **constraints and limitations:** + This is the scalar-offset stateful form of the unaligned store family. The + scalar offset width and update mode MUST match the selected form, and a later + flush op is still required. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstur` + +- **syntax:** `%align_out = pto.vstur %align_in, %value, %base, "MODE" : !pto.align, !pto.vreg, !pto.ptr -> !pto.align` +- **semantics:** Unaligned store with residual flush and state update. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%value` is the vector to + store, and `%base` is the UB base pointer. +- **outputs:** + `%align_out` is the updated residual state after the current partial store. +- **constraints and limitations:** + This form exposes only the evolving state; it does not by itself guarantee + that all buffered bytes have reached memory. A compatible final flush is still + required unless the surrounding sequence is known to be complete. +- **Latency:** **9** cycles. + + + +### 4. Predicate Load/Store + +> **Category:** UB ↔ Predicate Register data movement +> **Pipeline:** PIPE_V (Vector Core) + +Predicate registers (`!pto.mask`) are 256-bit registers that enable per-lane conditional execution. These ops move predicate values between UB and predicate registers. + +In concrete examples, `G` should be chosen to match the consumer family. The +examples below use `b32` when the loaded/stored mask is paired with `f32` +vector compares or selects. + +--- + +#### Predicate Loads + +##### `pto.plds` + +- **syntax:** `%result = pto.plds %source[%offset] {dist = "DIST"} : !pto.ptr -> !pto.mask` +- **semantics:** Load predicate register with scalar offset. + +**Distribution modes:** `NORM`, `US`, `DS` + +**Example:** +```mlir +%mask = pto.plds %ub[%c0] {dist = "NORM"} : !pto.ptr -> !pto.mask +``` + +--- + +##### `pto.pld` + +- **syntax:** `%result = pto.pld %source[%offset], "DIST" : !pto.ptr, index -> !pto.mask` +- **semantics:** Load predicate register with areg offset. + +--- + +##### `pto.pldi` + +- **syntax:** `%result = pto.pldi %source, %offset, "DIST" : !pto.ptr, i32 -> !pto.mask` +- **semantics:** Load predicate register with immediate offset. + +--- + +#### Predicate Stores + +##### `pto.psts` + +- **syntax:** `pto.psts %value, %dest[%offset] : !pto.mask, !pto.ptr` +- **semantics:** Store predicate register with scalar offset. + +**Example:** +```mlir +pto.psts %mask, %ub[%c0] : !pto.mask, !pto.ptr +``` + +--- + +##### `pto.pst` + +- **syntax:** `pto.pst %value, %dest[%offset], "DIST" : !pto.mask, !pto.ptr, index` +- **semantics:** Store predicate register with areg offset. + +**Distribution modes:** `NORM`, `PK` + +--- + +##### `pto.psti` + +- **syntax:** `pto.psti %value, %dest, %offset, "DIST" : !pto.mask, !pto.ptr, i32` +- **semantics:** Store predicate register with immediate offset. + +--- + +##### `pto.pstu` + +- **syntax:** `%align_out, %base_out = pto.pstu %align_in, %value, %base : !pto.align, !pto.mask, !pto.ptr -> !pto.align, !pto.ptr` +- **semantics:** Predicate unaligned store with align state update. + +--- + +#### Typical Usage Pattern + +```mlir +// Generate comparison mask +%mask = pto.vcmp %v0, %v1, %seed, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + +// Store mask to UB for later use +pto.psts %mask, %ub_mask[%c0] : !pto.mask, !pto.ptr + +// ... later in another kernel ... + +// Load mask from UB +%saved_mask = pto.plds %ub_mask[%c0] {dist = "NORM"} : !pto.ptr -> !pto.mask + +// Use for predicated select +%result = pto.vsel %v_true, %v_false, %saved_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 5. Materialization & Predicate Ops + +> **Category:** Scalar broadcast, predicate generation and manipulation +> **Pipeline:** PIPE_V (Vector Core) + +These ops create vectors from scalar values and manipulate predicate registers. + +#### Common Operand Model + +- `%value` is the scalar source value in SSA form. +- `%input` is either a source scalar or a source vector depending on the op. +- `%result` is the destination vector register value. +- For 32-bit scalar inputs, the scalar source MUST satisfy the backend's legal + scalar-source constraints for this family. + +--- + +#### Scalar Materialization + +##### `pto.vbr` + +- **syntax:** `%result = pto.vbr %value : T -> !pto.vreg` +- **semantics:** Broadcast scalar to all vector lanes. +- **inputs:** + `%value` is the scalar source. +- **outputs:** + `%result` is a vector whose active lanes all carry `%value`. +- **constraints and limitations:** + Supported forms are `b8`, `b16`, and `b32`. For `b8`, only the low 8 bits of + the scalar source are consumed. + +```c +for (int i = 0; i < N; i++) + dst[i] = value; +``` + +**Example:** +```mlir +%one = pto.vbr %c1_f32 : f32 -> !pto.vreg<64xf32> +``` + +--- + +##### `pto.vdup` + +- **syntax:** `%result = pto.vdup %input {position = "POSITION"} : T|!pto.vreg -> !pto.vreg` +- **semantics:** Duplicate scalar or vector element to all lanes. +- **inputs:** + `%input` supplies the scalar or source-lane value selected by `position`. +- **outputs:** + `%result` is the duplicated vector. +- **constraints and limitations:** + `position` selects which source element or scalar position is duplicated. The + current PTO micro Instruction representation models that selector as an attribute rather than a + separate operand. + +```c +for (int i = 0; i < N; i++) + dst[i] = input_scalar_or_element; +``` + +--- + +#### Predicate Generation + +##### `pto.pset_b8` / `pto.pset_b16` / `pto.pset_b32` + +- **syntax:** `%result = pto.pset_b8 "PAT_*" : !pto.mask` +- **syntax:** `%result = pto.pset_b16 "PAT_*" : !pto.mask` +- **syntax:** `%result = pto.pset_b32 "PAT_*" : !pto.mask` +- **semantics:** Generate predicate from pattern. + +**Patterns:** + +| Pattern | Description | +|---------|-------------| +| `PAT_ALL` | All lanes active | +| `PAT_ALLF` | All lanes inactive | +| `PAT_H` | High half active | +| `PAT_Q` | Upper quarter active | +| `PAT_VL1`...`PAT_VL128` | First N lanes active | +| `PAT_M3`, `PAT_M4` | Modular patterns | + +**Example — All 64 f32 lanes active:** +```mlir +%all_active = pto.pset_b32 "PAT_ALL" : !pto.mask +``` + +**Example — First 16 lanes active:** +```mlir +%first_16 = pto.pset_b32 "PAT_VL16" : !pto.mask +``` + +--- + +##### `pto.pge_b8` / `pto.pge_b16` / `pto.pge_b32` + +- **syntax:** `%result = pto.pge_b8 "PAT_*" : !pto.mask` +- **syntax:** `%result = pto.pge_b16 "PAT_*" : !pto.mask` +- **syntax:** `%result = pto.pge_b32 "PAT_*" : !pto.mask` +- **semantics:** Generate tail mask — first N lanes active. + +```c +for (int i = 0; i < TOTAL_LANES; i++) + mask[i] = (i < len); +``` + +**Example — Tail mask for remainder loop:** +```mlir +%tail_mask = pto.pge_b32 "PAT_VL8" : !pto.mask +``` + +--- + +##### `pto.plt_b8` / `pto.plt_b16` / `pto.plt_b32` + +- **syntax:** `%mask, %scalar_out = pto.plt_b8 %scalar : i32 -> !pto.mask, i32` +- **syntax:** `%mask, %scalar_out = pto.plt_b16 %scalar : i32 -> !pto.mask, i32` +- **syntax:** `%mask, %scalar_out = pto.plt_b32 %scalar : i32 -> !pto.mask, i32` +- **semantics:** Generate predicate state together with updated scalar state. + +--- + +#### Predicate Pack/Unpack + +##### `pto.ppack` + +- **syntax:** `%result = pto.ppack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Narrowing pack of predicate register. + +**Part tokens:** `LOWER`, `HIGHER` + +--- + +##### `pto.punpack` + +- **syntax:** `%result = pto.punpack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Widening unpack of predicate register. + +--- + +#### Predicate Logical Ops + +##### `pto.pand` + +- **syntax:** `%result = pto.pand %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise AND. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) dst[i] = src0[i] & src1[i]; +``` + +--- + +##### `pto.por` + +- **syntax:** `%result = pto.por %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise OR. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) dst[i] = src0[i] | src1[i]; +``` + +--- + +##### `pto.pxor` + +- **syntax:** `%result = pto.pxor %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise XOR. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) dst[i] = src0[i] ^ src1[i]; +``` + +--- + +##### `pto.pnot` + +- **syntax:** `%result = pto.pnot %input, %mask : !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise NOT. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) dst[i] = ~src[i]; +``` + +--- + +##### `pto.psel` + +- **syntax:** `%result = pto.psel %src0, %src1, %sel : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate select (mux). + +```c +for (int i = 0; i < N; i++) + dst[i] = sel[i] ? src0[i] : src1[i]; +``` + +--- + +#### Predicate Movement + +##### `pto.ppack` + +- **syntax:** `%result = pto.ppack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Narrowing pack of predicate register. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) dst[i] = src[i]; +``` + +--- + +##### `pto.punpack` + +- **syntax:** `%result = pto.punpack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Widening unpack of predicate register. + +--- + +##### `pto.pdintlv_b8` + +- **syntax:** `%low, %high = pto.pdintlv_b8 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **semantics:** Predicate deinterleave. + +--- + +##### `pto.pintlv_b16` + +- **syntax:** `%low, %high = pto.pintlv_b16 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **semantics:** Predicate interleave. + +--- + +#### Typical Usage + +```mlir +// Generate all-active mask for f32 (64 lanes) +%all = pto.pset_b32 "PAT_ALL" : !pto.mask + +// Generate tail mask for remainder (last 12 elements) +%tail = pto.pge_b32 "PAT_VL12" : !pto.mask + +// Compare and generate mask +%cmp_mask = pto.vcmp %a, %b, %all, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + +// Combine masks: only process tail elements that passed comparison +%combined = pto.pand %cmp_mask, %tail, %all : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + +// Use for predicated operation +%result = pto.vsel %true_vals, %false_vals, %combined : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 6. Unary Vector Ops + +> **Category:** Single-input vector operations +> **Pipeline:** PIPE_V (Vector Core) + +Element-wise operations that take one vector input and produce one vector output. + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%mask` is the predicate operand. For this family, inactive lanes follow the + predication behavior of the selected instruction form: zeroing forms + zero-fill inactive lanes, while merging forms preserve the destination value. +- `%result` is the destination vector register value. Unless stated otherwise, + `%result` has the same lane count and element type as `%input`. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). **fp16** values use **aclFloat16** in traces where measured. **bf16:** no simple-tile ST coverage on this surface; treat as **—**. + +| PTO op | RV (CA) | fp32 | fp16 | bf16 | +|--------|---------|------|------|------| +| `pto.vabs` | `RV_VABS_FP` | **5** | **5** | — | +| `pto.vneg` | `RV_VMULS` | **8** | **8** | — | +| `pto.vexp` | `RV_VEXP` | **16** | **21** | — | +| `pto.vln` | `RV_VLN` | **18** | **23** | — | +| `pto.vsqrt` | `RV_VSQRT` | **17** | **22** | — | +| `pto.vrsqrt` | `RV_VSQRT` / `RV_VDIV` | **17** / **17** | **22** / **22** | — | +| `pto.vrec` | `RV_VDIV` | **17** | **22** | — | +| `pto.vrelu` | `RV_VRELU` | **5** | **5** | — | +| `pto.vnot` | `RV_VNOT` | — | int-only paths | — | +| `pto.vmov` | `RV_VLD` proxy | **9** | **9** | — | + +--- + +#### Arithmetic + +##### `pto.vabs` + +- **syntax:** `%result = pto.vabs %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] < 0) ? -src[i] : src[i]; +``` + +- **inputs:** `%input` supplies the source lanes and `%mask` selects which lanes + participate. +- **outputs:** `%result` receives the lane-wise absolute values. +- **constraints and limitations:** Source and result types MUST match. Integer + overflow on the most-negative signed value follows the target-defined + behavior. + +--- + +##### `pto.vneg` + +- **syntax:** `%result = pto.vneg %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = -src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise arithmetic negation. +- **constraints and limitations:** Source and result types MUST match. + +--- + +#### Transcendental + +##### `pto.vexp` + +- **syntax:** `%result = pto.vexp %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = expf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds `exp(input[i])` per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + +--- + +##### `pto.vln` + +- **syntax:** `%result = pto.vln %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = logf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the natural logarithm per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + For real-number semantics, active inputs SHOULD be strictly positive; non- + positive inputs follow the target's exception/NaN rules. + +--- + +##### `pto.vsqrt` + +- **syntax:** `%result = pto.vsqrt %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = sqrtf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the square root per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + Negative active inputs follow the target's exception/NaN rules. + +--- + +##### `pto.vrsqrt` + +- **syntax:** `%result = pto.vrsqrt %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = 1.0f / sqrtf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds reciprocal-square-root values per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + Active inputs containing `+0` or `-0` follow the target's divide-style + exceptional behavior. + +--- + +##### `pto.vrec` + +- **syntax:** `%result = pto.vrec %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = 1.0f / src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the reciprocal per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + Active inputs containing `+0` or `-0` follow the target's divide-style + exceptional behavior. + +--- + +#### Activation + +##### `pto.vrelu` + +- **syntax:** `%result = pto.vrelu %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] > 0) ? src[i] : 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds `max(input[i], 0)` per active lane. +- **constraints and limitations:** Only floating-point element types are legal + on the current A5 surface described here. + +--- + +#### Bitwise + +##### `pto.vnot` + +- **syntax:** `%result = pto.vnot %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = ~src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the lane-wise bitwise inversion. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vbcnt` + +- **syntax:** `%result = pto.vbcnt %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = __builtin_popcount(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the population count for each active lane. +- **constraints and limitations:** Integer element types only. The count is + over the source element width, not over the full vector register. + +--- + +##### `pto.vcls` + +- **syntax:** `%result = pto.vcls %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = count_leading_sign_bits(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the leading-sign-bit count per active lane. +- **constraints and limitations:** Integer element types only. This operation is + sign-aware, so signed interpretation matters. + +--- + +#### Movement + +##### `pto.vmov` + +- **syntax:** `%result = pto.vmov %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Vector register copy. + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` is a copy of the source vector. +- **constraints and limitations:** Predicated `pto.vmov` behaves like a masked + copy, while the unpredicated form behaves like a full-register copy. + +--- + +#### Typical Usage + +```mlir +// Softmax numerator: exp(x - max) +%sub = pto.vsub %x, %max_broadcast, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%exp = pto.vexp %sub, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Reciprocal for division +%sum_rcp = pto.vrec %sum, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// ReLU activation +%activated = pto.vrelu %linear_out, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 7. Binary Vector Ops + +> **Category:** Two-input vector operations +> **Pipeline:** PIPE_V (Vector Core) + +Element-wise operations that take two vector inputs and produce one vector output. + +#### Common Operand Model + +- `%lhs` and `%rhs` are the two source vector register values. +- `%mask` is the predicate operand `Pg` that gates which lanes participate. +- `%result` is the destination vector register value. Unless explicitly noted, + it has the same lane count and element type as the inputs. +- Unless explicitly documented otherwise, `%lhs`, `%rhs`, and `%result` MUST + have matching vector shapes and element types. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). **fp16** uses **aclFloat16** in measured traces. **bf16:** — (no dedicated vec tile ST on this surface). + +| PTO op | RV (CA) | fp32 | fp16 | bf16 | +|--------|---------|------|------|------| +| `pto.vadd` | `RV_VADD` | **7** | **7** | — | +| `pto.vsub` | `RV_VSUB` | **7** | **7** | — | +| `pto.vmul` | `RV_VMUL` | **8** | **8** | — | +| `pto.vdiv` | `RV_VDIV` | **17** | **22** | — | + +--- + +#### Arithmetic + +##### `pto.vadd` + +- **syntax:** `%result = pto.vadd %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i64, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] + src1[i]; +``` + +- **inputs:** `%lhs` and `%rhs` are added lane-wise; `%mask` selects active + lanes. +- **outputs:** `%result` is the lane-wise sum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vsub` + +- **syntax:** `%result = pto.vsub %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i64, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] - src1[i]; +``` + +- **inputs:** `%lhs` is the minuend, `%rhs` is the subtrahend, and `%mask` + selects active lanes. +- **outputs:** `%result` is the lane-wise difference. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vmul` + +- **syntax:** `%result = pto.vmul %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, bf16, f32 (**NOT** i8/u8) + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] * src1[i]; +``` + +- **inputs:** `%lhs` and `%rhs` are multiplied lane-wise; `%mask` selects + active lanes. +- **outputs:** `%result` is the lane-wise product. +- **constraints and limitations:** The current A5 profile excludes `i8/u8` + forms from this surface. + +--- + +##### `pto.vdiv` + +- **syntax:** `%result = pto.vdiv %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 only (no integer division) + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] / src1[i]; +``` + +- **inputs:** `%lhs` is the numerator, `%rhs` is the denominator, and `%mask` + selects active lanes. +- **outputs:** `%result` is the lane-wise quotient. +- **constraints and limitations:** Floating-point element types only. Active + denominators containing `+0` or `-0` follow the target's exceptional + behavior. + +--- + +##### `pto.vmax` + +- **syntax:** `%result = pto.vmax %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src0[i] > src1[i]) ? src0[i] : src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` holds the lane-wise maximum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vmin` + +- **syntax:** `%result = pto.vmin %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src0[i] < src1[i]) ? src0[i] : src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` holds the lane-wise minimum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +#### Bitwise + +##### `pto.vand` + +- **syntax:** `%result = pto.vand %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] & src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise AND. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vor` + +- **syntax:** `%result = pto.vor %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] | src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise OR. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vxor` + +- **syntax:** `%result = pto.vxor %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] ^ src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise XOR. +- **constraints and limitations:** Integer element types only. + +--- + +#### Shift + +##### `pto.vshl` + +- **syntax:** `%result = pto.vshl %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] << src1[i]; +``` + +- **inputs:** `%lhs` supplies the shifted value, `%rhs` supplies the per-lane + shift amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. Shift counts + SHOULD stay within `[0, bitwidth(T) - 1]`; out-of-range behavior is target- + defined unless the verifier narrows it further. + +--- + +##### `pto.vshr` + +- **syntax:** `%result = pto.vshr %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] >> src1[i]; // arithmetic for signed, logical for unsigned +``` + +- **inputs:** `%lhs` supplies the shifted value, `%rhs` supplies the per-lane + shift amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. Signedness of the + element type determines arithmetic vs logical behavior. + +--- + +#### Carry Operations + +##### `pto.vaddc` + +- **syntax:** `%result, %carry = pto.vaddc %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Add with carry output. + +```c +for (int i = 0; i < N; i++) { + uint64_t r = (uint64_t)src0[i] + src1[i]; + dst[i] = (T)r; + carry[i] = (r >> bitwidth); +} +``` + +- **inputs:** `%lhs` and `%rhs` are added lane-wise and `%mask` selects active + lanes. +- **outputs:** `%result` is the truncated arithmetic result and `%carry` is the + carry/overflow predicate per lane. +- **constraints and limitations:** This is a carry-chain integer add family. On + the current A5 surface, it SHOULD be treated as an unsigned integer + operation. + +--- + +##### `pto.vsubc` + +- **syntax:** `%result, %borrow = pto.vsubc %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Subtract with borrow output. + +```c +for (int i = 0; i < N; i++) { + dst[i] = src0[i] - src1[i]; + borrow[i] = (src0[i] < src1[i]); +} +``` + +- **inputs:** `%lhs` and `%rhs` are subtracted lane-wise and `%mask` selects + active lanes. +- **outputs:** `%result` is the arithmetic difference and `%borrow` marks lanes + that borrowed. +- **constraints and limitations:** This operation SHOULD be treated as an + unsigned 32-bit carry-chain family unless and until the verifier states + otherwise. + +--- + +#### Typical Usage + +```mlir +// Vector addition +%sum = pto.vadd %a, %b, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Element-wise multiply +%prod = pto.vmul %x, %y, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Clamp to range [min, max] +%clamped_low = pto.vmax %input, %min_vec, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%clamped = pto.vmin %clamped_low, %max_vec, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Bit manipulation +%masked = pto.vand %data, %bitmask, %mask : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +``` + + + +### 8. Vec-Scalar Ops + +> **Category:** Vector-scalar operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that combine a vector with a scalar value, applying the scalar to every lane. + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%scalar` is the scalar operand in SSA form. +- `%mask` is the predicate operand. +- `%result` is the destination vector register value. +- For 32-bit scalar forms, the scalar source MUST satisfy the backend's legal + scalar-source constraints for this family. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). **fp16** uses **aclFloat16** in measured traces. **bf16:** —. + +| PTO op | RV (CA) | fp32 | fp16 | bf16 | +|--------|---------|------|------|------| +| `pto.vadds` | `RV_VADDS` | **7** | **7** | — | +| `pto.vmuls` | `RV_VMULS` | **8** | **8** | — | + +--- + +#### Arithmetic + +##### `pto.vadds` + +- **syntax:** `%result = pto.vadds %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] + scalar; +``` + +- **inputs:** `%input` is the source vector, `%scalar` is broadcast logically to + each active lane, and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise sum. +- **constraints and limitations:** Inactive lanes follow the predication + behavior defined for this family. On the current surface, inactive lanes are + treated as zeroing lanes. + +--- + +##### `pto.vsubs` + +- **syntax:** `%result = pto.vsubs %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] - scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise difference. +- **constraints and limitations:** Integer or floating-point legality depends on + the selected type family in lowering. + +--- + +##### `pto.vmuls` + +- **syntax:** `%result = pto.vmuls %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] * scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise product. +- **constraints and limitations:** Supported element types are hardware-family + specific; the current PTO micro Instruction documentation covers the common numeric cases. + +--- + +##### `pto.vmaxs` + +- **syntax:** `%result = pto.vmaxs %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] > scalar) ? src[i] : scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise maximum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vmins` + +- **syntax:** `%result = pto.vmins %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] < scalar) ? src[i] : scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise minimum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +#### Bitwise + +##### `pto.vands` + +- **syntax:** `%result = pto.vands %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] & scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise AND. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vors` + +- **syntax:** `%result = pto.vors %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] | scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise OR. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vxors` + +- **syntax:** `%result = pto.vxors %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] ^ scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise XOR. +- **constraints and limitations:** Integer element types only. + +--- + +#### Shift + +##### `pto.vshls` + +- **syntax:** `%result = pto.vshls %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] << scalar; +``` + +- **inputs:** `%input` is the value vector, `%scalar` is the uniform shift + amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. The shift amount + SHOULD stay within the source element width. + +--- + +##### `pto.vshrs` + +- **syntax:** `%result = pto.vshrs %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] >> scalar; +``` + +- **inputs:** `%input` is the value vector, `%scalar` is the uniform shift + amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vlrelu` + +- **syntax:** `%result = pto.vlrelu %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : scalar * src[i]; +``` + +- **inputs:** `%input` is the activation vector, `%scalar` is the leaky slope, + and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise leaky-ReLU result. +- **constraints and limitations:** Only `f16` and `f32` forms are currently + documented for `pto.vlrelu`. + +--- + +#### Carry Operations + +##### `pto.vaddcs` + +- **syntax:** `%result, %carry = pto.vaddcs %lhs, %rhs, %carry_in, %mask : !pto.vreg, !pto.vreg, !pto.mask, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Add with carry-in and carry-out. + +```c +for (int i = 0; i < N; i++) { + uint64_t r = (uint64_t)src0[i] + src1[i] + carry_in[i]; + dst[i] = (T)r; + carry_out[i] = (r >> bitwidth); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the value vectors, `%carry_in` is the + incoming carry predicate, and `%mask` selects active lanes. +- **outputs:** `%result` is the arithmetic result and `%carry` is the carry-out + predicate. +- **constraints and limitations:** This is the scalar-extended carry-chain + family. Treat it as an unsigned integer operation unless the verifier states a + wider legal domain. + +--- + +##### `pto.vsubcs` + +- **syntax:** `%result, %borrow = pto.vsubcs %lhs, %rhs, %borrow_in, %mask : !pto.vreg, !pto.vreg, !pto.mask, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Subtract with borrow-in and borrow-out. + +```c +for (int i = 0; i < N; i++) { + dst[i] = src0[i] - src1[i] - borrow_in[i]; + borrow_out[i] = (src0[i] < src1[i] + borrow_in[i]); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the value vectors, `%borrow_in` is the + incoming borrow predicate, and `%mask` selects active lanes. +- **outputs:** `%result` is the arithmetic result and `%borrow` is the + borrow-out predicate. +- **constraints and limitations:** This is the scalar-extended borrow-chain + family and SHOULD be treated as an unsigned integer operation. + +--- + +#### Typical Usage + +```mlir +// Add bias to all elements +%biased = pto.vadds %activation, %bias_scalar, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Scale by constant +%scaled = pto.vmuls %input, %scale, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Clamp to [0, 255] for uint8 quantization +%clamped_low = pto.vmaxs %input, %c0, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> +%clamped = pto.vmins %clamped_low, %c255, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Shift right by fixed amount +%shifted = pto.vshrs %data, %c4, %mask : !pto.vreg<64xi32>, i32, !pto.mask -> !pto.vreg<64xi32> +``` + + + +### 9. Conversion Ops + +> **Category:** Type conversion operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that convert between data types (float/int, narrowing/widening). + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%result` is the destination vector register value. +- `round_mode`, `sat`, and `part` control rounding, saturation, and lane-part + selection in attribute form. +- The single `pto.vcvt` surface covers float-int, float-float, int-float, and + int-int conversion families. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). Only representative traces below; other `pto.vcvt` conversion pairs depend on the RV lowering in the trace. + +| PTO op | RV (CA) | Note | Latency | +|--------|---------|------|---------| +| `pto.vcvt` | `RV_VCVT_F2F` | f32→f16 | **7** | +| `pto.vci` | — | no vector `RV_*` in sampled `veccore0` trace | — | + +--- + +#### `pto.vci` + +- **syntax:** `%result = pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- **semantics:** Generate a lane-index vector from a scalar seed/index value. +- **inputs:** + `%index` is the scalar seed or base index. +- **outputs:** + `%result` is the generated index vector. +- **constraints and limitations:** + This is an index-generation family, not a numeric conversion. `ORDER` and the + result element type together determine how indices are generated. + +--- + +#### `pto.vcvt` + +- **syntax:** `%result = pto.vcvt %input {round_mode = "ROUND_MODE", sat = "SAT_MODE", part = "PART_MODE"} : !pto.vreg -> !pto.vreg` +- **semantics:** Type conversion between float/int types with rounding control. + +```c +for (int i = 0; i < min(N, M); i++) + dst[i] = convert(src[i], T0, T1, round_mode); +``` + +- **inputs:** + `%input` is the source vector; attributes select rounding, saturation, and + even/odd placement when the conversion changes width. +- **outputs:** + `%result` is the converted vector. +- **constraints and limitations:** + Only documented source/destination type pairs are legal. `PART_EVEN` / + `PART_ODD` is only meaningful for width-changing forms that pack two source + streams into one destination register. + +--- + +##### Rounding Modes + +| Mode | Description | +|------|-------------| +| `ROUND_R` | Round to nearest, ties to even (default) | +| `ROUND_A` | Round away from zero | +| `ROUND_F` | Round toward negative infinity (floor) | +| `ROUND_C` | Round toward positive infinity (ceil) | +| `ROUND_Z` | Round toward zero (truncate) | +| `ROUND_O` | Round to odd | + +--- + +##### Saturation Modes + +| Mode | Description | +|------|-------------| +| `RS_ENABLE` | Saturate on overflow | +| `RS_DISABLE` | No saturation (wrap/undefined on overflow) | + +--- + +##### Part Modes (for width-changing conversions) + +| Mode | Description | +|------|-------------| +| `PART_EVEN` | Output to even-indexed lanes | +| `PART_ODD` | Output to odd-indexed lanes | + +--- + +##### A5 Supported Conversions + +**Float-Float (vcvtff):** +- f32 ↔ f16 +- f32 ↔ bf16 +- f16 ↔ bf16 + +**Float-Int (vcvtfi):** +- f16 → i16, f16 → i32 +- f32 → i16, f32 → i32 +- bf16 → i32 + +**Int-Float (vcvtif):** +- i16 → f16 +- i32 → f32 + +--- + +##### Width-Changing Conversion Pattern + +For conversions that change width (e.g., f32→f16), use even/odd parts and combine: + +```mlir +// Convert two f32 vectors to one f16 vector +%even = pto.vcvt %in0 {round_mode = "ROUND_R", sat = "RS_ENABLE", part = "PART_EVEN"} + : !pto.vreg<64xf32> -> !pto.vreg<128xf16> +%odd = pto.vcvt %in1 {round_mode = "ROUND_R", sat = "RS_ENABLE", part = "PART_ODD"} + : !pto.vreg<64xf32> -> !pto.vreg<128xf16> +%result = pto.vor %even, %odd, %mask : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> +``` + +--- + +#### `pto.vtrc` + +- **syntax:** `%result = pto.vtrc %input, "ROUND_MODE" : !pto.vreg -> !pto.vreg` +- **semantics:** Truncate/round float to integer-valued float (stays in float type). + +```c +for (int i = 0; i < N; i++) + dst[i] = round_to_int_valued_float(src[i], round_mode); +``` + +- **inputs:** + `%input` is the floating-point source vector and `ROUND_MODE` selects the + truncation/rounding rule. +- **outputs:** + `%result` is still a floating-point vector, but each active lane now carries + an integer-valued floating-point result. +- **constraints and limitations:** + This op does not change the element type. `ROUND_O` is supported for avoiding + double-rounding errors during staged conversions. + +**Example:** +```mlir +// Round to nearest integer, keep as float +%rounded = pto.vtrc %input, "ROUND_R" : !pto.vreg<64xf32> -> !pto.vreg<64xf32> +// input: [1.4, 2.6, -1.5, 3.0] +// output: [1.0, 3.0, -2.0, 3.0] +``` + +--- + +#### Typical Usage + +```mlir +// Quantization: f32 → i8 with saturation +%scaled = pto.vmuls %input, %scale, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> +%quantized = pto.vcvt %scaled {round_mode = "ROUND_R", sat = "RS_ENABLE"} + : !pto.vreg<64xf32> -> !pto.vreg<64xi32> +// Then narrow i32 → i8 via pack ops + +// Mixed precision: bf16 → f32 for accumulation +%f32_vec = pto.vcvt %bf16_input {round_mode = "ROUND_R"} + : !pto.vreg<128xbf16> -> !pto.vreg<64xf32> + +// Floor for integer division +%floored = pto.vtrc %ratio, "ROUND_F" : !pto.vreg<64xf32> -> !pto.vreg<64xf32> +%int_div = pto.vcvt %floored {round_mode = "ROUND_Z"} + : !pto.vreg<64xf32> -> !pto.vreg<64xi32> +``` + + + +### 10. Reduction Ops + +> **Category:** Vector reduction operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that reduce a vector to a scalar or per-group result. + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%mask` is the predicate operand `Pg`; inactive lanes do not participate. +- `%result` is the destination vector register value. +- Reduction results are written into the low-significance portion of the + destination vector and the remaining destination bits are zero-filled. + +--- + +#### Full Vector Reductions + +##### `pto.vcadd` + +- **syntax:** `%result = pto.vcadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i64, f16, f32 +- **semantics:** Sum all elements. Result in lane 0, others zeroed. + +```c +T sum = 0; +for (int i = 0; i < N; i++) + sum += src[i]; +dst[0] = sum; +for (int i = 1; i < N; i++) + dst[i] = 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains the reduction result in its low element(s). +- **constraints and limitations:** Some narrow integer forms may widen the + internal accumulation or result placement. If all predicate bits are zero, the + result is zero. + +--- + +##### `pto.vcmax` + +- **syntax:** `%result = pto.vcmax %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Find max element with argmax. Result value + index in lane 0. + +```c +T mx = -INF; int idx = 0; +for (int i = 0; i < N; i++) + if (src[i] > mx) { mx = src[i]; idx = i; } +dst_val[0] = mx; +dst_idx[0] = idx; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` carries the reduction result in the low destination + positions. +- **constraints and limitations:** This family computes both the extremum and + location information, but the exact packing of that information into the + destination vector depends on the chosen form. If all predicate bits are zero, + the result follows the zero-filled convention. + +--- + +##### `pto.vcmin` + +- **syntax:** `%result = pto.vcmin %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Find min element with argmin. Result value + index in lane 0. + +```c +T mn = INF; int idx = 0; +for (int i = 0; i < N; i++) + if (src[i] < mn) { mn = src[i]; idx = i; } +dst_val[0] = mn; +dst_idx[0] = idx; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` carries the reduction result in the low destination + positions. +- **constraints and limitations:** As with `pto.vcmax`, the exact value/index + packing depends on the chosen form and MUST be preserved consistently. + +--- + +#### Per-VLane (Group) Reductions + +The vector register is organized as **8 VLanes** of 32 bytes each. Group reductions operate within each VLane independently. + +``` +vreg layout (f32 example, 64 elements total): +VLane 0: [0..7] VLane 1: [8..15] VLane 2: [16..23] VLane 3: [24..31] +VLane 4: [32..39] VLane 5: [40..47] VLane 6: [48..55] VLane 7: [56..63] +``` + +##### `pto.vcgadd` + +- **syntax:** `%result = pto.vcgadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Sum within each VLane. 8 results at indices 0, 8, 16, 24, 32, 40, 48, 56 (for f32). + +```c +int K = N / 8; // elements per VLane +for (int g = 0; g < 8; g++) { + T sum = 0; + for (int i = 0; i < K; i++) + sum += src[g*K + i]; + dst[g*K] = sum; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +// For f32: results at dst[0], dst[8], dst[16], dst[24], dst[32], dst[40], dst[48], dst[56] +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one sum per 32-byte VLane group, written + contiguously into the low slot of each group. +- **constraints and limitations:** This is a per-32-byte VLane-group reduction. + Inactive lanes are treated as zero. + +--- + +##### `pto.vcgmax` + +- **syntax:** `%result = pto.vcgmax %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Max within each VLane. + +```c +int K = N / 8; +for (int g = 0; g < 8; g++) { + T mx = -INF; + for (int i = 0; i < K; i++) + if (src[g*K + i] > mx) mx = src[g*K + i]; + dst[g*K] = mx; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one maximum per 32-byte VLane group. +- **constraints and limitations:** Grouping is by hardware 32-byte VLane, not by + arbitrary software subvector. + +--- + +##### `pto.vcgmin` + +- **syntax:** `%result = pto.vcgmin %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Min within each VLane. + +```c +int K = N / 8; +for (int g = 0; g < 8; g++) { + T mn = INF; + for (int i = 0; i < K; i++) + if (src[g*K + i] < mn) mn = src[g*K + i]; + dst[g*K] = mn; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one minimum per 32-byte VLane group. +- **constraints and limitations:** Grouping is by hardware 32-byte VLane, not by + arbitrary software subvector. + +--- + +#### Prefix Operations + +##### `pto.vcpadd` + +- **syntax:** `%result = pto.vcpadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Inclusive prefix sum (scan). + +```c +dst[0] = src[0]; +for (int i = 1; i < N; i++) + dst[i] = dst[i-1] + src[i]; +``` + +**Example:** +```c +// input: [1, 2, 3, 4, 5, ...] +// output: [1, 3, 6, 10, 15, ...] +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` is the inclusive prefix-sum vector. +- **constraints and limitations:** Only floating-point element types are + documented on the current A5 surface here. + +--- + +#### Typical Usage + +```mlir +// Softmax: find max for numerical stability +%max_vec = pto.vcmax %logits, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// max is in lane 0, broadcast it +%max_broadcast = pto.vlds %ub_tmp[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> + +// Row-wise sum using vcgadd (for 8-row tile) +%row_sums = pto.vcgadd %tile, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// Results at indices 0, 8, 16, 24, 32, 40, 48, 56 + +// Full vector sum for normalization +%total = pto.vcadd %values, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// total[0] contains the sum + +// Prefix sum for cumulative distribution +%cdf = pto.vcpadd %pdf, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 11. Compare & Select + +> **Category:** Comparison and conditional selection operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that compare vectors and conditionally select elements. + +#### Common Operand Model + +- `%src0` and `%src1` are source vector operands. +- `%scalar` is the scalar operand for scalar-comparison families. +- `%seed` is the incoming predicate that limits which lanes participate in the + compare. +- `%result` is either a predicate mask (`vcmp`, `vcmps`) or a vector register + (`vsel`, `vselr`, `vselrv2`). + +--- + +#### Comparison Operations + +##### `pto.vcmp` + +- **syntax:** `%result = pto.vcmp %src0, %src1, %seed, "CMP_MODE" : !pto.vreg, !pto.vreg, !pto.mask -> !pto.mask` +- **semantics:** Element-wise comparison, output predicate mask. + +```c +for (int i = 0; i < N; i++) + if (seed[i]) + dst[i] = (src0[i] CMP src1[i]) ? 1 : 0; +``` + +**Compare modes:** + +| Mode | Operation | +|------|-----------| +| `eq` | Equal (==) | +| `ne` | Not equal (!=) | +| `lt` | Less than (<) | +| `le` | Less than or equal (<=) | +| `gt` | Greater than (>) | +| `ge` | Greater than or equal (>=) | + +**Example:** +```mlir +%all_active = pto.pset_b32 "PAT_ALL" : !pto.mask +%lt_mask = pto.vcmp %a, %b, %all_active, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +// lt_mask[i] = 1 if a[i] < b[i] +``` + +- **inputs:** `%src0`, `%src1`, and `%seed`; `CMP_MODE` selects the comparison + predicate. +- **outputs:** `%result` is the generated predicate mask. +- **constraints and limitations:** Only lanes enabled by `%seed` participate. + Integer and floating-point comparisons follow their own element-type-specific + comparison rules. + +--- + +##### `pto.vcmps` + +- **syntax:** `%result = pto.vcmps %src, %scalar, %seed, "CMP_MODE" : !pto.vreg, T, !pto.mask -> !pto.mask` +- **semantics:** Compare vector against scalar. + +```c +for (int i = 0; i < N; i++) + if (seed[i]) + dst[i] = (src[i] CMP scalar) ? 1 : 0; +``` + +**Example:** +```mlir +%positive_mask = pto.vcmps %values, %c0_f32, %all_active, "gt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +// positive_mask[i] = 1 if values[i] > 0 +``` + +- **inputs:** `%src` is the vector source, `%scalar` is the scalar comparison + value, and `%seed` is the incoming predicate. +- **outputs:** `%result` is the generated predicate mask. +- **constraints and limitations:** For 32-bit scalar forms, the scalar source + MUST satisfy the backend's legal scalar-source constraints for this family. + +--- + +#### Selection Operations + +##### `pto.vsel` + +- **syntax:** `%result = pto.vsel %src0, %src1, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Per-lane select based on mask. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? src0[i] : src1[i]; +``` + +**Example — Conditional assignment:** +```mlir +// dst = mask ? true_vals : false_vals +%result = pto.vsel %true_vals, %false_vals, %condition + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +- **inputs:** `%src0` is the true-path vector, `%src1` is the false-path vector, + and `%mask` selects between them. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** Source vectors and result MUST have matching + vector shapes and element types. + +--- + +##### `pto.vselr` + +- **syntax:** `%result = pto.vselr %src0, %src1 : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Select with reversed mask semantics. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? src1[i] : src0[i]; // reversed from vsel +``` + +- **inputs:** `%src0` and `%src1` are the source vectors. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** This family preserves reversed-select + semantics. If the concrete lowering uses an implicit predicate source, that + predicate source MUST be documented by the surrounding IR pattern. + +--- + +##### `pto.vselrv2` + +- **syntax:** `%result = pto.vselrv2 %src0, %src1 : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Variant select form with the same current two-vector operand shape. +- **inputs:** `%src0` and `%src1` are the source vectors. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** This page records the surface shape only. + Lowering MUST preserve the exact A5 variant semantics selected for this form. + +--- + +#### Typical Usage + +```mlir +// Clamp negative values to zero (manual ReLU) +%all = pto.pset_b32 "PAT_ALL" : !pto.mask +%zero = pto.vbr %c0_f32 : f32 -> !pto.vreg<64xf32> +%neg_mask = pto.vcmps %input, %c0_f32, %all, "lt" : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%clamped = pto.vsel %zero, %input, %neg_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Element-wise max via compare+select +%gt_mask = pto.vcmp %a, %b, %all, "gt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +%max_ab = pto.vsel %a, %b, %gt_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Threshold filter +%above_thresh = pto.vcmps %scores, %threshold, %all, "ge" : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%filtered = pto.vsel %scores, %zero, %above_thresh : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +--- + +#### Compare + Select Pattern + +```mlir +// Softmax safe exp: exp(x - max) where x < max returns exp of negative +// but we want to clamp to avoid underflow + +%all = pto.pset_b32 "PAT_ALL" : !pto.mask + +// 1. Compare against threshold +%too_small = pto.vcmps %x_minus_max, %min_exp_arg, %all, "lt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask + +// 2. Clamp values below threshold +%clamped = pto.vsel %min_exp_arg_vec, %x_minus_max, %too_small + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// 3. Safe exp +%exp_result = pto.vexp %clamped, %all : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 12. Data Rearrangement + +> **Category:** In-register data movement and permutation +> **Pipeline:** PIPE_V (Vector Core) + +Operations that rearrange data within or between vector registers without memory access. + +#### Common Operand Model + +- `%lhs` / `%rhs` are source vector register values. +- `%src` is a single source vector register value. +- `%result` is the destination vector register value unless an op explicitly + returns multiple vectors. +- These families do not access UB directly; they only rearrange register + contents. + +--- + +#### Interleave / Deinterleave + +##### `pto.vintlv` + +- **syntax:** `%low, %high = pto.vintlv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg, !pto.vreg` +- **semantics:** Interleave elements from two sources. + +```c +// Interleave: merge even/odd elements from two sources +// low = {src0[0], src1[0], src0[1], src1[1], ...} +// high = {src0[N/2], src1[N/2], src0[N/2+1], src1[N/2+1], ...} +``` + +- **inputs:** `%lhs` and `%rhs` are the two source vectors. +- **outputs:** `%low` and `%high` are the two destination vectors. +- **constraints and limitations:** The two outputs form a paired interleave + result. The PTO micro Instruction representation exposes that pair as two SSA results, and the pair ordering MUST + be preserved. + +--- + +##### `pto.vdintlv` + +- **syntax:** `%low, %high = pto.vdintlv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg, !pto.vreg` +- **semantics:** Deinterleave elements into even/odd. + +```c +// Deinterleave: separate even/odd elements +// low = {src0[0], src0[2], src0[4], ...} // even +// high = {src0[1], src0[3], src0[5], ...} // odd +``` + +- **inputs:** `%lhs` and `%rhs` represent the interleaved source stream in the + current PTO micro Instruction representation. +- **outputs:** `%low` and `%high` are the separated destination vectors. +- **constraints and limitations:** The two outputs form the even/odd + deinterleave result pair, and their ordering MUST be preserved. + +--- + +#### Slide / Shift + +##### `pto.vslide` + +- **syntax:** `%result = pto.vslide %src0, %src1, %amt : !pto.vreg, !pto.vreg, i16 -> !pto.vreg` +- **semantics:** Concatenate two vectors and extract N-element window at offset. + +```c +// Conceptually: tmp[0..2N-1] = {src1, src0} +// dst[i] = tmp[amt + i] +if (amt >= 0) + for (int i = 0; i < N; i++) + dst[i] = (i >= amt) ? src0[i - amt] : src1[N - amt + i]; +``` + +**Use case:** Sliding window operations, shift register patterns. + +- **inputs:** `%src0` and `%src1` provide the concatenated source window and + `%amt` selects the extraction offset. +- **outputs:** `%result` is the extracted destination window. +- **constraints and limitations:** `pto.vslide` operates on the logical + concatenation of `%src1` and `%src0`. The source order and extraction offset + MUST be preserved exactly. + +--- + +##### `pto.vshift` + +- **syntax:** `%result = pto.vshift %src, %amt : !pto.vreg, i16 -> !pto.vreg` +- **semantics:** Single-source slide (shift with zero fill). + +```c +for (int i = 0; i < N; i++) + dst[i] = (i >= amt) ? src[i - amt] : 0; +``` + +- **inputs:** `%src` is the source vector and `%amt` is the slide amount. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** This surface represents the single-source + slide/shift family. Zero-fill versus other fill behavior MUST match the + selected form. + +--- + +#### Compress / Expand + +##### `pto.vsqz` + +- **syntax:** `%result = pto.vsqz %src, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Compress — pack active lanes to front. + +```c +int j = 0; +for (int i = 0; i < N; i++) + if (mask[i]) dst[j++] = src[i]; +while (j < N) dst[j++] = 0; +``` + +**Use case:** Sparse data compaction, filtering. + +- **inputs:** `%src` is the source vector and `%mask` selects which elements are + kept. +- **outputs:** `%result` is the compacted vector. +- **constraints and limitations:** This is a reduction-style compaction family. + Preserved element order MUST match source lane order. + +--- + +##### `pto.vusqz` + +- **syntax:** `%result = pto.vusqz %mask : !pto.mask -> !pto.vreg` +- **semantics:** Expand — scatter front elements to active positions. + +```c +int j = 0; +for (int i = 0; i < N; i++) + if (mask[i]) dst[i] = src_front[j++]; + else dst[i] = 0; +``` + +- **inputs:** `%mask` is the expansion/placement predicate. +- **outputs:** `%result` is the expanded vector image. +- **constraints and limitations:** The source-front stream is implicit in the + current surface. Lane placement for active and inactive positions MUST be + preserved exactly. + +--- + +#### Permutation + +##### `pto.vperm` + +- **syntax:** `%result = pto.vperm %src, %index : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** In-register permute (table lookup). **Not** memory gather. + +```c +for (int i = 0; i < N; i++) + dst[i] = src[index[i] % N]; +``` + +**Note:** This operates on register contents, unlike `pto.vgather2` which reads from UB memory. + +- **inputs:** `%src` is the source vector and `%index` supplies per-lane source + indices. +- **outputs:** `%result` is the permuted vector. +- **constraints and limitations:** This is an in-register permutation family. + `%index` values outside the legal range follow the wrap/clamp behavior of the + selected form. + +--- + +##### `pto.vselr` + +- **syntax:** `%result = pto.vselr %src0, %src1 : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Register select with reversed mask semantics. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? src1[i] : src0[i]; +``` + +- **inputs:** `%src0` and `%src1` are source vectors. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** This page records the rearrangement use of + the family; the compare/select page documents the same name from the predicate + selection perspective. + +--- + +#### Pack / Unpack + +##### `pto.vpack` + +- **syntax:** `%result = pto.vpack %src0, %src1, %part : !pto.vreg, !pto.vreg, index -> !pto.vreg<2NxT_narrow>` +- **semantics:** Narrowing pack — two wide vectors to one narrow vector. + +```c +// e.g., two vreg<64xi32> → one vreg<128xi16> +for (int i = 0; i < N; i++) { + dst[i] = truncate(src0[i]); + dst[N + i] = truncate(src1[i]); +} +``` + +- **inputs:** `%src0` and `%src1` are wide source vectors and `%part` selects + the packing submode. +- **outputs:** `%result` is the packed narrow vector. +- **constraints and limitations:** Packing is a narrowing conversion. Source + values that do not fit the destination width follow the truncation semantics + of the selected packing mode. + +--- + +##### `pto.vsunpack` + +- **syntax:** `%result = pto.vsunpack %src, %part : !pto.vreg, index -> !pto.vreg` +- **semantics:** Sign-extending unpack — narrow to wide (half). + +```c +// e.g., vreg<128xi16> → vreg<64xi32> (one half) +for (int i = 0; i < N/2; i++) + dst[i] = sign_extend(src[part_offset + i]); +``` + +- **inputs:** `%src` is the packed narrow vector and `%part` selects which half + is unpacked. +- **outputs:** `%result` is the widened vector. +- **constraints and limitations:** This is the sign-extending unpack family. + +--- + +##### `pto.vzunpack` + +- **syntax:** `%result = pto.vzunpack %src, %part : !pto.vreg, index -> !pto.vreg` +- **semantics:** Zero-extending unpack — narrow to wide (half). + +```c +for (int i = 0; i < N/2; i++) + dst[i] = zero_extend(src[part_offset + i]); +``` + +- **inputs:** `%src` is the packed narrow vector and `%part` selects which half + is unpacked. +- **outputs:** `%result` is the widened vector. +- **constraints and limitations:** This is the zero-extending unpack family. + +--- + +#### Typical Usage + +```mlir +// AoS → SoA conversion using deinterleave +%even, %odd = pto.vdintlv %interleaved0, %interleaved1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +// Filter: keep only elements passing condition +%pass_mask = pto.vcmps %values, %threshold, %all, "gt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%compacted = pto.vsqz %values, %pass_mask + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Sliding window sum +%prev_window = pto.vslide %curr, %prev, %c1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, i16 -> !pto.vreg<64xf32> +%window_sum = pto.vadd %curr, %prev_window, %all + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Type narrowing via pack +%packed_i16 = pto.vpack %wide0_i32, %wide1_i32, %c0 + : !pto.vreg<64xi32>, !pto.vreg<64xi32>, index -> !pto.vreg<128xi16> +``` + +--- + +#### V2 Interleave Forms + +##### `pto.vintlvv2` + +- **syntax:** `%result = pto.vintlvv2 %lhs, %rhs, "PART" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **inputs:** `%lhs` and `%rhs` are source vectors and `PART` selects the + returned half of the V2 interleave result. +- **outputs:** `%result` is the selected interleave half. +- **constraints and limitations:** This op exposes only one half of the V2 + result in SSA form. + +##### `pto.vdintlvv2` + +- **syntax:** `%result = pto.vdintlvv2 %lhs, %rhs, "PART" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **inputs:** `%lhs` and `%rhs` are source vectors and `PART` selects the + returned half of the V2 deinterleave result. +- **outputs:** `%result` is the selected deinterleave half. +- **constraints and limitations:** This op exposes only one half of the V2 + result in SSA form. + + + +### 13. DSA/SFU Ops + +> **Category:** Domain-specific accelerator and special function unit operations +> **Pipeline:** PIPE_V (Vector Core) / SFU + +Fused operations, special functions, and UB-to-UB operations that leverage hardware acceleration. + +#### Common Operand Model + +- `%input`, `%lhs`, `%rhs`, `%acc`, and `%alpha` are source SSA values whose + roles are called out per instruction. +- `%mask` is the predicate operand `Pg` when present. +- `%result` is the destination SSA value. +- This page mixes three different backend shapes: pure `vreg -> vreg` ops, + conversion/fusion ops, and UB-to-UB helpers. Each instruction section calls + out which storage model it uses. + +--- + +#### Fused Activation Ops (vreg→vreg) + +##### `pto.vlrelu` + +- **syntax:** `%result = pto.vlrelu %input, %alpha, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Leaky ReLU with scalar alpha. + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : alpha * src[i]; +``` + +- **inputs:** `%input` is the activation vector, `%alpha` is the scalar slope, + and `%mask` selects active lanes. +- **outputs:** `%result` is the leaky-ReLU vector. +- **constraints and limitations:** Only `f16` and `f32` forms are currently + documented for `pto.vlrelu`. + +--- + +##### `pto.vprelu` + +- **syntax:** `%result = pto.vprelu %input, %alpha : !pto.vreg, !pto.vreg -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Parametric ReLU with per-element alpha vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : alpha[i] * src[i]; +``` + +- **inputs:** `%input` is the activation vector and `%alpha` is the per-element + slope vector. +- **outputs:** `%result` is the parametric-ReLU vector. +- **constraints and limitations:** Floating-point element types only on the + current A5 surface. + +--- + +##### `pto.vexpdif` + +- **syntax:** `%result = pto.vexpdif %input, %max : !pto.vreg, !pto.vreg -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Fused exp(x - max) for numerically stable softmax. + +```c +for (int i = 0; i < N; i++) + dst[i] = expf(src[i] - max[i]); +``` + +**Use case:** Softmax numerator computation with numerical stability. + +- **inputs:** `%input` is the source vector and `%max` is the broadcasted + subtraction term. +- **outputs:** `%result` is the fused `exp(input - max)` vector. +- **constraints and limitations:** Floating-point element types only. + +--- + +#### Fused Compute+Convert Ops + +##### `pto.vaddrelu` + +- **syntax:** `%result = pto.vaddrelu %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Fused add + ReLU. + +```c +for (int i = 0; i < N; i++) + dst[i] = max(src0[i] + src1[i], 0); +``` + +- **inputs:** `%lhs` and `%rhs` are the two addends. +- **outputs:** `%result` is the fused add-then-ReLU result. +- **constraints and limitations:** Floating-point element types only on the + current documented surface. + +--- + +##### `pto.vsubrelu` + +- **syntax:** `%result = pto.vsubrelu %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Fused sub + ReLU. + +```c +for (int i = 0; i < N; i++) + dst[i] = max(src0[i] - src1[i], 0); +``` + +- **inputs:** `%lhs` is the minuend and `%rhs` is the subtrahend. +- **outputs:** `%result` is the fused sub-then-ReLU result. +- **constraints and limitations:** Floating-point element types only on the + current documented surface. + +--- + +##### `pto.vaxpy` + +- **syntax:** `%result = pto.vaxpy %src0, %src1, %alpha : !pto.vreg, !pto.vreg, T -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** AXPY — scalar-vector multiply-add. + +```c +for (int i = 0; i < N; i++) + dst[i] = alpha * src0[i] + src1[i]; +``` + +- **inputs:** `%src0` is the scaled vector, `%src1` is the addend vector, and + `%alpha` is the scalar multiplier. +- **outputs:** `%result` is the fused AXPY result. +- **constraints and limitations:** Floating-point element types only on the + current documented surface. + +--- + +##### `pto.vaddreluconv` + +- **syntax:** `%result = pto.vaddreluconv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Fused add + ReLU + type conversion (HW fusion). + +```c +// f32→f16 variant: +for (int i = 0; i < 64; i++) + dst_f16[i] = f32_to_f16(max(src0_f32[i] + src1_f32[i], 0)); + +// f16→i8 variant: +for (int i = 0; i < 128; i++) + dst_i8[i] = f16_to_i8(max(src0_f16[i] + src1_f16[i], 0)); +``` + +- **inputs:** `%lhs` and `%rhs` are the source vectors. +- **outputs:** `%result` is the fused add/ReLU/convert result. +- **constraints and limitations:** Only backend-supported source/destination + type pairs are legal. Rounding, saturation, and packing rules follow the + semantics of this fused operation, not an arbitrary sequence of standalone + ops. + +--- + +##### `pto.vmulconv` + +- **syntax:** `%result = pto.vmulconv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Fused mul + type conversion (HW fusion). + +```c +// f16→i8 variant: +for (int i = 0; i < 128; i++) + dst_i8[i] = f16_to_i8(src0_f16[i] * src1_f16[i]); +``` + +- **inputs:** `%lhs` and `%rhs` are the source vectors. +- **outputs:** `%result` is the fused mul/convert result. +- **constraints and limitations:** Only backend-supported source/destination + type pairs are legal. + +--- + +#### Extended Arithmetic + +##### `pto.vmull` + +- **syntax:** `%low, %high = pto.vmull %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.vreg` +- **A5 types:** i32/u32 (native 32×32→64 widening multiply) +- **semantics:** Widening multiply with high/low results. + +```c +for (int i = 0; i < 64; i++) { + int64_t r = (int64_t)src0_i32[i] * (int64_t)src1_i32[i]; + dst_lo[i] = (int32_t)(r & 0xFFFFFFFF); + dst_hi[i] = (int32_t)(r >> 32); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the source vectors and `%mask` selects + active lanes. +- **outputs:** `%low` and `%high` expose the widened-product low/high parts. +- **constraints and limitations:** The current documented A5 form is the native + widening 32x32->64 integer multiply family. + +--- + +##### `pto.vmula` + +- **syntax:** `%result = pto.vmula %acc, %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Multiply-accumulate. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) + dst[i] = acc[i] + lhs[i] * rhs[i]; +``` + +- **inputs:** `%acc` is the accumulator input, `%lhs` and `%rhs` are the + multiplicands, and `%mask` selects active lanes. +- **outputs:** `%result` is the multiply-accumulate result. +- **constraints and limitations:** `pto.vmula` is a fused multiply-accumulate + operation and is not always interchangeable with separate `vmul` plus `vadd`. + +--- + +#### Index Generation + +##### `pto.vci` + +- **syntax:** `%result = pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- **semantics:** Generate lane index vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = base_index + i; +``` + +**Use case:** Generate indices for gather/scatter, argsort, etc. + +- **inputs:** `%index` is the scalar seed/base index. +- **outputs:** `%result` is the generated index vector. +- **constraints and limitations:** This page documents the arithmetic/indexing + use of the family; the conversion page also records the same opcode for + completeness. + +--- + +#### UB-to-UB Operations + +##### `pto.vtranspose` + +- **syntax:** `pto.vtranspose %dest, %src, %config : !pto.ptr, !pto.ptr, i64` +- **semantics:** UB-to-UB transpose operation (not vreg-to-vreg). + +**Note:** This operates on UB memory directly, not on vector registers. + +- **inputs:** `%dest` and `%src` are UB pointers and `%config` is the ISA + control/config word. +- **outputs:** This op writes UB memory and returns no SSA value. +- **constraints and limitations:** This is not a `vreg -> vreg` op even though + it lives in the `pto.v*` namespace. Its correctness depends on the control + word and UB layout contract. + +--- + +#### Sorting Operations + +##### `pto.vsort32` + +- **syntax:** `pto.vsort32 %dest, %src, %config : !pto.ptr, !pto.ptr, i64` +- **semantics:** Sort 32 elements in UB. +- **inputs:** `%dest` and `%src` are UB pointers and `%config` is the ISA + control/config word. +- **outputs:** This op writes UB memory and returns no SSA value. +- **constraints and limitations:** This is a UB-to-UB accelerator helper, not a + pure vector-register op. + +--- + +##### `pto.vmrgsort` + +- **syntax:** `pto.vmrgsort4 %dest, %src0, %src1, %src2, %src3, %count, %config : !pto.ptr, !pto.ptr x4, i64, i64` +- **semantics:** Merge-sort 4 pre-sorted input vectors. +- **inputs:** `%dest` is the UB destination, `%src0..%src3` are the four + pre-sorted UB inputs, `%count` is the number of valid elements, and `%config` + is the operation control word. +- **outputs:** This op writes UB memory and returns no SSA value. +- **constraints and limitations:** Inputs MUST already be sorted according to + the sort order encoded by `%config`. This page uses the shorter mnemonic + `pto.vmrgsort`, while the current implementation summary still refers to + `pto.vmrgsort4`. + +--- + +#### Current Implementation Surface Summary + +- `pto.vmull %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.vreg` +- `pto.vmula %acc, %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- `pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- `pto.vbitsort %dest, %src, %indices, %repeat_times : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, index` +- `pto.vmrgsort4 %dest, %src0, %src1, %src2, %src3, %count, %config : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, i64, i64` + +--- + +#### Typical Usage + +```mlir +// Softmax with fused expdiff +%max_broadcast = pto.vlds %ub_max[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> +%exp_stable = pto.vexpdif %logits, %max_broadcast : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +// Leaky ReLU activation +%activated = pto.vlrelu %linear_out, %alpha_scalar, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Fused residual add + ReLU +%residual = pto.vaddrelu %conv_out, %skip_connection : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +// Generate indices for argsort +%indices = pto.vci %c0 {order = "ASC"} : i32 -> !pto.vreg<64xi32> +``` + + + +### 14. Arith (Shared MLIR Dialect) + +> **Category:** Shared full scalar `arith` surface used around PTO ops +> **Dialect:** `arith` +> **Upstream Reference:** https://mlir.llvm.org/docs/Dialects/ArithOps/ + +The upstream MLIR `arith` dialect defines primitive arithmetic, comparison, select, and cast operations over signless integer, index, floating-point, and boolean-compatible scalar values. Within PTO micro Instruction code, the full scalar operation surface of `arith` is supported. These ops are used around PTO instructions to build constants, compute offsets and loop bounds, perform general scalar math, derive valid-shape metadata, and form predicates for `scf` control flow. + +These ops are part of the documented PTO micro Instruction surface, but they are not PTO ISA instructions. + +--- + +#### Role in PTO micro Instruction Code + +- materialize scalar constants used by PTO scalar operands and loop bounds +- compute scalar/index offsets for tensor views, partitioning, and dynamic shapes +- perform general scalar integer and floating-point math outside PTO vector/tile payload operations +- derive scalar predicates that guard `scf.if` or `scf.while` +- apply scalar casts, width changes, bitwise ops, and selects without introducing PTO-specific control ops + +Prefer PTO ops for vector or tile payload math. Use `arith` for scalar computation and bookkeeping that surrounds PTO regions. + +--- + +#### Supported Surface + +The documented PTO micro Instruction surface supports the full scalar operation surface of upstream `arith`. The upstream `arith` dialect reference remains authoritative for the exhaustive op-by-op syntax and semantics. The categories below summarize how that support is used in PTO micro Instruction code. + +| Category | Representative Ops | Typical Use in PTO micro Instruction Code | +|----------|--------------------|------------------| +| Constants | `arith.constant` | integer, floating-point, boolean, and `index` constants | +| Integer / Index Arithmetic | `arith.addi`, `arith.subi`, `arith.muli`, `arith.divsi`, `arith.divui`, `arith.ceildivsi`, `arith.ceildivui`, `arith.floordivsi`, `arith.remsi`, `arith.remui` | offsets, bounds, chunk sizes, scalar math | +| Floating-Point Arithmetic | `arith.addf`, `arith.subf`, `arith.mulf`, `arith.divf`, `arith.negf`, `arith.maximumf`, `arith.minimumf`, `arith.maxnumf`, `arith.minnumf` | scalar math around PTO regions | +| Bitwise / Shift Ops | `arith.andi`, `arith.ori`, `arith.xori`, `arith.shli`, `arith.shrsi`, `arith.shrui` | flags, masks, packed scalar fields | +| Comparisons / Select | `arith.cmpi`, `arith.cmpf`, `arith.select`, `arith.maxsi`, `arith.minui` | predicates, clamps, scalar muxes | +| Casts / Width Changes | `arith.index_cast`, `arith.index_castui`, `arith.extsi`, `arith.extui`, `arith.trunci`, `arith.sitofp`, `arith.uitofp`, `arith.fptosi`, `arith.fptoui`, `arith.extf`, `arith.truncf`, `arith.bitcast` | ABI glue, dynamic-shape plumbing, scalar type adaptation | + +--- + +#### Current PTOAS Coverage + +- the current repository examples are still dominated by constants, casts, integer/index arithmetic, compares, and selects because those are the most common surrounding-scalar patterns in existing kernels +- backend-specific tests such as the PTO shared-dialect fixture visibly exercise only a representative subset of `arith` ops in a single path +- the documented PTO micro Instruction source-level contract is nevertheless the full scalar `arith` surface, not just the index-heavy subset that appears most often in current samples + +This section therefore uses representative categories and examples instead of pretending that the supported `arith` surface is limited to the currently most common sample patterns. + +--- + +#### Typical Patterns + +##### Scalar Setup + +```mlir +%c0 = arith.constant 0 : index +%c1 = arith.constant 1 : index +%scale = arith.constant 2.0 : f32 +``` + +##### Dynamic Offset Computation + +```mlir +%vrow = arith.index_cast %valid_row : i32 to index +%chunk = arith.muli %row, %c32 : index +%tail = arith.subi %limit, %chunk : index +``` + +##### General Scalar Arithmetic + +```mlir +%sum_i = arith.addi %lhs_i, %rhs_i : i32 +%sum_f = arith.addf %lhs_f, %rhs_f : f32 +%prod_f = arith.mulf %sum_f, %scale : f32 +``` + +##### Scalar Predicate and Selection + +```mlir +%is_first = arith.cmpi eq, %i, %c0 : index +%active = arith.select %is_first, %first_count, %steady_count : index +``` + +##### Bitwise / Width Adaptation + +```mlir +%flags = arith.andi %flags0, %flags1 : i32 +%wide = arith.extui %flags : i32 to i64 +%shrunk = arith.trunci %wide : i64 to i16 +``` + +--- + +#### Authoring Guidance + +- treat upstream `arith` scalar semantics as the source of truth for supported scalar ops +- keep `arith` values scalar or `index` typed; do not use `arith` as a substitute for PTO vector/tile compute +- use `arith` for general scalar math, scalar comparisons, bitwise operations, and casts around PTO regions, not just for `index` arithmetic +- use `arith.cmpi` / `arith.cmpf` plus `scf.if` / `scf.while` for control flow, not ad hoc control intrinsics +- prefer `arith.index_cast` / `arith.index_castui` at ABI or shape boundaries where `index` is required, but do not read that as a restriction on the rest of scalar `arith` + + + +### 15. SCF (Shared MLIR Dialect) + +> **Category:** Shared structured control flow around PTO regions +> **Dialect:** `scf` +> **Upstream Reference:** https://mlir.llvm.org/docs/Dialects/SCFDialect/ + +The upstream MLIR `scf` dialect defines structured control flow operations with regions, including counted loops, conditional regions, and while-style loops. In PTO micro Instruction code, `scf` is the control shell around PTO ops: it sequences DMA, vector, and tile operations; carries scalar or tile state across iterations; and preserves analyzable control flow for PTO-specific analyses and lowerings. + +These ops are part of the documented PTO micro Instruction surface, but they are shared MLIR control-flow constructs rather than PTO ISA instructions. + +--- + +#### Supported Ops + +| Op | Role in PTO micro Instruction Code | Notes | +|----|------------------------|-------| +| `scf.for` | counted loops and loop-carried values | common structured counted loop form | +| `scf.if` | structured conditional execution | may yield values or act as side-effect-only branch | +| `scf.yield` | region terminator for `for` / `if` / `while` bodies | carries loop or branch results | +| `scf.while` | break-like or stateful loops | useful for source-level structured control | +| `scf.condition` | loop-continue / loop-exit decision for `scf.while` | placed in the "before" region | + +Ops such as `scf.execute_region`, `scf.forall`, or `scf.index_switch` are not part of the documented shared-dialect portion of the PTO micro Instruction surface here. + +--- + +#### Current PTOAS Coverage + +- `scf.for`, `scf.if`, and `scf.yield` are directly exercised in the shared-dialect PTO fixture and appear widely across PTO samples +- PTO synchronization and memory analyses explicitly reason about `scf.for`, `scf.if`, `scf.yield`, and `scf.while` +- `scf.while` and `scf.condition` appear in control-flow samples and are handled in PTO-to-EmitC control-flow lowering, but they are less broadly exercised than `for` / `if` on all backend paths + +--- + +#### Typical Patterns + +##### Counted Loop + +```mlir +scf.for %i = %c0 to %c4 step %c1 { + %offset = arith.muli %i, %c32 : index + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} +``` + +##### Counted Loop with Loop-Carried State + +```mlir +%final_alive = scf.for %i = %c0 to %c4 step %c1 + iter_args(%alive = %true) -> (i1) { + %break_now = arith.cmpi eq, %i, %c2 : index + %next_alive = scf.if %break_now -> (i1) { + scf.yield %false : i1 + } else { + scf.yield %alive : i1 + } + scf.yield %next_alive : i1 +} +``` + +##### Structured Conditional Region + +```mlir +%is_mode_a = arith.cmpi eq, %mode, %c0_i32 : i32 +scf.if %is_mode_a { + pto.tmuls ins(%data, %scale_a : !pto.tile_buf<...>, f32) outs(%data : !pto.tile_buf<...>) +} else { + pto.tadds ins(%data, %bias_b : !pto.tile_buf<...>, f32) outs(%data : !pto.tile_buf<...>) +} +``` + +##### While-Style Break Loop + +```mlir +%final:2 = scf.while (%i = %c0, %alive = %true) : (index, i1) -> (index, i1) { + %lt = arith.cmpi slt, %i, %c4 : index + %go = arith.andi %lt, %alive : i1 + scf.condition(%go) %i, %alive : index, i1 +} do { +^bb0(%i2: index, %alive2: i1): + %next_i = arith.addi %i2, %c1 : index + scf.yield %next_i, %alive2 : index, i1 +} +``` + +--- + +#### Authoring Guidance + +- use `scf.for` for regular counted loops and loop-carried scalar/tile state +- use `scf.if` for structured branching around PTO regions instead of inventing PTO-specific branch ops +- keep region results explicit with `scf.yield`; this is important for PTO analyses that track carried buffers and aliasing +- use `scf.while` only when a counted loop cannot express the control cleanly; `scf.for` remains the more common and better-exercised form in the current repository +- build branch predicates and loop conditions with `arith` ops, not PTO vector masks, unless the control decision truly comes from a scalarized value + +## Supported Data Types + +| Type | Bits | vreg Lanes | Description | +|------|------|-----------|-------------| +| `i8` / `u8` | 8 | 256 | Signed/unsigned 8-bit integer | +| `i16` / `u16` | 16 | 128 | Signed/unsigned 16-bit integer | +| `f16` | 16 | 128 | IEEE 754 half precision | +| `bf16` | 16 | 128 | Brain floating point | +| `i32` / `u32` | 32 | 64 | Signed/unsigned 32-bit integer | +| `f32` | 32 | 64 | IEEE 754 single precision | +| `i64` / `u64` | 64 | 32 | Signed/unsigned 64-bit integer | + +--- + +## Common Patterns + +### Softmax (Numerically Stable) + +```mlir +// 1. Find max +%max_vec = pto.vcmax %logits, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +pto.vsts %max_vec, %ub_tmp[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +%max_bc = pto.vlds %ub_tmp[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> + +// 2. exp(x - max) using fused op +%exp = pto.vexpdif %logits, %max_bc : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +// 3. Sum +%sum = pto.vcadd %exp, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +pto.vsts %sum, %ub_tmp[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +%sum_bc = pto.vlds %ub_tmp[%c0] {dist = "BRC_B32"} : !pto.ptr -> !pto.vreg<64xf32> + +// 4. Divide +%softmax = pto.vdiv %exp, %sum_bc, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +### ReLU Variants + +```mlir +// Standard ReLU +%relu = pto.vrelu %input, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Leaky ReLU (scalar alpha) +%lrelu = pto.vlrelu %input, %alpha, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Parametric ReLU (per-element alpha) +%prelu = pto.vprelu %input, %alpha_vec : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +// Fused add + ReLU +%fused = pto.vaddrelu %a, %b : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> +``` + +### Data Layout Conversion + +```mlir +// AoS → SoA (deinterleave) +%x, %y = pto.vldx2 %ub_xy[%offset], "DINTLV_B32" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +// SoA → AoS (interleave) +pto.vstx2 %x, %y, %ub_xy[%offset], "INTLV_B32", %all_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.ptr, index, !pto.mask +``` + +--- + +## Quick Reference by Category + +### Memory Operations + +| Operation | Group | Description | +|-----------|-------|-------------| +| GM→UB DMA | 2 | `pto.copy_gm_to_ubuf` | +| UB→GM DMA | 2 | `pto.copy_ubuf_to_gm` | +| UB→UB Copy | 2 | `pto.copy_ubuf_to_ubuf` | +| Contiguous Load | 3 | `pto.vlds` with `NORM` dist | +| Broadcast Load | 3 | `pto.vlds` with `BRC_*` dist | +| Gather | 3 | `pto.vgather2`, `pto.vgatherb` | +| Contiguous Store | 3 | `pto.vsts` with `NORM_*` dist | +| Scatter | 3 | `pto.vscatter` | + +### Compute Operations + +| Operation | Group | Description | +|-----------|-------|-------------| +| Element-wise Arithmetic | 6, 7 | `pto.vadd`, `pto.vmul`, `pto.vabs`, etc. | +| Scalar Operations | 8 | `pto.vadds`, `pto.vmuls`, etc. | +| Transcendental | 6 | `pto.vexp`, `pto.vln`, `pto.vsqrt`, etc. | +| Reduction | 10 | `pto.vcadd`, `pto.vcmax`, `pto.vcmin` | +| Comparison | 11 | `pto.vcmp`, `pto.vcmps` | +| Selection | 11 | `pto.vsel`, `pto.vselr` | + +### Type & Data Manipulation + +| Operation | Group | Description | +|-----------|-------|-------------| +| Type Conversion | 9 | `pto.vcvt` | +| Interleave/Deinterleave | 12 | `pto.vintlv`, `pto.vdintlv` | +| Interleave/Deinterleave | 12 | `pto.vintlv`, `pto.vdintlv`, `pto.vintlvv2`, `pto.vdintlvv2` | + +### Synchronization + +| Operation | Group | Description | +|-----------|-------|-------------| +| Intra-core Sync | 1 | `pto.set_flag`, `pto.wait_flag` | +| Pipeline Buffer Sync | 1 | `pto.get_buf`, `pto.rls_buf` | + +### Scalar & Control Operations + +Group 14 covers the full scalar `arith` surface. The rows below list common PTO micro Instruction patterns rather than an exhaustive partition of `arith` ops. + +| Operation | Group | Description | +|-----------|-------|-------------| +| Scalar Constants | 14 | `arith.constant` | +| Scalar Integer / Index Arithmetic | 14 | `arith.addi`, `arith.subi`, `arith.muli`, `arith.divsi`, `arith.remui`, `arith.ceildivsi`, etc. | +| Scalar Floating-Point Arithmetic | 14 | `arith.addf`, `arith.subf`, `arith.mulf`, `arith.divf`, `arith.maximumf`, etc. | +| Scalar Compare & Select | 14 | `arith.cmpi`, `arith.cmpf`, `arith.select` | +| Scalar Casts / Width Changes | 14 | `arith.index_cast`, `arith.index_castui`, `arith.extsi`, `arith.extui`, `arith.trunci`, `arith.sitofp`, etc. | +| Scalar Bitwise / Shift Ops | 14 | `arith.andi`, `arith.ori`, `arith.xori`, `arith.shli`, `arith.shrsi`, `arith.shrui`, etc. | +| Counted Loops | 15 | `scf.for` | +| Conditional Regions | 15 | `scf.if`, `scf.yield` | +| Break-like Structured Loops | 15 | `scf.while`, `scf.condition`, `scf.yield` | diff --git a/tilelang-dsl/docs/vpto_spec/vpto-spec-v0.3.md b/tilelang-dsl/docs/vpto_spec/vpto-spec-v0.3.md new file mode 100644 index 000000000..0a280204e --- /dev/null +++ b/tilelang-dsl/docs/vpto_spec/vpto-spec-v0.3.md @@ -0,0 +1,5351 @@ +# PTO micro Instruction Spec — Draft (A5) + +- v0.3: Add runtime block query and vector-interval legality notes; Normalize load/store distribution families; Update get_buf/rls_buf details +- v0.2: Update micro Instruction latency and throughput +- v0.1: Doc Init + +[toc] + +--- + +## Part I: Architecture Overview + +### Overview + +This document defines the PTO micro Instruction, a compiler-internal and externally facing specification designed to represent vector compute kernels within the PTO architecture. Much like NVVM provides a robust IR for GPU architectures, the PTO micro Instruction serves as the direct bridge between high-level programming models and the underlying hardware ISA, providing a precise, low-level representation of vector workloads explicitly designed for the Ascend 950 architecture. + +#### Position in the Stack and Layer Modeled + +The PTO micro Instruction operates as a very low-level intermediate representation within the PTO compiler stack. It is uniquely designed to accurately and comprehensively express all architectural information of the Ascend 950 hardware. It specifically models the bare-metal vector execution layer, making hardware-specific capabilities and constraints, such as exact vector lane configurations, memory space hierarchies, and hardware-specific fusion semantics, fully transparent and controllable. + +#### PTO Instruction Modes and Compilation Flows + +Within the end-to-end PTO software stack, PTO instructions may appear in three closely related authoring or lowering modes: + +- **PTO Tile Instruction**: tile-oriented PTO code that serves as a nano-kernel encapsulation of Tile operations, primarily expressing computation and data movement in terms of tile buffers, tile shapes, and tile-local layout. +- **PTO micro Instruction**: vector-execution-oriented PTO code that makes DMA setup, vector registers, masks, synchronization, and `__VEC_SCOPE__` boundaries explicit. This document is centered on this mode. +- **PTO Tile+micro Instruction**: a hybrid PTO form that keeps tile-level orchestration while embedding explicit micro-instruction regions where direct vector-pipeline control is required. + +From these PTO instruction forms, the stack can proceed along two main compilation flows: + +- **CCE generation flow**: PTO ISA is lowered into a CCE-oriented representation, which is then compiled by the BiSheng toolchain into Ascend device binaries. +- **Bytecode generation flow**: PTO ISA is emitted as bytecode, which is then compiled by the BiSheng toolchain into Ascend device binaries. + +```text +High-level frameworks / DSLs / library kernels + | + v + +----------------------------------+ + | PTO ISA layer | + | | + | (1) PTO Tile Instruction | + | (2) PTO micro Instruction | + | (3) PTO Tile+micro Instruction | + +----------------+-----------------+ + | + +------------+------------+ + | | + v v + +-------------------------+ +-------------------------+ + | Path A: generate CCE | | Path B: generate | + | (CCE-oriented form) | | bytecode | + +------------+------------+ +------------+------------+ + | | + v v + +-----------------------------------------------+ + | BiSheng compiler | + +---------------------------+-------------------+ + | + v + +-----------------------------+ + | Ascend device binaries | + +-----------------------------+ +``` + +#### Why External Developers Read or Author PTO micro Instruction + +While the majority of users will interact with the PTO architecture via higher-level frameworks, external developers may need to read or author PTO micro Instruction directly for several key reasons: + +- Custom Toolchain Development: build custom compiler frontends or domain-specific languages (DSLs) that target the Ascend 950 architecture with maximum hardware utilization. +- Performance Engineering: inspect the output of high-level compiler passes, verify fine-grained optimization behaviors, and pinpoint performance bottlenecks at the architectural level. +- Micro-Optimization: hand-author highly optimized, critical mathematical kernels using a stable, precise IR when higher-level abstractions cannot achieve the theoretical peak performance of the hardware. + +#### Relationship to CCE + +The PTO micro Instruction is designed to express the full semantic capabilities of the Compute Cube Engine (CCE), but with significant structural and pipeline advantages for compiler development. + +- Bypassing the C/Clang Pipeline: while CCE heavily relies on C/C++ extensions parsed by Clang, the PTO micro Instruction operates entirely independently of the C language frontend. By bypassing Clang AST generation and frontend processing, utilizing the PTO micro Instruction significantly reduces overall compilation time and memory overhead. +- Enhanced IR Verification: because the PTO micro Instruction is a strongly typed, SSA-based (Static Single Assignment) compiler IR rather than a C-wrapper API, it provides a much more rigorous and detailed IR verification process. Structural inconsistencies, invalid memory access patterns, and operand type mismatches are caught immediately with precise, explicit diagnostic feedback, providing developers with much higher visibility into kernel correctness than traditional CCE error reporting. + +#### Intended Audience + +This document is written for compiler engineers, library writers, and advanced performance architects. We expect the reader to have a working understanding of modern compiler infrastructure, specifically MLIR, the principles of Static Single Assignment (SSA) form, and a deep understanding of the vector-processing capabilities of the Ascend 950 architecture. + +### Getting Started + +The PTO micro Instruction is architected as a performance-critical layer within the compiler stack, specifically designed to exploit the **Decoupled Access-Execute** (DAE) nature of the Ascend 950 hardware. + +#### Hardware Pipeline Modeling + +The IR is structured to mirror the three primary hardware pipelines of the Ascend 950 architecture. Correct PTO micro Instruction authoring requires managing the interaction between these asynchronous units: + +**MTE2** (Memory Transfer Engine - Inbound): Responsible for moving data from Global Memory (GM) to the Unified Buffer (UB). + +**Vector Core** (Computation): The primary engine for executing SIMD operations on data stored in UB. + +**MTE3** (Memory Transfer Engine - Outbound): Responsible for moving processed data from UB back to GM. + +#### Architecture Detail: Vector Lane (VLane) + +The vector register is organized as **8 VLanes** of 32 bytes each. A VLane is the atomic unit for group reduction operations. + +``` +vreg (256 bytes total): +┌─────────┬─────────┬─────────┬─────┬─────────┬─────────┐ +│ VLane 0 │ VLane 1 │ VLane 2 │ ... │ VLane 6 │ VLane 7 │ +│ 32B │ 32B │ 32B │ │ 32B │ 32B │ +└─────────┴─────────┴─────────┴─────┴─────────┴─────────┘ +``` + +Elements per VLane by data type: + +| Data Type | Elements/VLane | Total Elements/vreg | +|-----------|---------------|-------------------| +| i8/si8/ui8 | 32 | 256 | +| i16/si16/ui16/f16/bf16 | 16 | 128 | +| i32/si32/ui32/f32 | 8 | 64 | +| i64/si64/ui64 | 4 | 32 | + +#### Memory and Synchronization Model + +The PTO micro Instruction enforces a strict memory hierarchy. The Unified Buffer (UB) is the only valid operand source for vector compute instructions. Consequently, the architecture of a PTO micro Instruction program is defined by the explicit management of data movement: + +**Address Space Isolation**: The IR uses `!pto.ptr` to distinguish between GM (`!pto.ptr`) and UB (`!pto.ptr`). The verifier ensures that vector compute operations do not access GM directly; data must first be moved into UB. + +**UB Capacity**: The Unified Buffer provides 256KB of on-chip SRAM (also referred to as "vecTile"). + +**Data Flow**: + +``` +┌─────────────────────────────────────────────┐ +│ Global Memory (GM) │ +│ (Off-chip HBM/DDR) │ +└─────────────────────┬───────────────────────┘ + │ DMA (MTE2 inbound / MTE3 outbound) +┌─────────────────────▼───────────────────────┐ +│ Unified Buffer (UB) │ +│ (On-chip SRAM, 256KB) │ +└─────────────────────┬───────────────────────┘ + │ Vector Load/Store (PIPE_V) +┌─────────────────────▼───────────────────────┐ +│ Vector Register File (VRF) │ +│ vreg (256B each) + mask (256-bit each) │ +└─────────────────────────────────────────────┘ +``` + +1. **GM → UB**: DMA transfer via MTE2 (`pto.copy_gm_to_ubuf`) +2. **UB → vreg**: Vector Load instructions (`pto.vlds`, `pto.vldsx2`, etc.) +3. **vreg → vreg**: Compute instructions (`pto.vadd`, `pto.vmul`, etc.) +4. **vreg → UB**: Vector Store instructions (`pto.vsts`, `pto.vstsx2`, etc.) +5. **UB → GM**: DMA transfer via MTE3 (`pto.copy_ubuf_to_gm`) + +**Load/Store Access Patterns**: + +For UB↔vreg data movement, besides contiguous load/store, the architecture provides rich access pattern support including strided access, pack/unpack, interleave/deinterleave, broadcast, upsample/downsample, channel split/merge, gather/scatter, and squeeze/expand operations. For detailed instruction syntax and distribution modes, refer to the [Vector Load/Store](#isa-03-vector-load-store) group in the ISA specification. + +#### Synchronization Model + +The Ascend 950 architecture employs a cluster-based design with a 1:2 ratio of Cube cores to Vector cores. The PTO micro Instruction provides multiple levels of synchronization to manage concurrent execution across pipelines and cores: + +**Inter-Core Synchronization (within a cluster):** + +Synchronization between cores within the same cluster is achieved via the core sync mechanism using `pto.set_intra_core` and `pto.wait_intra_core` operations. This enables coordination between Cube and Vector cores sharing the same cluster resources. + +**Vector Core Pipeline Synchronization:** + +Within a single core, multiple pipelines operate asynchronously: + +- **MTE2 (PIPE_MTE2)**: DMA copy-in from GM to UB +- **MTE3 (PIPE_MTE3)**: DMA copy-out from UB to GM +- **Vector Compute (PIPE_V)**: Vector ALU operations +- **Scalar (PIPE_S)**: Scalar unit running the kernel program + +Pipeline synchronization can be achieved through two mechanisms: + +1. **Flag/Event mechanism**: `pto.set_flag` and `pto.wait_flag` operations resolve Read-After-Write (RAW) and Write-After-Read (WAR) hazards between pipelines. + +2. **Buffer-ID mechanism**: `pto.get_buf` and `pto.rls_buf` provide finer-grained synchronization through buffer acquisition and release semantics for producer-consumer coordination. + +**Intra-Pipeline Memory Barriers (within `__VEC_SCOPE__`):** + +Within the vector execution scope, the hardware does not track UB address aliasing between reg↔UB accesses. When UB addresses overlap or alias between vector load/store operations, explicit memory barriers are required: + +```c +pto.mem_bar "VV_ALL" // All prior vector ops complete before subsequent +pto.mem_bar "VST_VLD" // All prior vector stores visible before subsequent loads +pto.mem_bar "VLD_VST" // All prior vector loads complete before subsequent stores +``` + +Without proper barriers, loads may see stale data or stores may be reordered incorrectly. + +#### Execution Scopes (__VEC_SCOPE__) + +`__VEC_SCOPE__` is the IR-level representation of a Vector Function (VF) launch. In the PTO architecture, it defines the hardware interface between the Scalar Unit and the Vector Thread. + +In PTO micro Instruction source IR, vector execution scopes are modeled as dedicated region ops. The default form is `pto.vecscope`; when the scope body must reject implicit capture and require explicit region arguments, use `pto.strict_vecscope`. + +**Scalar-Vector Interface:** + +The execution model follows non-blocking fork semantics: + +- Scalar invocation: the scalar processor invokes a vector thread by calling a VF. Once the launch command is issued, the scalar unit does not stall and continues executing subsequent instructions in the pipeline. +- Vector execution: after invocation, the vector thread independently fetches and executes the instructions defined within the VF scope. +- Parallelism: this decoupled execution allows the scalar and vector units to run in parallel, so the scalar unit can prepare addresses or manage control flow while the vector unit performs heavy SIMD computation. + +**Launch Mechanism And Constraints:** + +- Parameter buffering: all arguments required by the VF must be staged in hardware-specific buffers. +- Launch overhead: launching a VF incurs a latency of a few cycles. Very small VFs should account for this overhead because launch cost can rival useful computation time. + +**MLIR Representation:** + +```mlir +pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} +``` + +**Strict MLIR Representation:** + +```mlir +pto.strict_vecscope(%ub, %ub_out, %lane) { +^bb0(%in: !pto.ptr, %out: !pto.ptr, %iv: index): + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %in[%iv] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %out[%iv], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} : (!pto.ptr, !pto.ptr, index) -> () +``` + +`pto.strict_vecscope` is the strict form of `pto.vecscope`. + +- `pto.vecscope` allows the body to use surrounding SSA values directly. +- `pto.strict_vecscope` requires every external value used by the body to be passed through the op operand list and received as a body block argument. +- `pto.strict_vecscope` rejects implicit capture from the surrounding scope. +- both ops still represent one explicit VPTO vector interval. +- regardless of whether the source form uses `pto.vecscope`, + `pto.strict_vecscope`, or a lowered carrier loop with + `llvm.loop.aivector_scope`, every op that produces or consumes `!pto.vreg`, + `!pto.mask<...>`, or `!pto.align` must be enclosed by exactly one vector + interval +- nested vector intervals are not part of the legal VPTO surface; ordinary + nested `scf.for` structure is fine, but one vector interval may not contain + another vector interval + +### Example: VecScope + +```mlir +pto.set_loop2_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.copy_gm_to_ubuf %7, %2, %3, %3, %c0_i64, %c32_i64, %4, %c0_i64, %c0_i64, + %false, %c0_i64, %c128_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i64, i1, i64, i64, i64 + +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +pto.vecscope { + scf.for %lane = %c0 to %9 step %c64 { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %2[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %8[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } +} + +pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] +pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_ubtoout %c4096_i64, %c4096_i64 : i64, i64 +pto.set_loop2_stride_ubtoout %c4096_i64, %c4096_i64 : i64, i64 +pto.copy_ubuf_to_gm %8, %14, %3, %3, %c0_i64, %c32_i64, %4, %c0_i64, %c128_i64, %c128_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64, i64, i64 +``` + +### Example: Strict VecScope + +```mlir +pto.strict_vecscope(%ub_in, %ub_out, %lane, %remaining) { +^bb0(%in: !pto.ptr, %out: !pto.ptr, %iv: index, %rem: i32): + %mask, %next_remaining = pto.plt_b32 %rem : i32 -> !pto.mask, i32 + %v = pto.vlds %in[%iv] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %out[%iv], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} : (!pto.ptr, !pto.ptr, index, i32) -> () +``` + +Use `pto.strict_vecscope` when the source form should make all vector-scope inputs explicit in the region signature instead of relying on surrounding SSA visibility. The scope op itself only defines the vector-interval boundary and region argument contract. + +### Scope + +This document is the interface specification centered on the `mlir::pto` dialect and the shared MLIR surface used alongside it in PTO micro Instruction programs. + +It only describes: + +- operation names +- operand and result lists +- operand and result types +- important attributes +- C-style semantics for each operation + +It does not describe lowering strategy. + +PTO micro Instruction source programs are not restricted to `pto` operations alone. In practice they also use shared MLIR dialect ops, most notably the full scalar operation surface of `arith` together with structured control-flow ops from `scf`, to express scalar constants, scalar arithmetic, type conversion, comparisons, and structured control flow around PTO vector or tile regions. These shared-dialect ops are part of the supported PTO micro Instruction source surface and should be regarded as part of PTO-ISA alongside `pto` dialect operations. + +### Shared MLIR Dialects + +- `arith`: the full scalar `arith` surface is supported in PTO micro Instruction programs, covering scalar integer, floating-point, boolean, and `index` operations. In current samples the most common uses are still constants, offset/bounds arithmetic, casts, compares, and selects. +- `scf`: structured control flow used to model counted loops, conditional regions, loop-carried state, and break-like control around PTO compute and data-movement ops. +- Shared dialect ops remain in standard MLIR form so that PTO analyses and backend passes can reason about control flow and scalar state without re-encoding them as PTO-specific instructions. + +### BlockDim Query Operations + +These ops expose the current kernel instance's execution coordinates to scalar code. They are the PTO-level equivalent of runtime queries such as `GetBlockIdx()` and `GetBlockNum()` in kernel programming models. + +Use them when the same kernel body is launched across multiple blocks or subblocks and each execution instance must figure out which slice of the global workload it owns. + +A common pattern is: + +- split the full input/output tensor into `block_num` disjoint block-sized regions +- let each block compute its own starting offset from `block_idx` +- within one block, further tile the local region and drive the tile loop with ordinary scalar `arith` / `scf` ops + +For example, if a tensor is split evenly across 8 blocks and each block handles `block_length = 2048` elements, then block `b` owns the global range `[b * block_length, (b + 1) * block_length)`. The per-block GM base pointer can be formed by adding `block_idx * block_length` elements to the original base pointer. + +At the PTO micro Instruction level, these runtime-query ops are pure scalar producers. They do not perform data movement, do not allocate memory, and do not by themselves create tiling or double buffering. Instead, they provide the scalar values used by surrounding address computation and structured control flow. + +#### Example: block-level data partitioning + +```mlir +%block = pto.get_block_idx +%block_num = pto.get_block_num +%block_len = arith.constant 2048 : index +%base = arith.index_cast %block : i64 to index +%offset = arith.muli %base, %block_len : index +%block_in = pto.addptr %gm_in, %offset : !pto.ptr -> !pto.ptr +%block_out = pto.addptr %gm_out, %offset : !pto.ptr -> !pto.ptr +``` + +In this pattern, all blocks execute the same kernel body, but each block sees a different `%block` value and therefore computes a different GM window. + +#### `pto.get_block_idx` + +- **syntax:** `%block = pto.get_block_idx` +- **result:** `i64` +- **semantics:** Return the current block ID in the range `[0, pto.get_block_num())`. + +```c +block = block_idx(); +``` + +#### `pto.get_subblock_idx` + +- **syntax:** `%subblock = pto.get_subblock_idx` +- **result:** `i64` +- **semantics:** Return the current subblock ID in the range `[0, pto.get_subblock_num())`. + +```c +subblock = subblock_idx(); +``` + +#### `pto.get_block_num` + +- **syntax:** `%block_num = pto.get_block_num` +- **result:** `i64` +- **semantics:** Return the total number of launched blocks visible to the current kernel instance. + +```c +block_num = block_num(); +``` + +#### `pto.get_subblock_num` + +- **syntax:** `%subblock_num = pto.get_subblock_num` +- **result:** `i64` +- **semantics:** Return the total number of visible subblocks for the current execution instance. + +```c +subblock_num = subblock_num(); +``` + +Typical usage: + +```mlir +%block = pto.get_block_idx +%subblock = pto.get_subblock_idx +%block_num = pto.get_block_num +%subblock_num = pto.get_subblock_num +``` + +### Core Types + +### Element Types +`vreg`: `!pto.vreg` Fixed-width PTO micro Instruction vector type with total width exactly 256 bytes (2048 bits). `N` is the lane count, `T` is the element type, and `N * bitwidth(T) = 2048`. + +| Type | Bits | Description | +|------|------|-------------| +| `i8` / `si8` / `ui8` | 8 | Signless/signed/unsigned 8-bit integer | +| `i16` / `si16` / `ui16` | 16 | Signless/signed/unsigned 16-bit integer | +| `i32` / `si32` / `ui32` | 32 | Signless/signed/unsigned 32-bit integer | +| `i64` / `si64` / `ui64` | 64 | Signless/signed/unsigned 64-bit integer | +| `f16` | 16 | IEEE 754 half precision | +| `bf16` | 16 | Brain floating point | +| `f32` | 32 | IEEE 754 single precision | + +### Mask Types + +`mask`: `!pto.mask` Typed predicate-register view. `G` is one of `b8`, `b16`, `b32` and records the byte-granularity interpretation used by VPTO ops and verifiers. + +Typed masks are also the primary legality contract for predicated VPTO code: + +- vector ops over `f32`, `i32`, `si32`, and `ui32` consume `!pto.mask` +- vector ops over `f16`, `bf16`, `i16`, `si16`, and `ui16` consume + `!pto.mask` +- vector ops over 8-bit element families consume `!pto.mask` +- compare families keep seed-mask and result-mask granularity aligned with the + compared vector family +- carry families keep carry-in, carry-out, and execution-mask granularity + aligned with the data-vector family +- mask-only ops that do not explicitly change granularity preserve the same `G` + +### Address Space Conventions + +PTO micro Instruction memory operands use `!pto.ptr`. This specification models the following memory-space attributes: + +| Space | Interpretation | +|-------|----------------| +| `gm` | Global Memory (GM), off-chip HBM/DDR storage | +| `ub` | Unified Buffer (UB), on-chip vector buffer | + +Typical pointer construction and pointer arithmetic follow the same `!pto.ptr<..., space>` form: + +```mlir +%0 = pto.castptr %c0 : i64 -> !pto.ptr +%1 = pto.addptr %0, %c1024 : !pto.ptr -> !pto.ptr +``` + +### `!pto.ptr` + +`!pto.ptr` is the typed pointer form used for explicit memory operands in PTO micro Instruction. + +- `T` is the element type associated with the pointed-to storage. +- `space` is the memory domain, typically `gm` or `ub` in this specification. +- A `pto.ptr` value carries an address plus its element-type / memory-space interpretation, but it does not carry tensor shape or stride metadata by itself. +- Tensor semantics are introduced separately through view-building operations such as `pto.make_tensor_view`. +- Pointer arithmetic is element-based rather than byte-based. + +Typical examples: + +- `!pto.ptr` +- `!pto.ptr` +- `!pto.ptr` + +### Pointer Operations + +#### `pto.castptr` + +- **syntax:** `%result = pto.castptr %addr : i64 -> !pto.ptr` +- **semantics:** Reinterpret a scalar address value as a typed PTO pointer in the target memory space. + +```c +result = (ptr)addr; +``` + +`pto.castptr` is a pointer-construction operation. It does not perform data movement and does not by itself imply any load/store side effect. + +#### `pto.addptr` + +- **syntax:** `%result = pto.addptr %ptr, %offset : !pto.ptr -> !pto.ptr` +- **semantics:** Compute a new pointer by advancing the base pointer by an element offset. + +```c +result = ptr + offset; // offset counted in elements, not bytes +``` + +`pto.addptr` preserves both the element type `T` and the memory-space tag `space`. + +#### `pto.load_scalar` + +- **syntax:** `%value = pto.load_scalar %ptr[%offset] : !pto.ptr -> T` +- **semantics:** Load one scalar element from a pointer-like operand. + +```c +value = ptr[offset]; +``` + +- **inputs:** + `%ptr` is a typed PTO pointer `!pto.ptr`, and `%offset` is an + `index` displacement counted in elements. +- **outputs:** + `%value` is the loaded scalar element. +- **constraints and limitations:** + The result type MUST match the element type of `%ptr`. This op is a scalar + memory helper; unlike `pto.vlds`, it does not produce a `vreg` result and + does not participate in vector load `dist` families. + +#### `pto.store_scalar` + +- **syntax:** `pto.store_scalar %value, %ptr[%offset] : !pto.ptr, T` +- **semantics:** Store one scalar element to a pointer-like operand. + +```c +ptr[offset] = value; +``` + +- **inputs:** + `%value` is the scalar value to store. `%ptr` is a typed PTO pointer + `!pto.ptr`, and `%offset` is an `index` displacement counted in + elements. +- **constraints and limitations:** + The stored value type MUST match the element type of `%ptr`. This op is a + scalar memory helper; unlike `pto.vsts`, it does not consume a mask and does + not target vector-store `dist` families. + +#### Pointer-Based Vector Access Example + +The following lowered-style fragment shows how typed PTO pointers flow through +pointer construction, pointer arithmetic, structured control flow, and PTO +memory ops. Scalar memory access is expressed on `!pto.ptr` in +general, but the common VPTO pattern here is UB-local scalar access alongside +UB vector loads/stores: + +```mlir +%0 = pto.castptr %c0 : i64 -> !pto.ptr +%1 = pto.addptr %0, %c1024 : !pto.ptr -> !pto.ptr +pto.vecscope { + %16 = scf.for %arg3 = %c0 to %11 step %c64 iter_args(%arg4 = %12) -> (i32) { + %mask, %scalar_out = pto.plt_b32 %arg4 : i32 -> !pto.mask, i32 + %s = pto.load_scalar %1[%c4] : !pto.ptr -> f32 + pto.store_scalar %s, %1[%c8] : !pto.ptr, f32 + %17 = pto.vlds %1[%arg3] : !pto.ptr -> !pto.vreg<64xf32> + %18 = pto.vabs %17, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %18, %10[%arg3], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + scf.yield %scalar_out : i32 + } +} +``` + +In this pattern, `pto.castptr` materializes a typed UB pointer, `pto.addptr` shifts the base by 1024 `f32` elements, and the subsequent `[%arg3]` indexing on `pto.vlds` / `pto.vsts` applies an additional element offset relative to that base. + +### Special Types + +#### `!pto.mask` + +`!pto.mask` models an A5 predicate register (256-bit) under a typed granularity view, not an integer vector. + +`G` is part of the type and MUST be one of: + +- `b32` +- `b16` +- `b8` + +All three forms describe the same physical 256-bit predicate-register class. The type parameter does not encode how many lanes are currently active. Instead, it records how VPTO interprets the register when matching mask-producing ops, mask-consuming ops, and verifier legality rules. + +In the ISA chapters below, this document uses `!pto.mask` as shorthand when a +family is generic over granularity. For op families whose names already encode +the granularity, such as `pset_b32`, `pge_b16`, `plt_b8`, +`pdintlv_b8`, and `pintlv_b16`, examples use the corresponding concrete typed +mask. + +**Mask Granularity:** + +The predicate register is 256 bits in length, where each bit controls 1 byte of data. `G` therefore describes how many bytes form one logical element slot: + +| Mask Type | Bytes / Element Slot | Typical Element Family | Derived Logical Lanes | +|-----------|----------------------|------------------------|-----------------------| +| `!pto.mask` | 4 | `f32` / `i32` | 64 | +| `!pto.mask` | 2 | `f16` / `bf16` / `i16` | 128 | +| `!pto.mask` | 1 | 8-bit element family | 256 | + +This is intentionally different from a lane-vector model such as `mask<64xi1>`: + +- `!pto.mask` still denotes a 256-bit predicate register; +- `64` is only the derived logical lane count for the `b32` view; +- value-level patterns such as `PAT_VL32` describe which lanes are active, not a different type. +- `pto.vaddc`, `pto.vsubc`, `pto.vaddcs`, and `pto.vsubcs` use `!pto.mask` + to carry their per-lane carry results, interpreted with this same + granularity. + +**Predication Behavior (Zero-Merge):** + +The native hardware predication mode is **ZEROING** — inactive lanes produce zero: + +```c +dst[i] = mask[i] ? op(src0[i], src1[i]) : 0 // ZEROING mode +``` + +```mlir +// Predicated add: inactive lanes produce zero +%mask = pto.pset_b32 "PAT_VL32" : !pto.mask // first 32 logical b32 lanes active +%result = pto.vcmp %a, %b, %mask, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +``` + +```mlir +// Compare and select: generate mask from comparison, use for conditional select +%mask = pto.vcmp %lhs, %rhs, %seed, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +%out = pto.vsel %x, %y, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +#### `!pto.align` + +`!pto.align` models the A5 vector-align carrier state. It is not payload data. + +```mlir +%align = pto.vldas %ub : !pto.ptr -> !pto.align +%vec, %align_out = pto.vldus %ub, %align : !pto.ptr, !pto.align -> !pto.vreg<64xf32>, !pto.align + +%store_align = pto.init_align : !pto.align +%next_align = pto.vstus %store_align, %offset, %vec, %ub + : !pto.align, i32, !pto.vreg<64xf32>, !pto.ptr -> !pto.align +``` + +--- + +## Part II: Notation Convention + +This section defines the MLIR syntax patterns and C-style semantic notation used throughout the ISA reference (Part III). + +### MLIR Op Syntax Patterns + +All PTO micro Instruction operations follow standard MLIR syntax. The common patterns are: + +**Unary (one vector in, one vector out):** + +```mlir +%result = pto. %input : !pto.vreg -> !pto.vreg +``` + +**Binary (two vectors in, one vector out):** + +```mlir +%result = pto. %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg +``` + +**Vec-Scalar (one vector + one scalar in, one vector out):** + +```mlir +%result = pto. %input, %scalar : !pto.vreg, T -> !pto.vreg +``` + +**Load (memory to register):** + +```mlir +%result = pto.vlds %source[%offset] {dist = "DIST"} : !pto.ptr -> !pto.vreg +``` + +**Store (register to memory):** + +```mlir +pto.vsts %value, %destination[%offset] {dist = "DIST"} : !pto.vreg, !pto.ptr +``` + +**Dual Load (one load, two results — deinterleave):** + +```mlir +%low, %high = pto.vldsx2 %source[%offset], "DIST" : !pto.ptr, index -> !pto.vreg, !pto.vreg +``` + +**Dual Store (two inputs, one interleaved store):** + +```mlir +pto.vstsx2 %low, %high, %dest[%offset], "DIST", %mask : !pto.vreg, !pto.vreg, !pto.ptr, index, !pto.mask +``` + +**Compare (two vectors + seed mask in, mask out):** + +```mlir +%mask = pto.vcmp %src0, %src1, %seed, "CMP_MODE" : !pto.vreg, !pto.vreg, !pto.mask -> !pto.mask +``` + +**Conversion (one vector in, different-typed vector out):** + +```mlir +%result = pto.vcvt %input, %mask {rnd = "R", sat = "SAT", part = "EVEN"} : !pto.vreg, !pto.mask -> !pto.vreg +``` + +**Predicate construction:** + +```mlir +%mask = pto.pset_b32 "PAT_ALL" : !pto.mask +%tail = pto.pge_b32 "PAT_VL16" : !pto.mask +``` + +**Sync operations:** + +```mlir +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +**Pointer construction and arithmetic:** + +```mlir +%ptr = pto.castptr %addr : i64 -> !pto.ptr +%ptr2 = pto.addptr %ptr, %offset : !pto.ptr -> !pto.ptr +``` + +### Shared Dialect Syntax Patterns + +PTO micro Instruction programs may interleave PTO ops with standard MLIR `arith` and `scf` ops. +The examples below emphasize common index-heavy patterns, but `arith` support is not limited to index arithmetic. + +**Scalar / index constant:** + +```mlir +%c0 = arith.constant 0 : index +%zero = arith.constant 0.0 : f32 +``` + +**Scalar arithmetic (integer / float / boolean-style bitwise):** + +```mlir +%sum_i = arith.addi %lhs_i, %rhs_i : i32 +%sum_f = arith.addf %lhs_f, %rhs_f : f32 +%bits = arith.andi %flags0, %flags1 : i32 +``` + +**Scalar compare and select:** + +```mlir +%cond = arith.cmpi eq, %lhs, %rhs : index +%bound = arith.select %cond, %a, %b : index +``` + +**Counted loop with loop-carried values:** + +```mlir +%result = scf.for %iv = %lb to %ub step %step + iter_args(%acc = %init) -> (index) { + %next = arith.addi %acc, %iv : index + scf.yield %next : index +} +``` + +**Structured conditional region:** + +```mlir +%selected = scf.if %cond -> (index) { + scf.yield %then_value : index +} else { + scf.yield %else_value : index +} +``` + +**Structured while loop:** + +```mlir +%state:2 = scf.while (%iv = %c0, %alive = %true) : (index, i1) -> (index, i1) { + %keep_going = arith.cmpi slt, %iv, %limit : index + scf.condition(%keep_going) %iv, %alive : index, i1 +} do { +^bb0(%iv_in: index, %alive_in: i1): + %iv_next = arith.addi %iv_in, %c1 : index + scf.yield %iv_next, %alive_in : index, i1 +} +``` + +### C-Style Semantics Convention + +For each ISA operation in Part III, semantics are expressed as C code. The convention: + +```c +// Vector register contents as arrays: +T dst[N]; // destination +T src0[N]; // first source +T src1[N]; // second source (binary ops) +T scalar; // scalar operand (vec-scalar ops) +int mask[N]; // per-lane predicate (0 or 1) + +// N = lane count determined by type: +// N = 256 for i8/si8/ui8 +// N = 128 for i16/si16/ui16/f16/bf16 +// N = 64 for i32/si32/ui32/f32 +// N = 32 for i64/si64/ui64 +``` + +**Example — pto.vadd semantics:** + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] + src1[i]; +``` + +**Example — pto.vcgadd (group reduction per VLane) semantics:** + +```c +int K = N / 8; // elements per VLane +for (int g = 0; g < 8; g++) { + T sum = 0; + for (int i = 0; i < K; i++) + sum += src[g*K + i]; + dst[g*K] = sum; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +### Template Placeholder Conventions + +| Placeholder | Meaning | +|-------------|---------| +| `"SRC_PIPE"`, `"DST_PIPE"` | Pipeline identifiers: `"PIPE_MTE2"`, `"PIPE_V"`, `"PIPE_MTE3"` | +| `"EVENT_ID"` | Event identifier: `"EVENT_ID0"` etc. | +| `"DIST"` | Distribution mode string (see the relevant load/store ISA group in Part III) | +| `"CMP_MODE"` | Compare predicate: `eq \| ne \| lt \| le \| gt \| ge` | +| `"RND"` | Rounding mode: `R \| A \| F \| C \| Z \| O` | +| `"SAT"` | Saturation: `SAT \| NOSAT` | +| `"PART"` | Half selector: `EVEN \| ODD` | +| `"PAT_*"` | Predicate pattern literal | +| `T` | Element type (f32, f16, bf16, i32, i16, i8, etc.) | +| `N` | Lane count (`N * bitwidth(T) = 2048`) | + +--- + +### C-Style Semantics Convention + +For each ISA operation in Part III, semantics are expressed as C code. The convention: + +```c +// Vector register contents as arrays: +T dst[N]; // destination +T src0[N]; // first source +T src1[N]; // second source (binary ops) +T scalar; // scalar operand (vec-scalar ops) +int mask[N]; // per-lane predicate (0 or 1) + +// N = lane count determined by type: +// N = 256 for i8/si8/ui8 +// N = 128 for i16/si16/ui16/f16/bf16 +// N = 64 for i32/si32/ui32/f32 +// N = 32 for i64/si64/ui64 +``` + +**Example — pto.vadd semantics:** + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] + src1[i]; +``` + +**Example — pto.vcgadd (group reduction per VLane) semantics:** + +```c +int K = N / 8; // elements per VLane +for (int g = 0; g < 8; g++) { + T sum = 0; + for (int i = 0; i < K; i++) + sum += src[g*K + i]; + dst[g*K] = sum; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +### Template Placeholder Conventions + +| Placeholder | Meaning | +|-------------|---------| +| `"SRC_PIPE"`, `"DST_PIPE"` | Pipeline identifiers: `"PIPE_MTE2"`, `"PIPE_V"`, `"PIPE_MTE3"` | +| `"EVENT_ID"` | Event identifier: `"EVENT_ID0"` etc. | +| `"DIST"` | Distribution mode string (see the relevant load/store ISA group in Part III) | +| `"CMP_MODE"` | Compare predicate: `eq \| ne \| lt \| le \| gt \| ge` | +| `"RND"` | Rounding mode: `R \| A \| F \| C \| Z \| O` | +| `"SAT"` | Saturation: `SAT \| NOSAT` | +| `"PART"` | Half selector: `EVEN \| ODD` | +| `"PAT_*"` | Predicate pattern literal | +| `T` | Element type (f32, f16, bf16, i32, i16, i8, etc.) | +| `N` | Lane count (`N * bitwidth(T) = 2048`) | + +--- + +## Part III: ISA Instruction Reference — Summary + +This section provides a categorized overview of all PTO micro Instruction operations plus the shared MLIR `arith` and `scf` ops that may appear in PTO micro Instruction programs. Detailed documentation for each group is included later in this merged document. + +--- + +## Instruction Groups + +| # | Group | Description | Count | Details | +|---|-------|-------------|-------|---------| +| 1 | [Pipeline Sync](#isa-01-pipeline-sync) | Intra-core pipeline synchronization | 5 | `pto.set_flag`, `pto.wait_flag`, `pto.pipe_barrier`, `pto.get_buf`, `pto.rls_buf` | +| 2 | [DMA Copy Programming](#isa-02-dma-copy) | DMA configuration and transfer between GM↔UB | 9 | `pto.set_loop*_stride_*`, `pto.set_loop_size_*`, `pto.copy_gm_to_ubuf`, `pto.copy_ubuf_to_ubuf`, `pto.copy_ubuf_to_gm` | +| 3 | [Vector Load/Store](#isa-03-vector-load-store) | UB↔vreg data movement with various access patterns | ~20 | `pto.vlds`, `pto.vldsx2`, `pto.vgather2`, `pto.vsts`, `pto.vstsx2`, `pto.vscatter`, etc. | +| 4 | [Predicate Load/Store](#isa-04-predicate-load-store) | UB↔mask register movement | 5 | `pto.plds`, `pto.pldi`, `pto.psts`, `pto.psti`, `pto.pstu` | +| 5 | [Materialization & Predicate Ops](#isa-05-materialization-predicate) | Scalar broadcast, predicate generation and manipulation | ~17 | `pto.vbr`, `pto.vdup`, `pto.pset_b*`, `pto.pge_b*`, `pto.plt_b*`, `pto.ppack`, `pto.punpack`, `pto.pnot`, `pto.psel`, etc. | +| 6 | [Unary Vector Ops](#isa-06-unary-vector-ops) | Single-input element-wise operations | 6 | `pto.vabs`, `pto.vexp`, `pto.vln`, `pto.vsqrt`, `pto.vrelu`, `pto.vnot` | +| 7 | [Binary Vector Ops](#isa-07-binary-vector-ops) | Two-input element-wise operations | 13 | `pto.vadd`, `pto.vsub`, `pto.vmul`, `pto.vdiv`, `pto.vmax`, `pto.vmin`, `pto.vand`, `pto.vor`, `pto.vxor`, `pto.vshl`, `pto.vshr`, `pto.vaddc`, `pto.vsubc` | +| 8 | [Vec-Scalar Ops](#isa-08-vec-scalar-ops) | Vector-scalar operations | 9 | `pto.vadds`, `pto.vmuls`, `pto.vmaxs`, `pto.vmins`, `pto.vlrelu`, `pto.vshls`, `pto.vshrs`, `pto.vaddcs`, `pto.vsubcs` | +| 9 | [Conversion Ops](#isa-09-conversion-ops) | Type conversion with rounding/saturation control | 2 | `pto.vcvt`, `pto.vtrc` | +| 10 | [Reduction Ops](#isa-10-reduction-ops) | Vector reductions | 7 | `pto.vcadd`, `pto.vcmax`, `pto.vcmin`, `pto.vcgadd`, `pto.vcgmax`, `pto.vcgmin`, `pto.vcpadd` | +| 11 | [Compare & Select](#isa-11-compare-select) | Comparison and conditional selection | 4 (+1 not A5) | `pto.vcmp`, `pto.vcmps`, `pto.vsel`, `pto.vselr` (`pto.vselrv2` removed: not A5) | +| 12 | [Data Rearrangement](#isa-12-data-rearrangement) | In-register data movement and permutation | 2 (+2 not A5) | `pto.vintlv`, `pto.vdintlv` (`pto.vintlvv2`, `pto.vdintlvv2` removed: not A5) | +| 13 | [DSA/SFU Ops](#isa-13-dsa-sfu-ops) | Specialized ops, index generation, and sorting helpers | 9 | `pto.vlrelu`, `pto.vprelu`, `pto.vexpdif`, `pto.vaxpy`, `pto.vmull`, `pto.vmula`, `pto.vci`, `pto.vbitsort`, `pto.vmrgsort4` | +| 14 | [Arith (Shared MLIR Dialect)](#isa-14-shared-arith) | Full scalar `arith` surface used around PTO ops; the companion page lists categories and representative examples | all scalar ops | `arith.constant`, `arith.addi`, `arith.addf`, `arith.cmpi`, `arith.cmpf`, `arith.select`, `arith.index_cast`, `arith.extsi`, `arith.trunci`, `arith.andi`, `arith.shli`, etc. | +| 15 | [SCF (Shared MLIR Dialect)](#isa-15-shared-scf) | Structured loops, branches, and loop-carried state around PTO regions | 5 | `scf.for`, `scf.if`, `scf.while`, `scf.condition`, `scf.yield` | + +--- + +## Detailed ISA Group Reference + +This section inlines the 15 ISA group documents so the architectural overview, notation, summary table, and per-group semantics can be read in a single file. + + + +### 1. Pipeline Synchronization + +> **Category:** Synchronization primitives for coordinating pipeline execution +> **Pipelines:** MTE2 (GM→UB), PIPE_V (Vector), MTE3 (UB→GM) + +The PTO micro Instruction model operates on the Ascend 950's **Decoupled Access-Execute** architecture. The MTE and Vector pipelines run asynchronously, requiring explicit synchronization to prevent data hazards. + +--- + +#### Intra-Core Pipeline Sync + +These ops coordinate data flow between pipelines within a single vector core. + +##### `pto.set_flag` + +- **syntax:** `pto.set_flag["SRC_PIPE", "DST_PIPE", "EVENT_ID"]` +- **semantics:** Signal event from source pipe to destination pipe. + +```c +set_flag(src_pipe, dst_pipe, event_id); +``` + +**Example:** After MTE2 completes GM→UB transfer, signal Vector pipe: +```mlir +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +--- + +##### `pto.wait_flag` + +- **syntax:** `pto.wait_flag["SRC_PIPE", "DST_PIPE", "EVENT_ID"]` +- **semantics:** Block destination pipe until source pipe signals event. + +```c +wait_flag(src_pipe, dst_pipe, event_id); +``` + +**Example:** Vector pipe waits for MTE2 data to arrive: +```mlir +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] +``` + +--- + +##### `pto.pipe_barrier` + +- **syntax:** `pto.pipe_barrier "PIPE_*"` +- **semantics:** Drain all pending ops in the specified pipe. All previously issued operations on that pipe complete before any subsequent operation begins. + +```c +pipe_barrier(pipe); +``` + +**Pipe identifiers:** `PIPE_MTE2`, `PIPE_V`, `PIPE_MTE3` + +**Example:** Two back-to-back `copy_ubuf_to_gm` calls writing to the same GM address. Without a barrier, MTE3 may reorder them and the final GM value is non-deterministic: + +```mlir +// Both stores target the same GM address — order matters! +pto.copy_ubuf_to_gm %ub_partial_0, %gm_result, ... +// Without pipe_barrier, MTE3 could execute the second copy before the first +// completes, producing a non-deterministic result at %gm_result. +pto.pipe_barrier "PIPE_MTE3" +// After barrier: first copy is guaranteed complete. Second copy overwrites deterministically. +pto.copy_ubuf_to_gm %ub_partial_1, %gm_result, ... +``` + +--- + +##### `pto.get_buf` + +- **syntax:** `pto.get_buf "PIPE_*", %buf_id, %mode : i64, i64` +- **semantics:** Acquire buffer slot for inter-pipeline double-buffering coordination. + +```c +get_buf(pipe, buf_id, mode); +``` + +--- + +##### `pto.rls_buf` + +- **syntax:** `pto.rls_buf "PIPE_*", %buf_id, %mode : i64, i64` +- **semantics:** Release buffer slot to allow other pipeline to proceed. + +```c +rls_buf(pipe, buf_id, mode); +``` + +--- + +##### Mode Parameter for `get_buf` / `rls_buf` + +The `mode` parameter controls how `get_buf` and `rls_buf` interact with pipeline execution and dependency tracking: + +| Mode | `get_buf` Behavior | `rls_buf` Behavior | Use Case | +|------|-------------------|-------------------|----------| +| **0** (default) | **Blocking acquire**: waits for all previous `rls_buf` with same `buf_id` from all pipelines (in program order) before the specified pipe can proceed | **Immediate release**: signals completion for only the instructions related to the specified pipe | **Automatic ping/pong dependency** — recommended for double/multi-buffering | +| **1** | **Non-blocking acquire**: does not wait; pipe execution proceeds immediately | **Deferred release**: waits for all instructions across all pipelines with same `buf_id` to retire before signaling | **Backward compatibility** with `set_flag`/`wait_flag` semantics | + +**Mode 0 (Default — Recommended):** +- `get_buf`: The specified pipeline blocks until all previous `rls_buf` operations for the same buffer ID (from any pipeline) have completed, respecting program order. +- `rls_buf`: Immediately signals that the specified pipeline has finished using the buffer — only waits for that pipe's related instructions. +- This mode provides **automatic RAW/WAR/WAW dependency resolution** based on buffer ID and program order, making it ideal for ping/pong and N-buffer patterns. + +**Mode 1 (Legacy Compatibility):** +- `get_buf`: Does not block — the pipeline proceeds immediately without waiting. +- `rls_buf`: Waits for **all** previous instructions across **all** pipelines with the same buffer ID to retire before signaling release. +- This mode emulates `set_flag`/`wait_flag` behavior and is provided for backward compatibility with existing code patterns. + +> **Note:** A5 supports both `set_flag`/`wait_flag` and `get_buf`/`rls_buf` mechanisms. Mode 1 is rarely needed since mode 0 provides a more programmer-friendly approach for buffer-based synchronization. + +--- + +##### `pto.mem_bar` + +- **syntax:** `pto.mem_bar "BARRIER_TYPE"` +- **semantics:** Intra-vector-pipe memory fence within `__VEC_SCOPE__`. Required when UB addresses alias between vector load/store operations. + +```c +mem_bar(barrier_type); +``` + +**Barrier types:** + +| Type | Semantics | +|------|-----------| +| `VV_ALL` | All prior vector ops complete before subsequent | +| `VST_VLD` | All prior vector stores visible before subsequent loads | +| `VLD_VST` | All prior vector loads complete before subsequent stores | + +**Example:** Ensure stores are visible before loads to same UB region: +```mlir +pto.vsts %v0, %ub[%c0] : !pto.vreg<64xf32>, !pto.ptr +pto.mem_bar "VST_VLD" +%v1 = pto.vlds %ub[%c0] : !pto.ptr -> !pto.vreg<64xf32> +``` + +--- + +#### Why `get_buf` / `rls_buf` is More Programmer-Friendly + +The buffer-based synchronization (`get_buf`/`rls_buf`) provides the **same functional capability** as `set_flag`/`wait_flag` for maintaining correct ordering of RAW/WAR/WAW dependencies across pipelines, but with significant usability advantages: + +##### 1. No Manual Priming or Draining + +With `set_flag`/`wait_flag`, ping/pong loops require: +- **Pre-loop priming**: 4× `set_flag` to initialize reverse-dependency signals (otherwise first iteration deadlocks) +- **Post-loop draining**: 4× `wait_flag` to consume leftover signals from final iterations + +With `get_buf`/`rls_buf`: +- **First iteration**: Buffer is initially free, so `get_buf` proceeds immediately — no priming needed +- **Final iteration**: Last `rls_buf` simply completes — no draining required + +##### 2. No Loop Peeling for Complex Dependencies + +For non-1:1 producer-consumer ratios (e.g., 1 MTE2 load : N Vector compute slices), `set_flag`/`wait_flag` requires **peeling the set_flag outside the loop**: + +```mlir +// set_flag/wait_flag: 1 MTE2 load, 8 Vector computes on slices +// MTE2 loads large tile once +pto.copy_gm_to_ubuf %gm_ptr, %ub_tile, ... +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVT_TILE_READY"] // ◀ MUST be outside loop + +// Vector consumes in 8 slices — but wait_flag can only fire ONCE +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVT_TILE_READY"] // ◀ MUST peel before loop +scf.for %slice = %c0 to %c8 step %c1 { + // compute on %ub_tile[%slice] + // Cannot put wait_flag here — would deadlock on iteration 1+ +} +``` + +With `get_buf`/`rls_buf`, acquire/release can be **inside the loop** — no peeling needed: + +```mlir +// get_buf/rls_buf: same 1:8 pattern, acquire/release inside loop works fine +// MTE2 loads large tile +pto.get_buf "PIPE_MTE2", %bufid_tile, %c0 : i64, i64 +pto.copy_gm_to_ubuf %gm_ptr, %ub_tile, ... +pto.rls_buf "PIPE_MTE2", %bufid_tile, %c0 : i64, i64 + +// Vector acquires/releases per slice — all 8 iterations work correctly +scf.for %slice = %c0 to %c8 step %c1 { + pto.get_buf "PIPE_V", %bufid_tile, %c0 : i64, i64 // iteration 0: blocks until MTE2 done + // iteration 1-7: proceeds immediately (already acquired) + // compute on %ub_tile[%slice] + pto.rls_buf "PIPE_V", %bufid_tile, %c0 : i64, i64 +} +// No peeling required — get_buf handles the MTE2→V dependency automatically +``` + +##### 3. Simpler Mental Model + +| Aspect | `set_flag`/`wait_flag` | `get_buf`/`rls_buf` | +|--------|------------------------|---------------------| +| **Dependency tracking** | Manual: track event IDs, signal directions, pair every set with wait | Automatic: buffer ID + program order | +| **Event ID management** | **8 IDs per pipe-pair direction** (HW limit); each buffer occupies 1 ID per direction | **1 buffer ID per shared resource** (HW limit: 32 global); same ID used across all pipelines | +| **Error-prone areas** | Forgetting prime/drain, mismatched IDs, wrong direction | Forgetting release (but compile-time checkable) | + +##### Quick Example Comparison + +**Problem:** MTE2 loads into `buf[i%2]`, Vector processes, MTE3 stores — standard ping/pong. + +**set_flag/wait_flag approach:** +```mlir +// BEFORE loop: prime 4 reverse-dep signals +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_0"] +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_1"] +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_0"] +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_1"] + +scf.for %i = ... { + // 4 set_flag + 4 wait_flag inside loop + // Must track 4 IDs: 2 pipe-pair directions × 2 ping/pong buffers +} + +// AFTER loop: drain 4 signals +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_0"] +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_1"] +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_0"] +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_1"] +``` + +**get_buf/rls_buf approach:** +```mlir +scf.for %i = ... { + pto.get_buf %bufid_in[%pp], "PIPE_MTE2" + // ... MTE2 work ... + pto.rls_buf %bufid_in[%pp], "PIPE_MTE2" + + pto.get_buf %bufid_in[%pp], "PIPE_V" + pto.get_buf %bufid_out[%pp], "PIPE_V" + // ... Vector work ... + pto.rls_buf %bufid_in[%pp], "PIPE_V" + pto.rls_buf %bufid_out[%pp], "PIPE_V" + + pto.get_buf %bufid_out[%pp], "PIPE_MTE3" + // ... MTE3 work ... + pto.rls_buf %bufid_out[%pp], "PIPE_MTE3" +} +// Done. No prime. No drain. Dependencies resolved by buffer ID + program order. +``` + +--- + +#### Intra-Core Sync Patterns & Examples + +##### Example 1: `set_flag` / `wait_flag` (Explicit Events) + +Each cross-pipeline data dependency requires an explicit signal/wait pair. The programmer must manually insert `set_flag` after the producer and `wait_flag` before the consumer. + +```mlir +// ─── Stage 1: MTE2 loads data from GM into UB ─── +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, ... + +// MTE2 signals: "UB data is ready for Vector pipe" +pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +// ─── Stage 2: Vector pipe consumes UB data ─── +// Vector waits until MTE2's signal arrives +pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + +pto.vecscope { + %v = pto.vlds %ub_ptr[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} + +// Vector signals: "UB output is ready for MTE3" +pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + +// ─── Stage 3: MTE3 stores result from UB back to GM ─── +// MTE3 waits until Vector's signal arrives +pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + +pto.copy_ubuf_to_gm %ub_out, %gm_out, ... +``` + +**Key property:** Every cross-pipeline edge is an explicit `(set_flag, wait_flag)` pair. Simple for straight-line code, but gets verbose in loops (see Example 3). + +--- + +##### Example 2: `get_buf` / `rls_buf` (Resource-Based) + +Instead of naming events, each pipeline declares when it **acquires** (`get_buf`) and **releases** (`rls_buf`) a shared UB buffer. Cross-pipeline RAW/WAR dependencies are resolved implicitly by program order — if MTE2 releases `buf_A` and Vector later acquires `buf_A`, the hardware ensures the acquire cannot proceed until the release completes. + +```mlir +// ─── Stage 1: MTE2 loads data into UB ─── +// MTE2 acquires ub_ptr — blocks if Vector hasn't released it from a prior iteration +pto.get_buf "PIPE_MTE2", %bufid_ub_ptr, %c0 : i64, i64 // mode=0 (default) +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, ... +// MTE2 done writing ub_ptr — release it so Vector can consume +pto.rls_buf "PIPE_MTE2", %bufid_ub_ptr, %c0 : i64, i64 + +// ─── Stage 2: Vector computation ─── +// Vector acquires ub_ptr (input) — blocks until MTE2 releases it (RAW: MTE2 write → V read) +pto.get_buf "PIPE_V", %bufid_ub_ptr, %c0 : i64, i64 +// Vector acquires ub_out (output) — blocks until MTE3 releases it from a prior iteration (WAR: MTE3 read → V write) +pto.get_buf "PIPE_V", %bufid_ub_out, %c0 : i64, i64 + +pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub_ptr[%lane] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} + +// Vector done reading ub_ptr — release so MTE2 can reuse it in next iteration +pto.rls_buf "PIPE_V", %bufid_ub_ptr, %c0 : i64, i64 +// Vector done writing ub_out — release so MTE3 can consume +pto.rls_buf "PIPE_V", %bufid_ub_out, %c0 : i64, i64 + +// ─── Stage 3: MTE3 stores result to GM ─── +// MTE3 acquires ub_out — blocks until Vector releases it (RAW: V write → MTE3 read) +pto.get_buf "PIPE_MTE3", %bufid_ub_out, %c0 : i64, i64 +pto.copy_ubuf_to_gm %ub_out, %gm_out, ... +// MTE3 done reading ub_out — release so Vector can reuse it in next iteration +pto.rls_buf "PIPE_MTE3", %bufid_ub_out, %c0 : i64, i64 +``` + +**Key property:** No event IDs needed. Dependencies are implicit from program order of `get_buf`/`rls_buf` on the same buffer ID. This becomes much more convenient in multi-iteration loops (see Example 3). + +--- + +##### Example 3: Ping/Pong Double-Buffering Loop + +Double-buffering overlaps DMA and compute by using two UB buffers alternately. All three stages (MTE2, Vector, MTE3) appear in the **same iteration** — the hardware pipelines them across iterations because different iterations operate on different buffers (`buf[i%2]`). + +###### Event ID scheme (`set_flag` / `wait_flag`) + +With 2 ping/pong buffers and 2 pipeline pairs (MTE2↔V, V↔MTE3), `set_flag`/`wait_flag` needs **8 event IDs** = 2 pipe-pairs × 2 buffers × (forward + reverse): + +**MTE2 ↔ Vector (input buffers):** + +| Event ID | Direction | Purpose | +|----------|-----------|---------| +| `EVT_IN_FWD_0` | MTE2 → V | RAW: buf_in[0] data ready | +| `EVT_IN_FWD_1` | MTE2 → V | RAW: buf_in[1] data ready | +| `EVT_IN_REV_0` | V → MTE2 | WAR: Vector done reading buf_in[0] | +| `EVT_IN_REV_1` | V → MTE2 | WAR: Vector done reading buf_in[1] | + +**Vector ↔ MTE3 (output buffers):** + +| Event ID | Direction | Purpose | +|----------|-----------|---------| +| `EVT_OUT_FWD_0` | V → MTE3 | RAW: buf_out[0] result ready | +| `EVT_OUT_FWD_1` | V → MTE3 | RAW: buf_out[1] result ready | +| `EVT_OUT_REV_0` | MTE3 → V | WAR: MTE3 done reading buf_out[0] | +| `EVT_OUT_REV_1` | MTE3 → V | WAR: MTE3 done reading buf_out[1] | + +###### 3a. `set_flag` / `wait_flag` version + +```mlir +// ═══ Pre-loop: prime ALL reverse-dependency signals ═══ +// Both input and output buffers start unused. We must pre-send +// reverse-dep signals so the first iteration's wait_flags don't deadlock. +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_0"] // ◀ PRIME: buf_in[0] "free" +pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_1"] // ◀ PRIME: buf_in[1] "free" +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_0"] // ◀ PRIME: buf_out[0] "free" +pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_1"] // ◀ PRIME: buf_out[1] "free" + +scf.for %i = %c0 to %N step %c1 { + // ── All 3 stages in same iteration, indexed by i%2 ── + // %pp = i % 2 (ping/pong selector for buffer & event IDs) + + // ── MTE2: load tile[i] into buf_in[i%2] ── + // WAR: wait until Vector has released buf_in[i%2] from iteration i-2 + pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{pp}"] + pto.copy_gm_to_ubuf %gm_ptr[%i], %ub_in[%pp], ... + // RAW: signal Vector that buf_in[i%2] data is ready + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVT_IN_FWD_{pp}"] + + // ── Vector: compute buf_in[i%2] → buf_out[i%2] ── + // RAW: wait for MTE2 to finish loading buf_in[i%2] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVT_IN_FWD_{pp}"] + // WAR: wait for MTE3 to finish reading buf_out[i%2] from iteration i-2 + pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{pp}"] + scf.for %dummy = %c0 to %c1 step %c1 { + %v = pto.vlds %ub_in[%pp][%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%pp][%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } {llvm.loop.aivector_scope} + // WAR: tell MTE2 "done reading buf_in[i%2]" + pto.set_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{pp}"] + // RAW: tell MTE3 "buf_out[i%2] result ready" + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVT_OUT_FWD_{pp}"] + + // ── MTE3: store result from buf_out[i%2] to GM ── + // RAW: wait for Vector to finish writing buf_out[i%2] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVT_OUT_FWD_{pp}"] + pto.copy_ubuf_to_gm %ub_out[%pp], %gm_out[%i], ... + // WAR: tell Vector "done reading buf_out[i%2]" + pto.set_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{pp}"] +} + +// ═══ Post-loop: drain — match every pre-loop prime with a wait ═══ +// Each priming set_flag must be paired. The last loop iteration's +// set_flags are consumed by wait_flags that will never fire inside the +// loop (there is no iteration i+2). Drain them here. +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{(N-1)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVT_IN_REV_{(N-2)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{(N-1)%2}"] // ◀ DRAIN +pto.wait_flag["PIPE_MTE3", "PIPE_V", "EVT_OUT_REV_{(N-2)%2}"] // ◀ DRAIN +``` + +**What `set_flag`/`wait_flag` requires outside the loop:** +- **Before the loop (4 × `set_flag`):** Prime every reverse-dependency event ID — one per buffer per pipe-pair. Without this, the first iteration's `wait_flag` for reverse deps would deadlock (no signal was ever sent). +- **After the loop (4 × `wait_flag`):** Drain the matching reverse-dep signals from the last iterations. Every `set_flag` must be paired with a `wait_flag` — the last loop iterations produce signals that no subsequent iteration consumes, so they must be drained explicitly. + +###### 3b. `get_buf` / `rls_buf` version + +Same ping/pong double-buffering, but **no pre-loop priming or post-loop draining needed.** Buffer acquire/release semantics handle everything. + +```mlir +scf.for %i = %c0 to %N step %c1 { + // %pp = i % 2 (ping/pong selector) + + // ── MTE2: load tile[i] into buf[i%2] ── + // Acquires buf[i%2] — on first iteration, buffer is free so proceeds immediately. + // On later iterations, blocks until Vector releases buf[i%2] (WAR: automatic). + pto.get_buf "PIPE_MTE2", %bufid_buf[%pp], %c0 : i64, i64 // mode=0 + pto.copy_gm_to_ubuf %gm_ptr[%i], %ub_buf[%pp], ... + pto.rls_buf "PIPE_MTE2", %bufid_buf[%pp], %c0 : i64, i64 + + // ── Vector: compute on buf[i%2] ── + // Acquires buf[i%2] — blocks until MTE2 releases it (RAW: automatic) + pto.get_buf "PIPE_V", %bufid_buf[%pp], %c0 : i64, i64 + pto.get_buf "PIPE_V", %bufid_out[%pp], %c0 : i64, i64 + scf.for %dummy = %c0 to %c1 step %c1 { + %v = pto.vlds %ub_buf[%pp][%lane] : !pto.ptr -> !pto.vreg<64xf32> + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%pp][%lane], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } {llvm.loop.aivector_scope} + // Release buf[i%2] — MTE2 can reuse in iteration i+2 (WAR resolved) + pto.rls_buf "PIPE_V", %bufid_buf[%pp], %c0 : i64, i64 + pto.rls_buf "PIPE_V", %bufid_out[%pp], %c0 : i64, i64 + + // ── MTE3: store result ── + // Acquires out[i%2] — blocks until Vector releases it (RAW: automatic) + pto.get_buf "PIPE_MTE3", %bufid_out[%pp], %c0 : i64, i64 + pto.copy_ubuf_to_gm %ub_out[%pp], %gm_out[%i], ... + pto.rls_buf "PIPE_MTE3", %bufid_out[%pp], %c0 : i64, i64 +} +// No post-loop drain needed — last rls_buf completes the pipeline. +``` + +**No priming, no draining, no event IDs.** The acquire/release protocol on buffer IDs indexed by `i%2` implicitly resolves all cross-pipeline dependencies: +- **RAW** (MTE2→V): Vector's `get_buf` blocks until MTE2's `rls_buf` on `buf[i%2]` +- **WAR** (V→MTE2): MTE2's `get_buf` in iteration `i+2` blocks until Vector's `rls_buf` in iteration `i` (same buffer) +- **First iteration:** Buffer is initially free, so `get_buf` proceeds without blocking — no priming needed + +--- + +#### Comparison Summary + +| Aspect | `set_flag` / `wait_flag` | `get_buf` / `rls_buf` | +|--------|--------------------------|------------------------| +| Dependency model | Explicit event signals | Implicit via buffer acquire/release | +| IDs per pipe-pair | 2 IDs per buffer: 1 for forward (e.g., MTE2→V) + 1 for reverse (V→MTE2) | **1 ID per buffer** (handles both directions automatically) | +| Total HW IDs | **8 per pipe-pair** (hardware limit) | **32 global** across all pipes | +| Reverse (WAR) deps | Extra `set_flag`/`wait_flag` pair per buffer | Handled automatically | +| Pre-loop setup | `set_flag` to prime each reverse dep | **None** | +| Post-loop teardown | `wait_flag` to drain all primed signals | **None** | +| Loop peeling for complex deps | Required for non-1:1 or nested loops | **Not required** | +| Straight-line code | Simple, clear | Slightly more verbose (bracket each stage) | +| Ping/pong loops | 8 event IDs + 4 prime + 4 drain | Same pattern, **no overhead** | +| Best used for | Simple pipelines, fine-grained control | Double/multi-buffering, complex loops | + +--- + +#### Inter-Core Sync + +> **Note:** Inter-core sync is only needed for **mixed Cube+Vector tasks** where Cube produces data that Vector consumes (or vice versa). **Vec-only tasks can ignore this section entirely.** + +These ops coordinate execution across the Cube block and Vector subblocks within a cluster. Each core cluster consists of **1 Cube block : 2 Vector subblocks**, each with its own **SU (Sequencer Unit)** running independent instruction streams. + +``` +Core Cluster (1:2 ratio) +┌─────────────────────────────────────────────┐ +│ ┌──────────────┐ ┌──────────────┐ │ +│ │ AIC (Cube) │ │ AIV0 (Vec) │ │ +│ │ ┌────────┐ │ │ ┌────────┐ │ │ +│ │ │ SU │──┼────┼──│ SU │ │ │ +│ │ └────────┘ │ │ └────────┘ │ │ +│ │ CUBE pipe │ │ MTE2/V/MTE3 │ │ +│ │ L0C buffer │ │ UB (256KB) │ │ +│ └──────────────┘ └──────────────┘ │ +│ ┌──────────────┐ │ +│ │ AIV1 (Vec) │ │ +│ │ ┌────────┐ │ │ +│ │ │ SU │ │ │ +│ │ └────────┘ │ │ +│ │ MTE2/V/MTE3 │ │ +│ │ UB (256KB) │ │ +│ └──────────────┘ │ +└─────────────────────────────────────────────┘ +``` + +##### Platform Comparison + +| Aspect | A2A3 (Ascend 910) | A5 (Ascend 950) | +|--------|-------------------|-----------------| +| **Signal op** | `set_cross_core` (mode2) | `set_intra_block` | +| **Wait op** | `wait_flag_dev` | `wait_intra_core` | +| **Wait behavior** | SU-level blocking (entire core stalls) | Per-pipeline (only named pipe stalls) | +| **Semaphore pool** | 16 IDs per cluster, 4-bit counter | 16 IDs, but 32-ID address space (see below) | +| **C→V** | **Broadcast**: one `set` reaches both AIV0+AIV1 | **1:1**: separate `set` per subblock required | +| **V→C** | **Reduce**: Cube waits for both subblocks in one `wait` | **1:1**: Cube needs separate `wait` per subblock | + +##### A2A3: `set_cross_core` / `wait_flag_dev` + +```c +// mode2 broadcast/reduce semantics for 1:2 cluster +set_cross_core(pipe, semaphore_id); // pipe: VEC/MTE2/CUBE/FIX +wait_flag_dev(semaphore_id); // SU-level blocking +``` + +``` +C→V Broadcast (one set reaches both): + AIC ──set_cross_core──┬──> AIV0 sema++ + └──> AIV1 sema++ + +V→C Reduce (one wait for both): + AIV0 ──set_cross_core──┐ + ├──> AIC wait_flag_dev (blocks until BOTH) + AIV1 ──set_cross_core──┘ +``` + +##### `pto.set_cross_core` + +- **syntax:** `pto.set_cross_core %core_id, %event_id : i64, i64` +- **semantics:** Signal event to another core. Uses **mode2** for 1:2 cluster on A2A3. + +##### `pto.wait_flag_dev` + +- **syntax:** `pto.wait_flag_dev %core_id, %event_id : i64, i64` +- **semantics:** Wait for event from another core. **SU-level blocking** — entire core stalls. + +##### A5: `set_intra_block` / `wait_intra_core` + +```c +set_intra_block(trigger_pipe, semaphore_id); +wait_intra_core(wait_pipe, semaphore_id); // only named pipe stalls +``` + +**A5 semaphore address space:** The hardware has **16 physical semaphore IDs** but exposes a **32-ID address space** to support 1:1 signaling to each subblock: + +| ID Range | Target | +|----------|--------| +| 0–15 | AIV0 (subblock 0) | +| 16–31 (+15 offset) | AIV1 (subblock 1) | + +This means C→V requires **separate `set_intra_block` calls** per subblock (no broadcast), and V→C requires **separate `wait_intra_core` calls** per subblock (no hardware reduce). + +``` +C→V on A5 (1:1, no broadcast — need two sets): + AIC ──set_intra_block(pipe, sema_id)────> AIV0 + AIC ──set_intra_block(pipe, sema_id+15)──> AIV1 + +V→C on A5 (1:1, no reduce — need two waits): + AIV0 ──set_intra_block──> AIC wait_intra_core(pipe, sema_id) + AIV1 ──set_intra_block──> AIC wait_intra_core(pipe, sema_id+15) // extra wait +``` + +##### `pto.set_intra_block` + +- **syntax:** `pto.set_intra_block %block_id, %event_id : i64, i64` +- **semantics:** Signal event within a block (A5). Specifies **trigger pipe**. 1:1 per subblock. + +##### `pto.wait_intra_core` + +- **syntax:** `pto.wait_intra_core %block_id, %event_id : i64, i64` +- **semantics:** Wait for event within block (A5). Specifies **which pipeline should wait** — only that pipe stalls, SU and other pipes continue. + +##### Wait Granularity Comparison + +``` +A2A3 wait_flag_dev (SU-level stall): + SU ──┬── PIPE_MTE2 ───╳ ALL STALLED + ├── PIPE_V ───╳ ALL STALLED + └── PIPE_MTE3 ───╳ ALL STALLED + +A5 wait_intra_core "PIPE_MTE2" (per-pipe stall): + SU ──┬── PIPE_MTE2 ───╳ STALLED (waiting for Cube) + ├── PIPE_V ─── ✓ RUNNING + └── PIPE_MTE3 ─── ✓ RUNNING +``` + + + +### 2. DMA Copy Programming + +> **Category:** DMA transfer configuration and execution +> **Pipelines:** MTE2 (GM→UB), MTE3 (UB→GM) + +DMA transfers move data between Global Memory (GM) and Unified Buffer (UB). The MTE engines operate asynchronously from the Vector core, requiring explicit sync (see [Pipeline Sync](#isa-01-pipeline-sync)). + +The MTE2/MTE3 DMA engine executes a **multi-level nested loop** transfer. Before issuing the copy instruction, stride and loop-size registers must be configured. + +--- + +#### Loop Stride Configuration (GM→UB) + +These ops configure the MTE2 DMA engine's hardware loops for GM→UB transfers. They must be set **before** calling `pto.copy_gm_to_ubuf`. + +##### `pto.set_loop_size_outtoub` + +- **syntax:** `pto.set_loop_size_outtoub %loop1_count, %loop2_count : i64, i64` +- **semantics:** Configure HW loop iteration counts for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%loop1_count` | 21 bits | Inner HW loop iteration count | +| `%loop2_count` | 21 bits | Outer HW loop iteration count | + +When not using multi-level looping, set both to 1. + +--- + +##### `pto.set_loop2_stride_outtoub` + +- **syntax:** `pto.set_loop2_stride_outtoub %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure outer loop (loop2) pointer advance for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 40 bits | GM source pointer advance per loop2 iteration (bytes) | +| `%dst_stride` | 21 bits | UB destination pointer advance per loop2 iteration (bytes) | + +After each loop2 iteration, the DMA engine advances the GM read pointer by `%src_stride` and UB write pointer by `%dst_stride`. + +--- + +##### `pto.set_loop1_stride_outtoub` + +- **syntax:** `pto.set_loop1_stride_outtoub %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure inner loop (loop1) pointer advance for GM→UB DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 40 bits | GM source pointer advance per loop1 iteration (bytes) | +| `%dst_stride` | 21 bits | UB destination pointer advance per loop1 iteration (bytes) | + +--- + +#### Loop Stride Configuration (UB→GM) + +These ops configure the MTE3 DMA engine's hardware loops for UB→GM transfers. They must be set **before** calling `pto.copy_ubuf_to_gm`. + +Note: UB stride fields are 21 bits (sufficient for 256KB UB address space), GM stride fields are 40 bits (full GM address range). + +##### `pto.set_loop_size_ubtoout` + +- **syntax:** `pto.set_loop_size_ubtoout %loop1_count, %loop2_count : i64, i64` +- **semantics:** Configure HW loop iteration counts for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%loop1_count` | 21 bits | Inner HW loop iteration count | +| `%loop2_count` | 21 bits | Outer HW loop iteration count | + +--- + +##### `pto.set_loop2_stride_ubtoout` + +- **syntax:** `pto.set_loop2_stride_ubtoout %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure outer loop (loop2) pointer advance for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 21 bits | UB source pointer advance per loop2 iteration (bytes) | +| `%dst_stride` | 40 bits | GM destination pointer advance per loop2 iteration (bytes) | + +--- + +##### `pto.set_loop1_stride_ubtoout` + +- **syntax:** `pto.set_loop1_stride_ubtoout %src_stride, %dst_stride : i64, i64` +- **semantics:** Configure inner loop (loop1) pointer advance for UB→GM DMA. + +**Parameter Table:** + +| Parameter | Width | Description | +|-----------|-------|-------------| +| `%src_stride` | 21 bits | UB source pointer advance per loop1 iteration (bytes) | +| `%dst_stride` | 40 bits | GM destination pointer advance per loop1 iteration (bytes) | + +--- + +#### DMA Transfer Execution + +##### `pto.copy_gm_to_ubuf` + +- **syntax:** +```mlir +pto.copy_gm_to_ubuf %gm_src, %ub_dst, + %sid, %n_burst, %len_burst, %left_padding, %right_padding, + %data_select_bit, %l2_cache_ctl, %src_stride, %dst_stride + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` +- **semantics:** DMA transfer from Global Memory (`!pto.ptr`) to Unified Buffer (`!pto.ptr`). + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%gm_src` | GM source pointer (`!pto.ptr`) | +| `%ub_dst` | UB destination pointer (`!pto.ptr`, 32B-aligned) | +| `%sid` | Stream ID (usually 0) | +| `%n_burst` | Number of burst rows (innermost loop count) | +| `%len_burst` | Contiguous bytes transferred per burst row | +| `%left_padding` | Left padding count (bytes) | +| `%right_padding` | Right padding count (bytes) | +| `%data_select_bit` | Padding / data-select control bit (`i1`) | +| `%l2_cache_ctl` | L2 cache allocate control (TBD — controls whether DMA allocates in L2 cache) | +| `%src_stride` | GM source stride: start-to-start distance between consecutive burst rows (bytes) | +| `%dst_stride` | UB destination stride: start-to-start distance between consecutive burst rows (bytes, 32B-aligned) | + +--- + +##### `pto.copy_ubuf_to_gm` + +- **syntax:** +```mlir +pto.copy_ubuf_to_gm %ub_src, %gm_dst, + %sid, %n_burst, %len_burst, %reserved, %dst_stride, %src_stride + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` +- **semantics:** DMA transfer from Unified Buffer (`!pto.ptr`) to Global Memory (`!pto.ptr`). MTE3 reads only `len_burst` bytes from each UB row (de-padding). + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%ub_src` | UB source pointer (`!pto.ptr`, 32B-aligned) | +| `%gm_dst` | GM destination pointer (`!pto.ptr`) | +| `%sid` | Stream ID (usually 0) | +| `%n_burst` | Number of burst rows | +| `%len_burst` | Contiguous bytes transferred per burst row | +| `%reserved` | Reserved field (set to 0) | +| `%dst_stride` | GM destination stride: start-to-start distance between consecutive burst rows (bytes) | +| `%src_stride` | UB source stride: start-to-start distance between consecutive burst rows (bytes, 32B-aligned) | + +--- + +##### `pto.copy_ubuf_to_ubuf` + +- **syntax:** +```mlir +pto.copy_ubuf_to_ubuf %source, %dest, %sid, %n_burst, %len_burst, %src_stride, %dst_stride + : !pto.ptr, !pto.ptr, i64 x5 +``` +- **semantics:** Copy within Unified Buffer. + +**Parameters:** + +| Parameter | Description | +|-----------|-------------| +| `%source` | UB source pointer | +| `%dest` | UB destination pointer | +| `%sid` | Stream ID | +| `%n_burst` | Number of bursts | +| `%len_burst` | Length per burst | +| `%src_stride` | Source stride | +| `%dst_stride` | Destination stride | + +--- + +#### Burst / Stride / Pad Model + +All A5 DMA addresses are **stride-based**: stride is the distance from the start of one row to the start of the next row (`stride >= lenBurst`). There is no separate "gap" parameter. + +##### Key Terms + +``` +burst = lenBurst contiguous bytes transferred per row +stride = distance (bytes) from start of row[r] to start of row[r+1] +pad = ub_stride - lenBurst, padded to the 32B alignment boundary +``` + +##### Alignment Constraints + +- **UB addresses** (both source and destination) must be **32-byte aligned**. +- **GM→UB padding**: When `data_select_bit = true`, each UB row is padded from `lenBurst` up to the **32B-aligned boundary** of `ub_stride` with `pad_val` (set via `set_mov_pad_val`). This ensures every UB row starts at a 32B-aligned offset. +- **UB→GM de-padding**: MTE3 reads `lenBurst` bytes from each 32B-aligned UB row (skipping any padding that was added during load), writing only valid data to GM. This effectively strips padding on store. + +##### 2D Diagram: GM→UB (pto.copy_gm_to_ubuf) + +``` +GM (source, `!pto.ptr`): + + |<--- src_stride (start-to-start) --->| + |<- len_burst ->| | +Row 0: [##DATA########]......................| +Row 1: [##DATA########]......................| +Row 2: [##DATA########]......................| + ... +Row N-1: [##DATA########] + +UB (destination, `!pto.ptr`, 32B-aligned): + + |<---------- dst_stride (32B-aligned) ---------->| + |<- len_burst ->|<- pad (to 32B boundary) ->| | +Row 0: [##DATA########][000000 PAD 000000000000000] +Row 1: [##DATA########][000000 PAD 000000000000000] +Row 2: [##DATA########][000000 PAD 000000000000000] + ... +Row N-1: [##DATA########][000000 PAD 000000000000000] + +N = n_burst +stride = start of row[r] to start of row[r+1] +pad = filled with pad_val to 32B boundary (data_select_bit=true) +[DATA] = valid data transferred by DMA +[PAD] = pad_val fill (set via set_mov_pad_val) +``` + +##### 2D Diagram: UB→GM (pto.copy_ubuf_to_gm) + +``` +UB (source, `!pto.ptr`, 32B-aligned start addr): + + |<---------- src_stride (32B-aligned) --------->| + |<- len_burst ->|<-- pad (ignored on read) -->| | +Row 0: [##DATA########][000 pad 000000000000000000] +Row 1: [##DATA########][000 pad 000000000000000000] +Row 2: [##DATA########][000 pad 000000000000000000] + ... +Row N-1: [##DATA########][000 pad 000000000000000000] + +GM (destination, `!pto.ptr`): + + |<--- dst_stride (start-to-start) --->| + |<- len_burst ->| | +Row 0: [##DATA########]......................| +Row 1: [##DATA########]......................| +Row 2: [##DATA########]......................| + ... +Row N-1: [##DATA########] + +N = n_burst +MTE3 reads only len_burst bytes from each UB row (de-padding). +Only len_burst bytes are written to each GM row. +``` + +--- + +#### Multi-Level Loop Semantics (C Code) + +The full DMA transfer is a nested loop. The HW loop registers (set before the copy) control the outer levels, and the copy instruction parameters control the innermost burst level. + +##### GM→UB Full Loop + +```c +// C equivalent of what the HW executes: +for (int j = 0; j < loop2_count; j++) { // HW outer loop + uint8_t *gm1 = gm_src + j * loop2_src_stride; + uint8_t *ub1 = ub_dst + j * loop2_dst_stride; + + for (int k = 0; k < loop1_count; k++) { // HW inner loop + uint8_t *gm2 = gm1 + k * loop1_src_stride; + uint8_t *ub2 = ub1 + k * loop1_dst_stride; + + for (int r = 0; r < n_burst; r++) { // burst engine + memcpy(ub2 + r * dst_stride, // UB dest row + gm2 + r * src_stride, // GM src row + len_burst); // contiguous bytes + if (data_select_bit) + memset(ub2 + r * dst_stride + len_burst, + pad_val, dst_stride - len_burst); + } + } +} +``` + +##### UB→GM Full Loop + +```c +// C equivalent: +for (int j = 0; j < loop2_count; j++) { + uint8_t *ub1 = ub_src + j * loop2_src_stride; + uint8_t *gm1 = gm_dst + j * loop2_dst_stride; + + for (int k = 0; k < loop1_count; k++) { + uint8_t *ub2 = ub1 + k * loop1_src_stride; + uint8_t *gm2 = gm1 + k * loop1_dst_stride; + + for (int r = 0; r < n_burst; r++) { + memcpy(gm2 + r * dst_stride, // GM dest row + ub2 + r * src_stride, // UB src row + len_burst); // contiguous bytes + } + } +} +``` + +--- + +#### Example 1: GM→UB — Load a 32×32 f32 Tile (Simple Case) + +Load a 32×32 f32 tile from GM into UB. This matches the `abs_kernel_2d` test case. + +``` +GM layout (32 × 32 f32, contiguous): + + |<- len_burst = 128B (32 × 4) ->| + |<- src_stride = 128B --------->| + +--[#######TILE#######]--+ row 0 + +--[#######TILE#######]--+ row 1 + ... + +--[#######TILE#######]--+ row 31 + +UB layout (32 × 32 f32, 32B-aligned, contiguous): + + |<- dst_stride = 128B (32B-aligned) ->| + +--[#######TILE#######]--+ row 0 + +--[#######TILE#######]--+ row 1 + ... + +--[#######TILE#######]--+ row 31 + + len_burst = 32 × 4 = 128 bytes + src_stride = 128 bytes (contiguous rows) + dst_stride = 128 bytes (already 32B-aligned, no padding) +``` + +```mlir +// Simple 2D load — no multi-level loops needed +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 + +pto.copy_gm_to_ubuf %arg0, %ub_in, + %c0_i64, // sid = 0 + %c32_i64, // n_burst = 32 (32 rows) + %c128_i64, // len_burst = 128 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c128_i64, // src_stride = 128 bytes + %c128_i64 // dst_stride = 128 bytes + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +#### Example 2: GM→UB — Load a 2D Tile from a Larger Matrix + +Load a 64×128 tile (f16) from a 1024×512 matrix in GM into UB. + +``` +GM layout (1024 × 512 f16): + + col 0 col 128 col 512 + | | | + +--[###TILE###]+.....................+ row R + +--[###TILE###]+.....................+ row R+1 + ... + +--[###TILE###]+.....................+ row R+63 + + |<--------- src_stride = 1024B ----------->| + |<-len_burst=256B->| + + len_burst = 128 × 2 = 256 bytes (128 f16 elements) + src_stride = 512 × 2 = 1024 bytes (start-to-start, full GM row) + +UB layout (64 × 128 f16, 32B-aligned, contiguous): + + +--[###TILE###]--+ row 0 (256 bytes, 32B-aligned, no pad) + +--[###TILE###]--+ row 1 + ... + +--[###TILE###]--+ row 63 + + dst_stride = 256 bytes (= len_burst, already 32B-aligned, no padding) +``` + +```mlir +// Simple 2D load — no multi-level loops needed +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 (64 rows) + %c256_i64, // len_burst = 256 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c1024_i64, // src_stride = 1024 bytes (full matrix row) + %c256_i64 // dst_stride = 256 bytes (tile row) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +#### Example 3: GM→UB — Load with Padding + +Load 100 valid columns from GM into a 128-wide UB tile (f16). The remaining 28 columns are zero-padded. + +``` +GM (100 cols valid, contiguous): + + |<-len_burst=200B->| + |<- src_stride=200B (start-to-start) ->| + +--[####DATA####]-+ row 0 + +--[####DATA####]-+ row 1 + ... + +--[####DATA####]-+ row 63 + +UB (128 cols wide, 32B-aligned, padded): + + |<--------- dst_stride = 256B (32B-aligned) --------->| + |<-len_burst=200B->|<---- pad = 56B to 32B boundary ->| + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 0 + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 1 + ... + +--[####DATA####]-+[0000000 PAD 0000000000000000000000]+ row 63 + + len_burst = 100 × 2 = 200 bytes + src_stride = 200 bytes (start-to-start, contiguous in GM) + dst_stride = 128 × 2 = 256 bytes (32B-aligned tile width in UB) + pad = 256 - 200 = 56 bytes (padded to 32B boundary with pad_val) +``` + +```mlir +pto.set_loop_size_outtoub %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_outtoub %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 + %c200_i64, // len_burst = 200 bytes + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %true, // data_select_bit = true (enable padding) + %c0_i64, // l2_cache_ctl = 0 + %c200_i64, // src_stride = 200 bytes + %c256_i64 // dst_stride = 256 bytes (32B-aligned) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +--- + +#### Example 4: UB→GM — Store a 32×32 f32 Tile (Simple Case) + +Store a 32×32 f32 tile from UB back to GM. This matches the `abs_kernel_2d` test case. + +``` +UB (source, 32B-aligned, 32 × 32 f32): + + |<- src_stride = 128B (32B-aligned) ->| + |<- len_burst = 128B ->| + +--[#######TILE#######]---+ row 0 + +--[#######TILE#######]---+ row 1 + ... + +--[#######TILE#######]---+ row 31 + + (no padding here — len_burst == src_stride) + +GM (dest, 32 × 32 f32): + + |<- dst_stride = 128B ->| + |<- len_burst = 128B -->| + +--[#######TILE#######]---+ row 0 + +--[#######TILE#######]---+ row 1 + ... + +--[#######TILE#######]---+ row 31 +``` + +```mlir +// Configure MTE3 strides +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 + +pto.copy_ubuf_to_gm %ub_out, %arg1, + %c0_i64, // sid = 0 + %c32_i64, // n_burst = 32 + %c128_i64, // len_burst = 128 bytes + %c0_i64, // reserved = 0 + %c128_i64, // dst_stride = 128 bytes + %c128_i64 // src_stride = 128 bytes + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` + +--- + +#### Example 5: UB→GM — Store a 2D Tile Back to a Larger Matrix + +Store a 64×128 tile (f16) from UB back to a 1024×512 GM matrix at an offset. + +``` +UB (source, 32B-aligned, 64 × 128 f16): + + |<- src_stride = 256B (32B-aligned) ->| + |<- len_burst = 256B ->| + +--[#####TILE#####]---+ row 0 + +--[#####TILE#####]---+ row 1 + ... + +--[#####TILE#####]---+ row 63 + + (no padding here — len_burst == src_stride) + +GM (dest, into 1024 × 512 matrix): + + |<----------- dst_stride = 1024B (start-to-start) --------->| + |<- len_burst = 256B ->| | + col 0 col 128 col 512 + +--[#####TILE#####]---+.............................+ row R + +--[#####TILE#####]---+.............................+ row R+1 + ... + +--[#####TILE#####]---+.............................+ row R+63 + + MTE3 reads len_burst bytes from each 32B-aligned UB row, + writes only len_burst bytes per GM row (stride controls row spacing). +``` + +```mlir +// Configure MTE3 strides +pto.set_loop_size_ubtoout %c1_i64, %c1_i64 : i64, i64 +pto.set_loop1_stride_ubtoout %c0_i64, %c0_i64 : i64, i64 +pto.set_loop2_stride_ubtoout %c0_i64, %c0_i64 : i64, i64 + +pto.copy_ubuf_to_gm %ub_ptr, %gm_ptr, + %c0_i64, // sid = 0 + %c64_i64, // n_burst = 64 + %c256_i64, // len_burst = 256 bytes + %c0_i64, // reserved = 0 + %c1024_i64, // dst_stride = 1024 bytes (GM row) + %c256_i64 // src_stride = 256 bytes (UB row) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 +``` + +--- + +#### Example 6: GM→UB with Multi-Level Loop (Batch of Tiles) + +Load 4 batches of 8×128 tiles from a [4, 8, 128] f16 tensor using loop1. + +``` +GM [4, 8, 128] f16 (contiguous): UB (4 tiles laid out sequentially): + + batch 0: 8 rows × 256 bytes [batch 0: 8×128][batch 1: 8×128] + batch 1: 8 rows × 256 bytes [batch 2: 8×128][batch 3: 8×128] + batch 2: 8 rows × 256 bytes + batch 3: 8 rows × 256 bytes loop1 src_stride = 2048 bytes (8 × 256) + loop1 dst_stride = 2048 bytes (8 × 256) + Each batch = 8 × 256 = 2048 bytes loop1_count = 4 (iterate over batches) +``` + +```mlir +// loop1_count = 4 batches, loop2_count = 1 (not used) +pto.set_loop_size_outtoub %c4_i64, %c1_i64 : i64, i64 + +// loop1 stride: advance by one batch (2048 bytes) in both GM and UB +pto.set_loop1_stride_outtoub %c2048_i64, %c2048_i64 : i64, i64 +pto.set_loop2_stride_outtoub %c0_i64, %c0_i64 : i64, i64 + +pto.copy_gm_to_ubuf %gm_ptr, %ub_ptr, + %c0_i64, // sid = 0 + %c8_i64, // n_burst = 8 rows per batch + %c256_i64, // len_burst = 256 bytes per row + %c0_i64, // left_padding = 0 + %c0_i64, // right_padding = 0 + %false, // data_select_bit = false + %c0_i64, // l2_cache_ctl = 0 + %c256_i64, // src_stride = 256 (contiguous rows) + %c256_i64 // dst_stride = 256 (contiguous rows) + : !pto.ptr, !pto.ptr, i64, i64, i64, + i64, i64, i1, i64, i64, i64 +``` + +Execution trace: + +``` +loop1 iter 0: gm_ptr + 0×2048 → ub_ptr + 0×2048, DMA 8 rows × 256B +loop1 iter 1: gm_ptr + 1×2048 → ub_ptr + 1×2048, DMA 8 rows × 256B +loop1 iter 2: gm_ptr + 2×2048 → ub_ptr + 2×2048, DMA 8 rows × 256B +loop1 iter 3: gm_ptr + 3×2048 → ub_ptr + 3×2048, DMA 8 rows × 256B +``` + + + +### 3. Vector Load/Store + +> **Category:** UB ↔ Vector Register data movement +> **Pipeline:** PIPE_V (Vector Core) + +Vector loads move data from Unified Buffer (UB) to vector registers (`vreg`). Vector stores move data from `vreg` back to UB. All vector compute operates only on `vreg` — UB is the staging area between DMA and compute. + +#### Common Operand Model + +- `%source` / `%dest` is the base address operand in SSA form. The base pointer + MUST address the Vector tile buffer / UB space. +- `%offset` is the displacement operand in SSA form. The exact encoding is + instruction-specific, but the effective address and any post-update behavior + MUST match the selected instruction form. +- `%mask` is the predicate operand for predicated memory families. For memory + families, + inactive lanes or inactive blocks MUST NOT issue memory requests unless the + instruction explicitly documents a different behavior. +- `%result` is the destination vector register value in SSA form. +- `!pto.align` is the SSA carrier for alignment-buffer state used by unaligned + load/store families. The PTO micro Instruction representation makes that state explicit rather than implicit. + +--- + +#### Latency and throughput (A5) + +**Cycle-accurate simulator (CA model)** issue→retire timings for vector-side instructions behind this chapter. Values are **simulator** results, **not** guaranteed for silicon. + +**SOC:** Tables below are from **Ascend910_9599** CA sim (the pto-isa ST default when **Ascend950PR_9599** is not selected). + +**Log `dist:` tokens:** PTO load/store modes lower to **`RV_VLD` / `RV_VLDI` / `RV_VST` / `RV_VSTI`** with a **`dist:`** field on the vector pipes (`RVECLD` / `RVECST`). Some simulator logs typo contiguous load as `dist:NORAML`; treat as **`NORMAL`**. + +##### Reference op latencies (A5 mnemonics) + +| A5 mnemonic | Mode / note | Typical issue→retire (cycles) | +|-------------|-------------|------------------------------| +| `RV_VLD` | `dist:NORMAL` / `NORAML` | **9** | +| `RV_VLDI` | `dist:DINTLV` (dual vreg) | **9** | +| `RV_VST` / `RV_VSTI` | `dist:NORM` | **9** | +| `RV_VGATHER2` | `Dtype: B32` | **27–28** | +| `RV_VGATHERB` | indexed byte gather | **~21** | +| `RV_VSCATTER` | `Dtype: B16` | **~17** | +| `RV_VADD` | F32 between UB-backed ops | **7** | + +##### `dist:` tokens (issue→retire) + +Most **`dist:`** tokens are **9** issue→retire cycles. **`INTLV`** on **`RV_VSTI`** is **12** cycles. + +| `dist:` (as in log) | RV op | issue→retire (cycles) | +|---------------------|-------|----------------------| +| `DINTLV` | `RV_VLDI` | **9** | +| `BRC` | `RV_VLD` | **9** | +| `BRC_BLK` | `RV_VLD` | **9** | +| `INTLV` | `RV_VSTI` | **12** | +| `UNPK` | `RV_VLD` | **9** | +| `NORM` | `RV_VSTI` | **9** | +| `PK` | `RV_VSTI` | **9** | +| `NORMAL` / `NORAML` | `RV_VLD` | **9** | + +**Note:** PTO intrinsic **`BRC_BLK`** matches the **`BRC_BLK`** `dist:` string on **`RV_VLD`** in simulator logs (block-replicate path; not a plain contiguous copy in the usual tiling use). + +**Issue (vector load/store):** `pto.vlds` (**`RV_VLD`**) is **dual-issue capable**: two independent `pto.vlds` can issue **in the same cycle**. **Alternatively**, the hardware can issue **one** `pto.vlds` **and** **one** `pto.vsts` together (**1+1**) in the same cycle. Each cycle is **either** dual **`vlds`** **or** **`vlds` + `vsts` (1+1)**—those two issue modes are mutually exclusive. Sustained throughput still depends on RAW hazards and loop structure. + +**Throughput (simulator, pattern-dependent):** + +- **`RV_VLD` / `pto.vlds`:** Dual-issue **or** half of a **1+1** with `vsts`, per the rule above. +- **`RV_VST` / `pto.vsts`:** In a **1+1** cycle, pairs with one `vlds`; otherwise typically **one** store per cycle in tight loops. +- **`RV_VGATHER2`:** Much lower than contiguous `RV_VLD` (on the order of **~0.1** ops/cycle in steady-state alongside 27–28-cycle latency). + +##### PTO `dist` summary (loads) + +| PTO `dist` (load) | Latency | +|-------------------|-------------------| +| `NORM` | **9** cycles | +| `UNPK` | **9** cycles | +| `DINTLV` | **9** cycles (`RV_VLDI`) | +| `BRC` | **9** cycles (`RV_VLD`) | +| `BRC_BLK` | **9** cycles as **`dist:BRC_BLK`** on `RV_VLD` | +| `BDINTLV` | **9** cycles | +| `US`, `DS`, `SPLT4CHN`, `SPLT2CHN` | **9** cycles | + +##### PTO `dist` summary (stores) + +| PTO `dist` (store) | Latency | +|--------------------|-------------------| +| `NORM` | **9** cycles (`RV_VSTI`) | +| `PK` | **9** cycles | +| `INTLV` (`pto.vstx2`) | **12** cycles | +| `MRG4CHN`, `MRG2CHN` | **9** cycles | + +##### Gather, scatter, and special addressing + +| PTO op | A5-level | Latency | +|--------|----------|-------------------| +| `pto.vgather2` | `RV_VGATHER2` | **27–28** cycles (pattern-dependent) | +| `pto.vgatherb` | `RV_VGATHERB` | **~21** cycles issue→retire | +| `pto.vgather2_bc` | (broadcast gather) | **27–28** cycles (same as **`pto.vgather2`**) | +| `pto.vscatter` | `RV_VSCATTER` | **~17** cycles for **`Dtype: B16`** | + +##### Strided loads/stores, unaligned ops, alignment state + +Ops such as **`pto.vldas`**, **`pto.vldus`**, **`pto.vsld`**, **`pto.vsldb`**, **`pto.vsst`**, **`pto.vsstb`**, **`pto.vsta`**, **`pto.vstas`**, **`pto.vstar`**, **`pto.vstu`**, **`pto.vstus`**, **`pto.vstur`**: **9** cycles (same vector load/store pipe family as contiguous `RV_VLD` / `RV_VST` unless listed otherwise above). + +##### Dual-issue vs DMA + +DMA **`TLOAD` / `TSTORE`** (global memory ↔ UB) use **MTE** pipes, not `RV_VLD`/`RV_VST`. **MTE2** `MOV_*` latency is not the same as vector `RV_VLD` latency (see `02-dma-copy.md` for GM↔UB movement). + +--- + +#### Contiguous Loads + +##### `pto.vlds` + +- **syntax:** `%result = pto.vlds %source[%offset] {dist = "DIST"} : !pto.ptr -> !pto.vreg` +- **semantics:** Vector load with distribution mode. +- **inputs:** + `%source` is the UB base address, `%offset` is the load displacement, and + `DIST` selects the distribution mode. +- **outputs:** + `%result` is the loaded vector register value. +- **constraints and limitations:** + The effective address MUST satisfy the alignment rule of the selected + distribution mode. `NORM` reads one full vector footprint. Broadcast, + upsample, downsample, unpack, split-channel, and deinterleave modes change + how memory bytes are mapped into destination lanes, but they do not change the + fact that the source is UB memory. PTO surface exposes load `dist` as family + tokens, and each family only supports the element widths listed below. + +**Distribution families:** + +| Family | Allowed element widths | C semantics | Latency | +|------|-------------|-------------|-------------| +| `NORM` | width-agnostic | `dst[i] = UB[base + i * sizeof(T)]` | **9** cycles | +| `BRC` | `b8`, `b16`, `b32` | `dst[i] = UB[base]` for all `i` | **9** cycles | +| `US` | `b8`, `b16` | `dst[2*i] = dst[2*i+1] = UB[base + i]` | **9** cycles | +| `DS` | `b8`, `b16` | `dst[i] = UB[base + 2*i]` | **9** cycles | +| `UNPK` | `b8`, `b16`, `b32` | Expand packed source data into wider lanes | **9** cycles | +| `BRC_BLK` | width-agnostic | Block-replicate load path; simulator logs may print `dist:BRC_BLK` | **9** cycles | +| `E2B` | `b16`, `b32` | Load element groups and expand them into byte-oriented lane layout | **9** cycles | +| `UNPK4` | `b8` | Unpack 4-way packed `b8` source groups into destination lanes | **9** cycles | +| `SPLT4CHN` | `b8` | Split 4-channel interleaved source into one channel plane | **9** cycles | +| `SPLT2CHN` | `b8`, `b16` | Split 2-channel interleaved source into one channel plane | **9** cycles | + +`pto.vlds` currently covers only single-result load families. Dual-result +deinterleave forms are modeled separately in PTO surface as +[`pto.vldsx2`](#ptovldsx2): `BDINTLV` is the block-deinterleave family, while +`DINTLV` is the element-width-sensitive deinterleave family. + +**Example — Contiguous load:** +```mlir +%v = pto.vlds %ub[%offset] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> +``` + +**Example — Broadcast scalar to all lanes:** +```mlir +%v = pto.vlds %ub[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> +``` + +--- + +##### `pto.vldas` + +- **syntax:** `%result = pto.vldas %source : !pto.ptr -> !pto.align` +- **semantics:** Prime alignment buffer for subsequent unaligned load. +- **inputs:** + `%source` is the UB address whose surrounding aligned block seeds the load + alignment state. +- **outputs:** + `%result` is the initialized load-alignment state. +- **constraints and limitations:** + This op is the required leading operation for a `pto.vldus` stream using the + same alignment state. The source address itself need not be 32-byte aligned; + hardware truncates it to the aligned block boundary for the priming load. +- **Latency:** **9** cycles. + +--- + +##### `pto.vldus` + +- **syntax:** `%result, %align_out = pto.vldus %source, %align : !pto.ptr, !pto.align -> !pto.vreg, !pto.align` +- **semantics:** Unaligned load using primed align state. +- **inputs:** + `%source` is the current UB address and `%align` is the incoming load + alignment state primed by `pto.vldas` or a prior `pto.vldus`. +- **outputs:** + `%result` is the assembled vector value and `%align_out` is the updated + alignment state. +- **constraints and limitations:** + A matching `pto.vldas` MUST appear before the first dependent `pto.vldus` + stream in the same vector loop. The installed no-post A5 interface keeps a + struct-shaped internal return for lowering convenience, but its no-post + `base` field is not meaningful user-visible state. VPTO therefore hides that + value and only exposes the updated align carrier. Reusing the original + `%source` starts a new explicit access point; if the caller wants another + no-post access, it should compute the next source pointer explicitly and pair + it with the required align setup. +- **Latency:** **9** cycles. + +**Unaligned load pattern:** +```mlir +%align = pto.vldas %ub : !pto.ptr -> !pto.align +%vec, %align2 = pto.vldus %ub, %align : !pto.ptr, !pto.align -> !pto.vreg<64xf32>, !pto.align +``` + +--- + +##### `pto.init_align` + +- **syntax:** `%result = pto.init_align : !pto.align` +- **semantics:** Initialize store-side align carrier state. +- **outputs:** + `%result` is a fresh zero-initialized align carrier for store-side unaligned + streams such as `pto.vstus`, `pto.vstur`, `pto.vstar`, `pto.vstas`, and + `pto.pstu`. +- **constraints and limitations:** + This op is for store-family initialization only. Unaligned load streams still + start from `pto.vldas`. + +--- + +#### Dual Loads (Deinterleave) + +##### `pto.vldsx2` + +- **syntax:** `%low, %high = pto.vldsx2 %source[%offset], "DIST" : !pto.ptr, index -> !pto.vreg, !pto.vreg` +- **semantics:** Dual load with deinterleave (AoS → SoA conversion). +- **inputs:** + `%source` is the UB base pointer, `%offset` is the displacement, and `DIST` + selects a dual-load/deinterleave layout. +- **outputs:** + `%low` and `%high` are the two destination vectors. +- **constraints and limitations:** + This family is only legal for interleave/deinterleave style distributions. + The two outputs form an ordered pair, and that pairing MUST be preserved. + PTO surface accepts deinterleave families. `BDINTLV` is element-width + agnostic, while `DINTLV` supports only the element widths listed in the + table. +- **latency:** `BDINTLV` / `DINTLV` are both **9** cycles. + +**Distribution families:** + +| Family | Allowed element widths | C semantics | Latency | +|------|-------------|-------------|-------------| +| `BDINTLV` | width-agnostic | Block deinterleave into two destination vectors | **9** cycles | +| `DINTLV` | `b8`, `b16`, `b32` | Deinterleave alternating elements into `%low` / `%high` | **9** cycles | + +```c +// DINTLV family on 32-bit elements: deinterleave 32-bit elements +for (int i = 0; i < 64; i++) { + low[i] = UB[base + 8*i]; // even elements + high[i] = UB[base + 8*i + 4]; // odd elements +} +``` + +**Example — Load interleaved XY pairs into separate X/Y vectors:** +```mlir +%x, %y = pto.vldsx2 %ub[%offset], "DINTLV" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +``` + +##### `pto.vsldb` + +- **syntax:** `%result = pto.vsldb %source, %block_stride, %repeat_stride, %mask : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg` +- **semantics:** Block-strided load for 2D tile access. +- **inputs:** + `%source` is the UB base pointer. `%block_stride` and `%repeat_stride` are + the two 16-bit fields of the hardware control word, and `%mask` controls + which blocks participate. +- **outputs:** + `%result` is the loaded vector. +- **constraints and limitations:** + PTO surface does not expose the packed control word directly. If a block is + masked off, the corresponding destination block is zeroed and MUST NOT raise + an address overflow exception for that block. +- **Latency:** **9** cycles. + +```c +// Block-strided load on 32-bit elements: one 32B block = 8 lanes. +for (int blk = 0; blk < 8; ++blk) { + if (pg_b32[blk]) + dst_block[blk] = UB_block[base + repeat_stride + blk * block_stride]; + else + dst_block[blk] = 0; +} +``` + +--- + +#### Gather (Indexed) Loads + +##### `pto.vgather2` + +- **syntax:** `%result = pto.vgather2 %source, %offsets, %mask : !pto.ptr, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Indexed gather from UB. +- **inputs:** + `%source` is the UB base pointer, `%offsets` provides per-lane element + offsets, and `%mask` selects the active requests. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + Only masked-on indices participate. The index element width + and interpretation MUST match the selected gather form, and each effective + address must satisfy that form's alignment rules. +- **Latency:** **27–28** cycles per `RV_VGATHER2`; throughput much lower than contiguous `RV_VLD` (see **Latency and throughput (A5)** at the start of this chapter). + +```c +for (int i = 0; i < N; i++) + if (mask[i]) + dst[i] = UB[base + offsets[i] * sizeof(T)]; +``` + +--- + +##### `pto.vgatherb` + +- **syntax:** `%result = pto.vgatherb %source, %offsets, %mask : !pto.ptr, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Block gather load from UB. +- **inputs:** + `%source` is the UB base pointer, `%offsets` is a `ui32` offset vector, and + `%mask` is a `b32` predicate over the block-index lanes. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + This is a 32-byte block gather, not an element gather. `%source` MUST be + 32-byte aligned. Each participating `offsets[i]` is interpreted as a byte + offset and MUST itself be 32-byte aligned. Only the low `VL/8` bytes of the + offset vector are semantically valid; the effective block address is + `block_addr[i] = offsets_u32[i] + base`. If a `b32` predicate position is + false, the corresponding block does not participate in address coalescing, + does not raise overflow on that block address, and the destination block is + zero-filled. +- **Latency:** **~21** cycles issue→retire. + +```c +for (int blk = 0; blk < VL / 32; ++blk) { + if (pg_b32[blk]) + dst_block[blk] = UB_block[base + offsets_u32[blk]]; + else + dst_block[blk] = 0; +} +``` + +--- + +##### `pto.vgather2_bc` + +- **syntax:** `%result = pto.vgather2_bc %source, %offsets, %mask : !pto.ptr, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Gather with broadcast, conditioned by mask. +- **inputs:** + `%source` is the UB base pointer, `%offsets` contains gather indices, and + `%mask` gates which lanes participate. +- **outputs:** + `%result` is the gathered vector. +- **constraints and limitations:** + This is a backward-compatible family. Masked-off lanes do not participate in + address coalescing and do not trigger address overflow exceptions; their + destination lanes are zero-filled. On the current PTO surface, `%offsets` + uses 32-bit integer elements. +- **Latency:** **27–28** cycles (same as **`pto.vgather2`**). + +--- + +#### Contiguous Stores + +##### `pto.vsts` + +- **syntax:** `pto.vsts %value, %dest[%offset], %mask {dist = "DIST"} : !pto.vreg, !pto.ptr, !pto.mask` +- **semantics:** Vector store with distribution mode. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, `%offset` is + the displacement, `%mask` selects the active lanes or sub-elements, and + `DIST` selects the store distribution. +- **outputs:** + This op has no SSA result; it writes to UB memory. +- **constraints and limitations:** + The effective destination address MUST satisfy the alignment rule of the + selected store mode. The single-input `pto.vsts` family covers contiguous + store, first-element-only store, packed store, and channel-merge store. + Dual-input interleave store remains in `pto.vstsx2`. PTO surface exposes + store `dist` as family tokens, and each family only supports the element + widths listed below. + +**Distribution families:** + +| Family | Allowed element widths | C semantics | Latency | +|------|-------------|-------------|-------------| +| `NORM` | `b8`, `b16`, `b32` | `UB[base + i] = src[i]` | **9** cycles | +| `1PT` | `b8`, `b16`, `b32` | Only element 0 is written to the destination footprint | **9** cycles | +| `PK` | `b16`, `b32`, `b64` | Pack low half bits of each source element before store | **9** cycles | +| `PK4` | `b32` | Pack low 8 bits of each `b32` element before store | **9** cycles | +| `MRG4CHN` | `b8` | Merge 4 channel planes into an interleaved 4-channel layout | **9** cycles | +| `MRG2CHN` | `b8`, `b16` | Merge 2 channel planes into an interleaved 2-channel layout | **9** cycles | + +**Example — Contiguous store:** +```mlir +pto.vsts %v, %ub[%offset], %mask {dist = "NORM"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +--- + +#### Dual Stores (Interleave) + +##### `pto.vstsx2` + +- **syntax:** `pto.vstsx2 %low, %high, %dest[%offset], "DIST", %mask : !pto.vreg, !pto.vreg, !pto.ptr, index, !pto.mask` +- **semantics:** Dual interleaved store (SoA → AoS conversion). +- **inputs:** + `%low` and `%high` are the two source vectors, `%dest` is the UB base pointer, + `%offset` is the displacement, `DIST` selects the interleave layout, and + `%mask` gates the participating elements. +- **outputs:** + This op has no SSA result; it writes an interleaved stream to UB. +- **constraints and limitations:** + This family is only legal for interleave distributions. The two source + vectors form an ordered pair, and the interleave semantics of that pair MUST + be preserved. PTO surface accepts the `INTLV` family, which only supports the + element widths listed below. + be preserved. PTO surface accepts the `INTLV` family, which only supports the + element widths listed below. +- **latency:** `INTLV` is **12** cycles。 + +**Distribution families:** + +| Family | Allowed element widths | C semantics | Latency | +|------|-------------|-------------|-------------| +| `INTLV` | `b8`, `b16`, `b32` | Interleave `%low` / `%high` into one destination stream | **12** cycles | +| `INTLV` | `b8`, `b16`, `b32` | + +```c +// INTLV family on 32-bit elements: +for (int i = 0; i < 64; i++) { + UB[base + 8*i] = low[i]; + UB[base + 8*i + 4] = high[i]; +} +``` + +##### `pto.vsstb` + +- **syntax:** `pto.vsstb %value, %dest, %block_stride, %repeat_stride, %mask : !pto.vreg, !pto.ptr, i16, i16, !pto.mask` +- **semantics:** Block-strided store for 2D tile access. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, + `%block_stride` and `%repeat_stride` are the two 16-bit fields of the + hardware control word, and `%mask` controls block participation. +- **outputs:** + This op writes UB memory and returns no SSA value. +- **constraints and limitations:** + PTO surface does not expose the packed control word directly. Masked-off + blocks MUST NOT issue memory writes. +- **Latency:** **9** cycles. + +```c +// Block-strided store on 32-bit elements: one 32B block = 8 lanes. +for (int blk = 0; blk < 8; ++blk) { + if (pg_b32[blk]) + UB_block[base + repeat_stride + blk * block_stride] = src_block[blk]; +} +``` + +--- + +#### Scatter (Indexed) Stores + +##### `pto.vscatter` + +- **syntax:** `pto.vscatter %value, %dest, %offsets, %mask : !pto.vreg, !pto.ptr, !pto.vreg, !pto.mask` +- **semantics:** Indexed scatter to UB. +- **inputs:** + `%value` is the source vector, `%dest` is the UB base pointer, `%offsets` + provides per-lane or per-block indices, and `%mask` selects the active + requests. +- **outputs:** + This op writes UB memory and returns no SSA value. +- **constraints and limitations:** + Only `b8`, `b16`, and `b32` element sizes are supported. The index vector + must use a supported integer element type and layout for this family. + Each computed address MUST be element-aligned. If two or more indices alias, + only one write is guaranteed and the winning lane is implementation-defined. +- **Latency:** **~17** cycles for **`Dtype: B16`**. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) + UB[base + offsets[i] * sizeof(T)] = src[i]; +``` + +--- + +#### Alignment State Stores + +##### `pto.vstas` +- **syntax:** `pto.vstas %value, %dest, %offset : !pto.align, !pto.ptr, i32` +- **semantics:** Scalar-register-offset form of alignment-state flush. +- **inputs:** + `%value` is the pending store-alignment state, `%dest` is the UB base + pointer, and `%offset` is the scalar-register style displacement. +- **outputs:** + This op writes buffered tail bytes to UB and returns no SSA value. +- **constraints and limitations:** + This family flushes pending store-alignment state using an explicit scalar + offset and keeps the scalar-offset form explicit. The incoming `%value` + should come from `pto.init_align` or from a prior state-producing unaligned + store op in the same stream. `%dest` and `%offset` together must identify the + same logical flush point produced by the immediately preceding stateful + unaligned-store step on that stream; using an unrelated base/offset pair is + invalid even if `%value` itself came from the same stream. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstar` +- **syntax:** `pto.vstar %value, %dest : !pto.align, !pto.ptr` +- **semantics:** Flush alignment state using the register-update form. +- **inputs:** + `%value` is the pending store-alignment state and `%dest` is the UB base + pointer. +- **outputs:** + This op writes buffered tail bytes to UB and returns no SSA value. +- **constraints and limitations:** + The implicit update state consumed by this flush MUST correspond to the same + store stream that produced `%value`. The first store-side state in a stream + should be created by `pto.init_align`. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstar` + +- **syntax:** `pto.vstar %value, %dest : !pto.align, !pto.ptr` +- **semantics:** Flush remaining alignment state. +- **inputs:** + `%value` is the pending alignment/buffer state that still needs to be emitted, + and `%dest` is the UB destination base pointer. +- **outputs:** + No SSA result. The effect is a memory-side flush that writes the remaining + buffered bytes to memory. +- **constraints and limitations:** + This op terminates an unaligned-store sequence. It MUST be paired with a + compatible prior state-producing store sequence so that the pending tail state + is well-defined. +- **Latency:** **9** cycles. + +--- + +#### Stateful Store Ops + +These ops make reference-updated state explicit as SSA results. + +##### `pto.vstus` + +- **syntax:** `%align_out = pto.vstus %align_in, %offset, %value, %base : !pto.align, i32, !pto.vreg, !pto.ptr -> !pto.align` +- **semantics:** No-post unaligned store with scalar offset. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%offset` is the scalar + displacement, `%value` is the vector being stored, and `%base` is the UB base + pointer. +- **outputs:** + `%align_out` is the updated buffered-tail state. +- **constraints and limitations:** + This is the scalar-offset stateful form of the unaligned store family. The + scalar offset width MUST match the selected form, and a later flush op is + still required. The first `%align_in` in the stream should come from + `pto.init_align`. This op does not mean "store a full vector starting at + `%base + %offset`". Instead, `%offset` describes how far the store stream + advances at this step, and `%align_out` carries any residual tail that could + not be committed yet. The no-post surface does not expose an updated base + pointer. A later flush op must therefore use an explicit destination/offset + pair that identifies the same logical flush point as this `pto.vstus`. +- **Latency:** **9** cycles. + +--- + +##### `pto.vstur` + +- **syntax:** `%align_out = pto.vstur %align_in, %value, %base, "MODE" : !pto.align, !pto.vreg, !pto.ptr -> !pto.align` +- **semantics:** Unaligned store with residual flush and SPR-AR-driven state update. +- **inputs:** + `%align_in` is the incoming store-alignment state, `%value` is the vector to + store, `%base` is the UB base pointer, and `MODE` selects whether the + hardware updates `SPR AR` after the store. +- **outputs:** + `%align_out` is the updated residual state after the current partial store. +- **constraints and limitations:** + The effective address is `base + AR`, where `AR` is the hardware SPR state + carried outside SSA. `POST_UPDATE` means hardware may advance `SPR AR` + according to the fixed `SPR SQZN` configuration; `NO_POST_UPDATE` preserves + the current `SPR AR` value. This form exposes only the evolving residual + align-state in SSA; it does not by itself guarantee that all buffered bytes + have reached memory. A compatible final flush is still required unless the + surrounding sequence is known to be complete. Independent sequences typically + begin from `AR = 0`; if the surrounding program does not already guarantee + that, the hardware sequence should clear `SPR AR` before the first dependent + `pto.vstur`. The first `%align_in` in the stream should come from + `pto.init_align`. `pto.vstur` also consumes the fixed `SPR SQZN` state, so a + preceding squeeze producer such as `pto.vsqz` / `pto.vusqz` MUST establish + the byte count before the store. `MODE` MUST be one of `POST_UPDATE` or + `NO_POST_UPDATE`. +- **Latency:** **9** cycles. + + + +### 4. Predicate Load/Store + +> **Category:** UB ↔ Predicate Register data movement +> **Pipeline:** PIPE_V (Vector Core) + +Predicate registers (`!pto.mask`) are 256-bit registers that enable per-lane conditional execution. These ops move predicate values between UB and predicate registers. + +In concrete examples, `G` should be chosen to match the consumer family. The +examples below use `b32` when the loaded/stored mask is used with `f32` +vector compares or selects. + +The predicate load/store ops documented on this page always use explicit +`base[offset]` addressing. The immediate forms (`pldi`, `psti`) and dynamic +forms (`plds`, `psts`) differ only in how `%offset` is supplied. + +--- + +#### Predicate Loads + +##### `pto.plds` + +- **syntax:** `%result = pto.plds %source[%offset], "DIST" : !pto.ptr, index -> !pto.mask` +- **semantics:** Load predicate register with runtime offset. This is the + dynamic-offset form of `pto.pldi`: the predicate payload interpretation is + the same, but `%offset` is supplied as an SSA `index` instead of a constant + `index` immediate. +- **DIST:** mandatory string token, one of `NORM`, `US`, `DS`. + - `NORM`: load a normal packed predicate payload of size `VL/8`. + - `US`: load a packed predicate payload of size `VL/16`, then duplicate each + loaded bit once. + - `DS`: load a packed predicate payload of size `2 * VL/8`, then keep one + bit out of every two bits. + +The loaded payload is a packed predicate image in UB. Consumer ops interpret +the resulting `!pto.mask` according to the mask granularity `G`. +`pto.plds` only +models the explicit `base[offset]` form. + +**Example:** +```mlir +%mask = pto.plds %ub[%c0], "NORM" : !pto.ptr, index -> !pto.mask +``` + +--- + +##### `pto.pldi` + +- **syntax:** `%result = pto.pldi %source[%offset], "DIST" : !pto.ptr, index -> !pto.mask` +- **offset:** must be a constant `index` immediate in PTO surface form. +- **semantics:** Load predicate register with immediate offset. +- **DIST:** mandatory string token, one of `NORM`, `US`, `DS`. + - `NORM`: load a normal packed predicate payload of size `VL/8`. + - `US`: load a packed predicate payload of size `VL/16`, then duplicate each + loaded bit once. + - `DS`: load a packed predicate payload of size `2 * VL/8`, then keep one + bit out of every two bits. + +Like `pto.plds`, this op reads a packed predicate payload from UB and +materializes it as `!pto.mask`. + +--- + +#### Predicate Stores + +##### `pto.psts` + +- **syntax:** `pto.psts %value, %dest[%offset], "DIST" : !pto.mask, !pto.ptr, index` +- **semantics:** Store predicate register with runtime offset. This is the + dynamic-offset form of `pto.psti`: the predicate payload interpretation is + the same, but `%offset` is supplied as an SSA `index` instead of a constant + `index` immediate. +- **DIST:** mandatory string token, one of `NORM`, `PK`. + - `NORM`: store the packed predicate payload into a normal destination space + of size `VL/8`. + - `PK`: store the packed predicate payload into a destination space of size + `VL/16`, keeping one bit out of every two bits. + +`pto.psts` stores the packed predicate payload represented by `!pto.mask`. +It only models the explicit `base[offset]` form. + +**Example:** +```mlir +pto.psts %mask, %ub[%c0], "NORM" : !pto.mask, !pto.ptr, index +``` + +--- + +##### `pto.psti` + +- **syntax:** `pto.psti %value, %dest[%offset], "DIST" : !pto.mask, !pto.ptr, index` +- **offset:** must be a constant `index` immediate in PTO surface form. +- **semantics:** Store predicate register with immediate offset. +- **DIST:** mandatory string token, one of `NORM`, `PK`. + - `NORM`: store the packed predicate payload into a normal destination space + of size `VL/8`. + - `PK`: store the packed predicate payload into a destination space of size + `VL/16`, keeping one bit out of every two bits. + +`pto.psti` and `pto.psts` store the packed predicate payload represented by +`!pto.mask`. The surface distinction is only immediate-offset versus +dynamic-offset. + +--- + +##### `pto.pstu` + +- **syntax:** `%align_out, %base_out = pto.pstu %align_in, %value, %base : !pto.align, !pto.mask, !pto.ptr -> !pto.align, !pto.ptr` +- **syntax:** `%align_out, %base_out = pto.pstu %align_in, %value, %base : !pto.align, !pto.mask, !pto.ptr -> !pto.align, !pto.ptr` +- **semantics:** Predicate unaligned store with align/base state update. The base type is fixed by mask granularity: `b16 <-> ui16`, `b32 <-> ui32`. +- **outputs:** + `%align_out` and `%base_out` are the updated unaligned-store state and are + intended to be used by a later `pto.pstu` call. +- **constraints and limitations:** + The first `%align_in` in a predicate unaligned-store stream should come from + `pto.init_align`. + +--- + +#### Typical Usage Pattern + +```mlir +// Generate comparison mask +%mask = pto.vcmp %v0, %v1, %seed, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + +// Store mask to UB for later use +pto.psts %mask, %ub_mask[%c0], "NORM" : !pto.mask, !pto.ptr, index + +// ... later in another kernel ... + +// Load mask from UB +%saved_mask = pto.plds %ub_mask[%c0], "NORM" : !pto.ptr, index -> !pto.mask + +// Use for predicated select +%result = pto.vsel %v_true, %v_false, %saved_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 5. Materialization & Predicate Ops + +> **Category:** Scalar broadcast, predicate generation and manipulation +> **Pipeline:** PIPE_V (Vector Core) + +These ops create vectors from scalar values and manipulate predicate registers. + +#### Common Operand Model + +- `%value` is the scalar source value in SSA form. +- `%input` is either a source scalar or a source vector depending on the op. +- `%result` is the destination vector register value. +- For 32-bit scalar inputs, the scalar source MUST satisfy the backend's legal + scalar-source constraints for this family. + +--- + +#### Scalar Materialization + +##### `pto.vbr` + +- **syntax:** `%result = pto.vbr %value : T -> !pto.vreg` +- **semantics:** Broadcast scalar to all vector lanes. +- **inputs:** + `%value` is the scalar source. +- **outputs:** + `%result` is a vector whose active lanes all carry `%value`. +- **constraints and limitations:** + Supported forms are `b8`, `b16`, and `b32`. For `b8`, only the low 8 bits of + the scalar source are consumed. + +```c +for (int i = 0; i < N; i++) + dst[i] = value; +``` + +**Example:** +```mlir +%one = pto.vbr %c1_f32 : f32 -> !pto.vreg<64xf32> +``` + +--- + +##### `pto.vdup` + +- **syntax:** `%result = pto.vdup %input, %mask {position = "LOWEST|HIGHEST"} : T|!pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Duplicate scalar or vector element to all lanes. +- **inputs:** + `%input` supplies the scalar or source-lane value selected by `position`, + and `%mask` controls the active lanes. +- **outputs:** + `%result` is the duplicated vector. +- **constraints and limitations:** + `position` selects which source vector element is duplicated and is only valid + for vector input. `position` defaults to `LOWEST`. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? input_scalar_or_element : 0; +``` + +--- + +#### Predicate Generation + +##### `pto.pset_b8` / `pto.pset_b16` / `pto.pset_b32` + +- **syntax:** `%result = pto.pset_b8 "PATTERN" : !pto.mask` +- **syntax:** `%result = pto.pset_b16 "PATTERN" : !pto.mask` +- **syntax:** `%result = pto.pset_b32 "PATTERN" : !pto.mask` +- **semantics:** Materialize a predicate register from a named pattern token. + +**Supported pattern tokens:** + +| Pattern | Description | +|---------|-------------| +| `PAT_ALL` | All lanes active | +| `PAT_ALLF` | All lanes inactive | +| `PAT_H` | High half active | +| `PAT_Q` | Upper quarter active | +| `PAT_VL1`...`PAT_VL128` | First N logical lanes active | +| `PAT_M3`, `PAT_M4` | Modular patterns | + +`PAT_ALL` is the PTO spelling of the VISA-style all-true predicate pattern. +The other tokens listed above are also concrete installed-toolchain pattern +objects, not PTO-only aliases. + +**Example — All 64 f32 lanes active:** +```mlir +%all_active = pto.pset_b32 "PAT_ALL" : !pto.mask +``` + +**Example — First 16 lanes active:** +```mlir +%first_16 = pto.pset_b32 "PAT_VL16" : !pto.mask +``` + +--- + +##### `pto.pge_b8` / `pto.pge_b16` / `pto.pge_b32` + +- **syntax:** `%result = pto.pge_b8 "PATTERN" : !pto.mask` +- **syntax:** `%result = pto.pge_b16 "PATTERN" : !pto.mask` +- **syntax:** `%result = pto.pge_b32 "PATTERN" : !pto.mask` +- **semantics:** Generate a predicate from a lane-count pattern token. In the + common tail-mask form, `PAT_VL` marks the first `N` logical lanes active. +- **supported pattern tokens:** `PAT_ALL`, `PAT_ALLF`, `PAT_H`, `PAT_Q`, + `PAT_VL1`, `PAT_VL2`, `PAT_VL3`, `PAT_VL4`, `PAT_VL8`, `PAT_VL16`, + `PAT_VL32`, `PAT_VL64`, `PAT_VL128`, `PAT_M3`, `PAT_M4` + +```c +for (int i = 0; i < TOTAL_LANES; i++) + mask[i] = (i < len); +``` + +**Example — Tail mask for remainder loop:** +```mlir +%tail_mask = pto.pge_b32 "PAT_VL8" : !pto.mask +``` + +--- + +##### `pto.plt_b8` / `pto.plt_b16` / `pto.plt_b32` + +- **syntax:** `%mask, %scalar_out = pto.plt_b8 %scalar : i32 -> !pto.mask, i32` +- **syntax:** `%mask, %scalar_out = pto.plt_b16 %scalar : i32 -> !pto.mask, i32` +- **syntax:** `%mask, %scalar_out = pto.plt_b32 %scalar : i32 -> !pto.mask, i32` +- **semantics:** Generate a tail-style predicate from an SSA lane-count value. + On A5/V300-style toolchains, this family is exposed as a post-update wrapper: + the predicate result becomes `%mask`, and the wrapper's carry-out scalar state + is surfaced as `%scalar_out`. +- **inputs:** + `%scalar` is the incoming lane-count / remaining-count state. +- **outputs:** + `%mask` is the generated predicate. + `%scalar_out` is the post-update scalar carry-out from the same `plt` call + and can be threaded into a subsequent `pto.plt_b*` call in the same chain. + +```c +for (int i = 0; i < VL_t; ++i) + mask[i] = (i < scalar_in); + +scalar_out = (scalar_in < VL_t) ? 0 : (scalar_in - VL_t); +``` + +Where `VL_t` is the logical lane count of the concrete op variant: + +- `pto.plt_b8`: `VL_t = 256` +- `pto.plt_b16`: `VL_t = 128` +- `pto.plt_b32`: `VL_t = 64` + +--- + +#### Predicate Pack/Unpack + +##### `pto.ppack` + +- **syntax:** `%result = pto.ppack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Narrowing pack of predicate register. +- **part tokens:** + - `LOWER`: pack into the lower half of `%result`; the upper half is zeroed. + - `HIGHER`: pack into the higher half of `%result`; the lower half is zeroed. + +Conceptually, `pto.ppack` keeps one bit out of each adjacent 2-bit group from +`%input`, packs those kept bits into the selected half of `%result`, and fills +the other half with zeros. + +```c +// Let VL be the logical lane count of the destination predicate. +// LOWER +for (int i = 0; i < VL / 2; ++i) + result[i] = input[2 * i]; +for (int i = VL / 2; i < VL; ++i) + result[i] = 0; + +// HIGHER +for (int i = 0; i < VL / 2; ++i) + result[VL / 2 + i] = input[2 * i]; +for (int i = 0; i < VL / 2; ++i) + result[i] = 0; +``` + +--- + +##### `pto.punpack` + +- **syntax:** `%result = pto.punpack %input, "PART" : !pto.mask -> !pto.mask` +- **semantics:** Widening unpack of predicate register. +- **part tokens:** + - `LOWER`: unpack from the lower half of `%input`. + - `HIGHER`: unpack from the higher half of `%input`. + +Conceptually, `pto.punpack` reads the selected half of `%input`, zero-extends +each 1-bit predicate element into a 2-bit group in `%result`, and leaves the +expanded image in the full destination predicate register. + +```c +// Let VL be the logical lane count of the destination predicate. +// LOWER +for (int i = 0; i < VL / 2; ++i) { + result[2 * i] = input[i]; + result[2 * i + 1] = 0; +} + +// HIGHER +for (int i = 0; i < VL / 2; ++i) { + result[2 * i] = input[VL / 2 + i]; + result[2 * i + 1] = 0; +} +``` + +--- + +#### Predicate Logical Ops + +##### `pto.pand` + +- **syntax:** `%result = pto.pand %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise AND gated by a governing predicate. + +Inactive lanes selected out by `%mask` are zeroed. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? (src0[i] & src1[i]) : 0; +``` + +--- + +##### `pto.por` + +- **syntax:** `%result = pto.por %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise OR gated by a governing predicate. + +Inactive lanes selected out by `%mask` are zeroed. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? (src0[i] | src1[i]) : 0; +``` + +--- + +##### `pto.pxor` + +- **syntax:** `%result = pto.pxor %src0, %src1, %mask : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise XOR gated by a governing predicate. + +Inactive lanes selected by `%mask` are zeroed. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? (src0[i] ^ src1[i]) : 0; +``` + +--- + +##### `pto.pnot` + +- **syntax:** `%result = pto.pnot %input, %mask : !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate bitwise NOT gated by a governing predicate. + +Inactive lanes selected by `%mask` are zeroed. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? (~src[i]) : 0; +``` + +--- + +##### `pto.psel` + +- **syntax:** `%result = pto.psel %src0, %src1, %sel : !pto.mask, !pto.mask, !pto.mask -> !pto.mask` +- **semantics:** Predicate select (mux). `%sel` is the governing predicate that + chooses lanes from `%src0` or `%src1`. + +```c +for (int i = 0; i < N; i++) + dst[i] = sel[i] ? src0[i] : src1[i]; +``` + +--- + +##### `pto.pdintlv_b8` / `pto.pdintlv_b16` / `pto.pdintlv_b32` + +- **syntax:** `%low, %high = pto.pdintlv_b8 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **syntax:** `%low, %high = pto.pdintlv_b16 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **syntax:** `%low, %high = pto.pdintlv_b32 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **semantics:** De-interleave two predicate sources and return the two + de-interleaved predicate images in the same predicate element family. + +--- + +##### `pto.pintlv_b8` / `pto.pintlv_b16` / `pto.pintlv_b32` + +- **syntax:** `%low, %high = pto.pintlv_b8 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **syntax:** `%low, %high = pto.pintlv_b16 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **syntax:** `%low, %high = pto.pintlv_b32 %src0, %src1 : !pto.mask, !pto.mask -> !pto.mask, !pto.mask` +- **semantics:** Interleave two predicate sources and return the two + resulting predicate images in the same predicate element family. + +--- + +#### Typical Usage + +```mlir +// Generate all-active mask for f32 (64 lanes) +%all = pto.pset_b32 "PAT_ALL" : !pto.mask + +// Generate tail mask for remainder (last 12 elements) +%tail = pto.pge_b32 "PAT_VL12" : !pto.mask + +// Compare and generate mask +%cmp_mask = pto.vcmp %a, %b, %all, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask + +// Combine masks: only process tail elements that passed comparison +%combined = pto.pand %cmp_mask, %tail, %all : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + +// Use for predicated operation +%result = pto.vsel %true_vals, %false_vals, %combined : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 6. Unary Vector Ops + +> **Category:** Single-input vector operations +> **Pipeline:** PIPE_V (Vector Core) + +Element-wise operations that take one vector input and produce one vector output. + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%mask` is the predicate operand. For this family, inactive lanes follow the + predication behavior of the selected instruction form: zeroing forms + zero-fill inactive lanes, while merging forms preserve the destination value. +- `%result` is the destination vector register value. Unless stated otherwise, + `%result` has the same lane count and element type as `%input`. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). **fp16** values use **aclFloat16** in traces where measured. **bf16:** no simple-tile ST coverage on this surface; treat as **—**. + +| PTO op | RV (CA) | fp32 | fp16 | bf16 | +|--------|---------|------|------|------| +| `pto.vabs` | `RV_VABS_FP` | **5** | **5** | — | +| `pto.vneg` | `RV_VMULS` | **8** | **8** | — | +| `pto.vexp` | `RV_VEXP` | **16** | **21** | — | +| `pto.vln` | `RV_VLN` | **18** | **23** | — | +| `pto.vsqrt` | `RV_VSQRT` | **17** | **22** | — | +| `pto.vrelu` | `RV_VRELU` | **5** | **5** | — | +| `pto.vnot` | `RV_VNOT` | — | int-only paths | — | +| `pto.vmov` | `RV_VLD` proxy | **9** | **9** | — | + +--- + +#### Arithmetic + +##### `pto.vabs` + +- **syntax:** `%result = pto.vabs %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] < 0) ? -src[i] : src[i]; +``` + +- **inputs:** `%input` supplies the source lanes and `%mask` selects which lanes + participate. +- **outputs:** `%result` receives the lane-wise absolute values. +- **constraints and limitations:** Source and result types MUST match. On A5, + integer overflow follows the ISA default truncation behavior for this family; + `pto.vabs` is not an explicit saturating op. + +--- + +##### `pto.vneg` + +- **syntax:** `%result = pto.vneg %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = -src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise arithmetic negation. +- **constraints and limitations:** Source and result types MUST match. + +--- + +#### Transcendental + +##### `pto.vexp` + +- **syntax:** `%result = pto.vexp %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = expf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds `exp(input[i])` per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + +--- + +##### `pto.vln` + +- **syntax:** `%result = pto.vln %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = logf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the natural logarithm per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + For real-number semantics, active inputs SHOULD be strictly positive; non- + positive inputs follow the target's exception/NaN rules. + +--- + +##### `pto.vsqrt` + +- **syntax:** `%result = pto.vsqrt %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = sqrtf(src[i]); +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the square root per active lane. +- **constraints and limitations:** Only floating-point element types are legal. + Negative active inputs follow the target's exception/NaN rules. + +--- + +#### Activation + +##### `pto.vrelu` + +- **syntax:** `%result = pto.vrelu %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] > 0) ? src[i] : 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds `max(input[i], 0)` per active lane. +- **constraints and limitations:** Only floating-point element types are legal + on the current A5 surface described here. + +--- + +#### Bitwise + +##### `pto.vnot` + +- **syntax:** `%result = pto.vnot %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = ~src[i]; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects active lanes. +- **outputs:** `%result` holds the lane-wise bitwise inversion. +- **constraints and limitations:** Integer element types only. + +--- + +#### Movement + +#### Typical Usage + +```mlir +// Softmax numerator: exp(x - max) +%sub = pto.vsub %x, %max_broadcast, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%exp = pto.vexp %sub, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// ReLU activation +%activated = pto.vrelu %linear_out, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 7. Binary Vector Ops + +> **Category:** Two-input vector operations +> **Pipeline:** PIPE_V (Vector Core) + +Element-wise operations that take two vector inputs and produce one vector output. + +#### Common Operand Model + +- `%lhs` and `%rhs` are the two source vector register values. +- `%mask` is the predicate operand `Pg` that gates which lanes participate. +- `%result` is the destination vector register value. Unless explicitly noted, + it has the same lane count and element type as the inputs. +- Unless explicitly documented otherwise, `%lhs`, `%rhs`, and `%result` MUST + have matching vector shapes and element types. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). **fp16** uses **aclFloat16** in measured traces. **bf16:** — (no dedicated vec tile ST on this surface). + +| PTO op | RV (CA) | fp32 | fp16 | bf16 | +|--------|---------|------|------|------| +| `pto.vadd` | `RV_VADD` | **7** | **7** | — | +| `pto.vsub` | `RV_VSUB` | **7** | **7** | — | +| `pto.vmul` | `RV_VMUL` | **8** | **8** | — | +| `pto.vdiv` | `RV_VDIV` | **17** | **22** | — | + +--- + +#### Arithmetic + +##### `pto.vadd` + +- **syntax:** `%result = pto.vadd %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i64, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] + src1[i]; +``` + +- **inputs:** `%lhs` and `%rhs` are added lane-wise; `%mask` selects active + lanes. +- **outputs:** `%result` is the lane-wise sum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vsub` + +- **syntax:** `%result = pto.vsub %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i64, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] - src1[i]; +``` + +- **inputs:** `%lhs` is the minuend, `%rhs` is the subtrahend, and `%mask` + selects active lanes. +- **outputs:** `%result` is the lane-wise difference. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vmul` + +- **syntax:** `%result = pto.vmul %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, bf16, f32 (**NOT** i8/ui8) + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] * src1[i]; +``` + +- **inputs:** `%lhs` and `%rhs` are multiplied lane-wise; `%mask` selects + active lanes. +- **outputs:** `%result` is the lane-wise product. +- **constraints and limitations:** The current A5 profile excludes `i8/ui8` + forms from this surface. + +--- + +##### `pto.vdiv` + +- **syntax:** `%result = pto.vdiv %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 only (no integer division) + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] / src1[i]; +``` + +- **inputs:** `%lhs` is the numerator, `%rhs` is the denominator, and `%mask` + selects active lanes. +- **outputs:** `%result` is the lane-wise quotient. +- **constraints and limitations:** Floating-point element types only. Active + denominators containing `+0` or `-0` follow the target's exceptional + behavior. + +--- + +##### `pto.vmax` + +- **syntax:** `%result = pto.vmax %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src0[i] > src1[i]) ? src0[i] : src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` holds the lane-wise maximum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +##### `pto.vmin` + +- **syntax:** `%result = pto.vmin %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i8-i32, f16, bf16, f32 + +```c +for (int i = 0; i < N; i++) + dst[i] = (src0[i] < src1[i]) ? src0[i] : src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` holds the lane-wise minimum. +- **constraints and limitations:** Input and result types MUST match. + +--- + +#### Bitwise + +##### `pto.vand` + +- **syntax:** `%result = pto.vand %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] & src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise AND. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vor` + +- **syntax:** `%result = pto.vor %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] | src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise OR. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vxor` + +- **syntax:** `%result = pto.vxor %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] ^ src1[i]; +``` + +- **inputs:** `%lhs`, `%rhs`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise bitwise XOR. +- **constraints and limitations:** Integer element types only. + +--- + +#### Shift + +##### `pto.vshl` + +- **syntax:** `%result = pto.vshl %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] << src1[i]; +``` + +- **inputs:** `%lhs` supplies the shifted value, `%rhs` supplies the per-lane + shift amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. Shift counts + SHOULD stay within `[0, bitwidth(T) - 1]`; out-of-range behavior is target- + defined unless the verifier narrows it further. + +--- + +##### `pto.vshr` + +- **syntax:** `%result = pto.vshr %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** all integer types + +```c +for (int i = 0; i < N; i++) + dst[i] = src0[i] >> src1[i]; // arithmetic for signed, logical for unsigned +``` + +- **inputs:** `%lhs` supplies the shifted value, `%rhs` supplies the per-lane + shift amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. Signedness of the + element type determines arithmetic vs logical behavior. + +--- + +#### Carry Operations + +##### `pto.vaddc` + +- **syntax:** `%result, %carry = pto.vaddc %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Add with carry output. + +```c +for (int i = 0; i < N; i++) { + uint64_t r = (uint64_t)src0[i] + src1[i]; + dst[i] = (T)r; + carry[i] = (r >> bitwidth); +} +``` + +- **inputs:** `%lhs` and `%rhs` are added lane-wise and `%mask` selects active + lanes. +- **outputs:** `%result` is the truncated arithmetic result and `%carry` is the + carry/overflow predicate per lane. +- **A5 types:** `i32`, `si32`, `ui32` +- **constraints and limitations:** This is a carry-chain integer add family. On + the current A5 surface, only 32-bit integer element types are supported. + `%mask` and `%carry` therefore use the same typed-mask granularity as the + data vector family, which on the current documented A5 surface means + `!pto.mask`. + +--- + +##### `pto.vsubc` + +- **syntax:** `%result, %carry = pto.vsubc %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Subtract with per-lane carry output. + +```c +for (int i = 0; i < N; i++) { + dst[i] = src0[i] - src1[i]; + carry[i] = (src0[i] >= src1[i]); +} +``` + +- **inputs:** `%lhs` and `%rhs` are subtracted lane-wise and `%mask` selects + active lanes. +- **outputs:** `%result` is the arithmetic difference and `%carry` is the + per-lane carry predicate. For this subtraction family, active lanes set + `%carry[i] = 1` when the subtraction completes without borrow, and + `%carry[i] = 0` when a borrow occurs. +- **A5 types:** `i32`, `si32`, `ui32` +- **constraints and limitations:** This operation is currently restricted to + the 32-bit integer carry/borrow-chain family. `%mask` and `%carry` + therefore use the same typed-mask granularity as the data vector family, + which on the current documented A5 surface means `!pto.mask`. + +--- + +#### Typical Usage + +```mlir +// Vector addition +%sum = pto.vadd %a, %b, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Element-wise multiply +%prod = pto.vmul %x, %y, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Clamp to range [min, max] +%clamped_low = pto.vmax %input, %min_vec, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%clamped = pto.vmin %clamped_low, %max_vec, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Bit manipulation +%masked = pto.vand %data, %bitmask, %mask : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +``` + + + +### 8. Vec-Scalar Ops + +> **Category:** Vector-scalar operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that combine a vector with a scalar value, applying the scalar to every lane. + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%scalar` is the scalar operand in SSA form. +- `%mask` is the predicate operand. +- `%result` is the destination vector register value. +- For 32-bit scalar forms, the scalar source MUST satisfy the backend's legal + scalar-source constraints for this family. +- For elementwise vec-scalar families whose scalar conceptually matches the + vector element type (`pto.vadds`, `pto.vmuls`, `pto.vmaxs`, + `pto.vmins`, `pto.vlrelu`): + - signed integer vectors accept signed integer scalars with the same width, + and also accept signless `i` + - unsigned integer vectors accept unsigned integer scalars with the same + width, and also accept signless `i` + - signless integer vectors accept signless `i` +- `pto.vshls` and `pto.vshrs` are not part of that rule; their scalar operand + is the shift amount and remains fixed to `i16`. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). **fp16** uses **aclFloat16** in measured traces. **bf16:** —. + +| PTO op | RV (CA) | fp32 | fp16 | bf16 | +|--------|---------|------|------|------| +| `pto.vadds` | `RV_VADDS` | **7** | **7** | — | +| `pto.vmuls` | `RV_VMULS` | **8** | **8** | — | + +--- + +#### Arithmetic + +##### `pto.vadds` + +- **syntax:** `%result = pto.vadds %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` +- **A5 types:** `si8`, `si16`, `si32`, `ui8`, `ui16`, `ui32`, `f16`, `bf16`, `f32` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] + scalar; +``` + +- **inputs:** `%input` is the source vector, `%scalar` is broadcast logically to + each lane, and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise sum. +- **constraints and limitations:** Input vector element type, scalar type, and + result vector element type MUST match. For integer vector forms, `%scalar` + may also use matching-signedness integer or signless `i` with the same + bit width as the vector element type, so it can be fed directly from `arith` + constants. + +--- + +##### `pto.vmuls` + +- **syntax:** `%result = pto.vmuls %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] * scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise product. +- **constraints and limitations:** Supported element types are hardware-family + specific; the current PTO micro Instruction documentation covers the common + numeric cases. For integer vector forms, `%scalar` may use matching-signedness + integer or signless `i` with the same bit width as the vector element + type. + +--- + +##### `pto.vmaxs` + +- **syntax:** `%result = pto.vmaxs %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] > scalar) ? src[i] : scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise maximum. +- **constraints and limitations:** Input and result types MUST match. For + integer vector forms, `%scalar` may use matching-signedness integer or + signless `i` with the same bit width as the vector element type. + +--- + +##### `pto.vmins` + +- **syntax:** `%result = pto.vmins %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] < scalar) ? src[i] : scalar; +``` + +- **inputs:** `%input`, `%scalar`, and `%mask` as above. +- **outputs:** `%result` is the lane-wise minimum. +- **constraints and limitations:** Input and result types MUST match. For + integer vector forms, `%scalar` may use matching-signedness integer or + signless `i` with the same bit width as the vector element type. + +--- + +#### Shift + +##### `pto.vshls` + +- **syntax:** `%result = pto.vshls %input, %scalar, %mask : !pto.vreg, i16, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] << scalar; +``` + +- **inputs:** `%input` is the value vector, `%scalar` is the uniform `i16` shift + amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. The shift amount + SHOULD stay within the source element width. + +--- + +##### `pto.vshrs` + +- **syntax:** `%result = pto.vshrs %input, %scalar, %mask : !pto.vreg, i16, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = src[i] >> scalar; +``` + +- **inputs:** `%input` is the value vector, `%scalar` is the uniform `i16` shift + amount, and `%mask` selects active lanes. +- **outputs:** `%result` is the shifted vector. +- **constraints and limitations:** Integer element types only. + +--- + +##### `pto.vlrelu` + +- **syntax:** `%result = pto.vlrelu %input, %scalar, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : scalar * src[i]; +``` + +- **inputs:** `%input` is the activation vector, `%scalar` is the leaky slope, + and `%mask` selects active lanes. +- **outputs:** `%result` is the lane-wise leaky-ReLU result. +- **constraints and limitations:** Only `f16` and `f32` forms are currently + documented for `pto.vlrelu`. + +--- + +#### Carry Operations + +##### `pto.vaddcs` + +- **syntax:** `%result, %carry = pto.vaddcs %lhs, %rhs, %carry_in, %mask : !pto.vreg, !pto.vreg, !pto.mask, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Add with carry-in and carry-out. + +```c +for (int i = 0; i < N; i++) { + uint64_t r = (uint64_t)src0[i] + src1[i] + carry_in[i]; + dst[i] = (T)r; + carry_out[i] = (r >> bitwidth); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the value vectors, `%carry_in` is the + incoming carry predicate, and `%mask` selects active lanes. +- **outputs:** `%result` is the arithmetic result and `%carry` is the carry-out + predicate. +- **A5 types:** `i32`, `si32`, `ui32` +- **constraints and limitations:** This is the scalar-extended carry-chain + family. On the current A5 surface, only 32-bit integer element types are + supported. `%carry_in`, `%mask`, and `%carry` therefore all use the same + typed-mask granularity as the data vector family, which on the current + documented A5 surface means `!pto.mask`. + +--- + +##### `pto.vsubcs` + +- **syntax:** `%result, %carry = pto.vsubcs %lhs, %rhs, %carry_in, %mask : !pto.vreg, !pto.vreg, !pto.mask, !pto.mask -> !pto.vreg, !pto.mask` +- **semantics:** Subtract with carry input and output. + +```c +for (int i = 0; i < N; i++) { + dst[i] = src0[i] - src1[i] - (1 - carry_in[i]); + carry_out[i] = (src0[i] >= src1[i] + (1 - carry_in[i])); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the value vectors, `%carry_in` is the + incoming carry predicate, and `%mask` selects active lanes. +- **outputs:** `%result` is the arithmetic result and `%carry` is the + carry predicate after the lane-wise subtraction. For this subtraction family, + active lanes set `%carry[i] = 1` when the subtraction completes without + borrow, and `%carry[i] = 0` when a borrow occurs. +- **A5 types:** `i32`, `si32`, `ui32` +- **constraints and limitations:** This is the scalar-extended borrow-chain + family and is currently restricted to 32-bit integer element types. + `%carry_in`, `%mask`, and `%carry` therefore all use the same typed-mask + granularity as the data vector family, which on the current documented A5 + surface means `!pto.mask`. + +--- + +#### Typical Usage + +```mlir +// Add bias to all elements +%biased = pto.vadds %activation, %bias_scalar, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Scale by constant +%scaled = pto.vmuls %input, %scale, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Clamp to [0, 255] for uint8 quantization +%clamped_low = pto.vmaxs %input, %c0, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> +%clamped = pto.vmins %clamped_low, %c255, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Shift right by fixed amount +%shifted = pto.vshrs %data, %c4, %mask : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> +``` + + + +### 9. Conversion Ops + +> **Category:** Type conversion operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that convert between data types (float/int, narrowing/widening). + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%mask` is the predicate mask that selects active conversion lanes. +- `%result` is the destination vector register value. +- `rnd`, `sat`, and `part` are optional attributes that refine + conversion behavior when the selected source/destination type pair needs + rounding, saturation, or lane placement control. +- The single `pto.vcvt` surface covers float-int, float-float, int-float, and + int-int conversion families. + +#### CA latency (A5, Ascend910_9599 CA) + +Cycle-accurate simulator **popped→retire** latency (cycles). Only representative traces below; other `pto.vcvt` conversion pairs depend on the RV lowering in the trace. + +| PTO op | RV (CA) | Note | Latency | +|--------|---------|------|---------| +| `pto.vcvt` | `RV_VCVT_F2F` | f32→f16 | **7** | +| `pto.vci` | — | no vector `RV_*` in sampled `veccore0` trace | — | + +--- + +#### `pto.vci` + +- **syntax:** `%result = pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- **semantics:** Generate a lane-index vector from a scalar seed/index value. +- **inputs:** + `%index` is the scalar seed or base index. +- **outputs:** + `%result` is the generated index vector. +- **constraints and limitations:** + This is an index-generation family, not a numeric conversion. `ORDER` and the + result element type together determine how indices are generated. `%result` + uses an integer element type, and the scalar `%index` type matches that + result element type. + +--- + +#### `pto.vcvt` + +- **syntax:** `%result = pto.vcvt %input, %mask {rnd = "RND", sat = "SAT", part = "PART"} : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Type conversion between float/int types with rounding control. + +```c +for (int i = 0; i < min(N, M); i++) + if (mask[i]) + dst[i] = convert(src[i], T0, T1, rnd); +``` + +- **inputs:** + `%input` is the source vector, `%mask` selects active lanes, and attributes + select rounding, saturation, and output placement when the conversion changes + width or packs into sub-lane positions. +- **outputs:** + `%result` is the converted vector. +- **constraints and limitations:** + Only documented source/destination type pairs are legal. All three + attributes are optional at the surface level, but only the subset meaningful + to the selected conversion kind should be provided. The execution mask must + use the typed-mask granularity that matches the source vector family on the + current surface; there is no `!pto.mask` form in VPTO. + +--- + +##### Rounding Modes + +| Mode | Description | +|------|-------------| +| `R` | Round to nearest, ties to even (default) | +| `A` | Round away from zero | +| `F` | Round toward negative infinity (floor) | +| `C` | Round toward positive infinity (ceil) | +| `Z` | Round toward zero (truncate) | +| `O` | Round to odd | + +--- + +##### Saturation Modes + +| Mode | Description | +|------|-------------| +| `SAT` | Saturate on overflow | +| `NOSAT` | No saturation (wrap/undefined on overflow) | + +--- + +##### Part Modes + +Use `part` when a width-changing conversion writes only one half of each wider +destination lane group. This is typically used in even/odd placement forms such +as `32 -> 16` or `16 -> 32` style conversions. + +| Mode | Description | +|------|-------------| +| `EVEN` | Output to even-indexed lanes | +| `ODD` | Output to odd-indexed lanes | + +--- + +##### Attribute Guidance + +- `rnd` + - Use when the conversion needs an explicit rounding rule, especially for + float-to-int, float-to-float narrowing, or integer-to-float forms that do + not map exactly. +- `mask` + - Use to select which source lanes participate in the conversion. In + width-changing conversions, `mask` works together with `part` / `pp` to + determine which logical lane positions are produced. +- `sat` + - Use when the conversion may overflow the destination range and hardware + exposes a saturating form. +- `part` + - Use for width-changing conversions that select the even or odd half of the + destination packing layout. + +###### Float To Int + +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<32xsi64>` +- `%dst = pto.vcvt %src, %mask {rnd, sat} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xsi32>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {rnd, part} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xsi32>` +- `%dst = pto.vcvt %src, %mask {rnd, sat} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<256xsi8>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<64xsi32>` + +###### Float To Float + +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16>` +- `%dst = pto.vcvt %src, %mask {rnd, sat, part} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xbf16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<64xf32>` + +###### Int To Float + +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<128xf16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xsi8>, !pto.mask -> !pto.vreg<128xf16>` +- `%dst = pto.vcvt %src, %mask {rnd} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<128xf16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<64xf32>` +- `%dst = pto.vcvt %src, %mask {rnd} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<64xf32>` + +###### Int To Int + +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<128xui16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<64xui32>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xsi8>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<256xsi8>, !pto.mask -> !pto.vreg<64xsi32>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<64xui32>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<64xui32>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<128xsi16>, !pto.mask -> !pto.vreg<64xsi32>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<128xui16>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<256xui8>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<128xui16>` +- `%dst = pto.vcvt %src, %mask {sat, part} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<128xsi16>` +- `%dst = pto.vcvt %src, %mask {part} : !pto.vreg<64xsi32>, !pto.mask -> !pto.vreg<32xsi64>` + +##### A5 Supported Type Matrix + +The table below is only a summary. For exact attribute combinations, use the +per-form entries above as the source of truth. + +| `src \ dst` | `ui8` | `si8` | `ui16` | `si16` | `ui32` | `si32` | `si64` | `f16` | `f32` | `bf16` | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| `ui8` | | | Y | | Y | | | Y | | | +| `si8` | | | | Y | | Y | | Y | | | +| `ui16` | Y | | | | Y | | | | | | +| `si16` | Y | | | | Y | Y | | Y | Y | | +| `ui32` | Y | | Y | Y | | | | | | | +| `si32` | Y | | Y | Y | | | Y | | Y | | +| `si64` | | | | | | | | | | | +| `f16` | Y | Y | | Y | | Y | | | Y | | +| `f32` | | | | Y | | Y | Y | Y | | Y | +| `bf16` | | | | | | Y | | | Y | | + +--- + +##### Width-Changing Conversion Pattern + +For conversions that change width (e.g., f32→f16), use even/odd parts and combine: + +```mlir +// Convert two f32 vectors to one f16 vector +%even = pto.vcvt %in0, %mask {rnd = "R", sat = "SAT", part = "EVEN"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +%odd = pto.vcvt %in1, %mask {rnd = "R", sat = "SAT", part = "ODD"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +%result = pto.vor %even, %odd, %mask : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> +``` + +--- + +#### `pto.vtrc` + +- **syntax:** `%result = pto.vtrc %input, %mask, "RND" : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Truncate/round float to integer-valued float (stays in float type). + +```c +for (int i = 0; i < N; i++) + dst[i] = round_to_int_valued_float(src[i], rnd); +``` + +- **inputs:** + `%input` is the floating-point source vector, `%mask` selects active lanes, + and `RND` selects the truncation/rounding rule. +- **outputs:** + `%result` is still a floating-point vector, but each active lane now carries + an integer-valued floating-point result. +- **constraints and limitations:** + This op does not change the element type. `T` must be `f16`, `f32`, or + `bf16`. `RND` must be one of `R`, `A`, `F`, `C`, or `Z`. `BW` must match the + element width: `b16` for `f16`/`bf16`, `b32` for `f32`. + +**Example:** +```mlir +// Round to nearest integer, keep as float +%rounded = pto.vtrc %input, %mask, "R" : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// input: [1.4, 2.6, -1.5, 3.0] +// output: [1.0, 3.0, -2.0, 3.0] +``` + +--- + +#### Typical Usage + +```mlir +// Quantization: f32 → i8 with saturation +%scaled = pto.vmuls %input, %scale, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> +%quantized = pto.vcvt %scaled, %mask {rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xi32> +// Then narrow i32 → i8 via pack ops + +// Mixed precision: bf16 → f32 for accumulation +%f32_vec = pto.vcvt %bf16_input, %mask {part = "EVEN"} + : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<64xf32> + +// Floor for integer division +%floored = pto.vtrc %ratio, %mask, "F" : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%int_div = pto.vcvt %floored, %mask {rnd = "Z"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xi32> +``` + + + +### 10. Reduction Ops + +> **Category:** Vector reduction operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that reduce a vector to a scalar or per-group result. + +#### Common Operand Model + +- `%input` is the source vector register value. +- `%mask` is the predicate operand `Pg`; inactive lanes do not participate. +- `%result` is the destination vector register value. +- Reduction results are written into the low-significance portion of the + destination vector and the remaining destination bits are zero-filled. + +--- + +#### Full Vector Reductions + +##### `pto.vcadd` + +- **syntax:** `%result = pto.vcadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i64, f16, f32 +- **semantics:** Sum all elements. Result in lane 0, others zeroed. + +```c +T sum = 0; +for (int i = 0; i < N; i++) + sum += src[i]; +dst[0] = sum; +for (int i = 1; i < N; i++) + dst[i] = 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains the reduction result in its low element(s). +- **constraints and limitations:** Some narrow integer forms may widen the + internal accumulation or result placement. If all predicate bits are zero, the + result is zero. + +--- + +##### `pto.vcmax` + +- **syntax:** `%result = pto.vcmax %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Find max element with argmax. The lowest destination element + stores the maximum value, the second-lowest destination element stores the + index of the first maximum, and all remaining elements are zero-filled. + +```c +T mx = -INF; int idx = 0; +for (int i = 0; i < N; i++) + if (src[i] > mx) { mx = src[i]; idx = i; } +dst[0] = mx; +dst[1] = idx; +for (int i = 2; i < N; i++) + dst[i] = 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result[0]` holds the extremum value and `%result[1]` holds the + index. Other destination elements are zero-filled. +- **constraints and limitations:** If there are multiple maxima, the minimum + index is written. For floating-point types, inactive lanes are treated as + `-INF`; if all lanes are inactive, `%result[0]` becomes `-INF`. For integer + types, inactive lanes are treated as the literal minimum value; if all lanes + are inactive, `%result[0]` becomes that literal minimum value. The index is + written into the second destination element slot of the same destination + vector register. + +--- + +##### `pto.vcmin` + +- **syntax:** `%result = pto.vcmin %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Find min element with argmin. The lowest destination element + stores the minimum value, the second-lowest destination element stores the + index of the first minimum, and all remaining elements are zero-filled. + +```c +T mn = INF; int idx = 0; +for (int i = 0; i < N; i++) + if (src[i] < mn) { mn = src[i]; idx = i; } +dst[0] = mn; +dst[1] = idx; +for (int i = 2; i < N; i++) + dst[i] = 0; +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result[0]` holds the extremum value and `%result[1]` holds the + index. Other destination elements are zero-filled. +- **constraints and limitations:** If there are multiple minima, the minimum + index is written. For floating-point types, inactive lanes are treated as + `+INF`; if all lanes are inactive, `%result[0]` becomes `+INF`. For integer + types, inactive lanes are treated as the literal maximum value; if all lanes + are inactive, `%result[0]` becomes that literal maximum value. The index is + written into the second destination element slot of the same destination + vector register. + +--- + +#### Per-VLane (Group) Reductions + +The vector register is organized as **8 VLanes** of 32 bytes each. Group reductions operate within each VLane independently. + +``` +vreg layout (f32 example, 64 elements total): +VLane 0: [0..7] VLane 1: [8..15] VLane 2: [16..23] VLane 3: [24..31] +VLane 4: [32..39] VLane 5: [40..47] VLane 6: [48..55] VLane 7: [56..63] +``` + +##### `pto.vcgadd` + +- **syntax:** `%result = pto.vcgadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Sum within each VLane. 8 results at indices 0, 8, 16, 24, 32, 40, 48, 56 (for f32). + +```c +int K = N / 8; // elements per VLane +for (int g = 0; g < 8; g++) { + T sum = 0; + for (int i = 0; i < K; i++) + sum += src[g*K + i]; + dst[g*K] = sum; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +// For f32: results at dst[0], dst[8], dst[16], dst[24], dst[32], dst[40], dst[48], dst[56] +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one sum per 32-byte VLane group, written + contiguously into the low slot of each group. +- **constraints and limitations:** This is a per-32-byte VLane-group reduction. + Inactive lanes are treated as zero. + +--- + +##### `pto.vcgmax` + +- **syntax:** `%result = pto.vcgmax %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Max within each VLane. + +```c +int K = N / 8; +for (int g = 0; g < 8; g++) { + T mx = -INF; + for (int i = 0; i < K; i++) + if (src[g*K + i] > mx) mx = src[g*K + i]; + dst[g*K] = mx; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one maximum per 32-byte VLane group. +- **constraints and limitations:** Grouping is by hardware 32-byte VLane, not by + arbitrary software subvector. + +--- + +##### `pto.vcgmin` + +- **syntax:** `%result = pto.vcgmin %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** i16-i32, f16, f32 +- **semantics:** Min within each VLane. + +```c +int K = N / 8; +for (int g = 0; g < 8; g++) { + T mn = INF; + for (int i = 0; i < K; i++) + if (src[g*K + i] < mn) mn = src[g*K + i]; + dst[g*K] = mn; + for (int i = 1; i < K; i++) + dst[g*K + i] = 0; +} +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` contains one minimum per 32-byte VLane group. +- **constraints and limitations:** Grouping is by hardware 32-byte VLane, not by + arbitrary software subvector. + +--- + +#### Prefix Operations + +##### `pto.vcpadd` + +- **syntax:** `%result = pto.vcpadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Inclusive prefix sum (scan). + +```c +dst[0] = src[0]; +for (int i = 1; i < N; i++) + dst[i] = dst[i-1] + src[i]; +``` + +**Example:** +```c +// input: [1, 2, 3, 4, 5, ...] +// output: [1, 3, 6, 10, 15, ...] +``` + +- **inputs:** `%input` is the source vector and `%mask` selects participating + lanes. +- **outputs:** `%result` is the inclusive prefix-sum vector. +- **constraints and limitations:** Only floating-point element types are + documented on the current A5 surface here. + +--- + +#### Typical Usage + +```mlir +// Softmax: find max for numerical stability +%max_vec = pto.vcmax %logits, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// max is in lane 0, broadcast it +%max_broadcast = pto.vlds %ub_tmp[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> + +// Row-wise sum using vcgadd (for 8-row tile) +%row_sums = pto.vcgadd %tile, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// Results at indices 0, 8, 16, 24, 32, 40, 48, 56 + +// Full vector sum for normalization +%total = pto.vcadd %values, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// total[0] contains the sum + +// Prefix sum for cumulative distribution +%cdf = pto.vcpadd %pdf, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 11. Compare & Select + +> **Category:** Comparison and conditional selection operations +> **Pipeline:** PIPE_V (Vector Core) + +Operations that compare vectors and conditionally select elements. + +#### Common Operand Model + +- `%src0` and `%src1` are source vector operands. +- `%scalar` is the scalar operand for scalar-comparison families. +- `%seed` is the incoming predicate that limits which lanes participate in the + compare. +- `%result` is either a predicate mask (`vcmp`, `vcmps`) or a vector register + (`vsel`, `vselr`, `vselrv2`). + +--- + +#### Comparison Operations + +##### `pto.vcmp` + +- **syntax:** `%result = pto.vcmp %src0, %src1, %seed, "CMP_MODE" : !pto.vreg, !pto.vreg, !pto.mask -> !pto.mask` +- **semantics:** Element-wise comparison, output predicate mask. + +```c +for (int i = 0; i < N; i++) + if (seed[i]) + dst[i] = (src0[i] CMP src1[i]) ? 1 : 0; +``` + +**Compare modes:** + +| Mode | Operation | +|------|-----------| +| `eq` | Equal (==) | +| `ne` | Not equal (!=) | +| `lt` | Less than (<) | +| `le` | Less than or equal (<=) | +| `gt` | Greater than (>) | +| `ge` | Greater than or equal (>=) | + +**Example:** +```mlir +%all_active = pto.pset_b32 "PAT_ALL" : !pto.mask +%lt_mask = pto.vcmp %a, %b, %all_active, "lt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +// lt_mask[i] = 1 if a[i] < b[i] +``` + +- **inputs:** `%src0`, `%src1`, and `%seed`; `CMP_MODE` selects the comparison + predicate. +- **outputs:** `%result` is the generated predicate mask. +- **constraints and limitations:** Only lanes enabled by `%seed` participate. + Integer and floating-point comparisons follow their own element-type-specific + comparison rules. `%seed` and `%result` keep the typed-mask granularity that + matches `%src0` / `%src1`. + +--- + +##### `pto.vcmps` + +- **syntax:** `%result = pto.vcmps %src, %scalar, %seed, "CMP_MODE" : !pto.vreg, T, !pto.mask -> !pto.mask` +- **semantics:** Compare vector against scalar. + +```c +for (int i = 0; i < N; i++) + if (seed[i]) + dst[i] = (src[i] CMP scalar) ? 1 : 0; +``` + +**Example:** +```mlir +%positive_mask = pto.vcmps %values, %c0_f32, %all_active, "gt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +// positive_mask[i] = 1 if values[i] > 0 +``` + +- **inputs:** `%src` is the vector source, `%scalar` is the scalar comparison + value, and `%seed` is the incoming predicate. +- **outputs:** `%result` is the generated predicate mask. +- **constraints and limitations:** For 32-bit scalar forms, the scalar source + MUST satisfy the backend's legal scalar-source constraints for this family. + `%seed` and `%result` keep the typed-mask granularity that matches `%src`. + +--- + +#### Selection Operations + +##### `pto.vsel` + +- **syntax:** `%result = pto.vsel %src0, %src1, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Per-lane select based on mask. + +```c +for (int i = 0; i < N; i++) + dst[i] = mask[i] ? src0[i] : src1[i]; +``` + +**Example — Conditional assignment:** +```mlir +// dst = mask ? true_vals : false_vals +%result = pto.vsel %true_vals, %false_vals, %condition + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +- **inputs:** `%src0` is the true-path vector, `%src1` is the false-path vector, + and `%mask` selects between them. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** Source vectors and result MUST have matching + vector shapes and element types. `%mask` keeps the typed-mask granularity + that matches the selected vector family. + +--- + +##### `pto.vselr` + +- **syntax:** `%result = pto.vselr %src, %idx : !pto.vreg, !pto.vreg> -> !pto.vreg` +- **semantics:** Lane-select by index vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = src[idx[i]]; +``` + +- **inputs:** `%src` is the source vector. `%idx` is the lane-index vector. +- **outputs:** `%result` is the reordered vector. +- **constraints and limitations:** `%idx` must use integer elements. `%idx` + must have the same lane count as `%src`, and its integer element width must + match the bit width of `%src` element type. + +--- + +##### `pto.vselrv2` + +- **syntax:** `%result = pto.vselrv2 %src0, %src1 : !pto.vreg, !pto.vreg -> !pto.vreg` +- **semantics:** Variant select form with the same current two-vector operand shape. +- **inputs:** `%src0` and `%src1` are the source vectors. +- **outputs:** `%result` is the selected vector. +- **constraints and limitations:** This page records the surface shape only. + Lowering MUST preserve the exact A5 variant semantics selected for this form. + +--- + +#### Typical Usage + +```mlir +// Clamp negative values to zero (manual ReLU) +%all = pto.pset_b32 "PAT_ALL" : !pto.mask +%zero = pto.vbr %c0_f32 : f32 -> !pto.vreg<64xf32> +%neg_mask = pto.vcmps %input, %c0_f32, %all, "lt" : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%clamped = pto.vsel %zero, %input, %neg_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Element-wise max via compare+select +%gt_mask = pto.vcmp %a, %b, %all, "gt" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +%max_ab = pto.vsel %a, %b, %gt_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Threshold filter +%above_thresh = pto.vcmps %scores, %threshold, %all, "ge" : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%filtered = pto.vsel %scores, %zero, %above_thresh : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +--- + +#### Compare + Select Pattern + +```mlir +// Softmax safe exp: exp(x - max) where x < max returns exp of negative +// but we want to clamp to avoid underflow + +%all = pto.pset_b32 "PAT_ALL" : !pto.mask + +// 1. Compare against threshold +%too_small = pto.vcmps %x_minus_max, %min_exp_arg, %all, "lt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask + +// 2. Clamp values below threshold +%clamped = pto.vsel %min_exp_arg_vec, %x_minus_max, %too_small + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// 3. Safe exp +%exp_result = pto.vexp %clamped, %all : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + + + +### 12. Data Rearrangement + +> **Category:** In-register data movement and permutation +> **Pipeline:** PIPE_V (Vector Core) + +Operations that rearrange data within or between vector registers without memory access. + +#### Common Operand Model + +- `%lhs` / `%rhs` are source vector register values. +- `%src` is a single source vector register value. +- `%result` is the destination vector register value unless an op explicitly + returns multiple vectors. +- These families do not access UB directly; they only rearrange register + contents. + +--- + +#### Interleave / Deinterleave + +##### `pto.vintlv` + +- **syntax:** `%low, %high = pto.vintlv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg, !pto.vreg` +- **semantics:** Interleave elements from two sources. + +```c +// Interleave: merge even/odd elements from two sources +// low = {src0[0], src1[0], src0[1], src1[1], ...} +// high = {src0[N/2], src1[N/2], src0[N/2+1], src1[N/2+1], ...} +``` + +- **inputs:** `%lhs` and `%rhs` are the two source vectors. +- **outputs:** `%low` and `%high` are the two destination vectors. +- **constraints and limitations:** The two outputs form a paired interleave + result. The PTO micro Instruction representation exposes that pair as two SSA results, and the pair ordering MUST + be preserved. + +--- + +##### `pto.vdintlv` + +- **syntax:** `%low, %high = pto.vdintlv %lhs, %rhs : !pto.vreg, !pto.vreg -> !pto.vreg, !pto.vreg` +- **semantics:** Deinterleave elements into even/odd. + +```c +// Deinterleave: separate even/odd elements +// low = {src0[0], src0[2], src0[4], ...} // even +// high = {src0[1], src0[3], src0[5], ...} // odd +``` + +- **inputs:** `%lhs` and `%rhs` represent the interleaved source stream in the + current PTO micro Instruction representation. +- **outputs:** `%low` and `%high` are the separated destination vectors. +- **constraints and limitations:** The two outputs form the even/odd + deinterleave result pair, and their ordering MUST be preserved. + +--- + +#### Compress / Expand + +##### `pto.vsqz` + +- **syntax:** `%result = pto.vsqz %src, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Compress — pack active lanes to front. + +```c +int j = 0; +for (int i = 0; i < N; i++) + if (mask[i]) dst[j++] = src[i]; +while (j < N) dst[j++] = 0; +``` + +**Use case:** Sparse data compaction, filtering. + +- **inputs:** `%src` is the source vector and `%mask` selects which elements are + kept. +- **outputs:** `%result` is the compacted vector. +- **constraints and limitations:** This is a reduction-style compaction family. + Preserved element order MUST match source lane order. + +--- + +##### `pto.vusqz` + +- **syntax:** `%result = pto.vusqz %src, %mask : !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Generate per-lane prefix counts from the governing predicate. + +```c +dst[0] = 0; +for (int i = 1; i < N; i++) + dst[i] = mask[i - 1] ? (dst[i - 1] + 1) : dst[i - 1]; +``` + +- **inputs:** `%mask` is the governing predicate. The current PTO surface keeps + `%src` in the operand list for interface compatibility, but the observable + result semantics are determined by `%mask`. +- **outputs:** `%result[i]` equals the number of active lanes in `%mask[0:i)`, + with `%result[0] = 0`. +- **constraints and limitations:** `T` is currently limited to `si8`, `si16`, + or `si32`. This operation is a predicate-derived counting/rearrangement + primitive rather than a value-placement primitive. The final predicate lane + does not contribute to a later output lane because there is no `dst[N]`. + +--- + +--- + +##### `pto.vselr` + +- **syntax:** `%result = pto.vselr %src, %idx : !pto.vreg, !pto.vreg> -> !pto.vreg` +- **semantics:** Register lane-select with an explicit index vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = src[idx[i]]; +``` + +- **inputs:** `%src` is the source vector. `%idx` is the lane-index vector. +- **outputs:** `%result` is the reordered vector. +- **constraints and limitations:** This page records the rearrangement use of + the family; the compare/select page documents the same name from the predicate + selection perspective. + +--- + +#### Pack / Unpack + +##### `pto.vpack` + +- **syntax:** `%result = pto.vpack %src, "PART" : !pto.vreg -> !pto.vreg<2NxT_narrow>` +- **semantics:** Narrow one wide vector and place the narrowed payload into the + selected half of the result. The other half is filled with zero. + +```c +// e.g., vreg<64xi32> → vreg<128xui16> +for (int i = 0; i < N; i++) + dst[i] = 0; + +if (part == LOWER) { + for (int i = 0; i < N; i++) + dst[i] = truncate(src[i]); +} else { // HIGHER + for (int i = 0; i < N; i++) + dst[N + i] = truncate(src[i]); +} +``` + +- **inputs:** `%src` is the wide source vector. `"LOWER"` and `"HIGHER"` + select whether the narrowed payload lands in the lower or upper half. +- **outputs:** `%result` is the packed narrow vector. +- **constraints and limitations:** Packing is a narrowing conversion with + truncation semantics. Current VPTO surface supports `i32/ui32 -> ui16` and + `i16/ui16 -> ui8`. + +--- + +##### `pto.vsunpack` + +- **syntax:** `%result = pto.vsunpack %src, %part : !pto.vreg, index -> !pto.vreg` +- **semantics:** Sign-extending unpack — narrow to wide (half). + +```c +// e.g., vreg<128xi16> → vreg<64xi32> (one half) +for (int i = 0; i < N/2; i++) + dst[i] = sign_extend(src[part_offset + i]); +``` + +- **inputs:** `%src` is the packed narrow vector and `%part` selects which half + is unpacked. +- **outputs:** `%result` is the widened vector. +- **constraints and limitations:** This is the sign-extending unpack family. + +--- + +##### `pto.vzunpack` + +- **syntax:** `%result = pto.vzunpack %src, %part : !pto.vreg, index -> !pto.vreg` +- **semantics:** Zero-extending unpack — narrow to wide (half). + +```c +for (int i = 0; i < N/2; i++) + dst[i] = zero_extend(src[part_offset + i]); +``` + +- **inputs:** `%src` is the packed narrow vector and `%part` selects which half + is unpacked. +- **outputs:** `%result` is the widened vector. +- **constraints and limitations:** This is the zero-extending unpack family. + +--- + +#### Typical Usage + +```mlir +// AoS → SoA conversion using deinterleave +%even, %odd = pto.vdintlv %interleaved0, %interleaved1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +// Filter: keep only elements passing condition +%pass_mask = pto.vcmps %values, %threshold, %all, "gt" + : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.mask +%compacted = pto.vsqz %values, %pass_mask + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Type narrowing via pack +%packed_i16 = pto.vpack %wide_i32, "LOWER" + : !pto.vreg<64xi32> -> !pto.vreg<128xui16> +``` + +--- + +#### V2 Interleave Forms + +##### `pto.vintlvv2` + +- **syntax:** `%result = pto.vintlvv2 %lhs, %rhs, "PART" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **inputs:** `%lhs` and `%rhs` are source vectors and `PART` selects the + returned half of the V2 interleave result. +- **outputs:** `%result` is the selected interleave half. +- **constraints and limitations:** This op exposes only one half of the V2 + result in SSA form. + +##### `pto.vdintlvv2` + +- **syntax:** `%result = pto.vdintlvv2 %lhs, %rhs, "PART" : !pto.vreg, !pto.vreg -> !pto.vreg` +- **inputs:** `%lhs` and `%rhs` are source vectors and `PART` selects the + returned half of the V2 deinterleave result. +- **outputs:** `%result` is the selected deinterleave half. +- **constraints and limitations:** This op exposes only one half of the V2 + result in SSA form. + + + +### 13. DSA/SFU Ops + +> **Category:** Domain-specific accelerator and special function unit operations +> **Pipeline:** PIPE_V (Vector Core) / SFU + +Fused operations, special functions, and UB-to-UB operations that leverage hardware acceleration. + +#### Common Operand Model + +- `%input`, `%lhs`, `%rhs`, `%acc`, and `%alpha` are source SSA values whose + roles are called out per instruction. +- `%mask` is the predicate operand `Pg` when present. +- `%result` is the destination SSA value. +- This page mixes three different backend shapes: pure `vreg -> vreg` ops, + conversion/fusion ops, and UB-to-UB helpers. Each instruction section calls + out which storage model it uses. + +--- + +#### Fused Activation Ops (vreg→vreg) + +##### `pto.vlrelu` + +- **syntax:** `%result = pto.vlrelu %input, %alpha, %mask : !pto.vreg, T, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Leaky ReLU with scalar alpha. + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : alpha * src[i]; +``` + +- **inputs:** `%input` is the activation vector, `%alpha` is the scalar slope, + and `%mask` selects active lanes. +- **outputs:** `%result` is the leaky-ReLU vector. +- **constraints and limitations:** Only `f16` and `f32` forms are currently + documented for `pto.vlrelu`. + +--- + +##### `pto.vprelu` + +- **syntax:** `%result = pto.vprelu %input, %alpha : !pto.vreg, !pto.vreg -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** Parametric ReLU with per-element alpha vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = (src[i] >= 0) ? src[i] : alpha[i] * src[i]; +``` + +- **inputs:** `%input` is the activation vector and `%alpha` is the per-element + slope vector. +- **outputs:** `%result` is the parametric-ReLU vector. +- **constraints and limitations:** Floating-point element types only on the + current A5 surface. + +--- + +##### `pto.vexpdif` + +- **syntax:** `%result = pto.vexpdif %input, %max, %mask, "EVEN|ODD" : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **A5 types:** input `f16` or `f32`, output `f32` +- **semantics:** Fused exp(x - max) for numerically stable softmax. + +```c +for (int i = 0; i < N; i++) + dst[i] = expf(src[i] - max[i]); +``` + +**Use case:** Softmax numerator computation with numerical stability. + +- **inputs:** `%input` is the source vector and `%max` is the broadcasted + subtraction term. `%part` selects `EVEN` or `ODD` for the + underlying hardware contract. +- **outputs:** `%result` is the fused `exp(input - max)` vector with `f32` + elements. +- **constraints and limitations:** Source vectors must be `f16` or `f32`, the + result vector must be `f32`, and source/result storage width must match. + +--- + +#### Fused Compute+Convert Ops + +##### `pto.vaxpy` + +- **syntax:** `%result = pto.vaxpy %src0, %src1, %alpha, %mask : !pto.vreg, !pto.vreg, T, !pto.mask -> !pto.vreg` +- **A5 types:** f16, f32 +- **semantics:** AXPY — scalar-vector multiply-add. + +```c +for (int i = 0; i < N; i++) + dst[i] = alpha * src0[i] + src1[i]; +``` + +- **inputs:** `%src0` is the scaled vector, `%src1` is the addend vector, + `%alpha` is the scalar multiplier, and `%mask` selects active lanes. +- **outputs:** `%result` is the fused AXPY result. +- **constraints and limitations:** Floating-point element types only on the + current documented surface. + +--- + +#### Extended Arithmetic + +##### `pto.vmull` + +- **syntax:** `%low, %high = pto.vmull %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.vreg` +- **A5 types:** i32/ui32 (native 32×32→64 widening multiply) +- **semantics:** Widening multiply with high/low results. + +```c +for (int i = 0; i < 64; i++) { + int64_t r = (int64_t)src0_i32[i] * (int64_t)src1_i32[i]; + dst_lo[i] = (int32_t)(r & 0xFFFFFFFF); + dst_hi[i] = (int32_t)(r >> 32); +} +``` + +- **inputs:** `%lhs` and `%rhs` are the source vectors and `%mask` selects + active lanes. +- **outputs:** `%low` and `%high` expose the widened-product low/high parts. +- **constraints and limitations:** The current documented A5 form is the native + widening 32x32->64 integer multiply family. + +--- + +##### `pto.vmula` + +- **syntax:** `%result = pto.vmula %acc, %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- **semantics:** Multiply-accumulate. + +```c +for (int i = 0; i < N; i++) + if (mask[i]) + dst[i] = acc[i] + lhs[i] * rhs[i]; +``` + +- **inputs:** `%acc` is the accumulator input, `%lhs` and `%rhs` are the + multiplicands, and `%mask` selects active lanes. +- **outputs:** `%result` is the multiply-accumulate result. +- **constraints and limitations:** `pto.vmula` is a fused multiply-accumulate + operation and is not always interchangeable with separate `vmul` plus `vadd`. + +--- + +#### Index Generation + +##### `pto.vci` + +- **syntax:** `%result = pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- **semantics:** Generate lane index vector. + +```c +for (int i = 0; i < N; i++) + dst[i] = base_index + i; +``` + +**Use case:** Generate indices for gather/scatter, argsort, etc. + +- **inputs:** `%index` is the scalar seed/base index. +- **outputs:** `%result` is the generated index vector. +- **constraints and limitations:** This page documents the arithmetic/indexing + use of the family; the conversion page also records the same opcode for + completeness. + +--- + +#### Sorting Operations + +##### `pto.vbitsort` + +- **syntax:** `pto.vbitsort %dest, %src, %indices, %repeat_times : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, index` +- **semantics:** Sort 32 region proposals by score and materialize sorted + proposal records into `%dest`. +- **inputs:** `%dest` is the UB destination buffer. `%src` is the UB score + buffer. `%indices` is the UB index buffer. `%repeat_times` is the repeat + count; each repeat processes the next adjacent group of 32 scores and 32 + indices. +- **outputs:** This op writes UB memory and returns no SSA value. Each output + record occupies 8 bytes: the upper 4 bytes hold the index and the lower + 4 bytes hold the score. For `f16` score forms, the score uses the lower + 2 bytes of that 4-byte score field and the upper 2 bytes are reserved. +- **constraints and limitations:** `%dest`, `%src`, and `%indices` MUST be + UB-backed pointers and SHOULD satisfy the backend alignment contract expected + by the A5 `VBS32` instruction. Scores are sorted in descending order, so the + highest score is written to the lowest destination address. Equal-score ties + preserve the earlier input proposal first. This is a UB helper, not a pure + `vreg -> vreg` op. + +--- + +##### `pto.vmrgsort4` + +- **syntax:** `pto.vmrgsort4 %dest, %src0, %src1, %src2, %src3, %count, %config : !pto.ptr, !pto.ptr, !pto.ptr, !pto.ptr, !pto.ptr, i64, i64` +- **semantics:** Merge-sort 4 pre-sorted input vectors. +- **inputs:** `%dest` is the UB destination, `%src0..%src3` are the four + pre-sorted UB inputs, `%count` is the number of valid elements, and `%config` + is the operation control word. +- **outputs:** This op writes UB memory and returns no SSA value. +- **constraints and limitations:** Inputs MUST already be sorted according to + the sort order encoded by `%config`. + +--- + +#### Current Implementation Surface Summary + +- `pto.vmull %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg, !pto.vreg` +- `pto.vmula %acc, %lhs, %rhs, %mask : !pto.vreg, !pto.vreg, !pto.vreg, !pto.mask -> !pto.vreg` +- `pto.vci %index {order = "ORDER"} : integer -> !pto.vreg` +- `pto.vbitsort %dest, %src, %indices, %repeat_times : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, index` +- `pto.vmrgsort4 %dest, %src0, %src1, %src2, %src3, %count, %config : !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, !pto.ptr<...>, i64, i64` + +--- + +#### Typical Usage + +```mlir +// Softmax with fused expdiff +%max_broadcast = pto.vlds %ub_max[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> +%exp_stable = pto.vexpdif %logits, %max_broadcast, %mask, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Leaky ReLU activation +%activated = pto.vlrelu %linear_out, %alpha_scalar, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Generate indices for argsort +%indices = pto.vci %c0 {order = "ASC"} : i32 -> !pto.vreg<64xi32> +``` + + + +### 14. Arith (Shared MLIR Dialect) + +> **Category:** Shared full scalar `arith` surface used around PTO ops +> **Dialect:** `arith` +> **Upstream Reference:** https://mlir.llvm.org/docs/Dialects/ArithOps/ + +The upstream MLIR `arith` dialect defines primitive arithmetic, comparison, select, and cast operations over signless integer, index, floating-point, and boolean-compatible scalar values. Within PTO micro Instruction code, the full scalar operation surface of `arith` is supported. These ops are used around PTO instructions to build constants, compute offsets and loop bounds, perform general scalar math, derive valid-shape metadata, and form predicates for `scf` control flow. + +These ops are part of the documented PTO micro Instruction surface, but they are not PTO ISA instructions. + +--- + +#### Role in PTO micro Instruction Code + +- materialize scalar constants used by PTO scalar operands and loop bounds +- compute scalar/index offsets for tensor views, partitioning, and dynamic shapes +- perform general scalar integer and floating-point math outside PTO vector/tile payload operations +- derive scalar predicates that guard `scf.if` or `scf.while` +- apply scalar casts, width changes, bitwise ops, and selects without introducing PTO-specific control ops + +Prefer PTO ops for vector or tile payload math. Use `arith` for scalar computation and bookkeeping that surrounds PTO regions. + +--- + +#### Supported Surface + +The documented PTO micro Instruction surface supports the full scalar operation surface of upstream `arith`. The upstream `arith` dialect reference remains authoritative for the exhaustive op-by-op syntax and semantics. The categories below summarize how that support is used in PTO micro Instruction code. + +| Category | Representative Ops | Typical Use in PTO micro Instruction Code | +|----------|--------------------|------------------| +| Constants | `arith.constant` | integer, floating-point, boolean, and `index` constants | +| Integer / Index Arithmetic | `arith.addi`, `arith.subi`, `arith.muli`, `arith.divsi`, `arith.divui`, `arith.ceildivsi`, `arith.ceildivui`, `arith.floordivsi`, `arith.remsi`, `arith.remui` | offsets, bounds, chunk sizes, scalar math | +| Floating-Point Arithmetic | `arith.addf`, `arith.subf`, `arith.mulf`, `arith.divf`, `arith.negf`, `arith.maximumf`, `arith.minimumf`, `arith.maxnumf`, `arith.minnumf` | scalar math around PTO regions | +| Bitwise / Shift Ops | `arith.andi`, `arith.ori`, `arith.xori`, `arith.shli`, `arith.shrsi`, `arith.shrui` | flags, masks, packed scalar fields | +| Comparisons / Select | `arith.cmpi`, `arith.cmpf`, `arith.select`, `arith.maxsi`, `arith.minui` | predicates, clamps, scalar muxes | +| Casts / Width Changes | `arith.index_cast`, `arith.index_castui`, `arith.extsi`, `arith.extui`, `arith.trunci`, `arith.sitofp`, `arith.uitofp`, `arith.fptosi`, `arith.fptoui`, `arith.extf`, `arith.truncf`, `arith.bitcast` | ABI glue, dynamic-shape plumbing, scalar type adaptation | + +--- + +#### Current PTOAS Coverage + +- the current repository examples are still dominated by constants, casts, integer/index arithmetic, compares, and selects because those are the most common surrounding-scalar patterns in existing kernels +- backend-specific tests such as the PTO shared-dialect fixture visibly exercise only a representative subset of `arith` ops in a single path +- the documented PTO micro Instruction source-level contract is nevertheless the full scalar `arith` surface, not just the index-heavy subset that appears most often in current samples + +This section therefore uses representative categories and examples instead of pretending that the supported `arith` surface is limited to the currently most common sample patterns. + +--- + +#### Typical Patterns + +##### Scalar Setup + +```mlir +%c0 = arith.constant 0 : index +%c1 = arith.constant 1 : index +%scale = arith.constant 2.0 : f32 +``` + +##### Dynamic Offset Computation + +```mlir +%vrow = arith.index_cast %valid_row : i32 to index +%chunk = arith.muli %row, %c32 : index +%tail = arith.subi %limit, %chunk : index +``` + +##### General Scalar Arithmetic + +```mlir +%sum_i = arith.addi %lhs_i, %rhs_i : i32 +%sum_f = arith.addf %lhs_f, %rhs_f : f32 +%prod_f = arith.mulf %sum_f, %scale : f32 +``` + +##### Scalar Predicate and Selection + +```mlir +%is_first = arith.cmpi eq, %i, %c0 : index +%active = arith.select %is_first, %first_count, %steady_count : index +``` + +##### Bitwise / Width Adaptation + +```mlir +%flags = arith.andi %flags0, %flags1 : i32 +%wide = arith.extui %flags : i32 to i64 +%shrunk = arith.trunci %wide : i64 to i16 +``` + +--- + +#### Authoring Guidance + +- treat upstream `arith` scalar semantics as the source of truth for supported scalar ops +- keep `arith` values scalar or `index` typed; do not use `arith` as a substitute for PTO vector/tile compute +- use `arith` for general scalar math, scalar comparisons, bitwise operations, and casts around PTO regions, not just for `index` arithmetic +- use `arith.cmpi` / `arith.cmpf` plus `scf.if` / `scf.while` for control flow, not ad hoc control intrinsics +- prefer `arith.index_cast` / `arith.index_castui` at ABI or shape boundaries where `index` is required, but do not read that as a restriction on the rest of scalar `arith` + + + +### 15. SCF (Shared MLIR Dialect) + +> **Category:** Shared structured control flow around PTO regions +> **Dialect:** `scf` +> **Upstream Reference:** https://mlir.llvm.org/docs/Dialects/SCFDialect/ + +The upstream MLIR `scf` dialect defines structured control flow operations with regions, including counted loops, conditional regions, and while-style loops. In PTO micro Instruction code, `scf` is the control shell around PTO ops: it sequences DMA, vector, and tile operations; carries scalar or tile state across iterations; and preserves analyzable control flow for PTO-specific analyses and lowerings. + +These ops are part of the documented PTO micro Instruction surface, but they are shared MLIR control-flow constructs rather than PTO ISA instructions. + +--- + +#### Supported Ops + +| Op | Role in PTO micro Instruction Code | Notes | +|----|------------------------|-------| +| `scf.for` | counted loops and loop-carried values | common structured counted loop form | +| `scf.if` | structured conditional execution | may yield values or act as side-effect-only branch | +| `scf.yield` | region terminator for `for` / `if` / `while` bodies | carries loop or branch results | +| `scf.while` | break-like or stateful loops | useful for source-level structured control | +| `scf.condition` | loop-continue / loop-exit decision for `scf.while` | placed in the "before" region | + +Ops such as `scf.execute_region`, `scf.forall`, or `scf.index_switch` are not part of the documented shared-dialect portion of the PTO micro Instruction surface here. + +--- + +#### Current PTOAS Coverage + +- `scf.for`, `scf.if`, and `scf.yield` are directly exercised in the shared-dialect PTO fixture and appear widely across PTO samples +- PTO synchronization and memory analyses explicitly reason about `scf.for`, `scf.if`, `scf.yield`, and `scf.while` +- `scf.while` and `scf.condition` appear in control-flow samples and are handled in PTO-to-EmitC control-flow lowering, but they are less broadly exercised than `for` / `if` on all backend paths + +--- + +#### Typical Patterns + +##### Counted Loop + +```mlir +scf.for %i = %c0 to %c4 step %c1 { + %offset = arith.muli %i, %c32 : index + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %v = pto.vlds %ub[%offset] : !pto.ptr -> !pto.vreg<64xf32> + %abs = pto.vabs %v, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %abs, %ub_out[%offset], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} +``` + +##### Counted Loop with Loop-Carried State + +```mlir +%final_alive = scf.for %i = %c0 to %c4 step %c1 + iter_args(%alive = %true) -> (i1) { + %break_now = arith.cmpi eq, %i, %c2 : index + %next_alive = scf.if %break_now -> (i1) { + scf.yield %false : i1 + } else { + scf.yield %alive : i1 + } + scf.yield %next_alive : i1 +} +``` + +##### Structured Conditional Region + +```mlir +%is_mode_a = arith.cmpi eq, %mode, %c0_i32 : i32 +scf.if %is_mode_a { + pto.tmuls ins(%data, %scale_a : !pto.tile_buf<...>, f32) outs(%data : !pto.tile_buf<...>) +} else { + pto.tadds ins(%data, %bias_b : !pto.tile_buf<...>, f32) outs(%data : !pto.tile_buf<...>) +} +``` + +##### While-Style Break Loop + +```mlir +%final:2 = scf.while (%i = %c0, %alive = %true) : (index, i1) -> (index, i1) { + %lt = arith.cmpi slt, %i, %c4 : index + %go = arith.andi %lt, %alive : i1 + scf.condition(%go) %i, %alive : index, i1 +} do { +^bb0(%i2: index, %alive2: i1): + %next_i = arith.addi %i2, %c1 : index + scf.yield %next_i, %alive2 : index, i1 +} +``` + +--- + +#### Authoring Guidance + +- use `scf.for` for regular counted loops and loop-carried scalar/tile state +- use `scf.if` for structured branching around PTO regions instead of inventing PTO-specific branch ops +- keep region results explicit with `scf.yield`; this is important for PTO analyses that track carried buffers and aliasing +- use `scf.while` only when a counted loop cannot express the control cleanly; `scf.for` remains the more common and better-exercised form in the current repository +- build branch predicates and loop conditions with `arith` ops, not PTO vector masks, unless the control decision truly comes from a scalarized value + +## Supported Data Types + +| Type | Bits | vreg Lanes | Description | +|------|------|-----------|-------------| +| `i8` / `si8` / `ui8` | 8 | 256 | Signless/signed/unsigned 8-bit integer | +| `i16` / `si16` / `ui16` | 16 | 128 | Signless/signed/unsigned 16-bit integer | +| `f16` | 16 | 128 | IEEE 754 half precision | +| `bf16` | 16 | 128 | Brain floating point | +| `i32` / `si32` / `ui32` | 32 | 64 | Signless/signed/unsigned 32-bit integer | +| `f32` | 32 | 64 | IEEE 754 single precision | +| `i64` / `si64` / `ui64` | 64 | 32 | Signless/signed/unsigned 64-bit integer | + +--- + +## Common Patterns + +### Softmax (Numerically Stable) + +```mlir +// 1. Find max +%max_vec = pto.vcmax %logits, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +pto.vsts %max_vec, %ub_tmp[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +%max_bc = pto.vlds %ub_tmp[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> + +// 2. exp(x - max) using fused op +%exp = pto.vexpdif %logits, %max_bc, %mask, "ODD" : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// 3. Sum +%sum = pto.vcadd %exp, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +pto.vsts %sum, %ub_tmp[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +%sum_bc = pto.vlds %ub_tmp[%c0] {dist = "BRC"} : !pto.ptr -> !pto.vreg<64xf32> + +// 4. Divide +%softmax = pto.vdiv %exp, %sum_bc, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +``` + +### ReLU Variants + +```mlir +// Standard ReLU +%relu = pto.vrelu %input, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Leaky ReLU (scalar alpha) +%lrelu = pto.vlrelu %input, %alpha, %mask : !pto.vreg<64xf32>, f32, !pto.mask -> !pto.vreg<64xf32> + +// Parametric ReLU (per-element alpha) +%prelu = pto.vprelu %input, %alpha_vec : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + +``` + +### Data Layout Conversion + +```mlir +// AoS → SoA (deinterleave) +%x, %y = pto.vldsx2 %ub_xy[%offset], "DINTLV" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +// SoA → AoS (interleave) +pto.vstsx2 %x, %y, %ub_xy[%offset], "INTLV", %all_mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.ptr, index, !pto.mask +``` + +--- + +--- + +## Quick Reference by Category + +### Memory Operations + +| Operation | Group | Description | +|-----------|-------|-------------| +| GM→UB DMA | 2 | `pto.copy_gm_to_ubuf` | +| UB→GM DMA | 2 | `pto.copy_ubuf_to_gm` | +| UB→UB Copy | 2 | `pto.copy_ubuf_to_ubuf` | +| Contiguous Load | 3 | `pto.vlds` with `NORM` dist | +| Broadcast Load | 3 | `pto.vlds` with `BRC` family dist | +| Gather | 3 | `pto.vgather2`, `pto.vgatherb` | +| Contiguous Store | 3 | `pto.vsts` with `NORM` dist | +| Scatter | 3 | `pto.vscatter` | + +### Compute Operations + +| Operation | Group | Description | +|-----------|-------|-------------| +| Element-wise Arithmetic | 6, 7 | `pto.vadd`, `pto.vmul`, `pto.vabs`, etc. | +| Scalar Operations | 8 | `pto.vadds`, `pto.vmuls`, etc. | +| Transcendental | 6 | `pto.vexp`, `pto.vln`, `pto.vsqrt`, etc. | +| Reduction | 10 | `pto.vcadd`, `pto.vcmax`, `pto.vcmin` | +| Comparison | 11 | `pto.vcmp`, `pto.vcmps` | +| Selection | 11 | `pto.vsel`, `pto.vselr` | + +### Type & Data Manipulation + +| Operation | Group | Description | +|-----------|-------|-------------| +| Type Conversion | 9 | `pto.vcvt` | +| Interleave/Deinterleave | 12 | `pto.vintlv`, `pto.vdintlv` | +| Interleave/Deinterleave (not A5) | 12 | `pto.vintlvv2`, `pto.vdintlvv2` | + +### Synchronization + +| Operation | Group | Description | +|-----------|-------|-------------| +| Intra-core Sync | 1 | `pto.set_flag`, `pto.wait_flag` | +| Pipeline Buffer Sync | 1 | `pto.get_buf`, `pto.rls_buf` | + +### Scalar & Control Operations + +Group 14 covers the full scalar `arith` surface. The rows below list common PTO micro Instruction patterns rather than an exhaustive partition of `arith` ops. + +| Operation | Group | Description | +|-----------|-------|-------------| +| Scalar Constants | 14 | `arith.constant` | +| Scalar Integer / Index Arithmetic | 14 | `arith.addi`, `arith.subi`, `arith.muli`, `arith.divsi`, `arith.remui`, `arith.ceildivsi`, etc. | +| Scalar Floating-Point Arithmetic | 14 | `arith.addf`, `arith.subf`, `arith.mulf`, `arith.divf`, `arith.maximumf`, etc. | +| Scalar Compare & Select | 14 | `arith.cmpi`, `arith.cmpf`, `arith.select` | +| Scalar Casts / Width Changes | 14 | `arith.index_cast`, `arith.index_castui`, `arith.extsi`, `arith.extui`, `arith.trunci`, `arith.sitofp`, etc. | +| Scalar Bitwise / Shift Ops | 14 | `arith.andi`, `arith.ori`, `arith.xori`, `arith.shli`, `arith.shrsi`, `arith.shrui`, etc. | +| Counted Loops | 15 | `scf.for` | +| Conditional Regions | 15 | `scf.if`, `scf.yield` | +| Break-like Structured Loops | 15 | `scf.while`, `scf.condition`, `scf.yield` | diff --git a/tilelang-dsl/examples/README.md b/tilelang-dsl/examples/README.md new file mode 100644 index 000000000..970b382d5 --- /dev/null +++ b/tilelang-dsl/examples/README.md @@ -0,0 +1,38 @@ +TileLang DSL examples live here. + +Examples in this subtree should import `tilelang_dsl` as their package +entrypoint once the package wiring is added. + +Current examples: +- `v1_emit_mlir_demo.py`: minimal descriptor/materialization demo +- `v1_elementwise_tail_demo.py`: guide-aligned elementwise authoring demo that + covers DMA, explicit `strict_vecscope`, dynamic loop bound, and typed tail + mask lowering +- `v1_template_slot_multiop_demo.py`: shared kernel-body demo for + `tadd`/`tsub`/`tmul`/`tdiv` using `ops=[...]`, `templates={...}`, and + `pto.tpl("core", ...)` +- `v1_tadd_implicit_vecscope_demo.py`: advanced-mode flattened `TADD` example + with implicit `pto.vecscope` inference, dynamic Tile `valid_shape`, generic + dtype selection, partial-dynamic `valid_shape` modes, and `vlds`/`vsts` + tile indexing sugar +- `v1_tbinop_2d_nopostupdate_demo.py`: a representative TileLang DSL v1 + expansion of `pto::TBinOps_2D_NoPostUpdate` using `vadd` +- `v1_verify_smoke.py`: minimal verify smoke that is expected to pass the repo + `ptoas --pto-backend=vpto` legality path + +Typical usage from the repository root: + +```bash +python3 tilelang-dsl/examples/v1_emit_mlir_demo.py +python3 tilelang-dsl/examples/v1_emit_mlir_demo.py /tmp/tilelang_demo.mlir +PYTHONPATH=$PWD/tilelang-dsl/python python3 tilelang-dsl/examples/v1_elementwise_tail_demo.py +PYTHONPATH=$PWD/tilelang-dsl/python python3 tilelang-dsl/examples/v1_template_slot_multiop_demo.py +PYTHONPATH=$PWD/tilelang-dsl/python python3 tilelang-dsl/examples/v1_template_slot_multiop_demo.py tsub +PYTHONPATH=$PWD/tilelang-dsl/python python3 tilelang-dsl/examples/v1_template_slot_multiop_demo.py tmul f16 +PYTHONPATH=$PWD/tilelang-dsl/python python3 tilelang-dsl/examples/v1_tadd_implicit_vecscope_demo.py +PYTHONPATH=$PWD/tilelang-dsl/python python3 tilelang-dsl/examples/v1_tadd_implicit_vecscope_demo.py f16 +PYTHONPATH=$PWD/tilelang-dsl/python python3 tilelang-dsl/examples/v1_tadd_implicit_vecscope_demo.py f16 rows +PYTHONPATH=$PWD/tilelang-dsl/python python3 tilelang-dsl/examples/v1_tadd_implicit_vecscope_demo.py f16 cols +PYTHONPATH=$PWD/tilelang-dsl/python python3 tilelang-dsl/examples/v1_tbinop_2d_nopostupdate_demo.py +PYTHONPATH=$PWD/tilelang-dsl/python python3 tilelang-dsl/examples/v1_verify_smoke.py +``` diff --git a/tilelang-dsl/examples/tadd_demo.py b/tilelang-dsl/examples/tadd_demo.py new file mode 100644 index 000000000..db6c32d86 --- /dev/null +++ b/tilelang-dsl/examples/tadd_demo.py @@ -0,0 +1,81 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL v1 demo: pto.tadd (element-wise add) using Tile parameters. + +This example intentionally enables `advanced=True` because it demonstrates +explicit `strict_vecscope`. Stable kernels can rely on inferred `pto.vecscope` +and `tile[row, col:]` indexing sugar without opting into advanced mode. +""" + +import sys +from pathlib import Path + + +def _import_tilelang_dsl(): + try: + import tilelang_dsl as pto + return pto + except ModuleNotFoundError: + repo_root = Path(__file__).resolve().parents[2] + sys.path.insert(0, str(repo_root / "python")) + import tilelang_dsl as pto + return pto + + +pto = _import_tilelang_dsl() + + +@pto.vkernel( + op="pto.tadd", + dtypes=[(pto.f32, pto.f32, pto.f32)], + name="template_tadd", + advanced=True, +) +def template_tadd(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + # v1 strict_vecscope: all referenced values must be passed in explicitly, + # and only scalar offsets (not 2D subscripts) are supported for vlds/vsts. + with pto.strict_vecscope(src0, src1, dst, 0, 256, 64) as ( + a, b, c, lb, ub, step + ): + for j in range(lb, ub, step): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + vec_a = pto.vlds(a, j) + vec_b = pto.vlds(b, j) + result = pto.vadd(vec_a, vec_b, mask) + pto.vsts(result, c, j, mask) + + +def main(argv: list[str]) -> int: + specialized = template_tadd.specialize( + src0=pto.TileSpecialization( + shape=(16, 64), + memory_space=pto.MemorySpace.UB, + ), + src1=pto.TileSpecialization( + shape=(16, 64), + memory_space=pto.MemorySpace.UB, + ), + dst=pto.TileSpecialization( + shape=(16, 64), + memory_space=pto.MemorySpace.UB, + ), + ) + + if len(argv) == 2: + output_path = Path(argv[1]) + specialized.emit(output_path) + print(f"wrote MLIR to {output_path}") + return 0 + + print(specialized.mlir_text()) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main(sys.argv)) diff --git a/tilelang-dsl/examples/v1_elementwise_tail_demo.py b/tilelang-dsl/examples/v1_elementwise_tail_demo.py new file mode 100644 index 000000000..eb15b0db2 --- /dev/null +++ b/tilelang-dsl/examples/v1_elementwise_tail_demo.py @@ -0,0 +1,84 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""Guide-aligned TileLang DSL v1 elementwise authoring demo.""" + +import sys +from pathlib import Path + + +def _import_tilelang_dsl(): + repo_root = Path(__file__).resolve().parents[2] + candidates = ( + repo_root / "tilelang-dsl" / "python", + repo_root / "build" / "python", + ) + for candidate in reversed(candidates): + if candidate.is_dir(): + sys.path.insert(0, str(candidate)) + import tilelang_dsl as pto + + return pto + + +pto = _import_tilelang_dsl() + + +@pto.vkernel( + op="eltwise", + dtypes=[(pto.f32, pto.f32, pto.f32, pto.i32)], + name="tilelang_v1_elementwise_tail_demo", + advanced=True, +) +def kernel(inp: pto.TensorView, out: pto.TensorView, tile: pto.Tile, remaining: pto.i32): + rows = inp.shape[0] + pto.dma_load(inp[0:rows, 0:16], tile) + with pto.strict_vecscope(tile, tile, remaining, 0, rows, 64) as ( + src, + dst, + rem, + lb, + ub, + step, + ): + for lane in range(lb, ub, step): + mask, rem = pto.make_mask(pto.f32, rem) + vec = pto.vlds(src, lane) + pto.vsts(vec, dst, lane, mask) + pto.dma_store(tile, out[0:rows, 0:16]) + return None + + +def build_specialized_kernel(): + return kernel.specialize( + tile=pto.TileSpecialization( + shape=(16, 16), + memory_space=pto.MemorySpace.UB, + ) + ) + + +def main(argv) -> int: + specialized = build_specialized_kernel() + + if len(argv) > 2: + print(f"usage: {Path(argv[0]).name} [output.mlir]", file=sys.stderr) + return 2 + + if len(argv) == 2: + output_path = Path(argv[1]) + specialized.emit(output_path) + print(f"wrote MLIR to {output_path}") + return 0 + + print(specialized.mlir_text()) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main(sys.argv)) diff --git a/tilelang-dsl/examples/v1_emit_mlir_demo.py b/tilelang-dsl/examples/v1_emit_mlir_demo.py new file mode 100644 index 000000000..78c43578b --- /dev/null +++ b/tilelang-dsl/examples/v1_emit_mlir_demo.py @@ -0,0 +1,69 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""Minimal TileLang DSL v1 demo that materializes a kernel into MLIR.""" + +import sys +from pathlib import Path + + +def _import_tilelang_dsl(): + repo_root = Path(__file__).resolve().parents[2] + candidates = ( + repo_root / "tilelang-dsl" / "python", + repo_root / "build" / "python", + ) + for candidate in reversed(candidates): + if candidate.is_dir(): + sys.path.insert(0, str(candidate)) + import tilelang_dsl as pto + + return pto + + +pto = _import_tilelang_dsl() + + +@pto.vkernel( + op="eltwise_with_tile", + dtypes=[(pto.f32, pto.f16, pto.i32)], + name="tilelang_v1_demo_kernel", +) +def kernel(inp: pto.TensorView, tile: pto.Tile, scale: pto.i32): + return None + + +def build_specialized_kernel(): + return kernel.specialize( + tile=pto.TileSpecialization( + shape=(16, 32), + memory_space=pto.MemorySpace.UB, + config=pto.TileConfig.from_mapping({"layout": "row_major"}), + ) + ) + + +def main(argv: list[str]) -> int: + specialized = build_specialized_kernel() + + if len(argv) > 2: + print(f"usage: {Path(argv[0]).name} [output.mlir]", file=sys.stderr) + return 2 + + if len(argv) == 2: + output_path = Path(argv[1]) + specialized.emit(output_path) + print(f"wrote MLIR to {output_path}") + return 0 + + print(specialized.mlir_text()) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main(sys.argv)) diff --git a/tilelang-dsl/examples/v1_tadd_implicit_vecscope_demo.py b/tilelang-dsl/examples/v1_tadd_implicit_vecscope_demo.py new file mode 100644 index 000000000..ae81135ac --- /dev/null +++ b/tilelang-dsl/examples/v1_tadd_implicit_vecscope_demo.py @@ -0,0 +1,161 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""Flattened TileLang DSL advanced-mode version of A5 `TADD_IMPL`. + +This example mirrors the user-facing `TADD_IMPL -> TAdd -> BinaryInstr -> +TBinOps_2D_NoPostUpdate` flow from `pto/npu/a5/TAdd.hpp`, but spells the final +2D row-major vector body directly in Python: + +- top-level interface uses `dst, src0, src1` Tile parameters like `TADD` +- Tile specializations keep a static physical tile shape while exposing a + dynamic `valid_shape` input at materialization time; the demo can model + fully dynamic or partially dynamic `(valid_rows, valid_cols)` profiles +- the kernel surface is dtype-polymorphic and can be selected for any supported + vector dtype with `pto.select_kernel(...)` +- implicit `pto.vecscope` inference and tile indexing sugar cover the base + vector authoring path; this demo also keeps `advanced=True` enabled because it + lives alongside the matcher/advanced-surface examples +- `pto.vlds(tile[row, col:])` / `pto.vsts(vec, tile[row, col:], mask)` use + tile indexing sugar instead of manual offset arithmetic +""" + +import sys +from pathlib import Path + + +def _import_tilelang_dsl(): + repo_root = Path(__file__).resolve().parents[2] + candidates = ( + repo_root / "tilelang-dsl" / "python", + repo_root / "build" / "python", + ) + for candidate in reversed(candidates): + if candidate.is_dir(): + sys.path.insert(0, str(candidate)) + import tilelang_dsl as pto + + return pto + + +pto = _import_tilelang_dsl() +T = pto.TypeVar("T") +SUPPORTED_DTYPES = { + "i8": pto.i8, + "i16": pto.i16, + "i32": pto.i32, + "f16": pto.f16, + "bf16": pto.bf16, + "f32": pto.f32, +} +VALID_SHAPE_MODES = ("both", "rows", "cols", "static") +TILE_SHAPE = (8, 64) + + +@pto.vkernel( + op="tadd", + dtypes=[(T, T, T)], + advanced=True, + name="tilelang_advanced_tadd_demo", +) +def kernel(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): + # Flattened equivalent of the TAddCheck/TADD_IMPL parameter plumbing. + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + summed = pto.vadd(lhs, rhs, mask) + pto.vsts(summed, dst[row, col:], mask) + return None + + +def _resolve_valid_shape_profile(mode: str) -> tuple[object, object]: + rows, cols = TILE_SHAPE + if mode == "both": + return ("valid_rows", "valid_cols") + if mode == "rows": + return ("valid_rows", cols) + if mode == "cols": + return (rows, "valid_cols") + if mode == "static": + return TILE_SHAPE + raise ValueError(f"unsupported valid_shape mode '{mode}'") + + +def build_specialized_kernel(dtype=pto.f32, valid_shape_mode="both"): + selected = pto.select_kernel("a5", "tadd", (dtype, dtype, dtype)) + valid_shape = _resolve_valid_shape_profile(valid_shape_mode) + return selected.specialize( + src0=pto.TileSpecialization( + shape=TILE_SHAPE, + memory_space=pto.MemorySpace.UB, + valid_shape=valid_shape, + ), + src1=pto.TileSpecialization( + shape=TILE_SHAPE, + memory_space=pto.MemorySpace.UB, + valid_shape=valid_shape, + ), + dst=pto.TileSpecialization( + shape=TILE_SHAPE, + memory_space=pto.MemorySpace.UB, + valid_shape=valid_shape, + ), + ) + + +def _parse_cli(argv): + if len(argv) > 4: + return None, None, None + + dtype = pto.f32 + valid_shape_mode = "both" + output_path = None + args = list(argv[1:]) + for arg in args: + if arg in SUPPORTED_DTYPES: + dtype = SUPPORTED_DTYPES[arg] + continue + if arg in VALID_SHAPE_MODES: + valid_shape_mode = arg + continue + if output_path is None: + output_path = Path(arg) + continue + return None, None, None + return dtype, valid_shape_mode, output_path + + +def main(argv) -> int: + dtype, valid_shape_mode, output_path = _parse_cli(argv) + if dtype is None: + supported = ", ".join(SUPPORTED_DTYPES) + valid_shape_modes = ", ".join(VALID_SHAPE_MODES) + print( + f"usage: {Path(argv[0]).name} [{supported}] [{valid_shape_modes}] [output.mlir]", + file=sys.stderr, + ) + return 2 + specialized = build_specialized_kernel(dtype=dtype, valid_shape_mode=valid_shape_mode) + + if output_path is not None: + specialized.emit(output_path) + print(f"wrote MLIR to {output_path}") + return 0 + + print(specialized.mlir_text()) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main(sys.argv)) diff --git a/tilelang-dsl/examples/v1_tbinop_2d_nopostupdate_demo.py b/tilelang-dsl/examples/v1_tbinop_2d_nopostupdate_demo.py new file mode 100644 index 000000000..6b2356d28 --- /dev/null +++ b/tilelang-dsl/examples/v1_tbinop_2d_nopostupdate_demo.py @@ -0,0 +1,133 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""Representative TileLang DSL v1 form of `TBinOps_2D_NoPostUpdate`. + +This example mirrors the key structure from `pto::TBinOps_2D_NoPostUpdate`: +- two source UB tiles and one destination UB tile +- row-major 2D traversal +- explicit non-post-update absolute offsets: `row * row_stride + lane` +- binary vector op lowered as `pto.vadd` + +The TileLang DSL surface does not expose the C++ helper template directly, so +this example spells out the row/repeat loops and tail mask construction in the +authored Python kernel. +""" + +import sys +from pathlib import Path + + +def _import_tilelang_dsl(): + repo_root = Path(__file__).resolve().parents[2] + candidates = ( + repo_root / "tilelang-dsl" / "python", + repo_root / "build" / "python", + ) + for candidate in reversed(candidates): + if candidate.is_dir(): + sys.path.insert(0, str(candidate)) + import tilelang_dsl as pto + + return pto + + +pto = _import_tilelang_dsl() + + +@pto.vkernel( + op="eltwise", + dtypes=[(pto.f32, pto.f32, pto.f32, pto.f32, pto.f32, pto.f32)], + name="tilelang_v1_tbinop_2d_nopostupdate_demo", + advanced=True, +) +def kernel( + lhs_gm: pto.TensorView, + rhs_gm: pto.TensorView, + out_gm: pto.TensorView, + lhs_tile: pto.Tile, + rhs_tile: pto.Tile, + dst_tile: pto.Tile, +): + rows = lhs_gm.shape[0] + cols = lhs_gm.shape[1] + row_stride = lhs_tile.shape[1] + + pto.dma_load(lhs_gm[0:rows, 0:cols], lhs_tile) + pto.dma_load(rhs_gm[0:rows, 0:cols], rhs_tile) + + with pto.strict_vecscope( + lhs_tile, + rhs_tile, + dst_tile, + rows, + cols, + row_stride, + 0, + rows, + 1, + ) as ( + lhs, + rhs, + dst, + valid_rows, + valid_cols, + stride, + row_lb, + row_ub, + row_step, + ): + for row in range(row_lb, row_ub, row_step): + for lane in range(0, valid_cols, 64): + offset = row * stride + lane + mask, next_remaining = pto.make_mask(pto.f32, valid_cols - lane) + lhs_vec = pto.vlds(lhs, offset) + rhs_vec = pto.vlds(rhs, offset) + summed = pto.vadd(lhs_vec, rhs_vec, mask) + pto.vsts(summed, dst, offset, mask) + + pto.dma_store(dst_tile, out_gm[0:rows, 0:cols]) + return None + + +def build_specialized_kernel(): + return kernel.specialize( + lhs_tile=pto.TileSpecialization( + shape=(8, 64), + memory_space=pto.MemorySpace.UB, + ), + rhs_tile=pto.TileSpecialization( + shape=(8, 64), + memory_space=pto.MemorySpace.UB, + ), + dst_tile=pto.TileSpecialization( + shape=(8, 64), + memory_space=pto.MemorySpace.UB, + ), + ) + + +def main(argv) -> int: + specialized = build_specialized_kernel() + + if len(argv) > 2: + print(f"usage: {Path(argv[0]).name} [output.mlir]", file=sys.stderr) + return 2 + + if len(argv) == 2: + output_path = Path(argv[1]) + specialized.emit(output_path) + print(f"wrote MLIR to {output_path}") + return 0 + + print(specialized.mlir_text()) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main(sys.argv)) diff --git a/tilelang-dsl/examples/v1_template_slot_multiop_demo.py b/tilelang-dsl/examples/v1_template_slot_multiop_demo.py new file mode 100644 index 000000000..ab2a555fb --- /dev/null +++ b/tilelang-dsl/examples/v1_template_slot_multiop_demo.py @@ -0,0 +1,154 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""Shared-kernel-body TileLang DSL v1 demo using template slots. + +This example shows the recommended authoring pattern for a small family of +binary elementwise ops that share the same traversal, mask, load, and store +structure: + +- one `@pto.vkernel` descriptor matches multiple concrete ops via `ops=[...]` +- `templates={"core": ...}` maps each concrete op to its real `pto.*` vector op +- the kernel body uses a single `pto.tpl("core", ...)` placeholder call +- `pto.select_kernel(...)` binds the concrete op before materialization +""" + +import sys +from pathlib import Path + + +def _import_tilelang_dsl(): + repo_root = Path(__file__).resolve().parents[2] + candidates = ( + repo_root / "tilelang-dsl" / "python", + repo_root / "build" / "python", + ) + for candidate in reversed(candidates): + if candidate.is_dir(): + sys.path.insert(0, str(candidate)) + import tilelang_dsl as pto + + return pto + + +pto = _import_tilelang_dsl() +T = pto.TypeVar("T") +SUPPORTED_DTYPES = { + "i8": pto.i8, + "i16": pto.i16, + "i32": pto.i32, + "f16": pto.f16, + "bf16": pto.bf16, + "f32": pto.f32, +} +SUPPORTED_OPS = ( + "tadd", + "tsub", + "tmul", + "tdiv", +) +TILE_SHAPE = (8, 64) + + +@pto.vkernel( + ops=list(SUPPORTED_OPS), + dtypes=[(T, T, T)], + advanced=True, + templates={ + "core": { + "tadd": "vadd", + "tsub": "vsub", + "tmul": "vmul", + "tdiv": "vdiv", + } + }, + name="tilelang_template_slot_multiop_demo", +) +def kernel(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + out = pto.tpl("core", lhs, rhs, mask) + pto.vsts(out, dst[row, col:], mask) + return None + + +def build_specialized_kernel(op_name="tadd", dtype=pto.f32): + if op_name not in SUPPORTED_OPS: + raise ValueError(f"unsupported op '{op_name}'") + selected = pto.select_kernel("a5", op_name, (dtype, dtype, dtype)) + return selected.specialize( + src0=pto.TileSpecialization( + shape=TILE_SHAPE, + memory_space=pto.MemorySpace.UB, + valid_shape=("valid_rows", "valid_cols"), + ), + src1=pto.TileSpecialization( + shape=TILE_SHAPE, + memory_space=pto.MemorySpace.UB, + valid_shape=("valid_rows", "valid_cols"), + ), + dst=pto.TileSpecialization( + shape=TILE_SHAPE, + memory_space=pto.MemorySpace.UB, + valid_shape=("valid_rows", "valid_cols"), + ), + ) + + +def _parse_cli(argv): + if len(argv) > 4: + return None, None, None + + op_name = "tadd" + dtype = pto.f32 + output_path = None + for arg in argv[1:]: + if arg in SUPPORTED_OPS: + op_name = arg + continue + if arg in SUPPORTED_DTYPES: + dtype = SUPPORTED_DTYPES[arg] + continue + if output_path is None: + output_path = Path(arg) + continue + return None, None, None + return op_name, dtype, output_path + + +def main(argv) -> int: + op_name, dtype, output_path = _parse_cli(argv) + if op_name is None: + supported_ops = ", ".join(SUPPORTED_OPS) + supported_dtypes = ", ".join(SUPPORTED_DTYPES) + print( + f"usage: {Path(argv[0]).name} [{supported_ops}] [{supported_dtypes}] [output.mlir]", + file=sys.stderr, + ) + return 2 + + specialized = build_specialized_kernel(op_name=op_name, dtype=dtype) + + if output_path is not None: + specialized.emit(output_path) + print(f"wrote MLIR to {output_path}") + return 0 + + print(specialized.mlir_text()) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main(sys.argv)) diff --git a/tilelang-dsl/examples/v1_verify_smoke.py b/tilelang-dsl/examples/v1_verify_smoke.py new file mode 100644 index 000000000..bf770bce1 --- /dev/null +++ b/tilelang-dsl/examples/v1_verify_smoke.py @@ -0,0 +1,75 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""Minimal TileLang DSL v1 verify smoke for the repo PTOAS legality path.""" + +import sys +from pathlib import Path + + +def _import_tilelang_dsl(): + repo_root = Path(__file__).resolve().parents[2] + candidates = ( + repo_root / "tilelang-dsl" / "python", + repo_root / "build" / "python", + ) + for candidate in reversed(candidates): + if candidate.is_dir(): + sys.path.insert(0, str(candidate)) + import tilelang_dsl as pto + + return pto + + +pto = _import_tilelang_dsl() + + +@pto.vkernel( + op="eltwise", + dtypes=[(pto.f32, pto.f32)], + name="tilelang_v1_verify_smoke", +) +def kernel(inp: pto.TensorView, tile: pto.Tile): + return None + + +def build_specialized_kernel(): + return kernel.specialize( + tile=pto.TileSpecialization( + shape=(8, 16), + memory_space=pto.MemorySpace.UB, + ) + ) + + +def main(argv) -> int: + specialized = build_specialized_kernel() + + if len(argv) > 2: + print(f"usage: {Path(argv[0]).name} [output.mlir]", file=sys.stderr) + return 2 + + if len(argv) == 2: + output_path = Path(argv[1]) + specialized.emit(output_path) + print(f"wrote MLIR to {output_path}") + return 0 + + result = specialized.verify() + print(f"status={result.status}") + print(f"available={result.available}") + print(f"passed={result.passed}") + if result.command is not None: + print("command=" + " ".join(result.command)) + if result.message: + print(f"message={result.message}") + return 0 if result else 1 + + +if __name__ == "__main__": + raise SystemExit(main(sys.argv)) diff --git a/tilelang-dsl/python/README.md b/tilelang-dsl/python/README.md new file mode 100644 index 000000000..39272be3d --- /dev/null +++ b/tilelang-dsl/python/README.md @@ -0,0 +1,3 @@ +This directory hosts the TileLang DSL Python package sources. + +The package root is `tilelang_dsl/`. diff --git a/tilelang-dsl/python/tilelang_dsl/__init__.py b/tilelang-dsl/python/tilelang_dsl/__init__.py new file mode 100644 index 000000000..05510dcf4 --- /dev/null +++ b/tilelang-dsl/python/tilelang_dsl/__init__.py @@ -0,0 +1,185 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""TileLang DSL v1 package.""" + +from .kernel import ( + BoundKernelParameter, + InlineProcDescriptor, + KernelRegistry, + KernelSelectionCandidateMetadata, + KernelSelectionReport, + MaterializedMLIRModule, + TileLangFrontendError, + VKernelDescriptor, + ckernel, + inline_proc, + select_kernel, + vkernel, +) +from .types import ( + AlignType, + AnyFloat, + AnyInt, + AnyMask, + AnyType, + BarrierType, + BLayout, + DeinterleaveDist, + EVENT, + InterleaveDist, + PIPE, + Event, + MaskType, + MemorySpace, + MaskPattern, + CmpMode, + FractalMode, + PAT, + PadMode, + PredicatePart, + PositionMode, + OrderMode, + PadValue, + VcvtPartMode, + VcvtRoundMode, + VcvtSatMode, + VLoadDist, + PointerType, + PostUpdateMode, + Pipe, + PredicateDist, + VStoreDist, + ScalarType, + SLayout, + TensorView, + PartitionTensorView, + Tile, + TileConfig, + TileSpecialization, + TypeVar, + TypeVariable, + VectorType, + VRegType, + WildcardType, + bf16, + constexpr, + bytewidth, + elements_per_vreg, + f16, + f32, + get_op_attr, + get_lanes, + i1, + i8, + si8, + ui8, + i16, + si16, + ui16, + i32, + si32, + ui32, + i64, + si64, + ui64, + mask_b8, + mask_b16, + mask_b32, + ptr, + align, + vector, + vreg, +) + +__all__ = [ + "BoundKernelParameter", + "InlineProcDescriptor", + "KernelRegistry", + "KernelSelectionCandidateMetadata", + "KernelSelectionReport", + "MaterializedMLIRModule", + "TileLangFrontendError", + "VKernelDescriptor", + "ckernel", + "inline_proc", + "select_kernel", + "vkernel", + "ScalarType", + "WildcardType", + "TypeVariable", + "TypeVar", + "TensorView", + "PartitionTensorView", + "Tile", + "PointerType", + "VectorType", + "VRegType", + "MaskType", + "AlignType", + "ptr", + "vector", + "vreg", + "align", + "MemorySpace", + "Pipe", + "Event", + "PIPE", + "EVENT", + "MaskPattern", + "PredicateDist", + "VLoadDist", + "VStoreDist", + "PredicatePart", + "CmpMode", + "PAT", + "BarrierType", + "BLayout", + "DeinterleaveDist", + "InterleaveDist", + "PadMode", + "FractalMode", + "PadValue", + "PositionMode", + "OrderMode", + "SLayout", + "VcvtRoundMode", + "VcvtSatMode", + "VcvtPartMode", + "PostUpdateMode", + "TileConfig", + "TileSpecialization", + "i1", + "i8", + "si8", + "ui8", + "i16", + "si16", + "ui16", + "i32", + "si32", + "ui32", + "i64", + "si64", + "ui64", + "f16", + "bf16", + "f32", + "AnyFloat", + "AnyInt", + "AnyType", + "AnyMask", + "mask_b8", + "mask_b16", + "mask_b32", + "constexpr", + "get_op_attr", + "bytewidth", + "get_lanes", + "elements_per_vreg", +] diff --git a/tilelang-dsl/python/tilelang_dsl/expand_helper.py b/tilelang-dsl/python/tilelang_dsl/expand_helper.py new file mode 100644 index 000000000..a43f2f3a3 --- /dev/null +++ b/tilelang-dsl/python/tilelang_dsl/expand_helper.py @@ -0,0 +1,553 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""CLI helper invoked by ExpandTileOp to instantiate a tilelang DSL template. + +Usage: + python3 -m tilelang_dsl.expand_helper \ + --template-dir /path/to/templates \ + --target a5 \ + --op pto.tadd \ + --dtype f32 \ + --shape 16,64 \ + --memory-space ub + +Scans --template-dir for .py files, finds a @vkernel whose `op` matches, +specializes every Tile parameter with the given shape/memory_space, and +prints the materialized MLIR module to stdout. +""" + +from __future__ import annotations + +import argparse +from contextlib import contextmanager +import importlib +import importlib.util +import json +import sys +from pathlib import Path + +from .kernel import ( + KernelRegistry, + VKernelDescriptor, + _match_descriptor_dtype_signature, + select_kernel, +) +from .types import MemorySpace, ScalarType, TileConfig, TileSpecialization + + +_DTYPE_MAP: dict[str, ScalarType] = {} + + +def _populate_dtype_map() -> None: + from . import types as _t + + for name in ( + "f16", + "bf16", + "f32", + "i8", + "si8", + "ui8", + "i16", + "si16", + "ui16", + "i32", + "si32", + "ui32", + "i64", + "si64", + "ui64", + ): + obj = getattr(_t, name, None) + if isinstance(obj, ScalarType): + _DTYPE_MAP[name] = obj + + +_populate_dtype_map() + +_MEMSPACE_MAP = { + "ub": MemorySpace.UB, + "gm": MemorySpace.GM, + "mat": MemorySpace.MAT, + "left": MemorySpace.LEFT, + "right": MemorySpace.RIGHT, + "acc": MemorySpace.ACC, + "bias": MemorySpace.BIAS, +} + + +def _find_descriptors(module) -> list[VKernelDescriptor]: + """Return all VKernelDescriptor instances found as module-level attributes.""" + result = [] + for attr_name in dir(module): + obj = getattr(module, attr_name, None) + if isinstance(obj, VKernelDescriptor): + result.append(obj) + return result + + +@contextmanager +def _template_import_context(template_dir: Path): + """Temporarily expose a template directory and package parent to imports.""" + import_roots: list[str] = [] + for root in (template_dir, template_dir.parent): + root_text = str(root) + if root_text not in import_roots: + import_roots.append(root_text) + + added_roots: list[str] = [] + for root_text in reversed(import_roots): + if root_text in sys.path: + continue + sys.path.insert(0, root_text) + added_roots.append(root_text) + + try: + yield + finally: + for root_text in added_roots: + try: + sys.path.remove(root_text) + except ValueError: + pass + + +def _import_py_file(path: Path): + """Import a .py file as a module and return it.""" + importlib.invalidate_caches() + module_name = f"_tl_template_{path.stem}" + spec = importlib.util.spec_from_file_location(module_name, str(path)) + if spec is None or spec.loader is None: + return None + mod = importlib.util.module_from_spec(spec) + sys.modules[module_name] = mod + try: + spec.loader.exec_module(mod) + except Exception as exc: + sys.modules.pop(module_name, None) + print(f"expand_helper: warning: failed to import {path}: {exc}", file=sys.stderr) + return None + return mod + + +def _bind_descriptor_for_query( + descriptor: VKernelDescriptor, + target: str, + op_name: str, + operand_types: tuple[ScalarType, ...], +) -> VKernelDescriptor | None: + if descriptor.target != target or op_name not in descriptor.match_ops: + return None + op_bound = descriptor._bind_selected_op(op_name) + matched_signature = _match_descriptor_dtype_signature(op_bound, operand_types) + if matched_signature is None: + return None + if op_bound._selected_dtype_signature == matched_signature: + return op_bound + return op_bound._bind_selected_dtype_signature(matched_signature) + + +def _match_descriptor( + descriptors: list[VKernelDescriptor], + op_name: str, + operand_types: tuple[ScalarType, ...], +) -> VKernelDescriptor | None: + """Legacy helper: find and bind the first descriptor matching (op, dtype).""" + for desc in descriptors: + bound = _bind_descriptor_for_query(desc, "a5", op_name, operand_types) + if bound is not None: + return bound + return None + + +def _parse_optional_int_sequence( + values: list[object], + *, + field_name: str, + index: int, +) -> tuple[int | None, ...]: + parsed: list[int | None] = [] + for dim in values: + if dim is None: + parsed.append(None) + continue + try: + parsed.append(int(dim)) + except (TypeError, ValueError) as exc: + raise ValueError( + f"operand-specs[{index}] {field_name} entries must be integers or null" + ) from exc + return tuple(parsed) + + +def _parse_operand_specs(spec_text: str) -> list[dict]: + try: + raw_specs = json.loads(spec_text) + except json.JSONDecodeError as exc: + raise ValueError(f"invalid operand-specs JSON: {exc}") from exc + + if not isinstance(raw_specs, list) or not raw_specs: + raise ValueError("operand-specs must be a non-empty JSON array") + + specs: list[dict] = [] + for index, raw in enumerate(raw_specs): + if not isinstance(raw, dict): + raise ValueError(f"operand-specs[{index}] must be an object") + kind = raw.get("kind") + dtype_name = raw.get("dtype") + dtype = _DTYPE_MAP.get(dtype_name) + if dtype is None: + raise ValueError(f"operand-specs[{index}] has unsupported dtype {dtype_name!r}") + if kind == "scalar": + specs.append({"kind": "scalar", "dtype": dtype}) + continue + if kind == "tile": + shape = raw.get("shape") + if not isinstance(shape, list) or not shape: + raise ValueError(f"operand-specs[{index}] tile shape must be a non-empty list") + valid_shape = raw.get("valid_shape") + if valid_shape is not None and (not isinstance(valid_shape, list) or not valid_shape): + raise ValueError(f"operand-specs[{index}] tile valid_shape must be a non-empty list") + memory_space = _MEMSPACE_MAP.get(raw.get("memory_space")) + if memory_space is None: + raise ValueError( + f"operand-specs[{index}] has unknown memory-space {raw.get('memory_space')!r}" + ) + config_raw = raw.get("config") + config = None + if config_raw is not None: + if not isinstance(config_raw, dict): + raise ValueError(f"operand-specs[{index}] tile config must be an object") + try: + config = TileConfig.from_mapping(config_raw) + except (TypeError, ValueError) as exc: + raise ValueError( + f"operand-specs[{index}] has invalid tile config: {exc}" + ) from exc + specs.append( + { + "kind": "tile", + "dtype": dtype, + "shape": tuple(int(dim) for dim in shape), + "valid_shape": None + if valid_shape is None + else _parse_optional_int_sequence( + valid_shape, + field_name="tile valid_shape", + index=index, + ), + "config": config, + "memory_space": memory_space, + } + ) + continue + if kind == "view": + shape = raw.get("shape") + if not isinstance(shape, list) or not shape: + raise ValueError(f"operand-specs[{index}] view shape must be a non-empty list") + memory_space = _MEMSPACE_MAP.get(raw.get("memory_space", "gm")) + if memory_space is None: + raise ValueError( + f"operand-specs[{index}] has unknown memory-space {raw.get('memory_space')!r}" + ) + view_spec: dict = { + "kind": "view", + "dtype": dtype, + "shape": _parse_optional_int_sequence( + shape, + field_name="view shape", + index=index, + ), + "memory_space": memory_space, + } + raw_strides = raw.get("strides") + if isinstance(raw_strides, list) and raw_strides: + # null entries represent dynamic strides — keep as None. + view_spec["strides"] = tuple( + None if s is None else int(s) for s in raw_strides + ) + specs.append(view_spec) + continue + if kind == "vector": + shape = raw.get("shape") + if not isinstance(shape, list) or not shape: + raise ValueError(f"operand-specs[{index}] vector shape must be a non-empty list") + specs.append( + { + "kind": "vector", + "dtype": dtype, + "shape": _parse_optional_int_sequence( + shape, + field_name="vector shape", + index=index, + ), + } + ) + continue + raise ValueError(f"operand-specs[{index}] has unknown kind {kind!r}") + return specs + + +def _operand_spec_matches_param_kind(param_kind: str, operand_kind: str) -> bool: + if operand_kind == "tile": + return param_kind == "tile" + if operand_kind == "view": + return param_kind in ("tensorview", "partition_tensor_view") + if operand_kind == "vector": + # Prefer an explicit builtin vector annotation, but keep scalar fallback + # for older templates that still model the auxiliary vector slot as a + # scalar-ish placeholder. + return param_kind in ("vector", "scalar") + if operand_kind == "scalar": + return param_kind == "scalar" + return False + + +def _filter_descriptors_by_operand_schema( + descriptors: list[VKernelDescriptor], + *, + target: str, + op_name: str, + operand_specs: list[dict], +) -> list[VKernelDescriptor]: + operand_types = tuple(spec["dtype"] for spec in operand_specs) + filtered: list[VKernelDescriptor] = [] + for descriptor in descriptors: + bound = _bind_descriptor_for_query(descriptor, target, op_name, operand_types) + if bound is None: + continue + parameters = bound.parameters + if len(parameters) != len(operand_specs): + continue + if all( + _operand_spec_matches_param_kind(param.kind, operand_spec["kind"]) + for param, operand_spec in zip(parameters, operand_specs) + ): + filtered.append(bound) + return filtered + + +def _build_positional_context_attrs(operand_specs: list[dict]) -> dict[str, object]: + attrs: dict[str, object] = {} + for index, operand_spec in enumerate(operand_specs): + prefix = f"arg{index}" + attrs[f"{prefix}_kind"] = operand_spec["kind"] + attrs[f"{prefix}_dtype"] = operand_spec["dtype"] + if operand_spec["kind"] == "scalar": + continue + shape = tuple(operand_spec["shape"]) + attrs[f"{prefix}_shape"] = shape + attrs[f"{prefix}_rank"] = len(shape) + if operand_spec["kind"] == "vector": + continue + memory_space = operand_spec.get("memory_space") + if isinstance(memory_space, MemorySpace): + attrs[f"{prefix}_memory_space"] = memory_space.value + elif memory_space is not None: + attrs[f"{prefix}_memory_space"] = memory_space + if operand_spec["kind"] == "tile": + valid_shape = operand_spec.get("valid_shape") + effective_valid_shape = shape if valid_shape is None else tuple(valid_shape) + attrs[f"{prefix}_valid_shape"] = effective_valid_shape + if operand_spec.get("config") is not None: + attrs[f"{prefix}_config"] = operand_spec["config"] + continue + if "strides" in operand_spec: + attrs[f"{prefix}_strides"] = tuple(operand_spec["strides"]) + return attrs + + +def _select_descriptor( + descriptors: list[VKernelDescriptor], + *, + target: str, + op_name: str, + operand_specs: list[dict], + extra_context_attrs: dict[str, object] | None = None, +) -> VKernelDescriptor: + filtered_descriptors = _filter_descriptors_by_operand_schema( + descriptors, + target=target, + op_name=op_name, + operand_specs=operand_specs, + ) + operand_types = tuple(spec["dtype"] for spec in operand_specs) + if not filtered_descriptors: + raise LookupError( + "expand_helper found no registered kernel after operand schema filtering for " + f"target={target!r}, op={op_name!r}, operand_types={operand_types!r}" + ) + registry = KernelRegistry(tuple(filtered_descriptors)) + context_attrs = _build_positional_context_attrs(operand_specs) + if extra_context_attrs: + context_attrs.update(extra_context_attrs) + return select_kernel( + target, + op_name, + operand_types, + context_attrs=context_attrs, + registry=registry, + return_metadata=False, + ) + + +def _parse_context_attrs(spec_text: str) -> dict[str, object]: + try: + raw = json.loads(spec_text) + except json.JSONDecodeError as exc: + raise ValueError(f"invalid context-attrs JSON: {exc}") from exc + + if not isinstance(raw, dict): + raise ValueError("context-attrs must be a JSON object") + return dict(raw) + + +def main(argv: list[str] | None = None) -> int: + parser = argparse.ArgumentParser(description="TileLang DSL expand helper") + parser.add_argument("--template-dir", required=True, help="Directory of .py templates") + parser.add_argument("--target", default="a5", help="Target architecture, e.g. a5") + parser.add_argument("--op", required=True, help="Tile op name, e.g. pto.tadd") + parser.add_argument("--dtype", help="Element dtype, e.g. f32") + parser.add_argument("--shape", help="Tile shape, e.g. 16,64") + parser.add_argument( + "--memory-space", + default="ub", + help="Memory space (ub/gm/mat/left/right/acc/bias)", + ) + parser.add_argument( + "--operand-specs", + help="JSON array describing each operand (tile/scalar/vector schema)", + ) + parser.add_argument( + "--context-attrs", + help="JSON object describing static op/context attrs visible to the template", + ) + args = parser.parse_args(argv) + + template_dir = Path(args.template_dir) + if not template_dir.is_dir(): + print(f"expand_helper: error: {template_dir} is not a directory", file=sys.stderr) + return 1 + + operand_specs: list[dict] | None = None + extra_context_attrs: dict[str, object] = {} + if args.operand_specs: + try: + operand_specs = _parse_operand_specs(args.operand_specs) + except ValueError as exc: + print(f"expand_helper: error: {exc}", file=sys.stderr) + return 1 + else: + if args.dtype is None or args.shape is None: + print( + "expand_helper: error: either --operand-specs or both --dtype/--shape are required", + file=sys.stderr, + ) + return 1 + shape = tuple(int(d) for d in args.shape.split(",")) + mem_space = _MEMSPACE_MAP.get(args.memory_space) + if mem_space is None: + print(f"expand_helper: error: unknown memory-space '{args.memory_space}'", file=sys.stderr) + return 1 + target_dtype = _DTYPE_MAP.get(args.dtype) + if target_dtype is None: + print(f"expand_helper: error: unknown dtype '{args.dtype}'", file=sys.stderr) + return 1 + operand_specs = [ + {"kind": "tile", "dtype": target_dtype, "shape": shape, "memory_space": mem_space} + ] + + if args.context_attrs: + try: + extra_context_attrs = _parse_context_attrs(args.context_attrs) + except ValueError as exc: + print(f"expand_helper: error: {exc}", file=sys.stderr) + return 1 + + # Scan all .py files for descriptors. + all_descriptors: list[VKernelDescriptor] = [] + with _template_import_context(template_dir): + for py_path in sorted(template_dir.glob("*.py")): + mod = _import_py_file(py_path) + if mod is None: + continue + all_descriptors.extend(_find_descriptors(mod)) + + if not all_descriptors: + print(f"expand_helper: error: no @vkernel descriptors found in {template_dir}", file=sys.stderr) + return 1 + + try: + desc = _select_descriptor( + all_descriptors, + target=args.target, + op_name=args.op, + operand_specs=operand_specs, + extra_context_attrs=extra_context_attrs, + ) + except Exception as exc: + print(f"expand_helper: error: {exc}", file=sys.stderr) + return 1 + + # Specialize Tile parameters positionally from operand-specs. + tile_specs = {} + for param, operand_spec in zip(desc.parameters, operand_specs): + if param.kind == "tile": + if operand_spec["kind"] != "tile": + print( + "expand_helper: error: descriptor tile parameter does not match operand-specs", + file=sys.stderr, + ) + return 1 + tile_specs[param.name] = TileSpecialization( + shape=operand_spec["shape"], + memory_space=operand_spec["memory_space"], + config=operand_spec.get("config"), + valid_shape=operand_spec.get("valid_shape"), + ) + continue + if param.kind in ("tensorview", "partition_tensor_view"): + if operand_spec["kind"] != "view": + print( + f"expand_helper: error: descriptor {param.kind} parameter " + f"does not match operand-specs kind {operand_spec['kind']!r}", + file=sys.stderr, + ) + return 1 + continue + if param.kind == "vector": + if operand_spec["kind"] != "vector": + print( + "expand_helper: error: descriptor builtin vector parameter does not match operand-specs", + file=sys.stderr, + ) + return 1 + continue + if param.kind == "scalar" and operand_spec["kind"] not in ("scalar", "vector"): + print( + "expand_helper: error: descriptor scalar parameter does not match operand-specs", + file=sys.stderr, + ) + return 1 + + specialized = desc.specialize(**tile_specs) + + # Emit MLIR to stdout. + try: + mlir_text = specialized.mlir_text() + except Exception as exc: + print(f"expand_helper: error: materialization failed: {exc}", file=sys.stderr) + return 1 + + sys.stdout.write(mlir_text) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tilelang-dsl/python/tilelang_dsl/frontend_ast.py b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py new file mode 100644 index 000000000..3211967fd --- /dev/null +++ b/tilelang-dsl/python/tilelang_dsl/frontend_ast.py @@ -0,0 +1,1740 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""Frontend AST nodes for TileLang DSL descriptor materialization.""" + +from __future__ import annotations + +import ast +import inspect +from dataclasses import dataclass +from typing import Any + +from .support_matrix import ( + ADVANCED_EXPR_PTO_CALLS, + ADVANCED_TOPLEVEL_PTO_CALLS, + ADVANCED_VECSCOPE_PTO_CALLS, + CUBE_ONLY_PTO_CALLS, + DEFERRED_PTO_SURFACES, + SUPPORTED_TOPLEVEL_PTO_CALLS, + SUPPORTED_VECSCOPE_PTO_CALLS, + advanced_mode_message, + deferred_surface_message, +) + + +@dataclass(frozen=True) +class FrontendParameterNode: + name: str + kind: str + annotation: Any + dtype: Any + + +@dataclass(frozen=True) +class FrontendTileSpecializationNode: + name: str + shape: tuple[int, ...] + memory_space: str + config: Any + valid_shape: tuple[int | None, ...] | None + + +@dataclass(frozen=True) +class FrontendSourceLocation: + path: str + line: int + column: int + + +class FrontendExprNode: + """Base class for lowered frontend expressions.""" + + +@dataclass(frozen=True) +class FrontendNameExpr(FrontendExprNode): + name: str + + +@dataclass(frozen=True) +class FrontendConstantExpr(FrontendExprNode): + value: Any + + +@dataclass(frozen=True) +class FrontendSymbolExpr(FrontendExprNode): + namespace: str + name: str + + +@dataclass(frozen=True) +class FrontendSliceExpr(FrontendExprNode): + start: FrontendExprNode | None + stop: FrontendExprNode | None + step: FrontendExprNode | None + + +@dataclass(frozen=True) +class FrontendTupleExpr(FrontendExprNode): + elements: tuple[FrontendExprNode, ...] + + +@dataclass(frozen=True) +class FrontendAttributeExpr(FrontendExprNode): + base: FrontendExprNode + attr: str + + +@dataclass(frozen=True) +class FrontendSubscriptExpr(FrontendExprNode): + base: FrontendExprNode + index: FrontendExprNode + + +@dataclass(frozen=True) +class FrontendBinaryExpr(FrontendExprNode): + lhs: FrontendExprNode + op: str + rhs: FrontendExprNode + + +@dataclass(frozen=True) +class FrontendCallExpr(FrontendExprNode): + namespace: str | None + name: str + args: tuple[FrontendExprNode, ...] + keywords: tuple[tuple[str, FrontendExprNode], ...] = () + + +class FrontendTargetNode: + """Base class for assignment targets.""" + + +@dataclass(frozen=True) +class FrontendNameTarget(FrontendTargetNode): + name: str + + +@dataclass(frozen=True) +class FrontendTupleTarget(FrontendTargetNode): + elements: tuple[FrontendNameTarget, ...] + + +class FrontendStmtNode: + """Base class for lowered frontend statements.""" + + +@dataclass(frozen=True) +class FrontendNoOpStmt(FrontendStmtNode): + pass + + +@dataclass(frozen=True) +class FrontendAssignStmt(FrontendStmtNode): + target: FrontendTargetNode + value: FrontendExprNode + annotation: Any | None = None + + +@dataclass(frozen=True) +class FrontendExprStmt(FrontendStmtNode): + expr: FrontendExprNode + + +@dataclass(frozen=True) +class FrontendReturnStmt(FrontendStmtNode): + value: FrontendExprNode | None + + +@dataclass(frozen=True) +class FrontendForStmt(FrontendStmtNode): + target: str + lower_bound: FrontendExprNode + upper_bound: FrontendExprNode + step: FrontendExprNode + body: tuple[FrontendStmtNode, ...] + + +@dataclass(frozen=True) +class FrontendIfStmt(FrontendStmtNode): + condition: FrontendExprNode + then_body: tuple[FrontendStmtNode, ...] + else_body: tuple[FrontendStmtNode, ...] + is_constexpr: bool = False + + +@dataclass(frozen=True) +class FrontendVecscopeStmt(FrontendStmtNode): + body: tuple[FrontendStmtNode, ...] + + +@dataclass(frozen=True) +class FrontendStrictVecscopeStmt(FrontendStmtNode): + captures: tuple[FrontendExprNode, ...] + block_arguments: tuple[str, ...] + body: tuple[FrontendStmtNode, ...] + + +@dataclass(frozen=True) +class FrontendInlineProcParameterNode: + name: str + annotation: Any + default: FrontendExprNode | None + + +@dataclass(frozen=True) +class FrontendInlineProcNode: + name: str + parameters: tuple[FrontendInlineProcParameterNode, ...] + body: tuple[FrontendStmtNode, ...] + + +@dataclass(frozen=True) +class FrontendKernelNode: + target: str + op: str + name: str + kernel_family: str + verify_enabled: bool + advanced_enabled: bool + dtype_signature: tuple[Any, ...] | None + parameters: tuple[FrontendParameterNode, ...] + tile_specializations: tuple[FrontendTileSpecializationNode, ...] + body: tuple[FrontendStmtNode, ...] + context_attrs: tuple[tuple[str, Any], ...] = () + inline_procs: tuple[FrontendInlineProcNode, ...] = () + internal_inline_procs: tuple[FrontendInlineProcNode, ...] = () + + +@dataclass(frozen=True) +class _FrontendInlineProc: + name: str + source_info: Any + signature: inspect.Signature + + +@dataclass(frozen=True) +class _FrontendBuildContext: + source_info: Any + module_globals: dict[str, Any] | None + templates: dict[str, dict[str, str]] + selected_op: str | None + advanced_enabled: bool + kernel_family: str + inline_procs: dict[str, _FrontendInlineProc] + global_literal_constants: dict[str, Any] + local_bindings: frozenset[str] + active_inline_proc_stack: tuple[str, ...] = () + vecscope_depth: int = 0 + + def error(self, node: ast.AST, message: str) -> Exception: + if self.source_info is not None: + return self.source_info.error(node, message) + return ValueError(message) + + def nested_vecscope(self) -> "_FrontendBuildContext": + return _FrontendBuildContext( + source_info=self.source_info, + module_globals=self.module_globals, + templates=self.templates, + selected_op=self.selected_op, + advanced_enabled=self.advanced_enabled, + kernel_family=self.kernel_family, + inline_procs=self.inline_procs, + global_literal_constants=self.global_literal_constants, + local_bindings=self.local_bindings, + active_inline_proc_stack=self.active_inline_proc_stack, + vecscope_depth=self.vecscope_depth + 1, + ) + + def enter_inline_proc(self, name: str, source_info: Any) -> "_FrontendBuildContext": + local_bindings = _collect_source_local_bindings(source_info) + global_literal_constants = _collect_module_literal_constants( + source_info, + module_globals=self.module_globals, + local_bindings=local_bindings, + ) + return _FrontendBuildContext( + source_info=source_info, + module_globals=self.module_globals, + templates=self.templates, + selected_op=self.selected_op, + advanced_enabled=self.advanced_enabled, + kernel_family=self.kernel_family, + inline_procs=self.inline_procs, + global_literal_constants=global_literal_constants, + local_bindings=local_bindings, + active_inline_proc_stack=(*self.active_inline_proc_stack, name), + vecscope_depth=self.vecscope_depth, + ) + + +_UNSUPPORTED_GLOBAL_LITERAL = object() +_LOCAL_BINDINGS_CACHE: dict[tuple[str, int, str], frozenset[str]] = {} +_GLOBAL_NAME_READS_CACHE: dict[tuple[str, int, str], frozenset[str]] = {} + + +def _reject_cube_vector_surface(context: _FrontendBuildContext, node: ast.AST, surface_name: str) -> None: + raise context.error( + node, + f"vector-only surface `{surface_name}` is not part of the @pto.ckernel contract", + ) + + +def _reject_vector_cube_surface(context: _FrontendBuildContext, node: ast.AST, surface_name: str) -> None: + raise context.error( + node, + f"cube-only surface `{surface_name}` is not part of the @pto.vkernel contract", + ) + + +def _iter_target_names(node: ast.AST) -> tuple[str, ...]: + if isinstance(node, ast.Name): + return (node.id,) + if isinstance(node, (ast.Tuple, ast.List)): + names: list[str] = [] + for elt in node.elts: + names.extend(_iter_target_names(elt)) + return tuple(names) + return () + + +def _collect_source_global_name_reads( + source_info: Any, + local_bindings: frozenset[str], +) -> frozenset[str]: + if source_info is None: + return frozenset() + function_def = source_info.function_def + cache_key = ( + source_info.path, + source_info.start_line, + function_def.name, + ) + cached = _GLOBAL_NAME_READS_CACHE.get(cache_key) + if cached is not None: + return cached + + global_reads: set[str] = set() + for node in ast.walk(function_def): + if not isinstance(node, ast.Name) or not isinstance(node.ctx, ast.Load): + continue + if node.id in local_bindings: + continue + if node.id.startswith("__"): + continue + global_reads.add(node.id) + + frozen = frozenset(global_reads) + _GLOBAL_NAME_READS_CACHE[cache_key] = frozen + return frozen + + +def _collect_function_local_bindings(function_def: ast.FunctionDef) -> set[str]: + bindings: set[str] = set() + for arg in function_def.args.posonlyargs: + bindings.add(arg.arg) + for arg in function_def.args.args: + bindings.add(arg.arg) + for arg in function_def.args.kwonlyargs: + bindings.add(arg.arg) + if function_def.args.vararg is not None: + bindings.add(function_def.args.vararg.arg) + if function_def.args.kwarg is not None: + bindings.add(function_def.args.kwarg.arg) + + for node in ast.walk(function_def): + if isinstance(node, ast.Assign): + for target in node.targets: + bindings.update(_iter_target_names(target)) + continue + if isinstance(node, ast.AnnAssign): + bindings.update(_iter_target_names(node.target)) + continue + if isinstance(node, ast.For): + bindings.update(_iter_target_names(node.target)) + continue + if isinstance(node, ast.With): + for item in node.items: + if item.optional_vars is not None: + bindings.update(_iter_target_names(item.optional_vars)) + continue + return bindings + + +def _collect_source_local_bindings(source_info: Any) -> frozenset[str]: + if source_info is None: + return frozenset() + function_def = source_info.function_def + cache_key = ( + source_info.path, + source_info.start_line, + function_def.name, + ) + cached = _LOCAL_BINDINGS_CACHE.get(cache_key) + if cached is not None: + return cached + collected = frozenset(_collect_function_local_bindings(function_def)) + _LOCAL_BINDINGS_CACHE[cache_key] = collected + return collected + + +def _collect_module_literal_constants( + source_info: Any, + *, + module_globals: dict[str, Any] | None, + local_bindings: frozenset[str], +) -> dict[str, Any]: + if source_info is None or module_globals is None: + return {} + literal_constants: dict[str, Any] = {} + for name in _collect_source_global_name_reads(source_info, local_bindings): + value = module_globals.get(name, _UNSUPPORTED_GLOBAL_LITERAL) + if isinstance(value, (bool, int, float, str)): + literal_constants[name] = value + return literal_constants + + +def _attach_source_location( + frontend_node: FrontendExprNode | FrontendStmtNode, + ast_node: ast.AST, + context: _FrontendBuildContext, +) -> FrontendExprNode | FrontendStmtNode: + if context.source_info is None: + return frontend_node + line, column = context.source_info.location(ast_node) + object.__setattr__( + frontend_node, + "source_location", + FrontendSourceLocation( + path=context.source_info.path, + line=line, + column=column, + ), + ) + return frontend_node + + +def _inline_proc_param_specs(inline_proc: _FrontendInlineProc) -> tuple[tuple[str, ast.expr | None], ...]: + function_def = inline_proc.source_info.function_def + params = function_def.args.args + defaults = function_def.args.defaults + first_default = len(params) - len(defaults) + specs: list[tuple[str, ast.expr | None]] = [] + for index, param in enumerate(params): + default_node: ast.expr | None = None + if index >= first_default: + default_node = defaults[index - first_default] + specs.append((param.arg, default_node)) + return tuple(specs) + + +def _bind_inline_proc_call( + node: ast.Call, + inline_proc: _FrontendInlineProc, + context: _FrontendBuildContext, +) -> tuple[FrontendExprNode, ...]: + if any(keyword.arg is None for keyword in node.keywords): + raise context.error( + node, + "keyword unpacking via `**` is not supported in TileLang DSL v1", + ) + + param_specs = _inline_proc_param_specs(inline_proc) + param_names = tuple(param_name for param_name, _ in param_specs) + bound: dict[str, FrontendExprNode] = {} + + if len(node.args) > len(param_specs): + raise context.error( + node, + f"inline_proc `{inline_proc.name}` accepts at most {len(param_specs)} positional arguments in TileLang DSL v1", + ) + + for index, arg_node in enumerate(node.args): + param_name = param_names[index] + bound[param_name] = _build_expr(arg_node, context) + + seen_keywords: set[str] = set() + for keyword in node.keywords: + assert keyword.arg is not None + if keyword.arg in seen_keywords: + raise context.error( + keyword.value, + f"duplicate keyword `{keyword.arg}` for inline_proc `{inline_proc.name}` in TileLang DSL v1", + ) + if keyword.arg not in param_names: + raise context.error( + keyword.value, + f"inline_proc `{inline_proc.name}` does not define keyword `{keyword.arg}` in TileLang DSL v1", + ) + if keyword.arg in bound: + raise context.error( + keyword.value, + f"inline_proc `{inline_proc.name}` got multiple values for argument `{keyword.arg}` in TileLang DSL v1", + ) + seen_keywords.add(keyword.arg) + bound[keyword.arg] = _build_expr(keyword.value, context) + + ordered_args: list[FrontendExprNode] = [] + for param_name, default_node in param_specs: + value = bound.get(param_name) + if value is None: + if default_node is None: + raise context.error( + node, + f"inline_proc `{inline_proc.name}` is missing required argument `{param_name}` in TileLang DSL v1", + ) + value = _build_expr(default_node, context) + ordered_args.append(value) + return tuple(ordered_args) + + +def _collect_name_reads(expr: FrontendExprNode) -> set[str]: + if isinstance(expr, FrontendNameExpr): + return {expr.name} + if isinstance(expr, (FrontendConstantExpr, FrontendSymbolExpr)): + return set() + if isinstance(expr, FrontendSliceExpr): + names: set[str] = set() + if expr.start is not None: + names |= _collect_name_reads(expr.start) + if expr.stop is not None: + names |= _collect_name_reads(expr.stop) + if expr.step is not None: + names |= _collect_name_reads(expr.step) + return names + if isinstance(expr, FrontendTupleExpr): + names: set[str] = set() + for element in expr.elements: + names |= _collect_name_reads(element) + return names + if isinstance(expr, FrontendAttributeExpr): + return _collect_name_reads(expr.base) + if isinstance(expr, FrontendSubscriptExpr): + return _collect_name_reads(expr.base) | _collect_name_reads(expr.index) + if isinstance(expr, FrontendBinaryExpr): + return _collect_name_reads(expr.lhs) | _collect_name_reads(expr.rhs) + if isinstance(expr, FrontendCallExpr): + names: set[str] = set() + for arg in expr.args: + names |= _collect_name_reads(arg) + for _, keyword_value in expr.keywords: + names |= _collect_name_reads(keyword_value) + return names + return set() + + +def _extract_target_names(target: FrontendTargetNode) -> set[str]: + if isinstance(target, FrontendNameTarget): + return {target.name} + if isinstance(target, FrontendTupleTarget): + return {element.name for element in target.elements} + return set() + +def _validate_inline_capture( + stmt: FrontendStmtNode, + param_names: set[str], + assigned_names: set[str], + *, + context: _FrontendBuildContext, +) -> None: + allowed = param_names | assigned_names | set(context.global_literal_constants) + if isinstance(stmt, FrontendNoOpStmt): + return + if isinstance(stmt, FrontendAssignStmt): + missing = _collect_name_reads(stmt.value) - allowed + if missing: + name = sorted(missing)[0] + raise context.error( + context.source_info.function_def, + f"implicit capture of '{name}' is not allowed in inline_proc", + ) + assigned_names |= _extract_target_names(stmt.target) + return + if isinstance(stmt, FrontendExprStmt): + missing = _collect_name_reads(stmt.expr) - allowed + if missing: + name = sorted(missing)[0] + raise context.error( + context.source_info.function_def, + f"implicit capture of '{name}' is not allowed in inline_proc", + ) + return + if isinstance(stmt, FrontendReturnStmt): + if stmt.value is None: + return + missing = _collect_name_reads(stmt.value) - allowed + if missing: + name = sorted(missing)[0] + raise context.error( + context.source_info.function_def, + f"implicit capture of '{name}' is not allowed in inline_proc", + ) + return + if isinstance(stmt, FrontendForStmt): + header_reads = ( + _collect_name_reads(stmt.lower_bound) + | _collect_name_reads(stmt.upper_bound) + | _collect_name_reads(stmt.step) + ) + missing = header_reads - allowed + if missing: + name = sorted(missing)[0] + raise context.error( + context.source_info.function_def, + f"implicit capture of '{name}' is not allowed in inline_proc", + ) + + loop_assigned = set(assigned_names) + loop_assigned.add(stmt.target) + for child in stmt.body: + _validate_inline_capture(child, param_names, loop_assigned, context=context) + assigned_names.add(stmt.target) + return + if isinstance(stmt, FrontendIfStmt): + missing = _collect_name_reads(stmt.condition) - allowed + if missing: + name = sorted(missing)[0] + raise context.error( + context.source_info.function_def, + f"implicit capture of '{name}' is not allowed in inline_proc", + ) + then_assigned = set(assigned_names) + else_assigned = set(assigned_names) + for child in stmt.then_body: + _validate_inline_capture(child, param_names, then_assigned, context=context) + for child in stmt.else_body: + _validate_inline_capture(child, param_names, else_assigned, context=context) + assigned_names |= then_assigned | else_assigned + return + if isinstance(stmt, FrontendVecscopeStmt): + scope_assigned = set(assigned_names) + for child in stmt.body: + _validate_inline_capture(child, param_names, scope_assigned, context=context) + assigned_names |= scope_assigned + return + if isinstance(stmt, FrontendStrictVecscopeStmt): + captures_missing = set().union(*(_collect_name_reads(capture) for capture in stmt.captures)) - allowed + if captures_missing: + name = sorted(captures_missing)[0] + raise context.error( + context.source_info.function_def, + f"implicit capture of '{name}' is not allowed in inline_proc", + ) + scope_assigned = set(assigned_names) | set(stmt.block_arguments) + for child in stmt.body: + _validate_inline_capture(child, param_names, scope_assigned, context=context) + assigned_names |= scope_assigned + + +def _collect_inline_proc_calls_expr( + expr: FrontendExprNode, + inline_proc_names: set[str], + into: set[str], +) -> None: + if isinstance(expr, FrontendCallExpr): + if expr.namespace is None and expr.name in inline_proc_names: + into.add(expr.name) + for arg in expr.args: + _collect_inline_proc_calls_expr(arg, inline_proc_names, into) + for _, keyword_value in expr.keywords: + _collect_inline_proc_calls_expr(keyword_value, inline_proc_names, into) + return + if isinstance(expr, FrontendBinaryExpr): + _collect_inline_proc_calls_expr(expr.lhs, inline_proc_names, into) + _collect_inline_proc_calls_expr(expr.rhs, inline_proc_names, into) + return + if isinstance(expr, FrontendTupleExpr): + for element in expr.elements: + _collect_inline_proc_calls_expr(element, inline_proc_names, into) + return + if isinstance(expr, FrontendSliceExpr): + if expr.start is not None: + _collect_inline_proc_calls_expr(expr.start, inline_proc_names, into) + if expr.stop is not None: + _collect_inline_proc_calls_expr(expr.stop, inline_proc_names, into) + if expr.step is not None: + _collect_inline_proc_calls_expr(expr.step, inline_proc_names, into) + return + if isinstance(expr, FrontendAttributeExpr): + _collect_inline_proc_calls_expr(expr.base, inline_proc_names, into) + return + if isinstance(expr, FrontendSubscriptExpr): + _collect_inline_proc_calls_expr(expr.base, inline_proc_names, into) + _collect_inline_proc_calls_expr(expr.index, inline_proc_names, into) + + +def _collect_inline_proc_calls_stmt( + stmt: FrontendStmtNode, + inline_proc_names: set[str], + into: set[str], +) -> None: + if isinstance(stmt, FrontendNoOpStmt): + return + if isinstance(stmt, FrontendAssignStmt): + _collect_inline_proc_calls_expr(stmt.value, inline_proc_names, into) + return + if isinstance(stmt, FrontendExprStmt): + _collect_inline_proc_calls_expr(stmt.expr, inline_proc_names, into) + return + if isinstance(stmt, FrontendReturnStmt): + if stmt.value is not None: + _collect_inline_proc_calls_expr(stmt.value, inline_proc_names, into) + return + if isinstance(stmt, FrontendForStmt): + _collect_inline_proc_calls_expr(stmt.lower_bound, inline_proc_names, into) + _collect_inline_proc_calls_expr(stmt.upper_bound, inline_proc_names, into) + _collect_inline_proc_calls_expr(stmt.step, inline_proc_names, into) + for child in stmt.body: + _collect_inline_proc_calls_stmt(child, inline_proc_names, into) + return + if isinstance(stmt, FrontendIfStmt): + _collect_inline_proc_calls_expr(stmt.condition, inline_proc_names, into) + for child in stmt.then_body: + _collect_inline_proc_calls_stmt(child, inline_proc_names, into) + for child in stmt.else_body: + _collect_inline_proc_calls_stmt(child, inline_proc_names, into) + return + if isinstance(stmt, FrontendVecscopeStmt): + for child in stmt.body: + _collect_inline_proc_calls_stmt(child, inline_proc_names, into) + return + if isinstance(stmt, FrontendStrictVecscopeStmt): + for capture in stmt.captures: + _collect_inline_proc_calls_expr(capture, inline_proc_names, into) + for child in stmt.body: + _collect_inline_proc_calls_stmt(child, inline_proc_names, into) + + +def _validate_inline_proc_call_graph( + kernel_body: tuple[FrontendStmtNode, ...], + inline_proc_nodes: tuple[FrontendInlineProcNode, ...], + inline_proc_source_infos: dict[str, Any], +) -> None: + inline_proc_names = {node.name for node in inline_proc_nodes} + if not inline_proc_names: + return + + edges: dict[str, set[str]] = {node.name: set() for node in inline_proc_nodes} + for inline_proc_node in inline_proc_nodes: + callees = edges[inline_proc_node.name] + for stmt in inline_proc_node.body: + _collect_inline_proc_calls_stmt(stmt, inline_proc_names, callees) + + root_callees: set[str] = set() + for stmt in kernel_body: + _collect_inline_proc_calls_stmt(stmt, inline_proc_names, root_callees) + + color: dict[str, int] = {} + + def dfs(name: str) -> None: + state = color.get(name, 0) + if state == 1: + source_info = inline_proc_source_infos.get(name) + if source_info is not None: + raise source_info.error( + source_info.function_def, + f"recursive inline_proc call `{name}` is not supported in TileLang DSL v1", + ) + raise ValueError(f"recursive inline_proc call `{name}` is not supported in TileLang DSL v1") + if state == 2: + return + color[name] = 1 + for callee in edges.get(name, ()): + dfs(callee) + color[name] = 2 + + for callee in sorted(root_callees): + dfs(callee) + + +def _collect_reachable_inline_procs( + kernel_body: tuple[FrontendStmtNode, ...], + inline_proc_nodes: tuple[FrontendInlineProcNode, ...], +) -> set[str]: + inline_proc_names = {node.name for node in inline_proc_nodes} + if not inline_proc_names: + return set() + + edges: dict[str, set[str]] = {node.name: set() for node in inline_proc_nodes} + for inline_proc_node in inline_proc_nodes: + for stmt in inline_proc_node.body: + _collect_inline_proc_calls_stmt(stmt, inline_proc_names, edges[inline_proc_node.name]) + + roots: set[str] = set() + for stmt in kernel_body: + _collect_inline_proc_calls_stmt(stmt, inline_proc_names, roots) + + reachable: set[str] = set() + stack = list(roots) + while stack: + name = stack.pop() + if name in reachable: + continue + reachable.add(name) + stack.extend(edges.get(name, ())) + return reachable + + +_BINARY_OP_NAMES = { + ast.Add: "add", + ast.Sub: "sub", + ast.Mult: "mul", + ast.Mod: "mod", + ast.FloorDiv: "floordiv", + ast.BitAnd: "bitand", + ast.BitOr: "bitor", + ast.BitXor: "bitxor", + ast.LShift: "lshift", + ast.RShift: "rshift", +} +_COMPARE_OP_NAMES = { + ast.Eq: "eq", + ast.NotEq: "ne", + ast.Gt: "gt", + ast.Lt: "lt", + ast.GtE: "ge", + ast.LtE: "le", +} +_BOOL_OP_NAMES = { + ast.And: "and", + ast.Or: "or", +} + +_DMA_CALL_KEYWORDS: dict[str, frozenset[str]] = { + "set_mov_pad_val": frozenset({"pad_value"}), + "set_loop2_stride_outtoub": frozenset({"src_stride", "dst_stride"}), + "set_loop1_stride_outtoub": frozenset({"src_stride", "dst_stride"}), + "set_loop_size_outtoub": frozenset({"loop1", "loop2"}), + "set_loop2_stride_ubtoout": frozenset({"src_stride", "dst_stride"}), + "set_loop1_stride_ubtoout": frozenset({"src_stride", "dst_stride"}), + "set_loop_size_ubtoout": frozenset({"loop1", "loop2"}), + "copy_gm_to_ubuf": frozenset( + { + "src", + "dst", + "sid", + "n_burst", + "len_burst", + "left_padding_count", + "right_padding_count", + "data_select_bit", + "enable_ub_pad", + "l2_cache_ctl", + "gm_stride", + "ub_stride", + } + ), + "copy_ubuf_to_gm": frozenset( + { + "src", + "dst", + "sid", + "n_burst", + "len_burst", + "reserved", + "burst_dst_stride", + "burst_src_stride", + "gm_stride", + "ub_stride", + } + ), + "vcvt": frozenset({"rnd", "sat", "part"}), + "vtrc": frozenset({"rnd"}), + "vlds": frozenset({"dist"}), + "vsts": frozenset({"dist"}), + "vbitcast": frozenset(), + "pbitcast": frozenset(), + "cube_load": frozenset({"nburst", "loops"}), + "cube_store": frozenset({"nburst", "loops"}), + "cube_load_frac": frozenset({"shape", "src_layout", "dst_group", "ctrl"}), + "bias_load": frozenset({"nburst"}), + "left_load": frozenset(), + "right_load": frozenset(), + "left_load_mx": frozenset(), + "right_load_mx": frozenset(), + "mad": frozenset({"unit_flag_ctrl", "disable_gemv"}), + "mad_acc": frozenset({"unit_flag_ctrl", "disable_gemv"}), + "mad_bias": frozenset({"unit_flag_ctrl", "disable_gemv"}), + "mad_mx": frozenset({"unit_flag_ctrl", "disable_gemv"}), + "mad_mx_acc": frozenset({"unit_flag_ctrl", "disable_gemv"}), + "mad_mx_bias": frozenset({"unit_flag_ctrl", "disable_gemv"}), + "acc_store": frozenset( + {"mode", "loop0_src_stride", "split", "loop3"} + ), + "acc_store_gm": frozenset( + { + "mode", + "loop0_src_stride", + "split", + "loop3", + "sid", + "l2_cache_ctrl", + } + ), + "acc_store_ub": frozenset( + { + "mode", + "loop0_src_stride", + "channel_split_en", + "loop3", + "dual_dst_mode", + "sub_blockid", + } + ), +} + + +def _attribute_path(node: ast.AST) -> tuple[str, ...] | None: + if isinstance(node, ast.Name): + return (node.id,) + if isinstance(node, ast.Attribute): + base_path = _attribute_path(node.value) + if base_path is None: + return None + return base_path + (node.attr,) + return None + + +def _validate_resolved_template_op_surface( + op_name: str, + node: ast.AST, + context: _FrontendBuildContext, +) -> None: + if op_name in CUBE_ONLY_PTO_CALLS: + if context.kernel_family == "cube": + return + _reject_vector_cube_surface(context, node, f"pto.{op_name}") + return + if op_name in SUPPORTED_TOPLEVEL_PTO_CALLS: + return + if op_name in SUPPORTED_VECSCOPE_PTO_CALLS: + return + if op_name in ADVANCED_VECSCOPE_PTO_CALLS: + if context.advanced_enabled: + return + raise context.error( + node, + advanced_mode_message(op_name), + ) + if op_name in ADVANCED_EXPR_PTO_CALLS or op_name in ADVANCED_TOPLEVEL_PTO_CALLS: + if context.advanced_enabled: + return + raise context.error( + node, + advanced_mode_message(op_name), + ) + if op_name in DEFERRED_PTO_SURFACES: + raise context.error( + node, + deferred_surface_message(op_name), + ) + raise context.error( + node, + f"unsupported op surface `pto.{op_name}` in TileLang DSL v1", + ) + + +def _build_call_keywords( + node: ast.Call, + *, + namespace: str | None, + name: str, + context: _FrontendBuildContext, +) -> tuple[tuple[str, FrontendExprNode], ...]: + if not node.keywords: + return () + + for keyword in node.keywords: + if keyword.arg is None: + raise context.error( + keyword.value, + "keyword unpacking via `**` is not supported in TileLang DSL v1", + ) + + allowed_keywords = _DMA_CALL_KEYWORDS.get(name) if namespace == "pto" else None + if allowed_keywords is None: + call_name = f"{namespace + '.' if namespace else ''}{name}" + raise context.error( + node, + f"`{call_name}` does not support keyword arguments in TileLang DSL v1; " + "no public call surface currently accepts them", + ) + + seen: set[str] = set() + built_keywords: list[tuple[str, FrontendExprNode]] = [] + for keyword in node.keywords: + assert keyword.arg is not None + if keyword.arg in seen: + raise context.error( + keyword.value, + f"duplicate keyword `{keyword.arg}` for `pto.{name}` in TileLang DSL v1", + ) + if keyword.arg not in allowed_keywords: + raise context.error( + keyword.value, + f"unsupported keyword `{keyword.arg}` for `pto.{name}` in TileLang DSL v1", + ) + seen.add(keyword.arg) + built_keywords.append((keyword.arg, _build_expr(keyword.value, context))) + return tuple(built_keywords) + + +def _build_expr(node: ast.AST, context: _FrontendBuildContext) -> FrontendExprNode: + if isinstance(node, ast.Name): + if ( + node.id in context.global_literal_constants + and node.id not in context.local_bindings + ): + return _attach_source_location( + FrontendConstantExpr(value=context.global_literal_constants[node.id]), + node, + context, + ) + return _attach_source_location(FrontendNameExpr(name=node.id), node, context) + if isinstance(node, ast.Constant): + return _attach_source_location(FrontendConstantExpr(value=node.value), node, context) + if isinstance(node, ast.UnaryOp): + if isinstance(node.op, ast.UAdd): + sign = 1 + elif isinstance(node.op, ast.USub): + sign = -1 + else: + raise context.error( + node, + f"unsupported unary operator `{type(node.op).__name__}` in TileLang DSL v1", + ) + if not isinstance(node.operand, ast.Constant) or isinstance(node.operand.value, bool): + raise context.error( + node, + "unary +/- currently only supports numeric literals in TileLang DSL v1", + ) + literal = node.operand.value + if not isinstance(literal, (int, float)): + raise context.error( + node, + "unary +/- currently only supports numeric literals in TileLang DSL v1", + ) + return _attach_source_location( + FrontendConstantExpr(value=literal if sign > 0 else -literal), + node, + context, + ) + if isinstance(node, ast.Slice): + start = None if node.lower is None else _build_expr(node.lower, context) + stop = None if node.upper is None else _build_expr(node.upper, context) + step = None if node.step is None else _build_expr(node.step, context) + return _attach_source_location( + FrontendSliceExpr(start=start, stop=stop, step=step), + node, + context, + ) + if isinstance(node, (ast.Tuple, ast.List)): + return _attach_source_location( + FrontendTupleExpr( + elements=tuple(_build_expr(elt, context) for elt in node.elts) + ), + node, + context, + ) + if isinstance(node, ast.Attribute): + path = _attribute_path(node) + if path is not None and path[0] in { + "pto", + "PAT", + "PIPE", + "EVENT", + "MaskPattern", + "PredicateDist", + "VLoadDist", + "VStoreDist", + "PredicatePart", + "CmpMode", + "Pipe", + "Event", + "BarrierType", + "MemorySpace", + "PadMode", + "DeinterleaveDist", + "InterleaveDist", + "PositionMode", + "OrderMode", + "VcvtRoundMode", + "VcvtSatMode", + "VcvtPartMode", + "PostUpdateMode", + "FractalMode", + } and len(path) >= 2: + return _attach_source_location( + FrontendSymbolExpr(namespace=".".join(path[:-1]), name=path[-1]), + node, + context, + ) + return _attach_source_location( + FrontendAttributeExpr(base=_build_expr(node.value, context), attr=node.attr), + node, + context, + ) + if isinstance(node, ast.Subscript): + return _attach_source_location( + FrontendSubscriptExpr( + base=_build_expr(node.value, context), + index=_build_expr(node.slice, context), + ), + node, + context, + ) + if isinstance(node, ast.BinOp): + op_name = _BINARY_OP_NAMES.get(type(node.op)) + if op_name is None: + raise context.error( + node, + f"unsupported binary operator `{type(node.op).__name__}` in TileLang DSL v1", + ) + return _attach_source_location( + FrontendBinaryExpr( + lhs=_build_expr(node.left, context), + op=op_name, + rhs=_build_expr(node.right, context), + ), + node, + context, + ) + if isinstance(node, ast.Compare): + if len(node.ops) != 1 or len(node.comparators) != 1: + raise context.error( + node, + "chained comparisons are not supported in TileLang DSL v1", + ) + op_name = _COMPARE_OP_NAMES.get(type(node.ops[0])) + if op_name is None: + raise context.error( + node, + f"unsupported comparison operator `{type(node.ops[0]).__name__}` in TileLang DSL v1", + ) + return _attach_source_location( + FrontendBinaryExpr( + lhs=_build_expr(node.left, context), + op=op_name, + rhs=_build_expr(node.comparators[0], context), + ), + node, + context, + ) + if isinstance(node, ast.BoolOp): + op_name = _BOOL_OP_NAMES.get(type(node.op)) + if op_name is None: + raise context.error( + node, + f"unsupported boolean operator `{type(node.op).__name__}` in TileLang DSL v1", + ) + if len(node.values) < 2: + raise context.error( + node, + "boolean expressions must contain at least two operands in TileLang DSL v1", + ) + expr = _build_expr(node.values[0], context) + for value in node.values[1:]: + expr = FrontendBinaryExpr( + lhs=expr, + op=op_name, + rhs=_build_expr(value, context), + ) + return _attach_source_location(expr, node, context) + if isinstance(node, ast.Call): + if isinstance(node.func, ast.Name) and node.func.id in context.inline_procs: + inline_proc = context.inline_procs[node.func.id] + if node.func.id in context.active_inline_proc_stack: + raise context.error( + node, + f"recursive inline_proc call `{node.func.id}` is not supported in TileLang DSL v1", + ) + return _attach_source_location( + FrontendCallExpr( + namespace=None, + name=node.func.id, + args=_bind_inline_proc_call(node, inline_proc, context), + keywords=(), + ), + node, + context, + ) + if ( + isinstance(node.func, ast.Attribute) + and isinstance(node.func.value, ast.Name) + and node.func.value.id == "pto" + and node.func.attr == "tpl" + ): + if not node.args: + raise context.error( + node, + "pto.tpl() requires a non-empty string literal slot name as the first argument", + ) + slot_expr = node.args[0] + if not ( + isinstance(slot_expr, ast.Constant) + and isinstance(slot_expr.value, str) + and slot_expr.value + ): + raise context.error( + slot_expr, + "pto.tpl() requires a non-empty string literal slot name", + ) + slot_name = slot_expr.value + slot_bindings = context.templates.get(slot_name) + if slot_bindings is None: + raise context.error( + slot_expr, + f"unknown template slot {slot_name!r} in TileLang DSL v1", + ) + if context.selected_op is None: + raise context.error( + node, + "pto.tpl() requires pto.select_kernel(...) to bind a concrete op before expansion", + ) + resolved_op = slot_bindings.get(context.selected_op) + if resolved_op is None: + raise context.error( + slot_expr, + f"template slot {slot_name!r} does not define an implementation for " + f"selected op {context.selected_op!r}", + ) + _validate_resolved_template_op_surface(resolved_op, node, context) + return _attach_source_location( + FrontendCallExpr( + namespace="pto", + name=resolved_op, + args=tuple(_build_expr(arg, context) for arg in node.args[1:]), + keywords=_build_call_keywords( + node, + namespace="pto", + name=resolved_op, + context=context, + ), + ), + node, + context, + ) + if isinstance(node.func, ast.Name): + return _attach_source_location( + FrontendCallExpr( + namespace=None, + name=node.func.id, + args=tuple(_build_expr(arg, context) for arg in node.args), + keywords=_build_call_keywords( + node, + namespace=None, + name=node.func.id, + context=context, + ), + ), + node, + context, + ) + if isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name): + if node.func.value.id == "pto" and context.kernel_family == "cube": + if node.func.attr in SUPPORTED_VECSCOPE_PTO_CALLS: + _reject_cube_vector_surface(context, node, f"pto.{node.func.attr}") + if node.func.attr in ADVANCED_VECSCOPE_PTO_CALLS: + _reject_cube_vector_surface(context, node, f"pto.{node.func.attr}") + if node.func.attr in ADVANCED_TOPLEVEL_PTO_CALLS: + _reject_cube_vector_surface(context, node, f"pto.{node.func.attr}") + if node.func.attr in DEFERRED_PTO_SURFACES: + raise context.error(node, deferred_surface_message(node.func.attr)) + if node.func.value.id == "pto" and context.kernel_family != "cube": + if node.func.attr in CUBE_ONLY_PTO_CALLS: + _reject_vector_cube_surface(context, node, f"pto.{node.func.attr}") + return _attach_source_location( + FrontendCallExpr( + namespace=node.func.value.id, + name=node.func.attr, + args=tuple(_build_expr(arg, context) for arg in node.args), + keywords=_build_call_keywords( + node, + namespace=node.func.value.id, + name=node.func.attr, + context=context, + ), + ), + node, + context, + ) + if isinstance(node.func, ast.Attribute): + return _attach_source_location( + FrontendCallExpr( + namespace=None, + name=node.func.attr, + args=( + _build_expr(node.func.value, context), + *(tuple(_build_expr(arg, context) for arg in node.args)), + ), + keywords=_build_call_keywords( + node, + namespace=None, + name=node.func.attr, + context=context, + ), + ), + node, + context, + ) + raise context.error( + node, + f"unsupported expression `{type(node).__name__}` in TileLang DSL v1", + ) + + +def _build_target(node: ast.AST, context: _FrontendBuildContext) -> FrontendTargetNode: + if isinstance(node, ast.Name): + return FrontendNameTarget(name=node.id) + if isinstance(node, ast.Tuple): + elements = [] + for elt in node.elts: + if not isinstance(elt, ast.Name): + raise context.error(elt, "tuple assignment only supports names in TileLang DSL v1") + elements.append(FrontendNameTarget(name=elt.id)) + return FrontendTupleTarget(elements=tuple(elements)) + raise context.error( + node, + f"unsupported assignment target `{type(node).__name__}` in TileLang DSL v1", + ) + + +def _build_stmt_list(nodes: list[ast.stmt] | tuple[ast.stmt, ...], context: _FrontendBuildContext) -> tuple[FrontendStmtNode, ...]: + return tuple(_build_stmt(node, context) for node in nodes) + + +def _build_stmt(node: ast.stmt, context: _FrontendBuildContext) -> FrontendStmtNode: + if isinstance(node, ast.Pass): + return _attach_source_location(FrontendNoOpStmt(), node, context) + if isinstance(node, ast.Assign): + if len(node.targets) != 1: + raise context.error(node, "multiple assignment targets are not supported in TileLang DSL v1") + return _attach_source_location( + FrontendAssignStmt( + target=_build_target(node.targets[0], context), + value=_build_expr(node.value, context), + ), + node, + context, + ) + if isinstance(node, ast.AnnAssign): + if node.value is None: + raise context.error(node, "annotation-only assignments are not supported in TileLang DSL v1") + return _attach_source_location( + FrontendAssignStmt( + target=_build_target(node.target, context), + value=_build_expr(node.value, context), + annotation=node.annotation, + ), + node, + context, + ) + if isinstance(node, ast.Expr): + return _attach_source_location( + FrontendExprStmt(expr=_build_expr(node.value, context)), + node, + context, + ) + if isinstance(node, ast.Return): + value = None + if node.value is not None: + if not (isinstance(node.value, ast.Constant) and node.value.value is None): + value = _build_expr(node.value, context) + return _attach_source_location(FrontendReturnStmt(value=value), node, context) + if isinstance(node, ast.For): + if not isinstance(node.target, ast.Name): + raise context.error(node.target, "for target must be a single name") + if not isinstance(node.iter, ast.Call) or not isinstance(node.iter.func, ast.Name) or node.iter.func.id != "range": + raise context.error(node.iter, "only Python range(lb, ub, step) loops are supported") + if len(node.iter.args) != 3: + raise context.error(node.iter, "range() expects exactly 3 arguments in TileLang DSL v1") + return _attach_source_location( + FrontendForStmt( + target=node.target.id, + lower_bound=_build_expr(node.iter.args[0], context), + upper_bound=_build_expr(node.iter.args[1], context), + step=_build_expr(node.iter.args[2], context), + body=_build_stmt_list(node.body, context), + ), + node, + context, + ) + if isinstance(node, ast.If): + is_constexpr = False + condition_node: ast.AST = node.test + if ( + isinstance(node.test, ast.Call) + and isinstance(node.test.func, ast.Attribute) + and isinstance(node.test.func.value, ast.Name) + and node.test.func.value.id == "pto" + and node.test.func.attr == "constexpr" + ): + if node.test.keywords: + raise context.error( + node.test, + "pto.constexpr() does not support keyword arguments in TileLang DSL v1", + ) + if len(node.test.args) != 1: + raise context.error( + node.test, + "pto.constexpr() expects exactly 1 positional argument in TileLang DSL v1", + ) + is_constexpr = True + condition_node = node.test.args[0] + return _attach_source_location( + FrontendIfStmt( + condition=_build_expr(condition_node, context), + then_body=_build_stmt_list(node.body, context), + else_body=_build_stmt_list(node.orelse, context), + is_constexpr=is_constexpr, + ), + node, + context, + ) + if isinstance(node, ast.With): + if len(node.items) != 1: + raise context.error(node, "only a single with-item is supported in TileLang DSL v1") + item = node.items[0] + if not isinstance(item.context_expr, ast.Call): + raise context.error(item.context_expr, "with context must be a call in TileLang DSL v1") + if not ( + isinstance(item.context_expr.func, ast.Attribute) + and isinstance(item.context_expr.func.value, ast.Name) + and item.context_expr.func.value.id == "pto" + ): + raise context.error( + item.context_expr, + "only pto.vecscope/pto.strict_vecscope are supported in TileLang DSL v1", + ) + with_name = item.context_expr.func.attr + if with_name == "vecscope": + if item.context_expr.args or item.context_expr.keywords: + raise context.error( + item.context_expr, + "pto.vecscope() does not accept positional or keyword arguments in TileLang DSL v1", + ) + if item.optional_vars is not None: + raise context.error(item, "pto.vecscope() does not support `as` bindings in TileLang DSL v1") + return _attach_source_location( + FrontendVecscopeStmt( + body=_build_stmt_list(node.body, context.nested_vecscope()), + ), + node, + context, + ) + if with_name != "strict_vecscope": + raise context.error( + item.context_expr, + "only pto.vecscope/pto.strict_vecscope are supported in TileLang DSL v1", + ) + if not context.advanced_enabled: + raise context.error( + item.context_expr, + advanced_mode_message("strict_vecscope"), + ) + if not isinstance(item.optional_vars, ast.Tuple): + raise context.error(item, "pto.strict_vecscope requires tuple binding in 'as'") + block_arguments = [] + for elt in item.optional_vars.elts: + if not isinstance(elt, ast.Name): + raise context.error(elt, "pto.strict_vecscope bindings must be names") + block_arguments.append(elt.id) + return _attach_source_location( + FrontendStrictVecscopeStmt( + captures=tuple(_build_expr(arg, context) for arg in item.context_expr.args), + block_arguments=tuple(block_arguments), + body=_build_stmt_list(node.body, context.nested_vecscope()), + ), + node, + context, + ) + raise context.error( + node, + f"unsupported statement `{type(node).__name__}` in TileLang DSL v1", + ) + + +def build_frontend_kernel_node(descriptor: Any) -> FrontendKernelNode: + """Project the core-foundation descriptor into a lowering-owned AST.""" + + if getattr(descriptor, "_parameters", None) is not None: + parameters = tuple( + FrontendParameterNode( + name=param.name, + kind=param.kind, + annotation=param.annotation, + dtype=param.dtype, + ) + for param in descriptor.parameters + ) + else: + parameters = tuple( + FrontendParameterNode( + name=param.name, + kind=param.kind, + annotation=param.annotation, + dtype=None, + ) + for param in descriptor._parameter_specs + ) + tile_specializations = tuple( + FrontendTileSpecializationNode( + name=name, + shape=spec.shape, + memory_space=spec.memory_space.value, + config=spec.config, + valid_shape=spec.valid_shape, + ) + for name, spec in descriptor.specializations + ) + source_info = descriptor._source_info + local_bindings = _collect_source_local_bindings(source_info) + global_literal_constants = _collect_module_literal_constants( + source_info, + module_globals=getattr(descriptor._py_fn, "__globals__", None), + local_bindings=local_bindings, + ) + sorted_inline_procs = tuple(sorted(descriptor.inline_procs.items(), key=lambda item: item[0])) + sorted_internal_inline_procs = tuple( + sorted(descriptor.internal_inline_procs.items(), key=lambda item: item[0]) + ) + context = _FrontendBuildContext( + source_info=source_info, + module_globals=getattr(descriptor._py_fn, "__globals__", None), + templates=descriptor.templates, + selected_op=descriptor.selected_op, + advanced_enabled=descriptor.advanced_enabled, + inline_procs={ + name: _FrontendInlineProc( + name=name, + source_info=proc.source_info, + signature=proc.signature, + ) + for name, proc in sorted_inline_procs + }, + kernel_family=getattr(descriptor, "kernel_family", "vector"), + global_literal_constants=global_literal_constants, + local_bindings=local_bindings, + ) + body = () + if source_info is not None: + body = _build_stmt_list(source_info.function_def.body, context) + + inline_proc_descriptors = {name: descriptor for name, descriptor in sorted_inline_procs} + inline_proc_names = set(inline_proc_descriptors) + root_inline_calls: set[str] = set() + for stmt in body: + _collect_inline_proc_calls_stmt(stmt, inline_proc_names, root_inline_calls) + + inline_proc_nodes_by_name: dict[str, FrontendInlineProcNode] = {} + inline_proc_source_infos: dict[str, Any] = {} + pending = list(sorted(root_inline_calls)) + while pending: + name = pending.pop() + if name in inline_proc_nodes_by_name: + continue + inline_proc_descriptor = inline_proc_descriptors.get(name) + if inline_proc_descriptor is None: + continue + inline_source = inline_proc_descriptor.source_info + if inline_source is None: + if source_info is not None: + raise context.error( + source_info.function_def, + f"inline_proc `{name}` requires source-visible Python functions", + ) + raise ValueError( + f"inline_proc `{name}` requires source-visible Python functions" + ) + inline_proc_source_infos[name] = inline_source + helper_context = context.enter_inline_proc(name, inline_source) + helper_body = _build_stmt_list(inline_source.function_def.body, helper_context) + parameter_specs = _inline_proc_param_specs( + _FrontendInlineProc( + name=name, + source_info=inline_source, + signature=inline_proc_descriptor.signature, + ) + ) + inline_proc_node = FrontendInlineProcNode( + name=name, + parameters=tuple( + FrontendInlineProcParameterNode( + name=param_name, + annotation=arg.annotation, + default=None + if default_node is None + else _build_expr(default_node, helper_context), + ) + for (param_name, default_node), arg in zip(parameter_specs, inline_source.function_def.args.args) + ), + body=helper_body, + ) + inline_proc_nodes_by_name[name] = inline_proc_node + nested_calls: set[str] = set() + for stmt in helper_body: + _collect_inline_proc_calls_stmt(stmt, inline_proc_names, nested_calls) + for nested in sorted(nested_calls): + if nested not in inline_proc_nodes_by_name: + pending.append(nested) + + reachable_inline_proc_nodes = tuple( + inline_proc_nodes_by_name[name] + for name, _ in sorted_inline_procs + if name in inline_proc_nodes_by_name + ) + for inline_proc_node in reachable_inline_proc_nodes: + source = inline_proc_source_infos[inline_proc_node.name] + helper_context = context.enter_inline_proc(inline_proc_node.name, source) + assigned_names: set[str] = set() + param_names = {parameter.name for parameter in inline_proc_node.parameters} + for stmt in inline_proc_node.body: + _validate_inline_capture( + stmt, + param_names, + assigned_names, + context=helper_context, + ) + + _validate_inline_proc_call_graph( + body, + reachable_inline_proc_nodes, + inline_proc_source_infos, + ) + + internal_inline_proc_nodes: tuple[FrontendInlineProcNode, ...] = () + if sorted_internal_inline_procs: + merged_inline_proc_descriptors = { + name: _FrontendInlineProc( + name=name, + source_info=proc.source_info, + signature=proc.signature, + ) + for name, proc in (*sorted_inline_procs, *sorted_internal_inline_procs) + } + internal_context = _FrontendBuildContext( + source_info=source_info, + module_globals=getattr(descriptor._py_fn, "__globals__", None), + templates=descriptor.templates, + selected_op=descriptor.selected_op, + advanced_enabled=descriptor.advanced_enabled, + kernel_family="vector", + inline_procs=merged_inline_proc_descriptors, + global_literal_constants=global_literal_constants, + local_bindings=local_bindings, + ) + internal_nodes: list[FrontendInlineProcNode] = [] + internal_source_infos: dict[str, Any] = {} + for name, inline_proc_descriptor in sorted_internal_inline_procs: + inline_source = inline_proc_descriptor.source_info + if inline_source is None: + if source_info is not None: + raise context.error( + source_info.function_def, + f"inline_proc `{name}` requires source-visible Python functions", + ) + raise ValueError( + f"inline_proc `{name}` requires source-visible Python functions" + ) + internal_source_infos[name] = inline_source + helper_context = internal_context.enter_inline_proc(name, inline_source) + helper_body = _build_stmt_list(inline_source.function_def.body, helper_context) + parameter_specs = _inline_proc_param_specs( + _FrontendInlineProc( + name=name, + source_info=inline_source, + signature=inline_proc_descriptor.signature, + ) + ) + inline_proc_node = FrontendInlineProcNode( + name=name, + parameters=tuple( + FrontendInlineProcParameterNode( + name=param_name, + annotation=arg.annotation, + default=None + if default_node is None + else _build_expr(default_node, helper_context), + ) + for (param_name, default_node), arg in zip( + parameter_specs, + inline_source.function_def.args.args, + ) + ), + body=helper_body, + ) + internal_nodes.append(inline_proc_node) + + internal_inline_proc_nodes = tuple(internal_nodes) + for inline_proc_node in internal_inline_proc_nodes: + source = internal_source_infos[inline_proc_node.name] + helper_context = internal_context.enter_inline_proc(inline_proc_node.name, source) + assigned_names: set[str] = set() + param_names = {parameter.name for parameter in inline_proc_node.parameters} + for stmt in inline_proc_node.body: + _validate_inline_capture( + stmt, + param_names, + assigned_names, + context=helper_context, + ) + + dtype_signature = descriptor._selected_dtype_signature + if dtype_signature is None and getattr(descriptor, "kernel_family", "vector") != "cube": + dtype_signature = descriptor.dtype_signature + + return FrontendKernelNode( + target=descriptor.target, + op=descriptor.op, + name=descriptor.name, + kernel_family=getattr(descriptor, "kernel_family", "vector"), + verify_enabled=descriptor.verify_enabled, + advanced_enabled=descriptor.advanced_enabled, + dtype_signature=dtype_signature, + parameters=parameters, + tile_specializations=tile_specializations, + body=body, + context_attrs=tuple( + sorted(descriptor.constraint_context_attrs.items(), key=lambda item: item[0]) + ), + inline_procs=reachable_inline_proc_nodes, + internal_inline_procs=internal_inline_proc_nodes, + ) + + +__all__ = [ + "FrontendAssignStmt", + "FrontendAttributeExpr", + "FrontendBinaryExpr", + "FrontendCallExpr", + "FrontendConstantExpr", + "FrontendExprNode", + "FrontendExprStmt", + "FrontendForStmt", + "FrontendIfStmt", + "FrontendInlineProcNode", + "FrontendInlineProcParameterNode", + "FrontendKernelNode", + "FrontendNameExpr", + "FrontendNameTarget", + "FrontendNoOpStmt", + "FrontendParameterNode", + "FrontendReturnStmt", + "FrontendSliceExpr", + "FrontendVecscopeStmt", + "FrontendStrictVecscopeStmt", + "FrontendStmtNode", + "FrontendSubscriptExpr", + "FrontendSymbolExpr", + "FrontendTargetNode", + "FrontendTileSpecializationNode", + "FrontendTupleExpr", + "FrontendTupleTarget", + "build_frontend_kernel_node", +] diff --git a/tilelang-dsl/python/tilelang_dsl/kernel.py b/tilelang-dsl/python/tilelang_dsl/kernel.py new file mode 100644 index 000000000..06ab10c0a --- /dev/null +++ b/tilelang-dsl/python/tilelang_dsl/kernel.py @@ -0,0 +1,2894 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""Kernel descriptor surface for TileLang DSL v1.""" + +from __future__ import annotations + +import os +import inspect +import ast +import importlib.util +import sys +import textwrap +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Callable, Mapping + +from .types import ( + AnyMask, + AnyType, + MaskType, + MemorySpace, + PartitionTensorView, + PointerType, + ScalarType, + TensorView, + Tile, + TileConfig, + TileSpecialization, + TypeVariable, + VectorType, + WildcardType, + is_integer_dtype, +) +from .frontend_ast import _DMA_CALL_KEYWORDS, build_frontend_kernel_node +from .lowering import lower_semantic_kernel +from .semantic import analyze_frontend_kernel +from .support_matrix import ( + ADVANCED_EXPR_PTO_CALLS, + ADVANCED_TOPLEVEL_PTO_CALLS, + ADVANCED_VECSCOPE_PTO_CALLS, + CUBE_ONLY_PTO_CALLS, + DEFERRED_PTO_SURFACES, + SUPPORTED_TOPLEVEL_PTO_CALLS, + SUPPORTED_VECSCOPE_PTO_CALLS, + advanced_mode_message, + deferred_surface_message, +) + + +_UNSET = object() +_INTERNAL_SOFT_MATH_MODULE_NAME = "tilelang_dsl._internal_soft_math" +_SUPPORTED_TEMPLATE_PTO_CALLS = frozenset( + SUPPORTED_TOPLEVEL_PTO_CALLS + | SUPPORTED_VECSCOPE_PTO_CALLS + | ADVANCED_VECSCOPE_PTO_CALLS + | ADVANCED_EXPR_PTO_CALLS + | ADVANCED_TOPLEVEL_PTO_CALLS + | CUBE_ONLY_PTO_CALLS +) + +_DSL_DTYPE_NAMES = frozenset( + { + "i1", + "i8", + "si8", + "ui8", + "i16", + "si16", + "ui16", + "i32", + "si32", + "ui32", + "i64", + "si64", + "ui64", + "f16", + "bf16", + "f32", + } +) + + +_INLINE_PROC_REGISTRY: dict[tuple[str, str], "InlineProcDescriptor"] = {} +_INTERNAL_INLINE_PROC_CACHE: tuple[tuple[str, "InlineProcDescriptor"], ...] | None = None + + +@dataclass(frozen=True) +class InlineProcDescriptor: + """Descriptor returned by @tilelang_dsl.inline_proc.""" + + name: str + py_fn: Callable[..., Any] = field(repr=False) + signature: inspect.Signature = field(repr=False) + source_info: "_FunctionSourceInfo | None" = field(repr=False, default=None) + + +class _InlineProcValidator(ast.NodeVisitor): + def __init__(self, source_info: "_FunctionSourceInfo"): + self.source_info = source_info + + def validate(self) -> None: + fn = self.source_info.function_def + args = fn.args + if args.posonlyargs: + raise self.source_info.error(args.posonlyargs[0], "inline_proc does not support positional-only parameters in TileLang DSL v1") + if args.vararg is not None: + raise self.source_info.error(args.vararg, "inline_proc does not support *args in TileLang DSL v1") + if args.kwarg is not None: + raise self.source_info.error(args.kwarg, "inline_proc does not support **kwargs in TileLang DSL v1") + if args.kwonlyargs: + raise self.source_info.error(args.kwonlyargs[0], "inline_proc does not support keyword-only parameters in TileLang DSL v1") + tail_return: ast.Return | None = fn.body[-1] if fn.body and isinstance(fn.body[-1], ast.Return) else None + for node in ast.walk(fn): + if not isinstance(node, ast.Return): + continue + if node is tail_return: + continue + raise self.source_info.error( + node, + "inline_proc only supports an optional trailing `return` in TileLang DSL v1", + ) + + for stmt in fn.body: + self.visit(stmt) + + def visit_FunctionDef(self, node: ast.FunctionDef) -> None: + if node is self.source_info.function_def: + for stmt in node.body: + self.visit(stmt) + return + raise self.source_info.error(node, "nested function definitions are not supported inside inline_proc in TileLang DSL v1") + + +def _inline_proc_registry_key(fn: Callable[..., Any]) -> tuple[str, str]: + return (fn.__module__, fn.__name__) + + +def _find_inline_proc(name: str, *, module_name: str | None) -> InlineProcDescriptor | None: + if module_name is None: + return None + descriptor = _INLINE_PROC_REGISTRY.get((module_name, name)) + if descriptor is not None: + return descriptor + module = sys.modules.get(module_name) + if module is None: + return None + value = getattr(module, name, None) + if isinstance(value, InlineProcDescriptor): + return value + return None + + +def _validate_inline_proc_call_surface( + source_info: _FunctionSourceInfo, + node: ast.Call, + inline_proc: InlineProcDescriptor, +) -> None: + if any(keyword.arg is None for keyword in node.keywords): + keyword = next(keyword for keyword in node.keywords if keyword.arg is None) + raise source_info.error( + keyword.value, + "keyword unpacking via `**` is not supported in TileLang DSL v1", + ) + seen_keywords: set[str] = set() + for keyword in node.keywords: + assert keyword.arg is not None + if keyword.arg in seen_keywords: + raise source_info.error( + keyword.value, + f"duplicate keyword `{keyword.arg}` for inline_proc `{inline_proc.name}` in TileLang DSL v1", + ) + seen_keywords.add(keyword.arg) + positional_placeholders = [object() for _ in node.args] + keyword_placeholders = {keyword.arg: object() for keyword in node.keywords if keyword.arg is not None} + try: + inline_proc.signature.bind(*positional_placeholders, **keyword_placeholders) + except TypeError as exc: + raise source_info.error( + node, + f"invalid inline_proc call `{inline_proc.name}` in TileLang DSL v1: {exc}", + ) from exc + + +def _same_inline_proc_descriptor( + lhs: InlineProcDescriptor, + rhs: InlineProcDescriptor, +) -> bool: + return lhs is rhs or lhs.py_fn is rhs.py_fn + + +def _format_inline_proc_origin(descriptor: InlineProcDescriptor) -> str: + return f"{descriptor.py_fn.__module__}.{descriptor.py_fn.__name__}" + + +def _add_collected_inline_proc( + collected: dict[str, InlineProcDescriptor], + symbol: str, + descriptor: InlineProcDescriptor, +) -> None: + existing = collected.get(symbol) + if existing is None: + collected[symbol] = descriptor + return + if _same_inline_proc_descriptor(existing, descriptor): + return + raise ValueError( + "ambiguous inline_proc name " + f"`{symbol}` in TileLang DSL module: " + f"{_format_inline_proc_origin(existing)} conflicts with " + f"{_format_inline_proc_origin(descriptor)}" + ) + + +def _collect_inline_procs(module_name: str) -> tuple[tuple[str, InlineProcDescriptor], ...]: + collected: dict[str, InlineProcDescriptor] = {} + for (registered_module, symbol), descriptor in _INLINE_PROC_REGISTRY.items(): + if registered_module == module_name: + _add_collected_inline_proc(collected, symbol, descriptor) + + module = sys.modules.get(module_name) + if module is not None: + for symbol, value in vars(module).items(): + if not isinstance(value, InlineProcDescriptor): + continue + _add_collected_inline_proc(collected, symbol, value) + origin_module = value.py_fn.__module__ + for (registered_module, helper_name), helper in _INLINE_PROC_REGISTRY.items(): + if registered_module == origin_module: + _add_collected_inline_proc(collected, helper_name, helper) + + return tuple(sorted(collected.items(), key=lambda item: item[0])) + + +def _load_module_from_path(module_name: str, path: Path) -> Any: + module = sys.modules.get(module_name) + if module is not None: + return module + spec = importlib.util.spec_from_file_location(module_name, path) + if spec is None or spec.loader is None: + raise ImportError(f"unable to load module {module_name!r} from {path}") + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + +def _find_internal_soft_math_path() -> Path | None: + module_path = Path(__file__).resolve() + candidate_suffixes = ( + ("lib", "TileOps", "math.py"), + ("share", "ptoas", "TileOps", "math.py"), + ) + for root in (module_path.parent, *module_path.parents): + for suffix in candidate_suffixes: + candidate = root.joinpath(*suffix) + if candidate.exists(): + return candidate + return None + + +def _collect_internal_inline_procs() -> tuple[tuple[str, InlineProcDescriptor], ...]: + global _INTERNAL_INLINE_PROC_CACHE + if _INTERNAL_INLINE_PROC_CACHE is not None: + return _INTERNAL_INLINE_PROC_CACHE + + soft_math_path = _find_internal_soft_math_path() + if soft_math_path is None: + _INTERNAL_INLINE_PROC_CACHE = () + return _INTERNAL_INLINE_PROC_CACHE + + try: + module = _load_module_from_path(_INTERNAL_SOFT_MATH_MODULE_NAME, soft_math_path) + except Exception: + _INTERNAL_INLINE_PROC_CACHE = () + return _INTERNAL_INLINE_PROC_CACHE + + collected: dict[str, InlineProcDescriptor] = {} + for symbol, value in vars(module).items(): + if isinstance(value, InlineProcDescriptor): + collected.setdefault(symbol, value) + + _INTERNAL_INLINE_PROC_CACHE = tuple(sorted(collected.items(), key=lambda item: item[0])) + return _INTERNAL_INLINE_PROC_CACHE + + +def _register_inline_proc(descriptor: InlineProcDescriptor) -> InlineProcDescriptor: + _INLINE_PROC_REGISTRY[_inline_proc_registry_key(descriptor.py_fn)] = descriptor + return descriptor + + +def inline_proc( + py_fn: Callable[..., Any] | None = None, +) -> InlineProcDescriptor | Callable[[Callable[..., Any]], InlineProcDescriptor]: + """Register a top-level compile-time inline procedure for TileLang DSL kernels.""" + + def wrap(fn: Callable[..., Any]) -> InlineProcDescriptor: + if not callable(fn): + raise TypeError("@inline_proc can only decorate callables") + source_info = _load_function_source_info(fn) + if source_info is None: + raise TypeError("@inline_proc requires source-visible Python functions") + _InlineProcValidator(source_info).validate() + return _register_inline_proc( + InlineProcDescriptor( + name=fn.__name__, + py_fn=fn, + source_info=source_info, + signature=inspect.signature(fn), + ) + ) + + if py_fn is None: + return wrap + return wrap(py_fn) + + +def _validate_dtype_pattern(dtype: Any) -> ScalarType | MaskType | WildcardType | TypeVariable: + if isinstance(dtype, (ScalarType, MaskType, WildcardType, TypeVariable)): + return dtype + raise TypeError(f"unsupported dtype pattern {dtype!r}") + + +class TileLangFrontendError(ValueError): + """Source-located frontend diagnostic for TileLang DSL.""" + + def __init__(self, path: str, line: int, column: int, message: str): + self.path = path + self.line = line + self.column = column + self.message = message + super().__init__(f"{path}:{line}:{column}: {message}") + + +@dataclass(frozen=True) +class _FunctionSourceInfo: + path: str + start_line: int + function_def: ast.FunctionDef + + def location(self, node: ast.AST) -> tuple[int, int]: + line = self.start_line + getattr(node, "lineno", 1) - 1 + column = getattr(node, "col_offset", 0) + 1 + return line, column + + def error(self, node: ast.AST, message: str) -> TileLangFrontendError: + line, column = self.location(node) + return TileLangFrontendError(self.path, line, column, message) + + def parameter_node(self, param_name: str) -> ast.AST | None: + for arg in self.function_def.args.args: + if arg.arg == param_name: + return arg.annotation or arg + return None + + +class _KernelBodyValidator(ast.NodeVisitor): + def __init__( + self, + source_info: _FunctionSourceInfo, + *, + advanced_enabled: bool, + module_name: str | None, + kernel_family: str, + ): + self.source_info = source_info + self.advanced_enabled = advanced_enabled + self.module_name = module_name + self.kernel_family = kernel_family + self._vecscope_depth = 0 + self._static_dtype_bindings: set[str] = set() + + def validate(self) -> None: + for stmt in self.source_info.function_def.body: + self.visit(stmt) + + def visit_While(self, node: ast.While) -> None: + raise self.source_info.error(node, "unsupported Python syntax `while` in TileLang DSL v1") + + def visit_ListComp(self, node: ast.ListComp) -> None: + raise self.source_info.error( + node, "unsupported Python syntax `list comprehension` in TileLang DSL v1" + ) + + def visit_DictComp(self, node: ast.DictComp) -> None: + raise self.source_info.error( + node, "unsupported Python syntax `dict comprehension` in TileLang DSL v1" + ) + + def visit_SetComp(self, node: ast.SetComp) -> None: + raise self.source_info.error( + node, "unsupported Python syntax `set comprehension` in TileLang DSL v1" + ) + + def visit_GeneratorExp(self, node: ast.GeneratorExp) -> None: + raise self.source_info.error( + node, "unsupported Python syntax `generator expression` in TileLang DSL v1" + ) + + def visit_For(self, node: ast.For) -> None: + if not isinstance(node.target, ast.Name): + raise self.source_info.error(node.target, "for target must be a single name") + if not isinstance(node.iter, ast.Call) or not isinstance(node.iter.func, ast.Name): + raise self.source_info.error(node.iter, "only Python range(lb, ub, step) loops are supported") + if node.iter.func.id != "range": + raise self.source_info.error(node.iter, "only Python range(lb, ub, step) loops are supported") + if node.iter.keywords: + raise self.source_info.error( + node.iter, + "range() does not support keyword arguments in TileLang DSL v1", + ) + if len(node.iter.args) != 3: + raise self.source_info.error(node.iter, "range() expects exactly 3 arguments in TileLang DSL v1") + for stmt in node.body: + self.visit(stmt) + for stmt in node.orelse: + self.visit(stmt) + + def visit_If(self, node: ast.If) -> None: + for stmt in node.body: + self.visit(stmt) + for stmt in node.orelse: + self.visit(stmt) + + def visit_Assign(self, node: ast.Assign) -> None: + self.visit(node.value) + is_static_dtype = self._expr_is_static_dtype_expr(node.value) + for target in node.targets: + self._update_static_dtype_bindings(target, is_static_dtype=is_static_dtype) + + def visit_AnnAssign(self, node: ast.AnnAssign) -> None: + if node.value is not None: + self.visit(node.value) + is_static_dtype = node.value is not None and self._expr_is_static_dtype_expr(node.value) + self._update_static_dtype_bindings(node.target, is_static_dtype=is_static_dtype) + + def visit_AugAssign(self, node: ast.AugAssign) -> None: + self.visit(node.value) + self._update_static_dtype_bindings(node.target, is_static_dtype=False) + + def visit_With(self, node: ast.With) -> None: + if len(node.items) != 1: + raise self.source_info.error(node, "only single with item is supported in TileLang DSL v1") + item = node.items[0] + if not isinstance(item.context_expr, ast.Call): + raise self.source_info.error(item.context_expr, "with context must be a call in TileLang DSL v1") + if not ( + isinstance(item.context_expr.func, ast.Attribute) + and isinstance(item.context_expr.func.value, ast.Name) + and item.context_expr.func.value.id == "pto" + ): + raise self.source_info.error( + item.context_expr, + "only pto.vecscope/pto.strict_vecscope are supported as with-contexts in TileLang DSL v1", + ) + with_name = item.context_expr.func.attr + if self.kernel_family == "cube" and with_name in {"vecscope", "strict_vecscope"}: + raise self.source_info.error( + item.context_expr, + "@pto.ckernel does not support pto.vecscope()/pto.strict_vecscope(); " + "cube kernels must use linear control flow without vecscope", + ) + if with_name == "vecscope": + if item.context_expr.args or item.context_expr.keywords: + raise self.source_info.error( + item.context_expr, + "pto.vecscope() does not accept positional or keyword arguments in TileLang DSL v1", + ) + if item.optional_vars is not None: + raise self.source_info.error( + item, + "pto.vecscope() does not support `as` bindings in TileLang DSL v1", + ) + elif with_name == "strict_vecscope": + if not self.advanced_enabled: + raise self.source_info.error( + item.context_expr, + advanced_mode_message("strict_vecscope"), + ) + if not isinstance(item.optional_vars, ast.Tuple): + raise self.source_info.error(item, "pto.strict_vecscope requires tuple binding in 'as'") + for elt in item.optional_vars.elts: + if not isinstance(elt, ast.Name): + raise self.source_info.error(elt, "pto.strict_vecscope bindings must be names") + else: + raise self.source_info.error( + item.context_expr, + "only pto.vecscope/pto.strict_vecscope are supported as with-contexts in TileLang DSL v1", + ) + self._vecscope_depth += 1 + try: + for stmt in node.body: + self.visit(stmt) + finally: + self._vecscope_depth -= 1 + + def _validate_no_keyword_unpacking(self, node: ast.Call) -> None: + for keyword in node.keywords: + if keyword.arg is None: + raise self.source_info.error( + keyword.value, + "keyword unpacking via `**` is not supported in TileLang DSL v1", + ) + + def _reject_cube_vector_surface(self, node: ast.Call, surface_name: str) -> None: + raise self.source_info.error( + node, + f"vector-only surface `{surface_name}` is not part of the @pto.ckernel contract", + ) + + def _reject_vector_cube_surface(self, node: ast.Call, surface_name: str) -> None: + raise self.source_info.error( + node, + f"cube-only surface `{surface_name}` is not part of the @pto.vkernel contract", + ) + + def _validate_call_keywords(self, node: ast.Call) -> None: + if not node.keywords: + return + self._validate_no_keyword_unpacking(node) + + if isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name): + namespace = node.func.value.id + name = node.func.attr + elif isinstance(node.func, ast.Name): + namespace = None + name = node.func.id + else: + raise self.source_info.error( + node, + "unsupported call surface in TileLang DSL v1", + ) + + allowed_keywords = _DMA_CALL_KEYWORDS.get(name) if namespace == "pto" else None + if allowed_keywords is None: + call_name = f"{namespace + '.' if namespace else ''}{name}" + raise self.source_info.error( + node, + f"`{call_name}` does not support keyword arguments in TileLang DSL v1; " + "keyword arguments are only supported on selected public call surfaces", + ) + + seen: set[str] = set() + for keyword in node.keywords: + assert keyword.arg is not None + if keyword.arg in seen: + raise self.source_info.error( + keyword.value, + f"duplicate keyword `{keyword.arg}` for `pto.{name}` in TileLang DSL v1", + ) + if keyword.arg not in allowed_keywords: + raise self.source_info.error( + keyword.value, + f"unsupported keyword `{keyword.arg}` for `pto.{name}` in TileLang DSL v1", + ) + seen.add(keyword.arg) + + def visit_Call(self, node: ast.Call) -> None: + if isinstance(node.func, ast.Attribute) and node.func.attr == "eval": + if node.keywords: + raise self.source_info.error( + node, + "`eval` does not support keyword arguments in TileLang DSL v1", + ) + if len(node.args) > 1: + raise self.source_info.error( + node, + "`eval()` accepts at most one positional dtype argument in TileLang DSL v1", + ) + return + + if isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name): + if node.func.attr == "as_ptr": + if node.keywords: + raise self.source_info.error( + node, + "`as_ptr` does not support keyword arguments in TileLang DSL v1", + ) + if node.args: + raise self.source_info.error( + node, + "`as_ptr()` does not accept positional arguments in TileLang DSL v1", + ) + if self.kernel_family == "cube": + return + if self.advanced_enabled: + return + raise self.source_info.error( + node, + "surface `as_ptr` requires advanced=True in TileLang DSL v1", + ) + if node.func.attr == "astype": + if node.keywords: + raise self.source_info.error( + node, + "`astype` does not support keyword arguments in TileLang DSL v1", + ) + if len(node.args) != 1: + raise self.source_info.error( + node, + "`astype()` expects exactly 1 positional argument (target dtype) in TileLang DSL v1", + ) + # Type checking will be done during semantic analysis + return + if node.func.value.id == "pto" and node.func.attr == "tpl": + if self.kernel_family == "cube": + self._validate_no_keyword_unpacking(node) + else: + self._validate_call_keywords(node) + return + if node.func.value.id == "pto" and node.func.attr == "Tile": + self._validate_no_keyword_unpacking(node) + return + if node.func.value.id == "pto" and self.kernel_family == "cube": + self._validate_no_keyword_unpacking(node) + if node.func.attr in SUPPORTED_VECSCOPE_PTO_CALLS: + self._reject_cube_vector_surface(node, f"pto.{node.func.attr}") + if node.func.attr in ADVANCED_VECSCOPE_PTO_CALLS: + self._reject_cube_vector_surface(node, f"pto.{node.func.attr}") + if node.func.attr in ADVANCED_TOPLEVEL_PTO_CALLS: + self._reject_cube_vector_surface(node, f"pto.{node.func.attr}") + if node.func.attr in DEFERRED_PTO_SURFACES: + raise self.source_info.error( + node, + deferred_surface_message(node.func.attr), + ) + return + if node.func.value.id == "pto" and self.kernel_family != "cube": + if node.func.attr in CUBE_ONLY_PTO_CALLS: + self._reject_vector_cube_surface(node, f"pto.{node.func.attr}") + if node.func.value.id == "pto" and node.func.attr in SUPPORTED_TOPLEVEL_PTO_CALLS: + self._validate_call_keywords(node) + return + if node.func.value.id == "pto" and node.func.attr in SUPPORTED_VECSCOPE_PTO_CALLS: + self._validate_call_keywords(node) + return + if node.func.value.id == "pto" and node.func.attr in ADVANCED_VECSCOPE_PTO_CALLS: + if self.advanced_enabled: + self._validate_call_keywords(node) + return + raise self.source_info.error( + node, + advanced_mode_message(node.func.attr), + ) + if node.func.value.id == "pto" and ( + node.func.attr in ADVANCED_EXPR_PTO_CALLS + or node.func.attr in ADVANCED_TOPLEVEL_PTO_CALLS + ): + if self.advanced_enabled: + self._validate_call_keywords(node) + return + raise self.source_info.error( + node, + advanced_mode_message(node.func.attr), + ) + if node.func.value.id == "pto" and node.func.attr in DEFERRED_PTO_SURFACES: + raise self.source_info.error( + node, + deferred_surface_message(node.func.attr), + ) + if node.func.value.id == "pto": + raise self.source_info.error( + node, + f"unsupported op surface `pto.{node.func.attr}` in TileLang DSL v1", + ) + raise self.source_info.error( + node, + f"arbitrary external call `{node.func.value.id}.{node.func.attr}` is not supported " + "in TileLang DSL v1", + ) + + if isinstance(node.func, ast.Name): + if node.func.id == "range": + self._validate_call_keywords(node) + return + if node.func.id in self._static_dtype_bindings: + self._validate_call_keywords(node) + return + inline_proc = _find_inline_proc(node.func.id, module_name=self.module_name) + if inline_proc is not None: + _validate_inline_proc_call_surface(self.source_info, node, inline_proc) + return + raise self.source_info.error( + node, + f"arbitrary external call `{node.func.id}` is not supported in TileLang DSL v1", + ) + + raise self.source_info.error( + node, + "unsupported call surface in TileLang DSL v1", + ) + + def _expr_is_static_dtype_expr(self, node: ast.AST) -> bool: + if isinstance(node, ast.Name): + return node.id in self._static_dtype_bindings + if isinstance(node, ast.Attribute): + if ( + isinstance(node.value, ast.Name) + and node.value.id == "pto" + and node.attr in _DSL_DTYPE_NAMES + ): + return True + if node.attr == "element_type": + return True + return False + + def _update_static_dtype_bindings(self, target: ast.expr, *, is_static_dtype: bool) -> None: + if isinstance(target, ast.Name): + if is_static_dtype: + self._static_dtype_bindings.add(target.id) + else: + self._static_dtype_bindings.discard(target.id) + return + if isinstance(target, (ast.Tuple, ast.List)): + for element in target.elts: + self._update_static_dtype_bindings(element, is_static_dtype=False) + + +def _load_function_source_info(py_fn: Callable[..., Any]) -> _FunctionSourceInfo | None: + try: + source_lines, start_line = inspect.getsourcelines(py_fn) + path = inspect.getsourcefile(py_fn) or inspect.getfile(py_fn) + except (OSError, IOError, TypeError): + return None + + source = textwrap.dedent("".join(source_lines)) + module = ast.parse(source) + for node in module.body: + if isinstance(node, ast.FunctionDef) and node.name == py_fn.__name__: + return _FunctionSourceInfo(path=path, start_line=start_line, function_def=node) + return None + + +def _validate_function_body( + source_info: _FunctionSourceInfo | None, + *, + advanced_enabled: bool, + module_name: str | None, + kernel_family: str, +) -> None: + if source_info is None: + return + _KernelBodyValidator( + source_info, + advanced_enabled=advanced_enabled, + module_name=module_name, + kernel_family=kernel_family, + ).validate() + + +def _raise_tile_param_error( + source_info: _FunctionSourceInfo | None, + param_name: str, + message: str, + fallback_exception: type[Exception] = ValueError, +) -> None: + if source_info is not None: + node = source_info.parameter_node(param_name) + if node is not None: + raise source_info.error(node, message) + raise fallback_exception(message) + + +def _freeze_dtypes(dtypes: Any) -> tuple[tuple[Any, ...], ...]: + if not isinstance(dtypes, (list, tuple)): + raise TypeError("dtypes must be a sequence of signature tuples") + + frozen_signatures = [] + for signature in dtypes: + if not isinstance(signature, (list, tuple)): + raise TypeError("each dtypes entry must be a signature tuple") + frozen_signature = tuple(signature) + for dtype in frozen_signature: + _validate_dtype_pattern(dtype) + frozen_signatures.append(frozen_signature) + + if not frozen_signatures: + raise ValueError("dtypes must contain at least one signature tuple") + + return tuple(frozen_signatures) + + +@dataclass(frozen=True) +class BoundKernelParameter: + """One parameter after v1 monomorphic dtype binding.""" + + name: str + kind: str + annotation: Any + dtype: Any + + @property + def element_dtype(self) -> ScalarType | None: + if self.kind in ("tensorview", "partition_tensor_view", "tile", "ptr", "vector"): + return self.dtype + return None + + +@dataclass(frozen=True) +class KernelParameterSpec: + """One validated Python function parameter before dtype selection.""" + + name: str + kind: str + annotation: Any + + +@dataclass(frozen=True) +class _ConstraintValue: + value: Any | None + + def _coerce_other(self, other: Any) -> Any | None: + if isinstance(other, _ConstraintValue): + return other.value + return other + + def _arith(self, other: Any, fn: Callable[[Any, Any], Any]) -> "_ConstraintValue": + other_value = self._coerce_other(other) + if self.value is None or other_value is None: + return _ConstraintValue(None) + return _ConstraintValue(fn(self.value, other_value)) + + def _compare(self, other: Any, fn: Callable[[Any, Any], bool]) -> bool: + other_value = self._coerce_other(other) + if self.value is None or other_value is None: + return True + return fn(self.value, other_value) + + def __add__(self, other: Any) -> "_ConstraintValue": + return self._arith(other, lambda lhs, rhs: lhs + rhs) + + def __radd__(self, other: Any) -> "_ConstraintValue": + return _ConstraintValue(self._coerce_other(other)).__add__(self) + + def __sub__(self, other: Any) -> "_ConstraintValue": + return self._arith(other, lambda lhs, rhs: lhs - rhs) + + def __rsub__(self, other: Any) -> "_ConstraintValue": + return _ConstraintValue(self._coerce_other(other)).__sub__(self) + + def __mul__(self, other: Any) -> "_ConstraintValue": + return self._arith(other, lambda lhs, rhs: lhs * rhs) + + def __rmul__(self, other: Any) -> "_ConstraintValue": + return _ConstraintValue(self._coerce_other(other)).__mul__(self) + + def __floordiv__(self, other: Any) -> "_ConstraintValue": + return self._arith(other, lambda lhs, rhs: lhs // rhs) + + def __rfloordiv__(self, other: Any) -> "_ConstraintValue": + return _ConstraintValue(self._coerce_other(other)).__floordiv__(self) + + def __eq__(self, other: Any) -> bool: # type: ignore[override] + return self._compare(other, lambda lhs, rhs: lhs == rhs) + + def __ne__(self, other: Any) -> bool: # type: ignore[override] + return self._compare(other, lambda lhs, rhs: lhs != rhs) + + def __le__(self, other: Any) -> bool: + return self._compare(other, lambda lhs, rhs: lhs <= rhs) + + def __lt__(self, other: Any) -> bool: + return self._compare(other, lambda lhs, rhs: lhs < rhs) + + def __ge__(self, other: Any) -> bool: + return self._compare(other, lambda lhs, rhs: lhs >= rhs) + + def __gt__(self, other: Any) -> bool: + return self._compare(other, lambda lhs, rhs: lhs > rhs) + + def __bool__(self) -> bool: + if self.value is None: + return True + return bool(self.value) + + def __repr__(self) -> str: + return "?" if self.value is None else repr(self.value) + + +class _ConstraintSequenceView: + def __init__(self, values: tuple[Any | None, ...]): + self._values = tuple(_ConstraintValue(value) for value in values) + + def __getitem__(self, index: int) -> _ConstraintValue: + if -len(self._values) <= index < len(self._values): + return self._values[index] + return _ConstraintValue(None) + + def __len__(self) -> int: + return len(self._values) + + def __iter__(self): + return iter(self._values) + + def __repr__(self) -> str: + return repr(tuple(self._values)) + + +class _ConstraintParamView: + def __init__(self, name: str, attrs: Mapping[str, Any]): + self._name = name + self._attrs = dict(attrs) + + def _sequence_attr(self, attr_name: str) -> _ConstraintSequenceView: + values = self._attrs.get(attr_name) + if values is None: + rank = self._attrs.get("rank") + if isinstance(rank, int) and rank > 0: + values = (None,) * rank + else: + values = () + return _ConstraintSequenceView(tuple(values)) + + @property + def shape(self) -> _ConstraintSequenceView: + return self._sequence_attr("shape") + + @property + def valid_shape(self) -> _ConstraintSequenceView: + return self._sequence_attr("valid_shape") + + @property + def strides(self) -> _ConstraintSequenceView: + return self._sequence_attr("strides") + + @property + def rank(self) -> _ConstraintValue: + rank = self._attrs.get("rank") + if rank is None: + shape = self._attrs.get("shape") + if shape is not None: + rank = len(shape) + return _ConstraintValue(rank) + + @property + def dtype(self) -> Any: + return self._attrs.get("dtype") + + @property + def memory_space(self) -> Any: + memory_space = self._attrs.get("memory_space") + if memory_space is None and self._attrs.get("kind") == "tile": + return MemorySpace.UB + if memory_space is None: + return None + if isinstance(memory_space, MemorySpace): + return memory_space + return MemorySpace(memory_space) + + @property + def config(self) -> TileConfig | None: + config = self._attrs.get("config") + if config is None: + if self._attrs.get("kind") == "tile": + return TileConfig() + return None + if isinstance(config, TileConfig): + return config + if isinstance(config, Mapping): + return TileConfig.from_mapping(config) + raise TypeError(f"unsupported Tile config payload {config!r} in constraint view") + + def __repr__(self) -> str: + return f"{self._name}<{self._attrs!r}>" + + +@dataclass(frozen=True) +class VKernelDescriptor: + """Descriptor returned by `@tilelang_dsl.vkernel`.""" + + target: str + match_ops: tuple[str, ...] + dtypes: tuple[tuple[Any, ...], ...] + name: str + verify_enabled: bool + advanced_enabled: bool + _parameter_specs: tuple[KernelParameterSpec, ...] + _py_fn: Callable[..., Any] = field(repr=False) + kernel_family: str = "vector" + _source_info: _FunctionSourceInfo | None = field(repr=False, compare=False, default=None) + specializations: tuple[tuple[str, TileSpecialization], ...] = () + constraints: tuple[Callable[[Mapping[str, Any]], Any], ...] = field(default=(), repr=False) + priority: int = 0 + _templates: tuple[tuple[str, tuple[tuple[str, str], ...]], ...] = field(default=(), repr=False) + _inline_procs: tuple[tuple[str, InlineProcDescriptor], ...] = field(default=(), repr=False) + _internal_inline_procs: tuple[tuple[str, InlineProcDescriptor], ...] = field(default=(), repr=False) + _selected_op: str | None = None + _selected_dtype_signature: tuple[ScalarType | MaskType, ...] | None = None + _parameters: tuple[BoundKernelParameter, ...] | None = field(default=None, repr=False) + _constraint_context_attrs: tuple[tuple[str, Any], ...] = field(default=(), repr=False) + + @property + def py_fn(self) -> Callable[..., Any]: + return self._py_fn + + @property + def op(self) -> str: + if self._selected_op is None: + raise ValueError( + "descriptor requires pto.select_kernel(...) to bind a concrete op " + "before reading descriptor.op" + ) + return self._selected_op + + @property + def selected_op(self) -> str | None: + return self._selected_op + + @property + def templates(self) -> dict[str, dict[str, str]]: + return { + slot: dict(op_bindings) + for slot, op_bindings in self._templates + } + + @property + def inline_procs(self) -> dict[str, InlineProcDescriptor]: + return {name: descriptor for name, descriptor in self._inline_procs} + + @property + def internal_inline_procs(self) -> dict[str, InlineProcDescriptor]: + return {name: descriptor for name, descriptor in self._internal_inline_procs} + + @property + def dtype_signature(self) -> tuple[ScalarType | MaskType, ...]: + if self._selected_dtype_signature is None: + raise ValueError( + "descriptor requires pto.select_kernel(...) to choose a concrete dtype signature " + "before materialization" + ) + return self._selected_dtype_signature + + @property + def parameters(self) -> tuple[BoundKernelParameter, ...]: + if self._parameters is None: + raise ValueError( + "descriptor requires pto.select_kernel(...) to bind concrete parameter dtypes " + "before materialization" + ) + return self._parameters + + @property + def metadata(self) -> dict[str, Any]: + return { + "target": self.target, + "op": self._selected_op, + "match_ops": self.match_ops, + "selected_op": self._selected_op, + "dtypes": self.dtypes, + "name": self.name, + "verify": self.verify_enabled, + "advanced": self.advanced_enabled, + "constraints": self.constraints, + "priority": self.priority, + "templates": self.templates, + "inline_procs": tuple(sorted(self.inline_procs.keys())), + } + + @property + def tile_parameters(self) -> tuple[BoundKernelParameter, ...]: + return tuple(param for param in self.parameters if param.kind == "tile") + + @property + def specializations_by_name(self) -> dict[str, TileSpecialization]: + return dict(self.specializations) + + @property + def constraint_context_attrs(self) -> dict[str, Any]: + return dict(self._constraint_context_attrs) + + def _tile_parameter_names(self) -> tuple[str, ...]: + return tuple(param.name for param in self._parameter_specs if param.kind == "tile") + + def _bind_constraint_context_attrs( + self, + context_attrs: Mapping[str, Any], + ) -> "VKernelDescriptor": + frozen_context_attrs = tuple( + sorted(dict(context_attrs).items(), key=lambda item: item[0]) + ) + if self._constraint_context_attrs == frozen_context_attrs: + return self + return VKernelDescriptor( + target=self.target, + match_ops=self.match_ops, + dtypes=self.dtypes, + name=self.name, + verify_enabled=self.verify_enabled, + advanced_enabled=self.advanced_enabled, + kernel_family=self.kernel_family, + _parameter_specs=self._parameter_specs, + _py_fn=self._py_fn, + _source_info=self._source_info, + specializations=self.specializations, + constraints=self.constraints, + priority=self.priority, + _templates=self._templates, + _inline_procs=self._inline_procs, + _internal_inline_procs=self._internal_inline_procs, + _selected_op=self._selected_op, + _selected_dtype_signature=self._selected_dtype_signature, + _parameters=self._parameters, + _constraint_context_attrs=frozen_context_attrs, + ) + + def _bind_selected_dtype_signature( + self, + dtype_signature: tuple[ScalarType | MaskType, ...], + ) -> "VKernelDescriptor": + bound_parameters = _bind_parameters(self._parameter_specs, dtype_signature) + return VKernelDescriptor( + target=self.target, + match_ops=self.match_ops, + dtypes=self.dtypes, + name=self.name, + verify_enabled=self.verify_enabled, + advanced_enabled=self.advanced_enabled, + kernel_family=self.kernel_family, + _parameter_specs=self._parameter_specs, + _py_fn=self._py_fn, + _source_info=self._source_info, + specializations=self.specializations, + constraints=self.constraints, + priority=self.priority, + _templates=self._templates, + _inline_procs=self._inline_procs, + _internal_inline_procs=self._internal_inline_procs, + _selected_op=self._selected_op, + _selected_dtype_signature=dtype_signature, + _parameters=bound_parameters, + _constraint_context_attrs=self._constraint_context_attrs, + ) + + def _bind_selected_op(self, op: str) -> "VKernelDescriptor": + normalized_op = _validate_op(op) + if normalized_op not in self.match_ops: + raise ValueError( + f"selected op {normalized_op!r} is not in descriptor matcher set {self.match_ops!r}" + ) + if self._selected_op == normalized_op: + return self + return VKernelDescriptor( + target=self.target, + match_ops=self.match_ops, + dtypes=self.dtypes, + name=self.name, + verify_enabled=self.verify_enabled, + advanced_enabled=self.advanced_enabled, + kernel_family=self.kernel_family, + _parameter_specs=self._parameter_specs, + _py_fn=self._py_fn, + _source_info=self._source_info, + specializations=self.specializations, + constraints=self.constraints, + priority=self.priority, + _templates=self._templates, + _inline_procs=self._inline_procs, + _internal_inline_procs=self._internal_inline_procs, + _selected_op=normalized_op, + _selected_dtype_signature=self._selected_dtype_signature, + _parameters=self._parameters, + _constraint_context_attrs=self._constraint_context_attrs, + ) + + def specialize(self, **bindings: Any) -> "VKernelDescriptor": + tile_param_names = set(self._tile_parameter_names()) + if not tile_param_names: + if bindings: + unknown = ", ".join(sorted(bindings)) + raise TypeError( + f"specialize() received bindings for non-Tile parameters: {unknown}" + ) + return self + + unknown = sorted(set(bindings) - tile_param_names) + if unknown: + unknown_names = ", ".join(unknown) + raise TypeError( + f"specialize() only accepts bare Tile parameters; got: {unknown_names}" + ) + + updated = self.specializations_by_name + for name, binding in bindings.items(): + updated[name] = _coerce_tile_specialization( + name, + binding, + self._source_info, + kernel_family=self.kernel_family, + ) + + return VKernelDescriptor( + target=self.target, + match_ops=self.match_ops, + dtypes=self.dtypes, + name=self.name, + verify_enabled=self.verify_enabled, + advanced_enabled=self.advanced_enabled, + kernel_family=self.kernel_family, + _parameter_specs=self._parameter_specs, + _source_info=self._source_info, + specializations=tuple(sorted(updated.items())), + constraints=self.constraints, + priority=self.priority, + _templates=self._templates, + _inline_procs=self._inline_procs, + _internal_inline_procs=self._internal_inline_procs, + _selected_op=self._selected_op, + _selected_dtype_signature=self._selected_dtype_signature, + _parameters=self._parameters, + _py_fn=self._py_fn, + _constraint_context_attrs=self._constraint_context_attrs, + ) + + def _require_specialized_tiles(self, api_name: str) -> None: + tile_names = list(self._tile_parameter_names()) + if not tile_names: + return + + specialized = self.specializations_by_name + missing = [name for name in tile_names if name not in specialized] + if missing: + missing_names = ", ".join(missing) + _raise_tile_param_error( + self._source_info, + missing[0], + f"{api_name}() requires specialize() bindings for bare Tile parameters: " + f"{missing_names}", + ) + + def _require_materialization_binding(self, api_name: str) -> None: + self.parameters + if len(self.match_ops) > 1 and self._selected_op is None: + raise ValueError( + f"{api_name}() requires pto.select_kernel(...) to bind a concrete op " + "before materialization" + ) + + def _constraint_context_for_evaluation( + self, + extra_context_attrs: Mapping[str, Any] | None = None, + ) -> dict[str, Any]: + attrs = dict(self._constraint_context_attrs) + if extra_context_attrs is not None: + attrs.update(extra_context_attrs) + attrs.setdefault("target", self.target) + if self._selected_op is not None: + attrs.setdefault("op", self._selected_op) + attrs.setdefault("selected_op", self._selected_op) + + for index, spec in enumerate(self._parameter_specs): + existing = attrs.get(spec.name) + param_attrs = {} if not isinstance(existing, dict) else dict(existing) + positional_prefix = f"arg{index}" + param_attrs.setdefault("kind", spec.kind) + attrs.setdefault(f"{spec.name}_kind", spec.kind) + + def set_sequence_attr(attr_name: str) -> None: + named_key = f"{spec.name}_{attr_name}" + positional_key = f"{positional_prefix}_{attr_name}" + if named_key in attrs: + value = tuple(attrs[named_key]) + elif positional_key in attrs: + value = tuple(attrs[positional_key]) + attrs.setdefault(named_key, value) + else: + return + param_attrs.setdefault(attr_name, value) + + def set_scalar_attr(attr_name: str) -> None: + named_key = f"{spec.name}_{attr_name}" + positional_key = f"{positional_prefix}_{attr_name}" + if named_key in attrs: + value = attrs[named_key] + elif positional_key in attrs: + value = attrs[positional_key] + attrs.setdefault(named_key, value) + else: + return + param_attrs.setdefault(attr_name, value) + + set_sequence_attr("shape") + set_sequence_attr("valid_shape") + set_sequence_attr("strides") + set_scalar_attr("rank") + set_scalar_attr("memory_space") + set_scalar_attr("config") + + if spec.kind in ("tensorview", "partition_tensor_view"): + # TensorView authoring form is normalized to 5D in the current DSL spec. + param_attrs.setdefault("rank", 5) + param_attrs.setdefault("memory_space", "gm") + attrs.setdefault(f"{spec.name}_rank", 5) + attrs.setdefault(f"{spec.name}_memory_space", "gm") + attrs[spec.name] = param_attrs + + if self._parameters is not None: + for param in self._parameters: + param_attrs = attrs.get(param.name) + if not isinstance(param_attrs, dict): + param_attrs = {"kind": param.kind} + param_attrs.setdefault("dtype", param.dtype) + attrs[param.name] = param_attrs + attrs.setdefault(f"{param.name}_dtype", param.dtype) + + for name, spec in self.specializations_by_name.items(): + effective_valid_shape = spec.shape if spec.valid_shape is None else spec.valid_shape + param_attrs = attrs.get(name) + if not isinstance(param_attrs, dict): + param_attrs = {"kind": "tile"} + param_attrs.update( + { + "shape": spec.shape, + "rank": len(spec.shape), + "memory_space": spec.memory_space.value, + "valid_shape": effective_valid_shape, + "config": spec.config, + } + ) + attrs[name] = param_attrs + attrs[f"{name}_shape"] = spec.shape + attrs[f"{name}_rank"] = len(spec.shape) + attrs[f"{name}_memory_space"] = spec.memory_space.value + attrs[f"{name}_valid_shape"] = effective_valid_shape + if len(spec.shape) == 1: + attrs[f"{name}_extent"] = spec.shape[0] + attrs[f"{name}_valid_extent"] = effective_valid_shape[0] + elif len(spec.shape) == 2: + attrs[f"{name}_rows"] = spec.shape[0] + attrs[f"{name}_cols"] = spec.shape[1] + attrs[f"{name}_valid_rows"] = effective_valid_shape[0] + attrs[f"{name}_valid_cols"] = effective_valid_shape[1] + return attrs + + def _validate_materialization_constraints(self, api_name: str) -> None: + if not self.constraints: + return + context_attrs = self._constraint_context_for_evaluation() + evaluation = _evaluate_constraints(self, context_attrs) + _raise_constraint_evaluation_error(evaluation) + if evaluation.passed: + return + raise LookupError( + f"{api_name}() constraint evaluation rejected kernel {self.name!r} " + "for the current specialization/context attributes" + ) + + def _build_authoring_module(self): + self.parameters + frontend_kernel = build_frontend_kernel_node(self) + semantic_kernel = analyze_frontend_kernel(frontend_kernel) + return lower_semantic_kernel(semantic_kernel) + + def mlir_text(self) -> str: + self._require_materialization_binding("mlir_text") + self._require_specialized_tiles("mlir_text") + self._validate_materialization_constraints("mlir_text") + return self._build_authoring_module().render() + + def mlir_module(self) -> "MaterializedMLIRModule": + self._require_materialization_binding("mlir_module") + self._require_specialized_tiles("mlir_module") + return MaterializedMLIRModule(text=self.mlir_text(), target=self.target) + + def emit(self, path: str | Path) -> None: + self._require_materialization_binding("emit") + self._require_specialized_tiles("emit") + self._validate_materialization_constraints("emit") + output_path = Path(path) + output_path.write_text(self.mlir_text(), encoding="utf-8") + + +@dataclass(frozen=True) +class KernelSelectionCandidateMetadata: + """Structured selection diagnostics for one target/op-matched kernel candidate.""" + + descriptor: VKernelDescriptor + status: str + selected_op: str | None = None + matched_dtype_signature: tuple[ScalarType | MaskType, ...] | None = None + reason: str | None = None + failed_constraint_index: int | None = None + failed_constraint_name: str | None = None + failed_constraint_location: str | None = None + error_type: str | None = None + error_message: str | None = None + mlir_text: str | None = None + mlir_error: str | None = None + + @property + def name(self) -> str: + return self.descriptor.name + + @property + def priority(self) -> int: + return self.descriptor.priority + + @property + def match_ops(self) -> tuple[str, ...]: + return self.descriptor.match_ops + + @property + def dtype_signatures(self) -> tuple[tuple[Any, ...], ...]: + return self.descriptor.dtypes + + +@dataclass(frozen=True) +class KernelSelectionReport: + """Structured selector result returned by the opt-in metadata path.""" + + target: str + op: str + operand_types: tuple[ScalarType | MaskType, ...] + selected: VKernelDescriptor | None + candidates: tuple[KernelSelectionCandidateMetadata, ...] = () + final_status: str = "no_candidate" + final_error: str | None = None + _context_attrs: tuple[tuple[str, Any], ...] = field(default=(), repr=False) + + @property + def context_attrs(self) -> dict[str, Any]: + return dict(self._context_attrs) + + @property + def ok(self) -> bool: + return self.final_status == "selected" and self.selected is not None + + +@dataclass(frozen=True) +class _TargetOpSelectionCandidate: + descriptor: VKernelDescriptor + + +@dataclass(frozen=True) +class _DtypeSelectionCandidate: + descriptor: VKernelDescriptor + matched_descriptor: VKernelDescriptor | None = None + matched_dtype_signature: tuple[ScalarType | MaskType, ...] | None = None + + @property + def matched(self) -> bool: + return self.matched_descriptor is not None + + +@dataclass(frozen=True) +class _ConstraintSelectionCandidate: + descriptor: VKernelDescriptor + passed: bool + evaluation: "_ConstraintEvaluationResult" + bound_descriptor: VKernelDescriptor | None = None + + +@dataclass(frozen=True) +class _PrioritySelectionResult: + candidates: tuple[VKernelDescriptor, ...] + highest_priority: int | None + winners: tuple[VKernelDescriptor, ...] + + @property + def has_tie(self) -> bool: + return len(self.winners) > 1 + + @property + def winner(self) -> VKernelDescriptor | None: + if len(self.winners) != 1: + return None + return self.winners[0] + + +@dataclass(frozen=True) +class _MaterializationSelectionCandidate: + descriptor: VKernelDescriptor + mlir_text: str | None = None + mlir_error: str | None = None + + +@dataclass(frozen=True) +class _ConstraintEvaluationResult: + passed: bool + failed_constraint_index: int | None = None + failed_constraint_name: str | None = None + failed_constraint_location: str | None = None + error_type: str | None = None + error_message: str | None = None + + @property + def raised_error(self) -> bool: + return self.error_type is not None + + +class KernelRegistry: + """Explicit registry for TileLang kernel descriptors.""" + + def __init__(self, descriptors: tuple[VKernelDescriptor, ...] = ()): + self._descriptors: list[VKernelDescriptor] = [] + for descriptor in descriptors: + self.register(descriptor) + + def register(self, descriptor: VKernelDescriptor) -> VKernelDescriptor: + if not isinstance(descriptor, VKernelDescriptor): + raise TypeError("KernelRegistry.register() expects a VKernelDescriptor") + self._descriptors.append(descriptor) + return descriptor + + @property + def descriptors(self) -> tuple[VKernelDescriptor, ...]: + return tuple(self._descriptors) + + def __iter__(self): + return iter(self._descriptors) + + def __len__(self) -> int: + return len(self._descriptors) + + +_DEFAULT_KERNEL_REGISTRY = KernelRegistry() + + +@dataclass(frozen=True) +class MaterializedMLIRModule: + text: str + target: str = "a5" + + def __str__(self) -> str: + return self.text + + +def _repo_root() -> Path: + return Path(__file__).resolve().parents[3] + + +def _validate_target(target: str) -> str: + if not isinstance(target, str): + raise TypeError("target must be a string") + if target != "a5": + raise ValueError("TileLang DSL v1 currently only supports target='a5'") + return target + + +def _validate_op(op: Any) -> str: + if not isinstance(op, str) or not op: + raise TypeError("op must be a non-empty string") + return op + + +def _is_schema_form_match_op(op: str) -> bool: + normalized = " ".join(op.split()) + return " ins(" in normalized and "->" in normalized and " outs(" in normalized + + +def _freeze_match_ops(*, op: Any, ops: Any) -> tuple[str, ...]: + if op is not None and ops is not None: + raise ValueError("vkernel() accepts either op= or ops=, but not both") + if op is None and ops is None: + raise ValueError("vkernel() requires exactly one of op= or ops=") + if op is not None: + return (_validate_op(op),) + if not isinstance(ops, (list, tuple)): + raise TypeError("ops must be a sequence of non-empty strings") + if not ops: + raise ValueError("ops must contain at least one op") + normalized_ops = tuple(_validate_op(candidate) for candidate in ops) + if len(set(normalized_ops)) != len(normalized_ops): + raise ValueError("ops must not contain duplicates") + return normalized_ops + + +def _validate_template_slot_name(slot: Any) -> str: + if not isinstance(slot, str) or not slot: + raise TypeError("template slot names must be non-empty strings") + return slot + + +def _validate_template_value(slot: str, op_name: str, value: Any) -> str: + if not isinstance(value, str) or not value: + raise TypeError( + f"templates[{slot!r}][{op_name!r}] must be a non-empty pto op name string" + ) + if value not in _SUPPORTED_TEMPLATE_PTO_CALLS: + raise ValueError( + f"templates[{slot!r}][{op_name!r}] maps to unsupported pto op {value!r}" + ) + return value + + +def _freeze_templates( + templates: Any, + *, + match_ops: tuple[str, ...], +) -> tuple[tuple[str, tuple[tuple[str, str], ...]], ...]: + if templates in (_UNSET, None): + return () + if not isinstance(templates, Mapping): + raise TypeError("templates must be a mapping of slot names to per-op mappings") + + frozen_templates = [] + for slot, op_bindings in templates.items(): + normalized_slot = _validate_template_slot_name(slot) + if not isinstance(op_bindings, Mapping): + raise TypeError( + f"templates[{normalized_slot!r}] must be a mapping of concrete ops to pto op names" + ) + if not op_bindings: + raise ValueError( + f"templates[{normalized_slot!r}] must contain at least one concrete-op mapping" + ) + + frozen_bindings = [] + for concrete_op, real_op in op_bindings.items(): + normalized_concrete_op = _validate_op(concrete_op) + if normalized_concrete_op not in match_ops: + raise ValueError( + f"templates[{normalized_slot!r}] references op {normalized_concrete_op!r} " + f"outside descriptor matcher set {match_ops!r}" + ) + frozen_bindings.append( + ( + normalized_concrete_op, + _validate_template_value(normalized_slot, normalized_concrete_op, real_op), + ) + ) + frozen_templates.append((normalized_slot, tuple(frozen_bindings))) + + return tuple(frozen_templates) + + +def _validate_name(py_fn: Callable[..., Any], name: Any) -> str: + if name is None: + return py_fn.__name__ + if not isinstance(name, str) or not name: + raise TypeError("name must be a non-empty string") + return name + + +def _validate_verify(verify: Any) -> bool: + if not isinstance(verify, bool): + raise TypeError("verify must be a bool") + return verify + + +def _validate_advanced(advanced: Any) -> bool: + if not isinstance(advanced, bool): + raise TypeError("advanced must be a bool") + return advanced + + +def _validate_constraints(constraints: Any) -> tuple[Callable[[Mapping[str, Any]], Any], ...]: + if constraints is _UNSET: + return () + if not isinstance(constraints, (list, tuple)): + raise TypeError("constraints must be a sequence of predicate callables") + + frozen_constraints = [] + for index, constraint in enumerate(constraints): + if not callable(constraint): + raise TypeError(f"constraints[{index}] must be callable") + frozen_constraints.append(constraint) + return tuple(frozen_constraints) + + +def _validate_priority(priority: Any) -> int: + if priority is _UNSET: + return 0 + if isinstance(priority, bool) or not isinstance(priority, int): + raise TypeError("priority must be an int") + return priority + + +def _coerce_memory_space(value: Any, param_name: str) -> MemorySpace: + if isinstance(value, MemorySpace): + return value + if isinstance(value, str): + normalized = value.strip().upper() + try: + return MemorySpace[normalized] + except KeyError as exc: + raise ValueError( + f"specialization for '{param_name}' uses unsupported memory_space {value!r}" + ) from exc + raise TypeError( + f"specialization for '{param_name}' must provide MemorySpace or string memory_space" + ) + + +def _coerce_tile_config(value: Any, param_name: str) -> TileConfig | None: + if value is None: + return None + if isinstance(value, TileConfig): + return value + if isinstance(value, dict): + return TileConfig.from_mapping(value) + raise TypeError( + f"specialization for '{param_name}' must provide TileConfig, dict, or None for config" + ) + + +def _coerce_tile_valid_shape( + shape: tuple[int, ...], + value: Any, + param_name: str, + source_info: _FunctionSourceInfo | None, +) -> tuple[int | None, ...] | None: + if value is None: + return None + if not isinstance(value, (list, tuple)): + _raise_tile_param_error( + source_info, + param_name, + f"specialization for '{param_name}' must provide valid_shape as a tuple/list", + TypeError, + ) + valid_shape = tuple(value) + if len(valid_shape) != len(shape): + _raise_tile_param_error( + source_info, + param_name, + f"illegal Tile profile for '{param_name}': valid_shape rank must match shape rank", + ) + + normalized: list[int | None] = [] + for axis, (valid_dim, shape_dim) in enumerate(zip(valid_shape, shape)): + if isinstance(valid_dim, bool): + _raise_tile_param_error( + source_info, + param_name, + f"illegal Tile profile for '{param_name}': valid_shape axis {axis} must not be bool", + TypeError, + ) + if isinstance(valid_dim, int): + if valid_dim <= 0: + _raise_tile_param_error( + source_info, + param_name, + f"illegal Tile profile for '{param_name}': valid_shape axis {axis} must be positive", + ) + if valid_dim > shape_dim: + _raise_tile_param_error( + source_info, + param_name, + f"illegal Tile profile for '{param_name}': valid_shape axis {axis}={valid_dim} " + f"must be <= shape axis {axis}={shape_dim}", + ) + normalized.append(valid_dim) + continue + if valid_dim is None or isinstance(valid_dim, str): + normalized.append(None) + continue + _raise_tile_param_error( + source_info, + param_name, + f"illegal Tile profile for '{param_name}': valid_shape axis {axis} must be " + "a positive int, string symbol, or None", + TypeError, + ) + return tuple(normalized) + + +def _coerce_tile_specialization( + param_name: str, + binding: Any, + source_info: _FunctionSourceInfo | None, + *, + kernel_family: str = "vector", +) -> TileSpecialization: + if isinstance(binding, TileSpecialization): + spec = binding + elif isinstance(binding, dict): + if "shape" not in binding: + _raise_tile_param_error( + source_info, + param_name, + f"specialization for '{param_name}' must provide a static physical Tile shape", + TypeError, + ) + if "memory_space" not in binding: + _raise_tile_param_error( + source_info, + param_name, + f"specialization for '{param_name}' must provide memory_space", + TypeError, + ) + spec = TileSpecialization( + shape=tuple(binding["shape"]), + memory_space=_coerce_memory_space(binding["memory_space"], param_name), + config=_coerce_tile_config(binding.get("config"), param_name), + valid_shape=binding.get("valid_shape"), + ) + else: + _raise_tile_param_error( + source_info, + param_name, + f"specialization for '{param_name}' must be a TileSpecialization or dict", + TypeError, + ) + + if not spec.shape: + _raise_tile_param_error( + source_info, + param_name, + f"illegal Tile profile for '{param_name}': shape must be non-empty", + ) + for dim in spec.shape: + if not isinstance(dim, int) or isinstance(dim, bool): + _raise_tile_param_error( + source_info, + param_name, + f"dynamic physical Tile shape is not supported for '{param_name}'", + TypeError, + ) + if dim <= 0: + _raise_tile_param_error( + source_info, + param_name, + f"illegal Tile profile for '{param_name}': dimensions must be positive", + ) + if len(spec.shape) not in (1, 2): + _raise_tile_param_error( + source_info, + param_name, + f"illegal Tile profile for '{param_name}': v1 only supports rank-1 or rank-2 Tile shapes", + ) + allowed_memory_spaces = ( + {MemorySpace.UB} + if kernel_family != "cube" + else { + MemorySpace.MAT, + MemorySpace.LEFT, + MemorySpace.RIGHT, + MemorySpace.ACC, + MemorySpace.BIAS, + MemorySpace.UB, + } + ) + if spec.memory_space not in allowed_memory_spaces: + if kernel_family == "cube": + allowed_text = "MemorySpace.MAT/LEFT/RIGHT/ACC/BIAS/UB" + else: + allowed_text = "MemorySpace.UB" + _raise_tile_param_error( + source_info, + param_name, + f"illegal Tile profile for '{param_name}': {kernel_family} v1 only supports {allowed_text}", + ) + valid_shape = _coerce_tile_valid_shape(spec.shape, spec.valid_shape, param_name, source_info) + return TileSpecialization( + shape=spec.shape, + memory_space=spec.memory_space, + config=spec.config, + valid_shape=valid_shape, + ) + + +def _validate_leaf_dtype(dtype: Any, param_name: str) -> ScalarType | MaskType: + if not isinstance(dtype, (ScalarType, MaskType)): + raise TypeError( + f"dtypes entry for parameter '{param_name}' must be a TileLang scalar or mask dtype" + ) + return dtype + + +def _freeze_operand_types(operand_types: Any) -> tuple[ScalarType | MaskType, ...]: + if not isinstance(operand_types, (list, tuple)): + raise TypeError("operand_types must be a sequence of TileLang scalar or mask dtypes") + return tuple(_validate_leaf_dtype(dtype, f"operand_types[{index}]") for index, dtype in enumerate(operand_types)) + + +def _matches_wildcard(pattern: WildcardType, actual: ScalarType | MaskType) -> bool: + if pattern.name == "AnyType": + return isinstance(actual, ScalarType) + if pattern.name == "AnyFloat": + return isinstance(actual, ScalarType) and actual.name in {"f16", "bf16", "f32"} + if pattern.name == "AnyInt": + return isinstance(actual, ScalarType) and is_integer_dtype(actual) + if pattern.name == "AnyMask": + return isinstance(actual, MaskType) + raise TypeError(f"unsupported wildcard matcher {pattern.name!r}") + + +def _matches_scalar_annotation( + annotation: ScalarType | MaskType | WildcardType | TypeVariable, + actual: ScalarType | MaskType, +) -> bool: + if isinstance(annotation, (ScalarType, MaskType)): + return annotation == actual + if isinstance(annotation, WildcardType): + return _matches_wildcard(annotation, actual) + if isinstance(annotation, TypeVariable): + return True + raise TypeError(f"unsupported scalar annotation {annotation!r}") + + +def _match_dtype_signature( + dtype_signature: tuple[Any, ...], + operand_types: tuple[ScalarType | MaskType, ...], +) -> tuple[ScalarType | MaskType, ...] | None: + if len(dtype_signature) != len(operand_types): + return None + + typevar_bindings: dict[str, ScalarType | MaskType] = {} + for pattern, actual in zip(dtype_signature, operand_types): + if isinstance(pattern, (ScalarType, MaskType)): + if pattern != actual: + return None + continue + if isinstance(pattern, WildcardType): + if not _matches_wildcard(pattern, actual): + return None + continue + if isinstance(pattern, TypeVariable): + bound = typevar_bindings.get(pattern.name) + if bound is None: + typevar_bindings[pattern.name] = actual + continue + if bound != actual: + return None + continue + raise TypeError(f"unsupported dtype pattern {pattern!r}") + return operand_types + + +def _match_descriptor_dtype_signature( + descriptor: VKernelDescriptor, + operand_types: tuple[ScalarType | MaskType, ...], +) -> tuple[ScalarType | MaskType, ...] | None: + for dtype_signature in descriptor.dtypes: + matched = _match_dtype_signature(dtype_signature, operand_types) + if matched is not None: + return matched + return None + + +def _validate_parameter_spec(param: inspect.Parameter) -> KernelParameterSpec: + if param.kind not in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ): + raise TypeError( + f"parameter '{param.name}' uses unsupported parameter kind for TileLang DSL v1" + ) + if param.default is not inspect._empty: + raise TypeError( + f"parameter '{param.name}' must not declare a default value in TileLang DSL v1" + ) + if param.annotation is inspect._empty: + raise TypeError( + f"parameter '{param.name}' must declare a TileLang DSL type annotation" + ) + + annotation = param.annotation + if annotation is TensorView: + return KernelParameterSpec( + name=param.name, + kind="tensorview", + annotation=annotation, + ) + if annotation is PartitionTensorView: + return KernelParameterSpec( + name=param.name, + kind="partition_tensor_view", + annotation=annotation, + ) + if annotation is Tile: + return KernelParameterSpec( + name=param.name, + kind="tile", + annotation=annotation, + ) + if isinstance(annotation, VectorType): + return KernelParameterSpec( + name=param.name, + kind="vector", + annotation=annotation, + ) + if isinstance(annotation, PointerType): + return KernelParameterSpec( + name=param.name, + kind="ptr", + annotation=annotation, + ) + if isinstance(annotation, MaskType): + return KernelParameterSpec( + name=param.name, + kind="mask", + annotation=annotation, + ) + if isinstance(annotation, WildcardType) and annotation.name == "AnyMask": + return KernelParameterSpec( + name=param.name, + kind="mask", + annotation=annotation, + ) + if isinstance(annotation, (ScalarType, WildcardType, TypeVariable)): + return KernelParameterSpec( + name=param.name, + kind="scalar", + annotation=annotation, + ) + + raise TypeError( + f"parameter '{param.name}' uses unsupported annotation {annotation!r}" + ) + + +def _collect_parameter_specs(py_fn: Callable[..., Any]) -> tuple[KernelParameterSpec, ...]: + signature = inspect.signature(py_fn) + return tuple(_validate_parameter_spec(param) for param in signature.parameters.values()) + + +def _default_dtype_signature( + parameter_specs: tuple[KernelParameterSpec, ...], +) -> tuple[Any, ...]: + defaults: list[Any] = [] + for param_spec in parameter_specs: + if param_spec.kind in {"tensorview", "partition_tensor_view", "tile"}: + defaults.append(AnyType) + continue + if param_spec.kind == "vector": + defaults.append(param_spec.annotation.element_dtype) + continue + if param_spec.kind == "ptr": + defaults.append(param_spec.annotation.element_dtype) + continue + if param_spec.kind == "mask": + defaults.append(param_spec.annotation if isinstance(param_spec.annotation, MaskType) else AnyMask) + continue + if isinstance(param_spec.annotation, (WildcardType, TypeVariable)): + defaults.append(AnyType) + continue + defaults.append(param_spec.annotation) + return tuple(defaults) + + +def _validate_dtype_arity( + parameter_specs: tuple[KernelParameterSpec, ...], + dtypes: tuple[tuple[Any, ...], ...], + *, + kernel_family: str = "vector", +) -> None: + if kernel_family == "cube": + # Cube v1 reuses the shared descriptor object, but its `dtypes` surface + # follows the cube op contract rather than Python parameter arity. + # Keep descriptor construction permissive here and defer concrete + # semantic checks to cube frontend/semantic analysis. + return + expected_arity = len(parameter_specs) + for dtype_signature in dtypes: + if len(dtype_signature) != expected_arity: + raise ValueError( + "each dtypes signature must match the decorated function parameter count" + ) + + +def _bind_parameter( + param_spec: KernelParameterSpec, + dtype: Any, +) -> BoundKernelParameter: + bound_dtype = _validate_leaf_dtype(dtype, param_spec.name) + if param_spec.kind in {"tensorview", "partition_tensor_view"}: + return BoundKernelParameter( + name=param_spec.name, + kind=param_spec.kind, + annotation=param_spec.annotation, + dtype=bound_dtype, + ) + if param_spec.kind == "tile": + return BoundKernelParameter( + name=param_spec.name, + kind=param_spec.kind, + annotation=param_spec.annotation, + dtype=bound_dtype, + ) + if param_spec.kind == "vector": + if param_spec.annotation.element_dtype != bound_dtype: + raise TypeError( + f"vector parameter '{param_spec.name}' annotation {param_spec.annotation!r} " + f"does not match selected dtype {bound_dtype!r}" + ) + return BoundKernelParameter( + name=param_spec.name, + kind=param_spec.kind, + annotation=param_spec.annotation, + dtype=bound_dtype, + ) + if param_spec.kind == "ptr": + if param_spec.annotation.element_dtype != bound_dtype: + raise TypeError( + f"pointer parameter '{param_spec.name}' annotation {param_spec.annotation!r} " + f"does not match selected dtype {bound_dtype!r}" + ) + return BoundKernelParameter( + name=param_spec.name, + kind=param_spec.kind, + annotation=param_spec.annotation, + dtype=bound_dtype, + ) + if param_spec.kind == "mask": + if not isinstance(bound_dtype, MaskType): + raise TypeError( + f"mask parameter '{param_spec.name}' annotation {param_spec.annotation!r} " + f"does not match selected dtype {bound_dtype!r}" + ) + if isinstance(param_spec.annotation, MaskType) and param_spec.annotation != bound_dtype: + raise TypeError( + f"mask parameter '{param_spec.name}' annotation {param_spec.annotation!r} " + f"does not match selected dtype {bound_dtype!r}" + ) + if isinstance(param_spec.annotation, WildcardType) and not _matches_wildcard(param_spec.annotation, bound_dtype): + raise TypeError( + f"mask parameter '{param_spec.name}' annotation {param_spec.annotation!r} " + f"does not match selected dtype {bound_dtype!r}" + ) + return BoundKernelParameter( + name=param_spec.name, + kind=param_spec.kind, + annotation=param_spec.annotation, + dtype=bound_dtype, + ) + if not _matches_scalar_annotation(param_spec.annotation, bound_dtype): + raise TypeError( + f"scalar parameter '{param_spec.name}' annotation {param_spec.annotation!r} " + f"does not match selected dtype {bound_dtype!r}" + ) + return BoundKernelParameter( + name=param_spec.name, + kind=param_spec.kind, + annotation=param_spec.annotation, + dtype=bound_dtype, + ) + + +def _bind_parameters( + parameter_specs: tuple[KernelParameterSpec, ...], + dtype_signature: tuple[ScalarType | MaskType, ...], +) -> tuple[BoundKernelParameter, ...]: + if len(dtype_signature) != len(parameter_specs): + raise ValueError( + "selected dtype signature must match the decorated function parameter count" + ) + return tuple( + _bind_parameter(param_spec, dtype) + for param_spec, dtype in zip(parameter_specs, dtype_signature) + ) + + +def _build_descriptor( + py_fn: Callable[..., Any], + *, + target: str, + op: Any, + ops: Any, + templates: Any, + dtypes: Any, + name: Any, + verify: Any, + advanced: Any, + constraints: Any, + priority: Any, + kernel_family: str = "vector", +) -> VKernelDescriptor: + if not callable(py_fn): + raise TypeError("@vkernel can only decorate callables") + + source_info = _load_function_source_info(py_fn) + advanced_enabled = _validate_advanced(advanced) + inline_procs = _collect_inline_procs(py_fn.__module__) + internal_inline_procs = _collect_internal_inline_procs() + _validate_function_body( + source_info, + advanced_enabled=advanced_enabled, + module_name=py_fn.__module__, + kernel_family=kernel_family, + ) + match_ops = _freeze_match_ops(op=op, ops=ops) + if kernel_family == "cube": + for match_op in match_ops: + if _is_schema_form_match_op(match_op): + raise ValueError( + "@pto.ckernel does not support schema-form op matching; " + "use concrete op strings such as 'pto.mad' or ops=[...]" + ) + frozen_templates = _freeze_templates(templates, match_ops=match_ops) + parameter_specs = _collect_parameter_specs(py_fn) + if dtypes is None: + dtypes = (_default_dtype_signature(parameter_specs),) + frozen_dtypes = _freeze_dtypes(dtypes) + _validate_dtype_arity(parameter_specs, frozen_dtypes, kernel_family=kernel_family) + + selected_op: str | None = None + selected_dtype_signature: tuple[ScalarType | MaskType, ...] | None = None + bound_parameters: tuple[BoundKernelParameter, ...] | None = None + if len(match_ops) == 1: + selected_op = match_ops[0] + if ( + len(frozen_dtypes) == 1 + and all(isinstance(dtype, (ScalarType, MaskType)) for dtype in frozen_dtypes[0]) + and (kernel_family != "cube" or len(frozen_dtypes[0]) == len(parameter_specs)) + ): + selected_dtype_signature = tuple(frozen_dtypes[0]) + bound_parameters = _bind_parameters(parameter_specs, selected_dtype_signature) + + return VKernelDescriptor( + target=_validate_target(target), + match_ops=match_ops, + dtypes=frozen_dtypes, + name=_validate_name(py_fn, name), + verify_enabled=_validate_verify(verify), + advanced_enabled=advanced_enabled, + kernel_family=kernel_family, + _parameter_specs=parameter_specs, + _py_fn=py_fn, + _source_info=source_info, + constraints=_validate_constraints(constraints), + priority=_validate_priority(priority), + _templates=frozen_templates, + _inline_procs=inline_procs, + _internal_inline_procs=internal_inline_procs, + _selected_op=selected_op, + _selected_dtype_signature=selected_dtype_signature, + _parameters=bound_parameters, + _constraint_context_attrs=(), + ) + + +def _evaluate_constraints( + descriptor: VKernelDescriptor, + context_attrs: Mapping[str, Any], +) -> _ConstraintEvaluationResult: + named_context: dict[str, Any] = { + "target": context_attrs.get("target"), + "op": context_attrs.get("op"), + "selected_op": context_attrs.get("selected_op"), + } + for spec in descriptor._parameter_specs: + param_attrs = context_attrs.get(spec.name) + if not isinstance(param_attrs, Mapping): + param_attrs = {} + named_context[spec.name] = _ConstraintParamView(spec.name, param_attrs) + + for index, constraint in enumerate(descriptor.constraints): + constraint_name = _constraint_callable_name(constraint) + constraint_location = _constraint_callable_location(constraint) + try: + signature = inspect.signature(constraint) + parameters = list(signature.parameters.values()) + kwargs: dict[str, Any] = {} + for parameter in parameters: + if parameter.kind == inspect.Parameter.VAR_POSITIONAL: + raise TypeError("constraint callables with *args are not supported") + if parameter.kind == inspect.Parameter.VAR_KEYWORD: + for key, value in named_context.items(): + kwargs.setdefault(key, value) + for key, value in context_attrs.items(): + kwargs.setdefault(key, value) + continue + if parameter.name in named_context: + kwargs[parameter.name] = named_context[parameter.name] + continue + if parameter.name in context_attrs: + kwargs[parameter.name] = context_attrs[parameter.name] + continue + if parameter.default is not inspect._empty: + continue + raise TypeError( + f"constraint {index} for kernel {descriptor.name!r} requires unsupported parameter " + f"{parameter.name!r}" + ) + result = constraint(**kwargs) + except Exception as exc: + return _ConstraintEvaluationResult( + passed=False, + failed_constraint_index=index, + failed_constraint_name=constraint_name, + failed_constraint_location=constraint_location, + error_type=type(exc).__name__, + error_message=( + f"constraint {index} for kernel {descriptor.name!r} " + f"raised {type(exc).__name__}: {exc}" + f"{_format_constraint_location_suffix(constraint_location)}" + ), + ) + if not result: + return _ConstraintEvaluationResult( + passed=False, + failed_constraint_index=index, + failed_constraint_name=constraint_name, + failed_constraint_location=constraint_location, + error_message=( + f"constraint {index} for kernel {descriptor.name!r} returned False" + f"{_format_constraint_location_suffix(constraint_location)}" + ), + ) + return _ConstraintEvaluationResult(passed=True) + + +def _constraint_callable_name(constraint: Callable[..., Any]) -> str | None: + qualname = getattr(constraint, "__qualname__", None) + if isinstance(qualname, str) and qualname: + return qualname + name = getattr(constraint, "__name__", None) + if isinstance(name, str) and name: + return name + return None + + +def _constraint_callable_location(constraint: Callable[..., Any]) -> str | None: + code = getattr(constraint, "__code__", None) + filename = getattr(code, "co_filename", None) + firstlineno = getattr(code, "co_firstlineno", None) + if isinstance(filename, str) and filename and isinstance(firstlineno, int) and firstlineno > 0: + return f"{filename}:{firstlineno}" + return None + + +def _format_constraint_location_suffix(location: str | None) -> str: + if location is None: + return "" + return f" at {location}" + + +def _raise_constraint_evaluation_error(result: _ConstraintEvaluationResult) -> None: + if not result.raised_error or result.error_message is None: + return + raise TypeError(result.error_message) + + +def _format_descriptor_identity(descriptor: VKernelDescriptor) -> str: + dtype_signature = descriptor._selected_dtype_signature + if dtype_signature is None: + dtype_signature = tuple("?" for _ in descriptor.dtypes[0]) if descriptor.dtypes else () + return f"{descriptor.name}(priority={descriptor.priority}, dtypes={dtype_signature!r})" + + +def _bind_descriptor_for_target_op( + descriptor: VKernelDescriptor, + *, + target: str, + op: str, +) -> VKernelDescriptor | None: + if descriptor.target != target: + return None + if op not in descriptor.match_ops: + return None + return descriptor._bind_selected_op(op) + + +def _collect_target_op_candidates( + registry: KernelRegistry, + *, + target: str, + op: str, +) -> tuple[_TargetOpSelectionCandidate, ...]: + candidates: list[_TargetOpSelectionCandidate] = [] + for descriptor in registry: + op_bound_descriptor = _bind_descriptor_for_target_op( + descriptor, + target=target, + op=op, + ) + if op_bound_descriptor is None: + continue + candidates.append(_TargetOpSelectionCandidate(descriptor=op_bound_descriptor)) + return tuple(candidates) + + +def _evaluate_dtype_candidate( + candidate: _TargetOpSelectionCandidate, + *, + operand_types: tuple[ScalarType | MaskType, ...], +) -> _DtypeSelectionCandidate: + matched_signature = _match_descriptor_dtype_signature(candidate.descriptor, operand_types) + if matched_signature is None: + return _DtypeSelectionCandidate(descriptor=candidate.descriptor) + if candidate.descriptor._selected_dtype_signature == matched_signature: + return _DtypeSelectionCandidate( + descriptor=candidate.descriptor, + matched_descriptor=candidate.descriptor, + matched_dtype_signature=matched_signature, + ) + return _DtypeSelectionCandidate( + descriptor=candidate.descriptor, + matched_descriptor=candidate.descriptor._bind_selected_dtype_signature(matched_signature), + matched_dtype_signature=matched_signature, + ) + + +def _evaluate_dtype_candidates( + candidates: tuple[_TargetOpSelectionCandidate, ...], + *, + operand_types: tuple[ScalarType | MaskType, ...], +) -> tuple[_DtypeSelectionCandidate, ...]: + return tuple( + _evaluate_dtype_candidate( + candidate, + operand_types=operand_types, + ) + for candidate in candidates + ) + + +def _match_descriptor_query( + descriptor: VKernelDescriptor, + *, + target: str, + op: str, + operand_types: tuple[ScalarType | MaskType, ...], +) -> VKernelDescriptor | None: + op_bound_descriptor = _bind_descriptor_for_target_op( + descriptor, + target=target, + op=op, + ) + if op_bound_descriptor is None: + return None + dtype_result = _evaluate_dtype_candidate( + _TargetOpSelectionCandidate(descriptor=op_bound_descriptor), + operand_types=operand_types, + ) + return dtype_result.matched_descriptor + + +def _evaluate_constraint_candidate( + descriptor: VKernelDescriptor, + *, + context_attrs: Mapping[str, Any], +) -> _ConstraintSelectionCandidate: + evaluation = _evaluate_constraints( + descriptor, + descriptor._constraint_context_for_evaluation(context_attrs), + ) + if not evaluation.passed: + return _ConstraintSelectionCandidate( + descriptor=descriptor, + passed=False, + evaluation=evaluation, + ) + return _ConstraintSelectionCandidate( + descriptor=descriptor, + passed=True, + evaluation=evaluation, + bound_descriptor=descriptor._bind_constraint_context_attrs(context_attrs), + ) + + +def _evaluate_constraint_candidates( + descriptors: tuple[VKernelDescriptor, ...], + *, + context_attrs: Mapping[str, Any], +) -> tuple[_ConstraintSelectionCandidate, ...]: + return tuple( + _evaluate_constraint_candidate( + descriptor, + context_attrs=context_attrs, + ) + for descriptor in descriptors + ) + + +def _resolve_priority_candidates( + descriptors: tuple[VKernelDescriptor, ...], +) -> _PrioritySelectionResult: + if not descriptors: + return _PrioritySelectionResult( + candidates=(), + highest_priority=None, + winners=(), + ) + highest_priority = max(descriptor.priority for descriptor in descriptors) + winners = tuple( + descriptor + for descriptor in descriptors + if descriptor.priority == highest_priority + ) + return _PrioritySelectionResult( + candidates=descriptors, + highest_priority=highest_priority, + winners=winners, + ) + + +def _materialize_selection_candidate( + descriptor: VKernelDescriptor, +) -> _MaterializationSelectionCandidate: + try: + return _MaterializationSelectionCandidate( + descriptor=descriptor, + mlir_text=descriptor.mlir_text(), + ) + except Exception as exc: + return _MaterializationSelectionCandidate( + descriptor=descriptor, + mlir_error=str(exc), + ) + + +def _collect_materialization_candidates( + descriptors: tuple[VKernelDescriptor, ...], +) -> tuple[_MaterializationSelectionCandidate, ...]: + return tuple( + _materialize_selection_candidate(descriptor) + for descriptor in descriptors + ) + + +def _select_kernel_no_candidate_error( + *, + target: str, + op: str, + operand_types: tuple[ScalarType | MaskType, ...], +) -> str: + return ( + "select_kernel() found no registered kernel for " + f"target={target!r}, op={op!r}, operand_types={operand_types!r}" + ) + + +def _select_kernel_constraint_error( + *, + target: str, + op: str, + operand_types: tuple[ScalarType | MaskType, ...], +) -> str: + return ( + "select_kernel() found no registered kernel after constraint evaluation for " + f"target={target!r}, op={op!r}, operand_types={operand_types!r}" + ) + + +def _select_kernel_priority_tie_error( + *, + target: str, + op: str, + operand_types: tuple[ScalarType | MaskType, ...], + winners: tuple[VKernelDescriptor, ...], +) -> str: + winner_set = ", ".join(sorted(_format_descriptor_identity(descriptor) for descriptor in winners)) + return ( + "select_kernel() found multiple highest-priority kernels for " + f"target={target!r}, op={op!r}, operand_types={operand_types!r}: " + f"{winner_set}" + ) + + +def _build_selection_report( + *, + target: str, + op: str, + operand_types: tuple[ScalarType | MaskType, ...], + context_attrs: Mapping[str, Any], + dtype_results: tuple[_DtypeSelectionCandidate, ...], + constraint_results: tuple[_ConstraintSelectionCandidate, ...], + materialization_results: tuple[_MaterializationSelectionCandidate, ...], + priority_result: _PrioritySelectionResult, + final_status: str, + final_error: str | None, +) -> KernelSelectionReport: + constraint_by_descriptor_id = { + id(result.descriptor): result + for result in constraint_results + } + materialization_by_descriptor_id = { + id(result.descriptor): result + for result in materialization_results + } + winner_ids = {id(descriptor) for descriptor in priority_result.winners} + highest_priority = priority_result.highest_priority + candidates: list[KernelSelectionCandidateMetadata] = [] + + for dtype_result in dtype_results: + if dtype_result.matched_descriptor is None: + candidates.append( + KernelSelectionCandidateMetadata( + descriptor=dtype_result.descriptor, + status="dtype_mismatch", + selected_op=dtype_result.descriptor.selected_op, + reason=( + "no dtype signature matched " + f"operand_types={operand_types!r}" + ), + ) + ) + continue + + constraint_result = constraint_by_descriptor_id.get(id(dtype_result.matched_descriptor)) + if constraint_result is None: + continue + evaluation = constraint_result.evaluation + candidate_descriptor = constraint_result.bound_descriptor or dtype_result.matched_descriptor + materialization_result = materialization_by_descriptor_id.get(id(candidate_descriptor)) + base_kwargs = { + "descriptor": candidate_descriptor, + "selected_op": candidate_descriptor.selected_op, + "matched_dtype_signature": dtype_result.matched_dtype_signature, + "failed_constraint_index": evaluation.failed_constraint_index, + "failed_constraint_name": evaluation.failed_constraint_name, + "failed_constraint_location": evaluation.failed_constraint_location, + "error_type": evaluation.error_type, + "error_message": evaluation.error_message, + "mlir_text": None if materialization_result is None else materialization_result.mlir_text, + "mlir_error": None if materialization_result is None else materialization_result.mlir_error, + } + + if evaluation.raised_error: + candidates.append( + KernelSelectionCandidateMetadata( + status="constraint_error", + reason=evaluation.error_message, + **base_kwargs, + ) + ) + continue + if not evaluation.passed: + candidates.append( + KernelSelectionCandidateMetadata( + status="constraint_failed", + reason=evaluation.error_message, + **base_kwargs, + ) + ) + continue + if id(candidate_descriptor) in winner_ids: + status = "selected" if final_status == "selected" else "priority_tie" + reason = None if status == "selected" else final_error + else: + status = "priority_shadowed" + if highest_priority is None: + reason = "not selected" + else: + reason = f"shadowed by higher-priority candidate priority={highest_priority}" + candidates.append( + KernelSelectionCandidateMetadata( + status=status, + reason=reason, + **base_kwargs, + ) + ) + + frozen_context_attrs = tuple( + sorted(dict(context_attrs).items(), key=lambda item: item[0]) + ) + return KernelSelectionReport( + target=target, + op=op, + operand_types=operand_types, + selected=priority_result.winner if final_status == "selected" else None, + candidates=tuple(candidates), + final_status=final_status, + final_error=final_error, + _context_attrs=frozen_context_attrs, + ) + + +def select_kernel( + target: str, + op: str, + operand_types: Any, + context_attrs: Mapping[str, Any] | None = None, + registry: KernelRegistry | None = None, + *, + return_metadata: bool = False, + include_mlir: bool = True, +) -> VKernelDescriptor | KernelSelectionReport: + """Select one registered kernel descriptor for the given query.""" + + normalized_target = _validate_target(target) + normalized_op = _validate_op(op) + normalized_operand_types = _freeze_operand_types(operand_types) + + if context_attrs is None: + normalized_context_attrs: dict[str, Any] = {} + elif isinstance(context_attrs, Mapping): + normalized_context_attrs = dict(context_attrs) + else: + raise TypeError("context_attrs must be a mapping or None") + + active_registry = _DEFAULT_KERNEL_REGISTRY if registry is None else registry + if not isinstance(active_registry, KernelRegistry): + raise TypeError("registry must be a KernelRegistry or None") + if not isinstance(return_metadata, bool): + raise TypeError("return_metadata must be a bool") + if not isinstance(include_mlir, bool): + raise TypeError("include_mlir must be a bool") + + target_op_candidates = _collect_target_op_candidates( + active_registry, + target=normalized_target, + op=normalized_op, + ) + dtype_results = _evaluate_dtype_candidates( + target_op_candidates, + operand_types=normalized_operand_types, + ) + type_matched_candidates = tuple( + result.matched_descriptor + for result in dtype_results + if result.matched_descriptor is not None + ) + + if not type_matched_candidates: + no_candidate_error = _select_kernel_no_candidate_error( + target=normalized_target, + op=normalized_op, + operand_types=normalized_operand_types, + ) + if return_metadata: + return _build_selection_report( + target=normalized_target, + op=normalized_op, + operand_types=normalized_operand_types, + context_attrs=normalized_context_attrs, + dtype_results=dtype_results, + constraint_results=(), + materialization_results=(), + priority_result=_PrioritySelectionResult(candidates=(), highest_priority=None, winners=()), + final_status="no_candidate", + final_error=no_candidate_error, + ) + raise LookupError(no_candidate_error) + + constraint_results = _evaluate_constraint_candidates( + type_matched_candidates, + context_attrs=normalized_context_attrs, + ) + constrained_candidates = tuple( + result.bound_descriptor + for result in constraint_results + if result.bound_descriptor is not None + ) + if return_metadata: + priority_result = _resolve_priority_candidates(constrained_candidates) + materialization_results = ( + _collect_materialization_candidates(constrained_candidates) + if include_mlir + else () + ) + final_status = "selected" + final_error: str | None = None + if not constrained_candidates: + final_status = "no_candidate" + error_messages = [ + result.evaluation.error_message + for result in constraint_results + if result.evaluation.error_message is not None + ] + final_error = error_messages[0] if error_messages else _select_kernel_constraint_error( + target=normalized_target, + op=normalized_op, + operand_types=normalized_operand_types, + ) + elif priority_result.has_tie: + final_status = "priority_tie" + final_error = _select_kernel_priority_tie_error( + target=normalized_target, + op=normalized_op, + operand_types=normalized_operand_types, + winners=priority_result.winners, + ) + return _build_selection_report( + target=normalized_target, + op=normalized_op, + operand_types=normalized_operand_types, + context_attrs=normalized_context_attrs, + dtype_results=dtype_results, + constraint_results=constraint_results, + materialization_results=materialization_results, + priority_result=priority_result, + final_status=final_status, + final_error=final_error, + ) + for result in constraint_results: + _raise_constraint_evaluation_error(result.evaluation) + if not constrained_candidates: + raise LookupError( + _select_kernel_constraint_error( + target=normalized_target, + op=normalized_op, + operand_types=normalized_operand_types, + ) + ) + + priority_result = _resolve_priority_candidates(constrained_candidates) + if priority_result.has_tie: + raise LookupError( + _select_kernel_priority_tie_error( + target=normalized_target, + op=normalized_op, + operand_types=normalized_operand_types, + winners=priority_result.winners, + ) + ) + assert priority_result.winner is not None + return priority_result.winner + + +def vkernel( + py_fn: Callable[..., Any] | None = None, + *, + target: str = "a5", + op: str | None = None, + ops: tuple[str, ...] | list[str] | None = None, + templates: Any = _UNSET, + dtypes: Any = None, + name: str | None = None, + verify: bool = True, + advanced: bool = False, + constraints: Any = _UNSET, + priority: Any = _UNSET, +) -> VKernelDescriptor | Callable[[Callable[..., Any]], VKernelDescriptor]: + """Create a TileLang DSL v1 kernel descriptor. + + v1 keeps only the minimal descriptor metadata surface: + `target`, `op`/`ops`, `templates`, `dtypes`, `constraints`, `priority`, `name`, + `verify`, and opt-in `advanced`. + """ + + def wrap(fn: Callable[..., Any]) -> VKernelDescriptor: + descriptor = _build_descriptor( + fn, + target=target, + op=op, + ops=ops, + templates=templates, + dtypes=dtypes, + name=name, + verify=verify, + advanced=advanced, + constraints=constraints, + priority=priority, + kernel_family="vector", + ) + return _DEFAULT_KERNEL_REGISTRY.register(descriptor) + + if py_fn is None: + return wrap + return wrap(py_fn) + + +def ckernel( + py_fn: Callable[..., Any] | None = None, + *, + target: str = "a5", + op: str | None = None, + ops: tuple[str, ...] | list[str] | None = None, + templates: Any = _UNSET, + dtypes: Any = None, + name: str | None = None, + priority: Any = _UNSET, +) -> VKernelDescriptor | Callable[[Callable[..., Any]], VKernelDescriptor]: + """Create a TileLang DSL cube-kernel descriptor. + + This public entrypoint intentionally reuses the existing descriptor and + registry path first; cube-specific semantic and lowering behavior is added + incrementally in follow-up tasks. + """ + + def wrap(fn: Callable[..., Any]) -> VKernelDescriptor: + descriptor = _build_descriptor( + fn, + target=target, + op=op, + ops=ops, + templates=templates, + dtypes=dtypes, + name=name, + verify=True, + advanced=False, + constraints=_UNSET, + priority=priority, + kernel_family="cube", + ) + return _DEFAULT_KERNEL_REGISTRY.register(descriptor) + + if py_fn is None: + return wrap + return wrap(py_fn) + + +__all__ = [ + "BoundKernelParameter", + "InlineProcDescriptor", + "KernelRegistry", + "KernelSelectionCandidateMetadata", + "KernelSelectionReport", + "MaterializedMLIRModule", + "TileLangFrontendError", + "VKernelDescriptor", + "ckernel", + "inline_proc", + "select_kernel", + "vkernel", +] diff --git a/tilelang-dsl/python/tilelang_dsl/lowering.py b/tilelang-dsl/python/tilelang_dsl/lowering.py new file mode 100644 index 000000000..ec7dcab00 --- /dev/null +++ b/tilelang-dsl/python/tilelang_dsl/lowering.py @@ -0,0 +1,4639 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""Authoring-form VPTO lowering skeleton for TileLang DSL v1.""" + +from __future__ import annotations + +import math +import re +import struct +from dataclasses import dataclass + +from .semantic import ( + SemanticAlignStoreStmt, + SemanticAlignType, + SemanticAssignStmt, + SemanticAttributeAccess, + SemanticBinaryExpr, + SemanticBindingRef, + SemanticCallExpr, + SemanticDmaConfigStmt, + SemanticDmaUnaryConfigStmt, + SemanticDmaLoadStmt, + SemanticDmaStoreStmt, + SemanticExpr, + SemanticExprStmt, + SemanticForStmt, + SemanticGetBufStmt, + SemanticIndexCastExpr, + SemanticIfStmt, + SemanticIndexType, + SemanticIfResult, + SemanticKernel, + SemanticLiteralExpr, + SemanticMemBarStmt, + SemanticLowLevelCopyStmt, + SemanticMaskType, + SemanticMetaType, + SemanticPadValueType, + SemanticPipeBarrierStmt, + SemanticPredicateStoreStmt, + SemanticPtrType, + SemanticReturnStmt, + SemanticRlsBufStmt, + SemanticScalarStoreStmt, + SemanticScalarType, + SemanticSetCrossCoreStmt, + SemanticSetFlagStmt, + SemanticSetIntraBlockStmt, + SemanticSetIntraCoreStmt, + SemanticShapeType, + SemanticStmt, + SemanticVecscopeStmt, + SemanticStrictVecscopeStmt, + SemanticSubscriptAccess, + SemanticSymbolExpr, + SemanticTensorSliceExpr, + SemanticTensorViewType, + SemanticPartitionTensorViewType, + SemanticTileType, + SemanticType, + SemanticTupleExpr, + SemanticTupleType, + SemanticVScatterStmt, + SemanticVRegType, + SemanticVectorType, + SemanticVectorPairStoreStmt, + SemanticVectorStoreStmt, + SemanticWaitFlagDevStmt, + SemanticWaitFlagStmt, + SemanticWaitIntraCoreStmt, +) +from .types import ( + MaskPattern, + PadValue, + ScalarType, + TileConfig, + bytewidth, + get_lanes, + integer_bitwidth, + integer_signedness, + is_float_dtype, + is_integer_dtype, +) + + +_I1_TYPE = SemanticScalarType(dtype=ScalarType("i1")) +_I32_TYPE = SemanticScalarType(dtype=ScalarType("i32")) +_I64_TYPE = SemanticScalarType(dtype=ScalarType("i64")) + + +def _signless_mov_pad_scalar_type(dtype: ScalarType) -> SemanticScalarType | None: + bitwidth = integer_bitwidth(dtype) + if bitwidth == 8: + return SemanticScalarType(dtype=ScalarType("i8")) + if bitwidth == 16: + return SemanticScalarType(dtype=ScalarType("i16")) + if bitwidth == 32: + return SemanticScalarType(dtype=ScalarType("i32")) + return None + + +def _format_symbol_name(symbol_name: str) -> str: + if re.fullmatch(r"[A-Za-z_][A-Za-z0-9_$.]*", symbol_name): + return f"@{symbol_name}" + escaped = symbol_name.replace("\\", "\\\\").replace('"', '\\"') + return f'@"{escaped}"' + + +@dataclass(frozen=True) +class AuthoringModule: + """Lowering result that owns authoring-form VPTO text emission.""" + + kernel: SemanticKernel + + def render(self) -> str: + kernel_text = _AuthoringRenderer(self.kernel).render() + if not self.kernel.inline_helpers: + return kernel_text + + base_lines = kernel_text.splitlines() + module_close_index = max( + (index for index, line in enumerate(base_lines) if line == "}"), + default=-1, + ) + if module_close_index < 0: + return kernel_text + + merged_lines = base_lines[:module_close_index] + for helper in self.kernel.inline_helpers: + helper_lines = _extract_single_function_lines( + _AuthoringRenderer(helper).render() + ) + if not helper_lines: + continue + helper_lines[0] = _rewrite_inline_helper_attrs(helper_lines[0]) + merged_lines.extend(helper_lines) + + merged_lines.append("}") + merged_lines.append("") + return "\n".join(merged_lines) + + +def _extract_single_function_lines(rendered_text: str) -> list[str]: + lines = rendered_text.splitlines() + try: + function_start = next( + index for index, line in enumerate(lines) if line.lstrip().startswith("func.func ") + ) + except StopIteration: + return [] + module_close_index = max( + (index for index, line in enumerate(lines) if line == "}"), + default=-1, + ) + if module_close_index <= function_start: + return [] + return lines[function_start:module_close_index] + + +def _rewrite_inline_helper_attrs(function_line: str) -> str: + helper_attr = "private " + helper_marker_attr = "pto.tilelang.inline_proc" + if function_line.lstrip().startswith("func.func "): + after_keyword = function_line.lstrip()[len("func.func ") :] + if not after_keyword.startswith(("private ", "public ")): + function_line = function_line.replace("func.func ", f"func.func {helper_attr}", 1) + if "pto.tilelang.instance" in function_line: + return function_line.replace("pto.tilelang.instance", helper_marker_attr) + if "attributes {" in function_line: + return function_line + if function_line.rstrip().endswith("{"): + stripped = function_line.rstrip() + if stripped.lstrip().startswith("func.func "): + after_keyword = stripped.lstrip()[len("func.func ") :] + if not after_keyword.startswith(("private ", "public ")): + stripped = stripped.replace("func.func ", f"func.func {helper_attr}", 1) + return stripped[:-1] + f" attributes {{ {helper_marker_attr} }} {{" + return function_line + + +@dataclass(frozen=True) +class _RenderedValue: + name: str + type: SemanticType + + +@dataclass(frozen=True) +class _RenderedTextualType(SemanticType): + text: str + + +@dataclass(frozen=True) +class _DmaTransferConfig: + n_burst: _RenderedValue + len_burst: _RenderedValue + copy_src_stride: _RenderedValue + copy_dst_stride: _RenderedValue + loop_src_stride: _RenderedValue + loop_dst_stride: _RenderedValue + + +@dataclass(frozen=True) +class _DmaLoadPaddingProfile: + pad_mode_name: str + left_padding: int + right_padding: int + init_out_buffer: bool + pad_value: SemanticExpr | None + + +@dataclass(frozen=True) +class _DmaStoreTrimProfile: + left_padding: int + right_padding: int + + +class _AuthoringRenderer: + def __init__(self, kernel: SemanticKernel): + self.kernel = kernel + self._constant_lines: list[str] = [] + self._constant_cache: dict[tuple[str, object], str] = {} + self._castptr_cache: dict[tuple[str, str], str] = {} + self._tile_memref_cache: dict[str, _RenderedValue] = {} + self._tile_valid_dim_cache: dict[tuple[str, int], _RenderedValue] = {} + self._used_tile_buffers = self._collect_used_tile_buffers(kernel.body) + self._temp_counter = 0 + self._loop_counter = 0 + + def render(self) -> str: + parameter_list = ", ".join( + f"{param.ssa_name}: {self._render_type(param.type)}" + for param in self.kernel.parameters + if param.kind != "tile_valid_shape" and self._should_materialize_function_boundary_type(param.type) + ) + result_sig = "" + if self.kernel.body and isinstance(self.kernel.body[-1], SemanticReturnStmt): + return_value = self.kernel.body[-1].value + if return_value is not None: + result_sig = f" -> {self._render_type(return_value.type)}" + env = { + param.name: _RenderedValue(name=param.ssa_name, type=param.type) + for param in self.kernel.parameters + if param.kind != "tile_valid_shape" + } + entry_lines: list[str] = [] + for param in self.kernel.parameters: + if param.kind != "tile": + continue + if param.name in self._used_tile_buffers: + self._materialize_tile_memref( + env[param.name], + indent=4, + into=entry_lines, + ) + body_lines = self._render_block(self.kernel.body, env, indent=4) + + lines = [ + f"// tilelang.target = {self.kernel.target}", + f"// tilelang.op = {self.kernel.op}", + f"// tilelang.dtypes = {self.kernel.dtype_signature}", + f"// tilelang.verify = {self.kernel.verify_enabled}", + f"// tilelang.advanced = {self.kernel.advanced_enabled}", + ] + for binding in self.kernel.tile_bindings: + valid_shape = "" + if binding.valid_shape is not None: + valid_shape = f" valid_shape={self._format_shape_tuple(binding.valid_shape)}" + lines.append( + "// tilelang.specialize " + f"{binding.name} shape={binding.shape} memory_space={binding.memory_space} " + f"config={binding.config}{valid_shape}" + ) + kernel_kind = "cube" if self.kernel.kernel_family == "cube" else "vector" + function_attrs = ( + "attributes { " + f"pto.tilelang.instance, pto.kernel_kind = #pto.kernel_kind<{kernel_kind}> " + "}" + ) + lines.append(f'module attributes {{pto.target_arch = "{self.kernel.target}"}} {{') + lines.append( + " func.func " + f"{_format_symbol_name(self.kernel.symbol_name)}({parameter_list}){result_sig} " + f"{function_attrs} {{" + ) + lines.extend(self._constant_lines) + lines.extend(entry_lines) + lines.extend(body_lines) + lines.append(" }") + lines.append("}") + lines.append("") + return "\n".join(lines) + + def _should_materialize_function_boundary_type(self, ty: SemanticType) -> bool: + return not isinstance(ty, (SemanticMetaType, SemanticPadValueType)) + + def _collect_used_tile_buffers( + self, + statements: tuple[SemanticStmt, ...], + ) -> set[str]: + used: set[str] = set() + for stmt in statements: + self._collect_used_tile_buffers_from_stmt(stmt, used) + return used + + def _collect_used_tile_buffers_from_stmt( + self, + stmt: SemanticStmt, + used: set[str], + ) -> None: + if isinstance(stmt, SemanticAssignStmt): + self._collect_used_tile_buffers_from_expr(stmt.value, used) + return + if isinstance(stmt, SemanticExprStmt): + self._collect_used_tile_buffers_from_expr(stmt.expr, used) + return + if isinstance(stmt, SemanticDmaLoadStmt): + self._record_tile_buffer_use(stmt.dst, used) + self._collect_used_tile_buffers_from_expr(stmt.src, used) + return + if isinstance(stmt, SemanticDmaStoreStmt): + self._record_tile_buffer_use(stmt.src, used) + self._collect_used_tile_buffers_from_expr(stmt.dst, used) + return + if isinstance(stmt, SemanticVectorStoreStmt): + self._collect_used_tile_buffers_from_expr(stmt.value, used) + self._record_tile_buffer_use(stmt.destination, used) + for index in stmt.indices: + self._collect_used_tile_buffers_from_expr(index, used) + self._collect_used_tile_buffers_from_expr(stmt.mask, used) + return + if isinstance(stmt, SemanticVScatterStmt): + self._collect_used_tile_buffers_from_expr(stmt.value, used) + self._record_tile_buffer_use(stmt.destination, used) + self._collect_used_tile_buffers_from_expr(stmt.offsets, used) + self._collect_used_tile_buffers_from_expr(stmt.mask, used) + return + if isinstance(stmt, SemanticPredicateStoreStmt): + self._collect_used_tile_buffers_from_expr(stmt.value, used) + self._record_tile_buffer_use(stmt.destination, used) + for index in stmt.indices: + self._collect_used_tile_buffers_from_expr(index, used) + self._collect_used_tile_buffers_from_expr(stmt.dist, used) + return + if isinstance(stmt, SemanticAlignStoreStmt): + self._collect_used_tile_buffers_from_expr(stmt.value, used) + self._record_tile_buffer_use(stmt.destination, used) + for index in stmt.indices: + self._collect_used_tile_buffers_from_expr(index, used) + if stmt.offset is not None: + self._collect_used_tile_buffers_from_expr(stmt.offset, used) + return + if isinstance(stmt, SemanticVecscopeStmt): + for nested in stmt.body: + self._collect_used_tile_buffers_from_stmt(nested, used) + return + if isinstance(stmt, SemanticStrictVecscopeStmt): + for capture in stmt.captures: + self._record_tile_buffer_use(capture, used) + self._collect_used_tile_buffers_from_expr(capture, used) + for nested in stmt.body: + self._collect_used_tile_buffers_from_stmt(nested, used) + return + if isinstance(stmt, SemanticForStmt): + self._collect_used_tile_buffers_from_expr(stmt.lower_bound, used) + self._collect_used_tile_buffers_from_expr(stmt.upper_bound, used) + self._collect_used_tile_buffers_from_expr(stmt.step, used) + for nested in stmt.body: + self._collect_used_tile_buffers_from_stmt(nested, used) + return + if isinstance(stmt, SemanticIfStmt): + self._collect_used_tile_buffers_from_expr(stmt.condition, used) + for nested in stmt.then_body: + self._collect_used_tile_buffers_from_stmt(nested, used) + for nested in stmt.else_body: + self._collect_used_tile_buffers_from_stmt(nested, used) + return + if isinstance(stmt, SemanticReturnStmt) and stmt.value is not None: + self._collect_used_tile_buffers_from_expr(stmt.value, used) + + def _collect_used_tile_buffers_from_expr( + self, + expr: SemanticExpr, + used: set[str], + ) -> None: + if isinstance(expr, SemanticCallExpr): + if expr.namespace == "pto" and expr.name in {"vlds", "vldas", "vldus"} and expr.args: + self._record_tile_buffer_use(expr.args[0], used) + for arg in expr.args: + self._collect_used_tile_buffers_from_expr(arg, used) + return + if isinstance(expr, SemanticBinaryExpr): + self._collect_used_tile_buffers_from_expr(expr.lhs, used) + self._collect_used_tile_buffers_from_expr(expr.rhs, used) + return + if isinstance(expr, SemanticTupleExpr): + for element in expr.elements: + self._collect_used_tile_buffers_from_expr(element, used) + return + if isinstance(expr, SemanticTensorSliceExpr): + self._collect_used_tile_buffers_from_expr(expr.base, used) + for slice_expr in expr.slices: + if slice_expr.start is not None: + self._collect_used_tile_buffers_from_expr(slice_expr.start, used) + if slice_expr.stop is not None: + self._collect_used_tile_buffers_from_expr(slice_expr.stop, used) + if slice_expr.step is not None: + self._collect_used_tile_buffers_from_expr(slice_expr.step, used) + return + if isinstance(expr, SemanticAttributeAccess): + if expr.attr not in {"shape", "valid_shape", "strides", "element_type"}: + self._collect_used_tile_buffers_from_expr(expr.base, used) + return + if isinstance(expr, SemanticSubscriptAccess): + self._collect_used_tile_buffers_from_expr(expr.base, used) + self._collect_used_tile_buffers_from_expr(expr.index, used) + + def _record_tile_buffer_use( + self, + expr: SemanticExpr, + used: set[str], + ) -> None: + if isinstance(expr, SemanticBindingRef) and isinstance(expr.type, SemanticTileType): + used.add(expr.binding.name) + + def _render_block( + self, + statements: tuple[SemanticStmt, ...], + env: dict[str, _RenderedValue], + *, + indent: int, + ) -> list[str]: + lines: list[str] = [] + for stmt in statements: + lines.extend(self._render_stmt(stmt, env, indent=indent)) + return lines + + def _render_stmt( + self, + stmt: SemanticStmt, + env: dict[str, _RenderedValue], + *, + indent: int, + ) -> list[str]: + if isinstance(stmt, SemanticAssignStmt): + return self._render_assign(stmt, env, indent=indent) + if isinstance(stmt, SemanticExprStmt): + lines: list[str] = [] + self._lower_expr(stmt.expr, env, indent=indent, into=lines) + return lines + if isinstance(stmt, SemanticDmaLoadStmt): + return self._render_dma_load(stmt, env, indent=indent) + if isinstance(stmt, SemanticDmaStoreStmt): + return self._render_dma_store(stmt, env, indent=indent) + if isinstance(stmt, SemanticVectorStoreStmt): + return self._render_vector_store(stmt, env, indent=indent) + if isinstance(stmt, SemanticVScatterStmt): + return self._render_vscatter(stmt, env, indent=indent) + if isinstance(stmt, SemanticVectorPairStoreStmt): + return self._render_vector_pair_store(stmt, env, indent=indent) + if isinstance(stmt, SemanticPredicateStoreStmt): + return self._render_predicate_store(stmt, env, indent=indent) + if isinstance(stmt, SemanticAlignStoreStmt): + return self._render_align_store(stmt, env, indent=indent) + if isinstance(stmt, SemanticScalarStoreStmt): + return self._render_scalar_store(stmt, env, indent=indent) + if isinstance(stmt, SemanticSetFlagStmt): + return [ + self._indent(indent) + + f'pto.set_flag["{stmt.src_pipe}", "{stmt.dst_pipe}", "{stmt.event}"]' + ] + if isinstance(stmt, SemanticWaitFlagStmt): + return [ + self._indent(indent) + + f'pto.wait_flag["{stmt.src_pipe}", "{stmt.dst_pipe}", "{stmt.event}"]' + ] + if isinstance(stmt, SemanticPipeBarrierStmt): + return [self._indent(indent) + f"pto.barrier #pto.pipe<{stmt.pipe}>"] + if isinstance(stmt, SemanticGetBufStmt): + return self._render_buffer_sync_stmt("get_buf", stmt.pipe, stmt.buf_id, stmt.mode, env, indent=indent) + if isinstance(stmt, SemanticRlsBufStmt): + return self._render_buffer_sync_stmt("rls_buf", stmt.pipe, stmt.buf_id, stmt.mode, env, indent=indent) + if isinstance(stmt, SemanticMemBarStmt): + return [self._indent(indent) + f'pto.mem_bar "{stmt.barrier_type}"'] + if isinstance(stmt, SemanticSetCrossCoreStmt): + return self._render_i64_pair_stmt("set_cross_core", stmt.core_id, stmt.event_id, env, indent=indent) + if isinstance(stmt, SemanticSetIntraBlockStmt): + return self._render_i64_pair_stmt("set_intra_block", stmt.block_id, stmt.event_id, env, indent=indent) + if isinstance(stmt, SemanticSetIntraCoreStmt): + return self._render_i32_stmt("set_intra_core", stmt.config, env, indent=indent) + if isinstance(stmt, SemanticWaitFlagDevStmt): + return self._render_i64_pair_stmt("wait_flag_dev", stmt.core_id, stmt.event_id, env, indent=indent) + if isinstance(stmt, SemanticWaitIntraCoreStmt): + return self._render_i64_pair_stmt("wait_intra_core", stmt.block_id, stmt.event_id, env, indent=indent) + if isinstance(stmt, SemanticDmaUnaryConfigStmt): + return self._render_dma_unary_config(stmt, env, indent=indent) + if isinstance(stmt, SemanticDmaConfigStmt): + return self._render_dma_config(stmt, env, indent=indent) + if isinstance(stmt, SemanticLowLevelCopyStmt): + return self._render_low_level_copy(stmt, env, indent=indent) + if isinstance(stmt, SemanticReturnStmt): + lines: list[str] = [] + if stmt.value is None: + return [self._indent(indent) + "return"] + value = self._lower_expr(stmt.value, env, indent=indent, into=lines) + lines.append(self._indent(indent) + f"return {value.name} : {self._render_type(value.type)}") + return lines + if isinstance(stmt, SemanticVecscopeStmt): + return self._render_vecscope(stmt, env, indent=indent) + if isinstance(stmt, SemanticStrictVecscopeStmt): + return self._render_strict_vecscope(stmt, env, indent=indent) + if isinstance(stmt, SemanticForStmt): + return self._render_for(stmt, env, indent=indent) + if isinstance(stmt, SemanticIfStmt): + return self._render_if(stmt, env, indent=indent) + raise ValueError(f"unsupported semantic statement {type(stmt).__name__}") + + def _render_dma_config( + self, + stmt: SemanticDmaConfigStmt, + env: dict[str, _RenderedValue], + *, + indent: int, + ) -> list[str]: + lines: list[str] = [] + first = self._lower_to_i64(stmt.first, env, indent=indent, into=lines) + second = self._lower_to_i64(stmt.second, env, indent=indent, into=lines) + lines.append( + self._indent(indent) + + f"pto.{stmt.name} {first.name}, {second.name} : i64, i64" + ) + return lines + + def _render_dma_unary_config( + self, + stmt: SemanticDmaUnaryConfigStmt, + env: dict[str, _RenderedValue], + *, + indent: int, + ) -> list[str]: + lines: list[str] = [] + value = self._lower_expr(stmt.value, env, indent=indent, into=lines) + if ( + stmt.name == "set_mov_pad_val" + and isinstance(value.type, SemanticScalarType) + and is_integer_dtype(value.type.dtype) + ): + signless_type = _signless_mov_pad_scalar_type(value.type.dtype) + if signless_type is not None: + value = self._coerce_rendered_value( + value, + signless_type, + indent=indent, + into=lines, + ) + lines.append( + self._indent(indent) + + f"pto.{stmt.name} {value.name} : {self._render_type(value.type)}" + ) + return lines + + def _render_buffer_sync_stmt( + self, + name: str, + pipe: str, + buf_id: SemanticExpr, + mode: SemanticExpr, + env: dict[str, _RenderedValue], + *, + indent: int, + ) -> list[str]: + lines: list[str] = [] + rendered_buf_id = self._lower_to_i64(buf_id, env, indent=indent, into=lines) + rendered_mode = self._lower_to_i64(mode, env, indent=indent, into=lines) + lines.append( + self._indent(indent) + + f'pto.{name} "{pipe}", {rendered_buf_id.name}, {rendered_mode.name} : i64, i64' + ) + return lines + + def _render_i64_pair_stmt( + self, + name: str, + first: SemanticExpr, + second: SemanticExpr, + env: dict[str, _RenderedValue], + *, + indent: int, + ) -> list[str]: + lines: list[str] = [] + rendered_first = self._lower_to_i64(first, env, indent=indent, into=lines) + rendered_second = self._lower_to_i64(second, env, indent=indent, into=lines) + lines.append( + self._indent(indent) + + f"pto.{name} {rendered_first.name}, {rendered_second.name} : i64, i64" + ) + return lines + + def _render_i32_stmt( + self, + name: str, + value: SemanticExpr, + env: dict[str, _RenderedValue], + *, + indent: int, + ) -> list[str]: + lines: list[str] = [] + rendered_value = self._lower_to_i32(value, env, indent=indent, into=lines) + lines.append( + self._indent(indent) + + f"pto.{name} {rendered_value.name} : i32" + ) + return lines + + def _render_low_level_copy( + self, + stmt: SemanticLowLevelCopyStmt, + env: dict[str, _RenderedValue], + *, + indent: int, + ) -> list[str]: + lines: list[str] = [] + source = self._lower_expr(stmt.source, env, indent=indent, into=lines) + destination = self._lower_expr(stmt.destination, env, indent=indent, into=lines) + + rendered_operands = [] + rendered_types = [] + for index, operand in enumerate(stmt.operands): + if stmt.name == "copy_gm_to_ubuf" and index == 5: + lowered = self._lower_to_i1(operand, env, indent=indent, into=lines) + else: + lowered = self._lower_to_i64(operand, env, indent=indent, into=lines) + rendered_operands.append(lowered.name) + rendered_types.append(self._render_type(lowered.type)) + + operand_text = ", ".join([source.name, destination.name, *rendered_operands]) + type_text = ", ".join( + [self._render_type(source.type), self._render_type(destination.type), *rendered_types] + ) + lines.append( + self._indent(indent) + + f"pto.{stmt.name} {operand_text} : {type_text}" + ) + return lines + + def _render_assign( + self, + stmt: SemanticAssignStmt, + env: dict[str, _RenderedValue], + *, + indent: int, + ) -> list[str]: + if len(stmt.targets) != 1: + if isinstance(stmt.value, SemanticTupleExpr) or ( + isinstance(stmt.value, SemanticAttributeAccess) + and isinstance(stmt.value.type, SemanticShapeType) + ): + return self._render_tuple_expr_assign(stmt, env, indent=indent) + return self._render_multi_result_assign(stmt, env, indent=indent) + target = stmt.targets[0] + if isinstance(target.type, (SemanticMetaType, SemanticPadValueType)): + env[target.name] = _RenderedValue(name=target.ssa_name, type=target.type) + return [] + lines: list[str] = [] + lowered = self._lower_expr( + stmt.value, + env, + indent=indent, + desired_name=target.ssa_name, + into=lines, + ) + env[target.name] = lowered + return lines + + def _render_tuple_expr_assign( + self, + stmt: SemanticAssignStmt, + env: dict[str, _RenderedValue], + *, + indent: int, + ) -> list[str]: + if isinstance(stmt.value, SemanticTupleExpr): + elements = stmt.value.elements + elif isinstance(stmt.value, SemanticAttributeAccess) and isinstance(stmt.value.type, SemanticShapeType): + elements = tuple( + SemanticSubscriptAccess( + base=stmt.value, + index=SemanticLiteralExpr(value=axis, type=SemanticIndexType()), + type=SemanticIndexType(), + ) + for axis in range(stmt.value.type.rank) + ) + else: + raise NotImplementedError( + "tuple expression assignment expects a SemanticTupleExpr or shape-like attribute value" + ) + if len(stmt.targets) != len(elements): + raise NotImplementedError("tuple expression assignment arity mismatch") + + lines: list[str] = [] + for target, element in zip(stmt.targets, elements): + lowered = self._lower_expr( + element, + env, + indent=indent, + desired_name=target.ssa_name, + into=lines, + ) + env[target.name] = lowered + return lines + + def _render_multi_result_assign( + self, + stmt: SemanticAssignStmt, + env: dict[str, _RenderedValue], + *, + indent: int, + ) -> list[str]: + if not isinstance(stmt.value, SemanticCallExpr): + raise NotImplementedError("multi-result assignment expects a call expression in TileLang DSL v1") + if stmt.value.namespace != "pto": + raise NotImplementedError( + f"multi-result assignment for `pto.{stmt.value.name}` is not supported in TileLang DSL v1" + ) + if not isinstance(stmt.value.type, SemanticTupleType): + raise NotImplementedError("multi-result lowering expects a tuple result type") + if len(stmt.targets) != len(stmt.value.type.elements): + raise NotImplementedError("multi-result lowering expects assignment arity to match result arity") + + if stmt.value.name in {"make_mask", "plt_b8", "plt_b16", "plt_b32"}: + if len(stmt.targets) != 2 or len(stmt.value.type.elements) != 2: + raise NotImplementedError("mask multi-result lowering expects exactly two results") + lines: list[str] = [] + if stmt.value.name == "make_mask": + dtype_expr, remaining_expr = stmt.value.args + if not self._is_dtype_meta_expr(dtype_expr): + raise NotImplementedError("make_mask dtype lowering expects a dtype symbol") + remaining = self._lower_remaining_to_i32(remaining_expr, env, indent=indent, into=lines) + op_name = None + else: + remaining = self._lower_remaining_to_i32(stmt.value.args[0], env, indent=indent, into=lines) + op_name = stmt.value.name + mask_target, remaining_target = stmt.targets + mask_type, remaining_type = stmt.value.type.elements + suffix = self._mask_suffix(mask_type) + lowered_op = op_name or f"plt_{suffix}" + lines.append( + self._indent(indent) + + f"{mask_target.ssa_name}, {remaining_target.ssa_name} = pto.{lowered_op} {remaining.name} : " + + f"i32 -> {self._render_type(mask_type)}, {self._render_type(remaining_type)}" + ) + env[mask_target.name] = _RenderedValue(name=mask_target.ssa_name, type=mask_type) + env[remaining_target.name] = _RenderedValue(name=remaining_target.ssa_name, type=remaining_type) + return lines + + if stmt.value.name in {"vaddc", "vsubc"}: + lines = [] + lhs = self._lower_expr(stmt.value.args[0], env, indent=indent, into=lines) + rhs = self._lower_expr(stmt.value.args[1], env, indent=indent, into=lines) + mask = self._lower_expr(stmt.value.args[2], env, indent=indent, into=lines) + result_target, carry_target = stmt.targets + result_type, carry_type = stmt.value.type.elements + lines.append( + self._indent(indent) + + f"{result_target.ssa_name}, {carry_target.ssa_name} = pto.{stmt.value.name} " + + f"{lhs.name}, {rhs.name}, {mask.name} : " + + f"{self._render_type(lhs.type)}, {self._render_type(rhs.type)}, {self._render_type(mask.type)} " + + f"-> {self._render_type(result_type)}, {self._render_type(carry_type)}" + ) + env[result_target.name] = _RenderedValue(name=result_target.ssa_name, type=result_type) + env[carry_target.name] = _RenderedValue(name=carry_target.ssa_name, type=carry_type) + return lines + + if stmt.value.name in {"vaddcs", "vsubcs"}: + lines = [] + lhs = self._lower_expr(stmt.value.args[0], env, indent=indent, into=lines) + rhs = self._lower_expr(stmt.value.args[1], env, indent=indent, into=lines) + carry_in = self._lower_expr(stmt.value.args[2], env, indent=indent, into=lines) + mask = self._lower_expr(stmt.value.args[3], env, indent=indent, into=lines) + result_target, carry_target = stmt.targets + result_type, carry_type = stmt.value.type.elements + lines.append( + self._indent(indent) + + f"{result_target.ssa_name}, {carry_target.ssa_name} = pto.{stmt.value.name} " + + f"{lhs.name}, {rhs.name}, {carry_in.name}, {mask.name} : " + + f"{self._render_type(lhs.type)}, {self._render_type(rhs.type)}, " + + f"{self._render_type(carry_in.type)}, {self._render_type(mask.type)} " + + f"-> {self._render_type(result_type)}, {self._render_type(carry_type)}" + ) + env[result_target.name] = _RenderedValue(name=result_target.ssa_name, type=result_type) + env[carry_target.name] = _RenderedValue(name=carry_target.ssa_name, type=carry_type) + return lines + + if stmt.value.name in {"vintlv", "vdintlv"}: + lines = [] + lhs = self._lower_expr(stmt.value.args[0], env, indent=indent, into=lines) + rhs = self._lower_expr(stmt.value.args[1], env, indent=indent, into=lines) + low_target, high_target = stmt.targets + low_type, high_type = stmt.value.type.elements + lines.append( + self._indent(indent) + + f"{low_target.ssa_name}, {high_target.ssa_name} = pto.{stmt.value.name} " + + f"{lhs.name}, {rhs.name} : {self._render_type(lhs.type)}, {self._render_type(rhs.type)} " + + f"-> {self._render_type(low_type)}, {self._render_type(high_type)}" + ) + env[low_target.name] = _RenderedValue(name=low_target.ssa_name, type=low_type) + env[high_target.name] = _RenderedValue(name=high_target.ssa_name, type=high_type) + return lines + + if stmt.value.name in {"pdintlv_b8", "pdintlv_b16", "pdintlv_b32", "pintlv_b8", "pintlv_b16", "pintlv_b32"}: + lines = [] + lhs = self._lower_expr(stmt.value.args[0], env, indent=indent, into=lines) + rhs = self._lower_expr(stmt.value.args[1], env, indent=indent, into=lines) + low_target, high_target = stmt.targets + low_type, high_type = stmt.value.type.elements + lines.append( + self._indent(indent) + + f"{low_target.ssa_name}, {high_target.ssa_name} = pto.{stmt.value.name} " + + f"{lhs.name}, {rhs.name} : {self._render_type(lhs.type)}, {self._render_type(rhs.type)} " + + f"-> {self._render_type(low_type)}, {self._render_type(high_type)}" + ) + env[low_target.name] = _RenderedValue(name=low_target.ssa_name, type=low_type) + env[high_target.name] = _RenderedValue(name=high_target.ssa_name, type=high_type) + return lines + + if stmt.value.name == "vmull": + lines = [] + lhs = self._lower_expr(stmt.value.args[0], env, indent=indent, into=lines) + rhs = self._lower_expr(stmt.value.args[1], env, indent=indent, into=lines) + mask = self._lower_expr(stmt.value.args[2], env, indent=indent, into=lines) + low_target, high_target = stmt.targets + low_type, high_type = stmt.value.type.elements + lines.append( + self._indent(indent) + + f"{low_target.ssa_name}, {high_target.ssa_name} = pto.vmull " + + f"{lhs.name}, {rhs.name}, {mask.name} : " + + f"{self._render_type(lhs.type)}, {self._render_type(rhs.type)}, {self._render_type(mask.type)} " + + f"-> {self._render_type(low_type)}, {self._render_type(high_type)}" + ) + env[low_target.name] = _RenderedValue(name=low_target.ssa_name, type=low_type) + env[high_target.name] = _RenderedValue(name=high_target.ssa_name, type=high_type) + return lines + + if stmt.value.name == "vldsx2": + lines = [] + source = self._lower_expr(stmt.value.args[0], env, indent=indent, into=lines) + if isinstance(source.type, SemanticTileType): + source = self._materialize_tile_memref(source, indent=indent, into=lines) + index_args = stmt.value.args[1:-1] + if ( + isinstance(stmt.value.args[0].type, SemanticTileType) + and stmt.value.args[0].type.rank == 2 + and len(index_args) == 2 + ): + source = self._materialize_rank2_tile_subview( + source, + stmt.value.args[0].type, + index_args, + env, + indent=indent, + into=lines, + ) + rendered_indices = self._materialize_constant(0, SemanticIndexType()) + else: + rendered_indices = self._render_index_list(index_args, env, indent=indent, into=lines) + dist = self._render_string_literal(stmt.value.args[-1]) + low_target, high_target = stmt.targets + low_type, high_type = stmt.value.type.elements + lines.append( + self._indent(indent) + + f"{low_target.ssa_name}, {high_target.ssa_name} = pto.vldsx2 " + + f"{source.name}[{rendered_indices}], {dist} : " + + f"{self._render_type(source.type)}, index -> " + + f"{self._render_type(low_type)}, {self._render_type(high_type)}" + ) + env[low_target.name] = _RenderedValue(name=low_target.ssa_name, type=low_type) + env[high_target.name] = _RenderedValue(name=high_target.ssa_name, type=high_type) + return lines + + if stmt.value.name == "vldus": + lines = [] + source = self._lower_expr(stmt.value.args[0], env, indent=indent, into=lines) + index_args = stmt.value.args[1:-1] + if isinstance(source.type, SemanticTileType): + source = self._materialize_tile_memref(source, indent=indent, into=lines) + if ( + isinstance(stmt.value.args[0].type, SemanticTileType) + and stmt.value.args[0].type.rank == 2 + and len(index_args) == 2 + ): + source = self._materialize_rank2_tile_subview( + source, + stmt.value.args[0].type, + index_args, + env, + indent=indent, + into=lines, + ) + if self._is_memref_like_type(source.type): + ptr_name, ptr_type = self._materialize_copy_buffer_ptr(source, indent=indent, into=lines) + source = _RenderedValue(name=ptr_name, type=_RenderedTextualType(ptr_type)) + align = self._lower_expr(stmt.value.args[-1], env, indent=indent, into=lines) + result_target, align_target = stmt.targets + result_type, align_type = stmt.value.type.elements + lines.append( + self._indent(indent) + + f"{result_target.ssa_name}, {align_target.ssa_name} = pto.vldus " + + f"{source.name}, {align.name} : " + + f"{self._render_type(source.type)}, {self._render_type(align.type)} -> " + + f"{self._render_type(result_type)}, {self._render_type(align_type)}" + ) + env[result_target.name] = _RenderedValue(name=result_target.ssa_name, type=result_type) + env[align_target.name] = _RenderedValue(name=align_target.ssa_name, type=align_type) + return lines + + if stmt.value.name == "pstu": + lines = [] + align_in = self._lower_expr(stmt.value.args[0], env, indent=indent, into=lines) + value = self._lower_expr(stmt.value.args[1], env, indent=indent, into=lines) + base = self._lower_expr(stmt.value.args[2], env, indent=indent, into=lines) + align_target, base_target = stmt.targets + align_type, base_type = stmt.value.type.elements + lines.append( + self._indent(indent) + + f"{align_target.ssa_name}, {base_target.ssa_name} = pto.pstu " + + f"{align_in.name}, {value.name}, {base.name} : " + + f"{self._render_type(align_in.type)}, {self._render_type(value.type)}, {self._render_type(base.type)} " + + f"-> {self._render_type(align_type)}, {self._render_type(base_type)}" + ) + env[align_target.name] = _RenderedValue(name=align_target.ssa_name, type=align_type) + env[base_target.name] = _RenderedValue(name=base_target.ssa_name, type=base_type) + return lines + + if stmt.value.name == "get_vms4_sr": + lines = [] + result_names = ", ".join(target.ssa_name for target in stmt.targets) + result_types = ", ".join(self._render_type(result_type) for result_type in stmt.value.type.elements) + lines.append( + self._indent(indent) + + f"{result_names} = pto.get_vms4_sr : {result_types}" + ) + for target, result_type in zip(stmt.targets, stmt.value.type.elements): + env[target.name] = _RenderedValue(name=target.ssa_name, type=result_type) + return lines + + raise NotImplementedError( + f"multi-result assignment for `pto.{stmt.value.name}` is not supported in TileLang DSL v1" + ) + + def _render_dma_load( + self, + stmt: SemanticDmaLoadStmt, + env: dict[str, _RenderedValue], + *, + indent: int, + ) -> list[str]: + lines: list[str] = [] + profile = self._resolve_dma_load_padding_profile(stmt.options) + src = self._lower_expr(stmt.src.base, env, indent=indent, into=lines) + dst = self._lower_expr(stmt.dst, env, indent=indent, into=lines) + src_name, src_type = self._materialize_tensor_slice_ptr( + stmt.src, + src, + env, + indent=indent, + into=lines, + ) + dst_name, dst_type = self._materialize_tile_window_ptr( + dst, + col_offset=profile.left_padding, + indent=indent, + into=lines, + ) + transfer = self._infer_dma_load_transfer(stmt.src, stmt.dst.type, src, env, indent=indent, into=lines) + + copy_lines = self._render_dma_load_copy_ops( + src_name, + src_type, + dst_name, + dst_type, + transfer, + indent=indent, + ) + prefill_lines = self._render_dma_load_prefill( + stmt.dst, + dst, + env, + profile, + indent=indent, + ) + if profile.pad_mode_name == "PadFirstElem": + lines.extend(copy_lines) + lines.extend(prefill_lines) + if profile.init_out_buffer: + lines.extend(copy_lines) + return lines + + lines.extend(prefill_lines) + lines.extend(copy_lines) + return lines + + def _render_dma_store( + self, + stmt: SemanticDmaStoreStmt, + env: dict[str, _RenderedValue], + *, + indent: int, + ) -> list[str]: + lines: list[str] = [] + profile = self._resolve_dma_store_trim_profile(stmt.options) + src = self._lower_expr(stmt.src, env, indent=indent, into=lines) + dst = self._lower_expr(stmt.dst.base, env, indent=indent, into=lines) + src_name, src_type = self._materialize_tile_window_ptr( + src, + col_offset=profile.left_padding, + indent=indent, + into=lines, + ) + dst_name, dst_type = self._materialize_tensor_slice_ptr( + stmt.dst, + dst, + env, + indent=indent, + into=lines, + ) + transfer = self._infer_dma_store_transfer(stmt.dst, stmt.src.type, dst, env, indent=indent, into=lines) + + c0_i64 = self._materialize_constant(0, _I64_TYPE) + c1_i64 = self._materialize_constant(1, _I64_TYPE) + + lines.extend( + [ + self._indent(indent) + + f"pto.set_loop_size_ubtoout {c1_i64}, {c1_i64} : i64, i64", + self._indent(indent) + + f"pto.set_loop1_stride_ubtoout {transfer.loop_src_stride.name}, {transfer.loop_dst_stride.name} : i64, i64", + self._indent(indent) + + f"pto.set_loop2_stride_ubtoout {transfer.loop_src_stride.name}, {transfer.loop_dst_stride.name} : i64, i64", + self._indent(indent) + + "pto.copy_ubuf_to_gm " + + f"{src_name}, {dst_name}, {c0_i64}, {transfer.n_burst.name}, {transfer.len_burst.name}, {c0_i64}, " + + f"{transfer.copy_dst_stride.name}, {transfer.copy_src_stride.name} : {src_type}, {dst_type}, " + + "i64, i64, i64, i64, i64, i64", + ] + ) + return lines + + def _render_vector_store( + self, + stmt: SemanticVectorStoreStmt, + env: dict[str, _RenderedValue], + *, + indent: int, + ) -> list[str]: + lines: list[str] = [] + value = self._lower_expr(stmt.value, env, indent=indent, into=lines) + destination = self._lower_expr(stmt.destination, env, indent=indent, into=lines) + if isinstance(destination.type, SemanticTileType): + destination = self._materialize_tile_memref(destination, indent=indent, into=lines) + if ( + isinstance(stmt.destination.type, SemanticTileType) + and stmt.destination.type.rank == 2 + and len(stmt.indices) == 2 + ): + destination = self._materialize_rank2_tile_subview( + destination, + stmt.destination.type, + stmt.indices, + env, + indent=indent, + into=lines, + ) + rendered_indices = self._materialize_constant(0, SemanticIndexType()) + else: + rendered_indices = self._render_index_list(stmt.indices, env, indent=indent, into=lines) + mask = self._lower_expr(stmt.mask, env, indent=indent, into=lines) + attrs = "" + if stmt.dist is not None: + dist = self._render_string_literal(stmt.dist) + attrs = f" {{dist = {dist}}}" + lines.append( + self._indent(indent) + + "pto.vsts " + + f"{value.name}, {destination.name}[{rendered_indices}], {mask.name}{attrs} : " + + f"{self._render_type(value.type)}, {self._render_type(destination.type)}, {self._render_type(mask.type)}" + ) + return lines + + def _render_vector_pair_store( + self, + stmt: SemanticVectorPairStoreStmt, + env: dict[str, _RenderedValue], + *, + indent: int, + ) -> list[str]: + lines: list[str] = [] + low = self._lower_expr(stmt.low, env, indent=indent, into=lines) + high = self._lower_expr(stmt.high, env, indent=indent, into=lines) + destination = self._lower_expr(stmt.destination, env, indent=indent, into=lines) + if isinstance(destination.type, SemanticTileType): + destination = self._materialize_tile_memref(destination, indent=indent, into=lines) + if ( + isinstance(stmt.destination.type, SemanticTileType) + and stmt.destination.type.rank == 2 + and len(stmt.indices) == 2 + ): + destination = self._materialize_rank2_tile_subview( + destination, + stmt.destination.type, + stmt.indices, + env, + indent=indent, + into=lines, + ) + rendered_indices = self._materialize_constant(0, SemanticIndexType()) + else: + rendered_indices = self._render_index_list(stmt.indices, env, indent=indent, into=lines) + dist = self._render_string_literal(stmt.dist) + mask = self._lower_expr(stmt.mask, env, indent=indent, into=lines) + lines.append( + self._indent(indent) + + "pto.vstsx2 " + + f"{low.name}, {high.name}, {destination.name}[{rendered_indices}], {dist}, {mask.name} : " + + f"{self._render_type(low.type)}, {self._render_type(high.type)}, " + + f"{self._render_type(destination.type)}, {self._render_type(mask.type)}" + ) + return lines + + def _render_vscatter( + self, + stmt: SemanticVScatterStmt, + env: dict[str, _RenderedValue], + *, + indent: int, + ) -> list[str]: + lines: list[str] = [] + value = self._lower_expr(stmt.value, env, indent=indent, into=lines) + destination = self._lower_expr(stmt.destination, env, indent=indent, into=lines) + offsets = self._lower_expr(stmt.offsets, env, indent=indent, into=lines) + mask = self._lower_expr(stmt.mask, env, indent=indent, into=lines) + lines.append( + self._indent(indent) + + "pto.vscatter " + + f"{value.name}, {destination.name}, {offsets.name}, {mask.name} : " + + f"{self._render_type(value.type)}, {self._render_type(destination.type)}, " + + f"{self._render_type(offsets.type)}, {self._render_type(mask.type)}" + ) + return lines + + def _render_predicate_store( + self, + stmt: SemanticPredicateStoreStmt, + env: dict[str, _RenderedValue], + *, + indent: int, + ) -> list[str]: + lines: list[str] = [] + value = self._lower_expr(stmt.value, env, indent=indent, into=lines) + destination = self._lower_expr(stmt.destination, env, indent=indent, into=lines) + if isinstance(destination.type, SemanticTileType): + destination = self._materialize_tile_memref(destination, indent=indent, into=lines) + if ( + isinstance(stmt.destination.type, SemanticTileType) + and stmt.destination.type.rank == 2 + and len(stmt.indices) == 2 + ): + destination = self._materialize_rank2_tile_subview( + destination, + stmt.destination.type, + stmt.indices, + env, + indent=indent, + into=lines, + ) + rendered_offset = self._materialize_constant(0, SemanticIndexType()) + else: + if stmt.op_name == "psti": + rendered_offset = self._lower_to_index(stmt.indices[0], env, indent=indent, into=lines) + else: + rendered_offset = self._lower_expr(stmt.indices[0], env, indent=indent, into=lines) + dist = self._render_string_literal(stmt.dist) + lines.append( + self._indent(indent) + + f"pto.{stmt.op_name} {value.name}, {destination.name}[{rendered_offset.name}], {dist} : " + + f"{self._render_type(value.type)}, {self._render_type(destination.type)}, {self._render_type(rendered_offset.type)}" + ) + return lines + + def _render_align_store( + self, + stmt: SemanticAlignStoreStmt, + env: dict[str, _RenderedValue], + *, + indent: int, + ) -> list[str]: + lines: list[str] = [] + value = self._lower_expr(stmt.value, env, indent=indent, into=lines) + destination = self._lower_expr(stmt.destination, env, indent=indent, into=lines) + if isinstance(destination.type, SemanticTileType): + destination = self._materialize_tile_memref(destination, indent=indent, into=lines) + if ( + isinstance(stmt.destination.type, SemanticTileType) + and stmt.destination.type.rank == 2 + and len(stmt.indices) == 2 + ): + destination = self._materialize_rank2_tile_subview( + destination, + stmt.destination.type, + stmt.indices, + env, + indent=indent, + into=lines, + ) + if stmt.op_name == "vstar": + lines.append( + self._indent(indent) + + f"pto.vstar {value.name}, {destination.name} : " + + f"{self._render_type(value.type)}, {self._render_type(destination.type)}" + ) + return lines + if stmt.offset is None: + raise NotImplementedError("vstas lowering expects an explicit offset operand") + offset = self._lower_expr(stmt.offset, env, indent=indent, into=lines) + offset = self._coerce_rendered_value(offset, _I32_TYPE, indent=indent, into=lines) + lines.append( + self._indent(indent) + + f"pto.vstas {value.name}, {destination.name}, {offset.name} : " + + f"{self._render_type(value.type)}, {self._render_type(destination.type)}, {self._render_type(offset.type)}" + ) + return lines + + def _render_scalar_store( + self, + stmt: SemanticScalarStoreStmt, + env: dict[str, _RenderedValue], + *, + indent: int, + ) -> list[str]: + lines: list[str] = [] + value = self._lower_expr(stmt.value, env, indent=indent, into=lines) + destination = self._lower_expr(stmt.destination, env, indent=indent, into=lines) + offset = self._lower_expr(stmt.offset, env, indent=indent, into=lines) + lines.append( + self._indent(indent) + + f"pto.store_scalar {value.name}, {destination.name}[{offset.name}] : " + + f"{self._render_type(destination.type)}, {self._render_type(value.type)}" + ) + return lines + + def _render_index_list( + self, + indices: tuple[SemanticExpr, ...], + env: dict[str, _RenderedValue], + *, + indent: int, + into: list[str], + ) -> str: + rendered = [ + self._lower_expr(index, env, indent=indent, into=into).name for index in indices + ] + return ", ".join(rendered) + + def _render_rank2_subview_result_type( + self, + *, + element_dtype: str, + memory_space: str, + ) -> _RenderedTextualType: + return _RenderedTextualType( + f"memref, " + f"{self._render_memref_memory_space(memory_space)}>" + ) + + def _materialize_rank2_tile_subview( + self, + base: _RenderedValue, + tile_type: SemanticTileType, + indices: tuple[SemanticExpr, ...], + env: dict[str, _RenderedValue], + *, + indent: int, + into: list[str], + ) -> _RenderedValue: + row_index, col_index = indices + row_value = self._lower_expr(row_index, env, indent=indent, into=into) + col_value = self._lower_expr(col_index, env, indent=indent, into=into) + one = self._materialize_constant(1, SemanticIndexType()) + total_cols = self._materialize_constant(tile_type.shape[1], SemanticIndexType()) + remaining_cols = self._new_temp() + into.append( + self._indent(indent) + + f"{remaining_cols} = arith.subi {total_cols}, {col_value.name} : index" + ) + subview_type = self._render_rank2_subview_result_type( + element_dtype=tile_type.element_dtype.name, + memory_space=tile_type.memory_space or "ub", + ) + subview_name = self._new_temp() + into.append( + self._indent(indent) + + f"{subview_name} = memref.subview {base.name}[{row_value.name}, {col_value.name}] " + + f"[{one}, {remaining_cols}] [{one}, {one}] : " + + f"{self._render_type(base.type)} to {self._render_type(subview_type)}" + ) + return _RenderedValue(name=subview_name, type=subview_type) + + def _tensor_slice_extents(self, expr: SemanticTensorSliceExpr) -> tuple[int, int]: + if expr.type.rank != 2 or len(expr.type.extents) != 2: + raise NotImplementedError("TileLang DSL v1 DMA lowering currently only supports rank-2 TensorView slices") + return expr.type.extents + + def _materialize_tensor_slice_axis_size( + self, + slice_axis: object, + env: dict[str, _RenderedValue], + *, + indent: int, + into: list[str], + ) -> _RenderedValue: + if slice_axis.extent is not None: + return _RenderedValue( + name=self._materialize_constant(slice_axis.extent, SemanticIndexType()), + type=SemanticIndexType(), + ) + distance = self._emit_binary_value( + "sub", + self._lower_expr(slice_axis.stop, env, indent=indent, into=into), + self._lower_expr(slice_axis.start, env, indent=indent, into=into), + SemanticIndexType(), + indent=indent, + into=into, + ) + step_value = self._static_expr_value(slice_axis.step, default=1) + if not isinstance(step_value, int) or step_value <= 0: + raise NotImplementedError( + "partition_view lowering currently expects a static positive slice step in TileLang DSL v1" + ) + if step_value == 1: + return distance + numerator = self._emit_binary_value( + "add", + distance, + _RenderedValue( + name=self._materialize_constant(step_value - 1, SemanticIndexType()), + type=SemanticIndexType(), + ), + SemanticIndexType(), + indent=indent, + into=into, + ) + return self._emit_binary_value( + "floordiv", + numerator, + _RenderedValue( + name=self._materialize_constant(step_value, SemanticIndexType()), + type=SemanticIndexType(), + ), + SemanticIndexType(), + indent=indent, + into=into, + ) + + def _lower_tensor_slice_expr( + self, + expr: SemanticTensorSliceExpr, + env: dict[str, _RenderedValue], + *, + indent: int, + desired_name: str | None, + into: list[str] | None, + ) -> _RenderedValue: + if into is None: + into = [] + tensor_base = self._lower_expr(expr.base, env, indent=indent, into=into) + if not isinstance(tensor_base.type, (SemanticTensorViewType, SemanticPartitionTensorViewType)): + raise NotImplementedError("partition_view lowering expects a TensorView/PartitionTensorView source") + + offsets: list[_RenderedValue] = [] + sizes: list[_RenderedValue] = [] + for axis_slice in expr.slices: + offsets.append(self._lower_expr(axis_slice.start, env, indent=indent, into=into)) + sizes.append( + self._materialize_tensor_slice_axis_size( + axis_slice, + env, + indent=indent, + into=into, + ) + ) + + result_name = desired_name or self._new_temp() + result_type_text = self._render_partition_tensor_view_type( + element_dtype=expr.type.element_dtype.name, + shape=tuple("?" if dim is None else dim for dim in expr.type.extents), + ) + into.append( + self._indent(indent) + + f"{result_name} = pto.partition_view {tensor_base.name}, " + + f"offsets = [{', '.join(value.name for value in offsets)}], " + + f"sizes = [{', '.join(value.name for value in sizes)}] : " + + f"{self._render_type(tensor_base.type)} -> {result_type_text}" + ) + return _RenderedValue(name=result_name, type=_RenderedTextualType(result_type_text)) + + def _resolve_dma_load_padding_profile(self, options: object) -> _DmaLoadPaddingProfile: + pad_mode_name = self._static_pad_mode_name(getattr(options, "pad_mode", None)) or "PadNull" + left_padding = self._static_expr_value(getattr(options, "left_padding", None), default=0) + right_padding = self._static_expr_value(getattr(options, "right_padding", None), default=0) + init_out_buffer = self._static_expr_value(getattr(options, "init_out_buffer", None), default=False) + if not isinstance(left_padding, int) or left_padding < 0: + raise NotImplementedError( + "pto.dma_load lowering currently expects `left_padding` to be a static non-negative index" + ) + if not isinstance(right_padding, int) or right_padding < 0: + raise NotImplementedError( + "pto.dma_load lowering currently expects `right_padding` to be a static non-negative index" + ) + if not isinstance(init_out_buffer, bool): + raise NotImplementedError( + "pto.dma_load lowering currently expects `init_out_buffer` to be a compile-time bool" + ) + if pad_mode_name not in {"PadNull", "PadFirstElem", "PadValue"}: + raise NotImplementedError( + f"pto.dma_load lowering does not recognize pad_mode `{pad_mode_name}` in TileLang DSL v1" + ) + if pad_mode_name == "PadNull" and init_out_buffer: + raise NotImplementedError( + "pto.dma_load lowering does not support `init_out_buffer=True` with `pad_mode=PadMode.PadNull`; " + "the stable frontend-only path has no explicit fill value for that combination" + ) + return _DmaLoadPaddingProfile( + pad_mode_name=pad_mode_name, + left_padding=left_padding, + right_padding=right_padding, + init_out_buffer=init_out_buffer, + pad_value=getattr(options, "pad_value", None), + ) + + def _resolve_dma_store_trim_profile(self, options: object) -> _DmaStoreTrimProfile: + pad_mode_name = self._static_pad_mode_name(getattr(options, "pad_mode", None)) or "PadNull" + left_padding = self._static_expr_value(getattr(options, "left_padding", None), default=0) + right_padding = self._static_expr_value(getattr(options, "right_padding", None), default=0) + if pad_mode_name != "PadNull": + raise NotImplementedError( + "pto.dma_store lowering only supports `pad_mode=PadMode.PadNull`; " + "non-PadNull store padding would require GM-side fill in the stable frontend-only path" + ) + if self._static_expr_value(getattr(options, "pad_value", None)) is not None: + raise NotImplementedError( + "pto.dma_store lowering does not support `pad_value`; GM-side fill is unsupported" + ) + if not isinstance(left_padding, int) or left_padding < 0: + raise NotImplementedError( + "pto.dma_store lowering currently expects `left_padding` to be a static non-negative index" + ) + if not isinstance(right_padding, int) or right_padding < 0: + raise NotImplementedError( + "pto.dma_store lowering currently expects `right_padding` to be a static non-negative index" + ) + return _DmaStoreTrimProfile( + left_padding=left_padding, + right_padding=right_padding, + ) + + def _require_default_dma_lowering_profile(self, options: object, op_name: str) -> None: + if not self._is_default_dma_lowering_profile(options): + raise NotImplementedError( + f"{op_name} lowering for padding/trim/init options is not implemented yet in TileLang DSL v1; " + "this stable frontend-only DMA path only lowers the default no-padding profile today" + ) + + def _is_default_dma_lowering_profile(self, options: object) -> bool: + return ( + self._static_pad_mode_name(getattr(options, "pad_mode", None)) in {None, "PadNull"} + and self._static_expr_value(getattr(options, "pad_value", None)) is None + and self._static_expr_value(getattr(options, "left_padding", None), default=0) == 0 + and self._static_expr_value(getattr(options, "right_padding", None), default=0) == 0 + and self._static_expr_value(getattr(options, "init_out_buffer", None), default=False) is False + ) + + def _static_pad_mode_name(self, expr: SemanticExpr | None) -> str | None: + value = self._static_expr_value(expr) + return None if value is None else getattr(value, "name", None) + + def _static_expr_value(self, expr: SemanticExpr | None, *, default: object = None) -> object: + if expr is None: + return default + if isinstance(expr, SemanticLiteralExpr): + return expr.value + if isinstance(expr, SemanticSymbolExpr): + return expr.value + if isinstance(expr, SemanticBindingRef): + return expr.binding.value + return None + + def _infer_dma_load_transfer( + self, + slice_expr: SemanticTensorSliceExpr, + tile_type: SemanticTileType, + tensor_base: _RenderedValue, + env: dict[str, _RenderedValue], + *, + indent: int, + into: list[str], + ) -> _DmaTransferConfig: + element_bytes = self._dtype_byte_width(slice_expr.type.element_dtype) + row_count = self._materialize_dma_axis_extent(slice_expr, 0, env, indent=indent, into=into) + col_count = self._materialize_dma_axis_extent(slice_expr, 1, env, indent=indent, into=into) + gm_row_stride = self._materialize_tensor_row_stride_bytes( + slice_expr, + tensor_base, + element_bytes, + indent=indent, + into=into, + ) + row_step = self._materialize_dma_row_step(slice_expr, env, indent=indent, into=into) + copy_src_stride = self._emit_binary_value( + "mul", + gm_row_stride, + row_step, + _I64_TYPE, + indent=indent, + into=into, + ) + copy_dst_stride = self._materialize_tile_row_stride_bytes( + tile_type, + element_bytes, + indent=indent, + into=into, + ) + len_burst = self._materialize_dma_len_burst( + col_count, + element_bytes, + indent=indent, + into=into, + ) + loop_src_stride = self._emit_binary_value( + "mul", + row_count, + copy_src_stride, + _I64_TYPE, + indent=indent, + into=into, + ) + loop_dst_stride = self._emit_binary_value( + "mul", + row_count, + copy_dst_stride, + _I64_TYPE, + indent=indent, + into=into, + ) + return _DmaTransferConfig( + n_burst=row_count, + len_burst=len_burst, + copy_src_stride=copy_src_stride, + copy_dst_stride=copy_dst_stride, + loop_src_stride=loop_src_stride, + loop_dst_stride=loop_dst_stride, + ) + + def _infer_dma_store_transfer( + self, + slice_expr: SemanticTensorSliceExpr, + tile_type: SemanticTileType, + tensor_base: _RenderedValue, + env: dict[str, _RenderedValue], + *, + indent: int, + into: list[str], + ) -> _DmaTransferConfig: + element_bytes = self._dtype_byte_width(slice_expr.type.element_dtype) + row_count = self._materialize_dma_axis_extent(slice_expr, 0, env, indent=indent, into=into) + col_count = self._materialize_dma_axis_extent(slice_expr, 1, env, indent=indent, into=into) + copy_src_stride = self._materialize_tile_row_stride_bytes( + tile_type, + element_bytes, + indent=indent, + into=into, + ) + gm_row_stride = self._materialize_tensor_row_stride_bytes( + slice_expr, + tensor_base, + element_bytes, + indent=indent, + into=into, + ) + row_step = self._materialize_dma_row_step(slice_expr, env, indent=indent, into=into) + copy_dst_stride = self._emit_binary_value( + "mul", + gm_row_stride, + row_step, + _I64_TYPE, + indent=indent, + into=into, + ) + len_burst = self._materialize_dma_len_burst( + col_count, + element_bytes, + indent=indent, + into=into, + ) + loop_src_stride = self._emit_binary_value( + "mul", + row_count, + copy_src_stride, + _I64_TYPE, + indent=indent, + into=into, + ) + loop_dst_stride = self._emit_binary_value( + "mul", + row_count, + copy_dst_stride, + _I64_TYPE, + indent=indent, + into=into, + ) + return _DmaTransferConfig( + n_burst=row_count, + len_burst=len_burst, + copy_src_stride=copy_src_stride, + copy_dst_stride=copy_dst_stride, + loop_src_stride=loop_src_stride, + loop_dst_stride=loop_dst_stride, + ) + + def _render_dma_load_copy_ops( + self, + src_name: str, + src_type: str, + dst_name: str, + dst_type: str, + transfer: _DmaTransferConfig, + *, + indent: int, + ) -> list[str]: + c0_i64 = self._materialize_constant(0, _I64_TYPE) + c1_i64 = self._materialize_constant(1, _I64_TYPE) + false_bit = self._materialize_constant(False, _I1_TYPE) + return [ + self._indent(indent) + + f"pto.set_loop2_stride_outtoub {transfer.loop_src_stride.name}, {transfer.loop_dst_stride.name} : i64, i64", + self._indent(indent) + + f"pto.set_loop1_stride_outtoub {transfer.loop_src_stride.name}, {transfer.loop_dst_stride.name} : i64, i64", + self._indent(indent) + + f"pto.set_loop_size_outtoub {c1_i64}, {c1_i64} : i64, i64", + self._indent(indent) + + "pto.copy_gm_to_ubuf " + + f"{src_name}, {dst_name}, {c0_i64}, {transfer.n_burst.name}, {transfer.len_burst.name}, {c0_i64}, {c0_i64}, " + + f"{false_bit}, {c0_i64}, {transfer.copy_src_stride.name}, {transfer.copy_dst_stride.name} : " + + f"{src_type}, {dst_type}, " + + "i64, i64, i64, i64, i64, i1, i64, i64, i64", + ] + + def _render_dma_load_prefill( + self, + tile_expr: SemanticExpr, + tile_value: _RenderedValue, + env: dict[str, _RenderedValue], + profile: _DmaLoadPaddingProfile, + *, + indent: int, + ) -> list[str]: + fill_bands = profile.left_padding > 0 or profile.right_padding > 0 + if profile.pad_mode_name == "PadNull" and not profile.init_out_buffer: + return [] + if profile.pad_mode_name in {"PadValue", "PadFirstElem"} and not (profile.init_out_buffer or fill_bands): + return [] + + lines: list[str] = [] + tile_memref = self._materialize_tile_memref(tile_value, indent=indent, into=lines) + rows_upper = self._materialize_tile_window_extent( + tile_expr, + tile_value, + axis=0, + indent=indent, + into=lines, + ) + cols_upper = self._materialize_tile_window_extent( + tile_expr, + tile_value, + axis=1, + indent=indent, + into=lines, + ) + fill_vec = self._materialize_dma_load_prefill_vector( + tile_memref, + tile_value.type.element_dtype, + env, + profile, + indent=indent, + into=lines, + ) + + windows: list[tuple[_RenderedValue, _RenderedValue]] = [] + c0_index = _RenderedValue( + name=self._materialize_constant(0, SemanticIndexType()), + type=SemanticIndexType(), + ) + if profile.init_out_buffer: + windows.append((c0_index, cols_upper)) + else: + if profile.left_padding > 0: + windows.append( + ( + c0_index, + _RenderedValue( + name=self._materialize_constant(profile.left_padding, SemanticIndexType()), + type=SemanticIndexType(), + ), + ) + ) + if profile.right_padding > 0: + right_width = _RenderedValue( + name=self._materialize_constant(profile.right_padding, SemanticIndexType()), + type=SemanticIndexType(), + ) + right_start = self._emit_binary_value( + "sub", + cols_upper, + right_width, + SemanticIndexType(), + indent=indent, + into=lines, + ) + windows.append((right_start, cols_upper)) + + if not windows: + return [] + lines.extend( + self._render_tile_fill_windows( + tile_memref, + tile_value.type.element_dtype, + fill_vec, + rows_upper, + windows, + indent=indent, + ) + ) + return lines + + def _render_tile_fill_windows( + self, + tile_memref: _RenderedValue, + element_dtype: ScalarType, + fill_vec: _RenderedValue, + rows_upper: _RenderedValue, + windows: list[tuple[_RenderedValue, _RenderedValue]], + *, + indent: int, + ) -> list[str]: + lines: list[str] = [] + c0 = self._materialize_constant(0, SemanticIndexType()) + c1 = self._materialize_constant(1, SemanticIndexType()) + vector_step = self._materialize_constant(get_lanes(element_dtype), SemanticIndexType()) + mask_type = SemanticMaskType(granularity=self._mask_granularity_for_dtype(element_dtype)) + lines.append(self._indent(indent) + "pto.vecscope {") + for start, stop in windows: + row_iv = self._new_temp() + lines.append( + self._indent(indent + 2) + + f"scf.for {row_iv} = {c0} to {rows_upper.name} step {c1} {{" + ) + col_iv = self._new_temp() + lines.append( + self._indent(indent + 4) + + f"scf.for {col_iv} = {start.name} to {stop.name} step {vector_step} {{" + ) + remaining = self._emit_binary_value( + "sub", + stop, + _RenderedValue(name=col_iv, type=SemanticIndexType()), + SemanticIndexType(), + indent=indent + 6, + into=lines, + ) + remaining_i32 = self._coerce_rendered_value( + remaining, + _I32_TYPE, + indent=indent + 6, + into=lines, + ) + mask_name = self._new_temp() + next_name = self._new_temp() + lines.append( + self._indent(indent + 6) + + f"{mask_name}, {next_name} = pto.plt_{mask_type.granularity} {remaining_i32.name} : " + + f"i32 -> {self._render_type(mask_type)}, i32" + ) + lines.append( + self._indent(indent + 6) + + f"pto.vsts {fill_vec.name}, {tile_memref.name}[{row_iv}, {col_iv}], {mask_name} : " + + f"{self._render_type(fill_vec.type)}, {self._render_type(tile_memref.type)}, {self._render_type(mask_type)}" + ) + lines.append(self._indent(indent + 4) + "}") + lines.append(self._indent(indent + 2) + "}") + lines.append(self._indent(indent) + "}") + return lines + + def _materialize_dma_load_prefill_vector( + self, + tile_memref: _RenderedValue, + element_dtype: ScalarType, + env: dict[str, _RenderedValue], + profile: _DmaLoadPaddingProfile, + *, + indent: int, + into: list[str], + ) -> _RenderedValue: + vec_type = SemanticVRegType(element_dtype=element_dtype, lanes=get_lanes(element_dtype)) + result_name = self._new_temp() + if profile.pad_mode_name == "PadValue": + scalar = self._materialize_dma_pad_value_scalar( + profile.pad_value, + element_dtype, + env, + indent=indent, + into=into, + ) + into.append( + self._indent(indent) + + f"{result_name} = pto.vbr {scalar.name} : {self._render_type(scalar.type)} -> {self._render_type(vec_type)}" + ) + return _RenderedValue(name=result_name, type=vec_type) + if profile.pad_mode_name == "PadFirstElem": + c0 = self._materialize_constant(0, SemanticIndexType()) + first_col = self._materialize_constant(profile.left_padding, SemanticIndexType()) + into.append( + self._indent(indent) + + f'{result_name} = pto.vlds {tile_memref.name}[{c0}, {first_col}] {{dist = "{self._broadcast_dist_for_dtype(element_dtype)}"}} : ' + + f"{self._render_type(tile_memref.type)} -> {self._render_type(vec_type)}" + ) + return _RenderedValue(name=result_name, type=vec_type) + raise NotImplementedError( + f"pto.dma_load lowering does not produce a prefill vector for pad_mode `{profile.pad_mode_name}`" + ) + + def _materialize_dma_pad_value_scalar( + self, + expr: SemanticExpr | None, + element_dtype: ScalarType, + env: dict[str, _RenderedValue], + *, + indent: int, + into: list[str], + ) -> _RenderedValue: + scalar_type = SemanticScalarType(dtype=element_dtype) + static_value = self._static_expr_value(expr) + if isinstance(static_value, (int, float)): + return _RenderedValue( + name=self._materialize_constant(static_value, scalar_type), + type=scalar_type, + ) + if expr is None: + raise NotImplementedError("pto.dma_load PadValue lowering requires a concrete `pad_value` expression") + value = self._lower_expr(expr, env, indent=indent, into=into) + if isinstance(value.type, SemanticScalarType) and value.type.dtype == element_dtype: + return value + raise NotImplementedError( + "pto.dma_load PadValue lowering currently expects `pad_value` to be a compile-time numeric literal " + "or a scalar value whose dtype matches the destination Tile element dtype" + ) + + def _materialize_tile_window_extent( + self, + tile_expr: SemanticExpr, + tile_value: _RenderedValue, + *, + axis: int, + indent: int, + into: list[str], + ) -> _RenderedValue: + if ( + isinstance(tile_expr, SemanticBindingRef) + and isinstance(tile_expr.type, SemanticTileType) + and tile_expr.type.valid_shape is not None + and tile_expr.type.valid_shape[axis] is None + ): + return self._materialize_tile_valid_dim( + tile_expr.binding, + axis, + indent=indent, + into=into, + ) + if not isinstance(tile_value.type, SemanticTileType): + raise NotImplementedError("DMA load prefill expects a Tile destination") + valid_shape = tile_value.type.valid_shape or tile_value.type.shape + if valid_shape is None: + raise NotImplementedError("DMA load prefill expects a statically known Tile shape or valid_shape") + extent = valid_shape[axis] + if extent is None: + raise NotImplementedError("DMA load prefill does not support dynamic Tile valid_shape on non-binding values") + return _RenderedValue( + name=self._materialize_constant(extent, SemanticIndexType()), + type=SemanticIndexType(), + ) + + def _mask_granularity_for_dtype(self, dtype: ScalarType) -> str: + int_bits = integer_bitwidth(dtype) + if dtype.name == "f32" or int_bits in {32, 64}: + return "b32" + if dtype.name in {"f16", "bf16"} or int_bits == 16: + return "b16" + if int_bits == 8: + return "b8" + raise NotImplementedError(f"dtype `{dtype.name}` is not supported by DMA load prefill lowering") + + def _broadcast_dist_for_dtype(self, dtype: ScalarType) -> str: + int_bits = integer_bitwidth(dtype) + if dtype.name == "f32" or int_bits == 32: + return "BRC_B32" + if dtype.name in {"f16", "bf16"} or int_bits == 16: + return "BRC_B16" + if int_bits == 8: + return "BRC_B8" + raise NotImplementedError(f"dtype `{dtype.name}` is not supported by DMA load broadcast lowering") + + def _materialize_tile_window_ptr( + self, + tile_value: _RenderedValue, + *, + col_offset: int, + indent: int, + into: list[str], + ) -> tuple[str, str]: + base_ptr_name, base_ptr_type = self._materialize_copy_buffer_ptr( + tile_value, + indent=indent, + into=into, + ) + if col_offset == 0: + return base_ptr_name, base_ptr_type + byte_ptr_type = "!pto.ptr" + byte_ptr_name = self._new_temp() + into.append( + self._indent(indent) + + f"{byte_ptr_name} = pto.castptr {base_ptr_name} : {base_ptr_type} -> {byte_ptr_type}" + ) + offset_bytes = self._materialize_constant( + col_offset * self._dtype_byte_width(tile_value.type.element_dtype), + SemanticIndexType(), + ) + offset_ptr_name = self._new_temp() + into.append( + self._indent(indent) + + f"{offset_ptr_name} = pto.addptr {byte_ptr_name}, {offset_bytes} : {byte_ptr_type} -> {byte_ptr_type}" + ) + typed_ptr_name = self._new_temp() + into.append( + self._indent(indent) + + f"{typed_ptr_name} = pto.castptr {offset_ptr_name} : {byte_ptr_type} -> {base_ptr_type}" + ) + return typed_ptr_name, base_ptr_type + + def _materialize_tensor_slice_ptr( + self, + slice_expr: SemanticTensorSliceExpr, + tensor_base: _RenderedValue, + env: dict[str, _RenderedValue], + *, + indent: int, + into: list[str], + ) -> tuple[str, str]: + base_ptr_name, base_ptr_type = self._materialize_copy_buffer_ptr( + tensor_base, + indent=indent, + into=into, + ) + if self._is_zero_index_expr(slice_expr.slices[0].start) and self._is_zero_index_expr(slice_expr.slices[1].start): + return base_ptr_name, base_ptr_type + + byte_ptr_type = "!pto.ptr" + byte_ptr_name = self._new_temp() + into.append( + self._indent(indent) + + f"{byte_ptr_name} = pto.castptr {base_ptr_name} : {base_ptr_type} -> {byte_ptr_type}" + ) + offset = self._materialize_tensor_slice_offset_bytes( + slice_expr, + tensor_base, + env, + indent=indent, + into=into, + ) + offset_ptr_name = self._new_temp() + into.append( + self._indent(indent) + + f"{offset_ptr_name} = pto.addptr {byte_ptr_name}, {offset.name} : " + + f"{byte_ptr_type} -> {byte_ptr_type}" + ) + typed_ptr_name = self._new_temp() + into.append( + self._indent(indent) + + f"{typed_ptr_name} = pto.castptr {offset_ptr_name} : {byte_ptr_type} -> {base_ptr_type}" + ) + return typed_ptr_name, base_ptr_type + + def _materialize_tensor_slice_offset_bytes( + self, + slice_expr: SemanticTensorSliceExpr, + tensor_base: _RenderedValue, + env: dict[str, _RenderedValue], + *, + indent: int, + into: list[str], + ) -> _RenderedValue: + offset_elems = _RenderedValue( + name=self._materialize_constant(0, SemanticIndexType()), + type=SemanticIndexType(), + ) + for axis_index, slice_axis in enumerate(slice_expr.slices): + axis_start = self._lower_expr(slice_axis.start, env, indent=indent, into=into) + axis_stride_elems = self._materialize_tensor_axis_stride_elems( + tensor_base, + axis=slice_expr.type.physical_axes[axis_index], + indent=indent, + into=into, + ) + axis_offset_elems = self._emit_binary_value( + "mul", + axis_start, + axis_stride_elems, + SemanticIndexType(), + indent=indent, + into=into, + ) + offset_elems = self._emit_binary_value( + "add", + offset_elems, + axis_offset_elems, + SemanticIndexType(), + indent=indent, + into=into, + ) + return self._emit_binary_value( + "mul", + offset_elems, + _RenderedValue( + name=self._materialize_constant( + self._dtype_byte_width(slice_expr.type.element_dtype), + SemanticIndexType(), + ), + type=SemanticIndexType(), + ), + SemanticIndexType(), + indent=indent, + into=into, + ) + + def _is_zero_index_expr(self, expr: SemanticExpr) -> bool: + if isinstance(expr, SemanticLiteralExpr): + return isinstance(expr.value, int) and expr.value == 0 + if isinstance(expr, SemanticBindingRef): + return isinstance(expr.binding.value, int) and expr.binding.value == 0 + return False + + def _materialize_tensor_dim( + self, + tensor_base: _RenderedValue, + *, + axis: int, + indent: int, + into: list[str], + ) -> _RenderedValue: + dim_index = self._new_temp() + axis_value = self._materialize_constant(axis, SemanticIndexType()) + if isinstance(tensor_base.type, (SemanticTensorViewType, SemanticPartitionTensorViewType)): + into.append( + self._indent(indent) + + f"{dim_index} = pto.get_tensor_view_dim {tensor_base.name}, {axis_value} : " + + f"{self._render_type(tensor_base.type)} -> index" + ) + else: + into.append( + self._indent(indent) + + f"{dim_index} = memref.dim {tensor_base.name}, {axis_value} : {self._render_type(tensor_base.type)}" + ) + return _RenderedValue(name=dim_index, type=SemanticIndexType()) + + def _materialize_dma_axis_extent( + self, + slice_expr: SemanticTensorSliceExpr, + axis: int, + env: dict[str, _RenderedValue], + *, + indent: int, + into: list[str], + ) -> _RenderedValue: + axis_slice = slice_expr.slices[axis] + if axis_slice.extent is not None: + return _RenderedValue( + name=self._materialize_constant(axis_slice.extent, _I64_TYPE), + type=_I64_TYPE, + ) + + distance_expr = SemanticBinaryExpr( + lhs=axis_slice.stop, + op="sub", + rhs=axis_slice.start, + type=SemanticIndexType(), + ) + extent_expr: SemanticExpr = distance_expr + step_value = self._static_expr_value(axis_slice.step) + if not isinstance(step_value, int): + raise NotImplementedError("DMA lowering currently expects a static slice step") + if step_value != 1: + extent_expr = SemanticBinaryExpr( + lhs=SemanticBinaryExpr( + lhs=distance_expr, + op="add", + rhs=SemanticLiteralExpr(value=step_value - 1, type=SemanticIndexType()), + type=SemanticIndexType(), + ), + op="floordiv", + rhs=SemanticLiteralExpr(value=step_value, type=SemanticIndexType()), + type=SemanticIndexType(), + ) + return self._lower_to_i64(extent_expr, env, indent=indent, into=into) + + def _materialize_dma_row_step( + self, + slice_expr: SemanticTensorSliceExpr, + env: dict[str, _RenderedValue], + *, + indent: int, + into: list[str], + ) -> _RenderedValue: + return self._lower_to_i64(slice_expr.slices[0].step, env, indent=indent, into=into) + + def _materialize_tensor_axis_stride_elems( + self, + tensor_base: _RenderedValue, + axis: int, + *, + indent: int, + into: list[str], + ) -> _RenderedValue: + stride = _RenderedValue( + name=self._materialize_constant(1, SemanticIndexType()), + type=SemanticIndexType(), + ) + for dim_axis in range(axis + 1, tensor_base.type.rank): + dim_value = self._materialize_tensor_dim( + tensor_base, + axis=dim_axis, + indent=indent, + into=into, + ) + stride = self._emit_binary_value( + "mul", + stride, + dim_value, + SemanticIndexType(), + indent=indent, + into=into, + ) + return stride + + def _materialize_tensor_row_stride_bytes( + self, + slice_expr: SemanticTensorSliceExpr, + tensor_base: _RenderedValue, + element_bytes: int, + *, + indent: int, + into: list[str], + ) -> _RenderedValue: + stride_elems = self._materialize_tensor_axis_stride_elems( + tensor_base, + axis=slice_expr.type.physical_axes[0], + indent=indent, + into=into, + ) + dim_bytes = self._emit_binary_value( + "mul", + stride_elems, + _RenderedValue( + name=self._materialize_constant(element_bytes, SemanticIndexType()), + type=SemanticIndexType(), + ), + SemanticIndexType(), + indent=indent, + into=into, + ) + return self._coerce_rendered_to_i64(dim_bytes, indent=indent, into=into) + + def _materialize_tile_row_stride_bytes( + self, + tile_type: SemanticTileType, + element_bytes: int, + *, + indent: int, + into: list[str], + ) -> _RenderedValue: + if tile_type.shape is None or len(tile_type.shape) != 2: + raise NotImplementedError("DMA lowering requires a statically specialized rank-2 Tile shape") + row_bytes = tile_type.shape[1] * element_bytes + return _RenderedValue( + name=self._materialize_constant(row_bytes, _I64_TYPE), + type=_I64_TYPE, + ) + + def _materialize_dma_len_burst( + self, + col_count: _RenderedValue, + element_bytes: int, + *, + indent: int, + into: list[str], + ) -> _RenderedValue: + return self._emit_binary_value( + "mul", + col_count, + _RenderedValue( + name=self._materialize_constant(element_bytes, _I64_TYPE), + type=_I64_TYPE, + ), + _I64_TYPE, + indent=indent, + into=into, + ) + + def _dma_transfer_extents( + self, + slice_expr: SemanticTensorSliceExpr, + tile_type: SemanticTileType, + ) -> tuple[int, int]: + row_count, col_count = self._tensor_slice_extents(slice_expr) + if row_count is not None and col_count is not None: + return row_count, col_count + if tile_type.shape is None or len(tile_type.shape) != 2: + raise NotImplementedError("DMA lowering requires a statically specialized rank-2 Tile shape") + return tile_type.shape + + def _emit_binary_value( + self, + op: str, + lhs: _RenderedValue, + rhs: _RenderedValue, + result_type: SemanticType, + *, + indent: int, + into: list[str], + ) -> _RenderedValue: + result_name = self._new_temp() + into.append( + self._indent(indent) + + f"{result_name} = {self._render_binary_op(op, result_type)} " + f"{lhs.name}, {rhs.name} : {self._render_type(result_type)}" + ) + return _RenderedValue(name=result_name, type=result_type) + + def _render_strict_vecscope( + self, + stmt: SemanticStrictVecscopeStmt, + env: dict[str, _RenderedValue], + *, + indent: int, + ) -> list[str]: + lines: list[str] = [] + capture_values = [] + block_argument_values = [] + for expr, binding in zip(stmt.captures, stmt.block_arguments): + capture = self._lower_expr(expr, env, indent=indent, into=lines) + capture, block_arg = self._materialize_strict_vecscope_capture( + capture, + binding, + indent=indent, + into=lines, + ) + capture_values.append(capture) + block_argument_values.append(block_arg) + capture_names = ", ".join(value.name for value in capture_values) + block_args = ", ".join( + f"{binding.ssa_name}: {self._render_type(value.type)}" + for binding, value in zip(stmt.block_arguments, block_argument_values) + ) + function_type = ", ".join( + self._render_type(value.type) for value in block_argument_values + ) + + scope_env = { + binding.name: _RenderedValue(name=binding.ssa_name, type=value.type) + for binding, value in zip(stmt.block_arguments, block_argument_values) + } + + lines.append(self._indent(indent) + f"pto.strict_vecscope({capture_names}) {{") + lines.append(self._indent(indent) + f"^bb0({block_args}):") + lines.extend(self._render_block(stmt.body, scope_env, indent=indent + 2)) + lines.append(self._indent(indent) + f"}} : ({function_type}) -> ()") + return lines + + def _render_vecscope( + self, + stmt: SemanticVecscopeStmt, + env: dict[str, _RenderedValue], + *, + indent: int, + ) -> list[str]: + scope_env = dict(env) + lines = [self._indent(indent) + "pto.vecscope {"] + lines.extend(self._render_block(stmt.body, scope_env, indent=indent + 2)) + lines.append(self._indent(indent) + "}") + return lines + + def _render_for( + self, + stmt: SemanticForStmt, + env: dict[str, _RenderedValue], + *, + indent: int, + ) -> list[str]: + lines: list[str] = [] + lower_bound = self._lower_expr(stmt.lower_bound, env, indent=indent, into=lines) + upper_bound = self._lower_expr(stmt.upper_bound, env, indent=indent, into=lines) + step = self._lower_expr(stmt.step, env, indent=indent, into=lines) + + body_env = dict(env) + body_env[stmt.induction_variable.name] = _RenderedValue( + name=stmt.induction_variable.ssa_name, + type=stmt.induction_variable.type, + ) + + if not stmt.loop_carried: + lines.append( + self._indent(indent) + + f"scf.for {stmt.induction_variable.ssa_name} = {lower_bound.name} " + f"to {upper_bound.name} step {step.name} {{" + ) + lines.extend(self._render_block(stmt.body, body_env, indent=indent + 2)) + lines.append(self._indent(indent) + "}") + return lines + + carried_bindings = tuple(stmt.loop_carried) + if len(carried_bindings) == 1: + carried_binding = carried_bindings[0] + initial_value = self._coerce_rendered_value( + env[carried_binding.name], + carried_binding.type, + indent=indent, + into=lines, + ) + iter_arg_name = f"%{carried_binding.name}_iter_{self._loop_counter}" + self._loop_counter += 1 + body_env[carried_binding.name] = _RenderedValue( + name=iter_arg_name, + type=carried_binding.type, + ) + + lines.append( + self._indent(indent) + + f"{carried_binding.ssa_name}:1 = scf.for {stmt.induction_variable.ssa_name} = " + f"{lower_bound.name} to {upper_bound.name} step {step.name} " + f"iter_args({iter_arg_name} = {initial_value.name}) -> " + f"({self._render_type(carried_binding.type)}) {{" + ) + lines.extend(self._render_block(stmt.body, body_env, indent=indent + 2)) + yielded_value = self._coerce_rendered_value( + body_env[carried_binding.name], + carried_binding.type, + indent=indent + 2, + into=lines, + ) + lines.append( + self._indent(indent + 2) + + f"scf.yield {yielded_value.name} : {self._render_type(yielded_value.type)}" + ) + lines.append(self._indent(indent) + "}") + env[carried_binding.name] = _RenderedValue( + name=carried_binding.ssa_name, + type=carried_binding.type, + ) + return lines + + loop_id = self._loop_counter + self._loop_counter += 1 + + initial_values: list[_RenderedValue] = [] + iter_arg_names: list[str] = [] + for index, binding in enumerate(carried_bindings): + initial_values.append( + self._coerce_rendered_value( + env[binding.name], + binding.type, + indent=indent, + into=lines, + ) + ) + iter_arg_names.append(f"%{binding.name}_iter_{loop_id}_{index}") + body_env[binding.name] = _RenderedValue( + name=iter_arg_names[-1], + type=binding.type, + ) + + result_names = ", ".join(binding.ssa_name for binding in carried_bindings) + iter_args = ", ".join( + f"{iter_name} = {initial.name}" + for iter_name, initial in zip(iter_arg_names, initial_values) + ) + result_types = ", ".join(self._render_type(binding.type) for binding in carried_bindings) + + lines.append( + self._indent(indent) + + f"{result_names} = scf.for {stmt.induction_variable.ssa_name} = " + f"{lower_bound.name} to {upper_bound.name} step {step.name} " + f"iter_args({iter_args}) -> ({result_types}) {{" + ) + lines.extend(self._render_block(stmt.body, body_env, indent=indent + 2)) + yielded_values = [ + self._coerce_rendered_value( + body_env[binding.name], + binding.type, + indent=indent + 2, + into=lines, + ) + for binding in carried_bindings + ] + yielded_names = ", ".join(value.name for value in yielded_values) + yielded_types = ", ".join(self._render_type(value.type) for value in yielded_values) + lines.append( + self._indent(indent + 2) + + f"scf.yield {yielded_names} : {yielded_types}" + ) + lines.append(self._indent(indent) + "}") + for binding in carried_bindings: + env[binding.name] = _RenderedValue( + name=binding.ssa_name, + type=binding.type, + ) + return lines + + def _render_if( + self, + stmt: SemanticIfStmt, + env: dict[str, _RenderedValue], + *, + indent: int, + ) -> list[str]: + cond_lines: list[str] = [] + condition = self._lower_condition(stmt.condition, env, indent=indent, into=cond_lines) + then_env = dict(env) + else_env = dict(env) + + if not stmt.results: + lines = list(cond_lines) + lines.append(self._indent(indent) + f"scf.if {condition.name} {{") + lines.extend(self._render_block(stmt.then_body, then_env, indent=indent + 2)) + if stmt.else_body: + lines.append(self._indent(indent) + "} else {") + lines.extend(self._render_block(stmt.else_body, else_env, indent=indent + 2)) + lines.append(self._indent(indent) + "}") + return lines + + lines = list(cond_lines) + result_names = ", ".join(result.result_binding.ssa_name for result in stmt.results) + result_types = ", ".join( + self._render_type(result.result_binding.type) for result in stmt.results + ) + lines.append( + self._indent(indent) + + f"{result_names} = scf.if {condition.name} -> ({result_types}) {{" + ) + lines.extend(self._render_block(stmt.then_body, then_env, indent=indent + 2)) + then_values = [] + for result in stmt.results: + then_value = then_env.get( + result.result_binding.name, + then_env.get(result.then_binding.name), + ) + if then_value is None: + then_value = _RenderedValue(result.then_binding.ssa_name, result.then_binding.type) + then_values.append(then_value) + lines.append( + self._indent(indent + 2) + + "scf.yield " + + ", ".join(value.name for value in then_values) + + " : " + + ", ".join(self._render_type(value.type) for value in then_values) + ) + lines.append(self._indent(indent) + "} else {") + lines.extend(self._render_block(stmt.else_body, else_env, indent=indent + 2)) + else_values = [] + for result in stmt.results: + else_value = else_env.get( + result.result_binding.name, + else_env.get(result.else_binding.name), + ) + if else_value is None: + else_value = _RenderedValue(result.else_binding.ssa_name, result.else_binding.type) + else_values.append(else_value) + lines.append( + self._indent(indent + 2) + + "scf.yield " + + ", ".join(value.name for value in else_values) + + " : " + + ", ".join(self._render_type(value.type) for value in else_values) + ) + lines.append(self._indent(indent) + "}") + for result in stmt.results: + env[result.result_binding.name] = _RenderedValue( + name=result.result_binding.ssa_name, + type=result.result_binding.type, + ) + return lines + + def _lower_condition( + self, + expr: SemanticExpr, + env: dict[str, _RenderedValue], + *, + indent: int, + into: list[str], + ) -> _RenderedValue: + value = self._lower_expr(expr, env, indent=indent, into=into) + if isinstance(value.type, SemanticScalarType) and value.type.dtype.name == "i1": + return value + + zero_type: SemanticType + predicate: str + if isinstance(value.type, SemanticIndexType): + zero_type = SemanticIndexType() + predicate = "arith.cmpi ne" + elif isinstance(value.type, SemanticScalarType): + zero_type = value.type + if value.type.dtype.name in {"f16", "bf16", "f32"}: + predicate = "arith.cmpf une" + else: + predicate = "arith.cmpi ne" + else: + raise NotImplementedError(f"unsupported if condition type {value.type!r}") + + zero = self._materialize_constant(0, zero_type) + result_name = self._new_temp() + into.append( + self._indent(indent) + + f"{result_name} = {predicate}, {value.name}, {zero} : {self._render_type(value.type)}" + ) + return _RenderedValue(name=result_name, type=_I1_TYPE) + + def _lower_expr( + self, + expr: SemanticExpr, + env: dict[str, _RenderedValue], + *, + indent: int, + desired_name: str | None = None, + into: list[str] | None = None, + ) -> _RenderedValue: + if isinstance(expr, SemanticBindingRef): + return env.get(expr.binding.name, _RenderedValue(expr.binding.ssa_name, expr.type)) + if isinstance(expr, SemanticLiteralExpr): + return self._lower_literal_expr( + expr.value, + expr.type, + indent=indent, + desired_name=desired_name, + into=into, + ) + if isinstance(expr, SemanticIndexCastExpr): + if into is None: + into = [] + value = self._lower_expr(expr.value, env, indent=indent, into=into) + return self._coerce_rendered_to_index(value, indent=indent, into=into) + if isinstance(expr, SemanticSubscriptAccess): + return self._lower_subscript_access( + expr, + env, + indent=indent, + desired_name=desired_name, + into=into, + ) + if isinstance(expr, SemanticBinaryExpr): + if into is None: + into = [] + if expr.op in {"and", "or"}: + return self._lower_bool_expr( + expr.op, + expr.lhs, + expr.rhs, + env, + indent=indent, + desired_name=desired_name, + into=into, + ) + lhs = self._lower_expr(expr.lhs, env, indent=indent, into=into) + rhs = self._lower_expr(expr.rhs, env, indent=indent, into=into) + if expr.op in {"eq", "ne", "gt", "lt", "ge", "le"}: + return self._lower_compare_expr( + expr.op, + lhs, + rhs, + indent=indent, + desired_name=desired_name, + into=into, + ) + if isinstance(expr.type, SemanticScalarType): + lhs = self._coerce_rendered_value(lhs, expr.type, indent=indent, into=into) + rhs = self._coerce_rendered_value(rhs, expr.type, indent=indent, into=into) + result_name = desired_name or self._new_temp() + into.append( + self._indent(indent) + + f"{result_name} = {self._render_binary_op(expr.op, expr.type)} " + f"{lhs.name}, {rhs.name} : {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + if isinstance(expr, SemanticCallExpr): + return self._lower_call_expr(expr, env, indent=indent, desired_name=desired_name, into=into) + if isinstance(expr, SemanticAttributeAccess): + raise NotImplementedError("bare shape attribute values are not materialized directly") + if isinstance(expr, SemanticTensorSliceExpr): + return self._lower_tensor_slice_expr( + expr, + env, + indent=indent, + desired_name=desired_name, + into=into, + ) + if isinstance(expr, SemanticSymbolExpr): + raise NotImplementedError("symbol expressions are only lowered through specialized TileLang DSL ops") + raise NotImplementedError(f"unsupported semantic expression {type(expr).__name__}") + + def _lower_call_expr( + self, + expr: SemanticCallExpr, + env: dict[str, _RenderedValue], + *, + indent: int, + desired_name: str | None, + into: list[str] | None, + ) -> _RenderedValue: + if expr.namespace is None: + if into is None: + into = [] + rendered_args = [ + self._lower_expr(arg, env, indent=indent, into=into) + for arg in expr.args + if self._should_materialize_function_boundary_type(arg.type) + ] + rendered_arg_names = ", ".join(arg.name for arg in rendered_args) + rendered_arg_types = ", ".join(self._render_type(arg.type) for arg in rendered_args) + if not rendered_arg_types: + rendered_arg_types = "" + if expr.type is None: + into.append( + self._indent(indent) + + f"func.call {_format_symbol_name(expr.name)}({rendered_arg_names}) : " + + f"({rendered_arg_types}) -> ()" + ) + return _RenderedValue(name="__void_call__", type=SemanticMetaType(kind="void")) + result_name = desired_name or self._new_temp() + into.append( + self._indent(indent) + + f"{result_name} = func.call {_format_symbol_name(expr.name)}({rendered_arg_names}) : " + + f"({rendered_arg_types}) -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.namespace != "pto": + raise NotImplementedError(f"unsupported call namespace {expr.namespace!r}") + if isinstance(expr.type, SemanticTupleType): + raise NotImplementedError("multi-result call values must be assigned directly in TileLang DSL v1") + if into is None: + into = [] + result_name = desired_name or self._new_temp() + + if expr.name == "make_mask": + dtype_expr, pattern_expr = expr.args + if not self._is_dtype_meta_expr(dtype_expr): + raise NotImplementedError("make_mask dtype lowering expects a dtype symbol") + if not isinstance(pattern_expr, SemanticSymbolExpr) or not isinstance(pattern_expr.value, MaskPattern): + raise NotImplementedError("make_mask pattern lowering expects a MaskPattern symbol") + suffix = expr.type.granularity + into.append( + self._indent(indent) + + f'{result_name} = pto.pset_{suffix} "{pattern_expr.value.value}" : {self._render_type(expr.type)}' + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name in {"pset_b8", "pset_b16", "pset_b32", "pge_b8", "pge_b16", "pge_b32"}: + if not isinstance(expr.args[0], SemanticSymbolExpr) or not isinstance(expr.args[0].value, MaskPattern): + raise NotImplementedError(f"{expr.name} lowering expects a MaskPattern symbol") + pattern_token = expr.args[0].value.value.replace("\\", "\\\\").replace('"', '\\"') + pattern = f'"{pattern_token}"' + into.append( + self._indent(indent) + + f"{result_name} = pto.{expr.name} {pattern} : {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name == "init_align": + into.append( + self._indent(indent) + + f"{result_name} = pto.init_align : {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name == "vlds": + source = self._lower_expr(expr.args[0], env, indent=indent, into=into) + if isinstance(source.type, SemanticTileType): + source = self._materialize_tile_memref(source, indent=indent, into=into) + index_args = expr.args[1:] + dist_suffix = "" + if index_args and self._has_optional_string_literal(index_args[-1]): + dist_suffix = f" {{dist = {self._render_string_literal(index_args[-1])}}}" + index_args = index_args[:-1] + if ( + isinstance(expr.args[0].type, SemanticTileType) + and expr.args[0].type.rank == 2 + and len(index_args) == 2 + ): + source = self._materialize_rank2_tile_subview( + source, + expr.args[0].type, + index_args, + env, + indent=indent, + into=into, + ) + rendered_indices = self._materialize_constant(0, SemanticIndexType()) + else: + rendered_indices = self._render_index_list(index_args, env, indent=indent, into=into) + into.append( + self._indent(indent) + + f"{result_name} = pto.vlds {source.name}[{rendered_indices}]{dist_suffix} : " + + f"{self._render_type(source.type)} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name in {"plds", "pldi"}: + source = self._lower_expr(expr.args[0], env, indent=indent, into=into) + if expr.name == "pldi": + offset = self._lower_to_index(expr.args[1], env, indent=indent, into=into) + else: + offset = self._lower_expr(expr.args[1], env, indent=indent, into=into) + dist = self._render_string_literal(expr.args[2]) + into.append( + self._indent(indent) + + f"{result_name} = pto.{expr.name} {source.name}[{offset.name}], {dist} : " + + f"{self._render_type(source.type)}, {self._render_type(offset.type)} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name == "vldas": + source = self._lower_expr(expr.args[0], env, indent=indent, into=into) + index_args = expr.args[1:] + if isinstance(source.type, SemanticTileType): + source = self._materialize_tile_memref(source, indent=indent, into=into) + if ( + isinstance(expr.args[0].type, SemanticTileType) + and expr.args[0].type.rank == 2 + and len(index_args) == 2 + ): + source = self._materialize_rank2_tile_subview( + source, + expr.args[0].type, + index_args, + env, + indent=indent, + into=into, + ) + if self._is_memref_like_type(source.type): + ptr_name, ptr_type = self._materialize_copy_buffer_ptr(source, indent=indent, into=into) + source = _RenderedValue(name=ptr_name, type=_RenderedTextualType(ptr_type)) + into.append( + self._indent(indent) + + f"{result_name} = pto.vldas {source.name} : " + + f"{self._render_type(source.type)} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name == "load_scalar": + source = self._lower_expr(expr.args[0], env, indent=indent, into=into) + offset = self._lower_expr(expr.args[1], env, indent=indent, into=into) + into.append( + self._indent(indent) + + f"{result_name} = pto.load_scalar {source.name}[{offset.name}] : " + + f"{self._render_type(source.type)} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name == "vstus": + align_in = self._lower_expr(expr.args[0], env, indent=indent, into=into) + offset = self._lower_expr(expr.args[1], env, indent=indent, into=into) + offset = self._coerce_rendered_value(offset, _I32_TYPE, indent=indent, into=into) + value = self._lower_expr(expr.args[2], env, indent=indent, into=into) + base = self._lower_expr(expr.args[3], env, indent=indent, into=into) + into.append( + self._indent(indent) + + f"{result_name} = pto.vstus {align_in.name}, {offset.name}, {value.name}, {base.name} : " + + f"{self._render_type(align_in.type)}, {self._render_type(offset.type)}, {self._render_type(value.type)}, {self._render_type(base.type)} " + + f"-> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name == "vstur": + align_in = self._lower_expr(expr.args[0], env, indent=indent, into=into) + value = self._lower_expr(expr.args[1], env, indent=indent, into=into) + base = self._lower_expr(expr.args[2], env, indent=indent, into=into) + mode = self._render_string_literal(expr.args[3]) + into.append( + self._indent(indent) + + f"{result_name} = pto.vstur {align_in.name}, {value.name}, {base.name}, {mode} : " + + f"{self._render_type(align_in.type)}, {self._render_type(value.type)}, {self._render_type(base.type)} " + + f"-> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name in { + "get_block_idx", + "get_subblock_idx", + "get_block_num", + "get_subblock_num", + }: + into.append(self._indent(indent) + f"{result_name} = pto.{expr.name}") + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name == "vbr": + scalar = self._lower_expr(expr.args[0], env, indent=indent, into=into) + into.append( + self._indent(indent) + + f"{result_name} = pto.vbr {scalar.name} : " + + f"{self._render_type(scalar.type)} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name == "vdup": + value = self._lower_expr(expr.args[0], env, indent=indent, into=into) + mask = self._lower_expr(expr.args[1], env, indent=indent, into=into) + if len(expr.args) == 3: + position = self._render_string_literal(expr.args[2]) + into.append( + self._indent(indent) + + f"{result_name} = pto.vdup {value.name}, {mask.name} {{position = {position}}} : " + + f"{self._render_type(value.type)}, {self._render_type(mask.type)} -> {self._render_type(expr.type)}" + ) + else: + into.append( + self._indent(indent) + + f"{result_name} = pto.vdup {value.name}, {mask.name} : " + + f"{self._render_type(value.type)}, {self._render_type(mask.type)} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name == "vci": + index = self._lower_expr(expr.args[0], env, indent=indent, into=into) + order = self._render_string_literal(expr.args[1]) + into.append( + self._indent(indent) + + f"{result_name} = pto.vci {index.name} {{order = {order}}} : " + + f"{self._render_type(index.type)} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name == "tensor_view_as_ptr": + source = self._lower_expr(expr.args[0], env, indent=indent, into=into) + into.append( + self._indent(indent) + + f"{result_name} = pto.tensor_view_addr {source.name} : " + + f"{self._render_type(source.type)} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name == "tile_as_ptr": + source = self._lower_expr(expr.args[0], env, indent=indent, into=into) + into.append( + self._indent(indent) + + f"{result_name} = pto.tile_buf_addr {source.name} : " + + f"{self._render_type(source.type)} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name == "alloc_tile": + if not isinstance(expr.type, SemanticTileType): + raise NotImplementedError("pto.alloc_tile lowering expects a SemanticTileType result") + attrs: list[str] = [] + if len(expr.args) >= 1: + valid_shape_expr = expr.args[0] + valid_shape = expr.type.valid_shape + if valid_shape is None: + raise NotImplementedError("pto.alloc_tile lowering requires known tile valid_shape metadata") + if expr.type.rank >= 1 and valid_shape[0] is None: + valid_row = self._lower_expr(valid_shape_expr, env, indent=indent, into=into) + attrs.append(f"valid_row = {valid_row.name}") + if expr.type.rank >= 2 and valid_shape[1] is None: + if not isinstance(valid_shape_expr, SemanticTupleExpr): + raise NotImplementedError( + "pto.alloc_tile lowering expects tuple valid_shape when valid_col is dynamic" + ) + valid_col = self._lower_expr(valid_shape_expr.elements[1], env, indent=indent, into=into) + attrs.append(f"valid_col = {valid_col.name}") + if len(expr.args) >= 2: + addr = self._lower_expr(expr.args[1], env, indent=indent, into=into) + attrs.append(f"addr = {addr.name}") + op_text = f"{result_name} = pto.alloc_tile" + if attrs: + op_text += " " + " ".join(attrs) + into.append( + self._indent(indent) + + f"{op_text} : {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name == "castptr": + value = self._lower_expr(expr.args[0], env, indent=indent, into=into) + if isinstance(expr.type, SemanticPtrType) and isinstance(value.type, SemanticIndexType): + value = self._coerce_rendered_to_i64(value, indent=indent, into=into) + into.append( + self._indent(indent) + + f"{result_name} = pto.castptr {value.name} : " + + f"{self._render_type(value.type)} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name in { + "i1", + "i8", + "si8", + "ui8", + "i16", + "si16", + "ui16", + "i32", + "si32", + "ui32", + "i64", + "si64", + "ui64", + "f16", + "bf16", + "f32", + }: + value = self._lower_expr(expr.args[0], env, indent=indent, into=into) + return self._coerce_rendered_value(value, expr.type, indent=indent, into=into) + + if expr.name == "addptr": + pointer = self._lower_expr(expr.args[0], env, indent=indent, into=into) + offset = self._lower_expr(expr.args[1], env, indent=indent, into=into) + into.append( + self._indent(indent) + + f"{result_name} = pto.addptr {pointer.name}, {offset.name} : " + + f"{self._render_type(pointer.type)} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name in {"mad", "mad_acc", "mad_bias", "mad_mx", "mad_mx_acc", "mad_mx_bias"}: + self._render_cube_mad_like(expr, env, indent=indent, into=into) + return _RenderedValue(name="__void_call__", type=SemanticMetaType(kind="void")) + + if expr.name in {"cube_load", "cube_store"}: + self._render_cube_load_store(expr, env, indent=indent, into=into) + return _RenderedValue(name="__void_call__", type=SemanticMetaType(kind="void")) + + if expr.name == "cube_load_frac": + self._render_cube_load_frac(expr, env, indent=indent, into=into) + return _RenderedValue(name="__void_call__", type=SemanticMetaType(kind="void")) + + if expr.name == "bias_load": + self._render_cube_bias_load(expr, env, indent=indent, into=into) + return _RenderedValue(name="__void_call__", type=SemanticMetaType(kind="void")) + + if expr.name in {"left_load", "right_load", "left_load_mx", "right_load_mx"}: + self._render_cube_stage_load(expr, env, indent=indent, into=into) + return _RenderedValue(name="__void_call__", type=SemanticMetaType(kind="void")) + + if expr.name in {"acc_store", "acc_store_gm", "acc_store_ub"}: + self._render_cube_acc_store(expr, env, indent=indent, into=into) + return _RenderedValue(name="__void_call__", type=SemanticMetaType(kind="void")) + + if expr.name in {"ppack", "punpack"}: + value = self._lower_expr(expr.args[0], env, indent=indent, into=into) + part = self._render_string_literal(expr.args[1]) + into.append( + self._indent(indent) + + f"{result_name} = pto.{expr.name} {value.name}, {part} : " + + f"{self._render_type(value.type)} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name == "pnot": + value = self._lower_expr(expr.args[0], env, indent=indent, into=into) + mask = self._lower_expr(expr.args[1], env, indent=indent, into=into) + into.append( + self._indent(indent) + + f"{result_name} = pto.pnot {value.name}, {mask.name} : " + + f"{self._render_type(value.type)}, {self._render_type(mask.type)} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name in {"psel", "pand", "por", "pxor"}: + src0 = self._lower_expr(expr.args[0], env, indent=indent, into=into) + src1 = self._lower_expr(expr.args[1], env, indent=indent, into=into) + mask = self._lower_expr(expr.args[2], env, indent=indent, into=into) + into.append( + self._indent(indent) + + f"{result_name} = pto.{expr.name} {src0.name}, {src1.name}, {mask.name} : " + + f"{self._render_type(src0.type)}, {self._render_type(src1.type)}, {self._render_type(mask.type)} " + + f"-> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name == "vcmp": + lhs = self._lower_expr(expr.args[0], env, indent=indent, into=into) + rhs = self._lower_expr(expr.args[1], env, indent=indent, into=into) + seed = self._lower_expr(expr.args[2], env, indent=indent, into=into) + cmp_mode = self._render_string_literal(expr.args[3]) + into.append( + self._indent(indent) + + f"{result_name} = pto.vcmp {lhs.name}, {rhs.name}, {seed.name}, {cmp_mode} : " + + f"{self._render_type(lhs.type)}, {self._render_type(rhs.type)}, {self._render_type(seed.type)} " + + f"-> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name == "vcmps": + vector = self._lower_expr(expr.args[0], env, indent=indent, into=into) + scalar = self._lower_expr(expr.args[1], env, indent=indent, into=into) + seed = self._lower_expr(expr.args[2], env, indent=indent, into=into) + cmp_mode = self._render_string_literal(expr.args[3]) + into.append( + self._indent(indent) + + f"{result_name} = pto.vcmps {vector.name}, {scalar.name}, {seed.name}, {cmp_mode} : " + + f"{self._render_type(vector.type)}, {self._render_type(scalar.type)}, {self._render_type(seed.type)} " + + f"-> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name == "vsel": + src0 = self._lower_expr(expr.args[0], env, indent=indent, into=into) + src1 = self._lower_expr(expr.args[1], env, indent=indent, into=into) + mask = self._lower_expr(expr.args[2], env, indent=indent, into=into) + into.append( + self._indent(indent) + + f"{result_name} = pto.vsel {src0.name}, {src1.name}, {mask.name} : " + + f"{self._render_type(src0.type)}, {self._render_type(src1.type)}, {self._render_type(mask.type)} " + + f"-> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name in {"vselr", "vselrv2"}: + src0 = self._lower_expr(expr.args[0], env, indent=indent, into=into) + src1 = self._lower_expr(expr.args[1], env, indent=indent, into=into) + into.append( + self._indent(indent) + + f"{result_name} = pto.{expr.name} {src0.name}, {src1.name} : " + + f"{self._render_type(src0.type)}, {self._render_type(src1.type)} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name in {"vintlvv2", "vdintlvv2"}: + lhs = self._lower_expr(expr.args[0], env, indent=indent, into=into) + rhs = self._lower_expr(expr.args[1], env, indent=indent, into=into) + part = self._render_string_literal(expr.args[2]) + into.append( + self._indent(indent) + + f"{result_name} = pto.{expr.name} {lhs.name}, {rhs.name}, {part} : " + + f"{self._render_type(lhs.type)}, {self._render_type(rhs.type)} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name == "vcvt": + value = self._lower_expr(expr.args[0], env, indent=indent, into=into) + mask = self._lower_expr(expr.args[2], env, indent=indent, into=into) + attr_parts: list[str] = [] + if self._has_optional_string_literal(expr.args[3]): + attr_parts.append(f"rnd = {self._render_string_literal(expr.args[3])}") + if self._has_optional_string_literal(expr.args[4]): + attr_parts.append(f"sat = {self._render_string_literal(expr.args[4])}") + if self._has_optional_string_literal(expr.args[5]): + attr_parts.append(f"part = {self._render_string_literal(expr.args[5])}") + attr_suffix = f" {{{', '.join(attr_parts)}}}" if attr_parts else "" + into.append( + self._indent(indent) + + f"{result_name} = pto.vcvt {value.name}, {mask.name}{attr_suffix} : " + + f"{self._render_type(value.type)}, {self._render_type(mask.type)} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name == "vbitcast": + value = self._lower_expr(expr.args[0], env, indent=indent, into=into) + into.append( + self._indent(indent) + + f"{result_name} = pto.vbitcast {value.name} : " + + f"{self._render_type(value.type)} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name == "pbitcast": + value = self._lower_expr(expr.args[0], env, indent=indent, into=into) + into.append( + self._indent(indent) + + f"{result_name} = pto.pbitcast {value.name} : " + + f"{self._render_type(value.type)} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name == "vbitsort": + destination = self._lower_expr(expr.args[0], env, indent=indent, into=into) + source = self._lower_expr(expr.args[1], env, indent=indent, into=into) + indices = self._lower_expr(expr.args[2], env, indent=indent, into=into) + repeat_times = self._lower_expr(expr.args[3], env, indent=indent, into=into) + into.append( + self._indent(indent) + + f"pto.vbitsort {destination.name}, {source.name}, {indices.name}, {repeat_times.name} : " + + f"{self._render_type(destination.type)}, {self._render_type(source.type)}, " + + f"{self._render_type(indices.type)}, {self._render_type(repeat_times.type)}" + ) + return _RenderedValue(name="__void_call__", type=SemanticMetaType(kind="void")) + + if expr.name == "vmrgsort4": + destination = self._lower_expr(expr.args[0], env, indent=indent, into=into) + source0 = self._lower_expr(expr.args[1], env, indent=indent, into=into) + source1 = self._lower_expr(expr.args[2], env, indent=indent, into=into) + source2 = self._lower_expr(expr.args[3], env, indent=indent, into=into) + source3 = self._lower_expr(expr.args[4], env, indent=indent, into=into) + count = self._lower_expr(expr.args[5], env, indent=indent, into=into) + config = self._lower_expr(expr.args[6], env, indent=indent, into=into) + count = self._coerce_rendered_value(count, _I64_TYPE, indent=indent, into=into) + config = self._coerce_rendered_value(config, _I64_TYPE, indent=indent, into=into) + into.append( + self._indent(indent) + + f"pto.vmrgsort4 {destination.name}, {source0.name}, {source1.name}, {source2.name}, {source3.name}, " + + f"{count.name}, {config.name} : {self._render_type(destination.type)}, {self._render_type(source0.type)}, " + + f"{self._render_type(source1.type)}, {self._render_type(source2.type)}, {self._render_type(source3.type)}, " + + f"{self._render_type(count.type)}, {self._render_type(config.type)}" + ) + return _RenderedValue(name="__void_call__", type=SemanticMetaType(kind="void")) + + if expr.name == "vtrc": + value = self._lower_expr(expr.args[0], env, indent=indent, into=into) + mask = self._lower_expr(expr.args[1], env, indent=indent, into=into) + rnd = self._render_string_literal(expr.args[2]) + into.append( + self._indent(indent) + + f"{result_name} = pto.vtrc {value.name}, {mask.name}, {rnd} : " + + f"{self._render_type(value.type)}, {self._render_type(mask.type)} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name in { + "vabs", + "vrelu", + "vexp", + "vln", + "vsqrt", + "vrec", + "vnot", + "vcadd", + "vcmax", + "vbcnt", + "vneg", + "vcls", + "vcmin", + "vrsqrt", + "vmov", + "vsunpack", + "vzunpack", + "vusqz", + "vsqz", + "vcgadd", + "vcgmax", + "vcgmin", + "vcpadd", + "vsort32", + }: + if expr.name in {"vsunpack", "vzunpack"}: + value = self._lower_expr(expr.args[0], env, indent=indent, into=into) + part = self._lower_to_index(expr.args[1], env, indent=indent, into=into) + into.append( + self._indent(indent) + + f"{result_name} = pto.{expr.name} {value.name}, {part.name} : " + + f"{self._render_type(value.type)} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + value = self._lower_expr(expr.args[0], env, indent=indent, into=into) + mask = self._lower_expr(expr.args[1], env, indent=indent, into=into) + into.append( + self._indent(indent) + + f"{result_name} = pto.{expr.name} {value.name}, {mask.name} : " + + f"{self._render_type(value.type)}, {self._render_type(mask.type)} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name == "vexpdif": + lhs = self._lower_expr(expr.args[0], env, indent=indent, into=into) + rhs = self._lower_expr(expr.args[1], env, indent=indent, into=into) + mask = self._lower_expr(expr.args[2], env, indent=indent, into=into) + part = self._render_string_literal(expr.args[3]) + into.append( + self._indent(indent) + + f"{result_name} = pto.vexpdif {lhs.name}, {rhs.name}, {mask.name}, {part} : " + + f"{self._render_type(lhs.type)}, {self._render_type(rhs.type)}, {self._render_type(mask.type)} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name in { + "vadd", + "vsub", + "vmul", + "vdiv", + "vmax", + "vmin", + "vand", + "vor", + "vxor", + "vaddrelu", + "vaddreluconv", + "vsubrelu", + "vmulconv", + "vshl", + "vshr", + "vprelu", + "vperm", + "vmrgsort", + }: + lhs = self._lower_expr(expr.args[0], env, indent=indent, into=into) + rhs = self._lower_expr(expr.args[1], env, indent=indent, into=into) + mask = self._lower_expr(expr.args[2], env, indent=indent, into=into) + into.append( + self._indent(indent) + + f"{result_name} = pto.{expr.name} {lhs.name}, {rhs.name}, {mask.name} : " + + f"{self._render_type(lhs.type)}, {self._render_type(rhs.type)}, {self._render_type(mask.type)} " + + f"-> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name == "vpack": + vector = self._lower_expr(expr.args[0], env, indent=indent, into=into) + part = self._render_string_literal(expr.args[1]) + into.append( + self._indent(indent) + + f"{result_name} = pto.vpack {vector.name}, {part} : " + + f"{self._render_type(vector.type)} -> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name in {"vshift", "vslide"}: + vector = self._lower_expr(expr.args[0], env, indent=indent, into=into) + immediate = self._lower_expr(expr.args[1], env, indent=indent, into=into) + mask = self._lower_expr(expr.args[2], env, indent=indent, into=into) + into.append( + self._indent(indent) + + f"{result_name} = pto.{expr.name} {vector.name}, {immediate.name}, {mask.name} : " + + f"{self._render_type(vector.type)}, {self._render_type(immediate.type)}, {self._render_type(mask.type)} " + + f"-> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name in {"vadds", "vsubs", "vmuls", "vdivs", "vmaxs", "vmins", "vlrelu", "vshls", "vshrs", "vands", "vors", "vxors"}: + value = self._lower_expr(expr.args[0], env, indent=indent, into=into) + scalar = self._lower_expr(expr.args[1], env, indent=indent, into=into) + mask = self._lower_expr(expr.args[2], env, indent=indent, into=into) + into.append( + self._indent(indent) + + f"{result_name} = pto.{expr.name} {value.name}, {scalar.name}, {mask.name} : " + + f"{self._render_type(value.type)}, {self._render_type(scalar.type)}, {self._render_type(mask.type)} " + + f"-> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + if expr.name in {"vaxpy", "vmula"}: + vec0 = self._lower_expr(expr.args[0], env, indent=indent, into=into) + vec1 = self._lower_expr(expr.args[1], env, indent=indent, into=into) + vec2 = self._lower_expr(expr.args[2], env, indent=indent, into=into) + mask = self._lower_expr(expr.args[3], env, indent=indent, into=into) + into.append( + self._indent(indent) + + f"{result_name} = pto.{expr.name} {vec0.name}, {vec1.name}, {vec2.name}, {mask.name} : " + + f"{self._render_type(vec0.type)}, {self._render_type(vec1.type)}, {self._render_type(vec2.type)}, {self._render_type(mask.type)} " + + f"-> {self._render_type(expr.type)}" + ) + return _RenderedValue(name=result_name, type=expr.type) + + raise NotImplementedError(f"unsupported pto call `{expr.name}` in lowering") + + def _render_cube_mad_like( + self, + expr: SemanticCallExpr, + env: dict[str, _RenderedValue], + *, + indent: int, + into: list[str], + ) -> None: + has_bias = "bias" in expr.name + lhs = self._lower_expr(expr.args[0], env, indent=indent, into=into) + rhs = self._lower_expr(expr.args[1], env, indent=indent, into=into) + dst = self._lower_expr(expr.args[2], env, indent=indent, into=into) + bias = None + dim_start = 3 + if has_bias: + bias = self._lower_expr(expr.args[3], env, indent=indent, into=into) + dim_start = 4 + m = self._lower_to_i64(expr.args[dim_start], env, indent=indent, into=into) + n = self._lower_to_i64(expr.args[dim_start + 1], env, indent=indent, into=into) + k = self._lower_to_i64(expr.args[dim_start + 2], env, indent=indent, into=into) + unit_flag_ctrl = self._extract_static_int(expr.args[-2], context=f"pto.{expr.name} unit_flag_ctrl") + disable_gemv = self._extract_static_bool(expr.args[-1], context=f"pto.{expr.name} disable_gemv") + + attr_parts: list[str] = [] + if unit_flag_ctrl != 0: + attr_parts.append(f"unit_flag_ctrl = {unit_flag_ctrl} : i32") + if disable_gemv is not True: + attr_parts.append(f"disable_gemv = {'true' if disable_gemv else 'false'}") + attr_suffix = f" {{{', '.join(attr_parts)}}}" if attr_parts else "" + + operands = [lhs.name, rhs.name, dst.name] + operand_types = [self._render_type(lhs.type), self._render_type(rhs.type), self._render_type(dst.type)] + if bias is not None: + operands.append(bias.name) + operand_types.append(self._render_type(bias.type)) + operands.extend([m.name, n.name, k.name]) + operand_types.extend([self._render_type(m.type), self._render_type(n.type), self._render_type(k.type)]) + into.append( + self._indent(indent) + + f"pto.{expr.name} " + + ", ".join(operands) + + attr_suffix + + " : " + + ", ".join(operand_types) + ) + + def _render_cube_load_store( + self, + expr: SemanticCallExpr, + env: dict[str, _RenderedValue], + *, + indent: int, + into: list[str], + ) -> None: + source = self._lower_expr(expr.args[0], env, indent=indent, into=into) + destination = self._lower_expr(expr.args[1], env, indent=indent, into=into) + len_burst = self._lower_to_i64(expr.args[2], env, indent=indent, into=into) + nburst = self._lower_cube_i64_tuple(expr.args[3], env, indent=indent, into=into, expected_len=3) + loop_groups = self._lower_cube_loop_groups(expr.args[4], env, indent=indent, into=into) + + op_text = ( + f"pto.{expr.name} {source.name}, {destination.name}, {len_burst.name}" + f" nburst({nburst[0].name}, {nburst[1].name}, {nburst[2].name})" + ) + type_text = ( + f"{self._render_type(source.type)}, {self._render_type(destination.type)}, " + f"{self._render_type(len_burst.type)}, {self._render_type(nburst[0].type)}, " + f"{self._render_type(nburst[1].type)}, {self._render_type(nburst[2].type)}" + ) + for count, src_stride, dst_stride in loop_groups: + op_text += f" loop({count.name}, {src_stride.name}, {dst_stride.name})" + type_text += ( + f", loop {self._render_type(count.type)}, {self._render_type(src_stride.type)}, " + f"{self._render_type(dst_stride.type)}" + ) + into.append(self._indent(indent) + op_text + " : " + type_text) + + def _render_cube_load_frac( + self, + expr: SemanticCallExpr, + env: dict[str, _RenderedValue], + *, + indent: int, + into: list[str], + ) -> None: + source = self._lower_expr(expr.args[0], env, indent=indent, into=into) + destination = self._lower_expr(expr.args[1], env, indent=indent, into=into) + mode = self._extract_static_string(expr.args[2], context="pto.cube_load_frac mode") + shape = self._lower_cube_i64_tuple(expr.args[3], env, indent=indent, into=into, expected_len=2) + src_layout = self._lower_cube_i64_tuple(expr.args[4], env, indent=indent, into=into, min_len=1, max_len=2) + dst_group = self._lower_cube_i64_tuple(expr.args[5], env, indent=indent, into=into, expected_len=4) + ctrl = self._lower_cube_tuple_elements(expr.args[6], env, indent=indent, into=into, expected_len=2) + l2_cache_ctrl = self._coerce_rendered_to_i64(ctrl[0], indent=indent, into=into) + smallc0_en = self._coerce_rendered_value(ctrl[1], _I1_TYPE, indent=indent, into=into) + + src_layout_operands = ", ".join(value.name for value in src_layout) + src_layout_types = ", ".join(self._render_type(value.type) for value in src_layout) + into.append( + self._indent(indent) + + f"pto.cube_load_frac {source.name}, {destination.name}, {mode}, " + + f"shape({shape[0].name}, {shape[1].name}), " + + f"src_layout({src_layout_operands}), " + + f"dst_group({dst_group[0].name}, {dst_group[1].name}, {dst_group[2].name}, {dst_group[3].name}), " + + f"ctrl({l2_cache_ctrl.name}, {smallc0_en.name}) : " + + f"{self._render_type(source.type)}, {self._render_type(destination.type)}, {mode}, " + + f"shape {self._render_type(shape[0].type)}, {self._render_type(shape[1].type)}, " + + f"src_layout({src_layout_types}), " + + f"dst_group {self._render_type(dst_group[0].type)}, {self._render_type(dst_group[1].type)}, " + + f"{self._render_type(dst_group[2].type)}, {self._render_type(dst_group[3].type)}, " + + f"ctrl {self._render_type(l2_cache_ctrl.type)}, {self._render_type(smallc0_en.type)}" + ) + + def _render_cube_bias_load( + self, + expr: SemanticCallExpr, + env: dict[str, _RenderedValue], + *, + indent: int, + into: list[str], + ) -> None: + source = self._lower_expr(expr.args[0], env, indent=indent, into=into) + destination = self._lower_expr(expr.args[1], env, indent=indent, into=into) + len_burst = self._lower_to_i64(expr.args[2], env, indent=indent, into=into) + nburst = self._lower_cube_i64_tuple(expr.args[3], env, indent=indent, into=into, expected_len=3) + into.append( + self._indent(indent) + + f"pto.bias_load {source.name}, {destination.name}, {len_burst.name}" + + f" nburst({nburst[0].name}, {nburst[1].name}, {nburst[2].name}) : " + + f"{self._render_type(source.type)}, {self._render_type(destination.type)}, " + + f"{self._render_type(len_burst.type)}, {self._render_type(nburst[0].type)}, " + + f"{self._render_type(nburst[1].type)}, {self._render_type(nburst[2].type)}" + ) + + def _render_cube_stage_load( + self, + expr: SemanticCallExpr, + env: dict[str, _RenderedValue], + *, + indent: int, + into: list[str], + ) -> None: + source = self._lower_expr(expr.args[0], env, indent=indent, into=into) + destination = self._lower_expr(expr.args[1], env, indent=indent, into=into) + first = self._lower_to_i64(expr.args[2], env, indent=indent, into=into) + second = self._lower_to_i64(expr.args[3], env, indent=indent, into=into) + into.append( + self._indent(indent) + + f"pto.{expr.name} {source.name}, {destination.name}, {first.name}, {second.name} : " + + f"{self._render_type(source.type)}, {self._render_type(destination.type)}, " + + f"{self._render_type(first.type)}, {self._render_type(second.type)}" + ) + + def _render_cube_acc_store( + self, + expr: SemanticCallExpr, + env: dict[str, _RenderedValue], + *, + indent: int, + into: list[str], + ) -> None: + source = self._lower_expr(expr.args[0], env, indent=indent, into=into) + destination = self._lower_expr(expr.args[1], env, indent=indent, into=into) + dims = [ + self._lower_to_i64(expr.args[i], env, indent=indent, into=into) for i in range(2, 6) + ] + cursor = 6 + sid_l2: tuple[_RenderedValue, _RenderedValue] | None = None + dual_sub: tuple[_RenderedValue, _RenderedValue] | None = None + if expr.name == "acc_store_gm": + sid = self._lower_to_i64(expr.args[cursor], env, indent=indent, into=into) + l2 = self._lower_to_i64(expr.args[cursor + 1], env, indent=indent, into=into) + sid_l2 = (sid, l2) + cursor += 2 + elif expr.name == "acc_store_ub": + dual = self._lower_to_i64(expr.args[cursor], env, indent=indent, into=into) + sub = self._lower_to_i64(expr.args[cursor + 1], env, indent=indent, into=into) + dual_sub = (dual, sub) + cursor += 2 + + mode = self._extract_static_string(expr.args[cursor], context=f"pto.{expr.name} mode") + cursor += 1 + + split_value: _RenderedValue | None = None + loop0_src_stride: _RenderedValue | None = None + loop3_values: tuple[_RenderedValue, _RenderedValue, _RenderedValue] | None = None + if mode == "nz2dn" and cursor < len(expr.args): + loop0_src_stride = self._lower_to_i64(expr.args[cursor], env, indent=indent, into=into) + cursor += 1 + elif mode == "nz2nz" and cursor < len(expr.args): + split_value = self._lower_to_i64(expr.args[cursor], env, indent=indent, into=into) + cursor += 1 + if cursor < len(expr.args): + lowered_loop3 = self._lower_cube_i64_tuple( + expr.args[cursor], env, indent=indent, into=into, expected_len=3 + ) + loop3_values = (lowered_loop3[0], lowered_loop3[1], lowered_loop3[2]) + + pieces: list[_RenderedValue] = [source, destination, *dims] + if sid_l2 is not None: + pieces.extend(sid_l2) + elif dual_sub is not None: + pieces.extend(dual_sub) + + operand_text = ", ".join(value.name for value in pieces) + type_text = ", ".join(self._render_type(value.type) for value in pieces) + op_text = f"pto.{expr.name} {operand_text}, {mode}" + extra_type_parts: list[str] = [] + if mode == "nz2dn" and loop0_src_stride is not None: + op_text = f"pto.{expr.name} {operand_text}, {mode}({loop0_src_stride.name})" + extra_type_parts.append(self._render_type(loop0_src_stride.type)) + elif mode == "nz2nz" and split_value is not None: + op_text = f"pto.{expr.name} {operand_text}, {mode}({split_value.name})" + extra_type_parts.append(self._render_type(split_value.type)) + if loop3_values is not None: + op_text += ( + f", loop3({loop3_values[0].name}, {loop3_values[1].name}, {loop3_values[2].name})" + ) + extra_type_parts.extend( + [ + self._render_type(loop3_values[0].type), + self._render_type(loop3_values[1].type), + self._render_type(loop3_values[2].type), + ] + ) + suffix = (", " + ", ".join(extra_type_parts)) if extra_type_parts else "" + into.append(self._indent(indent) + op_text + " : " + type_text + suffix) + + def _lower_cube_loop_groups( + self, + expr: SemanticExpr, + env: dict[str, _RenderedValue], + *, + indent: int, + into: list[str], + ) -> list[tuple[_RenderedValue, _RenderedValue, _RenderedValue]]: + if self._is_none_meta_expr(expr): + return [] + if not isinstance(expr, SemanticTupleExpr): + raise NotImplementedError("cube loop lowering expects a tuple of loop triples") + loop_groups: list[tuple[_RenderedValue, _RenderedValue, _RenderedValue]] = [] + for loop_expr in expr.elements: + lowered_loop = self._lower_cube_i64_tuple(loop_expr, env, indent=indent, into=into, expected_len=3) + loop_groups.append((lowered_loop[0], lowered_loop[1], lowered_loop[2])) + return loop_groups + + def _lower_cube_i64_tuple( + self, + expr: SemanticExpr, + env: dict[str, _RenderedValue], + *, + indent: int, + into: list[str], + expected_len: int | None = None, + min_len: int | None = None, + max_len: int | None = None, + ) -> tuple[_RenderedValue, ...]: + lowered = self._lower_cube_tuple_elements(expr, env, indent=indent, into=into, expected_len=expected_len, min_len=min_len, max_len=max_len) + return tuple(self._coerce_rendered_to_i64(value, indent=indent, into=into) for value in lowered) + + def _lower_cube_tuple_elements( + self, + expr: SemanticExpr, + env: dict[str, _RenderedValue], + *, + indent: int, + into: list[str], + expected_len: int | None = None, + min_len: int | None = None, + max_len: int | None = None, + ) -> tuple[_RenderedValue, ...]: + if not isinstance(expr, SemanticTupleExpr): + raise NotImplementedError("cube structured lowering expects a tuple expression") + elements = expr.elements + if expected_len is not None and len(elements) != expected_len: + raise NotImplementedError(f"cube structured lowering expects exactly {expected_len} tuple elements") + if min_len is not None and len(elements) < min_len: + raise NotImplementedError(f"cube structured lowering expects at least {min_len} tuple elements") + if max_len is not None and len(elements) > max_len: + raise NotImplementedError(f"cube structured lowering expects at most {max_len} tuple elements") + return tuple(self._lower_expr(element, env, indent=indent, into=into) for element in elements) + + def _extract_static_value(self, expr: SemanticExpr, *, context: str) -> object: + if isinstance(expr, SemanticLiteralExpr): + return expr.value + if isinstance(expr, SemanticBindingRef): + return expr.binding.value + raise NotImplementedError(f"{context} must be a compile-time constant in TileLang DSL v1 lowering") + + def _extract_static_int(self, expr: SemanticExpr, *, context: str) -> int: + value = self._extract_static_value(expr, context=context) + if isinstance(value, bool) or not isinstance(value, int): + raise NotImplementedError(f"{context} must be an integer constant in TileLang DSL v1 lowering") + return value + + def _extract_static_bool(self, expr: SemanticExpr, *, context: str) -> bool: + value = self._extract_static_value(expr, context=context) + if not isinstance(value, bool): + raise NotImplementedError(f"{context} must be a boolean constant in TileLang DSL v1 lowering") + return value + + def _extract_static_string(self, expr: SemanticExpr, *, context: str) -> str: + value = self._extract_static_value(expr, context=context) + if not isinstance(value, str): + raise NotImplementedError(f"{context} must be a string constant in TileLang DSL v1 lowering") + return value + + def _is_none_meta_expr(self, expr: SemanticExpr | None) -> bool: + return isinstance(expr, SemanticLiteralExpr) and expr.value is None + + def _lower_compare_expr( + self, + op: str, + lhs: _RenderedValue, + rhs: _RenderedValue, + *, + indent: int, + desired_name: str | None, + into: list[str], + ) -> _RenderedValue: + result_name = desired_name or self._new_temp() + if isinstance(lhs.type, SemanticIndexType) and isinstance(rhs.type, SemanticIndexType): + index_predicates = { + "eq": "eq", + "ne": "ne", + "gt": "sgt", + "lt": "slt", + "ge": "sge", + "le": "sle", + } + predicate = index_predicates[op] + cmp_name = "arith.cmpi" + elif ( + isinstance(lhs.type, SemanticIndexType) + and isinstance(rhs.type, SemanticScalarType) + and is_integer_dtype(rhs.type.dtype) + and integer_bitwidth(rhs.type.dtype) in {8, 16, 32, 64} + ): + lhs = self._coerce_rendered_value(lhs, rhs.type, indent=indent, into=into) + elif ( + isinstance(rhs.type, SemanticIndexType) + and isinstance(lhs.type, SemanticScalarType) + and is_integer_dtype(lhs.type.dtype) + and integer_bitwidth(lhs.type.dtype) in {8, 16, 32, 64} + ): + rhs = self._coerce_rendered_value(rhs, lhs.type, indent=indent, into=into) + + if isinstance(lhs.type, SemanticScalarType) and lhs.type == rhs.type: + if lhs.type.dtype.name in {"f16", "bf16", "f32"}: + float_predicates = { + "eq": "oeq", + "ne": "une", + "gt": "ogt", + "lt": "olt", + "ge": "oge", + "le": "ole", + } + predicate = float_predicates[op] + cmp_name = "arith.cmpf" + else: + int_sign = integer_signedness(lhs.type.dtype) + int_predicates = { + "eq": "eq", + "ne": "ne", + "gt": "ugt" if int_sign == "unsigned" else "sgt", + "lt": "ult" if int_sign == "unsigned" else "slt", + "ge": "uge" if int_sign == "unsigned" else "sge", + "le": "ule" if int_sign == "unsigned" else "sle", + } + predicate = int_predicates[op] + cmp_name = "arith.cmpi" + into.append( + self._indent(indent) + + f"{result_name} = {cmp_name} {predicate}, {lhs.name}, {rhs.name} : " + f"{self._render_type(lhs.type)}" + ) + return _RenderedValue(name=result_name, type=_I1_TYPE) + + if isinstance(lhs.type, SemanticIndexType) and isinstance(rhs.type, SemanticIndexType): + into.append( + self._indent(indent) + + f"{result_name} = {cmp_name} {predicate}, {lhs.name}, {rhs.name} : {self._render_type(lhs.type)}" + ) + return _RenderedValue(name=result_name, type=_I1_TYPE) + + raise NotImplementedError( + f"comparison lowering requires matching scalar types or index operands, got {lhs.type!r} and {rhs.type!r}" + ) + + def _lower_bool_expr( + self, + op: str, + lhs_expr: SemanticExpr, + rhs_expr: SemanticExpr, + env: dict[str, _RenderedValue], + *, + indent: int, + desired_name: str | None, + into: list[str], + ) -> _RenderedValue: + lhs = self._lower_condition(lhs_expr, env, indent=indent, into=into) + rhs = self._lower_condition(rhs_expr, env, indent=indent, into=into) + result_name = desired_name or self._new_temp() + arith_op = "arith.andi" if op == "and" else "arith.ori" + into.append( + self._indent(indent) + + f"{result_name} = {arith_op} {lhs.name}, {rhs.name} : i1" + ) + return _RenderedValue(name=result_name, type=_I1_TYPE) + + def _render_string_literal(self, expr: SemanticExpr) -> str: + if isinstance(expr, SemanticLiteralExpr) and isinstance(expr.value, str): + escaped = expr.value.replace("\\", "\\\\").replace('"', '\\"') + return f'"{escaped}"' + if isinstance(expr, SemanticBindingRef) and isinstance(expr.binding.value, str): + escaped = expr.binding.value.replace("\\", "\\\\").replace('"', '\\"') + return f'"{escaped}"' + raise NotImplementedError("expected a string literal for TileLang DSL advanced-family lowering") + + def _has_optional_string_literal(self, expr: SemanticExpr) -> bool: + if isinstance(expr, SemanticLiteralExpr): + return isinstance(expr.value, str) + if isinstance(expr, SemanticBindingRef): + return isinstance(expr.binding.value, str) + return False + + def _render_dtype_symbol(self, expr: SemanticExpr, *, context: str) -> str: + if isinstance(expr, SemanticSymbolExpr) and isinstance(expr.value, ScalarType): + return expr.value.name + if ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "dtype" + and isinstance(expr.binding.value, ScalarType) + ): + return expr.binding.value.name + raise NotImplementedError(f"{context} expects a dtype symbol in TileLang DSL v1 lowering") + + def _lower_to_i1( + self, + expr: SemanticExpr, + env: dict[str, _RenderedValue], + *, + indent: int, + into: list[str], + ) -> _RenderedValue: + value = self._lower_expr(expr, env, indent=indent, into=into) + if isinstance(value.type, SemanticScalarType) and value.type.dtype.name == "i1": + return value + raise NotImplementedError("expected an i1 operand during TileLang DSL v1 lowering") + + def _lower_to_i64( + self, + expr: SemanticExpr, + env: dict[str, _RenderedValue], + *, + indent: int, + into: list[str], + ) -> _RenderedValue: + value = self._lower_expr(expr, env, indent=indent, into=into) + return self._coerce_rendered_to_i64(value, indent=indent, into=into) + + def _lower_to_i32( + self, + expr: SemanticExpr, + env: dict[str, _RenderedValue], + *, + indent: int, + into: list[str], + ) -> _RenderedValue: + value = self._lower_expr(expr, env, indent=indent, into=into) + if isinstance(value.type, SemanticIndexType) or ( + isinstance(value.type, SemanticScalarType) and is_integer_dtype(value.type.dtype) + ): + return self._coerce_rendered_value(value, _I32_TYPE, indent=indent, into=into) + raise NotImplementedError("expected an i32 or index operand during TileLang DSL v1 lowering") + + def _lower_to_index( + self, + expr: SemanticExpr, + env: dict[str, _RenderedValue], + *, + indent: int, + into: list[str], + ) -> _RenderedValue: + value = self._lower_expr(expr, env, indent=indent, into=into) + return self._coerce_rendered_to_index(value, indent=indent, into=into) + + def _coerce_rendered_to_index( + self, + value: _RenderedValue, + *, + indent: int, + into: list[str], + ) -> _RenderedValue: + if isinstance(value.type, SemanticIndexType): + return value + if isinstance(value.type, SemanticScalarType) and is_integer_dtype(value.type.dtype): + bits = integer_bitwidth(value.type.dtype) + if bits in {8, 16}: + value = self._coerce_rendered_value(value, _I32_TYPE, indent=indent, into=into) + elif bits in {32, 64}: + value = self._bridge_rendered_to_signless_integer(value, indent=indent, into=into) + else: + value = None + if value is not None: + cast_name = self._new_temp() + into.append( + self._indent(indent) + + f"{cast_name} = arith.index_cast {value.name} : {value.type.dtype.name} to index" + ) + return _RenderedValue(name=cast_name, type=SemanticIndexType()) + raise NotImplementedError("expected an integer scalar or index operand during TileLang DSL v1 lowering") + + def _coerce_rendered_to_i64( + self, + value: _RenderedValue, + *, + indent: int, + into: list[str], + ) -> _RenderedValue: + if isinstance(value.type, SemanticIndexType) or ( + isinstance(value.type, SemanticScalarType) and is_integer_dtype(value.type.dtype) + ): + return self._coerce_rendered_value(value, _I64_TYPE, indent=indent, into=into) + raise NotImplementedError("expected an i64 or index operand during TileLang DSL v1 lowering") + + def _lower_remaining_to_i32( + self, + expr: SemanticExpr, + env: dict[str, _RenderedValue], + *, + indent: int, + into: list[str], + ) -> _RenderedValue: + value = self._lower_expr(expr, env, indent=indent, into=into) + if isinstance(value.type, SemanticIndexType) or ( + isinstance(value.type, SemanticScalarType) and is_integer_dtype(value.type.dtype) + ): + return self._coerce_rendered_value(value, _I32_TYPE, indent=indent, into=into) + raise NotImplementedError("tail make_mask lowering expects an i32 or index remaining operand") + + def _bridge_rendered_to_signless_integer( + self, + value: _RenderedValue, + *, + indent: int, + into: list[str], + ) -> _RenderedValue: + if not isinstance(value.type, SemanticScalarType) or not is_integer_dtype(value.type.dtype): + return value + raw_type = self._signless_integer_scalar_type(value.type) + if raw_type is None: + return value + cast_name = self._new_temp() + into.append( + self._indent(indent) + + f"{cast_name} = builtin.unrealized_conversion_cast {value.name} : " + f"{self._render_type(value.type)} to {self._render_type(raw_type)}" + ) + return _RenderedValue(name=cast_name, type=raw_type) + + def _bridge_rendered_integer_to_target( + self, + value: _RenderedValue, + target_type: SemanticScalarType, + *, + indent: int, + into: list[str], + ) -> _RenderedValue: + if value.type == target_type: + return value + cast_name = self._new_temp() + into.append( + self._indent(indent) + + f"{cast_name} = builtin.unrealized_conversion_cast {value.name} : " + f"{self._render_type(value.type)} to {self._render_type(target_type)}" + ) + return _RenderedValue(name=cast_name, type=target_type) + + def _materialize_copy_buffer_ptr( + self, + value: _RenderedValue, + *, + indent: int, + into: list[str], + ) -> tuple[str, str]: + ptr_type = self._render_copy_buffer_type(value.type) + cache_key = (value.name, ptr_type) + existing = self._castptr_cache.get(cache_key) + if existing is not None: + return existing, ptr_type + + if isinstance(value.type, SemanticTileType): + value = self._materialize_tile_memref(value, indent=indent, into=into) + + if self._is_memref_like_type(value.type): + cast_name = self._new_temp() + into.append( + self._indent(indent) + + f"{cast_name} = pto.castptr {value.name} : {self._render_type(value.type)} -> {ptr_type}" + ) + self._castptr_cache[cache_key] = cast_name + return cast_name, ptr_type + + return value.name, ptr_type + + def _coerce_rendered_value( + self, + value: _RenderedValue, + target_type: SemanticType, + *, + indent: int, + into: list[str], + ) -> _RenderedValue: + def _scalar_int_bits(dtype: ScalarType) -> int | None: + if dtype.name == "i1": + return 1 + return integer_bitwidth(dtype) + + def _scalar_int_sign(dtype: ScalarType) -> str: + sign = integer_signedness(dtype) + return "signless" if sign is None else sign + + def _signless_int_type_for_bits(bits: int) -> SemanticScalarType: + if bits not in {8, 16, 32, 64}: + raise NotImplementedError( + f"unsupported integer bitwidth {bits!r} for signless coercion in TileLang DSL v1 lowering" + ) + return SemanticScalarType(dtype=ScalarType(f"i{bits}")) + + if type(value.type) is type(target_type) and value.type == target_type: + return value + if isinstance(value.type, SemanticIndexType) and isinstance(target_type, SemanticScalarType): + target_int_bits = _scalar_int_bits(target_type.dtype) + target_sign = _scalar_int_sign(target_type.dtype) + signless_target_type = self._signless_integer_scalar_type(target_type) + if signless_target_type is None and target_int_bits in {8, 16, 32, 64}: + signless_target_type = target_type + if target_int_bits in {8, 16, 32, 64} and signless_target_type is not None: + carrier_type = ( + _signless_int_type_for_bits(32) + if target_int_bits in {8, 16} + else _signless_int_type_for_bits(target_int_bits) + ) + op = "arith.index_castui" if target_sign == "unsigned" else "arith.index_cast" + cast_name = self._new_temp() + into.append( + self._indent(indent) + + f"{cast_name} = {op} {value.name} : index to {carrier_type.dtype.name}" + ) + lowered_value = _RenderedValue(name=cast_name, type=carrier_type) + if target_int_bits in {8, 16}: + lowered_value = self._coerce_rendered_value( + lowered_value, + signless_target_type, + indent=indent, + into=into, + ) + if signless_target_type != target_type: + return self._coerce_rendered_value( + lowered_value, + target_type, + indent=indent, + into=into, + ) + return lowered_value + if target_type.dtype.name in {"f16", "bf16", "f32"}: + index_to_int_name = self._new_temp() + index_to_int_op = "arith.index_castui" + into.append( + self._indent(indent) + + f"{index_to_int_name} = {index_to_int_op} {value.name} : index to i64" + ) + cast_name = self._new_temp() + into.append( + self._indent(indent) + + f"{cast_name} = arith.uitofp {index_to_int_name} : i64 to {target_type.dtype.name}" + ) + return _RenderedValue(name=cast_name, type=target_type) + if isinstance(value.type, SemanticScalarType) and isinstance(target_type, SemanticScalarType): + src = value.type.dtype.name + dst = target_type.dtype.name + if src == dst: + return value + src_bits = _scalar_int_bits(value.type.dtype) + dst_bits = _scalar_int_bits(target_type.dtype) + if src_bits is not None and dst_bits is not None: + src_sign = _scalar_int_sign(value.type.dtype) + signless_value = self._bridge_rendered_to_signless_integer(value, indent=indent, into=into) + signless_target_type = self._signless_integer_scalar_type(target_type) or target_type + if signless_value.type == signless_target_type: + if signless_target_type != target_type: + return self._bridge_rendered_integer_to_target( + signless_value, + target_type, + indent=indent, + into=into, + ) + return signless_value + cast_name = self._new_temp() + if src_bits < dst_bits: + op = "arith.extui" if _scalar_int_sign(value.type.dtype) == "unsigned" else "arith.extsi" + elif src_bits > dst_bits: + op = "arith.trunci" + else: + raise NotImplementedError( + f"unsupported same-width integer coercion from {value.type!r} to {target_type!r} " + "in TileLang DSL v1 lowering" + ) + into.append( + self._indent(indent) + + f"{cast_name} = {op} {signless_value.name} : " + f"{self._render_type(signless_value.type)} to {self._render_type(signless_target_type)}" + ) + lowered_value = _RenderedValue(name=cast_name, type=signless_target_type) + if signless_target_type != target_type: + return self._bridge_rendered_integer_to_target( + lowered_value, + target_type, + indent=indent, + into=into, + ) + return lowered_value + if src_bits is not None and dst in {"f16", "bf16", "f32"}: + signless_value = self._bridge_rendered_to_signless_integer(value, indent=indent, into=into) + cast_name = self._new_temp() + op = "arith.uitofp" if _scalar_int_sign(value.type.dtype) == "unsigned" else "arith.sitofp" + into.append( + self._indent(indent) + + f"{cast_name} = {op} {signless_value.name} : " + f"{self._render_type(signless_value.type)} to {dst}" + ) + return _RenderedValue(name=cast_name, type=target_type) + if src in {"f16", "bf16", "f32"} and dst_bits is not None: + signless_target_type = self._signless_integer_scalar_type(target_type) or target_type + cast_name = self._new_temp() + op = "arith.fptoui" if _scalar_int_sign(target_type.dtype) == "unsigned" else "arith.fptosi" + into.append( + self._indent(indent) + + f"{cast_name} = {op} {value.name} : {src} to {self._render_type(signless_target_type)}" + ) + lowered_value = _RenderedValue(name=cast_name, type=signless_target_type) + if signless_target_type != target_type: + return self._bridge_rendered_integer_to_target( + lowered_value, + target_type, + indent=indent, + into=into, + ) + return lowered_value + cast_name = self._new_temp() + if src in {"f16", "bf16", "f32"} and dst in {"f16", "bf16", "f32"}: + op = "arith.extf" if src in {"f16", "bf16"} and dst == "f32" else "arith.truncf" + into.append( + self._indent(indent) + + f"{cast_name} = {op} {value.name} : {src} to {dst}" + ) + return _RenderedValue(name=cast_name, type=target_type) + raise NotImplementedError( + f"unsupported value coercion from {value.type!r} to {target_type!r} in TileLang DSL v1 lowering" + ) + + def _materialize_strict_vecscope_capture( + self, + capture: _RenderedValue, + binding: SemanticBinding, + *, + indent: int, + into: list[str], + ) -> tuple[_RenderedValue, _RenderedValue]: + if not self._is_memref_like_type(capture.type): + return capture, _RenderedValue(name=binding.ssa_name, type=binding.type) + + ptr_name, ptr_type = self._materialize_copy_buffer_ptr( + capture, + indent=indent, + into=into, + ) + rendered_ptr_type = _RenderedTextualType(ptr_type) + return ( + _RenderedValue(name=ptr_name, type=rendered_ptr_type), + _RenderedValue(name=binding.ssa_name, type=rendered_ptr_type), + ) + + def _mask_suffix(self, ty: SemanticType) -> str: + if not isinstance(ty, SemanticMaskType): + raise NotImplementedError("tail make_mask lowering expects a mask result type") + return ty.granularity + + def _is_dtype_meta_expr(self, expr: SemanticExpr) -> bool: + if isinstance(expr, SemanticSymbolExpr): + return isinstance(expr.value, ScalarType) and expr.type.kind == "dtype" + if isinstance(expr, SemanticBindingRef): + return ( + isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "dtype" + and isinstance(expr.binding.value, ScalarType) + ) + return False + + def _lower_subscript_access( + self, + expr: SemanticSubscriptAccess, + env: dict[str, _RenderedValue], + *, + indent: int, + desired_name: str | None, + into: list[str] | None, + ) -> _RenderedValue: + if isinstance(expr.base, SemanticTupleExpr): + if not isinstance(expr.index, SemanticLiteralExpr) or not isinstance(expr.index.value, int): + raise NotImplementedError("tuple indices must be integer literals in TileLang DSL v1 lowering") + if expr.index.value < 0 or expr.index.value >= len(expr.base.elements): + raise NotImplementedError( + f"tuple subscript index {expr.index.value} is out of bounds for tuple length {len(expr.base.elements)}" + ) + return self._lower_expr( + expr.base.elements[expr.index.value], + env, + indent=indent, + desired_name=desired_name, + into=into, + ) + if ( + into is not None + and isinstance(expr.base, SemanticAttributeAccess) + and expr.base.attr == "valid_shape" + and isinstance(expr.base.base, SemanticBindingRef) + and isinstance(expr.base.base.type, SemanticTileType) + and isinstance(expr.index, SemanticLiteralExpr) + and isinstance(expr.index.value, int) + ): + return self._materialize_tile_valid_dim( + expr.base.base.binding, + expr.index.value, + indent=indent, + into=into, + desired_name=desired_name, + ) + if ( + into is not None + and isinstance(expr.base, SemanticAttributeAccess) + and expr.base.attr in {"shape", "valid_shape", "strides"} + and isinstance(expr.base.base, SemanticBindingRef) + and isinstance( + expr.base.base.type, + (SemanticTensorViewType, SemanticPartitionTensorViewType), + ) + and isinstance(expr.index, SemanticLiteralExpr) + and isinstance(expr.index.value, int) + ): + tensor_value = env.get( + expr.base.base.binding.name, + _RenderedValue(expr.base.base.binding.ssa_name, expr.base.base.type), + ) + result_name = desired_name or self._new_temp() + axis_value = self._materialize_constant(expr.index.value, SemanticIndexType()) + op_name = ( + "pto.get_tensor_view_stride" + if expr.base.attr == "strides" + else "pto.get_tensor_view_dim" + ) + into.append( + self._indent(indent) + + f"{result_name} = {op_name} {tensor_value.name}, {axis_value} : " + + f"{self._render_type(tensor_value.type)} -> index" + ) + return _RenderedValue(name=result_name, type=SemanticIndexType()) + value = self._extract_shape_subscript_value(expr, env) + if isinstance(value, _RenderedValue): + return value + if desired_name is not None and into is not None: + into.append( + self._indent(indent) + + f"{desired_name} = arith.constant {self._format_constant(value, expr.type)} : " + f"{self._render_arith_constant_type(expr.type)}" + ) + return _RenderedValue(name=desired_name, type=expr.type) + return _RenderedValue( + name=self._materialize_constant(value, expr.type), + type=expr.type, + ) + + def _tensor_shape_binding_name(self, tensor_name: str, axis: int) -> str: + return f"__shape_{tensor_name}_{axis}" + + def _tensor_stride_binding_name(self, tensor_name: str, axis: int) -> str: + return f"__stride_{tensor_name}_{axis}" + + def _materialize_tile_memref( + self, + value: _RenderedValue, + *, + indent: int, + into: list[str], + ) -> _RenderedValue: + existing = self._tile_memref_cache.get(value.name) + if existing is not None: + return existing + if not isinstance(value.type, SemanticTileType): + return value + memref_type = _RenderedTextualType( + self._render_memref_type( + element_dtype=value.type.element_dtype.name, + shape=value.type.shape if value.type.shape is not None else ("?",) * value.type.rank, + memory_space=value.type.memory_space or "ub", + ) + ) + memref_name = self._new_temp() + into.append( + self._indent(indent) + + f"{memref_name} = pto.tile_buf_addr {value.name} : " + + f"{self._render_type(value.type)} -> {self._render_type(memref_type)}" + ) + rendered = _RenderedValue(name=memref_name, type=memref_type) + self._tile_memref_cache[value.name] = rendered + return rendered + + def _materialize_tile_valid_dim( + self, + binding: object, + axis: int, + *, + indent: int, + into: list[str], + desired_name: str | None = None, + ) -> _RenderedValue: + cache_key = (binding.name, axis) + existing = self._tile_valid_dim_cache.get(cache_key) + if existing is not None: + return existing + source = _RenderedValue(name=binding.ssa_name, type=binding.type) + op_name = "pto.tile_valid_rows" if axis == 0 else "pto.tile_valid_cols" + result_name = desired_name or self._new_temp() + into.append( + self._indent(indent) + + f"{result_name} = {op_name} {source.name} : " + + f"{self._render_type(source.type)} -> index" + ) + rendered = _RenderedValue(name=result_name, type=SemanticIndexType()) + self._tile_valid_dim_cache[cache_key] = rendered + return rendered + + def _extract_shape_subscript_value( + self, + expr: SemanticSubscriptAccess, + env: dict[str, _RenderedValue], + ) -> int | _RenderedValue: + if not isinstance(expr.base, SemanticAttributeAccess): + raise NotImplementedError("only shape/stride indexing is supported in TileLang DSL v1 lowering") + if expr.base.attr not in {"shape", "valid_shape", "strides"}: + raise NotImplementedError( + "only `.shape[...]`, `.valid_shape[...]`, and `.strides[...]` indexing are supported in TileLang DSL v1 lowering" + ) + if not isinstance(expr.index, SemanticLiteralExpr) or not isinstance(expr.index.value, int): + raise NotImplementedError("shape/stride indices must be integer literals in TileLang DSL v1 lowering") + if not isinstance(expr.base.base, SemanticBindingRef): + raise NotImplementedError("shape/stride indexing expects a bound TensorView or Tile value") + + base_binding = expr.base.base.binding + base_value = env.get(base_binding.name, _RenderedValue(base_binding.ssa_name, base_binding.type)) + base_type = base_value.type + index = expr.index.value + + if isinstance(base_type, SemanticTileType): + if expr.base.attr == "shape": + if base_type.shape is None: + raise NotImplementedError("dynamic Tile shapes are not supported in TileLang DSL v1 lowering") + return base_type.shape[index] + if base_type.valid_shape is None: + raise NotImplementedError("dynamic Tile shapes are not supported in TileLang DSL v1 lowering") + valid_dim = base_type.valid_shape[index] + if valid_dim is not None: + return valid_dim + return _RenderedValue(name=base_binding.ssa_name, type=base_type) + + if isinstance(base_type, (SemanticTensorViewType, SemanticPartitionTensorViewType)): + if expr.base.attr == "strides": + hidden_name = self._tensor_stride_binding_name(base_binding.name, index) + else: + hidden_name = self._tensor_shape_binding_name(base_binding.name, index) + hidden_value = env.get(hidden_name) + if hidden_value is None: + raise NotImplementedError( + f"missing TensorView/PartitionTensorView {expr.base.attr} binding for '{base_binding.name}.{expr.base.attr}[{index}]'" + ) + return hidden_value + + raise NotImplementedError("shape/stride indexing expects a Tile, TensorView, or PartitionTensorView operand") + + def _format_shape_tuple(self, shape: tuple[int | None, ...]) -> str: + return "(" + ", ".join("?" if dim is None else str(dim) for dim in shape) + ")" + + def _materialize_constant(self, value: object, ty: SemanticType) -> str: + cache_key = (self._render_type(ty), value) + if cache_key in self._constant_cache: + return self._constant_cache[cache_key] + + raw_type = self._signless_integer_scalar_type(ty) + if raw_type is not None: + raw_name = self._materialize_constant(value, raw_type) + name = self._constant_name(value, ty) + self._constant_cache[cache_key] = name + self._constant_lines.append( + self._indent(4) + + f"{name} = builtin.unrealized_conversion_cast {raw_name} : " + f"{self._render_type(raw_type)} to {self._render_type(ty)}" + ) + return name + + name = self._constant_name(value, ty) + self._constant_cache[cache_key] = name + self._constant_lines.append( + self._indent(4) + + f"{name} = arith.constant {self._format_constant(value, ty)} : " + f"{self._render_arith_constant_type(ty)}" + ) + return name + + def _signless_integer_scalar_type(self, ty: SemanticType) -> SemanticScalarType | None: + if not isinstance(ty, SemanticScalarType) or not is_integer_dtype(ty.dtype): + return None + signedness = integer_signedness(ty.dtype) + if signedness in {None, "signless"}: + return None + bitwidth = integer_bitwidth(ty.dtype) + if bitwidth not in {8, 16, 32, 64}: + raise NotImplementedError( + f"unsupported integer bitwidth {bitwidth!r} for signless literal lowering" + ) + return SemanticScalarType(dtype=ScalarType(f"i{bitwidth}")) + + def _lower_literal_expr( + self, + value: object, + ty: SemanticType, + *, + indent: int, + desired_name: str | None = None, + into: list[str] | None = None, + ) -> _RenderedValue: + raw_type = self._signless_integer_scalar_type(ty) or ty + if desired_name is not None and into is not None and raw_type == ty: + into.append( + self._indent(indent) + + f"{desired_name} = arith.constant {self._format_constant(value, ty)} : " + f"{self._render_arith_constant_type(ty)}" + ) + return _RenderedValue(name=desired_name, type=ty) + + if desired_name is not None and into is not None: + raw_name = self._new_temp() + into.append( + self._indent(indent) + + f"{raw_name} = arith.constant {self._format_constant(value, raw_type)} : " + f"{self._render_arith_constant_type(raw_type)}" + ) + into.append( + self._indent(indent) + + f"{desired_name} = builtin.unrealized_conversion_cast {raw_name} : " + f"{self._render_type(raw_type)} to {self._render_type(ty)}" + ) + return _RenderedValue(name=desired_name, type=ty) + + return _RenderedValue( + name=self._materialize_constant(value, ty), + type=ty, + ) + + def _constant_name(self, value: object, ty: SemanticType) -> str: + if isinstance(ty, SemanticIndexType): + stem = f"c{value}" + elif isinstance(ty, SemanticScalarType): + if ty.dtype.name == "i1" and isinstance(value, bool): + stem = "true" if value else "false" + else: + stem = f"c{value}_{ty.dtype.name}" + else: + stem = "cst" + # Keep generated SSA names MLIR-safe for constants whose textual value + # contains punctuation such as decimal points or scientific-notation + # exponents (for example f32 max -> `3.4028235e+38`). + stem = re.sub(r"[^0-9A-Za-z_]", "_", stem) + stem = re.sub(r"_+", "_", stem).strip("_") or "cst" + if stem[0].isdigit(): + stem = f"c_{stem}" + + name = f"%{stem}" + existing = {line.split(" = ", 1)[0].strip() for line in self._constant_lines} + if name not in existing: + return name + suffix = 0 + while f"{name}_{suffix}" in existing: + suffix += 1 + return f"{name}_{suffix}" + + def _format_constant(self, value: object, ty: SemanticType) -> str: + if isinstance(ty, SemanticIndexType): + return str(value) + if isinstance(ty, SemanticScalarType): + if ty.dtype.name in {"f16", "bf16", "f32"} and isinstance( + value, (bool, int, float) + ): + return self._format_float_constant(float(value), ty.dtype.name) + if ty.dtype.name == "i1" and isinstance(value, bool): + return "1" if value else "0" + return str(value) + raise NotImplementedError(f"unsupported constant type {ty!r}") + + def _render_arith_constant_type(self, ty: SemanticType) -> str: + if isinstance(ty, SemanticScalarType) and is_integer_dtype(ty.dtype): + width = integer_bitwidth(ty.dtype) + if width is None: + raise NotImplementedError( + f"unsupported integer dtype {ty.dtype.name!r} for arith.constant emission" + ) + return f"i{width}" + return self._render_type(ty) + + def _format_float_constant(self, value: float, dtype_name: str) -> str: + # Emit stable bit-pattern literals for values that are parse-sensitive + # (`inf`/`nan`) or sign-sensitive (`-0.0`). + if math.isnan(value): + return self._float_nan_bit_pattern(dtype_name) + if math.isinf(value): + sign_bit = value < 0.0 + return self._float_inf_bit_pattern(dtype_name, sign_bit=sign_bit) + if value == 0.0 and math.copysign(1.0, value) < 0.0: + return self._float_to_bit_pattern_literal(value, dtype_name) + return str(value) + + def _float_nan_bit_pattern(self, dtype_name: str) -> str: + if dtype_name == "f16": + return "0x7E00" + if dtype_name == "bf16": + return "0x7FC0" + if dtype_name == "f32": + return "0x7FC00000" + raise NotImplementedError( + f"unsupported float dtype {dtype_name!r} for NaN constant emission" + ) + + def _float_inf_bit_pattern(self, dtype_name: str, *, sign_bit: bool) -> str: + if dtype_name == "f16": + return "0xFC00" if sign_bit else "0x7C00" + if dtype_name == "bf16": + return "0xFF80" if sign_bit else "0x7F80" + if dtype_name == "f32": + return "0xFF800000" if sign_bit else "0x7F800000" + raise NotImplementedError( + f"unsupported float dtype {dtype_name!r} for inf constant emission" + ) + + def _float_to_bit_pattern_literal(self, value: float, dtype_name: str) -> str: + if dtype_name == "f16": + bits = struct.unpack(">H", struct.pack(">e", value))[0] + return f"0x{bits:04X}" + if dtype_name == "bf16": + bits = struct.unpack(">I", struct.pack(">f", value))[0] >> 16 + return f"0x{bits:04X}" + if dtype_name == "f32": + bits = struct.unpack(">I", struct.pack(">f", value))[0] + return f"0x{bits:08X}" + raise NotImplementedError( + f"unsupported float dtype {dtype_name!r} for bit-pattern emission" + ) + + def _render_binary_op(self, op: str, ty: SemanticType) -> str: + if isinstance(ty, SemanticIndexType): + if op == "add": + return "arith.addi" + if op == "sub": + return "arith.subi" + if op == "mul": + return "arith.muli" + if op == "mod": + if isinstance(ty, SemanticIndexType): + return "arith.remui" + if op == "floordiv": + return "arith.divui" + if isinstance(ty, SemanticScalarType): + dtype = ty.dtype + if is_float_dtype(dtype): + if op == "add": + return "arith.addf" + if op == "sub": + return "arith.subf" + if op == "mul": + return "arith.mulf" + if is_integer_dtype(dtype): + if op == "add": + return "arith.addi" + if op == "sub": + return "arith.subi" + if op == "mul": + return "arith.muli" + if op == "mod": + sign = integer_signedness(dtype) + return "arith.remui" if sign == "unsigned" else "arith.remsi" + if op == "floordiv": + sign = integer_signedness(dtype) + return "arith.divui" if sign == "unsigned" else "arith.floordivsi" + if op == "bitand": + return "arith.andi" + if op == "bitor": + return "arith.ori" + if op == "bitxor": + return "arith.xori" + if op == "lshift": + return "arith.shli" + if op == "rshift": + sign = integer_signedness(dtype) + return "arith.shrui" if sign == "unsigned" else "arith.shrsi" + raise NotImplementedError(f"unsupported binary op '{op}' for type {ty!r}") + + def _render_type(self, ty: SemanticType) -> str: + if isinstance(ty, _RenderedTextualType): + return ty.text + if isinstance(ty, SemanticIndexType): + return "index" + if isinstance(ty, SemanticScalarType): + return ty.dtype.name + if isinstance(ty, SemanticPtrType): + return f"!pto.ptr<{ty.element_dtype.name}, {ty.memory_space}>" + if isinstance(ty, SemanticTensorViewType): + return self._render_tensor_view_type( + element_dtype=ty.element_dtype.name, + shape=("?",) * ty.rank, + ) + if isinstance(ty, SemanticPartitionTensorViewType): + return self._render_partition_tensor_view_type( + element_dtype=ty.element_dtype.name, + shape=("?",) * ty.rank, + ) + if isinstance(ty, SemanticTileType): + return self._render_tile_buf_type(ty) + if isinstance(ty, SemanticAlignType): + return "!pto.align" + if isinstance(ty, SemanticMaskType): + return f"!pto.mask<{ty.granularity}>" + if isinstance(ty, SemanticVRegType): + return f"!pto.vreg<{ty.lanes}x{ty.element_dtype.name}>" + if isinstance(ty, SemanticVectorType): + dims = "x".join(str(dim) for dim in ty.shape) + return f"vector<{dims}x{ty.element_dtype.name}>" + raise NotImplementedError(f"unsupported semantic type {ty!r}") + + def _is_memref_like_type(self, ty: SemanticType) -> bool: + return isinstance(ty, (SemanticTensorViewType, SemanticPartitionTensorViewType, SemanticTileType)) or ( + isinstance(ty, _RenderedTextualType) and ty.text.startswith("memref<") + ) + + def _render_copy_buffer_type(self, ty: SemanticType) -> str: + if isinstance(ty, SemanticPtrType): + return self._render_type(ty) + if isinstance(ty, (SemanticTensorViewType, SemanticPartitionTensorViewType)): + return f"!pto.ptr<{ty.element_dtype.name}, gm>" + if isinstance(ty, SemanticTileType): + memory_space = ty.memory_space or "ub" + return f"!pto.ptr<{ty.element_dtype.name}, {memory_space}>" + return self._render_type(ty) + + def _render_memref_type( + self, + *, + element_dtype: str, + shape: tuple[int | str, ...], + memory_space: str, + ) -> str: + dims = "x".join(str(dim) for dim in shape) + return f"memref<{dims}x{element_dtype}, {self._render_memref_memory_space(memory_space)}>" + + def _render_tensor_view_type( + self, + *, + element_dtype: str, + shape: tuple[int | str, ...], + ) -> str: + dims = "x".join(str(dim) for dim in shape) + return f"!pto.tensor_view<{dims}x{element_dtype}>" + + def _render_partition_tensor_view_type( + self, + *, + element_dtype: str, + shape: tuple[int | str, ...], + ) -> str: + dims = "x".join(str(dim) for dim in shape) + return f"!pto.partition_tensor_view<{dims}x{element_dtype}>" + + def _render_memref_memory_space(self, memory_space: str) -> str: + if memory_space == "gm": + return "#pto.address_space" + if memory_space == "ub": + return "#pto.address_space" + if memory_space in {"mat", "left", "right", "acc", "bias"}: + return f"#pto.address_space<{memory_space}>" + raise NotImplementedError(f"unsupported memref memory space '{memory_space}' in TileLang DSL v1 lowering") + + def _render_tile_buf_type(self, ty: SemanticTileType) -> str: + if ty.shape is None: + raise NotImplementedError("tile_buf lowering requires statically specialized Tile shape") + if ty.rank not in (1, 2): + raise NotImplementedError("tile_buf lowering only supports rank-1 or rank-2 Tile values") + rows = ty.shape[0] + cols = 1 if ty.rank == 1 else ty.shape[1] + valid_shape = ty.valid_shape or ty.shape + v_row = valid_shape[0] + v_col = 1 if ty.rank == 1 else valid_shape[1] + config = ty.config or TileConfig() + return ( + f"!pto.tile_buf" + ) + + def _render_tile_buf_loc(self, memory_space: str) -> str: + if memory_space == "ub": + return "vec" + if memory_space == "gm": + return "gm" + if memory_space in {"mat", "left", "right", "acc", "bias"}: + return memory_space + raise NotImplementedError(f"unsupported tile_buf memory space '{memory_space}'") + + def _render_tile_buf_dim(self, dim: int | None) -> str: + return "?" if dim is None else str(dim) + + def _render_tile_buf_pad_value(self, pad_value: PadValue) -> str: + if pad_value.is_custom: + raise NotImplementedError( + "custom TileConfig.pad_value MLIR type rendering requires PTO tile_buf parser support for custom pad encodings" + ) + return str(pad_value.encoded) + + def _dtype_byte_width(self, dtype: ScalarType) -> int: + try: + return bytewidth(dtype) + except TypeError as exc: + raise NotImplementedError(f"unsupported DMA dtype '{dtype.name}' in TileLang DSL v1 lowering") from exc + + def _indent(self, indent: int) -> str: + return " " * indent + + def _new_temp(self) -> str: + name = f"%tmp_{self._temp_counter}" + self._temp_counter += 1 + return name + + +def lower_semantic_kernel(kernel: SemanticKernel) -> AuthoringModule: + """Lower the semantic model to the current authoring-form VPTO builder.""" + + return AuthoringModule(kernel=kernel) + + +__all__ = ["AuthoringModule", "lower_semantic_kernel"] diff --git a/tilelang-dsl/python/tilelang_dsl/semantic.py b/tilelang-dsl/python/tilelang_dsl/semantic.py new file mode 100644 index 000000000..9c06799b2 --- /dev/null +++ b/tilelang-dsl/python/tilelang_dsl/semantic.py @@ -0,0 +1,7422 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""Semantic model for TileLang DSL descriptor lowering.""" + +from __future__ import annotations + +import ast +import struct +from dataclasses import dataclass +from typing import Any + +from .frontend_ast import ( + FrontendAssignStmt, + FrontendAttributeExpr, + FrontendBinaryExpr, + FrontendCallExpr, + FrontendConstantExpr, + FrontendExprNode, + FrontendExprStmt, + FrontendForStmt, + FrontendIfStmt, + FrontendInlineProcNode, + FrontendKernelNode, + FrontendNameExpr, + FrontendNameTarget, + FrontendNoOpStmt, + FrontendReturnStmt, + FrontendSliceExpr, + FrontendSourceLocation, + FrontendStrictVecscopeStmt, + FrontendStmtNode, + FrontendSubscriptExpr, + FrontendSymbolExpr, + FrontendTargetNode, + FrontendTupleExpr, + FrontendTupleTarget, + FrontendVecscopeStmt, +) +from .support_matrix import ( + DEFERRED_PTO_SURFACES, + advanced_mode_message, + deferred_surface_message, + unsupported_feature_message, +) +from .types import ( + AlignType, + BarrierType, + BLayout, + CmpMode, + DeinterleaveDist, + Event, + InterleaveDist, + MaskType, + MaskPattern, + MemorySpace, + OrderMode, + PadMode, + PadValue, + PredicateDist, + PredicatePart, + Pipe, + PostUpdateMode, + PositionMode, + PointerType, + ScalarType, + SLayout, + TileConfig, + FractalMode, + VcvtPartMode, + VcvtRoundMode, + VcvtSatMode, + VLoadDist, + VRegType, + VStoreDist, + bf16, + bytewidth, + f16, + f32, + i1, + i8, + i16, + i32, + i64, + integer_bitwidth, + integer_signedness, + is_float_dtype, + is_integer_dtype, + si8, + si16, + si32, + si64, + align, + ui8, + ui16, + ui32, + ui64, + VectorType, +) + + +_DTYPE_SYMBOLS = { + "i1": i1, + "i8": i8, + "si8": si8, + "ui8": ui8, + "i16": i16, + "si16": si16, + "ui16": ui16, + "i32": i32, + "si32": si32, + "ui32": ui32, + "i64": i64, + "si64": si64, + "ui64": ui64, + "f16": f16, + "bf16": bf16, + "f32": f32, +} +_MASK_TYPE_SYMBOLS = { + "mask_b8": MaskType("b8"), + "mask_b16": MaskType("b16"), + "mask_b32": MaskType("b32"), +} +_PATTERN_SYMBOLS = {pattern.name: pattern for pattern in MaskPattern} +_PIPE_SYMBOLS = {pipe.name: pipe for pipe in Pipe} +_EVENT_SYMBOLS = {event.name: event for event in Event} +_BARRIER_TYPE_SYMBOLS = {barrier_type.name: barrier_type for barrier_type in BarrierType} +_MEMORY_SPACE_SYMBOLS = {memory_space.name: memory_space for memory_space in MemorySpace} +_PAD_MODE_SYMBOLS = {pad_mode.name: pad_mode for pad_mode in PadMode} +_B_LAYOUT_SYMBOLS = {layout.name: layout for layout in BLayout} +_S_LAYOUT_SYMBOLS = {layout.name: layout for layout in SLayout} +_PAD_VALUE_SYMBOLS = { + pad_value.name: pad_value + for pad_value in (PadValue.NULL, PadValue.ZERO, PadValue.MAX, PadValue.MIN) +} +_PREDICATE_DIST_SYMBOLS = {dist.name: dist for dist in PredicateDist} +_VLOAD_DIST_SYMBOLS = {dist.name: dist for dist in VLoadDist} +_VSTORE_DIST_SYMBOLS = {dist.name: dist for dist in VStoreDist} +_PREDICATE_PART_SYMBOLS = {part.name: part for part in PredicatePart} +_CMP_MODE_SYMBOLS = {mode.name: mode for mode in CmpMode} +_DEINTERLEAVE_DIST_SYMBOLS = dict(DeinterleaveDist.__members__) +_INTERLEAVE_DIST_SYMBOLS = dict(InterleaveDist.__members__) +_POSITION_MODE_SYMBOLS = {position_mode.name: position_mode for position_mode in PositionMode} +_ORDER_MODE_SYMBOLS = {order_mode.name: order_mode for order_mode in OrderMode} +_VCVT_ROUND_MODE_SYMBOLS = {mode.name: mode for mode in VcvtRoundMode} +_VCVT_SAT_MODE_SYMBOLS = {mode.name: mode for mode in VcvtSatMode} +_VCVT_PART_MODE_SYMBOLS = {mode.name: mode for mode in VcvtPartMode} +_POST_UPDATE_MODE_SYMBOLS = {mode.name: mode for mode in PostUpdateMode} +_FRACTAL_MODE_SYMBOLS = {mode.name: mode for mode in FractalMode} +_TILE_CONSTRUCTOR_ALLOWED_KEYWORDS = frozenset( + { + "valid_shape", + "blayout", + "slayout", + "fractal_size", + "pad_value", + "compact_mode", + "addr", + } +) +_VCVT_ATTR_CONTRACTS: dict[tuple[str, str], tuple[bool, bool, bool]] = { + # (src_kind, dst_kind): (requires_rnd, requires_sat, requires_part) + ("f32", "f16"): (True, True, True), + ("f32", "bf16"): (True, True, True), + ("f32", "s16"): (True, True, True), + ("f32", "s64"): (True, True, True), + ("f32", "s32"): (True, True, False), + ("f16", "f32"): (False, False, True), + ("f16", "s32"): (True, False, True), + ("f16", "s16"): (True, True, False), + ("f16", "s8"): (True, True, True), + ("f16", "u8"): (True, True, True), + ("bf16", "f16"): (True, True, False), + ("bf16", "f32"): (False, False, True), + ("bf16", "s32"): (True, True, True), + ("u8", "f16"): (False, False, True), + ("u8", "u16"): (False, False, True), + ("u8", "u32"): (False, False, True), + ("s8", "f16"): (False, False, True), + ("s8", "s16"): (False, False, True), + ("s8", "s32"): (False, False, True), + ("u16", "u8"): (False, True, True), + ("u16", "u32"): (False, False, True), + ("s16", "f16"): (True, False, False), + ("s16", "f32"): (False, False, True), + ("s16", "u32"): (False, False, True), + ("s16", "s32"): (False, False, True), + ("s16", "u8"): (False, True, True), + ("u32", "u8"): (False, True, True), + ("u32", "u16"): (False, True, True), + ("u32", "s16"): (False, True, True), + ("s32", "f32"): (True, False, False), + ("s32", "u8"): (False, True, True), + ("s32", "u16"): (False, True, True), + ("s32", "s16"): (False, True, True), + ("s32", "s64"): (False, False, True), + ("s64", "f32"): (True, False, True), + ("s64", "s32"): (False, True, True), +} + + +def _classify_vcvt_elem_kind(dtype: ScalarType) -> str | None: + if dtype == f16: + return "f16" + if dtype == bf16: + return "bf16" + if dtype == f32: + return "f32" + if not is_integer_dtype(dtype): + return None + width = integer_bitwidth(dtype) + sign = integer_signedness(dtype) + is_unsigned = sign == "unsigned" + if width == 8: + return "u8" if is_unsigned else "s8" + if width == 16: + return "u16" if is_unsigned else "s16" + if width == 32: + return "u32" if is_unsigned else "s32" + if width == 64: + return None if is_unsigned else "s64" + return None +_UNARY_VECTOR_OPS = { + "vabs", + "vrelu", + "vexp", + "vln", + "vsqrt", + "vrec", + "vnot", + "vcadd", + "vcmax", + "vbcnt", + "vneg", + "vcls", + "vcmin", + "vrsqrt", + "vmov", + "vsunpack", + "vzunpack", + "vusqz", + "vsqz", + "vtrc", + "vcgadd", + "vcgmax", + "vcgmin", + "vcpadd", + "vsort32", +} +_BINARY_VECTOR_OPS = { + "vadd", + "vsub", + "vmul", + "vdiv", + "vmod", + "vmax", + "vmin", + "vand", + "vor", + "vxor", + "vaddrelu", + "vaddreluconv", + "vsubrelu", + "vmulconv", + "vshl", + "vshr", + "vprelu", + "vpack", + "vperm", + "vmrgsort", +} +_CUBE_MATMUL_OPS = { + "mad", + "mad_acc", + "mad_bias", + "mad_mx", + "mad_mx_acc", + "mad_mx_bias", +} +_CUBE_TRANSFER_OPS = { + "cube_load", + "cube_store", + "cube_load_frac", + "bias_load", + "left_load", + "right_load", + "left_load_mx", + "right_load_mx", + "acc_store", + "acc_store_gm", + "acc_store_ub", +} +_CUBE_CALL_OPS = _CUBE_MATMUL_OPS | _CUBE_TRANSFER_OPS +_VECTOR_SCALAR_OPS = { + "vadds", + "vsubs", + "vmuls", + "vdivs", + "vmaxs", + "vmins", + "vlrelu", + "vshls", + "vshrs", + "vands", + "vors", + "vxors", +} +_VECTOR_IMMEDIATE_OPS = {"vshift", "vslide"} +_TERNARY_VECTOR_OPS = {"vaxpy", "vmula"} +_MULTI_RESULT_VECTOR_OPS = {"vmull", "vldsx2", "vldus", "pstu"} +_BROADCAST_VECTOR_OPS = {"vbr", "vdup", "vci"} +_VEXPDIF_OP_ALIASES = {"vexpdif", "vexpdiff"} +_LOW_LEVEL_DMA_UNARY_CONFIG_OPS = {"set_mov_pad_val"} +_LOW_LEVEL_DMA_CONFIG_OPS = { + "set_loop2_stride_outtoub", + "set_loop1_stride_outtoub", + "set_loop_size_outtoub", + "set_loop2_stride_ubtoout", + "set_loop1_stride_ubtoout", + "set_loop_size_ubtoout", +} +_LOW_LEVEL_DMA_COPY_OPS = { + "copy_gm_to_ubuf", + "copy_ubuf_to_gm", + "copy_ubuf_to_ubuf", +} + + +def _is_supported_mov_pad_scalar_dtype(dtype: ScalarType) -> bool: + if is_integer_dtype(dtype): + return integer_bitwidth(dtype) in {8, 16, 32} + return dtype.name in {"f16", "bf16", "f32"} + + +_UB_HELPER_OPS = {"vbitsort", "vmrgsort4"} +_TENSORVIEW_RANK = 5 + + +class SemanticType: + """Base class for semantic value types.""" + + +@dataclass(frozen=True) +class SemanticTensorViewType(SemanticType): + element_dtype: ScalarType + rank: int = _TENSORVIEW_RANK + + +@dataclass(frozen=True) +class SemanticPartitionTensorViewType(SemanticType): + element_dtype: ScalarType + rank: int = _TENSORVIEW_RANK + + +@dataclass(frozen=True) +class SemanticTensorSliceType(SemanticType): + element_dtype: ScalarType + rank: int + extents: tuple[int | None, ...] + physical_axes: tuple[int, ...] + + +@dataclass(frozen=True) +class SemanticTileType(SemanticType): + element_dtype: ScalarType + rank: int + shape: tuple[int, ...] | None + valid_shape: tuple[int | None, ...] | None + memory_space: str | None + config: TileConfig | None + + +@dataclass(frozen=True) +class SemanticTileConfigType(SemanticType): + element_dtype: ScalarType | None = None + + +@dataclass(frozen=True) +class SemanticScalarType(SemanticType): + dtype: ScalarType + + +@dataclass(frozen=True) +class SemanticPtrType(SemanticType): + element_dtype: ScalarType + memory_space: str + + +@dataclass(frozen=True) +class SemanticIndexType(SemanticType): + pass + + +@dataclass(frozen=True) +class SemanticShapeType(SemanticType): + rank: int + + +@dataclass(frozen=True) +class SemanticSliceType(SemanticType): + pass + + +@dataclass(frozen=True) +class SemanticTupleType(SemanticType): + elements: tuple[SemanticType, ...] + + +@dataclass(frozen=True) +class SemanticMetaType(SemanticType): + kind: str + + +@dataclass(frozen=True) +class SemanticPadValueType(SemanticType): + element_dtype: ScalarType | None = None + + +@dataclass(frozen=True) +class SemanticAlignType(SemanticType): + pass + + +@dataclass(frozen=True) +class SemanticMaskType(SemanticType): + granularity: str + + +@dataclass(frozen=True) +class SemanticVRegType(SemanticType): + element_dtype: ScalarType + lanes: int + + +@dataclass(frozen=True) +class SemanticVectorType(SemanticType): + element_dtype: ScalarType + shape: tuple[int, ...] + + +_I32_TYPE = SemanticScalarType(dtype=i32) + + +@dataclass(frozen=True) +class SemanticBinding: + name: str + ssa_name: str + type: SemanticType + origin: str + value: Any | None = None + + +@dataclass(frozen=True) +class SemanticTileBinding: + name: str + shape: tuple[int, ...] + valid_shape: tuple[int | None, ...] | None + memory_space: str + config: Any + + +class SemanticExpr: + """Base class for typed semantic expressions.""" + + +@dataclass(frozen=True) +class SemanticBindingRef(SemanticExpr): + binding: SemanticBinding + type: SemanticType + + +@dataclass(frozen=True) +class SemanticLiteralExpr(SemanticExpr): + value: Any + type: SemanticType + + +@dataclass(frozen=True) +class SemanticSymbolExpr(SemanticExpr): + namespace: str + name: str + value: Any + type: SemanticType + + +@dataclass(frozen=True) +class SemanticSliceExpr(SemanticExpr): + start: SemanticExpr | None + stop: SemanticExpr | None + step: SemanticExpr | None + type: SemanticSliceType + + +@dataclass(frozen=True) +class SemanticTensorSliceAxis: + start: SemanticExpr + stop: SemanticExpr + step: SemanticExpr + extent: int | None + + +@dataclass(frozen=True) +class SemanticTupleExpr(SemanticExpr): + elements: tuple[SemanticExpr, ...] + type: SemanticTupleType + + +@dataclass(frozen=True) +class SemanticAttributeAccess(SemanticExpr): + base: SemanticExpr + attr: str + type: SemanticType + + +@dataclass(frozen=True) +class SemanticSubscriptAccess(SemanticExpr): + base: SemanticExpr + index: SemanticExpr + type: SemanticType + + +@dataclass(frozen=True) +class SemanticTensorSliceExpr(SemanticExpr): + base: SemanticExpr + slices: tuple[SemanticTensorSliceAxis, ...] + type: SemanticTensorSliceType + + +@dataclass(frozen=True) +class SemanticBinaryExpr(SemanticExpr): + lhs: SemanticExpr + op: str + rhs: SemanticExpr + type: SemanticType + + +@dataclass(frozen=True) +class SemanticIndexCastExpr(SemanticExpr): + value: SemanticExpr + type: SemanticIndexType + + +@dataclass(frozen=True) +class SemanticCallExpr(SemanticExpr): + namespace: str | None + name: str + args: tuple[SemanticExpr, ...] + type: SemanticType | None + + +class SemanticStmt: + """Base class for semantic statements.""" + + +@dataclass(frozen=True) +class SemanticAssignStmt(SemanticStmt): + targets: tuple[SemanticBinding, ...] + value: SemanticExpr + annotation: Any | None = None + + +@dataclass(frozen=True) +class SemanticExprStmt(SemanticStmt): + expr: SemanticExpr + + +@dataclass(frozen=True) +class SemanticDmaOptions: + pad_mode: SemanticExpr | None = None + pad_value: SemanticExpr | None = None + left_padding: SemanticExpr | None = None + right_padding: SemanticExpr | None = None + init_out_buffer: SemanticExpr | None = None + + +@dataclass(frozen=True) +class SemanticDmaLoadStmt(SemanticStmt): + src: SemanticTensorSliceExpr + dst: SemanticExpr + options: SemanticDmaOptions = SemanticDmaOptions() + + +@dataclass(frozen=True) +class SemanticDmaStoreStmt(SemanticStmt): + src: SemanticExpr + dst: SemanticTensorSliceExpr + options: SemanticDmaOptions = SemanticDmaOptions() + + +@dataclass(frozen=True) +class SemanticVectorStoreStmt(SemanticStmt): + value: SemanticExpr + destination: SemanticExpr + indices: tuple[SemanticExpr, ...] + dist: SemanticExpr | None + mask: SemanticExpr + + +@dataclass(frozen=True) +class SemanticVectorPairStoreStmt(SemanticStmt): + low: SemanticExpr + high: SemanticExpr + destination: SemanticExpr + indices: tuple[SemanticExpr, ...] + dist: SemanticExpr + mask: SemanticExpr + + +@dataclass(frozen=True) +class SemanticVScatterStmt(SemanticStmt): + value: SemanticExpr + destination: SemanticExpr + offsets: SemanticExpr + mask: SemanticExpr + + +@dataclass(frozen=True) +class SemanticPredicateStoreStmt(SemanticStmt): + op_name: str + value: SemanticExpr + destination: SemanticExpr + indices: tuple[SemanticExpr, ...] + dist: SemanticExpr + + +@dataclass(frozen=True) +class SemanticAlignStoreStmt(SemanticStmt): + op_name: str + value: SemanticExpr + destination: SemanticExpr + indices: tuple[SemanticExpr, ...] = () + offset: SemanticExpr | None = None + + +@dataclass(frozen=True) +class SemanticScalarStoreStmt(SemanticStmt): + value: SemanticExpr + destination: SemanticExpr + offset: SemanticExpr + + +@dataclass(frozen=True) +class SemanticVecscopeStmt(SemanticStmt): + body: tuple[SemanticStmt, ...] + + +@dataclass(frozen=True) +class SemanticSetFlagStmt(SemanticStmt): + src_pipe: str + dst_pipe: str + event: str + + +@dataclass(frozen=True) +class SemanticWaitFlagStmt(SemanticStmt): + src_pipe: str + dst_pipe: str + event: str + + +@dataclass(frozen=True) +class SemanticPipeBarrierStmt(SemanticStmt): + pipe: str + + +@dataclass(frozen=True) +class SemanticGetBufStmt(SemanticStmt): + pipe: str + buf_id: SemanticExpr + mode: SemanticExpr + + +@dataclass(frozen=True) +class SemanticRlsBufStmt(SemanticStmt): + pipe: str + buf_id: SemanticExpr + mode: SemanticExpr + + +@dataclass(frozen=True) +class SemanticMemBarStmt(SemanticStmt): + barrier_type: str + + +@dataclass(frozen=True) +class SemanticSetCrossCoreStmt(SemanticStmt): + core_id: SemanticExpr + event_id: SemanticExpr + + +@dataclass(frozen=True) +class SemanticSetIntraBlockStmt(SemanticStmt): + block_id: SemanticExpr + event_id: SemanticExpr + + +@dataclass(frozen=True) +class SemanticSetIntraCoreStmt(SemanticStmt): + config: SemanticExpr + + +@dataclass(frozen=True) +class SemanticWaitFlagDevStmt(SemanticStmt): + core_id: SemanticExpr + event_id: SemanticExpr + + +@dataclass(frozen=True) +class SemanticWaitIntraCoreStmt(SemanticStmt): + block_id: SemanticExpr + event_id: SemanticExpr + + +@dataclass(frozen=True) +class SemanticDmaConfigStmt(SemanticStmt): + name: str + first: SemanticExpr + second: SemanticExpr + + +@dataclass(frozen=True) +class SemanticDmaUnaryConfigStmt(SemanticStmt): + name: str + value: SemanticExpr + + +@dataclass(frozen=True) +class SemanticLowLevelCopyStmt(SemanticStmt): + name: str + source: SemanticExpr + destination: SemanticExpr + operands: tuple[SemanticExpr, ...] + + +@dataclass(frozen=True) +class SemanticIfResult: + result_binding: SemanticBinding + then_binding: SemanticBinding + else_binding: SemanticBinding + + +@dataclass(frozen=True) +class SemanticIfStmt(SemanticStmt): + condition: SemanticExpr + then_body: tuple[SemanticStmt, ...] + else_body: tuple[SemanticStmt, ...] + results: tuple[SemanticIfResult, ...] + + +@dataclass(frozen=True) +class SemanticReturnStmt(SemanticStmt): + value: SemanticExpr | None + + +@dataclass(frozen=True) +class SemanticForStmt(SemanticStmt): + induction_variable: SemanticBinding + lower_bound: SemanticExpr + upper_bound: SemanticExpr + step: SemanticExpr + body: tuple[SemanticStmt, ...] + loop_carried: tuple[SemanticBinding, ...] + + +@dataclass(frozen=True) +class SemanticStrictVecscopeStmt(SemanticStmt): + captures: tuple[SemanticExpr, ...] + block_arguments: tuple[SemanticBinding, ...] + body: tuple[SemanticStmt, ...] + + +@dataclass(frozen=True) +class SemanticParameter: + binding: SemanticBinding + + @property + def name(self) -> str: + return self.binding.name + + @property + def kind(self) -> str: + return self.binding.origin + + @property + def type(self) -> SemanticType: + return self.binding.type + + @property + def ssa_name(self) -> str: + return self.binding.ssa_name + + +@dataclass(frozen=True) +class SemanticKernel: + target: str + op: str + symbol_name: str + kernel_family: str + verify_enabled: bool + advanced_enabled: bool + dtype_signature: tuple[Any, ...] | None + parameters: tuple[SemanticParameter, ...] + tile_bindings: tuple[SemanticTileBinding, ...] + body: tuple[SemanticStmt, ...] + inline_helpers: tuple["SemanticKernel", ...] = () + + +class _SemanticAnalyzer: + def __init__(self, node: FrontendKernelNode): + self.node = node + self._context_attrs = dict(node.context_attrs) + self._counter = 0 + self._tile_specializations = { + spec.name: spec for spec in node.tile_specializations + } + self._hidden_parameters: list[SemanticParameter] = [] + self._inline_proc_nodes: dict[str, FrontendInlineProcNode] = { + inline_proc.name: inline_proc for inline_proc in node.inline_procs + } + self._internal_inline_proc_nodes: dict[str, FrontendInlineProcNode] = { + inline_proc.name: inline_proc for inline_proc in node.internal_inline_procs + } + self._inline_proc_specializations: dict[ + tuple[str, tuple[tuple[SemanticType, object], ...]], SemanticKernel + ] = {} + self._inline_proc_return_types: dict[ + tuple[str, tuple[tuple[SemanticType, object], ...]], SemanticType | None + ] = {} + self._inline_proc_order: list[tuple[str, tuple[tuple[SemanticType, object], ...]]] = [] + self._inline_proc_active_stack: list[tuple[str, tuple[tuple[SemanticType, object], ...]]] = [] + + def _expr_source_location( + self, + expr: FrontendExprNode | SemanticExpr, + ) -> FrontendSourceLocation | None: + return getattr(expr, "source_location", None) + + def _attach_expr_source_location( + self, + semantic_expr: SemanticExpr, + frontend_expr: FrontendExprNode, + ) -> SemanticExpr: + source_location = self._expr_source_location(frontend_expr) + if source_location is not None: + object.__setattr__(semantic_expr, "source_location", source_location) + return semantic_expr + + def _format_source_message( + self, + message: str, + expr: FrontendExprNode | SemanticExpr | None = None, + ) -> str: + if expr is None: + return message + source_location = self._expr_source_location(expr) + if source_location is None: + return message + return ( + f"{source_location.path}:{source_location.line}:{source_location.column}: " + f"{message}" + ) + + def _raise_expr_type_error( + self, + message: str, + expr: FrontendExprNode | SemanticExpr | None = None, + ) -> None: + raise TypeError(self._format_source_message(message, expr)) + + def analyze(self) -> SemanticKernel: + env: dict[str, SemanticBinding] = {} + parameters = [] + for index, param in enumerate(self.node.parameters): + binding = SemanticBinding( + name=param.name, + ssa_name=f"%arg{index}", + type=self._parameter_type(param), + origin=param.kind, + ) + env[param.name] = binding + parameters.append(SemanticParameter(binding=binding)) + body, _ = self._analyze_kernel_body(env) + parameters.extend(self._hidden_parameters) + tile_bindings = tuple( + SemanticTileBinding( + name=spec.name, + shape=spec.shape, + valid_shape=spec.valid_shape, + memory_space=spec.memory_space, + config=spec.config, + ) + for spec in self.node.tile_specializations + ) + return SemanticKernel( + target=self.node.target, + op=self.node.op, + symbol_name=self.node.name, + kernel_family=self.node.kernel_family, + verify_enabled=self.node.verify_enabled, + advanced_enabled=self.node.advanced_enabled, + dtype_signature=self.node.dtype_signature, + parameters=tuple(parameters), + tile_bindings=tile_bindings, + body=body, + inline_helpers=tuple( + self._inline_proc_specializations[key] + for key in self._inline_proc_order + ), + ) + + def _analyze_kernel_body( + self, + env: dict[str, SemanticBinding], + ) -> tuple[tuple[SemanticStmt, ...], dict[str, SemanticBinding]]: + return self._analyze_block( + self.node.body, + env, + allow_outer_lookup=True, + ) + + def _parameter_type(self, param: Any) -> SemanticType: + if param.kind == "tensorview": + return SemanticTensorViewType( + element_dtype=param.dtype, + rank=_TENSORVIEW_RANK, + ) + if param.kind == "partition_tensor_view": + return SemanticPartitionTensorViewType( + element_dtype=param.dtype, + rank=_TENSORVIEW_RANK, + ) + if param.kind == "tile": + spec = self._tile_specializations.get(param.name) + rank = 2 if spec is None else len(spec.shape) + shape = None if spec is None else spec.shape + valid_shape = None if spec is None else ( + spec.shape if spec.valid_shape is None else spec.valid_shape + ) + memory_space = None if spec is None else spec.memory_space + return SemanticTileType( + element_dtype=param.dtype, + rank=rank, + shape=shape, + valid_shape=valid_shape, + memory_space=memory_space, + config=None if spec is None else (spec.config or TileConfig()), + ) + if param.kind == "vector": + vector_type = param.annotation + return SemanticVectorType( + element_dtype=param.dtype, + shape=vector_type.shape, + ) + if param.kind == "ptr": + memory_space = param.annotation.memory_space.value + return SemanticPtrType( + element_dtype=param.dtype, + memory_space=memory_space, + ) + if param.kind == "mask": + return SemanticMaskType(granularity=param.dtype.granularity) + if param.kind == "scalar": + return SemanticScalarType(dtype=param.dtype) + raise ValueError(f"unsupported parameter kind {param.kind!r}") + + def _new_ssa_name(self, stem: str) -> str: + name = f"%{stem}_{self._counter}" + self._counter += 1 + return name + + def _tensor_shape_binding_name(self, tensor_name: str, axis: int) -> str: + return f"__shape_{tensor_name}_{axis}" + + def _tensor_stride_binding_name(self, tensor_name: str, axis: int) -> str: + return f"__stride_{tensor_name}_{axis}" + + def _tile_valid_shape_binding_name(self, tile_name: str, axis: int) -> str: + return f"__valid_shape_{tile_name}_{axis}" + + def _ensure_hidden_parameter( + self, + hidden_name: str, + origin: str, + ) -> SemanticBinding: + for parameter in self._hidden_parameters: + if parameter.name == hidden_name: + return parameter.binding + binding = SemanticBinding( + name=hidden_name, + ssa_name=f"%arg{len(self.node.parameters) + len(self._hidden_parameters)}", + type=SemanticIndexType(), + origin=origin, + ) + self._hidden_parameters.append(SemanticParameter(binding=binding)) + return binding + + def _ensure_tensor_shape_parameter( + self, + tensor_binding: SemanticBinding, + axis: int, + ) -> SemanticBinding: + hidden_name = self._tensor_shape_binding_name(tensor_binding.name, axis) + return self._ensure_hidden_parameter(hidden_name, "tensorview_shape") + + def _ensure_tensor_stride_parameter( + self, + tensor_binding: SemanticBinding, + axis: int, + ) -> SemanticBinding: + hidden_name = self._tensor_stride_binding_name(tensor_binding.name, axis) + return self._ensure_hidden_parameter(hidden_name, "tensorview_stride") + + def _ensure_tile_valid_shape_parameter( + self, + tile_binding: SemanticBinding, + axis: int, + ) -> SemanticBinding: + hidden_name = self._tile_valid_shape_binding_name(tile_binding.name, axis) + return self._ensure_hidden_parameter(hidden_name, "tile_valid_shape") + + def _make_binding( + self, + name: str, + ty: SemanticType, + origin: str, + *, + value: Any | None = None, + ) -> SemanticBinding: + stem = name if name.isidentifier() else "v" + return SemanticBinding( + name=name, + ssa_name=self._new_ssa_name(stem), + type=ty, + origin=origin, + value=value, + ) + + def _analyze_block( + self, + statements: tuple[FrontendStmtNode, ...], + env: dict[str, SemanticBinding], + *, + allow_outer_lookup: bool, + ) -> tuple[tuple[SemanticStmt, ...], dict[str, SemanticBinding]]: + current_env = dict(env) + semantic_statements = [] + for stmt in statements: + emitted_stmts, current_env = self._analyze_stmt_or_inline( + stmt, + current_env, + allow_outer_lookup=allow_outer_lookup, + ) + semantic_statements.extend(emitted_stmts) + return tuple(semantic_statements), current_env + + def _analyze_stmt_or_inline( + self, + stmt: FrontendStmtNode, + env: dict[str, SemanticBinding], + *, + allow_outer_lookup: bool, + ) -> tuple[tuple[SemanticStmt, ...], dict[str, SemanticBinding]]: + if isinstance(stmt, FrontendNoOpStmt): + # Python `pass` lowers to a frontend no-op and does not materialize semantic IR. + return tuple(), dict(env) + if ( + isinstance(stmt, FrontendExprStmt) + and isinstance(stmt.expr, FrontendConstantExpr) + and isinstance(stmt.expr.value, str) + ): + # Treat Python docstring-style string expression statements as no-op. + return tuple(), dict(env) + if isinstance(stmt, FrontendIfStmt) and stmt.is_constexpr: + return self._analyze_constexpr_if( + stmt, + env, + allow_outer_lookup=allow_outer_lookup, + ) + semantic_stmt, updated_env = self._analyze_stmt( + stmt, + env, + allow_outer_lookup=allow_outer_lookup, + ) + return (semantic_stmt,), updated_env + + def _analyze_stmt( + self, + stmt: FrontendStmtNode, + env: dict[str, SemanticBinding], + *, + allow_outer_lookup: bool, + ) -> tuple[SemanticStmt, dict[str, SemanticBinding]]: + if isinstance(stmt, FrontendAssignStmt): + value = self._analyze_expr(stmt.value, env, allow_outer_lookup=allow_outer_lookup) + updated_env = dict(env) + targets = self._bind_assignment_target( + stmt.target, + value, + updated_env, + stmt.annotation, + ) + return ( + SemanticAssignStmt(targets=targets, value=value, annotation=stmt.annotation), + updated_env, + ) + if isinstance(stmt, FrontendExprStmt): + if self._is_dma_call(stmt.expr): + return self._analyze_dma_stmt(stmt.expr, env, allow_outer_lookup=allow_outer_lookup) + if self._is_sync_call(stmt.expr): + return self._analyze_sync_stmt(stmt.expr, env, allow_outer_lookup=allow_outer_lookup) + if self._is_low_level_dma_call(stmt.expr): + return self._analyze_low_level_dma_stmt( + stmt.expr, + env, + allow_outer_lookup=allow_outer_lookup, + ) + if self._is_vector_store_call(stmt.expr): + return self._analyze_vector_store_stmt(stmt.expr, env, allow_outer_lookup=allow_outer_lookup) + if self._is_scalar_store_call(stmt.expr): + return self._analyze_scalar_store_stmt(stmt.expr, env, allow_outer_lookup=allow_outer_lookup) + expr = self._analyze_expr(stmt.expr, env, allow_outer_lookup=allow_outer_lookup) + return SemanticExprStmt(expr=expr), dict(env) + if isinstance(stmt, FrontendReturnStmt): + value = None + if stmt.value is not None: + value = self._analyze_expr(stmt.value, env, allow_outer_lookup=allow_outer_lookup) + return SemanticReturnStmt(value=value), dict(env) + if isinstance(stmt, FrontendForStmt): + return self._analyze_for(stmt, env, allow_outer_lookup=allow_outer_lookup) + if isinstance(stmt, FrontendIfStmt): + return self._analyze_if(stmt, env, allow_outer_lookup=allow_outer_lookup) + if isinstance(stmt, FrontendVecscopeStmt): + return self._analyze_explicit_vecscope(stmt, env, allow_outer_lookup=allow_outer_lookup) + if isinstance(stmt, FrontendStrictVecscopeStmt): + return self._analyze_strict_vecscope(stmt, env) + raise ValueError(f"unsupported frontend statement {type(stmt).__name__}") + + def _inline_proc_specialization_key( + self, + name: str, + args: tuple[SemanticExpr, ...], + *, + internal: bool = False, + ) -> tuple[str, tuple[tuple[SemanticType, object], ...]]: + specialization_name = f"__internal__::{name}" if internal else name + return ( + specialization_name, + tuple( + (arg.type, self._inline_proc_static_specialization_token(arg)) + for arg in args + ), + ) + + def _inline_proc_static_specialization_token( + self, + expr: SemanticExpr, + ) -> object: + if isinstance(expr, SemanticLiteralExpr) and expr.value is None: + return ("none",) + + if isinstance(expr.type, SemanticMetaType) and expr.type.kind in { + "dtype", + "ptr_type", + "mask_type", + }: + value = self._try_static_value(expr) + if value is not None: + return ("meta", expr.type.kind, value) + + value = self._try_static_value(expr) + if isinstance(value, bool): + return ("bool", value) + if isinstance(value, int) and not isinstance(value, bool): + return ("int", value) + if value is None: + return ("dynamic",) + return ("dynamic",) + + def _inline_proc_bound_static_value( + self, + expr: SemanticExpr, + ) -> Any | None: + token = self._inline_proc_static_specialization_token(expr) + kind = token[0] + if kind == "meta": + return token[2] + if kind in {"bool", "int"}: + return token[1] + if kind == "none": + return None + return None + + def _inline_proc_symbol_name( + self, + name: str, + index: int, + ) -> str: + sanitized = "".join(char if char.isalnum() else "_" for char in name) + return f"__tl_inline_{sanitized}_{index}" + + def _collect_inline_helper_tile_bindings( + self, + parameters: tuple[SemanticParameter, ...], + ) -> tuple[SemanticTileBinding, ...]: + tile_bindings: list[SemanticTileBinding] = [] + for parameter in parameters: + if not isinstance(parameter.type, SemanticTileType): + continue + if parameter.type.shape is None: + continue + tile_bindings.append( + SemanticTileBinding( + name=parameter.name, + shape=parameter.type.shape, + valid_shape=parameter.type.valid_shape, + memory_space=parameter.type.memory_space or "ub", + config=parameter.type.config or TileConfig(), + ) + ) + return tuple(tile_bindings) + + def _materialize_inline_proc_specialization( + self, + name: str, + args: tuple[SemanticExpr, ...], + *, + internal: bool = False, + ) -> SemanticKernel: + inline_proc_nodes = ( + self._internal_inline_proc_nodes if internal else self._inline_proc_nodes + ) + inline_proc_node = inline_proc_nodes.get(name) + if inline_proc_node is None: + raise TypeError(f"inline_proc `{name}` is not registered in the current TileLang module") + + key = self._inline_proc_specialization_key(name, args, internal=internal) + existing = self._inline_proc_specializations.get(key) + if existing is not None: + return existing + if key in self._inline_proc_active_stack: + raise TypeError( + f"recursive inline_proc call `{name}` is not supported in TileLang DSL v1" + ) + + if len(inline_proc_node.parameters) != len(args): + raise TypeError( + f"inline_proc `{name}` expects {len(inline_proc_node.parameters)} arguments in TileLang DSL v1" + ) + + helper_env: dict[str, SemanticBinding] = {} + helper_parameters: list[SemanticParameter] = [] + for index, (param, arg_expr) in enumerate(zip(inline_proc_node.parameters, args)): + binding = SemanticBinding( + name=param.name, + ssa_name=f"%arg{index}", + type=arg_expr.type, + origin="inline_param", + value=self._inline_proc_bound_static_value(arg_expr), + ) + helper_env[param.name] = binding + helper_parameters.append(SemanticParameter(binding=binding)) + + saved_hidden_parameters = self._hidden_parameters + self._hidden_parameters = [] + self._inline_proc_active_stack.append(key) + try: + body, _ = self._analyze_block( + inline_proc_node.body, + helper_env, + allow_outer_lookup=False, + ) + finally: + self._inline_proc_active_stack.pop() + helper_hidden_parameters = tuple(self._hidden_parameters) + self._hidden_parameters = saved_hidden_parameters + + if helper_hidden_parameters: + raise TypeError( + f"inline_proc `{name}` currently does not support dynamic shape metadata captures in TileLang DSL v1" + ) + + return_type: SemanticType | None = None + if body and isinstance(body[-1], SemanticReturnStmt): + return_type = None if body[-1].value is None else body[-1].value.type + + helper_index = len(self._inline_proc_order) + helper_kernel = SemanticKernel( + target=self.node.target, + op=self.node.op, + symbol_name=self._inline_proc_symbol_name(name, helper_index), + kernel_family=self.node.kernel_family, + verify_enabled=False, + advanced_enabled=self.node.advanced_enabled, + dtype_signature=self.node.dtype_signature, + parameters=tuple(helper_parameters), + tile_bindings=self._collect_inline_helper_tile_bindings(tuple(helper_parameters)), + body=body, + inline_helpers=(), + ) + self._inline_proc_specializations[key] = helper_kernel + self._inline_proc_return_types[key] = return_type + self._inline_proc_order.append(key) + return helper_kernel + + def _analyze_inline_proc_call_expr( + self, + name: str, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + helper_kernel = self._materialize_inline_proc_specialization(name, args) + key = self._inline_proc_specialization_key(name, args) + return SemanticCallExpr( + namespace=None, + name=helper_kernel.symbol_name, + args=args, + type=self._inline_proc_return_types.get(key), + ) + + def _analyze_internal_inline_proc_call_expr( + self, + name: str, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + helper_kernel = self._materialize_inline_proc_specialization( + name, + args, + internal=True, + ) + key = self._inline_proc_specialization_key(name, args, internal=True) + return SemanticCallExpr( + namespace=None, + name=helper_kernel.symbol_name, + args=args, + type=self._inline_proc_return_types.get(key), + ) + + def _is_internal_inline_proc_context(self) -> bool: + return any(key[0].startswith("__internal__::") for key in self._inline_proc_active_stack) + + def _analyze_explicit_vecscope( + self, + stmt: FrontendVecscopeStmt, + env: dict[str, SemanticBinding], + *, + allow_outer_lookup: bool, + ) -> tuple[SemanticStmt, dict[str, SemanticBinding]]: + body, updated_env = self._analyze_block( + stmt.body, + dict(env), + allow_outer_lookup=allow_outer_lookup, + ) + return SemanticVecscopeStmt(body=body), updated_env + + def _is_dma_call(self, expr: FrontendExprNode) -> bool: + return ( + isinstance(expr, FrontendCallExpr) + and expr.namespace == "pto" + and expr.name in {"dma_load", "dma_store"} + ) + + def _is_vector_store_call(self, expr: FrontendExprNode) -> bool: + return ( + isinstance(expr, FrontendCallExpr) + and expr.namespace == "pto" + and expr.name in {"psts", "pst", "psti", "vsst", "vsta", "vstas", "vstar", "vscatter", "vsts", "vstsx2"} + ) + + def _is_scalar_store_call(self, expr: FrontendExprNode) -> bool: + return ( + isinstance(expr, FrontendCallExpr) + and expr.namespace == "pto" + and expr.name == "store_scalar" + ) + + def _is_sync_call(self, expr: FrontendExprNode) -> bool: + return ( + isinstance(expr, FrontendCallExpr) + and expr.namespace == "pto" + and expr.name in { + "set_flag", + "wait_flag", + "pipe_barrier", + "barrier", + "get_buf", + "rls_buf", + "mem_bar", + "set_cross_core", + "set_intra_block", + "set_intra_core", + "wait_flag_dev", + "wait_intra_core", + } + ) + + def _is_ub_helper_call(self, expr: FrontendExprNode) -> bool: + return ( + isinstance(expr, FrontendCallExpr) + and expr.namespace == "pto" + and expr.name in _UB_HELPER_OPS + ) + + def _is_low_level_dma_call(self, expr: FrontendExprNode) -> bool: + return ( + isinstance(expr, FrontendCallExpr) + and expr.namespace == "pto" + and expr.name in _LOW_LEVEL_DMA_UNARY_CONFIG_OPS | _LOW_LEVEL_DMA_CONFIG_OPS | _LOW_LEVEL_DMA_COPY_OPS + ) + + def _analyze_dma_stmt( + self, + expr: FrontendCallExpr, + env: dict[str, SemanticBinding], + *, + allow_outer_lookup: bool, + ) -> tuple[SemanticStmt, dict[str, SemanticBinding]]: + args = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args + ) + if expr.name == "dma_load": + if len(args) != 2: + raise TypeError("pto.dma_load expects exactly 2 positional arguments in TileLang DSL v1") + src = self._require_tensor_slice(args[0], "pto.dma_load source") + dst = self._require_tile_expr(args[1], "pto.dma_load destination") + options = self._analyze_dma_options( + expr.keywords, + env, + allow_outer_lookup=allow_outer_lookup, + context="pto.dma_load", + ) + self._validate_dma_load_profile(src, dst, options) + return SemanticDmaLoadStmt(src=src, dst=dst, options=options), dict(env) + if expr.name == "dma_store": + if len(args) != 2: + raise TypeError("pto.dma_store expects exactly 2 positional arguments in TileLang DSL v1") + src = self._require_tile_expr(args[0], "pto.dma_store source") + dst = self._require_tensor_slice(args[1], "pto.dma_store destination") + options = self._analyze_dma_options( + expr.keywords, + env, + allow_outer_lookup=allow_outer_lookup, + context="pto.dma_store", + ) + self._validate_dma_store_profile(src, dst, options) + return SemanticDmaStoreStmt(src=src, dst=dst, options=options), dict(env) + raise ValueError(f"unsupported DMA stmt pto.{expr.name}") + + def _analyze_dma_options( + self, + keywords: tuple[tuple[str, FrontendExprNode], ...], + env: dict[str, SemanticBinding], + *, + allow_outer_lookup: bool, + context: str, + ) -> SemanticDmaOptions: + analyzed: dict[str, SemanticExpr] = {} + for name, keyword_expr in keywords: + analyzed[name] = self._analyze_expr( + keyword_expr, + env, + allow_outer_lookup=allow_outer_lookup, + ) + + pad_mode = analyzed.get("pad_mode") + if pad_mode is not None: + self._pad_mode_value(pad_mode, default=PadMode.PadNull) + + left_padding = analyzed.get("left_padding") + if left_padding is not None: + left_padding = self._require_index_typed_expr(left_padding) + + right_padding = analyzed.get("right_padding") + if right_padding is not None: + right_padding = self._require_index_typed_expr(right_padding) + + init_out_buffer = analyzed.get("init_out_buffer") + if init_out_buffer is not None: + self._require_i1_expr(init_out_buffer, f"{context} init_out_buffer") + + return SemanticDmaOptions( + pad_mode=pad_mode, + pad_value=analyzed.get("pad_value"), + left_padding=left_padding, + right_padding=right_padding, + init_out_buffer=init_out_buffer, + ) + + def _analyze_vector_store_stmt( + self, + expr: FrontendCallExpr, + env: dict[str, SemanticBinding], + *, + allow_outer_lookup: bool, + ) -> tuple[SemanticStmt, dict[str, SemanticBinding]]: + if expr.name in {"psts", "pst", "psti"}: + canonical_name = "psts" if expr.name == "pst" else expr.name + if len(expr.args) in {2, 3} and isinstance(expr.args[1], FrontendSubscriptExpr): + raise TypeError( + f"pto.{expr.name} does not support Tile element-indexing syntax in TileLang DSL v1; " + f"use explicit pointer form `pto.{expr.name}(mask, buf, offset[, dist])`" + ) + + args = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args + ) + dist_expr: SemanticExpr | None = None + if len(args) == 3: + value, destination, offset = args + indices = (offset,) + elif len(args) == 4: + value, destination, offset, dist_expr = args + indices = (offset,) + else: + raise TypeError( + f"pto.{expr.name} expects 3 or 4 positional arguments in TileLang DSL v1: " + f"`pto.{expr.name}(mask, buf, offset[, dist])`" + ) + self._require_mask_expr(value, f"pto.{expr.name} value") + self._require_vector_pointer_expr(destination, f"pto.{expr.name} destination") + normalized_indices = [] + for index in indices: + if expr.name == "psti": + self._require_i32_like_expr(index, "pto.psti offset") + else: + index = self._require_index_typed_expr(index) + normalized_indices.append(index) + indices = tuple(normalized_indices) + dist = self._normalize_predicate_store_dist(dist_expr, f"pto.{expr.name} dist") + return ( + SemanticPredicateStoreStmt( + op_name=canonical_name, + value=value, + destination=destination, + indices=indices, + dist=dist, + ), + dict(env), + ) + + if expr.name in {"vsta", "vstas", "vstar"}: + offset: SemanticExpr | None = None + op_name = "vstas" if expr.name == "vsta" else expr.name + if expr.name == "vsta": + if len(expr.args) == 2: + value = self._analyze_expr(expr.args[0], env, allow_outer_lookup=allow_outer_lookup) + destination, indices = self._analyze_tile_vector_access( + expr.args[1], + env, + allow_outer_lookup=allow_outer_lookup, + context="pto.vsta destination", + ) + offset = SemanticLiteralExpr(value=0, type=SemanticScalarType(dtype=i32)) + else: + args = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args + ) + if len(args) != 3: + raise TypeError("pto.vsta expects 2 or 3 positional arguments in TileLang DSL v1") + value, destination, offset = args + indices = () + elif expr.name == "vstas": + if len(expr.args) == 3 and isinstance(expr.args[1], FrontendSubscriptExpr): + value = self._analyze_expr(expr.args[0], env, allow_outer_lookup=allow_outer_lookup) + destination, indices = self._analyze_tile_vector_access( + expr.args[1], + env, + allow_outer_lookup=allow_outer_lookup, + context="pto.vstas destination", + ) + offset = self._analyze_expr(expr.args[2], env, allow_outer_lookup=allow_outer_lookup) + else: + args = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args + ) + if len(args) != 3: + raise TypeError("pto.vstas expects exactly 3 positional arguments in TileLang DSL v1") + value, destination, offset = args + indices = () + else: + if len(expr.args) == 2 and isinstance(expr.args[1], FrontendSubscriptExpr): + value = self._analyze_expr(expr.args[0], env, allow_outer_lookup=allow_outer_lookup) + destination, indices = self._analyze_tile_vector_access( + expr.args[1], + env, + allow_outer_lookup=allow_outer_lookup, + context="pto.vstar destination", + ) + else: + args = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args + ) + if len(args) != 2: + raise TypeError("pto.vstar expects exactly 2 positional arguments in TileLang DSL v1") + value, destination = args + indices = () + self._require_align_expr(value, f"pto.{expr.name} value") + self._require_vector_pointer_expr(destination, f"pto.{expr.name} destination") + indices = tuple(self._require_index_typed_expr(index) for index in indices) + if offset is not None: + self._require_i32_like_expr(offset, f"pto.{expr.name} offset") + return ( + SemanticAlignStoreStmt( + op_name=op_name, + value=value, + destination=destination, + indices=indices, + offset=offset, + ), + dict(env), + ) + + if expr.name == "vscatter": + args = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args + ) + if len(args) != 4: + raise TypeError("pto.vscatter expects exactly 4 positional arguments in TileLang DSL v1") + value, destination, offsets, mask = args + value_type = self._require_vreg_expr(value, "pto.vscatter value") + self._require_vector_pointer_expr(destination, "pto.vscatter destination") + offsets_type = self._require_vreg_expr(offsets, "pto.vscatter offsets") + if not is_integer_dtype(offsets_type.element_dtype): + raise TypeError("pto.vscatter offsets must use an integer vector type in TileLang DSL v1") + if integer_bitwidth(offsets_type.element_dtype) != 32: + raise TypeError("pto.vscatter currently requires i32 offset vectors in TileLang DSL v1") + if value_type.lanes != offsets_type.lanes: + raise TypeError("pto.vscatter value and offsets must use the same lane count in TileLang DSL v1") + self._require_matching_vector_pointer(value_type, destination.type, "pto.vscatter") + self._require_mask_for_vreg(mask, value_type, "pto.vscatter") + return ( + SemanticVScatterStmt( + value=value, + destination=destination, + offsets=offsets, + mask=mask, + ), + dict(env), + ) + + if expr.name == "vsst": + if len(expr.args) == 3: + scalar = self._analyze_expr(expr.args[0], env, allow_outer_lookup=allow_outer_lookup) + destination, indices = self._analyze_tile_vector_access( + expr.args[1], + env, + allow_outer_lookup=allow_outer_lookup, + context="pto.vsst destination", + ) + mask = self._analyze_expr(expr.args[2], env, allow_outer_lookup=allow_outer_lookup) + else: + args = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args + ) + if len(args) != 4: + raise TypeError("pto.vsst expects 3 or 4 positional arguments in TileLang DSL v1") + scalar, destination, offset, mask = args + indices = (offset,) + scalar_type = self._require_scalar_expr(scalar, "pto.vsst scalar") + self._require_vector_pointer_expr(destination, "pto.vsst destination") + indices = tuple(self._require_index_typed_expr(index) for index in indices) + destination_dtype = destination.type.element_dtype + if scalar_type.dtype != destination_dtype: + raise TypeError("pto.vsst scalar dtype must match destination element dtype in TileLang DSL v1") + value = SemanticCallExpr( + namespace="pto", + name="vbr", + args=(scalar,), + type=self._vreg_type_for_dtype(destination_dtype), + ) + self._require_mask_for_vreg(mask, value.type, "pto.vsst") + self._require_matching_vector_pointer(value.type, destination.type, "pto.vsst") + return ( + SemanticVectorStoreStmt( + value=value, + destination=destination, + indices=indices, + dist=None, + mask=mask, + ), + dict(env), + ) + + if expr.name == "vsts": + analyzed_keywords = { + name: self._analyze_expr(value, env, allow_outer_lookup=allow_outer_lookup) + for name, value in expr.keywords + } + unexpected_keywords = sorted(set(analyzed_keywords) - {"dist"}) + if unexpected_keywords: + keyword_text = ", ".join(unexpected_keywords) + raise TypeError( + "pto.vsts only accepts keyword attr `dist`; " + f"got unsupported keyword(s): {keyword_text}" + ) + dist = self._normalize_vsts_dist(analyzed_keywords.get("dist"), "pto.vsts dist") + if len(expr.args) == 3: + value = self._analyze_expr(expr.args[0], env, allow_outer_lookup=allow_outer_lookup) + destination, indices = self._analyze_tile_vector_access( + expr.args[1], + env, + allow_outer_lookup=allow_outer_lookup, + context="pto.vsts destination", + ) + mask = self._analyze_expr(expr.args[2], env, allow_outer_lookup=allow_outer_lookup) + else: + args = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args + ) + if len(args) != 4: + raise TypeError("pto.vsts expects 3 or 4 positional arguments in TileLang DSL v1") + value, destination, offset, mask = args + indices = (offset,) + self._require_vreg_expr(value, "pto.vsts value") + self._require_vector_pointer_expr(destination, "pto.vsts destination") + indices = tuple(self._require_index_typed_expr(index) for index in indices) + self._require_mask_for_vsts(mask, value.type, dist, "pto.vsts") + self._require_matching_vector_pointer(value.type, destination.type, "pto.vsts") + return ( + SemanticVectorStoreStmt( + value=value, + destination=destination, + indices=indices, + dist=dist, + mask=mask, + ), + dict(env), + ) + + if len(expr.args) == 5: + low = self._analyze_expr(expr.args[0], env, allow_outer_lookup=allow_outer_lookup) + high = self._analyze_expr(expr.args[1], env, allow_outer_lookup=allow_outer_lookup) + destination, indices = self._analyze_tile_vector_access( + expr.args[2], + env, + allow_outer_lookup=allow_outer_lookup, + context="pto.vstsx2 destination", + ) + dist = self._analyze_expr(expr.args[3], env, allow_outer_lookup=allow_outer_lookup) + mask = self._analyze_expr(expr.args[4], env, allow_outer_lookup=allow_outer_lookup) + else: + args = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args + ) + if len(args) != 6: + raise TypeError("pto.vstsx2 expects 5 or 6 positional arguments in TileLang DSL v1") + low, high, destination, offset, dist, mask = args + indices = (offset,) + low_type = self._require_vreg_expr(low, "pto.vstsx2 low") + high_type = self._require_vreg_expr(high, "pto.vstsx2 high") + if low_type != high_type: + raise TypeError("pto.vstsx2 requires low/high vectors to use the same vector type") + self._require_vector_pointer_expr(destination, "pto.vstsx2 destination") + indices = tuple(self._require_index_typed_expr(index) for index in indices) + dist = self._normalize_vstsx2_dist(dist) + self._require_mask_for_vreg(mask, low_type, "pto.vstsx2") + self._require_matching_vector_pointer(low_type, destination.type, "pto.vstsx2") + return ( + SemanticVectorPairStoreStmt( + low=low, + high=high, + destination=destination, + indices=indices, + dist=dist, + mask=mask, + ), + dict(env), + ) + + def _analyze_scalar_store_stmt( + self, + expr: FrontendCallExpr, + env: dict[str, SemanticBinding], + *, + allow_outer_lookup: bool, + ) -> tuple[SemanticStmt, dict[str, SemanticBinding]]: + args = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args + ) + if len(args) != 3: + raise TypeError("pto.store_scalar expects exactly 3 positional arguments in TileLang DSL v1") + if isinstance(args[0].type, SemanticPtrType): + destination = self._require_pointer_expr(args[0], "pto.store_scalar destination") + offset = args[1] + value = args[2] + else: + value = args[0] + destination = self._require_pointer_expr(args[1], "pto.store_scalar destination") + offset = args[2] + offset = self._require_index_typed_expr(offset) + value_type = self._require_scalar_expr(value, "pto.store_scalar value") + if value_type.dtype != destination.type.element_dtype: + raise TypeError("pto.store_scalar value dtype must match destination pointer element dtype") + return ( + SemanticScalarStoreStmt(value=value, destination=destination, offset=offset), + dict(env), + ) + + def _analyze_sync_stmt( + self, + expr: FrontendCallExpr, + env: dict[str, SemanticBinding], + *, + allow_outer_lookup: bool, + ) -> tuple[SemanticStmt, dict[str, SemanticBinding]]: + args = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args + ) + if expr.name in {"set_flag", "wait_flag"}: + if len(args) != 3: + raise TypeError(f"pto.{expr.name} expects exactly 3 positional arguments in TileLang DSL v1") + src_pipe = self._require_sync_pipe(args[0], f"pto.{expr.name} source pipe") + dst_pipe = self._require_sync_pipe(args[1], f"pto.{expr.name} destination pipe") + event = self._require_sync_event(args[2], f"pto.{expr.name} event") + if expr.name == "set_flag": + return SemanticSetFlagStmt(src_pipe=src_pipe, dst_pipe=dst_pipe, event=event), dict(env) + return SemanticWaitFlagStmt(src_pipe=src_pipe, dst_pipe=dst_pipe, event=event), dict(env) + if expr.name in {"get_buf", "rls_buf"}: + if len(args) not in {2, 3}: + raise TypeError(f"pto.{expr.name} expects 2 or 3 positional arguments in TileLang DSL v1") + pipe = self._require_sync_pipe(args[0], f"pto.{expr.name} pipe") + self._require_i64_like_expr(args[1], f"pto.{expr.name} buf_id") + mode = args[2] if len(args) == 3 else SemanticLiteralExpr(value=0, type=SemanticScalarType(dtype=i64)) + self._require_i64_like_expr(mode, f"pto.{expr.name} mode") + if expr.name == "get_buf": + return SemanticGetBufStmt(pipe=pipe, buf_id=args[1], mode=mode), dict(env) + return SemanticRlsBufStmt(pipe=pipe, buf_id=args[1], mode=mode), dict(env) + if expr.name == "mem_bar": + if len(args) != 1: + raise TypeError("pto.mem_bar expects exactly 1 positional argument in TileLang DSL v1") + barrier_type = self._require_barrier_type(args[0], "pto.mem_bar barrier_type") + return SemanticMemBarStmt(barrier_type=barrier_type), dict(env) + if expr.name in {"set_cross_core", "set_intra_block", "wait_flag_dev", "wait_intra_core"}: + if len(args) != 2: + raise TypeError(f"pto.{expr.name} expects exactly 2 positional arguments in TileLang DSL v1") + identifier = self._require_scalar_or_index_expr(args[0], f"pto.{expr.name} first operand") + self._require_i64_like_expr(identifier, f"pto.{expr.name} first operand") + event_id = self._normalize_event_id_expr(args[1], f"pto.{expr.name} event_id") + if expr.name == "set_cross_core": + return SemanticSetCrossCoreStmt(core_id=identifier, event_id=event_id), dict(env) + if expr.name == "set_intra_block": + return SemanticSetIntraBlockStmt(block_id=identifier, event_id=event_id), dict(env) + if expr.name == "wait_flag_dev": + return SemanticWaitFlagDevStmt(core_id=identifier, event_id=event_id), dict(env) + return SemanticWaitIntraCoreStmt(block_id=identifier, event_id=event_id), dict(env) + if expr.name == "set_intra_core": + if len(args) != 1: + raise TypeError("pto.set_intra_core expects exactly 1 positional argument in TileLang DSL v1") + config = self._require_scalar_or_index_expr(args[0], "pto.set_intra_core config") + self._require_i32_like_expr(config, "pto.set_intra_core config") + return SemanticSetIntraCoreStmt(config=config), dict(env) + if expr.name in {"pipe_barrier", "barrier"}: + if len(args) != 1: + raise TypeError(f"pto.{expr.name} expects exactly 1 positional argument in TileLang DSL v1") + pipe = self._require_sync_pipe(args[0], f"pto.{expr.name} pipe") + return SemanticPipeBarrierStmt(pipe=pipe), dict(env) + raise ValueError(f"unsupported sync stmt pto.{expr.name}") + + def _analyze_low_level_dma_stmt( + self, + expr: FrontendCallExpr, + env: dict[str, SemanticBinding], + *, + allow_outer_lookup: bool, + ) -> tuple[SemanticStmt, dict[str, SemanticBinding]]: + args = self._analyze_low_level_dma_operands( + expr, + env, + allow_outer_lookup=allow_outer_lookup, + ) + if expr.name in _LOW_LEVEL_DMA_UNARY_CONFIG_OPS: + if len(args) != 1: + raise TypeError(f"pto.{expr.name} expects exactly 1 positional argument in TileLang DSL") + scalar = self._require_scalar_expr(args[0], f"pto.{expr.name} pad_value") + if not _is_supported_mov_pad_scalar_dtype(scalar.dtype): + raise TypeError( + "pto.set_mov_pad_val pad_value must be an 8/16/32-bit integer or f16/bf16/f32 in TileLang DSL v1" + ) + return ( + SemanticDmaUnaryConfigStmt( + name=expr.name, + value=args[0], + ), + dict(env), + ) + if expr.name in _LOW_LEVEL_DMA_CONFIG_OPS: + if len(args) != 2: + raise TypeError(f"pto.{expr.name} expects exactly 2 positional arguments in TileLang DSL") + self._require_i64_like_expr(args[0], f"pto.{expr.name} first operand") + self._require_i64_like_expr(args[1], f"pto.{expr.name} second operand") + return ( + SemanticDmaConfigStmt( + name=expr.name, + first=args[0], + second=args[1], + ), + dict(env), + ) + if expr.name == "copy_gm_to_ubuf": + if len(args) != 11: + raise TypeError("pto.copy_gm_to_ubuf expects exactly 11 positional arguments in TileLang DSL") + source = self._require_pointer_expr(args[0], "pto.copy_gm_to_ubuf source", memory_space="gm") + destination = self._require_pointer_expr(args[1], "pto.copy_gm_to_ubuf destination", memory_space="ub") + for operand, label in zip( + args[2:7] + args[8:], + ( + "sid", + "n_burst", + "len_burst", + "left_padding_count", + "right_padding_count", + "l2_cache_ctl", + "gm_stride", + "ub_stride", + ), + ): + self._require_i64_like_expr(operand, f"pto.copy_gm_to_ubuf {label}") + self._require_i1_expr(args[7], "pto.copy_gm_to_ubuf data_select_bit") + return ( + SemanticLowLevelCopyStmt( + name=expr.name, + source=source, + destination=destination, + operands=args[2:], + ), + dict(env), + ) + if expr.name == "copy_ubuf_to_gm": + if len(args) != 8: + raise TypeError("pto.copy_ubuf_to_gm expects exactly 8 positional arguments in TileLang DSL") + source = self._require_pointer_expr(args[0], "pto.copy_ubuf_to_gm source", memory_space="ub") + destination = self._require_pointer_expr(args[1], "pto.copy_ubuf_to_gm destination", memory_space="gm") + for operand, label in zip( + args[2:], + ( + "sid", + "n_burst", + "len_burst", + "reserved", + "burst_dst_stride", + "burst_src_stride", + ), + ): + self._require_i64_like_expr(operand, f"pto.copy_ubuf_to_gm {label}") + return ( + SemanticLowLevelCopyStmt( + name=expr.name, + source=source, + destination=destination, + operands=args[2:], + ), + dict(env), + ) + if expr.name == "copy_ubuf_to_ubuf": + if len(args) != 7: + raise TypeError("pto.copy_ubuf_to_ubuf expects exactly 7 positional arguments in TileLang DSL") + source = self._require_pointer_expr(args[0], "pto.copy_ubuf_to_ubuf source", memory_space="ub") + destination = self._require_pointer_expr(args[1], "pto.copy_ubuf_to_ubuf destination", memory_space="ub") + for operand, label in zip( + args[2:], + ("sid", "n_burst", "len_burst", "src_stride", "dst_stride"), + ): + self._require_i64_like_expr(operand, f"pto.copy_ubuf_to_ubuf {label}") + return ( + SemanticLowLevelCopyStmt( + name=expr.name, + source=source, + destination=destination, + operands=args[2:], + ), + dict(env), + ) + raise ValueError(f"unsupported low-level DMA stmt pto.{expr.name}") + + def _analyze_low_level_dma_operands( + self, + expr: FrontendCallExpr, + env: dict[str, SemanticBinding], + *, + allow_outer_lookup: bool, + ) -> tuple[SemanticExpr, ...]: + if expr.args and expr.keywords: + raise TypeError( + f"pto.{expr.name} does not support mixing positional and keyword operands in TileLang DSL v1" + ) + if not expr.keywords: + return tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args + ) + + analyzed_keywords: dict[str, SemanticExpr] = { + name: self._analyze_expr(value, env, allow_outer_lookup=allow_outer_lookup) + for name, value in expr.keywords + } + + def index_literal(value: int) -> SemanticLiteralExpr: + return SemanticLiteralExpr(value=value, type=SemanticIndexType()) + + def bool_literal(value: bool) -> SemanticLiteralExpr: + return SemanticLiteralExpr(value=value, type=SemanticScalarType(dtype=i1)) + + if expr.name == "set_mov_pad_val": + return (analyzed_keywords["pad_value"],) + if expr.name in { + "set_loop2_stride_outtoub", + "set_loop1_stride_outtoub", + "set_loop2_stride_ubtoout", + "set_loop1_stride_ubtoout", + }: + return ( + analyzed_keywords["src_stride"], + analyzed_keywords["dst_stride"], + ) + if expr.name in {"set_loop_size_outtoub", "set_loop_size_ubtoout"}: + return ( + analyzed_keywords["loop1"], + analyzed_keywords["loop2"], + ) + if expr.name == "copy_gm_to_ubuf": + if "data_select_bit" in analyzed_keywords and "enable_ub_pad" in analyzed_keywords: + raise TypeError( + "pto.copy_gm_to_ubuf keyword form accepts either `data_select_bit` or `enable_ub_pad`, not both" + ) + return ( + analyzed_keywords["src"], + analyzed_keywords["dst"], + analyzed_keywords.get("sid", index_literal(0)), + analyzed_keywords["n_burst"], + analyzed_keywords["len_burst"], + analyzed_keywords.get("left_padding_count", index_literal(0)), + analyzed_keywords.get("right_padding_count", index_literal(0)), + analyzed_keywords.get( + "data_select_bit", + analyzed_keywords.get("enable_ub_pad", bool_literal(False)), + ), + analyzed_keywords.get("l2_cache_ctl", index_literal(0)), + analyzed_keywords["gm_stride"], + analyzed_keywords["ub_stride"], + ) + if expr.name == "copy_ubuf_to_gm": + if "burst_dst_stride" in analyzed_keywords and "gm_stride" in analyzed_keywords: + raise TypeError( + "pto.copy_ubuf_to_gm keyword form accepts either `burst_dst_stride` or `gm_stride`, not both" + ) + if "burst_src_stride" in analyzed_keywords and "ub_stride" in analyzed_keywords: + raise TypeError( + "pto.copy_ubuf_to_gm keyword form accepts either `burst_src_stride` or `ub_stride`, not both" + ) + return ( + analyzed_keywords["src"], + analyzed_keywords["dst"], + analyzed_keywords.get("sid", index_literal(0)), + analyzed_keywords["n_burst"], + analyzed_keywords["len_burst"], + analyzed_keywords.get("reserved", index_literal(0)), + analyzed_keywords.get( + "burst_dst_stride", + analyzed_keywords["gm_stride"], + ), + analyzed_keywords.get( + "burst_src_stride", + analyzed_keywords["ub_stride"], + ), + ) + raise TypeError( + f"pto.{expr.name} keyword form is not implemented in TileLang DSL v1" + ) + + def _require_tensor_slice( + self, + expr: SemanticExpr, + context: str, + ) -> SemanticTensorSliceExpr: + if not isinstance(expr, SemanticTensorSliceExpr): + raise TypeError(f"{context} must be a TensorView slice in TileLang DSL v1") + return expr + + def _require_tile_expr(self, expr: SemanticExpr, context: str) -> SemanticExpr: + if not isinstance(expr.type, SemanticTileType): + raise TypeError(f"{context} must be a Tile value in TileLang DSL v1") + if expr.type.rank != 2: + raise TypeError(f"{context} currently only supports rank-2 Tile values in TileLang DSL v1") + if expr.type.shape is None: + raise TypeError(f"{context} requires a statically specialized Tile shape in TileLang DSL v1") + if expr.type.memory_space != "ub": + raise TypeError(f"{context} currently only supports MemorySpace.UB Tile values in TileLang DSL v1") + return expr + + def _require_pointer_expr( + self, + expr: SemanticExpr, + context: str, + *, + memory_space: str | None = None, + ) -> SemanticExpr: + if not isinstance(expr.type, SemanticPtrType): + raise TypeError(f"{context} must be a pointer value in TileLang DSL") + if memory_space is not None and expr.type.memory_space != memory_space: + raise TypeError(f"{context} requires MemorySpace.{memory_space.upper()} pointers in TileLang DSL") + return expr + + def _require_vector_pointer_expr(self, expr: SemanticExpr, context: str) -> SemanticExpr: + if isinstance(expr.type, SemanticTileType): + return self._require_tile_expr(expr, context) + return self._require_pointer_expr(expr, context, memory_space="ub") + + def _validate_dma_common_types( + self, + tensor_slice_type: SemanticTensorSliceType, + tile_type: SemanticTileType, + op_name: str, + ) -> None: + if tensor_slice_type.rank != 2: + raise TypeError(f"{op_name} currently only supports rank-2 TensorView slices in TileLang DSL v1") + if tile_type.rank != 2 or tile_type.shape is None: + raise TypeError(f"{op_name} requires a statically specialized rank-2 Tile in TileLang DSL v1") + if tensor_slice_type.element_dtype != tile_type.element_dtype: + raise TypeError(f"{op_name} requires matching TensorView/Tile element dtypes in TileLang DSL v1") + + def _validate_dma_load_profile( + self, + src: SemanticTensorSliceExpr, + dst: SemanticExpr, + options: SemanticDmaOptions, + ) -> None: + assert isinstance(dst.type, SemanticTileType) + self._validate_dma_common_types(src.type, dst.type, "pto.dma_load") + self._validate_dma_slice_profile(src, "pto.dma_load") + + pad_mode = self._pad_mode_value(options.pad_mode, default=PadMode.PadNull) + left_padding = self._require_static_non_negative_index_value( + options.left_padding, + context="pto.dma_load left_padding", + default=0, + ) + right_padding = self._require_static_non_negative_index_value( + options.right_padding, + context="pto.dma_load right_padding", + default=0, + ) + self._require_static_bool_value( + options.init_out_buffer, + context="pto.dma_load init_out_buffer", + default=False, + ) + self._validate_dma_load_option_profile(options, pad_mode) + + valid_shape = self._resolved_tile_valid_shape(dst.type) + expected_extents = ( + valid_shape[0], + self._trimmed_tile_axis_extent( + valid_shape[1], + left_padding, + right_padding, + op_name="pto.dma_load", + axis=1, + window_label="destination Tile valid window", + ), + ) + self._validate_dma_extent_match( + actual_extents=src.type.extents, + expected_extents=expected_extents, + op_name="pto.dma_load", + actual_label="source slice", + expected_label="destination Tile valid window", + left_padding=left_padding, + right_padding=right_padding, + ) + + def _validate_dma_store_profile( + self, + src: SemanticExpr, + dst: SemanticTensorSliceExpr, + options: SemanticDmaOptions, + ) -> None: + assert isinstance(src.type, SemanticTileType) + self._validate_dma_common_types(dst.type, src.type, "pto.dma_store") + self._validate_dma_slice_profile(dst, "pto.dma_store") + + pad_mode = self._pad_mode_value(options.pad_mode, default=PadMode.PadNull) + left_padding = self._require_static_non_negative_index_value( + options.left_padding, + context="pto.dma_store left_padding", + default=0, + ) + right_padding = self._require_static_non_negative_index_value( + options.right_padding, + context="pto.dma_store right_padding", + default=0, + ) + self._validate_dma_store_option_profile(options, pad_mode) + + valid_shape = self._resolved_tile_valid_shape(src.type) + expected_extents = ( + valid_shape[0], + self._trimmed_tile_axis_extent( + valid_shape[1], + left_padding, + right_padding, + op_name="pto.dma_store", + axis=1, + window_label="source Tile interior window", + ), + ) + self._validate_dma_extent_match( + actual_extents=dst.type.extents, + expected_extents=expected_extents, + op_name="pto.dma_store", + actual_label="destination slice", + expected_label="source Tile interior window", + left_padding=left_padding, + right_padding=right_padding, + ) + + def _validate_dma_slice_profile( + self, + tensor_slice: SemanticTensorSliceExpr, + op_name: str, + ) -> None: + for axis, slice_axis in enumerate(tensor_slice.slices): + step = self._static_index_value(slice_axis.step, default=1) + if step is None: + raise TypeError( + f"{op_name} stable frontend-only DMA profile requires a static positive " + f"slice step on axis {axis}" + ) + if step <= 0: + raise TypeError( + f"{op_name} stable frontend-only DMA profile requires a positive " + f"slice step on axis {axis}, got {step!r}" + ) + if axis == 1 and step != 1: + raise TypeError( + f"{op_name} stable frontend-only DMA profile only supports step == 1 " + "on TensorView slice axis 1" + ) + + def _validate_dma_load_option_profile( + self, + options: SemanticDmaOptions, + pad_mode: PadMode, + ) -> None: + if pad_mode == PadMode.PadValue and options.pad_value is None: + raise TypeError( + "pto.dma_load stable frontend-only DMA profile requires `pad_value` when " + "`pad_mode=PadMode.PadValue`" + ) + if pad_mode != PadMode.PadValue and options.pad_value is not None: + raise TypeError( + "pto.dma_load stable frontend-only DMA profile only accepts `pad_value` " + "when `pad_mode=PadMode.PadValue`" + ) + + def _validate_dma_store_option_profile( + self, + options: SemanticDmaOptions, + pad_mode: PadMode, + ) -> None: + if options.pad_value is not None: + raise TypeError( + "pto.dma_store stable frontend-only DMA profile does not support `pad_value`; " + "GM-side fill is unsupported" + ) + if pad_mode != PadMode.PadNull: + raise TypeError( + "pto.dma_store stable frontend-only DMA profile only supports " + "`pad_mode=PadMode.PadNull`; non-PadNull store padding would require GM-side fill" + ) + + def _resolved_tile_valid_shape( + self, + tile_type: SemanticTileType, + ) -> tuple[int | None, ...]: + assert tile_type.shape is not None + return tile_type.shape if tile_type.valid_shape is None else tile_type.valid_shape + + def _trimmed_tile_axis_extent( + self, + base_extent: int | None, + left_padding: int, + right_padding: int, + *, + op_name: str, + axis: int, + window_label: str, + ) -> int | None: + if base_extent is None: + return None + trimmed_extent = base_extent - left_padding - right_padding + if trimmed_extent <= 0: + raise TypeError( + f"{op_name} stable frontend-only DMA profile requires {window_label} axis {axis}=" + f"{base_extent!r} to remain positive after left_padding={left_padding} " + f"and right_padding={right_padding}" + ) + return trimmed_extent + + def _validate_dma_extent_match( + self, + *, + actual_extents: tuple[int | None, ...], + expected_extents: tuple[int | None, ...], + op_name: str, + actual_label: str, + expected_label: str, + left_padding: int, + right_padding: int, + ) -> None: + for axis, (actual_extent, expected_extent) in enumerate(zip(actual_extents, expected_extents)): + if actual_extent is None or expected_extent is None: + continue + if actual_extent != expected_extent: + padding_suffix = "" + if axis == 1 and (left_padding != 0 or right_padding != 0): + padding_suffix = ( + f" after left_padding={left_padding} and right_padding={right_padding}" + ) + raise TypeError( + f"{op_name} stable frontend-only DMA profile requires {actual_label} extent " + f"axis {axis}={actual_extent!r} to match {expected_label} axis {axis}=" + f"{expected_extent!r}{padding_suffix}" + ) + + def _bind_assignment_target( + self, + target: FrontendTargetNode, + value: SemanticExpr, + env: dict[str, SemanticBinding], + annotation: Any | None, + ) -> tuple[SemanticBinding, ...]: + if isinstance(target, FrontendNameTarget): + if isinstance(value.type, SemanticTupleType): + raise ValueError("multi-result call assignment requires tuple binding in TileLang DSL v1") + inferred_type: SemanticType = value.type + if isinstance(value.type, SemanticTensorSliceType): + # Tensor slicing materializes a logical partition descriptor value in IR. + inferred_type = SemanticPartitionTensorViewType( + element_dtype=value.type.element_dtype, + rank=value.type.rank, + ) + annotated_type = self._annotation_type(annotation, inferred_type, env) + binding = self._make_binding( + target.name, + annotated_type if annotated_type is not None else inferred_type, + "ssa", + value=self._binding_value_for_expr(value), + ) + env[target.name] = binding + return (binding,) + if isinstance(target, FrontendTupleTarget): + if isinstance(value.type, SemanticTupleType): + element_types = value.type.elements + elif isinstance(value.type, SemanticShapeType): + element_types = tuple(SemanticIndexType() for _ in range(value.type.rank)) + else: + raise ValueError("tuple assignment expects a tuple-typed value") + if annotation is not None: + raise TypeError("annotated tuple assignment is not supported in TileLang DSL v1") + if len(target.elements) != len(element_types): + raise ValueError("tuple assignment arity must match the tuple value") + tuple_values: tuple[SemanticExpr, ...] + if isinstance(value, SemanticTupleExpr): + tuple_values = value.elements + elif isinstance(value, SemanticAttributeAccess) and isinstance(value.type, SemanticShapeType): + if isinstance(value.base, SemanticBindingRef): + if isinstance(value.base.type, SemanticTileType) and value.attr == "valid_shape": + valid_shape = value.base.type.valid_shape + if valid_shape is not None: + for axis, dim in enumerate(valid_shape): + if dim is None: + self._ensure_tile_valid_shape_parameter(value.base.binding, axis) + tuple_values = tuple( + SemanticSubscriptAccess( + base=value, + index=SemanticLiteralExpr(value=axis, type=SemanticIndexType()), + type=SemanticIndexType(), + ) + for axis in range(value.type.rank) + ) + elif isinstance(value, SemanticCallExpr): + if len(value.args) == len(element_types): + tuple_values = value.args + else: + tuple_values = tuple( + SemanticLiteralExpr(value=None, type=element_type) for element_type in element_types + ) + else: + tuple_values = tuple( + SemanticLiteralExpr(value=None, type=element_type) for element_type in element_types + ) + bindings = [] + for element, element_type, element_value in zip(target.elements, element_types, tuple_values): + binding = self._make_binding( + element.name, + element_type, + "ssa", + value=self._binding_value_for_expr(element_value), + ) + env[element.name] = binding + bindings.append(binding) + return tuple(bindings) + raise ValueError(f"unsupported frontend assignment target {type(target).__name__}") + + def _binding_value_for_expr(self, expr: SemanticExpr) -> Any | None: + return self._try_static_value(expr) + + def _annotation_type( + self, + annotation: Any | None, + inferred_type: SemanticType | None, + env: dict[str, SemanticBinding], + ) -> SemanticType | None: + if annotation is None: + return inferred_type + annotation_expr = self._analyze_annotation_expr(annotation, env) + if isinstance(annotation_expr.type, SemanticMetaType): + if annotation_expr.type.kind == "dtype" and isinstance(inferred_type, SemanticScalarType): + dtype = self._require_dtype_symbol(annotation_expr, "annotated scalar type") + if inferred_type.dtype != dtype: + raise TypeError( + f"annotated scalar type `{dtype!r}` does not match inferred {inferred_type.dtype!r}" + ) + return inferred_type + if annotation_expr.type.kind == "ptr_type" and isinstance(inferred_type, SemanticPtrType): + ptr_type = self._require_ptr_type_expr(annotation_expr, "annotated pointer type") + if inferred_type.element_dtype != ptr_type.element_dtype: + raise TypeError( + f"annotated pointer type `{ptr_type!r}` does not match inferred pointer element type {inferred_type.element_dtype!r}" + ) + if inferred_type.memory_space != ptr_type.memory_space.value: + raise TypeError( + f"annotated pointer type `{ptr_type!r}` does not match inferred pointer memory space `{inferred_type.memory_space}`" + ) + return inferred_type + if annotation_expr.type.kind == "vreg_type" and isinstance(inferred_type, SemanticVRegType): + vreg_type = self._require_vreg_type_expr(annotation_expr, "annotated vector type") + if inferred_type.element_dtype != vreg_type.element_dtype or inferred_type.lanes != vreg_type.lanes: + raise TypeError( + f"annotated vector type `{vreg_type!r}` does not match inferred !pto.vreg<{inferred_type.lanes}x{inferred_type.element_dtype.name}>" + ) + return inferred_type + if annotation_expr.type.kind == "vector_type" and isinstance(inferred_type, SemanticVectorType): + vector_type = self._require_vector_type_expr(annotation_expr, "annotated builtin vector type") + if ( + inferred_type.element_dtype != vector_type.element_dtype + or inferred_type.shape != vector_type.shape + ): + shape_text = "x".join(str(dim) for dim in inferred_type.shape) + raise TypeError( + f"annotated builtin vector type `{vector_type!r}` does not match inferred !pto.vector<{shape_text}x{inferred_type.element_dtype.name}>" + ) + return inferred_type + if annotation_expr.type.kind == "mask_type" and isinstance(inferred_type, SemanticMaskType): + mask_type = self._require_mask_type_expr(annotation_expr, "annotated mask type") + if inferred_type.granularity != mask_type.granularity: + raise TypeError( + f"annotated mask type `{mask_type!r}` does not match inferred !pto.mask<{inferred_type.granularity}>" + ) + return inferred_type + if annotation_expr.type.kind == "align_type" and isinstance(inferred_type, SemanticAlignType): + return inferred_type + if ( + annotation_expr.type.kind == "partition_tensor_view_type" + and isinstance(inferred_type, SemanticPartitionTensorViewType) + ): + return inferred_type + raise TypeError("unsupported annotated assignment type in TileLang DSL v1") + + def _analyze_annotation_expr( + self, + annotation: ast.AST, + env: dict[str, SemanticBinding], + ) -> SemanticExpr: + frontend_expr = self._build_frontend_annotation_expr(annotation) + return self._analyze_expr(frontend_expr, env, allow_outer_lookup=True) + + def _build_frontend_annotation_expr(self, node: ast.AST) -> FrontendExprNode: + if isinstance(node, ast.Name): + return FrontendNameExpr(name=node.id) + if isinstance(node, ast.Constant): + return FrontendConstantExpr(value=node.value) + if isinstance(node, ast.Attribute): + path = self._annotation_attribute_path(node) + if path is not None and path[0] in {"pto", "PAT", "PIPE", "EVENT"} and len(path) >= 2: + return FrontendSymbolExpr(namespace=".".join(path[:-1]), name=path[-1]) + return FrontendAttributeExpr( + base=self._build_frontend_annotation_expr(node.value), + attr=node.attr, + ) + if isinstance(node, ast.Call): + if any(keyword.arg is None for keyword in node.keywords): + raise TypeError("annotated assignment type does not support keyword unpacking in TileLang DSL v1") + if node.keywords: + raise TypeError("annotated assignment type does not support keyword arguments in TileLang DSL v1") + if isinstance(node.func, ast.Name): + return FrontendCallExpr( + namespace=None, + name=node.func.id, + args=tuple(self._build_frontend_annotation_expr(arg) for arg in node.args), + keywords=(), + ) + if isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name): + return FrontendCallExpr( + namespace=node.func.value.id, + name=node.func.attr, + args=tuple(self._build_frontend_annotation_expr(arg) for arg in node.args), + keywords=(), + ) + raise TypeError("unsupported annotated assignment type in TileLang DSL v1") + + def _annotation_attribute_path(self, node: ast.AST) -> tuple[str, ...] | None: + if isinstance(node, ast.Name): + return (node.id,) + if isinstance(node, ast.Attribute): + base_path = self._annotation_attribute_path(node.value) + if base_path is None: + return None + return base_path + (node.attr,) + return None + + def _analyze_for( + self, + stmt: FrontendForStmt, + env: dict[str, SemanticBinding], + *, + allow_outer_lookup: bool, + ) -> tuple[SemanticStmt, dict[str, SemanticBinding]]: + lower_bound = self._analyze_expr(stmt.lower_bound, env, allow_outer_lookup=allow_outer_lookup) + upper_bound = self._analyze_expr(stmt.upper_bound, env, allow_outer_lookup=allow_outer_lookup) + step = self._analyze_expr(stmt.step, env, allow_outer_lookup=allow_outer_lookup) + for expr in (lower_bound, upper_bound, step): + self._require_loop_bound_type(expr.type) + + body_env = dict(env) + induction_variable = self._make_binding(stmt.target, SemanticIndexType(), "loop_iv") + body_env[stmt.target] = induction_variable + body, final_body_env = self._analyze_block( + stmt.body, + body_env, + allow_outer_lookup=allow_outer_lookup, + ) + + updated_env = dict(env) + loop_carried = [] + for name, outer_binding in env.items(): + final_binding = final_body_env.get(name) + if final_binding is None or final_binding is outer_binding: + continue + merged_type = self._merge_loop_carried_types(outer_binding.type, final_binding.type) + if merged_type is None: + raise TypeError( + f"loop-carried binding '{name}' changes type from {outer_binding.type!r} to {final_binding.type!r}" + ) + merged = self._make_binding(name, merged_type, "loop_result") + updated_env[name] = merged + loop_carried.append(merged) + + return ( + SemanticForStmt( + induction_variable=induction_variable, + lower_bound=lower_bound, + upper_bound=upper_bound, + step=step, + body=body, + loop_carried=tuple(loop_carried), + ), + updated_env, + ) + + def _analyze_if( + self, + stmt: FrontendIfStmt, + env: dict[str, SemanticBinding], + *, + allow_outer_lookup: bool, + ) -> tuple[SemanticStmt, dict[str, SemanticBinding]]: + condition = self._analyze_expr(stmt.condition, env, allow_outer_lookup=allow_outer_lookup) + self._require_condition_type(condition.type) + if self._contains_meta_condition_operand(condition): + raise TypeError( + "if condition comparing meta values requires wrapping the condition with pto.constexpr(...) " + "in TileLang DSL v1" + ) + + then_body, then_env = self._analyze_block( + stmt.then_body, + dict(env), + allow_outer_lookup=allow_outer_lookup, + ) + else_body, else_env = self._analyze_block( + stmt.else_body, + dict(env), + allow_outer_lookup=allow_outer_lookup, + ) + + updated_env = dict(env) + merged_results: list[SemanticIfResult] = [] + for name, outer_binding in env.items(): + then_binding = then_env.get(name, outer_binding) + else_binding = else_env.get(name, outer_binding) + if then_binding is outer_binding and else_binding is outer_binding: + continue + if then_binding.type != else_binding.type: + raise TypeError( + f"if/else merge for '{name}' changes type between branches: " + f"{then_binding.type!r} vs {else_binding.type!r}" + ) + merged_binding = self._make_binding(name, then_binding.type, "if_result") + updated_env[name] = merged_binding + merged_results.append( + SemanticIfResult( + result_binding=merged_binding, + then_binding=then_binding, + else_binding=else_binding, + ) + ) + + return ( + SemanticIfStmt( + condition=condition, + then_body=then_body, + else_body=else_body, + results=tuple(merged_results), + ), + updated_env, + ) + + def _contains_meta_condition_operand(self, expr: SemanticExpr) -> bool: + if isinstance(expr, SemanticBinaryExpr): + if expr.op in {"eq", "ne"} and ( + isinstance(expr.lhs.type, SemanticMetaType) or isinstance(expr.rhs.type, SemanticMetaType) + ): + return True + if expr.op in {"and", "or"}: + return self._contains_meta_condition_operand(expr.lhs) or self._contains_meta_condition_operand(expr.rhs) + return False + + def _analyze_constexpr_if( + self, + stmt: FrontendIfStmt, + env: dict[str, SemanticBinding], + *, + allow_outer_lookup: bool, + ) -> tuple[tuple[SemanticStmt, ...], dict[str, SemanticBinding]]: + condition = self._analyze_expr(stmt.condition, env, allow_outer_lookup=allow_outer_lookup) + self._require_condition_type(condition.type) + static_value = self._require_constexpr_condition_bool( + condition, + context="if pto.constexpr(...) condition", + ) + selected_body = stmt.then_body if static_value else stmt.else_body + return self._analyze_block( + selected_body, + dict(env), + allow_outer_lookup=allow_outer_lookup, + ) + + def _analyze_strict_vecscope( + self, + stmt: FrontendStrictVecscopeStmt, + env: dict[str, SemanticBinding], + ) -> tuple[SemanticStmt, dict[str, SemanticBinding]]: + if not self.node.advanced_enabled: + raise TypeError(advanced_mode_message("strict_vecscope")) + if len(stmt.captures) != len(stmt.block_arguments): + raise ValueError("strict_vecscope capture arity must match block arguments") + + captures = tuple( + self._analyze_expr(expr, env, allow_outer_lookup=True) + for expr in stmt.captures + ) + scope_env: dict[str, SemanticBinding] = {} + block_arguments = [] + for name, capture in zip(stmt.block_arguments, captures): + if capture.type is None: + raise TypeError( + f"strict_vecscope block argument '{name}' type could not be inferred" + ) + block_binding = self._make_binding(name, capture.type, "strict_vecscope_arg") + scope_env[name] = block_binding + block_arguments.append(block_binding) + body, _ = self._analyze_block( + stmt.body, + scope_env, + allow_outer_lookup=False, + ) + return ( + SemanticStrictVecscopeStmt( + captures=captures, + block_arguments=tuple(block_arguments), + body=body, + ), + dict(env), + ) + + def _analyze_expr( + self, + expr: FrontendExprNode, + env: dict[str, SemanticBinding], + *, + allow_outer_lookup: bool, + ) -> SemanticExpr: + if isinstance(expr, FrontendNameExpr): + binding = env.get(expr.name) + if binding is None: + if allow_outer_lookup: + raise ValueError(f"unknown name '{expr.name}'") + raise ValueError( + f"implicit capture of '{expr.name}' is not allowed in pto.strict_vecscope" + ) + return self._attach_expr_source_location( + SemanticBindingRef(binding=binding, type=binding.type), + expr, + ) + if isinstance(expr, FrontendConstantExpr): + if isinstance(expr.value, bool): + return self._attach_expr_source_location( + SemanticLiteralExpr(value=expr.value, type=SemanticScalarType(dtype=i1)), + expr, + ) + if isinstance(expr.value, int): + return self._attach_expr_source_location( + SemanticLiteralExpr(value=expr.value, type=SemanticIndexType()), + expr, + ) + if isinstance(expr.value, float): + return self._attach_expr_source_location( + SemanticLiteralExpr( + value=expr.value, + type=SemanticScalarType(dtype=f32), + ), + expr, + ) + if isinstance(expr.value, str): + return self._attach_expr_source_location( + SemanticLiteralExpr( + value=expr.value, + type=SemanticMetaType(kind="string"), + ), + expr, + ) + if expr.value is None: + return self._attach_expr_source_location( + SemanticLiteralExpr(value=None, type=SemanticIndexType()), + expr, + ) + raise TypeError(f"unsupported constant {expr.value!r} in TileLang DSL v1") + if isinstance(expr, FrontendSymbolExpr): + return self._attach_expr_source_location( + self._analyze_symbol_expr(expr), + expr, + ) + if isinstance(expr, FrontendSliceExpr): + start = None if expr.start is None else self._analyze_expr(expr.start, env, allow_outer_lookup=allow_outer_lookup) + stop = None if expr.stop is None else self._analyze_expr(expr.stop, env, allow_outer_lookup=allow_outer_lookup) + step = None if expr.step is None else self._analyze_expr(expr.step, env, allow_outer_lookup=allow_outer_lookup) + if start is not None: + start = self._require_index_typed_expr(start) + if stop is not None: + stop = self._require_index_typed_expr(stop) + if step is not None: + step = self._require_index_typed_expr(step) + return self._attach_expr_source_location( + SemanticSliceExpr( + start=start, + stop=stop, + step=step, + type=SemanticSliceType(), + ), + expr, + ) + if isinstance(expr, FrontendTupleExpr): + elements = tuple( + self._analyze_expr(element, env, allow_outer_lookup=allow_outer_lookup) + for element in expr.elements + ) + return self._attach_expr_source_location( + SemanticTupleExpr( + elements=elements, + type=SemanticTupleType(elements=tuple(element.type for element in elements)), + ), + expr, + ) + if isinstance(expr, FrontendAttributeExpr): + base = self._analyze_expr(expr.base, env, allow_outer_lookup=allow_outer_lookup) + if expr.attr == "element_type": + return self._attach_expr_source_location(self._element_type_expr(base), expr) + if expr.attr == "rank": + return self._attach_expr_source_location(self._rank_expr(base), expr) + if expr.attr == "memory_space": + return self._attach_expr_source_location(self._memory_space_expr(base), expr) + if expr.attr == "pad_value" and isinstance(base.type, SemanticTileType): + return self._attach_expr_source_location(self._tile_pad_value_expr(base), expr) + if expr.attr == "config": + return self._attach_expr_source_location(self._tile_config_expr(base), expr) + if expr.attr == "valid_shape": + return self._attach_expr_source_location(self._valid_shape_expr(base), expr) + if expr.attr == "strides": + return self._attach_expr_source_location(self._strides_expr(base), expr) + if isinstance(base.type, SemanticTileConfigType): + return self._attach_expr_source_location(self._tile_config_attr_expr(base, expr.attr), expr) + attr_type = self._attribute_type(base, expr.attr) + return self._attach_expr_source_location( + SemanticAttributeAccess(base=base, attr=expr.attr, type=attr_type), + expr, + ) + if isinstance(expr, FrontendSubscriptExpr): + base = self._analyze_expr(expr.base, env, allow_outer_lookup=allow_outer_lookup) + index = self._analyze_expr(expr.index, env, allow_outer_lookup=allow_outer_lookup) + if isinstance(base.type, (SemanticShapeType, SemanticTupleType)): + index = self._require_index_typed_expr(index) + result_type = self._subscript_type(base, index) + if isinstance(result_type, SemanticTensorSliceType): + slices = self._normalize_tensor_slice(index, base.type.rank) + return self._attach_expr_source_location( + SemanticTensorSliceExpr(base=base, slices=slices, type=result_type), + expr, + ) + return self._attach_expr_source_location( + SemanticSubscriptAccess(base=base, index=index, type=result_type), + expr, + ) + if isinstance(expr, FrontendBinaryExpr): + lhs = self._analyze_expr(expr.lhs, env, allow_outer_lookup=allow_outer_lookup) + rhs = self._analyze_expr(expr.rhs, env, allow_outer_lookup=allow_outer_lookup) + lhs, rhs = self._retarget_literals_for_binary_op(lhs, rhs, expr.op) + result_type = self._binary_type(lhs, rhs, expr.op) + return self._attach_expr_source_location( + SemanticBinaryExpr(lhs=lhs, op=expr.op, rhs=rhs, type=result_type), + expr, + ) + if isinstance(expr, FrontendCallExpr): + if expr.namespace is None: + binding = env.get(expr.name) + if ( + binding is not None + and isinstance(binding.type, SemanticMetaType) + and binding.type.kind == "dtype" + and isinstance(binding.value, ScalarType) + ): + if expr.keywords: + raise TypeError( + f"`{expr.name}` does not support keyword arguments in TileLang DSL v1" + ) + args = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args + ) + return self._analyze_scalar_constructor_for_dtype( + binding.value, + args, + surface_name=expr.name, + ) + if expr.namespace is None and expr.name in self._inline_proc_nodes: + if expr.keywords: + raise TypeError( + f"inline_proc call `{expr.name}` reached semantic analysis with unresolved keywords in TileLang DSL v1" + ) + args = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args + ) + return self._analyze_inline_proc_call_expr(expr.name, args) + if expr.namespace is None and expr.name == "eval": + if expr.keywords: + raise TypeError("method call `eval` does not support keyword arguments in TileLang DSL v1") + if not expr.args: + raise TypeError("`eval()` expects a receiver in TileLang DSL v1") + base = self._analyze_expr(expr.args[0], env, allow_outer_lookup=allow_outer_lookup) + args = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args[1:] + ) + return self._analyze_eval_method(base, args) + if expr.namespace is None and expr.name == "astype": + if expr.keywords: + raise TypeError("method call `astype` does not support keyword arguments in TileLang DSL v1") + if not expr.args: + raise TypeError("`astype()` expects a receiver in TileLang DSL v1") + base = self._analyze_expr(expr.args[0], env, allow_outer_lookup=allow_outer_lookup) + args = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args[1:] + ) + return self._analyze_astype_method(base, args) + if expr.namespace not in {None, "pto"} and expr.name == "eval": + if expr.keywords: + raise TypeError("method call `eval` does not support keyword arguments in TileLang DSL v1") + binding = env.get(expr.namespace) + if binding is None: + if allow_outer_lookup: + raise ValueError(f"unknown name '{expr.namespace}'") + raise ValueError( + f"implicit capture of '{expr.namespace}' is not allowed in pto.strict_vecscope" + ) + base = SemanticBindingRef(binding=binding, type=binding.type) + args = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args + ) + return self._analyze_eval_method(base, args) + if expr.namespace not in {None, "pto"} and expr.name == "as_ptr": + if expr.keywords: + raise TypeError("method call `as_ptr` does not support keyword arguments in TileLang DSL v1") + binding = env.get(expr.namespace) + if binding is None: + if allow_outer_lookup: + raise ValueError(f"unknown name '{expr.namespace}'") + raise ValueError( + f"implicit capture of '{expr.namespace}' is not allowed in pto.strict_vecscope" + ) + base = SemanticBindingRef(binding=binding, type=binding.type) + return self._analyze_as_ptr_method(base) + if expr.namespace not in {None, "pto"} and expr.name == "astype": + if expr.keywords: + raise TypeError("method call `astype` does not support keyword arguments in TileLang DSL v1") + binding = env.get(expr.namespace) + if binding is None: + if allow_outer_lookup: + raise ValueError(f"unknown name '{expr.namespace}'") + raise ValueError( + f"implicit capture of '{expr.namespace}' is not allowed in pto.strict_vecscope" + ) + base = SemanticBindingRef(binding=binding, type=binding.type) + args = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args + ) + return self._analyze_astype_method(base, args) + if expr.namespace == "pto" and expr.name == "vlds": + return self._analyze_vlds_frontend_call( + expr, + env, + allow_outer_lookup=allow_outer_lookup, + ) + if ( + expr.namespace == "pto" + and expr.name == "vldas" + and len(expr.args) == 1 + and isinstance(expr.args[0], FrontendSubscriptExpr) + ): + base, indices = self._analyze_tile_vector_access( + expr.args[0], + env, + allow_outer_lookup=allow_outer_lookup, + context="pto.vldas source", + ) + return self._analyze_vldas((base, *indices)) + if ( + expr.namespace == "pto" + and expr.name == "vldus" + and len(expr.args) == 2 + and isinstance(expr.args[0], FrontendSubscriptExpr) + ): + base, indices = self._analyze_tile_vector_access( + expr.args[0], + env, + allow_outer_lookup=allow_outer_lookup, + context="pto.vldus source", + ) + align_expr = self._analyze_expr(expr.args[1], env, allow_outer_lookup=allow_outer_lookup) + return self._analyze_vldus((base, *indices, align_expr)) + if expr.namespace == "pto" and expr.name == "vldsx2" and len(expr.args) == 2: + base, indices = self._analyze_tile_vector_access( + expr.args[0], + env, + allow_outer_lookup=allow_outer_lookup, + context="pto.vldsx2 source", + ) + dist = self._analyze_expr(expr.args[1], env, allow_outer_lookup=allow_outer_lookup) + return self._analyze_vldsx2((base, *indices, dist)) + if expr.namespace == "pto" and expr.name == "vcvt": + return self._analyze_vcvt_frontend_call( + expr, + env, + allow_outer_lookup=allow_outer_lookup, + ) + if expr.namespace == "pto" and expr.name == "vtrc": + return self._analyze_vtrc_frontend_call( + expr, + env, + allow_outer_lookup=allow_outer_lookup, + ) + if expr.namespace == "pto" and expr.name == "Tile": + return self._analyze_tile_frontend_call( + expr, + env, + allow_outer_lookup=allow_outer_lookup, + ) + if expr.namespace == "pto" and expr.name in _CUBE_CALL_OPS: + return self._analyze_cube_frontend_call_expr( + expr, + env, + allow_outer_lookup=allow_outer_lookup, + ) + if expr.keywords: + raise TypeError( + f"call surface `{expr.namespace + '.' if expr.namespace else ''}{expr.name}` " + "carries keyword arguments, but semantic keyword handling is not implemented " + "in TileLang DSL v1 yet" + ) + args = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args + ) + return self._analyze_call_expr(expr.namespace, expr.name, args) + raise ValueError(f"unsupported frontend expression {type(expr).__name__}") + + def _analyze_symbol_expr(self, expr: FrontendSymbolExpr) -> SemanticExpr: + if expr.namespace == "pto": + dtype = _DTYPE_SYMBOLS.get(expr.name) + if dtype is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=dtype, + type=SemanticMetaType(kind="dtype"), + ) + mask_type = _MASK_TYPE_SYMBOLS.get(expr.name) + if mask_type is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=mask_type, + type=SemanticMetaType(kind="mask_type"), + ) + if expr.name == "align": + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=align, + type=SemanticMetaType(kind="align_type"), + ) + if expr.name == "PartitionTensorView": + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=expr.name, + type=SemanticMetaType(kind="partition_tensor_view_type"), + ) + if expr.namespace in {"PAT", "pto.PAT", "pto.MaskPattern"}: + pattern = _PATTERN_SYMBOLS.get(expr.name) + if pattern is None and expr.name.startswith("PAT_"): + canonical = expr.name[len("PAT_") :] + pattern = _PATTERN_SYMBOLS.get(canonical) + if pattern is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=pattern, + type=SemanticMetaType(kind="mask_pattern"), + ) + if expr.namespace in {"PIPE", "pto.PIPE", "Pipe", "pto.Pipe"}: + pipe = _PIPE_SYMBOLS.get(expr.name) + if pipe is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=pipe, + type=SemanticMetaType(kind="pipe"), + ) + if expr.namespace in {"EVENT", "pto.EVENT", "Event", "pto.Event"}: + event = _EVENT_SYMBOLS.get(expr.name) + if event is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=event, + type=SemanticMetaType(kind="event"), + ) + if expr.namespace in {"BarrierType", "pto.BarrierType"}: + barrier_type = _BARRIER_TYPE_SYMBOLS.get(expr.name) + if barrier_type is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=barrier_type, + type=SemanticMetaType(kind="barrier_type"), + ) + if expr.namespace in {"MemorySpace", "pto.MemorySpace"}: + memory_space = _MEMORY_SPACE_SYMBOLS.get(expr.name) + if memory_space is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=memory_space, + type=SemanticMetaType(kind="memory_space"), + ) + if expr.namespace in {"PadMode", "pto.PadMode"}: + pad_mode = _PAD_MODE_SYMBOLS.get(expr.name) + if pad_mode is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=pad_mode, + type=SemanticMetaType(kind="pad_mode"), + ) + if expr.namespace in {"BLayout", "pto.BLayout"}: + b_layout = _B_LAYOUT_SYMBOLS.get(expr.name) + if b_layout is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=b_layout, + type=SemanticMetaType(kind="b_layout"), + ) + if expr.namespace in {"SLayout", "pto.SLayout"}: + s_layout = _S_LAYOUT_SYMBOLS.get(expr.name) + if s_layout is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=s_layout, + type=SemanticMetaType(kind="s_layout"), + ) + if expr.namespace in {"PadValue", "pto.PadValue"}: + pad_value = _PAD_VALUE_SYMBOLS.get(expr.name) + if pad_value is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=pad_value, + type=SemanticPadValueType(), + ) + if expr.namespace in {"PredicateDist", "pto.PredicateDist"}: + predicate_dist = _PREDICATE_DIST_SYMBOLS.get(expr.name) + if predicate_dist is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=predicate_dist, + type=SemanticMetaType(kind="predicate_dist"), + ) + if expr.namespace in {"VLoadDist", "pto.VLoadDist"}: + vload_dist = _VLOAD_DIST_SYMBOLS.get(expr.name) + if vload_dist is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=vload_dist, + type=SemanticMetaType(kind="vload_dist"), + ) + if expr.namespace in {"VStoreDist", "pto.VStoreDist"}: + vstore_dist = _VSTORE_DIST_SYMBOLS.get(expr.name) + if vstore_dist is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=vstore_dist, + type=SemanticMetaType(kind="vstore_dist"), + ) + if expr.namespace in {"PredicatePart", "pto.PredicatePart"}: + predicate_part = _PREDICATE_PART_SYMBOLS.get(expr.name) + if predicate_part is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=predicate_part, + type=SemanticMetaType(kind="predicate_part"), + ) + if expr.namespace in {"CmpMode", "pto.CmpMode"}: + cmp_mode = _CMP_MODE_SYMBOLS.get(expr.name) + if cmp_mode is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=cmp_mode, + type=SemanticMetaType(kind="cmp_mode"), + ) + if expr.namespace in {"DeinterleaveDist", "pto.DeinterleaveDist"}: + dist = _DEINTERLEAVE_DIST_SYMBOLS.get(expr.name) + if dist is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=dist, + type=SemanticMetaType(kind="deinterleave_dist"), + ) + if expr.namespace in {"InterleaveDist", "pto.InterleaveDist"}: + dist = _INTERLEAVE_DIST_SYMBOLS.get(expr.name) + if dist is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=dist, + type=SemanticMetaType(kind="interleave_dist"), + ) + if expr.namespace in {"PositionMode", "pto.PositionMode"}: + position_mode = _POSITION_MODE_SYMBOLS.get(expr.name) + if position_mode is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=position_mode, + type=SemanticMetaType(kind="position_mode"), + ) + if expr.namespace in {"OrderMode", "pto.OrderMode"}: + order_mode = _ORDER_MODE_SYMBOLS.get(expr.name) + if order_mode is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=order_mode, + type=SemanticMetaType(kind="order_mode"), + ) + if expr.namespace in {"VcvtRoundMode", "pto.VcvtRoundMode"}: + round_mode = _VCVT_ROUND_MODE_SYMBOLS.get(expr.name) + if round_mode is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=round_mode, + type=SemanticMetaType(kind="vcvt_round_mode"), + ) + if expr.namespace in {"VcvtSatMode", "pto.VcvtSatMode"}: + sat_mode = _VCVT_SAT_MODE_SYMBOLS.get(expr.name) + if sat_mode is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=sat_mode, + type=SemanticMetaType(kind="vcvt_sat_mode"), + ) + if expr.namespace in {"VcvtPartMode", "pto.VcvtPartMode"}: + part_mode = _VCVT_PART_MODE_SYMBOLS.get(expr.name) + if part_mode is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=part_mode, + type=SemanticMetaType(kind="vcvt_part_mode"), + ) + if expr.namespace in {"PostUpdateMode", "pto.PostUpdateMode"}: + post_update_mode = _POST_UPDATE_MODE_SYMBOLS.get(expr.name) + if post_update_mode is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=post_update_mode, + type=SemanticMetaType(kind="post_update_mode"), + ) + if expr.namespace in {"FractalMode", "pto.FractalMode"}: + fractal_mode = _FRACTAL_MODE_SYMBOLS.get(expr.name) + if fractal_mode is not None: + return SemanticSymbolExpr( + namespace=expr.namespace, + name=expr.name, + value=fractal_mode, + type=SemanticMetaType(kind="cube_mode"), + ) + raise TypeError( + f"symbol `{expr.namespace}.{expr.name}` is not supported in TileLang DSL v1" + ) + + def _attribute_type(self, base: SemanticExpr, attr: str) -> SemanticType: + base_type = base.type + if isinstance(base_type, (SemanticTensorViewType, SemanticPartitionTensorViewType)) and attr == "shape": + return SemanticShapeType(rank=base_type.rank) + if isinstance(base_type, (SemanticTensorViewType, SemanticPartitionTensorViewType)) and attr == "strides": + return SemanticShapeType(rank=base_type.rank) + if isinstance(base_type, SemanticTileType) and attr == "shape": + return SemanticShapeType(rank=base_type.rank) + if isinstance(base_type, (SemanticTensorViewType, SemanticPartitionTensorViewType, SemanticTileType)) and attr == "valid_shape": + return SemanticShapeType(rank=base_type.rank) + raise TypeError(f"unsupported attribute access '{attr}' in TileLang DSL v1") + + def _element_type_expr(self, base: SemanticExpr) -> SemanticExpr: + base_type = base.type + if isinstance(base_type, (SemanticTensorViewType, SemanticPartitionTensorViewType, SemanticTileType)): + return SemanticSymbolExpr( + namespace="pto", + name=base_type.element_dtype.name, + value=base_type.element_dtype, + type=SemanticMetaType(kind="dtype"), + ) + raise TypeError("unsupported attribute access 'element_type' in TileLang DSL v1") + + def _rank_expr(self, base: SemanticExpr) -> SemanticExpr: + base_type = base.type + if isinstance(base_type, (SemanticTensorViewType, SemanticPartitionTensorViewType, SemanticTileType)): + return SemanticLiteralExpr(value=base_type.rank, type=SemanticIndexType()) + raise TypeError("unsupported attribute access 'rank' in TileLang DSL v1") + + def _memory_space_expr(self, base: SemanticExpr) -> SemanticExpr: + base_type = base.type + if isinstance(base_type, (SemanticTensorViewType, SemanticPartitionTensorViewType)): + return SemanticSymbolExpr( + namespace="pto", + name=MemorySpace.GM.name, + value=MemorySpace.GM, + type=SemanticMetaType(kind="memory_space"), + ) + if isinstance(base_type, SemanticTileType): + memory_space = MemorySpace.UB if base_type.memory_space is None else MemorySpace(base_type.memory_space) + return SemanticSymbolExpr( + namespace="pto", + name=memory_space.name, + value=memory_space, + type=SemanticMetaType(kind="memory_space"), + ) + raise TypeError("unsupported attribute access 'memory_space' in TileLang DSL v1") + + def _tile_config_expr(self, base: SemanticExpr) -> SemanticExpr: + base_type = base.type + if isinstance(base_type, SemanticTileType): + return SemanticLiteralExpr( + value=base_type.config or TileConfig(), + type=SemanticTileConfigType(element_dtype=base_type.element_dtype), + ) + raise TypeError("unsupported attribute access 'config' in TileLang DSL v1") + + def _tile_pad_value_expr(self, base: SemanticExpr) -> SemanticExpr: + base_type = base.type + if not isinstance(base_type, SemanticTileType): + raise TypeError("unsupported attribute access 'pad_value' in TileLang DSL v1") + config = base_type.config or TileConfig() + return SemanticSymbolExpr( + namespace="pto", + name=config.pad_value.name, + value=config.pad_value, + type=SemanticPadValueType(element_dtype=base_type.element_dtype), + ) + + def _pad_value_eval_expr( + self, + base: SemanticExpr, + dtype_expr: SemanticExpr | None = None, + ) -> SemanticExpr: + if not isinstance(base.type, SemanticPadValueType): + raise TypeError("`eval()` expects a PadValue descriptor in TileLang DSL v1") + element_dtype = base.type.element_dtype + if dtype_expr is not None: + explicit_dtype = self._try_static_value(dtype_expr) + if not isinstance(explicit_dtype, ScalarType): + raise TypeError("PadValue.eval(dtype) expects a TileLang scalar dtype symbol in TileLang DSL v1") + element_dtype = explicit_dtype + if element_dtype is None: + raise TypeError( + "PadValue.eval() requires either a Tile-bound pad descriptor or an explicit dtype argument " + "in TileLang DSL v1" + ) + pad_value = self._try_static_value(base) + if not isinstance(pad_value, PadValue): + raise TypeError("PadValue.eval() expects a statically known PadValue enum in TileLang DSL v1") + pad_scalar = pad_value.eval(element_dtype) + if pad_scalar is None: + raise TypeError( + "PadValue.NULL.eval() is invalid in TileLang DSL v1; " + "guard it with `pto.constexpr(tile.pad_value != pto.PadValue.NULL)` before calling `.eval()`" + ) + return SemanticLiteralExpr( + value=pad_scalar, + type=SemanticScalarType(dtype=element_dtype), + ) + + def _analyze_eval_method( + self, + base: SemanticExpr, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + if len(args) > 1: + raise TypeError("`eval()` accepts at most one positional dtype argument in TileLang DSL v1") + return self._pad_value_eval_expr(base, args[0] if args else None) + + def _tile_config_attr_expr(self, base: SemanticExpr, attr: str) -> SemanticExpr: + config = self._try_static_value(base) + if not isinstance(config, TileConfig): + raise TypeError("Tile config metadata must be statically known in TileLang DSL v1") + if attr == "b_layout": + return SemanticSymbolExpr( + namespace="pto", + name=config.b_layout.name, + value=config.b_layout, + type=SemanticMetaType(kind="b_layout"), + ) + if attr == "s_layout": + return SemanticSymbolExpr( + namespace="pto", + name=config.s_layout.name, + value=config.s_layout, + type=SemanticMetaType(kind="s_layout"), + ) + if attr == "s_fractal_size": + return SemanticLiteralExpr( + value=config.s_fractal_size, + type=SemanticScalarType(dtype=i32), + ) + if attr == "pad_value": + if not isinstance(base.type, SemanticTileConfigType): + raise TypeError( + "TileConfig.pad_value expects a TileConfig value in TileLang DSL v1" + ) + return SemanticSymbolExpr( + namespace="pto", + name=config.pad_value.name, + value=config.pad_value, + type=SemanticPadValueType(element_dtype=base.type.element_dtype), + ) + raise TypeError(f"unsupported TileConfig attribute access '{attr}' in TileLang DSL v1") + + def _analyze_as_ptr_method(self, base: SemanticExpr) -> SemanticExpr: + base_type = base.type + if isinstance(base_type, (SemanticTensorViewType, SemanticPartitionTensorViewType)): + return SemanticCallExpr( + namespace="pto", + name="tensor_view_as_ptr", + args=(base,), + type=SemanticPtrType( + element_dtype=base_type.element_dtype, + memory_space="gm", + ), + ) + if isinstance(base_type, SemanticTileType): + return SemanticCallExpr( + namespace="pto", + name="tile_as_ptr", + args=(base,), + type=SemanticPtrType( + element_dtype=base_type.element_dtype, + memory_space=base_type.memory_space or "ub", + ), + ) + raise TypeError("`as_ptr()` expects a TensorView/PartitionTensorView or Tile value in TileLang DSL v1") + + def _analyze_astype_method(self, base: SemanticExpr, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + if len(args) != 1: + raise TypeError("`astype()` expects exactly 1 positional argument (target dtype) in TileLang DSL v1") + if isinstance(base.type, SemanticVRegType): + target_dtype = self._require_dtype_symbol(args[0], "astype target dtype") + return SemanticCallExpr( + namespace="pto", + name="vbitcast", + args=(base, args[0]), + type=self._vreg_type_for_dtype(target_dtype), + ) + if isinstance(base.type, SemanticMaskType): + target_mask_type = self._require_mask_type_expr(args[0], "astype target dtype") + return SemanticCallExpr( + namespace="pto", + name="pbitcast", + args=(base, args[0]), + type=SemanticMaskType(granularity=target_mask_type.granularity), + ) + raise TypeError("`astype()` expects a vector register or mask value in TileLang DSL v1") + + def _valid_shape_expr(self, base: SemanticExpr) -> SemanticExpr: + base_type = base.type + if not isinstance(base_type, (SemanticTensorViewType, SemanticPartitionTensorViewType, SemanticTileType)): + raise TypeError("unsupported attribute access 'valid_shape' in TileLang DSL v1") + shape_access = SemanticAttributeAccess( + base=base, + attr="valid_shape", + type=SemanticShapeType(rank=base_type.rank), + ) + elements = [] + for axis in range(base_type.rank): + if ( + isinstance(base, SemanticBindingRef) + and isinstance(base.type, SemanticTileType) + and base.type.valid_shape is not None + and base.type.valid_shape[axis] is None + ): + self._ensure_tile_valid_shape_parameter(base.binding, axis) + elements.append( + SemanticSubscriptAccess( + base=shape_access, + index=SemanticLiteralExpr(value=axis, type=SemanticIndexType()), + type=SemanticIndexType(), + ) + ) + return SemanticTupleExpr( + elements=tuple(elements), + type=SemanticTupleType(elements=tuple(SemanticIndexType() for _ in elements)), + ) + + def _strides_expr(self, base: SemanticExpr) -> SemanticExpr: + base_type = base.type + if not isinstance(base_type, (SemanticTensorViewType, SemanticPartitionTensorViewType)): + raise TypeError("unsupported attribute access 'strides' in TileLang DSL v1") + stride_access = SemanticAttributeAccess( + base=base, + attr="strides", + type=SemanticShapeType(rank=base_type.rank), + ) + elements = [] + for axis in range(base_type.rank): + elements.append( + SemanticSubscriptAccess( + base=stride_access, + index=SemanticLiteralExpr(value=axis, type=SemanticIndexType()), + type=SemanticIndexType(), + ) + ) + return SemanticTupleExpr( + elements=tuple(elements), + type=SemanticTupleType(elements=tuple(SemanticIndexType() for _ in elements)), + ) + + def _subscript_type(self, base: SemanticExpr, index: SemanticExpr) -> SemanticType: + if isinstance(base.type, SemanticShapeType): + if not isinstance(index.type, SemanticIndexType): + raise TypeError("shape subscript index must be an index value in TileLang DSL v1") + if not isinstance(index, SemanticLiteralExpr) or not isinstance(index.value, int): + raise TypeError( + "shape/stride/valid_shape subscript index must be an integer literal in TileLang DSL v1" + ) + if index.value < 0 or index.value >= base.type.rank: + raise TypeError( + f"shape subscript index {index.value} is out of bounds for rank {base.type.rank}" + ) + return SemanticIndexType() + if isinstance(base.type, SemanticTupleType): + if not isinstance(index.type, SemanticIndexType): + raise TypeError("tuple subscript index must be an index value in TileLang DSL v1") + if not isinstance(base, SemanticTupleExpr): + raise TypeError( + "tuple subscripting currently requires a shape-like tuple expression in TileLang DSL v1" + ) + if not base.type.elements: + raise TypeError("cannot subscript an empty tuple in TileLang DSL v1") + if not isinstance(index, SemanticLiteralExpr) or not isinstance(index.value, int): + raise TypeError("tuple subscript index must be an integer literal in TileLang DSL v1") + + if index.value < 0 or index.value >= len(base.type.elements): + raise TypeError( + f"tuple subscript index {index.value} is out of bounds for tuple length {len(base.type.elements)}" + ) + return base.type.elements[index.value] + if isinstance(base.type, (SemanticTensorViewType, SemanticPartitionTensorViewType)): + if not isinstance(index, SemanticTupleExpr): + raise TypeError("TensorView slicing expects a tuple of slices in TileLang DSL v1") + return self._tensor_slice_type(base.type, index) + raise TypeError("unsupported subscript base in TileLang DSL v1") + + def _analyze_tile_vector_access( + self, + expr: FrontendExprNode, + env: dict[str, SemanticBinding], + *, + allow_outer_lookup: bool, + context: str, + ) -> tuple[SemanticExpr, tuple[SemanticExpr, ...]]: + if not isinstance(expr, FrontendSubscriptExpr): + raise TypeError( + f"{context} expects Tile element-indexing syntax in TileLang DSL v1" + ) + base = self._analyze_expr(expr.base, env, allow_outer_lookup=allow_outer_lookup) + tile = self._require_tile_expr(base, context) + indices = self._tile_vector_indices( + expr.index, + tile.type, + env, + allow_outer_lookup=allow_outer_lookup, + context=context, + ) + return base, indices + + def _tile_vector_indices( + self, + index_expr: FrontendExprNode, + tile_type: SemanticTileType, + env: dict[str, SemanticBinding], + *, + allow_outer_lookup: bool, + context: str, + ) -> tuple[SemanticExpr, ...]: + if tile_type.rank == 1: + if not isinstance(index_expr, FrontendSliceExpr): + raise TypeError(f"{context} expects Tile[start:] syntax for rank-1 Tile values") + if index_expr.stop is not None: + raise TypeError(f"{context} does not support explicit slice stop in TileLang DSL advanced mode") + if index_expr.step is not None: + raise TypeError(f"{context} does not support stepped Tile vector slices in TileLang DSL advanced mode") + if index_expr.start is None: + return (SemanticLiteralExpr(value=0, type=SemanticIndexType()),) + start = self._analyze_expr(index_expr.start, env, allow_outer_lookup=allow_outer_lookup) + start = self._require_index_typed_expr(start) + return (start,) + + if tile_type.rank != 2 or tile_type.shape is None: + raise TypeError(f"{context} currently only supports statically specialized rank-1 or rank-2 Tiles") + if not isinstance(index_expr, FrontendTupleExpr) or len(index_expr.elements) != 2: + raise TypeError(f"{context} expects Tile[row, col:] syntax for rank-2 Tile values") + + row_expr, col_expr = index_expr.elements + if not isinstance(col_expr, FrontendSliceExpr): + raise TypeError(f"{context} expects Tile[row, col:] syntax for rank-2 Tile values") + if col_expr.stop is not None: + raise TypeError(f"{context} does not support explicit slice stop in TileLang DSL advanced mode") + if col_expr.step is not None: + raise TypeError(f"{context} does not support stepped Tile vector slices in TileLang DSL advanced mode") + + row = self._analyze_expr(row_expr, env, allow_outer_lookup=allow_outer_lookup) + row = self._require_index_typed_expr(row) + if col_expr.start is None: + col = SemanticLiteralExpr(value=0, type=SemanticIndexType()) + else: + col = self._analyze_expr(col_expr.start, env, allow_outer_lookup=allow_outer_lookup) + col = self._require_index_typed_expr(col) + return (row, col) + + def _tensor_slice_type( + self, + tensor_type: SemanticTensorViewType | SemanticPartitionTensorViewType, + index: SemanticTupleExpr, + ) -> SemanticTensorSliceType: + if not 1 <= len(index.elements) <= tensor_type.rank: + raise TypeError( + f"TensorView slice rank {len(index.elements)} must be between 1 and " + f"{tensor_type.rank} in TileLang DSL v1" + ) + axis_offset = tensor_type.rank - len(index.elements) + extents = [] + for axis, element in enumerate(index.elements): + if not isinstance(element, SemanticSliceExpr): + raise TypeError( + f"TensorView slicing axis {axis} must use a Python slice in TileLang DSL v1" + ) + self._require_optional_index_typed_expr(element.start) + self._require_optional_index_typed_expr(element.stop) + self._require_optional_index_typed_expr(element.step) + + if element.stop is None: + raise TypeError("TensorView slicing requires explicit stop bounds in TileLang DSL v1") + extents.append(self._normalized_tensor_slice_extent(element)) + return SemanticTensorSliceType( + element_dtype=tensor_type.element_dtype, + rank=len(index.elements), + extents=tuple(extents), + physical_axes=tuple(range(axis_offset, tensor_type.rank)), + ) + + def _normalize_tensor_slice( + self, + index: SemanticExpr, + rank: int, + ) -> tuple[SemanticTensorSliceAxis, ...]: + if not isinstance(index, SemanticTupleExpr): + raise TypeError("TensorView slicing expects a tuple index in TileLang DSL v1") + if not 1 <= len(index.elements) <= rank: + raise TypeError( + f"TensorView slicing expects between 1 and {rank} slice elements in TileLang DSL v1" + ) + slices = [] + for element in index.elements: + if not isinstance(element, SemanticSliceExpr): + raise TypeError("TensorView slicing only supports slice syntax in TileLang DSL v1") + if element.stop is None: + raise TypeError("TensorView slicing requires explicit stop bounds in TileLang DSL v1") + start = self._normalize_optional_index_expr(element.start, default=0) + stop = element.stop + step = self._normalize_optional_index_expr(element.step, default=1) + slices.append( + SemanticTensorSliceAxis( + start=start, + stop=stop, + step=step, + extent=self._normalized_tensor_slice_extent(element), + ) + ) + return tuple(slices) + + def _binary_type( + self, + lhs: SemanticExpr, + rhs: SemanticExpr, + op: str, + ) -> SemanticType: + mixed_index_scalar_type = self._mixed_index_integer_scalar_type(lhs.type, rhs.type) + if op in {"add", "sub", "mul", "mod", "floordiv", "bitand", "bitor", "bitxor", "lshift", "rshift"}: + if isinstance(lhs.type, SemanticIndexType) and isinstance(rhs.type, SemanticIndexType): + if op in {"add", "sub", "mul", "mod", "floordiv"}: + return SemanticIndexType() + if isinstance(lhs.type, SemanticScalarType) and lhs.type == rhs.type: + dtype = lhs.type.dtype + if op in {"add", "sub", "mul"} and (is_integer_dtype(dtype) or is_float_dtype(dtype)): + return SemanticScalarType(dtype=dtype) + if op in {"mod", "floordiv"} and is_integer_dtype(dtype): + return SemanticScalarType(dtype=dtype) + if op in {"bitand", "bitor", "bitxor", "lshift", "rshift"} and is_integer_dtype(dtype): + return SemanticScalarType(dtype=dtype) + if mixed_index_scalar_type is not None and op in {"add", "sub", "mul", "mod", "floordiv"}: + return mixed_index_scalar_type + raise TypeError( + "binary expressions currently require matching index operands, " + "matching scalar operands (add/sub/mul for integer/float; " + "mod/floordiv/bitwise/shift for integer), or index operands " + "mixed with integer scalars for add/sub/mul/mod/floordiv" + ) + if op in {"eq", "ne"}: + if isinstance(lhs.type, SemanticIndexType) and isinstance(rhs.type, SemanticIndexType): + return SemanticScalarType(dtype=i1) + if isinstance(lhs.type, SemanticScalarType) and lhs.type == rhs.type: + return SemanticScalarType(dtype=i1) + if mixed_index_scalar_type is not None: + return SemanticScalarType(dtype=i1) + if isinstance(lhs.type, SemanticPadValueType) and isinstance(rhs.type, SemanticPadValueType): + return SemanticScalarType(dtype=i1) + if isinstance(lhs.type, SemanticMetaType) and lhs.type == rhs.type: + return SemanticScalarType(dtype=i1) + raise TypeError( + "comparison expressions currently require matching scalar/meta types, " + "index-typed operands, or index operands mixed with integer scalars" + ) + if op in {"gt", "lt", "ge", "le"}: + if isinstance(lhs.type, SemanticIndexType) and isinstance(rhs.type, SemanticIndexType): + return SemanticScalarType(dtype=i1) + if isinstance(lhs.type, SemanticScalarType) and lhs.type == rhs.type: + return SemanticScalarType(dtype=i1) + if mixed_index_scalar_type is not None: + return SemanticScalarType(dtype=i1) + raise TypeError( + "ordered comparison expressions currently require matching scalar types, " + "index-typed operands, or index operands mixed with integer scalars" + ) + if op in {"and", "or"}: + self._require_condition_type(lhs.type) + self._require_condition_type(rhs.type) + return SemanticScalarType(dtype=i1) + raise TypeError(f"unsupported binary operator '{op}' in TileLang DSL v1") + + def _retarget_literals_for_binary_op( + self, + lhs: SemanticExpr, + rhs: SemanticExpr, + op: str, + ) -> tuple[SemanticExpr, SemanticExpr]: + if isinstance(lhs.type, SemanticScalarType): + rhs = self._retarget_literal_to_scalar_type_for_binary_op(rhs, lhs.type.dtype, op) + if isinstance(rhs.type, SemanticScalarType): + lhs = self._retarget_literal_to_scalar_type_for_binary_op(lhs, rhs.type.dtype, op) + return lhs, rhs + + def _retarget_literal_to_scalar_type_for_binary_op( + self, + expr: SemanticExpr, + target_dtype: ScalarType, + op: str, + ) -> SemanticExpr: + if not isinstance(expr, SemanticLiteralExpr): + return expr + if not self._binary_op_supports_scalar_dtype(op, target_dtype): + return expr + if is_integer_dtype(target_dtype): + if not isinstance(expr.type, SemanticIndexType): + return expr + if not isinstance(expr.value, int) or isinstance(expr.value, bool): + return expr + checked = self._check_integer_literal_range(expr.value, target_dtype, f"{target_dtype!r} literal") + retargeted = SemanticLiteralExpr( + value=checked, + type=SemanticScalarType(dtype=target_dtype), + ) + source_location = self._expr_source_location(expr) + if source_location is not None: + object.__setattr__(retargeted, "source_location", source_location) + return retargeted + if is_float_dtype(target_dtype): + if isinstance(expr.value, bool) or not isinstance(expr.value, (int, float)): + return expr + if isinstance(expr.type, SemanticScalarType) and not is_float_dtype(expr.type.dtype): + return expr + if not isinstance(expr.type, (SemanticIndexType, SemanticScalarType)): + return expr + retargeted = SemanticLiteralExpr( + value=float(expr.value), + type=SemanticScalarType(dtype=target_dtype), + ) + source_location = self._expr_source_location(expr) + if source_location is not None: + object.__setattr__(retargeted, "source_location", source_location) + return retargeted + return expr + + def _binary_op_supports_scalar_dtype(self, op: str, dtype: ScalarType) -> bool: + if is_integer_dtype(dtype): + if op in {"add", "sub", "mul", "eq", "ne", "gt", "lt", "ge", "le"}: + return True + if op in {"mod", "floordiv", "bitand", "bitor", "bitxor", "lshift", "rshift"}: + return True + return False + if is_float_dtype(dtype): + return op in {"add", "sub", "mul", "eq", "ne", "gt", "lt", "ge", "le"} + return False + + def _mixed_index_integer_scalar_type( + self, + lhs_type: SemanticType, + rhs_type: SemanticType, + ) -> SemanticScalarType | None: + scalar_type: SemanticScalarType | None = None + if isinstance(lhs_type, SemanticIndexType) and isinstance(rhs_type, SemanticScalarType): + scalar_type = rhs_type + elif isinstance(rhs_type, SemanticIndexType) and isinstance(lhs_type, SemanticScalarType): + scalar_type = lhs_type + if scalar_type is None or not is_integer_dtype(scalar_type.dtype): + return None + if integer_bitwidth(scalar_type.dtype) not in {8, 16, 32, 64}: + return None + return scalar_type + + def _analyze_call_expr( + self, + namespace: str | None, + name: str, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + if namespace is None and name == "range": + return SemanticCallExpr(namespace=namespace, name=name, args=args, type=None) + if namespace is None: + if name in self._inline_proc_nodes: + return self._analyze_inline_proc_call_expr(name, args) + if name in self._internal_inline_proc_nodes and self._is_internal_inline_proc_context(): + return self._analyze_internal_inline_proc_call_expr(name, args) + raise TypeError( + f"call surface `{name}` is not supported in TileLang DSL v1" + ) + if namespace != "pto": + raise TypeError( + f"call surface `{namespace + '.' if namespace else ''}{name}` is not supported in TileLang DSL v1 yet" + ) + if name in DEFERRED_PTO_SURFACES: + raise TypeError(deferred_surface_message(name)) + if name in _DTYPE_SYMBOLS: + return self._analyze_scalar_constructor(name, args) + if name == "Tile": + raise TypeError( + "pto.Tile(...) requires dedicated keyword-aware semantic handling in TileLang DSL v1" + ) + if name == "ptr": + return self._analyze_ptr_type(args) + if name == "vreg": + return self._analyze_vreg_type(args) + if name == "vector": + return self._analyze_vector_type(args) + if name == "castptr": + return self._analyze_castptr(args) + if name == "addptr": + return self._analyze_addptr(args) + if name == "bytewidth": + return self._analyze_bytewidth(args) + if name in {"get_lanes", "elements_per_vreg"}: + return self._analyze_get_lanes(args, call_name=name) + if name == "get_op_attr": + return self._analyze_get_op_attr(args) + if name == "constexpr": + raise TypeError( + "pto.constexpr(...) is only supported as an if-condition wrapper in TileLang DSL v1" + ) + if name == "make_mask": + return self._analyze_make_mask(args) + if name in { + "get_block_idx", + "get_subblock_idx", + "get_block_num", + "get_subblock_num", + }: + return self._analyze_runtime_block_query(name, args) + if name == "init_align": + return self._analyze_init_align(args) + if name == "vlds": + return self._analyze_vlds(args) + if name == "vldas": + return self._analyze_vldas(args) + if name == "vldus": + return self._analyze_vldus(args) + if name == "vldsx2": + return self._analyze_vldsx2(args) + if name in {"pset_b8", "pset_b16", "pset_b32", "pge_b8", "pge_b16", "pge_b32"}: + return self._analyze_predicate_pattern_op(name, args) + if name in {"plt_b8", "plt_b16", "plt_b32"}: + return self._analyze_predicate_tail_op(name, args) + if name in {"plds", "pld", "pldi"}: + return self._analyze_predicate_load_op(name, args) + if name == "pstu": + return self._analyze_pstu(args) + if name == "vstus": + return self._analyze_vstus(args) + if name == "vstur": + return self._analyze_vstur(args) + if name == "load_scalar": + return self._analyze_load_scalar(args) + if name in {"ppack", "punpack"}: + return self._analyze_mask_part_op(name, args) + if name in {"pnot", "psel", "pand", "por", "pxor"}: + return self._analyze_mask_logic_op(name, args) + if name in {"pdintlv_b8", "pdintlv_b16", "pdintlv_b32", "pintlv_b8", "pintlv_b16", "pintlv_b32"}: + return self._analyze_predicate_reorder_op(name, args) + if name in {"vcmp", "vcmps"}: + return self._analyze_compare_op(name, args) + if name in {"vsel", "vselr", "vselrv2"}: + return self._analyze_select_op(name, args) + if name in {"vaddc", "vsubc", "vaddcs", "vsubcs"}: + return self._analyze_carry_op(name, args) + if name in {"vintlv", "vdintlv", "vintlvv2", "vdintlvv2"}: + return self._analyze_rearrangement_op(name, args) + if name == "vpack": + return self._analyze_vpack_op(args) + if name == "vcvt": + return self._analyze_vcvt(args) + if name == "vbitcast": + return self._analyze_vbitcast(args) + if name == "pbitcast": + return self._analyze_pbitcast(args) + if name == "vtrc": + return self._analyze_vtrc(args) + if name == "vbitsort": + return self._analyze_vbitsort(args) + if name == "vmrgsort4": + return self._analyze_vmrgsort4(args) + if name == "get_vms4_sr": + return self._analyze_get_vms4_sr(args) + if name in _BROADCAST_VECTOR_OPS: + return self._analyze_broadcast_vector_op(name, args) + if name in _MULTI_RESULT_VECTOR_OPS: + return self._analyze_multi_result_vector_op(name, args) + if name in _VEXPDIF_OP_ALIASES: + return self._analyze_vexpdif_op(args) + if name in _UNARY_VECTOR_OPS: + return self._analyze_unary_vector_op(name, args) + if name in _BINARY_VECTOR_OPS: + return self._analyze_binary_vector_op(name, args) + if name in _VECTOR_SCALAR_OPS: + return self._analyze_vector_scalar_op(name, args) + if name in _VECTOR_IMMEDIATE_OPS: + return self._analyze_vector_immediate_op(name, args) + if name in _TERNARY_VECTOR_OPS: + return self._analyze_ternary_vector_op(name, args) + raise TypeError(f"call surface `pto.{name}` is not supported in TileLang DSL v1 yet") + + def _analyze_cube_frontend_call_expr( + self, + expr: FrontendCallExpr, + env: dict[str, SemanticBinding], + *, + allow_outer_lookup: bool, + ) -> SemanticExpr: + args = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args + ) + keywords = self._analyze_keyword_args( + expr.keywords, + env, + allow_outer_lookup=allow_outer_lookup, + context=f"pto.{expr.name}", + ) + + if expr.name in _CUBE_MATMUL_OPS: + return self._analyze_cube_mad_like_op(expr.name, args, keywords) + if expr.name in {"cube_load", "cube_store"}: + return self._analyze_cube_load_store(expr.name, args, keywords) + if expr.name == "cube_load_frac": + return self._analyze_cube_load_frac(args, keywords) + if expr.name == "bias_load": + return self._analyze_cube_bias_load(args, keywords) + if expr.name in {"left_load", "right_load", "left_load_mx", "right_load_mx"}: + return self._analyze_cube_stage_load(expr.name, args, keywords) + if expr.name in {"acc_store", "acc_store_gm", "acc_store_ub"}: + return self._analyze_cube_acc_store(expr.name, args, keywords) + raise TypeError(f"call surface `pto.{expr.name}` is not supported in TileLang DSL v1 yet") + + def _analyze_keyword_args( + self, + keywords: tuple[tuple[str, FrontendExprNode], ...], + env: dict[str, SemanticBinding], + *, + allow_outer_lookup: bool, + context: str, + allowed_keywords: set[str] | None = None, + ) -> dict[str, SemanticExpr]: + analyzed: dict[str, SemanticExpr] = {} + seen: set[str] = set() + for keyword_name, keyword_value in keywords: + if keyword_name in seen: + raise TypeError(f"duplicate keyword `{keyword_name}` for {context} in TileLang DSL v1") + if allowed_keywords is not None and keyword_name not in allowed_keywords: + allowed_text = ", ".join(sorted(allowed_keywords)) + raise TypeError( + f"{context} only accepts keyword(s) {allowed_text} in TileLang DSL v1; " + f"got unsupported keyword `{keyword_name}`" + ) + analyzed[keyword_name] = self._analyze_expr( + keyword_value, + env, + allow_outer_lookup=allow_outer_lookup, + ) + seen.add(keyword_name) + return analyzed + + def _require_semantic_tuple_expr( + self, + expr: SemanticExpr, + context: str, + *, + exact_len: int | None = None, + min_len: int | None = None, + max_len: int | None = None, + ) -> tuple[SemanticExpr, ...]: + if not isinstance(expr, SemanticTupleExpr): + raise TypeError(f"{context} must be a tuple or list literal in TileLang DSL v1") + elements = expr.elements + if exact_len is not None and len(elements) != exact_len: + raise TypeError(f"{context} expects exactly {exact_len} elements in TileLang DSL v1") + if min_len is not None and len(elements) < min_len: + raise TypeError(f"{context} expects at least {min_len} elements in TileLang DSL v1") + if max_len is not None and len(elements) > max_len: + raise TypeError(f"{context} expects at most {max_len} elements in TileLang DSL v1") + return elements + + def _require_cube_pointer_expr( + self, + expr: SemanticExpr, + context: str, + *, + memory_space: str, + ) -> SemanticPtrType: + ptr = self._require_pointer_expr(expr, context, memory_space=memory_space) + return ptr.type + + def _require_matching_cube_pointer_element_dtypes( + self, + lhs: SemanticExpr, + rhs: SemanticExpr, + context: str, + ) -> None: + lhs_dtype = lhs.type.element_dtype + rhs_dtype = rhs.type.element_dtype + if lhs_dtype is None or rhs_dtype is None: + return + if lhs_dtype != rhs_dtype: + raise TypeError(f"{context} requires source/destination pointer element dtypes to match") + + def _require_cube_i64_tuple( + self, + expr: SemanticExpr, + context: str, + *, + exact_len: int | None = None, + min_len: int | None = None, + max_len: int | None = None, + ) -> SemanticTupleExpr: + elements = self._require_semantic_tuple_expr( + expr, + context, + exact_len=exact_len, + min_len=min_len, + max_len=max_len, + ) + for element in elements: + self._require_i64_like_expr(element, context) + return SemanticTupleExpr(elements=elements, type=SemanticTupleType(elements=tuple(element.type for element in elements))) + + def _require_cube_optional_none( + self, + expr: SemanticExpr | None, + context: str, + ) -> SemanticExpr | None: + if expr is None: + return self._missing_optional_meta_expr() + if isinstance(expr, SemanticLiteralExpr) and expr.value is None: + return self._missing_optional_meta_expr() + raise TypeError(f"{context} must be omitted or `None` in TileLang DSL v1") + + def _cube_keyword_or_default( + self, + keywords: dict[str, SemanticExpr], + name: str, + default: SemanticExpr, + ) -> SemanticExpr: + return keywords.get(name, default) + + def _is_none_literal_expr(self, expr: SemanticExpr | None) -> bool: + return isinstance(expr, SemanticLiteralExpr) and expr.value is None + + def _normalize_cube_mode( + self, + expr: SemanticExpr, + context: str, + allowed_modes: set[str], + ) -> SemanticExpr: + if ( + isinstance(expr, SemanticSymbolExpr) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "cube_mode" + and isinstance(expr.value, FractalMode) + ): + mode = expr.value.value + elif ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "cube_mode" + and isinstance(expr.binding.value, FractalMode) + ): + mode = expr.binding.value.value + else: + raise TypeError(f"{context} must be a FractalMode enum in TileLang DSL v1") + if mode not in allowed_modes: + allowed_text = " or ".join(f'\"{value}\"' for value in sorted(allowed_modes)) + raise TypeError(f"{context} must be {allowed_text} in TileLang DSL v1") + return SemanticLiteralExpr(value=mode, type=SemanticMetaType(kind="string")) + + def _normalize_cube_loop_groups( + self, + expr: SemanticExpr | None, + context: str, + ) -> SemanticExpr: + if expr is None: + return self._missing_optional_meta_expr() + if isinstance(expr, SemanticLiteralExpr) and expr.value is None: + return self._missing_optional_meta_expr() + loops = self._require_semantic_tuple_expr(expr, context) + normalized_loops = [] + for index, loop_expr in enumerate(loops): + loop_context = f"{context}[{index}]" + normalized_loops.append(self._require_cube_i64_tuple(loop_expr, loop_context, exact_len=3)) + return SemanticTupleExpr(elements=tuple(normalized_loops), type=SemanticTupleType(elements=tuple(loop.type for loop in normalized_loops))) + + def _analyze_cube_mad_like_op( + self, + name: str, + args: tuple[SemanticExpr, ...], + keywords: dict[str, SemanticExpr], + ) -> SemanticExpr: + expected_argc = 7 if "bias" in name else 6 + if len(args) != expected_argc: + raise TypeError(f"pto.{name} expects exactly {expected_argc} positional arguments in TileLang DSL v1") + lhs = self._require_pointer_expr(args[0], f"pto.{name} lhs", memory_space="left") + rhs = self._require_pointer_expr(args[1], f"pto.{name} rhs", memory_space="right") + dst = self._require_pointer_expr(args[2], f"pto.{name} dst", memory_space="acc") + self._require_matching_cube_pointer_element_dtypes( + lhs, + rhs, + f"pto.{name}", + ) + if "bias" in name: + bias = self._require_pointer_expr(args[3], f"pto.{name} bias", memory_space="bias") + self._require_matching_cube_pointer_element_dtypes( + bias, + dst, + f"pto.{name}", + ) + m_index = 4 if "bias" in name else 3 + self._require_i64_like_expr(args[m_index], f"pto.{name} m") + self._require_i64_like_expr(args[m_index + 1], f"pto.{name} n") + self._require_i64_like_expr(args[m_index + 2], f"pto.{name} k") + allowed_keywords = {"unit_flag_ctrl", "disable_gemv"} + unsupported_keywords = sorted(set(keywords) - allowed_keywords) + if unsupported_keywords: + raise TypeError( + f"pto.{name} only accepts keyword(s) unit_flag_ctrl, disable_gemv in TileLang DSL v1; " + f"got unsupported keyword(s): {', '.join(unsupported_keywords)}" + ) + unit_flag_ctrl = self._require_scalar_or_index_expr( + self._cube_keyword_or_default(keywords, "unit_flag_ctrl", SemanticLiteralExpr(value=0, type=SemanticIndexType())), + f"pto.{name} unit_flag_ctrl", + ) + self._require_i64_like_expr(unit_flag_ctrl, f"pto.{name} unit_flag_ctrl") + disable_gemv_expr = self._cube_keyword_or_default( + keywords, + "disable_gemv", + SemanticLiteralExpr(value=False, type=SemanticScalarType(dtype=i1)), + ) + if not isinstance(disable_gemv_expr.type, SemanticScalarType) or disable_gemv_expr.type.dtype != i1: + raise TypeError(f"pto.{name} disable_gemv must be an i1/bool value in TileLang DSL v1") + return SemanticCallExpr( + namespace="pto", + name=name, + args=args + (unit_flag_ctrl, disable_gemv_expr), + type=None, + ) + + def _analyze_cube_load_store( + self, + name: str, + args: tuple[SemanticExpr, ...], + keywords: dict[str, SemanticExpr], + ) -> SemanticExpr: + if len(args) != 3: + raise TypeError(f"pto.{name} expects exactly 3 positional arguments in TileLang DSL v1") + src = self._require_pointer_expr( + args[0], + f"pto.{name} source", + memory_space="gm" if name == "cube_load" else "mat", + ) + dst = self._require_pointer_expr( + args[1], + f"pto.{name} destination", + memory_space="mat" if name == "cube_load" else "ub", + ) + self._require_matching_cube_pointer_element_dtypes( + src, + dst, + f"pto.{name}", + ) + self._require_i64_like_expr(args[2], f"pto.{name} len_burst") + allowed_keywords = {"nburst", "loops"} + unsupported_keywords = sorted(set(keywords) - allowed_keywords) + if unsupported_keywords: + raise TypeError( + f"pto.{name} only accepts keyword(s) nburst, loops in TileLang DSL v1; " + f"got unsupported keyword(s): {', '.join(unsupported_keywords)}" + ) + nburst_expr = keywords.get("nburst", SemanticTupleExpr( + elements=( + SemanticLiteralExpr(value=1, type=SemanticIndexType()), + SemanticLiteralExpr(value=0, type=SemanticIndexType()), + SemanticLiteralExpr(value=0, type=SemanticIndexType()), + ), + type=SemanticTupleType(elements=(SemanticIndexType(), SemanticIndexType(), SemanticIndexType())), + )) + if "nburst" in keywords: + nburst_expr = self._require_cube_i64_tuple(keywords["nburst"], f"pto.{name} nburst", exact_len=3) + loops_expr = self._normalize_cube_loop_groups(keywords.get("loops"), f"pto.{name} loops") + return SemanticCallExpr( + namespace="pto", + name=name, + args=(args[0], args[1], args[2], nburst_expr, loops_expr), + type=None, + ) + + def _analyze_cube_load_frac( + self, + args: tuple[SemanticExpr, ...], + keywords: dict[str, SemanticExpr], + ) -> SemanticExpr: + if len(args) != 3: + raise TypeError("pto.cube_load_frac expects exactly 3 positional arguments in TileLang DSL v1") + src = self._require_pointer_expr(args[0], "pto.cube_load_frac source", memory_space="gm") + dst = self._require_pointer_expr(args[1], "pto.cube_load_frac destination", memory_space="mat") + mode = self._normalize_cube_mode( + args[2], + "pto.cube_load_frac mode", + {FractalMode.ND2NZ.value, FractalMode.DN2NZ.value}, + ) + self._require_matching_cube_pointer_element_dtypes( + src, + dst, + "pto.cube_load_frac", + ) + allowed_keywords = {"shape", "src_layout", "dst_group", "ctrl"} + unsupported = ", ".join(sorted(set(keywords) - allowed_keywords)) + if unsupported: + raise TypeError( + "pto.cube_load_frac only accepts keyword(s) shape, src_layout, dst_group, ctrl " + f"in TileLang DSL v1; got unsupported keyword(s): {unsupported}" + ) + missing = sorted(allowed_keywords - set(keywords)) + if missing: + raise TypeError( + f"pto.cube_load_frac requires keyword(s) {', '.join(missing)} in TileLang DSL v1" + ) + shape = self._require_cube_i64_tuple(keywords["shape"], "pto.cube_load_frac shape", exact_len=2) + src_layout = self._require_cube_i64_tuple(keywords["src_layout"], "pto.cube_load_frac src_layout", min_len=1, max_len=2) + dst_group = self._require_cube_i64_tuple(keywords["dst_group"], "pto.cube_load_frac dst_group", exact_len=4) + ctrl = self._require_semantic_tuple_expr(keywords["ctrl"], "pto.cube_load_frac ctrl", exact_len=2) + self._require_i64_like_expr(ctrl[0], "pto.cube_load_frac ctrl") + if not ( + isinstance(ctrl[1].type, SemanticScalarType) + and ctrl[1].type.dtype == i1 + ): + raise TypeError("pto.cube_load_frac ctrl smallc0_en must be an i1/bool value in TileLang DSL v1") + ctrl_expr = SemanticTupleExpr( + elements=ctrl, + type=SemanticTupleType(elements=tuple(element.type for element in ctrl)), + ) + return SemanticCallExpr( + namespace="pto", + name="cube_load_frac", + args=(args[0], args[1], mode, shape, src_layout, dst_group, ctrl_expr), + type=None, + ) + + def _analyze_cube_bias_load( + self, + args: tuple[SemanticExpr, ...], + keywords: dict[str, SemanticExpr], + ) -> SemanticExpr: + if len(args) != 3: + raise TypeError("pto.bias_load expects exactly 3 positional arguments in TileLang DSL v1") + src = self._require_pointer_expr(args[0], "pto.bias_load source", memory_space="mat") + dst = self._require_pointer_expr(args[1], "pto.bias_load destination", memory_space="bias") + allowed_pairs = { + ("f32", "f32"), + ("i32", "i32"), + ("f16", "f32"), + ("bf16", "f32"), + } + if src.type.element_dtype is not None and dst.type.element_dtype is not None and ( + src.type.element_dtype.name, + dst.type.element_dtype.name, + ) not in allowed_pairs: + raise TypeError( + "pto.bias_load only supports f32->f32, i32->i32, f16->f32, and bf16->f32 in TileLang DSL v1" + ) + self._require_i64_like_expr(args[2], "pto.bias_load len_burst") + allowed_keywords = {"nburst"} + unsupported_keywords = sorted(set(keywords) - allowed_keywords) + if unsupported_keywords: + raise TypeError( + f"pto.bias_load only accepts keyword(s) nburst in TileLang DSL v1; " + f"got unsupported keyword(s): {', '.join(unsupported_keywords)}" + ) + nburst_expr = keywords.get("nburst", SemanticTupleExpr( + elements=( + SemanticLiteralExpr(value=1, type=SemanticIndexType()), + SemanticLiteralExpr(value=0, type=SemanticIndexType()), + SemanticLiteralExpr(value=0, type=SemanticIndexType()), + ), + type=SemanticTupleType(elements=(SemanticIndexType(), SemanticIndexType(), SemanticIndexType())), + )) + if "nburst" in keywords: + nburst_expr = self._require_cube_i64_tuple(keywords["nburst"], "pto.bias_load nburst", exact_len=3) + return SemanticCallExpr( + namespace="pto", + name="bias_load", + args=(args[0], args[1], args[2], nburst_expr), + type=None, + ) + + def _analyze_cube_stage_load( + self, + name: str, + args: tuple[SemanticExpr, ...], + keywords: dict[str, SemanticExpr], + ) -> SemanticExpr: + if keywords: + raise TypeError(f"pto.{name} does not accept keyword arguments in TileLang DSL v1") + if len(args) != 4: + raise TypeError(f"pto.{name} expects exactly 4 positional arguments in TileLang DSL v1") + src = self._require_pointer_expr(args[0], f"pto.{name} source", memory_space="mat") + dst_space = "left" if name.startswith("left") else "right" + dst = self._require_pointer_expr(args[1], f"pto.{name} destination", memory_space=dst_space) + self._require_matching_cube_pointer_element_dtypes( + src, + dst, + f"pto.{name}", + ) + self._require_i64_like_expr(args[2], f"pto.{name} first dimension") + self._require_i64_like_expr(args[3], f"pto.{name} second dimension") + return SemanticCallExpr(namespace="pto", name=name, args=args, type=None) + + def _analyze_cube_acc_store( + self, + name: str, + args: tuple[SemanticExpr, ...], + keywords: dict[str, SemanticExpr], + ) -> SemanticExpr: + if len(args) != 6: + raise TypeError(f"pto.{name} expects exactly 6 positional arguments in TileLang DSL v1") + src = self._require_pointer_expr(args[0], f"pto.{name} source", memory_space="acc") + dst_space = "mat" if name == "acc_store" else "gm" if name == "acc_store_gm" else "ub" + dst = self._require_pointer_expr(args[1], f"pto.{name} destination", memory_space=dst_space) + for index, label in enumerate(("m", "n", "src_stride", "dst_stride"), start=2): + self._require_i64_like_expr(args[index], f"pto.{name} {label}") + + allowed_keywords = { + "mode", + "loop0_src_stride", + "split", + "loop3", + } + if name == "acc_store_gm": + allowed_keywords |= {"sid", "l2_cache_ctrl"} + if name == "acc_store_ub": + allowed_keywords = { + "mode", + "loop0_src_stride", + "channel_split_en", + "loop3", + "dual_dst_mode", + "sub_blockid", + } + unsupported = sorted(set(keywords) - allowed_keywords) + if unsupported: + raise TypeError( + f"pto.{name} only accepts keyword(s) {', '.join(sorted(allowed_keywords))} in TileLang DSL v1; " + f"got unsupported keyword(s): {', '.join(unsupported)}" + ) + + mode = self._normalize_cube_mode( + keywords.get( + "mode", + SemanticSymbolExpr( + namespace="pto.FractalMode", + name="NZ2ND", + value=FractalMode.NZ2ND, + type=SemanticMetaType(kind="cube_mode"), + ), + ), + f"pto.{name} mode", + {FractalMode.NZ2ND.value, FractalMode.NZ2DN.value, FractalMode.NZ2NZ.value}, + ) + mode_text = self._require_string_expr(mode, f"pto.{name} mode") + + loop0_src_stride = keywords.get("loop0_src_stride") + split_key = "channel_split_en" if name == "acc_store_ub" else "split" + split_expr = keywords.get(split_key) + loop3_expr = keywords.get("loop3") + + if self._is_none_literal_expr(loop0_src_stride): + loop0_src_stride = None + if self._is_none_literal_expr(split_expr): + split_expr = None + if self._is_none_literal_expr(loop3_expr): + loop3_expr = None + + if mode_text == "nz2nd": + if loop0_src_stride is not None: + raise TypeError(f"pto.{name} mode \"nz2nd\" does not accept loop0_src_stride in TileLang DSL v1") + if split_expr is not None: + raise TypeError(f"pto.{name} mode \"nz2nd\" does not accept {split_key} in TileLang DSL v1") + elif mode_text == "nz2dn": + if split_expr is not None: + raise TypeError(f"pto.{name} mode \"nz2dn\" does not accept {split_key} in TileLang DSL v1") + elif mode_text == "nz2nz": + if loop0_src_stride is not None: + raise TypeError(f"pto.{name} mode \"nz2nz\" does not accept loop0_src_stride in TileLang DSL v1") + if split_expr is None: + raise TypeError(f"pto.{name} mode \"nz2nz\" requires {split_key} in TileLang DSL v1") + if loop3_expr is not None: + raise TypeError(f"pto.{name} mode \"nz2nz\" does not accept loop3(...) in TileLang DSL v1") + + if loop0_src_stride is not None: + self._require_i64_like_expr(loop0_src_stride, f"pto.{name} loop0_src_stride") + if split_expr is not None: + self._require_i64_like_expr(split_expr, f"pto.{name} {split_key}") + if loop3_expr is not None: + loop3_expr = self._require_cube_i64_tuple(loop3_expr, f"pto.{name} loop3", exact_len=3) + + tail_args: list[SemanticExpr] = [] + if name == "acc_store_gm": + sid_expr = keywords.get("sid", SemanticLiteralExpr(value=0, type=SemanticIndexType())) + l2_cache_ctrl_expr = keywords.get("l2_cache_ctrl", SemanticLiteralExpr(value=0, type=SemanticIndexType())) + self._require_i64_like_expr(sid_expr, f"pto.{name} sid") + self._require_i64_like_expr(l2_cache_ctrl_expr, f"pto.{name} l2_cache_ctrl") + tail_args.extend([sid_expr, l2_cache_ctrl_expr]) + elif name == "acc_store_ub": + dual_dst_mode_expr = keywords.get("dual_dst_mode", SemanticLiteralExpr(value=0, type=SemanticIndexType())) + sub_blockid_expr = keywords.get("sub_blockid", SemanticLiteralExpr(value=0, type=SemanticIndexType())) + self._require_i64_like_expr(dual_dst_mode_expr, f"pto.{name} dual_dst_mode") + self._require_i64_like_expr(sub_blockid_expr, f"pto.{name} sub_blockid") + tail_args.extend([dual_dst_mode_expr, sub_blockid_expr]) + tail_args.append(mode) + if loop0_src_stride is not None: + tail_args.append(loop0_src_stride) + if split_expr is not None: + tail_args.append(split_expr) + if loop3_expr is not None: + tail_args.append(loop3_expr) + + return SemanticCallExpr(namespace="pto", name=name, args=tuple(args) + tuple(tail_args), type=None) + + def _analyze_make_mask(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + if len(args) != 2: + raise TypeError("pto.make_mask expects exactly 2 positional arguments in TileLang DSL v1") + dtype_expr, value_expr = args + dtype = self._require_dtype_symbol(dtype_expr, "pto.make_mask element type") + if isinstance(value_expr, SemanticSymbolExpr) and value_expr.type.kind == "mask_pattern": + return SemanticCallExpr( + namespace="pto", + name="make_mask", + args=args, + type=SemanticMaskType(granularity=self._mask_granularity_for_dtype(dtype)), + ) + self._require_tail_remaining_expr(value_expr, "pto.make_mask tail remaining") + return SemanticCallExpr( + namespace="pto", + name="make_mask", + args=args, + type=SemanticTupleType( + elements=( + SemanticMaskType(granularity=self._mask_granularity_for_dtype(dtype)), + _I32_TYPE, + ) + ), + ) + + def _analyze_predicate_pattern_op( + self, + name: str, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + if len(args) != 1: + raise TypeError(f"pto.{name} expects exactly 1 positional argument in TileLang DSL v1") + pattern = args[0] + if not ( + isinstance(pattern, SemanticSymbolExpr) + and isinstance(pattern.type, SemanticMetaType) + and pattern.type.kind == "mask_pattern" + and isinstance(pattern.value, MaskPattern) + ): + raise TypeError(f"pto.{name} pattern must be a MaskPattern symbol such as `pto.PAT.ALL`") + granularity = name.rsplit("_", 1)[-1] + return SemanticCallExpr( + namespace="pto", + name=name, + args=args, + type=SemanticMaskType(granularity=granularity), + ) + + def _analyze_predicate_tail_op( + self, + name: str, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + if len(args) != 1: + raise TypeError(f"pto.{name} expects exactly 1 positional argument in TileLang DSL v1") + self._require_tail_remaining_expr(args[0], f"pto.{name} scalar") + granularity = name.rsplit("_", 1)[-1] + mask_type = SemanticMaskType(granularity=granularity) + return SemanticCallExpr( + namespace="pto", + name=name, + args=args, + type=SemanticTupleType(elements=(mask_type, _I32_TYPE)), + ) + + def _literal_expr_from_context_value(self, value: object, context: str) -> SemanticExpr: + if isinstance(value, bool): + return SemanticLiteralExpr(value=value, type=SemanticScalarType(dtype=i1)) + if isinstance(value, int) and not isinstance(value, bool): + return SemanticLiteralExpr(value=value, type=SemanticIndexType()) + if isinstance(value, float): + return SemanticLiteralExpr(value=value, type=SemanticScalarType(dtype=f32)) + if isinstance(value, str): + return SemanticLiteralExpr(value=value, type=SemanticMetaType(kind="string")) + if isinstance(value, ScalarType): + return SemanticSymbolExpr( + namespace="pto", + name=value.name, + value=value, + type=SemanticMetaType(kind="dtype"), + ) + if isinstance(value, MemorySpace): + return SemanticSymbolExpr( + namespace="pto", + name=value.name, + value=value, + type=SemanticMetaType(kind="memory_space"), + ) + if isinstance(value, CmpMode): + return SemanticSymbolExpr( + namespace="pto", + name=value.name, + value=value, + type=SemanticMetaType(kind="cmp_mode"), + ) + if isinstance(value, PredicatePart): + return SemanticSymbolExpr( + namespace="pto", + name=value.name, + value=value, + type=SemanticMetaType(kind="predicate_part"), + ) + raise TypeError( + f"{context} resolved to unsupported static value {value!r} in TileLang DSL v1" + ) + + def _analyze_get_op_attr(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + if len(args) not in {1, 2}: + raise TypeError( + "pto.get_op_attr expects 1 or 2 positional arguments `(name, default?)` in TileLang DSL v1" + ) + attr_name = self._require_string_expr(args[0], "pto.get_op_attr name") + if attr_name in self._context_attrs: + return self._literal_expr_from_context_value( + self._context_attrs[attr_name], + f"pto.get_op_attr({attr_name!r})", + ) + if len(args) == 2: + return args[1] + raise TypeError( + f"pto.get_op_attr could not resolve attribute {attr_name!r} and no default was provided" + ) + + def _analyze_scalar_constructor( + self, + name: str, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + return self._analyze_scalar_constructor_for_dtype( + _DTYPE_SYMBOLS[name], + args, + surface_name=f"pto.{name}", + ) + + def _analyze_scalar_constructor_for_dtype( + self, + target_dtype: ScalarType, + args: tuple[SemanticExpr, ...], + *, + surface_name: str, + ) -> SemanticExpr: + if len(args) != 1: + raise TypeError(f"{surface_name} expects exactly 1 positional argument in TileLang DSL v1") + + if ( + target_dtype.name in {"f16", "bf16", "f32"} + and isinstance(args[0], SemanticLiteralExpr) + and isinstance(args[0].type, SemanticMetaType) + and args[0].type.kind == "string" + ): + parsed = self._parse_float_literal_string(args[0].value, target_dtype, f"{surface_name} value") + return SemanticLiteralExpr( + value=parsed, + type=SemanticScalarType(dtype=target_dtype), + ) + if ( + is_integer_dtype(target_dtype) + and isinstance(args[0], SemanticLiteralExpr) + and isinstance(args[0].type, SemanticMetaType) + and args[0].type.kind == "string" + ): + parsed = self._parse_integer_literal_string( + args[0].value, + target_dtype, + f"{surface_name} value", + ) + return SemanticLiteralExpr( + value=parsed, + type=SemanticScalarType(dtype=target_dtype), + ) + + value = self._require_scalar_or_index_expr(args[0], f"{surface_name} value") + + if isinstance(value.type, SemanticScalarType) and value.type.dtype == target_dtype: + return value + + if isinstance(value, SemanticLiteralExpr): + literal_value = value.value + if target_dtype == i1: + if isinstance(literal_value, bool): + return SemanticLiteralExpr(value=literal_value, type=SemanticScalarType(dtype=i1)) + if isinstance(literal_value, int): + return SemanticLiteralExpr(value=bool(literal_value), type=SemanticScalarType(dtype=i1)) + if isinstance(literal_value, float): + return SemanticLiteralExpr(value=bool(literal_value), type=SemanticScalarType(dtype=i1)) + elif is_integer_dtype(target_dtype): + if isinstance(literal_value, bool): + casted = int(literal_value) + elif isinstance(literal_value, (int, float)): + casted = int(literal_value) + else: + casted = None + if casted is not None: + checked = self._check_integer_literal_range( + casted, + target_dtype, + f"{surface_name} value", + ) + return SemanticLiteralExpr(value=checked, type=SemanticScalarType(dtype=target_dtype)) + else: + if isinstance(literal_value, (bool, int, float)): + return SemanticLiteralExpr( + value=float(literal_value), + type=SemanticScalarType(dtype=target_dtype), + ) + + return SemanticCallExpr( + namespace="pto", + name=target_dtype.name, + args=(value,), + type=SemanticScalarType(dtype=target_dtype), + ) + + def _parse_float_literal_string( + self, + literal: str, + target_dtype: ScalarType, + context: str, + ) -> float: + text = literal.strip().lower() + if text in {"inf", "+inf", "infinity", "+infinity"}: + return float("inf") + if text in {"-inf", "-infinity"}: + return float("-inf") + if text in {"nan", "+nan", "-nan"}: + return float("nan") + + if text.startswith("0x"): + try: + bit_pattern = int(text, 16) + except ValueError as exc: + raise TypeError( + f"{context} string literal {literal!r} is not a valid hex bit-pattern" + ) from exc + return self._float_from_bit_pattern(bit_pattern, target_dtype, context=context) + + try: + return float(text) + except ValueError as exc: + raise TypeError( + f"{context} string literal {literal!r} is not a valid float literal" + ) from exc + + def _parse_integer_literal_string( + self, + literal: str, + target_dtype: ScalarType, + context: str, + ) -> int: + text = literal.strip().lower() + bits = integer_bitwidth(target_dtype) + signedness = integer_signedness(target_dtype) + assert bits is not None + signless_or_signed = signedness != "unsigned" + if not text.startswith("0x"): + raise TypeError( + f"{context} string literals must use hex bit-pattern form like \"0xFF\" in TileLang DSL v1" + ) + try: + parsed = int(text, 16) + except ValueError as exc: + raise TypeError( + f"{context} string literal {literal!r} is not a valid hex bit-pattern" + ) from exc + if parsed >= (1 << bits): + raise TypeError( + f"{context} bit-pattern literal {literal!r} exceeds {bits}-bit width for {target_dtype.name}" + ) + if signless_or_signed and parsed >= (1 << (bits - 1)): + parsed -= 1 << bits + return self._check_integer_literal_range(parsed, target_dtype, context) + + def _check_integer_literal_range( + self, + value: int, + target_dtype: ScalarType, + context: str, + ) -> int: + bits = integer_bitwidth(target_dtype) + signedness = integer_signedness(target_dtype) + assert bits is not None + if signedness == "unsigned": + min_value = 0 + max_value = (1 << bits) - 1 + else: + min_value = -(1 << (bits - 1)) + max_value = (1 << (bits - 1)) - 1 + if value < min_value or value > max_value: + raise TypeError( + f"{context} {value} is out of range for {target_dtype.name} in TileLang DSL v1" + ) + return value + + def _float_from_bit_pattern( + self, + bit_pattern: int, + target_dtype: ScalarType, + *, + context: str, + ) -> float: + if target_dtype.name == "f16": + if bit_pattern < 0 or bit_pattern > 0xFFFF: + raise TypeError(f"{context} f16 bit-pattern must be in [0x0, 0xFFFF]") + return float(struct.unpack(">e", struct.pack(">H", bit_pattern))[0]) + if target_dtype.name == "bf16": + if bit_pattern < 0 or bit_pattern > 0xFFFF: + raise TypeError(f"{context} bf16 bit-pattern must be in [0x0, 0xFFFF]") + widened = bit_pattern << 16 + return float(struct.unpack(">f", struct.pack(">I", widened))[0]) + if target_dtype.name == "f32": + if bit_pattern < 0 or bit_pattern > 0xFFFFFFFF: + raise TypeError(f"{context} f32 bit-pattern must be in [0x0, 0xFFFFFFFF]") + return float(struct.unpack(">f", struct.pack(">I", bit_pattern))[0]) + raise TypeError(f"{context} bit-pattern literals are not supported for dtype {target_dtype.name}") + + def _analyze_ptr_type(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + if len(args) != 2: + raise TypeError("pto.ptr expects exactly 2 positional arguments in TileLang DSL") + dtype = self._require_dtype_symbol(args[0], "pto.ptr element type") + memory_space = self._require_memory_space_symbol(args[1], "pto.ptr memory space") + return SemanticLiteralExpr( + value=PointerType(element_dtype=dtype, memory_space=memory_space), + type=SemanticMetaType(kind="ptr_type"), + ) + + def _analyze_vreg_type(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + if len(args) != 1: + raise TypeError("pto.vreg expects exactly 1 positional argument in TileLang DSL v1") + dtype = self._require_dtype_symbol(args[0], "pto.vreg element type") + vreg_type = self._vreg_type_for_dtype(dtype) + return SemanticLiteralExpr( + value=VRegType(element_dtype=dtype, lanes=vreg_type.lanes), + type=SemanticMetaType(kind="vreg_type"), + ) + + def _analyze_vector_type(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + if len(args) != 2: + raise TypeError("pto.vector expects exactly 2 positional arguments in TileLang DSL v1") + dtype = self._require_dtype_symbol(args[0], "pto.vector element type") + shape = self._require_vector_shape_expr(args[1], "pto.vector shape") + return SemanticLiteralExpr( + value=VectorType(element_dtype=dtype, shape=shape), + type=SemanticMetaType(kind="vector_type"), + ) + + def _analyze_castptr(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + if len(args) != 2: + raise TypeError("pto.castptr expects exactly 2 positional arguments in TileLang DSL") + value, target = args + target_type = self._require_cast_target_type(target) + if isinstance(target_type, SemanticPtrType): + self._require_castptr_input(value, target_type) + else: + self._require_pointer_expr(value, "pto.castptr input") + return SemanticCallExpr(namespace="pto", name="castptr", args=args, type=target_type) + + def _analyze_addptr(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + if len(args) != 2: + raise TypeError("pto.addptr expects exactly 2 positional arguments in TileLang DSL") + pointer, offset = args + ptr = self._require_pointer_expr(pointer, "pto.addptr pointer") + offset = self._require_index_typed_expr(offset) + return SemanticCallExpr(namespace="pto", name="addptr", args=(ptr, offset), type=ptr.type) + + def _analyze_tile_frontend_call( + self, + expr: FrontendCallExpr, + env: dict[str, SemanticBinding], + *, + allow_outer_lookup: bool, + ) -> SemanticExpr: + analyzed_keywords = { + name: self._analyze_expr(value, env, allow_outer_lookup=allow_outer_lookup) + for name, value in expr.keywords + } + unexpected_keywords = sorted(set(analyzed_keywords) - _TILE_CONSTRUCTOR_ALLOWED_KEYWORDS) + if unexpected_keywords: + keyword_text = ", ".join(unexpected_keywords) + raise TypeError( + "pto.Tile only accepts keyword args " + "`valid_shape`, `blayout`, `slayout`, `fractal_size`, `pad_value`, `compact_mode`, and `addr`; " + f"got unsupported keyword(s): {keyword_text}" + ) + if len(expr.args) != 3: + raise TypeError( + "pto.Tile expects exactly 3 positional arguments `(shape, dtype, memory_space)` in TileLang DSL v1" + ) + + shape_expr = self._analyze_expr(expr.args[0], env, allow_outer_lookup=allow_outer_lookup) + dtype_expr = self._analyze_expr(expr.args[1], env, allow_outer_lookup=allow_outer_lookup) + memory_space_expr = self._analyze_expr(expr.args[2], env, allow_outer_lookup=allow_outer_lookup) + return self._analyze_tile_constructor( + shape_expr, + dtype_expr, + memory_space_expr, + valid_shape_expr=analyzed_keywords.get("valid_shape"), + b_layout_expr=analyzed_keywords.get("blayout"), + s_layout_expr=analyzed_keywords.get("slayout"), + fractal_size_expr=analyzed_keywords.get("fractal_size"), + pad_value_expr=analyzed_keywords.get("pad_value"), + compact_mode_expr=analyzed_keywords.get("compact_mode"), + addr_expr=analyzed_keywords.get("addr"), + ) + + def _analyze_tile_constructor( + self, + shape_expr: SemanticExpr, + dtype_expr: SemanticExpr, + memory_space_expr: SemanticExpr, + *, + valid_shape_expr: SemanticExpr | None, + b_layout_expr: SemanticExpr | None, + s_layout_expr: SemanticExpr | None, + fractal_size_expr: SemanticExpr | None, + pad_value_expr: SemanticExpr | None, + compact_mode_expr: SemanticExpr | None, + addr_expr: SemanticExpr | None, + ) -> SemanticExpr: + if compact_mode_expr is not None: + raise TypeError("pto.Tile compact_mode is not supported in TileLang DSL v1 yet") + if addr_expr is not None: + self._require_i64_like_expr(addr_expr, "pto.Tile addr") + + shape = self._require_static_shape_tuple(shape_expr, "pto.Tile shape") + if not shape: + raise TypeError("pto.Tile shape must be non-empty in TileLang DSL v1") + if len(shape) not in {1, 2}: + raise TypeError("pto.Tile only supports rank-1 or rank-2 shapes in TileLang DSL v1") + dtype = self._require_dtype_symbol(dtype_expr, "pto.Tile dtype") + memory_space = self._require_memory_space_symbol(memory_space_expr, "pto.Tile memory_space") + valid_shape = self._normalize_tile_valid_shape_expr(valid_shape_expr, shape, "pto.Tile valid_shape") + config = self._build_tile_constructor_config( + memory_space, + b_layout_expr=b_layout_expr, + s_layout_expr=s_layout_expr, + fractal_size_expr=fractal_size_expr, + pad_value_expr=pad_value_expr, + ) + lowered_args: list[SemanticExpr] = [] + if valid_shape_expr is not None or addr_expr is not None: + lowered_args.append( + valid_shape_expr + if valid_shape_expr is not None + else SemanticLiteralExpr(value=None, type=SemanticMetaType(kind="none")) + ) + if addr_expr is not None: + lowered_args.append(addr_expr) + return SemanticCallExpr( + namespace="pto", + name="alloc_tile", + args=tuple(lowered_args), + type=SemanticTileType( + element_dtype=dtype, + rank=len(shape), + shape=shape, + valid_shape=valid_shape, + memory_space=memory_space.value, + config=config, + ), + ) + + def _require_static_shape_tuple( + self, + expr: SemanticExpr, + context: str, + ) -> tuple[int, ...]: + value = self._try_static_value(expr) + if not isinstance(value, tuple): + raise TypeError(f"{context} must be a statically known tuple/list of integers in TileLang DSL v1") + dims: list[int] = [] + for index, dim in enumerate(value): + if isinstance(dim, bool) or not isinstance(dim, int): + raise TypeError(f"{context}[{index}] must be a positive integer in TileLang DSL v1") + if dim <= 0: + raise TypeError(f"{context}[{index}] must be a positive integer in TileLang DSL v1") + dims.append(dim) + return tuple(dims) + + def _normalize_tile_valid_shape_expr( + self, + expr: SemanticExpr | None, + shape: tuple[int, ...], + context: str, + ) -> tuple[int | None, ...]: + if expr is None: + return shape + value = self._try_static_value(expr) + if not isinstance(value, tuple): + raise TypeError(f"{context} must be a statically known tuple/list in TileLang DSL v1") + if len(value) != len(shape): + raise TypeError(f"{context} rank must match tile shape rank in TileLang DSL v1") + dims: list[int | None] = [] + for index, (dim, bound) in enumerate(zip(value, shape)): + if dim is None: + dims.append(None) + continue + if isinstance(dim, bool) or not isinstance(dim, int): + raise TypeError(f"{context}[{index}] must be an integer or None in TileLang DSL v1") + if dim <= 0: + raise TypeError(f"{context}[{index}] must be positive when provided in TileLang DSL v1") + if dim > bound: + raise TypeError(f"{context}[{index}] must be <= shape[{index}] in TileLang DSL v1") + dims.append(dim) + return tuple(dims) + + def _build_tile_constructor_config( + self, + memory_space: MemorySpace, + *, + b_layout_expr: SemanticExpr | None, + s_layout_expr: SemanticExpr | None, + fractal_size_expr: SemanticExpr | None, + pad_value_expr: SemanticExpr | None, + ) -> TileConfig: + defaults = dict(TileConfig.for_memory_space(memory_space).fields) + if b_layout_expr is not None: + defaults["b_layout"] = self._require_b_layout_symbol(b_layout_expr, "pto.Tile blayout") + if s_layout_expr is not None: + defaults["s_layout"] = self._require_s_layout_symbol(s_layout_expr, "pto.Tile slayout") + if fractal_size_expr is not None: + fractal = self._try_static_value(fractal_size_expr) + if isinstance(fractal, bool) or not isinstance(fractal, int): + raise TypeError("pto.Tile fractal_size must be a static integer in TileLang DSL v1") + defaults["s_fractal_size"] = fractal + if pad_value_expr is not None: + pad_value = self._try_static_value(pad_value_expr) + if not isinstance(pad_value, PadValue): + raise TypeError("pto.Tile pad_value must be a PadValue symbol in TileLang DSL v1") + defaults["pad_value"] = pad_value + return TileConfig(tuple(sorted(defaults.items()))) + + def _require_b_layout_symbol(self, expr: SemanticExpr, context: str) -> BLayout: + if ( + isinstance(expr, SemanticSymbolExpr) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "b_layout" + and isinstance(expr.value, BLayout) + ): + return expr.value + if ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "b_layout" + and isinstance(expr.binding.value, BLayout) + ): + return expr.binding.value + raise TypeError(f"{context} must be a BLayout symbol in TileLang DSL v1") + + def _require_s_layout_symbol(self, expr: SemanticExpr, context: str) -> SLayout: + if ( + isinstance(expr, SemanticSymbolExpr) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "s_layout" + and isinstance(expr.value, SLayout) + ): + return expr.value + if ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "s_layout" + and isinstance(expr.binding.value, SLayout) + ): + return expr.binding.value + raise TypeError(f"{context} must be an SLayout symbol in TileLang DSL v1") + + def _analyze_get_lanes( + self, + args: tuple[SemanticExpr, ...], + *, + call_name: str = "get_lanes", + ) -> SemanticExpr: + if len(args) != 1: + raise TypeError( + f"pto.{call_name} expects exactly 1 positional argument in TileLang DSL v1" + ) + dtype = self._require_dtype_symbol(args[0], f"pto.{call_name} dtype") + return SemanticLiteralExpr(value=self._vreg_type_for_dtype(dtype).lanes, type=SemanticIndexType()) + + def _analyze_bytewidth(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + if len(args) != 1: + raise TypeError("pto.bytewidth expects exactly 1 positional argument in TileLang DSL v1") + dtype = self._require_dtype_symbol(args[0], "pto.bytewidth dtype") + return SemanticLiteralExpr(value=bytewidth(dtype), type=SemanticIndexType()) + + def _analyze_init_align(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + if args: + raise TypeError("pto.init_align does not accept positional arguments in TileLang DSL v1") + return SemanticCallExpr(namespace="pto", name="init_align", args=(), type=SemanticAlignType()) + + def _analyze_vlds( + self, + args: tuple[SemanticExpr, ...], + *, + dist: SemanticExpr | None = None, + ) -> SemanticExpr: + if len(args) < 2: + raise TypeError("pto.vlds expects at least 2 positional arguments in TileLang DSL v1") + source, *indices = args + source_type = source.type + if isinstance(source_type, SemanticTileType): + source = self._require_tile_expr(source, "pto.vlds source") + else: + source = self._require_pointer_expr(source, "pto.vlds source", memory_space="ub") + indices = tuple(self._require_index_typed_expr(index) for index in indices) + lowered_args: tuple[SemanticExpr, ...] + if dist is not None: + lowered_args = (source, *indices, dist) + else: + lowered_args = (source, *indices) + return SemanticCallExpr( + namespace="pto", + name="vlds", + args=lowered_args, + type=self._vreg_type_for_dtype(source.type.element_dtype), + ) + + def _analyze_vlds_frontend_call( + self, + expr: FrontendCallExpr, + env: dict[str, SemanticBinding], + *, + allow_outer_lookup: bool, + ) -> SemanticExpr: + analyzed_keywords = { + name: self._analyze_expr(value, env, allow_outer_lookup=allow_outer_lookup) + for name, value in expr.keywords + } + unexpected_keywords = sorted(set(analyzed_keywords) - {"dist"}) + if unexpected_keywords: + keyword_text = ", ".join(unexpected_keywords) + raise TypeError( + "pto.vlds only accepts keyword attr `dist`; " + f"got unsupported keyword(s): {keyword_text}" + ) + dist = self._normalize_vlds_dist(analyzed_keywords.get("dist"), "pto.vlds dist") + if len(expr.args) == 1 and isinstance(expr.args[0], FrontendSubscriptExpr): + base, indices = self._analyze_tile_vector_access( + expr.args[0], + env, + allow_outer_lookup=allow_outer_lookup, + context="pto.vlds source", + ) + return self._analyze_vlds((base, *indices), dist=dist) + + args = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args + ) + return self._analyze_vlds(args, dist=dist) + + def _analyze_vldas(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + if len(args) not in {1, 2, 3}: + raise TypeError("pto.vldas expects 1 positional source or Tile[start:]/Tile[row, col:] in TileLang DSL v1") + source, *indices = args + source_type = source.type + if isinstance(source_type, SemanticTileType): + source = self._require_tile_expr(source, "pto.vldas source") + indices = tuple(self._require_index_typed_expr(index) for index in indices) + else: + if indices: + raise TypeError("pto.vldas pointer syntax does not accept explicit indices in TileLang DSL v1") + source = self._require_pointer_expr(source, "pto.vldas source", memory_space="ub") + return SemanticCallExpr( + namespace="pto", + name="vldas", + args=(source, *indices), + type=SemanticAlignType(), + ) + + def _analyze_vldus(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + if len(args) not in {2, 3, 4}: + raise TypeError("pto.vldus expects (source, align) or Tile element-indexing syntax in TileLang DSL v1") + source, *rest = args + align_expr = rest[-1] + index_args = rest[:-1] + source_type = source.type + if isinstance(source_type, SemanticTileType): + source = self._require_tile_expr(source, "pto.vldus source") + index_args = tuple(self._require_index_typed_expr(index) for index in index_args) + else: + if index_args: + raise TypeError("pto.vldus pointer syntax does not accept explicit indices in TileLang DSL v1") + source = self._require_pointer_expr(source, "pto.vldus source", memory_space="ub") + self._require_align_expr(align_expr, "pto.vldus align") + return SemanticCallExpr( + namespace="pto", + name="vldus", + args=(source, *index_args, align_expr), + type=SemanticTupleType(elements=(self._vreg_type_for_dtype(source.type.element_dtype), SemanticAlignType())), + ) + + def _analyze_vldsx2(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + if len(args) not in {3, 4}: + raise TypeError("pto.vldsx2 expects 3 or 4 positional arguments in TileLang DSL v1") + source, *rest = args + if len(rest) == 2: + index_args = rest[:1] + dist = rest[1] + else: + index_args = rest[:2] + dist = rest[2] + source_type = source.type + if isinstance(source_type, SemanticTileType): + source = self._require_tile_expr(source, "pto.vldsx2 source") + else: + source = self._require_pointer_expr(source, "pto.vldsx2 source", memory_space="ub") + index_args = tuple(self._require_index_typed_expr(index) for index in index_args) + dist = self._normalize_vldsx2_dist(dist) + vreg_type = self._vreg_type_for_dtype(source.type.element_dtype) + return SemanticCallExpr( + namespace="pto", + name="vldsx2", + args=(source, *index_args, dist), + type=SemanticTupleType(elements=(vreg_type, vreg_type)), + ) + + def _analyze_predicate_load_op( + self, + name: str, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + expects_i32_immediate = name == "pldi" + canonical_name = "plds" if name == "pld" else name + if len(args) not in {2, 3}: + raise TypeError( + f"pto.{name} expects 2 or 3 positional arguments in TileLang DSL v1: " + f"`pto.{name}(buf, offset[, dist])`" + ) + + source, offset = args[:2] + source = self._require_pointer_expr(source, f"pto.{name} source", memory_space="ub") + if expects_i32_immediate: + self._require_i32_like_expr(offset, "pto.pldi offset") + else: + offset = self._require_index_typed_expr(offset) + dist = self._normalize_predicate_load_dist( + args[2] if len(args) == 3 else None, + f"pto.{name} dist", + ) + + if source.type.element_dtype == ui8: + granularity = "b8" + elif source.type.element_dtype == ui16: + granularity = "b16" + elif source.type.element_dtype == ui32: + granularity = "b32" + else: + raise TypeError( + f"pto.{name} source must be !pto.ptr in TileLang DSL v1" + ) + + return SemanticCallExpr( + namespace="pto", + name=canonical_name, + args=(source, offset, dist), + type=SemanticMaskType(granularity=granularity), + ) + + def _analyze_pstu(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + if len(args) != 3: + raise TypeError("pto.pstu expects exactly 3 positional arguments in TileLang DSL v1") + align_expr, value, base = args + self._require_align_expr(align_expr, "pto.pstu align_in") + mask_type = self._require_mask_expr(value, "pto.pstu value") + base = self._require_pointer_expr(base, "pto.pstu base", memory_space="ub") + if mask_type.granularity == "b16": + expected = ui16 + elif mask_type.granularity == "b32": + expected = ui32 + else: + raise TypeError("pto.pstu only supports !pto.mask and !pto.mask in TileLang DSL v1") + if base.type.element_dtype != expected: + raise TypeError( + f"pto.pstu requires !pto.ptr<{expected.name}, ub> for mask granularity {mask_type.granularity}" + ) + return SemanticCallExpr( + namespace="pto", + name="pstu", + args=(align_expr, value, base), + type=SemanticTupleType(elements=(SemanticAlignType(), base.type)), + ) + + def _analyze_vstus(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + if len(args) != 4: + raise TypeError("pto.vstus expects exactly 4 positional arguments in TileLang DSL v1") + align_expr, offset, value, base = args + self._require_align_expr(align_expr, "pto.vstus align_in") + self._require_i32_like_expr(offset, "pto.vstus offset") + self._require_vreg_expr(value, "pto.vstus value") + base = self._require_pointer_expr(base, "pto.vstus base", memory_space="ub") + return SemanticCallExpr( + namespace="pto", + name="vstus", + args=(align_expr, offset, value, base), + type=SemanticAlignType(), + ) + + def _analyze_vstur(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + if len(args) not in {3, 4}: + raise TypeError("pto.vstur expects 3 or 4 positional arguments in TileLang DSL v1") + align_expr, value, base = args[:3] + mode = self._normalize_post_update_mode(args[3] if len(args) == 4 else None, "pto.vstur mode") + self._require_align_expr(align_expr, "pto.vstur align_in") + self._require_vreg_expr(value, "pto.vstur value") + base = self._require_pointer_expr(base, "pto.vstur base", memory_space="ub") + return SemanticCallExpr( + namespace="pto", + name="vstur", + args=(align_expr, value, base, mode), + type=SemanticAlignType(), + ) + + def _analyze_load_scalar(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + if len(args) == 2: + destination_dtype = None + pointer, offset = args + elif len(args) == 3: + destination_dtype = self._require_dtype_symbol(args[0], "pto.load_scalar result type") + pointer, offset = args[1:] + else: + raise TypeError("pto.load_scalar expects 2 or 3 positional arguments in TileLang DSL v1") + pointer = self._require_pointer_expr(pointer, "pto.load_scalar source") + offset = self._require_index_typed_expr(offset) + if destination_dtype is not None and destination_dtype != pointer.type.element_dtype: + raise TypeError("pto.load_scalar result type must match source pointer element dtype") + return SemanticCallExpr( + namespace="pto", + name="load_scalar", + args=(pointer, offset), + type=SemanticScalarType(dtype=pointer.type.element_dtype), + ) + + def _analyze_runtime_block_query( + self, + name: str, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + if args: + raise TypeError(f"pto.{name} does not accept positional arguments in TileLang DSL v1") + return SemanticCallExpr( + namespace="pto", + name=name, + args=(), + type=SemanticScalarType(dtype=i64), + ) + + def _analyze_broadcast_vector_op( + self, + name: str, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + if name == "vbr": + if len(args) != 1: + raise TypeError("pto.vbr expects exactly 1 positional argument in TileLang DSL v1") + value = args[0] + vec_type = self._vreg_type_for_scalar_or_index(value, "pto.vbr value") + return SemanticCallExpr(namespace="pto", name=name, args=args, type=vec_type) + + if name == "vdup": + if len(args) not in {2, 3}: + raise TypeError("pto.vdup expects 2 or 3 positional arguments in TileLang DSL v1") + value = args[0] + if isinstance(value.type, SemanticVRegType): + vec_type = value.type + mask = args[1] + self._require_mask_for_vreg(mask, vec_type, "pto.vdup") + position_arg = args[2] if len(args) == 3 else None + position = self._normalize_position_mode(position_arg, "pto.vdup position") + return SemanticCallExpr( + namespace="pto", + name=name, + args=(value, mask, position), + type=vec_type, + ) + + if len(args) == 3: + raise TypeError( + "pto.vdup scalar input does not accept `position`; use `pto.vdup(input, mask)` " + "in TileLang DSL v1" + ) + vec_type = self._vreg_type_for_scalar_or_index(value, "pto.vdup input") + mask = args[1] + self._require_mask_for_vreg(mask, vec_type, "pto.vdup") + return SemanticCallExpr( + namespace="pto", + name=name, + args=(value, mask), + type=vec_type, + ) + + if name == "vci": + if len(args) not in {1, 2}: + raise TypeError("pto.vci expects 1 or 2 positional arguments in TileLang DSL v1") + index = self._require_scalar_or_index_expr(args[0], "pto.vci index") + index_dtype = i32 if isinstance(index.type, SemanticIndexType) else index.type.dtype + if not (is_integer_dtype(index_dtype) and integer_bitwidth(index_dtype) in {8, 16, 32}): + raise TypeError("pto.vci index only supports 8/16/32-bit integer dtypes in TileLang DSL v1") + order_arg = args[1] if len(args) == 2 else None + order = self._normalize_order_mode(order_arg, "pto.vci order") + return SemanticCallExpr( + namespace="pto", + name=name, + args=(index, order), + type=self._vreg_type_for_dtype(index_dtype), + ) + + raise TypeError(f"call surface `pto.{name}` is not supported in TileLang DSL v1 yet") + + def _analyze_unary_vector_op( + self, + name: str, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + if name in {"vsunpack", "vzunpack"}: + if len(args) != 2: + raise TypeError(f"pto.{name} expects exactly 2 positional arguments in TileLang DSL v1") + value, part = args + vreg = self._require_vreg_expr(value, f"pto.{name} value") + self._require_i32_like_expr(part, f"pto.{name} part") + self._validate_unary_dtype(name, vreg.element_dtype) + result_dtype = self._unpack_result_dtype(name, vreg.element_dtype) + return SemanticCallExpr( + namespace="pto", + name=name, + args=args, + type=SemanticVRegType(element_dtype=result_dtype, lanes=vreg.lanes // 2), + ) + if len(args) != 2: + raise TypeError(f"pto.{name} expects exactly 2 positional arguments in TileLang DSL v1") + value, mask = args + vreg = self._require_vreg_expr(value, f"pto.{name} value") + self._require_mask_for_vreg(mask, vreg, f"pto.{name}") + self._validate_unary_dtype(name, vreg.element_dtype) + result_type = vreg + if name == "vcadd": + result_type = self._vcadd_result_vreg_type(vreg) + return SemanticCallExpr(namespace="pto", name=name, args=args, type=result_type) + + def _analyze_vexpdif_op( + self, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + if len(args) != 4: + raise TypeError("pto.vexpdif expects exactly 4 positional arguments in TileLang DSL v1") + input_expr, max_expr, mask_expr, part_expr = args + input_type = self._require_vreg_expr(input_expr, "pto.vexpdif input") + max_type = self._require_vreg_expr(max_expr, "pto.vexpdif max") + if input_type != max_type: + raise TypeError("pto.vexpdif requires input/max vector types to match") + self._validate_vexpdif_dtype(input_type.element_dtype) + self._require_mask_for_vreg(mask_expr, input_type, "pto.vexpdif") + part = self._normalize_vexpdif_part(part_expr, "pto.vexpdif part") + return SemanticCallExpr( + namespace="pto", + name="vexpdif", + args=(input_expr, max_expr, mask_expr, part), + type=self._vexpdif_result_vreg_type(input_type), + ) + + def _analyze_binary_vector_op( + self, + name: str, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + if len(args) != 3: + raise TypeError(f"pto.{name} expects exactly 3 positional arguments in TileLang DSL v1") + lhs_expr, rhs_expr, mask = args + lhs = self._require_vreg_expr(lhs_expr, f"pto.{name} lhs") + rhs = self._require_vreg_expr(rhs_expr, f"pto.{name} rhs") + if name == "vperm": + if not (is_integer_dtype(rhs.element_dtype) and integer_bitwidth(rhs.element_dtype) in {8, 16, 32}): + raise TypeError("pto.vperm indices vector only supports integer vector dtypes in TileLang DSL v1") + if lhs.lanes != rhs.lanes: + raise TypeError("pto.vperm requires data/indices vectors to use the same lane width") + elif lhs != rhs: + raise TypeError(f"pto.{name} requires lhs/rhs vector types to match") + self._require_mask_for_vreg(mask, lhs, f"pto.{name}") + self._validate_binary_dtype(name, lhs.element_dtype) + if ( + name in {"vdiv", "vmod"} + and is_integer_dtype(lhs.element_dtype) + and integer_bitwidth(lhs.element_dtype) in {8, 16, 32} + ): + return self._analyze_internal_inline_proc_call_expr( + "_tl_soft_vdiv" if name == "vdiv" else "_tl_soft_vmod", + ( + lhs_expr, + rhs_expr, + mask, + self._dtype_symbol_expr(lhs.element_dtype), + ), + ) + return SemanticCallExpr(namespace="pto", name=name, args=args, type=lhs) + + def _analyze_vector_scalar_op( + self, + name: str, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + if len(args) != 3: + raise TypeError(f"pto.{name} expects exactly 3 positional arguments in TileLang DSL v1") + vector_expr, scalar_expr, mask = args + vreg = self._require_vreg_expr(vector_expr, f"pto.{name} vector") + scalar = self._require_scalar_expr(scalar_expr, f"pto.{name} scalar") + if name in {"vshls", "vshrs"}: + if scalar.dtype != i16: + raise TypeError(f"pto.{name} scalar dtype must be i16") + elif scalar.dtype != vreg.element_dtype: + raise TypeError(f"pto.{name} scalar dtype must match vector element dtype") + self._require_mask_for_vreg(mask, vreg, f"pto.{name}") + self._validate_vector_scalar_dtype(name, vreg.element_dtype) + return SemanticCallExpr(namespace="pto", name=name, args=args, type=vreg) + + def _analyze_vector_immediate_op( + self, + name: str, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + if len(args) != 3: + raise TypeError(f"pto.{name} expects exactly 3 positional arguments in TileLang DSL v1") + vector = self._require_vreg_expr(args[0], f"pto.{name} vector") + immediate = self._require_scalar_or_index_expr(args[1], f"pto.{name} immediate") + if isinstance(immediate.type, SemanticScalarType) and not ( + is_integer_dtype(immediate.type.dtype) and integer_bitwidth(immediate.type.dtype) in {8, 16, 32} + ): + raise TypeError(f"pto.{name} immediate only supports 8/16/32-bit integer dtypes in TileLang DSL v1") + self._require_mask_for_vreg(args[2], vector, f"pto.{name}") + self._validate_vector_immediate_dtype(name, vector.element_dtype) + return SemanticCallExpr(namespace="pto", name=name, args=args, type=vector) + + def _analyze_ternary_vector_op( + self, + name: str, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + if len(args) != 4: + raise TypeError(f"pto.{name} expects exactly 4 positional arguments in TileLang DSL v1") + vec0 = self._require_vreg_expr(args[0], f"pto.{name} vec0") + vec1 = self._require_vreg_expr(args[1], f"pto.{name} vec1") + vec2 = self._require_vreg_expr(args[2], f"pto.{name} vec2") + if not (vec0 == vec1 == vec2): + raise TypeError(f"pto.{name} requires all vector operands to use the same vector type") + self._require_mask_for_vreg(args[3], vec0, f"pto.{name}") + self._validate_ternary_vector_dtype(name, vec0.element_dtype) + return SemanticCallExpr(namespace="pto", name=name, args=args, type=vec0) + + def _analyze_multi_result_vector_op( + self, + name: str, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + if name != "vmull": + raise TypeError(f"call surface `pto.{name}` is not supported in TileLang DSL v1 yet") + if len(args) != 3: + raise TypeError("pto.vmull expects exactly 3 positional arguments in TileLang DSL") + lhs = self._require_vreg_expr(args[0], "pto.vmull lhs") + rhs = self._require_vreg_expr(args[1], "pto.vmull rhs") + if lhs != rhs: + raise TypeError("pto.vmull requires lhs/rhs vector types to match") + self._require_mask_for_vreg(args[2], lhs, "pto.vmull") + self._validate_multi_result_vector_dtype(name, lhs.element_dtype) + return SemanticCallExpr( + namespace="pto", + name=name, + args=args, + type=SemanticTupleType(elements=(lhs, lhs)), + ) + + def _analyze_mask_part_op( + self, + name: str, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + if len(args) != 2: + raise TypeError(f"pto.{name} expects exactly 2 positional arguments in TileLang DSL") + mask = self._require_mask_expr(args[0], f"pto.{name} mask") + part = self._normalize_predicate_part(args[1], f"pto.{name} part") + result_granularity = mask.granularity + if name == "punpack": + if mask.granularity == "b8": + result_granularity = "b16" + elif mask.granularity == "b16": + result_granularity = "b32" + return SemanticCallExpr( + namespace="pto", + name=name, + args=(args[0], part), + type=SemanticMaskType(granularity=result_granularity), + ) + + def _analyze_mask_logic_op( + self, + name: str, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + if name == "pnot": + if len(args) != 2: + raise TypeError("pto.pnot expects exactly 2 positional arguments in TileLang DSL") + value = self._require_mask_expr(args[0], "pto.pnot input") + mask = self._require_mask_expr(args[1], "pto.pnot mask") + self._require_matching_mask_types(value, mask, "pto.pnot") + return SemanticCallExpr(namespace="pto", name=name, args=args, type=value) + if name in {"pand", "por", "pxor"}: + if len(args) != 3: + raise TypeError(f"pto.{name} expects exactly 3 positional arguments in TileLang DSL") + src0 = self._require_mask_expr(args[0], f"pto.{name} src0") + src1 = self._require_mask_expr(args[1], f"pto.{name} src1") + mask = self._require_mask_expr(args[2], f"pto.{name} mask") + self._require_matching_mask_types(src0, src1, f"pto.{name}") + self._require_matching_mask_types(src0, mask, f"pto.{name}") + return SemanticCallExpr(namespace="pto", name=name, args=args, type=src0) + if len(args) != 3: + raise TypeError("pto.psel expects exactly 3 positional arguments in TileLang DSL") + src0 = self._require_mask_expr(args[0], "pto.psel src0") + src1 = self._require_mask_expr(args[1], "pto.psel src1") + mask = self._require_mask_expr(args[2], "pto.psel mask") + self._require_matching_mask_types(src0, src1, "pto.psel") + self._require_matching_mask_types(src0, mask, "pto.psel") + return SemanticCallExpr(namespace="pto", name=name, args=args, type=src0) + + def _analyze_predicate_reorder_op( + self, + name: str, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + if len(args) != 2: + raise TypeError(f"pto.{name} expects exactly 2 positional arguments in TileLang DSL v1") + lhs = self._require_mask_expr(args[0], f"pto.{name} src0") + rhs = self._require_mask_expr(args[1], f"pto.{name} src1") + expected_granularity = { + "pdintlv_b8": "b8", + "pdintlv_b16": "b16", + "pdintlv_b32": "b32", + "pintlv_b8": "b8", + "pintlv_b16": "b16", + "pintlv_b32": "b32", + }[name] + if lhs.granularity != expected_granularity or rhs.granularity != expected_granularity: + raise TypeError(f"pto.{name} expects !pto.mask<{expected_granularity}> operands") + return SemanticCallExpr( + namespace="pto", + name=name, + args=args, + type=SemanticTupleType( + elements=( + SemanticMaskType(granularity=expected_granularity), + SemanticMaskType(granularity=expected_granularity), + ) + ), + ) + + def _analyze_compare_op( + self, + name: str, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + if name == "vcmp": + if len(args) != 4: + raise TypeError("pto.vcmp expects exactly 4 positional arguments in TileLang DSL") + lhs = self._require_vreg_expr(args[0], "pto.vcmp lhs") + rhs = self._require_vreg_expr(args[1], "pto.vcmp rhs") + if lhs != rhs: + raise TypeError("pto.vcmp requires lhs/rhs vector types to match") + seed = self._require_mask_expr(args[2], "pto.vcmp seed mask") + self._require_mask_for_vreg(args[2], lhs, "pto.vcmp") + cmp_mode = self._normalize_cmp_mode(args[3], "pto.vcmp compare mode") + return SemanticCallExpr( + namespace="pto", + name=name, + args=(args[0], args[1], args[2], cmp_mode), + type=SemanticMaskType(granularity=seed.granularity), + ) + + if len(args) != 4: + raise TypeError("pto.vcmps expects exactly 4 positional arguments in TileLang DSL") + vector = self._require_vreg_expr(args[0], "pto.vcmps vector") + scalar = self._require_scalar_expr(args[1], "pto.vcmps scalar") + if scalar.dtype != vector.element_dtype: + raise TypeError("pto.vcmps scalar dtype must match vector element dtype") + seed = self._require_mask_expr(args[2], "pto.vcmps seed mask") + self._require_mask_for_vreg(args[2], vector, "pto.vcmps") + cmp_mode = self._normalize_cmp_mode(args[3], "pto.vcmps compare mode") + return SemanticCallExpr( + namespace="pto", + name=name, + args=(args[0], args[1], args[2], cmp_mode), + type=SemanticMaskType(granularity=seed.granularity), + ) + + def _analyze_select_op( + self, + name: str, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + if name == "vsel": + if len(args) != 3: + raise TypeError("pto.vsel expects exactly 3 positional arguments in TileLang DSL") + src0 = self._require_vreg_expr(args[0], "pto.vsel src0") + src1 = self._require_vreg_expr(args[1], "pto.vsel src1") + if src0 != src1: + raise TypeError("pto.vsel requires src0/src1 vector types to match") + self._require_mask_for_vreg(args[2], src0, "pto.vsel") + return SemanticCallExpr(namespace="pto", name=name, args=args, type=src0) + + if len(args) != 2: + raise TypeError(f"pto.{name} expects exactly 2 positional arguments in TileLang DSL") + src0 = self._require_vreg_expr(args[0], f"pto.{name} src0") + src1 = self._require_vreg_expr(args[1], f"pto.{name} src1") + if src0 != src1: + raise TypeError(f"pto.{name} requires src0/src1 vector types to match") + return SemanticCallExpr(namespace="pto", name=name, args=args, type=src0) + + def _analyze_carry_op( + self, + name: str, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + if name in {"vaddc", "vsubc"}: + if len(args) != 3: + raise TypeError(f"pto.{name} expects exactly 3 positional arguments in TileLang DSL") + lhs = self._require_vreg_expr(args[0], f"pto.{name} lhs") + rhs = self._require_vreg_expr(args[1], f"pto.{name} rhs") + if lhs != rhs: + raise TypeError(f"pto.{name} requires lhs/rhs vector types to match") + self._require_mask_for_vreg(args[2], lhs, f"pto.{name}") + carry_type = args[2].type + return SemanticCallExpr( + namespace="pto", + name=name, + args=args, + type=SemanticTupleType(elements=(lhs, carry_type)), + ) + + if len(args) != 4: + raise TypeError(f"pto.{name} expects exactly 4 positional arguments in TileLang DSL") + lhs = self._require_vreg_expr(args[0], f"pto.{name} lhs") + rhs = self._require_vreg_expr(args[1], f"pto.{name} rhs") + if lhs != rhs: + raise TypeError(f"pto.{name} requires lhs/rhs vector types to match") + carry_in = self._require_mask_expr(args[2], f"pto.{name} carry_in") + self._require_mask_for_vreg(args[3], lhs, f"pto.{name}") + carry_mask = self._require_mask_expr(args[3], f"pto.{name} mask") + self._require_matching_mask_types(carry_in, carry_mask, f"pto.{name}") + return SemanticCallExpr( + namespace="pto", + name=name, + args=args, + type=SemanticTupleType(elements=(lhs, carry_in)), + ) + + def _analyze_rearrangement_op( + self, + name: str, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + if name in {"vintlv", "vdintlv"}: + if len(args) != 2: + raise TypeError(f"pto.{name} expects exactly 2 positional arguments in TileLang DSL") + lhs = self._require_vreg_expr(args[0], f"pto.{name} lhs") + rhs = self._require_vreg_expr(args[1], f"pto.{name} rhs") + if lhs != rhs: + raise TypeError(f"pto.{name} requires lhs/rhs vector types to match") + return SemanticCallExpr( + namespace="pto", + name=name, + args=args, + type=SemanticTupleType(elements=(lhs, lhs)), + ) + + if len(args) != 3: + raise TypeError(f"pto.{name} expects exactly 3 positional arguments in TileLang DSL") + lhs = self._require_vreg_expr(args[0], f"pto.{name} lhs") + rhs = self._require_vreg_expr(args[1], f"pto.{name} rhs") + if lhs != rhs: + raise TypeError(f"pto.{name} requires lhs/rhs vector types to match") + self._require_string_expr(args[2], f"pto.{name} part") + return SemanticCallExpr(namespace="pto", name=name, args=args, type=lhs) + + def _missing_optional_meta_expr(self) -> SemanticLiteralExpr: + return SemanticLiteralExpr(value=None, type=SemanticMetaType(kind="none")) + + def _analyze_vcvt_frontend_call( + self, + expr: FrontendCallExpr, + env: dict[str, SemanticBinding], + *, + allow_outer_lookup: bool, + ) -> SemanticExpr: + if len(expr.args) != 3: + raise TypeError( + "pto.vcvt expects exactly 3 positional operands `(vec, to_type, mask)` " + "before optional keyword attrs in TileLang DSL v1" + ) + args = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args + ) + analyzed_keywords = { + name: self._analyze_expr(value, env, allow_outer_lookup=allow_outer_lookup) + for name, value in expr.keywords + } + allowed_keywords = {"rnd", "sat", "part"} + unexpected_keywords = sorted(set(analyzed_keywords) - allowed_keywords) + if unexpected_keywords: + keyword_text = ", ".join(unexpected_keywords) + raise TypeError( + "pto.vcvt only accepts keyword attrs `rnd`, `sat`, and `part`; " + f"got unsupported keyword(s): {keyword_text}" + ) + return self._analyze_vcvt( + args, + rnd=self._normalize_vcvt_round_mode(analyzed_keywords.get("rnd")), + sat=self._normalize_vcvt_sat_mode(analyzed_keywords.get("sat")), + part=self._normalize_vcvt_part_mode(analyzed_keywords.get("part")), + rnd_explicit="rnd" in analyzed_keywords, + sat_explicit="sat" in analyzed_keywords, + part_explicit="part" in analyzed_keywords, + ) + + def _analyze_vtrc_frontend_call( + self, + expr: FrontendCallExpr, + env: dict[str, SemanticBinding], + *, + allow_outer_lookup: bool, + ) -> SemanticExpr: + if len(expr.args) != 2: + raise TypeError( + "pto.vtrc expects exactly 2 positional operands `(vec, mask)` " + "before optional keyword attrs in TileLang DSL v1" + ) + args = tuple( + self._analyze_expr(arg, env, allow_outer_lookup=allow_outer_lookup) + for arg in expr.args + ) + analyzed_keywords = { + name: self._analyze_expr(value, env, allow_outer_lookup=allow_outer_lookup) + for name, value in expr.keywords + } + allowed_keywords = {"rnd"} + unexpected_keywords = sorted(set(analyzed_keywords) - allowed_keywords) + if unexpected_keywords: + keyword_text = ", ".join(unexpected_keywords) + raise TypeError( + "pto.vtrc only accepts keyword attr `rnd`; " + f"got unsupported keyword(s): {keyword_text}" + ) + return self._analyze_vtrc( + args, + rnd=self._normalize_vtrc_round_mode(analyzed_keywords.get("rnd")), + ) + + def _analyze_vcvt( + self, + args: tuple[SemanticExpr, ...], + *, + rnd: SemanticExpr | None = None, + sat: SemanticExpr | None = None, + part: SemanticExpr | None = None, + rnd_explicit: bool = False, + sat_explicit: bool = False, + part_explicit: bool = False, + ) -> SemanticExpr: + if len(args) != 3: + raise TypeError("pto.vcvt expects exactly 3 positional arguments in TileLang DSL") + vector = self._require_vreg_expr(args[0], "pto.vcvt vector") + target_dtype = self._require_dtype_symbol(args[1], "pto.vcvt to_type") + self._require_mask_for_vreg(args[2], vector, "pto.vcvt") + contract = self._lookup_vcvt_attr_contract(vector.element_dtype, target_dtype) + if contract is not None: + self._require_explicit_vcvt_attrs( + src_dtype=vector.element_dtype, + dst_dtype=target_dtype, + rnd_required=contract[0], + sat_required=contract[1], + part_required=contract[2], + rnd_explicit=rnd_explicit, + sat_explicit=sat_explicit, + part_explicit=part_explicit, + ) + return SemanticCallExpr( + namespace="pto", + name="vcvt", + args=( + args[0], + args[1], + args[2], + rnd if rnd is not None else self._missing_optional_meta_expr(), + sat if sat is not None else self._missing_optional_meta_expr(), + part if part is not None else self._missing_optional_meta_expr(), + ), + type=self._vreg_type_for_dtype(target_dtype), + ) + + def _analyze_vpack_op( + self, + args: tuple[SemanticExpr, ...], + ) -> SemanticExpr: + if len(args) != 2: + raise TypeError("pto.vpack expects exactly 2 positional arguments in TileLang DSL") + vector = self._require_vreg_expr(args[0], "pto.vpack vector") + part = self._normalize_predicate_part(args[1], "pto.vpack part") + self._validate_binary_dtype("vpack", vector.element_dtype) + result_dtype = self._pack_result_dtype(vector.element_dtype) + return SemanticCallExpr( + namespace="pto", + name="vpack", + args=(args[0], part), + type=SemanticVRegType(element_dtype=result_dtype, lanes=vector.lanes * 2), + ) + + def _analyze_vtrc( + self, + args: tuple[SemanticExpr, ...], + *, + rnd: SemanticExpr | None = None, + ) -> SemanticExpr: + if len(args) != 2: + raise TypeError("pto.vtrc expects exactly 2 positional arguments in TileLang DSL v1") + vector = self._require_vreg_expr(args[0], "pto.vtrc vector") + self._require_mask_for_vreg(args[1], vector, "pto.vtrc") + if vector.element_dtype not in {f16, bf16, f32}: + raise TypeError("pto.vtrc only supports f16/bf16/f32 vector element types in TileLang DSL v1") + return SemanticCallExpr( + namespace="pto", + name="vtrc", + args=( + args[0], + args[1], + rnd + if rnd is not None + else SemanticLiteralExpr(value="R", type=SemanticMetaType(kind="string")), + ), + type=vector, + ) + + def _analyze_vbitcast(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + if len(args) != 2: + raise TypeError("pto.vbitcast expects exactly 2 positional arguments in TileLang DSL") + vector = self._require_vreg_expr(args[0], "pto.vbitcast vector") + target_dtype = self._require_dtype_symbol(args[1], "pto.vbitcast to_type") + # No mask for vbitcast (pure type conversion) + return SemanticCallExpr( + namespace="pto", + name="vbitcast", + args=( + args[0], + args[1], + ), + type=self._vreg_type_for_dtype(target_dtype), + ) + + def _analyze_pbitcast(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + if len(args) != 2: + raise TypeError("pto.pbitcast expects exactly 2 positional arguments in TileLang DSL") + self._require_mask_expr(args[0], "pto.pbitcast mask") + target_mask_type = self._require_mask_type_expr(args[1], "pto.pbitcast to_type") + return SemanticCallExpr( + namespace="pto", + name="pbitcast", + args=( + args[0], + args[1], + ), + type=SemanticMaskType(granularity=target_mask_type.granularity), + ) + + def _analyze_vbitsort(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + if len(args) != 4: + raise TypeError("pto.vbitsort expects exactly 4 positional arguments in TileLang DSL v1") + destination = self._require_pointer_expr(args[0], "pto.vbitsort destination", memory_space="ub") + source = self._require_pointer_expr(args[1], "pto.vbitsort source", memory_space="ub") + indices = self._require_pointer_expr(args[2], "pto.vbitsort indices", memory_space="ub") + count = self._require_index_typed_expr(args[3]) + return SemanticCallExpr( + namespace="pto", + name="vbitsort", + args=(destination, source, indices, count), + type=None, + ) + + def _analyze_vmrgsort4(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + if len(args) != 7: + raise TypeError("pto.vmrgsort4 expects exactly 7 positional arguments in TileLang DSL v1") + destination = self._require_pointer_expr(args[0], "pto.vmrgsort4 destination", memory_space="ub") + source0 = self._require_pointer_expr(args[1], "pto.vmrgsort4 src0", memory_space="ub") + source1 = self._require_pointer_expr(args[2], "pto.vmrgsort4 src1", memory_space="ub") + source2 = self._require_pointer_expr(args[3], "pto.vmrgsort4 src2", memory_space="ub") + source3 = self._require_pointer_expr(args[4], "pto.vmrgsort4 src3", memory_space="ub") + self._require_i64_like_expr(args[5], "pto.vmrgsort4 count") + self._require_i64_like_expr(args[6], "pto.vmrgsort4 config") + return SemanticCallExpr( + namespace="pto", + name="vmrgsort4", + args=(destination, source0, source1, source2, source3, args[5], args[6]), + type=None, + ) + + def _analyze_get_vms4_sr(self, args: tuple[SemanticExpr, ...]) -> SemanticExpr: + if args: + raise TypeError("pto.get_vms4_sr does not accept positional arguments in TileLang DSL v1") + count_type = SemanticScalarType(dtype=i16) + return SemanticCallExpr( + namespace="pto", + name="get_vms4_sr", + args=(), + type=SemanticTupleType(elements=(count_type, count_type, count_type, count_type)), + ) + + def _require_dtype_symbol(self, expr: SemanticExpr, context: str) -> ScalarType: + if not ( + isinstance(expr, SemanticSymbolExpr) + and expr.type.kind == "dtype" + and isinstance(expr.value, ScalarType) + ): + if ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "dtype" + and isinstance(expr.binding.value, ScalarType) + ): + return expr.binding.value + raise TypeError(f"{context} must be a TileLang scalar dtype symbol in TileLang DSL v1") + return expr.value + + def _dtype_symbol_expr(self, dtype: ScalarType) -> SemanticSymbolExpr: + return SemanticSymbolExpr( + namespace="pto", + name=dtype.name, + value=dtype, + type=SemanticMetaType(kind="dtype"), + ) + + def _require_memory_space_symbol(self, expr: SemanticExpr, context: str) -> MemorySpace: + if ( + isinstance(expr, SemanticSymbolExpr) + and expr.type.kind == "memory_space" + and isinstance(expr.value, MemorySpace) + ): + return expr.value + if ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "memory_space" + and isinstance(expr.binding.value, MemorySpace) + ): + return expr.binding.value + raise TypeError(f"{context} must be a TileLang MemorySpace symbol") + + def _require_ptr_type_expr(self, expr: SemanticExpr, context: str) -> PointerType: + if ( + isinstance(expr, SemanticLiteralExpr) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "ptr_type" + and isinstance(expr.value, PointerType) + ): + return expr.value + if ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "ptr_type" + and isinstance(expr.binding.value, PointerType) + ): + return expr.binding.value + raise TypeError(f"{context} must be a pointer type constructed with pto.ptr(...)") + + def _require_vreg_type_expr(self, expr: SemanticExpr, context: str) -> VRegType: + if ( + isinstance(expr, SemanticLiteralExpr) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "vreg_type" + and isinstance(expr.value, VRegType) + ): + return expr.value + if ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "vreg_type" + and isinstance(expr.binding.value, VRegType) + ): + return expr.binding.value + raise TypeError(f"{context} must be a vector type constructed with pto.vreg(...)") + + def _require_vector_type_expr(self, expr: SemanticExpr, context: str) -> VectorType: + if ( + isinstance(expr, SemanticLiteralExpr) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "vector_type" + and isinstance(expr.value, VectorType) + ): + return expr.value + if ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "vector_type" + and isinstance(expr.binding.value, VectorType) + ): + return expr.binding.value + raise TypeError(f"{context} must be a builtin vector type constructed with pto.vector(...)") + + def _require_vector_shape_expr(self, expr: SemanticExpr, context: str) -> tuple[int, ...]: + if not isinstance(expr, SemanticTupleExpr): + dim = self._static_index_value(expr, default=None) + if dim is None: + raise TypeError(f"{context} must be a static integer or tuple of static integers") + if dim <= 0: + raise TypeError(f"{context} shape entries must be positive") + return (dim,) + if isinstance(expr, SemanticTupleExpr): + shape: list[int] = [] + for element in expr.elements: + dim = self._static_index_value(element, default=None) + if dim is None: + raise TypeError(f"{context} tuple entries must be static integers") + if dim <= 0: + raise TypeError(f"{context} shape entries must be positive") + shape.append(dim) + if not shape: + raise TypeError(f"{context} must be a non-empty shape") + return tuple(shape) + raise TypeError(f"{context} must be a static integer or tuple of static integers") + + def _require_mask_type_expr(self, expr: SemanticExpr, context: str) -> MaskType: + if ( + isinstance(expr, SemanticSymbolExpr) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "mask_type" + and isinstance(expr.value, MaskType) + ): + return expr.value + if ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "mask_type" + and isinstance(expr.binding.value, MaskType) + ): + return expr.binding.value + raise TypeError(f"{context} must be a mask type such as pto.mask_b32") + + def _require_cast_target_type(self, expr: SemanticExpr) -> SemanticType: + if self._is_i64_dtype_expr(expr): + return SemanticScalarType(dtype=i64) + ptr_type = self._require_ptr_type_expr(expr, "pto.castptr target type") + return SemanticPtrType( + element_dtype=ptr_type.element_dtype, + memory_space=ptr_type.memory_space.value, + ) + + def _require_castptr_input(self, expr: SemanticExpr, target_type: SemanticPtrType) -> None: + if isinstance(expr.type, SemanticIndexType): + return + if isinstance(expr.type, SemanticScalarType) and expr.type.dtype == i64: + return + if isinstance(expr.type, SemanticPtrType): + if expr.type.memory_space != target_type.memory_space: + raise TypeError("pto.castptr pointer-to-pointer casts must stay within one PTO memory space") + return + raise TypeError("pto.castptr input must be an index/i64, pointer, or memref-backed address value") + + def _is_i64_dtype_expr(self, expr: SemanticExpr) -> bool: + if isinstance(expr, SemanticSymbolExpr): + return expr.type.kind == "dtype" and expr.value == i64 + if isinstance(expr, SemanticBindingRef): + return ( + isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "dtype" + and expr.binding.value == i64 + ) + return False + + def _require_vreg_expr(self, expr: SemanticExpr, context: str) -> SemanticVRegType: + if not isinstance(expr.type, SemanticVRegType): + raise TypeError(f"{context} must be a vector register value in TileLang DSL v1") + return expr.type + + def _require_scalar_expr(self, expr: SemanticExpr, context: str) -> SemanticScalarType: + if not isinstance(expr.type, SemanticScalarType): + raise TypeError(f"{context} must be a scalar value in TileLang DSL v1") + return expr.type + + def _require_scalar_or_index_expr(self, expr: SemanticExpr, context: str) -> SemanticExpr: + if isinstance(expr.type, (SemanticScalarType, SemanticIndexType)): + return expr + raise TypeError(f"{context} must be a scalar or index value in TileLang DSL v1") + + def _vreg_type_for_scalar_or_index(self, expr: SemanticExpr, context: str) -> SemanticVRegType: + value = self._require_scalar_or_index_expr(expr, context) + if isinstance(value.type, SemanticScalarType): + return self._vreg_type_for_dtype(value.type.dtype) + return self._vreg_type_for_dtype(i32) + + def _vcadd_result_vreg_type(self, vreg_type: SemanticVRegType) -> SemanticVRegType: + dtype = vreg_type.element_dtype + if not is_integer_dtype(dtype): + return vreg_type + signedness = integer_signedness(dtype) + bitwidth = integer_bitwidth(dtype) + if bitwidth == 8: + widened_dtype = ui16 if signedness == "unsigned" else i16 + return self._vreg_type_for_dtype(widened_dtype) + if bitwidth == 16: + widened_dtype = ui32 if signedness == "unsigned" else i32 + return self._vreg_type_for_dtype(widened_dtype) + return vreg_type + + def _vexpdif_result_vreg_type(self, vreg_type: SemanticVRegType) -> SemanticVRegType: + if vreg_type.element_dtype.name == "f32": + return vreg_type + return SemanticVRegType(element_dtype=f32, lanes=vreg_type.lanes // 2) + + def _normalize_position_mode( + self, + expr: SemanticExpr | None, + context: str, + ) -> SemanticExpr: + if expr is None: + return SemanticLiteralExpr(value=PositionMode.LOWEST.value, type=SemanticMetaType(kind="string")) + if ( + isinstance(expr, SemanticSymbolExpr) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "position_mode" + and isinstance(expr.value, PositionMode) + ): + return SemanticLiteralExpr(value=expr.value.value, type=SemanticMetaType(kind="string")) + if ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "position_mode" + and isinstance(expr.binding.value, PositionMode) + ): + return SemanticLiteralExpr(value=expr.binding.value.value, type=SemanticMetaType(kind="string")) + position = self._require_string_expr(expr, context) + if position == "POS_LOWEST": + position = PositionMode.LOWEST.value + if position not in {PositionMode.LOWEST.value, PositionMode.HIGHEST.value}: + raise TypeError( + "pto.vdup position must be `PositionMode.LOWEST` or `PositionMode.HIGHEST` in TileLang DSL v1" + ) + return SemanticLiteralExpr(value=position, type=SemanticMetaType(kind="string")) + + def _normalize_order_mode( + self, + expr: SemanticExpr | None, + context: str, + ) -> SemanticExpr: + if expr is None: + return SemanticLiteralExpr(value=OrderMode.ASC.value, type=SemanticMetaType(kind="string")) + if ( + isinstance(expr, SemanticSymbolExpr) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "order_mode" + and isinstance(expr.value, OrderMode) + ): + return SemanticLiteralExpr(value=expr.value.value, type=SemanticMetaType(kind="string")) + if ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "order_mode" + and isinstance(expr.binding.value, OrderMode) + ): + return SemanticLiteralExpr(value=expr.binding.value.value, type=SemanticMetaType(kind="string")) + order = self._require_string_expr(expr, context) + if order not in {OrderMode.ASC.value, OrderMode.DESC.value}: + raise TypeError( + "pto.vci currently only supports order `OrderMode.ASC` or `OrderMode.DESC` in TileLang DSL v1" + ) + return SemanticLiteralExpr(value=order, type=SemanticMetaType(kind="string")) + + def _normalize_vcvt_round_mode(self, expr: SemanticExpr | None) -> SemanticExpr | None: + if expr is None: + return None + if ( + isinstance(expr, SemanticSymbolExpr) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "vcvt_round_mode" + and isinstance(expr.value, VcvtRoundMode) + ): + return SemanticLiteralExpr(value=expr.value.value, type=SemanticMetaType(kind="string")) + if ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "vcvt_round_mode" + and isinstance(expr.binding.value, VcvtRoundMode) + ): + return SemanticLiteralExpr(value=expr.binding.value.value, type=SemanticMetaType(kind="string")) + round_mode = self._require_string_expr(expr, "pto.vcvt rnd") + if round_mode not in {mode.value for mode in VcvtRoundMode}: + raise TypeError( + "pto.vcvt rnd must be a VcvtRoundMode enum such as " + "`pto.VcvtRoundMode.R` or one of the canonical strings " + '`"R"`, `"A"`, `"F"`, `"C"`, `"Z"`, `"O"` in TileLang DSL v1' + ) + return SemanticLiteralExpr(value=round_mode, type=SemanticMetaType(kind="string")) + + def _normalize_vtrc_round_mode(self, expr: SemanticExpr | None) -> SemanticExpr | None: + normalized = self._normalize_vcvt_round_mode(expr) + if normalized is None: + return None + round_mode = self._require_string_expr(normalized, "pto.vtrc rnd") + if round_mode == VcvtRoundMode.O.value: + raise TypeError( + "pto.vtrc rnd must be one of " + '`"R"`, `"A"`, `"F"`, `"C"`, `"Z"` or a matching ' + "VcvtRoundMode enum in TileLang DSL v1" + ) + return SemanticLiteralExpr(value=round_mode, type=SemanticMetaType(kind="string")) + + def _normalize_vcvt_sat_mode(self, expr: SemanticExpr | None) -> SemanticExpr | None: + if expr is None: + return None + if ( + isinstance(expr, SemanticSymbolExpr) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "vcvt_sat_mode" + and isinstance(expr.value, VcvtSatMode) + ): + return SemanticLiteralExpr(value=expr.value.value, type=SemanticMetaType(kind="string")) + if ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "vcvt_sat_mode" + and isinstance(expr.binding.value, VcvtSatMode) + ): + return SemanticLiteralExpr(value=expr.binding.value.value, type=SemanticMetaType(kind="string")) + sat_mode = self._require_string_expr(expr, "pto.vcvt sat") + if sat_mode not in {mode.value for mode in VcvtSatMode}: + raise TypeError( + "pto.vcvt sat must be a VcvtSatMode enum such as " + "`pto.VcvtSatMode.SAT` or `pto.VcvtSatMode.NOSAT`, or one of the " + 'canonical strings `"SAT"` / `"NOSAT"` in TileLang DSL v1' + ) + return SemanticLiteralExpr(value=sat_mode, type=SemanticMetaType(kind="string")) + + def _normalize_vcvt_part_mode(self, expr: SemanticExpr | None) -> SemanticExpr | None: + if expr is None: + return None + if ( + isinstance(expr, SemanticSymbolExpr) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "vcvt_part_mode" + and isinstance(expr.value, VcvtPartMode) + ): + return SemanticLiteralExpr(value=expr.value.value, type=SemanticMetaType(kind="string")) + if ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "vcvt_part_mode" + and isinstance(expr.binding.value, VcvtPartMode) + ): + return SemanticLiteralExpr(value=expr.binding.value.value, type=SemanticMetaType(kind="string")) + part_mode = self._require_string_expr(expr, "pto.vcvt part") + if part_mode not in {mode.value for mode in VcvtPartMode}: + raise TypeError( + "pto.vcvt part must be a VcvtPartMode enum such as " + "`pto.VcvtPartMode.EVEN`, `pto.VcvtPartMode.ODD`, or " + "`pto.VcvtPartMode.P0`..`pto.VcvtPartMode.P3`, or one of the " + 'canonical strings `"EVEN"`, `"ODD"`, `"P0"`, `"P1"`, `"P2"`, or `"P3"` ' + "in TileLang DSL v1" + ) + return SemanticLiteralExpr(value=part_mode, type=SemanticMetaType(kind="string")) + + def _lookup_vcvt_attr_contract( + self, src_dtype: ScalarType, dst_dtype: ScalarType + ) -> tuple[bool, bool, bool] | None: + src_kind = _classify_vcvt_elem_kind(src_dtype) + dst_kind = _classify_vcvt_elem_kind(dst_dtype) + if src_kind is None or dst_kind is None: + return None + return _VCVT_ATTR_CONTRACTS.get((src_kind, dst_kind)) + + def _require_explicit_vcvt_attrs( + self, + *, + src_dtype: ScalarType, + dst_dtype: ScalarType, + rnd_required: bool, + sat_required: bool, + part_required: bool, + rnd_explicit: bool, + sat_explicit: bool, + part_explicit: bool, + ) -> None: + pair = f"{src_dtype.name}->{dst_dtype.name}" + + def _check(attr_name: str, required: bool, explicit: bool) -> None: + if required and not explicit: + raise TypeError( + f"pto.vcvt {pair} requires explicit `{attr_name}=` in TileLang DSL v1" + ) + if not required and explicit: + raise TypeError( + f"pto.vcvt {pair} does not accept `{attr_name}=` for this type pair in TileLang DSL v1" + ) + + _check("rnd", rnd_required, rnd_explicit) + _check("sat", sat_required, sat_explicit) + _check("part", part_required, part_explicit) + + def _require_mask_expr(self, expr: SemanticExpr, context: str) -> SemanticMaskType: + if not isinstance(expr.type, SemanticMaskType): + raise TypeError(f"{context} must be a mask value in TileLang DSL") + return expr.type + + def _require_align_expr(self, expr: SemanticExpr, context: str) -> None: + if not isinstance(expr.type, SemanticAlignType): + raise TypeError(f"{context} must be a pto.align value in TileLang DSL v1") + + def _require_matching_mask_types( + self, + lhs: SemanticMaskType, + rhs: SemanticMaskType, + context: str, + ) -> None: + if lhs != rhs: + raise TypeError(f"{context} requires all mask operands to use the same mask granularity") + + def _require_string_expr(self, expr: SemanticExpr, context: str) -> str: + if isinstance(expr, SemanticLiteralExpr) and isinstance(expr.type, SemanticMetaType) and expr.type.kind == "string": + return expr.value + if ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "string" + and isinstance(expr.binding.value, str) + ): + return expr.binding.value + raise TypeError(f"{context} must be a string literal in TileLang DSL") + + def _normalize_vexpdif_part(self, expr: SemanticExpr, context: str) -> SemanticExpr: + if ( + isinstance(expr, SemanticSymbolExpr) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "vcvt_part_mode" + and isinstance(expr.value, VcvtPartMode) + ): + part = expr.value.value + elif ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "vcvt_part_mode" + and isinstance(expr.binding.value, VcvtPartMode) + ): + part = expr.binding.value.value + else: + part = self._require_string_expr(expr, context) + if part not in {VcvtPartMode.EVEN.value, VcvtPartMode.ODD.value}: + raise TypeError( + "pto.vexpdif part must be `pto.VcvtPartMode.EVEN` or " + "`pto.VcvtPartMode.ODD`, or one of the canonical strings " + '`"EVEN"` / `"ODD"` in TileLang DSL v1' + ) + return SemanticLiteralExpr(value=part, type=SemanticMetaType(kind="string")) + + def _normalize_cmp_mode(self, expr: SemanticExpr, context: str) -> SemanticExpr: + if ( + isinstance(expr, SemanticSymbolExpr) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "cmp_mode" + and isinstance(expr.value, CmpMode) + ): + return SemanticLiteralExpr(value=expr.value.value, type=SemanticMetaType(kind="string")) + if ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "cmp_mode" + and isinstance(expr.binding.value, CmpMode) + ): + return SemanticLiteralExpr(value=expr.binding.value.value, type=SemanticMetaType(kind="string")) + cmp_mode = self._require_string_expr(expr, context) + if cmp_mode not in {mode.value for mode in CmpMode}: + raise TypeError( + f"{context} must be a CmpMode enum such as `pto.CmpMode.LT`, " + 'or one of the canonical strings `"eq"`, `"ne"`, `"lt"`, `"le"`, `"gt"`, `"ge"` ' + "in TileLang DSL v1" + ) + return SemanticLiteralExpr(value=cmp_mode, type=SemanticMetaType(kind="string")) + + def _normalize_predicate_part(self, expr: SemanticExpr, context: str) -> SemanticExpr: + if ( + isinstance(expr, SemanticSymbolExpr) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "predicate_part" + and isinstance(expr.value, PredicatePart) + ): + return SemanticLiteralExpr(value=expr.value.value, type=SemanticMetaType(kind="string")) + if ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "predicate_part" + and isinstance(expr.binding.value, PredicatePart) + ): + return SemanticLiteralExpr(value=expr.binding.value.value, type=SemanticMetaType(kind="string")) + part = self._require_string_expr(expr, context) + if part not in {token.value for token in PredicatePart}: + raise TypeError( + f"{context} must be a PredicatePart enum such as `pto.PredicatePart.LOWER`, " + 'or one of the canonical strings `"LOWER"`, `"HIGHER"` in TileLang DSL v1' + ) + return SemanticLiteralExpr(value=part, type=SemanticMetaType(kind="string")) + + def _normalize_post_update_mode( + self, + expr: SemanticExpr | None, + context: str, + ) -> SemanticExpr: + if expr is None: + return SemanticLiteralExpr(value="NO_POST_UPDATE", type=SemanticMetaType(kind="string")) + if ( + isinstance(expr, SemanticSymbolExpr) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "post_update_mode" + and isinstance(expr.value, PostUpdateMode) + ): + return SemanticLiteralExpr(value=expr.value.value, type=SemanticMetaType(kind="string")) + if ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "post_update_mode" + and isinstance(expr.binding.value, PostUpdateMode) + ): + return SemanticLiteralExpr(value=expr.binding.value.value, type=SemanticMetaType(kind="string")) + raise TypeError( + "pto.vstur mode must be a PostUpdateMode enum such as " + "`pto.PostUpdateMode.NO_POST_UPDATE` or `pto.PostUpdateMode.POST_UPDATE` in TileLang DSL v1" + ) + + def _normalize_predicate_store_dist( + self, + expr: SemanticExpr | None, + context: str, + ) -> SemanticExpr: + if expr is None: + return SemanticLiteralExpr(value="NORM", type=SemanticMetaType(kind="string")) + if ( + isinstance(expr, SemanticSymbolExpr) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "predicate_dist" + and isinstance(expr.value, PredicateDist) + ): + dist = expr.value.value + elif ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "predicate_dist" + and isinstance(expr.binding.value, PredicateDist) + ): + dist = expr.binding.value.value + else: + raise TypeError( + "predicate store dist must be a PredicateDist enum such as " + "`pto.PredicateDist.NORM` or `pto.PredicateDist.PK` in TileLang DSL v1" + ) + if dist not in {"NORM", "PK"}: + raise TypeError( + "predicate store dist must be one of " + "`pto.PredicateDist.NORM` or `pto.PredicateDist.PK` in TileLang DSL v1" + ) + return SemanticLiteralExpr(value=dist, type=SemanticMetaType(kind="string")) + + def _normalize_predicate_load_dist( + self, + expr: SemanticExpr | None, + context: str, + ) -> SemanticExpr: + if expr is None: + return SemanticLiteralExpr(value="NORM", type=SemanticMetaType(kind="string")) + if ( + isinstance(expr, SemanticSymbolExpr) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "predicate_dist" + and isinstance(expr.value, PredicateDist) + ): + dist = expr.value.value + elif ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "predicate_dist" + and isinstance(expr.binding.value, PredicateDist) + ): + dist = expr.binding.value.value + else: + raise TypeError( + "predicate load dist must be a PredicateDist enum such as " + "`pto.PredicateDist.NORM`, `pto.PredicateDist.US`, or `pto.PredicateDist.DS` in TileLang DSL v1" + ) + if dist not in {"NORM", "US", "DS"}: + raise TypeError( + "predicate load dist must be one of " + "`pto.PredicateDist.NORM`, `pto.PredicateDist.US`, or `pto.PredicateDist.DS` in TileLang DSL v1" + ) + return SemanticLiteralExpr(value=dist, type=SemanticMetaType(kind="string")) + + def _normalize_vlds_dist( + self, + expr: SemanticExpr | None, + context: str, + ) -> SemanticExpr | None: + if expr is None: + return None + if ( + isinstance(expr, SemanticSymbolExpr) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "vload_dist" + and isinstance(expr.value, VLoadDist) + ): + dist = expr.value.value + elif ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "vload_dist" + and isinstance(expr.binding.value, VLoadDist) + ): + dist = expr.binding.value.value + else: + raise TypeError( + "pto.vlds dist must be a VLoadDist enum such as " + "`pto.VLoadDist.NORM`, `pto.VLoadDist.UNPK_B16`, or " + "`pto.VLoadDist.BRC_B32` in TileLang DSL v1" + ) + return SemanticLiteralExpr(value=dist, type=SemanticMetaType(kind="string")) + + def _normalize_vsts_dist( + self, + expr: SemanticExpr | None, + context: str, + ) -> SemanticExpr | None: + if expr is None: + return None + if ( + isinstance(expr, SemanticSymbolExpr) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "vstore_dist" + and isinstance(expr.value, VStoreDist) + ): + dist = expr.value.value + elif ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "vstore_dist" + and isinstance(expr.binding.value, VStoreDist) + ): + dist = expr.binding.value.value + else: + raise TypeError( + "pto.vsts dist must be a VStoreDist enum such as " + "`pto.VStoreDist.NORM_B32`, `pto.VStoreDist.PK_B32`, or " + "`pto.VStoreDist.ONE_POINT_B8` in TileLang DSL v1" + ) + return SemanticLiteralExpr(value=dist, type=SemanticMetaType(kind="string")) + + def _require_i1_expr(self, expr: SemanticExpr, context: str) -> None: + scalar = self._require_scalar_expr(expr, context) + if scalar.dtype != i1: + raise TypeError(f"{context} must be an i1 value in TileLang DSL") + + def _require_i32_like_expr(self, expr: SemanticExpr, context: str) -> None: + if isinstance(expr.type, SemanticIndexType): + return + scalar = self._require_scalar_expr(expr, context) + if scalar.dtype != i32: + raise TypeError(f"{context} must be an i32 or index value in TileLang DSL") + + def _require_i64_like_expr(self, expr: SemanticExpr, context: str) -> None: + if isinstance(expr.type, SemanticIndexType): + return + scalar = self._require_scalar_expr(expr, context) + if scalar.dtype != i64: + raise TypeError(f"{context} must be an i64 or index value in TileLang DSL") + + def _require_tail_remaining_expr(self, expr: SemanticExpr, context: str) -> None: + if isinstance(expr.type, SemanticIndexType): + return + if isinstance(expr.type, SemanticScalarType) and expr.type.dtype.name == "i32": + return + raise TypeError(f"{context} must be an i32 or index value in TileLang DSL v1") + + def _require_mask_for_vreg( + self, + mask_expr: SemanticExpr, + vreg_type: SemanticVRegType, + context: str, + ) -> None: + if not isinstance(mask_expr.type, SemanticMaskType): + raise TypeError(f"{context} requires a mask operand in TileLang DSL v1") + expected = self._mask_granularity_for_dtype(vreg_type.element_dtype) + if mask_expr.type.granularity != expected: + raise TypeError( + f"{context} requires mask granularity {expected} for vector dtype {vreg_type.element_dtype!r}" + ) + + def _require_mask_for_vsts( + self, + mask_expr: SemanticExpr, + vreg_type: SemanticVRegType, + dist_expr: SemanticExpr | None, + context: str, + ) -> None: + if not isinstance(mask_expr.type, SemanticMaskType): + raise TypeError(f"{context} requires a mask operand in TileLang DSL v1") + expected = self._mask_granularity_for_dtype(vreg_type.element_dtype) + if dist_expr is not None: + dist = self._require_string_expr(dist_expr, f"{context} dist") + if dist == "PK_B16": + expected = "b16" + elif dist == "PK_B32": + expected = "b32" + elif dist == "PK_B64": + expected = "b32" + elif dist == "MRG4CHN_B8": + expected = "b32" + elif dist in {"MRG2CHN_B8", "MRG2CHN_B16"}: + expected = "b16" if dist == "MRG2CHN_B8" else "b32" + if mask_expr.type.granularity != expected: + raise TypeError( + f"{context} requires mask granularity {expected} for store dist " + f"{self._require_string_expr(dist_expr, f'{context} dist') if dist_expr is not None else 'default'}" + ) + + def _require_matching_vector_pointer( + self, + vreg_type: SemanticVRegType, + pointer_type: SemanticType, + context: str, + ) -> None: + if isinstance(pointer_type, SemanticTileType): + if pointer_type.element_dtype != vreg_type.element_dtype: + raise TypeError(f"{context} requires destination Tile dtype to match vector dtype") + return + if isinstance(pointer_type, SemanticPtrType): + if pointer_type.memory_space != "ub": + raise TypeError(f"{context} requires a UB pointer destination in TileLang DSL") + if pointer_type.element_dtype != vreg_type.element_dtype: + raise TypeError(f"{context} requires destination pointer dtype to match vector dtype") + return + raise TypeError(f"{context} requires a Tile or pointer destination in TileLang DSL") + + def _normalize_vldsx2_dist(self, expr: SemanticExpr) -> SemanticExpr: + if ( + isinstance(expr, SemanticSymbolExpr) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "deinterleave_dist" + and isinstance(expr.value, DeinterleaveDist) + ): + dist = expr.value.value + elif ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "deinterleave_dist" + and isinstance(expr.binding.value, DeinterleaveDist) + ): + dist = expr.binding.value.value + else: + dist = self._require_string_expr(expr, "pto.vldsx2 dist") + legacy_map = { + "DINTLV_B8": "DINTLV", + "DINTLV_B16": "DINTLV", + "DINTLV_B32": "DINTLV", + "BD": "BDINTLV", + } + normalized = legacy_map.get(dist, dist) + if normalized not in {"DINTLV", "BDINTLV"}: + raise TypeError( + "pto.vldsx2 dist must be one of \"DINTLV\" or \"BDINTLV\" in TileLang DSL v1" + ) + return SemanticLiteralExpr(value=normalized, type=SemanticMetaType(kind="string")) + + def _normalize_vstsx2_dist(self, expr: SemanticExpr) -> SemanticExpr: + if ( + isinstance(expr, SemanticSymbolExpr) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "interleave_dist" + and isinstance(expr.value, InterleaveDist) + ): + dist = expr.value.value + elif ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "interleave_dist" + and isinstance(expr.binding.value, InterleaveDist) + ): + dist = expr.binding.value.value + else: + dist = self._require_string_expr(expr, "pto.vstsx2 dist") + legacy_map = { + "INTLV_B8": "INTLV", + "INTLV_B16": "INTLV", + "INTLV_B32": "INTLV", + } + normalized = legacy_map.get(dist, dist) + if normalized != "INTLV": + raise TypeError("pto.vstsx2 dist must be \"INTLV\" in TileLang DSL v1") + return SemanticLiteralExpr(value=normalized, type=SemanticMetaType(kind="string")) + + def _mask_granularity_for_dtype(self, dtype: ScalarType) -> str: + int_bits = integer_bitwidth(dtype) + if dtype.name == "f32" or int_bits in {32, 64}: + return "b32" + if dtype.name in {"f16", "bf16"} or int_bits == 16: + return "b16" + if int_bits == 8: + return "b8" + raise TypeError(f"dtype `{dtype.name}` is not supported by make_mask/vector lowering in TileLang DSL v1") + + def _vreg_type_for_dtype(self, dtype: ScalarType) -> SemanticVRegType: + width = bytewidth(dtype) + if width not in {1, 2, 4, 8}: + raise TypeError(f"dtype `{dtype.name}` is not supported by vlds/vsts in TileLang DSL v1") + return SemanticVRegType(element_dtype=dtype, lanes=256 // width) + + def _unpack_result_dtype(self, name: str, dtype: ScalarType) -> ScalarType: + if not is_integer_dtype(dtype): + raise TypeError(f"pto.{name} only supports integer vector dtypes in TileLang DSL v1") + width = integer_bitwidth(dtype) + if width not in {8, 16, 32}: + raise TypeError(f"pto.{name} only supports 8/16/32-bit integer vector dtypes in TileLang DSL v1") + + if name == "vzunpack": + mapping = { + "i8": ui16, + "si8": ui16, + "ui8": ui16, + "i16": ui32, + "si16": ui32, + "ui16": ui32, + "i32": ui64, + "si32": ui64, + "ui32": ui64, + } + return mapping[dtype.name] + + mapping = { + "i8": i16, + "si8": si16, + "i16": i32, + "si16": si32, + "i32": i64, + "si32": si64, + } + if dtype.name not in mapping: + raise TypeError(f"pto.{name} requires signed/signless integer vector dtypes in TileLang DSL v1") + return mapping[dtype.name] + + def _pack_result_dtype(self, dtype: ScalarType) -> ScalarType: + if not is_integer_dtype(dtype): + raise TypeError("pto.vpack only supports integer vector dtypes in TileLang DSL v1") + mapping = { + "i32": ui16, + "si32": ui16, + "ui32": ui16, + "i16": ui8, + "si16": ui8, + "ui16": ui8, + } + if dtype.name not in mapping: + raise TypeError("pto.vpack only supports 32->16 and 16->8 integer packing in TileLang DSL v1") + return mapping[dtype.name] + + def _validate_unary_dtype(self, name: str, dtype: ScalarType) -> None: + if name in {"vexp", "vln", "vsqrt", "vrec", "vrsqrt"} and dtype.name not in {"f16", "f32"}: + raise TypeError(f"pto.{name} only supports f16/f32 in TileLang DSL v1") + if name == "vrelu" and not ( + dtype.name in {"f16", "f32"} + or (is_integer_dtype(dtype) and integer_bitwidth(dtype) == 32) + ): + raise TypeError("pto.vrelu only supports i32/f16/f32 in TileLang DSL v1") + if name in {"vnot", "vbcnt", "vcls", "vsunpack", "vzunpack", "vusqz", "vsqz"} and not ( + is_integer_dtype(dtype) and integer_bitwidth(dtype) in {8, 16, 32} + ): + raise TypeError(f"pto.{name} only supports integer vector dtypes in TileLang DSL v1") + if name in {"vabs", "vneg", "vmov", "vtrc", "vcadd", "vcmax", "vcmin"} and not ( + (is_integer_dtype(dtype) and integer_bitwidth(dtype) in {8, 16, 32}) or dtype.name in {"f16", "bf16", "f32"} + ): + raise TypeError(f"pto.{name} does not support this dtype in TileLang DSL v1") + + def _validate_binary_dtype(self, name: str, dtype: ScalarType) -> None: + if name == "vdiv" and not ( + (is_integer_dtype(dtype) and integer_bitwidth(dtype) in {8, 16, 32}) + or dtype.name in {"f16", "f32"} + ): + raise TypeError( + "pto.vdiv only supports 8/16/32-bit integer families and f16/f32 in TileLang DSL v1" + ) + if name == "vmod" and not ( + is_integer_dtype(dtype) and integer_bitwidth(dtype) in {8, 16, 32} + ): + raise TypeError( + "pto.vmod only supports 8/16/32-bit integer families in TileLang DSL v1" + ) + if name == "vprelu" and dtype.name not in {"f16", "f32"}: + raise TypeError("pto.vprelu only supports f16/f32 in TileLang DSL v1") + if name in {"vaddreluconv", "vmulconv"} and dtype.name not in {"f16", "bf16", "f32"}: + raise TypeError(f"pto.{name} only supports f16/bf16/f32 in TileLang DSL v1") + if name in {"vand", "vxor"} and not ( + is_integer_dtype(dtype) and integer_bitwidth(dtype) in {8, 16, 32} + ): + raise TypeError(f"pto.{name} only supports integer vector dtypes in TileLang DSL v1") + if name == "vor" and not ( + (is_integer_dtype(dtype) and integer_bitwidth(dtype) in {8, 16, 32}) + or dtype.name in {"f16", "bf16", "f32"} + ): + raise TypeError("pto.vor only supports integer vector dtypes and f16/bf16/f32 in TileLang DSL v1") + if name in {"vshl", "vshr"} and not ( + is_integer_dtype(dtype) and integer_bitwidth(dtype) in {8, 16, 32} + ): + raise TypeError(f"pto.{name} only supports integer vector dtypes in TileLang DSL v1") + if name == "vmul" and not ( + (is_integer_dtype(dtype) and integer_bitwidth(dtype) in {16, 32}) or dtype.name in {"f16", "f32"} + ): + raise TypeError("pto.vmul only supports 16/32-bit integer families and f16/f32 in TileLang DSL v1") + if name == "vperm" and not ( + (is_integer_dtype(dtype) and integer_bitwidth(dtype) in {8, 16, 32}) or dtype.name in {"f16", "bf16", "f32"} + ): + raise TypeError("pto.vperm does not support this data vector dtype in TileLang DSL v1") + if name in {"vadd", "vsub", "vmax", "vmin", "vaddrelu", "vsubrelu"} and not ( + (is_integer_dtype(dtype) and integer_bitwidth(dtype) in {8, 16, 32}) or dtype.name in {"f16", "bf16", "f32"} + ): + raise TypeError(f"pto.{name} does not support this dtype in TileLang DSL v1") + if name in {"vpack", "vmrgsort"} and not ( + (is_integer_dtype(dtype) and integer_bitwidth(dtype) in {8, 16, 32}) or dtype.name in {"f16", "bf16", "f32"} + ): + raise TypeError(f"pto.{name} does not support this dtype in TileLang DSL v1") + + def _validate_vexpdif_dtype(self, dtype: ScalarType) -> None: + if dtype.name not in {"f16", "f32"}: + raise TypeError("pto.vexpdif only supports f16/f32 in TileLang DSL v1") + + def _validate_vector_scalar_dtype(self, name: str, dtype: ScalarType) -> None: + if name == "vdivs" and dtype.name not in {"f16", "f32"}: + raise TypeError("pto.vdivs only supports f16/f32 in TileLang DSL v1") + if name == "vlrelu" and dtype.name not in {"f16", "f32"}: + raise TypeError("pto.vlrelu only supports f16/f32 in TileLang DSL v1") + if name in {"vshls", "vshrs"} and not ( + is_integer_dtype(dtype) and integer_bitwidth(dtype) in {8, 16, 32} + ): + raise TypeError(f"pto.{name} only supports integer vector dtypes in TileLang DSL v1") + if name in {"vands", "vors", "vxors"} and not ( + is_integer_dtype(dtype) and integer_bitwidth(dtype) in {8, 16, 32} + ): + raise TypeError(f"pto.{name} only supports integer vector dtypes in TileLang DSL v1") + if name in {"vadds", "vsubs", "vmuls", "vmaxs", "vmins"} and not ( + (is_integer_dtype(dtype) and integer_bitwidth(dtype) in {8, 16, 32}) or dtype.name in {"f16", "bf16", "f32"} + ): + raise TypeError(f"pto.{name} does not support this dtype in TileLang DSL v1") + + def _validate_vector_immediate_dtype(self, name: str, dtype: ScalarType) -> None: + if name in {"vshift", "vslide"} and not ( + (is_integer_dtype(dtype) and integer_bitwidth(dtype) in {8, 16, 32}) or dtype.name in {"f16", "bf16", "f32"} + ): + raise TypeError(f"pto.{name} does not support this vector dtype in TileLang DSL v1") + + def _validate_ternary_vector_dtype(self, name: str, dtype: ScalarType) -> None: + if name == "vaxpy" and not ( + (is_integer_dtype(dtype) and integer_bitwidth(dtype) in {16, 32}) or dtype.name in {"f16", "f32"} + ): + raise TypeError("pto.vaxpy only supports 16/32-bit integer families and f16/f32 in TileLang DSL v1") + if name == "vmula" and not ( + (is_integer_dtype(dtype) and integer_bitwidth(dtype) in {16, 32}) or dtype.name in {"f16", "f32"} + ): + raise TypeError("pto.vmula only supports 16/32-bit integer families and f16/f32 in TileLang DSL v1") + + def _validate_multi_result_vector_dtype(self, name: str, dtype: ScalarType) -> None: + if name == "vmull" and not (is_integer_dtype(dtype) and integer_bitwidth(dtype) == 32): + raise TypeError("pto.vmull only supports 32-bit integer vector families in TileLang DSL v1") + + def _require_sync_pipe(self, expr: SemanticExpr, context: str) -> str: + if isinstance(expr, SemanticSymbolExpr) and expr.type.kind == "pipe": + return expr.value.value + if isinstance(expr, SemanticLiteralExpr) and isinstance(expr.type, SemanticMetaType) and expr.type.kind == "string": + return expr.value + raise TypeError(f"{context} must be a PIPE symbol or pipe string in TileLang DSL v1") + + def _require_sync_event(self, expr: SemanticExpr, context: str) -> str: + if isinstance(expr, SemanticSymbolExpr) and expr.type.kind == "event": + return expr.value.value + if isinstance(expr, SemanticLiteralExpr) and isinstance(expr.type, SemanticMetaType) and expr.type.kind == "string": + return expr.value + raise TypeError(f"{context} must be an EVENT symbol or event string in TileLang DSL v1") + + def _require_barrier_type(self, expr: SemanticExpr, context: str) -> str: + if isinstance(expr, SemanticSymbolExpr) and expr.type.kind == "barrier_type": + return expr.value.value + if isinstance(expr, SemanticBindingRef) and isinstance(expr.type, SemanticMetaType): + if expr.type.kind == "barrier_type" and isinstance(expr.binding.value, BarrierType): + return expr.binding.value.value + if isinstance(expr, SemanticLiteralExpr) and isinstance(expr.type, SemanticMetaType) and expr.type.kind == "string": + if expr.value in {barrier_type.value for barrier_type in BarrierType}: + return expr.value + raise TypeError( + f"{context} must be a BarrierType symbol or canonical barrier string " + "(`VV_ALL`, `VST_VLD`, `VLD_VST`, `VST_VST`, `VS_ALL`, `VST_LD`, " + "`VLD_ST`, `VST_ST`, `SV_ALL`, `ST_VLD`, `LD_VST`, or `ST_VST`) " + "in TileLang DSL v1" + ) + + def _normalize_event_id_expr(self, expr: SemanticExpr, context: str) -> SemanticExpr: + if isinstance(expr, SemanticSymbolExpr) and expr.type.kind == "event" and isinstance(expr.value, Event): + return SemanticLiteralExpr( + value=int(expr.value.name[2:]), + type=SemanticScalarType(dtype=i64), + ) + if isinstance(expr, SemanticBindingRef) and isinstance(expr.type, SemanticMetaType): + if expr.type.kind == "event" and isinstance(expr.binding.value, Event): + return SemanticLiteralExpr( + value=int(expr.binding.value.name[2:]), + type=SemanticScalarType(dtype=i64), + ) + self._require_i64_like_expr(expr, context) + return expr + + def _pad_mode_value( + self, + expr: SemanticExpr | None, + *, + default: PadMode, + ) -> PadMode: + if expr is None: + return default + if isinstance(expr, SemanticSymbolExpr) and expr.type.kind == "pad_mode": + return expr.value + if ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "pad_mode" + and isinstance(expr.binding.value, PadMode) + ): + return expr.binding.value + raise TypeError("DMA pad_mode must be a PadMode symbol in TileLang DSL v1") + + def _require_loop_bound_type(self, ty: SemanticType) -> None: + if isinstance(ty, (SemanticIndexType, SemanticScalarType)): + return + raise TypeError(f"loop bound must be scalar/index typed, got {ty!r}") + + def _require_condition_type(self, ty: SemanticType) -> None: + if isinstance(ty, SemanticIndexType): + return + if isinstance(ty, SemanticScalarType): + return + raise TypeError(f"if condition must be scalar/index typed, got {ty!r}") + + def _merge_loop_carried_types( + self, + outer_type: SemanticType, + final_type: SemanticType, + ) -> SemanticType | None: + if final_type == outer_type: + return outer_type + if ( + isinstance(outer_type, SemanticIndexType) + and isinstance(final_type, SemanticScalarType) + and final_type.dtype == i32 + ): + return final_type + if ( + isinstance(final_type, SemanticIndexType) + and isinstance(outer_type, SemanticScalarType) + and outer_type.dtype == i32 + ): + return outer_type + return None + + def _require_index_typed_expr(self, expr: SemanticExpr) -> SemanticExpr: + if isinstance(expr.type, SemanticIndexType): + return expr + if isinstance(expr.type, SemanticScalarType) and is_integer_dtype(expr.type.dtype): + bits = integer_bitwidth(expr.type.dtype) + if bits in {8, 16, 32, 64}: + if isinstance(expr, SemanticLiteralExpr) and isinstance(expr.value, int) and not isinstance(expr.value, bool): + coerced: SemanticExpr = SemanticLiteralExpr(value=expr.value, type=SemanticIndexType()) + else: + coerced = SemanticIndexCastExpr(value=expr, type=SemanticIndexType()) + source_location = self._expr_source_location(expr) + if source_location is not None: + object.__setattr__(coerced, "source_location", source_location) + return coerced + self._raise_expr_type_error( + "slice bounds and vector offsets must be index-typed in TileLang DSL v1", + expr, + ) + + def _try_static_dtype(self, expr: SemanticExpr) -> ScalarType | None: + if ( + isinstance(expr, SemanticSymbolExpr) + and expr.type.kind == "dtype" + and isinstance(expr.value, ScalarType) + ): + return expr.value + if ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticMetaType) + and expr.type.kind == "dtype" + and isinstance(expr.binding.value, ScalarType) + ): + return expr.binding.value + return None + + def _try_static_subscript_value(self, expr: SemanticSubscriptAccess) -> Any | None: + index_value = self._try_static_value(expr.index) + if not isinstance(index_value, int): + return None + + base = expr.base + if isinstance(base, SemanticAttributeAccess) and isinstance(base.base, SemanticBindingRef): + binding_ref = base.base + binding_type = binding_ref.type + if isinstance(binding_type, SemanticTileType): + if base.attr == "shape" and binding_type.shape is not None: + if 0 <= index_value < len(binding_type.shape): + return binding_type.shape[index_value] + if base.attr == "valid_shape" and binding_type.valid_shape is not None: + if 0 <= index_value < len(binding_type.valid_shape): + return binding_type.valid_shape[index_value] + return None + if isinstance(binding_type, (SemanticTensorViewType, SemanticPartitionTensorViewType)): + return None + + base_value = self._try_static_value(base) + if isinstance(base_value, (tuple, list)): + if 0 <= index_value < len(base_value): + return base_value[index_value] + return None + return None + + def _try_static_value(self, expr: SemanticExpr | None) -> Any | None: + if expr is None: + return None + if isinstance(expr, SemanticSymbolExpr): + return expr.value + if isinstance(expr, SemanticLiteralExpr): + return expr.value + if isinstance(expr, SemanticIndexCastExpr): + value = self._try_static_value(expr.value) + if isinstance(value, int) and not isinstance(value, bool): + return value + return None + if isinstance(expr, SemanticBindingRef): + return expr.binding.value + if isinstance(expr, SemanticTupleExpr): + elements = [] + for element in expr.elements: + static_element = self._try_static_value(element) + if static_element is None: + return None + elements.append(static_element) + return tuple(elements) + if isinstance(expr, SemanticSubscriptAccess): + return self._try_static_subscript_value(expr) + if isinstance(expr, SemanticBinaryExpr): + if expr.op in {"and", "or"}: + lhs_bool = self._try_static_condition_bool(expr.lhs) + rhs_bool = self._try_static_condition_bool(expr.rhs) + if lhs_bool is None or rhs_bool is None: + return None + if expr.op == "and": + return lhs_bool and rhs_bool + return lhs_bool or rhs_bool + lhs = self._try_static_value(expr.lhs) + rhs = self._try_static_value(expr.rhs) + if lhs is None or rhs is None: + return None + if expr.op == "add": + if ( + isinstance(lhs, (int, float)) + and isinstance(rhs, (int, float)) + and not isinstance(lhs, bool) + and not isinstance(rhs, bool) + ): + return lhs + rhs + return None + if expr.op == "sub": + if ( + isinstance(lhs, (int, float)) + and isinstance(rhs, (int, float)) + and not isinstance(lhs, bool) + and not isinstance(rhs, bool) + ): + return lhs - rhs + return None + if expr.op == "mul": + if ( + isinstance(lhs, (int, float)) + and isinstance(rhs, (int, float)) + and not isinstance(lhs, bool) + and not isinstance(rhs, bool) + ): + return lhs * rhs + return None + if expr.op == "mod": + if isinstance(lhs, int) and isinstance(rhs, int): + if rhs == 0: + return None + return lhs % rhs + return None + if expr.op == "floordiv": + if isinstance(lhs, int) and isinstance(rhs, int): + if rhs == 0: + return None + return lhs // rhs + return None + if expr.op == "bitand": + if isinstance(lhs, int) and isinstance(rhs, int): + if isinstance(lhs, bool) or isinstance(rhs, bool): + return None + return lhs & rhs + return None + if expr.op == "bitor": + if isinstance(lhs, int) and isinstance(rhs, int): + if isinstance(lhs, bool) or isinstance(rhs, bool): + return None + return lhs | rhs + return None + if expr.op == "bitxor": + if isinstance(lhs, int) and isinstance(rhs, int): + if isinstance(lhs, bool) or isinstance(rhs, bool): + return None + return lhs ^ rhs + return None + if expr.op == "lshift": + if isinstance(lhs, int) and isinstance(rhs, int): + if isinstance(lhs, bool) or isinstance(rhs, bool) or rhs < 0: + return None + return lhs << rhs + return None + if expr.op == "rshift": + if isinstance(lhs, int) and isinstance(rhs, int): + if isinstance(lhs, bool) or isinstance(rhs, bool) or rhs < 0: + return None + return lhs >> rhs + return None + if expr.op == "eq": + return lhs == rhs + if expr.op == "ne": + return lhs != rhs + if expr.op == "gt": + try: + return lhs > rhs + except TypeError: + return None + if expr.op == "lt": + try: + return lhs < rhs + except TypeError: + return None + if expr.op == "ge": + try: + return lhs >= rhs + except TypeError: + return None + if expr.op == "le": + try: + return lhs <= rhs + except TypeError: + return None + return None + if isinstance(expr, SemanticCallExpr): + if expr.namespace != "pto": + return None + if expr.name == "bytewidth": + if len(expr.args) != 1: + return None + dtype = self._try_static_dtype(expr.args[0]) + if dtype is None: + return None + return bytewidth(dtype) + if expr.name in {"get_lanes", "elements_per_vreg"}: + if len(expr.args) != 1: + return None + dtype = self._try_static_dtype(expr.args[0]) + if dtype is None: + return None + return self._vreg_type_for_dtype(dtype).lanes + return None + + def _try_static_condition_bool(self, expr: SemanticExpr | None) -> bool | None: + value = self._try_static_value(expr) + if isinstance(value, bool): + return value + if isinstance(value, int): + return value != 0 + return None + + def _require_constexpr_condition_bool( + self, + expr: SemanticExpr, + *, + context: str, + ) -> bool: + value = self._try_static_condition_bool(expr) + if value is None: + raise TypeError( + f"{context} must be a compile-time bool in TileLang DSL v1" + ) + return value + + def _static_index_value(self, expr: SemanticExpr | None, *, default: int | None) -> int | None: + if expr is None: + return default + value = self._try_static_value(expr) + if isinstance(value, int) and not isinstance(value, bool): + return value + return None + + def _require_optional_index_typed_expr(self, expr: SemanticExpr | None) -> SemanticExpr | None: + if expr is None: + return None + return self._require_index_typed_expr(expr) + + def _static_bool_value(self, expr: SemanticExpr | None, *, default: bool | None) -> bool | None: + if expr is None: + return default + if isinstance(expr, SemanticLiteralExpr): + if ( + isinstance(expr.type, SemanticScalarType) + and expr.type.dtype == i1 + and isinstance(expr.value, bool) + ): + return expr.value + return None + if ( + isinstance(expr, SemanticBindingRef) + and isinstance(expr.type, SemanticScalarType) + and expr.type.dtype == i1 + and isinstance(expr.binding.value, bool) + ): + return expr.binding.value + return None + + def _require_static_bool_value( + self, + expr: SemanticExpr | None, + *, + context: str, + default: bool, + ) -> bool: + value = self._static_bool_value(expr, default=default) + if value is None: + raise TypeError( + f"{context} must be a compile-time bool in the stable frontend-only DMA profile" + ) + return value + + def _require_static_non_negative_index_value( + self, + expr: SemanticExpr | None, + *, + context: str, + default: int, + ) -> int: + value = self._static_index_value(expr, default=default) + if value is None: + raise TypeError( + f"{context} must be a static non-negative index in the stable frontend-only DMA profile" + ) + if value < 0: + raise TypeError( + f"{context} must be a non-negative index in the stable frontend-only DMA profile" + ) + return value + + def _normalize_optional_index_expr( + self, + expr: SemanticExpr | None, + *, + default: int, + ) -> SemanticExpr: + if expr is not None: + return expr + return SemanticLiteralExpr(value=default, type=SemanticIndexType()) + + def _normalized_tensor_slice_extent(self, expr: SemanticSliceExpr) -> int | None: + start = self._static_index_value(expr.start, default=0) + stop = self._static_index_value(expr.stop, default=None) + step = self._static_index_value(expr.step, default=1) + if stop is None or start is None or step is None: + return None + if step <= 0: + raise TypeError("TensorView slicing requires a positive static step in TileLang DSL v1") + distance = stop - start + if distance <= 0: + raise TypeError("TensorView slicing requires positive extents in TileLang DSL v1") + return (distance + step - 1) // step + + +def analyze_frontend_kernel(node: FrontendKernelNode) -> SemanticKernel: + """Normalize descriptor-owned AST into a lowering semantic model.""" + + return _SemanticAnalyzer(node).analyze() + + +__all__ = [ + "SemanticAssignStmt", + "SemanticAttributeAccess", + "SemanticBinaryExpr", + "SemanticBinding", + "SemanticBindingRef", + "SemanticCallExpr", + "SemanticDmaOptions", + "SemanticDmaLoadStmt", + "SemanticDmaStoreStmt", + "SemanticExpr", + "SemanticExprStmt", + "SemanticForStmt", + "SemanticGetBufStmt", + "SemanticAlignStoreStmt", + "SemanticAlignType", + "SemanticIfResult", + "SemanticIfStmt", + "SemanticIndexCastExpr", + "SemanticIndexType", + "SemanticKernel", + "SemanticLiteralExpr", + "SemanticMemBarStmt", + "SemanticMaskType", + "SemanticPadValueType", + "SemanticParameter", + "SemanticPipeBarrierStmt", + "SemanticPredicateStoreStmt", + "SemanticRlsBufStmt", + "SemanticReturnStmt", + "SemanticScalarType", + "SemanticSetCrossCoreStmt", + "SemanticSetFlagStmt", + "SemanticSetIntraBlockStmt", + "SemanticSetIntraCoreStmt", + "SemanticShapeType", + "SemanticSliceExpr", + "SemanticSliceType", + "SemanticStmt", + "SemanticVecscopeStmt", + "SemanticStrictVecscopeStmt", + "SemanticSubscriptAccess", + "SemanticSymbolExpr", + "SemanticTensorSliceAxis", + "SemanticTensorSliceExpr", + "SemanticTensorSliceType", + "SemanticTensorViewType", + "SemanticPartitionTensorViewType", + "SemanticTileBinding", + "SemanticTileConfigType", + "SemanticTileType", + "SemanticTupleExpr", + "SemanticTupleType", + "SemanticType", + "SemanticVectorType", + "SemanticVRegType", + "SemanticVScatterStmt", + "SemanticVectorPairStoreStmt", + "SemanticVectorStoreStmt", + "SemanticWaitFlagDevStmt", + "SemanticWaitFlagStmt", + "SemanticWaitIntraCoreStmt", + "analyze_frontend_kernel", +] diff --git a/tilelang-dsl/python/tilelang_dsl/support_matrix.py b/tilelang-dsl/python/tilelang_dsl/support_matrix.py new file mode 100644 index 000000000..8a1e79b7b --- /dev/null +++ b/tilelang-dsl/python/tilelang_dsl/support_matrix.py @@ -0,0 +1,495 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""Support-matrix definitions and diagnostics for TileLang DSL v1.""" + +from __future__ import annotations + +FOLLOW_UP_CHANGE = "extend-tilelang-dsl-matcher-and-advanced-surface" + +# Tier definitions for TileLang DSL surface classification +# These tiers represent the user-facing support level of language features: +# - BASIC: Core surface that is fully supported and recommended for general use +# - ADVANCED: Features requiring advanced=True, suitable for expert users + +BASIC_TIER = "basic" +ADVANCED_TIER = "advanced" + +# Tier metadata for PTO calls and language constructs +# This provides a unified source of truth for documentation and testing + +SUPPORTED_TOPLEVEL_PTO_CALLS = frozenset( + { + "constexpr", + "bytewidth", + "get_lanes", + "elements_per_vreg", + "get_op_attr", + "vreg", + "i1", + "i8", + "si8", + "ui8", + "i16", + "si16", + "ui16", + "i32", + "si32", + "ui32", + "i64", + "si64", + "ui64", + "f16", + "bf16", + "f32", + "get_block_idx", + "get_subblock_idx", + "get_block_num", + "get_subblock_num", + "set_flag", + "wait_flag", + "pipe_barrier", + "barrier", + "get_buf", + "rls_buf", + "mem_bar", + "set_cross_core", + "set_intra_block", + "set_intra_core", + "wait_flag_dev", + "wait_intra_core", + } +) + +SUPPORTED_VECSCOPE_PTO_CALLS = frozenset( + { + "make_mask", + "init_align", + "vlds", + "vldas", + "vldus", + "vldsx2", + "plds", + "psts", + "pstu", + "vsst", + "vsta", + "vstas", + "vstar", + "vsts", + "vstsx2", + "vstus", + "vstur", + "vabs", + "vrelu", + "vexp", + "vln", + "vsqrt", + "vrec", + "vnot", + "vcadd", + "vcmax", + "vbcnt", + "vneg", + "vcls", + "vcmin", + "vrsqrt", + "vmov", + "vsunpack", + "vzunpack", + "vusqz", + "vsqz", + "vexpdif", + "vexpdiff", + "vtrc", + "vbr", + "vdup", + "vadd", + "vsub", + "vmul", + "vdiv", + "vmod", + "vmax", + "vmin", + "vand", + "vor", + "vxor", + "vaddrelu", + "vaddreluconv", + "vsubrelu", + "vaxpy", + "vmulconv", + "vmull", + "vmula", + "vshl", + "vshr", + "vprelu", + "vadds", + "vsubs", + "vmuls", + "vdivs", + "vmaxs", + "vmins", + "vlrelu", + "vshls", + "vshrs", + "vands", + "vors", + "vxors", + "vcgadd", + "vcgmax", + "vcgmin", + "vcpadd", + "vpack", + "vperm", + "vshift", + "vslide", + "vsort32", + "vmrgsort", + "vcvt", + "vbitcast", + "pbitcast", + "vci", + } +) + +ADVANCED_VECSCOPE_PTO_CALLS = frozenset( + { + "vscatter", + "vcmp", + "vcmps", + "vsel", + "vselr", + "vselrv2", + "pset_b8", + "pset_b16", + "pset_b32", + "pge_b8", + "pge_b16", + "pge_b32", + "plt_b8", + "plt_b16", + "plt_b32", + "pnot", + "psel", + "pand", + "por", + "pxor", + "ppack", + "punpack", + "pld", + "pldi", + "pst", + "psti", + "pdintlv_b8", + "pdintlv_b16", + "pdintlv_b32", + "pintlv_b8", + "pintlv_b16", + "pintlv_b32", + "vaddc", + "vsubc", + "vaddcs", + "vsubcs", + "vintlv", + "vdintlv", + "vintlvv2", + "vdintlvv2", + "vbitsort", + "vmrgsort4", + "get_vms4_sr", + } +) + +ADVANCED_EXPR_PTO_CALLS = frozenset( + { + "ptr", + "castptr", + "addptr", + "load_scalar", + } +) + +ADVANCED_TOPLEVEL_PTO_CALLS = frozenset( + { + "strict_vecscope", + "store_scalar", + "set_mov_pad_val", + "copy_gm_to_ubuf", + "copy_ubuf_to_gm", + "copy_ubuf_to_ubuf", + "set_loop2_stride_outtoub", + "set_loop1_stride_outtoub", + "set_loop_size_outtoub", + "set_loop2_stride_ubtoout", + "set_loop1_stride_ubtoout", + "set_loop_size_ubtoout", + } +) + +CUBE_ONLY_PTO_CALLS = frozenset( + { + "cube_load", + "cube_store", + "cube_load_frac", + "bias_load", + "left_load", + "right_load", + "left_load_mx", + "right_load_mx", + "mad", + "mad_acc", + "mad_bias", + "mad_mx", + "mad_mx_acc", + "mad_mx_bias", + "acc_store", + "acc_store_gm", + "acc_store_ub", + } +) + +DEFERRED_PTO_SURFACES = frozenset( + { + "vreduce", + } +) + +# Public surface groupings used by the guide, migration notes, and tests. +# These groupings intentionally mirror the user-facing authoring tiers rather +# than the internal lowering organization. + +BASIC_TENSORVIEW_SURFACES = frozenset({"TensorView"}) +BASIC_TILE_SURFACES = frozenset({"Tile"}) +BASIC_HIGH_LEVEL_DMA_SURFACES = frozenset() +BASIC_BASE_VECTOR_SURFACES = frozenset( + f"pto.{name}" for name in sorted(SUPPORTED_VECSCOPE_PTO_CALLS) +) + +ADVANCED_RAW_POINTER_SURFACES = frozenset( + { + "ptr", + "pto.ptr", + "PointerType", + "pto.castptr", + "pto.addptr", + } +) +ADVANCED_LOW_LEVEL_DMA_SURFACES = frozenset( + { + "pto.set_mov_pad_val", + "pto.copy_gm_to_ubuf", + "pto.copy_ubuf_to_gm", + "pto.copy_ubuf_to_ubuf", + "pto.set_loop2_stride_outtoub", + "pto.set_loop1_stride_outtoub", + "pto.set_loop_size_outtoub", + "pto.set_loop2_stride_ubtoout", + "pto.set_loop1_stride_ubtoout", + "pto.set_loop_size_ubtoout", + } +) +ADVANCED_EXPLICIT_VECSCOPE_SURFACES = frozenset({"pto.strict_vecscope"}) +ADVANCED_TILE_HELPER_SURFACES = frozenset( + { + "tile.slice", + "tile.reshape", + "tile.as_ptr", + "tensorview.as_ptr", + "pto.tile_from_ptr", + "pto.tile_with_strides", + "pto.tile_config", + } +) +BASIC_TILE_INDEXING_SURFACES = frozenset( + { + "tile[start:]", + "tile[row, col:]", + } +) + +AUTHORING_TIER_SURFACE_GROUPS = { + "TensorView": BASIC_TENSORVIEW_SURFACES, + "Tile": BASIC_TILE_SURFACES, + "base_vector_ops": BASIC_BASE_VECTOR_SURFACES, + "tile_indexing_sugar": BASIC_TILE_INDEXING_SURFACES, + "strict_vecscope": ADVANCED_EXPLICIT_VECSCOPE_SURFACES, + "raw_pointer_family": ADVANCED_RAW_POINTER_SURFACES, + "low_level_dma_family": ADVANCED_LOW_LEVEL_DMA_SURFACES, + "tile_helper_family": ADVANCED_TILE_HELPER_SURFACES, +} + +AUTHORING_TIER_GROUP_TIERS = { + "TensorView": BASIC_TIER, + "Tile": BASIC_TIER, + "base_vector_ops": BASIC_TIER, + "tile_indexing_sugar": BASIC_TIER, + "strict_vecscope": ADVANCED_TIER, + "raw_pointer_family": ADVANCED_TIER, + "low_level_dma_family": ADVANCED_TIER, + "tile_helper_family": ADVANCED_TIER, +} + + +def unsupported_feature_message(feature: str) -> str: + return ( + f"{feature} is not supported in TileLang DSL v1; " + f"see follow-up change `{FOLLOW_UP_CHANGE}`" + ) + + +def deferred_surface_message(name: str) -> str: + return unsupported_feature_message(f"advanced family surface `pto.{name}`") + + +def advanced_mode_message(name: str) -> str: + return f"surface `pto.{name}` requires advanced=True in TileLang DSL" + + +# Tier mapping for PTO calls +def get_pto_call_tier(call_name: str) -> str: + """Return the tier of a PTO call. + + Args: + call_name: Name of the PTO call (without 'pto.' prefix) + + Returns: + One of BASIC_TIER or ADVANCED_TIER + + Raises: + KeyError: If the PTO call is not part of the supported DSL surface + """ + if call_name in SUPPORTED_TOPLEVEL_PTO_CALLS: + return BASIC_TIER + if call_name in SUPPORTED_VECSCOPE_PTO_CALLS: + return BASIC_TIER + if call_name in ADVANCED_VECSCOPE_PTO_CALLS: + return ADVANCED_TIER + if call_name in ADVANCED_EXPR_PTO_CALLS: + return ADVANCED_TIER + if call_name in ADVANCED_TOPLEVEL_PTO_CALLS: + return ADVANCED_TIER + raise KeyError(unsupported_feature_message(f"pto.{call_name}")) + + +UNSUPPORTED_LANGUAGE_CONSTRUCTS = frozenset( + { + "dma_load", + "dma_store", + "pto.dma_load", + "pto.dma_store", + "pto.dma_copy", + "pto.vreduce", + "pto.tile", + "SyncOpType", + } +) + + +# Tier mapping for language constructs (non-PTO-call features) +# These are higher-level abstractions in the TileLang DSL +LANGUAGE_CONSTRUCT_TIERS = { + # Basic tier constructs + "TensorView": BASIC_TIER, + "Tile": BASIC_TIER, + "VRegType": BASIC_TIER, + "MaskType": BASIC_TIER, + "pto.vreg": BASIC_TIER, + "pto.mask_b8": BASIC_TIER, + "pto.mask_b16": BASIC_TIER, + "pto.mask_b32": BASIC_TIER, + "BarrierType": BASIC_TIER, + "PadMode": BASIC_TIER, + "BLayout": BASIC_TIER, + "SLayout": BASIC_TIER, + "PadValue": BASIC_TIER, + "constexpr": BASIC_TIER, + "pto.constexpr": BASIC_TIER, + "tile[start:]": BASIC_TIER, + "tile[row, col:]": BASIC_TIER, + # Advanced tier constructs + "ptr": ADVANCED_TIER, # raw pointer constructor + "strict_vecscope": ADVANCED_TIER, # explicit vecscope management + "pto.strict_vecscope": ADVANCED_TIER, + "tile.slice": ADVANCED_TIER, + "tile.reshape": ADVANCED_TIER, + "tile.as_ptr": ADVANCED_TIER, + "tensorview.as_ptr": ADVANCED_TIER, + "pto.tile_from_ptr": ADVANCED_TIER, + "pto.tile_with_strides": ADVANCED_TIER, + "pto.tile_config": ADVANCED_TIER, +} + + +def get_feature_tier(feature_name: str) -> str: + """Return the tier of a TileLang DSL feature. + + Args: + feature_name: Name of the feature, which can be: + - A PTO call name (e.g., 'vadd', 'ptr') + - A language construct (e.g., 'TensorView', 'dma_load') + - A qualified construct (e.g., 'tile.slice', 'pto.tile_from_ptr') + + Returns: + One of BASIC_TIER or ADVANCED_TIER + + Raises: + KeyError: If the feature is documented but not part of the supported DSL surface + """ + # First check if it's a known language construct + if feature_name in LANGUAGE_CONSTRUCT_TIERS: + return LANGUAGE_CONSTRUCT_TIERS[feature_name] + if feature_name in UNSUPPORTED_LANGUAGE_CONSTRUCTS: + raise KeyError(unsupported_feature_message(feature_name)) + + # Check if it's a PTO call (might be qualified with 'pto.' prefix) + call_name = feature_name + if feature_name.startswith("pto."): + call_name = feature_name[4:] + + # Check PTO call tier + return get_pto_call_tier(call_name) + + +def get_surface_group_tier(group_name: str) -> str: + """Return the authoring tier for a documented public-surface group.""" + + return AUTHORING_TIER_GROUP_TIERS[group_name] + + +__all__ = [ + "CUBE_ONLY_PTO_CALLS", + "DEFERRED_PTO_SURFACES", + "FOLLOW_UP_CHANGE", + "ADVANCED_EXPR_PTO_CALLS", + "ADVANCED_TOPLEVEL_PTO_CALLS", + "ADVANCED_VECSCOPE_PTO_CALLS", + "SUPPORTED_TOPLEVEL_PTO_CALLS", + "SUPPORTED_VECSCOPE_PTO_CALLS", + "BASIC_TIER", + "ADVANCED_TIER", + "BASIC_TENSORVIEW_SURFACES", + "BASIC_TILE_SURFACES", + "BASIC_HIGH_LEVEL_DMA_SURFACES", + "BASIC_BASE_VECTOR_SURFACES", + "BASIC_TILE_INDEXING_SURFACES", + "ADVANCED_EXPLICIT_VECSCOPE_SURFACES", + "ADVANCED_RAW_POINTER_SURFACES", + "ADVANCED_LOW_LEVEL_DMA_SURFACES", + "ADVANCED_TILE_HELPER_SURFACES", + "AUTHORING_TIER_SURFACE_GROUPS", + "AUTHORING_TIER_GROUP_TIERS", + "UNSUPPORTED_LANGUAGE_CONSTRUCTS", + "LANGUAGE_CONSTRUCT_TIERS", + "advanced_mode_message", + "deferred_surface_message", + "unsupported_feature_message", + "get_pto_call_tier", + "get_feature_tier", + "get_surface_group_tier", +] diff --git a/tilelang-dsl/python/tilelang_dsl/types.py b/tilelang-dsl/python/tilelang_dsl/types.py new file mode 100644 index 000000000..86f5edcff --- /dev/null +++ b/tilelang-dsl/python/tilelang_dsl/types.py @@ -0,0 +1,880 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""Public type markers for the TileLang DSL v1 surface.""" + +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +import struct +from typing import Any, Mapping + + +@dataclass(frozen=True) +class ScalarType: + name: str + + def __repr__(self) -> str: + return self.name + + +_INTEGER_DTYPE_WIDTHS = { + "i8": 8, + "si8": 8, + "ui8": 8, + "i16": 16, + "si16": 16, + "ui16": 16, + "i32": 32, + "si32": 32, + "ui32": 32, + "i64": 64, + "si64": 64, + "ui64": 64, +} + +_INTEGER_DTYPE_SIGNS = { + "i8": "signless", + "si8": "signed", + "ui8": "unsigned", + "i16": "signless", + "si16": "signed", + "ui16": "unsigned", + "i32": "signless", + "si32": "signed", + "ui32": "unsigned", + "i64": "signless", + "si64": "signed", + "ui64": "unsigned", +} + +_FLOAT_DTYPE_WIDTHS = { + "f16": 16, + "bf16": 16, + "f32": 32, +} + +_DTYPE_BYTE_WIDTHS = { + name: bits // 8 for name, bits in _INTEGER_DTYPE_WIDTHS.items() +} +_DTYPE_BYTE_WIDTHS.update({name: bits // 8 for name, bits in _FLOAT_DTYPE_WIDTHS.items()}) + + +class TensorView: + """Bare TensorView annotation marker for TileLang DSL v1.""" + + +class PartitionTensorView: + """Bare PartitionTensorView annotation marker for TileLang DSL v1.""" + + +class Tile: + """Bare Tile annotation marker for TileLang DSL v1.""" + + +@dataclass(frozen=True) +class PointerType: + element_dtype: ScalarType + memory_space: "MemorySpace" + + def __repr__(self) -> str: + return f"ptr({self.element_dtype!r}, {self.memory_space!r})" + + +@dataclass(frozen=True) +class VRegType: + element_dtype: ScalarType + lanes: int + + def __repr__(self) -> str: + return f"vreg({self.element_dtype!r})" + + +@dataclass(frozen=True) +class VectorType: + element_dtype: ScalarType + shape: tuple[int, ...] + + def __repr__(self) -> str: + return f"vector({self.element_dtype!r}, {self.shape!r})" + + +@dataclass(frozen=True) +class MaskType: + granularity: str + + def __repr__(self) -> str: + return f"mask_{self.granularity}" + + +@dataclass(frozen=True) +class AlignType: + def __repr__(self) -> str: + return "align" + + +@dataclass(frozen=True) +class WildcardType: + name: str + + def __repr__(self) -> str: + return self.name + + +@dataclass(frozen=True) +class TypeVariable: + name: str + + def __repr__(self) -> str: + return f"TypeVar({self.name!r})" + + +class MemorySpace(str, Enum): + GM = "gm" + MAT = "mat" + LEFT = "left" + RIGHT = "right" + ACC = "acc" + BIAS = "bias" + UB = "ub" + + +class Pipe(str, Enum): + MTE1 = "PIPE_MTE1" + MTE2 = "PIPE_MTE2" + V = "PIPE_V" + MTE3 = "PIPE_MTE3" + ALL = "PIPE_ALL" + + +class Event(str, Enum): + ID0 = "EVENT_ID0" + ID1 = "EVENT_ID1" + ID2 = "EVENT_ID2" + ID3 = "EVENT_ID3" + ID4 = "EVENT_ID4" + ID5 = "EVENT_ID5" + ID6 = "EVENT_ID6" + ID7 = "EVENT_ID7" + ID8 = "EVENT_ID8" + ID9 = "EVENT_ID9" + ID10 = "EVENT_ID10" + ID11 = "EVENT_ID11" + ID12 = "EVENT_ID12" + ID13 = "EVENT_ID13" + ID14 = "EVENT_ID14" + ID15 = "EVENT_ID15" + ID16 = "EVENT_ID16" + ID17 = "EVENT_ID17" + ID18 = "EVENT_ID18" + ID19 = "EVENT_ID19" + ID20 = "EVENT_ID20" + ID21 = "EVENT_ID21" + ID22 = "EVENT_ID22" + ID23 = "EVENT_ID23" + ID24 = "EVENT_ID24" + ID25 = "EVENT_ID25" + ID26 = "EVENT_ID26" + ID27 = "EVENT_ID27" + ID28 = "EVENT_ID28" + ID29 = "EVENT_ID29" + ID30 = "EVENT_ID30" + ID31 = "EVENT_ID31" + + +class BarrierType(str, Enum): + VV_ALL = "VV_ALL" + VST_VLD = "VST_VLD" + VLD_VST = "VLD_VST" + VST_VST = "VST_VST" + VS_ALL = "VS_ALL" + VST_LD = "VST_LD" + VLD_ST = "VLD_ST" + VST_ST = "VST_ST" + SV_ALL = "SV_ALL" + ST_VLD = "ST_VLD" + LD_VST = "LD_VST" + ST_VST = "ST_VST" + + +class MaskPattern(str, Enum): + ALL = "PAT_ALL" + ALLF = "PAT_ALLF" + EVEN = "PAT_EVEN" + ODD = "PAT_ODD" + VL16 = "PAT_VL16" + VL32 = "PAT_VL32" + + +class PredicateDist(str, Enum): + NORM = "NORM" + US = "US" + DS = "DS" + PK = "PK" + + +class VLoadDist(str, Enum): + NORM = "NORM" + BRC_B8 = "BRC_B8" + BRC_B16 = "BRC_B16" + BRC_B32 = "BRC_B32" + US_B8 = "US_B8" + US_B16 = "US_B16" + DS_B8 = "DS_B8" + DS_B16 = "DS_B16" + UNPK_B8 = "UNPK_B8" + UNPK_B16 = "UNPK_B16" + UNPK_B32 = "UNPK_B32" + BRC_BLK = "BRC_BLK" + E2B_B16 = "E2B_B16" + E2B_B32 = "E2B_B32" + UNPK4 = "UNPK4" + SPLT4CHN = "SPLT4CHN" + SPLT2CHN_B8 = "SPLT2CHN_B8" + SPLT2CHN_B16 = "SPLT2CHN_B16" + + +class VStoreDist(str, Enum): + NORM_B8 = "NORM_B8" + NORM_B16 = "NORM_B16" + NORM_B32 = "NORM_B32" + ONE_POINT_B8 = "1PT_B8" + ONE_POINT_B16 = "1PT_B16" + ONE_POINT_B32 = "1PT_B32" + PK_B16 = "PK_B16" + PK_B32 = "PK_B32" + PK_B64 = "PK_B64" + PK4_B32 = "PK4_B32" + MRG4CHN_B8 = "MRG4CHN_B8" + MRG2CHN_B8 = "MRG2CHN_B8" + MRG2CHN_B16 = "MRG2CHN_B16" + + +class PredicatePart(str, Enum): + LOWER = "LOWER" + HIGHER = "HIGHER" + + +class CmpMode(str, Enum): + EQ = "eq" + NE = "ne" + LT = "lt" + LE = "le" + GT = "gt" + GE = "ge" + + +class PadMode(str, Enum): + PadNull = "PadNull" + PadFirstElem = "PadFirstElem" + PadValue = "PadValue" + + +class BLayout(str, Enum): + ROW_MAJOR = "row_major" + COL_MAJOR = "col_major" + + +class SLayout(str, Enum): + NONE_BOX = "none_box" + ROW_MAJOR = "row_major" + COL_MAJOR = "col_major" + + +def _float32_from_bits(bits: int) -> float: + return struct.unpack(">f", bits.to_bytes(4, byteorder="big", signed=False))[0] + + +_FLOAT_DTYPE_MAX = { + "f16": 65504.0, + "bf16": _float32_from_bits(0x7F7F0000), + "f32": _float32_from_bits(0x7F7FFFFF), +} +_FLOAT_DTYPE_MIN = { + "f16": -65504.0, + "bf16": _float32_from_bits(0xFF7F0000), + "f32": _float32_from_bits(0xFF7FFFFF), +} + + +@dataclass(frozen=True) +class PadValue: + """Tile pad descriptor matching the C++ PadValue design. + + Standard values occupy the low integer range: + - NULL = 0 + - ZERO = 1 + - MAX = 2 + - MIN = 3 + + Custom values use the C++ `CustomBase` convention and carry an f32 bit + pattern authored through `custom_f32(...)`. + """ + + encoded: int + _symbol_name: str | None = None + _float32_bits: int | None = None + + CustomBase = 0x100000000 + _STANDARD_TEXT = { + 0: "null", + 1: "zero", + 2: "max", + 3: "min", + } + + def __post_init__(self) -> None: + if isinstance(self.encoded, bool) or not isinstance(self.encoded, int): + raise TypeError("PadValue.encoded must be a uint64-compatible integer") + if self.encoded < 0 or self.encoded >= (1 << 64): + raise ValueError("PadValue.encoded must be in uint64 range") + if self._float32_bits is not None and not (0 <= self._float32_bits < (1 << 32)): + raise ValueError("PadValue custom float32 payload must be a 32-bit integer") + + @property + def name(self) -> str: + if self._symbol_name is not None: + return self._symbol_name + return "CUSTOM" + + @property + def value(self) -> int: + raise AttributeError( + "PadValue.value is not available; use PadValue.encoded for host-side payload access " + "or pad.eval(...) for scalar materialization" + ) + + @property + def text(self) -> str: + standard = self._STANDARD_TEXT.get(self.encoded) + if standard is not None: + return standard + return f"0x{self.encoded:016X}" + + @property + def is_custom(self) -> bool: + return self._symbol_name is None and self.encoded >= self.CustomBase + + @property + def float32_bits(self) -> int: + if not self.is_custom: + raise ValueError("only custom PadValue instances carry a float32 payload") + if self._float32_bits is not None: + return self._float32_bits + return (self.encoded >> 32) & 0xFFFFFFFF + + def as_float32(self) -> float: + return _float32_from_bits(self.float32_bits) + + def eval(self, dtype: ScalarType) -> int | float | None: + if not isinstance(dtype, ScalarType): + raise TypeError("PadValue.eval expects a TileLang scalar dtype") + if self == PadValue.NULL: + return None + if self == PadValue.ZERO: + return 0.0 if is_float_dtype(dtype) else 0 + if self == PadValue.MAX: + if is_float_dtype(dtype): + return _FLOAT_DTYPE_MAX[dtype.name] + width = integer_bitwidth(dtype) + signedness = integer_signedness(dtype) + if width is None or signedness is None: + raise TypeError(f"PadValue.MAX does not support dtype `{dtype.name}`") + if signedness == "unsigned": + return (1 << width) - 1 + return (1 << (width - 1)) - 1 + if self == PadValue.MIN: + if is_float_dtype(dtype): + return _FLOAT_DTYPE_MIN[dtype.name] + width = integer_bitwidth(dtype) + signedness = integer_signedness(dtype) + if width is None or signedness is None: + raise TypeError(f"PadValue.MIN does not support dtype `{dtype.name}`") + if signedness == "unsigned": + return 0 + return -(1 << (width - 1)) + if self.is_custom: + if not is_float_dtype(dtype): + raise TypeError( + "custom Tile pad_value currently only materializes for floating Tile element dtypes" + ) + return self.as_float32() + raise TypeError(f"unsupported PadValue payload {self!r}") + + @classmethod + def from_uint64(cls, value: int) -> "PadValue": + if isinstance(value, bool) or not isinstance(value, int): + raise TypeError("PadValue.from_uint64 expects an integer") + if value == 0: + return cls.NULL + if value == 1: + return cls.ZERO + if value == 2: + return cls.MAX + if value == 3: + return cls.MIN + if value < 0 or value >= (1 << 64): + raise ValueError("PadValue.from_uint64 expects a uint64-compatible integer") + return cls(value) + + @classmethod + def custom_f32(cls, value: float | str | int) -> "PadValue": + bits = cls._normalize_custom_f32_bits(value) + encoded = cls.CustomBase | (bits << 32) + return cls(encoded=encoded, _float32_bits=bits) + + @staticmethod + def _normalize_custom_f32_bits(value: float | str | int) -> int: + if isinstance(value, bool): + raise TypeError("PadValue.custom_f32 does not accept bool") + if isinstance(value, int): + if value < 0 or value >= (1 << 32): + raise ValueError("PadValue.custom_f32 integer payload must fit in 32 bits") + return value + if isinstance(value, str): + text = value.strip() + if text.lower().startswith("0x"): + bits = int(text, 16) + if bits < 0 or bits >= (1 << 32): + raise ValueError("PadValue.custom_f32 hex payload must fit in 32 bits") + return bits + value = float(text) + packed = struct.pack(">f", float(value)) + return int.from_bytes(packed, byteorder="big", signed=False) + + def __repr__(self) -> str: + if self == PadValue.NULL: + return "PadValue.NULL" + if self == PadValue.ZERO: + return "PadValue.ZERO" + if self == PadValue.MAX: + return "PadValue.MAX" + if self == PadValue.MIN: + return "PadValue.MIN" + return f"PadValue.custom_f32(0x{self.float32_bits:08X})" + + +PadValue.NULL = PadValue(0, "NULL") +PadValue.ZERO = PadValue(1, "ZERO") +PadValue.MAX = PadValue(2, "MAX") +PadValue.MIN = PadValue(3, "MIN") + + +class DeinterleaveDist(str, Enum): + DINTLV = "DINTLV" + BDINTLV = "BDINTLV" + B8 = "DINTLV" + B16 = "DINTLV" + B32 = "DINTLV" + BD = "BDINTLV" + + +class InterleaveDist(str, Enum): + INTLV = "INTLV" + B8 = "INTLV" + B16 = "INTLV" + B32 = "INTLV" + + +class PositionMode(str, Enum): + LOWEST = "LOWEST" + HIGHEST = "HIGHEST" + + +class OrderMode(str, Enum): + ASC = "ASC" + DESC = "DESC" + + +class VcvtRoundMode(str, Enum): + R = "R" + A = "A" + F = "F" + C = "C" + Z = "Z" + O = "O" + + +class VcvtSatMode(str, Enum): + SAT = "SAT" + NOSAT = "NOSAT" + + +class VcvtPartMode(str, Enum): + EVEN = "EVEN" + ODD = "ODD" + P0 = "P0" + P1 = "P1" + P2 = "P2" + P3 = "P3" + + +class PostUpdateMode(str, Enum): + POST_UPDATE = "POST_UPDATE" + NO_POST_UPDATE = "NO_POST_UPDATE" + + +class FractalMode(str, Enum): + ND2NZ = "nd2nz" + DN2NZ = "dn2nz" + NZ2ND = "nz2nd" + NZ2DN = "nz2dn" + NZ2NZ = "nz2nz" + + +@dataclass(frozen=True) +class TileConfig: + fields: tuple[tuple[str, Any], ...] = () + + @classmethod + def from_mapping(cls, mapping: Mapping[str, Any]) -> "TileConfig": + if not isinstance(mapping, Mapping): + raise TypeError("TileConfig.from_mapping expects a mapping") + normalized: dict[str, Any] = {} + for key, value in mapping.items(): + canonical_key = cls._canonical_key(key) + if canonical_key in normalized: + raise ValueError(f"duplicate TileConfig field '{canonical_key}'") + normalized[canonical_key] = cls._normalize_field_value(canonical_key, value) + return cls(tuple(sorted(normalized.items()))) + + @staticmethod + def _canonical_key(key: Any) -> str: + if not isinstance(key, str): + raise TypeError("TileConfig field names must be strings") + aliases = { + "layout": "b_layout", + "blayout": "b_layout", + "b_layout": "b_layout", + "slayout": "s_layout", + "s_layout": "s_layout", + "fractal": "s_fractal_size", + "s_fractal_size": "s_fractal_size", + "pad": "pad_value", + "pad_value": "pad_value", + } + return aliases.get(key, key) + + @staticmethod + def _normalize_field_value(key: str, value: Any) -> Any: + if key == "b_layout": + return TileConfig._normalize_b_layout(value) + if key == "s_layout": + return TileConfig._normalize_s_layout(value) + if key == "s_fractal_size": + if isinstance(value, bool) or not isinstance(value, int): + raise TypeError("TileConfig.s_fractal_size must be an integer") + return value + if key == "pad_value": + return TileConfig._normalize_pad_value(value) + return value + + @staticmethod + def _normalize_b_layout(value: Any) -> BLayout: + if isinstance(value, BLayout): + return value + if isinstance(value, str): + normalized = value.strip().upper().replace("-", "_") + if normalized == "ROW_MAJOR": + return BLayout.ROW_MAJOR + if normalized == "COL_MAJOR": + return BLayout.COL_MAJOR + raise ValueError(f"unsupported TileConfig b_layout value {value!r}") + + @staticmethod + def _normalize_s_layout(value: Any) -> SLayout: + if isinstance(value, SLayout): + return value + if isinstance(value, str): + normalized = value.strip().upper().replace("-", "_") + if normalized == "NONE_BOX": + return SLayout.NONE_BOX + if normalized == "ROW_MAJOR": + return SLayout.ROW_MAJOR + if normalized == "COL_MAJOR": + return SLayout.COL_MAJOR + raise ValueError(f"unsupported TileConfig s_layout value {value!r}") + + @staticmethod + def _normalize_pad_value(value: Any) -> PadValue: + if isinstance(value, PadValue): + return value + if isinstance(value, int) and not isinstance(value, bool): + return PadValue.from_uint64(value) + if isinstance(value, str): + text = value.strip() + if text.lower().startswith("0x"): + return PadValue.from_uint64(int(text, 16)) + normalized = value.strip().upper().replace("-", "_") + if normalized == "NULL": + return PadValue.NULL + if normalized == "ZERO": + return PadValue.ZERO + if normalized == "MAX": + return PadValue.MAX + if normalized == "MIN": + return PadValue.MIN + raise ValueError(f"unsupported TileConfig pad_value value {value!r}") + + @property + def b_layout(self) -> BLayout: + value = dict(self.fields).get("b_layout", BLayout.ROW_MAJOR) + return self._normalize_b_layout(value) + + @property + def s_layout(self) -> SLayout: + value = dict(self.fields).get("s_layout", SLayout.NONE_BOX) + return self._normalize_s_layout(value) + + @property + def s_fractal_size(self) -> int: + value = dict(self.fields).get("s_fractal_size", 512) + if isinstance(value, bool) or not isinstance(value, int): + raise TypeError("TileConfig.s_fractal_size must be an integer") + return value + + @property + def pad_value(self) -> PadValue: + value = dict(self.fields).get("pad_value", PadValue.NULL) + return self._normalize_pad_value(value) + + @classmethod + def for_memory_space(cls, memory_space: MemorySpace) -> "TileConfig": + if not isinstance(memory_space, MemorySpace): + raise TypeError("TileConfig.for_memory_space expects a TileLang MemorySpace") + defaults: dict[str, Any] + if memory_space in {MemorySpace.MAT, MemorySpace.LEFT}: + defaults = { + "b_layout": BLayout.COL_MAJOR, + "s_layout": SLayout.ROW_MAJOR, + "s_fractal_size": 512, + "pad_value": PadValue.NULL, + } + elif memory_space == MemorySpace.RIGHT: + defaults = { + "b_layout": BLayout.ROW_MAJOR, + "s_layout": SLayout.COL_MAJOR, + "s_fractal_size": 512, + "pad_value": PadValue.NULL, + } + elif memory_space == MemorySpace.ACC: + defaults = { + "b_layout": BLayout.COL_MAJOR, + "s_layout": SLayout.ROW_MAJOR, + "s_fractal_size": 1024, + "pad_value": PadValue.NULL, + } + elif memory_space == MemorySpace.BIAS: + defaults = { + "b_layout": BLayout.ROW_MAJOR, + "s_layout": SLayout.NONE_BOX, + "s_fractal_size": 512, + "pad_value": PadValue.NULL, + } + else: + defaults = { + "b_layout": BLayout.ROW_MAJOR, + "s_layout": SLayout.NONE_BOX, + "s_fractal_size": 512, + "pad_value": PadValue.NULL, + } + return cls(tuple(sorted(defaults.items()))) + + +@dataclass(frozen=True) +class TileSpecialization: + shape: tuple[int, ...] + memory_space: MemorySpace + config: TileConfig | None = None + valid_shape: tuple[int | None, ...] | None = None + + +i1 = ScalarType("i1") +i8 = ScalarType("i8") +si8 = ScalarType("si8") +ui8 = ScalarType("ui8") +i16 = ScalarType("i16") +si16 = ScalarType("si16") +ui16 = ScalarType("ui16") +i32 = ScalarType("i32") +si32 = ScalarType("si32") +ui32 = ScalarType("ui32") +i64 = ScalarType("i64") +si64 = ScalarType("si64") +ui64 = ScalarType("ui64") +f16 = ScalarType("f16") +bf16 = ScalarType("bf16") +f32 = ScalarType("f32") +PIPE = Pipe +EVENT = Event +PAT = MaskPattern +AnyFloat = WildcardType("AnyFloat") +AnyInt = WildcardType("AnyInt") +AnyType = WildcardType("AnyType") +AnyMask = WildcardType("AnyMask") +mask_b8 = MaskType("b8") +mask_b16 = MaskType("b16") +mask_b32 = MaskType("b32") +align = AlignType() + + +def TypeVar(name: str) -> TypeVariable: + if not isinstance(name, str) or not name: + raise TypeError("TypeVar name must be a non-empty string") + return TypeVariable(name) + + +def ptr(dtype: ScalarType, memory_space: MemorySpace) -> PointerType: + if not isinstance(dtype, ScalarType): + raise TypeError("ptr() expects a TileLang scalar dtype") + if not isinstance(memory_space, MemorySpace): + raise TypeError("ptr() expects a TileLang MemorySpace") + return PointerType(element_dtype=dtype, memory_space=memory_space) + + +def vreg(dtype: ScalarType) -> VRegType: + if not isinstance(dtype, ScalarType): + raise TypeError("vreg() expects a TileLang scalar dtype") + return VRegType(element_dtype=dtype, lanes=get_lanes(dtype)) + + +def vector(dtype: ScalarType, shape: tuple[int, ...] | list[int] | int) -> VectorType: + if not isinstance(dtype, ScalarType): + raise TypeError("vector() expects a TileLang scalar dtype") + if isinstance(shape, int) and not isinstance(shape, bool): + normalized_shape = (shape,) + elif isinstance(shape, (list, tuple)): + normalized_shape = tuple(shape) + else: + raise TypeError("vector() expects a shape integer or a non-empty sequence of integers") + if not normalized_shape: + raise TypeError("vector() expects a non-empty shape") + for dim in normalized_shape: + if not isinstance(dim, int) or isinstance(dim, bool): + raise TypeError("vector() shape entries must be integers") + if dim <= 0: + raise TypeError("vector() shape entries must be positive") + return VectorType(element_dtype=dtype, shape=normalized_shape) + + +def integer_bitwidth(dtype: ScalarType) -> int | None: + if not isinstance(dtype, ScalarType): + return None + return _INTEGER_DTYPE_WIDTHS.get(dtype.name) + + +def integer_signedness(dtype: ScalarType) -> str | None: + if not isinstance(dtype, ScalarType): + return None + return _INTEGER_DTYPE_SIGNS.get(dtype.name) + + +def is_integer_dtype(dtype: ScalarType) -> bool: + return integer_bitwidth(dtype) is not None + + +def is_float_dtype(dtype: ScalarType) -> bool: + return isinstance(dtype, ScalarType) and dtype.name in _FLOAT_DTYPE_WIDTHS + + +def bytewidth(dtype: ScalarType) -> int: + if not isinstance(dtype, ScalarType): + raise TypeError("bytewidth expects a TileLang scalar dtype") + width = _DTYPE_BYTE_WIDTHS.get(dtype.name) + if width is None: + raise TypeError(f"dtype `{dtype.name}` is not supported by bytewidth") + return width + + +def get_lanes(dtype: ScalarType) -> int: + return 256 // bytewidth(dtype) + + +def elements_per_vreg(dtype: ScalarType) -> int: + return get_lanes(dtype) + + +def constexpr(value: bool) -> bool: + return value + + +def get_op_attr(name: str, default: Any = None) -> Any: + if not isinstance(name, str) or not name: + raise TypeError("get_op_attr expects a non-empty string attribute name") + return default + + +__all__ = [ + "ScalarType", + "WildcardType", + "TypeVariable", + "TypeVar", + "TensorView", + "PartitionTensorView", + "Tile", + "PointerType", + "VRegType", + "VectorType", + "MaskType", + "ptr", + "vreg", + "vector", + "MemorySpace", + "Pipe", + "Event", + "PIPE", + "EVENT", + "MaskPattern", + "PredicateDist", + "VLoadDist", + "VStoreDist", + "PredicatePart", + "CmpMode", + "PAT", + "BarrierType", + "PadMode", + "BLayout", + "SLayout", + "PadValue", + "DeinterleaveDist", + "InterleaveDist", + "PositionMode", + "OrderMode", + "PostUpdateMode", + "TileConfig", + "TileSpecialization", + "i1", + "i8", + "si8", + "ui8", + "i16", + "si16", + "ui16", + "i32", + "si32", + "ui32", + "i64", + "si64", + "ui64", + "f16", + "bf16", + "f32", + "AnyFloat", + "AnyInt", + "AnyType", + "AnyMask", + "mask_b8", + "mask_b16", + "mask_b32", + "constexpr", + "get_op_attr", + "bytewidth", + "get_lanes", + "elements_per_vreg", +] diff --git a/tilelang-dsl/skills/auto-update-vpto-spec/SKILL.md b/tilelang-dsl/skills/auto-update-vpto-spec/SKILL.md new file mode 100644 index 000000000..d610f283b --- /dev/null +++ b/tilelang-dsl/skills/auto-update-vpto-spec/SKILL.md @@ -0,0 +1,80 @@ +--- +name: auto-update-vpto-spec +description: 自动对齐 TileLang DSL 与最新 VPTO 规范:比较 spec 差异并指导实现、lowering 与测试同步更新。 +license: MIT +--- + +根据最新 VPTO 规范自动更新 TileLang DSL 的规范与实现。 + +--- + +## 输入 + +- 最新规范文件(建议命名:`vpto-latest.md`) +- 当前 DSL 对齐规范文件(建议命名:`vpto-current.md`) + +如果用户没有提供路径,先询问文件路径后再继续。 + +--- + +## 执行步骤 + +1. 读取最新 VPTO 规范 `vpto-latest.md`(如果最新版本来自网络,则先下载保存)。 +2. 读取当前 DSL 使用的规范 `vpto-current.md`。 +3. 对比两者差异: + - 若无差异:输出“无需更新”,结束。 + - 若有差异:生成差异报告。 +4. 根据差异报告,逐项与用户确认每个差异变更的处理方式(新增/修改/删除),按照分类进行处理 + +5. 差异分类与处理规则: + +### A. 新增 op + +- 更新 DSL spec 中对应章节,添加新 op 的描述、参数、返回值、示例等。 +- 在 DSL 实现中新增该 op(包含前端定义与必要 lowering)。 +- 补齐对应测试用例(语法、语义、lowering/代码生成路径)。 + +### B. 修改 op 语义 + +- 同步修改 DSL spec 对应 op 的语义描述。 +- 评估并更新 DSL 实现语义行为。 +- 增加/更新回归测试覆盖新语义。 + +### C. 修改 op 参数格式 + +- 优先保持 DSL 前端接口不变(向后兼容用户调用方式)。 +- 在 lowering/转换逻辑层吸收格式变化。 +- 增加测试验证旧接口与新规范语义一致。 + +### D. 删除 op + +- 在 DSL spec 中删除对应 op。 +- 在 DSL 实现中将该 op 标记为不受支持,并在用户使用时显式报错。 +- 增加测试验证报错信息清晰可见。 + +6. 统一补充测试: + - 至少覆盖:新增/变更/删除的 golden path。 + - 包含失败路径(非法参数、已删除 op 调用)验证。 + +7. 将vpto-spec-current.md改名为vpto-spec-*.md(如vpto-spec-2024-06.md),并将vpto-latest.md改名为vpto-spec-current.md,保持版本迭代记录。 + +--- + +## 输出要求 + +- 输出变更摘要: + - 差异总览(新增/修改/删除清单) + - 更新的 DSL spec 章节 + - 更新的实现文件 + - 新增/修改的测试文件 +- 若存在无法自动判定的语义映射,先向用户提问后再继续。 + +--- + +## 护栏 + +- 不在未确认语义的情况下擅自改变 DSL 前端接口。 +- 优先在 lowering 层处理规范格式变更。 +- 删除 op 时必须提供显式错误信息与测试覆盖。 +- 若某一差异无法映射到当前 DSL 架构,先报告阻塞点并请求用户决策。 +- 始终先改文档再落实现 diff --git a/tilelang-dsl/tests/README.md b/tilelang-dsl/tests/README.md new file mode 100644 index 000000000..c1370a85a --- /dev/null +++ b/tilelang-dsl/tests/README.md @@ -0,0 +1,5 @@ +TileLang DSL tests live here. + +Keep tests for this frontend isolated from the legacy `test/python/` and +other repository-wide test trees unless a follow-up task explicitly wires +shared coverage. diff --git a/tilelang-dsl/tests/conftest.py b/tilelang-dsl/tests/conftest.py new file mode 100644 index 000000000..3b3667567 --- /dev/null +++ b/tilelang-dsl/tests/conftest.py @@ -0,0 +1,37 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +"""Pytest bootstrap for the TileLang DSL test tree.""" + +from __future__ import annotations + +import sys +from pathlib import Path + + +def _ensure_tilelang_dsl_import_path() -> None: + # Always prefer the in-tree Python sources so pytest exercises the current + # workspace edits rather than stale build artifacts. Keep the build-tree + # path as a fallback when the source package is unavailable. + repo_root = Path(__file__).resolve().parents[2] + source_path = repo_root / "tilelang-dsl" / "python" + build_path = repo_root / "build" / "python" + + source_text = str(source_path) + if source_path.exists(): + if source_text in sys.path: + sys.path.remove(source_text) + sys.path.insert(0, source_text) + return + + build_text = str(build_path) + if build_path.exists() and build_text not in sys.path: + sys.path.insert(0, build_text) + + +_ensure_tilelang_dsl_import_path() diff --git a/tilelang-dsl/tests/import_tilelang_dsl.py b/tilelang-dsl/tests/import_tilelang_dsl.py new file mode 100644 index 000000000..ac0d7c964 --- /dev/null +++ b/tilelang-dsl/tests/import_tilelang_dsl.py @@ -0,0 +1,20 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import tilelang_dsl + + +def main() -> None: + package_file = getattr(tilelang_dsl, "__file__", None) + if not package_file: + raise SystemExit("tilelang_dsl import did not expose __file__") + print(package_file) + + +if __name__ == "__main__": + main() diff --git a/tilelang-dsl/tests/test_tilelang_dsl_v1.py b/tilelang-dsl/tests/test_tilelang_dsl_v1.py new file mode 100644 index 000000000..a1e5f4912 --- /dev/null +++ b/tilelang-dsl/tests/test_tilelang_dsl_v1.py @@ -0,0 +1,9804 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import tempfile +import unittest +import io +import subprocess +import sys +from contextlib import redirect_stderr +from unittest import mock +from importlib import util +from pathlib import Path + +import tilelang_dsl as pto +import tilelang_dsl.expand_helper as expand_helper +import tilelang_dsl.kernel as kernel_impl +from tilelang_dsl.support_matrix import ( + ADVANCED_EXPLICIT_VECSCOPE_SURFACES, + ADVANCED_LOW_LEVEL_DMA_SURFACES, + ADVANCED_RAW_POINTER_SURFACES, + ADVANCED_TILE_HELPER_SURFACES, + ADVANCED_TIER, + AUTHORING_TIER_SURFACE_GROUPS, + BASIC_TIER, + BASIC_TILE_INDEXING_SURFACES, + ADVANCED_VECSCOPE_PTO_CALLS, + SUPPORTED_VECSCOPE_PTO_CALLS, + get_feature_tier, + get_surface_group_tier, +) +from tilelang_dsl.frontend_ast import ( + FrontendAssignStmt, + FrontendCallExpr, + FrontendExprStmt, + FrontendForStmt, + FrontendIfStmt, + FrontendStrictVecscopeStmt, + FrontendVecscopeStmt, + FrontendNoOpStmt, + build_frontend_kernel_node, +) +from tilelang_dsl.lowering import AuthoringModule, lower_semantic_kernel +from tilelang_dsl.semantic import ( + SemanticAlignStoreStmt, + SemanticAlignType, + SemanticAssignStmt, + SemanticBindingRef, + SemanticBinaryExpr, + SemanticCallExpr, + SemanticDmaConfigStmt, + SemanticDmaUnaryConfigStmt, + SemanticExprStmt, + SemanticForStmt, + SemanticGetBufStmt, + SemanticIfStmt, + SemanticIndexType, + SemanticLiteralExpr, + SemanticMemBarStmt, + SemanticLowLevelCopyStmt, + SemanticMaskType, + SemanticPadValueType, + SemanticPartitionTensorViewType, + SemanticPipeBarrierStmt, + SemanticPtrType, + SemanticPredicateStoreStmt, + SemanticReturnStmt, + SemanticRlsBufStmt, + SemanticScalarStoreStmt, + SemanticScalarType, + SemanticTupleExpr, + SemanticSetCrossCoreStmt, + SemanticSetFlagStmt, + SemanticSetIntraBlockStmt, + SemanticSetIntraCoreStmt, + SemanticStrictVecscopeStmt, + SemanticSymbolExpr, + SemanticTensorViewType, + SemanticTileConfigType, + SemanticTileType, + SemanticVecscopeStmt, + SemanticVScatterStmt, + SemanticVectorPairStoreStmt, + SemanticVectorStoreStmt, + SemanticVRegType, + SemanticWaitFlagDevStmt, + SemanticWaitFlagStmt, + SemanticWaitIntraCoreStmt, + analyze_frontend_kernel, +) + +GLOBAL_TILELANG_LITERAL_BLOCK_SIZE = 32 +INLINE_PROC_GLOBAL_LANE = 0 + + +def _walk_semantic_stmts(statements): + for stmt in statements: + yield stmt + if isinstance(stmt, SemanticVecscopeStmt): + yield from _walk_semantic_stmts(stmt.body) + elif isinstance(stmt, SemanticForStmt): + yield from _walk_semantic_stmts(stmt.body) + elif isinstance(stmt, SemanticIfStmt): + yield from _walk_semantic_stmts(stmt.then_body) + yield from _walk_semantic_stmts(stmt.else_body) + + +def _find_inline_helper(semantic_kernel, symbol_prefix): + return next( + helper for helper in semantic_kernel.inline_helpers if helper.symbol_name.startswith(symbol_prefix) + ) + + +def _find_helper_assign_by_ssa(helper, ssa_name): + return next( + stmt + for stmt in helper.body + if isinstance(stmt, SemanticAssignStmt) + and any(target.ssa_name == ssa_name for target in stmt.targets) + ) + + +def _find_last_helper_assign_by_name(helper, name): + return next( + stmt + for stmt in reversed(helper.body) + if isinstance(stmt, SemanticAssignStmt) + and any(target.name == name for target in stmt.targets) + ) + + +def _find_helper_return_stmt(helper): + return next(stmt for stmt in helper.body if isinstance(stmt, SemanticReturnStmt)) + + +def _resolve_helper_expr(helper, expr): + if isinstance(expr, SemanticBindingRef): + assign = _find_helper_assign_by_ssa(helper, expr.binding.ssa_name) + return _resolve_helper_expr(helper, assign.value) + return expr + + +def _resolve_helper_broadcast_scalar_literal(helper, expr): + resolved = _resolve_helper_expr(helper, expr) + if isinstance(resolved, SemanticLiteralExpr): + return resolved.value + if isinstance(resolved, SemanticCallExpr) and resolved.namespace == "pto" and resolved.name == "vbr": + return _resolve_helper_broadcast_scalar_literal(helper, resolved.args[0]) + raise AssertionError(f"expected helper scalar literal or broadcast, got {resolved!r}") + + +class TileLangDSLPackageTests(unittest.TestCase): + def test_package_exports_surface(self) -> None: + self.assertIsNotNone(pto.__file__) + self.assertTrue(hasattr(pto, "vkernel")) + self.assertTrue(hasattr(pto, "KernelRegistry")) + self.assertTrue(hasattr(pto, "select_kernel")) + self.assertTrue(hasattr(pto, "TensorView")) + self.assertTrue(hasattr(pto, "Tile")) + self.assertTrue(hasattr(pto, "TileSpecialization")) + self.assertTrue(hasattr(pto, "PointerType")) + self.assertTrue(hasattr(pto, "VectorType")) + self.assertTrue(hasattr(pto, "VRegType")) + self.assertTrue(hasattr(pto, "MaskType")) + self.assertTrue(hasattr(pto, "AlignType")) + self.assertTrue(hasattr(pto, "ptr")) + self.assertTrue(hasattr(pto, "vector")) + self.assertTrue(hasattr(pto, "vreg")) + self.assertTrue(hasattr(pto, "MemorySpace")) + self.assertTrue(hasattr(pto, "align")) + self.assertTrue(hasattr(pto, "mask_b8")) + self.assertTrue(hasattr(pto, "mask_b16")) + self.assertTrue(hasattr(pto, "mask_b32")) + self.assertTrue(hasattr(pto, "constexpr")) + self.assertTrue(hasattr(pto, "bytewidth")) + self.assertTrue(hasattr(pto, "get_lanes")) + self.assertTrue(hasattr(pto, "elements_per_vreg")) + self.assertTrue(hasattr(pto, "PAT")) + self.assertTrue(hasattr(pto, "PredicateDist")) + self.assertTrue(hasattr(pto, "PadMode")) + self.assertTrue(hasattr(pto, "BarrierType")) + self.assertTrue(hasattr(pto, "BLayout")) + self.assertTrue(hasattr(pto, "DeinterleaveDist")) + self.assertTrue(hasattr(pto, "InterleaveDist")) + self.assertTrue(hasattr(pto, "PositionMode")) + self.assertTrue(hasattr(pto, "OrderMode")) + self.assertTrue(hasattr(pto, "PadValue")) + self.assertTrue(hasattr(pto, "VcvtRoundMode")) + self.assertTrue(hasattr(pto, "VcvtSatMode")) + self.assertTrue(hasattr(pto, "VcvtPartMode")) + self.assertTrue(hasattr(pto, "PostUpdateMode")) + self.assertTrue(hasattr(pto, "FractalMode")) + self.assertTrue(hasattr(pto, "SLayout")) + self.assertTrue(hasattr(pto, "PIPE")) + self.assertTrue(hasattr(pto, "EVENT")) + self.assertTrue(hasattr(pto, "si8")) + self.assertTrue(hasattr(pto, "ui8")) + self.assertTrue(hasattr(pto, "si16")) + self.assertTrue(hasattr(pto, "ui16")) + self.assertTrue(hasattr(pto, "si32")) + self.assertTrue(hasattr(pto, "ui32")) + self.assertTrue(hasattr(pto, "si64")) + self.assertTrue(hasattr(pto, "ui64")) + self.assertEqual(pto.BarrierType.VST_VLD.value, "VST_VLD") + self.assertEqual(pto.BarrierType.VST_VST.value, "VST_VST") + self.assertEqual(pto.BarrierType.VS_ALL.value, "VS_ALL") + self.assertEqual(pto.BarrierType.VST_LD.value, "VST_LD") + self.assertEqual(pto.BarrierType.VLD_ST.value, "VLD_ST") + self.assertEqual(pto.BarrierType.VST_ST.value, "VST_ST") + self.assertEqual(pto.BarrierType.SV_ALL.value, "SV_ALL") + self.assertEqual(pto.BarrierType.ST_VLD.value, "ST_VLD") + self.assertEqual(pto.BarrierType.LD_VST.value, "LD_VST") + self.assertEqual(pto.BarrierType.ST_VST.value, "ST_VST") + self.assertEqual(pto.PadMode.PadNull.value, "PadNull") + self.assertEqual(pto.PadMode.PadFirstElem.value, "PadFirstElem") + self.assertEqual(pto.PadMode.PadValue.value, "PadValue") + self.assertEqual(pto.BLayout.ROW_MAJOR.value, "row_major") + self.assertEqual(pto.SLayout.NONE_BOX.value, "none_box") + self.assertEqual(pto.PadValue.NULL.encoded, 0) + self.assertEqual(pto.PadValue.ZERO.encoded, 1) + self.assertEqual(pto.PadValue.MAX.encoded, 2) + self.assertEqual(pto.PadValue.MIN.encoded, 3) + self.assertEqual(pto.PadValue.NULL.text, "null") + self.assertEqual(pto.DeinterleaveDist.DINTLV.value, "DINTLV") + self.assertEqual(pto.DeinterleaveDist.BDINTLV.value, "BDINTLV") + self.assertEqual(pto.InterleaveDist.INTLV.value, "INTLV") + self.assertEqual(pto.PositionMode.LOWEST.value, "LOWEST") + self.assertEqual(pto.PositionMode.HIGHEST.value, "HIGHEST") + self.assertEqual(pto.OrderMode.ASC.value, "ASC") + self.assertEqual(pto.OrderMode.DESC.value, "DESC") + self.assertEqual(pto.PredicateDist.NORM.value, "NORM") + self.assertEqual(pto.PredicateDist.US.value, "US") + self.assertEqual(pto.PredicateDist.DS.value, "DS") + self.assertEqual(pto.PredicateDist.PK.value, "PK") + self.assertTrue(hasattr(pto, "PredicatePart")) + self.assertEqual(pto.PredicatePart.LOWER.value, "LOWER") + self.assertEqual(pto.PredicatePart.HIGHER.value, "HIGHER") + self.assertTrue(hasattr(pto, "CmpMode")) + self.assertEqual(pto.CmpMode.EQ.value, "eq") + self.assertEqual(pto.CmpMode.NE.value, "ne") + self.assertEqual(pto.CmpMode.LT.value, "lt") + self.assertEqual(pto.CmpMode.LE.value, "le") + self.assertEqual(pto.CmpMode.GT.value, "gt") + self.assertEqual(pto.CmpMode.GE.value, "ge") + self.assertEqual(pto.VcvtRoundMode.R.value, "R") + self.assertEqual(pto.VcvtSatMode.SAT.value, "SAT") + self.assertEqual(pto.VcvtPartMode.EVEN.value, "EVEN") + self.assertEqual(pto.VcvtPartMode.ODD.value, "ODD") + self.assertEqual(pto.VcvtPartMode.P0.value, "P0") + self.assertEqual(pto.VcvtPartMode.P1.value, "P1") + self.assertEqual(pto.VcvtPartMode.P2.value, "P2") + self.assertEqual(pto.VcvtPartMode.P3.value, "P3") + self.assertEqual(pto.PostUpdateMode.POST_UPDATE.value, "POST_UPDATE") + self.assertEqual(pto.PostUpdateMode.NO_POST_UPDATE.value, "NO_POST_UPDATE") + self.assertEqual(pto.FractalMode.ND2NZ.value, "nd2nz") + self.assertEqual(pto.FractalMode.DN2NZ.value, "dn2nz") + self.assertEqual(pto.FractalMode.NZ2ND.value, "nz2nd") + self.assertEqual(pto.FractalMode.NZ2DN.value, "nz2dn") + self.assertEqual(pto.FractalMode.NZ2NZ.value, "nz2nz") + self.assertEqual(pto.Event.ID31.value, "EVENT_ID31") + self.assertEqual(pto.MemorySpace.GM.value, "gm") + self.assertEqual(pto.MemorySpace.MAT.value, "mat") + self.assertEqual(pto.MemorySpace.LEFT.value, "left") + self.assertEqual(pto.MemorySpace.RIGHT.value, "right") + self.assertEqual(pto.MemorySpace.ACC.value, "acc") + self.assertEqual(pto.MemorySpace.BIAS.value, "bias") + self.assertEqual(pto.MemorySpace.UB.value, "ub") + self.assertIs(pto.DeinterleaveDist.B32, pto.DeinterleaveDist.DINTLV) + self.assertIs(pto.InterleaveDist.B32, pto.InterleaveDist.INTLV) + self.assertEqual(pto.si8.name, "si8") + self.assertEqual(pto.ui16.name, "ui16") + self.assertEqual(pto.si32.name, "si32") + self.assertEqual(pto.ui64.name, "ui64") + self.assertIsNot(pto.si8, pto.i8) + self.assertIsNot(pto.ui32, pto.i32) + self.assertEqual(pto.bytewidth(pto.si16), 2) + self.assertEqual(pto.bytewidth(pto.ui64), 8) + self.assertEqual(pto.get_lanes(pto.ui32), 64) + self.assertEqual(pto.get_lanes(pto.i64), 32) + self.assertEqual(pto.elements_per_vreg(pto.si8), 256) + self.assertEqual(repr(pto.align), "align") + + def test_tile_config_exposes_normalized_query_properties(self) -> None: + default_config = pto.TileConfig() + self.assertEqual(default_config.b_layout, pto.BLayout.ROW_MAJOR) + self.assertEqual(default_config.s_layout, pto.SLayout.NONE_BOX) + self.assertEqual(default_config.s_fractal_size, 512) + self.assertEqual(default_config.pad_value, pto.PadValue.NULL) + + config = pto.TileConfig.from_mapping( + { + "layout": "col_major", + "s_layout": "row_major", + "fractal": 16, + "pad": "max", + } + ) + self.assertEqual(config.b_layout, pto.BLayout.COL_MAJOR) + self.assertEqual(config.s_layout, pto.SLayout.ROW_MAJOR) + self.assertEqual(config.s_fractal_size, 16) + self.assertEqual(config.pad_value, pto.PadValue.MAX) + + def test_pad_value_supports_standard_and_custom_payloads(self) -> None: + custom = pto.PadValue.custom_f32(-1.0) + self.assertTrue(custom.is_custom) + self.assertEqual(custom.float32_bits, 0xBF800000) + self.assertEqual(custom.encoded, pto.PadValue.CustomBase | (0xBF800000 << 32)) + self.assertAlmostEqual(custom.as_float32(), -1.0) + self.assertAlmostEqual(custom.eval(pto.f32), -1.0) + self.assertEqual(pto.PadValue.MAX.eval(pto.ui16), 0xFFFF) + self.assertEqual(pto.PadValue.MIN.eval(pto.ui16), 0) + self.assertEqual(pto.PadValue.MAX.eval(pto.i16), 0x7FFF) + self.assertEqual(pto.PadValue.MIN.eval(pto.i16), -0x8000) + self.assertIsNone(pto.PadValue.NULL.eval(pto.f16)) + with self.assertRaises(AttributeError): + _ = pto.PadValue.ZERO.value + + +class TileLangDSLExpandHelperTests(unittest.TestCase): + def test_cross_file_inline_proc_direct_import_materializes(self) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + root = Path(tmpdir) + shared_name = "shared_cross_file_positive_unique" + (root / f"{shared_name}.py").write_text( + """ +import tilelang_dsl as pto + +@pto.inline_proc +def shared_touch(): + return +""", + encoding="utf-8", + ) + template_path = root / "cross_file_positive_template_unique.py" + template_path.write_text( + f""" +import tilelang_dsl as pto +from {shared_name} import shared_touch + +@pto.vkernel(op="pto.cross_file_positive_unique", dtypes=[(pto.f32,)]) +def kernel(src: pto.Tile): + shared_touch() + return +""", + encoding="utf-8", + ) + + with expand_helper._template_import_context(root): + mod = expand_helper._import_py_file(template_path) + self.assertIsNotNone(mod) + desc = expand_helper._find_descriptors(mod)[0] + self.assertIn("shared_touch", desc.inline_procs) + + specialized = desc.specialize( + src=pto.TileSpecialization(shape=(1, 64), memory_space=pto.MemorySpace.UB) + ) + frontend = build_frontend_kernel_node(specialized) + self.assertIn("shared_touch", {proc.name for proc in frontend.inline_procs}) + + text = specialized.mlir_text() + self.assertRegex(text, r"func\.call @__tl_inline_shared_touch_") + self.assertRegex(text, r"func\.func private @__tl_inline_shared_touch_") + + def test_cross_file_inline_proc_package_import_materializes_without_leaking_sys_path(self) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + package_root = Path(tmpdir) + template_dir = package_root / "TileOps" + template_dir.mkdir() + (template_dir / "__init__.py").write_text("", encoding="utf-8") + + shared_name = "shared_cross_file_package_unique" + (template_dir / f"{shared_name}.py").write_text( + """ +import tilelang_dsl as pto + +@pto.inline_proc +def shared_touch(): + return +""", + encoding="utf-8", + ) + template_path = template_dir / "cross_file_package_template_unique.py" + template_path.write_text( + f""" +import tilelang_dsl as pto +from TileOps.{shared_name} import shared_touch + +@pto.vkernel(op="pto.cross_file_package_unique", dtypes=[(pto.f32,)]) +def kernel(src: pto.Tile): + shared_touch() + return +""", + encoding="utf-8", + ) + + before_counts = { + str(template_dir): sys.path.count(str(template_dir)), + str(package_root): sys.path.count(str(package_root)), + } + with expand_helper._template_import_context(template_dir): + self.assertGreaterEqual( + sys.path.count(str(template_dir)), + before_counts[str(template_dir)] + 1, + ) + self.assertGreaterEqual( + sys.path.count(str(package_root)), + before_counts[str(package_root)] + 1, + ) + mod = expand_helper._import_py_file(template_path) + self.assertIsNotNone(mod) + self.assertEqual(sys.path.count(str(template_dir)), before_counts[str(template_dir)]) + self.assertEqual(sys.path.count(str(package_root)), before_counts[str(package_root)]) + + desc = expand_helper._find_descriptors(mod)[0] + self.assertIn("shared_touch", desc.inline_procs) + + specialized = desc.specialize( + src=pto.TileSpecialization(shape=(1, 64), memory_space=pto.MemorySpace.UB) + ) + frontend = build_frontend_kernel_node(specialized) + self.assertIn("shared_touch", {proc.name for proc in frontend.inline_procs}) + + text = specialized.mlir_text() + self.assertRegex(text, r"func\.call @__tl_inline_shared_touch_") + self.assertRegex(text, r"func\.func private @__tl_inline_shared_touch_") + + def test_cross_file_inline_proc_collects_shared_helper_callees(self) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + root = Path(tmpdir) + shared_name = "shared_cross_file_nested_unique" + (root / f"{shared_name}.py").write_text( + """ +import tilelang_dsl as pto + +@pto.inline_proc +def shared_leaf(): + return + +@pto.inline_proc +def shared_entry(): + shared_leaf() + return +""", + encoding="utf-8", + ) + template_path = root / "cross_file_nested_template_unique.py" + template_path.write_text( + f""" +import tilelang_dsl as pto +from {shared_name} import shared_entry + +@pto.vkernel(op="pto.cross_file_nested_unique", dtypes=[(pto.f32,)]) +def kernel(src: pto.Tile): + shared_entry() + return +""", + encoding="utf-8", + ) + + with expand_helper._template_import_context(root): + mod = expand_helper._import_py_file(template_path) + self.assertIsNotNone(mod) + desc = expand_helper._find_descriptors(mod)[0] + self.assertIn("shared_entry", desc.inline_procs) + self.assertIn("shared_leaf", desc.inline_procs) + + specialized = desc.specialize( + src=pto.TileSpecialization(shape=(1, 64), memory_space=pto.MemorySpace.UB) + ) + frontend = build_frontend_kernel_node(specialized) + self.assertEqual( + {proc.name for proc in frontend.inline_procs}, + {"shared_entry", "shared_leaf"}, + ) + + text = specialized.mlir_text() + self.assertRegex(text, r"func\.call @__tl_inline_shared_entry_") + self.assertRegex(text, r"func\.func private @__tl_inline_shared_entry_") + self.assertRegex(text, r"func\.func private @__tl_inline_shared_leaf_") + + def test_cross_file_imported_plain_function_is_rejected(self) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + root = Path(tmpdir) + shared_name = "shared_cross_file_plain_unique" + (root / f"{shared_name}.py").write_text( + """ +def plain_helper(): + return +""", + encoding="utf-8", + ) + template_path = root / "cross_file_plain_template_unique.py" + template_path.write_text( + f""" +import tilelang_dsl as pto +from {shared_name} import plain_helper + +@pto.vkernel(op="pto.cross_file_plain_unique", dtypes=[(pto.f32,)]) +def kernel(src: pto.Tile): + plain_helper() + return +""", + encoding="utf-8", + ) + + stderr = io.StringIO() + with redirect_stderr(stderr), expand_helper._template_import_context(root): + mod = expand_helper._import_py_file(template_path) + + self.assertIsNone(mod) + self.assertIn( + "arbitrary external call `plain_helper` is not supported", + stderr.getvalue(), + ) + + def test_cross_file_inline_proc_negative_diagnostics(self) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + root = Path(tmpdir) + recursive_name = "shared_cross_file_recursive_unique" + (root / f"{recursive_name}.py").write_text( + """ +import tilelang_dsl as pto + +@pto.inline_proc +def shared_recur(): + shared_recur() + return +""", + encoding="utf-8", + ) + recursive_template = root / "cross_file_recursive_template_unique.py" + recursive_template.write_text( + f""" +import tilelang_dsl as pto +from {recursive_name} import shared_recur + +@pto.vkernel(op="pto.cross_file_recursive_unique", dtypes=[(pto.f32,)]) +def kernel(src: pto.Tile): + shared_recur() + return +""", + encoding="utf-8", + ) + with expand_helper._template_import_context(root): + recursive_mod = expand_helper._import_py_file(recursive_template) + self.assertIsNotNone(recursive_mod) + recursive_desc = expand_helper._find_descriptors(recursive_mod)[0] + with self.assertRaises(pto.TileLangFrontendError) as recursive_ctx: + recursive_desc.specialize( + src=pto.TileSpecialization(shape=(1, 64), memory_space=pto.MemorySpace.UB) + ).mlir_text() + self.assertIn("recursive inline_proc call `shared_recur`", str(recursive_ctx.exception)) + + capture_name = "shared_cross_file_capture_unique" + (root / f"{capture_name}.py").write_text( + """ +import tilelang_dsl as pto + +scale = object() + +@pto.inline_proc +def shared_capture(): + value = scale + return +""", + encoding="utf-8", + ) + capture_template = root / "cross_file_capture_template_unique.py" + capture_template.write_text( + f""" +import tilelang_dsl as pto +from {capture_name} import shared_capture + +@pto.vkernel(op="pto.cross_file_capture_unique", dtypes=[(pto.f32,)]) +def kernel(src: pto.Tile): + shared_capture() + return +""", + encoding="utf-8", + ) + with expand_helper._template_import_context(root): + capture_mod = expand_helper._import_py_file(capture_template) + self.assertIsNotNone(capture_mod) + capture_desc = expand_helper._find_descriptors(capture_mod)[0] + with self.assertRaises(pto.TileLangFrontendError) as capture_ctx: + capture_desc.specialize( + src=pto.TileSpecialization(shape=(1, 64), memory_space=pto.MemorySpace.UB) + ).mlir_text() + self.assertIn("implicit capture of 'scale' is not allowed", str(capture_ctx.exception)) + + conflict_name = "shared_cross_file_conflict_unique" + (root / f"{conflict_name}.py").write_text( + """ +import tilelang_dsl as pto + +@pto.inline_proc +def helper(): + return + +@pto.inline_proc +def entry(): + return +""", + encoding="utf-8", + ) + conflict_template = root / "cross_file_conflict_template_unique.py" + conflict_template.write_text( + f""" +import tilelang_dsl as pto +from {conflict_name} import entry as helper + +@pto.vkernel(op="pto.cross_file_conflict_unique", dtypes=[(pto.f32,)]) +def kernel(src: pto.Tile): + helper() + return +""", + encoding="utf-8", + ) + stderr = io.StringIO() + with redirect_stderr(stderr), expand_helper._template_import_context(root): + conflict_mod = expand_helper._import_py_file(conflict_template) + self.assertIsNone(conflict_mod) + self.assertIn("ambiguous inline_proc name `helper`", stderr.getvalue()) + + def test_operand_specs_preserve_tile_valid_shape_and_pad_value(self) -> None: + source = """ +import tilelang_dsl as pto + +@pto.vkernel(op="pto.expand_helper_tile_config_unique", dtypes=[(pto.f32, pto.f32)]) +def kernel(src: pto.Tile, dst: pto.Tile): + rows, cols = src.valid_shape + pad = dst.pad_value + if pto.constexpr(pad != pto.PadValue.NULL): + scalar = pad.eval() + return None +""" + with tempfile.TemporaryDirectory() as tmpdir: + module_path = Path(tmpdir) / "expand_helper_tile_config_unique.py" + module_path.write_text(source, encoding="utf-8") + + mod = expand_helper._import_py_file(module_path) + self.assertIsNotNone(mod) + descriptors = expand_helper._find_descriptors(mod) + self.assertTrue(descriptors) + + operand_specs = expand_helper._parse_operand_specs( + """ +[ + { + "kind": "tile", + "dtype": "f32", + "shape": [16, 64], + "valid_shape": [8, 48], + "memory_space": "ub", + "config": { + "b_layout": "row_major", + "s_layout": "none_box", + "s_fractal_size": 512, + "pad_value": "0x0" + } + }, + { + "kind": "tile", + "dtype": "f32", + "shape": [16, 64], + "valid_shape": [8, 48], + "memory_space": "ub", + "config": { + "b_layout": "row_major", + "s_layout": "none_box", + "s_fractal_size": 512, + "pad_value": "0x1" + } + } +] +""" + ) + desc = expand_helper._select_descriptor( + descriptors, + target="a5", + op_name="pto.expand_helper_tile_config_unique", + operand_specs=operand_specs, + ) + self.assertIsNotNone(desc) + + tile_specs = {} + for param, operand_spec in zip(desc.parameters, operand_specs): + self.assertEqual(param.kind, "tile") + tile_specs[param.name] = pto.TileSpecialization( + shape=operand_spec["shape"], + memory_space=operand_spec["memory_space"], + config=operand_spec["config"], + valid_shape=operand_spec["valid_shape"], + ) + + mlir_text = desc.specialize(**tile_specs).mlir_text() + + self.assertIn("valid_shape=(8, 48)", mlir_text) + self.assertIn( + "!pto.tile_buf", + mlir_text, + ) + self.assertIn( + "!pto.tile_buf", + mlir_text, + ) + + def test_select_descriptor_uses_positional_context_for_named_constraints(self) -> None: + source = """ +import tilelang_dsl as pto + +@pto.vkernel( + target="a5", + op="pto.expand_helper_positional_constraints_unique", + dtypes=[(pto.f32, pto.f32)], + constraints=[ + lambda src: src.rank == 5, + lambda src: src.strides[4] == 1, + lambda dst: dst.config.b_layout == pto.BLayout.ROW_MAJOR, + ], +) +def template_nd(src: pto.TensorView, dst: pto.Tile): + return None + +@pto.vkernel( + target="a5", + op="pto.expand_helper_positional_constraints_unique", + dtypes=[(pto.f32, pto.f32)], + constraints=[ + lambda inp: inp.rank == 5, + lambda out: out.config.b_layout == pto.BLayout.COL_MAJOR, + ], + priority=9, +) +def template_dn(inp: pto.TensorView, out: pto.Tile): + return None +""" + with tempfile.TemporaryDirectory() as tmpdir: + module_path = Path(tmpdir) / "expand_helper_positional_constraints_unique.py" + module_path.write_text(source, encoding="utf-8") + + mod = expand_helper._import_py_file(module_path) + self.assertIsNotNone(mod) + descriptors = expand_helper._find_descriptors(mod) + self.assertTrue(descriptors) + + operand_specs = expand_helper._parse_operand_specs( + """ +[ + { + "kind": "view", + "dtype": "f32", + "shape": [1, 1, 1, 16, 64], + "strides": [1024, 1024, 1024, 64, 1], + "memory_space": "gm" + }, + { + "kind": "tile", + "dtype": "f32", + "shape": [16, 64], + "valid_shape": [16, 64], + "memory_space": "ub", + "config": { + "b_layout": "row_major", + "s_layout": "none_box", + "s_fractal_size": 512, + "pad_value": "0x0" + } + } +] +""" + ) + + selected = expand_helper._select_descriptor( + descriptors, + target="a5", + op_name="pto.expand_helper_positional_constraints_unique", + operand_specs=operand_specs, + ) + + self.assertEqual(selected.name, "template_nd") + + def test_select_descriptor_accepts_aux_vector_operand_for_vector_annotation(self) -> None: + source = """ +import tilelang_dsl as pto + +@pto.vkernel( + target="a5", + op="pto.expand_helper_vector_operand_unique", + dtypes=[(pto.f32, pto.i16, pto.f32)], +) +def template(src: pto.Tile, ex_vec: pto.vector(pto.i16, (4,)), dst: pto.Tile): + return None +""" + with tempfile.TemporaryDirectory() as tmpdir: + module_path = Path(tmpdir) / "expand_helper_vector_operand_unique.py" + module_path.write_text(source, encoding="utf-8") + + mod = expand_helper._import_py_file(module_path) + self.assertIsNotNone(mod) + descriptors = expand_helper._find_descriptors(mod) + self.assertTrue(descriptors) + + operand_specs = expand_helper._parse_operand_specs( + """ +[ + { + "kind": "tile", + "dtype": "f32", + "shape": [1, 256], + "valid_shape": [1, 256], + "memory_space": "ub", + "config": { + "b_layout": "row_major", + "s_layout": "none_box", + "s_fractal_size": 512, + "pad_value": "0x0" + } + }, + { + "kind": "vector", + "dtype": "i16", + "shape": [4] + }, + { + "kind": "tile", + "dtype": "f32", + "shape": [1, 256], + "valid_shape": [1, 256], + "memory_space": "ub", + "config": { + "b_layout": "row_major", + "s_layout": "none_box", + "s_fractal_size": 512, + "pad_value": "0x0" + } + } +] +""" + ) + + selected = expand_helper._select_descriptor( + descriptors, + target="a5", + op_name="pto.expand_helper_vector_operand_unique", + operand_specs=operand_specs, + ) + tile_specs = {} + for param, operand_spec in zip(selected.parameters, operand_specs): + if param.kind != "tile": + continue + tile_specs[param.name] = pto.TileSpecialization( + shape=operand_spec["shape"], + memory_space=operand_spec["memory_space"], + config=operand_spec["config"], + valid_shape=operand_spec["valid_shape"], + ) + mlir_text = selected.specialize(**tile_specs).mlir_text() + + self.assertEqual(selected.name, "template") + self.assertEqual(selected.parameters[1].kind, "vector") + self.assertEqual(selected.parameters[1].annotation, pto.vector(pto.i16, (4,))) + self.assertIn("vector<4xi16>", mlir_text) + context_attrs = expand_helper._build_positional_context_attrs(operand_specs) + self.assertEqual(context_attrs["arg1_kind"], "vector") + self.assertEqual(context_attrs["arg1_shape"], (4,)) + self.assertEqual(context_attrs["arg1_rank"], 1) + + +class TileLangDSLSupportMatrixTests(unittest.TestCase): + def test_stable_starter_surface_groups_map_to_stable_tier(self) -> None: + self.assertEqual(get_surface_group_tier("TensorView"), BASIC_TIER) + self.assertEqual(get_surface_group_tier("Tile"), BASIC_TIER) + self.assertEqual(get_surface_group_tier("base_vector_ops"), BASIC_TIER) + self.assertEqual(get_surface_group_tier("tile_indexing_sugar"), BASIC_TIER) + + self.assertIn("TensorView", AUTHORING_TIER_SURFACE_GROUPS["TensorView"]) + self.assertIn("Tile", AUTHORING_TIER_SURFACE_GROUPS["Tile"]) + self.assertNotIn("dma_load/store", AUTHORING_TIER_SURFACE_GROUPS) + self.assertIn("pto.vlds", AUTHORING_TIER_SURFACE_GROUPS["base_vector_ops"]) + self.assertIn("pto.vsts", AUTHORING_TIER_SURFACE_GROUPS["base_vector_ops"]) + self.assertIn("pto.vadd", AUTHORING_TIER_SURFACE_GROUPS["base_vector_ops"]) + self.assertIn("pto.vmuls", AUTHORING_TIER_SURFACE_GROUPS["base_vector_ops"]) + self.assertIn("pto.vmod", AUTHORING_TIER_SURFACE_GROUPS["base_vector_ops"]) + self.assertIn("tile[start:]", BASIC_TILE_INDEXING_SURFACES) + self.assertIn("tile[row, col:]", BASIC_TILE_INDEXING_SURFACES) + + self.assertEqual(get_feature_tier("TensorView"), BASIC_TIER) + self.assertEqual(get_feature_tier("Tile"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.vlds"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.vsts"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.vadd"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.vmuls"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.vmod"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.get_buf"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.rls_buf"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.get_block_idx"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.get_subblock_num"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.mem_bar"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.set_cross_core"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.set_intra_block"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.set_intra_core"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.wait_flag_dev"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.wait_intra_core"), BASIC_TIER) + self.assertEqual(get_feature_tier("BarrierType"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.vaddrelu"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.vaxpy"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.vmull"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.vands"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.vbr"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.vdup"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.vci"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.vpack"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.vsort32"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.vldsx2"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.vstsx2"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.vscatter"), ADVANCED_TIER) + self.assertEqual(get_feature_tier("pto.vbitsort"), ADVANCED_TIER) + self.assertEqual(get_feature_tier("pto.vmrgsort4"), ADVANCED_TIER) + self.assertEqual(get_feature_tier("pto.get_vms4_sr"), ADVANCED_TIER) + self.assertEqual(get_feature_tier("PadMode"), BASIC_TIER) + self.assertEqual(get_feature_tier("VRegType"), BASIC_TIER) + self.assertEqual(get_feature_tier("MaskType"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.vreg"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.mask_b8"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.mask_b16"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.mask_b32"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.bytewidth"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.get_lanes"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.elements_per_vreg"), BASIC_TIER) + self.assertEqual(get_feature_tier("pto.constexpr"), BASIC_TIER) + self.assertEqual(get_feature_tier("constexpr"), BASIC_TIER) + self.assertEqual(get_feature_tier("tile[start:]"), BASIC_TIER) + self.assertEqual(get_feature_tier("tile[row, col:]"), BASIC_TIER) + + def test_non_stable_surface_groups_keep_advanced_boundaries(self) -> None: + self.assertEqual(get_surface_group_tier("strict_vecscope"), ADVANCED_TIER) + self.assertEqual(get_surface_group_tier("raw_pointer_family"), ADVANCED_TIER) + self.assertEqual(get_surface_group_tier("low_level_dma_family"), ADVANCED_TIER) + self.assertEqual(get_surface_group_tier("tile_helper_family"), ADVANCED_TIER) + + self.assertIn("pto.strict_vecscope", ADVANCED_EXPLICIT_VECSCOPE_SURFACES) + self.assertIn("pto.ptr", ADVANCED_RAW_POINTER_SURFACES) + self.assertIn("pto.castptr", ADVANCED_RAW_POINTER_SURFACES) + self.assertIn("pto.set_mov_pad_val", ADVANCED_LOW_LEVEL_DMA_SURFACES) + self.assertIn("pto.copy_ubuf_to_ubuf", ADVANCED_LOW_LEVEL_DMA_SURFACES) + self.assertIn("pto.tile_with_strides", ADVANCED_TILE_HELPER_SURFACES) + + self.assertEqual(get_feature_tier("strict_vecscope"), ADVANCED_TIER) + self.assertEqual(get_feature_tier("pto.strict_vecscope"), ADVANCED_TIER) + self.assertEqual(get_feature_tier("pto.ptr"), ADVANCED_TIER) + self.assertEqual(get_feature_tier("pto.castptr"), ADVANCED_TIER) + self.assertEqual(get_feature_tier("pto.load_scalar"), ADVANCED_TIER) + self.assertEqual(get_feature_tier("pto.store_scalar"), ADVANCED_TIER) + self.assertEqual(get_feature_tier("pto.set_mov_pad_val"), ADVANCED_TIER) + self.assertEqual(get_feature_tier("pto.copy_ubuf_to_ubuf"), ADVANCED_TIER) + self.assertEqual(get_feature_tier("pto.tile_with_strides"), ADVANCED_TIER) + + def test_unsupported_features_do_not_report_legacy_tiers(self) -> None: + with self.assertRaises(KeyError): + get_surface_group_tier("dma_load/store") + with self.assertRaises(KeyError): + get_feature_tier("pto.dma_load") + with self.assertRaises(KeyError): + get_feature_tier("pto.dma_store") + with self.assertRaises(KeyError): + get_feature_tier("pto.dma_copy") + with self.assertRaises(KeyError): + get_feature_tier("pto.vreduce") + +class TileLangDSLMatcherEntryTests(unittest.TestCase): + def test_select_kernel_returns_descriptor_from_default_registry(self) -> None: + @pto.vkernel(op="matcher_entry_default_registry_unique", dtypes=[(pto.f32, pto.i32)]) + def kernel(inp: pto.TensorView, scale: pto.i32): + return None + + selected = pto.select_kernel( + "a5", + "matcher_entry_default_registry_unique", + (pto.f32, pto.i32), + ) + + self.assertIs(selected, kernel) + + def test_select_kernel_uses_explicit_registry_without_falling_back(self) -> None: + @pto.vkernel(op="matcher_entry_registry_isolation_unique", dtypes=[(pto.f32,)]) + def default_kernel(inp: pto.TensorView): + return None + + empty_registry = pto.KernelRegistry() + with self.assertRaises(LookupError) as ctx: + pto.select_kernel( + "a5", + "matcher_entry_registry_isolation_unique", + (pto.f32,), + registry=empty_registry, + ) + self.assertIn("found no registered kernel", str(ctx.exception)) + + isolated_registry = pto.KernelRegistry() + isolated_registry.register(default_kernel) + selected = pto.select_kernel( + "a5", + "matcher_entry_registry_isolation_unique", + (pto.f32,), + registry=isolated_registry, + ) + + self.assertIs(selected, default_kernel) + self.assertEqual(len(isolated_registry.descriptors), 1) + + def test_select_kernel_binds_concrete_signature_from_multi_signature_descriptor(self) -> None: + @pto.vkernel( + op="matcher_multi_signature_unique", + dtypes=[ + (pto.f16, pto.f16), + (pto.f32, pto.f32), + ], + ) + def kernel(inp: pto.TensorView, tile: pto.Tile): + return None + + selected = pto.select_kernel( + "a5", + "matcher_multi_signature_unique", + (pto.f32, pto.f32), + ) + + self.assertEqual(selected.dtype_signature, (pto.f32, pto.f32)) + self.assertEqual( + [(param.name, param.kind, param.dtype) for param in selected.parameters], + [("inp", "tensorview", pto.f32), ("tile", "tile", pto.f32)], + ) + specialized = selected.specialize( + tile=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB) + ) + self.assertIn( + "!pto.tile_buf None: + @pto.vkernel(op="matcher_default_dtypes_unique") + def kernel(inp: pto.Tile, out: pto.Tile): + return None + + selected = pto.select_kernel( + "a5", + "matcher_default_dtypes_unique", + (pto.f16, pto.f16), + ) + + self.assertEqual(selected.dtype_signature, (pto.f16, pto.f16)) + self.assertEqual( + [(param.name, param.kind, param.dtype) for param in selected.parameters], + [("inp", "tile", pto.f16), ("out", "tile", pto.f16)], + ) + specialized = selected.specialize( + inp=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + out=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + self.assertIn( + "!pto.tile_buf None: + @pto.vkernel(op="matcher_default_dtypes_scalar_guard_unique") + def kernel(inp: pto.TensorView, scale: pto.i32): + return None + + selected = pto.select_kernel( + "a5", + "matcher_default_dtypes_scalar_guard_unique", + (pto.f32, pto.i32), + ) + self.assertEqual(selected.dtype_signature, (pto.f32, pto.i32)) + + with self.assertRaises(LookupError) as ctx: + pto.select_kernel( + "a5", + "matcher_default_dtypes_scalar_guard_unique", + (pto.f32, pto.f16), + ) + self.assertIn("found no registered kernel", str(ctx.exception)) + + def test_select_kernel_matches_wildcards_deterministically(self) -> None: + @pto.vkernel( + op="matcher_wildcard_unique", + dtypes=[ + (pto.AnyInt, pto.AnyType), + (pto.AnyFloat, pto.AnyType), + ], + ) + def kernel(lhs: pto.TensorView, rhs: pto.Tile): + return None + + selected = pto.select_kernel( + "a5", + "matcher_wildcard_unique", + (pto.f32, pto.i32), + ) + + self.assertEqual(selected.dtype_signature, (pto.f32, pto.i32)) + self.assertEqual(selected.parameters[0].dtype, pto.f32) + self.assertEqual(selected.parameters[1].dtype, pto.i32) + + selected_int = pto.select_kernel( + "a5", + "matcher_wildcard_unique", + (pto.ui16, pto.si16), + ) + self.assertEqual(selected_int.dtype_signature, (pto.ui16, pto.si16)) + self.assertEqual(selected_int.parameters[0].dtype, pto.ui16) + self.assertEqual(selected_int.parameters[1].dtype, pto.si16) + + def test_select_kernel_enforces_typevar_consistency_per_signature(self) -> None: + @pto.vkernel( + op="matcher_typevar_unique", + dtypes=[(pto.TypeVar("T"), pto.TypeVar("T"))], + ) + def kernel(lhs: pto.TensorView, rhs: pto.Tile): + return None + + selected = pto.select_kernel( + "a5", + "matcher_typevar_unique", + (pto.f32, pto.f32), + ) + self.assertEqual(selected.dtype_signature, (pto.f32, pto.f32)) + + with self.assertRaises(LookupError) as ctx: + pto.select_kernel( + "a5", + "matcher_typevar_unique", + (pto.f32, pto.i32), + ) + self.assertIn("found no registered kernel", str(ctx.exception)) + + def test_scalar_typevar_annotation_tracks_selected_dtype(self) -> None: + elem = pto.TypeVar("Elem") + + @pto.vkernel( + op="scalar_typevar_binding_unique", + dtypes=[(elem, elem, elem)], + ) + def kernel(inp: pto.Tile, scale: elem, out: pto.Tile): + return None + + selected = pto.select_kernel( + "a5", + "scalar_typevar_binding_unique", + (pto.bf16, pto.bf16, pto.bf16), + ) + + self.assertEqual(selected.dtype_signature, (pto.bf16, pto.bf16, pto.bf16)) + self.assertEqual( + [(param.name, param.kind, param.dtype) for param in selected.parameters], + [("inp", "tile", pto.bf16), ("scale", "scalar", pto.bf16), ("out", "tile", pto.bf16)], + ) + + def test_scalar_wildcard_annotation_accepts_selected_dtype(self) -> None: + @pto.vkernel( + op="scalar_wildcard_binding_unique", + dtypes=[(pto.AnyType, pto.AnyType, pto.AnyType)], + ) + def kernel(inp: pto.Tile, scale: pto.AnyType, out: pto.Tile): + return None + + selected = pto.select_kernel( + "a5", + "scalar_wildcard_binding_unique", + (pto.i16, pto.i16, pto.i16), + ) + + self.assertEqual(selected.dtype_signature, (pto.i16, pto.i16, pto.i16)) + self.assertEqual( + [(param.name, param.kind, param.dtype) for param in selected.parameters], + [("inp", "tile", pto.i16), ("scale", "scalar", pto.i16), ("out", "tile", pto.i16)], + ) + + def test_polymorphic_descriptor_requires_select_kernel_before_materialization(self) -> None: + @pto.vkernel( + op="matcher_materialization_gate_unique", + dtypes=[(pto.AnyFloat, pto.AnyFloat)], + ) + def kernel(inp: pto.TensorView, out: pto.TensorView): + return None + + with self.assertRaises(ValueError) as ctx: + kernel.mlir_text() + self.assertIn("requires pto.select_kernel(...)", str(ctx.exception)) + + def test_select_kernel_evaluates_constraints_before_priority(self) -> None: + def requires_large_batch(batch=0): + return batch >= 1024 + + @pto.vkernel( + op="matcher_constraint_priority_unique", + dtypes=[(pto.AnyFloat, pto.AnyFloat)], + constraints=[requires_large_batch], + priority=100, + ) + def high_priority_kernel(inp: pto.TensorView, out: pto.TensorView): + return None + + @pto.vkernel( + op="matcher_constraint_priority_unique", + dtypes=[(pto.AnyFloat, pto.AnyFloat)], + constraints=[], + priority=10, + ) + def fallback_kernel(inp: pto.TensorView, out: pto.TensorView): + return None + + selected = pto.select_kernel( + "a5", + "matcher_constraint_priority_unique", + (pto.f32, pto.f32), + context_attrs={"batch": 128}, + ) + self.assertIs(selected.py_fn, fallback_kernel.py_fn) + self.assertEqual(selected.priority, 10) + + selected = pto.select_kernel( + "a5", + "matcher_constraint_priority_unique", + (pto.f32, pto.f32), + context_attrs={"batch": 4096}, + ) + self.assertIs(selected.py_fn, high_priority_kernel.py_fn) + self.assertEqual(selected.priority, 100) + + def test_select_kernel_raises_tie_error_for_equal_highest_priority(self) -> None: + @pto.vkernel( + op="matcher_priority_tie_unique", + dtypes=[(pto.AnyFloat, pto.AnyFloat)], + priority=50, + ) + def lhs(inp: pto.TensorView, out: pto.TensorView): + return None + + @pto.vkernel( + op="matcher_priority_tie_unique", + dtypes=[(pto.AnyFloat, pto.AnyFloat)], + priority=50, + ) + def rhs(inp: pto.TensorView, out: pto.TensorView): + return None + + with self.assertRaises(LookupError) as ctx: + pto.select_kernel( + "a5", + "matcher_priority_tie_unique", + (pto.f32, pto.f32), + ) + self.assertIn("multiple highest-priority kernels", str(ctx.exception)) + self.assertIn("lhs(priority=50", str(ctx.exception)) + self.assertIn("rhs(priority=50", str(ctx.exception)) + + def test_select_kernel_reports_no_candidate_after_constraint_evaluation(self) -> None: + @pto.vkernel( + op="matcher_constraint_empty_unique", + dtypes=[(pto.AnyFloat, pto.AnyFloat)], + constraints=[lambda enabled=False: enabled], + priority=1, + ) + def kernel(inp: pto.TensorView, out: pto.TensorView): + return None + + with self.assertRaises(LookupError) as ctx: + pto.select_kernel( + "a5", + "matcher_constraint_empty_unique", + (pto.f32, pto.f32), + context_attrs={"enabled": False}, + ) + self.assertIn("after constraint evaluation", str(ctx.exception)) + + def test_select_kernel_report_mode_keeps_default_descriptor_path_compatible(self) -> None: + @pto.vkernel(op="matcher_report_default_compat_unique", dtypes=[(pto.f32, pto.f32)]) + def kernel(inp: pto.TensorView, out: pto.TensorView): + return None + + selected = pto.select_kernel( + "a5", + "matcher_report_default_compat_unique", + (pto.f32, pto.f32), + return_metadata=False, + include_mlir=False, + ) + + self.assertIsInstance(selected, pto.VKernelDescriptor) + self.assertIs(selected.py_fn, kernel.py_fn) + self.assertEqual(selected.dtype_signature, (pto.f32, pto.f32)) + + def test_select_kernel_report_mode_records_dtype_mismatch_candidates(self) -> None: + @pto.vkernel( + op="matcher_report_dtype_mismatch_unique", + dtypes=[(pto.f32, pto.f32)], + priority=5, + ) + def mismatch(inp: pto.TensorView, out: pto.TensorView): + return None + + @pto.vkernel( + op="matcher_report_dtype_mismatch_unique", + dtypes=[(pto.AnyFloat, pto.AnyFloat)], + priority=10, + ) + def fallback(inp: pto.TensorView, out: pto.TensorView): + return None + + report = pto.select_kernel( + "a5", + "matcher_report_dtype_mismatch_unique", + (pto.bf16, pto.bf16), + return_metadata=True, + include_mlir=False, + ) + + self.assertIsInstance(report, pto.KernelSelectionReport) + self.assertEqual(report.final_status, "selected") + self.assertIsNotNone(report.selected) + assert report.selected is not None + self.assertEqual(report.selected.py_fn, fallback.py_fn) + self.assertEqual( + [(candidate.name, candidate.status) for candidate in report.candidates], + [("mismatch", "dtype_mismatch"), ("fallback", "selected")], + ) + + def test_select_kernel_report_mode_records_constraint_failure_candidates(self) -> None: + constrained_check = lambda enabled=False: enabled + + @pto.vkernel( + op="matcher_report_constraint_failure_unique", + dtypes=[(pto.AnyFloat, pto.AnyFloat)], + constraints=[constrained_check], + priority=20, + ) + def constrained(inp: pto.TensorView, out: pto.TensorView): + return None + + @pto.vkernel( + op="matcher_report_constraint_failure_unique", + dtypes=[(pto.AnyFloat, pto.AnyFloat)], + priority=5, + ) + def fallback(inp: pto.TensorView, out: pto.TensorView): + return None + + report = pto.select_kernel( + "a5", + "matcher_report_constraint_failure_unique", + (pto.f32, pto.f32), + context_attrs={"enabled": False}, + return_metadata=True, + include_mlir=False, + ) + + self.assertEqual(report.final_status, "selected") + self.assertIsNotNone(report.selected) + assert report.selected is not None + self.assertEqual(report.selected.py_fn, fallback.py_fn) + expected_location = ( + f"{constrained_check.__code__.co_filename}:{constrained_check.__code__.co_firstlineno}" + ) + self.assertEqual( + [ + ( + candidate.name, + candidate.status, + candidate.failed_constraint_index, + candidate.failed_constraint_location, + ) + for candidate in report.candidates + ], + [ + ("constrained", "constraint_failed", 0, expected_location), + ("fallback", "selected", None, None), + ], + ) + self.assertIn(expected_location, report.candidates[0].reason) + + def test_select_kernel_report_mode_records_constraint_exceptions(self) -> None: + bad_constraint = lambda missing: missing + + @pto.vkernel( + op="matcher_report_constraint_exception_unique", + dtypes=[(pto.f32, pto.f32)], + constraints=[bad_constraint], + ) + def bad(inp: pto.TensorView, out: pto.TensorView): + return None + + report = pto.select_kernel( + "a5", + "matcher_report_constraint_exception_unique", + (pto.f32, pto.f32), + return_metadata=True, + include_mlir=False, + ) + + self.assertEqual(report.final_status, "no_candidate") + self.assertIsNone(report.selected) + self.assertIn("requires unsupported parameter", report.final_error) + self.assertEqual(len(report.candidates), 1) + expected_location = ( + f"{bad_constraint.__code__.co_filename}:{bad_constraint.__code__.co_firstlineno}" + ) + candidate = report.candidates[0] + self.assertEqual(candidate.name, "bad") + self.assertEqual(candidate.status, "constraint_error") + self.assertEqual(candidate.failed_constraint_index, 0) + self.assertEqual(candidate.failed_constraint_location, expected_location) + self.assertEqual(candidate.error_type, "TypeError") + self.assertIn("requires unsupported parameter", candidate.error_message) + self.assertIn(expected_location, candidate.error_message) + + def test_select_kernel_report_mode_reports_priority_ties(self) -> None: + @pto.vkernel( + op="matcher_report_priority_tie_unique", + dtypes=[(pto.f32, pto.f32)], + priority=33, + ) + def lhs(inp: pto.TensorView, out: pto.TensorView): + return None + + @pto.vkernel( + op="matcher_report_priority_tie_unique", + dtypes=[(pto.f32, pto.f32)], + priority=33, + ) + def rhs(inp: pto.TensorView, out: pto.TensorView): + return None + + report = pto.select_kernel( + "a5", + "matcher_report_priority_tie_unique", + (pto.f32, pto.f32), + return_metadata=True, + include_mlir=False, + ) + + self.assertEqual(report.final_status, "priority_tie") + self.assertIsNone(report.selected) + self.assertIn("multiple highest-priority kernels", report.final_error) + self.assertEqual( + [(candidate.name, candidate.status) for candidate in report.candidates], + [("lhs", "priority_tie"), ("rhs", "priority_tie")], + ) + + def test_select_kernel_report_mode_reports_no_candidate_without_candidates(self) -> None: + empty_registry = pto.KernelRegistry() + + report = pto.select_kernel( + "a5", + "matcher_report_empty_registry_unique", + (pto.f32,), + registry=empty_registry, + return_metadata=True, + include_mlir=False, + ) + + self.assertEqual(report.final_status, "no_candidate") + self.assertIsNone(report.selected) + self.assertEqual(report.candidates, ()) + self.assertIn("found no registered kernel", report.final_error) + + def test_select_kernel_report_mode_includes_mlir_text_for_materializable_candidate(self) -> None: + @pto.vkernel( + op="matcher_report_mlir_text_unique", + dtypes=[(pto.f32, pto.f32)], + ) + def kernel(inp: pto.TensorView, out: pto.TensorView): + return None + + report = pto.select_kernel( + "a5", + "matcher_report_mlir_text_unique", + (pto.f32, pto.f32), + return_metadata=True, + include_mlir=True, + ) + + self.assertEqual(report.final_status, "selected") + self.assertEqual(len(report.candidates), 1) + candidate = report.candidates[0] + self.assertEqual(candidate.status, "selected") + self.assertIsNotNone(candidate.mlir_text) + self.assertIsNone(candidate.mlir_error) + self.assertIn("module attributes", candidate.mlir_text) + self.assertIn("@kernel", candidate.mlir_text) + self.assertIn("!pto.tensor_view", candidate.mlir_text) + + def test_select_kernel_report_mode_includes_mlir_error_for_unspecialized_tile_candidate(self) -> None: + @pto.vkernel( + op="matcher_report_mlir_error_unique", + dtypes=[(pto.f32, pto.f32)], + ) + def kernel(inp: pto.TensorView, out: pto.Tile): + return None + + report = pto.select_kernel( + "a5", + "matcher_report_mlir_error_unique", + (pto.f32, pto.f32), + return_metadata=True, + include_mlir=True, + ) + + self.assertEqual(report.final_status, "selected") + self.assertEqual(len(report.candidates), 1) + candidate = report.candidates[0] + self.assertEqual(candidate.status, "selected") + self.assertIsNone(candidate.mlir_text) + self.assertIsNotNone(candidate.mlir_error) + self.assertIn("requires specialize() bindings for bare Tile parameters", candidate.mlir_error) + + def test_materialization_constraints_can_see_specializations_and_selected_context_attrs(self) -> None: + @pto.vkernel( + op="matcher_materialization_constraint_unique", + dtypes=[(pto.f32, pto.f32)], + constraints=[ + lambda src: src.rank == 5, + lambda dst, expected_rows=None: dst.shape[0] == expected_rows, + lambda src, dst: dst.valid_shape[1] <= src.shape[4], + ], + ) + def kernel(src: pto.TensorView, dst: pto.Tile): + return None + + selected = pto.select_kernel( + "a5", + "matcher_materialization_constraint_unique", + (pto.f32, pto.f32), + context_attrs={"expected_rows": 8, "src_shape": (2, 2, 1, 1, 16), "src_strides": (32, 16, 16, 16, 1)}, + ).specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB, valid_shape=(4, 16)), + ) + text = selected.mlir_text() + self.assertIn("!pto.tensor_view", text) + self.assertIn("!pto.tile_buf None: + @pto.vkernel( + op="matcher_parameter_style_constraints_unique", + dtypes=[(pto.f32, pto.f32)], + constraints=[ + lambda src, dst: src.rank == 5, + lambda src: src.strides[4] == 1, + lambda src, dst: src.shape[0] <= dst.shape[0], + ], + ) + def kernel(src: pto.TensorView, dst: pto.Tile): + return None + + selected = pto.select_kernel( + "a5", + "matcher_parameter_style_constraints_unique", + (pto.f32, pto.f32), + context_attrs={"src_shape": (4, 1, 1, 1, 16), "src_strides": (16, 16, 16, 16, 1)}, + ).specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + self.assertIn("!pto.tile_buf None: + @pto.vkernel( + op="matcher_positional_context_unique", + dtypes=[(pto.f32, pto.f32)], + constraints=[ + lambda src: src.rank == 5, + lambda src: src.strides[4] == 1, + lambda dst: dst.config.b_layout == pto.BLayout.ROW_MAJOR, + ], + ) + def template_nd(src: pto.TensorView, dst: pto.Tile): + return None + + @pto.vkernel( + op="matcher_positional_context_unique", + dtypes=[(pto.f32, pto.f32)], + constraints=[ + lambda inp: inp.rank == 5, + lambda out: out.config.b_layout == pto.BLayout.COL_MAJOR, + ], + priority=9, + ) + def template_dn(inp: pto.TensorView, out: pto.Tile): + return None + + operand_specs = expand_helper._parse_operand_specs( + """ +[ + { + "kind": "view", + "dtype": "f32", + "shape": [1, 1, 1, 16, 64], + "strides": [1024, 1024, 1024, 64, 1], + "memory_space": "gm" + }, + { + "kind": "tile", + "dtype": "f32", + "shape": [16, 64], + "valid_shape": [16, 64], + "memory_space": "ub", + "config": { + "b_layout": "row_major", + "s_layout": "none_box", + "s_fractal_size": 512, + "pad_value": "0x0" + } + } +] +""" + ) + + registry = pto.KernelRegistry((template_nd, template_dn)) + selected = pto.select_kernel( + "a5", + "matcher_positional_context_unique", + (pto.f32, pto.f32), + context_attrs=expand_helper._build_positional_context_attrs(operand_specs), + registry=registry, + ) + + self.assertEqual(selected.name, "template_nd") + + def test_select_kernel_binds_selected_op_for_multi_op_descriptor(self) -> None: + @pto.vkernel( + ops=["matcher_multi_op_bind_add_unique", "matcher_multi_op_bind_sub_unique"], + dtypes=[(pto.f32, pto.f32)], + ) + def kernel(inp: pto.TensorView, out: pto.TensorView): + return None + + selected = pto.select_kernel( + "a5", + "matcher_multi_op_bind_sub_unique", + (pto.f32, pto.f32), + ) + + self.assertIs(selected.py_fn, kernel.py_fn) + self.assertEqual(selected.match_ops, ("matcher_multi_op_bind_add_unique", "matcher_multi_op_bind_sub_unique")) + self.assertEqual(selected.selected_op, "matcher_multi_op_bind_sub_unique") + self.assertEqual(selected.op, "matcher_multi_op_bind_sub_unique") + self.assertEqual(selected.dtype_signature, (pto.f32, pto.f32)) + + def test_select_kernel_hits_same_multi_op_descriptor_for_multiple_query_ops(self) -> None: + @pto.vkernel( + ops=[ + "matcher_multi_hit_add_unique", + "matcher_multi_hit_mul_unique", + "matcher_multi_hit_div_unique", + ], + dtypes=[(pto.f32, pto.f32)], + ) + def kernel(inp: pto.TensorView, out: pto.TensorView): + return None + + add_selected = pto.select_kernel( + "a5", + "matcher_multi_hit_add_unique", + (pto.f32, pto.f32), + ) + mul_selected = pto.select_kernel( + "a5", + "matcher_multi_hit_mul_unique", + (pto.f32, pto.f32), + ) + + self.assertIs(add_selected.py_fn, kernel.py_fn) + self.assertIs(mul_selected.py_fn, kernel.py_fn) + self.assertEqual(add_selected.match_ops, kernel.match_ops) + self.assertEqual(mul_selected.match_ops, kernel.match_ops) + self.assertEqual(add_selected.selected_op, "matcher_multi_hit_add_unique") + self.assertEqual(mul_selected.selected_op, "matcher_multi_hit_mul_unique") + self.assertEqual(add_selected.op, "matcher_multi_hit_add_unique") + self.assertEqual(mul_selected.op, "matcher_multi_hit_mul_unique") + + def test_select_kernel_prefers_higher_priority_single_op_over_multi_op(self) -> None: + @pto.vkernel( + op="matcher_single_beats_multi_priority_unique", + dtypes=[(pto.f32, pto.f32)], + priority=12, + ) + def single(inp: pto.TensorView, out: pto.TensorView): + return None + + @pto.vkernel( + ops=[ + "matcher_single_beats_multi_priority_unique", + "matcher_single_beats_multi_priority_alt_unique", + ], + dtypes=[(pto.f32, pto.f32)], + priority=4, + ) + def multi(inp: pto.TensorView, out: pto.TensorView): + return None + + selected = pto.select_kernel( + "a5", + "matcher_single_beats_multi_priority_unique", + (pto.f32, pto.f32), + ) + + self.assertIs(selected.py_fn, single.py_fn) + self.assertEqual(selected.selected_op, "matcher_single_beats_multi_priority_unique") + self.assertEqual(selected.priority, 12) + + def test_select_kernel_prefers_priority_over_single_op_specificity(self) -> None: + @pto.vkernel( + op="matcher_single_vs_multi_priority_unique", + dtypes=[(pto.f32, pto.f32)], + priority=5, + ) + def single(inp: pto.TensorView, out: pto.TensorView): + return None + + @pto.vkernel( + ops=["matcher_single_vs_multi_priority_unique", "matcher_single_vs_multi_priority_alt_unique"], + dtypes=[(pto.f32, pto.f32)], + priority=9, + ) + def multi(inp: pto.TensorView, out: pto.TensorView): + return None + + selected = pto.select_kernel( + "a5", + "matcher_single_vs_multi_priority_unique", + (pto.f32, pto.f32), + ) + + self.assertIs(selected.py_fn, multi.py_fn) + self.assertEqual(selected.selected_op, "matcher_single_vs_multi_priority_unique") + self.assertEqual(selected.priority, 9) + + def test_select_kernel_raises_tie_error_when_single_and_multi_op_candidates_tie(self) -> None: + @pto.vkernel( + op="matcher_single_multi_tie_unique", + dtypes=[(pto.f32, pto.f32)], + priority=17, + ) + def single(inp: pto.TensorView, out: pto.TensorView): + return None + + @pto.vkernel( + ops=["matcher_single_multi_tie_unique", "matcher_single_multi_tie_alt_unique"], + dtypes=[(pto.f32, pto.f32)], + priority=17, + ) + def multi(inp: pto.TensorView, out: pto.TensorView): + return None + + with self.assertRaises(LookupError) as ctx: + pto.select_kernel( + "a5", + "matcher_single_multi_tie_unique", + (pto.f32, pto.f32), + ) + + self.assertIn("multiple highest-priority kernels", str(ctx.exception)) + self.assertIn("single(priority=17", str(ctx.exception)) + self.assertIn("multi(priority=17", str(ctx.exception)) + + +class TileLangDSLDescriptorTests(unittest.TestCase): + def test_descriptor_metadata_and_parameter_binding(self) -> None: + @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f16, pto.i32)], verify=False) + def kernel(inp: pto.TensorView, tile: pto.Tile, scale: pto.i32): + return None + + self.assertEqual(kernel.target, "a5") + self.assertEqual(kernel.op, "eltwise") + self.assertEqual(kernel.name, "kernel") + self.assertFalse(kernel.verify_enabled) + self.assertFalse(kernel.advanced_enabled) + self.assertEqual(kernel.metadata["verify"], False) + self.assertEqual(kernel.metadata["advanced"], False) + self.assertEqual(kernel.dtype_signature, (pto.f32, pto.f16, pto.i32)) + self.assertEqual( + [(param.name, param.kind, param.dtype) for param in kernel.parameters], + [("inp", "tensorview", pto.f32), ("tile", "tile", pto.f16), ("scale", "scalar", pto.i32)], + ) + self.assertEqual(kernel.parameters[0].element_dtype, pto.f32) + self.assertEqual(kernel.parameters[1].element_dtype, pto.f16) + self.assertIsNone(kernel.parameters[2].element_dtype) + + def test_descriptor_accepts_multi_op_matcher_metadata(self) -> None: + @pto.vkernel(ops=["tadd", "tsub"], dtypes=[(pto.f32, pto.f32)]) + def kernel(inp: pto.TensorView, out: pto.TensorView): + return None + + self.assertEqual(kernel.match_ops, ("tadd", "tsub")) + self.assertIsNone(kernel.selected_op) + self.assertIsNone(kernel.metadata["op"]) + self.assertEqual(kernel.metadata["match_ops"], ("tadd", "tsub")) + self.assertIsNone(kernel.metadata["selected_op"]) + self.assertEqual(kernel.dtype_signature, (pto.f32, pto.f32)) + self.assertEqual( + [(param.name, param.kind, param.dtype) for param in kernel.parameters], + [("inp", "tensorview", pto.f32), ("out", "tensorview", pto.f32)], + ) + with self.assertRaises(ValueError) as ctx: + _ = kernel.op + self.assertIn("bind a concrete op", str(ctx.exception)) + + def test_descriptor_defaults_dtypes_for_beginner_tile_kernels(self) -> None: + @pto.vkernel(op="default_dtypes_unique") + def kernel(inp: pto.Tile, out: pto.Tile): + return None + + self.assertEqual(kernel.match_ops, ("default_dtypes_unique",)) + self.assertEqual(kernel.dtypes, ((pto.AnyType, pto.AnyType),)) + self.assertEqual(kernel.metadata["dtypes"], ((pto.AnyType, pto.AnyType),)) + with self.assertRaises(ValueError) as ctx: + _ = kernel.dtype_signature + self.assertIn("choose a concrete dtype signature", str(ctx.exception)) + + def test_descriptor_defaults_scalar_typevar_to_anytype(self) -> None: + elem = pto.TypeVar("Elem") + + @pto.vkernel(op="default_scalar_typevar_unique") + def kernel(inp: pto.Tile, scale: elem, out: pto.Tile): + return None + + self.assertEqual(kernel.match_ops, ("default_scalar_typevar_unique",)) + self.assertEqual(kernel.dtypes, ((pto.AnyType, pto.AnyType, pto.AnyType),)) + self.assertEqual(kernel.metadata["dtypes"], ((pto.AnyType, pto.AnyType, pto.AnyType),)) + with self.assertRaises(ValueError) as ctx: + _ = kernel.dtype_signature + self.assertIn("choose a concrete dtype signature", str(ctx.exception)) + + def test_descriptor_defaults_scalar_wildcard_to_anytype(self) -> None: + @pto.vkernel(op="default_scalar_wildcard_unique") + def kernel(inp: pto.Tile, scale: pto.AnyType, out: pto.Tile): + return None + + self.assertEqual(kernel.match_ops, ("default_scalar_wildcard_unique",)) + self.assertEqual(kernel.dtypes, ((pto.AnyType, pto.AnyType, pto.AnyType),)) + self.assertEqual(kernel.metadata["dtypes"], ((pto.AnyType, pto.AnyType, pto.AnyType),)) + with self.assertRaises(ValueError) as ctx: + _ = kernel.dtype_signature + self.assertIn("choose a concrete dtype signature", str(ctx.exception)) + + def test_descriptor_accepts_templates_metadata(self) -> None: + @pto.vkernel( + ops=["tadd", "tsub", "tmul"], + dtypes=[(pto.f32, pto.f32)], + templates={ + "core": { + "tadd": "vadd", + "tsub": "vsub", + }, + "post": { + "tmul": "vrelu", + }, + }, + ) + def kernel(inp: pto.TensorView, out: pto.TensorView): + return None + + self.assertEqual( + kernel.templates, + { + "core": { + "tadd": "vadd", + "tsub": "vsub", + }, + "post": { + "tmul": "vrelu", + }, + }, + ) + self.assertEqual(kernel.metadata["templates"], kernel.templates) + + def test_descriptor_rejects_op_and_ops_together(self) -> None: + with self.assertRaises(ValueError) as ctx: + @pto.vkernel(op="tadd", ops=["tsub"], dtypes=[(pto.f32,)]) + def kernel(inp: pto.TensorView): + return None + + self.assertIn("either op= or ops=", str(ctx.exception)) + + def test_descriptor_requires_one_of_op_or_ops(self) -> None: + with self.assertRaises(ValueError) as ctx: + @pto.vkernel(dtypes=[(pto.f32,)]) + def kernel(inp: pto.TensorView): + return None + + self.assertIn("exactly one of op= or ops=", str(ctx.exception)) + + def test_descriptor_rejects_template_slot_with_non_string_name(self) -> None: + with self.assertRaises(TypeError) as ctx: + @pto.vkernel( + ops=["tadd"], + dtypes=[(pto.f32,)], + templates={1: {"tadd": "vadd"}}, + ) + def kernel(inp: pto.TensorView): + return None + + self.assertIn("template slot names must be non-empty strings", str(ctx.exception)) + + def test_descriptor_rejects_template_op_outside_matcher_set(self) -> None: + with self.assertRaises(ValueError) as ctx: + @pto.vkernel( + ops=["tadd", "tsub"], + dtypes=[(pto.f32, pto.f32)], + templates={"core": {"tmul": "vmul"}}, + ) + def kernel(inp: pto.TensorView, out: pto.TensorView): + return None + + self.assertIn("outside descriptor matcher set", str(ctx.exception)) + + def test_descriptor_rejects_template_mapping_to_unknown_pto_op(self) -> None: + with self.assertRaises(ValueError) as ctx: + @pto.vkernel( + ops=["tadd"], + dtypes=[(pto.f32,)], + templates={"core": {"tadd": "vunknown"}}, + ) + def kernel(inp: pto.TensorView): + return None + + self.assertIn("maps to unsupported pto op", str(ctx.exception)) + + def test_pointer_parameter_annotation_binds_as_ptr_kind(self) -> None: + @pto.vkernel(op="ptr_surface", dtypes=[(pto.f32, pto.i64)], advanced=True) + def kernel(src: pto.ptr(pto.f32, pto.MemorySpace.UB), addr: pto.i64): + return None + + self.assertEqual(kernel.parameters[0].kind, "ptr") + self.assertEqual(kernel.parameters[0].dtype, pto.f32) + self.assertEqual(kernel.parameters[0].annotation, pto.ptr(pto.f32, pto.MemorySpace.UB)) + self.assertEqual(kernel.parameters[0].element_dtype, pto.f32) + + def test_get_vms4_sr_lowers_to_four_i16_results(self) -> None: + @pto.vkernel(op="get_vms4_sr_surface", dtypes=[(pto.i16,)], advanced=True) + def kernel(dst: pto.ptr(pto.i16, pto.MemorySpace.GM)): + list0, list1, list2, list3 = pto.get_vms4_sr() + pto.store_scalar(dst, 0, list2) + return None + + specialized = kernel.specialize() + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + status_assign = next( + stmt + for stmt in semantic_kernel.body + if isinstance(stmt, SemanticAssignStmt) + and isinstance(stmt.value, SemanticCallExpr) + and stmt.value.name == "get_vms4_sr" + ) + self.assertEqual(len(status_assign.targets), 4) + self.assertTrue(all(isinstance(target.type, SemanticScalarType) for target in status_assign.targets)) + self.assertTrue(all(target.type.dtype == pto.i16 for target in status_assign.targets)) + + text = specialized.mlir_text() + self.assertRegex( + text, + r"%list0_\d+, %list1_\d+, %list2_\d+, %list3_\d+ = pto\.get_vms4_sr : i16, i16, i16, i16", + ) + self.assertIn("pto.store_scalar", text) + + def test_vreg_type_constructor_exposes_inferred_lane_count(self) -> None: + vec_type = pto.vreg(pto.f32) + self.assertIsInstance(vec_type, pto.VRegType) + self.assertEqual(vec_type.element_dtype, pto.f32) + self.assertEqual(vec_type.lanes, 64) + self.assertEqual(repr(vec_type), "vreg(f32)") + + def test_vector_type_constructor_exposes_shape(self) -> None: + vec_type = pto.vector(pto.i16, (4,)) + self.assertIsInstance(vec_type, pto.VectorType) + self.assertEqual(vec_type.element_dtype, pto.i16) + self.assertEqual(vec_type.shape, (4,)) + self.assertEqual(repr(vec_type), "vector(i16, (4,))") + + def test_vector_parameter_annotation_binds_as_vector_kind(self) -> None: + @pto.vkernel(op="vector_surface_unique", dtypes=[(pto.f32, pto.i16, pto.f32)], advanced=True) + def kernel(src: pto.Tile, ex_vec: pto.vector(pto.i16, (4,)), dst: pto.Tile): + return None + + self.assertEqual(kernel.parameters[1].kind, "vector") + self.assertEqual(kernel.parameters[1].dtype, pto.i16) + self.assertEqual(kernel.parameters[1].annotation, pto.vector(pto.i16, (4,))) + self.assertEqual(kernel.parameters[1].element_dtype, pto.i16) + + def test_mask_type_constants_expose_granularity(self) -> None: + self.assertIsInstance(pto.mask_b8, pto.MaskType) + self.assertIsInstance(pto.mask_b16, pto.MaskType) + self.assertIsInstance(pto.mask_b32, pto.MaskType) + self.assertEqual(pto.mask_b8.granularity, "b8") + self.assertEqual(pto.mask_b16.granularity, "b16") + self.assertEqual(pto.mask_b32.granularity, "b32") + self.assertEqual(repr(pto.mask_b32), "mask_b32") + + def test_mask_parameter_annotation_binds_as_mask_kind(self) -> None: + @pto.vkernel(op="mask_surface", dtypes=[(pto.mask_b32, pto.f32)], advanced=True) + def kernel(mask: pto.mask_b32, dst: pto.Tile): + return None + + self.assertEqual(kernel.parameters[0].kind, "mask") + self.assertEqual(kernel.parameters[0].dtype, pto.mask_b32) + self.assertEqual(kernel.parameters[0].annotation, pto.mask_b32) + self.assertIsNone(kernel.parameters[0].element_dtype) + + def test_specialization_enables_materialization_apis(self) -> None: + @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f16)]) + def kernel(inp: pto.TensorView, tile: pto.Tile): + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(16, 32), + memory_space=pto.MemorySpace.UB, + config=pto.TileConfig.from_mapping( + { + "layout": "row_major", + "pad_value": pto.PadValue.ZERO, + } + ), + ) + ) + + self.assertIn("tile", specialized.specializations_by_name) + text = specialized.mlir_text() + self.assertIn("// tilelang.target = a5", text) + self.assertIn("// tilelang.specialize tile shape=(16, 32) memory_space=ub", text) + self.assertIn('module attributes {pto.target_arch = "a5"} {', text) + self.assertIn( + "func.func @kernel(%arg0: !pto.tensor_view, %arg1: !pto.tile_buf) attributes { pto.tilelang.instance, pto.kernel_kind = #pto.kernel_kind } {", + text, + ) + module = specialized.mlir_module() + self.assertEqual(type(module).__name__, "MaterializedMLIRModule") + self.assertEqual(module.text, text) + + with tempfile.TemporaryDirectory() as tmpdir: + out = Path(tmpdir) / "kernel.mlir" + specialized.emit(out) + self.assertEqual(out.read_text(encoding="utf-8"), text) + + def test_ckernel_specialize_accepts_cube_bare_tile_profiles(self) -> None: + @pto.ckernel(op="pto.mad", dtypes=[(pto.f16, pto.f16, pto.f32)], name="pure_compute_cube_tiles_unique") + def kernel(lhs: pto.Tile, rhs: pto.Tile, acc: pto.Tile): + return None + + specialized = kernel.specialize( + lhs=pto.TileSpecialization(shape=(16, 32), memory_space=pto.MemorySpace.LEFT), + rhs=pto.TileSpecialization(shape=(32, 16), memory_space=pto.MemorySpace.RIGHT), + acc=pto.TileSpecialization(shape=(16, 16), memory_space=pto.MemorySpace.ACC), + ) + + self.assertEqual(specialized.specializations_by_name["lhs"].memory_space, pto.MemorySpace.LEFT) + self.assertEqual(specialized.specializations_by_name["rhs"].memory_space, pto.MemorySpace.RIGHT) + self.assertEqual(specialized.specializations_by_name["acc"].memory_space, pto.MemorySpace.ACC) + + text = specialized.mlir_text() + self.assertIn("!pto.tile_buf None: + @pto.ckernel( + op="cube_shared_registry_select_unique", + dtypes=[(pto.f16, pto.f16, pto.f16, pto.f16, pto.f16, pto.f32)], + name="cube_shared_registry_select_unique", + ) + def kernel( + inp: pto.TensorView, + part: pto.PartitionTensorView, + l1: pto.Tile, + left: pto.Tile, + right: pto.Tile, + acc: pto.Tile, + ): + gm_ptr = inp.as_ptr() + _ = part.as_ptr() + _ = pto.addptr(gm_ptr, 64) + pto.mad(left.as_ptr(), right.as_ptr(), acc.as_ptr(), 16, 16, 32) + return None + + selected = pto.select_kernel( + "a5", + "cube_shared_registry_select_unique", + (pto.f16, pto.f16, pto.f16, pto.f16, pto.f16, pto.f32), + ) + + self.assertIs(selected, kernel) + self.assertEqual(selected.dtype_signature, (pto.f16, pto.f16, pto.f16, pto.f16, pto.f16, pto.f32)) + self.assertEqual( + [(param.name, param.kind, param.dtype) for param in selected.parameters], + [ + ("inp", "tensorview", pto.f16), + ("part", "partition_tensor_view", pto.f16), + ("l1", "tile", pto.f16), + ("left", "tile", pto.f16), + ("right", "tile", pto.f16), + ("acc", "tile", pto.f32), + ], + ) + + def test_ckernel_cube_as_ptr_and_addptr_bind_typed_pointers(self) -> None: + @pto.ckernel( + op="pto.mad", + dtypes=[(pto.f16, pto.f16, pto.f16, pto.f16, pto.f16, pto.f32)], + name="cube_as_ptr_addptr_unique", + ) + def kernel(inp: pto.TensorView, part: pto.PartitionTensorView, l1: pto.Tile, left: pto.Tile, right: pto.Tile, acc: pto.Tile): + gm_ptr = inp.as_ptr() + gm_offset = pto.addptr(gm_ptr, 64) + part_ptr = part.as_ptr() + l1_ptr = l1.as_ptr() + left_ptr = left.as_ptr() + right_ptr = right.as_ptr() + acc_ptr = acc.as_ptr() + pto.mad(left_ptr, right_ptr, acc_ptr, 16, 16, 32) + _ = part_ptr + _ = gm_offset + return None + + specialized = kernel.specialize( + l1=pto.TileSpecialization(shape=(16, 32), memory_space=pto.MemorySpace.MAT), + left=pto.TileSpecialization(shape=(16, 32), memory_space=pto.MemorySpace.LEFT), + right=pto.TileSpecialization(shape=(32, 16), memory_space=pto.MemorySpace.RIGHT), + acc=pto.TileSpecialization(shape=(16, 16), memory_space=pto.MemorySpace.ACC), + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + assign_stmts = [stmt for stmt in semantic_kernel.body if isinstance(stmt, SemanticAssignStmt)] + ptr_assigns = { + stmt.targets[0].name: stmt + for stmt in assign_stmts + if isinstance(stmt.value, SemanticCallExpr) + and stmt.targets + and stmt.targets[0].name in {"gm_ptr", "gm_offset", "part_ptr", "l1_ptr", "left_ptr", "right_ptr", "acc_ptr"} + } + + self.assertEqual(ptr_assigns["gm_ptr"].value.name, "tensor_view_as_ptr") + self.assertIsInstance(ptr_assigns["gm_ptr"].targets[0].type, SemanticPtrType) + self.assertEqual(ptr_assigns["gm_ptr"].targets[0].type.element_dtype, pto.f16) + self.assertEqual(ptr_assigns["gm_ptr"].targets[0].type.memory_space, "gm") + + self.assertEqual(ptr_assigns["gm_offset"].value.name, "addptr") + self.assertIsInstance(ptr_assigns["gm_offset"].targets[0].type, SemanticPtrType) + self.assertEqual(ptr_assigns["gm_offset"].targets[0].type.element_dtype, pto.f16) + self.assertEqual(ptr_assigns["gm_offset"].targets[0].type.memory_space, "gm") + + self.assertEqual(ptr_assigns["part_ptr"].value.name, "tensor_view_as_ptr") + self.assertIsInstance(ptr_assigns["part_ptr"].targets[0].type, SemanticPtrType) + self.assertEqual(ptr_assigns["part_ptr"].targets[0].type.element_dtype, pto.f16) + self.assertEqual(ptr_assigns["part_ptr"].targets[0].type.memory_space, "gm") + + self.assertEqual(ptr_assigns["l1_ptr"].value.name, "tile_as_ptr") + self.assertIsInstance(ptr_assigns["l1_ptr"].targets[0].type, SemanticPtrType) + self.assertEqual(ptr_assigns["l1_ptr"].targets[0].type.element_dtype, pto.f16) + self.assertEqual(ptr_assigns["l1_ptr"].targets[0].type.memory_space, "mat") + + self.assertEqual(ptr_assigns["left_ptr"].value.name, "tile_as_ptr") + self.assertEqual(ptr_assigns["left_ptr"].targets[0].type.element_dtype, pto.f16) + self.assertEqual(ptr_assigns["left_ptr"].targets[0].type.memory_space, "left") + + self.assertEqual(ptr_assigns["right_ptr"].value.name, "tile_as_ptr") + self.assertEqual(ptr_assigns["right_ptr"].targets[0].type.element_dtype, pto.f16) + self.assertEqual(ptr_assigns["right_ptr"].targets[0].type.memory_space, "right") + + self.assertEqual(ptr_assigns["acc_ptr"].value.name, "tile_as_ptr") + self.assertEqual(ptr_assigns["acc_ptr"].targets[0].type.element_dtype, pto.f32) + self.assertEqual(ptr_assigns["acc_ptr"].targets[0].type.memory_space, "acc") + + def test_ckernel_materializes_cube_kernel_kind_without_vecscope_carrier(self) -> None: + @pto.ckernel(op="cube_kernel_kind_query_unique", dtypes=[(pto.f16, pto.f16, pto.f32)], name="cube_kernel_kind_unique") + def kernel(lhs: pto.Tile, rhs: pto.Tile, acc: pto.Tile): + if pto.constexpr(True): + return None + return None + + specialized = kernel.specialize( + lhs=pto.TileSpecialization(shape=(16, 32), memory_space=pto.MemorySpace.LEFT), + rhs=pto.TileSpecialization(shape=(32, 16), memory_space=pto.MemorySpace.RIGHT), + acc=pto.TileSpecialization(shape=(16, 16), memory_space=pto.MemorySpace.ACC), + ) + + text = specialized.mlir_text() + self.assertIn("pto.kernel_kind = #pto.kernel_kind", text) + self.assertIn("pto.tilelang.instance", text) + self.assertIn("func.func @cube_kernel_kind_unique", text) + self.assertNotIn("pto.strict_vecscope", text) + self.assertNotIn("pto.vecscope", text) + + def test_ckernel_cube_pointer_helpers_lower_to_typed_authoring_values(self) -> None: + @pto.ckernel( + op="cube_pointer_helpers_unique", + dtypes=[(pto.f16, pto.f16, pto.f16, pto.f16, pto.f16, pto.f32)], + name="cube_pointer_helpers_unique", + ) + def kernel(inp: pto.TensorView, part: pto.PartitionTensorView, l1: pto.Tile, left: pto.Tile, right: pto.Tile, acc: pto.Tile): + gm_ptr = inp.as_ptr() + gm_offset = pto.addptr(gm_ptr, 64) + part_ptr = part.as_ptr() + l1_ptr = l1.as_ptr() + left_ptr = left.as_ptr() + right_ptr = right.as_ptr() + acc_ptr = acc.as_ptr() + return gm_offset + + specialized = kernel.specialize( + l1=pto.TileSpecialization(shape=(16, 32), memory_space=pto.MemorySpace.MAT), + left=pto.TileSpecialization(shape=(16, 32), memory_space=pto.MemorySpace.LEFT), + right=pto.TileSpecialization(shape=(32, 16), memory_space=pto.MemorySpace.RIGHT), + acc=pto.TileSpecialization(shape=(16, 16), memory_space=pto.MemorySpace.ACC), + ) + + text = specialized.mlir_text() + self.assertIn("pto.kernel_kind = #pto.kernel_kind", text) + self.assertRegex( + text, + r"%gm_ptr_\d+ = pto\.tensor_view_addr %arg0 : !pto\.tensor_view<\?x\?x\?x\?x\?xf16> -> !pto\.ptr", + ) + self.assertRegex( + text, + r"%gm_offset_\d+ = pto\.addptr %gm_ptr_\d+, %c64 : !pto\.ptr -> !pto\.ptr", + ) + self.assertRegex( + text, + r"%part_ptr_\d+ = pto\.tensor_view_addr %arg1 : !pto\.partition_tensor_view<\?x\?x\?x\?x\?xf16> -> !pto\.ptr", + ) + self.assertRegex( + text, + r"%l1_ptr_\d+ = pto\.tile_buf_addr %arg2 : !pto\.tile_buf -> !pto\.ptr", + ) + self.assertRegex( + text, + r"%left_ptr_\d+ = pto\.tile_buf_addr %arg3 : !pto\.tile_buf -> !pto\.ptr", + ) + self.assertRegex( + text, + r"%right_ptr_\d+ = pto\.tile_buf_addr %arg4 : !pto\.tile_buf -> !pto\.ptr", + ) + self.assertRegex( + text, + r"%acc_ptr_\d+ = pto\.tile_buf_addr %arg5 : !pto\.tile_buf -> !pto\.ptr", + ) + + def test_ckernel_full_pipeline_bridge_ops_lower_one_to_one_in_authoring_form(self) -> None: + @pto.ckernel(op="cube_bridge_pipeline_query_unique", dtypes=[(pto.f16,)], name="cube_bridge_pipeline_unique") + def kernel(inp: pto.TensorView): + gm = inp.as_ptr() + l1 = pto.Tile((16, 32), pto.f16, pto.MemorySpace.MAT) + left = pto.Tile((16, 32), pto.f16, pto.MemorySpace.LEFT) + right = pto.Tile((32, 16), pto.f16, pto.MemorySpace.RIGHT) + acc = pto.Tile((16, 16), pto.f32, pto.MemorySpace.ACC) + bias = pto.Tile((1, 16), pto.f32, pto.MemorySpace.BIAS) + ub = pto.Tile((16, 16), pto.f32, pto.MemorySpace.UB) + + pto.cube_load(gm, l1.as_ptr(), 16, nburst=(1, 0, 0), loops=((2, 32, 64),)) + pto.bias_load(l1.as_ptr(), bias.as_ptr(), 16, nburst=(1, 0, 0)) + pto.left_load(l1.as_ptr(), left.as_ptr(), 16, 32) + pto.right_load(l1.as_ptr(), right.as_ptr(), 32, 16) + pto.mad(left.as_ptr(), right.as_ptr(), acc.as_ptr(), 16, 16, 32, unit_flag_ctrl=2, disable_gemv=pto.i1(True)) + pto.cube_load_frac( + gm, + l1.as_ptr(), + pto.FractalMode.ND2NZ, + shape=(16, 16), + src_layout=(4, 8), + dst_group=(1, 2, 3, 4), + ctrl=(0, False), + ) + pto.acc_store(acc.as_ptr(), l1.as_ptr(), 16, 16, 16, 16, mode=pto.FractalMode.NZ2DN, loop0_src_stride=64, loop3=(3, 4, 5)) + pto.acc_store_gm( + acc.as_ptr(), + gm, + 16, + 16, + 16, + 16, + mode=pto.FractalMode.NZ2NZ, + split=7, + sid=4, + l2_cache_ctrl=5, + ) + pto.acc_store_ub( + acc.as_ptr(), + ub.as_ptr(), + 16, + 16, + 16, + 16, + mode=pto.FractalMode.NZ2ND, + dual_dst_mode=6, + sub_blockid=7, + ) + return None + + text = pto.select_kernel( + "a5", + "cube_bridge_pipeline_query_unique", + (pto.f16,), + registry=pto.KernelRegistry((kernel,)), + ).mlir_text() + self.assertIn("pto.kernel_kind = #pto.kernel_kind", text) + self.assertRegex( + text, + r"pto\.cube_load %[A-Za-z0-9_]+, %[A-Za-z0-9_]+, %[A-Za-z0-9_]+ nburst\(%[A-Za-z0-9_]+, %[A-Za-z0-9_]+, %[A-Za-z0-9_]+\) loop\(%[A-Za-z0-9_]+, %[A-Za-z0-9_]+, %[A-Za-z0-9_]+\) : !pto\.ptr, !pto\.ptr, i64, i64, i64, i64, loop i64, i64, i64", + ) + self.assertRegex( + text, + r"pto\.bias_load %[A-Za-z0-9_]+, %[A-Za-z0-9_]+, %[A-Za-z0-9_]+ nburst\(%[A-Za-z0-9_]+, %[A-Za-z0-9_]+, %[A-Za-z0-9_]+\) : !pto\.ptr, !pto\.ptr, i64, i64, i64, i64", + ) + self.assertRegex( + text, + r"pto\.left_load %[A-Za-z0-9_]+, %[A-Za-z0-9_]+, %[A-Za-z0-9_]+, %[A-Za-z0-9_]+ : !pto\.ptr, !pto\.ptr, i64, i64", + ) + self.assertRegex( + text, + r"pto\.right_load %[A-Za-z0-9_]+, %[A-Za-z0-9_]+, %[A-Za-z0-9_]+, %[A-Za-z0-9_]+ : !pto\.ptr, !pto\.ptr, i64, i64", + ) + self.assertRegex( + text, + r"pto\.mad %[A-Za-z0-9_]+, %[A-Za-z0-9_]+, %[A-Za-z0-9_]+, %[A-Za-z0-9_]+, %[A-Za-z0-9_]+, %[A-Za-z0-9_]+ \{unit_flag_ctrl = 2 : i32\} : !pto\.ptr, !pto\.ptr, !pto\.ptr, i64, i64, i64", + ) + self.assertRegex( + text, + r"pto\.cube_load_frac %[A-Za-z0-9_]+, %[A-Za-z0-9_]+, nd2nz, shape\(%[A-Za-z0-9_]+, %[A-Za-z0-9_]+\), src_layout\(%[A-Za-z0-9_]+, %[A-Za-z0-9_]+\), dst_group\(%[A-Za-z0-9_]+, %[A-Za-z0-9_]+, %[A-Za-z0-9_]+, %[A-Za-z0-9_]+\), ctrl\(%[A-Za-z0-9_]+, %false\) : !pto\.ptr, !pto\.ptr, nd2nz, shape i64, i64, src_layout\(i64, i64\), dst_group i64, i64, i64, i64, ctrl i64, i1", + ) + self.assertRegex( + text, + r"pto\.acc_store %[A-Za-z0-9_]+, %[A-Za-z0-9_]+, %[A-Za-z0-9_]+, %[A-Za-z0-9_]+, %[A-Za-z0-9_]+, %[A-Za-z0-9_]+, nz2dn\(%[A-Za-z0-9_]+\), loop3\(%[A-Za-z0-9_]+, %[A-Za-z0-9_]+, %[A-Za-z0-9_]+\) : !pto\.ptr, !pto\.ptr, i64, i64, i64, i64, i64, i64, i64, i64", + ) + self.assertRegex( + text, + r"pto\.acc_store_gm %[A-Za-z0-9_]+, %[A-Za-z0-9_]+, %[A-Za-z0-9_]+, %[A-Za-z0-9_]+, %[A-Za-z0-9_]+, %[A-Za-z0-9_]+, %[A-Za-z0-9_]+, %[A-Za-z0-9_]+, nz2nz\(%[A-Za-z0-9_]+\) : !pto\.ptr, !pto\.ptr, i64, i64, i64, i64, i64, i64, i64", + ) + self.assertRegex( + text, + r"pto\.acc_store_ub %[A-Za-z0-9_]+, %[A-Za-z0-9_]+, %[A-Za-z0-9_]+, %[A-Za-z0-9_]+, %[A-Za-z0-9_]+, %[A-Za-z0-9_]+, %[A-Za-z0-9_]+, %[A-Za-z0-9_]+, nz2nd : !pto\.ptr, !pto\.ptr, i64, i64, i64, i64, i64, i64", + ) + self.assertNotIn("copy_gm_to_cbuf", text) + self.assertNotIn("copy_matrix_cc_to_gm", text) + + def test_ckernel_cube_bridge_variant_ops_lower_one_to_one(self) -> None: + @pto.ckernel(op="cube_bridge_variants_query_unique", dtypes=[(pto.f16,)], name="cube_bridge_variants_unique") + def kernel(inp: pto.TensorView): + gm = inp.as_ptr() + l1 = pto.Tile((16, 32), pto.f16, pto.MemorySpace.MAT) + ub = pto.Tile((16, 32), pto.f16, pto.MemorySpace.UB) + left = pto.Tile((16, 64), pto.f16, pto.MemorySpace.LEFT) + right = pto.Tile((64, 16), pto.f16, pto.MemorySpace.RIGHT) + acc = pto.Tile((16, 16), pto.f32, pto.MemorySpace.ACC) + bias = pto.Tile((1, 16), pto.f32, pto.MemorySpace.BIAS) + + pto.cube_store(l1.as_ptr(), ub.as_ptr(), 16, nburst=(1, 2, 3)) + pto.left_load_mx(l1.as_ptr(), left.as_ptr(), 16, 64) + pto.right_load_mx(l1.as_ptr(), right.as_ptr(), 64, 16) + pto.mad_acc(left.as_ptr(), right.as_ptr(), acc.as_ptr(), 16, 16, 64, unit_flag_ctrl=3, disable_gemv=pto.i1(False)) + pto.mad_bias(left.as_ptr(), right.as_ptr(), acc.as_ptr(), bias.as_ptr(), 16, 16, 64) + return None + + text = pto.select_kernel( + "a5", + "cube_bridge_variants_query_unique", + (pto.f16,), + registry=pto.KernelRegistry((kernel,)), + ).mlir_text() + self.assertRegex( + text, + r"pto\.cube_store %[A-Za-z0-9_]+, %[A-Za-z0-9_]+, %[A-Za-z0-9_]+ nburst\(%[A-Za-z0-9_]+, %[A-Za-z0-9_]+, %[A-Za-z0-9_]+\) : !pto\.ptr, !pto\.ptr, i64, i64, i64, i64", + ) + self.assertRegex( + text, + r"pto\.left_load_mx %[A-Za-z0-9_]+, %[A-Za-z0-9_]+, %[A-Za-z0-9_]+, %[A-Za-z0-9_]+ : !pto\.ptr, !pto\.ptr, i64, i64", + ) + self.assertRegex( + text, + r"pto\.right_load_mx %[A-Za-z0-9_]+, %[A-Za-z0-9_]+, %[A-Za-z0-9_]+, %[A-Za-z0-9_]+ : !pto\.ptr, !pto\.ptr, i64, i64", + ) + self.assertRegex( + text, + r"pto\.mad_acc %[A-Za-z0-9_]+, %[A-Za-z0-9_]+, %[A-Za-z0-9_]+, %[A-Za-z0-9_]+, %[A-Za-z0-9_]+, %[A-Za-z0-9_]+ \{unit_flag_ctrl = 3 : i32, disable_gemv = false\} : !pto\.ptr, !pto\.ptr, !pto\.ptr, i64, i64, i64", + ) + self.assertRegex( + text, + r"pto\.mad_bias %[A-Za-z0-9_]+, %[A-Za-z0-9_]+, %[A-Za-z0-9_]+, %[A-Za-z0-9_]+, %[A-Za-z0-9_]+, %[A-Za-z0-9_]+, %[A-Za-z0-9_]+ \{disable_gemv = false\} : !pto\.ptr, !pto\.ptr, !pto\.ptr, !pto\.ptr, i64, i64, i64", + ) + + def test_ckernel_specialize_rejects_gm_bare_tile_profile(self) -> None: + @pto.ckernel(op="pto.mad", dtypes=[(pto.f16,)], name="cube_tile_reject_gm_unique") + def kernel(tile: pto.Tile): + return None + + with self.assertRaises(pto.TileLangFrontendError) as ctx: + kernel.specialize(tile=pto.TileSpecialization(shape=(16, 16), memory_space=pto.MemorySpace.GM)) + + self.assertIn("cube v1 only supports MemorySpace.MAT/LEFT/RIGHT/ACC/BIAS/UB", str(ctx.exception)) + + def test_vkernel_specialize_still_rejects_non_ub_bare_tile_profile(self) -> None: + @pto.vkernel(op="vector_tile_reject_cube_space_unique", dtypes=[(pto.f16,)]) + def kernel(tile: pto.Tile): + return None + + with self.assertRaises(pto.TileLangFrontendError) as ctx: + kernel.specialize(tile=pto.TileSpecialization(shape=(16, 16), memory_space=pto.MemorySpace.LEFT)) + + self.assertIn("vector v1 only supports MemorySpace.UB", str(ctx.exception)) + + def test_multi_op_descriptor_requires_select_kernel_before_materialization_apis(self) -> None: + @pto.vkernel( + ops=["multi_op_gate_add_unique", "multi_op_gate_sub_unique"], + dtypes=[(pto.f32, pto.f32)], + ) + def kernel(inp: pto.TensorView, out: pto.TensorView): + return None + + with self.assertRaises(ValueError) as text_ctx: + kernel.mlir_text() + self.assertIn("mlir_text() requires pto.select_kernel(...) to bind a concrete op", str(text_ctx.exception)) + + with self.assertRaises(ValueError) as module_ctx: + kernel.mlir_module() + self.assertIn( + "mlir_module() requires pto.select_kernel(...) to bind a concrete op", + str(module_ctx.exception), + ) + + with tempfile.TemporaryDirectory() as tmpdir: + out = Path(tmpdir) / "kernel.mlir" + with self.assertRaises(ValueError) as emit_ctx: + kernel.emit(out) + self.assertIn("emit() requires pto.select_kernel(...) to bind a concrete op", str(emit_ctx.exception)) + + def test_ckernel_multi_op_descriptor_requires_select_kernel_before_materialization(self) -> None: + @pto.ckernel( + ops=["cube_multi_op_gate_mad_unique", "cube_multi_op_gate_mad_acc_unique"], + dtypes=[(pto.f16, pto.f16, pto.f32)], + name="cube_multi_op_gate_unique", + ) + def kernel(lhs: pto.Tile, rhs: pto.Tile, acc: pto.Tile): + return None + + self.assertIsNone(kernel.selected_op) + + with self.assertRaises(ValueError) as text_ctx: + kernel.mlir_text() + self.assertIn("mlir_text() requires pto.select_kernel(...) to bind a concrete op", str(text_ctx.exception)) + + def test_selected_multi_op_descriptor_can_materialize_normally(self) -> None: + @pto.vkernel( + ops=["multi_op_materialize_add_unique", "multi_op_materialize_sub_unique"], + dtypes=[(pto.f32, pto.f32)], + ) + def kernel(inp: pto.TensorView, out: pto.TensorView): + return None + + selected = pto.select_kernel( + "a5", + "multi_op_materialize_sub_unique", + (pto.f32, pto.f32), + ) + + text = selected.mlir_text() + self.assertIn("// tilelang.target = a5", text) + self.assertIn("// tilelang.op = multi_op_materialize_sub_unique", text) + self.assertIn( + 'func.func @kernel(%arg0: !pto.tensor_view, %arg1: !pto.tensor_view) attributes { pto.tilelang.instance, pto.kernel_kind = #pto.kernel_kind } {', + text, + ) + + def test_ckernel_full_pipeline_mlir_text_and_verify_regression(self) -> None: + @pto.ckernel( + op="cube_full_pipeline_verify_regression_unique", + dtypes=[(pto.f16,)], + name="cube_full_pipeline_verify_regression_unique", + ) + def kernel(inp: pto.TensorView): + gm = inp.as_ptr() + l1 = pto.Tile((16, 32), pto.f16, pto.MemorySpace.MAT) + left = pto.Tile((16, 32), pto.f16, pto.MemorySpace.LEFT) + right = pto.Tile((32, 16), pto.f16, pto.MemorySpace.RIGHT) + acc = pto.Tile((16, 16), pto.f32, pto.MemorySpace.ACC) + pto.cube_load(gm, l1.as_ptr(), 16, nburst=(1, 0, 0)) + pto.left_load(l1.as_ptr(), left.as_ptr(), 16, 32) + pto.right_load(l1.as_ptr(), right.as_ptr(), 32, 16) + pto.mad(left.as_ptr(), right.as_ptr(), acc.as_ptr(), 16, 16, 32) + pto.acc_store_gm(acc.as_ptr(), gm, 16, 16, 16, 16) + return None + + selected = pto.select_kernel( + "a5", + "cube_full_pipeline_verify_regression_unique", + (pto.f16,), + registry=pto.KernelRegistry((kernel,)), + ) + + text = selected.mlir_text() + self.assertIn("// tilelang.op = cube_full_pipeline_verify_regression_unique", text) + self.assertIn("pto.kernel_kind = #pto.kernel_kind", text) + self.assertIn("pto.cube_load ", text) + self.assertIn("pto.left_load ", text) + self.assertIn("pto.right_load ", text) + self.assertIn("pto.mad ", text) + self.assertIn("pto.acc_store_gm ", text) + + def test_ckernel_pure_compute_mlir_text_and_verify_regression(self) -> None: + @pto.ckernel( + op="cube_pure_compute_verify_regression_unique", + dtypes=[(pto.f16, pto.f16, pto.f32)], + name="cube_pure_compute_verify_regression_unique", + ) + def kernel(lhs: pto.Tile, rhs: pto.Tile, acc: pto.Tile): + pto.mad_acc(lhs.as_ptr(), rhs.as_ptr(), acc.as_ptr(), 16, 16, 32) + return None + + specialized = kernel.specialize( + lhs=pto.TileSpecialization(shape=(16, 32), memory_space=pto.MemorySpace.LEFT), + rhs=pto.TileSpecialization(shape=(32, 16), memory_space=pto.MemorySpace.RIGHT), + acc=pto.TileSpecialization(shape=(16, 16), memory_space=pto.MemorySpace.ACC), + ) + + text = specialized.mlir_text() + self.assertIn("// tilelang.op = cube_pure_compute_verify_regression_unique", text) + self.assertIn("pto.kernel_kind = #pto.kernel_kind", text) + self.assertIn("!pto.tile_buf None: + @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f16, pto.i32)]) + def kernel(inp: pto.TensorView, tile: pto.Tile, scale: pto.i32): + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(8, 16), + memory_space=pto.MemorySpace.UB, + ) + ) + + frontend_kernel = build_frontend_kernel_node(specialized) + self.assertEqual(frontend_kernel.name, "kernel") + self.assertEqual( + [(param.name, param.kind) for param in frontend_kernel.parameters], + [("inp", "tensorview"), ("tile", "tile"), ("scale", "scalar")], + ) + self.assertEqual(frontend_kernel.tile_specializations[0].shape, (8, 16)) + + semantic_kernel = analyze_frontend_kernel(frontend_kernel) + self.assertEqual(semantic_kernel.symbol_name, "kernel") + self.assertEqual(semantic_kernel.tile_bindings[0].memory_space, "ub") + + authoring_module = lower_semantic_kernel(semantic_kernel) + self.assertIsInstance(authoring_module, AuthoringModule) + self.assertEqual(authoring_module.render(), specialized.mlir_text()) + self.assertIn("return", authoring_module.render()) + + def test_descriptor_pipeline_ignores_kernel_docstring_expression(self) -> None: + @pto.vkernel(op="docstring_passthrough_unique", dtypes=[(pto.f32,)]) + def kernel(inp: pto.TensorView): + """This docstring should be ignored as a no-op expression statement.""" + return None + + frontend_kernel = build_frontend_kernel_node(kernel) + self.assertEqual(len(frontend_kernel.body), 2) + self.assertIsInstance(frontend_kernel.body[0], FrontendExprStmt) + + semantic_kernel = analyze_frontend_kernel(frontend_kernel) + self.assertEqual(len(semantic_kernel.body), 1) + + text = lower_semantic_kernel(semantic_kernel).render() + self.assertIn("// tilelang.op = docstring_passthrough_unique", text) + self.assertIn("func.func @kernel", text) + self.assertIn("return", text) + + def test_frontend_rejects_hidden_dma_load_surface(self) -> None: + with self.assertRaises(pto.TileLangFrontendError) as ctx: + + @pto.vkernel(op="dma_load_hidden", dtypes=[(pto.f32, pto.f32)]) + def kernel(inp: pto.TensorView, tile: pto.Tile): + pto.dma_load(inp[0:16, 0:16], tile) + return None + + self.assertIn("unsupported op surface `pto.dma_load`", str(ctx.exception)) + self.assertIn(f"{__file__}:", str(ctx.exception)) + + def test_frontend_rejects_hidden_dma_store_surface(self) -> None: + with self.assertRaises(pto.TileLangFrontendError) as ctx: + + @pto.vkernel(op="dma_store_hidden", dtypes=[(pto.f32, pto.f32)]) + def kernel(out: pto.TensorView, tile: pto.Tile): + pto.dma_store(tile, out[0:16, 0:16]) + return None + + self.assertIn("unsupported op surface `pto.dma_store`", str(ctx.exception)) + self.assertIn(f"{__file__}:", str(ctx.exception)) + + def test_frontend_rejects_hidden_dma_copy_surface(self) -> None: + with self.assertRaises(pto.TileLangFrontendError) as ctx: + + @pto.vkernel(op="dma_copy_hidden", dtypes=[(pto.f32, pto.f32)]) + def kernel(src: pto.Tile, dst: pto.Tile): + pto.dma_copy(src, dst) + return None + + self.assertIn("unsupported op surface `pto.dma_copy`", str(ctx.exception)) + self.assertIn(f"{__file__}:", str(ctx.exception)) + + def test_frontend_rejects_keyword_arguments_on_public_surfaces(self) -> None: + with self.assertRaises(pto.TileLangFrontendError) as ctx: + + @pto.vkernel(op="dma_kw_wrong_surface", dtypes=[(pto.f32, pto.f32)]) + def kernel(inp: pto.TensorView, tile: pto.Tile): + pto.vlds(tile, offset=0) + return None + + self.assertIn( + "unsupported keyword `offset` for `pto.vlds` in TileLang DSL v1", + str(ctx.exception), + ) + self.assertIn(f"{__file__}:", str(ctx.exception)) + + def test_frontend_rewrites_template_slot_to_selected_real_op(self) -> None: + @pto.vkernel( + ops=["template_slot_add_unique", "template_slot_sub_unique"], + dtypes=[(pto.f32, pto.f32, pto.f32)], + advanced=True, + templates={ + "core": { + "template_slot_add_unique": "vadd", + "template_slot_sub_unique": "vsub", + } + }, + ) + def kernel(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): + with pto.strict_vecscope(dst, src0, src1, 0, 64, 64) as ( + out_tile, + lhs_tile, + rhs_tile, + lb, + ub, + step, + ): + for lane in range(lb, ub, step): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + lhs = pto.vlds(lhs_tile, lane) + rhs = pto.vlds(rhs_tile, lane) + out = pto.tpl("core", lhs, rhs, mask) + pto.vsts(out, out_tile, lane, mask) + return None + + add_selected = pto.select_kernel( + "a5", + "template_slot_add_unique", + (pto.f32, pto.f32, pto.f32), + ).specialize( + dst=pto.TileSpecialization(shape=(16, 16), memory_space=pto.MemorySpace.UB), + src0=pto.TileSpecialization(shape=(16, 16), memory_space=pto.MemorySpace.UB), + src1=pto.TileSpecialization(shape=(16, 16), memory_space=pto.MemorySpace.UB), + ) + sub_selected = pto.select_kernel( + "a5", + "template_slot_sub_unique", + (pto.f32, pto.f32, pto.f32), + ).specialize( + dst=pto.TileSpecialization(shape=(16, 16), memory_space=pto.MemorySpace.UB), + src0=pto.TileSpecialization(shape=(16, 16), memory_space=pto.MemorySpace.UB), + src1=pto.TileSpecialization(shape=(16, 16), memory_space=pto.MemorySpace.UB), + ) + + add_frontend = build_frontend_kernel_node(add_selected) + sub_frontend = build_frontend_kernel_node(sub_selected) + + add_vecscope = add_frontend.body[0] + sub_vecscope = sub_frontend.body[0] + self.assertIsInstance(add_vecscope, FrontendStrictVecscopeStmt) + self.assertIsInstance(sub_vecscope, FrontendStrictVecscopeStmt) + + add_loop = add_vecscope.body[0] + sub_loop = sub_vecscope.body[0] + self.assertIsInstance(add_loop, FrontendForStmt) + self.assertIsInstance(sub_loop, FrontendForStmt) + + add_out_assign = add_loop.body[3] + sub_out_assign = sub_loop.body[3] + self.assertIsInstance(add_out_assign, FrontendAssignStmt) + self.assertIsInstance(sub_out_assign, FrontendAssignStmt) + self.assertIsInstance(add_out_assign.value, FrontendCallExpr) + self.assertIsInstance(sub_out_assign.value, FrontendCallExpr) + self.assertEqual(add_out_assign.value.namespace, "pto") + self.assertEqual(sub_out_assign.value.namespace, "pto") + self.assertEqual(add_out_assign.value.name, "vadd") + self.assertEqual(sub_out_assign.value.name, "vsub") + + add_text = add_selected.mlir_text() + sub_text = sub_selected.mlir_text() + self.assertIn("pto.vadd", add_text) + self.assertNotIn("pto.vsub", add_text) + self.assertIn("pto.vsub", sub_text) + self.assertNotIn("pto.vadd", sub_text) + + def test_template_slot_shared_kernel_body_expands_for_four_ops(self) -> None: + @pto.vkernel( + ops=[ + "template_slot_tadd_unique", + "template_slot_tsub_unique", + "template_slot_tmul_unique", + "template_slot_tdiv_unique", + ], + dtypes=[(pto.f32, pto.f32, pto.f32)], + advanced=True, + templates={ + "core": { + "template_slot_tadd_unique": "vadd", + "template_slot_tsub_unique": "vsub", + "template_slot_tmul_unique": "vmul", + "template_slot_tdiv_unique": "vdiv", + } + }, + ) + def kernel(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): + with pto.strict_vecscope(dst, src0, src1, 0, 64, 64) as ( + out_tile, + lhs_tile, + rhs_tile, + lb, + ub, + step, + ): + for lane in range(lb, ub, step): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + lhs = pto.vlds(lhs_tile, lane) + rhs = pto.vlds(rhs_tile, lane) + out = pto.tpl("core", lhs, rhs, mask) + pto.vsts(out, out_tile, lane, mask) + return None + + isolated_registry = pto.KernelRegistry((kernel,)) + expected_ops = { + "template_slot_tadd_unique": "vadd", + "template_slot_tsub_unique": "vsub", + "template_slot_tmul_unique": "vmul", + "template_slot_tdiv_unique": "vdiv", + } + + for query_op, real_op in expected_ops.items(): + selected = pto.select_kernel( + "a5", + query_op, + (pto.f32, pto.f32, pto.f32), + registry=isolated_registry, + ).specialize( + dst=pto.TileSpecialization(shape=(16, 16), memory_space=pto.MemorySpace.UB), + src0=pto.TileSpecialization(shape=(16, 16), memory_space=pto.MemorySpace.UB), + src1=pto.TileSpecialization(shape=(16, 16), memory_space=pto.MemorySpace.UB), + ) + + frontend_kernel = build_frontend_kernel_node(selected) + vecscope = frontend_kernel.body[0] + self.assertIsInstance(vecscope, FrontendStrictVecscopeStmt) + loop_stmt = vecscope.body[0] + self.assertIsInstance(loop_stmt, FrontendForStmt) + out_assign = loop_stmt.body[3] + self.assertIsInstance(out_assign, FrontendAssignStmt) + self.assertIsInstance(out_assign.value, FrontendCallExpr) + self.assertEqual(out_assign.value.name, real_op) + + text = selected.mlir_text() + self.assertIn(f"pto.{real_op}", text) + self.assertNotIn("pto.tpl(", text) + + def test_template_slot_rejects_non_literal_slot_name(self) -> None: + slot_name = "core" + + @pto.vkernel( + op="template_slot_non_literal_unique", + dtypes=[(pto.f32, pto.f32, pto.f32)], + advanced=True, + templates={"core": {"template_slot_non_literal_unique": "vadd"}}, + ) + def kernel(dst: pto.TensorView, src0: pto.TensorView, src1: pto.TensorView): + with pto.strict_vecscope(dst, src0, src1, 0, 64, 64) as (out_tile, lhs_tile, rhs_tile, lb, ub, step): + for lane in range(lb, ub, step): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + out = pto.tpl(slot_name, lhs_tile, rhs_tile, mask) + return None + + with self.assertRaises(pto.TileLangFrontendError) as ctx: + build_frontend_kernel_node(kernel) + + self.assertIn("pto.tpl() requires a non-empty string literal slot name", str(ctx.exception)) + self.assertIn(f"{__file__}:", str(ctx.exception)) + + def test_template_slot_rejects_unknown_slot_before_ir_generation(self) -> None: + @pto.vkernel( + op="template_slot_unknown_slot_unique", + dtypes=[(pto.f32, pto.f32, pto.f32)], + advanced=True, + templates={"core": {"template_slot_unknown_slot_unique": "vadd"}}, + ) + def kernel(dst: pto.TensorView, src0: pto.TensorView, src1: pto.TensorView): + with pto.strict_vecscope(dst, src0, src1, 0, 64, 64) as (out_tile, lhs_tile, rhs_tile, lb, ub, step): + for lane in range(lb, ub, step): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + out = pto.tpl("missing", lhs_tile, rhs_tile, mask) + return None + + with self.assertRaises(pto.TileLangFrontendError) as ctx: + build_frontend_kernel_node(kernel) + + self.assertIn("unknown template slot 'missing'", str(ctx.exception)) + self.assertIn(f"{__file__}:", str(ctx.exception)) + + def test_template_slot_rejects_missing_selected_op_mapping(self) -> None: + @pto.vkernel( + ops=["template_slot_missing_map_add_unique", "template_slot_missing_map_sub_unique"], + dtypes=[(pto.f32, pto.f32, pto.f32)], + advanced=True, + templates={"core": {"template_slot_missing_map_add_unique": "vadd"}}, + ) + def kernel(dst: pto.TensorView, src0: pto.TensorView, src1: pto.TensorView): + with pto.strict_vecscope(dst, src0, src1, 0, 64, 64) as (out_tile, lhs_tile, rhs_tile, lb, ub, step): + for lane in range(lb, ub, step): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + out = pto.tpl("core", lhs_tile, rhs_tile, mask) + return None + + selected = pto.select_kernel( + "a5", + "template_slot_missing_map_sub_unique", + (pto.f32, pto.f32, pto.f32), + ) + + with self.assertRaises(pto.TileLangFrontendError) as ctx: + build_frontend_kernel_node(selected) + + self.assertIn("template slot 'core' does not define an implementation for selected op", str(ctx.exception)) + self.assertIn("template_slot_missing_map_sub_unique", str(ctx.exception)) + self.assertIn(f"{__file__}:", str(ctx.exception)) + + def test_template_slot_requires_selected_op_before_expansion(self) -> None: + @pto.vkernel( + ops=["template_slot_unbound_add_unique", "template_slot_unbound_sub_unique"], + dtypes=[(pto.f32, pto.f32, pto.f32)], + advanced=True, + templates={ + "core": { + "template_slot_unbound_add_unique": "vadd", + "template_slot_unbound_sub_unique": "vsub", + } + }, + ) + def kernel(dst: pto.TensorView, src0: pto.TensorView, src1: pto.TensorView): + with pto.strict_vecscope(dst, src0, src1, 0, 64, 64) as (out_tile, lhs_tile, rhs_tile, lb, ub, step): + for lane in range(lb, ub, step): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + out = pto.tpl("core", lhs_tile, rhs_tile, mask) + return None + + with self.assertRaises(pto.TileLangFrontendError) as ctx: + build_frontend_kernel_node(kernel) + + self.assertIn("pto.tpl() requires pto.select_kernel(...) to bind a concrete op before expansion", str(ctx.exception)) + self.assertIn(f"{__file__}:", str(ctx.exception)) + + def test_template_slot_respects_resolved_op_surface_rules(self) -> None: + @pto.vkernel( + op="template_slot_advanced_surface_unique", + dtypes=[(pto.i32, pto.i32, pto.i32)], + templates={"cmp": {"template_slot_advanced_surface_unique": "vcmp"}}, + ) + def kernel(dst: pto.TensorView, src0: pto.TensorView, src1: pto.TensorView): + mask = pto.make_mask(pto.i32, pto.PAT.ALL) + out = pto.tpl("cmp", dst, src0, mask, "lt") + return None + + with self.assertRaises(pto.TileLangFrontendError) as ctx: + build_frontend_kernel_node(kernel) + + self.assertIn("surface `pto.vcmp` requires advanced=True", str(ctx.exception)) + self.assertIn(f"{__file__}:", str(ctx.exception)) + + def test_ckernel_template_slot_expands_after_selected_op_binding(self) -> None: + @pto.ckernel( + ops=[ + "cube_template_slot_mad_query_unique", + "cube_template_slot_mad_acc_query_unique", + ], + dtypes=[(pto.f16, pto.f16, pto.f32)], + templates={ + "compute": { + "cube_template_slot_mad_query_unique": "mad", + "cube_template_slot_mad_acc_query_unique": "mad_acc", + } + }, + name="cube_template_slot_unique", + ) + def kernel(lhs: pto.Tile, rhs: pto.Tile, acc: pto.Tile): + lhs_ptr = lhs.as_ptr() + rhs_ptr = rhs.as_ptr() + acc_ptr = acc.as_ptr() + pto.tpl("compute", lhs_ptr, rhs_ptr, acc_ptr, 16, 16, 32) + return None + + mad_selected = pto.select_kernel( + "a5", + "cube_template_slot_mad_query_unique", + (pto.f16, pto.f16, pto.f32), + ).specialize( + lhs=pto.TileSpecialization(shape=(16, 32), memory_space=pto.MemorySpace.LEFT), + rhs=pto.TileSpecialization(shape=(32, 16), memory_space=pto.MemorySpace.RIGHT), + acc=pto.TileSpecialization(shape=(16, 16), memory_space=pto.MemorySpace.ACC), + ) + mad_acc_selected = pto.select_kernel( + "a5", + "cube_template_slot_mad_acc_query_unique", + (pto.f16, pto.f16, pto.f32), + ).specialize( + lhs=pto.TileSpecialization(shape=(16, 32), memory_space=pto.MemorySpace.LEFT), + rhs=pto.TileSpecialization(shape=(32, 16), memory_space=pto.MemorySpace.RIGHT), + acc=pto.TileSpecialization(shape=(16, 16), memory_space=pto.MemorySpace.ACC), + ) + + mad_frontend = build_frontend_kernel_node(mad_selected) + mad_acc_frontend = build_frontend_kernel_node(mad_acc_selected) + mad_expr_stmt = mad_frontend.body[3] + mad_acc_expr_stmt = mad_acc_frontend.body[3] + self.assertIsInstance(mad_expr_stmt, FrontendExprStmt) + self.assertIsInstance(mad_acc_expr_stmt, FrontendExprStmt) + self.assertIsInstance(mad_expr_stmt.expr, FrontendCallExpr) + self.assertIsInstance(mad_acc_expr_stmt.expr, FrontendCallExpr) + self.assertEqual(mad_expr_stmt.expr.namespace, "pto") + self.assertEqual(mad_acc_expr_stmt.expr.namespace, "pto") + self.assertEqual(mad_expr_stmt.expr.name, "mad") + self.assertEqual(mad_acc_expr_stmt.expr.name, "mad_acc") + + mad_semantic = analyze_frontend_kernel(mad_frontend) + mad_acc_semantic = analyze_frontend_kernel(mad_acc_frontend) + mad_call = next( + stmt.expr for stmt in mad_semantic.body + if isinstance(stmt, SemanticExprStmt) + and isinstance(stmt.expr, SemanticCallExpr) + and stmt.expr.namespace == "pto" + and stmt.expr.name in {"mad", "mad_acc"} + ) + mad_acc_call = next( + stmt.expr for stmt in mad_acc_semantic.body + if isinstance(stmt, SemanticExprStmt) + and isinstance(stmt.expr, SemanticCallExpr) + and stmt.expr.namespace == "pto" + and stmt.expr.name in {"mad", "mad_acc"} + ) + self.assertEqual(mad_call.name, "mad") + self.assertEqual(mad_acc_call.name, "mad_acc") + + def test_ckernel_template_slot_rejects_missing_selected_op_mapping(self) -> None: + @pto.ckernel( + ops=["cube_template_slot_missing_map_mad_unique", "cube_template_slot_missing_map_mad_acc_unique"], + dtypes=[(pto.f16,)], + templates={"compute": {"cube_template_slot_missing_map_mad_unique": "mad"}}, + name="cube_template_slot_missing_map_unique", + ) + def kernel(inp: pto.TensorView): + left = pto.Tile((16, 32), pto.f16, pto.MemorySpace.LEFT) + right = pto.Tile((32, 16), pto.f16, pto.MemorySpace.RIGHT) + acc = pto.Tile((16, 16), pto.f32, pto.MemorySpace.ACC) + pto.tpl("compute", left.as_ptr(), right.as_ptr(), acc.as_ptr(), 16, 16, 32) + return None + + selected = pto.select_kernel( + "a5", + "cube_template_slot_missing_map_mad_acc_unique", + (pto.f16,), + registry=pto.KernelRegistry((kernel,)), + ) + + with self.assertRaises(pto.TileLangFrontendError) as ctx: + build_frontend_kernel_node(selected) + + self.assertIn("template slot 'compute' does not define an implementation for selected op", str(ctx.exception)) + self.assertIn("cube_template_slot_missing_map_mad_acc_unique", str(ctx.exception)) + self.assertIn(f"{__file__}:", str(ctx.exception)) + + def test_callable_based_runtime_template_dispatch_remains_rejected(self) -> None: + with self.assertRaises(pto.TileLangFrontendError) as ctx: + + @pto.vkernel( + op="template_slot_callable_dispatch_unique", + dtypes=[(pto.f32, pto.f32, pto.f32)], + advanced=True, + ) + def kernel(dst: pto.TensorView, src0: pto.TensorView, src1: pto.TensorView): + table = {"core": pto.vadd} + with pto.strict_vecscope(dst, src0, src1, 0, 64, 64) as ( + out_tile, + lhs_tile, + rhs_tile, + lb, + ub, + step, + ): + for lane in range(lb, ub, step): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + out = table["core"](lhs_tile, rhs_tile, mask) + return None + + self.assertIn("unsupported call surface in TileLang DSL v1", str(ctx.exception)) + self.assertIn(f"{__file__}:", str(ctx.exception)) + + def test_semantic_pipeline_binds_parameter_loop_and_strict_vecscope_types(self) -> None: + @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f16, pto.i32)], advanced=True) + def kernel(inp: pto.TensorView, tile: pto.Tile, scale: pto.i32): + rows = tile.shape[0] + step = rows + with pto.strict_vecscope(inp, tile, scale, 0, rows, step) as ( + vin, + vtmp, + factor, + lb, + ub, + vec_step, + ): + for lane in range(lb, ub, vec_step): + current = factor + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(8, 16), + memory_space=pto.MemorySpace.UB, + ) + ) + + frontend_kernel = build_frontend_kernel_node(specialized) + self.assertEqual(len(frontend_kernel.body), 4) + + semantic_kernel = analyze_frontend_kernel(frontend_kernel) + self.assertIsInstance(semantic_kernel.parameters[0].type, SemanticTensorViewType) + self.assertEqual(semantic_kernel.parameters[0].type.rank, 5) + self.assertIsInstance(semantic_kernel.parameters[1].type, SemanticTileType) + self.assertEqual(semantic_kernel.parameters[1].type.shape, (8, 16)) + self.assertIsInstance(semantic_kernel.parameters[2].type, SemanticScalarType) + + rows_assign = semantic_kernel.body[0] + self.assertIsInstance(rows_assign, SemanticAssignStmt) + self.assertIsInstance(rows_assign.targets[0].type, SemanticIndexType) + self.assertTrue(rows_assign.targets[0].ssa_name.startswith("%rows_")) + + vecscope_stmt = semantic_kernel.body[2] + self.assertIsInstance(vecscope_stmt, SemanticStrictVecscopeStmt) + self.assertEqual( + [binding.name for binding in vecscope_stmt.block_arguments], + ["vin", "vtmp", "factor", "lb", "ub", "vec_step"], + ) + self.assertIsInstance(vecscope_stmt.block_arguments[0].type, SemanticTensorViewType) + self.assertIsInstance(vecscope_stmt.block_arguments[1].type, SemanticTileType) + self.assertIsInstance(vecscope_stmt.block_arguments[2].type, SemanticScalarType) + self.assertIsInstance(vecscope_stmt.block_arguments[3].type, SemanticIndexType) + self.assertIsInstance(vecscope_stmt.block_arguments[4].type, SemanticIndexType) + self.assertIsInstance(vecscope_stmt.block_arguments[5].type, SemanticIndexType) + self.assertTrue(vecscope_stmt.block_arguments[0].ssa_name.startswith("%vin_")) + + loop_stmt = vecscope_stmt.body[0] + self.assertIsInstance(loop_stmt, SemanticForStmt) + self.assertEqual(loop_stmt.induction_variable.name, "lane") + self.assertIsInstance(loop_stmt.induction_variable.type, SemanticIndexType) + self.assertTrue(loop_stmt.induction_variable.ssa_name.startswith("%lane_")) + self.assertEqual(loop_stmt.loop_carried, ()) + + text = specialized.mlir_text() + self.assertIn("%rows_", text) + self.assertIn("= arith.constant 8 : index", text) + self.assertRegex( + text, + r"pto\.strict_vecscope\(%tmp_\d+, %tmp_\d+, %arg2, %c0, %rows_\d+, %rows_\d+\)", + ) + self.assertIn("^bb0(", text) + self.assertIn("scf.for %lane_", text) + self.assertIn("to %ub_6 step %vec_step_7 {", text) + + def test_tensorview_defaults_to_5d_shape_profile(self) -> None: + @pto.vkernel(op="tensorview_5d_shape_profile_unique", dtypes=[(pto.f32,)]) + def kernel(inp: pto.TensorView): + d0, d1, d2, d3, d4 = inp.valid_shape + return None + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(kernel)) + self.assertIsInstance(semantic_kernel.parameters[0].type, SemanticTensorViewType) + self.assertEqual(semantic_kernel.parameters[0].type.rank, 5) + self.assertEqual( + [(param.name, param.kind) for param in semantic_kernel.parameters], + [("inp", "tensorview")], + ) + + text = kernel.mlir_text() + self.assertIn( + "func.func @kernel(%arg0: !pto.tensor_view) " + "attributes { pto.tilelang.instance, pto.kernel_kind = #pto.kernel_kind } {", + text, + ) + self.assertEqual(text.count("pto.get_tensor_view_dim"), 5) + + def test_tensorview_strides_profile_lowers_through_explicit_stride_queries(self) -> None: + @pto.vkernel(op="tensorview_5d_stride_profile_unique", dtypes=[(pto.f32,)]) + def kernel(inp: pto.TensorView): + s0, s1, s2, s3, s4 = inp.strides + for lane in range(0, s4, 1): + current = lane + return None + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(kernel)) + self.assertEqual( + [(param.name, param.kind) for param in semantic_kernel.parameters], + [("inp", "tensorview")], + ) + + text = kernel.mlir_text() + self.assertIn( + "func.func @kernel(%arg0: !pto.tensor_view) " + "attributes { pto.tilelang.instance, pto.kernel_kind = #pto.kernel_kind } {", + text, + ) + self.assertEqual(text.count("pto.get_tensor_view_stride"), 5) + self.assertRegex(text, r"scf\.for %lane_\d+ = %c0 to %s4_\d+ step %c1 \{") + + def test_tensorview_accepts_full_5d_slice_profile(self) -> None: + @pto.vkernel(op="tensorview_5d_slice_profile_unique", dtypes=[(pto.f32,)]) + def kernel(inp: pto.TensorView): + view = inp[0:1, 0:2, 0:3, 0:4, 0:5] + return None + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(kernel)) + slice_assign = semantic_kernel.body[0] + self.assertIsInstance(slice_assign, SemanticAssignStmt) + self.assertEqual(slice_assign.value.type.rank, 5) + self.assertEqual(slice_assign.value.type.extents, (1, 2, 3, 4, 5)) + self.assertEqual(slice_assign.value.type.physical_axes, (0, 1, 2, 3, 4)) + + def test_tensorview_3d_slice_profile_right_aligns_into_5d_descriptor(self) -> None: + @pto.vkernel(op="tensorview_3d_slice_profile_unique", dtypes=[(pto.f32,)]) + def kernel(inp: pto.TensorView): + view = inp[0:8, 0:16, 0:32] + return None + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(kernel)) + slice_assign = semantic_kernel.body[0] + self.assertIsInstance(slice_assign, SemanticAssignStmt) + self.assertEqual(slice_assign.value.type.rank, 3) + self.assertEqual(slice_assign.value.type.extents, (8, 16, 32)) + self.assertEqual(slice_assign.value.type.physical_axes, (2, 3, 4)) + + def test_tensorview_2d_slice_profile_right_aligns_into_5d_descriptor(self) -> None: + @pto.vkernel(op="tensorview_2d_slice_profile_unique", dtypes=[(pto.f32,)]) + def kernel(inp: pto.TensorView): + view = inp[0:16, 0:32] + return None + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(kernel)) + slice_assign = semantic_kernel.body[0] + self.assertIsInstance(slice_assign, SemanticAssignStmt) + self.assertEqual(slice_assign.value.type.rank, 2) + self.assertEqual(slice_assign.value.type.extents, (16, 32)) + self.assertEqual(slice_assign.value.type.physical_axes, (3, 4)) + + def test_tensorview_slice_binding_lowers_to_partition_tensor_view_descriptor(self) -> None: + @pto.vkernel(op="tensorview_slice_partition_binding_unique", dtypes=[(pto.f32,)]) + def kernel(inp: pto.TensorView): + part = inp[0:16, 0:32] + rows, cols = part.shape + s0, s1 = part.strides + if rows != 0 and cols != 0: + rows = s0 + s1 + return None + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(kernel)) + slice_assign = semantic_kernel.body[0] + self.assertIsInstance(slice_assign, SemanticAssignStmt) + self.assertIsInstance(slice_assign.targets[0].type, SemanticPartitionTensorViewType) + self.assertEqual(slice_assign.targets[0].type.rank, 2) + + text = kernel.mlir_text() + self.assertIn(" = pto.partition_view %arg0, offsets = [%c0, %c0], sizes = [%c16, %c32] : ", text) + self.assertIn("-> !pto.partition_tensor_view<16x32xf32>", text) + self.assertEqual(text.count("pto.get_tensor_view_dim"), 2) + self.assertEqual(text.count("pto.get_tensor_view_stride"), 2) + + def test_partition_tensor_view_annotation_accepts_tensorview_slice_binding(self) -> None: + @pto.vkernel(op="partition_tensor_view_annotation_unique", dtypes=[(pto.f32,)]) + def kernel(inp: pto.TensorView): + part: pto.PartitionTensorView = inp[0:8, 0:8] + r0, r1 = part.shape + return None + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(kernel)) + slice_assign = semantic_kernel.body[0] + self.assertIsInstance(slice_assign, SemanticAssignStmt) + self.assertIsInstance(slice_assign.targets[0].type, SemanticPartitionTensorViewType) + self.assertEqual(slice_assign.targets[0].type.rank, 2) + + text = kernel.mlir_text() + self.assertIn(" = pto.partition_view %arg0, offsets = [%c0, %c0], sizes = [%c8, %c8] : ", text) + self.assertIn("-> !pto.partition_tensor_view<8x8xf32>", text) + self.assertEqual(text.count("pto.get_tensor_view_dim"), 2) + + def test_dynamic_tensorview_shape_profile_supports_runtime_bound_without_high_level_dma(self) -> None: + @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32)]) + def kernel(inp: pto.TensorView, tile: pto.Tile): + rows = inp.shape[0] + for lane in range(0, rows, 1): + current = lane + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(16, 16), + memory_space=pto.MemorySpace.UB, + ) + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + self.assertEqual( + [(param.name, param.kind) for param in semantic_kernel.parameters], + [("inp", "tensorview"), ("tile", "tile")], + ) + + rows_assign = semantic_kernel.body[0] + self.assertIsInstance(rows_assign, SemanticAssignStmt) + self.assertIsInstance(rows_assign.targets[0].type, SemanticIndexType) + + loop_stmt = semantic_kernel.body[1] + self.assertIsInstance(loop_stmt, SemanticForStmt) + + text = specialized.mlir_text() + self.assertIn( + "func.func @kernel(%arg0: !pto.tensor_view, %arg1: !pto.tile_buf) attributes { pto.tilelang.instance, pto.kernel_kind = #pto.kernel_kind } {", + text, + ) + self.assertIn("scf.for %lane_", text) + self.assertIn("pto.get_tensor_view_dim", text) + + def test_semantic_recognizes_padmode_symbol(self) -> None: + @pto.vkernel(op="pad_mode_symbol", dtypes=[(pto.f32, pto.f32)]) + def kernel(inp: pto.TensorView, tile: pto.Tile): + mode = pto.PadMode.PadFirstElem + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(16, 16), + memory_space=pto.MemorySpace.UB, + ) + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + assign_stmt = semantic_kernel.body[0] + self.assertIsInstance(assign_stmt, SemanticAssignStmt) + self.assertIsInstance(assign_stmt.value, SemanticSymbolExpr) + self.assertEqual(assign_stmt.value.value, pto.PadMode.PadFirstElem) + self.assertEqual(assign_stmt.value.type.kind, "pad_mode") + + def test_tile_config_attributes_bind_as_static_metadata(self) -> None: + @pto.vkernel(op="tile_config_attrs_unique", dtypes=[(pto.f16,)]) + def kernel(tile: pto.Tile): + config = tile.config + layout = config.b_layout + secondary = config.s_layout + fractal = config.s_fractal_size + pad = config.pad_value + pad_direct = tile.pad_value + pad_scalar = pad.eval() + pad_direct_scalar = pad_direct.eval() + rank = tile.rank + space = tile.memory_space + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(8, 16), + memory_space=pto.MemorySpace.UB, + config=pto.TileConfig.from_mapping( + { + "b_layout": pto.BLayout.COL_MAJOR, + "s_layout": pto.SLayout.ROW_MAJOR, + "s_fractal_size": 16, + "pad_value": pto.PadValue.ZERO, + } + ), + ) + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + ( + config_assign, + layout_assign, + secondary_assign, + fractal_assign, + pad_assign, + pad_direct_assign, + pad_scalar_assign, + pad_direct_scalar_assign, + rank_assign, + space_assign, + ) = ( + semantic_kernel.body[:10] + ) + + self.assertIsInstance(config_assign, SemanticAssignStmt) + self.assertIsInstance(config_assign.targets[0].type, SemanticTileConfigType) + self.assertIsInstance(config_assign.value, SemanticLiteralExpr) + self.assertEqual(config_assign.targets[0].value, config_assign.value.value) + self.assertIsInstance(config_assign.value.type, SemanticTileConfigType) + + self.assertIsInstance(layout_assign.value, SemanticSymbolExpr) + self.assertEqual(layout_assign.value.value, pto.BLayout.COL_MAJOR) + self.assertEqual(layout_assign.value.type.kind, "b_layout") + + self.assertIsInstance(secondary_assign.value, SemanticSymbolExpr) + self.assertEqual(secondary_assign.value.value, pto.SLayout.ROW_MAJOR) + self.assertEqual(secondary_assign.value.type.kind, "s_layout") + + self.assertIsInstance(fractal_assign.value, SemanticLiteralExpr) + self.assertEqual(fractal_assign.value.value, 16) + self.assertIsInstance(fractal_assign.targets[0].type, SemanticScalarType) + self.assertEqual(fractal_assign.targets[0].type.dtype, pto.i32) + + self.assertIsInstance(pad_assign.value, SemanticSymbolExpr) + self.assertEqual(pad_assign.value.value, pto.PadValue.ZERO) + self.assertIsInstance(pad_assign.targets[0].type, SemanticPadValueType) + self.assertEqual(pad_assign.targets[0].type.element_dtype, pto.f16) + + self.assertIsInstance(pad_direct_assign.value, SemanticSymbolExpr) + self.assertEqual(pad_direct_assign.value.value, pto.PadValue.ZERO) + self.assertIsInstance(pad_direct_assign.targets[0].type, SemanticPadValueType) + self.assertEqual(pad_direct_assign.targets[0].type.element_dtype, pto.f16) + + self.assertIsInstance(pad_scalar_assign.value, SemanticLiteralExpr) + self.assertEqual(pad_scalar_assign.value.value, 0.0) + self.assertIsInstance(pad_scalar_assign.targets[0].type, SemanticScalarType) + self.assertEqual(pad_scalar_assign.targets[0].type.dtype, pto.f16) + + self.assertIsInstance(pad_direct_scalar_assign.value, SemanticLiteralExpr) + self.assertEqual(pad_direct_scalar_assign.value.value, 0.0) + self.assertIsInstance(pad_direct_scalar_assign.targets[0].type, SemanticScalarType) + self.assertEqual(pad_direct_scalar_assign.targets[0].type.dtype, pto.f16) + + self.assertEqual(rank_assign.value.value, 2) + self.assertIsInstance(rank_assign.targets[0].type, SemanticIndexType) + + self.assertIsInstance(space_assign.value, SemanticSymbolExpr) + self.assertEqual(space_assign.value.value, pto.MemorySpace.UB) + self.assertEqual(space_assign.value.type.kind, "memory_space") + + def test_pad_value_eval_requires_non_null_enum(self) -> None: + @pto.vkernel(op="tile_pad_value_null_eval", dtypes=[(pto.f16,)]) + def kernel(tile: pto.Tile): + scalar = tile.pad_value.eval() + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(8, 16), + memory_space=pto.MemorySpace.UB, + ) + ) + + with self.assertRaises(TypeError) as ctx: + analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + + self.assertIn("PadValue.NULL.eval() is invalid", str(ctx.exception)) + + def test_standalone_pad_value_eval_accepts_explicit_dtype(self) -> None: + @pto.vkernel(op="standalone_pad_value_eval_dtype", dtypes=[(pto.f32,)]) + def kernel(tile: pto.Tile): + scalar = pto.PadValue.MAX.eval(pto.f32) + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(8, 16), + memory_space=pto.MemorySpace.UB, + ) + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + scalar_assign = semantic_kernel.body[0] + + self.assertIsInstance(scalar_assign, SemanticAssignStmt) + self.assertIsInstance(scalar_assign.value, SemanticLiteralExpr) + self.assertAlmostEqual(scalar_assign.value.value, pto.PadValue.MAX.eval(pto.f32)) + self.assertIsInstance(scalar_assign.targets[0].type, SemanticScalarType) + self.assertEqual(scalar_assign.targets[0].type.dtype, pto.f32) + + def test_standalone_pad_value_eval_accepts_static_dtype_binding(self) -> None: + @pto.vkernel(op="standalone_pad_value_eval_dtype_binding", dtypes=[(pto.f32,)]) + def kernel(tile: pto.Tile): + dtype = tile.element_type + scalar = pto.PadValue.MAX.eval(dtype) + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(8, 16), + memory_space=pto.MemorySpace.UB, + ) + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + dtype_assign, scalar_assign = semantic_kernel.body[:2] + + self.assertIsInstance(dtype_assign, SemanticAssignStmt) + self.assertIsInstance(dtype_assign.value, SemanticSymbolExpr) + self.assertEqual(dtype_assign.value.value, pto.f32) + + self.assertIsInstance(scalar_assign, SemanticAssignStmt) + self.assertIsInstance(scalar_assign.value, SemanticLiteralExpr) + self.assertAlmostEqual(scalar_assign.value.value, pto.PadValue.MAX.eval(pto.f32)) + self.assertIsInstance(scalar_assign.targets[0].type, SemanticScalarType) + self.assertEqual(scalar_assign.targets[0].type.dtype, pto.f32) + + def test_static_dtype_binding_supports_constructor_call_surface(self) -> None: + @pto.vkernel(op="static_dtype_binding_constructor_unique", dtypes=[(pto.i32,)]) + def kernel(tile: pto.Tile): + idx_dtype = tile.element_type + cols = tile.shape[1] + zero_idx = idx_dtype(0) + v_col = idx_dtype(cols) + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(8, 16), + memory_space=pto.MemorySpace.UB, + ) + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + dtype_assign, cols_assign, zero_assign, cast_assign = semantic_kernel.body[:4] + + self.assertIsInstance(dtype_assign, SemanticAssignStmt) + self.assertIsInstance(dtype_assign.value, SemanticSymbolExpr) + self.assertEqual(dtype_assign.value.value, pto.i32) + + self.assertIsInstance(cols_assign, SemanticAssignStmt) + self.assertIsInstance(cols_assign.targets[0].type, SemanticIndexType) + + self.assertIsInstance(zero_assign, SemanticAssignStmt) + self.assertIsInstance(zero_assign.value, SemanticLiteralExpr) + self.assertEqual(zero_assign.value.value, 0) + self.assertIsInstance(zero_assign.targets[0].type, SemanticScalarType) + self.assertEqual(zero_assign.targets[0].type.dtype, pto.i32) + + self.assertIsInstance(cast_assign, SemanticAssignStmt) + self.assertIsInstance(cast_assign.value, SemanticCallExpr) + self.assertEqual(cast_assign.value.namespace, "pto") + self.assertEqual(cast_assign.value.name, "i32") + self.assertIsInstance(cast_assign.targets[0].type, SemanticScalarType) + self.assertEqual(cast_assign.targets[0].type.dtype, pto.i32) + + def test_unsigned_integer_constants_lower_with_signless_arith_types(self) -> None: + @pto.vkernel(op="tile_pad_value_ui32_max_eval_unique", dtypes=[(pto.ui32,)]) + def kernel(tile: pto.Tile): + scalar = tile.pad_value.eval() + explicit = pto.ui32(4294967295) + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(8, 16), + memory_space=pto.MemorySpace.UB, + config=pto.TileConfig.from_mapping( + { + "pad_value": pto.PadValue.MAX, + } + ), + ) + ) + + text = specialized.mlir_text() + self.assertIn("dtype=ui32", text) + self.assertIn("arith.constant 4294967295 : i32", text) + self.assertNotIn("arith.constant 4294967295 : ui32", text) + + def test_cached_unsigned_integer_constructor_constant_preserves_typed_bridge(self) -> None: + @pto.vkernel( + op="cached_ui16_constructor_constant_bridge_unique", + dtypes=[(pto.ui16, pto.ui16)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + all_mask = pto.make_mask(pto.ui16, pto.PAT.ALL) + vec = pto.vlds(src, 0) + biased = pto.vadds(vec, pto.ui16(1), all_mask) + out = pto.vadds(biased, pto.ui16(1), all_mask) + pto.vsts(out, dst, 0, all_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertEqual(text.count("arith.constant 1 : i16"), 1) + self.assertEqual(text.count(": i16 to ui16"), 1) + self.assertNotIn("arith.constant 1 : ui16", text) + + def test_narrow_typed_integer_zero_constructors_lower_with_signless_bridge(self) -> None: + @pto.vkernel(op="si16_zero_constructor_bridge_unique", dtypes=[(pto.si16, pto.si16)], advanced=True) + def si16_kernel(dst: pto.Tile, src: pto.Tile): + all_mask = pto.make_mask(pto.si16, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vadds(vec, pto.si16(0), all_mask) + pto.vsts(out, dst, 0, all_mask) + return None + + @pto.vkernel(op="ui16_zero_constructor_bridge_unique", dtypes=[(pto.ui16, pto.ui16)], advanced=True) + def ui16_kernel(dst: pto.Tile, src: pto.Tile): + all_mask = pto.make_mask(pto.ui16, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vadds(vec, pto.ui16(0), all_mask) + pto.vsts(out, dst, 0, all_mask) + return None + + @pto.vkernel(op="si8_zero_constructor_bridge_unique", dtypes=[(pto.si8, pto.si8)], advanced=True) + def si8_kernel(dst: pto.Tile, src: pto.Tile): + all_mask = pto.make_mask(pto.si8, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vadds(vec, pto.si8(0), all_mask) + pto.vsts(out, dst, 0, all_mask) + return None + + @pto.vkernel(op="ui8_zero_constructor_bridge_unique", dtypes=[(pto.ui8, pto.ui8)], advanced=True) + def ui8_kernel(dst: pto.Tile, src: pto.Tile): + all_mask = pto.make_mask(pto.ui8, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vadds(vec, pto.ui8(0), all_mask) + pto.vsts(out, dst, 0, all_mask) + return None + + tile_specs = dict( + dst=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + for dtype_name, raw_type, kernel in ( + ("si16", "i16", si16_kernel), + ("ui16", "i16", ui16_kernel), + ("si8", "i8", si8_kernel), + ("ui8", "i8", ui8_kernel), + ): + with self.subTest(dtype=dtype_name): + text = kernel.specialize(**tile_specs).mlir_text() + self.assertEqual(text.count(f"arith.constant 0 : {raw_type}"), 1) + self.assertEqual(text.count(f": {raw_type} to {dtype_name}"), 1) + self.assertNotIn(f"arith.constant 0 : {dtype_name}", text) + + def test_unsigned_pad_value_eval_broadcast_bitcasts_signless_literal(self) -> None: + @pto.vkernel(op="tile_pad_value_ui16_vbr_unique", dtypes=[(pto.ui16,)], advanced=True) + def kernel(tile: pto.Tile): + scalar = tile.pad_value.eval() + vec = pto.vbr(scalar) + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(8, 16), + memory_space=pto.MemorySpace.UB, + config=pto.TileConfig.from_mapping( + { + "pad_value": pto.PadValue.MAX, + } + ), + ) + ) + + text = specialized.mlir_text() + self.assertIn("dtype=ui16", text) + self.assertIn("arith.constant 65535 : i16", text) + self.assertIn("builtin.unrealized_conversion_cast", text) + self.assertIn(": i16 to ui16", text) + self.assertIn("pto.vbr", text) + + def test_index_to_unsigned_scalar_constructor_bridges_via_signless_integer(self) -> None: + @pto.vkernel(op="index_to_ui32_constructor_unique", dtypes=[(pto.ui32,)], advanced=True) + def kernel(tile: pto.Tile): + cols = tile.valid_shape[1] + for col in range(0, cols, 1): + offset = pto.ui32(col) + vec = pto.vbr(offset) + mask, _ = pto.make_mask(pto.ui32, 1) + pto.vsts(vec, tile[col, 0:], mask) + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(8, 1), + valid_shape=(8, 1), + memory_space=pto.MemorySpace.UB, + config=pto.TileConfig.from_mapping( + { + "b_layout": pto.BLayout.COL_MAJOR, + "s_layout": pto.SLayout.NONE_BOX, + } + ), + ) + ) + + text = specialized.mlir_text() + self.assertIn("arith.index_castui", text) + self.assertIn(": index to i32", text) + self.assertIn("builtin.unrealized_conversion_cast", text) + self.assertIn(": i32 to ui32", text) + self.assertNotIn(": index to ui32", text) + + def test_index_to_ui16_scalar_constructor_bridges_via_signless_integer(self) -> None: + @pto.vkernel(op="index_to_ui16_constructor_unique", dtypes=[(pto.ui16,)], advanced=True) + def kernel(tile: pto.Tile): + cols = tile.valid_shape[1] + for col in range(0, cols, 1): + offset = pto.ui16(col) + vec = pto.vbr(offset) + mask, _ = pto.make_mask(pto.ui16, 1) + pto.vsts(vec, tile[col, 0:], mask) + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(8, 1), + valid_shape=(8, 1), + memory_space=pto.MemorySpace.UB, + config=pto.TileConfig.from_mapping( + { + "b_layout": pto.BLayout.COL_MAJOR, + "s_layout": pto.SLayout.NONE_BOX, + } + ), + ) + ) + + text = specialized.mlir_text() + self.assertIn("arith.index_castui", text) + self.assertIn(": index to i32", text) + self.assertIn("arith.trunci", text) + self.assertIn(": i32 to i16", text) + self.assertIn("builtin.unrealized_conversion_cast", text) + self.assertIn(": i16 to ui16", text) + self.assertNotIn(": index to ui16", text) + + def test_index_to_ui8_scalar_constructor_bridges_via_signless_integer(self) -> None: + @pto.vkernel(op="index_to_ui8_constructor_unique", dtypes=[(pto.ui8,)], advanced=True) + def kernel(tile: pto.Tile): + cols = tile.valid_shape[1] + for col in range(0, cols, 1): + offset = pto.ui8(col) + vec = pto.vbr(offset) + mask, _ = pto.make_mask(pto.ui8, 1) + pto.vsts(vec, tile[col, 0:], mask) + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(8, 1), + valid_shape=(8, 1), + memory_space=pto.MemorySpace.UB, + config=pto.TileConfig.from_mapping( + { + "b_layout": pto.BLayout.COL_MAJOR, + "s_layout": pto.SLayout.NONE_BOX, + } + ), + ) + ) + + text = specialized.mlir_text() + self.assertIn("arith.index_castui", text) + self.assertIn(": index to i32", text) + self.assertIn("arith.trunci", text) + self.assertIn(": i32 to i8", text) + self.assertIn("builtin.unrealized_conversion_cast", text) + self.assertIn(": i8 to ui8", text) + self.assertNotIn(": index to ui8", text) + + def test_index_to_si8_scalar_constructor_bridges_via_signless_integer(self) -> None: + @pto.vkernel(op="index_to_si8_constructor_unique", dtypes=[(pto.si8,)], advanced=True) + def kernel(tile: pto.Tile): + cols = tile.valid_shape[1] + for col in range(0, cols, 1): + offset = pto.si8(col) + vec = pto.vbr(offset) + mask, _ = pto.make_mask(pto.si8, 1) + pto.vsts(vec, tile[col, 0:], mask) + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(8, 1), + valid_shape=(8, 1), + memory_space=pto.MemorySpace.UB, + config=pto.TileConfig.from_mapping( + { + "b_layout": pto.BLayout.COL_MAJOR, + "s_layout": pto.SLayout.NONE_BOX, + } + ), + ) + ) + + text = specialized.mlir_text() + self.assertIn("arith.index_cast", text) + self.assertIn(": index to i32", text) + self.assertIn("arith.trunci", text) + self.assertIn(": i32 to i8", text) + self.assertIn("builtin.unrealized_conversion_cast", text) + self.assertIn(": i8 to si8", text) + self.assertNotIn(": index to si8", text) + + def test_index_to_i16_scalar_constructor_lowers_via_index_cast_then_trunci(self) -> None: + @pto.vkernel(op="index_to_i16_constructor_unique", dtypes=[(pto.i16,)], advanced=True) + def kernel(tile: pto.Tile): + cols = tile.valid_shape[1] + for col in range(0, cols, 1): + offset = pto.i16(col) + vec = pto.vbr(offset) + mask, _ = pto.make_mask(pto.i16, 1) + pto.vsts(vec, tile[col, 0:], mask) + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(8, 1), + valid_shape=(8, 1), + memory_space=pto.MemorySpace.UB, + config=pto.TileConfig.from_mapping( + { + "b_layout": pto.BLayout.COL_MAJOR, + "s_layout": pto.SLayout.NONE_BOX, + } + ), + ) + ) + + text = specialized.mlir_text() + self.assertIn("arith.index_cast", text) + self.assertIn(": index to i32", text) + self.assertIn("arith.trunci", text) + self.assertIn(": i32 to i16", text) + self.assertNotIn("builtin.unrealized_conversion_cast", text) + self.assertNotIn(": index to i16", text) + + def test_index_to_si16_scalar_constructor_bridges_via_signless_integer(self) -> None: + @pto.vkernel(op="index_to_si16_constructor_unique", dtypes=[(pto.si16,)], advanced=True) + def kernel(tile: pto.Tile): + cols = tile.valid_shape[1] + for col in range(0, cols, 1): + offset = pto.si16(col) + vec = pto.vbr(offset) + mask, _ = pto.make_mask(pto.si16, 1) + pto.vsts(vec, tile[col, 0:], mask) + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(8, 1), + valid_shape=(8, 1), + memory_space=pto.MemorySpace.UB, + config=pto.TileConfig.from_mapping( + { + "b_layout": pto.BLayout.COL_MAJOR, + "s_layout": pto.SLayout.NONE_BOX, + } + ), + ) + ) + + text = specialized.mlir_text() + self.assertIn("arith.index_cast", text) + self.assertIn(": index to i32", text) + self.assertIn("arith.trunci", text) + self.assertIn(": i32 to i16", text) + self.assertIn("builtin.unrealized_conversion_cast", text) + self.assertIn(": i16 to si16", text) + self.assertNotIn(": index to si16", text) + + def test_index_to_32bit_integer_scalar_constructors_bridge_via_signless_integer(self) -> None: + @pto.vkernel(op="index_to_i32_constructor_unique", dtypes=[(pto.i32,)], advanced=True) + def i32_kernel(tile: pto.Tile): + cols = tile.valid_shape[1] + for col in range(0, cols, 1): + offset = pto.i32(col) + vec = pto.vbr(offset) + mask, _ = pto.make_mask(pto.i32, 1) + pto.vsts(vec, tile[col, 0:], mask) + return None + + @pto.vkernel(op="index_to_si32_constructor_unique", dtypes=[(pto.si32,)], advanced=True) + def si32_kernel(tile: pto.Tile): + cols = tile.valid_shape[1] + for col in range(0, cols, 1): + offset = pto.si32(col) + vec = pto.vbr(offset) + mask, _ = pto.make_mask(pto.si32, 1) + pto.vsts(vec, tile[col, 0:], mask) + return None + + @pto.vkernel(op="index_to_ui32_constructor_bridge_unique", dtypes=[(pto.ui32,)], advanced=True) + def ui32_kernel(tile: pto.Tile): + cols = tile.valid_shape[1] + for col in range(0, cols, 1): + offset = pto.ui32(col) + vec = pto.vbr(offset) + mask, _ = pto.make_mask(pto.ui32, 1) + pto.vsts(vec, tile[col, 0:], mask) + return None + + tile_spec = pto.TileSpecialization( + shape=(8, 1), + valid_shape=(8, 1), + memory_space=pto.MemorySpace.UB, + config=pto.TileConfig.from_mapping( + { + "b_layout": pto.BLayout.COL_MAJOR, + "s_layout": pto.SLayout.NONE_BOX, + } + ), + ) + + for dtype_name, op_name, kernel in ( + ("i32", "arith.index_cast", i32_kernel), + ("si32", "arith.index_cast", si32_kernel), + ("ui32", "arith.index_castui", ui32_kernel), + ): + with self.subTest(dtype=dtype_name): + text = kernel.specialize(tile=tile_spec).mlir_text() + self.assertIn(op_name, text) + self.assertIn(": index to i32", text) + if dtype_name == "i32": + self.assertNotIn("builtin.unrealized_conversion_cast", text) + else: + self.assertIn("builtin.unrealized_conversion_cast", text) + self.assertIn(f": i32 to {dtype_name}", text) + self.assertNotIn(f": index to {dtype_name}", text) + + def test_index_to_64bit_integer_scalar_constructors_bridge_via_signless_integer(self) -> None: + @pto.vkernel(op="index_to_i64_constructor_unique", dtypes=[(pto.i64,)], advanced=True) + def i64_kernel(tile: pto.Tile): + cols = tile.valid_shape[1] + for col in range(0, cols, 1): + offset = pto.i64(col) + vec = pto.vbr(offset) + mask, _ = pto.make_mask(pto.i64, 1) + pto.vsts(vec, tile[col, 0:], mask) + return None + + @pto.vkernel(op="index_to_si64_constructor_unique", dtypes=[(pto.si64,)], advanced=True) + def si64_kernel(tile: pto.Tile): + cols = tile.valid_shape[1] + for col in range(0, cols, 1): + offset = pto.si64(col) + vec = pto.vbr(offset) + mask, _ = pto.make_mask(pto.si64, 1) + pto.vsts(vec, tile[col, 0:], mask) + return None + + @pto.vkernel(op="index_to_ui64_constructor_unique", dtypes=[(pto.ui64,)], advanced=True) + def ui64_kernel(tile: pto.Tile): + cols = tile.valid_shape[1] + for col in range(0, cols, 1): + offset = pto.ui64(col) + vec = pto.vbr(offset) + mask, _ = pto.make_mask(pto.ui64, 1) + pto.vsts(vec, tile[col, 0:], mask) + return None + + tile_spec = pto.TileSpecialization( + shape=(8, 1), + valid_shape=(8, 1), + memory_space=pto.MemorySpace.UB, + config=pto.TileConfig.from_mapping( + { + "b_layout": pto.BLayout.COL_MAJOR, + "s_layout": pto.SLayout.NONE_BOX, + } + ), + ) + + for dtype_name, op_name, kernel in ( + ("i64", "arith.index_cast", i64_kernel), + ("si64", "arith.index_cast", si64_kernel), + ("ui64", "arith.index_castui", ui64_kernel), + ): + with self.subTest(dtype=dtype_name): + text = kernel.specialize(tile=tile_spec).mlir_text() + self.assertIn(op_name, text) + self.assertIn(": index to i64", text) + if dtype_name == "i64": + self.assertNotIn("builtin.unrealized_conversion_cast", text) + else: + self.assertIn("builtin.unrealized_conversion_cast", text) + self.assertIn(f": i64 to {dtype_name}", text) + self.assertNotIn(f": index to {dtype_name}", text) + + + def test_make_mask_vlds_vsts_and_vector_families_lower_inside_strict_vecscope(self) -> None: + @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32)], advanced=True) + def kernel(tile: pto.Tile, scale: pto.f32): + with pto.strict_vecscope(tile, tile, scale, 0, 256, 64) as ( + src, + dst, + factor, + lb, + ub, + step, + ): + for lane in range(lb, ub, step): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + vec = pto.vlds(src, lane) + biased = pto.vadds(vec, factor, mask) + summed = pto.vadd(biased, vec, mask) + activated = pto.vrelu(summed, mask) + pto.vsts(activated, dst, lane, mask) + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(16, 16), + memory_space=pto.MemorySpace.UB, + ) + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + vecscope = semantic_kernel.body[0] + self.assertIsInstance(vecscope, SemanticStrictVecscopeStmt) + loop_stmt = vecscope.body[0] + self.assertIsInstance(loop_stmt, SemanticForStmt) + mask_assign = loop_stmt.body[0] + self.assertIsInstance(mask_assign, SemanticAssignStmt) + self.assertIsInstance(mask_assign.value, SemanticCallExpr) + self.assertEqual(mask_assign.value.name, "make_mask") + self.assertIsInstance(mask_assign.targets[0].type, SemanticMaskType) + self.assertIsInstance(loop_stmt.body[-1], SemanticVectorStoreStmt) + + text = specialized.mlir_text() + self.assertRegex(text, r'%mask_\d+ = pto\.pset_b32 "PAT_ALL" : !pto\.mask') + self.assertRegex(text, r"%vec_\d+ = pto\.vlds %src_\d+\[%lane_\d+\] : !pto\.ptr -> !pto\.vreg<64xf32>") + self.assertRegex(text, r"%biased_\d+ = pto\.vadds %vec_\d+, %factor_\d+, %mask_\d+ : !pto\.vreg<64xf32>, f32, !pto\.mask -> !pto\.vreg<64xf32>") + self.assertRegex(text, r"%summed_\d+ = pto\.vadd %biased_\d+, %vec_\d+, %mask_\d+ : !pto\.vreg<64xf32>, !pto\.vreg<64xf32>, !pto\.mask -> !pto\.vreg<64xf32>") + self.assertRegex(text, r"%activated_\d+ = pto\.vrelu %summed_\d+, %mask_\d+ : !pto\.vreg<64xf32>, !pto\.mask -> !pto\.vreg<64xf32>") + self.assertRegex(text, r"pto\.vsts %activated_\d+, %dst_\d+\[%lane_\d+\], %mask_\d+ : !pto\.vreg<64xf32>, !pto\.ptr, !pto\.mask") + + def test_vrelu_accepts_i32_inside_strict_vecscope(self) -> None: + @pto.vkernel(op="eltwise", dtypes=[(pto.i32, pto.i32)], advanced=True) + def kernel(tile: pto.Tile, bias: pto.i32): + with pto.strict_vecscope(tile, tile, bias, 0, 256, 64) as ( + src, + dst, + offset, + lb, + ub, + step, + ): + for lane in range(lb, ub, step): + mask = pto.make_mask(pto.i32, pto.PAT.ALL) + vec = pto.vlds(src, lane) + shifted = pto.vadds(vec, offset, mask) + activated = pto.vrelu(shifted, mask) + pto.vsts(activated, dst, lane, mask) + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(16, 16), + memory_space=pto.MemorySpace.UB, + ), + ) + + text = specialized.mlir_text() + self.assertRegex(text, r'%mask_\d+ = pto\.pset_b32 "PAT_ALL" : !pto\.mask') + self.assertRegex(text, r"%vec_\d+ = pto\.vlds %src_\d+\[%lane_\d+\] : !pto\.ptr -> !pto\.vreg<64xi32>") + self.assertRegex(text, r"%shifted_\d+ = pto\.vadds %vec_\d+, %offset_\d+, %mask_\d+ : !pto\.vreg<64xi32>, i32, !pto\.mask -> !pto\.vreg<64xi32>") + self.assertRegex(text, r"%activated_\d+ = pto\.vrelu %shifted_\d+, %mask_\d+ : !pto\.vreg<64xi32>, !pto\.mask -> !pto\.vreg<64xi32>") + self.assertRegex(text, r"pto\.vsts %activated_\d+, %dst_\d+\[%lane_\d+\], %mask_\d+ : !pto\.vreg<64xi32>, !pto\.ptr, !pto\.mask") + + def test_tail_make_mask_lowers_to_typed_plt_and_updates_remaining(self) -> None: + @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.i32)], advanced=True) + def kernel(tile: pto.Tile, remaining: pto.i32): + with pto.strict_vecscope(tile, tile, remaining, 0, 64, 64) as (src, dst, rem_in, lb, ub, step): + mask, next_remaining = pto.make_mask(pto.f32, rem_in) + vec = pto.vlds(src, lb) + pto.vsts(vec, dst, lb, mask) + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(16, 16), + memory_space=pto.MemorySpace.UB, + ) + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + vecscope = semantic_kernel.body[0] + self.assertIsInstance(vecscope, SemanticStrictVecscopeStmt) + mask_assign = vecscope.body[0] + self.assertIsInstance(mask_assign, SemanticAssignStmt) + self.assertEqual(mask_assign.value.name, "make_mask") + self.assertEqual(len(mask_assign.targets), 2) + self.assertIsInstance(mask_assign.targets[0].type, SemanticMaskType) + self.assertIsInstance(mask_assign.targets[1].type, SemanticScalarType) + self.assertEqual(mask_assign.targets[1].type.dtype, pto.i32) + + text = specialized.mlir_text() + self.assertRegex( + text, + r"%mask_\d+, %next_remaining_\d+ = pto\.plt_b32 %rem_in_\d+ : i32 -> !pto\.mask, i32", + ) + self.assertIn( + "pto.vsts %vec_", + text, + ) + + def test_nested_index_arithmetic_lowers_before_vector_accesses(self) -> None: + @pto.vkernel( + op="eltwise", + dtypes=[(pto.f32, pto.f32, pto.f32)], + advanced=True, + ) + def kernel( + lhs_tile: pto.Tile, + rhs_tile: pto.Tile, + dst_tile: pto.Tile, + ): + rows = lhs_tile.shape[0] + cols = lhs_tile.shape[1] + row_stride = lhs_tile.shape[1] + + with pto.strict_vecscope( + lhs_tile, + rhs_tile, + dst_tile, + rows, + cols, + row_stride, + 0, + rows, + 1, + ) as (lhs, rhs, dst, valid_rows, valid_cols, stride, row_lb, row_ub, row_step): + for row in range(row_lb, row_ub, row_step): + for lane in range(0, valid_cols, 64): + offset = row * stride + lane + mask, next_remaining = pto.make_mask(pto.f32, valid_cols - lane) + summed = pto.vadd(pto.vlds(lhs, offset), pto.vlds(rhs, offset), mask) + pto.vsts(summed, dst, offset, mask) + return None + + specialized = kernel.specialize( + lhs_tile=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + rhs_tile=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + dst_tile=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertRegex(text, r"%tmp_\d+ = arith\.muli %row_\d+, %stride_\d+ : index") + self.assertRegex(text, r"%offset_\d+ = arith\.addi %tmp_\d+, %lane_\d+ : index") + self.assertRegex(text, r"%tmp_\d+ = arith\.subi %valid_cols_\d+, %lane_\d+ : index") + self.assertRegex(text, r"%tmp_\d+ = arith\.index_cast %tmp_\d+ : index to i32") + self.assertIn("pto.plt_b32", text) + self.assertIn("pto.vadd", text) + + def test_scalar_binary_arithmetic_supports_float_and_integer_paths(self) -> None: + @pto.vkernel( + op="scalar_binary_arithmetic_unique", + dtypes=[(pto.f32, pto.f32, pto.i32)], + advanced=True, + ) + def kernel(dst_tile: pto.Tile, src_tile: pto.Tile, gate: pto.i32): + rows = src_tile.shape[0] + cols = src_tile.shape[1] + with pto.strict_vecscope( + src_tile, + dst_tile, + gate, + rows, + cols, + 0, + rows, + 1, + ) as (src, dst, in_gate, valid_rows, valid_cols, row_lb, row_ub, row_step): + for row in range(row_lb, row_ub, row_step): + for lane in range(0, valid_cols, 64): + half = in_gate // pto.i32(2) + remain = in_gate % pto.i32(7) + factor = pto.f32(half) + pto.f32(remain) * pto.f32(0.5) + mask, _ = pto.make_mask(pto.f32, valid_cols - lane) + vec = pto.vlds(src, lane) + vec = pto.vmuls(vec, factor, mask) + pto.vsts(vec, dst, lane, mask) + return None + + specialized = kernel.specialize( + dst_tile=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src_tile=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertRegex(text, r"= arith\.floordivsi %in_gate_\d+, %c2_i32 : i32") + self.assertRegex(text, r"= arith\.remsi %in_gate_\d+, %c7_i32 : i32") + self.assertRegex(text, r"%c0_5_f32 = arith\.constant 0\.5 : f32") + self.assertRegex(text, r"= arith\.mulf %tmp_\d+, %c0_5_f32 : f32") + self.assertRegex(text, r"= arith\.addf %tmp_\d+, %tmp_\d+ : f32") + + def test_index_and_i32_scalar_binary_ops_bridge_index_literals(self) -> None: + @pto.vkernel( + op="index_i32_scalar_binary_bridge_unique", + dtypes=[(pto.f32, pto.AnyType, pto.f32)], + advanced=True, + ) + def kernel(src: pto.Tile, gate: pto.AnyType, dst: pto.Tile): + rows = src.shape[0] + cols = src.shape[1] + with pto.strict_vecscope( + src, + dst, + gate, + rows, + cols, + 0, + rows, + 1, + ) as (src_tile, dst_tile, in_gate, valid_rows, valid_cols, row_lb, row_ub, row_step): + for row in range(row_lb, row_ub, row_step): + if in_gate > 1: + for lane in range(0, valid_cols, 64): + lane_limit = in_gate + 1 + mask, _ = pto.make_mask(pto.f32, lane_limit) + vec = pto.vlds(src_tile, lane) + pto.vsts(vec, dst_tile, lane, mask) + return None + + selected = pto.select_kernel( + "a5", + "index_i32_scalar_binary_bridge_unique", + (pto.f32, pto.i32, pto.f32), + ) + specialized = selected.specialize( + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("%c1_i32 = arith.constant 1 : i32", text) + self.assertRegex(text, r"%tmp_\d+ = arith\.cmpi sgt, %in_gate_\d+, %c1_i32 : i32") + self.assertRegex(text, r"%\w+_\d+ = arith\.addi %in_gate_\d+, %c1_i32 : i32") + + def test_binary_literals_follow_scalar_operand_types(self) -> None: + @pto.vkernel(op="index_i16_scalar_binary_infer_unique", dtypes=[(pto.i16,)], advanced=True) + def i16_kernel(gate: pto.i16): + _ = gate + 1 + _ = gate > 2 + return None + + @pto.vkernel(op="index_f32_scalar_binary_infer_unique", dtypes=[(pto.f32,)], advanced=True) + def f32_kernel(gate: pto.f32): + _ = gate + 1 + _ = gate > 2.5 + return None + + i16_text = i16_kernel.specialize().mlir_text() + self.assertIn("%c1_i16 = arith.constant 1 : i16", i16_text) + self.assertIn("%c2_i16 = arith.constant 2 : i16", i16_text) + self.assertRegex(i16_text, r"= arith\.addi %arg0, %c1_i16 : i16") + self.assertRegex(i16_text, r"= arith\.cmpi sgt, %arg0, %c2_i16 : i16") + + f32_text = f32_kernel.specialize().mlir_text() + self.assertRegex(f32_text, r"%c1(?:_0)?_f32 = arith\.constant 1(?:\.0+)? : f32") + self.assertRegex(f32_text, r"%c2_5_f32 = arith\.constant 2(?:\.5+)? : f32") + self.assertRegex(f32_text, r"= arith\.addf %arg0, %c1(?:_0)?_f32 : f32") + self.assertRegex(f32_text, r"= arith\.cmpf ogt, %arg0, %c2_5_f32 : f32") + + def test_index_floordiv_lowers_to_divui_instead_of_floordivsi(self) -> None: + @pto.vkernel( + op="index_floordiv_lowering_unique", + dtypes=[(pto.f32, pto.f32, pto.f32)], + advanced=True, + ) + def kernel( + lhs_tile: pto.Tile, + rhs_tile: pto.Tile, + dst_tile: pto.Tile, + ): + rows = lhs_tile.shape[0] + cols = lhs_tile.shape[1] + with pto.strict_vecscope( + lhs_tile, + rhs_tile, + dst_tile, + rows, + cols, + 0, + rows, + 1, + ) as (lhs, rhs, dst, valid_rows, valid_cols, row_lb, row_ub, row_step): + for row in range(row_lb, row_ub, row_step): + for lane in range(0, valid_cols, 64): + row_bucket = row // valid_cols + offset = row_bucket * valid_cols + lane + mask, _ = pto.make_mask(pto.f32, valid_cols - lane) + summed = pto.vadd(pto.vlds(lhs, offset), pto.vlds(rhs, offset), mask) + pto.vsts(summed, dst, offset, mask) + return None + + specialized = kernel.specialize( + lhs_tile=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + rhs_tile=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + dst_tile=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertRegex(text, r"= arith\.divui %row_\d+, %valid_cols_\d+ : index") + self.assertNotRegex(text, r"arith\.floordivsi .*: index") + + def test_scalar_bitwise_and_shift_ops_lower_for_signed_and_unsigned(self) -> None: + @pto.vkernel( + op="scalar_bitwise_shift_unique", + dtypes=[(pto.i32, pto.ui32)], + advanced=True, + ) + def kernel(signed_val: pto.i32, unsigned_val: pto.ui32): + signed_mix = (signed_val & pto.i32(15)) | pto.i32(1) + signed_mix = signed_mix ^ pto.i32(2) + signed_mix = signed_mix >> pto.i32(1) + signed_mix = signed_mix << pto.i32(3) + + unsigned_mix = unsigned_val & pto.ui32(31) + unsigned_mix = unsigned_mix >> pto.ui32(2) + unsigned_mix = unsigned_mix << pto.ui32(1) + unsigned_mix = unsigned_mix ^ pto.ui32(7) + return None + + specialized = kernel.specialize() + text = specialized.mlir_text() + + self.assertIn("arith.andi", text) + self.assertIn("arith.ori", text) + self.assertIn("arith.xori", text) + self.assertRegex(text, r"= arith\.shrsi %\w+_\d+, %c1_i32 : i32") + self.assertRegex(text, r"= arith\.shli %\w+_\d+, %c3_i32 : i32") + self.assertRegex(text, r"= arith\.shrui %\w+_\d+, %c2_ui32 : ui32") + self.assertRegex(text, r"= arith\.shli %\w+_\d+, %c1_ui32 : ui32") + + def test_scalar_bitwise_rejects_float_operands(self) -> None: + @pto.vkernel(op="scalar_bitwise_float_reject_unique", dtypes=[(pto.f32,)]) + def kernel(value: pto.f32): + _ = value & pto.f32(1.0) + return None + + specialized = kernel.specialize() + with self.assertRaises(TypeError) as ctx: + specialized.mlir_text() + self.assertIn("mod/floordiv/bitwise/shift for integer", str(ctx.exception)) + + def test_integer_scalars_implicitly_cast_in_index_contexts(self) -> None: + @pto.vkernel( + op="index_context_integer_scalar_cast_unique", + dtypes=[(pto.f32, pto.si16, pto.ui8)], + advanced=True, + ) + def kernel(src: pto.Tile, row: pto.si16, col: pto.ui8): + _ = pto.vlds(src[row, col:]) + return None + + specialized = kernel.specialize( + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertRegex(text, r"= arith\.extsi %tmp_\d+ : i16 to i32") + self.assertRegex(text, r"= arith\.index_cast %tmp_\d+ : i32 to index") + self.assertRegex(text, r"= arith\.extui %tmp_\d+ : i8 to i32") + self.assertRegex(text, r"= arith\.index_cast %tmp_\d+ : i32 to index") + self.assertRegex(text, r"memref\.subview %tmp_\d+\[%tmp_\d+, %tmp_\d+\]") + + def test_stable_mode_lowers_tile_vector_sugar_without_frontend_vecscope(self) -> None: + @pto.vkernel(op="tadd_stable", dtypes=[(pto.f32, pto.f32, pto.f32)]) + def kernel(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): + dtype = dst.element_type + rows, cols = dst.valid_shape + all_mask = pto.make_mask(dtype, pto.PAT.ALL) + for row in range(0, rows, 1): + for col in range(0, cols, pto.get_lanes(dtype)): + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + summed = pto.vadd(lhs, rhs, all_mask) + pto.vsts(summed, dst[row, col:], all_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src0=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src1=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + vecscope_stmts = [stmt for stmt in semantic_kernel.body if isinstance(stmt, SemanticVecscopeStmt)] + self.assertEqual(len(vecscope_stmts), 0) + + text = specialized.mlir_text() + self.assertNotIn("pto.vecscope {", text) + self.assertNotIn("pto.strict_vecscope(", text) + self.assertRegex(text, r"memref\.subview %tmp_\d+\[%row_\d+, %col_\d+\] \[%c1, %tmp_\d+\] \[%c1, %c1\]") + self.assertRegex(text, r"pto\.vlds %tmp_\d+\[%c0\]") + self.assertRegex(text, r"pto\.vsts %summed_\d+, %tmp_\d+\[%c0\], %(?:all_mask|mask)_\d+") + + def test_advanced_mode_lowers_tile_vector_sugar_without_frontend_vecscope(self) -> None: + @pto.vkernel(op="tadd", dtypes=[(pto.f32, pto.f32, pto.f32)], advanced=True) + def kernel(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): + dtype = dst.element_type + rows, cols = dst.valid_shape + all_mask = pto.make_mask(dtype, pto.PAT.ALL) + for row in range(0, rows, 1): + for col in range(0, cols, pto.get_lanes(dtype)): + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + summed = pto.vadd(lhs, rhs, all_mask) + pto.vsts(summed, dst[row, col:], all_mask) + return None + + self.assertTrue(kernel.advanced_enabled) + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src0=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src1=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + vecscope_stmts = [stmt for stmt in semantic_kernel.body if isinstance(stmt, SemanticVecscopeStmt)] + self.assertEqual(len(vecscope_stmts), 0) + outer_loop = next(stmt for stmt in semantic_kernel.body if isinstance(stmt, SemanticForStmt)) + self.assertIsInstance(outer_loop, SemanticForStmt) + inner_loop = outer_loop.body[0] + self.assertIsInstance(inner_loop, SemanticForStmt) + self.assertTrue(inner_loop.body) + + text = specialized.mlir_text() + self.assertIn("// tilelang.advanced = True", text) + self.assertNotIn("pto.vecscope {", text) + self.assertNotIn("pto.strict_vecscope(", text) + self.assertIn("!pto.tile_buf> to memref<\?x\?xf32, strided<\[\?, \?\], offset: \?>, #pto\.address_space>") + self.assertRegex(text, r"pto\.vlds %tmp_\d+\[%c0\] : memref<\?x\?xf32, strided<\[\?, \?\], offset: \?>, #pto\.address_space> -> !pto\.vreg<64xf32>") + self.assertRegex(text, r"pto\.vsts %summed_\d+, %tmp_\d+\[%c0\], %(?:all_mask|mask)_\d+ : !pto\.vreg<64xf32>, memref<\?x\?xf32, strided<\[\?, \?\], offset: \?>, #pto\.address_space>, !pto\.mask") + self.assertNotRegex(text, r"arith\.muli %row_\d+, %c64 : index") + self.assertNotRegex(text, r"arith\.addi %tmp_\d+, %col_\d+ : index") + self.assertLess(text.index("pto.tile_buf_addr %arg1"), text.index("scf.for %row_")) + self.assertLess(text.index("pto.tile_buf_addr %arg2"), text.index("scf.for %row_")) + self.assertLess(text.index("pto.tile_buf_addr %arg0"), text.index("scf.for %row_")) + self.assertLess(text.index("pto.tile_valid_rows %arg0"), text.index("scf.for %row_")) + self.assertLess(text.index("pto.tile_valid_cols %arg0"), text.index("scf.for %row_")) + + def test_element_type_valid_shape_and_get_lanes_surface_lower_in_advanced_mode(self) -> None: + @pto.vkernel(op="tadd", dtypes=[(pto.f32, pto.f32, pto.f32)], advanced=True) + def kernel(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + remained = valid_cols + for row in range(0, valid_rows, 1): + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + summed = pto.vadd(pto.vlds(src0[row, col:]), pto.vlds(src1[row, col:]), mask) + pto.vsts(summed, dst[row, col:], mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src0=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src1=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("step %c64", text) + self.assertRegex(text, r"%mask_\d+, %remained_\d+ = pto\.plt_b32 %remained_iter_\d+ : i32 -> !pto\.mask, i32") + self.assertIn("pto.vadd", text) + self.assertIn("pto.vsts", text) + self.assertIn("pto.tile_valid_rows %arg0", text) + self.assertIn("pto.tile_valid_cols %arg0", text) + self.assertRegex(text, r"memref\.subview %tmp_\d+\[%row_\d+, %col_\d+\] \[%c1, %tmp_\d+\] \[%c1, %c1\]") + self.assertRegex(text, r"pto\.vlds %tmp_\d+\[%c0\]") + self.assertRegex(text, r"pto\.vsts %summed_\d+, %tmp_\d+\[%c0\], %mask_\d+") + + def test_bytewidth_surface_lowers_to_constant_index(self) -> None: + @pto.vkernel(op="bytewidth_query_unique", dtypes=[(pto.f32,)], advanced=True) + def kernel(dst: pto.Tile): + elem_bytes = pto.bytewidth(dst.element_type) + rows, cols = dst.valid_shape + for col in range(0, cols, elem_bytes): + current = col + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("= arith.constant 4 : index", text) + self.assertRegex(text, r"scf\.for %col_\d+ = %c0 to %cols_\d+ step %elem_bytes_\d+") + self.assertIn("pto.tile_valid_cols %arg0", text) + + def test_elements_per_vreg_alias_surface_lowers_to_constant_index(self) -> None: + @pto.vkernel(op="elements_per_vreg_query_unique", dtypes=[(pto.f32,)], advanced=True) + def kernel(dst: pto.Tile): + lanes = pto.elements_per_vreg(dst.element_type) + rows, cols = dst.valid_shape + for col in range(0, cols, lanes): + current = col + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("= arith.constant 64 : index", text) + self.assertRegex(text, r"scf\.for %col_\d+ = %c0 to %cols_\d+ step %lanes_\d+") + self.assertIn("pto.tile_valid_cols %arg0", text) + + def test_vreg_type_constructor_and_annotation_match_vector_value(self) -> None: + @pto.vkernel(op="vreg_type_annotation_unique", dtypes=[(pto.f32,)], advanced=True) + def kernel(dst: pto.Tile): + dtype = dst.element_type + vec_ty = pto.vreg(dtype) + mask = pto.make_mask(dtype, pto.PAT.ALL) + vec: pto.vreg(dtype) = pto.vlds(dst, 0) + pto.vsts(vec, dst, 0, mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + self.assertFalse(any(isinstance(stmt, SemanticVecscopeStmt) for stmt in semantic_kernel.body)) + vec_assign = next( + stmt + for stmt in _walk_semantic_stmts(semantic_kernel.body) + if isinstance(stmt, SemanticAssignStmt) + and stmt.targets[0].name == "vec" + ) + self.assertIsInstance(vec_assign.targets[0].type, SemanticVRegType) + self.assertEqual(vec_assign.targets[0].type.element_dtype, pto.f32) + self.assertEqual(vec_assign.targets[0].type.lanes, 64) + self.assertTrue( + any( + isinstance(stmt, SemanticAssignStmt) + and stmt.targets[0].name == "vec_ty" + for stmt in semantic_kernel.body + ) + ) + + text = specialized.mlir_text() + self.assertRegex(text, r"%vec_\d+ = pto\.vlds %tmp_\d+\[%c0\] : memref<8x64xf32, #pto\.address_space> -> !pto\.vreg<64xf32>") + self.assertRegex(text, r"pto\.vsts %vec_\d+, %tmp_\d+\[%c0\], %mask_\d+ : !pto\.vreg<64xf32>, memref<8x64xf32, #pto\.address_space>, !pto\.mask") + + def test_mask_type_annotation_matches_make_mask_result(self) -> None: + @pto.vkernel(op="mask_type_annotation_unique", dtypes=[(pto.f32,)], advanced=True) + def kernel(dst: pto.Tile): + mask_ty = pto.mask_b32 + mask: pto.mask_b32 = pto.make_mask(pto.f32, pto.PAT.ALL) + alias_mask: mask_ty = mask + vec: pto.vreg(pto.f32) = pto.vlds(dst, 0) + pto.vsts(vec, dst, 0, alias_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertRegex(text, r'%mask_\d+ = pto\.pset_b32 "PAT_ALL" : !pto\.mask') + self.assertRegex(text, r"pto\.vsts %vec_\d+, %tmp_\d+\[%c0\], %\w+ : !pto\.vreg<64xf32>, memref<8x64xf32, #pto\.address_space>, !pto\.mask") + + def test_extended_float_vector_ops_surface_lowers(self) -> None: + @pto.vkernel( + op="extended_float_vector_ops_unique", + dtypes=[(pto.f32, pto.f32, pto.f32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile, alpha: pto.f32): + all_mask = pto.make_mask(pto.f32, pto.PAT.ALL) + vec0 = pto.vlds(src, 0) + vec1 = pto.vlds(src, 64) + vec2 = pto.vlds(src, 128) + vec3 = pto.vlds(src, 192) + + out = pto.vln(vec0, all_mask) + out = pto.vsqrt(out, all_mask) + out = pto.vrec(out, all_mask) + out = pto.vrsqrt(out, all_mask) + out = pto.vexpdif(out, vec1, all_mask, pto.VcvtPartMode.ODD) + out = pto.vcadd(out, all_mask) + out = pto.vcmax(out, all_mask) + out = pto.vcmin(out, all_mask) + out = pto.vmov(out, all_mask) + out = pto.vtrc(out, all_mask) + out = pto.vprelu(out, vec1, all_mask) + out = pto.vlrelu(out, alpha, all_mask) + out = pto.vcvt(out, pto.f32, all_mask) + pto.vsts(out, dst, 0, all_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 256), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 256), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("pto.vln", text) + self.assertIn("pto.vsqrt", text) + self.assertIn("pto.vrec", text) + self.assertIn("pto.vrsqrt", text) + self.assertIn("pto.vexpdif", text) + self.assertIn("pto.vcadd", text) + self.assertIn("pto.vcmax", text) + self.assertIn("pto.vcmin", text) + self.assertIn("pto.vmov", text) + self.assertIn("pto.vtrc", text) + self.assertIn("pto.vprelu", text) + self.assertIn("pto.vlrelu", text) + self.assertIn("pto.vcvt", text) + + def test_vexpdif_f16_surface_lowers_to_f32_half_lanes(self) -> None: + @pto.vkernel( + op="vexpdif_f16_surface_unique", + dtypes=[(pto.f32, pto.f16, pto.f16)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile, max_src: pto.Tile): + vec = pto.vlds(src, 0) + max_vec = pto.vlds(max_src, 0) + mask = pto.make_mask(pto.f16, pto.PAT.ALL) + out = pto.vexpdif(vec, max_vec, mask, pto.VcvtPartMode.ODD) + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + pto.vsts(out, dst, 0, mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + max_src=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertRegex( + text, + r'pto\.vexpdif %\w+_\d+, %\w+_\d+, %\w+_\d+, "ODD" : !pto\.vreg<128xf16>, !pto\.vreg<128xf16>, !pto\.mask -> !pto\.vreg<64xf32>', + ) + + def test_vcvt_supports_keyword_attrs_with_enums(self) -> None: + @pto.vkernel( + op="vcvt_keyword_attrs_unique", + dtypes=[(pto.f16, pto.f32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + src_mask = pto.make_mask(pto.f32, pto.PAT.ALL) + dst_mask = pto.make_mask(pto.f16, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vcvt( + vec, + pto.f16, + src_mask, + rnd=pto.VcvtRoundMode.R, + sat=pto.VcvtSatMode.SAT, + part=pto.VcvtPartMode.ODD, + ) + pto.vsts(out, dst, 0, dst_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn('pto.vcvt', text) + self.assertIn('rnd = "R"', text) + self.assertIn('sat = "SAT"', text) + self.assertIn('part = "ODD"', text) + self.assertRegex( + text, + r"= pto\.vcvt %[^,\s]+, %[^,\s]+(?: \{[^}]+\})? : !pto\.vreg<[^>]+>, !pto\.mask -> !pto\.vreg<[^>]+>", + ) + + def test_vcvt_supports_part_t_modes_with_enum(self) -> None: + @pto.vkernel( + op="vcvt_part_t_enum_unique", + dtypes=[(pto.i8, pto.f16)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + src_mask = pto.make_mask(pto.f16, pto.PAT.ALL) + dst_mask = pto.make_mask(pto.i8, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vcvt( + vec, + pto.i8, + src_mask, + rnd=pto.VcvtRoundMode.R, + sat=pto.VcvtSatMode.SAT, + part=pto.VcvtPartMode.P0, + ) + pto.vsts(out, dst, 0, dst_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 256), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("pto.vcvt", text) + self.assertIn('rnd = "R"', text) + self.assertIn('sat = "SAT"', text) + self.assertIn('part = "P0"', text) + + def test_vcvt_supports_part_t_modes_with_canonical_string(self) -> None: + @pto.vkernel( + op="vcvt_part_t_string_unique", + dtypes=[(pto.i8, pto.f16)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + src_mask = pto.make_mask(pto.f16, pto.PAT.ALL) + dst_mask = pto.make_mask(pto.i8, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vcvt( + vec, + pto.i8, + src_mask, + rnd="R", + sat="SAT", + part="P3", + ) + pto.vsts(out, dst, 0, dst_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 256), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("pto.vcvt", text) + self.assertIn('part = "P3"', text) + + def test_vcvt_i32_to_i64_reuses_b32_mask_and_emits_i64_vreg(self) -> None: + @pto.vkernel( + op="vcvt_i32_to_i64_unique", + dtypes=[(pto.i64, pto.i32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + src_mask = pto.make_mask(pto.i32, pto.PAT.ALL) + dst_mask = pto.make_mask(pto.i64, pto.PAT.ALL) + vec = pto.vlds(src, 0, dist=pto.VLoadDist.UNPK_B32) + out = pto.vcvt( + vec, + pto.i64, + src_mask, + part=pto.VcvtPartMode.EVEN, + ) + pto.vsts(out, dst, 0, dst_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 32), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + store_stmt = next(stmt for stmt in _walk_semantic_stmts(semantic_kernel.body) if isinstance(stmt, SemanticVectorStoreStmt)) + self.assertIsInstance(store_stmt.mask.type, SemanticMaskType) + self.assertEqual(store_stmt.mask.type.granularity, "b32") + + text = specialized.mlir_text() + self.assertIn("!pto.mask", text) + self.assertIn('dist = "UNPK_B32"', text) + self.assertRegex(text, r"!pto\.vreg<32xi64>") + self.assertIn('part = "EVEN"', text) + self.assertIn("pto.vsts", text) + + def test_vlds_dist_requires_vload_dist_enum(self) -> None: + with self.assertRaises(TypeError) as ctx: + + @pto.vkernel( + op="vlds_dist_requires_enum_unique", + dtypes=[(pto.i32, pto.i32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + mask = pto.make_mask(pto.i32, pto.PAT.ALL) + vec = pto.vlds(src, 0, dist="UNPK_B32") + pto.vsts(vec, dst, 0, mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + specialized.mlir_text() + + self.assertIn("VLoadDist enum", str(ctx.exception)) + + def test_vsts_dist_requires_vstore_dist_enum(self) -> None: + with self.assertRaises(TypeError) as ctx: + + @pto.vkernel( + op="vsts_dist_requires_enum_unique", + dtypes=[(pto.ui8, pto.ui8)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + mask = pto.make_mask(pto.ui8, pto.PAT.ALL) + vec = pto.vlds(src, 0) + pto.vsts(vec, dst, 0, mask, dist="NORM_B8") + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 256), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 256), memory_space=pto.MemorySpace.UB), + ) + specialized.mlir_text() + + self.assertIn("VStoreDist enum", str(ctx.exception)) + + def test_vtrc_defaults_to_round_nearest(self) -> None: + @pto.vkernel( + op="vtrc_default_rnd_unique", + dtypes=[(pto.f32, pto.f32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + all_mask = pto.make_mask(pto.f32, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vtrc(vec, all_mask) + pto.vsts(out, dst, 0, all_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("pto.vtrc", text) + self.assertIn(', "R" :', text) + self.assertRegex( + text, + r"= pto\.vtrc %[^,\s]+, %[^,\s]+, \"R\" : !pto\.vreg<[^>]+>, !pto\.mask<[^>]+> -> !pto\.vreg<[^>]+>", + ) + + def test_vtrc_supports_keyword_rnd_with_enums(self) -> None: + @pto.vkernel( + op="vtrc_keyword_rnd_unique", + dtypes=[(pto.f32, pto.f32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + all_mask = pto.make_mask(pto.f32, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vtrc(vec, all_mask, rnd=pto.VcvtRoundMode.F) + pto.vsts(out, dst, 0, all_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("pto.vtrc", text) + self.assertIn(', "F" :', text) + + def test_vtrc_rejects_round_mode_o(self) -> None: + with self.assertRaises(TypeError) as ctx: + @pto.vkernel( + op="vtrc_round_mode_o_unique", + dtypes=[(pto.f32, pto.f32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + all_mask = pto.make_mask(pto.f32, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vtrc(vec, all_mask, rnd=pto.VcvtRoundMode.O) + pto.vsts(out, dst, 0, all_mask) + return None + + kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ).mlir_text() + + self.assertIn("pto.vtrc rnd must be one of", str(ctx.exception)) + + def test_advanced_sort_memory_ops_surface_lower(self) -> None: + @pto.vkernel( + op="advanced_sort_memory_ops_unique", + dtypes=[(pto.f32, pto.f32, pto.i32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile, idx: pto.Tile): + dst_ptr = dst.as_ptr() + src_ptr = src.as_ptr() + idx_ptr = idx.as_ptr() + + pto.vbitsort(dst_ptr, src_ptr, idx_ptr, 1) + pto.vmrgsort4( + dst_ptr, + src_ptr, + pto.addptr(src_ptr, 64), + pto.addptr(src_ptr, 128), + pto.addptr(src_ptr, 192), + pto.i64(64), + pto.i64(0), + ) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 256), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 256), memory_space=pto.MemorySpace.UB), + idx=pto.TileSpecialization(shape=(8, 256), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertRegex( + text, + r"pto\.vbitsort %dst_ptr_\d+, %src_ptr_\d+, %idx_ptr_\d+, %c1 : !pto\.ptr, !pto\.ptr, !pto\.ptr, index", + ) + self.assertRegex( + text, + r"pto\.vmrgsort4 %dst_ptr_\d+, %src_ptr_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %c\d+_i64, %c\d+_i64 : " + r"!pto\.ptr, !pto\.ptr, !pto\.ptr, !pto\.ptr, !pto\.ptr, i64, i64", + ) + + def test_vbitsort_helper_lowers_without_frontend_vecscope(self) -> None: + @pto.vkernel( + op="vbitsort_no_frontend_vecscope_unique", + dtypes=[(pto.f32, pto.f32, pto.i32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile, idx: pto.Tile): + dst_ptr = dst.as_ptr() + src_ptr = src.as_ptr() + idx_ptr = idx.as_ptr() + + pto.vbitsort(dst_ptr, src_ptr, idx_ptr, 1) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 256), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 256), memory_space=pto.MemorySpace.UB), + idx=pto.TileSpecialization(shape=(8, 256), memory_space=pto.MemorySpace.UB), + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + vecscope_stmts = [stmt for stmt in semantic_kernel.body if isinstance(stmt, SemanticVecscopeStmt)] + self.assertEqual(vecscope_stmts, []) + + text = specialized.mlir_text() + self.assertRegex( + text, + r"pto\.vbitsort %dst_ptr_\d+, %src_ptr_\d+, %idx_ptr_\d+, %c1 : !pto\.ptr, !pto\.ptr, !pto\.ptr, index", + ) + self.assertNotIn("pto.vecscope {", text) + + def test_vcvt_rejects_legacy_string_spellings(self) -> None: + with self.assertRaises(TypeError) as ctx: + + @pto.vkernel( + op="vcvt_keyword_attrs_legacy_unique", + dtypes=[(pto.f16, pto.f32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + src_mask = pto.make_mask(pto.f32, pto.PAT.ALL) + dst_mask = pto.make_mask(pto.f16, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vcvt( + vec, + pto.f16, + src_mask, + rnd="ROUND_R", + sat="RS_ENABLE", + part="PART_EVEN", + ) + pto.vsts(out, dst, 0, dst_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + specialized.mlir_text() + + self.assertIn("pto.vcvt rnd must be a VcvtRoundMode enum", str(ctx.exception)) + + def test_vcvt_requires_explicit_required_attrs_for_type_pair(self) -> None: + with self.assertRaises(TypeError) as ctx: + + @pto.vkernel( + op="vcvt_missing_required_attrs_unique", + dtypes=[(pto.i32, pto.f32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + src_mask = pto.make_mask(pto.f32, pto.PAT.ALL) + dst_mask = pto.make_mask(pto.i32, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vcvt(vec, pto.i32, src_mask, rnd=pto.VcvtRoundMode.R) + pto.vsts(out, dst, 0, dst_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + specialized.mlir_text() + + self.assertIn("requires explicit `sat=`", str(ctx.exception)) + + def test_vcvt_rejects_disallowed_attrs_for_type_pair(self) -> None: + with self.assertRaises(TypeError) as ctx: + + @pto.vkernel( + op="vcvt_disallowed_attr_unique", + dtypes=[(pto.i32, pto.f32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + src_mask = pto.make_mask(pto.f32, pto.PAT.ALL) + dst_mask = pto.make_mask(pto.i32, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vcvt( + vec, + pto.i32, + src_mask, + rnd=pto.VcvtRoundMode.R, + sat=pto.VcvtSatMode.SAT, + part=pto.VcvtPartMode.ODD, + ) + pto.vsts(out, dst, 0, dst_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + specialized.mlir_text() + + self.assertIn("does not accept `part=`", str(ctx.exception)) + + def test_vcvt_f16_to_i32_requires_rnd_and_part(self) -> None: + with self.assertRaises(TypeError) as ctx: + + @pto.vkernel( + op="vcvt_f16_to_i32_missing_rnd_unique", + dtypes=[(pto.i32, pto.f16)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + src_mask = pto.make_mask(pto.f16, pto.PAT.ALL) + dst_mask = pto.make_mask(pto.i32, pto.PAT.ALL) + vec = pto.vlds(src, 0, dist=pto.VLoadDist.UNPK_B16) + out = pto.vcvt( + vec, + pto.i32, + src_mask, + part=pto.VcvtPartMode.EVEN, + ) + pto.vsts(out, dst, 0, dst_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + specialized.mlir_text() + + self.assertIn("requires explicit `rnd=`", str(ctx.exception)) + + with self.assertRaises(TypeError) as ctx: + + @pto.vkernel( + op="vcvt_f16_to_i32_missing_part_unique", + dtypes=[(pto.i32, pto.f16)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + src_mask = pto.make_mask(pto.f16, pto.PAT.ALL) + dst_mask = pto.make_mask(pto.i32, pto.PAT.ALL) + vec = pto.vlds(src, 0, dist=pto.VLoadDist.UNPK_B16) + out = pto.vcvt( + vec, + pto.i32, + src_mask, + rnd=pto.VcvtRoundMode.R, + ) + pto.vsts(out, dst, 0, dst_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + specialized.mlir_text() + + self.assertIn("requires explicit `part=`", str(ctx.exception)) + + def test_vcvt_f16_to_i32_rejects_sat(self) -> None: + with self.assertRaises(TypeError) as ctx: + + @pto.vkernel( + op="vcvt_f16_to_i32_sat_unique", + dtypes=[(pto.i32, pto.f16)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + src_mask = pto.make_mask(pto.f16, pto.PAT.ALL) + dst_mask = pto.make_mask(pto.i32, pto.PAT.ALL) + vec = pto.vlds(src, 0, dist=pto.VLoadDist.UNPK_B16) + out = pto.vcvt( + vec, + pto.i32, + src_mask, + rnd=pto.VcvtRoundMode.R, + sat=pto.VcvtSatMode.SAT, + part=pto.VcvtPartMode.EVEN, + ) + pto.vsts(out, dst, 0, dst_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + specialized.mlir_text() + + self.assertIn("does not accept `sat=`", str(ctx.exception)) + + def test_vcvt_f16_to_i32_accepts_rnd_and_part(self) -> None: + @pto.vkernel( + op="vcvt_f16_to_i32_attrs_unique", + dtypes=[(pto.i32, pto.f16)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + src_mask = pto.make_mask(pto.f16, pto.PAT.ALL) + dst_mask = pto.make_mask(pto.i32, pto.PAT.ALL) + vec = pto.vlds(src, 0, dist=pto.VLoadDist.UNPK_B16) + out = pto.vcvt( + vec, + pto.i32, + src_mask, + rnd=pto.VcvtRoundMode.R, + part=pto.VcvtPartMode.EVEN, + ) + pto.vsts(out, dst, 0, dst_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("pto.vcvt", text) + self.assertIn('rnd = "R"', text) + self.assertIn('part = "EVEN"', text) + self.assertNotIn('sat = "SAT"', text) + + def test_vcvt_bf16_to_f16_requires_rnd_and_sat(self) -> None: + with self.assertRaises(TypeError) as ctx: + + @pto.vkernel( + op="vcvt_bf16_to_f16_missing_rnd_unique", + dtypes=[(pto.f16, pto.bf16)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + src_mask = pto.make_mask(pto.bf16, pto.PAT.ALL) + dst_mask = pto.make_mask(pto.f16, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vcvt( + vec, + pto.f16, + src_mask, + sat=pto.VcvtSatMode.SAT, + ) + pto.vsts(out, dst, 0, dst_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + specialized.mlir_text() + + self.assertIn("requires explicit `rnd=`", str(ctx.exception)) + + with self.assertRaises(TypeError) as ctx: + + @pto.vkernel( + op="vcvt_bf16_to_f16_missing_sat_unique", + dtypes=[(pto.f16, pto.bf16)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + src_mask = pto.make_mask(pto.bf16, pto.PAT.ALL) + dst_mask = pto.make_mask(pto.f16, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vcvt( + vec, + pto.f16, + src_mask, + rnd=pto.VcvtRoundMode.R, + ) + pto.vsts(out, dst, 0, dst_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + specialized.mlir_text() + + self.assertIn("requires explicit `sat=`", str(ctx.exception)) + + def test_vcvt_bf16_to_f16_rejects_part(self) -> None: + with self.assertRaises(TypeError) as ctx: + + @pto.vkernel( + op="vcvt_bf16_to_f16_part_unique", + dtypes=[(pto.f16, pto.bf16)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + src_mask = pto.make_mask(pto.bf16, pto.PAT.ALL) + dst_mask = pto.make_mask(pto.f16, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vcvt( + vec, + pto.f16, + src_mask, + rnd=pto.VcvtRoundMode.R, + sat=pto.VcvtSatMode.SAT, + part=pto.VcvtPartMode.EVEN, + ) + pto.vsts(out, dst, 0, dst_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + specialized.mlir_text() + + self.assertIn("does not accept `part=`", str(ctx.exception)) + + def test_vcvt_bf16_to_f16_accepts_rnd_and_sat(self) -> None: + @pto.vkernel( + op="vcvt_bf16_to_f16_attrs_unique", + dtypes=[(pto.f16, pto.bf16)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + src_mask = pto.make_mask(pto.bf16, pto.PAT.ALL) + dst_mask = pto.make_mask(pto.f16, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vcvt( + vec, + pto.f16, + src_mask, + rnd=pto.VcvtRoundMode.R, + sat=pto.VcvtSatMode.SAT, + ) + pto.vsts(out, dst, 0, dst_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("pto.vcvt", text) + self.assertIn('rnd = "R"', text) + self.assertIn('sat = "SAT"', text) + self.assertNotIn('part = "EVEN"', text) + + def test_vbitcast_supports_direct_interface(self) -> None: + @pto.vkernel( + op="vbitcast_direct_unique", + dtypes=[(pto.f32, pto.f32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + all_mask = pto.make_mask(pto.f32, pto.PAT.ALL) + # Load float vector + fvec = pto.vlds(src, 0) # !pto.vreg<64xf32> + # Convert to integer via vbitcast + ivec = pto.vbitcast(fvec, pto.i32) # !pto.vreg<64xi32> + # Convert back to float + fvec2 = pto.vbitcast(ivec, pto.f32) # !pto.vreg<64xf32> + # Store result + pto.vsts(fvec2, dst, 0, all_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("pto.vbitcast", text) + self.assertRegex(text, r"= pto\.vbitcast %[^:]+ : !pto\.vreg<64xf32> -> !pto\.vreg<64xi32>") + self.assertRegex(text, r"= pto\.vbitcast %[^:]+ : !pto\.vreg<64xi32> -> !pto\.vreg<64xf32>") + + def test_vbitcast_supports_astype_syntax_sugar(self) -> None: + @pto.vkernel( + op="vbitcast_astype_unique", + dtypes=[(pto.f32, pto.f32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + all_mask = pto.make_mask(pto.f32, pto.PAT.ALL) + # Load float vector + fvec = pto.vlds(src, 0) # !pto.vreg<64xf32> + # Convert to integer via astype syntax sugar + ivec = fvec.astype(pto.i32) # !pto.vreg<64xi32> + # Convert back to float + fvec2 = ivec.astype(pto.f32) # !pto.vreg<64xf32> + # Store result + pto.vsts(fvec2, dst, 0, all_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("pto.vbitcast", text) + # astype calls should be lowered to vbitcast + count = text.count("pto.vbitcast") + self.assertGreaterEqual(count, 2) + + def test_vbitcast_rejects_non_vreg_input(self) -> None: + with self.assertRaises(TypeError) as ctx: + @pto.vkernel( + op="vbitcast_non_vreg_input_unique", + dtypes=[(pto.f32, pto.f32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + all_mask = pto.make_mask(pto.f32, pto.PAT.ALL) + # Try to vbitcast a non-vector value + scalar = pto.f32(1.0) + ivec = pto.vbitcast(scalar, pto.i32) + pto.vsts(ivec, dst, 0, all_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + specialized.mlir_text() + + self.assertIn("vector register value", str(ctx.exception)) + + def test_astype_requires_vector_register(self) -> None: + with self.assertRaises(TypeError) as ctx: + @pto.vkernel( + op="astype_non_vreg_input_unique", + dtypes=[(pto.f32, pto.f32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + # Try to call astype on a non-vector value + scalar = pto.f32(1.0) + ivec = scalar.astype(pto.i32) + all_mask = pto.make_mask(pto.f32, pto.PAT.ALL) + pto.vsts(ivec, dst, 0, all_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + specialized.mlir_text() + + self.assertIn("vector register or mask value", str(ctx.exception)) + + def test_vbitcast_supports_element_size_change(self) -> None: + @pto.vkernel( + op="vbitcast_element_size_change_unique", + dtypes=[(pto.f32, pto.f32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + f32_mask = pto.make_mask(pto.f32, pto.PAT.ALL) + f16_mask = pto.make_mask(pto.f16, pto.PAT.ALL) + # Load f32 vector (64 elements) + f32_vec = pto.vlds(src, 0) # !pto.vreg<64xf32> + # Convert to f16 (128 elements) + f16_vec = pto.vbitcast(f32_vec, pto.f16) # !pto.vreg<128xf16> + # Convert back to f32 + f32_vec2 = pto.vbitcast(f16_vec, pto.f32) # !pto.vreg<64xf32> + # Store result + pto.vsts(f32_vec2, dst, 0, f32_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("pto.vbitcast", text) + self.assertRegex(text, r"= pto\.vbitcast %[^:]+ : !pto\.vreg<64xf32> -> !pto\.vreg<128xf16>") + self.assertRegex(text, r"= pto\.vbitcast %[^:]+ : !pto\.vreg<128xf16> -> !pto\.vreg<64xf32>") + + def test_pbitcast_supports_direct_interface(self) -> None: + @pto.vkernel( + op="pbitcast_direct_unique", + dtypes=[(pto.f32, pto.f32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + src_mask = pto.make_mask(pto.f16, pto.PAT.ALL) + dst_mask = pto.pbitcast(src_mask, pto.mask_b32) + vec = pto.vlds(src, 0) + pto.vsts(vec, dst, 0, dst_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("pto.pbitcast", text) + self.assertRegex(text, r'%src_mask_\d+ = pto\.pset_b16 "PAT_ALL" : !pto\.mask') + self.assertRegex(text, r'= pto\.pbitcast %[^:]+ : !pto\.mask -> !pto\.mask') + + def test_pbitcast_rejects_non_mask_input(self) -> None: + with self.assertRaises(TypeError) as ctx: + @pto.vkernel( + op="pbitcast_non_mask_input_unique", + dtypes=[(pto.f32, pto.f32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + vec = pto.vlds(src, 0) + mask = pto.pbitcast(vec, pto.mask_b32) + pto.vsts(vec, dst, 0, mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + specialized.mlir_text() + + self.assertIn("mask value", str(ctx.exception)) + + def test_mask_astype_lowers_to_pbitcast(self) -> None: + @pto.vkernel( + op="mask_astype_unique", + dtypes=[(pto.f32, pto.f32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + src_mask = pto.make_mask(pto.f16, pto.PAT.ALL) + dst_mask = src_mask.astype(pto.mask_b32) + vec = pto.vlds(src, 0) + pto.vsts(vec, dst, 0, dst_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("pto.pbitcast", text) + self.assertRegex(text, r'%src_mask_\d+ = pto\.pset_b16 "PAT_ALL" : !pto\.mask') + self.assertRegex(text, r'= pto\.pbitcast %[^:]+ : !pto\.mask -> !pto\.mask') + + def test_astype_rejects_non_vreg_or_mask_receiver(self) -> None: + with self.assertRaises(TypeError) as ctx: + @pto.vkernel( + op="astype_invalid_receiver_unique", + dtypes=[(pto.f32, pto.f32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + scalar = pto.f32(1.0) + mask = scalar.astype(pto.mask_b32) + vec = pto.vlds(src, 0) + pto.vsts(vec, dst, 0, mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + specialized.mlir_text() + + self.assertIn("vector register or mask value", str(ctx.exception)) + + def test_index_to_float_scalar_cast_lowers_via_integer_bridge(self) -> None: + @pto.vkernel( + op="index_to_float_scalar_cast_unique", + dtypes=[(pto.f32, pto.f32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + mask, _ = pto.make_mask(pto.f32, 1) + vec = pto.vlds(src, 0) + for col in range(0, 1, 1): + scalar = pto.f32(col) + out = pto.vadds(vec, scalar, mask) + pto.vsts(out, dst, 0, mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("arith.index_castui", text) + self.assertRegex(text, r"arith\.uitofp %\w+ : i64 to f32") + self.assertNotRegex(text, r"arith\.uitofp %\w+ : index to f32") + + def test_extended_integer_vector_ops_surface_lowers(self) -> None: + @pto.vkernel( + op="extended_integer_vector_ops_unique", + dtypes=[(pto.i32, pto.i32, pto.i16)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile, shift: pto.i16): + all_mask = pto.make_mask(pto.i32, pto.PAT.ALL) + vec0 = pto.vlds(src, 0) + vec1 = pto.vlds(src, 64) + + out = pto.vbcnt(vec0, all_mask) + out = pto.vneg(out, all_mask) + out = pto.vcls(out, all_mask) + pto.vsunpack(vec0, 0) + pto.vzunpack(vec0.astype(pto.ui32), 0) + pto.vusqz(vec0.astype(pto.ui32), pto.make_mask(pto.ui32, pto.PAT.ALL)) + pto.vsqz(vec0, all_mask) + out = pto.vshl(out, vec1, all_mask) + out = pto.vshr(out, vec1, all_mask) + out = pto.vshls(out, shift, all_mask) + out = pto.vshrs(out, shift, all_mask) + pto.vsts(out, dst, 0, all_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("pto.vbcnt", text) + self.assertIn("pto.vneg", text) + self.assertIn("pto.vcls", text) + self.assertIn("pto.vsunpack", text) + self.assertIn("pto.vzunpack", text) + self.assertIn("pto.vusqz", text) + self.assertIn("pto.vsqz", text) + self.assertIn("pto.vshl", text) + self.assertIn("pto.vshr", text) + self.assertIn("pto.vshls", text) + self.assertIn("pto.vshrs", text) + + def test_fused_vector_ops_surface_lowers(self) -> None: + @pto.vkernel( + op="fused_vector_ops_unique", + dtypes=[(pto.f32, pto.f32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + all_mask = pto.make_mask(pto.f32, pto.PAT.ALL) + vec0 = pto.vlds(src, 0) + vec1 = pto.vlds(src, 64) + vec2 = pto.vlds(src, 128) + vec3 = pto.vlds(src, 192) + + out = pto.vaddrelu(vec0, vec1, all_mask) + out = pto.vaddreluconv(out, vec2, all_mask) + out = pto.vsubrelu(out, vec3, all_mask) + out = pto.vmulconv(out, vec1, all_mask) + out = pto.vaxpy(vec1, out, vec2, all_mask) + out = pto.vmula(vec1, vec2, out, all_mask) + pto.vsts(out, dst, 0, all_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 256), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 256), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("pto.vaddrelu", text) + self.assertIn("pto.vaddreluconv", text) + self.assertIn("pto.vsubrelu", text) + self.assertIn("pto.vmulconv", text) + self.assertIn("pto.vaxpy", text) + self.assertIn("pto.vmula", text) + + def test_vmull_and_vector_scalar_bitwise_surface_lowers(self) -> None: + @pto.vkernel( + op="vmull_and_scalar_bitwise_unique", + dtypes=[(pto.i32, pto.i32, pto.i32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile, scalar: pto.i32): + all_mask = pto.make_mask(pto.i32, pto.PAT.ALL) + vec0 = pto.vlds(src, 0) + vec1 = pto.vlds(src, 64) + + low, high = pto.vmull(vec0, vec1, all_mask) + out = pto.vadd(low, high, all_mask) + out = pto.vands(out, scalar, all_mask) + out = pto.vors(out, scalar, all_mask) + out = pto.vxors(out, scalar, all_mask) + pto.vsts(out, dst, 0, all_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("pto.vmull", text) + self.assertIn("pto.vands", text) + self.assertIn("pto.vors", text) + self.assertIn("pto.vxors", text) + + def test_vci_typed_integer_inputs_lower_without_typed_arith(self) -> None: + @pto.vkernel( + op="vci_typed_integer_inputs_unique", + dtypes=[(pto.ui16, pto.si16, pto.i32)], + advanced=True, + ) + def kernel(dst_u: pto.Tile, dst_s: pto.Tile, seed: pto.i32): + unsigned_mask = pto.make_mask(pto.ui16, pto.PAT.ALL) + signed_mask = pto.make_mask(pto.si16, pto.PAT.ALL) + + unsigned_idx = pto.vci(pto.ui16(0)) + signed_idx = pto.vci(pto.si16(seed), pto.OrderMode.ASC) + + pto.vsts(unsigned_idx, dst_u, 0, unsigned_mask) + pto.vsts(signed_idx, dst_s, 0, signed_mask) + return None + + specialized = kernel.specialize( + dst_u=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + dst_s=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("pto.vci", text) + self.assertIn(": i16 to ui16", text) + self.assertIn(": i16 to si16", text) + self.assertNotIn("arith.constant 0 : ui16", text) + self.assertNotRegex(text, r"arith\.(extsi|extui|trunci|bitcast) %\w+ : .* to (ui16|si16)") + + def test_vector_scalar_bitwise_typed_scalar_inputs_lower_without_typed_arith(self) -> None: + @pto.vkernel( + op="vector_scalar_bitwise_typed_scalar_inputs_unique", + dtypes=[(pto.ui16, pto.ui16, pto.i32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile, seed: pto.i32): + mask = pto.make_mask(pto.ui16, pto.PAT.ALL) + vec = pto.vlds(src, 0) + scalar = pto.ui16(seed) + out = pto.vands(vec, scalar, mask) + out = pto.vors(out, scalar, mask) + out = pto.vxors(out, scalar, mask) + pto.vsts(out, dst, 0, mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("pto.vands", text) + self.assertIn("pto.vors", text) + self.assertIn("pto.vxors", text) + self.assertIn("arith.trunci", text) + self.assertIn(": i32 to i16", text) + self.assertIn(": i16 to ui16", text) + self.assertNotRegex(text, r"arith\.trunci %\w+ : i32 to ui16") + + def test_broadcast_and_index_vector_ops_surface_lowers(self) -> None: + @pto.vkernel( + op="broadcast_and_index_vector_ops_unique", + dtypes=[(pto.i32, pto.i32, pto.i32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile, seed: pto.i32): + all_mask = pto.make_mask(pto.i32, pto.PAT.ALL) + vec0 = pto.vlds(src, 0) + + broadcast = pto.vbr(seed) + dup_from_vec = pto.vdup(vec0, all_mask, pto.PositionMode.HIGHEST) + dup_from_scalar = pto.vdup(seed, all_mask) + idx0 = pto.vci(seed) + idx1 = pto.vci(seed, pto.OrderMode.ASC) + + out = pto.vadd(broadcast, dup_from_vec, all_mask) + out = pto.vadd(out, dup_from_scalar, all_mask) + out = pto.vadd(out, idx0, all_mask) + out = pto.vadd(out, idx1, all_mask) + pto.vsts(out, dst, 0, all_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("pto.vbr", text) + self.assertIn("pto.vdup", text) + self.assertIn("pto.vci", text) + self.assertRegex( + text, + r'pto\.vdup\s+%[^\s]+,\s+%[^\s]+\s+\{position = "HIGHEST"\}\s+:', + ) + self.assertRegex( + text, + r'pto\.vdup\s+%[^\s]+,\s+%[^\s]+\s+:', + ) + self.assertNotIn('position = "LOWEST"', text) + self.assertNotIn('position = "POS_LOWEST"', text) + self.assertRegex( + text, + r'pto\.vci\s+%[^\s]+\s+\{order = "ASC"\}\s+:', + ) + self.assertNotRegex( + text, + r'pto\.vci\s+%[^\s]+,\s*"ASC"\s+:', + ) + + def test_vci_desc_lowers_to_desc_order_attr(self) -> None: + @pto.vkernel( + op="vci_desc_order_unique", + dtypes=[(pto.i32, pto.i32, pto.i32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile, seed: pto.i32): + mask = pto.make_mask(pto.i32, pto.PAT.ALL) + indices = pto.vci(seed, pto.OrderMode.DESC) + vec = pto.vlds(src, 0) + out = pto.vadd(vec, indices, mask) + pto.vsts(out, dst, 0, mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertRegex( + text, + r'pto\.vci\s+%[^\s]+\s+\{order = "DESC"\}\s+:', + ) + + def test_vdup_scalar_input_rejects_position_argument(self) -> None: + with self.assertRaises(TypeError) as ctx: + + @pto.vkernel( + op="vdup_scalar_reject_position_unique", + dtypes=[(pto.i32, pto.i32)], + advanced=True, + ) + def kernel(dst: pto.Tile, seed: pto.i32): + all_mask = pto.make_mask(pto.i32, pto.PAT.ALL) + out = pto.vdup(seed, all_mask, pto.PositionMode.HIGHEST) + pto.vsts(out, dst, 0, all_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + specialized.mlir_text() + + self.assertIn("pto.vdup scalar input does not accept `position`", str(ctx.exception)) + + def test_vbr_and_vdup_accept_narrow_typed_scalar_constructors_with_explicit_bridges(self) -> None: + @pto.vkernel( + op="narrow_typed_vbr_vdup_scalar_constructors_unique", + dtypes=[(pto.si16, pto.ui16)], + advanced=True, + ) + def kernel(dst_s: pto.Tile, dst_u: pto.Tile): + signed_mask = pto.make_mask(pto.si16, pto.PAT.ALL) + unsigned_mask = pto.make_mask(pto.ui16, pto.PAT.ALL) + + signed = pto.vadd( + pto.vbr(pto.si16(0)), + pto.vdup(pto.si16(0), signed_mask), + signed_mask, + ) + unsigned = pto.vadd( + pto.vbr(pto.ui16(0)), + pto.vdup(pto.ui16(0), unsigned_mask), + unsigned_mask, + ) + + pto.vsts(signed, dst_s, 0, signed_mask) + pto.vsts(unsigned, dst_u, 0, unsigned_mask) + return None + + specialized = kernel.specialize( + dst_s=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + dst_u=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("pto.vbr", text) + self.assertIn("pto.vdup", text) + self.assertIn("arith.constant 0 : i16", text) + self.assertIn(": i16 to si16", text) + self.assertIn(": i16 to ui16", text) + self.assertNotIn("arith.constant 0 : si16", text) + self.assertNotIn("arith.constant 0 : ui16", text) + + def test_signed_and_unsigned_integer_dtypes_lower_distinctly(self) -> None: + @pto.vkernel( + op="signed_unsigned_integer_types_unique", + dtypes=[(pto.si16, pto.si16, pto.ui16, pto.ui16)], + advanced=True, + ) + def kernel(dst_s: pto.Tile, src_s: pto.Tile, dst_u: pto.Tile, src_u: pto.Tile): + signed_mask = pto.make_mask(pto.si16, pto.PAT.ALL) + unsigned_mask = pto.make_mask(pto.ui16, pto.PAT.ALL) + signed_vec = pto.vlds(src_s, 0) + unsigned_vec = pto.vlds(src_u, 0) + signed_out = pto.vadds(signed_vec, pto.si16(-1), signed_mask) + unsigned_out = pto.vadds(unsigned_vec, pto.ui16(1), unsigned_mask) + pto.vsts(signed_out, dst_s, 0, signed_mask) + pto.vsts(unsigned_out, dst_u, 0, unsigned_mask) + return None + + specialized = kernel.specialize( + dst_s=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + src_s=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + dst_u=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + src_u=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("dtype=si16", text) + self.assertIn("dtype=ui16", text) + self.assertIn("!pto.vreg<128xsi16>", text) + self.assertIn("!pto.vreg<128xui16>", text) + + def test_vcmps_literal_scalar_uses_signless_integer_bridge(self) -> None: + @pto.vkernel( + op="vcmps_literal_scalar_bridge_unique", + dtypes=[(pto.si16, pto.si16)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + all_mask = pto.make_mask(pto.si16, pto.PAT.ALL) + vec = pto.vlds(src, 0) + cmp_mask = pto.vcmps(vec, pto.si16(-1), all_mask, pto.CmpMode.GT) + selected = pto.vsel(vec, vec, cmp_mask) + pto.vsts(selected, dst, 0, all_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("pto.vcmps", text) + self.assertIn("arith.constant -1 : i16", text) + self.assertIn(": i16 to si16", text) + self.assertNotIn("arith.constant -1 : si16", text) + + def test_vadds_index_constructor_scalar_uses_signless_integer_bridge(self) -> None: + @pto.vkernel( + op="vadds_index_constructor_bridge_unique", + dtypes=[(pto.ui16, pto.ui16)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + cols = dst.valid_shape[1] + vec = pto.vlds(src, 0) + mask, _ = pto.make_mask(pto.ui16, 1) + for col in range(0, cols, 1): + scalar = pto.ui16(col) + out = pto.vadds(vec, scalar, mask) + pto.vsts(out, dst, 0, mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization( + shape=(8, 128), + valid_shape=(8, 1), + memory_space=pto.MemorySpace.UB, + ), + src=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("pto.vadds", text) + self.assertIn("arith.index_castui", text) + self.assertIn(": index to i32", text) + self.assertIn("arith.trunci", text) + self.assertIn(": i32 to i16", text) + self.assertIn(": i16 to ui16", text) + self.assertNotIn(": index to ui16", text) + + def test_vshrs_cast_result_scalar_uses_signless_integer_bridge(self) -> None: + @pto.vkernel( + op="vshrs_cast_result_scalar_bridge_unique", + dtypes=[(pto.i32, pto.i32, pto.ui16)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile, shift_seed: pto.ui16): + all_mask = pto.make_mask(pto.i32, pto.PAT.ALL) + vec = pto.vlds(src, 0) + shift = pto.i16(shift_seed) + out = pto.vshrs(vec, shift, all_mask) + pto.vsts(out, dst, 0, all_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("pto.vshrs", text) + self.assertIn(": ui16 to i16", text) + self.assertNotRegex(text, r"arith\.bitcast %\w+ : ui16 to i16") + self.assertNotRegex(text, r"arith\.trunci %\w+ : ui16 to i16") + + def test_vbr_accepts_float_literal_constant(self) -> None: + @pto.vkernel( + op="broadcast_float_literal_constant_unique", + dtypes=[(pto.f32, pto.f32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + all_mask = pto.make_mask(pto.f32, pto.PAT.ALL) + vec0 = pto.vlds(src, 0) + bias = pto.vbr(0.0) + out = pto.vadd(vec0, bias, all_mask) + pto.vsts(out, dst, 0, all_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("= arith.constant 0.0 : f32", text) + self.assertIn("pto.vbr", text) + + def test_kernel_accepts_module_level_literal_constant_reference(self) -> None: + @pto.vkernel( + op="module_level_literal_constant_reference_unique", + dtypes=[(pto.f32, pto.f32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + all_mask, _ = pto.make_mask(pto.f32, GLOBAL_TILELANG_LITERAL_BLOCK_SIZE) + vec = pto.vlds(src, 0) + out = pto.vadds(vec, pto.f32(0.0), all_mask) + pto.vsts(out, dst, 0, all_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertRegex(text, r"arith\.constant 32 : (index|i64)") + self.assertIn("pto.plt_b32", text) + + def test_scalar_constructor_call_surfaces_lower(self) -> None: + @pto.vkernel( + op="scalar_constructor_call_surfaces_unique", + dtypes=[(pto.i32, pto.i32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + mask = pto.make_mask(pto.i32, pto.PAT.ALL) + base = pto.i32(1) + idx = pto.i16(base) + idx = pto.i8(idx) + idx = pto.i64(idx) + flt = pto.f16(idx) + flt = pto.bf16(flt) + flt = pto.f32(flt) + gate = pto.i1(flt) + scalar = pto.i32(gate) + vec = pto.vlds(src, 0) + out = pto.vadds(vec, scalar, mask) + pto.vsts(out, dst, 0, mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("arith.trunci", text) + self.assertIn("arith.extsi", text) + self.assertIn("arith.sitofp", text) + self.assertIn("arith.fptosi", text) + self.assertIn("arith.extf", text) + self.assertIn("arith.truncf", text) + + def test_typed_integer_scalar_coercion_uses_signless_integer_carriers(self) -> None: + @pto.vkernel( + op="typed_integer_scalar_coercion_unique", + dtypes=[(pto.si16, pto.si16, pto.ui16, pto.ui16, pto.i32, pto.i32, pto.i32)], + advanced=True, + ) + def kernel( + dst_s: pto.Tile, + src_s: pto.Tile, + dst_u: pto.Tile, + src_u: pto.Tile, + dst_i: pto.Tile, + src_i: pto.Tile, + seed: pto.i32, + ): + signed_mask = pto.make_mask(pto.si16, pto.PAT.ALL) + unsigned_mask = pto.make_mask(pto.ui16, pto.PAT.ALL) + scalar_mask = pto.make_mask(pto.i32, pto.PAT.ALL) + + signed_scalar = pto.si16(seed) + unsigned_scalar = pto.ui16(seed) + + signed_vec = pto.vlds(src_s, 0) + unsigned_vec = pto.vlds(src_u, 0) + scalar_vec = pto.vlds(src_i, 0) + + signed_out = pto.vadds(signed_vec, signed_scalar, signed_mask) + unsigned_out = pto.vadds(unsigned_vec, unsigned_scalar, unsigned_mask) + scalar_out = pto.vadds(scalar_vec, pto.i32(signed_scalar), scalar_mask) + scalar_out = pto.vadds(scalar_out, pto.i32(unsigned_scalar), scalar_mask) + + pto.vsts(signed_out, dst_s, 0, signed_mask) + pto.vsts(unsigned_out, dst_u, 0, unsigned_mask) + pto.vsts(scalar_out, dst_i, 0, scalar_mask) + return None + + specialized = kernel.specialize( + dst_s=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + src_s=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + dst_u=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + src_u=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + dst_i=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src_i=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("arith.trunci", text) + self.assertIn(": i32 to i16", text) + self.assertIn(": i16 to si16", text) + self.assertIn(": i16 to ui16", text) + self.assertIn(": si16 to i16", text) + self.assertIn(": ui16 to i16", text) + self.assertIn("arith.extsi", text) + self.assertIn("arith.extui", text) + self.assertNotRegex(text, r"arith\.trunci %\w+ : i32 to (si16|ui16)") + self.assertNotRegex(text, r"arith\.extsi %\w+ : si16 to i32") + self.assertNotRegex(text, r"arith\.extui %\w+ : ui16 to i32") + + def test_typed_integer_float_scalar_coercion_uses_signless_integer_carriers(self) -> None: + @pto.vkernel( + op="typed_integer_float_scalar_coercion_unique", + dtypes=[(pto.ui16, pto.ui16)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + mask = pto.make_mask(pto.ui16, pto.PAT.ALL) + vec = pto.vlds(src, 0) + scalar = pto.ui16(1) + flt = pto.f32(scalar) + back = pto.ui16(flt) + out = pto.vadds(vec, back, mask) + pto.vsts(out, dst, 0, mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn(": ui16 to i16", text) + self.assertIn("arith.uitofp", text) + self.assertIn(": i16 to f32", text) + self.assertIn("arith.fptoui", text) + self.assertIn(": f32 to i16", text) + self.assertIn(": i16 to ui16", text) + self.assertNotRegex(text, r"arith\.uitofp %\w+ : ui16 to f32") + self.assertNotRegex(text, r"arith\.fptoui %\w+ : f32 to ui16") + + def test_scalar_constructor_accepts_signed_float_literals(self) -> None: + @pto.vkernel(op="scalar_constructor_signed_float_literals_unique", dtypes=[(pto.f32,)]) + def kernel(inp: pto.TensorView): + a = pto.f16(-1.5) + b = pto.bf16(+2.5) + c = pto.f32(-3.5) + return None + + text = kernel.mlir_text() + self.assertIn("= arith.constant -1.5 : f16", text) + self.assertIn("= arith.constant 2.5 : bf16", text) + self.assertIn("= arith.constant -3.5 : f32", text) + + def test_scalar_constructor_accepts_special_float_string_literals(self) -> None: + @pto.vkernel(op="scalar_constructor_special_float_literals_unique", dtypes=[(pto.f32,)]) + def kernel(inp: pto.TensorView): + a = pto.f16("-inf") + b = pto.bf16("inf") + c = pto.f32("nan") + d = pto.f16("0xFC00") + e = pto.bf16("0xFF80") + f = pto.f32("0xFF800000") + return None + + text = kernel.mlir_text() + self.assertIn("= arith.constant 0xFC00 : f16", text) + self.assertIn("= arith.constant 0x7F80 : bf16", text) + self.assertIn("= arith.constant 0x7FC00000 : f32", text) + self.assertIn("= arith.constant 0xFF80 : bf16", text) + self.assertIn("= arith.constant 0xFF800000 : f32", text) + + def test_scalar_constructor_emits_negative_zero_as_stable_bit_pattern(self) -> None: + @pto.vkernel(op="scalar_constructor_negative_zero_unique", dtypes=[(pto.f32,)]) + def kernel(inp: pto.TensorView): + a = pto.f16(-0.0) + b = pto.bf16(-0.0) + c = pto.f32(-0.0) + return None + + text = kernel.mlir_text() + self.assertIn("= arith.constant 0x8000 : f16", text) + self.assertIn("= arith.constant 0x8000 : bf16", text) + self.assertIn("= arith.constant 0x80000000 : f32", text) + + def test_scalar_constructor_rejects_bad_arity(self) -> None: + @pto.vkernel(op="scalar_constructor_bad_arity_no_arg_unique", dtypes=[(pto.f32,)]) + def kernel_no_arg(inp: pto.TensorView): + x = pto.i32() + return None + + with self.assertRaises(TypeError) as no_arg_ctx: + kernel_no_arg.mlir_text() + + self.assertIn("pto.i32 expects exactly 1 positional argument", str(no_arg_ctx.exception)) + + @pto.vkernel(op="scalar_constructor_bad_arity_two_arg_unique", dtypes=[(pto.f32,)]) + def kernel_two_arg(inp: pto.TensorView): + x = pto.f32(1.0, 2.0) + return None + + with self.assertRaises(TypeError) as two_arg_ctx: + kernel_two_arg.mlir_text() + + self.assertIn("pto.f32 expects exactly 1 positional argument", str(two_arg_ctx.exception)) + + def test_scalar_constructor_rejects_non_scalar_operand(self) -> None: + @pto.vkernel(op="scalar_constructor_bad_operand_unique", dtypes=[(pto.f32,)]) + def kernel(inp: pto.TensorView): + x = pto.i32(inp) + return None + + with self.assertRaises(TypeError) as ctx: + kernel.mlir_text() + + self.assertIn("pto.i32 value must be a scalar or index value", str(ctx.exception)) + + def test_scalar_constructor_accepts_integer_hex_bit_pattern_strings(self) -> None: + @pto.vkernel(op="scalar_constructor_integer_hex_bit_patterns_unique", dtypes=[(pto.f32,)]) + def kernel(inp: pto.TensorView): + x = pto.i16("0x7FFF") + y = pto.i32("0x7FFFFFFF") + z = pto.i16("0x8000") + a = pto.i32("0x80000000") + b = pto.ui16("0x8000") + return None + + text = kernel.mlir_text() + self.assertIn("= arith.constant 32767 : i16", text) + self.assertIn("= arith.constant 2147483647 : i32", text) + self.assertIn("= arith.constant -32768 : i16", text) + self.assertIn("= arith.constant -2147483648 : i32", text) + self.assertIn("= arith.constant 32768 : i16", text) + + def test_scalar_constructor_rejects_non_hex_integer_string_literals(self) -> None: + @pto.vkernel(op="scalar_constructor_non_hex_integer_strings_unique", dtypes=[(pto.f32,)]) + def kernel(inp: pto.TensorView): + x = pto.i32("1024") + return None + + with self.assertRaises(TypeError) as ctx: + kernel.mlir_text() + + self.assertIn("string literals must use hex bit-pattern form", str(ctx.exception)) + + def test_scalar_constructor_rejects_out_of_range_integer_literal(self) -> None: + @pto.vkernel(op="scalar_constructor_oob_int_unique", dtypes=[(pto.f32,)]) + def kernel(inp: pto.TensorView): + x = pto.i8(1024) + return None + + with self.assertRaises(TypeError) as ctx: + kernel.mlir_text() + + self.assertIn("out of range for i8", str(ctx.exception)) + + def test_scalar_constructor_rejects_out_of_range_integer_string_literal(self) -> None: + @pto.vkernel(op="scalar_constructor_oob_integer_string_unique", dtypes=[(pto.f32,)]) + def kernel(inp: pto.TensorView): + x = pto.i16("0x10000") + return None + + with self.assertRaises(TypeError) as ctx: + kernel.mlir_text() + + self.assertIn("exceeds 16-bit width for i16", str(ctx.exception)) + + def test_vector_bindings_propagate_through_constexpr_if_without_frontend_vecscope(self) -> None: + @pto.vkernel( + op="vector_binding_constexpr_if_unique", + dtypes=[(pto.f32, pto.f32)], + ) + def kernel(dst: pto.Tile, src: pto.Tile): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + acc = pto.vbr(0.0) + vec = pto.vlds(src, 0) + acc = pto.vadd(acc, vec, mask) + if pto.constexpr(True): + pto.vsts(acc, dst, 0, mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("pto.vadd", text) + self.assertIn("pto.vsts", text) + self.assertIn("= arith.constant 0.0 : f32", text) + + def test_loop_lowering_supports_multiple_loop_carried_bindings(self) -> None: + @pto.vkernel( + op="loop_multi_carried_bindings_unique", + dtypes=[(pto.f32, pto.f32)], + ) + def kernel(dst: pto.Tile, src: pto.Tile): + remained = 64 + acc = pto.vbr(0.0) + all_mask = pto.make_mask(pto.f32, pto.PAT.ALL) + for col in range(0, 64, 64): + mask, remained = pto.make_mask(pto.f32, remained) + vec = pto.vlds(src, col) + acc = pto.vadd(acc, vec, mask) + pto.vsts(acc, dst, 0, all_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertRegex(text, r"%remained_\d+, %acc_\d+ = scf\.for") + self.assertRegex(text, r"iter_args\(%remained_iter_\d+_0 = [^,]+, %acc_iter_\d+_1 = [^)]+\)") + self.assertRegex(text, r"scf\.yield %remained_\d+, %acc_\d+ : i32, !pto\.vreg<64xf32>") + + def test_reduction_and_rearrangement_vector_ops_surface_lowers(self) -> None: + @pto.vkernel( + op="reduction_and_rearrangement_vector_ops_unique", + dtypes=[(pto.i32, pto.i32, pto.i32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile, shift: pto.i32): + all_mask = pto.make_mask(pto.i32, pto.PAT.ALL) + vec0 = pto.vlds(src, 0) + vec1 = pto.vlds(src, 64) + packed_mask = pto.make_mask(pto.ui16, pto.PAT.ALL) + + out = pto.vcgadd(vec0, all_mask) + out = pto.vcgmax(out, all_mask) + out = pto.vcgmin(out, all_mask) + out = pto.vcpadd(out, all_mask) + packed0 = pto.vpack(vec0, pto.PredicatePart.LOWER) + packed1 = pto.vpack(vec1, pto.PredicatePart.HIGHER) + indices = pto.vci(pto.i16(shift), pto.OrderMode.ASC) + packed0 = pto.vperm(packed0, indices, packed_mask) + packed0 = pto.vshift(packed0, pto.i16(shift), packed_mask) + packed0 = pto.vslide(packed0, pto.i16(shift), packed_mask) + packed0 = pto.vmrgsort(packed0, packed1, packed_mask) + out = pto.vsort32(out, all_mask) + pto.vsts(out, dst, 0, all_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("pto.vcgadd", text) + self.assertIn("pto.vcgmax", text) + self.assertIn("pto.vcgmin", text) + self.assertIn("pto.vcpadd", text) + self.assertIn("pto.vpack", text) + self.assertIn("pto.vperm", text) + self.assertIn("pto.vshift", text) + self.assertIn("pto.vslide", text) + self.assertIn("pto.vsort32", text) + self.assertIn("pto.vmrgsort", text) + + def test_scalar_loop_prologue_lowers_without_frontend_vecscope(self) -> None: + @pto.vkernel(op="tadd_outer_scope_unique", dtypes=[(pto.f32, pto.f32, pto.f32)]) + def kernel(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + summed = pto.vadd(lhs, rhs, mask) + pto.vsts(summed, dst[row, col:], mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src0=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src1=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + vecscope_stmts = [stmt for stmt in semantic_kernel.body if isinstance(stmt, SemanticVecscopeStmt)] + self.assertEqual(len(vecscope_stmts), 0) + outer_loop = next(stmt for stmt in semantic_kernel.body if isinstance(stmt, SemanticForStmt)) + self.assertIsInstance(outer_loop, SemanticForStmt) + self.assertIsInstance(outer_loop.body[0], SemanticAssignStmt) + self.assertIsInstance(outer_loop.body[1], SemanticForStmt) + + text = specialized.mlir_text() + self.assertNotIn("pto.vecscope {", text) + self.assertRegex(text, r"scf\.for %row_\d+ = %c0 to %valid_rows_\d+ step %c1") + + def test_unused_tile_does_not_hoist_tile_buf_addr_or_valid_shape_intrinsics(self) -> None: + @pto.vkernel(op="tile_usage_scan_unique", dtypes=[(pto.f32, pto.f32, pto.f32)], advanced=True) + def kernel(dst: pto.Tile, src: pto.Tile, scratch: pto.Tile): + rows, cols = dst.valid_shape + mask = pto.make_mask(dst.element_type, pto.PAT.ALL) + for row in range(0, rows, 1): + for col in range(0, cols, pto.get_lanes(dst.element_type)): + value = pto.vlds(src[row, col:]) + pto.vsts(value, dst[row, col:], mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + scratch=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn("pto.tile_buf_addr %arg0", text) + self.assertIn("pto.tile_buf_addr %arg1", text) + self.assertNotIn("pto.tile_buf_addr %arg2", text) + self.assertIn("pto.tile_valid_rows %arg0", text) + self.assertIn("pto.tile_valid_cols %arg0", text) + self.assertNotIn("pto.tile_valid_rows %arg1", text) + self.assertNotIn("pto.tile_valid_cols %arg1", text) + self.assertNotIn("pto.tile_valid_rows %arg2", text) + self.assertNotIn("pto.tile_valid_cols %arg2", text) + + def test_tile_dynamic_valid_shape_profile_lowers_to_runtime_bounds_in_advanced_mode(self) -> None: + elem = pto.TypeVar("Elem") + + @pto.vkernel(op="tadd_dynamic_valid_shape_unique", dtypes=[(elem, elem, elem)], advanced=True) + def kernel(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + remained = valid_cols + for row in range(0, valid_rows, 1): + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + summed = pto.vadd(pto.vlds(src0[row, col:]), pto.vlds(src1[row, col:]), mask) + pto.vsts(summed, dst[row, col:], mask) + return None + + selected = pto.select_kernel( + "a5", + "tadd_dynamic_valid_shape_unique", + (pto.f16, pto.f16, pto.f16), + ) + specialized = selected.specialize( + dst=pto.TileSpecialization( + shape=(8, 128), + memory_space=pto.MemorySpace.UB, + valid_shape=("valid_rows", "valid_cols"), + ), + src0=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + src1=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + self.assertEqual( + [(param.name, param.kind) for param in semantic_kernel.parameters], + [ + ("dst", "tile"), + ("src0", "tile"), + ("src1", "tile"), + ("__valid_shape_dst_0", "tile_valid_shape"), + ("__valid_shape_dst_1", "tile_valid_shape"), + ], + ) + self.assertEqual(semantic_kernel.tile_bindings[0].valid_shape, (None, None)) + + text = specialized.mlir_text() + self.assertIn( + "func.func @kernel(%arg0: !pto.tile_buf, %arg1: !pto.tile_buf, %arg2: !pto.tile_buf) attributes { pto.tilelang.instance, pto.kernel_kind = #pto.kernel_kind } {", + text, + ) + self.assertIn("valid_shape=(?, ?)", text) + self.assertNotIn("pto.vecscope {", text) + self.assertIn("step %c128", text) + self.assertIn("pto.tile_valid_rows %arg0", text) + self.assertIn("pto.tile_valid_cols %arg0", text) + self.assertNotIn("pto.tile_valid_rows %arg1", text) + self.assertNotIn("pto.tile_valid_cols %arg1", text) + self.assertNotIn("pto.tile_valid_rows %arg2", text) + self.assertNotIn("pto.tile_valid_cols %arg2", text) + self.assertLess(text.index("pto.tile_valid_rows %arg0"), text.index("scf.for %row_")) + self.assertLess(text.index("pto.tile_valid_cols %arg0"), text.index("scf.for %row_")) + self.assertRegex(text, r"scf\.for %row_\d+ = %c0 to %valid_rows_\d+ step %c1") + self.assertRegex(text, r"scf\.for %col_\d+ = %c0 to %valid_cols_\d+ step %c128") + self.assertRegex(text, r"%tmp_\d+ = arith\.index_cast %valid_cols_\d+ : index to i32") + self.assertRegex( + text, + r"pto\.tile_buf_addr %arg1 : !pto\.tile_buf> to memref<\?x\?xf16, strided<\[\?, \?\], offset: \?>, #pto\.address_space>", + ) + self.assertRegex( + text, + r"pto\.vlds %tmp_\d+\[%c0\] : memref<\?x\?xf16, strided<\[\?, \?\], offset: \?>, #pto\.address_space> -> !pto\.vreg<128xf16>", + ) + self.assertRegex( + text, + r"pto\.vsts %summed_\d+, %tmp_\d+\[%c0\], %mask_\d+ : !pto\.vreg<128xf16>, memref<\?x\?xf16, strided<\[\?, \?\], offset: \?>, #pto\.address_space>, !pto\.mask", + ) + + def test_tile_valid_shape_subscript_profile_lowers_to_runtime_bounds_in_advanced_mode(self) -> None: + @pto.vkernel(op="tile_valid_shape_subscript_unique", dtypes=[(pto.f16,)], advanced=True) + def kernel(dst: pto.Tile): + valid_rows = dst.valid_shape[0] + valid_cols = dst.valid_shape[1] + area = valid_rows * valid_cols + if area == 0: + area = 1 + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization( + shape=(8, 128), + memory_space=pto.MemorySpace.UB, + valid_shape=("valid_rows", "valid_cols"), + ), + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + self.assertEqual( + [(param.name, param.kind) for param in semantic_kernel.parameters], + [ + ("dst", "tile"), + ("__valid_shape_dst_0", "tile_valid_shape"), + ("__valid_shape_dst_1", "tile_valid_shape"), + ], + ) + valid_rows_assign = semantic_kernel.body[0] + valid_cols_assign = semantic_kernel.body[1] + self.assertIsInstance(valid_rows_assign, SemanticAssignStmt) + self.assertIsInstance(valid_cols_assign, SemanticAssignStmt) + self.assertIsInstance(valid_rows_assign.targets[0].type, SemanticIndexType) + self.assertIsInstance(valid_cols_assign.targets[0].type, SemanticIndexType) + + text = specialized.mlir_text() + self.assertIn("valid_shape=(?, ?)", text) + self.assertRegex(text, r"%valid_rows_\d+ = pto\.tile_valid_rows %arg0") + self.assertRegex(text, r"%valid_cols_\d+ = pto\.tile_valid_cols %arg0") + + def test_tile_partial_dynamic_valid_shape_profile_tracks_dynamic_axes_only(self) -> None: + elem = pto.TypeVar("Elem") + + @pto.vkernel(op="tadd_partial_dynamic_valid_shape_unique", dtypes=[(elem, elem, elem)], advanced=True) + def kernel(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + remained = valid_cols + for row in range(0, valid_rows, 1): + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + summed = pto.vadd(pto.vlds(src0[row, col:]), pto.vlds(src1[row, col:]), mask) + pto.vsts(summed, dst[row, col:], mask) + return None + + selected = pto.select_kernel( + "a5", + "tadd_partial_dynamic_valid_shape_unique", + (pto.f16, pto.f16, pto.f16), + ) + + rows_dynamic = selected.specialize( + dst=pto.TileSpecialization( + shape=(8, 128), + memory_space=pto.MemorySpace.UB, + valid_shape=("valid_rows", 128), + ), + src0=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + src1=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + rows_dynamic_semantic = analyze_frontend_kernel(build_frontend_kernel_node(rows_dynamic)) + self.assertEqual( + [(param.name, param.kind) for param in rows_dynamic_semantic.parameters], + [ + ("dst", "tile"), + ("src0", "tile"), + ("src1", "tile"), + ("__valid_shape_dst_0", "tile_valid_shape"), + ], + ) + rows_dynamic_text = rows_dynamic.mlir_text() + self.assertIn("valid_shape=(?, 128)", rows_dynamic_text) + self.assertIn("pto.tile_valid_rows %arg0", rows_dynamic_text) + self.assertIn("pto.tile_valid_cols %arg0", rows_dynamic_text) + self.assertRegex(rows_dynamic_text, r"scf\.for %row_\d+ = %c0 to %valid_rows_\d+ step %c1") + self.assertRegex(rows_dynamic_text, r"scf\.for %col_\d+ = %c0 to %valid_cols_\d+ step %c128") + + cols_dynamic = selected.specialize( + dst=pto.TileSpecialization( + shape=(8, 128), + memory_space=pto.MemorySpace.UB, + valid_shape=(8, "valid_cols"), + ), + src0=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + src1=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + cols_dynamic_semantic = analyze_frontend_kernel(build_frontend_kernel_node(cols_dynamic)) + self.assertEqual( + [(param.name, param.kind) for param in cols_dynamic_semantic.parameters], + [ + ("dst", "tile"), + ("src0", "tile"), + ("src1", "tile"), + ("__valid_shape_dst_1", "tile_valid_shape"), + ], + ) + cols_dynamic_text = cols_dynamic.mlir_text() + self.assertIn("valid_shape=(8, ?)", cols_dynamic_text) + self.assertIn("pto.tile_valid_rows %arg0", cols_dynamic_text) + self.assertIn("pto.tile_valid_cols %arg0", cols_dynamic_text) + self.assertRegex(cols_dynamic_text, r"scf\.for %row_\d+ = %c0 to %valid_rows_\d+ step %c1") + self.assertRegex(cols_dynamic_text, r"scf\.for %col_\d+ = %c0 to %valid_cols_\d+ step %c128") + + def test_advanced_mode_scalar_assignments_lowers_without_frontend_vecscope(self) -> None: + @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32)], advanced=True) + def kernel(src: pto.Tile, dst: pto.Tile): + dtype = src.element_type + first_mask = pto.make_mask(dtype, pto.PAT.ALL) + first = pto.vlds(src[0, 0:]) + pto.vsts(first, dst[0, 0:], first_mask) + boundary = 1 + second_mask = pto.make_mask(dtype, pto.PAT.ALL) + second = pto.vlds(src[1, 0:]) + pto.vsts(second, dst[1, 0:], second_mask) + return None + + specialized = kernel.specialize( + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + vecscope_stmts = [stmt for stmt in semantic_kernel.body if isinstance(stmt, SemanticVecscopeStmt)] + self.assertEqual(len(vecscope_stmts), 0) + + text = specialized.mlir_text() + self.assertNotIn("pto.vecscope {", text) + boundary_index = text.index("%boundary_") + first_vsts = text.index("pto.vsts") + second_vsts = text.rindex("pto.vsts") + self.assertLess(first_vsts, boundary_index) + self.assertLess(boundary_index, second_vsts) + self.assertLess(boundary_index, text.index("return")) + + def test_explicit_vecscope_is_supported_in_stable_mode(self) -> None: + @pto.vkernel(op="explicit_vecscope_stable_unique", dtypes=[(pto.f32, pto.f32)]) + def kernel(src: pto.Tile, dst: pto.Tile): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + with pto.vecscope(): + vec = pto.vlds(src, 0) + pto.vsts(vec, dst, 0, mask) + return None + + specialized = kernel.specialize( + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + frontend_kernel = build_frontend_kernel_node(specialized) + self.assertIsInstance(frontend_kernel.body[1], FrontendVecscopeStmt) + + semantic_kernel = analyze_frontend_kernel(frontend_kernel) + vecscope_stmts = [stmt for stmt in semantic_kernel.body if isinstance(stmt, SemanticVecscopeStmt)] + self.assertEqual(len(vecscope_stmts), 1) + + text = specialized.mlir_text() + self.assertEqual(text.count("pto.vecscope {"), 1) + self.assertIn("pto.vlds", text) + self.assertIn("pto.vsts", text) + + def test_explicit_vecscope_does_not_trigger_additional_frontend_inference(self) -> None: + @pto.vkernel(op="explicit_vecscope_disables_infer_unique", dtypes=[(pto.f32, pto.f32)], advanced=True) + def kernel(src: pto.Tile, dst: pto.Tile): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + with pto.vecscope(): + first = pto.vlds(src, 0) + pto.vsts(first, dst, 0, mask) + second = pto.vlds(src, 64) + pto.vsts(second, dst, 64, mask) + return None + + specialized = kernel.specialize( + src=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + dst=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + vecscope_stmts = [stmt for stmt in semantic_kernel.body if isinstance(stmt, SemanticVecscopeStmt)] + self.assertEqual(len(vecscope_stmts), 1) + + text = specialized.mlir_text() + self.assertEqual(text.count("pto.vecscope {"), 1) + self.assertIn("pto.vlds", text) + self.assertIn("pto.vsts", text) + + def test_constexpr_if_tail_store_lowers_without_frontend_vecscope(self) -> None: + @pto.vkernel(op="trowsum_like_vecscope_unique", dtypes=[(pto.f32, pto.f32, pto.f32)], advanced=True) + def kernel(dst: pto.Tile, src: pto.Tile, tmp: pto.Tile): + src_dtype = src.element_type + valid_rows, valid_cols = src.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + acc = pto.vbr(0.0) + for col in range(0, valid_cols, pto.get_lanes(src_dtype)): + mask, remained = pto.make_mask(src_dtype, remained) + vec = pto.vlds(src[row, col:]) + reduced = pto.vcadd(vec, mask) + one_mask, _ = pto.make_mask(src_dtype, 1) + acc = pto.vadd(acc, reduced, one_mask) + out_mask, _ = pto.make_mask(src_dtype, 1) + if pto.constexpr(src_dtype != dst.element_type): + casted = pto.vcvt(acc, dst.element_type, out_mask) + pto.vsts(casted, dst[row, 0:], out_mask) + else: + pto.vsts(acc, dst[row, 0:], out_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + tmp=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + vecscope_stmts = [stmt for stmt in semantic_kernel.body if isinstance(stmt, SemanticVecscopeStmt)] + self.assertEqual(len(vecscope_stmts), 0) + + text = specialized.mlir_text() + self.assertNotIn("pto.vecscope {", text) + self.assertRegex(text, r"scf\.for %row_\d+") + self.assertIn("pto.vsts", text) + + def test_advanced_mode_control_flow_lowers_without_frontend_vecscope_per_branch(self) -> None: + @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32, pto.i32)], advanced=True) + def kernel(src: pto.Tile, dst: pto.Tile, flag: pto.i32): + dtype = src.element_type + all_mask = pto.make_mask(dtype, pto.PAT.ALL) + if flag: + first = pto.vlds(src[0, 0:]) + pto.vsts(first, dst[0, 0:], all_mask) + else: + second = pto.vlds(src[1, 0:]) + pto.vsts(second, dst[1, 0:], all_mask) + return None + + specialized = kernel.specialize( + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + self.assertEqual([type(stmt).__name__ for stmt in semantic_kernel.body[:-1]], [ + "SemanticAssignStmt", + "SemanticAssignStmt", + "SemanticIfStmt", + ]) + if_stmt = semantic_kernel.body[2] + self.assertIsInstance(if_stmt, SemanticIfStmt) + self.assertEqual(len(if_stmt.then_body), 2) + self.assertEqual(len(if_stmt.else_body), 2) + self.assertFalse(any(isinstance(stmt, SemanticVecscopeStmt) for stmt in if_stmt.then_body)) + self.assertFalse(any(isinstance(stmt, SemanticVecscopeStmt) for stmt in if_stmt.else_body)) + + text = specialized.mlir_text() + self.assertIn("scf.if", text) + self.assertNotIn("pto.vecscope {", text) + self.assertLess(text.index("scf.if"), text.index("return")) + + def test_advanced_mode_keeps_strict_vecscope_as_hard_boundary(self) -> None: + @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32)], advanced=True) + def kernel(src: pto.Tile, dst: pto.Tile): + all_mask = pto.make_mask(pto.f32, pto.PAT.ALL) + rows = src.shape[0] + for row in range(0, rows, 1): + vec = pto.vlds(src[row, 0:]) + pto.vsts(vec, dst[row, 0:], all_mask) + with pto.strict_vecscope(src, dst, all_mask, 0, 64, 64) as (vin, vout, mask, lb, ub, step): + for lane in range(lb, ub, step): + scoped = pto.vlds(vin, lane) + pto.vsts(scoped, vout, lane, mask) + return None + + specialized = kernel.specialize( + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertNotIn("pto.vecscope {", text) + self.assertEqual(text.count("pto.strict_vecscope("), 1) + + def test_advanced_mode_lowers_raw_pointer_and_low_level_dma_surface(self) -> None: + @pto.vkernel(op="ptr_dma", dtypes=[(pto.f32, pto.f32, pto.i64)], advanced=True) + def kernel( + src_gm: pto.ptr(pto.f32, pto.MemorySpace.GM), + dst_gm: pto.ptr(pto.f32, pto.MemorySpace.GM), + addr: pto.i64, + ): + ub_src = pto.castptr(addr, pto.ptr(pto.f32, pto.MemorySpace.UB)) + ub_dst = pto.addptr(ub_src, 64) + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + vec = pto.vlds(ub_src, 0) + pto.vsts(vec, ub_dst, 0, mask) + + src_bytes = pto.castptr(src_gm, pto.ptr(pto.i8, pto.MemorySpace.GM)) + dst_bytes = pto.castptr(dst_gm, pto.ptr(pto.i8, pto.MemorySpace.GM)) + src_offset = pto.addptr(src_bytes, 0) + dst_offset = pto.addptr(dst_bytes, 0) + typed_src = pto.castptr(src_offset, pto.ptr(pto.f32, pto.MemorySpace.GM)) + typed_dst = pto.castptr(dst_offset, pto.ptr(pto.f32, pto.MemorySpace.GM)) + + pto.set_loop2_stride_outtoub(4096, 4096) + pto.set_loop1_stride_outtoub(4096, 4096) + pto.set_loop_size_outtoub(1, 1) + pto.copy_gm_to_ubuf(typed_src, ub_src, 0, 32, 128, 0, 0, False, 0, 128, 128) + + pto.set_loop2_stride_ubtoout(4096, 4096) + pto.set_loop1_stride_ubtoout(4096, 4096) + pto.set_loop_size_ubtoout(1, 1) + pto.copy_ubuf_to_ubuf(ub_src, ub_dst, 0, 32, 128, 128, 128) + pto.copy_ubuf_to_gm(ub_dst, typed_dst, 0, 32, 128, 0, 128, 128) + return None + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(kernel)) + self.assertIsInstance(semantic_kernel.parameters[0].type, SemanticPtrType) + self.assertEqual(semantic_kernel.parameters[0].type.memory_space, "gm") + self.assertIsInstance(semantic_kernel.parameters[1].type, SemanticPtrType) + self.assertEqual(semantic_kernel.parameters[1].type.memory_space, "gm") + self.assertTrue(any(isinstance(stmt, SemanticDmaConfigStmt) for stmt in semantic_kernel.body)) + self.assertTrue(any(isinstance(stmt, SemanticLowLevelCopyStmt) for stmt in semantic_kernel.body)) + vecscope_stmts = [stmt for stmt in semantic_kernel.body if isinstance(stmt, SemanticVecscopeStmt)] + self.assertEqual(len(vecscope_stmts), 0) + + text = kernel.mlir_text() + self.assertIn( + "func.func @kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: i64) attributes { pto.tilelang.instance, pto.kernel_kind = #pto.kernel_kind } {", + text, + ) + self.assertRegex( + text, + r"%ub_src_\d+ = pto\.castptr %arg2 : i64 -> !pto\.ptr", + ) + self.assertRegex( + text, + r"%ub_dst_\d+ = pto\.addptr %ub_src_\d+, %c64 : !pto\.ptr -> !pto\.ptr", + ) + self.assertNotIn("pto.vecscope {", text) + self.assertRegex( + text, + r"%vec_\d+ = pto\.vlds %ub_src_\d+\[%c0\] : !pto\.ptr -> !pto\.vreg<64xf32>", + ) + self.assertRegex( + text, + r"pto\.vsts %vec_\d+, %ub_dst_\d+\[%c0\], %mask_\d+ : !pto\.vreg<64xf32>, !pto\.ptr, !pto\.mask", + ) + self.assertRegex( + text, + r"%src_bytes_\d+ = pto\.castptr %arg0 : !pto\.ptr -> !pto\.ptr", + ) + self.assertRegex( + text, + r"%dst_bytes_\d+ = pto\.castptr %arg1 : !pto\.ptr -> !pto\.ptr", + ) + self.assertRegex( + text, + r"%src_offset_\d+ = pto\.addptr %src_bytes_\d+, %c0 : !pto\.ptr -> !pto\.ptr", + ) + self.assertRegex( + text, + r"%dst_offset_\d+ = pto\.addptr %dst_bytes_\d+, %c0 : !pto\.ptr -> !pto\.ptr", + ) + self.assertRegex( + text, + r"pto\.set_loop2_stride_outtoub %tmp_\d+, %tmp_\d+ : i64, i64", + ) + self.assertRegex( + text, + r"pto\.set_loop1_stride_outtoub %tmp_\d+, %tmp_\d+ : i64, i64", + ) + self.assertRegex( + text, + r"pto\.set_loop_size_outtoub %tmp_\d+, %tmp_\d+ : i64, i64", + ) + self.assertRegex( + text, + r"pto\.copy_gm_to_ubuf %typed_src_\d+, %ub_src_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %false, %tmp_\d+, %tmp_\d+, %tmp_\d+", + ) + self.assertIn( + ": !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64", + text, + ) + self.assertRegex( + text, + r"pto\.set_loop2_stride_ubtoout %tmp_\d+, %tmp_\d+ : i64, i64", + ) + self.assertRegex( + text, + r"pto\.set_loop1_stride_ubtoout %tmp_\d+, %tmp_\d+ : i64, i64", + ) + self.assertRegex( + text, + r"pto\.set_loop_size_ubtoout %tmp_\d+, %tmp_\d+ : i64, i64", + ) + self.assertRegex( + text, + r"pto\.copy_ubuf_to_ubuf %ub_src_\d+, %ub_dst_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+", + ) + self.assertIn( + ": !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64", + text, + ) + self.assertRegex( + text, + r"pto\.copy_ubuf_to_gm %ub_dst_\d+, %typed_dst_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+", + ) + + def test_as_ptr_method_and_keyword_low_level_dma_surface_lower_in_advanced_mode(self) -> None: + @pto.vkernel(op="tensorview_tile_as_ptr_dma_unique", dtypes=[(pto.f32, pto.f32)], advanced=True) + def kernel(inp: pto.TensorView, dst: pto.Tile): + gm_ptr = inp.as_ptr() + ub_ptr = dst.as_ptr() + + pto.set_loop2_stride_outtoub(src_stride=4096, dst_stride=2048) + pto.set_loop1_stride_outtoub(src_stride=1024, dst_stride=512) + pto.set_loop_size_outtoub(loop1=1, loop2=1) + pto.copy_gm_to_ubuf( + src=gm_ptr, + dst=ub_ptr, + n_burst=1, + len_burst=64, + gm_stride=128, + ub_stride=128, + enable_ub_pad=False, + ) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + self.assertTrue(any(isinstance(stmt, SemanticDmaConfigStmt) for stmt in semantic_kernel.body)) + self.assertTrue(any(isinstance(stmt, SemanticLowLevelCopyStmt) for stmt in semantic_kernel.body)) + + text = specialized.mlir_text() + self.assertRegex( + text, + r"%gm_ptr_\d+ = pto\.tensor_view_addr %arg0 : !pto\.tensor_view<\?x\?x\?x\?x\?xf32> -> !pto\.ptr", + ) + self.assertRegex( + text, + r"%ub_ptr_\d+ = pto\.tile_buf_addr %arg1 : !pto\.tile_buf -> !pto\.ptr", + ) + self.assertRegex(text, r"pto\.set_loop2_stride_outtoub %tmp_\d+, %tmp_\d+ : i64, i64") + self.assertRegex(text, r"pto\.set_loop1_stride_outtoub %tmp_\d+, %tmp_\d+ : i64, i64") + self.assertRegex(text, r"pto\.set_loop_size_outtoub %tmp_\d+, %tmp_\d+ : i64, i64") + self.assertRegex( + text, + r"pto\.copy_gm_to_ubuf %gm_ptr_\d+, %ub_ptr_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %false, %tmp_\d+, %tmp_\d+, %tmp_\d+", + ) + + def test_tile_constructor_binds_body_local_tile_with_default_ub_config(self) -> None: + @pto.vkernel(op="body_local_tile_ctor_unique", dtypes=[(pto.f32,)], advanced=True) + def kernel(inp: pto.TensorView): + buf = pto.Tile([8, 64], pto.f32, pto.MemorySpace.UB) + ptr = buf.as_ptr() + return None + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(kernel)) + assign_stmt = next(stmt for stmt in semantic_kernel.body if isinstance(stmt, SemanticAssignStmt)) + self.assertIsInstance(assign_stmt.value, SemanticCallExpr) + self.assertEqual(assign_stmt.value.name, "alloc_tile") + self.assertIsInstance(assign_stmt.targets[0].type, SemanticTileType) + tile_type = assign_stmt.targets[0].type + self.assertEqual(tile_type.shape, (8, 64)) + self.assertEqual(tile_type.valid_shape, (8, 64)) + self.assertEqual(tile_type.memory_space, "ub") + self.assertIsNotNone(tile_type.config) + self.assertEqual(tile_type.config.b_layout, pto.BLayout.ROW_MAJOR) + self.assertEqual(tile_type.config.s_layout, pto.SLayout.NONE_BOX) + self.assertEqual(tile_type.config.s_fractal_size, 512) + + text = kernel.mlir_text() + self.assertRegex( + text, + r"%buf_\d+ = pto\.alloc_tile : !pto\.tile_buf", + ) + self.assertRegex( + text, + r"%ptr_\d+ = pto\.tile_buf_addr %buf_\d+ : !pto\.tile_buf -> !pto\.ptr", + ) + + def test_tile_constructor_uses_cube_memory_space_default_layouts(self) -> None: + @pto.ckernel(op="pto.mad", dtypes=[(pto.f16,)], name="tile_ctor_defaults_unique") + def kernel(a: pto.PartitionTensorView): + l1 = pto.Tile((16, 32), pto.f16, pto.MemorySpace.MAT) + left = pto.Tile((16, 32), pto.f16, pto.MemorySpace.LEFT) + right = pto.Tile((32, 16), pto.f16, pto.MemorySpace.RIGHT) + acc = pto.Tile((16, 16), pto.f32, pto.MemorySpace.ACC) + bias = pto.Tile((1, 16), pto.f32, pto.MemorySpace.BIAS) + return None + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(kernel)) + tile_assigns = [ + stmt for stmt in semantic_kernel.body + if isinstance(stmt, SemanticAssignStmt) and isinstance(stmt.targets[0].type, SemanticTileType) + ] + self.assertEqual(len(tile_assigns), 5) + + configs_by_name = {stmt.targets[0].name: stmt.targets[0].type.config for stmt in tile_assigns} + self.assertEqual(configs_by_name["l1"].b_layout, pto.BLayout.COL_MAJOR) + self.assertEqual(configs_by_name["l1"].s_layout, pto.SLayout.ROW_MAJOR) + self.assertEqual(configs_by_name["l1"].s_fractal_size, 512) + self.assertEqual(configs_by_name["left"].b_layout, pto.BLayout.COL_MAJOR) + self.assertEqual(configs_by_name["left"].s_layout, pto.SLayout.ROW_MAJOR) + self.assertEqual(configs_by_name["right"].b_layout, pto.BLayout.ROW_MAJOR) + self.assertEqual(configs_by_name["right"].s_layout, pto.SLayout.COL_MAJOR) + self.assertEqual(configs_by_name["acc"].b_layout, pto.BLayout.COL_MAJOR) + self.assertEqual(configs_by_name["acc"].s_layout, pto.SLayout.ROW_MAJOR) + self.assertEqual(configs_by_name["acc"].s_fractal_size, 1024) + self.assertEqual(configs_by_name["bias"].b_layout, pto.BLayout.ROW_MAJOR) + self.assertEqual(configs_by_name["bias"].s_layout, pto.SLayout.NONE_BOX) + + def test_tile_constructor_lowers_cube_alloc_tile_locations(self) -> None: + @pto.ckernel(op="pto.mad", dtypes=[(pto.f16,)], name="tile_ctor_cube_alloc_unique") + def kernel(a: pto.PartitionTensorView): + l1 = pto.Tile((16, 32), pto.f16, pto.MemorySpace.MAT) + left = pto.Tile((16, 32), pto.f16, pto.MemorySpace.LEFT) + right = pto.Tile((32, 16), pto.f16, pto.MemorySpace.RIGHT) + acc = pto.Tile((16, 16), pto.f32, pto.MemorySpace.ACC) + bias = pto.Tile((1, 16), pto.f32, pto.MemorySpace.BIAS) + return None + + text = kernel.mlir_text() + self.assertIn("pto.alloc_tile : !pto.tile_buf", text) + self.assertIn("pto.alloc_tile : !pto.tile_buf", text) + self.assertIn("pto.alloc_tile : !pto.tile_buf", text) + self.assertIn("pto.alloc_tile : !pto.tile_buf", text) + self.assertIn("pto.alloc_tile : !pto.tile_buf", text) + + def test_set_mov_pad_val_lowers_in_advanced_mode(self) -> None: + @pto.vkernel(op="set_mov_pad_val_dma_unique", dtypes=[(pto.f32, pto.f32)], advanced=True) + def kernel(inp: pto.TensorView, dst: pto.Tile): + gm_ptr = inp.as_ptr() + ub_ptr = dst.as_ptr() + + pto.set_mov_pad_val(pad_value=pto.f32(0.0)) + pto.set_loop2_stride_outtoub(src_stride=4096, dst_stride=2048) + pto.set_loop1_stride_outtoub(src_stride=1024, dst_stride=512) + pto.set_loop_size_outtoub(loop1=1, loop2=1) + pto.copy_gm_to_ubuf( + src=gm_ptr, + dst=ub_ptr, + n_burst=1, + len_burst=64, + gm_stride=128, + ub_stride=128, + enable_ub_pad=True, + ) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + self.assertTrue(any(isinstance(stmt, SemanticDmaUnaryConfigStmt) for stmt in semantic_kernel.body)) + + text = specialized.mlir_text() + self.assertRegex(text, r"pto\.set_mov_pad_val %[^ ]+ : f32") + self.assertRegex( + text, + r"pto\.copy_gm_to_ubuf %gm_ptr_\d+, %ub_ptr_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %true, %tmp_\d+, %tmp_\d+, %tmp_\d+", + ) + + def test_set_mov_pad_val_automatically_bitcasts_unsigned_tile_pad_value_to_signless_scalar(self) -> None: + @pto.vkernel(op="set_mov_pad_val_tile_pad_bitcast_unique", dtypes=[(pto.ui16,)], advanced=True) + def kernel(dst: pto.Tile): + pto.set_mov_pad_val(dst.pad_value.eval()) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization( + shape=(260, 32), + memory_space=pto.MemorySpace.UB, + config=pto.TileConfig.from_mapping( + { + "b_layout": "row_major", + "s_layout": "none_box", + "s_fractal_size": 512, + "pad_value": "0x2", + } + ), + valid_shape=(260, 7), + ) + ) + + text = specialized.mlir_text() + self.assertIn("builtin.unrealized_conversion_cast", text) + self.assertRegex(text, r"pto\.set_mov_pad_val %[^ ]+ : i16") + + def test_copy_ubuf_to_gm_keyword_surface_lowers_in_advanced_mode(self) -> None: + @pto.vkernel(op="tile_to_tensorview_dma_unique", dtypes=[(pto.f32, pto.f32)], advanced=True) + def kernel(src: pto.Tile, dst: pto.TensorView): + ub_ptr = src.as_ptr() + gm_ptr = dst.as_ptr() + + pto.set_loop2_stride_ubtoout(src_stride=4096, dst_stride=2048) + pto.set_loop1_stride_ubtoout(src_stride=1024, dst_stride=512) + pto.set_loop_size_ubtoout(loop1=1, loop2=1) + pto.copy_ubuf_to_gm( + src=ub_ptr, + dst=gm_ptr, + n_burst=1, + len_burst=64, + gm_stride=128, + ub_stride=128, + ) + return None + + specialized = kernel.specialize( + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + self.assertTrue(any(isinstance(stmt, SemanticDmaConfigStmt) for stmt in semantic_kernel.body)) + self.assertTrue(any(isinstance(stmt, SemanticLowLevelCopyStmt) for stmt in semantic_kernel.body)) + + text = specialized.mlir_text() + self.assertRegex( + text, + r"%ub_ptr_\d+ = pto\.tile_buf_addr %arg0 : !pto\.tile_buf -> !pto\.ptr", + ) + self.assertRegex( + text, + r"%gm_ptr_\d+ = pto\.tensor_view_addr %arg1 : !pto\.tensor_view<\?x\?x\?x\?x\?xf32> -> !pto\.ptr", + ) + self.assertRegex(text, r"pto\.set_loop2_stride_ubtoout %tmp_\d+, %tmp_\d+ : i64, i64") + self.assertRegex(text, r"pto\.set_loop1_stride_ubtoout %tmp_\d+, %tmp_\d+ : i64, i64") + self.assertRegex(text, r"pto\.set_loop_size_ubtoout %tmp_\d+, %tmp_\d+ : i64, i64") + self.assertRegex( + text, + r"pto\.copy_ubuf_to_gm %ub_ptr_\d+, %gm_ptr_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+, %tmp_\d+", + ) + + def test_castptr_rejects_tensorview_or_tile_inputs_in_advanced_mode(self) -> None: + @pto.vkernel(op="castptr_tensorview_reject_unique", dtypes=[(pto.f32,)], advanced=True) + def tensorview_kernel(inp: pto.TensorView): + tmp = pto.castptr(inp, pto.ptr(pto.f32, pto.MemorySpace.GM)) + return None + + with self.assertRaises(TypeError) as tensorview_ctx: + analyze_frontend_kernel(build_frontend_kernel_node(tensorview_kernel)) + self.assertIn("pto.castptr input must be an index/i64, pointer, or memref-backed address value", str(tensorview_ctx.exception)) + + @pto.vkernel(op="castptr_tile_reject_unique", dtypes=[(pto.f32,)], advanced=True) + def tile_kernel(inp: pto.Tile): + tmp = pto.castptr(inp, pto.ptr(pto.f32, pto.MemorySpace.UB)) + return None + + specialized = tile_kernel.specialize( + inp=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + with self.assertRaises(TypeError) as tile_ctx: + analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + self.assertIn("pto.castptr input must be an index/i64, pointer, or memref-backed address value", str(tile_ctx.exception)) + + def test_constexpr_if_folds_static_dtype_condition_without_scf_if(self) -> None: + @pto.vkernel(op="constexpr_if_dtype_fold", dtypes=[(pto.f16, pto.f32)]) + def kernel(dst: pto.Tile, src: pto.Tile): + step = 64 + if pto.constexpr(dst.element_type != src.element_type): + step = 128 + else: + step = 64 + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + self.assertFalse(any(isinstance(stmt, SemanticIfStmt) for stmt in semantic_kernel.body)) + + text = specialized.mlir_text() + self.assertNotIn("scf.if", text) + self.assertNotIn("arith.cmpi ne", text) + self.assertRegex(text, r"%step_\d+ = arith\.constant 128 : index") + + def test_constexpr_if_rejects_non_static_condition(self) -> None: + @pto.vkernel(op="constexpr_if_dynamic_reject", dtypes=[(pto.f32,)]) + def kernel(src: pto.TensorView): + step = 64 + if pto.constexpr(src.shape[0] != 1): + step = 128 + return None + + with self.assertRaises(TypeError) as ctx: + analyze_frontend_kernel(build_frontend_kernel_node(kernel)) + self.assertIn( + "if pto.constexpr(...) condition must be a compile-time bool", + str(ctx.exception), + ) + + def test_if_compare_or_condition_lowers_to_cmp_and_bool_ops(self) -> None: + @pto.vkernel(op="if_compare_or", dtypes=[(pto.f32,)]) + def kernel(src: pto.TensorView): + loop1 = src.shape[3] + loop2 = src.shape[4] + step = 64 + if loop1 != 1 or loop2 != 1: + step = 128 + else: + step = 64 + return None + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(kernel)) + self.assertEqual( + [(param.name, param.kind) for param in semantic_kernel.parameters], + [("src", "tensorview")], + ) + self.assertIsInstance(semantic_kernel.body[3], SemanticIfStmt) + condition = semantic_kernel.body[3].condition + self.assertIsInstance(condition, SemanticBinaryExpr) + self.assertEqual(condition.op, "or") + self.assertIsInstance(condition.lhs, SemanticBinaryExpr) + self.assertEqual(condition.lhs.op, "ne") + self.assertIsInstance(condition.rhs, SemanticBinaryExpr) + self.assertEqual(condition.rhs.op, "ne") + + text = kernel.mlir_text() + self.assertEqual(text.count("arith.cmpi ne"), 2) + self.assertRegex(text, r"%loop1_\d+ = pto\.get_tensor_view_dim %arg0, %c3 : !pto\.tensor_view<\?x\?x\?x\?x\?xf32> -> index") + self.assertRegex(text, r"%loop2_\d+ = pto\.get_tensor_view_dim %arg0, %c4 : !pto\.tensor_view<\?x\?x\?x\?x\?xf32> -> index") + self.assertRegex(text, r"arith\.cmpi ne, %loop1_\d+, %c1 : index") + self.assertRegex(text, r"arith\.cmpi ne, %loop2_\d+, %c1 : index") + self.assertRegex(text, r"arith\.ori %tmp_\d+, %tmp_\d+ : i1") + self.assertRegex(text, r"%step_\d+ = scf\.if %tmp_\d+ -> \(index\) \{") + + def test_if_ordered_index_comparisons_lower_to_signed_cmp_predicates(self) -> None: + @pto.vkernel(op="if_compare_ordered_index", dtypes=[(pto.f32,)]) + def kernel(src: pto.TensorView): + dim0 = src.shape[0] + dim1 = src.shape[1] + step = 64 + if dim0 > 1 and dim1 <= 8: + step = 128 + else: + step = 32 + return None + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(kernel)) + self.assertIsInstance(semantic_kernel.body[3], SemanticIfStmt) + condition = semantic_kernel.body[3].condition + self.assertIsInstance(condition, SemanticBinaryExpr) + self.assertEqual(condition.op, "and") + + text = kernel.mlir_text() + self.assertRegex(text, r"arith\.cmpi sgt, %dim0_\d+, %c1 : index") + self.assertRegex(text, r"arith\.cmpi sle, %dim1_\d+, %c8 : index") + self.assertRegex(text, r"arith\.andi %tmp_\d+, %tmp_\d+ : i1") + self.assertRegex(text, r"%step_\d+ = scf\.if %tmp_\d+ -> \(index\) \{") + + def test_if_ordered_float_comparison_lowers_to_cmpf_predicate(self) -> None: + @pto.vkernel(op="if_compare_ordered_float", dtypes=[(pto.f32, pto.f32, pto.f32)]) + def kernel(src: pto.TensorView, lhs: pto.f32, rhs: pto.f32): + step = 64 + if lhs > rhs: + step = 128 + else: + step = 64 + return None + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(kernel)) + self.assertIsInstance(semantic_kernel.body[1], SemanticIfStmt) + + text = kernel.mlir_text() + self.assertRegex(text, r"arith\.cmpf ogt, %arg1, %arg2 : f32") + self.assertRegex(text, r"%step_\d+ = scf\.if %tmp_\d+ -> \(index\) \{") + + def test_shape_and_stride_tuple_unpacking_lower_cleanly(self) -> None: + @pto.vkernel(op="shape_stride_unpack", dtypes=[(pto.f32, pto.f32)], advanced=True) + def kernel(src: pto.TensorView, dst: pto.Tile): + g0, g1, g2, g3, g4 = src.shape + s0, s1, s2, s3, s4 = src.strides + ub_rows, ub_cols = dst.shape + total = g0 + g1 + g2 + g3 + g4 + stride_total = s0 + s1 + s2 + s3 + s4 + area = ub_rows * ub_cols + if total != 0 or stride_total != area: + total = area + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + self.assertEqual( + [(param.name, param.kind) for param in semantic_kernel.parameters], + [("src", "tensorview"), ("dst", "tile")], + ) + + text = specialized.mlir_text() + self.assertEqual(text.count("pto.get_tensor_view_dim"), 5) + self.assertEqual(text.count("pto.get_tensor_view_stride"), 5) + self.assertRegex(text, r"%ub_rows_\d+ = arith\.constant 8 : index") + self.assertRegex(text, r"%ub_cols_\d+ = arith\.constant 64 : index") + + def test_shape_subscript_rejects_non_literal_index_in_semantic(self) -> None: + @pto.vkernel(op="shape_dynamic_subscript_reject_unique", dtypes=[(pto.f32,)]) + def kernel(src: pto.TensorView): + axis = src.shape[0] + value = src.shape[axis] + return None + + with self.assertRaises(TypeError) as ctx: + analyze_frontend_kernel(build_frontend_kernel_node(kernel)) + self.assertIn( + "shape/stride/valid_shape subscript index must be an integer literal in TileLang DSL v1", + str(ctx.exception), + ) + + def test_valid_shape_subscript_rejects_non_literal_index_in_semantic(self) -> None: + @pto.vkernel(op="valid_shape_dynamic_subscript_reject_unique", dtypes=[(pto.f16,)], advanced=True) + def kernel(dst: pto.Tile): + axis = 0 + value = dst.valid_shape[axis] + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization( + shape=(8, 128), + memory_space=pto.MemorySpace.UB, + valid_shape=("valid_rows", "valid_cols"), + ) + ) + + with self.assertRaises(TypeError) as ctx: + analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + self.assertIn( + "tuple subscript index must be an integer literal in TileLang DSL v1", + str(ctx.exception), + ) + + def test_tuple_call_result_subscript_rejects_in_semantic(self) -> None: + @pto.vkernel(op="tuple_call_result_subscript_reject_unique", dtypes=[(pto.f16,)], advanced=True) + def kernel(dst: pto.Tile): + mask = pto.make_mask(dst.element_type, pto.i32(64))[0] + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 128), memory_space=pto.MemorySpace.UB), + ) + + with self.assertRaises(TypeError) as ctx: + analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + self.assertIn( + "tuple subscripting currently requires a shape-like tuple expression in TileLang DSL v1", + str(ctx.exception), + ) + + def test_advanced_mode_lowers_compare_predicate_carry_and_rearrangement_families(self) -> None: + @pto.vkernel(op="advanced_family", dtypes=[(pto.i32, pto.i32, pto.i32, pto.i32)], advanced=True) + def kernel(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile, scalar: pto.i32): + all_mask = pto.make_mask(pto.i32, pto.PAT.ALL) + lhs = pto.vlds(src0[0, 0:]) + rhs = pto.vlds(src1[0, 0:]) + cmp_mask = pto.vcmp(lhs, rhs, all_mask, pto.CmpMode.LT) + cmp_scalar_mask = pto.vcmps(lhs, scalar, all_mask, pto.CmpMode.GT) + negated = pto.pnot(cmp_mask, all_mask) + picked = pto.psel(cmp_mask, negated, cmp_scalar_mask) + packed = pto.ppack(picked, pto.PredicatePart.LOWER) + unpacked = pto.punpack(packed, pto.PredicatePart.HIGHER) + sum_vec, carry_mask = pto.vaddc(lhs, rhs, all_mask) + diff_vec, borrow_mask = pto.vsubc(lhs, rhs, all_mask) + sum_with_carry, carry_mask2 = pto.vaddcs(sum_vec, diff_vec, carry_mask, all_mask) + diff_with_borrow, borrow_mask2 = pto.vsubcs(sum_with_carry, diff_vec, borrow_mask, all_mask) + low, high = pto.vintlv(sum_with_carry, diff_with_borrow) + dlow, dhigh = pto.vdintlv(low, high) + even = pto.vintlvv2(dlow, dhigh, "PART_EVEN") + odd = pto.vdintlvv2(dlow, dhigh, "PART_ODD") + selected = pto.vsel(even, odd, unpacked) + selected_r = pto.vselr(selected, sum_with_carry) + final = pto.vselrv2(selected_r, diff_with_borrow) + pto.vsts(final, dst[0, 0:], all_mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src0=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + src1=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + vecscope_stmts = [stmt for stmt in semantic_kernel.body if isinstance(stmt, SemanticVecscopeStmt)] + self.assertEqual(len(vecscope_stmts), 0) + + text = specialized.mlir_text() + self.assertNotIn("pto.vecscope {", text) + self.assertIn('pto.vcmp ', text) + self.assertIn(', "lt" : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.mask', text) + self.assertIn('pto.vcmps ', text) + self.assertIn(', "gt" : !pto.vreg<64xi32>, i32, !pto.mask -> !pto.mask', text) + self.assertIn(" = pto.pnot ", text) + self.assertIn(" = pto.psel ", text) + self.assertIn(' = pto.ppack ', text) + self.assertIn('"LOWER"', text) + self.assertIn(' = pto.punpack ', text) + self.assertIn('"HIGHER"', text) + self.assertRegex( + text, + r"%sum_vec_\d+, %carry_mask_\d+ = pto\.vaddc %lhs_\d+, %rhs_\d+, %all_mask_\d+ : !pto\.vreg<64xi32>, !pto\.vreg<64xi32>, !pto\.mask -> !pto\.vreg<64xi32>, !pto\.mask", + ) + self.assertRegex( + text, + r"%diff_vec_\d+, %borrow_mask_\d+ = pto\.vsubc %lhs_\d+, %rhs_\d+, %all_mask_\d+ : !pto\.vreg<64xi32>, !pto\.vreg<64xi32>, !pto\.mask -> !pto\.vreg<64xi32>, !pto\.mask", + ) + self.assertRegex( + text, + r"%sum_with_carry_\d+, %carry_mask2_\d+ = pto\.vaddcs %sum_vec_\d+, %diff_vec_\d+, %carry_mask_\d+, %all_mask_\d+ : !pto\.vreg<64xi32>, !pto\.vreg<64xi32>, !pto\.mask, !pto\.mask -> !pto\.vreg<64xi32>, !pto\.mask", + ) + self.assertRegex( + text, + r"%diff_with_borrow_\d+, %borrow_mask2_\d+ = pto\.vsubcs %sum_with_carry_\d+, %diff_vec_\d+, %borrow_mask_\d+, %all_mask_\d+ : !pto\.vreg<64xi32>, !pto\.vreg<64xi32>, !pto\.mask, !pto\.mask -> !pto\.vreg<64xi32>, !pto\.mask", + ) + self.assertRegex( + text, + r"%low_\d+, %high_\d+ = pto\.vintlv %sum_with_carry_\d+, %diff_with_borrow_\d+ : !pto\.vreg<64xi32>, !pto\.vreg<64xi32> -> !pto\.vreg<64xi32>, !pto\.vreg<64xi32>", + ) + self.assertRegex( + text, + r"%dlow_\d+, %dhigh_\d+ = pto\.vdintlv %low_\d+, %high_\d+ : !pto\.vreg<64xi32>, !pto\.vreg<64xi32> -> !pto\.vreg<64xi32>, !pto\.vreg<64xi32>", + ) + self.assertIn(" = pto.vintlvv2 ", text) + self.assertIn(" = pto.vdintlvv2 ", text) + self.assertIn(" = pto.vsel ", text) + self.assertIn(" = pto.vselr ", text) + self.assertIn(" = pto.vselrv2 ", text) + self.assertIn("pto.vsts ", text) + + def test_vbitcast_and_mem_bar_with_vector_users_lower_without_frontend_vecscope(self) -> None: + @pto.vkernel(op="issue_217_vecscope", dtypes=[(pto.i32, pto.ui8)], advanced=True) + def kernel(src: pto.Tile, dst: pto.Tile): + valid_rows, valid_cols = dst.valid_shape + full_mask = pto.make_mask(pto.i32, pto.PAT.ALL) + idx_mask = pto.make_mask(pto.i16, pto.PAT.ALL) + v_idx = pto.vci(pto.i8(0), pto.OrderMode.ASC) + v_idx_i16 = pto.vbitcast(v_idx, pto.i16) + v_idx_i16 = pto.vmuls(v_idx_i16, pto.i16(4), idx_mask) + v_idx_ui8 = pto.vbitcast(v_idx_i16, pto.ui8) + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(pto.i32)): + store_mask, remained = pto.make_mask(pto.ui8, remained) + vec = pto.vlds(src[row, col:]) + converted = pto.vcvt( + vec, + pto.ui8, + full_mask, + sat=pto.VcvtSatMode.NOSAT, + part=pto.VcvtPartMode.P0, + ) + result = pto.vselr(converted, v_idx_ui8) + pto.mem_bar(pto.BarrierType.VST_VST) + pto.vsts(result, dst[row, col:], store_mask, dist=pto.VStoreDist.NORM_B8) + + return None + + specialized = kernel.specialize( + src=pto.TileSpecialization(shape=(16, 64), memory_space=pto.MemorySpace.UB), + dst=pto.TileSpecialization(shape=(16, 64), memory_space=pto.MemorySpace.UB), + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + vecscope_stmts = [stmt for stmt in semantic_kernel.body if isinstance(stmt, SemanticVecscopeStmt)] + self.assertEqual(len(vecscope_stmts), 0) + + text = specialized.mlir_text() + self.assertNotIn("pto.vecscope {", text) + self.assertIn("pto.vbitcast", text) + self.assertIn('pto.mem_bar "VST_VST"', text) + self.assertIn("pto.vselr", text) + self.assertIn("pto.vsts", text) + + def test_scalar_get_lanes_between_vector_def_and_use_lowers_without_frontend_vecscope(self) -> None: + @pto.vkernel(op="issue_240_vecscope", dtypes=[(pto.si8, pto.i32)], advanced=True) + def kernel(src: pto.Tile, dst: pto.Tile): + valid_rows, valid_cols = dst.valid_shape + b8_mask = pto.make_mask(pto.ui8, pto.PAT.ALL) + v_zero = pto.vdup(pto.ui8(0), b8_mask) + lanes_i32 = pto.get_lanes(pto.i32) + lanes_i16 = pto.get_lanes(pto.i16) + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, lanes_i16): + mask_b16_cur, remained = pto.make_mask(pto.i16, remained) + mask_b16_next, remained2 = pto.make_mask(pto.i16, remained) + mask_b32_cur = pto.punpack(mask_b16_cur, pto.PredicatePart.LOWER) + mask_b32_next = pto.punpack(mask_b16_next, pto.PredicatePart.LOWER) + vec_si8 = pto.vlds(src[row, col:], dist=pto.VLoadDist.UNPK_B8) + vec_ui8 = pto.vbitcast(vec_si8, pto.ui8) + vec_ui8_lo, vec_ui8_hi = pto.vintlv(vec_ui8, v_zero) + vec_si8_lo = pto.vbitcast(vec_ui8_lo, pto.si8) + vec_si8_hi = pto.vbitcast(vec_ui8_hi, pto.si8) + out_lo = pto.vcvt(vec_si8_lo, pto.i32, b8_mask, part=pto.VcvtPartMode.P0) + out_hi = pto.vcvt(vec_si8_hi, pto.i32, b8_mask, part=pto.VcvtPartMode.P0) + pto.vsts(out_lo, dst[row, col:], mask_b32_cur, dist=pto.VStoreDist.NORM_B32) + pto.vsts( + out_hi, + dst[row, col + lanes_i32:], + mask_b32_next, + dist=pto.VStoreDist.NORM_B32, + ) + return None + + specialized = kernel.specialize( + src=pto.TileSpecialization(shape=(16, 64), memory_space=pto.MemorySpace.UB), + dst=pto.TileSpecialization(shape=(16, 64), memory_space=pto.MemorySpace.UB), + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + vecscope_stmts = [stmt for stmt in semantic_kernel.body if isinstance(stmt, SemanticVecscopeStmt)] + self.assertEqual(len(vecscope_stmts), 0) + + text = specialized.mlir_text() + self.assertNotIn("pto.vecscope {", text) + self.assertIn(" = arith.constant 64 : index", text) + self.assertIn(" = arith.constant 128 : index", text) + self.assertIn(" = pto.vdup ", text) + self.assertIn(" = pto.vintlv ", text) + + def test_punpack_widens_b16_mask_for_norm_b32_store_in_advanced_mode(self) -> None: + @pto.vkernel(op="punpack_widen_b16_to_b32_unique", dtypes=[(pto.si8, pto.i32)], advanced=True) + def kernel(src: pto.Tile, dst: pto.Tile): + valid_rows, valid_cols = dst.valid_shape + lanes_i32 = pto.get_lanes(pto.i32) + for row in range(0, valid_rows, 1): + b8_mask = pto.make_mask(pto.i8, pto.PAT.ALL) + mask_b16, _ = pto.make_mask(pto.i16, valid_cols) + mask_b32 = pto.punpack(mask_b16, pto.PredicatePart.LOWER) + vec_si8 = pto.vlds(src[row, 0:], dist=pto.VLoadDist.UNPK_B8) + vec_ui8 = pto.vbitcast(vec_si8, pto.ui8) + v_zero_i8 = pto.vdup(pto.i8(0), b8_mask) + v_zero = pto.vbitcast(v_zero_i8, pto.ui8) + wide_lo, _ = pto.vintlv(vec_ui8, v_zero) + narrowed = pto.vbitcast(wide_lo, pto.si8) + converted = pto.vcvt(narrowed, pto.i32, b8_mask, part=pto.VcvtPartMode.P0) + pto.vsts(converted, dst[row, 0:], mask_b32, dist=pto.VStoreDist.NORM_B32) + pto.vsts(converted, dst[row, lanes_i32:], mask_b32, dist=pto.VStoreDist.NORM_B32) + return None + + specialized = kernel.specialize( + src=pto.TileSpecialization(shape=(16, 64), memory_space=pto.MemorySpace.UB), + dst=pto.TileSpecialization(shape=(16, 64), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn(' = pto.punpack ', text) + self.assertRegex( + text, + r"pto\.punpack %mask_b16_\d+, \"LOWER\" : !pto\.mask -> !pto\.mask", + ) + self.assertRegex( + text, + r"pto\.vsts %converted_\d+, %tmp_\d+\[%c0\], %mask_b32_\d+ \{dist = \"NORM_B32\"\} : !pto\.vreg<64xi32>, memref<\?x\?xi32, strided<\[\?, \?\], offset: \?>, #pto\.address_space>, !pto\.mask", + ) + + def test_elementwise_kernel_positive_regression_covers_vecscope_tail_mask_and_dynamic_loop_bound(self) -> None: + @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32, pto.i32)], advanced=True) + def kernel(inp: pto.TensorView, tile: pto.Tile, remaining: pto.i32): + rows = inp.shape[0] + with pto.strict_vecscope(tile, tile, remaining, 0, rows, 64) as ( + src, + dst, + rem, + lb, + ub, + step, + ): + for lane in range(lb, ub, step): + mask, rem = pto.make_mask(pto.f32, rem) + vec = pto.vlds(src, lane) + pto.vsts(vec, dst, lane, mask) + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(16, 16), + memory_space=pto.MemorySpace.UB, + ) + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + self.assertEqual(len(semantic_kernel.body), 3) + self.assertIsInstance(semantic_kernel.body[1], SemanticStrictVecscopeStmt) + + vecscope = semantic_kernel.body[1] + self.assertIsInstance(vecscope, SemanticStrictVecscopeStmt) + loop_stmt = vecscope.body[0] + self.assertIsInstance(loop_stmt, SemanticForStmt) + self.assertEqual(len(loop_stmt.loop_carried), 1) + self.assertEqual(loop_stmt.loop_carried[0].name, "rem") + + text = specialized.mlir_text() + self.assertIn( + "func.func @kernel(%arg0: !pto.tensor_view, %arg1: !pto.tile_buf, %arg2: i32) attributes { pto.tilelang.instance, pto.kernel_kind = #pto.kernel_kind } {", + text, + ) + self.assertRegex( + text, + r"%rows_\d+ = pto\.get_tensor_view_dim %arg0, %c0 : !pto\.tensor_view<\?x\?x\?x\?x\?xf32> -> index", + ) + self.assertRegex( + text, + r"pto\.strict_vecscope\(%tmp_\d+, %tmp_\d+, %arg2, %c0, %rows_\d+, %c64\)", + ) + self.assertRegex( + text, + r"scf\.for %lane_\d+ = %lb_\d+ to %ub_\d+ step %step_\d+ iter_args\(%rem_iter_\d+ = %rem_\d+\) -> \(i32\) \{", + ) + self.assertRegex( + text, + r"%mask_\d+, %rem_\d+ = pto\.plt_b32 %rem_iter_\d+ : i32 -> !pto\.mask, i32", + ) + + def test_if_else_and_sync_ops_lower_to_scf_if_and_authoring_sync_ops(self) -> None: + @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f32, pto.i32)], advanced=True) + def kernel(inp: pto.TensorView, tile: pto.Tile, flag: pto.i32): + pto.set_flag(pto.PIPE.MTE2, pto.PIPE.V, pto.EVENT.ID0) + pto.wait_flag(pto.PIPE.MTE2, pto.PIPE.V, pto.EVENT.ID0) + step = 64 + if flag: + step = 64 + pto.set_flag(pto.PIPE.V, pto.PIPE.MTE3, pto.EVENT.ID0) + else: + step = 128 + pto.wait_flag(pto.PIPE.V, pto.PIPE.MTE3, pto.EVENT.ID0) + with pto.strict_vecscope(tile, tile, 0, 256, step) as (src, dst, lb, ub, vec_step): + for lane in range(lb, ub, vec_step): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + vec = pto.vlds(src, lane) + pto.vsts(vec, dst, lane, mask) + pto.pipe_barrier(pto.PIPE.ALL) + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(16, 16), + memory_space=pto.MemorySpace.UB, + ) + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + self.assertIsInstance(semantic_kernel.body[0], SemanticSetFlagStmt) + self.assertIsInstance(semantic_kernel.body[1], SemanticWaitFlagStmt) + self.assertIsInstance(semantic_kernel.body[3], SemanticIfStmt) + self.assertIsInstance(semantic_kernel.body[5], SemanticPipeBarrierStmt) + + text = specialized.mlir_text() + self.assertIn('pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"]', text) + self.assertIn('pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"]', text) + self.assertIn("= arith.cmpi ne, %arg2, %c0_i32 : i32", text) + self.assertRegex(text, r"%step_\d+ = scf\.if %tmp_\d+ -> \(index\) \{") + self.assertIn('pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"]', text) + self.assertIn('pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"]', text) + self.assertRegex(text, r"scf\.yield %step_\d+ : index") + self.assertIn("%step_2 = arith.constant 128 : index", text) + self.assertRegex( + text, + r"pto\.strict_vecscope\(%tmp_\d+, %tmp_\d+, %c0, %c256, %step_\d+\)", + ) + self.assertIn("scf.for %lane_", text) + self.assertIn("pto.barrier #pto.pipe", text) + + def test_if_else_with_two_merged_bindings_lowers_to_multi_result_scf_if(self) -> None: + @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.i32)], advanced=True) + def kernel(tile: pto.Tile, flag: pto.i32): + step = 64 + upper = 256 + if flag: + step = 32 + upper = upper - step + else: + step = 64 + upper = 128 + with pto.strict_vecscope(tile, tile, 0, upper, step) as (src, dst, lb, ub, vec_step): + for lane in range(lb, ub, vec_step): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + vec = pto.vlds(src, lane) + pto.vsts(vec, dst, lane, mask) + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(16, 16), + memory_space=pto.MemorySpace.UB, + ) + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + if_stmt = next(stmt for stmt in semantic_kernel.body if isinstance(stmt, SemanticIfStmt)) + self.assertIsInstance(if_stmt, SemanticIfStmt) + self.assertEqual([result.result_binding.name for result in if_stmt.results], ["step", "upper"]) + + text = specialized.mlir_text() + self.assertRegex( + text, + r"%step_\d+, %upper_\d+ = scf\.if %tmp_\d+ -> \(index, index\) \{", + ) + self.assertRegex( + text, + r"scf\.yield %step_\d+, %upper_\d+ : index, index", + ) + self.assertRegex( + text, + r"pto\.strict_vecscope\(%tmp_\d+, %tmp_\d+, %c0, %upper_\d+, %step_\d+\)", + ) + + def test_extended_sync_buffer_ops_lower_to_authoring_surface(self) -> None: + Pipe = pto.Pipe + Event = pto.Event + BarrierType = pto.BarrierType + + @pto.vkernel( + op="extended_sync_surface", + dtypes=[(pto.f32, pto.i64, pto.i64, pto.i64, pto.i64, pto.i32)], + advanced=True, + ) + def kernel( + tile: pto.Tile, + buf_id: pto.i64, + mode: pto.i64, + core_id: pto.i64, + block_id: pto.i64, + config: pto.i32, + ): + pto.get_buf(Pipe.MTE2, buf_id, mode) + pto.rls_buf(Pipe.V, buf_id) + pto.mem_bar(BarrierType.VST_VLD) + pto.set_cross_core(core_id, Event.ID7) + pto.set_intra_block(block_id, Event.ID16) + pto.set_intra_core(config) + pto.wait_flag_dev(core_id, Event.ID8) + pto.wait_intra_core(block_id, Event.ID31) + with pto.strict_vecscope(tile, tile, 0, 128, 64) as (src, dst, lb, ub, step): + for lane in range(lb, ub, step): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + vec = pto.vlds(src, lane) + pto.vsts(vec, dst, lane, mask) + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(8, 16), + memory_space=pto.MemorySpace.UB, + ) + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + self.assertIsInstance(semantic_kernel.body[0], SemanticGetBufStmt) + self.assertIsInstance(semantic_kernel.body[1], SemanticRlsBufStmt) + self.assertIsInstance(semantic_kernel.body[2], SemanticMemBarStmt) + self.assertIsInstance(semantic_kernel.body[3], SemanticSetCrossCoreStmt) + self.assertIsInstance(semantic_kernel.body[4], SemanticSetIntraBlockStmt) + self.assertIsInstance(semantic_kernel.body[5], SemanticSetIntraCoreStmt) + self.assertIsInstance(semantic_kernel.body[6], SemanticWaitFlagDevStmt) + self.assertIsInstance(semantic_kernel.body[7], SemanticWaitIntraCoreStmt) + + text = specialized.mlir_text() + self.assertIn('pto.get_buf "PIPE_MTE2", %arg1, %arg2 : i64, i64', text) + self.assertIn('pto.rls_buf "PIPE_V", %arg1, %c0_i64 : i64, i64', text) + self.assertIn('pto.mem_bar "VST_VLD"', text) + self.assertIn("pto.set_cross_core %arg3, %c7_i64 : i64, i64", text) + self.assertIn("pto.set_intra_block %arg4, %c16_i64 : i64, i64", text) + self.assertIn("pto.set_intra_core %arg5 : i32", text) + self.assertIn("pto.wait_flag_dev %arg3, %c8_i64 : i64, i64", text) + self.assertIn("pto.wait_intra_core %arg4, %c31_i64 : i64, i64", text) + + def test_mem_bar_accepts_extended_barrier_type_enum(self) -> None: + BarrierType = pto.BarrierType + + @pto.vkernel( + op="mem_bar_extended_enum_unique", + dtypes=[(pto.f32, pto.f32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + pto.mem_bar(BarrierType.ST_VST) + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + vec = pto.vlds(src, 0) + pto.vsts(vec, dst, 0, mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + self.assertIsInstance(semantic_kernel.body[0], SemanticMemBarStmt) + + text = specialized.mlir_text() + self.assertIn('pto.mem_bar "ST_VST"', text) + + def test_mem_bar_accepts_extended_barrier_type_enum_vst_st(self) -> None: + BarrierType = pto.BarrierType + + @pto.vkernel( + op="mem_bar_extended_enum_vst_st_unique", + dtypes=[(pto.f32, pto.f32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + pto.mem_bar(BarrierType.VST_ST) + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + vec = pto.vlds(src, 0) + pto.vsts(vec, dst, 0, mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn('pto.mem_bar "VST_ST"', text) + + def test_mem_bar_rejects_unknown_barrier_string(self) -> None: + with self.assertRaises(TypeError) as ctx: + + @pto.vkernel( + op="mem_bar_invalid_string_unique", + dtypes=[(pto.f32, pto.f32)], + advanced=True, + ) + def kernel(dst: pto.Tile, src: pto.Tile): + pto.mem_bar("NOT_A_BARRIER") + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + vec = pto.vlds(src, 0) + pto.vsts(vec, dst, 0, mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + specialized.mlir_text() + + self.assertIn("canonical barrier string", str(ctx.exception)) + + def test_runtime_block_queries_and_scalar_pointer_helpers_lower_to_v0_3_surface(self) -> None: + @pto.vkernel( + op="runtime_block_queries_and_scalar_helpers", + dtypes=[(pto.f32, pto.f32)], + advanced=True, + ) + def kernel( + src: pto.ptr(pto.f32, pto.MemorySpace.UB), + dst: pto.ptr(pto.f32, pto.MemorySpace.UB), + ): + block = pto.get_block_idx() + block_num = pto.get_block_num() + subblock = pto.get_subblock_idx() + subblock_num = pto.get_subblock_num() + value = pto.load_scalar(src, 0) + pto.store_scalar(dst, 0, value) + return None + + specialized = kernel.specialize() + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + + store_stmt = next(stmt for stmt in semantic_kernel.body if isinstance(stmt, SemanticScalarStoreStmt)) + self.assertIsInstance(store_stmt, SemanticScalarStoreStmt) + self.assertEqual(store_stmt.destination.type.element_dtype, pto.f32) + + text = specialized.mlir_text() + self.assertIn("= pto.get_block_idx", text) + self.assertIn("= pto.get_block_num", text) + self.assertIn("= pto.get_subblock_idx", text) + self.assertIn("= pto.get_subblock_num", text) + self.assertIn("= pto.load_scalar %arg0[%c0] : !pto.ptr -> f32", text) + self.assertIn("pto.store_scalar", text) + + def test_vldsx2_and_vstsx2_tile_sugar_lower_with_normalized_dist_tokens(self) -> None: + @pto.vkernel(op="vldsx2_vstsx2_tile_sugar", dtypes=[(pto.f32, pto.f32)]) + def kernel(src: pto.Tile, dst: pto.Tile): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + low, high = pto.vldsx2(src[0, 0:], pto.DeinterleaveDist.B32) + pto.vstsx2(low, high, dst[0, 0:], pto.InterleaveDist.B32, mask) + return None + + specialized = kernel.specialize( + src=pto.TileSpecialization(shape=(1, 128), memory_space=pto.MemorySpace.UB), + dst=pto.TileSpecialization(shape=(1, 128), memory_space=pto.MemorySpace.UB), + ) + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + pair_store = next(stmt for stmt in _walk_semantic_stmts(semantic_kernel.body) if isinstance(stmt, SemanticVectorPairStoreStmt)) + self.assertIsInstance(pair_store, SemanticVectorPairStoreStmt) + + text = specialized.mlir_text() + self.assertIn("pto.vldsx2", text) + self.assertIn("pto.vstsx2", text) + self.assertIn('"DINTLV"', text) + self.assertIn('"INTLV"', text) + self.assertNotIn("DINTLV_B32", text) + self.assertNotIn("INTLV_B32", text) + + def test_vldsx2_and_vstsx2_still_accept_legacy_string_tokens_for_compatibility(self) -> None: + @pto.vkernel(op="vldsx2_vstsx2_legacy_tokens", dtypes=[(pto.f32, pto.f32)]) + def kernel(src: pto.Tile, dst: pto.Tile): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + low, high = pto.vldsx2(src[0, 0:], "DINTLV_B32") + pto.vstsx2(low, high, dst[0, 0:], "INTLV_B32", mask) + return None + + specialized = kernel.specialize( + src=pto.TileSpecialization(shape=(1, 128), memory_space=pto.MemorySpace.UB), + dst=pto.TileSpecialization(shape=(1, 128), memory_space=pto.MemorySpace.UB), + ) + + text = specialized.mlir_text() + self.assertIn('"DINTLV"', text) + self.assertIn('"INTLV"', text) + + def test_vscatter_lowers_from_advanced_pointer_surface(self) -> None: + @pto.vkernel( + op="vscatter_pointer_surface", + dtypes=[(pto.i32, pto.f32)], + advanced=True, + ) + def kernel( + offsets_src: pto.ptr(pto.i32, pto.MemorySpace.UB), + dst: pto.ptr(pto.f32, pto.MemorySpace.UB), + ): + vec = pto.vbr(1.0) + offsets = pto.vlds(offsets_src, 0) + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + pto.vscatter(vec, dst, offsets, mask) + return None + + specialized = kernel.specialize() + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + scatter_stmt = next(stmt for stmt in _walk_semantic_stmts(semantic_kernel.body) if isinstance(stmt, SemanticVScatterStmt)) + + self.assertIsInstance(scatter_stmt, SemanticVScatterStmt) + self.assertEqual(scatter_stmt.destination.type.memory_space, "ub") + self.assertEqual(scatter_stmt.value.type.element_dtype, pto.f32) + self.assertEqual(scatter_stmt.offsets.type.element_dtype, pto.i32) + self.assertEqual(scatter_stmt.mask.type.granularity, "b32") + + text = specialized.mlir_text() + self.assertIn("pto.vscatter", text) + self.assertIn("!pto.vreg<64xf32>", text) + self.assertIn("!pto.vreg<64xi32>", text) + self.assertIn("!pto.mask", text) + + def test_align_load_and_stateful_store_ops_lower_to_current_vpto_surface(self) -> None: + @pto.vkernel( + op="align_load_and_stateful_store_ops", + dtypes=[(pto.f32, pto.f32)], + advanced=True, + ) + def kernel( + src: pto.ptr(pto.f32, pto.MemorySpace.UB), + dst: pto.ptr(pto.f32, pto.MemorySpace.UB), + ): + load_align = pto.vldas(src) + vec, load_align = pto.vldus(src, load_align) + store_align = pto.init_align() + store_align = pto.vstus(store_align, 0, vec, dst) + store_align = pto.vstur(store_align, vec, dst) + pto.vstas(store_align, dst, 0) + post_align = pto.vstur(pto.init_align(), vec, dst, pto.PostUpdateMode.POST_UPDATE) + pto.vstar(post_align, dst) + return None + + specialized = kernel.specialize() + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + all_stmts = tuple(_walk_semantic_stmts(semantic_kernel.body)) + align_store_stmts = [stmt for stmt in all_stmts if isinstance(stmt, SemanticAlignStoreStmt)] + + self.assertTrue(any(isinstance(stmt, SemanticAssignStmt) and isinstance(stmt.value.type, SemanticAlignType) for stmt in all_stmts)) + self.assertEqual(len(align_store_stmts), 2) + self.assertEqual([stmt.op_name for stmt in align_store_stmts], ["vstas", "vstar"]) + + text = specialized.mlir_text() + self.assertIn("pto.vldas", text) + self.assertIn("pto.vldus", text) + self.assertIn("pto.init_align", text) + self.assertIn("pto.vstus", text) + self.assertIn("pto.vstur", text) + self.assertIn("pto.vstas", text) + self.assertIn("pto.vstar", text) + self.assertIn('"POST_UPDATE"', text) + self.assertIn('"NO_POST_UPDATE"', text) + self.assertIn("!pto.align", text) + + def test_predicate_store_and_compatibility_store_sugar_lower_to_supported_ops(self) -> None: + @pto.vkernel( + op="predicate_store_and_store_sugar", + dtypes=[(pto.f32, pto.ui32)], + advanced=True, + ) + def kernel( + dst: pto.ptr(pto.f32, pto.MemorySpace.UB), + mask_dst: pto.ptr(pto.ui32, pto.MemorySpace.UB), + ): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + pto.psts(mask, mask_dst, 0) + align = pto.init_align() + align, mask_base = pto.pstu(align, mask, mask_dst) + pto.vsta(align, mask_base, 0) + pto.vsst(1.0, dst, 0, mask) + return None + + specialized = kernel.specialize() + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + all_stmts = tuple(_walk_semantic_stmts(semantic_kernel.body)) + + self.assertTrue(any(isinstance(stmt, SemanticPredicateStoreStmt) for stmt in all_stmts)) + self.assertTrue(any(isinstance(stmt, SemanticAlignStoreStmt) and stmt.op_name == "vstas" for stmt in all_stmts)) + + text = specialized.mlir_text() + self.assertIn("pto.psts", text) + self.assertIn('"NORM"', text) + self.assertIn("pto.pstu", text) + self.assertIn("pto.vbr", text) + self.assertIn("pto.vsts", text) + self.assertIn("pto.vstas", text) + self.assertNotIn("pto.vsst", text) + self.assertNotIn("pto.vsta ", text) + + def test_psts_rejects_tile_indexing_surface(self) -> None: + @pto.vkernel( + op="predicate_store_tile_indexing_reject", + dtypes=[(pto.ui32,)], + advanced=True, + ) + def kernel(mask_dst: pto.Tile): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + pto.psts(mask, mask_dst[0, 0:]) + return None + + specialized = kernel.specialize( + mask_dst=pto.TileSpecialization(shape=(16, 64), memory_space=pto.MemorySpace.UB), + ) + with self.assertRaises(TypeError) as ctx: + analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + self.assertIn("does not support Tile element-indexing syntax", str(ctx.exception)) + self.assertIn("pto.psts(mask, buf, offset", str(ctx.exception)) + + def test_plds_load_lower_to_supported_op(self) -> None: + @pto.vkernel( + op="predicate_load_from_ub_buffer", + dtypes=[(pto.ui32, pto.ui32)], + advanced=True, + ) + def kernel( + mask_src: pto.ptr(pto.ui32, pto.MemorySpace.UB), + mask_dst: pto.ptr(pto.ui32, pto.MemorySpace.UB), + ): + mask = pto.plds(mask_src, 0) + pto.psts(mask, mask_dst, 0) + return None + + specialized = kernel.specialize() + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + load_assign = next( + stmt + for stmt in _walk_semantic_stmts(semantic_kernel.body) + if isinstance(stmt, SemanticAssignStmt) + and isinstance(stmt.value, SemanticCallExpr) + and stmt.value.name == "plds" + ) + self.assertIsInstance(load_assign.value.type, SemanticMaskType) + self.assertEqual(load_assign.value.type.granularity, "b32") + + text = specialized.mlir_text() + self.assertIn("pto.plds", text) + self.assertIn('"NORM"', text) + self.assertIn("pto.psts", text) + + def test_plds_rejects_unsupported_dist_token(self) -> None: + with self.assertRaises(TypeError) as ctx: + + @pto.vkernel( + op="predicate_load_invalid_dist", + dtypes=[(pto.ui32,)], + advanced=True, + ) + def kernel(mask_src: pto.ptr(pto.ui32, pto.MemorySpace.UB)): + _mask = pto.plds(mask_src, 0, pto.PredicateDist.PK) + return None + + kernel.specialize().mlir_text() + + self.assertIn("predicate load dist must be one of", str(ctx.exception)) + self.assertIn("pto.PredicateDist.DS", str(ctx.exception)) + + def test_predicate_generation_and_logic_families_lower_to_supported_ops(self) -> None: + @pto.vkernel( + op="predicate_generation_and_logic_families", + dtypes=[(pto.ui32,)], + advanced=True, + ) + def kernel(mask_dst: pto.ptr(pto.ui32, pto.MemorySpace.UB)): + mask8 = pto.pset_b8(pto.PAT.ALL) + mask16 = pto.pge_b16(pto.PAT.VL16) + mask32, _next = pto.plt_b32(64) + and_mask = pto.pand(mask32, mask32, mask32) + or_mask = pto.por(and_mask, mask32, mask32) + xor_mask = pto.pxor(or_mask, mask32, mask32) + pto.psts(xor_mask, mask_dst, 0) + _ = mask8 + _ = mask16 + return None + + text = kernel.specialize().mlir_text() + self.assertIn("pto.pset_b8", text) + self.assertIn("pto.pge_b16", text) + self.assertIn("pto.plt_b32", text) + self.assertIn("pto.pand", text) + self.assertIn("pto.por", text) + self.assertIn("pto.pxor", text) + + def test_predicate_load_store_alias_and_immediate_forms_lower_to_supported_ops(self) -> None: + @pto.vkernel( + op="predicate_load_store_alias_and_immediate_forms", + dtypes=[(pto.ui32, pto.ui32, pto.ui32, pto.si32)], + advanced=True, + ) + def kernel( + mask_src: pto.ptr(pto.ui32, pto.MemorySpace.UB), + mask_dst: pto.ptr(pto.ui32, pto.MemorySpace.UB), + off_u: pto.ui32, + off_s: pto.si32, + ): + mask0 = pto.pld(mask_src, 0, pto.PredicateDist.NORM) + mask1 = pto.pldi(mask_src, pto.i32(off_u), pto.PredicateDist.US) + pto.pst(mask0, mask_dst, 0) + pto.psti(mask1, mask_dst, pto.i32(off_s), pto.PredicateDist.PK) + return None + + text = kernel.specialize().mlir_text() + self.assertIn("pto.plds", text) + self.assertIn("pto.pldi", text) + self.assertIn("pto.psts", text) + self.assertIn("pto.psti", text) + self.assertIn("builtin.unrealized_conversion_cast", text) + self.assertIn("arith.index_cast", text) + self.assertNotRegex(text, r"arith\.extsi %\w+ : si32 to i32") + self.assertNotRegex(text, r"arith\.extui %\w+ : ui32 to i32") + + def test_predicate_reorder_families_lower_to_supported_ops(self) -> None: + @pto.vkernel( + op="predicate_reorder_families", + dtypes=[(pto.ui32,)], + advanced=True, + ) + def kernel(mask_dst: pto.ptr(pto.ui32, pto.MemorySpace.UB)): + mask8 = pto.pset_b8(pto.PAT.ALL) + mask16 = pto.pset_b16(pto.PAT.ALL) + mask32 = pto.pset_b32(pto.PAT.ALL) + low8, high8 = pto.pdintlv_b8(mask8, mask8) + low8i, high8i = pto.pintlv_b8(mask8, mask8) + low16d, high16d = pto.pdintlv_b16(mask16, mask16) + low16, high16 = pto.pintlv_b16(mask16, mask16) + low32, high32 = pto.pdintlv_b32(mask32, mask32) + low32i, high32i = pto.pintlv_b32(mask32, mask32) + all32 = pto.make_mask(pto.ui32, pto.PAT.ALL) + pto.psts(all32, mask_dst, 0) + return None + + text = kernel.specialize().mlir_text() + self.assertIn("pto.pdintlv_b8", text) + self.assertIn("pto.pintlv_b8", text) + self.assertIn("pto.pdintlv_b16", text) + self.assertIn("pto.pintlv_b16", text) + self.assertIn("pto.pdintlv_b32", text) + self.assertIn("pto.pintlv_b32", text) + + def test_pdintlv_b8_rejects_wrong_mask_granularity(self) -> None: + with self.assertRaises(TypeError) as ctx: + + @pto.vkernel( + op="predicate_reorder_wrong_mask_granularity", + dtypes=[(pto.ui32,)], + advanced=True, + ) + def kernel(mask_src: pto.ptr(pto.ui32, pto.MemorySpace.UB)): + mask32 = pto.plds(mask_src, 0) + _low, _high = pto.pdintlv_b8(mask32, mask32) + return None + + kernel.specialize().mlir_text() + + self.assertIn("expects !pto.mask operands", str(ctx.exception)) + + def test_strict_vecscope_rejects_implicit_capture_during_semantic_analysis(self) -> None: + @pto.vkernel(op="eltwise", dtypes=[(pto.f32, pto.f16, pto.i32)], advanced=True) + def kernel(inp: pto.TensorView, tile: pto.Tile, scale: pto.i32): + with pto.strict_vecscope(inp, tile) as (vin, vtmp): + leaked = scale + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization( + shape=(8, 16), + memory_space=pto.MemorySpace.UB, + ) + ) + + with self.assertRaises(ValueError) as ctx: + analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + self.assertIn("implicit capture of 'scale' is not allowed", str(ctx.exception)) + + +class TileLangDSLInlineProcTests(unittest.TestCase): + @pto.inline_proc + def _inline_copy_row(dst: pto.Tile, src: pto.Tile, lane: pto.i32): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + vec = pto.vlds(src, lane) + pto.vsts(vec, dst, lane, mask) + return None + + @pto.inline_proc + def _inline_recur(dst: pto.Tile): + _inline_recur(dst) + return None + + @pto.inline_proc + def _inline_capture(dst: pto.Tile): + pto.vlds(dst, lane) + return None + + @pto.inline_proc + def _inline_capture_global_literal(dst: pto.Tile): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + vec = pto.vlds(dst, INLINE_PROC_GLOBAL_LANE) + pto.vsts(vec, dst, INLINE_PROC_GLOBAL_LANE, mask) + return None + + def test_inline_proc_exports_from_package_surface(self) -> None: + self.assertTrue(hasattr(pto, "inline_proc")) + self.assertTrue(hasattr(pto, "InlineProcDescriptor")) + + def test_inline_proc_call_keeps_call_in_frontend_and_mlir_text(self) -> None: + @pto.vkernel(op="inline_proc_backend_call_unique", dtypes=[(pto.f32, pto.f32)]) + def kernel(dst: pto.Tile, src: pto.Tile): + _inline_copy_row(dst, src, 0) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + + frontend_kernel = build_frontend_kernel_node(specialized) + self.assertEqual(len(frontend_kernel.body), 2) + self.assertIsInstance(frontend_kernel.body[0], FrontendExprStmt) + self.assertIsInstance(frontend_kernel.body[0].expr, FrontendCallExpr) + self.assertEqual(frontend_kernel.body[0].expr.name, "_inline_copy_row") + self.assertGreaterEqual(len(frontend_kernel.inline_procs), 1) + self.assertIn("_inline_copy_row", {proc.name for proc in frontend_kernel.inline_procs}) + + text = specialized.mlir_text() + self.assertIn("func.call", text) + self.assertRegex(text, r"func\.call @__tl_inline_") + + def test_inline_proc_supports_default_parameters_and_keyword_call(self) -> None: + @pto.inline_proc + def inline_store(dst: pto.Tile, src: pto.Tile, lane: pto.i32 = 0): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + vec = pto.vlds(src, lane) + pto.vsts(vec, dst, lane, mask) + return None + + @pto.vkernel(op="inline_proc_keyword_default_unique", dtypes=[(pto.f32, pto.f32)]) + def kernel(dst: pto.Tile, src: pto.Tile): + inline_store(dst=dst, src=src) + return None + + text = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ).mlir_text() + self.assertIn("func.call", text) + self.assertRegex(text, r"func\.func private @__tl_inline_") + + def test_inline_proc_supports_return_expression_in_expression_position(self) -> None: + @pto.inline_proc + def inline_load(src: pto.Tile, lane: pto.i32 = 0): + return pto.vlds(src, lane) + + @pto.vkernel(op="inline_proc_expr_return_unique", dtypes=[(pto.f32, pto.f32)]) + def kernel(dst: pto.Tile, src: pto.Tile): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + vec = inline_load(src) + pto.vsts(vec, dst, 0, mask) + return None + + text = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ).mlir_text() + self.assertIn("func.call", text) + self.assertRegex(text, r"= func\.call @__tl_inline_") + self.assertIn("pto.vsts", text) + + def test_vdiv_integer_vector_types_rewrite_to_internal_helper(self) -> None: + @pto.vkernel(op="vdiv_i16_dtype_support_unique", dtypes=[(pto.i16, pto.i16)]) + def kernel_i16(dst: pto.Tile, src: pto.Tile): + dtype = dst.element_type + mask = pto.make_mask(dtype, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vdiv(vec, vec, mask) + pto.vsts(out, dst, 0, mask) + return None + + @pto.vkernel(op="vdiv_i32_dtype_support_unique", dtypes=[(pto.i32, pto.i32)]) + def kernel_i32(dst: pto.Tile, src: pto.Tile): + dtype = dst.element_type + mask = pto.make_mask(dtype, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vdiv(vec, vec, mask) + pto.vsts(out, dst, 0, mask) + return None + + specialized_i16 = kernel_i16.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + semantic_i16 = analyze_frontend_kernel(build_frontend_kernel_node(specialized_i16)) + assign_i16 = next( + stmt + for stmt in _walk_semantic_stmts(semantic_i16.body) + if isinstance(stmt, SemanticAssignStmt) + and len(stmt.targets) == 1 + and stmt.targets[0].name == "out" + and isinstance(stmt.value, SemanticCallExpr) + ) + self.assertIsNone(assign_i16.value.namespace) + self.assertRegex(assign_i16.value.name, r"^__tl_inline__tl_soft_vdiv_") + self.assertEqual(assign_i16.value.type, SemanticVRegType(element_dtype=pto.i16, lanes=128)) + self.assertGreaterEqual(len(semantic_i16.inline_helpers), 1) + self.assertRegex(specialized_i16.mlir_text(), r"func\.call @__tl_inline__tl_soft_vdiv_") + + text_i32 = kernel_i32.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + semantic_i32 = analyze_frontend_kernel(build_frontend_kernel_node(text_i32)) + assign_i32 = next( + stmt + for stmt in _walk_semantic_stmts(semantic_i32.body) + if isinstance(stmt, SemanticAssignStmt) + and len(stmt.targets) == 1 + and stmt.targets[0].name == "out" + and isinstance(stmt.value, SemanticCallExpr) + ) + self.assertIsNone(assign_i32.value.namespace) + self.assertRegex(assign_i32.value.name, r"^__tl_inline__tl_soft_vdiv_") + self.assertEqual(assign_i32.value.type, SemanticVRegType(element_dtype=pto.i32, lanes=64)) + self.assertGreaterEqual(len(semantic_i32.inline_helpers), 1) + self.assertRegex(text_i32.mlir_text(), r"func\.call @__tl_inline__tl_soft_vdiv_") + + def test_vdiv_f16_and_f32_vector_types_keep_authoring_form_vpto_path(self) -> None: + @pto.vkernel( + op="vdiv_float_dtype_support_unique", + dtypes=[(pto.f16, pto.f16), (pto.f32, pto.f32)], + ) + def kernel(dst: pto.Tile, src: pto.Tile): + dtype = dst.element_type + mask = pto.make_mask(dtype, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vdiv(vec, vec, mask) + pto.vsts(out, dst, 0, mask) + return None + + cases = [ + (pto.f16, 128), + (pto.f32, 64), + ] + + for dtype, lanes in cases: + with self.subTest(dtype=dtype): + selected = pto.select_kernel("a5", "vdiv_float_dtype_support_unique", (dtype, dtype)) + specialized = selected.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + + assign_stmt = next( + stmt + for stmt in _walk_semantic_stmts(semantic_kernel.body) + if isinstance(stmt, SemanticAssignStmt) + and len(stmt.targets) == 1 + and stmt.targets[0].name == "out" + and isinstance(stmt.value, SemanticCallExpr) + ) + self.assertEqual(assign_stmt.value.namespace, "pto") + self.assertEqual(assign_stmt.value.name, "vdiv") + self.assertEqual( + assign_stmt.value.type, + SemanticVRegType(element_dtype=dtype, lanes=lanes), + ) + self.assertEqual(len(semantic_kernel.inline_helpers), 0) + + text = lower_semantic_kernel(semantic_kernel).render() + self.assertEqual(text, specialized.mlir_text()) + self.assertIn("= pto.vdiv ", text) + self.assertNotIn("__tl_inline__tl_soft_vdiv_", text) + + def test_vdiv_i8_and_ui8_vector_types_rewrite_to_internal_helper(self) -> None: + @pto.vkernel(op="vdiv_i8_dtype_support_unique", dtypes=[(pto.i8, pto.i8)]) + def kernel_i8(dst: pto.Tile, src: pto.Tile): + mask = pto.make_mask(pto.i8, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vdiv(vec, vec, mask) + pto.vsts(out, dst, 0, mask) + return None + + @pto.vkernel(op="vdiv_ui8_dtype_support_unique", dtypes=[(pto.ui8, pto.ui8)]) + def kernel_ui8(dst: pto.Tile, src: pto.Tile): + mask = pto.make_mask(pto.ui8, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vdiv(vec, vec, mask) + pto.vsts(out, dst, 0, mask) + return None + + specialized_i8 = kernel_i8.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + semantic_i8 = analyze_frontend_kernel(build_frontend_kernel_node(specialized_i8)) + assign_i8 = next( + stmt + for stmt in _walk_semantic_stmts(semantic_i8.body) + if isinstance(stmt, SemanticAssignStmt) + and len(stmt.targets) == 1 + and stmt.targets[0].name == "out" + and isinstance(stmt.value, SemanticCallExpr) + ) + self.assertIsNone(assign_i8.value.namespace) + self.assertRegex(assign_i8.value.name, r"^__tl_inline__tl_soft_vdiv_") + self.assertEqual(assign_i8.value.type, SemanticVRegType(element_dtype=pto.i8, lanes=256)) + self.assertGreaterEqual(len(semantic_i8.inline_helpers), 1) + self.assertRegex(specialized_i8.mlir_text(), r"func\.call @__tl_inline__tl_soft_vdiv_") + + specialized_ui8 = kernel_ui8.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + semantic_ui8 = analyze_frontend_kernel(build_frontend_kernel_node(specialized_ui8)) + assign_ui8 = next( + stmt + for stmt in _walk_semantic_stmts(semantic_ui8.body) + if isinstance(stmt, SemanticAssignStmt) + and len(stmt.targets) == 1 + and stmt.targets[0].name == "out" + and isinstance(stmt.value, SemanticCallExpr) + ) + self.assertIsNone(assign_ui8.value.namespace) + self.assertRegex(assign_ui8.value.name, r"^__tl_inline__tl_soft_vdiv_") + self.assertEqual(assign_ui8.value.type, SemanticVRegType(element_dtype=pto.ui8, lanes=256)) + self.assertGreaterEqual(len(semantic_ui8.inline_helpers), 1) + self.assertRegex(specialized_ui8.mlir_text(), r"func\.call @__tl_inline__tl_soft_vdiv_") + + def test_vdiv_rejects_bf16_vector_type(self) -> None: + @pto.vkernel(op="vdiv_bf16_reject_unique", dtypes=[(pto.bf16, pto.bf16)]) + def kernel(dst: pto.Tile, src: pto.Tile): + mask = pto.make_mask(pto.bf16, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vdiv(vec, vec, mask) + pto.vsts(out, dst, 0, mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + + with self.assertRaises(TypeError) as ctx: + specialized.mlir_text() + + self.assertIn( + "pto.vdiv only supports 8/16/32-bit integer families and f16/f32 in TileLang DSL v1", + str(ctx.exception), + ) + + def test_vmod_integer_vector_types_rewrite_to_internal_helper(self) -> None: + @pto.vkernel(op="vmod_i16_dtype_support_unique", dtypes=[(pto.i16, pto.i16)]) + def kernel_i16(dst: pto.Tile, src: pto.Tile): + dtype = dst.element_type + mask = pto.make_mask(dtype, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vmod(vec, vec, mask) + pto.vsts(out, dst, 0, mask) + return None + + @pto.vkernel(op="vmod_i32_dtype_support_unique", dtypes=[(pto.i32, pto.i32)]) + def kernel_i32(dst: pto.Tile, src: pto.Tile): + dtype = dst.element_type + mask = pto.make_mask(dtype, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vmod(vec, vec, mask) + pto.vsts(out, dst, 0, mask) + return None + + specialized_i16 = kernel_i16.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + semantic_i16 = analyze_frontend_kernel(build_frontend_kernel_node(specialized_i16)) + assign_i16 = next( + stmt + for stmt in _walk_semantic_stmts(semantic_i16.body) + if isinstance(stmt, SemanticAssignStmt) + and len(stmt.targets) == 1 + and stmt.targets[0].name == "out" + and isinstance(stmt.value, SemanticCallExpr) + ) + self.assertIsNone(assign_i16.value.namespace) + self.assertRegex(assign_i16.value.name, r"^__tl_inline__tl_soft_vmod_") + self.assertEqual(assign_i16.value.type, SemanticVRegType(element_dtype=pto.i16, lanes=128)) + self.assertGreaterEqual(len(semantic_i16.inline_helpers), 1) + self.assertRegex(specialized_i16.mlir_text(), r"func\.call @__tl_inline__tl_soft_vmod_") + + specialized_i32 = kernel_i32.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + semantic_i32 = analyze_frontend_kernel(build_frontend_kernel_node(specialized_i32)) + assign_i32 = next( + stmt + for stmt in _walk_semantic_stmts(semantic_i32.body) + if isinstance(stmt, SemanticAssignStmt) + and len(stmt.targets) == 1 + and stmt.targets[0].name == "out" + and isinstance(stmt.value, SemanticCallExpr) + ) + self.assertIsNone(assign_i32.value.namespace) + self.assertRegex(assign_i32.value.name, r"^__tl_inline__tl_soft_vmod_") + self.assertEqual(assign_i32.value.type, SemanticVRegType(element_dtype=pto.i32, lanes=64)) + self.assertGreaterEqual(len(semantic_i32.inline_helpers), 1) + self.assertRegex(specialized_i32.mlir_text(), r"func\.call @__tl_inline__tl_soft_vmod_") + + def test_vmod_i8_and_ui8_vector_types_rewrite_to_internal_helper(self) -> None: + @pto.vkernel(op="vmod_i8_dtype_support_unique", dtypes=[(pto.i8, pto.i8)]) + def kernel_i8(dst: pto.Tile, src: pto.Tile): + mask = pto.make_mask(pto.i8, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vmod(vec, vec, mask) + pto.vsts(out, dst, 0, mask) + return None + + @pto.vkernel(op="vmod_ui8_dtype_support_unique", dtypes=[(pto.ui8, pto.ui8)]) + def kernel_ui8(dst: pto.Tile, src: pto.Tile): + mask = pto.make_mask(pto.ui8, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vmod(vec, vec, mask) + pto.vsts(out, dst, 0, mask) + return None + + specialized_i8 = kernel_i8.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + semantic_i8 = analyze_frontend_kernel(build_frontend_kernel_node(specialized_i8)) + assign_i8 = next( + stmt + for stmt in _walk_semantic_stmts(semantic_i8.body) + if isinstance(stmt, SemanticAssignStmt) + and len(stmt.targets) == 1 + and stmt.targets[0].name == "out" + and isinstance(stmt.value, SemanticCallExpr) + ) + self.assertIsNone(assign_i8.value.namespace) + self.assertRegex(assign_i8.value.name, r"^__tl_inline__tl_soft_vmod_") + self.assertEqual(assign_i8.value.type, SemanticVRegType(element_dtype=pto.i8, lanes=256)) + self.assertGreaterEqual(len(semantic_i8.inline_helpers), 1) + self.assertRegex(specialized_i8.mlir_text(), r"func\.call @__tl_inline__tl_soft_vmod_") + + specialized_ui8 = kernel_ui8.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + semantic_ui8 = analyze_frontend_kernel(build_frontend_kernel_node(specialized_ui8)) + assign_ui8 = next( + stmt + for stmt in _walk_semantic_stmts(semantic_ui8.body) + if isinstance(stmt, SemanticAssignStmt) + and len(stmt.targets) == 1 + and stmt.targets[0].name == "out" + and isinstance(stmt.value, SemanticCallExpr) + ) + self.assertIsNone(assign_ui8.value.namespace) + self.assertRegex(assign_ui8.value.name, r"^__tl_inline__tl_soft_vmod_") + self.assertEqual(assign_ui8.value.type, SemanticVRegType(element_dtype=pto.ui8, lanes=256)) + self.assertGreaterEqual(len(semantic_ui8.inline_helpers), 1) + self.assertRegex(specialized_ui8.mlir_text(), r"func\.call @__tl_inline__tl_soft_vmod_") + + def test_vmod_rejects_f32_vector_type(self) -> None: + @pto.vkernel(op="vmod_f32_reject_unique", dtypes=[(pto.f32, pto.f32)]) + def kernel(dst: pto.Tile, src: pto.Tile): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vmod(vec, vec, mask) + pto.vsts(out, dst, 0, mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + + with self.assertRaises(TypeError) as ctx: + specialized.mlir_text() + + self.assertIn( + "pto.vmod only supports 8/16/32-bit integer families in TileLang DSL v1", + str(ctx.exception), + ) + + def test_integer_divmod_helpers_lock_zero_divisor_sentinel_convention(self) -> None: + @pto.vkernel( + op="integer_divmod_zero_divisor_contract_unique", + dtypes=[ + (pto.i8, pto.i8), + (pto.ui8, pto.ui8), + (pto.i16, pto.i16), + (pto.ui16, pto.ui16), + (pto.i32, pto.i32), + (pto.ui32, pto.ui32), + ], + ) + def kernel(dst: pto.Tile, src: pto.Tile): + dtype = dst.element_type + mask = pto.make_mask(dtype, pto.PAT.ALL) + vec = pto.vlds(src, 0) + quot = pto.vdiv(vec, vec, mask) + rem = pto.vmod(vec, vec, mask) + pto.vsts(quot, dst, 0, mask) + pto.vsts(rem, dst, 0, mask) + return None + + cases = [ + ("vdiv", pto.i8, "__tl_inline__tl_soft_vdiv_i8_", -1), + ("vdiv", pto.ui8, "__tl_inline__tl_soft_vdiv_u8_", 0xFF), + ("vdiv", pto.i16, "__tl_inline__tl_soft_vdiv_i16_", -1), + ("vdiv", pto.ui16, "__tl_inline__tl_soft_vdiv_u16_", 0xFFFF), + ("vdiv", pto.i32, "__tl_inline__tl_soft_vdiv_i32_", -1), + ("vdiv", pto.ui32, "__tl_inline__tl_soft_vdiv_u32_", 0xFFFFFFFF), + ("vmod", pto.i8, "__tl_inline__tl_soft_vmod_i8_", -1), + ("vmod", pto.ui8, "__tl_inline__tl_soft_vmod_u8_", 0xFF), + ("vmod", pto.i16, "__tl_inline__tl_soft_vmod_i16_", -1), + ("vmod", pto.ui16, "__tl_inline__tl_soft_vmod_u16_", 0xFFFF), + ("vmod", pto.i32, "__tl_inline__tl_soft_vmod_i32_", -1), + ("vmod", pto.ui32, "__tl_inline__tl_soft_vmod_u32_", 0xFFFFFFFF), + ] + + for op_name, dtype, helper_prefix, expected_sentinel in cases: + with self.subTest(op=op_name, dtype=dtype): + selected = pto.select_kernel( + "a5", + "integer_divmod_zero_divisor_contract_unique", + (dtype, dtype), + ) + specialized = selected.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + helper = _find_inline_helper(semantic_kernel, helper_prefix) + + zero_mask_assign = _find_last_helper_assign_by_name(helper, "zero_mask") + self.assertIsInstance(zero_mask_assign.value, SemanticCallExpr) + self.assertEqual(zero_mask_assign.value.namespace, "pto") + self.assertEqual(zero_mask_assign.value.name, "vcmps") + self.assertEqual(zero_mask_assign.value.args[3].value, "eq") + self.assertEqual( + _resolve_helper_broadcast_scalar_literal(helper, zero_mask_assign.value.args[1]), + 0, + ) + + return_stmt = _find_helper_return_stmt(helper) + self.assertIsInstance(return_stmt.value, SemanticCallExpr) + self.assertEqual(return_stmt.value.namespace, "pto") + self.assertEqual(return_stmt.value.name, "vsel") + self.assertIsInstance(return_stmt.value.args[2], SemanticBindingRef) + self.assertEqual(return_stmt.value.args[2].binding.name, "zero_mask") + self.assertEqual( + _resolve_helper_broadcast_scalar_literal(helper, return_stmt.value.args[0]), + expected_sentinel, + ) + + def test_signed_vdiv_helpers_derive_result_sign_from_operand_signs(self) -> None: + @pto.vkernel( + op="signed_vdiv_sign_contract_unique", + dtypes=[(pto.i16, pto.i16), (pto.i32, pto.i32)], + ) + def kernel(dst: pto.Tile, src: pto.Tile): + dtype = dst.element_type + mask = pto.make_mask(dtype, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vdiv(vec, vec, mask) + pto.vsts(out, dst, 0, mask) + return None + + cases = [ + (pto.i16, "__tl_inline__tl_soft_vdiv_i16_", "i16"), + (pto.i32, "__tl_inline__tl_soft_vdiv_i32_", "i32"), + ] + + for dtype, helper_prefix, dtype_name in cases: + with self.subTest(dtype=dtype): + selected = pto.select_kernel("a5", "signed_vdiv_sign_contract_unique", (dtype, dtype)) + specialized = selected.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + helper = _find_inline_helper(semantic_kernel, helper_prefix) + + xor_assign = _find_last_helper_assign_by_name(helper, "x_xor_y") + self.assertIsInstance(xor_assign.value, SemanticCallExpr) + self.assertEqual(xor_assign.value.namespace, "pto") + self.assertEqual(xor_assign.value.name, "vxor") + self.assertEqual(xor_assign.value.args[0].binding.name, "vec") + self.assertEqual(xor_assign.value.args[1].binding.name, "scalar_vec") + self.assertEqual(xor_assign.value.args[2].binding.name, "active_mask") + + p_pos_assign = _find_last_helper_assign_by_name(helper, "p_pos") + self.assertIsInstance(p_pos_assign.value, SemanticCallExpr) + self.assertEqual(p_pos_assign.value.namespace, "pto") + self.assertEqual(p_pos_assign.value.name, "vcmps") + self.assertEqual(p_pos_assign.value.args[0].binding.name, "x_xor_y") + self.assertEqual(p_pos_assign.value.args[1].binding.name, "zero") + self.assertEqual(p_pos_assign.value.args[2].binding.name, "active_mask") + self.assertEqual(p_pos_assign.value.args[3].value, "ge") + + q_assign = _find_last_helper_assign_by_name(helper, "q") + self.assertIsInstance(q_assign.value, SemanticCallExpr) + self.assertEqual(q_assign.value.namespace, "pto") + self.assertEqual(q_assign.value.name, "vsel") + self.assertEqual(q_assign.value.args[1].binding.name, "neg_q") + self.assertEqual(q_assign.value.args[2].binding.name, "p_pos") + self.assertIsInstance(q_assign.value.args[0], SemanticCallExpr) + self.assertEqual(q_assign.value.args[0].namespace, "pto") + self.assertEqual(q_assign.value.args[0].name, "vbitcast") + self.assertIsInstance(q_assign.value.args[0].args[0], SemanticBindingRef) + self.assertEqual(q_assign.value.args[0].args[1].name, dtype_name) + + def test_signed_vmod_helpers_apply_floor_fix_when_signs_differ(self) -> None: + @pto.vkernel( + op="signed_vmod_sign_contract_unique", + dtypes=[(pto.i16, pto.i16), (pto.i32, pto.i32)], + ) + def kernel(dst: pto.Tile, src: pto.Tile): + dtype = dst.element_type + mask = pto.make_mask(dtype, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = pto.vmod(vec, vec, mask) + pto.vsts(out, dst, 0, mask) + return None + + cases = [ + (pto.i16, "__tl_inline__tl_soft_vmod_i16_"), + (pto.i32, "__tl_inline__tl_soft_vmod_i32_"), + ] + + for dtype, helper_prefix in cases: + with self.subTest(dtype=dtype): + selected = pto.select_kernel("a5", "signed_vmod_sign_contract_unique", (dtype, dtype)) + specialized = selected.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + helper = _find_inline_helper(semantic_kernel, helper_prefix) + + nonzero_assign = _find_last_helper_assign_by_name(helper, "nonzero_remainder") + self.assertIsInstance(nonzero_assign.value, SemanticCallExpr) + self.assertEqual(nonzero_assign.value.namespace, "pto") + self.assertEqual(nonzero_assign.value.name, "vcmps") + self.assertEqual(nonzero_assign.value.args[1].binding.name, "zero") + self.assertEqual(nonzero_assign.value.args[2].binding.name, "active_mask") + self.assertEqual(nonzero_assign.value.args[3].value, "ne") + + sign_diff_assign = _find_last_helper_assign_by_name(helper, "sign_diff") + self.assertIsInstance(sign_diff_assign.value, SemanticCallExpr) + self.assertEqual(sign_diff_assign.value.namespace, "pto") + self.assertEqual(sign_diff_assign.value.name, "pxor") + self.assertEqual(sign_diff_assign.value.args[0].binding.name, "sign_x") + self.assertEqual(sign_diff_assign.value.args[1].binding.name, "sign_y") + self.assertEqual(sign_diff_assign.value.args[2].binding.name, "active_mask") + + need_fix_assign = _find_last_helper_assign_by_name(helper, "need_floor_fix") + self.assertIsInstance(need_fix_assign.value, SemanticCallExpr) + self.assertEqual(need_fix_assign.value.namespace, "pto") + self.assertEqual(need_fix_assign.value.name, "pand") + self.assertEqual(need_fix_assign.value.args[0].binding.name, "sign_diff") + self.assertEqual(need_fix_assign.value.args[1].binding.name, "nonzero_remainder") + self.assertEqual(need_fix_assign.value.args[2].binding.name, "active_mask") + + amended_assign = _find_last_helper_assign_by_name(helper, "amended_remainder") + self.assertIsInstance(amended_assign.value, SemanticCallExpr) + self.assertEqual(amended_assign.value.namespace, "pto") + self.assertEqual(amended_assign.value.name, "vadd") + self.assertEqual(amended_assign.value.args[0].binding.name, "scalar_vec") + self.assertEqual(amended_assign.value.args[1].binding.name, "remainder") + self.assertEqual(amended_assign.value.args[2].binding.name, "active_mask") + + remainder_assign = _find_last_helper_assign_by_name(helper, "remainder") + self.assertIsInstance(remainder_assign.value, SemanticCallExpr) + self.assertEqual(remainder_assign.value.namespace, "pto") + self.assertEqual(remainder_assign.value.name, "vsel") + self.assertEqual(remainder_assign.value.args[0].binding.name, "amended_remainder") + self.assertEqual(remainder_assign.value.args[2].binding.name, "need_floor_fix") + + def test_i8_divmod_helpers_use_explicit_widen_narrow_profile(self) -> None: + @pto.vkernel( + op="i8_divmod_widen_narrow_contract_unique", + dtypes=[ + (pto.i8, pto.i8), + (pto.ui8, pto.ui8), + ], + ) + def kernel(dst: pto.Tile, src: pto.Tile): + dtype = dst.element_type + mask = pto.make_mask(dtype, pto.PAT.ALL) + vec = pto.vlds(src, 0) + quot = pto.vdiv(vec, vec, mask) + rem = pto.vmod(vec, vec, mask) + pto.vsts(quot, dst, 0, mask) + pto.vsts(rem, dst, 1, mask) + return None + + cases = [ + ( + pto.i8, + "__tl_inline__tl_soft_vdiv_i8_", + "__tl_inline__tl_soft_vdiv_i16_", + "vsunpack", + "q", + "q_low", + "q_high", + "vbitcast", + ), + ( + pto.ui8, + "__tl_inline__tl_soft_vdiv_u8_", + "__tl_inline__tl_soft_vdiv_u16_", + "vzunpack", + "q", + "q_low", + "q_high", + "vor", + ), + ( + pto.i8, + "__tl_inline__tl_soft_vmod_i8_", + "__tl_inline__tl_soft_vmod_i16_", + "vsunpack", + "r", + "r_low", + "r_high", + "vbitcast", + ), + ( + pto.ui8, + "__tl_inline__tl_soft_vmod_u8_", + "__tl_inline__tl_soft_vmod_u16_", + "vzunpack", + "r", + "r_low", + "r_high", + "vor", + ), + ] + + for ( + dtype, + helper_prefix, + widened_helper_prefix, + unpack_name, + packed_result_name, + lower_result_name, + higher_result_name, + packed_result_op, + ) in cases: + with self.subTest(dtype=dtype, helper=helper_prefix): + selected = pto.select_kernel( + "a5", + "i8_divmod_widen_narrow_contract_unique", + (dtype, dtype), + ) + specialized = selected.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + helper = _find_inline_helper(semantic_kernel, helper_prefix) + + active_low_assign = _find_last_helper_assign_by_name(helper, "active_low") + self.assertIsInstance(active_low_assign.value, SemanticCallExpr) + self.assertEqual(active_low_assign.value.namespace, "pto") + self.assertEqual(active_low_assign.value.name, "punpack") + self.assertEqual(active_low_assign.value.args[0].binding.name, "active_mask") + + active_high_assign = _find_last_helper_assign_by_name(helper, "active_high") + self.assertIsInstance(active_high_assign.value, SemanticCallExpr) + self.assertEqual(active_high_assign.value.namespace, "pto") + self.assertEqual(active_high_assign.value.name, "punpack") + self.assertEqual(active_high_assign.value.args[0].binding.name, "active_mask") + + for name, expected_half in ( + ("vec_low", 0), + ("vec_high", 1), + ("scalar_low", 0), + ("scalar_high", 1), + ): + assign = _find_last_helper_assign_by_name(helper, name) + self.assertIsInstance(assign.value, SemanticCallExpr) + self.assertEqual(assign.value.namespace, "pto") + self.assertEqual(assign.value.name, unpack_name) + self.assertEqual(assign.value.args[1].value, expected_half) + + lower_assign = _find_last_helper_assign_by_name(helper, lower_result_name) + self.assertIsInstance(lower_assign.value, SemanticCallExpr) + self.assertIsNone(lower_assign.value.namespace) + self.assertRegex(lower_assign.value.name, rf"^{widened_helper_prefix}") + self.assertEqual(lower_assign.value.args[2].binding.name, "active_low") + + higher_assign = _find_last_helper_assign_by_name(helper, higher_result_name) + self.assertIsInstance(higher_assign.value, SemanticCallExpr) + self.assertIsNone(higher_assign.value.namespace) + self.assertRegex(higher_assign.value.name, rf"^{widened_helper_prefix}") + self.assertEqual(higher_assign.value.args[2].binding.name, "active_high") + + packed_low_assign = _find_last_helper_assign_by_name(helper, "packed_low") + self.assertIsInstance(packed_low_assign.value, SemanticCallExpr) + self.assertEqual(packed_low_assign.value.namespace, "pto") + self.assertEqual(packed_low_assign.value.name, "vpack") + self.assertEqual(packed_low_assign.value.args[0].binding.name, lower_result_name) + + packed_high_assign = _find_last_helper_assign_by_name(helper, "packed_high") + self.assertIsInstance(packed_high_assign.value, SemanticCallExpr) + self.assertEqual(packed_high_assign.value.namespace, "pto") + self.assertEqual(packed_high_assign.value.name, "vpack") + self.assertEqual(packed_high_assign.value.args[0].binding.name, higher_result_name) + + packed_result_assign = _find_last_helper_assign_by_name(helper, packed_result_name) + self.assertIsInstance(packed_result_assign.value, SemanticCallExpr) + self.assertEqual(packed_result_assign.value.namespace, "pto") + self.assertEqual(packed_result_assign.value.name, packed_result_op) + if packed_result_op == "vor": + combined_expr = packed_result_assign.value + else: + self.assertIsInstance(packed_result_assign.value.args[0], SemanticCallExpr) + combined_expr = packed_result_assign.value.args[0] + self.assertEqual(combined_expr.namespace, "pto") + self.assertEqual(combined_expr.name, "vor") + self.assertEqual(combined_expr.args[0].binding.name, "packed_low") + self.assertEqual(combined_expr.args[1].binding.name, "packed_high") + self.assertEqual(combined_expr.args[2].binding.name, "full_mask_b8") + + def test_integer_divmod_rewrite_uses_injected_internal_helpers(self) -> None: + @pto.vkernel(op="divmod_internal_helper_injection_unique", dtypes=[(pto.i32, pto.i32)]) + def kernel(dst: pto.Tile, src: pto.Tile): + mask = pto.make_mask(pto.i32, pto.PAT.ALL) + vec = pto.vlds(src, 0) + quot = pto.vdiv(vec, vec, mask) + rem = pto.vmod(vec, vec, mask) + pto.vsts(quot, dst, 0, mask) + pto.vsts(rem, dst, 1, mask) + return None + + self.assertNotIn("vdiv", kernel.py_fn.__globals__) + self.assertNotIn("vmod", kernel.py_fn.__globals__) + self.assertNotIn("_tl_soft_vdiv_i32", kernel.inline_procs) + self.assertNotIn("_tl_soft_vmod_i32", kernel.inline_procs) + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + frontend_kernel = build_frontend_kernel_node(specialized) + + self.assertEqual({proc.name for proc in frontend_kernel.inline_procs}, set()) + internal_names = {proc.name for proc in frontend_kernel.internal_inline_procs} + self.assertIn("_tl_soft_vdiv", internal_names) + self.assertIn("_tl_soft_vmod", internal_names) + self.assertIn("_tl_soft_vdiv_i32", internal_names) + self.assertIn("_tl_soft_vmod_i32", internal_names) + + semantic_kernel = analyze_frontend_kernel(frontend_kernel) + helper_symbols = {helper.symbol_name for helper in semantic_kernel.inline_helpers} + self.assertTrue(any(name.startswith("__tl_inline__tl_soft_vdiv_") for name in helper_symbols)) + self.assertTrue(any(name.startswith("__tl_inline__tl_soft_vmod_") for name in helper_symbols)) + + def test_internal_vdiv_helper_name_is_not_public_surface(self) -> None: + with self.assertRaises(pto.TileLangFrontendError) as ctx: + + @pto.vkernel(op="internal_vdiv_helper_reject_unique", dtypes=[(pto.i32, pto.i32)]) + def kernel(dst: pto.Tile, src: pto.Tile): + mask = pto.make_mask(pto.i32, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = _tl_soft_vdiv_i32(vec, vec, mask) + pto.vsts(out, dst, 0, mask) + return None + + self.assertIn( + "arbitrary external call `_tl_soft_vdiv_i32` is not supported in TileLang DSL v1", + str(ctx.exception), + ) + + def test_internal_vmod_helper_name_is_not_public_surface(self) -> None: + with self.assertRaises(pto.TileLangFrontendError) as ctx: + + @pto.vkernel(op="internal_vmod_helper_reject_unique", dtypes=[(pto.i32, pto.i32)]) + def kernel(dst: pto.Tile, src: pto.Tile): + mask = pto.make_mask(pto.i32, pto.PAT.ALL) + vec = pto.vlds(src, 0) + out = _tl_soft_vmod_i32(vec, vec, mask) + pto.vsts(out, dst, 0, mask) + return None + + self.assertIn( + "arbitrary external call `_tl_soft_vmod_i32` is not supported in TileLang DSL v1", + str(ctx.exception), + ) + + def test_inline_proc_and_pto_surface_can_share_basename(self) -> None: + @pto.inline_proc + def vdiv(src: pto.Tile, lane: pto.i32 = 0): + return pto.vlds(src, lane) + + @pto.vkernel(op="inline_proc_same_basename_as_pto_surface_unique", + dtypes=[(pto.f32, pto.f32)]) + def kernel(dst: pto.Tile, src: pto.Tile): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + helper_vec = vdiv(src, 0) + raw_vec = pto.vdiv(helper_vec, helper_vec, mask) + pto.vsts(raw_vec, dst, 0, mask) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + + frontend_kernel = build_frontend_kernel_node(specialized) + self.assertGreaterEqual(len(frontend_kernel.inline_procs), 1) + self.assertIn("vdiv", {proc.name for proc in frontend_kernel.inline_procs}) + call_values = [ + stmt.value + for stmt in frontend_kernel.body + if isinstance(stmt, FrontendAssignStmt) + and isinstance(stmt.value, FrontendCallExpr) + ] + helper_call = next( + value for value in call_values if value.namespace is None and value.name == "vdiv" + ) + raw_call = next( + value for value in call_values if value.namespace == "pto" and value.name == "vdiv" + ) + self.assertEqual(len(helper_call.args), 2) + self.assertEqual(len(raw_call.args), 3) + self.assertIsInstance(raw_call, FrontendCallExpr) + self.assertIsNone(helper_call.namespace) + self.assertEqual(helper_call.name, "vdiv") + self.assertEqual(raw_call.namespace, "pto") + self.assertEqual(raw_call.name, "vdiv") + + text = specialized.mlir_text() + self.assertRegex(text, r"func\.call @__tl_inline_vdiv_") + self.assertIn("= pto.vdiv ", text) + + def test_inline_proc_rejects_non_trailing_return(self) -> None: + with self.assertRaises(pto.TileLangFrontendError) as ctx: + + @pto.inline_proc + def bad_inline(flag: pto.i32): + if flag: + return flag + return flag + + self.assertIn("optional trailing `return`", str(ctx.exception)) + self.assertIn(f"{__file__}:", str(ctx.exception)) + + def test_inline_proc_rejects_recursive_calls(self) -> None: + with self.assertRaises(pto.TileLangFrontendError) as ctx: + + @pto.vkernel(op="inline_proc_recursive_unique", dtypes=[(pto.f32,)]) + def kernel(dst: pto.Tile): + _inline_recur(dst) + return None + + kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB) + ).mlir_text() + + self.assertIn("recursive inline_proc call `_inline_recur` is not supported", str(ctx.exception)) + self.assertIn(f"{__file__}:", str(ctx.exception)) + + def test_inline_proc_rejects_implicit_capture(self) -> None: + with self.assertRaises(pto.TileLangFrontendError) as ctx: + + @pto.vkernel(op="inline_proc_capture_unique", dtypes=[(pto.f32,)]) + def kernel(dst: pto.Tile): + lane = pto.i32(0) + _inline_capture(dst) + return None + + kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB) + ).mlir_text() + + self.assertIn("implicit capture of 'lane' is not allowed in inline_proc", str(ctx.exception)) + self.assertIn(f"{__file__}:", str(ctx.exception)) + + def test_inline_proc_allows_module_level_literal_capture(self) -> None: + @pto.vkernel(op="inline_proc_global_literal_capture_unique", dtypes=[(pto.f32,)]) + def kernel(dst: pto.Tile): + _inline_capture_global_literal(dst) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB) + ) + + frontend_kernel = build_frontend_kernel_node(specialized) + self.assertIn( + "_inline_capture_global_literal", + {proc.name for proc in frontend_kernel.inline_procs}, + ) + + text = specialized.mlir_text() + self.assertIn("func.call", text) + self.assertIn("arith.constant 0 : index", text) + + def test_inline_proc_rejects_kw_only_vararg_and_kwargs(self) -> None: + with self.assertRaises(pto.TileLangFrontendError) as kw_only_ctx: + + @pto.inline_proc + def bad_kw_only(dst: pto.Tile, *, lane: pto.i32): + return None + + self.assertIn("keyword-only parameters", str(kw_only_ctx.exception)) + + with self.assertRaises(pto.TileLangFrontendError) as vararg_ctx: + + @pto.inline_proc + def bad_vararg(dst: pto.Tile, *lanes: pto.i32): + return None + + self.assertIn("does not support *args", str(vararg_ctx.exception)) + + with self.assertRaises(pto.TileLangFrontendError) as kwargs_ctx: + + @pto.inline_proc + def bad_kwargs(dst: pto.Tile, **opts: pto.i32): + return None + + self.assertIn("does not support **kwargs", str(kwargs_ctx.exception)) + + def test_inline_proc_rejects_invalid_keyword_binding(self) -> None: + @pto.inline_proc + def inline_store(dst: pto.Tile, src: pto.Tile): + return None + + with self.assertRaises(pto.TileLangFrontendError) as ctx: + + @pto.vkernel(op="inline_proc_invalid_keyword_unique", dtypes=[(pto.f32, pto.f32)]) + def kernel(dst: pto.Tile, src: pto.Tile): + inline_store(dst=dst, src=src, lane=0) + return None + + self.assertIn("unexpected keyword argument 'lane'", str(ctx.exception)) + self.assertIn(f"{__file__}:", str(ctx.exception)) + + def test_inline_proc_rejects_missing_required_argument(self) -> None: + @pto.inline_proc + def inline_store(dst: pto.Tile, src: pto.Tile): + return None + + with self.assertRaises(pto.TileLangFrontendError) as ctx: + + @pto.vkernel(op="inline_proc_missing_required_unique", dtypes=[(pto.f32, pto.f32)]) + def kernel(dst: pto.Tile, src: pto.Tile): + inline_store(dst=dst) + return None + + self.assertIn("missing a required argument: 'src'", str(ctx.exception)) + self.assertIn(f"{__file__}:", str(ctx.exception)) + + def test_inline_proc_rejects_multiple_values_for_single_parameter(self) -> None: + @pto.inline_proc + def inline_store(dst: pto.Tile, src: pto.Tile): + return None + + with self.assertRaises(pto.TileLangFrontendError) as ctx: + + @pto.vkernel(op="inline_proc_multiple_values_unique", dtypes=[(pto.f32, pto.f32)]) + def kernel(dst: pto.Tile, src: pto.Tile): + inline_store(dst, src, src=src) + return None + + self.assertIn("multiple values for argument 'src'", str(ctx.exception)) + self.assertIn(f"{__file__}:", str(ctx.exception)) + + def test_inline_proc_semantic_emits_controlled_namespace_none_call(self) -> None: + @pto.vkernel(op="inline_proc_semantic_call_unique", dtypes=[(pto.f32, pto.f32)]) + def kernel(dst: pto.Tile, src: pto.Tile): + _inline_copy_row(dst, src, 0) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + + call_stmts = [ + stmt + for stmt in semantic_kernel.body + if isinstance(stmt, SemanticExprStmt) and isinstance(stmt.expr, SemanticCallExpr) + ] + self.assertGreaterEqual(len(call_stmts), 1) + inline_call = call_stmts[0].expr + self.assertIsNone(inline_call.namespace) + self.assertRegex(inline_call.name, r"^__tl_inline_") + self.assertGreaterEqual(len(semantic_kernel.inline_helpers), 1) + self.assertRegex(semantic_kernel.inline_helpers[0].symbol_name, r"^__tl_inline_") + + def test_inline_proc_semantic_keeps_expression_call_return_type(self) -> None: + @pto.inline_proc + def inline_const_i32(): + return pto.i32(1) + + @pto.vkernel(op="inline_proc_semantic_expr_unique", dtypes=[(pto.f32,)]) + def kernel(dst: pto.Tile): + lane = inline_const_i32() + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + + assign_stmt = next( + stmt + for stmt in semantic_kernel.body + if isinstance(stmt, SemanticAssignStmt) and isinstance(stmt.value, SemanticCallExpr) + ) + self.assertIsNone(assign_stmt.value.namespace) + self.assertRegex(assign_stmt.value.name, r"^__tl_inline_") + self.assertIsInstance(assign_stmt.value.type, SemanticScalarType) + + def test_inline_proc_lowering_renders_private_helpers_and_call_bindings(self) -> None: + @pto.inline_proc + def inline_const_i32(): + return 1 + + @pto.inline_proc + def inline_store(dst: pto.Tile, src: pto.Tile): + lane = inline_const_i32() + _inline_copy_row(dst, src, lane) + return None + + @pto.vkernel(op="inline_proc_lowering_helpers_unique", dtypes=[(pto.f32, pto.f32)]) + def kernel(dst: pto.Tile, src: pto.Tile): + lane = inline_const_i32() + inline_store(dst, src) + _inline_copy_row(dst, src, lane) + return None + + text = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ).mlir_text() + self.assertIn("func.func private @__tl_inline_", text) + self.assertGreaterEqual(text.count("func.func"), 3) + self.assertRegex(text, r"= func\.call @__tl_inline_[A-Za-z0-9_]+\(.*\) : \([^\)]*\) -> index") + self.assertRegex(text, r"func\.call @__tl_inline_[A-Za-z0-9_]+\(.*\) : \([^\)]*\) -> \(\)") + + def test_inline_proc_supports_constexpr_dtype_dispatch(self) -> None: + @pto.inline_proc + def inline_pick_lane(dtype): + if pto.constexpr(dtype == pto.ui16): + lane = 1 + elif pto.constexpr(dtype == pto.i16): + lane = 2 + elif pto.constexpr(dtype == pto.ui32): + lane = 3 + else: + lane = 4 + return lane + + @pto.vkernel(op="inline_proc_constexpr_dtype_dispatch_unique", dtypes=[(pto.f32,)]) + def kernel(dst: pto.Tile): + lane = inline_pick_lane(dst.element_type) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + + assign_stmt = next( + stmt + for stmt in semantic_kernel.body + if isinstance(stmt, SemanticAssignStmt) and isinstance(stmt.value, SemanticCallExpr) + ) + self.assertRegex(assign_stmt.value.name, r"^__tl_inline_") + self.assertEqual(len(semantic_kernel.inline_helpers), 1) + helper_assign = next( + stmt + for stmt in semantic_kernel.inline_helpers[0].body + if isinstance(stmt, SemanticAssignStmt) + and len(stmt.targets) == 1 + and stmt.targets[0].name == "lane" + ) + self.assertIsInstance(helper_assign.value, SemanticLiteralExpr) + self.assertEqual(helper_assign.value.value, 4) + + def test_inline_proc_specializes_same_type_with_different_static_values(self) -> None: + @pto.inline_proc + def inline_scale(lane: pto.i32): + if pto.constexpr(lane == 1): + value = 2 + else: + value = 4 + return value + + @pto.vkernel(op="inline_proc_static_value_specialization_unique", dtypes=[(pto.f32,)]) + def kernel(dst: pto.Tile): + lane0 = inline_scale(1) + lane1 = inline_scale(2) + return None + + specialized = kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(specialized)) + + self.assertEqual(len(semantic_kernel.inline_helpers), 2) + literal_values = [] + for helper in semantic_kernel.inline_helpers: + helper_assign = next( + stmt + for stmt in helper.body + if isinstance(stmt, SemanticAssignStmt) + and len(stmt.targets) == 1 + and stmt.targets[0].name == "value" + ) + self.assertIsInstance(helper_assign.value, SemanticLiteralExpr) + literal_values.append(helper_assign.value.value) + self.assertEqual(sorted(literal_values), [2, 4]) + + def test_inline_proc_rejects_mutual_recursion(self) -> None: + @pto.inline_proc + def inline_a(dst: pto.Tile): + inline_b(dst) + return None + + @pto.inline_proc + def inline_b(dst: pto.Tile): + inline_a(dst) + return None + + with self.assertRaises(pto.TileLangFrontendError) as ctx: + + @pto.vkernel(op="inline_proc_mutual_recursion_unique", dtypes=[(pto.f32,)]) + def kernel(dst: pto.Tile): + inline_a(dst) + return None + + kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB) + ).mlir_text() + + self.assertIn("recursive inline_proc call", str(ctx.exception)) + self.assertIn(f"{__file__}:", str(ctx.exception)) + + +class TileLangDSLDiagnosticsTests(unittest.TestCase): + def test_matcher_feature_validation_rejects_invalid_constraints_and_priority(self) -> None: + def kernel(x: pto.TensorView): + return None + + with self.assertRaises(TypeError) as constraints_ctx: + pto.vkernel(op="x", dtypes=[(pto.f32,)], constraints=[123])(kernel) + self.assertIn("constraints[0] must be callable", str(constraints_ctx.exception)) + + with self.assertRaises(TypeError) as priority_ctx: + pto.vkernel(op="x", dtypes=[(pto.f32,)], priority=True)(kernel) + self.assertIn("priority must be an int", str(priority_ctx.exception)) + + def test_ckernel_rejects_vector_only_surface_in_body(self) -> None: + with self.assertRaises(pto.TileLangFrontendError) as ctx: + + @pto.ckernel(op="pto.mad", dtypes=[(pto.f16,)], name="cube_reject_vadd_unique") + def kernel(tile: pto.Tile): + vec = pto.vadd(1, 2, 3) + return None + + self.assertIn("vector-only surface `pto.vadd` is not part of the @pto.ckernel contract", str(ctx.exception)) + + def test_ckernel_rejects_vector_load_surface_in_body(self) -> None: + with self.assertRaises(pto.TileLangFrontendError) as ctx: + + @pto.ckernel(op="pto.mad", dtypes=[(pto.f16,)], name="cube_reject_vlds_unique") + def kernel(tile: pto.Tile): + vec = pto.vlds(tile, 0) + return None + + self.assertIn("vector-only surface `pto.vlds` is not part of the @pto.ckernel contract", str(ctx.exception)) + + def test_ckernel_rejects_vecscope_in_body(self) -> None: + with self.assertRaises(pto.TileLangFrontendError) as ctx: + + @pto.ckernel(op="pto.mad", dtypes=[(pto.f16,)], name="cube_reject_vecscope_unique") + def kernel(tile: pto.Tile): + with pto.vecscope(): + return None + return None + + self.assertIn("@pto.ckernel does not support pto.vecscope()/pto.strict_vecscope()", str(ctx.exception)) + + def test_ckernel_rejects_strict_vecscope_in_body(self) -> None: + with self.assertRaises(pto.TileLangFrontendError) as ctx: + + @pto.ckernel(op="pto.mad", dtypes=[(pto.f16,)], name="cube_reject_strict_vecscope_unique") + def kernel(tile: pto.Tile): + with pto.strict_vecscope(tile, tile, 0, 16, 16) as (src, dst, lb, ub, step): + return None + return None + + self.assertIn("@pto.ckernel does not support pto.vecscope()/pto.strict_vecscope()", str(ctx.exception)) + + def test_ckernel_rejects_schema_form_matcher_surface(self) -> None: + with self.assertRaises(ValueError) as ctx: + + @pto.ckernel( + op="pto.mad ins(lhs: f16, rhs: f16) -> outs(acc: f32)", + dtypes=[(pto.f16, pto.f16, pto.f32)], + name="cube_schema_form_reject_unique", + ) + def kernel(lhs: pto.Tile, rhs: pto.Tile, acc: pto.Tile): + return None + + self.assertIn("@pto.ckernel does not support schema-form op matching", str(ctx.exception)) + + def test_ckernel_rejects_vector_only_inline_helper_surface(self) -> None: + @pto.inline_proc + def cube_bad_helper(tile: pto.Tile): + return pto.vlds(tile, 0) + + @pto.ckernel( + op="cube_inline_helper_reject_query_unique", + dtypes=[(pto.f16,)], + name="cube_inline_helper_reject_unique", + ) + def kernel(tile: pto.Tile): + cube_bad_helper(tile) + return None + + specialized = kernel.specialize( + tile=pto.TileSpecialization(shape=(16, 32), memory_space=pto.MemorySpace.LEFT) + ) + + with self.assertRaises(pto.TileLangFrontendError) as ctx: + build_frontend_kernel_node(specialized) + + self.assertIn("vector-only surface `pto.vlds` is not part of the @pto.ckernel contract", str(ctx.exception)) + self.assertIn(f"{__file__}:", str(ctx.exception)) + + def test_vkernel_rejects_cube_only_inline_helper_surface(self) -> None: + @pto.inline_proc + def vector_bad_helper(dst: pto.TensorView): + return pto.mad(1, 2, 3, 16, 16, 32) + + @pto.vkernel( + op="vector_inline_helper_reject_query_unique", + dtypes=[(pto.f32,)], + name="vector_inline_helper_reject_unique", + ) + def kernel(dst: pto.TensorView): + vector_bad_helper(dst) + return None + + with self.assertRaises(pto.TileLangFrontendError) as ctx: + build_frontend_kernel_node(kernel) + + self.assertIn("cube-only surface `pto.mad` is not part of the @pto.vkernel contract", str(ctx.exception)) + self.assertIn(f"{__file__}:", str(ctx.exception)) + + def test_vkernel_rejects_cube_only_surface_in_body(self) -> None: + with self.assertRaises(pto.TileLangFrontendError) as ctx: + + @pto.vkernel(op="vector_reject_mad_unique", dtypes=[(pto.f32,)]) + def kernel(dst: pto.TensorView): + acc = pto.mad(1, 2, 3, 16, 16, 32) + return None + + self.assertIn("cube-only surface `pto.mad` is not part of the @pto.vkernel contract", str(ctx.exception)) + + def test_vkernel_template_slot_rejects_cube_only_surface(self) -> None: + @pto.vkernel( + op="template_slot_cube_surface_unique", + dtypes=[(pto.f32,)], + templates={"compute": {"template_slot_cube_surface_unique": "mad"}}, + ) + def kernel(dst: pto.TensorView): + out = pto.tpl("compute", 1, 2, 3, 16, 16, 32) + return None + + with self.assertRaises(pto.TileLangFrontendError) as ctx: + build_frontend_kernel_node(kernel) + + self.assertIn("cube-only surface `pto.mad` is not part of the @pto.vkernel contract", str(ctx.exception)) + self.assertIn(f"{__file__}:", str(ctx.exception)) + + def test_ckernel_cube_ops_semantic_validation_accepts_structured_surface(self) -> None: + @pto.ckernel(op="pto.mad", dtypes=[(pto.f16, pto.f16, pto.f32)], name="cube_semantic_success_unique") + def kernel(inp: pto.TensorView, bias_src: pto.Tile): + gm = inp.as_ptr() + l1 = pto.Tile((16, 32), pto.f16, pto.MemorySpace.MAT) + left = pto.Tile((16, 32), pto.f16, pto.MemorySpace.LEFT) + right = pto.Tile((32, 16), pto.f16, pto.MemorySpace.RIGHT) + acc = pto.Tile((16, 16), pto.f32, pto.MemorySpace.ACC) + bias = pto.Tile((1, 16), pto.f32, pto.MemorySpace.BIAS) + ub = pto.Tile((16, 16), pto.f32, pto.MemorySpace.UB) + + pto.cube_load(gm, l1.as_ptr(), 16, nburst=(1, 0, 0)) + pto.bias_load(l1.as_ptr(), bias.as_ptr(), 16, nburst=(1, 0, 0)) + pto.left_load(l1.as_ptr(), left.as_ptr(), 16, 32) + pto.right_load(l1.as_ptr(), right.as_ptr(), 32, 16) + pto.mad(left.as_ptr(), right.as_ptr(), acc.as_ptr(), 16, 16, 32, unit_flag_ctrl=2, disable_gemv=pto.i1(True)) + pto.cube_load_frac( + gm, + l1.as_ptr(), + pto.FractalMode.ND2NZ, + shape=(16, 16), + src_layout=(4,), + dst_group=(1, 0, 0, 0), + ctrl=(0, False), + ) + pto.acc_store(acc.as_ptr(), l1.as_ptr(), 16, 16, 16, 16, mode=pto.FractalMode.NZ2ND) + pto.acc_store_gm( + acc.as_ptr(), + gm, + 16, + 16, + 16, + 16, + mode=pto.FractalMode.NZ2NZ, + split=0, + sid=0, + l2_cache_ctrl=0, + ) + pto.acc_store_ub( + acc.as_ptr(), + ub.as_ptr(), + 16, + 16, + 16, + 16, + mode=pto.FractalMode.NZ2NZ, + channel_split_en=0, + dual_dst_mode=0, + sub_blockid=0, + ) + return None + + semantic_kernel = analyze_frontend_kernel(build_frontend_kernel_node(kernel)) + cube_calls = [ + stmt + for stmt in semantic_kernel.body + if isinstance(stmt, SemanticExprStmt) and isinstance(stmt.expr, SemanticCallExpr) + and stmt.expr.namespace == "pto" + and stmt.expr.name in { + "cube_load", + "bias_load", + "left_load", + "right_load", + "mad", + "cube_load_frac", + "acc_store", + "acc_store_gm", + "acc_store_ub", + } + ] + self.assertGreaterEqual(len(cube_calls), 8) + mad_stmt = next(stmt for stmt in cube_calls if stmt.expr.name == "mad") + self.assertEqual(mad_stmt.expr.args[-2].value, 2) + self.assertIsInstance(mad_stmt.expr.args[-1], SemanticLiteralExpr) + self.assertTrue(mad_stmt.expr.args[-1].value) + frac_stmt = next(stmt for stmt in cube_calls if stmt.expr.name == "cube_load_frac") + self.assertIsInstance(frac_stmt.expr.args[3], SemanticTupleExpr) + self.assertIsInstance(frac_stmt.expr.args[4], SemanticTupleExpr) + self.assertIsInstance(frac_stmt.expr.args[5], SemanticTupleExpr) + self.assertIsInstance(frac_stmt.expr.args[6], SemanticTupleExpr) + + def test_ckernel_cube_ops_reject_invalid_mode_and_address_space(self) -> None: + @pto.ckernel(op="pto.mad", dtypes=[(pto.f16, pto.f16, pto.f32)], name="cube_bad_mode_unique") + def mode_kernel(inp: pto.TensorView, tile: pto.Tile): + gm = inp.as_ptr() + mat = pto.Tile((16, 16), pto.f16, pto.MemorySpace.MAT) + pto.cube_load_frac( + gm, + mat.as_ptr(), + "bad", + shape=(16, 16), + src_layout=(4,), + dst_group=(1, 0, 0, 0), + ctrl=(0, False), + ) + + with self.assertRaises(TypeError) as mode_ctx: + analyze_frontend_kernel(build_frontend_kernel_node(mode_kernel)) + + self.assertIn("pto.cube_load_frac mode must be", str(mode_ctx.exception)) + + @pto.ckernel(op="pto.mad", dtypes=[(pto.f16, pto.f16, pto.f32)], name="cube_bad_addr_unique") + def addr_kernel(inp: pto.TensorView, tile: pto.Tile): + gm = inp.as_ptr() + mat = pto.Tile((16, 16), pto.f16, pto.MemorySpace.MAT) + left = pto.Tile((16, 16), pto.f16, pto.MemorySpace.LEFT) + pto.left_load(gm, left.as_ptr(), 16, 16) + + with self.assertRaises(TypeError) as addr_ctx: + analyze_frontend_kernel(build_frontend_kernel_node(addr_kernel)) + + self.assertIn("pto.left_load source requires MemorySpace.MAT pointers in TileLang DSL", str(addr_ctx.exception)) + + def test_advanced_mode_keeps_vreduce_rejected_until_authoring_op_exists(self) -> None: + with self.assertRaises(pto.TileLangFrontendError) as ctx: + + @pto.vkernel(op="x", dtypes=[(pto.i32,)], advanced=True) + def kernel(x: pto.Tile): + pto.vreduce(x) + return None + + self.assertIn("advanced family surface `pto.vreduce`", str(ctx.exception)) + + def test_set_mov_pad_val_rejects_unsupported_scalar_dtype(self) -> None: + with self.assertRaises(TypeError) as ctx: + + @pto.vkernel(op="set_mov_pad_val_bad_dtype_unique", dtypes=[(pto.f32,)], advanced=True) + def kernel(dst: pto.Tile): + pto.set_mov_pad_val(pto.i64(0)) + return None + + kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB) + ).mlir_text() + + self.assertIn( + "pto.set_mov_pad_val pad_value must be an 8/16/32-bit integer or f16/bf16/f32", + str(ctx.exception), + ) + + def test_unsupported_python_syntax_reports_source_location(self) -> None: + with self.assertRaises(pto.TileLangFrontendError) as ctx: + + @pto.vkernel(op="x", dtypes=[(pto.f32,)]) + def kernel(x: pto.TensorView): + while True: + return None + + self.assertIn("unsupported Python syntax `while`", str(ctx.exception)) + self.assertIn(f"{__file__}:", str(ctx.exception)) + + def test_pass_statement_builds_frontend_noop_and_compiles(self) -> None: + @pto.vkernel(op="pass_statement_frontend_noop_unique", dtypes=[(pto.f32,)]) + def kernel(dst: pto.Tile): + pass + if pto.constexpr(True): + pass + else: + pass + return None + + selected = pto.select_kernel( + "a5", + "pass_statement_frontend_noop_unique", + (pto.f32,), + ) + frontend_kernel = build_frontend_kernel_node(selected) + self.assertIsInstance(frontend_kernel.body[0], FrontendNoOpStmt) + self.assertIsInstance(frontend_kernel.body[1], FrontendIfStmt) + if_stmt = frontend_kernel.body[1] + self.assertTrue(if_stmt.is_constexpr) + self.assertIsInstance(if_stmt.then_body[0], FrontendNoOpStmt) + self.assertIsInstance(if_stmt.else_body[0], FrontendNoOpStmt) + + text = selected.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ).mlir_text() + self.assertIn("return", text) + self.assertNotIn("scf.if", text) + + def test_vreg_annotated_assignment_rejects_mismatched_dtype(self) -> None: + with self.assertRaises(TypeError) as ctx: + + @pto.vkernel(op="vreg_annotation_mismatch_unique", dtypes=[(pto.f32,)], advanced=True) + def kernel(dst: pto.Tile): + vec: pto.vreg(pto.f16) = pto.vlds(dst, 0) + return None + + kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ).mlir_text() + + self.assertIn("annotated vector type `vreg(f16)` does not match inferred !pto.vreg<64xf32>", str(ctx.exception)) + + def test_mask_annotated_assignment_rejects_mismatched_granularity(self) -> None: + with self.assertRaises(TypeError) as ctx: + + @pto.vkernel(op="mask_annotation_mismatch_unique", dtypes=[(pto.f32,)], advanced=True) + def kernel(dst: pto.Tile): + mask: pto.mask_b16 = pto.make_mask(pto.f32, pto.PAT.ALL) + return None + + kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 64), memory_space=pto.MemorySpace.UB), + ).mlir_text() + + self.assertIn("annotated mask type `mask_b16` does not match inferred !pto.mask", str(ctx.exception)) + + def test_arbitrary_external_call_reports_source_location(self) -> None: + def helper(): + return None + + with self.assertRaises(pto.TileLangFrontendError) as ctx: + + @pto.vkernel(op="x", dtypes=[(pto.f32,)]) + def kernel(x: pto.TensorView): + helper() + return None + + self.assertIn("arbitrary external call `helper`", str(ctx.exception)) + self.assertIn(f"{__file__}:", str(ctx.exception)) + + def test_vstur_rejects_raw_string_mode_and_requires_enum(self) -> None: + with self.assertRaises(TypeError) as ctx: + + @pto.vkernel(op="vstur_raw_string_mode_unique", dtypes=[(pto.f32,)], advanced=True) + def kernel(dst: pto.ptr(pto.f32, pto.MemorySpace.UB)): + align = pto.init_align() + vec = pto.vbr(1.0) + pto.vstur(align, vec, dst, "POST_UPDATE") + return None + + kernel.specialize().mlir_text() + + self.assertIn("pto.vstur mode must be a PostUpdateMode enum", str(ctx.exception)) + + def test_unsupported_pto_surface_reports_source_location(self) -> None: + with self.assertRaises(pto.TileLangFrontendError) as ctx: + + @pto.vkernel(op="x", dtypes=[(pto.f32,)]) + def kernel(x: pto.TensorView): + pto.not_a_real_surface(x) + return None + + self.assertIn("unsupported op surface `pto.not_a_real_surface`", str(ctx.exception)) + self.assertIn(f"{__file__}:", str(ctx.exception)) + + def test_strict_vecscope_requires_advanced_mode(self) -> None: + with self.assertRaises(pto.TileLangFrontendError) as ctx: + + @pto.vkernel(op="x", dtypes=[(pto.f32, pto.f32)]) + def kernel(x: pto.TensorView, tile: pto.Tile): + with pto.strict_vecscope(tile, tile, 0, 256, 64) as (lhs, rhs, lb, ub, step): + pass + return None + + self.assertIn("surface `pto.strict_vecscope` requires advanced=True", str(ctx.exception)) + self.assertIn(f"{__file__}:", str(ctx.exception)) + + def test_advanced_family_requires_advanced_mode(self) -> None: + with self.assertRaises(pto.TileLangFrontendError) as ctx: + + @pto.vkernel(op="x", dtypes=[(pto.f32, pto.f32)]) + def kernel(x: pto.TensorView, tile: pto.Tile): + mask = pto.make_mask(pto.f32, pto.PAT.ALL) + pto.vcmp(tile, tile, mask, "lt") + return None + + self.assertIn("surface `pto.vcmp` requires advanced=True", str(ctx.exception)) + self.assertIn(f"{__file__}:", str(ctx.exception)) + + def test_missing_specialization_reports_source_location(self) -> None: + @pto.vkernel(op="x", dtypes=[(pto.f32, pto.f16)]) + def kernel(x: pto.TensorView, tile: pto.Tile): + return None + + with self.assertRaises(pto.TileLangFrontendError) as ctx: + kernel.mlir_text() + + self.assertIn("requires specialize() bindings for bare Tile parameters", str(ctx.exception)) + self.assertIn(f"{__file__}:", str(ctx.exception)) + + def test_dynamic_shape_and_illegal_profile_report_source_location(self) -> None: + @pto.vkernel(op="x", dtypes=[(pto.f32, pto.f16)]) + def kernel(x: pto.TensorView, tile: pto.Tile): + return None + + with self.assertRaises(pto.TileLangFrontendError) as dynamic_ctx: + kernel.specialize(tile={"shape": (16, "n"), "memory_space": "ub"}) + self.assertIn("dynamic physical Tile shape is not supported", str(dynamic_ctx.exception)) + self.assertIn(f"{__file__}:", str(dynamic_ctx.exception)) + + with self.assertRaises(pto.TileLangFrontendError) as rank_ctx: + kernel.specialize(tile={"shape": (4, 4, 4), "memory_space": "ub"}) + self.assertIn("v1 only supports rank-1 or rank-2 Tile shapes", str(rank_ctx.exception)) + self.assertIn(f"{__file__}:", str(rank_ctx.exception)) + + with self.assertRaises(pto.TileLangFrontendError) as space_ctx: + kernel.specialize(tile={"shape": (4, 4), "memory_space": "gm"}) + self.assertIn("v1 only supports MemorySpace.UB", str(space_ctx.exception)) + self.assertIn(f"{__file__}:", str(space_ctx.exception)) + + with self.assertRaises(pto.TileLangFrontendError) as valid_shape_ctx: + kernel.specialize(tile={"shape": (4, 4), "memory_space": "ub", "valid_shape": (5, 4)}) + self.assertIn("valid_shape axis 0=5 must be <= shape axis 0=4", str(valid_shape_ctx.exception)) + self.assertIn(f"{__file__}:", str(valid_shape_ctx.exception)) + + def test_slice_index_type_error_reports_template_source_location(self) -> None: + source = """ +import tilelang_dsl as pto + +@pto.inline_proc +def store_row(dst: pto.Tile, src: pto.Tile, row: pto.f32): + vec = pto.vlds(src[row, 0:]) + mask = pto.make_mask(dst.element_type, pto.PAT.ALL) + pto.vsts(vec, dst[row, 0:], mask) + return None + +@pto.vkernel(op="diag_index_type_unique", dtypes=[(pto.f32, pto.f32, pto.f32)]) +def kernel(dst: pto.Tile, src: pto.Tile, row: pto.f32): + store_row(dst, src, row) + return None +""" + with tempfile.TemporaryDirectory() as tmpdir: + module_path = Path(tmpdir) / "diag_index_type_kernel.py" + module_path.write_text(source, encoding="utf-8") + spec = util.spec_from_file_location("diag_index_type_kernel", module_path) + self.assertIsNotNone(spec) + self.assertIsNotNone(spec.loader) + module = util.module_from_spec(spec) + assert spec.loader is not None + spec.loader.exec_module(module) + + specialized = module.kernel.specialize( + dst=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + src=pto.TileSpecialization(shape=(8, 16), memory_space=pto.MemorySpace.UB), + ) + + with self.assertRaises(TypeError) as ctx: + specialized.mlir_text() + + message = str(ctx.exception) + self.assertIn(str(module_path), message) + self.assertIn(":6:", message) + self.assertIn("slice bounds and vector offsets must be index-typed", message) + + +if __name__ == "__main__": + unittest.main() diff --git a/tools/ptoas/CMakeLists.txt b/tools/ptoas/CMakeLists.txt index 96d841ca6..4e1aa7fb0 100644 --- a/tools/ptoas/CMakeLists.txt +++ b/tools/ptoas/CMakeLists.txt @@ -20,6 +20,8 @@ set(LLVM_LINK_COMPONENTS # 原因:LLVM 构建目录里已经有一个 ptoas 了,重名会导致 CMake 报错。 add_llvm_executable(pto-opt ptoas.cpp + VPTOHostStubEmission.cpp + VPTOFatobjEmission.cpp ) # ========================================================= # [新增] 魔法操作:修改最终输出的文件名为 "ptoas" @@ -27,6 +29,12 @@ add_llvm_executable(pto-opt set_target_properties(pto-opt PROPERTIES OUTPUT_NAME "ptoas") target_compile_definitions(pto-opt PRIVATE PTOAS_RELEASE_VERSION="${PTOAS_CLI_VERSION}" + # Source-tree defaults for TileLang DSL expansion. These let ptoas run + # directly from the build tree without passing --tilelang-path / + # --tilelang-pkg-path. Installed layouts that move these directories + # still need to override the flags explicitly. + PTOAS_DEFAULT_TILELANG_PATH="${CMAKE_SOURCE_DIR}/lib/TileOps" + PTOAS_DEFAULT_TILELANG_PKG_PATH="${CMAKE_SOURCE_DIR}/tilelang-dsl/python" ) # [修改 2] 更新链接库名称 # 原因:In-tree 时你的库叫 MLIRPTODialect,但现在 Out-of-tree 它们是你自己定义的 @@ -69,3 +77,8 @@ target_link_libraries(pto-opt PRIVATE add_dependencies(pto-opt PTOOpsIncGen ) + +install(TARGETS pto-opt + RUNTIME DESTINATION bin + COMPONENT PTOAS_Runtime +) diff --git a/tools/ptoas/VPTOFatobjEmission.cpp b/tools/ptoas/VPTOFatobjEmission.cpp new file mode 100644 index 000000000..0fb060a12 --- /dev/null +++ b/tools/ptoas/VPTOFatobjEmission.cpp @@ -0,0 +1,582 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "VPTOFatobjEmission.h" + +#include "PTO/Transforms/VPTOLLVMEmitter.h" + +#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/Path.h" +#include "llvm/Support/Process.h" +#include "llvm/Support/Program.h" +#include "llvm/Support/ToolOutputFile.h" +#include "llvm/TargetParser/Host.h" +#include "llvm/Support/raw_ostream.h" + +#include +#include +#include +#include + +namespace { + +using llvm::StringRef; + +static bool runCommandWithStderr(llvm::StringRef program, + llvm::ArrayRef ownedArgs, + llvm::StringRef stderrPath, + llvm::raw_ostream &diagOS, + llvm::StringRef what, + std::optional stdinPath = + std::nullopt); + +class TempFileRegistry { +public: + ~TempFileRegistry() { cleanup(); } + + void cleanup() { + for (const std::string &path : paths) + llvm::sys::fs::remove(path); + paths.clear(); + } + + bool create(StringRef prefix, StringRef suffix, std::string &path, + llvm::raw_ostream &diagOS) { + llvm::SmallString<128> tempPath; + int fd = -1; + std::error_code ec = llvm::sys::fs::createTemporaryFile(prefix, suffix, fd, + tempPath); + if (ec) { + diagOS << "Error: failed to create temporary file for " << prefix + << suffix << ": " << ec.message() << "\n"; + return false; + } + llvm::sys::Process::SafelyCloseFileDescriptor(fd); + path = tempPath.str().str(); + paths.push_back(path); + return true; + } + +private: + llvm::SmallVector paths; +}; + +static bool writeTextFile(StringRef path, StringRef content, + llvm::raw_ostream &diagOS) { + std::error_code ec; + llvm::raw_fd_ostream os(path, ec, llvm::sys::fs::OF_Text); + if (ec) { + diagOS << "Error: failed to open " << path << " for write: " + << ec.message() << "\n"; + return false; + } + os << content; + os.flush(); + return true; +} + +static bool writeLLVMModuleFile(llvm::Module &module, StringRef path, + llvm::raw_ostream &diagOS) { + std::error_code ec; + llvm::raw_fd_ostream os(path, ec, llvm::sys::fs::OF_Text); + if (ec) { + diagOS << "Error: failed to open " << path << " for write: " + << ec.message() << "\n"; + return false; + } + module.print(os, nullptr); + os.flush(); + return true; +} + +static std::string sanitizeModuleId(llvm::StringRef raw) { + std::string out; + out.reserve(raw.size()); + for (char c : raw) { + if (std::isalnum(static_cast(c)) || c == '_') + out.push_back(c); + else + out.push_back('_'); + } + if (out.empty()) + out = "ptoas_fatobj"; + return out; +} + +static std::optional getAscendHomePath() { + const char *env = std::getenv("ASCEND_HOME_PATH"); + if (!env || !*env) + return std::nullopt; + return std::string(env); +} + +static std::string joinPath(llvm::StringRef lhs, llvm::StringRef rhs) { + llvm::SmallString<256> joined(lhs); + llvm::sys::path::append(joined, rhs); + return std::string(joined.str()); +} + +static std::optional locateProgram(llvm::StringRef envPath, + llvm::StringRef fallbackName) { + if (!envPath.empty() && llvm::sys::fs::exists(envPath)) + return envPath.str(); + if (auto found = llvm::sys::findProgramByName(fallbackName)) + return *found; + return std::nullopt; +} + +class VPTOFatobjToolchain; + +static bool compileDeviceLLVMToObject(llvm::StringRef llPath, + llvm::StringRef outObjPath, + llvm::StringRef targetCPU, + llvm::StringRef bishengPath, + llvm::StringRef stderrPath, + llvm::raw_ostream &diagOS); +static bool compileHostStubToFatobj(llvm::StringRef stubPath, + llvm::StringRef outObjPath, + llvm::StringRef moduleId, + llvm::StringRef targetCPU, + const VPTOFatobjToolchain &toolchain, + llvm::StringRef deviceObjPath, + llvm::StringRef stderrPath, + llvm::raw_ostream &diagOS); +static bool mergeDeviceObjects(llvm::ArrayRef deviceObjPaths, + llvm::StringRef outObjPath, + llvm::StringRef ldLldPath, + llvm::StringRef stderrPath, + llvm::raw_ostream &diagOS); + +class VPTOFatobjToolchain { +public: + static std::optional + create(llvm::raw_ostream &diagOS) { + std::optional ascendHome = getAscendHomePath(); + if (!ascendHome) { + diagOS << "Error: ASCEND_HOME_PATH is required for VPTO fatobj emission.\n"; + return std::nullopt; + } + + VPTOFatobjToolchain toolchain(*ascendHome); + if (!toolchain.validate(diagOS)) + return std::nullopt; + return toolchain; + } + + const std::string &ascendHome() const { return ascendHomePath; } + const std::string &bisheng() const { return bishengPath; } + const std::string &bishengCc1() const { return bishengCc1Path; } + const std::string &cceLd() const { return cceLdPath; } + const std::string &ldLld() const { return ldLldPath; } + const std::string &resourceDir() const { return resourceDirPath; } + const std::string &resourceIncludeDir() const { + return resourceIncludeDirPath; + } + const std::string &cceStubDir() const { return cceStubDirPath; } + const std::string &bishengCompilerBinDir() const { + return bishengCompilerBinDirPath; + } + +private: + explicit VPTOFatobjToolchain(llvm::StringRef ascendHome) + : ascendHomePath(ascendHome.str()), + bishengPath(joinPath(ascendHomePath, "bin/bisheng")), + bishengCc1Path( + joinPath(ascendHomePath, "tools/bisheng_compiler/bin/bisheng")), + cceLdPath(joinPath(ascendHomePath, "bin/cce-ld")), + ldLldPath( + locateProgram(joinPath(ascendHomePath, "bin/ld.lld"), "ld.lld") + .value_or(std::string())), + resourceDirPath(joinPath( + ascendHomePath, "tools/bisheng_compiler/lib/clang/15.0.5")), + resourceIncludeDirPath(joinPath(resourceDirPath, "include")), + cceStubDirPath(joinPath(resourceIncludeDirPath, "cce_stub")), + bishengCompilerBinDirPath( + joinPath(ascendHomePath, "tools/bisheng_compiler/bin")) {} + + bool validate(llvm::raw_ostream &diagOS) const { + if (!llvm::sys::fs::exists(bishengPath)) { + diagOS << "Error: unable to locate bisheng: " << bishengPath << "\n"; + return false; + } + if (!llvm::sys::fs::exists(bishengCc1Path)) { + diagOS << "Error: unable to locate bisheng cc1 frontend: " + << bishengCc1Path << "\n"; + return false; + } + if (!llvm::sys::fs::exists(cceLdPath)) { + diagOS << "Error: unable to locate cce-ld: " << cceLdPath << "\n"; + return false; + } + if (ldLldPath.empty() || !llvm::sys::fs::exists(ldLldPath)) { + diagOS << "Error: unable to locate ld.lld.\n"; + return false; + } + return true; + } + + std::string ascendHomePath; + std::string bishengPath; + std::string bishengCc1Path; + std::string cceLdPath; + std::string ldLldPath; + std::string resourceDirPath; + std::string resourceIncludeDirPath; + std::string cceStubDirPath; + std::string bishengCompilerBinDirPath; +}; + +class VPTOFatobjArtifacts { +public: + explicit VPTOFatobjArtifacts(TempFileRegistry &tempFiles) + : tempFiles(tempFiles) {} + + bool emitStubSource(StringRef stubSource, llvm::raw_ostream &diagOS) { + if (!tempFiles.create("ptoas-host-stub", ".cpp", stubPath, diagOS)) + return false; + if (!writeTextFile(stubPath, stubSource, diagOS)) + return false; + return true; + } + + bool initCommandLogs(llvm::raw_ostream &diagOS) { + if (!tempFiles.create("ptoas-stderr", ".log", stderrPath, diagOS)) + return false; + return true; + } + + bool emitCubeObject(llvm::Module *module, + const VPTOFatobjToolchain &toolchain, + llvm::raw_ostream &diagOS) { + if (!module) + return true; + if (!tempFiles.create("ptoas-device", ".ll", cubeLLPath, diagOS)) + return false; + if (!writeLLVMModuleFile(*module, cubeLLPath, diagOS)) + return false; + if (!tempFiles.create("ptoas-device", ".o", cubeObjPath, diagOS)) + return false; + return compileDeviceLLVMToObject(cubeLLPath, cubeObjPath, + "dav-c310-cube", toolchain.bisheng(), + stderrPath, diagOS); + } + + bool emitVectorObject(llvm::Module *module, + const VPTOFatobjToolchain &toolchain, + llvm::raw_ostream &diagOS) { + if (!module) + return true; + if (!tempFiles.create("ptoas-device", ".ll", vectorLLPath, diagOS)) + return false; + if (!writeLLVMModuleFile(*module, vectorLLPath, diagOS)) + return false; + if (!tempFiles.create("ptoas-device", ".o", vectorObjPath, diagOS)) + return false; + return compileDeviceLLVMToObject(vectorLLPath, vectorObjPath, + "dav-c310-vec", toolchain.bisheng(), + stderrPath, diagOS); + } + + bool mergeDeviceObjects(const VPTOFatobjToolchain &toolchain, + llvm::raw_ostream &diagOS) { + llvm::SmallVector deviceObjPaths; + if (!cubeObjPath.empty()) + deviceObjPaths.push_back(cubeObjPath); + if (!vectorObjPath.empty()) + deviceObjPaths.push_back(vectorObjPath); + if (deviceObjPaths.empty()) { + diagOS << "Error: VPTO fatobj emission requires at least one device module.\n"; + return false; + } + if (!tempFiles.create("ptoas-device-merged", ".o", mergedDeviceObjPath, + diagOS)) + return false; + return ::mergeDeviceObjects(deviceObjPaths, mergedDeviceObjPath, + toolchain.ldLld(), stderrPath, diagOS); + } + + bool compileHostStub(const VPTOFatobjToolchain &toolchain, + llvm::StringRef moduleId, + llvm::StringRef targetCPU, + llvm::raw_ostream &diagOS) { + if (!tempFiles.create("ptoas-host-stub", ".o", hostStubObjPath, diagOS)) + return false; + return compileHostStubToFatobj(stubPath, hostStubObjPath, moduleId, + targetCPU, toolchain, mergedDeviceObjPath, + stderrPath, diagOS); + } + + bool repackFatObj(const VPTOFatobjToolchain &toolchain, + llvm::StringRef moduleId, llvm::StringRef targetCPU, + llvm::StringRef outPath, llvm::raw_ostream &diagOS) { + llvm::SmallVector args = { + toolchain.cceLd(), + toolchain.ldLld(), + "-x", + "-cce-lite-bin-module-id", + moduleId.str(), + std::string("-cce-aicore-arch=") + targetCPU.str(), + "-r", + "-o", + outPath.str(), + "-cce-stub-dir", + toolchain.cceStubDir(), + "-cce-install-dir", + toolchain.bishengCompilerBinDir(), + "-cce-inputs-number", + "1", + hostStubObjPath, + }; + return runCommandWithStderr(toolchain.cceLd(), args, stderrPath, diagOS, + "fatobj repack"); + } + +private: + TempFileRegistry &tempFiles; + std::string cubeLLPath; + std::string cubeObjPath; + std::string vectorLLPath; + std::string vectorObjPath; + std::string mergedDeviceObjPath; + std::string stderrPath; + std::string stubPath; + std::string hostStubObjPath; +}; + +static bool runCommandWithStderr(llvm::StringRef program, + llvm::ArrayRef ownedArgs, + llvm::StringRef stderrPath, + llvm::raw_ostream &diagOS, + llvm::StringRef what, + std::optional stdinPath) { + llvm::SmallVector args; + args.reserve(ownedArgs.size()); + for (const std::string &arg : ownedArgs) + args.push_back(arg); + llvm::SmallVector, 3> redirects = { + stdinPath, std::nullopt, stderrPath}; + + std::string execErr; + bool execFailed = false; + int rc = llvm::sys::ExecuteAndWait(program, args, std::nullopt, redirects, 0, + 0, &execErr, &execFailed); + if (!execFailed && rc == 0) + return true; + + diagOS << "Error: " << what << " failed\n"; + diagOS << "Command:"; + for (llvm::StringRef arg : args) + diagOS << " " << arg; + diagOS << "\n"; + if (!execErr.empty()) + diagOS << execErr << "\n"; + if (auto buffer = llvm::MemoryBuffer::getFile(stderrPath)) + diagOS << buffer.get()->getBuffer() << "\n"; + return false; +} + +static bool compileDeviceLLVMToObject(llvm::StringRef llPath, + llvm::StringRef outObjPath, + llvm::StringRef targetCPU, + llvm::StringRef bishengPath, + llvm::StringRef stderrPath, + llvm::raw_ostream &diagOS) { + llvm::SmallVector args = { + bishengPath.str(), + "--target=hiipu64-hisilicon-cce", + std::string("-march=") + targetCPU.str(), + std::string("--cce-aicore-arch=") + targetCPU.str(), + "--cce-aicore-only", + "-O2", + "-c", + "-x", + "ir", + "-", + "-o", + outObjPath.str(), + }; + return runCommandWithStderr(bishengPath, args, stderrPath, diagOS, + "device LLVM compilation", llPath); +} + +static bool compileHostStubToFatobj(llvm::StringRef stubPath, + llvm::StringRef outObjPath, + llvm::StringRef moduleId, + llvm::StringRef targetCPU, + const VPTOFatobjToolchain &toolchain, + llvm::StringRef deviceObjPath, + llvm::StringRef stderrPath, + llvm::raw_ostream &diagOS) { + std::string coverageDir = "."; + std::string debugDir = "."; + std::string hostTriple = llvm::sys::getProcessTriple(); + + llvm::SmallVector args = { + toolchain.bishengCc1(), + "-cc1", + "-triple", + hostTriple, + "-target-cpu", + llvm::sys::getHostCPUName().str(), + "-fcce-aicpu-legacy-launch", + "-fcce-is-host", + "-cce-enable-mix", + "-mllvm", + "-enable-mix=true", + "-cce-launch-with-flagv2-impl", + "-fcce-aicore-arch", + targetCPU.str(), + "-fcce-fatobj-compile", + "-emit-obj", + "--mrelax-relocations", + "-disable-free", + "-clear-ast-before-backend", + "-disable-llvm-verifier", + "-discard-value-names", + "-main-file-name", + "stub.cpp", + "-mrelocation-model", + "pic", + "-pic-level", + "2", + "-fhalf-no-semantic-interposition", + "-mframe-pointer=none", + "-fmath-errno", + "-ffp-contract=on", + "-fno-rounding-math", + "-mconstructor-aliases", + "-funwind-tables=2", + "-fallow-half-arguments-and-returns", + "-mllvm", + "-treat-scalable-fixed-error-as-warning", + std::string("-fcoverage-compilation-dir=") + coverageDir, + "-resource-dir", + toolchain.resourceDir(), + "-internal-isystem", + toolchain.resourceIncludeDir(), + "-include", + "__clang_cce_runtime_wrapper.h", + "-D", + "_FORTIFY_SOURCE=2", + "-D", + "REGISTER_BASE", + "-O2", + "-Wno-macro-redefined", + "-Wno-ignored-attributes", + "-std=c++17", + "-fdeprecated-macro", + std::string("-fdebug-compilation-dir=") + debugDir, + "-ferror-limit", + "19", + "-stack-protector", + "2", + "-fno-signed-char", + "-fgnuc-version=4.2.1", + "-fcxx-exceptions", + "-fexceptions", + "-vectorize-loops", + "-vectorize-slp", + "-mllvm", + "-cce-aicore-stack-size=0x8000", + "-mllvm", + "-cce-aicore-function-stack-size=0x8000", + "-mllvm", + "-cce-aicore-record-overflow=true", + "-mllvm", + "-cce-aicore-addr-transform", + "-mllvm", + "-cce-aicore-dcci-insert-for-scalar=false", + "-fcce-include-aibinary", + deviceObjPath.str(), + "-fcce-device-module-id", + moduleId.str(), + "-faddrsig", + "-D__GCC_HAVE_DWARF2_CFI_ASM=1", + "-o", + outObjPath.str(), + "-x", + "cce", + stubPath.str(), + }; + return runCommandWithStderr(toolchain.bishengCc1(), args, stderrPath, diagOS, + "host stub compilation"); +} + +static bool mergeDeviceObjects(llvm::ArrayRef deviceObjPaths, + llvm::StringRef outObjPath, + llvm::StringRef ldLldPath, + llvm::StringRef stderrPath, + llvm::raw_ostream &diagOS) { + if (deviceObjPaths.empty()) + return false; + + llvm::SmallVector args = { + ldLldPath.str(), + "-m", + "aicorelinux", + "-Ttext", + "0", + }; + for (const std::string &path : deviceObjPaths) + args.push_back(path); + args.push_back("-o"); + args.push_back(outObjPath.str()); + args.push_back("-r"); + args.push_back("--allow-multiple-definition"); + return runCommandWithStderr(ldLldPath, args, stderrPath, diagOS, + "device object merge"); +} + +} // namespace + +mlir::LogicalResult mlir::pto::emitVPTOFatobj(llvm::Module *cubeModule, + llvm::Module *vectorModule, + llvm::StringRef stubSource, + llvm::ToolOutputFile &outputFile, + llvm::raw_ostream &diagOS) { + if (!cubeModule && !vectorModule) { + diagOS << "Error: VPTO fatobj emission requires at least one LLVM module.\n"; + return failure(); + } + + std::optional toolchain = + VPTOFatobjToolchain::create(diagOS); + if (!toolchain) + return failure(); + + TempFileRegistry tempFiles; + VPTOFatobjArtifacts artifacts(tempFiles); + if (!artifacts.emitStubSource(stubSource, diagOS)) + return failure(); + if (!artifacts.initCommandLogs(diagOS)) + return failure(); + + if (!artifacts.emitCubeObject(cubeModule, *toolchain, diagOS)) + return failure(); + if (!artifacts.emitVectorObject(vectorModule, *toolchain, diagOS)) + return failure(); + + if (!artifacts.mergeDeviceObjects(*toolchain, diagOS)) + return failure(); + + std::string moduleId = sanitizeModuleId(outputFile.getFilename()); + constexpr llvm::StringLiteral hostTargetCPU = "dav-c310"; + if (!artifacts.compileHostStub(*toolchain, moduleId, hostTargetCPU, diagOS)) + return failure(); + + if (!artifacts.repackFatObj(*toolchain, moduleId, hostTargetCPU, + outputFile.getFilename(), diagOS)) + return failure(); + outputFile.keep(); + return success(); +} diff --git a/tools/ptoas/VPTOFatobjEmission.h b/tools/ptoas/VPTOFatobjEmission.h new file mode 100644 index 000000000..fd161ba60 --- /dev/null +++ b/tools/ptoas/VPTOFatobjEmission.h @@ -0,0 +1,31 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef PTOAS_VPTO_FATOBJ_EMISSION_H +#define PTOAS_VPTO_FATOBJ_EMISSION_H + +#include "llvm/ADT/StringRef.h" +#include "mlir/Support/LogicalResult.h" + +namespace llvm { +class ToolOutputFile; +class Module; +class raw_ostream; +} + +namespace mlir::pto { + +LogicalResult emitVPTOFatobj(llvm::Module *cubeModule, + llvm::Module *vectorModule, + llvm::StringRef stubSource, + llvm::ToolOutputFile &outputFile, + llvm::raw_ostream &diagOS); + +} // namespace mlir::pto + +#endif diff --git a/tools/ptoas/VPTOHostStubEmission.cpp b/tools/ptoas/VPTOHostStubEmission.cpp new file mode 100644 index 000000000..ae56d85f8 --- /dev/null +++ b/tools/ptoas/VPTOHostStubEmission.cpp @@ -0,0 +1,137 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "VPTOHostStubEmission.h" + +#include "PTO/IR/PTO.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/Support/raw_ostream.h" + +#include + +using namespace mlir; + +namespace { + +static bool hasVPTOKernelAttr(Operation *op) { + return op->hasAttr("pto.kernel") || op->hasAttr("pto.aicore"); +} + +struct VPTOKernelStubDecl { + std::string logicalName; + SmallVector argTypes; +}; + +static std::string getLogicalKernelName(llvm::StringRef symbol) { + if (symbol.ends_with("_mix_aiv")) + return symbol.drop_back(strlen("_mix_aiv")).str(); + if (symbol.ends_with("_mix_aic")) + return symbol.drop_back(strlen("_mix_aic")).str(); + return symbol.str(); +} + +static std::string getStubScalarCType(Type type) { + if (isa(type)) + return "long long"; + if (auto intType = dyn_cast(type)) { + switch (intType.getWidth()) { + case 1: + case 8: + return "signed char"; + case 16: + return "short"; + case 32: + return "int"; + case 64: + return "long long"; + default: + return "long long"; + } + } + if (auto floatType = dyn_cast(type)) { + if (floatType.isF32()) + return "float"; + if (floatType.isF64()) + return "double"; + return "short"; + } + return "long long"; +} + +static std::string getStubCType(Type type) { + if (isa(type)) + return "__gm__ void *"; + return getStubScalarCType(type); +} + +} // namespace + +static LogicalResult collectVPTOKernelStubDecls( + ModuleOp module, SmallVectorImpl &decls, + llvm::raw_ostream &diagOS) { + bool hadError = false; + llvm::StringMap logicalNameToIndex; + + module.walk([&](func::FuncOp func) { + if (func.isExternal() || !hasVPTOKernelAttr(func)) + return; + + std::string logicalName = getLogicalKernelName(func.getSymName()); + SmallVector argTypes; + argTypes.reserve(func.getNumArguments()); + for (Type type : func.getArgumentTypes()) + argTypes.push_back(getStubCType(type)); + + auto [it, inserted] = + logicalNameToIndex.try_emplace(logicalName, decls.size()); + if (inserted) { + decls.push_back(VPTOKernelStubDecl{logicalName, std::move(argTypes)}); + return; + } + + VPTOKernelStubDecl &existing = decls[it->second]; + if (existing.argTypes != argTypes) { + diagOS << "Error: mixed kernel variants disagree on host stub signature " + << "for '" << logicalName << "'.\n"; + hadError = true; + } + }); + + return hadError ? failure() : success(); +} + +LogicalResult mlir::pto::emitVPTOHostStubSource(ModuleOp module, + std::string &stubSource, + llvm::raw_ostream &diagOS) { + SmallVector stubDecls; + if (failed(collectVPTOKernelStubDecls(module, stubDecls, diagOS))) + return failure(); + + if (stubDecls.empty()) { + diagOS << "Error: no pto.kernel functions found for host stub emission.\n"; + return failure(); + } + + stubSource.clear(); + llvm::raw_string_ostream os(stubSource); + os << "#ifndef __global__\n#define __global__\n#endif\n\n"; + os << "#ifndef __gm__\n#define __gm__\n#endif\n\n"; + for (const VPTOKernelStubDecl &decl : stubDecls) { + os << "extern \"C\" __global__ [aicore] void " << decl.logicalName << "("; + for (size_t i = 0; i < decl.argTypes.size(); ++i) { + if (i) + os << ", "; + os << decl.argTypes[i] << " arg" << i; + } + os << ") {}\n"; + } + os.flush(); + return success(); +} diff --git a/tools/ptoas/VPTOHostStubEmission.h b/tools/ptoas/VPTOHostStubEmission.h new file mode 100644 index 000000000..7a091a503 --- /dev/null +++ b/tools/ptoas/VPTOHostStubEmission.h @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef PTOAS_VPTO_HOST_STUB_EMISSION_H +#define PTOAS_VPTO_HOST_STUB_EMISSION_H + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" + +#include + +namespace llvm { +class raw_ostream; +} + +namespace mlir::pto { + +LogicalResult emitVPTOHostStubSource(ModuleOp module, std::string &stubSource, + llvm::raw_ostream &diagOS); + +} // namespace mlir::pto + +#endif diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index 1d40e1c43..360f44443 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -7,9 +7,13 @@ // See LICENSE in the root of the software repository for the full text of the License. #include "PTO/IR/PTO.h" +#include "PTO/Transforms/VPTOLLVMEmitter.h" #include "PTO/Transforms/Passes.h" #include "PTO/Transforms/BufferizableOpInterfaceImpl.h" +#include "VPTOFatobjEmission.h" +#include "VPTOHostStubEmission.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Diagnostics.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllPasses.h" @@ -27,6 +31,7 @@ #include "llvm/Support/SourceMgr.h" #include "llvm/Support/ToolOutputFile.h" #include "llvm/Support/FileSystem.h" // [Fix] Required for OF_None +#include "llvm/Support/Path.h" #include "ptobc/ptobc_decode.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" @@ -42,6 +47,7 @@ #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/ADT/StringMap.h" +#include "llvm/Support/MemoryBuffer.h" #include #include @@ -69,10 +75,77 @@ using StringRefVector = } // namespace +int main(int argc, char **argv); + static void printPTOASVersion(llvm::raw_ostream &os) { os << "ptoas " << PTOAS_RELEASE_VERSION << "\n"; } +static std::string getParentDir(llvm::StringRef path) { + llvm::SmallString<256> parent(path); + llvm::sys::path::remove_filename(parent); + llvm::sys::path::remove_dots(parent, true); + return std::string(parent); +} + +static bool pathExists(llvm::StringRef path) { + return !path.empty() && llvm::sys::fs::exists(path); +} + +static std::string joinPath(llvm::StringRef lhs, llvm::StringRef rhs) { + llvm::SmallString<256> joined(lhs); + llvm::sys::path::append(joined, rhs); + llvm::sys::path::remove_dots(joined, true); + return std::string(joined); +} + +static std::string detectInstalledTilelangPath(const char *argv0) { + std::string exePath = llvm::sys::fs::getMainExecutable(argv0, (void *)&main); + if (exePath.empty()) + return {}; + + const std::string exeDir = getParentDir(exePath); + const std::string prefixDir = getParentDir(exeDir); + const std::string installedTileOps = joinPath(prefixDir, "share/ptoas/TileOps"); + if (pathExists(installedTileOps)) + return installedTileOps; + return {}; +} + +static std::string detectInstalledTilelangPkgPath(const char *argv0) { + std::string exePath = llvm::sys::fs::getMainExecutable(argv0, (void *)&main); + if (exePath.empty()) + return {}; + + const std::string exeDir = getParentDir(exePath); + const std::string prefixDir = getParentDir(exeDir); + const std::string installedPkgRoot = prefixDir; + const std::string installedPkg = joinPath(installedPkgRoot, "tilelang_dsl"); + if (pathExists(installedPkg)) + return installedPkgRoot; + return {}; +} + +static bool hasCLIOption(int argc, char **argv, llvm::StringRef option) { + const std::string optionWithValue = (option + "=").str(); + for (int i = 1; i < argc; ++i) { + llvm::StringRef arg(argv[i]); + if (arg == option || arg.starts_with(optionWithValue)) + return true; + } + return false; +} + +static LogicalResult applyConfiguredPassManagerCLOptions( + PassManager &pm, llvm::StringRef pipelineName, + llvm::raw_ostream &diagOS = llvm::errs()) { + if (succeeded(mlir::applyPassManagerCLOptions(pm))) + return success(); + diagOS << "Error: failed to apply MLIR pass manager command-line options for " + << pipelineName << ".\n"; + return failure(); +} + static LogicalResult reorderEmitCFunctions(ModuleOp module) { SmallVector declarations; SmallVector definitions; @@ -222,6 +295,53 @@ static llvm::cl::opt graphSyncSolverEventIdMax( "Lower values exercise the PIPE_ALL coloring fallback sooner."), llvm::cl::init(kDefaultGraphSyncSolverEventIdMax)); +static llvm::cl::opt enableTileOpExpand( + "enable-tile-op-expand", + llvm::cl::desc( + "Deprecated compatibility flag. TileOp expansion is controlled by " + "--pto-backend=vpto."), + llvm::cl::init(false)); + +#ifndef PTOAS_DEFAULT_TILELANG_PATH +#define PTOAS_DEFAULT_TILELANG_PATH "" +#endif +#ifndef PTOAS_DEFAULT_TILELANG_PKG_PATH +#define PTOAS_DEFAULT_TILELANG_PKG_PATH "" +#endif + +static llvm::cl::opt tilelangPath( + "tilelang-path", + llvm::cl::desc("Path to directory of .py tilelang DSL template files " + "(default: /lib/TileOps, baked in at build time)"), + llvm::cl::init(PTOAS_DEFAULT_TILELANG_PATH)); + +static llvm::cl::opt tilelangPkgPath( + "tilelang-pkg-path", + llvm::cl::desc("PYTHONPATH for tilelang_dsl package " + "(default: /tilelang-dsl/python, baked in at build time)"), + llvm::cl::init(PTOAS_DEFAULT_TILELANG_PKG_PATH)); + +static pto::ExpandTileOpOptions resolveExpandTileOpOptions(int argc, + char **argv) { + pto::ExpandTileOpOptions expandOpts; + expandOpts.tilelangPath = tilelangPath; + expandOpts.tilelangPkgPath = tilelangPkgPath; + + if (!hasCLIOption(argc, argv, "--tilelang-path")) { + std::string detectedTilelangPath = detectInstalledTilelangPath(argv[0]); + if (!detectedTilelangPath.empty()) + expandOpts.tilelangPath = detectedTilelangPath; + } + + if (!hasCLIOption(argc, argv, "--tilelang-pkg-path")) { + std::string detectedTilelangPkgPath = detectInstalledTilelangPkgPath(argv[0]); + if (!detectedTilelangPkgPath.empty()) + expandOpts.tilelangPkgPath = detectedTilelangPkgPath; + } + + return expandOpts; +} + static llvm::cl::opt disableInferLayout( "disable-infer-layout", llvm::cl::desc("Disable PTO layout inference pass (static-only)"), @@ -249,12 +369,54 @@ static llvm::cl::opt ptoBuildLevel( llvm::cl::value_desc("level1|level2|level3"), llvm::cl::init("level2")); +static llvm::cl::opt ptoBackend( + "pto-backend", + llvm::cl::desc("Final PTOAS backend: emitc or vpto (default: emitc)"), + llvm::cl::value_desc("emitc|vpto"), llvm::cl::init("emitc")); + +static llvm::cl::opt emitVPTO( + "emit-vpto", + llvm::cl::desc("Write final post-pass VPTO IR to -o"), + llvm::cl::init(false)); + +static llvm::cl::opt vptoPrintIR( + "vpto-print-ir", + llvm::cl::desc("Print post-pass VPTO backend IR to stderr"), + llvm::cl::init(false)); + +static llvm::cl::opt vptoLoweringStrategy( + "vpto-lowering-strategy", + llvm::cl::desc("VPTO vector lowering strategy: post-update or no-post-update"), + llvm::cl::value_desc("post-update|no-post-update"), + llvm::cl::init("post-update")); + +static llvm::cl::opt dumpVPTOIR( + "dump-vpto-ir", + llvm::cl::desc("Print post-pass VPTO backend IR to stderr"), + llvm::cl::init(false)); + +static llvm::cl::opt ptoPrintSeamIR( + "pto-print-seam-ir", + llvm::cl::desc("Print shared pre-backend seam IR to stderr"), + llvm::cl::init(false)); + +static llvm::cl::opt ptoSeamIRFile( + "pto-seam-ir-file", + llvm::cl::desc("Write shared pre-backend seam IR to a file"), + llvm::cl::value_desc("path"), + llvm::cl::init("")); + enum class PTOBuildLevel { Level1, Level2, Level3, }; +enum class PTOBackend { + EmitC, + VPTO, +}; + static PTOBuildLevel defaultBuildLevel() { return PTOBuildLevel::Level2; } @@ -300,6 +462,69 @@ static bool parseAutoSyncTailHint(llvm::StringRef hintStr, std::string &normaliz return false; } +static bool parseBackend(llvm::StringRef backendStr, PTOBackend &out) { + std::string s = backendStr.str(); + for (char &c : s) + c = static_cast(std::tolower(static_cast(c))); + if (s == "emitc") { + out = PTOBackend::EmitC; + return true; + } + if (s == "vpto") { + out = PTOBackend::VPTO; + return true; + } + return false; +} + +static LogicalResult emitSharedPreBackendSeamIR(ModuleOp module, + llvm::StringRef outputPath) { + if (outputPath.empty()) + return success(); + + if (outputPath == "-") { + module->print(llvm::outs()); + llvm::outs() << "\n"; + llvm::outs().flush(); + return success(); + } + + std::error_code ec; + llvm::ToolOutputFile outputFile(outputPath, ec, llvm::sys::fs::OF_None); + if (ec) { + llvm::errs() << "Error: failed to open seam IR file '" << outputPath + << "': " << ec.message() << "\n"; + return failure(); + } + + module->print(outputFile.os()); + outputFile.os() << "\n"; + outputFile.keep(); + return success(); +} + +static bool hasUnexpandedTileOps(ModuleOp module) { + bool found = false; + module.walk([&](Operation *op) { + if (found) + return; + if (isa(op)) + found = true; + }); + return found; +} + +static bool hasTilelangInlineHelpers(ModuleOp module) { + bool found = false; + module.walk([&](func::FuncOp func) { + if (found) + return; + if (func->hasAttr("pto.tilelang.inline_proc")) + found = true; + }); + return found; +} + // -------------------------------------------------------------------------- // Post-process C++ output: rewrite marker calls into Tile member calls. // We emit marker calls in EmitC IR because EmitC currently does not provide a @@ -957,6 +1182,117 @@ static bool shouldDeclareVariablesAtTop(ModuleOp module) { llvm::any_of(module.getOps(), hasMultiBlockFunc); } +static void prepareVPTOForEmission(PassManager &pm) { + auto &kernelModulePM = pm.nest(); + kernelModulePM.addPass(createCanonicalizerPass()); + kernelModulePM.addPass(createCSEPass()); + kernelModulePM.addPass(pto::createVPTOPtrNormalizePass()); + kernelModulePM.addPass(pto::createVPTOPtrCastCleanupPass()); + kernelModulePM.addPass(createReconcileUnrealizedCastsPass()); + kernelModulePM.addNestedPass( + createVPTOExpandWrapperOpsPass()); + kernelModulePM.addPass(createCSEPass()); + kernelModulePM.addNestedPass( + pto::createPTOInferVPTOVecScopePass()); + kernelModulePM.addPass(createCanonicalizerPass()); + kernelModulePM.addPass(createCSEPass()); + kernelModulePM.addPass(pto::createPTOValidateVPTOEmissionIRPass()); +} + +static void lowerPTOToVPTOBackend(PassManager &pm, int argc, char **argv) { + // TileOp Expand path: + // 1. ExpandTileOp: instantiate TileLang DSL templates, replace tile ops + // with func.call to template functions (tile_buf params) + // 2. InlineLibCall: inline template function bodies + // 3. FoldTileBufIntrinsics: fold tile_buf_addr / tile_valid_rows / + // tile_valid_cols to concrete memref/constant values + auto &kernelModulePM = pm.nest(); + pto::ExpandTileOpOptions expandOpts = resolveExpandTileOpOptions(argc, argv); + kernelModulePM.addPass(pto::createExpandTileOpPass(expandOpts)); + + kernelModulePM.addPass(pto::createPTOInlineLibCallPass()); + kernelModulePM.addNestedPass( + pto::createFoldTileBufIntrinsicsPass()); + // FoldTileBufIntrinsics materializes many constant branch conditions. + // Clean them up immediately on the TileOp expansion path before the + // authoring-stage VPTO verifier and let the existing CSE passes remove the + // resulting dead values later in the pipeline. + kernelModulePM.addPass(mlir::createSCCPPass()); + kernelModulePM.addPass(mlir::createCanonicalizerPass()); +} + +static void inlineTilelangHelpersOnVPTOInput(PassManager &pm) { + auto &kernelModulePM = pm.nest(); + kernelModulePM.addPass(pto::createPTOInlineLibCallPass()); + kernelModulePM.addPass(mlir::createSCCPPass()); + kernelModulePM.addPass(mlir::createCanonicalizerPass()); +} + +static pto::VPTOEmissionOptions buildVPTOEmissionOptions() { + pto::VPTOEmissionOptions options; + options.dumpVPTOIR = false; + options.targetTriple = "hiipu64-hisilicon-cce"; + return options; +} + +static int emitVPTOBackendResult(ModuleOp module, + llvm::ToolOutputFile &outputFile) { + if (emitVPTO) { + module.print(outputFile.os()); + outputFile.os() << "\n"; + outputFile.keep(); + return 0; + } + + pto::VPTOEmissionOptions options = buildVPTOEmissionOptions(); + std::string stubSource; + if (failed(pto::emitVPTOHostStubSource(module, stubSource, llvm::errs()))) { + llvm::errs() << "Error: Failed to emit VPTO host stub source.\n"; + return 1; + } + + pto::EmittedLLVMModule cubeModule; + pto::EmittedLLVMModule vectorModule; + if (failed( + pto::lowerVPTOModuleToLLVMModules(module, options, cubeModule, + vectorModule, llvm::errs()))) { + llvm::errs() << "Error: Failed to lower VPTO to LLVM modules.\n"; + return 1; + } + + if (failed(pto::emitVPTOFatobj(cubeModule.module.get(), + vectorModule.module.get(), stubSource, + outputFile, llvm::errs()))) { + llvm::errs() << "Error: Failed to emit VPTO fatobj.\n"; + return 1; + } + outputFile.keep(); + return 0; +} + +static LogicalResult runVPTOBackendPipeline(OwningOpRef &module, + int argc, char **argv, + bool hasTileOpsToExpand, + bool hasTilelangHelpers) { + PassManager pm(module->getContext()); + pm.enableVerifier(); + pm.addPass(pto::createVPTOSplitCVModulePass()); + pm.addPass(pto::createVPTONormalizeContainerPass()); + if (!hasTileOpsToExpand && hasTilelangHelpers) + inlineTilelangHelpersOnVPTOInput(pm); + if (hasTileOpsToExpand) + lowerPTOToVPTOBackend(pm, argc, argv); + prepareVPTOForEmission(pm); + if (failed(applyConfiguredPassManagerCLOptions( + pm, "VPTO unified emission pipeline"))) + return failure(); + if (failed(pm.run(module.get()))) { + llvm::errs() << "Error: VPTO emission pipeline failed.\n"; + return failure(); + } + return success(); +} + int main(int argc, char **argv) { DialectRegistry registry; registry.insert(); @@ -975,27 +1311,40 @@ int main(int argc, char **argv) { registry.insert(); registry.insert(); + mlir::registerAllPasses(); + ::registerPTOPasses(); + mlir::pto::registerPTOViewToMemrefPass(); + ::registerPTOInlineLibCall(); + ::registerFoldTileBufIntrinsics(); + ::registerExpandTileOp(); + mlir::registerPassManagerCLOptions(); llvm::cl::SetVersionPrinter(printPTOASVersion); bool cliArchSpecified = false; for (int i = 1; i < argc; ++i) { llvm::StringRef arg(argv[i]); - if (arg == "--pto-arch" || arg.starts_with("--pto-arch=")) { + if (arg == "--pto-arch" || arg.starts_with("--pto-arch=")) cliArchSpecified = true; - break; - } } - // Register all passes so that --mlir-print-ir-after/before can resolve - // pass names like 'cse' at option-parse time. - mlir::registerAllPasses(); - registerPTOPasses(); - // Parse command line options - mlir::registerPassManagerCLOptions(); llvm::cl::ParseCommandLineOptions(argc, argv, "PTO Assembler (ptoas)\n"); + PTOBackend effectiveBackend = PTOBackend::EmitC; + if (!parseBackend(ptoBackend, effectiveBackend)) { + llvm::errs() << "Error: invalid --pto-backend='" << ptoBackend + << "'. Expected 'emitc' or 'vpto'.\n"; + return 1; + } + + if (effectiveBackend != PTOBackend::VPTO && + (emitVPTO || ptoPrintSeamIR || !ptoSeamIRFile.empty())) { + llvm::errs() << "Error: VPTO-specific flags require " + "--pto-backend=vpto.\n"; + return 1; + } + // Read whole input first (so we can auto-detect .ptobc by magic). auto fileOrErr = llvm::MemoryBuffer::getFileOrSTDIN(inputFilename); if (!fileOrErr) { @@ -1020,6 +1369,7 @@ int main(int argc, char **argv) { OwningOpRef module; llvm::StringRef buf = (*fileOrErr)->getBuffer(); const bool isPTOBC = (buf.size() >= 6 && std::memcmp(buf.data(), "PTOBC\0", 6) == 0); + auto normalizeArch = [](llvm::StringRef archValue) { std::string normalized = archValue.str(); for (char &c : normalized) @@ -1175,6 +1525,29 @@ int main(int argc, char **argv) { return 1; } + // [Fix] ToolOutputFile Usage + std::error_code ec; + llvm::ToolOutputFile outputFile(outputFilename, ec, llvm::sys::fs::OF_None); + if (ec) { + llvm::errs() << ec.message() << "\n"; + return 1; + } + + const bool hasTileOpsToExpand = hasUnexpandedTileOps(*module); + const bool hasTilelangHelpers = hasTilelangInlineHelpers(*module); + + if (effectiveBackend == PTOBackend::VPTO && !hasTileOpsToExpand) { + if (ptoPrintSeamIR || !ptoSeamIRFile.empty()) { + llvm::errs() << "Error: shared pre-backend seam IR is unavailable when " + "skipping the shared PTO-to-VPTO lowering pipeline.\n"; + return 1; + } + if (failed(runVPTOBackendPipeline(module, argc, argv, hasTileOpsToExpand, + hasTilelangHelpers))) + return 1; + return emitVPTOBackendResult(module.get(), outputFile); + } + // Main PassManager PassManager pm(&context); @@ -1219,34 +1592,45 @@ int main(int argc, char **argv) { pto::createPTOGraphSyncSolverPass(graphSyncOpts)); } - - - std::unique_ptr outputFile; - llvm::raw_ostream *outputOS = &llvm::outs(); - if (outputFilename != "-") { - std::error_code ec; - outputFile = std::make_unique( - outputFilename, ec, llvm::sys::fs::OF_None); - if (ec) { - llvm::errs() << ec.message() << "\n"; - return 1; - } - outputOS = &outputFile->os(); - } - if (emitMlirIR) { if (failed(pm.run(*module))) { llvm::errs() << "Error: Pass execution failed.\n"; return 1; } - module->print(*outputOS); - if (outputFile) - outputFile->keep(); + module->print(outputFile.os()); + outputFile.keep(); return 0; } + // Reintroduce tile-native handles once on the shared mainline so both + // backends consume the same post-planning seam IR. pm.addPass(pto::createPTOMaterializeTileHandlesPass()); pm.addPass(createCSEPass()); + if (failed(applyConfiguredPassManagerCLOptions(pm, "main PTOAS pipeline"))) + return 1; + + module->getOperation()->setAttr("pto.target_arch", + mlir::StringAttr::get(&context, arch)); + + if (effectiveBackend == PTOBackend::VPTO) { + if (failed(pm.run(*module))) { + llvm::errs() << "Error: Pass execution failed.\n"; + return 1; + } + + if (ptoPrintSeamIR) { + module->print(llvm::errs()); + llvm::errs() << "\n"; + } + if (failed(emitSharedPreBackendSeamIR(*module, ptoSeamIRFile))) + return 1; + + if (failed(runVPTOBackendPipeline(module, argc, argv, hasTileOpsToExpand, + hasTilelangHelpers))) + return 1; + return emitVPTOBackendResult(module.get(), outputFile); + } + if (arch == "a3") { pm.addPass(pto::createEmitPTOManualPass(pto::PTOArch::A3)); } else { @@ -1288,11 +1672,10 @@ int main(int argc, char **argv) { rewriteScalarConstantDecls(cppOutput); rewriteHoistedGlobalTensorDecls(cppOutput); - *outputOS << cppOutput; - outputOS->flush(); + outputFile.os() << cppOutput; + outputFile.os().flush(); - if (outputFile) - outputFile->keep(); // Success, keep the file + outputFile.keep(); // Success, keep the file return 0; }